""" 聊天 WebSocket 事件处理器 实现智队机器人的聊天功能 """ from flask_socketio import emit, join_room, leave_room, rooms from flask import request from datetime import datetime from typing import Optional, Dict, Any from app.extensions import db, redis_client from app.models import User, Session, Message, Bot, Agent, Connection from app.services.bot_service import BotService from app.services.session_service import SessionService class ChatConnectionManager: """聊天连接管理器""" def __init__(self): # socket_id -> {user_id, current_session_id} self.connections: Dict[str, Dict[str, Any]] = {} def add_connection(self, socket_id: str, user_id: str): """添加连接""" self.connections[socket_id] = { 'user_id': user_id, 'current_session_id': None, 'connected_at': datetime.utcnow() } def set_current_session(self, socket_id: str, session_id: Optional[str]): """设置当前会话""" if socket_id in self.connections: self.connections[socket_id]['current_session_id'] = session_id def get_user_id(self, socket_id: str) -> Optional[str]: """获取用户 ID""" return self.connections.get(socket_id, {}).get('user_id') def get_current_session(self, socket_id: str) -> Optional[str]: """获取当前会话 ID""" return self.connections.get(socket_id, {}).get('current_session_id') def remove_connection(self, socket_id: str): """移除连接""" if socket_id in self.connections: del self.connections[socket_id] # 全局聊天连接管理器 chat_manager = ChatConnectionManager() def get_user_from_socket(socket_id: str) -> Optional[User]: """从 socket_id 获取用户""" user_id = chat_manager.get_user_id(socket_id) if user_id: return User.query.get(user_id) return None def emit_chat_error(message: str, code: str = 'CHAT_ERROR', session_id: Optional[str] = None): """发送聊天错误""" error_data = { 'code': code, 'message': message } if session_id: error_data['session_id'] = session_id emit('chat_error', error_data) def register_chat_handlers(socketio): """注册聊天事件处理器""" # ==================== 创建会话 ==================== @socketio.on('chat.send.create') def handle_chat_create(data): """ 创建聊天会话 参数: - bot_id: str - 机器人 ID(必填) - title: str - 会话标题(可选) """ sid = request.sid user = get_user_from_socket(sid) if not user: return emit_chat_error('User not authenticated', 'AUTH_REQUIRED') bot_id = data.get('bot_id') if not bot_id: return emit_chat_error('bot_id is required', 'MISSING_BOT_ID') # 获取 Bot bot = BotService.get_bot_by_id(bot_id) if not bot: return emit_chat_error('Bot not found', 'BOT_NOT_FOUND') # 检查权限 if not BotService.check_permission(user, bot, 'use'): return emit_chat_error('Permission denied', 'PERMISSION_DENIED') # 检查 Bot 是否绑定了 Agent if not bot.agent_id: return emit_chat_error('Bot has no agent bound', 'NO_AGENT_BOUND') # 检查 Agent 是否在线 agent = Agent.query.get(bot.agent_id) if not agent or agent.status != 'online': return emit_chat_error('Agent is offline', 'AGENT_OFFLINE') # 创建会话 title = data.get('title', f'Chat with {bot.display_name or bot.name}') session = Session( user_id=user.id, bot_id=bot.id, primary_agent_id=agent.id, title=title, channel_type='pit-bot', status='active', created_at=datetime.utcnow(), last_active_at=datetime.utcnow() ) db.session.add(session) db.session.commit() # 加入房间 join_room(session.id) chat_manager.set_current_session(sid, session.id) # 返回创建结果 emit('chat.created', { 'session_id': session.id, 'bot': bot.to_dict(), 'agent': { 'id': agent.id, 'name': agent.name, 'display_name': agent.display_name }, 'title': title, 'created_at': session.created_at.isoformat() }) # ==================== 加入会话 ==================== @socketio.on('chat.send.join') def handle_chat_join(data): """ 加入会话 参数: - session_id: str - 会话 ID """ sid = request.sid user = get_user_from_socket(sid) if not user: return emit_chat_error('User not authenticated', 'AUTH_REQUIRED') session_id = data.get('session_id') if not session_id: return emit_chat_error('session_id is required', 'MISSING_SESSION_ID') # 获取会话 session = Session.query.get(session_id) if not session: return emit_chat_error('Session not found', 'SESSION_NOT_FOUND', session_id) # 检查权限 if session.user_id != user.id: return emit_chat_error('Permission denied', 'PERMISSION_DENIED', session_id) # 获取 Bot 信息 bot = None if session.bot_id: bot = BotService.get_bot_by_id(session.bot_id) # 获取历史消息 messages = Message.query.filter_by(session_id=session_id)\ .order_by(Message.created_at.desc())\ .limit(50)\ .all() # 加入房间 join_room(session_id) chat_manager.set_current_session(sid, session_id) # 返回加入结果 emit('chat.joined', { 'session_id': session_id, 'bot': bot.to_dict() if bot else None, 'messages': [m.to_dict() for m in reversed(messages)], 'message_count': len(messages) }) # ==================== 离开会话 ==================== @socketio.on('chat.send.leave') def handle_chat_leave(data): """ 离开会话 参数: - session_id: str - 会话 ID """ sid = request.sid session_id = data.get('session_id') if not session_id: return emit_chat_error('session_id is required', 'MISSING_SESSION_ID') # 离开房间 leave_room(session_id) chat_manager.set_current_session(sid, None) emit('chat.left', {'session_id': session_id}) # ==================== 发送消息 ==================== @socketio.on('chat.send.message') def handle_chat_message(data): """ 发送消息 参数: - session_id: str - 会话 ID - content: str - 消息内容 - reply_to: str - 回复的消息 ID(可选) """ sid = request.sid user = get_user_from_socket(sid) if not user: return emit_chat_error('User not authenticated', 'AUTH_REQUIRED') session_id = data.get('session_id') content = data.get('content') reply_to = data.get('reply_to') if not session_id: return emit_chat_error('session_id is required', 'MISSING_SESSION_ID') if not content or not content.strip(): return emit_chat_error('content is required', 'MISSING_CONTENT', session_id) # 获取会话 session = Session.query.get(session_id) if not session: return emit_chat_error('Session not found', 'SESSION_NOT_FOUND', session_id) # 检查权限 if session.user_id != user.id: return emit_chat_error('Permission denied', 'PERMISSION_DENIED', session_id) # 获取 Bot 信息 bot = None sender_name = user.nickname or user.username if session.bot_id: bot = BotService.get_bot_by_id(session.bot_id) # 创建用户消息 message = Message( session_id=session_id, sender_type='user', sender_id=user.id, sender_name=sender_name, bot_id=session.bot_id, message_type='text', content=content.strip(), content_type='markdown', reply_to=reply_to, status='sent', ack_status='pending', created_at=datetime.utcnow() ) db.session.add(message) # 更新会话 session.message_count += 1 session.last_active_at = datetime.utcnow() session.updated_at = datetime.utcnow() db.session.commit() # 广播消息到房间(用户端) emit('chat.message', { 'message_id': message.id, 'session_id': session_id, 'sender_type': 'user', 'sender_id': user.id, 'sender_name': sender_name, 'bot_id': session.bot_id, 'content': content.strip(), 'content_type': 'markdown', 'reply_to': reply_to, 'timestamp': message.created_at.isoformat() }, room=session_id) # TODO: 转发消息给 Agent(通过 PIT Channel 协议) # 这里需要实现将消息转发给绑定的 Agent # 使用 session.primary_agent_id 获取 Agent # 然后通过 Agent 的 socket_id 发送消息 # ==================== 正在输入 ==================== @socketio.on('chat.send.typing') def handle_chat_typing(data): """ 正在输入状态 参数: - session_id: str - 会话 ID - is_typing: bool - 是否正在输入 """ sid = request.sid user = get_user_from_socket(sid) if not user: return session_id = data.get('session_id') is_typing = data.get('is_typing', False) if not session_id: return # 广播输入状态到房间(除了发送者) emit('chat.typing', { 'session_id': session_id, 'user_id': user.id, 'user_name': user.nickname or user.username, 'is_typing': is_typing }, room=session_id, include_self=False) # ==================== 消息已读 ==================== @socketio.on('chat.send.read') def handle_chat_read(data): """ 标记消息已读 参数: - session_id: str - 会话 ID - message_ids: list - 消息 ID 列表 """ sid = request.sid user = get_user_from_socket(sid) if not user: return emit_chat_error('User not authenticated', 'AUTH_REQUIRED') session_id = data.get('session_id') message_ids = data.get('message_ids', []) if not session_id: return emit_chat_error('session_id is required', 'MISSING_SESSION_ID') # 更新消息状态 for msg_id in message_ids: message = Message.query.get(msg_id) if message and message.session_id == session_id: message.status = 'read' # 更新会话未读数 session = Session.query.get(session_id) if session: session.unread_count = 0 db.session.commit() # 返回已读确认 emit('chat.read', { 'session_id': session_id, 'message_ids': message_ids }) # ==================== 关闭会话 ==================== def close_chat_session(session_id: str, reason: str = 'closed'): """关闭聊天会话(内部方法)""" session = Session.query.get(session_id) if session: session.status = 'closed' session.updated_at = datetime.utcnow() db.session.commit() # 通知房间内的所有用户 emit('chat.closed', { 'session_id': session_id, 'reason': reason }, room=session_id) # 导出注册函数 __all__ = ['register_chat_handlers', 'chat_manager']