Create sshconnection.py

This commit is contained in:
AiShell 2025-05-26 15:59:04 +08:00 committed by GitHub
parent 848e600a36
commit be0bd74e27
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -0,0 +1,262 @@
import paramiko
import time
import logging
logger = logging.getLogger(__name__)
class SSHConnection:
"""SSH连接管理"""
def __init__(self, hostname: str = "bandit.labs.overthewire.org", port: int = 2220):
self.hostname = hostname
self.port = port
self.client = None
self.channel = None
self.is_connected = False
def connect(self, username: str, password: str, hostname: str = None, port: int = None, max_retries: int = 3) -> bool:
"""连接到SSH服务器带重试机制"""
# 如果已经连接,先关闭现有连接
if self.is_connected:
logger.info("检测到现有连接,先关闭...")
self.close()
# 使用传入的参数,如果没有则使用默认值
target_hostname = hostname if hostname is not None else self.hostname
target_port = port if port is not None else self.port
for attempt in range(max_retries):
try:
logger.info(f"尝试连接 {username}@{target_hostname}:{target_port} (第{attempt + 1}次)")
# 创建新的SSH客户端
self.client = paramiko.SSHClient()
self.client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
# 设置socket选项提高连接稳定性
sock = self._create_socket(target_hostname, target_port)
if sock is None:
raise Exception("无法创建socket连接")
# 连接SSH服务器
self.client.connect(
hostname=target_hostname,
port=target_port,
username=username,
password=password,
timeout=15,
# 网络稳定性参数
look_for_keys=False, # 不查找密钥文件
allow_agent=False, # 不使用SSH代理
banner_timeout=60, # 增加banner读取超时时间
auth_timeout=60, # 增加认证超时时间
channel_timeout=30, # 增加channel超时时间
# 使用自定义socket
sock=sock
)
# 创建交互式shell
self.channel = self.client.invoke_shell()
# 设置channel参数
self.channel.settimeout(15)
# 等待shell准备就绪
time.sleep(3)
# 清空初始输出
self._clear_initial_output()
self.is_connected = True
logger.info("SSH连接成功建立")
return True
except Exception as e:
logger.warning(f"{attempt + 1}次连接失败: {e}")
self.close() # 清理资源
if attempt < max_retries - 1:
# 等待后重试,递增等待时间
wait_time = (attempt + 1) * 2
logger.info(f"等待{wait_time}秒后重试...")
time.sleep(wait_time)
else:
logger.error(f"所有连接尝试失败,最后错误: {e}")
return False
def _create_socket(self, hostname: str, port: int):
"""创建socket连接"""
import socket
try:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
# 设置socket选项
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
# 设置超时
sock.settimeout(15)
# 连接
sock.connect((hostname, port))
return sock
except Exception as e:
logger.error(f"创建socket失败: {e}")
if sock:
sock.close()
return None
def _clear_initial_output(self):
"""清空初始输出"""
max_attempts = 10
total_cleared = 0
for attempt in range(max_attempts):
try:
if self.channel.recv_ready():
data = self.channel.recv(4096)
total_cleared += len(data)
time.sleep(0.3)
else:
# 没有更多数据,等待一下确认
time.sleep(0.5)
if not self.channel.recv_ready():
break
except Exception as e:
logger.warning(f"清空初始输出时出错: {e}")
break
if total_cleared > 0:
logger.debug(f"清空了{total_cleared}字节的初始输出")
def execute_command(self, command: str, timeout: int = 15, max_retries: int = 2) -> str:
"""执行命令并返回输出,带重试机制"""
if not self.is_connected or not self.channel:
logger.error("SSH连接未建立或已断开")
return ""
for attempt in range(max_retries):
try:
# 检查channel是否仍然活跃
if self.channel.closed:
logger.error("SSH channel已关闭")
self.is_connected = False
return ""
# 发送命令
self.channel.send(command + '\n')
time.sleep(0.8) # 增加等待时间
# 读取输出
output = ""
start_time = time.time()
last_received_time = start_time
while time.time() - start_time < timeout:
try:
if self.channel.recv_ready():
chunk = self.channel.recv(4096).decode('utf-8', errors='ignore')
output += chunk
last_received_time = time.time()
# 检查是否命令执行完成
if self._is_command_finished(chunk):
return output
else:
# 如果超过3秒没有新数据且有输出可能命令已完成
if output and (time.time() - last_received_time) > 3:
break
time.sleep(0.1)
except paramiko.ssh_exception.SSHException as e:
logger.error(f"SSH通信异常: {e}")
self.is_connected = False
return output
except Exception as e:
logger.error(f"读取输出时出错: {e}")
if attempt < max_retries - 1:
logger.info("尝试重新执行命令...")
time.sleep(1)
break
return output
return output
except Exception as e:
logger.error(f"执行命令失败 (第{attempt + 1}次): {e}")
if attempt < max_retries - 1:
time.sleep(1)
continue
self.is_connected = False
return ""
return ""
def _is_command_finished(self, chunk: str) -> bool:
"""检查命令是否执行完成"""
# 检查常见的shell提示符
lines = chunk.split('\n')
for line in lines:
line = line.strip()
# 更准确的提示符检测
if (line.endswith('$ ') or line.endswith('# ') or
line.endswith('$') or line.endswith('#') or
line.endswith('> ') or line.endswith('~$ ') or
('@' in line and ('$' in line or '#' in line))):
return True
return False
def is_connection_alive(self) -> bool:
"""检查连接是否还活着"""
if not self.is_connected or not self.client or not self.channel:
return False
try:
# 发送一个简单的测试命令
transport = self.client.get_transport()
if transport is None or not transport.is_active():
return False
if self.channel.closed:
return False
return True
except Exception:
return False
def close(self):
"""关闭连接"""
try:
if self.channel:
self.channel.close()
self.channel = None
if self.client:
self.client.close()
self.client = None
self.is_connected = False
logger.info("SSH连接已关闭")
except Exception as e:
logger.warning(f"关闭连接时出错: {e}")
finally:
# 确保状态重置
self.channel = None
self.client = None
self.is_connected = False
def __enter__(self):
"""支持with语句的上下文管理器"""
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""支持with语句的上下文管理器"""
self.close()
def __del__(self):
"""析构函数,确保资源被释放"""
try:
self.close()
except:
pass