Enhance Ollama Shell with robust error handling, new commands, and improved server connection management

This commit is contained in:
openai1998 2025-02-20 16:48:01 +08:00
parent 435d55953a
commit 5c51e532d9
5 changed files with 1795 additions and 10 deletions

Binary file not shown.

Before

Width:  |  Height:  |  Size: 72 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 38 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 126 KiB

179
main.py
View File

@ -7,7 +7,9 @@
import argparse
import re
import sys
from typing import List
from typing import List, Tuple
import logging
import subprocess
from ollama import Client
from prompt_toolkit import PromptSession
@ -18,11 +20,27 @@ from rich.panel import Panel
from rich.progress import Progress, SpinnerColumn, TextColumn
from rich.style import Style
from rich.table import Table
from httpx import Timeout, HTTPError
class OllamaShell:
def __init__(self, host: str = "http://112.117.14.179:11434/"):
self.client = Client(host=host)
def __init__(self, host: str = None):
if not host:
raise ValueError("必须提供 Ollama 服务器地址")
if not host.startswith(("http://", "https://")):
raise ValueError("服务器地址必须以 http:// 或 https:// 开头")
# 保存 host 地址
self.host = host
# 根据协议决定是否验证证书
self.verify_ssl = not (host.startswith("https://") and ":" in host.split("://")[1].split("/")[0])
self.client = Client(
host=host,
timeout=Timeout(30.0),
verify=self.verify_ssl
)
self.console = Console()
self.commands = {
"list": (self.list_models, "📃 列出可用模型"),
@ -32,6 +50,8 @@ class OllamaShell:
"ps": (self.show_processes, "⚡️ 显示运行中的模型"),
"help": (self.show_help, "❓ 显示帮助信息"),
"exit": (self.exit_shell, "🚪 退出程序"),
"rm": (self.delete_model, "🗑️ 删除指定模型"),
"version": (self.show_version, "📌 显示版本信息"),
}
def list_models(self, *args: List[str]) -> None:
@ -108,8 +128,15 @@ class OllamaShell:
self.console.print(table)
except ConnectionError:
self.console.print("[red]连接服务器失败[/red]")
except TimeoutError:
self.console.print("[red]请求超时[/red]")
except HTTPError as e:
self.console.print(f"[red]HTTP 错误: {e.response.status_code}[/red]")
except Exception as e:
self.console.print(f"[red]错误: {str(e)}[/red]")
self.console.print("[red]发生未知错误[/red]")
logging.error(f"Unexpected error: {str(e)}")
def pull_model(self, *args: List[str]) -> None:
"""拉取指定的模型"""
@ -118,6 +145,11 @@ class OllamaShell:
return
model_name = args[0]
# 修改模型名称验证,允许更多字符
if not re.match(r'^[a-zA-Z0-9_\-\./:]+$', model_name):
self.console.print("[red]错误: 模型名称包含非法字符[/red]")
return
self.console.print(f"\n[bold]📥 开始拉取模型: {model_name}[/bold]")
try:
@ -133,8 +165,15 @@ class OllamaShell:
)
self.console.print("[green]✅ 模型拉取完成![/green]")
except ConnectionError:
self.console.print("[red]连接服务器失败[/red]")
except TimeoutError:
self.console.print("[red]请求超时[/red]")
except HTTPError as e:
self.console.print(f"[red]HTTP 错误: {e.response.status_code}[/red]")
except Exception as e:
self.console.print(f"[red]错误: {str(e)}[/red]")
self.console.print("[red]发生未知错误[/red]")
logging.error(f"Unexpected error: {str(e)}")
def show_model(self, *args: List[str]) -> None:
"""显示模型详细信息"""
@ -179,8 +218,15 @@ class OllamaShell:
)
self.console.print(panel)
except ConnectionError:
self.console.print("[red]连接服务器失败[/red]")
except TimeoutError:
self.console.print("[red]请求超时[/red]")
except HTTPError as e:
self.console.print(f"[red]HTTP 错误: {e.response.status_code}[/red]")
except Exception as e:
self.console.print(f"[red]错误: {str(e)}[/red]")
self.console.print("[red]发生未知错误[/red]")
logging.error(f"Unexpected error: {str(e)}")
def show_processes(self, *args: List[str]) -> None:
"""显示运行中的模型进程"""
@ -229,8 +275,15 @@ class OllamaShell:
self.console.print(table)
except ConnectionError:
self.console.print("[red]连接服务器失败[/red]")
except TimeoutError:
self.console.print("[red]请求超时[/red]")
except HTTPError as e:
self.console.print(f"[red]HTTP 错误: {e.response.status_code}[/red]")
except Exception as e:
self.console.print(f"[red]❌ 错误: {str(e)}[/red]")
self.console.print("[red]发生未知错误[/red]")
logging.error(f"Unexpected error: {str(e)}")
def chat_with_model(self, *args: List[str]) -> None:
"""与模型进行对话"""
@ -292,8 +345,18 @@ class OllamaShell:
except EOFError:
self.console.print("\n[yellow]👋 再见![/yellow]")
break
except ConnectionError:
self.console.print("[red]连接服务器失败[/red]")
break
except TimeoutError:
self.console.print("[red]请求超时[/red]")
break
except HTTPError as e:
self.console.print(f"[red]HTTP 错误: {e.response.status_code}[/red]")
break
except Exception as e:
self.console.print(f"\n[red]❌ 错误: {str(e)}[/red]")
self.console.print("[red]发生未知错误[/red]")
logging.error(f"Unexpected error: {str(e)}")
break
def show_help(self, *args: List[str]) -> None:
@ -301,7 +364,7 @@ class OllamaShell:
table = Table(title="✨ 命令列表", show_header=True, header_style="bold magenta")
table.add_column("📝 命令", style="cyan")
table.add_column("📄 说明", style="green")
table.add_column("📖 用法", style="yellow")
table.add_column("<EFBFBD><EFBFBD> 用法", style="yellow", justify="left")
commands_help = [
("list", "📃 列出所有可用的模型", "list"),
@ -309,6 +372,8 @@ class OllamaShell:
("show", "🔍 显示模型详细信息", "show <model_name>"),
("chat", "💬 与模型进行对话", "chat <model_name>"),
("ps", "⚡️ 显示运行中的模型", "ps"),
("rm", "🗑️ 删除指定模型","rm <model_name>"),
("version", "📌 显示版本信息", "version"),
("help", "❓ 显示帮助信息", "help"),
("exit", "🚪 退出程序", "exit"),
]
@ -387,8 +452,102 @@ class OllamaShell:
except EOFError:
self.console.print("\n[yellow]👋 再见!✨[/yellow]")
break
except ConnectionError:
self.console.print("[red]连接服务器失败[/red]")
except TimeoutError:
self.console.print("[red]请求超时[/red]")
except HTTPError as e:
self.console.print(f"[red]HTTP 错误: {e.response.status_code}[/red]")
except Exception as e:
self.console.print(f"[red]❌ 错误: {str(e)}[/red]")
self.console.print("[red]发生未知错误[/red]")
logging.error(f"Unexpected error: {str(e)}")
break
def delete_model(self, *args: List[str]) -> None:
"""删除指定的模型"""
if not args:
self.console.print("[red]错误: 请指定要删除的模型名称[/red]")
return
model_name = args[0]
# 修改模型名称验证,允许更多字符
if not re.match(r'^[a-zA-Z0-9_\-\./:]+$', model_name):
self.console.print("[red]错误: 模型名称包含非法字符[/red]")
return
try:
# 确认删除
self.console.print(f"\n[yellow]⚠️ 确定要删除模型 {model_name} 吗?这个操作不可恢复![/yellow]")
self.console.print("[dim]输入 'yes' 确认删除,其他输入取消[/dim]")
# 创建确认会话
confirm_session = PromptSession()
confirm = confirm_session.prompt("\n确认> ")
if confirm.lower() != 'yes':
self.console.print("[yellow]已取消删除操作[/yellow]")
return
with Progress(
SpinnerColumn(),
TextColumn(f"[bold red]正在删除模型 {model_name}..."),
transient=True,
) as progress:
progress.add_task("delete")
self.client.delete(model_name)
self.console.print(f"[green]✅ 模型 {model_name} 已成功删除![/green]")
except ConnectionError:
self.console.print("[red]连接服务器失败[/red]")
except TimeoutError:
self.console.print("[red]请求超时[/red]")
except HTTPError as e:
self.console.print(f"[red]HTTP 错误: {e.response.status_code}[/red]")
except Exception as e:
self.console.print("[red]发生未知错误[/red]")
logging.error(f"Unexpected error: {str(e)}")
def show_version(self, *args: List[str]) -> None:
"""显示 Ollama 版本信息"""
try:
with Progress(
SpinnerColumn(),
TextColumn("[bold blue]获取版本信息..."),
transient=True,
) as progress:
progress.add_task("fetch")
# 使用保存的 verify_ssl 设置
import httpx
response = httpx.get(
f"{self.host}/api/version",
verify=self.verify_ssl
)
response.raise_for_status()
data = response.json()
if not data or 'version' not in data:
self.console.print("[yellow]⚠️ 无法获取版本信息[/yellow]")
return
version = data['version']
# 创建面板显示版本信息
panel = Panel.fit(
f"[bold cyan]Ollama 版本:[/bold cyan] {version}",
title="📌 版本信息",
border_style="green"
)
self.console.print(panel)
except ConnectionError:
self.console.print("[red]连接服务器失败[/red]")
except TimeoutError:
self.console.print("[red]请求超时[/red]")
except HTTPError as e:
self.console.print(f"[red]HTTP 错误: {e.response.status_code}[/red]")
except Exception as e:
self.console.print("[red]获取版本信息时发生错误[/red]")
logging.error(f"Version info error: {str(e)}")
def main():

1626
ollamaAPI说明.md Normal file

File diff suppressed because it is too large Load Diff