fix: 修复 Python 3.12 类型注解兼容性问题

主要修复:
1. llm/client.py - 添加 from __future__ import annotations,修复 AsyncIterator 导入
2. llm/openai_client.py - 修复 stream_chat 返回类型(同步返回 AsyncIterator)
3. agent/context.py - 添加 TokenCounter Protocol 类型
4. agent/failover.py - 修复 Callable 类型参数
5. agent/runner.py - 添加 __future__ annotations
6. session/manager.py - 添加 TYPE_CHECKING 和 TokenCounter 类型
7. tools/*.py - 添加 __future__ annotations

类型注解改进:
- 使用 collections.abc.AsyncIterator 替代 typing.AsyncIterator
- 添加 Protocol 定义用于类型检查
- 修复所有 missing type annotation 警告
This commit is contained in:
yunxiafei
2026-03-17 17:43:39 +08:00
parent 4b54f64b97
commit 1ce5f12655
9 changed files with 80 additions and 37 deletions

View File

@@ -1,6 +1,12 @@
"""上下文管理器"""
from typing import List, Optional, Dict, Any
from __future__ import annotations
from typing import List, Dict, Any, TYPE_CHECKING
if TYPE_CHECKING:
from ..llm.client import TokenCounter
from .types import Message
@@ -15,7 +21,7 @@ class ContextManager:
max_tokens: int = 128000, # GPT-4o 默认
reserve_tokens: int = 4000, # 为输出保留的 token
compression_threshold: float = 0.8, # 触发压缩的阈值
):
) -> None:
self.max_tokens = max_tokens
self.reserve_tokens = reserve_tokens
self.compression_threshold = compression_threshold
@@ -23,11 +29,11 @@ class ContextManager:
self._current_tokens = 0
self._compression_count = 0
def count_tokens(self, messages: List[Message], token_counter) -> int:
def count_tokens(self, messages: List[Message], token_counter: TokenCounter) -> int:
"""计算消息列表的 token 总数"""
return token_counter.count_messages_tokens(messages)
def update_token_count(self, messages: List[Message], token_counter) -> int:
def update_token_count(self, messages: List[Message], token_counter: TokenCounter) -> int:
"""更新并返回当前 token 数"""
self._current_tokens = self.count_tokens(messages, token_counter)
return self._current_tokens

View File

@@ -1,5 +1,7 @@
"""故障转移处理器"""
from __future__ import annotations
import asyncio
import time
from typing import Optional, List, Dict, Any, Callable
@@ -39,14 +41,14 @@ class FailoverHandler:
base_delay: float = 1.0,
max_delay: float = 60.0,
exponential_base: float = 2.0,
):
) -> None:
self.max_retries = max_retries
self.base_delay = base_delay
self.max_delay = max_delay
self.exponential_base = exponential_base
self._state = RetryState()
self._error_handlers: Dict[ErrorKind, Callable] = {}
self._error_handlers: Dict[ErrorKind, Callable[[Exception], bool]] = {}
def classify_error(self, error: Exception) -> ErrorKind:
"""分类错误类型"""
@@ -113,7 +115,7 @@ class FailoverHandler:
"""处理错误,返回是否应该重试"""
return self.should_retry(error)
def register_handler(self, error_kind: ErrorKind, handler: Callable) -> None:
def register_handler(self, error_kind: ErrorKind, handler: Callable[[Exception], bool]) -> None:
"""注册特定错误类型的处理器"""
self._error_handlers[error_kind] = handler

View File

@@ -1,9 +1,12 @@
"""Agent 运行器"""
from __future__ import annotations
import asyncio
import json
import time
from typing import Optional, List, Dict, Any, AsyncIterator
from collections.abc import AsyncIterator
from typing import Optional, List, Dict, Any, Union
from .types import (
RunParams,
@@ -36,7 +39,7 @@ class AgentRunner:
tool_registry: ToolRegistry,
session_manager: SessionManager,
context_manager: Optional[ContextManager] = None,
):
) -> None:
self.llm = llm_client
self.tools = tool_registry
self.session = session_manager

View File

@@ -1,9 +1,18 @@
"""LLM 客户端基类"""
from abc import ABC, abstractmethod
from typing import AsyncIterator, List, Optional, Dict, Any
from __future__ import annotations
from ..agent.types import Message, Usage
from abc import ABC, abstractmethod
from collections.abc import AsyncIterator
from typing import List, Optional, Dict, Any, Protocol
from ..agent.types import Message
class TokenCounter(Protocol):
"""Token 计数器协议"""
def count_messages_tokens(self, messages: List[Message]) -> int: ...
def count_tokens(self, text: str) -> int: ...
class LLMClient(ABC):
@@ -14,7 +23,7 @@ class LLMClient(ABC):
model: str,
api_key: Optional[str] = None,
base_url: Optional[str] = None,
):
) -> None:
self.model = model
self.api_key = api_key
self.base_url = base_url
@@ -24,7 +33,7 @@ class LLMClient(ABC):
self,
messages: List[Message],
tools: Optional[List[Dict[str, Any]]] = None,
**kwargs,
**kwargs: Any,
) -> Dict[str, Any]:
"""发送聊天请求
@@ -39,11 +48,11 @@ class LLMClient(ABC):
pass
@abstractmethod
async def stream_chat(
def stream_chat(
self,
messages: List[Message],
tools: Optional[List[Dict[str, Any]]] = None,
**kwargs,
**kwargs: Any,
) -> AsyncIterator[Dict[str, Any]]:
"""流式聊天请求

View File

@@ -1,7 +1,10 @@
"""OpenAI 客户端实现"""
from __future__ import annotations
import json
from typing import AsyncIterator, List, Optional, Dict, Any
from collections.abc import AsyncIterator
from typing import List, Optional, Dict, Any, Union
from .client import LLMClient
from ..agent.types import Message, MessageRole
@@ -18,8 +21,8 @@ class OpenAIClient(LLMClient):
model: str = "gpt-4o",
api_key: Optional[str] = None,
base_url: Optional[str] = None,
**kwargs,
):
**kwargs: Any,
) -> None:
super().__init__(model, api_key, base_url)
self._extra_kwargs = kwargs
@@ -40,7 +43,7 @@ class OpenAIClient(LLMClient):
def _convert_messages(self, messages: List[Message]) -> List[Dict[str, Any]]:
"""转换为 OpenAI 消息格式"""
result = []
result: List[Dict[str, Any]] = []
for msg in messages:
item: Dict[str, Any] = {"role": msg.role.value}
@@ -49,7 +52,7 @@ class OpenAIClient(LLMClient):
item["content"] = msg.content
else:
# 多模态内容
content = []
content: List[Dict[str, Any]] = []
for block in msg.content:
if block.type == "text" and block.text:
content.append({"type": "text", "text": block.text})
@@ -97,7 +100,7 @@ class OpenAIClient(LLMClient):
self,
messages: List[Message],
tools: Optional[List[Dict[str, Any]]] = None,
**kwargs,
**kwargs: Any,
) -> Dict[str, Any]:
"""发送聊天请求"""
# 构建请求参数
@@ -120,7 +123,7 @@ class OpenAIClient(LLMClient):
message = choice.message
# 提取工具调用
tool_calls = None
tool_calls: Optional[List[Dict[str, Any]]] = None
if message.tool_calls:
tool_calls = [
{
@@ -142,13 +145,13 @@ class OpenAIClient(LLMClient):
},
}
async def stream_chat(
async def stream_chat_impl(
self,
messages: List[Message],
tools: Optional[List[Dict[str, Any]]] = None,
**kwargs,
**kwargs: Any,
) -> AsyncIterator[Dict[str, Any]]:
"""流式聊天请求"""
"""流式聊天请求实现"""
params: Dict[str, Any] = {
"model": self.model,
"messages": self._convert_messages(messages),
@@ -211,6 +214,15 @@ class OpenAIClient(LLMClient):
"finish_reason": chunk.choices[0].finish_reason,
}
def stream_chat(
self,
messages: List[Message],
tools: Optional[List[Dict[str, Any]]] = None,
**kwargs: Any,
) -> AsyncIterator[Dict[str, Any]]:
"""流式聊天请求"""
return self.stream_chat_impl(messages, tools, **kwargs)
def count_tokens(self, text: str) -> int:
"""估算 token 数量
@@ -238,7 +250,7 @@ class AnthropicClient(LLMClient):
model: str = "claude-3-opus-20240229",
api_key: Optional[str] = None,
base_url: Optional[str] = None,
):
) -> None:
super().__init__(model, api_key, base_url)
# 实际实现需要 anthropic SDK
self._client = None
@@ -247,16 +259,16 @@ class AnthropicClient(LLMClient):
self,
messages: List[Message],
tools: Optional[List[Dict[str, Any]]] = None,
**kwargs,
**kwargs: Any,
) -> Dict[str, Any]:
"""发送聊天请求"""
raise NotImplementedError("Anthropic client not yet implemented")
async def stream_chat(
def stream_chat(
self,
messages: List[Message],
tools: Optional[List[Dict[str, Any]]] = None,
**kwargs,
**kwargs: Any,
) -> AsyncIterator[Dict[str, Any]]:
"""流式聊天请求"""
raise NotImplementedError("Anthropic client not yet implemented")

View File

@@ -1,11 +1,16 @@
"""会话管理器"""
from __future__ import annotations
import json
import time
from pathlib import Path
from typing import List, Optional, Dict, Any
from typing import List, Optional, Dict, Any, Union, TYPE_CHECKING
from datetime import datetime
if TYPE_CHECKING:
from ..llm.client import TokenCounter
from ..agent.types import Message, MessageRole
@@ -21,7 +26,7 @@ class SessionManager:
session_dir: str = ".sessions",
max_history: int = 100,
auto_save: bool = True,
):
) -> None:
self.session_id = session_id
self.session_dir = Path(session_dir)
self.max_history = max_history
@@ -149,11 +154,11 @@ class SessionManager:
self.messages = []
self.metadata = {}
def get_token_count(self, token_counter) -> int:
def get_token_count(self, token_counter: TokenCounter) -> int:
"""获取会话的 token 总数
Args:
token_counter: 具有 count_tokens 方法的对象
token_counter: 具有 count_messages_tokens 方法的对象
"""
return token_counter.count_messages_tokens(self.messages)

View File

@@ -1,7 +1,9 @@
"""工具基类和注册表"""
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional, List
from typing import Any, Dict, Optional, List, Union
from dataclasses import dataclass, field

View File

@@ -1,10 +1,12 @@
"""文件操作工具"""
from __future__ import annotations
import aiofiles
import base64
import mimetypes
from pathlib import Path
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, List, Union
from .base import BaseTool, ToolContext

View File

@@ -1,9 +1,11 @@
"""Shell 命令执行工具"""
from __future__ import annotations
import asyncio
import os
import signal
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, List, Union
from .base import BaseTool, ToolContext