169 lines
5.0 KiB
Python
169 lines
5.0 KiB
Python
|
|
"""
|
|||
|
|
会话服务
|
|||
|
|
处理会话相关的业务逻辑
|
|||
|
|
"""
|
|||
|
|
from typing import Optional, List
|
|||
|
|
from datetime import datetime
|
|||
|
|
from app.models import db, Session, Agent
|
|||
|
|
from app.services.scheduler import AgentScheduler
|
|||
|
|
|
|||
|
|
|
|||
|
|
class SessionService:
|
|||
|
|
"""会话服务"""
|
|||
|
|
|
|||
|
|
def __init__(self, scheduler_strategy: str = 'weighted_round_robin'):
|
|||
|
|
self.scheduler = AgentScheduler(scheduler_strategy)
|
|||
|
|
|
|||
|
|
def create_session(
|
|||
|
|
self,
|
|||
|
|
user_id: str,
|
|||
|
|
title: str = None,
|
|||
|
|
agent_id: str = None,
|
|||
|
|
channel_type: str = 'web',
|
|||
|
|
context: dict = None
|
|||
|
|
) -> Session:
|
|||
|
|
"""创建新会话"""
|
|||
|
|
# 如果没有指定 Agent,调度一个
|
|||
|
|
if not agent_id:
|
|||
|
|
from app.models import Agent
|
|||
|
|
agents = Agent.query.all()
|
|||
|
|
selected_agent = self.scheduler.select_agent(agents, context)
|
|||
|
|
agent_id = selected_agent.id if selected_agent else None
|
|||
|
|
|
|||
|
|
# 创建会话
|
|||
|
|
session = Session(
|
|||
|
|
user_id=user_id,
|
|||
|
|
primary_agent_id=agent_id,
|
|||
|
|
title=title or f"Session {datetime.utcnow().strftime('%Y-%m-%d %H:%M')}",
|
|||
|
|
channel_type=channel_type,
|
|||
|
|
status='active'
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
db.session.add(session)
|
|||
|
|
|
|||
|
|
# 更新 Agent 会话数
|
|||
|
|
if agent_id:
|
|||
|
|
agent = Agent.query.get(agent_id)
|
|||
|
|
if agent:
|
|||
|
|
agent.current_sessions += 1
|
|||
|
|
|
|||
|
|
db.session.commit()
|
|||
|
|
|
|||
|
|
return session
|
|||
|
|
|
|||
|
|
def close_session(self, session_id: str, user_id: str = None) -> Optional[Session]:
|
|||
|
|
"""关闭会话"""
|
|||
|
|
session = Session.query.get(session_id)
|
|||
|
|
if not session:
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
# 验证权限
|
|||
|
|
if user_id and session.user_id != user_id:
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
session.status = 'closed'
|
|||
|
|
session.updated_at = datetime.utcnow()
|
|||
|
|
|
|||
|
|
# 更新 Agent 会话数
|
|||
|
|
if session.primary_agent_id:
|
|||
|
|
agent = Agent.query.get(session.primary_agent_id)
|
|||
|
|
if agent and agent.current_sessions > 0:
|
|||
|
|
agent.current_sessions -= 1
|
|||
|
|
|
|||
|
|
db.session.commit()
|
|||
|
|
|
|||
|
|
return session
|
|||
|
|
|
|||
|
|
def get_user_sessions(self, user_id: str, status: str = None) -> List[Session]:
|
|||
|
|
"""获取用户会话列表"""
|
|||
|
|
query = Session.query.filter_by(user_id=user_id)
|
|||
|
|
|
|||
|
|
if status:
|
|||
|
|
query = query.filter_by(status=status)
|
|||
|
|
|
|||
|
|
return query.order_by(Session.created_at.desc()).all()
|
|||
|
|
|
|||
|
|
def get_session(self, session_id: str, user_id: str = None) -> Optional[Session]:
|
|||
|
|
"""获取会话详情"""
|
|||
|
|
session = Session.query.get(session_id)
|
|||
|
|
|
|||
|
|
if not session:
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
if user_id and session.user_id != user_id:
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
return session
|
|||
|
|
|
|||
|
|
def update_session_activity(self, session_id: str) -> bool:
|
|||
|
|
"""更新会话活跃时间"""
|
|||
|
|
session = Session.query.get(session_id)
|
|||
|
|
if not session:
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
session.last_active_at = datetime.utcnow()
|
|||
|
|
session.message_count += 1
|
|||
|
|
db.session.commit()
|
|||
|
|
|
|||
|
|
return True
|
|||
|
|
|
|||
|
|
def assign_agent(self, session_id: str, agent_id: str) -> Optional[Session]:
|
|||
|
|
"""分配 Agent 到会话"""
|
|||
|
|
session = Session.query.get(session_id)
|
|||
|
|
if not session:
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
# 验证 Agent 是否在线
|
|||
|
|
agent = Agent.query.get(agent_id)
|
|||
|
|
if not agent or agent.status != 'online':
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
# 更新会话
|
|||
|
|
old_agent_id = session.primary_agent_id
|
|||
|
|
session.primary_agent_id = agent_id
|
|||
|
|
|
|||
|
|
# 更新新旧 Agent 的会话数
|
|||
|
|
if old_agent_id:
|
|||
|
|
old_agent = Agent.query.get(old_agent_id)
|
|||
|
|
if old_agent and old_agent.current_sessions > 0:
|
|||
|
|
old_agent.current_sessions -= 1
|
|||
|
|
|
|||
|
|
agent.current_sessions += 1
|
|||
|
|
|
|||
|
|
db.session.commit()
|
|||
|
|
|
|||
|
|
return session
|
|||
|
|
|
|||
|
|
def add_participant(self, session_id: str, agent_id: str) -> Optional[Session]:
|
|||
|
|
"""添加参与 Agent"""
|
|||
|
|
session = Session.query.get(session_id)
|
|||
|
|
if not session:
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
# 获取当前参与列表
|
|||
|
|
participants = session.participating_agent_ids or []
|
|||
|
|
|
|||
|
|
if agent_id not in participants:
|
|||
|
|
participants.append(agent_id)
|
|||
|
|
session.participating_agent_ids = participants
|
|||
|
|
|
|||
|
|
db.session.commit()
|
|||
|
|
|
|||
|
|
return session
|
|||
|
|
|
|||
|
|
def remove_participant(self, session_id: str, agent_id: str) -> Optional[Session]:
|
|||
|
|
"""移除参与 Agent"""
|
|||
|
|
session = Session.query.get(session_id)
|
|||
|
|
if not session:
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
participants = session.participating_agent_ids or []
|
|||
|
|
|
|||
|
|
if agent_id in participants:
|
|||
|
|
participants.remove(agent_id)
|
|||
|
|
session.participating_agent_ids = participants
|
|||
|
|
|
|||
|
|
db.session.commit()
|
|||
|
|
|
|||
|
|
return session
|