命令行传参

This commit is contained in:
b3nguang 2025-02-18 13:21:10 +08:00
parent 6ef14c4f44
commit 97f9cf3b13
2 changed files with 82 additions and 35 deletions

90
main.py
View File

@ -4,14 +4,16 @@
@ Date: 2025-02-18 12:04:37
"""
import argparse
import sys
from typing import List
from ollama import Client
from prompt_toolkit import PromptSession
from prompt_toolkit.completion import WordCompleter
from rich.console import Console
from rich.panel import Panel
from rich.progress import Progress, SpinnerColumn, TextColumn
from rich.prompt import Prompt
from rich.table import Table
@ -196,7 +198,7 @@ class OllamaShell:
progress.add_task("fetch")
response = self.client.ps()
if not response or not hasattr(response, 'models') or not response.models:
if not response or not hasattr(response, "models") or not response.models:
self.console.print("[yellow]⚠️ 没有正在运行的模型[/yellow]")
return
@ -219,7 +221,11 @@ class OllamaShell:
size_str = f"{size_gb:.1f}GB"
# 格式化过期时间
expires_str = model.expires_at.strftime("%Y-%m-%d %H:%M:%S") if model.expires_at else "Unknown"
expires_str = (
model.expires_at.strftime("%Y-%m-%d %H:%M:%S")
if model.expires_at
else "Unknown"
)
table.add_row(
model.name,
@ -227,7 +233,7 @@ class OllamaShell:
model.details.format if model.details else "Unknown",
model.details.parameter_size if model.details else "Unknown",
model.details.quantization_level if model.details else "Unknown",
expires_str
expires_str,
)
self.console.print(table)
@ -245,15 +251,21 @@ class OllamaShell:
self.console.print(f"\n[bold]💬 开始与 {model_name} 对话[/bold]")
self.console.print("[dim]🚪 输入 'exit' 结束对话[/dim]")
# 创建对话会话
chat_session = PromptSession()
while True:
try:
message = Prompt.ask("\n[bold green]你[/bold green]")
# 获取用户输入
message = chat_session.prompt("\n👤 你> ")
if message.lower() == "exit":
break
self.console.print("\n[bold blue]AI[/bold blue]")
self.console.print("\n[bold blue]🤖 AI[/bold blue]")
with Progress(
SpinnerColumn(), TextColumn("[bold blue]思考中..."), transient=True
SpinnerColumn(),
TextColumn("[bold blue]🤔 思考中..."),
transient=True,
) as progress:
progress.add_task("think")
stream = self.client.chat(
@ -269,10 +281,13 @@ class OllamaShell:
self.console.print(content, end="", highlight=False)
except KeyboardInterrupt:
self.console.print("\n[yellow]对话已取消[/yellow]")
self.console.print("\n[yellow]⛔️ 对话已取消[/yellow]")
break
except EOFError:
self.console.print("\n[yellow]👋 再见![/yellow]")
break
except Exception as e:
self.console.print(f"\n[red]错误: {str(e)}[/red]")
self.console.print(f"\n[red]错误: {str(e)}[/red]")
break
def show_help(self, *args: List[str]) -> None:
@ -304,6 +319,30 @@ class OllamaShell:
self.console.print("[yellow]👋 再见!✨[/yellow]")
sys.exit(0)
def get_model_list(self) -> List[str]:
"""获取模型列表"""
try:
models = self.client.list()
if hasattr(models, "models"):
return [model.model for model in models.models]
elif isinstance(models, list):
return [model.model for model in models]
return []
except Exception:
return []
def get_command_completer(self) -> WordCompleter:
"""创建命令补全器"""
# 获取所有命令
commands = list(self.commands.keys())
# 获取所有模型
models = self.get_model_list()
# 创建补全器
word_list = commands + [
f"{cmd} {model}" for cmd in ["chat", "show", "pull"] for model in models
]
return WordCompleter(word_list, ignore_case=True)
def run(self) -> None:
"""运行交互式shell"""
self.console.print(
@ -314,9 +353,20 @@ class OllamaShell:
)
)
# 创建命令行会话
session = PromptSession()
while True:
try:
command = Prompt.ask("\n[bold cyan]🤖 ollama[/bold cyan]")
# 获取最新的补全器
completer = self.get_command_completer()
# 显示提示符并等待输入
command = session.prompt(
"\n🤖 ollama> ",
completer=completer,
complete_while_typing=True,
)
args = command.strip().split()
if not args:
continue
@ -332,12 +382,30 @@ class OllamaShell:
except KeyboardInterrupt:
self.console.print("\n[yellow]⛔️ 操作已取消[/yellow]")
continue
except EOFError:
self.console.print("\n[yellow]👋 再见!✨[/yellow]")
break
except Exception as e:
self.console.print(f"[red]❌ 错误: {str(e)}[/red]")
def main():
shell = OllamaShell()
# 创建命令行解析器
parser = argparse.ArgumentParser(
description="Ollama Shell - 一个功能强大的 Ollama 命令行工具"
)
parser.add_argument(
"-H",
"--host",
default="http://localhost:11434",
help="Ollama 服务器地址,默认为 http://localhost:11434",
)
# 解析命令行参数
args = parser.parse_args()
# 创建 shell 实例
shell = OllamaShell(host=args.host)
shell.run()

View File

@ -1,24 +1,3 @@
annotated-types==0.7.0
anyio==4.8.0
certifi==2025.1.31
charset-normalizer==3.4.1
colorama==0.4.6
distro==1.9.0
h11==0.14.0
httpcore==1.0.7
httpx==0.28.1
idna==3.10
jiter==0.8.2
markdown-it-py==3.0.0
mdurl==0.1.2
openai==1.61.1
pydantic==2.10.6
pydantic-core==2.27.2
pygments==2.19.1
pytz==2025.1
requests==2.32.3
rich==13.9.4
sniffio==1.3.1
tqdm==4.67.1
typing-extensions==4.12.2
urllib3==2.3.0
ollama
prompt_toolkit
rich