feat: Phase 2 - 服务层实现 + Bug修复

This commit is contained in:
2026-03-14 20:08:20 +08:00
parent 0a117a444c
commit 1836d118fe
9 changed files with 958 additions and 5 deletions

View File

@@ -11,6 +11,55 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
---
## [0.4.0] - 2026-03-14
### Added
#### 🔧 服务层实现 (Phase 2)
- **AgentScheduler** - 5 种调度策略
- RoundRobinScheduler - 轮询调度
- WeightedRoundRobinScheduler - 加权轮询调度(默认)
- LeastConnectionsScheduler - 最少连接调度
- LeastResponseTimeScheduler - 最快响应调度
- CapabilityMatchScheduler - 能力匹配调度
- **MessageQueue** - 消息队列管理
- 消息入队/出队
- ACK 确认机制
- 重试队列管理
- Redis 存储支持
- **SessionService** - 会话服务
- 创建/关闭会话
- Agent 分配
- 多 Agent 协作支持
- 会话活跃度追踪
- **MessageService** - 消息服务
- 创建/获取消息
- 消息确认
- 已读标记
- 重试机制
- 消息统计
- **AgentService** - Agent 服务
- Agent 注册/注销
- 状态更新
- 心跳管理
- 配置更新
- 离线检测
### Fixed
- **gateways.py** - 修复蓝图名称错误 (`gateway_bp``gateways_bp`)
### Changed
- **Dockerfile** - 添加生产环境 Docker 镜像配置
---
## [0.3.0] - 2026-03-14
### Added

29
Dockerfile Normal file
View File

@@ -0,0 +1,29 @@
# PIT Router Dockerfile
FROM python:3.12-slim
# 设置工作目录
WORKDIR /app
# 安装系统依赖
RUN apt-get update && apt-get install -y \
gcc \
postgresql-client \
curl \
&& rm -rf /var/lib/apt/lists/*
# 复制依赖文件
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
# 复制应用代码
COPY . .
# 暴露端口
EXPOSE 9000
# 健康检查
HEALTHCHECK --interval=30s --timeout=10s --start-period=40s --retries=3 \
CMD curl -f http://localhost:9000/health || exit 1
# 启动命令
CMD ["gunicorn", "-w", "4", "-b", "0.0.0.0:9000", "-k", "gevent", "--worker-connections", "1000", "run:app"]

View File

@@ -10,7 +10,7 @@ from app.models import db, Gateway
gateways_bp = Blueprint('gateways', __name__)
@gateway_bp.route('/', methods=['GET'])
@gateways_bp.route('/', methods=['GET'])
@jwt_required()
def get_gateways():
"""获取 Gateway 列表"""
@@ -18,7 +18,7 @@ def get_gateways():
return jsonify({'gateways': [g.to_dict() for g in gateways]}), 200
@gateway_bp.route('/', methods=['POST'])
@gateways_bp.route('/', methods=['POST'])
@jwt_required()
def register_gateway():
"""注册 Gateway"""
@@ -55,7 +55,7 @@ def register_gateway():
return jsonify({'gateway': gateway.to_dict()}), 201
@gateway_bp.route('/<gateway_id>', methods=['DELETE'])
@gateways_bp.route('/<gateway_id>', methods=['DELETE'])
@jwt_required()
def delete_gateway(gateway_id):
"""注销 Gateway"""
@@ -69,7 +69,7 @@ def delete_gateway(gateway_id):
return jsonify({'message': 'Gateway deleted'}), 200
@gateway_bp.route('/<gateway_id>/status', methods=['GET'])
@gateways_bp.route('/<gateway_id>/status', methods=['GET'])
@jwt_required()
def get_gateway_status(gateway_id):
"""获取 Gateway 状态"""
@@ -85,7 +85,7 @@ def get_gateway_status(gateway_id):
}), 200
@gateway_bp.route('/<gateway_id>/heartbeat', methods=['POST'])
@gateways_bp.route('/<gateway_id>/heartbeat', methods=['POST'])
def gateway_heartbeat(gateway_id):
"""Gateway 心跳上报"""
gateway = Gateway.query.get(gateway_id)

View File

@@ -1,3 +1,16 @@
"""
服务模块
"""
from .scheduler import AgentScheduler
from .message_queue import MessageQueue
from .session_service import SessionService
from .message_service import MessageService
from .agent_service import AgentService
__all__ = [
'AgentScheduler',
'MessageQueue',
'SessionService',
'MessageService',
'AgentService',
]

View File

@@ -0,0 +1,198 @@
"""
Agent 服务
处理 Agent 相关的业务逻辑
"""
from typing import Optional, List
from datetime import datetime
from app.models import db, Agent, Gateway
class AgentService:
"""Agent 服务"""
def register_agent(
self,
name: str,
gateway_id: str = None,
model: str = None,
capabilities: list = None,
priority: int = 5,
weight: int = 10,
connection_limit: int = 5
) -> Agent:
"""注册新 Agent"""
agent = Agent(
name=name,
display_name=name,
gateway_id=gateway_id,
model=model,
capabilities=capabilities or [],
priority=priority,
weight=weight,
connection_limit=connection_limit,
status='offline'
)
db.session.add(agent)
# 更新 Gateway 的 Agent 计数
if gateway_id:
gateway = Gateway.query.get(gateway_id)
if gateway:
gateway.agent_count += 1
db.session.commit()
return agent
def get_agent(self, agent_id: str) -> Optional[Agent]:
"""获取 Agent"""
return Agent.query.get(agent_id)
def get_agents(
self,
status: str = None,
gateway_id: str = None
) -> List[Agent]:
"""获取 Agent 列表"""
query = Agent.query
if status:
query = query.filter_by(status=status)
if gateway_id:
query = query.filter_by(gateway_id=gateway_id)
return query.all()
def get_available_agents(self) -> List[Agent]:
"""获取可用 Agent 列表"""
return Agent.query.filter_by(status='online').filter(
Agent.current_sessions < Agent.connection_limit
).all()
def update_agent_status(
self,
agent_id: str,
status: str,
socket_id: str = None
) -> Optional[Agent]:
"""更新 Agent 状态"""
agent = Agent.query.get(agent_id)
if not agent:
return None
agent.status = status
if socket_id:
agent.socket_id = socket_id
if status == 'online':
agent.last_heartbeat = datetime.utcnow()
db.session.commit()
return agent
def heartbeat(self, agent_id: str) -> Optional[Agent]:
"""Agent 心跳"""
agent = Agent.query.get(agent_id)
if not agent:
return None
agent.last_heartbeat = datetime.utcnow()
agent.status = 'online'
db.session.commit()
return agent
def update_agent_config(
self,
agent_id: str,
name: str = None,
display_name: str = None,
model: str = None,
capabilities: list = None,
priority: int = None,
weight: int = None,
connection_limit: int = None
) -> Optional[Agent]:
"""更新 Agent 配置"""
agent = Agent.query.get(agent_id)
if not agent:
return None
if name is not None:
agent.name = name
if display_name is not None:
agent.display_name = display_name
if model is not None:
agent.model = model
if capabilities is not None:
agent.capabilities = capabilities
if priority is not None:
agent.priority = priority
if weight is not None:
agent.weight = weight
if connection_limit is not None:
agent.connection_limit = connection_limit
db.session.commit()
return agent
def delete_agent(self, agent_id: str) -> bool:
"""删除 Agent"""
agent = Agent.query.get(agent_id)
if not agent:
return False
# 更新 Gateway 的 Agent 计数
if agent.gateway_id:
gateway = Gateway.query.get(agent.gateway_id)
if gateway and gateway.agent_count > 0:
gateway.agent_count -= 1
db.session.delete(agent)
db.session.commit()
return True
def get_agent_stats(self) -> dict:
"""获取 Agent 统计"""
total = Agent.query.count()
online = Agent.query.filter_by(status='online').count()
offline = Agent.query.filter_by(status='offline').count()
busy = Agent.query.filter_by(status='busy').count()
# 总连接数
total_sessions = db.session.query(
db.func.sum(Agent.current_sessions)
).scalar() or 0
return {
'total': total,
'online': online,
'offline': offline,
'busy': busy,
'total_sessions': total_sessions,
}
def check_offline_agents(self, timeout: int = 120) -> List[Agent]:
"""检查超时下线的 Agent"""
from datetime import timedelta
threshold = datetime.utcnow() - timedelta(seconds=timeout)
offline_agents = Agent.query.filter(
Agent.status == 'online',
Agent.last_heartbeat < threshold
).all()
for agent in offline_agents:
agent.status = 'offline'
db.session.commit()
return offline_agents

View File

@@ -0,0 +1,155 @@
"""
消息队列服务
实现消息缓存、重试机制
"""
from datetime import datetime
from typing import Optional, List
from app.models import db, Message
from app.extensions import redis_client
import json
class MessageQueue:
"""消息队列管理"""
# Redis 键前缀
PENDING_QUEUE = "pit:messages:pending"
RETRY_QUEUE = "pit:messages:retry"
MESSAGE_PREFIX = "pit:message:"
# 配置
MAX_RETRY = 3
RETRY_DELAY = 5 # 秒
@classmethod
def enqueue(cls, message: Message) -> bool:
"""消息入队"""
try:
# 存储消息详情
message_key = f"{cls.MESSAGE_PREFIX}{message.id}"
redis_client.hset(message_key, mapping={
'id': message.id,
'session_id': message.session_id,
'sender_type': message.sender_type,
'sender_id': message.sender_id,
'content': message.content or '',
'status': message.status,
'retry_count': str(message.retry_count),
'created_at': message.created_at.isoformat() if message.created_at else '',
})
# 加入待处理队列
redis_client.rpush(cls.PENDING_QUEUE, message.id)
return True
except Exception as e:
print(f"Failed to enqueue message: {e}")
return False
@classmethod
def dequeue(cls) -> Optional[dict]:
"""消息出队"""
try:
message_id = redis_client.lpop(cls.PENDING_QUEUE)
if not message_id:
return None
message_key = f"{cls.MESSAGE_PREFIX}{message_id}"
message_data = redis_client.hgetall(message_key)
if not message_data:
return None
return {
'id': message_data.get('id'),
'session_id': message_data.get('session_id'),
'sender_type': message_data.get('sender_type'),
'sender_id': message_data.get('sender_id'),
'content': message_data.get('content'),
'status': message_data.get('status'),
'retry_count': int(message_data.get('retry_count', 0)),
}
except Exception as e:
print(f"Failed to dequeue message: {e}")
return None
@classmethod
def ack(cls, message_id: str) -> bool:
"""消息确认"""
try:
message_key = f"{cls.MESSAGE_PREFIX}{message_id}"
redis_client.delete(message_key)
return True
except Exception as e:
print(f"Failed to ack message: {e}")
return False
@classmethod
def retry(cls, message_id: str) -> bool:
"""消息重试"""
try:
message_key = f"{cls.MESSAGE_PREFIX}{message_id}"
retry_count = redis_client.hget(message_key, 'retry_count')
if retry_count is None:
return False
retry_count = int(retry_count)
if retry_count >= cls.MAX_RETRY:
# 超过最大重试次数,标记为失败
redis_client.hset(message_key, 'status', 'failed')
return False
# 增加重试计数
redis_client.hset(message_key, 'retry_count', str(retry_count + 1))
# 加入重试队列
redis_client.rpush(cls.RETRY_QUEUE, message_id)
return True
except Exception as e:
print(f"Failed to retry message: {e}")
return False
@classmethod
def get_pending_count(cls) -> int:
"""获取待处理消息数量"""
try:
return redis_client.llen(cls.PENDING_QUEUE)
except:
return 0
@classmethod
def get_retry_count(cls) -> int:
"""获取重试消息数量"""
try:
return redis_client.llen(cls.RETRY_QUEUE)
except:
return 0
@classmethod
def process_retry_queue(cls) -> List[dict]:
"""处理重试队列"""
messages = []
try:
while True:
message_id = redis_client.lpop(cls.RETRY_QUEUE)
if not message_id:
break
message_key = f"{cls.MESSAGE_PREFIX}{message_id}"
message_data = redis_client.hgetall(message_key)
if message_data:
messages.append({
'id': message_data.get('id'),
'session_id': message_data.get('session_id'),
'content': message_data.get('content'),
'retry_count': int(message_data.get('retry_count', 0)),
})
except Exception as e:
print(f"Failed to process retry queue: {e}")
return messages

View File

@@ -0,0 +1,167 @@
"""
消息服务
处理消息相关的业务逻辑
"""
from typing import Optional, List
from datetime import datetime
from app.models import db, Message, Session
from app.services.message_queue import MessageQueue
import uuid
class MessageService:
"""消息服务"""
def __init__(self):
self.queue = MessageQueue()
def create_message(
self,
session_id: str,
sender_type: str,
sender_id: str,
content: str,
message_type: str = 'text',
content_type: str = 'markdown',
reply_to: str = None
) -> Message:
"""创建消息"""
message = Message(
id=str(uuid.uuid4()),
session_id=session_id,
sender_type=sender_type,
sender_id=sender_id,
message_type=message_type,
content=content,
content_type=content_type,
reply_to=reply_to,
status='sent',
ack_status='pending'
)
db.session.add(message)
# 更新会话统计
session = Session.query.get(session_id)
if session:
session.message_count += 1
session.last_active_at = datetime.utcnow()
db.session.commit()
# 入队等待确认
self.queue.enqueue(message)
return message
def get_message(self, message_id: str) -> Optional[Message]:
"""获取消息"""
return Message.query.get(message_id)
def get_session_messages(
self,
session_id: str,
limit: int = 50,
offset: int = 0
) -> List[Message]:
"""获取会话消息列表"""
return Message.query.filter_by(session_id=session_id)\
.order_by(Message.created_at.asc())\
.offset(offset)\
.limit(limit)\
.all()
def acknowledge_message(self, message_id: str) -> Optional[Message]:
"""确认消息已送达"""
message = Message.query.get(message_id)
if not message:
return None
message.ack_status = 'acknowledged'
message.status = 'delivered'
message.delivered_at = datetime.utcnow()
db.session.commit()
# 从队列中移除
self.queue.ack(message_id)
return message
def mark_as_read(self, message_id: str) -> Optional[Message]:
"""标记消息已读"""
message = Message.query.get(message_id)
if not message:
return None
message.status = 'read'
db.session.commit()
return message
def mark_session_read(self, session_id: str, user_id: str) -> int:
"""标记会话所有消息已读"""
# 更新数据库
count = Message.query.filter_by(
session_id=session_id,
status='delivered'
).update({'status': 'read'})
# 更新会话未读数
session = Session.query.get(session_id)
if session:
session.unread_count = 0
db.session.commit()
return count
def retry_message(self, message_id: str) -> Optional[Message]:
"""重试发送消息"""
message = Message.query.get(message_id)
if not message:
return None
# 检查重试次数
if message.retry_count >= 3:
message.status = 'failed'
db.session.commit()
return None
# 增加重试计数
message.retry_count += 1
message.status = 'sent'
db.session.commit()
# 加入重试队列
self.queue.retry(message_id)
return message
def get_pending_messages(self) -> List[dict]:
"""获取待处理消息"""
return self.queue.process_retry_queue()
def get_message_stats(self, session_id: str = None) -> dict:
"""获取消息统计"""
query = Message.query
if session_id:
query = query.filter_by(session_id=session_id)
total = query.count()
sent = query.filter_by(status='sent').count()
delivered = query.filter_by(status='delivered').count()
read = query.filter_by(status='read').count()
failed = query.filter_by(status='failed').count()
return {
'total': total,
'sent': sent,
'delivered': delivered,
'read': read,
'failed': failed,
'pending': self.queue.get_pending_count(),
'retry': self.queue.get_retry_count(),
}

174
app/services/scheduler.py Normal file
View File

@@ -0,0 +1,174 @@
"""
Agent 调度器服务
实现多种调度策略
"""
from typing import List, Optional
from app.models import Agent
class BaseScheduler:
"""调度器基类"""
def select_agent(self, agents: List[Agent], context: dict = None) -> Optional[Agent]:
"""选择 Agent子类实现"""
raise NotImplementedError
class RoundRobinScheduler(BaseScheduler):
"""轮询调度器"""
def __init__(self):
self._index = 0
def select_agent(self, agents: List[Agent], context: dict = None) -> Optional[Agent]:
if not agents:
return None
# 只选择在线的 Agent
online_agents = [a for a in agents if a.status == 'online']
if not online_agents:
return None
# 轮询选择
agent = online_agents[self._index % len(online_agents)]
self._index += 1
return agent
class WeightedRoundRobinScheduler(BaseScheduler):
"""加权轮询调度器(默认)"""
def __init__(self):
self._index = 0
self._weights = {}
def select_agent(self, agents: List[Agent], context: dict = None) -> Optional[Agent]:
if not agents:
return None
# 只选择在线的 Agent
online_agents = [a for a in agents if a.status == 'online']
if not online_agents:
return None
# 计算总权重
total_weight = sum(a.weight for a in online_agents)
if total_weight == 0:
return online_agents[0] if online_agents else None
# 加权轮询
current_weight = self._index % total_weight
cumulative = 0
for agent in online_agents:
cumulative += agent.weight
if current_weight < cumulative:
self._index += 1
return agent
return online_agents[-1]
class LeastConnectionsScheduler(BaseScheduler):
"""最少连接调度器"""
def select_agent(self, agents: List[Agent], context: dict = None) -> Optional[Agent]:
if not agents:
return None
# 只选择在线且未达连接上限的 Agent
available_agents = [
a for a in agents
if a.status == 'online' and a.current_sessions < a.connection_limit
]
if not available_agents:
return None
# 选择连接数最少的
return min(available_agents, key=lambda a: a.current_sessions)
class LeastResponseTimeScheduler(BaseScheduler):
"""最快响应调度器(基于最后心跳时间)"""
def select_agent(self, agents: List[Agent], context: dict = None) -> Optional[Agent]:
if not agents:
return None
# 只选择在线的 Agent
online_agents = [a for a in agents if a.status == 'online']
if not online_agents:
return None
# 选择最近有心跳的(响应快的)
from datetime import datetime
now = datetime.utcnow()
def response_score(agent):
if not agent.last_heartbeat:
return float('inf')
# 最近心跳 = 分数低 = 优先选择
return (now - agent.last_heartbeat).total_seconds()
return min(online_agents, key=response_score)
class CapabilityMatchScheduler(BaseScheduler):
"""能力匹配调度器"""
def select_agent(self, agents: List[Agent], context: dict = None) -> Optional[Agent]:
if not agents or not context:
return None
required_capabilities = context.get('capabilities', [])
if not required_capabilities:
# 没有特殊要求,使用加权轮询
return WeightedRoundRobinScheduler().select_agent(agents, context)
# 只选择在线且有对应能力的 Agent
matching_agents = []
for agent in agents:
if agent.status != 'online':
continue
agent_caps = agent.capabilities or []
# 检查是否具备所有必需能力
if all(cap in agent_caps for cap in required_capabilities):
matching_agents.append(agent)
if not matching_agents:
# 没有匹配的,回退到加权轮询
return WeightedRoundRobinScheduler().select_agent(agents, context)
# 在匹配的 Agent 中使用加权轮询
return WeightedRoundRobinScheduler().select_agent(matching_agents, context)
class AgentScheduler:
"""Agent 调度器工厂"""
STRATEGIES = {
'round_robin': RoundRobinScheduler,
'weighted_round_robin': WeightedRoundRobinScheduler,
'least_connections': LeastConnectionsScheduler,
'least_response_time': LeastResponseTimeScheduler,
'capability_match': CapabilityMatchScheduler,
}
def __init__(self, strategy: str = 'weighted_round_robin'):
self._strategy_name = strategy
self._scheduler = self.STRATEGIES.get(strategy, WeightedRoundRobinScheduler)()
def select_agent(self, agents: List[Agent], context: dict = None) -> Optional[Agent]:
"""选择 Agent"""
return self._scheduler.select_agent(agents, context)
def get_strategy(self) -> str:
"""获取当前策略"""
return self._strategy_name
@classmethod
def get_available_strategies(cls) -> list:
"""获取可用策略列表"""
return list(cls.STRATEGIES.keys())

View File

@@ -0,0 +1,168 @@
"""
会话服务
处理会话相关的业务逻辑
"""
from typing import Optional, List
from datetime import datetime
from app.models import db, Session, Agent
from app.services.scheduler import AgentScheduler
class SessionService:
"""会话服务"""
def __init__(self, scheduler_strategy: str = 'weighted_round_robin'):
self.scheduler = AgentScheduler(scheduler_strategy)
def create_session(
self,
user_id: str,
title: str = None,
agent_id: str = None,
channel_type: str = 'web',
context: dict = None
) -> Session:
"""创建新会话"""
# 如果没有指定 Agent调度一个
if not agent_id:
from app.models import Agent
agents = Agent.query.all()
selected_agent = self.scheduler.select_agent(agents, context)
agent_id = selected_agent.id if selected_agent else None
# 创建会话
session = Session(
user_id=user_id,
primary_agent_id=agent_id,
title=title or f"Session {datetime.utcnow().strftime('%Y-%m-%d %H:%M')}",
channel_type=channel_type,
status='active'
)
db.session.add(session)
# 更新 Agent 会话数
if agent_id:
agent = Agent.query.get(agent_id)
if agent:
agent.current_sessions += 1
db.session.commit()
return session
def close_session(self, session_id: str, user_id: str = None) -> Optional[Session]:
"""关闭会话"""
session = Session.query.get(session_id)
if not session:
return None
# 验证权限
if user_id and session.user_id != user_id:
return None
session.status = 'closed'
session.updated_at = datetime.utcnow()
# 更新 Agent 会话数
if session.primary_agent_id:
agent = Agent.query.get(session.primary_agent_id)
if agent and agent.current_sessions > 0:
agent.current_sessions -= 1
db.session.commit()
return session
def get_user_sessions(self, user_id: str, status: str = None) -> List[Session]:
"""获取用户会话列表"""
query = Session.query.filter_by(user_id=user_id)
if status:
query = query.filter_by(status=status)
return query.order_by(Session.created_at.desc()).all()
def get_session(self, session_id: str, user_id: str = None) -> Optional[Session]:
"""获取会话详情"""
session = Session.query.get(session_id)
if not session:
return None
if user_id and session.user_id != user_id:
return None
return session
def update_session_activity(self, session_id: str) -> bool:
"""更新会话活跃时间"""
session = Session.query.get(session_id)
if not session:
return False
session.last_active_at = datetime.utcnow()
session.message_count += 1
db.session.commit()
return True
def assign_agent(self, session_id: str, agent_id: str) -> Optional[Session]:
"""分配 Agent 到会话"""
session = Session.query.get(session_id)
if not session:
return None
# 验证 Agent 是否在线
agent = Agent.query.get(agent_id)
if not agent or agent.status != 'online':
return None
# 更新会话
old_agent_id = session.primary_agent_id
session.primary_agent_id = agent_id
# 更新新旧 Agent 的会话数
if old_agent_id:
old_agent = Agent.query.get(old_agent_id)
if old_agent and old_agent.current_sessions > 0:
old_agent.current_sessions -= 1
agent.current_sessions += 1
db.session.commit()
return session
def add_participant(self, session_id: str, agent_id: str) -> Optional[Session]:
"""添加参与 Agent"""
session = Session.query.get(session_id)
if not session:
return None
# 获取当前参与列表
participants = session.participating_agent_ids or []
if agent_id not in participants:
participants.append(agent_id)
session.participating_agent_ids = participants
db.session.commit()
return session
def remove_participant(self, session_id: str, agent_id: str) -> Optional[Session]:
"""移除参与 Agent"""
session = Session.query.get(session_id)
if not session:
return None
participants = session.participating_agent_ids or []
if agent_id in participants:
participants.remove(agent_id)
session.participating_agent_ids = participants
db.session.commit()
return session