diff --git a/app/socketio/chat_handlers.py b/app/socketio/chat_handlers.py new file mode 100644 index 0000000..59a4c41 --- /dev/null +++ b/app/socketio/chat_handlers.py @@ -0,0 +1,391 @@ +""" +聊天 WebSocket 事件处理器 +实现智队机器人的聊天功能 +""" +from flask_socketio import emit, join_room, leave_room, rooms +from flask import request +from datetime import datetime +from typing import Optional, Dict, Any + +from app.extensions import db, redis_client +from app.models import User, Session, Message, Bot, Agent, Connection +from app.services.bot_service import BotService +from app.services.session_service import SessionService + + +class ChatConnectionManager: + """聊天连接管理器""" + + def __init__(self): + # socket_id -> {user_id, current_session_id} + self.connections: Dict[str, Dict[str, Any]] = {} + + def add_connection(self, socket_id: str, user_id: str): + """添加连接""" + self.connections[socket_id] = { + 'user_id': user_id, + 'current_session_id': None, + 'connected_at': datetime.utcnow() + } + + def set_current_session(self, socket_id: str, session_id: Optional[str]): + """设置当前会话""" + if socket_id in self.connections: + self.connections[socket_id]['current_session_id'] = session_id + + def get_user_id(self, socket_id: str) -> Optional[str]: + """获取用户 ID""" + return self.connections.get(socket_id, {}).get('user_id') + + def get_current_session(self, socket_id: str) -> Optional[str]: + """获取当前会话 ID""" + return self.connections.get(socket_id, {}).get('current_session_id') + + def remove_connection(self, socket_id: str): + """移除连接""" + if socket_id in self.connections: + del self.connections[socket_id] + + +# 全局聊天连接管理器 +chat_manager = ChatConnectionManager() + + +def get_user_from_socket(socket_id: str) -> Optional[User]: + """从 socket_id 获取用户""" + user_id = chat_manager.get_user_id(socket_id) + if user_id: + return User.query.get(user_id) + return None + + +def emit_chat_error(message: str, code: str = 'CHAT_ERROR', session_id: Optional[str] = None): + """发送聊天错误""" + error_data = { + 'code': code, + 'message': message + } + if session_id: + error_data['session_id'] = session_id + emit('chat_error', error_data) + + +def register_chat_handlers(socketio): + """注册聊天事件处理器""" + + # ==================== 创建会话 ==================== + @socketio.on('chat.send.create') + def handle_chat_create(data): + """ + 创建聊天会话 + + 参数: + - bot_id: str - 机器人 ID(必填) + - title: str - 会话标题(可选) + """ + sid = request.sid + user = get_user_from_socket(sid) + + if not user: + return emit_chat_error('User not authenticated', 'AUTH_REQUIRED') + + bot_id = data.get('bot_id') + if not bot_id: + return emit_chat_error('bot_id is required', 'MISSING_BOT_ID') + + # 获取 Bot + bot = BotService.get_bot_by_id(bot_id) + if not bot: + return emit_chat_error('Bot not found', 'BOT_NOT_FOUND') + + # 检查权限 + if not BotService.check_permission(user, bot, 'use'): + return emit_chat_error('Permission denied', 'PERMISSION_DENIED') + + # 检查 Bot 是否绑定了 Agent + if not bot.agent_id: + return emit_chat_error('Bot has no agent bound', 'NO_AGENT_BOUND') + + # 检查 Agent 是否在线 + agent = Agent.query.get(bot.agent_id) + if not agent or agent.status != 'online': + return emit_chat_error('Agent is offline', 'AGENT_OFFLINE') + + # 创建会话 + title = data.get('title', f'Chat with {bot.display_name or bot.name}') + session = Session( + user_id=user.id, + bot_id=bot.id, + primary_agent_id=agent.id, + title=title, + channel_type='pit-bot', + status='active', + created_at=datetime.utcnow(), + last_active_at=datetime.utcnow() + ) + db.session.add(session) + db.session.commit() + + # 加入房间 + join_room(session.id) + chat_manager.set_current_session(sid, session.id) + + # 返回创建结果 + emit('chat.created', { + 'session_id': session.id, + 'bot': bot.to_dict(), + 'agent': { + 'id': agent.id, + 'name': agent.name, + 'display_name': agent.display_name + }, + 'title': title, + 'created_at': session.created_at.isoformat() + }) + + # ==================== 加入会话 ==================== + @socketio.on('chat.send.join') + def handle_chat_join(data): + """ + 加入会话 + + 参数: + - session_id: str - 会话 ID + """ + sid = request.sid + user = get_user_from_socket(sid) + + if not user: + return emit_chat_error('User not authenticated', 'AUTH_REQUIRED') + + session_id = data.get('session_id') + if not session_id: + return emit_chat_error('session_id is required', 'MISSING_SESSION_ID') + + # 获取会话 + session = Session.query.get(session_id) + if not session: + return emit_chat_error('Session not found', 'SESSION_NOT_FOUND', session_id) + + # 检查权限 + if session.user_id != user.id: + return emit_chat_error('Permission denied', 'PERMISSION_DENIED', session_id) + + # 获取 Bot 信息 + bot = None + if session.bot_id: + bot = BotService.get_bot_by_id(session.bot_id) + + # 获取历史消息 + messages = Message.query.filter_by(session_id=session_id)\ + .order_by(Message.created_at.desc())\ + .limit(50)\ + .all() + + # 加入房间 + join_room(session_id) + chat_manager.set_current_session(sid, session_id) + + # 返回加入结果 + emit('chat.joined', { + 'session_id': session_id, + 'bot': bot.to_dict() if bot else None, + 'messages': [m.to_dict() for m in reversed(messages)], + 'message_count': len(messages) + }) + + # ==================== 离开会话 ==================== + @socketio.on('chat.send.leave') + def handle_chat_leave(data): + """ + 离开会话 + + 参数: + - session_id: str - 会话 ID + """ + sid = request.sid + session_id = data.get('session_id') + + if not session_id: + return emit_chat_error('session_id is required', 'MISSING_SESSION_ID') + + # 离开房间 + leave_room(session_id) + chat_manager.set_current_session(sid, None) + + emit('chat.left', {'session_id': session_id}) + + # ==================== 发送消息 ==================== + @socketio.on('chat.send.message') + def handle_chat_message(data): + """ + 发送消息 + + 参数: + - session_id: str - 会话 ID + - content: str - 消息内容 + - reply_to: str - 回复的消息 ID(可选) + """ + sid = request.sid + user = get_user_from_socket(sid) + + if not user: + return emit_chat_error('User not authenticated', 'AUTH_REQUIRED') + + session_id = data.get('session_id') + content = data.get('content') + reply_to = data.get('reply_to') + + if not session_id: + return emit_chat_error('session_id is required', 'MISSING_SESSION_ID') + + if not content or not content.strip(): + return emit_chat_error('content is required', 'MISSING_CONTENT', session_id) + + # 获取会话 + session = Session.query.get(session_id) + if not session: + return emit_chat_error('Session not found', 'SESSION_NOT_FOUND', session_id) + + # 检查权限 + if session.user_id != user.id: + return emit_chat_error('Permission denied', 'PERMISSION_DENIED', session_id) + + # 获取 Bot 信息 + bot = None + sender_name = user.nickname or user.username + if session.bot_id: + bot = BotService.get_bot_by_id(session.bot_id) + + # 创建用户消息 + message = Message( + session_id=session_id, + sender_type='user', + sender_id=user.id, + sender_name=sender_name, + bot_id=session.bot_id, + message_type='text', + content=content.strip(), + content_type='markdown', + reply_to=reply_to, + status='sent', + ack_status='pending', + created_at=datetime.utcnow() + ) + db.session.add(message) + + # 更新会话 + session.message_count += 1 + session.last_active_at = datetime.utcnow() + session.updated_at = datetime.utcnow() + + db.session.commit() + + # 广播消息到房间(用户端) + emit('chat.message', { + 'message_id': message.id, + 'session_id': session_id, + 'sender_type': 'user', + 'sender_id': user.id, + 'sender_name': sender_name, + 'bot_id': session.bot_id, + 'content': content.strip(), + 'content_type': 'markdown', + 'reply_to': reply_to, + 'timestamp': message.created_at.isoformat() + }, room=session_id) + + # TODO: 转发消息给 Agent(通过 PIT Channel 协议) + # 这里需要实现将消息转发给绑定的 Agent + # 使用 session.primary_agent_id 获取 Agent + # 然后通过 Agent 的 socket_id 发送消息 + + # ==================== 正在输入 ==================== + @socketio.on('chat.send.typing') + def handle_chat_typing(data): + """ + 正在输入状态 + + 参数: + - session_id: str - 会话 ID + - is_typing: bool - 是否正在输入 + """ + sid = request.sid + user = get_user_from_socket(sid) + + if not user: + return + + session_id = data.get('session_id') + is_typing = data.get('is_typing', False) + + if not session_id: + return + + # 广播输入状态到房间(除了发送者) + emit('chat.typing', { + 'session_id': session_id, + 'user_id': user.id, + 'user_name': user.nickname or user.username, + 'is_typing': is_typing + }, room=session_id, include_self=False) + + # ==================== 消息已读 ==================== + @socketio.on('chat.send.read') + def handle_chat_read(data): + """ + 标记消息已读 + + 参数: + - session_id: str - 会话 ID + - message_ids: list - 消息 ID 列表 + """ + sid = request.sid + user = get_user_from_socket(sid) + + if not user: + return emit_chat_error('User not authenticated', 'AUTH_REQUIRED') + + session_id = data.get('session_id') + message_ids = data.get('message_ids', []) + + if not session_id: + return emit_chat_error('session_id is required', 'MISSING_SESSION_ID') + + # 更新消息状态 + for msg_id in message_ids: + message = Message.query.get(msg_id) + if message and message.session_id == session_id: + message.status = 'read' + + # 更新会话未读数 + session = Session.query.get(session_id) + if session: + session.unread_count = 0 + + db.session.commit() + + # 返回已读确认 + emit('chat.read', { + 'session_id': session_id, + 'message_ids': message_ids + }) + + # ==================== 关闭会话 ==================== + def close_chat_session(session_id: str, reason: str = 'closed'): + """关闭聊天会话(内部方法)""" + session = Session.query.get(session_id) + if session: + session.status = 'closed' + session.updated_at = datetime.utcnow() + db.session.commit() + + # 通知房间内的所有用户 + emit('chat.closed', { + 'session_id': session_id, + 'reason': reason + }, room=session_id) + + +# 导出注册函数 +__all__ = ['register_chat_handlers', 'chat_manager'] diff --git a/app/socketio/handlers.py b/app/socketio/handlers.py index 29db808..2b11a51 100644 --- a/app/socketio/handlers.py +++ b/app/socketio/handlers.py @@ -20,6 +20,7 @@ connection_manager = ConnectionManager() def register_handlers(socketio): """注册 Socket.IO 事件处理器""" + # ==================== 连接事件 ==================== @socketio.on('connect') def handle_connect(): """客户端连接""" @@ -38,10 +39,13 @@ def register_handlers(socketio): if sid in connection_manager.socket_sessions: del connection_manager.socket_sessions[sid] + # ==================== 认证事件 ==================== @socketio.on('auth') def handle_auth(data): """处理认证""" from flask import request + from app.models import User + sid = request.sid token = data.get('token') @@ -71,11 +75,13 @@ def register_handlers(socketio): except Exception as e: emit('auth_error', {'code': 'INVALID_TOKEN', 'message': str(e)}) + # ==================== 心跳事件 ==================== @socketio.on('ping') def handle_ping(data): """心跳响应""" emit('pong', {'timestamp': datetime.utcnow().timestamp()}) + # ==================== 会话事件 ==================== @socketio.on('session.create') def handle_session_create(data): """创建会话""" @@ -124,6 +130,7 @@ def register_handlers(socketio): 'participants': [session.user_id, session.primary_agent_id] }) + # ==================== 消息事件 ==================== @socketio.on('message.send') def handle_message_send(data): """发送消息""" @@ -180,3 +187,7 @@ def register_handlers(socketio): db.session.commit() emit('message.acked', {'message_id': message_id, 'status': status}) + + # ==================== 聊天事件 (Step 4) ==================== + from app.socketio.chat_handlers import register_chat_handlers + register_chat_handlers(socketio)