mirror of
https://github.com/b3nguang/Ollama-Scan.git
synced 2025-05-05 10:06:54 +00:00
Enhance Ollama Shell with robust error handling, new commands, and improved server connection management
This commit is contained in:
parent
435d55953a
commit
5c51e532d9
Binary file not shown.
Before Width: | Height: | Size: 72 KiB |
Binary file not shown.
Before Width: | Height: | Size: 38 KiB |
Binary file not shown.
Before Width: | Height: | Size: 126 KiB |
179
main.py
179
main.py
@ -7,7 +7,9 @@
|
||||
import argparse
|
||||
import re
|
||||
import sys
|
||||
from typing import List
|
||||
from typing import List, Tuple
|
||||
import logging
|
||||
import subprocess
|
||||
|
||||
from ollama import Client
|
||||
from prompt_toolkit import PromptSession
|
||||
@ -18,11 +20,27 @@ from rich.panel import Panel
|
||||
from rich.progress import Progress, SpinnerColumn, TextColumn
|
||||
from rich.style import Style
|
||||
from rich.table import Table
|
||||
from httpx import Timeout, HTTPError
|
||||
|
||||
|
||||
class OllamaShell:
|
||||
def __init__(self, host: str = "http://112.117.14.179:11434/"):
|
||||
self.client = Client(host=host)
|
||||
def __init__(self, host: str = None):
|
||||
if not host:
|
||||
raise ValueError("必须提供 Ollama 服务器地址")
|
||||
if not host.startswith(("http://", "https://")):
|
||||
raise ValueError("服务器地址必须以 http:// 或 https:// 开头")
|
||||
|
||||
# 保存 host 地址
|
||||
self.host = host
|
||||
|
||||
# 根据协议决定是否验证证书
|
||||
self.verify_ssl = not (host.startswith("https://") and ":" in host.split("://")[1].split("/")[0])
|
||||
|
||||
self.client = Client(
|
||||
host=host,
|
||||
timeout=Timeout(30.0),
|
||||
verify=self.verify_ssl
|
||||
)
|
||||
self.console = Console()
|
||||
self.commands = {
|
||||
"list": (self.list_models, "📃 列出可用模型"),
|
||||
@ -32,6 +50,8 @@ class OllamaShell:
|
||||
"ps": (self.show_processes, "⚡️ 显示运行中的模型"),
|
||||
"help": (self.show_help, "❓ 显示帮助信息"),
|
||||
"exit": (self.exit_shell, "🚪 退出程序"),
|
||||
"rm": (self.delete_model, "🗑️ 删除指定模型"),
|
||||
"version": (self.show_version, "📌 显示版本信息"),
|
||||
}
|
||||
|
||||
def list_models(self, *args: List[str]) -> None:
|
||||
@ -108,8 +128,15 @@ class OllamaShell:
|
||||
|
||||
self.console.print(table)
|
||||
|
||||
except ConnectionError:
|
||||
self.console.print("[red]连接服务器失败[/red]")
|
||||
except TimeoutError:
|
||||
self.console.print("[red]请求超时[/red]")
|
||||
except HTTPError as e:
|
||||
self.console.print(f"[red]HTTP 错误: {e.response.status_code}[/red]")
|
||||
except Exception as e:
|
||||
self.console.print(f"[red]错误: {str(e)}[/red]")
|
||||
self.console.print("[red]发生未知错误[/red]")
|
||||
logging.error(f"Unexpected error: {str(e)}")
|
||||
|
||||
def pull_model(self, *args: List[str]) -> None:
|
||||
"""拉取指定的模型"""
|
||||
@ -118,6 +145,11 @@ class OllamaShell:
|
||||
return
|
||||
|
||||
model_name = args[0]
|
||||
# 修改模型名称验证,允许更多字符
|
||||
if not re.match(r'^[a-zA-Z0-9_\-\./:]+$', model_name):
|
||||
self.console.print("[red]错误: 模型名称包含非法字符[/red]")
|
||||
return
|
||||
|
||||
self.console.print(f"\n[bold]📥 开始拉取模型: {model_name}[/bold]")
|
||||
|
||||
try:
|
||||
@ -133,8 +165,15 @@ class OllamaShell:
|
||||
)
|
||||
self.console.print("[green]✅ 模型拉取完成![/green]")
|
||||
|
||||
except ConnectionError:
|
||||
self.console.print("[red]连接服务器失败[/red]")
|
||||
except TimeoutError:
|
||||
self.console.print("[red]请求超时[/red]")
|
||||
except HTTPError as e:
|
||||
self.console.print(f"[red]HTTP 错误: {e.response.status_code}[/red]")
|
||||
except Exception as e:
|
||||
self.console.print(f"[red]错误: {str(e)}[/red]")
|
||||
self.console.print("[red]发生未知错误[/red]")
|
||||
logging.error(f"Unexpected error: {str(e)}")
|
||||
|
||||
def show_model(self, *args: List[str]) -> None:
|
||||
"""显示模型详细信息"""
|
||||
@ -179,8 +218,15 @@ class OllamaShell:
|
||||
)
|
||||
self.console.print(panel)
|
||||
|
||||
except ConnectionError:
|
||||
self.console.print("[red]连接服务器失败[/red]")
|
||||
except TimeoutError:
|
||||
self.console.print("[red]请求超时[/red]")
|
||||
except HTTPError as e:
|
||||
self.console.print(f"[red]HTTP 错误: {e.response.status_code}[/red]")
|
||||
except Exception as e:
|
||||
self.console.print(f"[red]错误: {str(e)}[/red]")
|
||||
self.console.print("[red]发生未知错误[/red]")
|
||||
logging.error(f"Unexpected error: {str(e)}")
|
||||
|
||||
def show_processes(self, *args: List[str]) -> None:
|
||||
"""显示运行中的模型进程"""
|
||||
@ -229,8 +275,15 @@ class OllamaShell:
|
||||
|
||||
self.console.print(table)
|
||||
|
||||
except ConnectionError:
|
||||
self.console.print("[red]连接服务器失败[/red]")
|
||||
except TimeoutError:
|
||||
self.console.print("[red]请求超时[/red]")
|
||||
except HTTPError as e:
|
||||
self.console.print(f"[red]HTTP 错误: {e.response.status_code}[/red]")
|
||||
except Exception as e:
|
||||
self.console.print(f"[red]❌ 错误: {str(e)}[/red]")
|
||||
self.console.print("[red]发生未知错误[/red]")
|
||||
logging.error(f"Unexpected error: {str(e)}")
|
||||
|
||||
def chat_with_model(self, *args: List[str]) -> None:
|
||||
"""与模型进行对话"""
|
||||
@ -292,8 +345,18 @@ class OllamaShell:
|
||||
except EOFError:
|
||||
self.console.print("\n[yellow]👋 再见![/yellow]")
|
||||
break
|
||||
except ConnectionError:
|
||||
self.console.print("[red]连接服务器失败[/red]")
|
||||
break
|
||||
except TimeoutError:
|
||||
self.console.print("[red]请求超时[/red]")
|
||||
break
|
||||
except HTTPError as e:
|
||||
self.console.print(f"[red]HTTP 错误: {e.response.status_code}[/red]")
|
||||
break
|
||||
except Exception as e:
|
||||
self.console.print(f"\n[red]❌ 错误: {str(e)}[/red]")
|
||||
self.console.print("[red]发生未知错误[/red]")
|
||||
logging.error(f"Unexpected error: {str(e)}")
|
||||
break
|
||||
|
||||
def show_help(self, *args: List[str]) -> None:
|
||||
@ -301,7 +364,7 @@ class OllamaShell:
|
||||
table = Table(title="✨ 命令列表", show_header=True, header_style="bold magenta")
|
||||
table.add_column("📝 命令", style="cyan")
|
||||
table.add_column("📄 说明", style="green")
|
||||
table.add_column("📖 用法", style="yellow")
|
||||
table.add_column("<EFBFBD><EFBFBD> 用法", style="yellow", justify="left")
|
||||
|
||||
commands_help = [
|
||||
("list", "📃 列出所有可用的模型", "list"),
|
||||
@ -309,6 +372,8 @@ class OllamaShell:
|
||||
("show", "🔍 显示模型详细信息", "show <model_name>"),
|
||||
("chat", "💬 与模型进行对话", "chat <model_name>"),
|
||||
("ps", "⚡️ 显示运行中的模型", "ps"),
|
||||
("rm", "🗑️ 删除指定模型","rm <model_name>"),
|
||||
("version", "📌 显示版本信息", "version"),
|
||||
("help", "❓ 显示帮助信息", "help"),
|
||||
("exit", "🚪 退出程序", "exit"),
|
||||
]
|
||||
@ -387,8 +452,102 @@ class OllamaShell:
|
||||
except EOFError:
|
||||
self.console.print("\n[yellow]👋 再见!✨[/yellow]")
|
||||
break
|
||||
except ConnectionError:
|
||||
self.console.print("[red]连接服务器失败[/red]")
|
||||
except TimeoutError:
|
||||
self.console.print("[red]请求超时[/red]")
|
||||
except HTTPError as e:
|
||||
self.console.print(f"[red]HTTP 错误: {e.response.status_code}[/red]")
|
||||
except Exception as e:
|
||||
self.console.print(f"[red]❌ 错误: {str(e)}[/red]")
|
||||
self.console.print("[red]发生未知错误[/red]")
|
||||
logging.error(f"Unexpected error: {str(e)}")
|
||||
break
|
||||
|
||||
def delete_model(self, *args: List[str]) -> None:
|
||||
"""删除指定的模型"""
|
||||
if not args:
|
||||
self.console.print("[red]错误: 请指定要删除的模型名称[/red]")
|
||||
return
|
||||
|
||||
model_name = args[0]
|
||||
# 修改模型名称验证,允许更多字符
|
||||
if not re.match(r'^[a-zA-Z0-9_\-\./:]+$', model_name):
|
||||
self.console.print("[red]错误: 模型名称包含非法字符[/red]")
|
||||
return
|
||||
|
||||
try:
|
||||
# 确认删除
|
||||
self.console.print(f"\n[yellow]⚠️ 确定要删除模型 {model_name} 吗?这个操作不可恢复![/yellow]")
|
||||
self.console.print("[dim]输入 'yes' 确认删除,其他输入取消[/dim]")
|
||||
|
||||
# 创建确认会话
|
||||
confirm_session = PromptSession()
|
||||
confirm = confirm_session.prompt("\n确认> ")
|
||||
|
||||
if confirm.lower() != 'yes':
|
||||
self.console.print("[yellow]已取消删除操作[/yellow]")
|
||||
return
|
||||
|
||||
with Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn(f"[bold red]正在删除模型 {model_name}..."),
|
||||
transient=True,
|
||||
) as progress:
|
||||
progress.add_task("delete")
|
||||
self.client.delete(model_name)
|
||||
|
||||
self.console.print(f"[green]✅ 模型 {model_name} 已成功删除![/green]")
|
||||
|
||||
except ConnectionError:
|
||||
self.console.print("[red]连接服务器失败[/red]")
|
||||
except TimeoutError:
|
||||
self.console.print("[red]请求超时[/red]")
|
||||
except HTTPError as e:
|
||||
self.console.print(f"[red]HTTP 错误: {e.response.status_code}[/red]")
|
||||
except Exception as e:
|
||||
self.console.print("[red]发生未知错误[/red]")
|
||||
logging.error(f"Unexpected error: {str(e)}")
|
||||
|
||||
def show_version(self, *args: List[str]) -> None:
|
||||
"""显示 Ollama 版本信息"""
|
||||
try:
|
||||
with Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn("[bold blue]获取版本信息..."),
|
||||
transient=True,
|
||||
) as progress:
|
||||
progress.add_task("fetch")
|
||||
# 使用保存的 verify_ssl 设置
|
||||
import httpx
|
||||
response = httpx.get(
|
||||
f"{self.host}/api/version",
|
||||
verify=self.verify_ssl
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
if not data or 'version' not in data:
|
||||
self.console.print("[yellow]⚠️ 无法获取版本信息[/yellow]")
|
||||
return
|
||||
|
||||
version = data['version']
|
||||
# 创建面板显示版本信息
|
||||
panel = Panel.fit(
|
||||
f"[bold cyan]Ollama 版本:[/bold cyan] {version}",
|
||||
title="📌 版本信息",
|
||||
border_style="green"
|
||||
)
|
||||
self.console.print(panel)
|
||||
|
||||
except ConnectionError:
|
||||
self.console.print("[red]连接服务器失败[/red]")
|
||||
except TimeoutError:
|
||||
self.console.print("[red]请求超时[/red]")
|
||||
except HTTPError as e:
|
||||
self.console.print(f"[red]HTTP 错误: {e.response.status_code}[/red]")
|
||||
except Exception as e:
|
||||
self.console.print("[red]获取版本信息时发生错误[/red]")
|
||||
logging.error(f"Version info error: {str(e)}")
|
||||
|
||||
|
||||
def main():
|
||||
|
1626
ollamaAPI说明.md
Normal file
1626
ollamaAPI说明.md
Normal file
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user