fix: 修复 Python 3.12 类型注解兼容性问题
主要修复: 1. llm/client.py - 添加 from __future__ import annotations,修复 AsyncIterator 导入 2. llm/openai_client.py - 修复 stream_chat 返回类型(同步返回 AsyncIterator) 3. agent/context.py - 添加 TokenCounter Protocol 类型 4. agent/failover.py - 修复 Callable 类型参数 5. agent/runner.py - 添加 __future__ annotations 6. session/manager.py - 添加 TYPE_CHECKING 和 TokenCounter 类型 7. tools/*.py - 添加 __future__ annotations 类型注解改进: - 使用 collections.abc.AsyncIterator 替代 typing.AsyncIterator - 添加 Protocol 定义用于类型检查 - 修复所有 missing type annotation 警告
This commit is contained in:
@@ -1,6 +1,12 @@
|
||||
"""上下文管理器"""
|
||||
|
||||
from typing import List, Optional, Dict, Any
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, Dict, Any, TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..llm.client import TokenCounter
|
||||
|
||||
from .types import Message
|
||||
|
||||
|
||||
@@ -15,7 +21,7 @@ class ContextManager:
|
||||
max_tokens: int = 128000, # GPT-4o 默认
|
||||
reserve_tokens: int = 4000, # 为输出保留的 token
|
||||
compression_threshold: float = 0.8, # 触发压缩的阈值
|
||||
):
|
||||
) -> None:
|
||||
self.max_tokens = max_tokens
|
||||
self.reserve_tokens = reserve_tokens
|
||||
self.compression_threshold = compression_threshold
|
||||
@@ -23,11 +29,11 @@ class ContextManager:
|
||||
self._current_tokens = 0
|
||||
self._compression_count = 0
|
||||
|
||||
def count_tokens(self, messages: List[Message], token_counter) -> int:
|
||||
def count_tokens(self, messages: List[Message], token_counter: TokenCounter) -> int:
|
||||
"""计算消息列表的 token 总数"""
|
||||
return token_counter.count_messages_tokens(messages)
|
||||
|
||||
def update_token_count(self, messages: List[Message], token_counter) -> int:
|
||||
def update_token_count(self, messages: List[Message], token_counter: TokenCounter) -> int:
|
||||
"""更新并返回当前 token 数"""
|
||||
self._current_tokens = self.count_tokens(messages, token_counter)
|
||||
return self._current_tokens
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
"""故障转移处理器"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Optional, List, Dict, Any, Callable
|
||||
@@ -39,14 +41,14 @@ class FailoverHandler:
|
||||
base_delay: float = 1.0,
|
||||
max_delay: float = 60.0,
|
||||
exponential_base: float = 2.0,
|
||||
):
|
||||
) -> None:
|
||||
self.max_retries = max_retries
|
||||
self.base_delay = base_delay
|
||||
self.max_delay = max_delay
|
||||
self.exponential_base = exponential_base
|
||||
|
||||
self._state = RetryState()
|
||||
self._error_handlers: Dict[ErrorKind, Callable] = {}
|
||||
self._error_handlers: Dict[ErrorKind, Callable[[Exception], bool]] = {}
|
||||
|
||||
def classify_error(self, error: Exception) -> ErrorKind:
|
||||
"""分类错误类型"""
|
||||
@@ -113,7 +115,7 @@ class FailoverHandler:
|
||||
"""处理错误,返回是否应该重试"""
|
||||
return self.should_retry(error)
|
||||
|
||||
def register_handler(self, error_kind: ErrorKind, handler: Callable) -> None:
|
||||
def register_handler(self, error_kind: ErrorKind, handler: Callable[[Exception], bool]) -> None:
|
||||
"""注册特定错误类型的处理器"""
|
||||
self._error_handlers[error_kind] = handler
|
||||
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
"""Agent 运行器"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from typing import Optional, List, Dict, Any, AsyncIterator
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Optional, List, Dict, Any, Union
|
||||
|
||||
from .types import (
|
||||
RunParams,
|
||||
@@ -36,7 +39,7 @@ class AgentRunner:
|
||||
tool_registry: ToolRegistry,
|
||||
session_manager: SessionManager,
|
||||
context_manager: Optional[ContextManager] = None,
|
||||
):
|
||||
) -> None:
|
||||
self.llm = llm_client
|
||||
self.tools = tool_registry
|
||||
self.session = session_manager
|
||||
|
||||
@@ -1,9 +1,18 @@
|
||||
"""LLM 客户端基类"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import AsyncIterator, List, Optional, Dict, Any
|
||||
from __future__ import annotations
|
||||
|
||||
from ..agent.types import Message, Usage
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import List, Optional, Dict, Any, Protocol
|
||||
|
||||
from ..agent.types import Message
|
||||
|
||||
|
||||
class TokenCounter(Protocol):
|
||||
"""Token 计数器协议"""
|
||||
def count_messages_tokens(self, messages: List[Message]) -> int: ...
|
||||
def count_tokens(self, text: str) -> int: ...
|
||||
|
||||
|
||||
class LLMClient(ABC):
|
||||
@@ -14,7 +23,7 @@ class LLMClient(ABC):
|
||||
model: str,
|
||||
api_key: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
):
|
||||
) -> None:
|
||||
self.model = model
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url
|
||||
@@ -24,7 +33,7 @@ class LLMClient(ABC):
|
||||
self,
|
||||
messages: List[Message],
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
**kwargs,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
"""发送聊天请求
|
||||
|
||||
@@ -39,11 +48,11 @@ class LLMClient(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def stream_chat(
|
||||
def stream_chat(
|
||||
self,
|
||||
messages: List[Message],
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
**kwargs,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[Dict[str, Any]]:
|
||||
"""流式聊天请求
|
||||
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
"""OpenAI 客户端实现"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import AsyncIterator, List, Optional, Dict, Any
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import List, Optional, Dict, Any, Union
|
||||
|
||||
from .client import LLMClient
|
||||
from ..agent.types import Message, MessageRole
|
||||
@@ -18,8 +21,8 @@ class OpenAIClient(LLMClient):
|
||||
model: str = "gpt-4o",
|
||||
api_key: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(model, api_key, base_url)
|
||||
self._extra_kwargs = kwargs
|
||||
|
||||
@@ -40,7 +43,7 @@ class OpenAIClient(LLMClient):
|
||||
|
||||
def _convert_messages(self, messages: List[Message]) -> List[Dict[str, Any]]:
|
||||
"""转换为 OpenAI 消息格式"""
|
||||
result = []
|
||||
result: List[Dict[str, Any]] = []
|
||||
for msg in messages:
|
||||
item: Dict[str, Any] = {"role": msg.role.value}
|
||||
|
||||
@@ -49,7 +52,7 @@ class OpenAIClient(LLMClient):
|
||||
item["content"] = msg.content
|
||||
else:
|
||||
# 多模态内容
|
||||
content = []
|
||||
content: List[Dict[str, Any]] = []
|
||||
for block in msg.content:
|
||||
if block.type == "text" and block.text:
|
||||
content.append({"type": "text", "text": block.text})
|
||||
@@ -97,7 +100,7 @@ class OpenAIClient(LLMClient):
|
||||
self,
|
||||
messages: List[Message],
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
**kwargs,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
"""发送聊天请求"""
|
||||
# 构建请求参数
|
||||
@@ -120,7 +123,7 @@ class OpenAIClient(LLMClient):
|
||||
message = choice.message
|
||||
|
||||
# 提取工具调用
|
||||
tool_calls = None
|
||||
tool_calls: Optional[List[Dict[str, Any]]] = None
|
||||
if message.tool_calls:
|
||||
tool_calls = [
|
||||
{
|
||||
@@ -142,13 +145,13 @@ class OpenAIClient(LLMClient):
|
||||
},
|
||||
}
|
||||
|
||||
async def stream_chat(
|
||||
async def stream_chat_impl(
|
||||
self,
|
||||
messages: List[Message],
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
**kwargs,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[Dict[str, Any]]:
|
||||
"""流式聊天请求"""
|
||||
"""流式聊天请求实现"""
|
||||
params: Dict[str, Any] = {
|
||||
"model": self.model,
|
||||
"messages": self._convert_messages(messages),
|
||||
@@ -211,6 +214,15 @@ class OpenAIClient(LLMClient):
|
||||
"finish_reason": chunk.choices[0].finish_reason,
|
||||
}
|
||||
|
||||
def stream_chat(
|
||||
self,
|
||||
messages: List[Message],
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[Dict[str, Any]]:
|
||||
"""流式聊天请求"""
|
||||
return self.stream_chat_impl(messages, tools, **kwargs)
|
||||
|
||||
def count_tokens(self, text: str) -> int:
|
||||
"""估算 token 数量
|
||||
|
||||
@@ -238,7 +250,7 @@ class AnthropicClient(LLMClient):
|
||||
model: str = "claude-3-opus-20240229",
|
||||
api_key: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
):
|
||||
) -> None:
|
||||
super().__init__(model, api_key, base_url)
|
||||
# 实际实现需要 anthropic SDK
|
||||
self._client = None
|
||||
@@ -247,16 +259,16 @@ class AnthropicClient(LLMClient):
|
||||
self,
|
||||
messages: List[Message],
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
**kwargs,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
"""发送聊天请求"""
|
||||
raise NotImplementedError("Anthropic client not yet implemented")
|
||||
|
||||
async def stream_chat(
|
||||
def stream_chat(
|
||||
self,
|
||||
messages: List[Message],
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
**kwargs,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[Dict[str, Any]]:
|
||||
"""流式聊天请求"""
|
||||
raise NotImplementedError("Anthropic client not yet implemented")
|
||||
|
||||
@@ -1,11 +1,16 @@
|
||||
"""会话管理器"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Dict, Any
|
||||
from typing import List, Optional, Dict, Any, Union, TYPE_CHECKING
|
||||
from datetime import datetime
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..llm.client import TokenCounter
|
||||
|
||||
from ..agent.types import Message, MessageRole
|
||||
|
||||
|
||||
@@ -21,7 +26,7 @@ class SessionManager:
|
||||
session_dir: str = ".sessions",
|
||||
max_history: int = 100,
|
||||
auto_save: bool = True,
|
||||
):
|
||||
) -> None:
|
||||
self.session_id = session_id
|
||||
self.session_dir = Path(session_dir)
|
||||
self.max_history = max_history
|
||||
@@ -149,11 +154,11 @@ class SessionManager:
|
||||
self.messages = []
|
||||
self.metadata = {}
|
||||
|
||||
def get_token_count(self, token_counter) -> int:
|
||||
def get_token_count(self, token_counter: TokenCounter) -> int:
|
||||
"""获取会话的 token 总数
|
||||
|
||||
Args:
|
||||
token_counter: 具有 count_tokens 方法的对象
|
||||
token_counter: 具有 count_messages_tokens 方法的对象
|
||||
"""
|
||||
return token_counter.count_messages_tokens(self.messages)
|
||||
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
"""工具基类和注册表"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Optional, List
|
||||
from typing import Any, Dict, Optional, List, Union
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
"""文件操作工具"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import aiofiles
|
||||
import base64
|
||||
import mimetypes
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict, Optional, List, Union
|
||||
|
||||
from .base import BaseTool, ToolContext
|
||||
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
"""Shell 命令执行工具"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import signal
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict, Optional, List, Union
|
||||
|
||||
from .base import BaseTool, ToolContext
|
||||
|
||||
|
||||
Reference in New Issue
Block a user