Files
pit-router/app/services/session_service.py

169 lines
5.0 KiB
Python
Raw Normal View History

"""
会话服务
处理会话相关的业务逻辑
"""
from typing import Optional, List
from datetime import datetime
from app.models import db, Session, Agent
from app.services.scheduler import AgentScheduler
class SessionService:
"""会话服务"""
def __init__(self, scheduler_strategy: str = 'weighted_round_robin'):
self.scheduler = AgentScheduler(scheduler_strategy)
def create_session(
self,
user_id: str,
title: str = None,
agent_id: str = None,
channel_type: str = 'web',
context: dict = None
) -> Session:
"""创建新会话"""
# 如果没有指定 Agent调度一个
if not agent_id:
from app.models import Agent
agents = Agent.query.all()
selected_agent = self.scheduler.select_agent(agents, context)
agent_id = selected_agent.id if selected_agent else None
# 创建会话
session = Session(
user_id=user_id,
primary_agent_id=agent_id,
title=title or f"Session {datetime.utcnow().strftime('%Y-%m-%d %H:%M')}",
channel_type=channel_type,
status='active'
)
db.session.add(session)
# 更新 Agent 会话数
if agent_id:
agent = Agent.query.get(agent_id)
if agent:
agent.current_sessions += 1
db.session.commit()
return session
def close_session(self, session_id: str, user_id: str = None) -> Optional[Session]:
"""关闭会话"""
session = Session.query.get(session_id)
if not session:
return None
# 验证权限
if user_id and session.user_id != user_id:
return None
session.status = 'closed'
session.updated_at = datetime.utcnow()
# 更新 Agent 会话数
if session.primary_agent_id:
agent = Agent.query.get(session.primary_agent_id)
if agent and agent.current_sessions > 0:
agent.current_sessions -= 1
db.session.commit()
return session
def get_user_sessions(self, user_id: str, status: str = None) -> List[Session]:
"""获取用户会话列表"""
query = Session.query.filter_by(user_id=user_id)
if status:
query = query.filter_by(status=status)
return query.order_by(Session.created_at.desc()).all()
def get_session(self, session_id: str, user_id: str = None) -> Optional[Session]:
"""获取会话详情"""
session = Session.query.get(session_id)
if not session:
return None
if user_id and session.user_id != user_id:
return None
return session
def update_session_activity(self, session_id: str) -> bool:
"""更新会话活跃时间"""
session = Session.query.get(session_id)
if not session:
return False
session.last_active_at = datetime.utcnow()
session.message_count += 1
db.session.commit()
return True
def assign_agent(self, session_id: str, agent_id: str) -> Optional[Session]:
"""分配 Agent 到会话"""
session = Session.query.get(session_id)
if not session:
return None
# 验证 Agent 是否在线
agent = Agent.query.get(agent_id)
if not agent or agent.status != 'online':
return None
# 更新会话
old_agent_id = session.primary_agent_id
session.primary_agent_id = agent_id
# 更新新旧 Agent 的会话数
if old_agent_id:
old_agent = Agent.query.get(old_agent_id)
if old_agent and old_agent.current_sessions > 0:
old_agent.current_sessions -= 1
agent.current_sessions += 1
db.session.commit()
return session
def add_participant(self, session_id: str, agent_id: str) -> Optional[Session]:
"""添加参与 Agent"""
session = Session.query.get(session_id)
if not session:
return None
# 获取当前参与列表
participants = session.participating_agent_ids or []
if agent_id not in participants:
participants.append(agent_id)
session.participating_agent_ids = participants
db.session.commit()
return session
def remove_participant(self, session_id: str, agent_id: str) -> Optional[Session]:
"""移除参与 Agent"""
session = Session.query.get(session_id)
if not session:
return None
participants = session.participating_agent_ids or []
if agent_id in participants:
participants.remove(agent_id)
session.participating_agent_ids = participants
db.session.commit()
return session