diff --git a/CHANGELOG.md b/CHANGELOG.md index 5e2765c..00d5438 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,55 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 --- +## [0.4.0] - 2026-03-14 + +### Added + +#### 🔧 服务层实现 (Phase 2) + +- **AgentScheduler** - 5 种调度策略 + - RoundRobinScheduler - 轮询调度 + - WeightedRoundRobinScheduler - 加权轮询调度(默认) + - LeastConnectionsScheduler - 最少连接调度 + - LeastResponseTimeScheduler - 最快响应调度 + - CapabilityMatchScheduler - 能力匹配调度 + +- **MessageQueue** - 消息队列管理 + - 消息入队/出队 + - ACK 确认机制 + - 重试队列管理 + - Redis 存储支持 + +- **SessionService** - 会话服务 + - 创建/关闭会话 + - Agent 分配 + - 多 Agent 协作支持 + - 会话活跃度追踪 + +- **MessageService** - 消息服务 + - 创建/获取消息 + - 消息确认 + - 已读标记 + - 重试机制 + - 消息统计 + +- **AgentService** - Agent 服务 + - Agent 注册/注销 + - 状态更新 + - 心跳管理 + - 配置更新 + - 离线检测 + +### Fixed + +- **gateways.py** - 修复蓝图名称错误 (`gateway_bp` → `gateways_bp`) + +### Changed + +- **Dockerfile** - 添加生产环境 Docker 镜像配置 + +--- + ## [0.3.0] - 2026-03-14 ### Added diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..f10f8e7 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,29 @@ +# PIT Router Dockerfile +FROM python:3.12-slim + +# 设置工作目录 +WORKDIR /app + +# 安装系统依赖 +RUN apt-get update && apt-get install -y \ + gcc \ + postgresql-client \ + curl \ + && rm -rf /var/lib/apt/lists/* + +# 复制依赖文件 +COPY requirements.txt . +RUN pip install --no-cache-dir -r requirements.txt + +# 复制应用代码 +COPY . . + +# 暴露端口 +EXPOSE 9000 + +# 健康检查 +HEALTHCHECK --interval=30s --timeout=10s --start-period=40s --retries=3 \ + CMD curl -f http://localhost:9000/health || exit 1 + +# 启动命令 +CMD ["gunicorn", "-w", "4", "-b", "0.0.0.0:9000", "-k", "gevent", "--worker-connections", "1000", "run:app"] diff --git a/app/routes/gateways.py b/app/routes/gateways.py index 478bae0..b719f6c 100644 --- a/app/routes/gateways.py +++ b/app/routes/gateways.py @@ -10,7 +10,7 @@ from app.models import db, Gateway gateways_bp = Blueprint('gateways', __name__) -@gateway_bp.route('/', methods=['GET']) +@gateways_bp.route('/', methods=['GET']) @jwt_required() def get_gateways(): """获取 Gateway 列表""" @@ -18,7 +18,7 @@ def get_gateways(): return jsonify({'gateways': [g.to_dict() for g in gateways]}), 200 -@gateway_bp.route('/', methods=['POST']) +@gateways_bp.route('/', methods=['POST']) @jwt_required() def register_gateway(): """注册 Gateway""" @@ -55,7 +55,7 @@ def register_gateway(): return jsonify({'gateway': gateway.to_dict()}), 201 -@gateway_bp.route('/', methods=['DELETE']) +@gateways_bp.route('/', methods=['DELETE']) @jwt_required() def delete_gateway(gateway_id): """注销 Gateway""" @@ -69,7 +69,7 @@ def delete_gateway(gateway_id): return jsonify({'message': 'Gateway deleted'}), 200 -@gateway_bp.route('//status', methods=['GET']) +@gateways_bp.route('//status', methods=['GET']) @jwt_required() def get_gateway_status(gateway_id): """获取 Gateway 状态""" @@ -85,7 +85,7 @@ def get_gateway_status(gateway_id): }), 200 -@gateway_bp.route('//heartbeat', methods=['POST']) +@gateways_bp.route('//heartbeat', methods=['POST']) def gateway_heartbeat(gateway_id): """Gateway 心跳上报""" gateway = Gateway.query.get(gateway_id) diff --git a/app/services/__init__.py b/app/services/__init__.py index b60bf1b..d081b6b 100644 --- a/app/services/__init__.py +++ b/app/services/__init__.py @@ -1,3 +1,16 @@ """ 服务模块 """ +from .scheduler import AgentScheduler +from .message_queue import MessageQueue +from .session_service import SessionService +from .message_service import MessageService +from .agent_service import AgentService + +__all__ = [ + 'AgentScheduler', + 'MessageQueue', + 'SessionService', + 'MessageService', + 'AgentService', +] diff --git a/app/services/agent_service.py b/app/services/agent_service.py new file mode 100644 index 0000000..72eadf7 --- /dev/null +++ b/app/services/agent_service.py @@ -0,0 +1,198 @@ +""" +Agent 服务 +处理 Agent 相关的业务逻辑 +""" +from typing import Optional, List +from datetime import datetime +from app.models import db, Agent, Gateway + + +class AgentService: + """Agent 服务""" + + def register_agent( + self, + name: str, + gateway_id: str = None, + model: str = None, + capabilities: list = None, + priority: int = 5, + weight: int = 10, + connection_limit: int = 5 + ) -> Agent: + """注册新 Agent""" + agent = Agent( + name=name, + display_name=name, + gateway_id=gateway_id, + model=model, + capabilities=capabilities or [], + priority=priority, + weight=weight, + connection_limit=connection_limit, + status='offline' + ) + + db.session.add(agent) + + # 更新 Gateway 的 Agent 计数 + if gateway_id: + gateway = Gateway.query.get(gateway_id) + if gateway: + gateway.agent_count += 1 + + db.session.commit() + + return agent + + def get_agent(self, agent_id: str) -> Optional[Agent]: + """获取 Agent""" + return Agent.query.get(agent_id) + + def get_agents( + self, + status: str = None, + gateway_id: str = None + ) -> List[Agent]: + """获取 Agent 列表""" + query = Agent.query + + if status: + query = query.filter_by(status=status) + + if gateway_id: + query = query.filter_by(gateway_id=gateway_id) + + return query.all() + + def get_available_agents(self) -> List[Agent]: + """获取可用 Agent 列表""" + return Agent.query.filter_by(status='online').filter( + Agent.current_sessions < Agent.connection_limit + ).all() + + def update_agent_status( + self, + agent_id: str, + status: str, + socket_id: str = None + ) -> Optional[Agent]: + """更新 Agent 状态""" + agent = Agent.query.get(agent_id) + if not agent: + return None + + agent.status = status + + if socket_id: + agent.socket_id = socket_id + + if status == 'online': + agent.last_heartbeat = datetime.utcnow() + + db.session.commit() + + return agent + + def heartbeat(self, agent_id: str) -> Optional[Agent]: + """Agent 心跳""" + agent = Agent.query.get(agent_id) + if not agent: + return None + + agent.last_heartbeat = datetime.utcnow() + agent.status = 'online' + + db.session.commit() + + return agent + + def update_agent_config( + self, + agent_id: str, + name: str = None, + display_name: str = None, + model: str = None, + capabilities: list = None, + priority: int = None, + weight: int = None, + connection_limit: int = None + ) -> Optional[Agent]: + """更新 Agent 配置""" + agent = Agent.query.get(agent_id) + if not agent: + return None + + if name is not None: + agent.name = name + if display_name is not None: + agent.display_name = display_name + if model is not None: + agent.model = model + if capabilities is not None: + agent.capabilities = capabilities + if priority is not None: + agent.priority = priority + if weight is not None: + agent.weight = weight + if connection_limit is not None: + agent.connection_limit = connection_limit + + db.session.commit() + + return agent + + def delete_agent(self, agent_id: str) -> bool: + """删除 Agent""" + agent = Agent.query.get(agent_id) + if not agent: + return False + + # 更新 Gateway 的 Agent 计数 + if agent.gateway_id: + gateway = Gateway.query.get(agent.gateway_id) + if gateway and gateway.agent_count > 0: + gateway.agent_count -= 1 + + db.session.delete(agent) + db.session.commit() + + return True + + def get_agent_stats(self) -> dict: + """获取 Agent 统计""" + total = Agent.query.count() + online = Agent.query.filter_by(status='online').count() + offline = Agent.query.filter_by(status='offline').count() + busy = Agent.query.filter_by(status='busy').count() + + # 总连接数 + total_sessions = db.session.query( + db.func.sum(Agent.current_sessions) + ).scalar() or 0 + + return { + 'total': total, + 'online': online, + 'offline': offline, + 'busy': busy, + 'total_sessions': total_sessions, + } + + def check_offline_agents(self, timeout: int = 120) -> List[Agent]: + """检查超时下线的 Agent""" + from datetime import timedelta + + threshold = datetime.utcnow() - timedelta(seconds=timeout) + + offline_agents = Agent.query.filter( + Agent.status == 'online', + Agent.last_heartbeat < threshold + ).all() + + for agent in offline_agents: + agent.status = 'offline' + + db.session.commit() + + return offline_agents diff --git a/app/services/message_queue.py b/app/services/message_queue.py new file mode 100644 index 0000000..12d385e --- /dev/null +++ b/app/services/message_queue.py @@ -0,0 +1,155 @@ +""" +消息队列服务 +实现消息缓存、重试机制 +""" +from datetime import datetime +from typing import Optional, List +from app.models import db, Message +from app.extensions import redis_client +import json + + +class MessageQueue: + """消息队列管理""" + + # Redis 键前缀 + PENDING_QUEUE = "pit:messages:pending" + RETRY_QUEUE = "pit:messages:retry" + MESSAGE_PREFIX = "pit:message:" + + # 配置 + MAX_RETRY = 3 + RETRY_DELAY = 5 # 秒 + + @classmethod + def enqueue(cls, message: Message) -> bool: + """消息入队""" + try: + # 存储消息详情 + message_key = f"{cls.MESSAGE_PREFIX}{message.id}" + redis_client.hset(message_key, mapping={ + 'id': message.id, + 'session_id': message.session_id, + 'sender_type': message.sender_type, + 'sender_id': message.sender_id, + 'content': message.content or '', + 'status': message.status, + 'retry_count': str(message.retry_count), + 'created_at': message.created_at.isoformat() if message.created_at else '', + }) + + # 加入待处理队列 + redis_client.rpush(cls.PENDING_QUEUE, message.id) + + return True + except Exception as e: + print(f"Failed to enqueue message: {e}") + return False + + @classmethod + def dequeue(cls) -> Optional[dict]: + """消息出队""" + try: + message_id = redis_client.lpop(cls.PENDING_QUEUE) + if not message_id: + return None + + message_key = f"{cls.MESSAGE_PREFIX}{message_id}" + message_data = redis_client.hgetall(message_key) + + if not message_data: + return None + + return { + 'id': message_data.get('id'), + 'session_id': message_data.get('session_id'), + 'sender_type': message_data.get('sender_type'), + 'sender_id': message_data.get('sender_id'), + 'content': message_data.get('content'), + 'status': message_data.get('status'), + 'retry_count': int(message_data.get('retry_count', 0)), + } + except Exception as e: + print(f"Failed to dequeue message: {e}") + return None + + @classmethod + def ack(cls, message_id: str) -> bool: + """消息确认""" + try: + message_key = f"{cls.MESSAGE_PREFIX}{message_id}" + redis_client.delete(message_key) + return True + except Exception as e: + print(f"Failed to ack message: {e}") + return False + + @classmethod + def retry(cls, message_id: str) -> bool: + """消息重试""" + try: + message_key = f"{cls.MESSAGE_PREFIX}{message_id}" + retry_count = redis_client.hget(message_key, 'retry_count') + + if retry_count is None: + return False + + retry_count = int(retry_count) + + if retry_count >= cls.MAX_RETRY: + # 超过最大重试次数,标记为失败 + redis_client.hset(message_key, 'status', 'failed') + return False + + # 增加重试计数 + redis_client.hset(message_key, 'retry_count', str(retry_count + 1)) + + # 加入重试队列 + redis_client.rpush(cls.RETRY_QUEUE, message_id) + + return True + except Exception as e: + print(f"Failed to retry message: {e}") + return False + + @classmethod + def get_pending_count(cls) -> int: + """获取待处理消息数量""" + try: + return redis_client.llen(cls.PENDING_QUEUE) + except: + return 0 + + @classmethod + def get_retry_count(cls) -> int: + """获取重试消息数量""" + try: + return redis_client.llen(cls.RETRY_QUEUE) + except: + return 0 + + @classmethod + def process_retry_queue(cls) -> List[dict]: + """处理重试队列""" + messages = [] + + try: + while True: + message_id = redis_client.lpop(cls.RETRY_QUEUE) + if not message_id: + break + + message_key = f"{cls.MESSAGE_PREFIX}{message_id}" + message_data = redis_client.hgetall(message_key) + + if message_data: + messages.append({ + 'id': message_data.get('id'), + 'session_id': message_data.get('session_id'), + 'content': message_data.get('content'), + 'retry_count': int(message_data.get('retry_count', 0)), + }) + except Exception as e: + print(f"Failed to process retry queue: {e}") + + return messages diff --git a/app/services/message_service.py b/app/services/message_service.py new file mode 100644 index 0000000..fa40765 --- /dev/null +++ b/app/services/message_service.py @@ -0,0 +1,167 @@ +""" +消息服务 +处理消息相关的业务逻辑 +""" +from typing import Optional, List +from datetime import datetime +from app.models import db, Message, Session +from app.services.message_queue import MessageQueue +import uuid + + +class MessageService: + """消息服务""" + + def __init__(self): + self.queue = MessageQueue() + + def create_message( + self, + session_id: str, + sender_type: str, + sender_id: str, + content: str, + message_type: str = 'text', + content_type: str = 'markdown', + reply_to: str = None + ) -> Message: + """创建消息""" + message = Message( + id=str(uuid.uuid4()), + session_id=session_id, + sender_type=sender_type, + sender_id=sender_id, + message_type=message_type, + content=content, + content_type=content_type, + reply_to=reply_to, + status='sent', + ack_status='pending' + ) + + db.session.add(message) + + # 更新会话统计 + session = Session.query.get(session_id) + if session: + session.message_count += 1 + session.last_active_at = datetime.utcnow() + + db.session.commit() + + # 入队等待确认 + self.queue.enqueue(message) + + return message + + def get_message(self, message_id: str) -> Optional[Message]: + """获取消息""" + return Message.query.get(message_id) + + def get_session_messages( + self, + session_id: str, + limit: int = 50, + offset: int = 0 + ) -> List[Message]: + """获取会话消息列表""" + return Message.query.filter_by(session_id=session_id)\ + .order_by(Message.created_at.asc())\ + .offset(offset)\ + .limit(limit)\ + .all() + + def acknowledge_message(self, message_id: str) -> Optional[Message]: + """确认消息已送达""" + message = Message.query.get(message_id) + if not message: + return None + + message.ack_status = 'acknowledged' + message.status = 'delivered' + message.delivered_at = datetime.utcnow() + + db.session.commit() + + # 从队列中移除 + self.queue.ack(message_id) + + return message + + def mark_as_read(self, message_id: str) -> Optional[Message]: + """标记消息已读""" + message = Message.query.get(message_id) + if not message: + return None + + message.status = 'read' + db.session.commit() + + return message + + def mark_session_read(self, session_id: str, user_id: str) -> int: + """标记会话所有消息已读""" + # 更新数据库 + count = Message.query.filter_by( + session_id=session_id, + status='delivered' + ).update({'status': 'read'}) + + # 更新会话未读数 + session = Session.query.get(session_id) + if session: + session.unread_count = 0 + + db.session.commit() + + return count + + def retry_message(self, message_id: str) -> Optional[Message]: + """重试发送消息""" + message = Message.query.get(message_id) + if not message: + return None + + # 检查重试次数 + if message.retry_count >= 3: + message.status = 'failed' + db.session.commit() + return None + + # 增加重试计数 + message.retry_count += 1 + message.status = 'sent' + + db.session.commit() + + # 加入重试队列 + self.queue.retry(message_id) + + return message + + def get_pending_messages(self) -> List[dict]: + """获取待处理消息""" + return self.queue.process_retry_queue() + + def get_message_stats(self, session_id: str = None) -> dict: + """获取消息统计""" + query = Message.query + + if session_id: + query = query.filter_by(session_id=session_id) + + total = query.count() + sent = query.filter_by(status='sent').count() + delivered = query.filter_by(status='delivered').count() + read = query.filter_by(status='read').count() + failed = query.filter_by(status='failed').count() + + return { + 'total': total, + 'sent': sent, + 'delivered': delivered, + 'read': read, + 'failed': failed, + 'pending': self.queue.get_pending_count(), + 'retry': self.queue.get_retry_count(), + } diff --git a/app/services/scheduler.py b/app/services/scheduler.py new file mode 100644 index 0000000..de24d41 --- /dev/null +++ b/app/services/scheduler.py @@ -0,0 +1,174 @@ +""" +Agent 调度器服务 +实现多种调度策略 +""" +from typing import List, Optional +from app.models import Agent + + +class BaseScheduler: + """调度器基类""" + + def select_agent(self, agents: List[Agent], context: dict = None) -> Optional[Agent]: + """选择 Agent,子类实现""" + raise NotImplementedError + + +class RoundRobinScheduler(BaseScheduler): + """轮询调度器""" + + def __init__(self): + self._index = 0 + + def select_agent(self, agents: List[Agent], context: dict = None) -> Optional[Agent]: + if not agents: + return None + + # 只选择在线的 Agent + online_agents = [a for a in agents if a.status == 'online'] + if not online_agents: + return None + + # 轮询选择 + agent = online_agents[self._index % len(online_agents)] + self._index += 1 + return agent + + +class WeightedRoundRobinScheduler(BaseScheduler): + """加权轮询调度器(默认)""" + + def __init__(self): + self._index = 0 + self._weights = {} + + def select_agent(self, agents: List[Agent], context: dict = None) -> Optional[Agent]: + if not agents: + return None + + # 只选择在线的 Agent + online_agents = [a for a in agents if a.status == 'online'] + if not online_agents: + return None + + # 计算总权重 + total_weight = sum(a.weight for a in online_agents) + if total_weight == 0: + return online_agents[0] if online_agents else None + + # 加权轮询 + current_weight = self._index % total_weight + cumulative = 0 + + for agent in online_agents: + cumulative += agent.weight + if current_weight < cumulative: + self._index += 1 + return agent + + return online_agents[-1] + + +class LeastConnectionsScheduler(BaseScheduler): + """最少连接调度器""" + + def select_agent(self, agents: List[Agent], context: dict = None) -> Optional[Agent]: + if not agents: + return None + + # 只选择在线且未达连接上限的 Agent + available_agents = [ + a for a in agents + if a.status == 'online' and a.current_sessions < a.connection_limit + ] + + if not available_agents: + return None + + # 选择连接数最少的 + return min(available_agents, key=lambda a: a.current_sessions) + + +class LeastResponseTimeScheduler(BaseScheduler): + """最快响应调度器(基于最后心跳时间)""" + + def select_agent(self, agents: List[Agent], context: dict = None) -> Optional[Agent]: + if not agents: + return None + + # 只选择在线的 Agent + online_agents = [a for a in agents if a.status == 'online'] + if not online_agents: + return None + + # 选择最近有心跳的(响应快的) + from datetime import datetime + now = datetime.utcnow() + + def response_score(agent): + if not agent.last_heartbeat: + return float('inf') + # 最近心跳 = 分数低 = 优先选择 + return (now - agent.last_heartbeat).total_seconds() + + return min(online_agents, key=response_score) + + +class CapabilityMatchScheduler(BaseScheduler): + """能力匹配调度器""" + + def select_agent(self, agents: List[Agent], context: dict = None) -> Optional[Agent]: + if not agents or not context: + return None + + required_capabilities = context.get('capabilities', []) + if not required_capabilities: + # 没有特殊要求,使用加权轮询 + return WeightedRoundRobinScheduler().select_agent(agents, context) + + # 只选择在线且有对应能力的 Agent + matching_agents = [] + for agent in agents: + if agent.status != 'online': + continue + + agent_caps = agent.capabilities or [] + # 检查是否具备所有必需能力 + if all(cap in agent_caps for cap in required_capabilities): + matching_agents.append(agent) + + if not matching_agents: + # 没有匹配的,回退到加权轮询 + return WeightedRoundRobinScheduler().select_agent(agents, context) + + # 在匹配的 Agent 中使用加权轮询 + return WeightedRoundRobinScheduler().select_agent(matching_agents, context) + + +class AgentScheduler: + """Agent 调度器工厂""" + + STRATEGIES = { + 'round_robin': RoundRobinScheduler, + 'weighted_round_robin': WeightedRoundRobinScheduler, + 'least_connections': LeastConnectionsScheduler, + 'least_response_time': LeastResponseTimeScheduler, + 'capability_match': CapabilityMatchScheduler, + } + + def __init__(self, strategy: str = 'weighted_round_robin'): + self._strategy_name = strategy + self._scheduler = self.STRATEGIES.get(strategy, WeightedRoundRobinScheduler)() + + def select_agent(self, agents: List[Agent], context: dict = None) -> Optional[Agent]: + """选择 Agent""" + return self._scheduler.select_agent(agents, context) + + def get_strategy(self) -> str: + """获取当前策略""" + return self._strategy_name + + @classmethod + def get_available_strategies(cls) -> list: + """获取可用策略列表""" + return list(cls.STRATEGIES.keys()) diff --git a/app/services/session_service.py b/app/services/session_service.py new file mode 100644 index 0000000..afa9a38 --- /dev/null +++ b/app/services/session_service.py @@ -0,0 +1,168 @@ +""" +会话服务 +处理会话相关的业务逻辑 +""" +from typing import Optional, List +from datetime import datetime +from app.models import db, Session, Agent +from app.services.scheduler import AgentScheduler + + +class SessionService: + """会话服务""" + + def __init__(self, scheduler_strategy: str = 'weighted_round_robin'): + self.scheduler = AgentScheduler(scheduler_strategy) + + def create_session( + self, + user_id: str, + title: str = None, + agent_id: str = None, + channel_type: str = 'web', + context: dict = None + ) -> Session: + """创建新会话""" + # 如果没有指定 Agent,调度一个 + if not agent_id: + from app.models import Agent + agents = Agent.query.all() + selected_agent = self.scheduler.select_agent(agents, context) + agent_id = selected_agent.id if selected_agent else None + + # 创建会话 + session = Session( + user_id=user_id, + primary_agent_id=agent_id, + title=title or f"Session {datetime.utcnow().strftime('%Y-%m-%d %H:%M')}", + channel_type=channel_type, + status='active' + ) + + db.session.add(session) + + # 更新 Agent 会话数 + if agent_id: + agent = Agent.query.get(agent_id) + if agent: + agent.current_sessions += 1 + + db.session.commit() + + return session + + def close_session(self, session_id: str, user_id: str = None) -> Optional[Session]: + """关闭会话""" + session = Session.query.get(session_id) + if not session: + return None + + # 验证权限 + if user_id and session.user_id != user_id: + return None + + session.status = 'closed' + session.updated_at = datetime.utcnow() + + # 更新 Agent 会话数 + if session.primary_agent_id: + agent = Agent.query.get(session.primary_agent_id) + if agent and agent.current_sessions > 0: + agent.current_sessions -= 1 + + db.session.commit() + + return session + + def get_user_sessions(self, user_id: str, status: str = None) -> List[Session]: + """获取用户会话列表""" + query = Session.query.filter_by(user_id=user_id) + + if status: + query = query.filter_by(status=status) + + return query.order_by(Session.created_at.desc()).all() + + def get_session(self, session_id: str, user_id: str = None) -> Optional[Session]: + """获取会话详情""" + session = Session.query.get(session_id) + + if not session: + return None + + if user_id and session.user_id != user_id: + return None + + return session + + def update_session_activity(self, session_id: str) -> bool: + """更新会话活跃时间""" + session = Session.query.get(session_id) + if not session: + return False + + session.last_active_at = datetime.utcnow() + session.message_count += 1 + db.session.commit() + + return True + + def assign_agent(self, session_id: str, agent_id: str) -> Optional[Session]: + """分配 Agent 到会话""" + session = Session.query.get(session_id) + if not session: + return None + + # 验证 Agent 是否在线 + agent = Agent.query.get(agent_id) + if not agent or agent.status != 'online': + return None + + # 更新会话 + old_agent_id = session.primary_agent_id + session.primary_agent_id = agent_id + + # 更新新旧 Agent 的会话数 + if old_agent_id: + old_agent = Agent.query.get(old_agent_id) + if old_agent and old_agent.current_sessions > 0: + old_agent.current_sessions -= 1 + + agent.current_sessions += 1 + + db.session.commit() + + return session + + def add_participant(self, session_id: str, agent_id: str) -> Optional[Session]: + """添加参与 Agent""" + session = Session.query.get(session_id) + if not session: + return None + + # 获取当前参与列表 + participants = session.participating_agent_ids or [] + + if agent_id not in participants: + participants.append(agent_id) + session.participating_agent_ids = participants + + db.session.commit() + + return session + + def remove_participant(self, session_id: str, agent_id: str) -> Optional[Session]: + """移除参与 Agent""" + session = Session.query.get(session_id) + if not session: + return None + + participants = session.participating_agent_ids or [] + + if agent_id in participants: + participants.remove(agent_id) + session.participating_agent_ids = participants + + db.session.commit() + + return session