feat: Phase 2 - 服务层实现 + Bug修复
This commit is contained in:
49
CHANGELOG.md
49
CHANGELOG.md
@@ -11,6 +11,55 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
## [0.4.0] - 2026-03-14
|
||||||
|
|
||||||
|
### Added
|
||||||
|
|
||||||
|
#### 🔧 服务层实现 (Phase 2)
|
||||||
|
|
||||||
|
- **AgentScheduler** - 5 种调度策略
|
||||||
|
- RoundRobinScheduler - 轮询调度
|
||||||
|
- WeightedRoundRobinScheduler - 加权轮询调度(默认)
|
||||||
|
- LeastConnectionsScheduler - 最少连接调度
|
||||||
|
- LeastResponseTimeScheduler - 最快响应调度
|
||||||
|
- CapabilityMatchScheduler - 能力匹配调度
|
||||||
|
|
||||||
|
- **MessageQueue** - 消息队列管理
|
||||||
|
- 消息入队/出队
|
||||||
|
- ACK 确认机制
|
||||||
|
- 重试队列管理
|
||||||
|
- Redis 存储支持
|
||||||
|
|
||||||
|
- **SessionService** - 会话服务
|
||||||
|
- 创建/关闭会话
|
||||||
|
- Agent 分配
|
||||||
|
- 多 Agent 协作支持
|
||||||
|
- 会话活跃度追踪
|
||||||
|
|
||||||
|
- **MessageService** - 消息服务
|
||||||
|
- 创建/获取消息
|
||||||
|
- 消息确认
|
||||||
|
- 已读标记
|
||||||
|
- 重试机制
|
||||||
|
- 消息统计
|
||||||
|
|
||||||
|
- **AgentService** - Agent 服务
|
||||||
|
- Agent 注册/注销
|
||||||
|
- 状态更新
|
||||||
|
- 心跳管理
|
||||||
|
- 配置更新
|
||||||
|
- 离线检测
|
||||||
|
|
||||||
|
### Fixed
|
||||||
|
|
||||||
|
- **gateways.py** - 修复蓝图名称错误 (`gateway_bp` → `gateways_bp`)
|
||||||
|
|
||||||
|
### Changed
|
||||||
|
|
||||||
|
- **Dockerfile** - 添加生产环境 Docker 镜像配置
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
## [0.3.0] - 2026-03-14
|
## [0.3.0] - 2026-03-14
|
||||||
|
|
||||||
### Added
|
### Added
|
||||||
|
|||||||
29
Dockerfile
Normal file
29
Dockerfile
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
# PIT Router Dockerfile
|
||||||
|
FROM python:3.12-slim
|
||||||
|
|
||||||
|
# 设置工作目录
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
# 安装系统依赖
|
||||||
|
RUN apt-get update && apt-get install -y \
|
||||||
|
gcc \
|
||||||
|
postgresql-client \
|
||||||
|
curl \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
# 复制依赖文件
|
||||||
|
COPY requirements.txt .
|
||||||
|
RUN pip install --no-cache-dir -r requirements.txt
|
||||||
|
|
||||||
|
# 复制应用代码
|
||||||
|
COPY . .
|
||||||
|
|
||||||
|
# 暴露端口
|
||||||
|
EXPOSE 9000
|
||||||
|
|
||||||
|
# 健康检查
|
||||||
|
HEALTHCHECK --interval=30s --timeout=10s --start-period=40s --retries=3 \
|
||||||
|
CMD curl -f http://localhost:9000/health || exit 1
|
||||||
|
|
||||||
|
# 启动命令
|
||||||
|
CMD ["gunicorn", "-w", "4", "-b", "0.0.0.0:9000", "-k", "gevent", "--worker-connections", "1000", "run:app"]
|
||||||
@@ -10,7 +10,7 @@ from app.models import db, Gateway
|
|||||||
gateways_bp = Blueprint('gateways', __name__)
|
gateways_bp = Blueprint('gateways', __name__)
|
||||||
|
|
||||||
|
|
||||||
@gateway_bp.route('/', methods=['GET'])
|
@gateways_bp.route('/', methods=['GET'])
|
||||||
@jwt_required()
|
@jwt_required()
|
||||||
def get_gateways():
|
def get_gateways():
|
||||||
"""获取 Gateway 列表"""
|
"""获取 Gateway 列表"""
|
||||||
@@ -18,7 +18,7 @@ def get_gateways():
|
|||||||
return jsonify({'gateways': [g.to_dict() for g in gateways]}), 200
|
return jsonify({'gateways': [g.to_dict() for g in gateways]}), 200
|
||||||
|
|
||||||
|
|
||||||
@gateway_bp.route('/', methods=['POST'])
|
@gateways_bp.route('/', methods=['POST'])
|
||||||
@jwt_required()
|
@jwt_required()
|
||||||
def register_gateway():
|
def register_gateway():
|
||||||
"""注册 Gateway"""
|
"""注册 Gateway"""
|
||||||
@@ -55,7 +55,7 @@ def register_gateway():
|
|||||||
return jsonify({'gateway': gateway.to_dict()}), 201
|
return jsonify({'gateway': gateway.to_dict()}), 201
|
||||||
|
|
||||||
|
|
||||||
@gateway_bp.route('/<gateway_id>', methods=['DELETE'])
|
@gateways_bp.route('/<gateway_id>', methods=['DELETE'])
|
||||||
@jwt_required()
|
@jwt_required()
|
||||||
def delete_gateway(gateway_id):
|
def delete_gateway(gateway_id):
|
||||||
"""注销 Gateway"""
|
"""注销 Gateway"""
|
||||||
@@ -69,7 +69,7 @@ def delete_gateway(gateway_id):
|
|||||||
return jsonify({'message': 'Gateway deleted'}), 200
|
return jsonify({'message': 'Gateway deleted'}), 200
|
||||||
|
|
||||||
|
|
||||||
@gateway_bp.route('/<gateway_id>/status', methods=['GET'])
|
@gateways_bp.route('/<gateway_id>/status', methods=['GET'])
|
||||||
@jwt_required()
|
@jwt_required()
|
||||||
def get_gateway_status(gateway_id):
|
def get_gateway_status(gateway_id):
|
||||||
"""获取 Gateway 状态"""
|
"""获取 Gateway 状态"""
|
||||||
@@ -85,7 +85,7 @@ def get_gateway_status(gateway_id):
|
|||||||
}), 200
|
}), 200
|
||||||
|
|
||||||
|
|
||||||
@gateway_bp.route('/<gateway_id>/heartbeat', methods=['POST'])
|
@gateways_bp.route('/<gateway_id>/heartbeat', methods=['POST'])
|
||||||
def gateway_heartbeat(gateway_id):
|
def gateway_heartbeat(gateway_id):
|
||||||
"""Gateway 心跳上报"""
|
"""Gateway 心跳上报"""
|
||||||
gateway = Gateway.query.get(gateway_id)
|
gateway = Gateway.query.get(gateway_id)
|
||||||
|
|||||||
@@ -1,3 +1,16 @@
|
|||||||
"""
|
"""
|
||||||
服务模块
|
服务模块
|
||||||
"""
|
"""
|
||||||
|
from .scheduler import AgentScheduler
|
||||||
|
from .message_queue import MessageQueue
|
||||||
|
from .session_service import SessionService
|
||||||
|
from .message_service import MessageService
|
||||||
|
from .agent_service import AgentService
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'AgentScheduler',
|
||||||
|
'MessageQueue',
|
||||||
|
'SessionService',
|
||||||
|
'MessageService',
|
||||||
|
'AgentService',
|
||||||
|
]
|
||||||
|
|||||||
198
app/services/agent_service.py
Normal file
198
app/services/agent_service.py
Normal file
@@ -0,0 +1,198 @@
|
|||||||
|
"""
|
||||||
|
Agent 服务
|
||||||
|
处理 Agent 相关的业务逻辑
|
||||||
|
"""
|
||||||
|
from typing import Optional, List
|
||||||
|
from datetime import datetime
|
||||||
|
from app.models import db, Agent, Gateway
|
||||||
|
|
||||||
|
|
||||||
|
class AgentService:
|
||||||
|
"""Agent 服务"""
|
||||||
|
|
||||||
|
def register_agent(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
gateway_id: str = None,
|
||||||
|
model: str = None,
|
||||||
|
capabilities: list = None,
|
||||||
|
priority: int = 5,
|
||||||
|
weight: int = 10,
|
||||||
|
connection_limit: int = 5
|
||||||
|
) -> Agent:
|
||||||
|
"""注册新 Agent"""
|
||||||
|
agent = Agent(
|
||||||
|
name=name,
|
||||||
|
display_name=name,
|
||||||
|
gateway_id=gateway_id,
|
||||||
|
model=model,
|
||||||
|
capabilities=capabilities or [],
|
||||||
|
priority=priority,
|
||||||
|
weight=weight,
|
||||||
|
connection_limit=connection_limit,
|
||||||
|
status='offline'
|
||||||
|
)
|
||||||
|
|
||||||
|
db.session.add(agent)
|
||||||
|
|
||||||
|
# 更新 Gateway 的 Agent 计数
|
||||||
|
if gateway_id:
|
||||||
|
gateway = Gateway.query.get(gateway_id)
|
||||||
|
if gateway:
|
||||||
|
gateway.agent_count += 1
|
||||||
|
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
return agent
|
||||||
|
|
||||||
|
def get_agent(self, agent_id: str) -> Optional[Agent]:
|
||||||
|
"""获取 Agent"""
|
||||||
|
return Agent.query.get(agent_id)
|
||||||
|
|
||||||
|
def get_agents(
|
||||||
|
self,
|
||||||
|
status: str = None,
|
||||||
|
gateway_id: str = None
|
||||||
|
) -> List[Agent]:
|
||||||
|
"""获取 Agent 列表"""
|
||||||
|
query = Agent.query
|
||||||
|
|
||||||
|
if status:
|
||||||
|
query = query.filter_by(status=status)
|
||||||
|
|
||||||
|
if gateway_id:
|
||||||
|
query = query.filter_by(gateway_id=gateway_id)
|
||||||
|
|
||||||
|
return query.all()
|
||||||
|
|
||||||
|
def get_available_agents(self) -> List[Agent]:
|
||||||
|
"""获取可用 Agent 列表"""
|
||||||
|
return Agent.query.filter_by(status='online').filter(
|
||||||
|
Agent.current_sessions < Agent.connection_limit
|
||||||
|
).all()
|
||||||
|
|
||||||
|
def update_agent_status(
|
||||||
|
self,
|
||||||
|
agent_id: str,
|
||||||
|
status: str,
|
||||||
|
socket_id: str = None
|
||||||
|
) -> Optional[Agent]:
|
||||||
|
"""更新 Agent 状态"""
|
||||||
|
agent = Agent.query.get(agent_id)
|
||||||
|
if not agent:
|
||||||
|
return None
|
||||||
|
|
||||||
|
agent.status = status
|
||||||
|
|
||||||
|
if socket_id:
|
||||||
|
agent.socket_id = socket_id
|
||||||
|
|
||||||
|
if status == 'online':
|
||||||
|
agent.last_heartbeat = datetime.utcnow()
|
||||||
|
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
return agent
|
||||||
|
|
||||||
|
def heartbeat(self, agent_id: str) -> Optional[Agent]:
|
||||||
|
"""Agent 心跳"""
|
||||||
|
agent = Agent.query.get(agent_id)
|
||||||
|
if not agent:
|
||||||
|
return None
|
||||||
|
|
||||||
|
agent.last_heartbeat = datetime.utcnow()
|
||||||
|
agent.status = 'online'
|
||||||
|
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
return agent
|
||||||
|
|
||||||
|
def update_agent_config(
|
||||||
|
self,
|
||||||
|
agent_id: str,
|
||||||
|
name: str = None,
|
||||||
|
display_name: str = None,
|
||||||
|
model: str = None,
|
||||||
|
capabilities: list = None,
|
||||||
|
priority: int = None,
|
||||||
|
weight: int = None,
|
||||||
|
connection_limit: int = None
|
||||||
|
) -> Optional[Agent]:
|
||||||
|
"""更新 Agent 配置"""
|
||||||
|
agent = Agent.query.get(agent_id)
|
||||||
|
if not agent:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if name is not None:
|
||||||
|
agent.name = name
|
||||||
|
if display_name is not None:
|
||||||
|
agent.display_name = display_name
|
||||||
|
if model is not None:
|
||||||
|
agent.model = model
|
||||||
|
if capabilities is not None:
|
||||||
|
agent.capabilities = capabilities
|
||||||
|
if priority is not None:
|
||||||
|
agent.priority = priority
|
||||||
|
if weight is not None:
|
||||||
|
agent.weight = weight
|
||||||
|
if connection_limit is not None:
|
||||||
|
agent.connection_limit = connection_limit
|
||||||
|
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
return agent
|
||||||
|
|
||||||
|
def delete_agent(self, agent_id: str) -> bool:
|
||||||
|
"""删除 Agent"""
|
||||||
|
agent = Agent.query.get(agent_id)
|
||||||
|
if not agent:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 更新 Gateway 的 Agent 计数
|
||||||
|
if agent.gateway_id:
|
||||||
|
gateway = Gateway.query.get(agent.gateway_id)
|
||||||
|
if gateway and gateway.agent_count > 0:
|
||||||
|
gateway.agent_count -= 1
|
||||||
|
|
||||||
|
db.session.delete(agent)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_agent_stats(self) -> dict:
|
||||||
|
"""获取 Agent 统计"""
|
||||||
|
total = Agent.query.count()
|
||||||
|
online = Agent.query.filter_by(status='online').count()
|
||||||
|
offline = Agent.query.filter_by(status='offline').count()
|
||||||
|
busy = Agent.query.filter_by(status='busy').count()
|
||||||
|
|
||||||
|
# 总连接数
|
||||||
|
total_sessions = db.session.query(
|
||||||
|
db.func.sum(Agent.current_sessions)
|
||||||
|
).scalar() or 0
|
||||||
|
|
||||||
|
return {
|
||||||
|
'total': total,
|
||||||
|
'online': online,
|
||||||
|
'offline': offline,
|
||||||
|
'busy': busy,
|
||||||
|
'total_sessions': total_sessions,
|
||||||
|
}
|
||||||
|
|
||||||
|
def check_offline_agents(self, timeout: int = 120) -> List[Agent]:
|
||||||
|
"""检查超时下线的 Agent"""
|
||||||
|
from datetime import timedelta
|
||||||
|
|
||||||
|
threshold = datetime.utcnow() - timedelta(seconds=timeout)
|
||||||
|
|
||||||
|
offline_agents = Agent.query.filter(
|
||||||
|
Agent.status == 'online',
|
||||||
|
Agent.last_heartbeat < threshold
|
||||||
|
).all()
|
||||||
|
|
||||||
|
for agent in offline_agents:
|
||||||
|
agent.status = 'offline'
|
||||||
|
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
return offline_agents
|
||||||
155
app/services/message_queue.py
Normal file
155
app/services/message_queue.py
Normal file
@@ -0,0 +1,155 @@
|
|||||||
|
"""
|
||||||
|
消息队列服务
|
||||||
|
实现消息缓存、重试机制
|
||||||
|
"""
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Optional, List
|
||||||
|
from app.models import db, Message
|
||||||
|
from app.extensions import redis_client
|
||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
|
class MessageQueue:
|
||||||
|
"""消息队列管理"""
|
||||||
|
|
||||||
|
# Redis 键前缀
|
||||||
|
PENDING_QUEUE = "pit:messages:pending"
|
||||||
|
RETRY_QUEUE = "pit:messages:retry"
|
||||||
|
MESSAGE_PREFIX = "pit:message:"
|
||||||
|
|
||||||
|
# 配置
|
||||||
|
MAX_RETRY = 3
|
||||||
|
RETRY_DELAY = 5 # 秒
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def enqueue(cls, message: Message) -> bool:
|
||||||
|
"""消息入队"""
|
||||||
|
try:
|
||||||
|
# 存储消息详情
|
||||||
|
message_key = f"{cls.MESSAGE_PREFIX}{message.id}"
|
||||||
|
redis_client.hset(message_key, mapping={
|
||||||
|
'id': message.id,
|
||||||
|
'session_id': message.session_id,
|
||||||
|
'sender_type': message.sender_type,
|
||||||
|
'sender_id': message.sender_id,
|
||||||
|
'content': message.content or '',
|
||||||
|
'status': message.status,
|
||||||
|
'retry_count': str(message.retry_count),
|
||||||
|
'created_at': message.created_at.isoformat() if message.created_at else '',
|
||||||
|
})
|
||||||
|
|
||||||
|
# 加入待处理队列
|
||||||
|
redis_client.rpush(cls.PENDING_QUEUE, message.id)
|
||||||
|
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to enqueue message: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def dequeue(cls) -> Optional[dict]:
|
||||||
|
"""消息出队"""
|
||||||
|
try:
|
||||||
|
message_id = redis_client.lpop(cls.PENDING_QUEUE)
|
||||||
|
if not message_id:
|
||||||
|
return None
|
||||||
|
|
||||||
|
message_key = f"{cls.MESSAGE_PREFIX}{message_id}"
|
||||||
|
message_data = redis_client.hgetall(message_key)
|
||||||
|
|
||||||
|
if not message_data:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return {
|
||||||
|
'id': message_data.get('id'),
|
||||||
|
'session_id': message_data.get('session_id'),
|
||||||
|
'sender_type': message_data.get('sender_type'),
|
||||||
|
'sender_id': message_data.get('sender_id'),
|
||||||
|
'content': message_data.get('content'),
|
||||||
|
'status': message_data.get('status'),
|
||||||
|
'retry_count': int(message_data.get('retry_count', 0)),
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to dequeue message: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def ack(cls, message_id: str) -> bool:
|
||||||
|
"""消息确认"""
|
||||||
|
try:
|
||||||
|
message_key = f"{cls.MESSAGE_PREFIX}{message_id}"
|
||||||
|
redis_client.delete(message_key)
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to ack message: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def retry(cls, message_id: str) -> bool:
|
||||||
|
"""消息重试"""
|
||||||
|
try:
|
||||||
|
message_key = f"{cls.MESSAGE_PREFIX}{message_id}"
|
||||||
|
retry_count = redis_client.hget(message_key, 'retry_count')
|
||||||
|
|
||||||
|
if retry_count is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
retry_count = int(retry_count)
|
||||||
|
|
||||||
|
if retry_count >= cls.MAX_RETRY:
|
||||||
|
# 超过最大重试次数,标记为失败
|
||||||
|
redis_client.hset(message_key, 'status', 'failed')
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 增加重试计数
|
||||||
|
redis_client.hset(message_key, 'retry_count', str(retry_count + 1))
|
||||||
|
|
||||||
|
# 加入重试队列
|
||||||
|
redis_client.rpush(cls.RETRY_QUEUE, message_id)
|
||||||
|
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to retry message: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_pending_count(cls) -> int:
|
||||||
|
"""获取待处理消息数量"""
|
||||||
|
try:
|
||||||
|
return redis_client.llen(cls.PENDING_QUEUE)
|
||||||
|
except:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_retry_count(cls) -> int:
|
||||||
|
"""获取重试消息数量"""
|
||||||
|
try:
|
||||||
|
return redis_client.llen(cls.RETRY_QUEUE)
|
||||||
|
except:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def process_retry_queue(cls) -> List[dict]:
|
||||||
|
"""处理重试队列"""
|
||||||
|
messages = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
message_id = redis_client.lpop(cls.RETRY_QUEUE)
|
||||||
|
if not message_id:
|
||||||
|
break
|
||||||
|
|
||||||
|
message_key = f"{cls.MESSAGE_PREFIX}{message_id}"
|
||||||
|
message_data = redis_client.hgetall(message_key)
|
||||||
|
|
||||||
|
if message_data:
|
||||||
|
messages.append({
|
||||||
|
'id': message_data.get('id'),
|
||||||
|
'session_id': message_data.get('session_id'),
|
||||||
|
'content': message_data.get('content'),
|
||||||
|
'retry_count': int(message_data.get('retry_count', 0)),
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to process retry queue: {e}")
|
||||||
|
|
||||||
|
return messages
|
||||||
167
app/services/message_service.py
Normal file
167
app/services/message_service.py
Normal file
@@ -0,0 +1,167 @@
|
|||||||
|
"""
|
||||||
|
消息服务
|
||||||
|
处理消息相关的业务逻辑
|
||||||
|
"""
|
||||||
|
from typing import Optional, List
|
||||||
|
from datetime import datetime
|
||||||
|
from app.models import db, Message, Session
|
||||||
|
from app.services.message_queue import MessageQueue
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
|
||||||
|
class MessageService:
|
||||||
|
"""消息服务"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.queue = MessageQueue()
|
||||||
|
|
||||||
|
def create_message(
|
||||||
|
self,
|
||||||
|
session_id: str,
|
||||||
|
sender_type: str,
|
||||||
|
sender_id: str,
|
||||||
|
content: str,
|
||||||
|
message_type: str = 'text',
|
||||||
|
content_type: str = 'markdown',
|
||||||
|
reply_to: str = None
|
||||||
|
) -> Message:
|
||||||
|
"""创建消息"""
|
||||||
|
message = Message(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
session_id=session_id,
|
||||||
|
sender_type=sender_type,
|
||||||
|
sender_id=sender_id,
|
||||||
|
message_type=message_type,
|
||||||
|
content=content,
|
||||||
|
content_type=content_type,
|
||||||
|
reply_to=reply_to,
|
||||||
|
status='sent',
|
||||||
|
ack_status='pending'
|
||||||
|
)
|
||||||
|
|
||||||
|
db.session.add(message)
|
||||||
|
|
||||||
|
# 更新会话统计
|
||||||
|
session = Session.query.get(session_id)
|
||||||
|
if session:
|
||||||
|
session.message_count += 1
|
||||||
|
session.last_active_at = datetime.utcnow()
|
||||||
|
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
# 入队等待确认
|
||||||
|
self.queue.enqueue(message)
|
||||||
|
|
||||||
|
return message
|
||||||
|
|
||||||
|
def get_message(self, message_id: str) -> Optional[Message]:
|
||||||
|
"""获取消息"""
|
||||||
|
return Message.query.get(message_id)
|
||||||
|
|
||||||
|
def get_session_messages(
|
||||||
|
self,
|
||||||
|
session_id: str,
|
||||||
|
limit: int = 50,
|
||||||
|
offset: int = 0
|
||||||
|
) -> List[Message]:
|
||||||
|
"""获取会话消息列表"""
|
||||||
|
return Message.query.filter_by(session_id=session_id)\
|
||||||
|
.order_by(Message.created_at.asc())\
|
||||||
|
.offset(offset)\
|
||||||
|
.limit(limit)\
|
||||||
|
.all()
|
||||||
|
|
||||||
|
def acknowledge_message(self, message_id: str) -> Optional[Message]:
|
||||||
|
"""确认消息已送达"""
|
||||||
|
message = Message.query.get(message_id)
|
||||||
|
if not message:
|
||||||
|
return None
|
||||||
|
|
||||||
|
message.ack_status = 'acknowledged'
|
||||||
|
message.status = 'delivered'
|
||||||
|
message.delivered_at = datetime.utcnow()
|
||||||
|
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
# 从队列中移除
|
||||||
|
self.queue.ack(message_id)
|
||||||
|
|
||||||
|
return message
|
||||||
|
|
||||||
|
def mark_as_read(self, message_id: str) -> Optional[Message]:
|
||||||
|
"""标记消息已读"""
|
||||||
|
message = Message.query.get(message_id)
|
||||||
|
if not message:
|
||||||
|
return None
|
||||||
|
|
||||||
|
message.status = 'read'
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
return message
|
||||||
|
|
||||||
|
def mark_session_read(self, session_id: str, user_id: str) -> int:
|
||||||
|
"""标记会话所有消息已读"""
|
||||||
|
# 更新数据库
|
||||||
|
count = Message.query.filter_by(
|
||||||
|
session_id=session_id,
|
||||||
|
status='delivered'
|
||||||
|
).update({'status': 'read'})
|
||||||
|
|
||||||
|
# 更新会话未读数
|
||||||
|
session = Session.query.get(session_id)
|
||||||
|
if session:
|
||||||
|
session.unread_count = 0
|
||||||
|
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
return count
|
||||||
|
|
||||||
|
def retry_message(self, message_id: str) -> Optional[Message]:
|
||||||
|
"""重试发送消息"""
|
||||||
|
message = Message.query.get(message_id)
|
||||||
|
if not message:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 检查重试次数
|
||||||
|
if message.retry_count >= 3:
|
||||||
|
message.status = 'failed'
|
||||||
|
db.session.commit()
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 增加重试计数
|
||||||
|
message.retry_count += 1
|
||||||
|
message.status = 'sent'
|
||||||
|
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
# 加入重试队列
|
||||||
|
self.queue.retry(message_id)
|
||||||
|
|
||||||
|
return message
|
||||||
|
|
||||||
|
def get_pending_messages(self) -> List[dict]:
|
||||||
|
"""获取待处理消息"""
|
||||||
|
return self.queue.process_retry_queue()
|
||||||
|
|
||||||
|
def get_message_stats(self, session_id: str = None) -> dict:
|
||||||
|
"""获取消息统计"""
|
||||||
|
query = Message.query
|
||||||
|
|
||||||
|
if session_id:
|
||||||
|
query = query.filter_by(session_id=session_id)
|
||||||
|
|
||||||
|
total = query.count()
|
||||||
|
sent = query.filter_by(status='sent').count()
|
||||||
|
delivered = query.filter_by(status='delivered').count()
|
||||||
|
read = query.filter_by(status='read').count()
|
||||||
|
failed = query.filter_by(status='failed').count()
|
||||||
|
|
||||||
|
return {
|
||||||
|
'total': total,
|
||||||
|
'sent': sent,
|
||||||
|
'delivered': delivered,
|
||||||
|
'read': read,
|
||||||
|
'failed': failed,
|
||||||
|
'pending': self.queue.get_pending_count(),
|
||||||
|
'retry': self.queue.get_retry_count(),
|
||||||
|
}
|
||||||
174
app/services/scheduler.py
Normal file
174
app/services/scheduler.py
Normal file
@@ -0,0 +1,174 @@
|
|||||||
|
"""
|
||||||
|
Agent 调度器服务
|
||||||
|
实现多种调度策略
|
||||||
|
"""
|
||||||
|
from typing import List, Optional
|
||||||
|
from app.models import Agent
|
||||||
|
|
||||||
|
|
||||||
|
class BaseScheduler:
|
||||||
|
"""调度器基类"""
|
||||||
|
|
||||||
|
def select_agent(self, agents: List[Agent], context: dict = None) -> Optional[Agent]:
|
||||||
|
"""选择 Agent,子类实现"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class RoundRobinScheduler(BaseScheduler):
|
||||||
|
"""轮询调度器"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._index = 0
|
||||||
|
|
||||||
|
def select_agent(self, agents: List[Agent], context: dict = None) -> Optional[Agent]:
|
||||||
|
if not agents:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 只选择在线的 Agent
|
||||||
|
online_agents = [a for a in agents if a.status == 'online']
|
||||||
|
if not online_agents:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 轮询选择
|
||||||
|
agent = online_agents[self._index % len(online_agents)]
|
||||||
|
self._index += 1
|
||||||
|
return agent
|
||||||
|
|
||||||
|
|
||||||
|
class WeightedRoundRobinScheduler(BaseScheduler):
|
||||||
|
"""加权轮询调度器(默认)"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._index = 0
|
||||||
|
self._weights = {}
|
||||||
|
|
||||||
|
def select_agent(self, agents: List[Agent], context: dict = None) -> Optional[Agent]:
|
||||||
|
if not agents:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 只选择在线的 Agent
|
||||||
|
online_agents = [a for a in agents if a.status == 'online']
|
||||||
|
if not online_agents:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 计算总权重
|
||||||
|
total_weight = sum(a.weight for a in online_agents)
|
||||||
|
if total_weight == 0:
|
||||||
|
return online_agents[0] if online_agents else None
|
||||||
|
|
||||||
|
# 加权轮询
|
||||||
|
current_weight = self._index % total_weight
|
||||||
|
cumulative = 0
|
||||||
|
|
||||||
|
for agent in online_agents:
|
||||||
|
cumulative += agent.weight
|
||||||
|
if current_weight < cumulative:
|
||||||
|
self._index += 1
|
||||||
|
return agent
|
||||||
|
|
||||||
|
return online_agents[-1]
|
||||||
|
|
||||||
|
|
||||||
|
class LeastConnectionsScheduler(BaseScheduler):
|
||||||
|
"""最少连接调度器"""
|
||||||
|
|
||||||
|
def select_agent(self, agents: List[Agent], context: dict = None) -> Optional[Agent]:
|
||||||
|
if not agents:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 只选择在线且未达连接上限的 Agent
|
||||||
|
available_agents = [
|
||||||
|
a for a in agents
|
||||||
|
if a.status == 'online' and a.current_sessions < a.connection_limit
|
||||||
|
]
|
||||||
|
|
||||||
|
if not available_agents:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 选择连接数最少的
|
||||||
|
return min(available_agents, key=lambda a: a.current_sessions)
|
||||||
|
|
||||||
|
|
||||||
|
class LeastResponseTimeScheduler(BaseScheduler):
|
||||||
|
"""最快响应调度器(基于最后心跳时间)"""
|
||||||
|
|
||||||
|
def select_agent(self, agents: List[Agent], context: dict = None) -> Optional[Agent]:
|
||||||
|
if not agents:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 只选择在线的 Agent
|
||||||
|
online_agents = [a for a in agents if a.status == 'online']
|
||||||
|
if not online_agents:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 选择最近有心跳的(响应快的)
|
||||||
|
from datetime import datetime
|
||||||
|
now = datetime.utcnow()
|
||||||
|
|
||||||
|
def response_score(agent):
|
||||||
|
if not agent.last_heartbeat:
|
||||||
|
return float('inf')
|
||||||
|
# 最近心跳 = 分数低 = 优先选择
|
||||||
|
return (now - agent.last_heartbeat).total_seconds()
|
||||||
|
|
||||||
|
return min(online_agents, key=response_score)
|
||||||
|
|
||||||
|
|
||||||
|
class CapabilityMatchScheduler(BaseScheduler):
|
||||||
|
"""能力匹配调度器"""
|
||||||
|
|
||||||
|
def select_agent(self, agents: List[Agent], context: dict = None) -> Optional[Agent]:
|
||||||
|
if not agents or not context:
|
||||||
|
return None
|
||||||
|
|
||||||
|
required_capabilities = context.get('capabilities', [])
|
||||||
|
if not required_capabilities:
|
||||||
|
# 没有特殊要求,使用加权轮询
|
||||||
|
return WeightedRoundRobinScheduler().select_agent(agents, context)
|
||||||
|
|
||||||
|
# 只选择在线且有对应能力的 Agent
|
||||||
|
matching_agents = []
|
||||||
|
for agent in agents:
|
||||||
|
if agent.status != 'online':
|
||||||
|
continue
|
||||||
|
|
||||||
|
agent_caps = agent.capabilities or []
|
||||||
|
# 检查是否具备所有必需能力
|
||||||
|
if all(cap in agent_caps for cap in required_capabilities):
|
||||||
|
matching_agents.append(agent)
|
||||||
|
|
||||||
|
if not matching_agents:
|
||||||
|
# 没有匹配的,回退到加权轮询
|
||||||
|
return WeightedRoundRobinScheduler().select_agent(agents, context)
|
||||||
|
|
||||||
|
# 在匹配的 Agent 中使用加权轮询
|
||||||
|
return WeightedRoundRobinScheduler().select_agent(matching_agents, context)
|
||||||
|
|
||||||
|
|
||||||
|
class AgentScheduler:
|
||||||
|
"""Agent 调度器工厂"""
|
||||||
|
|
||||||
|
STRATEGIES = {
|
||||||
|
'round_robin': RoundRobinScheduler,
|
||||||
|
'weighted_round_robin': WeightedRoundRobinScheduler,
|
||||||
|
'least_connections': LeastConnectionsScheduler,
|
||||||
|
'least_response_time': LeastResponseTimeScheduler,
|
||||||
|
'capability_match': CapabilityMatchScheduler,
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, strategy: str = 'weighted_round_robin'):
|
||||||
|
self._strategy_name = strategy
|
||||||
|
self._scheduler = self.STRATEGIES.get(strategy, WeightedRoundRobinScheduler)()
|
||||||
|
|
||||||
|
def select_agent(self, agents: List[Agent], context: dict = None) -> Optional[Agent]:
|
||||||
|
"""选择 Agent"""
|
||||||
|
return self._scheduler.select_agent(agents, context)
|
||||||
|
|
||||||
|
def get_strategy(self) -> str:
|
||||||
|
"""获取当前策略"""
|
||||||
|
return self._strategy_name
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_available_strategies(cls) -> list:
|
||||||
|
"""获取可用策略列表"""
|
||||||
|
return list(cls.STRATEGIES.keys())
|
||||||
168
app/services/session_service.py
Normal file
168
app/services/session_service.py
Normal file
@@ -0,0 +1,168 @@
|
|||||||
|
"""
|
||||||
|
会话服务
|
||||||
|
处理会话相关的业务逻辑
|
||||||
|
"""
|
||||||
|
from typing import Optional, List
|
||||||
|
from datetime import datetime
|
||||||
|
from app.models import db, Session, Agent
|
||||||
|
from app.services.scheduler import AgentScheduler
|
||||||
|
|
||||||
|
|
||||||
|
class SessionService:
|
||||||
|
"""会话服务"""
|
||||||
|
|
||||||
|
def __init__(self, scheduler_strategy: str = 'weighted_round_robin'):
|
||||||
|
self.scheduler = AgentScheduler(scheduler_strategy)
|
||||||
|
|
||||||
|
def create_session(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
title: str = None,
|
||||||
|
agent_id: str = None,
|
||||||
|
channel_type: str = 'web',
|
||||||
|
context: dict = None
|
||||||
|
) -> Session:
|
||||||
|
"""创建新会话"""
|
||||||
|
# 如果没有指定 Agent,调度一个
|
||||||
|
if not agent_id:
|
||||||
|
from app.models import Agent
|
||||||
|
agents = Agent.query.all()
|
||||||
|
selected_agent = self.scheduler.select_agent(agents, context)
|
||||||
|
agent_id = selected_agent.id if selected_agent else None
|
||||||
|
|
||||||
|
# 创建会话
|
||||||
|
session = Session(
|
||||||
|
user_id=user_id,
|
||||||
|
primary_agent_id=agent_id,
|
||||||
|
title=title or f"Session {datetime.utcnow().strftime('%Y-%m-%d %H:%M')}",
|
||||||
|
channel_type=channel_type,
|
||||||
|
status='active'
|
||||||
|
)
|
||||||
|
|
||||||
|
db.session.add(session)
|
||||||
|
|
||||||
|
# 更新 Agent 会话数
|
||||||
|
if agent_id:
|
||||||
|
agent = Agent.query.get(agent_id)
|
||||||
|
if agent:
|
||||||
|
agent.current_sessions += 1
|
||||||
|
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
return session
|
||||||
|
|
||||||
|
def close_session(self, session_id: str, user_id: str = None) -> Optional[Session]:
|
||||||
|
"""关闭会话"""
|
||||||
|
session = Session.query.get(session_id)
|
||||||
|
if not session:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 验证权限
|
||||||
|
if user_id and session.user_id != user_id:
|
||||||
|
return None
|
||||||
|
|
||||||
|
session.status = 'closed'
|
||||||
|
session.updated_at = datetime.utcnow()
|
||||||
|
|
||||||
|
# 更新 Agent 会话数
|
||||||
|
if session.primary_agent_id:
|
||||||
|
agent = Agent.query.get(session.primary_agent_id)
|
||||||
|
if agent and agent.current_sessions > 0:
|
||||||
|
agent.current_sessions -= 1
|
||||||
|
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
return session
|
||||||
|
|
||||||
|
def get_user_sessions(self, user_id: str, status: str = None) -> List[Session]:
|
||||||
|
"""获取用户会话列表"""
|
||||||
|
query = Session.query.filter_by(user_id=user_id)
|
||||||
|
|
||||||
|
if status:
|
||||||
|
query = query.filter_by(status=status)
|
||||||
|
|
||||||
|
return query.order_by(Session.created_at.desc()).all()
|
||||||
|
|
||||||
|
def get_session(self, session_id: str, user_id: str = None) -> Optional[Session]:
|
||||||
|
"""获取会话详情"""
|
||||||
|
session = Session.query.get(session_id)
|
||||||
|
|
||||||
|
if not session:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if user_id and session.user_id != user_id:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return session
|
||||||
|
|
||||||
|
def update_session_activity(self, session_id: str) -> bool:
|
||||||
|
"""更新会话活跃时间"""
|
||||||
|
session = Session.query.get(session_id)
|
||||||
|
if not session:
|
||||||
|
return False
|
||||||
|
|
||||||
|
session.last_active_at = datetime.utcnow()
|
||||||
|
session.message_count += 1
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def assign_agent(self, session_id: str, agent_id: str) -> Optional[Session]:
|
||||||
|
"""分配 Agent 到会话"""
|
||||||
|
session = Session.query.get(session_id)
|
||||||
|
if not session:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 验证 Agent 是否在线
|
||||||
|
agent = Agent.query.get(agent_id)
|
||||||
|
if not agent or agent.status != 'online':
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 更新会话
|
||||||
|
old_agent_id = session.primary_agent_id
|
||||||
|
session.primary_agent_id = agent_id
|
||||||
|
|
||||||
|
# 更新新旧 Agent 的会话数
|
||||||
|
if old_agent_id:
|
||||||
|
old_agent = Agent.query.get(old_agent_id)
|
||||||
|
if old_agent and old_agent.current_sessions > 0:
|
||||||
|
old_agent.current_sessions -= 1
|
||||||
|
|
||||||
|
agent.current_sessions += 1
|
||||||
|
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
return session
|
||||||
|
|
||||||
|
def add_participant(self, session_id: str, agent_id: str) -> Optional[Session]:
|
||||||
|
"""添加参与 Agent"""
|
||||||
|
session = Session.query.get(session_id)
|
||||||
|
if not session:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 获取当前参与列表
|
||||||
|
participants = session.participating_agent_ids or []
|
||||||
|
|
||||||
|
if agent_id not in participants:
|
||||||
|
participants.append(agent_id)
|
||||||
|
session.participating_agent_ids = participants
|
||||||
|
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
return session
|
||||||
|
|
||||||
|
def remove_participant(self, session_id: str, agent_id: str) -> Optional[Session]:
|
||||||
|
"""移除参与 Agent"""
|
||||||
|
session = Session.query.get(session_id)
|
||||||
|
if not session:
|
||||||
|
return None
|
||||||
|
|
||||||
|
participants = session.participating_agent_ids or []
|
||||||
|
|
||||||
|
if agent_id in participants:
|
||||||
|
participants.remove(agent_id)
|
||||||
|
session.participating_agent_ids = participants
|
||||||
|
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
return session
|
||||||
Reference in New Issue
Block a user