- 创建 chat_handlers.py 聊天事件处理器 - 实现 6 个 C→S 事件: - chat.send.create - 创建聊天会话 - chat.send.join - 加入会话 - chat.send.leave - 离开会话 - chat.send.message - 发送消息 - chat.send.typing - 正在输入 - chat.send.read - 消息已读 - 实现 7 个 S→C 事件: - chat.created - 会话已创建 - chat.joined - 已加入会话 - chat.left - 已离开会话 - chat.message - 收到消息 - chat.typing - 对方正在输入 - chat.read - 消息已读确认 - chat.closed - 会话被关闭 - 创建 ChatConnectionManager 管理连接 - 注册聊天事件处理器 - 更新版本号到 0.9.3
392 lines
12 KiB
Python
392 lines
12 KiB
Python
"""
|
||
聊天 WebSocket 事件处理器
|
||
实现智队机器人的聊天功能
|
||
"""
|
||
from flask_socketio import emit, join_room, leave_room, rooms
|
||
from flask import request
|
||
from datetime import datetime
|
||
from typing import Optional, Dict, Any
|
||
|
||
from app.extensions import db, redis_client
|
||
from app.models import User, Session, Message, Bot, Agent, Connection
|
||
from app.services.bot_service import BotService
|
||
from app.services.session_service import SessionService
|
||
|
||
|
||
class ChatConnectionManager:
|
||
"""聊天连接管理器"""
|
||
|
||
def __init__(self):
|
||
# socket_id -> {user_id, current_session_id}
|
||
self.connections: Dict[str, Dict[str, Any]] = {}
|
||
|
||
def add_connection(self, socket_id: str, user_id: str):
|
||
"""添加连接"""
|
||
self.connections[socket_id] = {
|
||
'user_id': user_id,
|
||
'current_session_id': None,
|
||
'connected_at': datetime.utcnow()
|
||
}
|
||
|
||
def set_current_session(self, socket_id: str, session_id: Optional[str]):
|
||
"""设置当前会话"""
|
||
if socket_id in self.connections:
|
||
self.connections[socket_id]['current_session_id'] = session_id
|
||
|
||
def get_user_id(self, socket_id: str) -> Optional[str]:
|
||
"""获取用户 ID"""
|
||
return self.connections.get(socket_id, {}).get('user_id')
|
||
|
||
def get_current_session(self, socket_id: str) -> Optional[str]:
|
||
"""获取当前会话 ID"""
|
||
return self.connections.get(socket_id, {}).get('current_session_id')
|
||
|
||
def remove_connection(self, socket_id: str):
|
||
"""移除连接"""
|
||
if socket_id in self.connections:
|
||
del self.connections[socket_id]
|
||
|
||
|
||
# 全局聊天连接管理器
|
||
chat_manager = ChatConnectionManager()
|
||
|
||
|
||
def get_user_from_socket(socket_id: str) -> Optional[User]:
|
||
"""从 socket_id 获取用户"""
|
||
user_id = chat_manager.get_user_id(socket_id)
|
||
if user_id:
|
||
return User.query.get(user_id)
|
||
return None
|
||
|
||
|
||
def emit_chat_error(message: str, code: str = 'CHAT_ERROR', session_id: Optional[str] = None):
|
||
"""发送聊天错误"""
|
||
error_data = {
|
||
'code': code,
|
||
'message': message
|
||
}
|
||
if session_id:
|
||
error_data['session_id'] = session_id
|
||
emit('chat_error', error_data)
|
||
|
||
|
||
def register_chat_handlers(socketio):
|
||
"""注册聊天事件处理器"""
|
||
|
||
# ==================== 创建会话 ====================
|
||
@socketio.on('chat.send.create')
|
||
def handle_chat_create(data):
|
||
"""
|
||
创建聊天会话
|
||
|
||
参数:
|
||
- bot_id: str - 机器人 ID(必填)
|
||
- title: str - 会话标题(可选)
|
||
"""
|
||
sid = request.sid
|
||
user = get_user_from_socket(sid)
|
||
|
||
if not user:
|
||
return emit_chat_error('User not authenticated', 'AUTH_REQUIRED')
|
||
|
||
bot_id = data.get('bot_id')
|
||
if not bot_id:
|
||
return emit_chat_error('bot_id is required', 'MISSING_BOT_ID')
|
||
|
||
# 获取 Bot
|
||
bot = BotService.get_bot_by_id(bot_id)
|
||
if not bot:
|
||
return emit_chat_error('Bot not found', 'BOT_NOT_FOUND')
|
||
|
||
# 检查权限
|
||
if not BotService.check_permission(user, bot, 'use'):
|
||
return emit_chat_error('Permission denied', 'PERMISSION_DENIED')
|
||
|
||
# 检查 Bot 是否绑定了 Agent
|
||
if not bot.agent_id:
|
||
return emit_chat_error('Bot has no agent bound', 'NO_AGENT_BOUND')
|
||
|
||
# 检查 Agent 是否在线
|
||
agent = Agent.query.get(bot.agent_id)
|
||
if not agent or agent.status != 'online':
|
||
return emit_chat_error('Agent is offline', 'AGENT_OFFLINE')
|
||
|
||
# 创建会话
|
||
title = data.get('title', f'Chat with {bot.display_name or bot.name}')
|
||
session = Session(
|
||
user_id=user.id,
|
||
bot_id=bot.id,
|
||
primary_agent_id=agent.id,
|
||
title=title,
|
||
channel_type='pit-bot',
|
||
status='active',
|
||
created_at=datetime.utcnow(),
|
||
last_active_at=datetime.utcnow()
|
||
)
|
||
db.session.add(session)
|
||
db.session.commit()
|
||
|
||
# 加入房间
|
||
join_room(session.id)
|
||
chat_manager.set_current_session(sid, session.id)
|
||
|
||
# 返回创建结果
|
||
emit('chat.created', {
|
||
'session_id': session.id,
|
||
'bot': bot.to_dict(),
|
||
'agent': {
|
||
'id': agent.id,
|
||
'name': agent.name,
|
||
'display_name': agent.display_name
|
||
},
|
||
'title': title,
|
||
'created_at': session.created_at.isoformat()
|
||
})
|
||
|
||
# ==================== 加入会话 ====================
|
||
@socketio.on('chat.send.join')
|
||
def handle_chat_join(data):
|
||
"""
|
||
加入会话
|
||
|
||
参数:
|
||
- session_id: str - 会话 ID
|
||
"""
|
||
sid = request.sid
|
||
user = get_user_from_socket(sid)
|
||
|
||
if not user:
|
||
return emit_chat_error('User not authenticated', 'AUTH_REQUIRED')
|
||
|
||
session_id = data.get('session_id')
|
||
if not session_id:
|
||
return emit_chat_error('session_id is required', 'MISSING_SESSION_ID')
|
||
|
||
# 获取会话
|
||
session = Session.query.get(session_id)
|
||
if not session:
|
||
return emit_chat_error('Session not found', 'SESSION_NOT_FOUND', session_id)
|
||
|
||
# 检查权限
|
||
if session.user_id != user.id:
|
||
return emit_chat_error('Permission denied', 'PERMISSION_DENIED', session_id)
|
||
|
||
# 获取 Bot 信息
|
||
bot = None
|
||
if session.bot_id:
|
||
bot = BotService.get_bot_by_id(session.bot_id)
|
||
|
||
# 获取历史消息
|
||
messages = Message.query.filter_by(session_id=session_id)\
|
||
.order_by(Message.created_at.desc())\
|
||
.limit(50)\
|
||
.all()
|
||
|
||
# 加入房间
|
||
join_room(session_id)
|
||
chat_manager.set_current_session(sid, session_id)
|
||
|
||
# 返回加入结果
|
||
emit('chat.joined', {
|
||
'session_id': session_id,
|
||
'bot': bot.to_dict() if bot else None,
|
||
'messages': [m.to_dict() for m in reversed(messages)],
|
||
'message_count': len(messages)
|
||
})
|
||
|
||
# ==================== 离开会话 ====================
|
||
@socketio.on('chat.send.leave')
|
||
def handle_chat_leave(data):
|
||
"""
|
||
离开会话
|
||
|
||
参数:
|
||
- session_id: str - 会话 ID
|
||
"""
|
||
sid = request.sid
|
||
session_id = data.get('session_id')
|
||
|
||
if not session_id:
|
||
return emit_chat_error('session_id is required', 'MISSING_SESSION_ID')
|
||
|
||
# 离开房间
|
||
leave_room(session_id)
|
||
chat_manager.set_current_session(sid, None)
|
||
|
||
emit('chat.left', {'session_id': session_id})
|
||
|
||
# ==================== 发送消息 ====================
|
||
@socketio.on('chat.send.message')
|
||
def handle_chat_message(data):
|
||
"""
|
||
发送消息
|
||
|
||
参数:
|
||
- session_id: str - 会话 ID
|
||
- content: str - 消息内容
|
||
- reply_to: str - 回复的消息 ID(可选)
|
||
"""
|
||
sid = request.sid
|
||
user = get_user_from_socket(sid)
|
||
|
||
if not user:
|
||
return emit_chat_error('User not authenticated', 'AUTH_REQUIRED')
|
||
|
||
session_id = data.get('session_id')
|
||
content = data.get('content')
|
||
reply_to = data.get('reply_to')
|
||
|
||
if not session_id:
|
||
return emit_chat_error('session_id is required', 'MISSING_SESSION_ID')
|
||
|
||
if not content or not content.strip():
|
||
return emit_chat_error('content is required', 'MISSING_CONTENT', session_id)
|
||
|
||
# 获取会话
|
||
session = Session.query.get(session_id)
|
||
if not session:
|
||
return emit_chat_error('Session not found', 'SESSION_NOT_FOUND', session_id)
|
||
|
||
# 检查权限
|
||
if session.user_id != user.id:
|
||
return emit_chat_error('Permission denied', 'PERMISSION_DENIED', session_id)
|
||
|
||
# 获取 Bot 信息
|
||
bot = None
|
||
sender_name = user.nickname or user.username
|
||
if session.bot_id:
|
||
bot = BotService.get_bot_by_id(session.bot_id)
|
||
|
||
# 创建用户消息
|
||
message = Message(
|
||
session_id=session_id,
|
||
sender_type='user',
|
||
sender_id=user.id,
|
||
sender_name=sender_name,
|
||
bot_id=session.bot_id,
|
||
message_type='text',
|
||
content=content.strip(),
|
||
content_type='markdown',
|
||
reply_to=reply_to,
|
||
status='sent',
|
||
ack_status='pending',
|
||
created_at=datetime.utcnow()
|
||
)
|
||
db.session.add(message)
|
||
|
||
# 更新会话
|
||
session.message_count += 1
|
||
session.last_active_at = datetime.utcnow()
|
||
session.updated_at = datetime.utcnow()
|
||
|
||
db.session.commit()
|
||
|
||
# 广播消息到房间(用户端)
|
||
emit('chat.message', {
|
||
'message_id': message.id,
|
||
'session_id': session_id,
|
||
'sender_type': 'user',
|
||
'sender_id': user.id,
|
||
'sender_name': sender_name,
|
||
'bot_id': session.bot_id,
|
||
'content': content.strip(),
|
||
'content_type': 'markdown',
|
||
'reply_to': reply_to,
|
||
'timestamp': message.created_at.isoformat()
|
||
}, room=session_id)
|
||
|
||
# TODO: 转发消息给 Agent(通过 PIT Channel 协议)
|
||
# 这里需要实现将消息转发给绑定的 Agent
|
||
# 使用 session.primary_agent_id 获取 Agent
|
||
# 然后通过 Agent 的 socket_id 发送消息
|
||
|
||
# ==================== 正在输入 ====================
|
||
@socketio.on('chat.send.typing')
|
||
def handle_chat_typing(data):
|
||
"""
|
||
正在输入状态
|
||
|
||
参数:
|
||
- session_id: str - 会话 ID
|
||
- is_typing: bool - 是否正在输入
|
||
"""
|
||
sid = request.sid
|
||
user = get_user_from_socket(sid)
|
||
|
||
if not user:
|
||
return
|
||
|
||
session_id = data.get('session_id')
|
||
is_typing = data.get('is_typing', False)
|
||
|
||
if not session_id:
|
||
return
|
||
|
||
# 广播输入状态到房间(除了发送者)
|
||
emit('chat.typing', {
|
||
'session_id': session_id,
|
||
'user_id': user.id,
|
||
'user_name': user.nickname or user.username,
|
||
'is_typing': is_typing
|
||
}, room=session_id, include_self=False)
|
||
|
||
# ==================== 消息已读 ====================
|
||
@socketio.on('chat.send.read')
|
||
def handle_chat_read(data):
|
||
"""
|
||
标记消息已读
|
||
|
||
参数:
|
||
- session_id: str - 会话 ID
|
||
- message_ids: list - 消息 ID 列表
|
||
"""
|
||
sid = request.sid
|
||
user = get_user_from_socket(sid)
|
||
|
||
if not user:
|
||
return emit_chat_error('User not authenticated', 'AUTH_REQUIRED')
|
||
|
||
session_id = data.get('session_id')
|
||
message_ids = data.get('message_ids', [])
|
||
|
||
if not session_id:
|
||
return emit_chat_error('session_id is required', 'MISSING_SESSION_ID')
|
||
|
||
# 更新消息状态
|
||
for msg_id in message_ids:
|
||
message = Message.query.get(msg_id)
|
||
if message and message.session_id == session_id:
|
||
message.status = 'read'
|
||
|
||
# 更新会话未读数
|
||
session = Session.query.get(session_id)
|
||
if session:
|
||
session.unread_count = 0
|
||
|
||
db.session.commit()
|
||
|
||
# 返回已读确认
|
||
emit('chat.read', {
|
||
'session_id': session_id,
|
||
'message_ids': message_ids
|
||
})
|
||
|
||
# ==================== 关闭会话 ====================
|
||
def close_chat_session(session_id: str, reason: str = 'closed'):
|
||
"""关闭聊天会话(内部方法)"""
|
||
session = Session.query.get(session_id)
|
||
if session:
|
||
session.status = 'closed'
|
||
session.updated_at = datetime.utcnow()
|
||
db.session.commit()
|
||
|
||
# 通知房间内的所有用户
|
||
emit('chat.closed', {
|
||
'session_id': session_id,
|
||
'reason': reason
|
||
}, room=session_id)
|
||
|
||
|
||
# 导出注册函数
|
||
__all__ = ['register_chat_handlers', 'chat_manager']
|