183 lines
5.5 KiB
Python
183 lines
5.5 KiB
Python
"""
|
||
Socket.IO 事件处理
|
||
"""
|
||
from flask_socketio import emit, join_room, leave_room
|
||
from flask_jwt_extended import decode_token
|
||
from datetime import datetime
|
||
import json
|
||
from app.extensions import redis_client
|
||
|
||
# 连接管理器
|
||
class ConnectionManager:
|
||
def __init__(self):
|
||
self.user_sockets = {} # user_id -> socket_id
|
||
self.agent_sockets = {} # agent_id -> socket_id
|
||
self.socket_sessions = {} # socket_id -> session_id
|
||
|
||
connection_manager = ConnectionManager()
|
||
|
||
|
||
def register_handlers(socketio):
|
||
"""注册 Socket.IO 事件处理器"""
|
||
|
||
@socketio.on('connect')
|
||
def handle_connect():
|
||
"""客户端连接"""
|
||
print('Client connected')
|
||
# 发送认证请求
|
||
emit('auth', {'message': 'Please authenticate'})
|
||
|
||
@socketio.on('disconnect')
|
||
def handle_disconnect():
|
||
"""客户端断开连接"""
|
||
from flask import request
|
||
sid = request.sid
|
||
print(f'Client disconnected: {sid}')
|
||
|
||
# 清理连接
|
||
if sid in connection_manager.socket_sessions:
|
||
del connection_manager.socket_sessions[sid]
|
||
|
||
@socketio.on('auth')
|
||
def handle_auth(data):
|
||
"""处理认证"""
|
||
from flask import request
|
||
sid = request.sid
|
||
|
||
token = data.get('token')
|
||
if not token:
|
||
emit('auth_error', {'code': 'MISSING_TOKEN', 'message': 'Token is required'})
|
||
return
|
||
|
||
try:
|
||
# 验证 JWT Token
|
||
decoded = decode_token(token)
|
||
user_id = decoded['sub']
|
||
|
||
# 保存连接信息
|
||
connection_manager.user_sockets[user_id] = sid
|
||
|
||
# 存储到 Redis
|
||
redis_client.hset(f'socket:{sid}', mapping={
|
||
'user_id': user_id,
|
||
'connected_at': datetime.utcnow().isoformat()
|
||
})
|
||
redis_client.expire(f'socket:{sid}', 86400)
|
||
|
||
emit('authenticated', {
|
||
'user_id': user_id,
|
||
'socket_id': sid
|
||
})
|
||
except Exception as e:
|
||
emit('auth_error', {'code': 'INVALID_TOKEN', 'message': str(e)})
|
||
|
||
@socketio.on('ping')
|
||
def handle_ping(data):
|
||
"""心跳响应"""
|
||
emit('pong', {'timestamp': datetime.utcnow().timestamp()})
|
||
|
||
@socketio.on('session.create')
|
||
def handle_session_create(data):
|
||
"""创建会话"""
|
||
from app.models import db, Session, Agent
|
||
|
||
user_id = data.get('user_id')
|
||
title = data.get('title', 'New Session')
|
||
agent_id = data.get('agent_id')
|
||
|
||
# 如果没有指定 Agent,分配一个
|
||
if not agent_id:
|
||
agent = Agent.query.filter_by(status='online').first()
|
||
if agent:
|
||
agent_id = agent.id
|
||
|
||
session = Session(
|
||
user_id=user_id,
|
||
primary_agent_id=agent_id,
|
||
title=title
|
||
)
|
||
db.session.add(session)
|
||
db.session.commit()
|
||
|
||
emit('session.created', {
|
||
'session_id': session.id,
|
||
'agent_id': agent_id
|
||
})
|
||
|
||
@socketio.on('session.join')
|
||
def handle_session_join(data):
|
||
"""加入会话"""
|
||
from flask import request
|
||
from app.models import Session
|
||
|
||
sid = request.sid
|
||
session_id = data.get('session_id')
|
||
|
||
# 加入房间
|
||
join_room(session_id)
|
||
connection_manager.socket_sessions[sid] = session_id
|
||
|
||
session = Session.query.get(session_id)
|
||
if session:
|
||
emit('session.joined', {
|
||
'session_id': session_id,
|
||
'participants': [session.user_id, session.primary_agent_id]
|
||
})
|
||
|
||
@socketio.on('message.send')
|
||
def handle_message_send(data):
|
||
"""发送消息"""
|
||
from flask import request
|
||
from app.models import db, Message, Session
|
||
|
||
sid = request.sid
|
||
session_id = data.get('session_id')
|
||
content = data.get('content')
|
||
sender_type = data.get('sender_type', 'user')
|
||
sender_id = data.get('sender_id')
|
||
|
||
# 创建消息
|
||
message = Message(
|
||
session_id=session_id,
|
||
sender_type=sender_type,
|
||
sender_id=sender_id,
|
||
content=content,
|
||
status='sent'
|
||
)
|
||
db.session.add(message)
|
||
|
||
# 更新会话
|
||
session = Session.query.get(session_id)
|
||
if session:
|
||
session.message_count += 1
|
||
session.last_active_at = datetime.utcnow()
|
||
|
||
db.session.commit()
|
||
|
||
# 广播消息到房间
|
||
emit('message', {
|
||
'message_id': message.id,
|
||
'session_id': session_id,
|
||
'sender': {'type': sender_type, 'id': sender_id},
|
||
'content': content,
|
||
'timestamp': datetime.utcnow().isoformat()
|
||
}, room=session_id)
|
||
|
||
@socketio.on('message.ack')
|
||
def handle_message_ack(data):
|
||
"""消息确认"""
|
||
from app.models import db, Message
|
||
|
||
message_id = data.get('message_id')
|
||
status = data.get('status')
|
||
|
||
message = Message.query.get(message_id)
|
||
if message:
|
||
message.ack_status = status
|
||
if status == 'acknowledged':
|
||
message.status = 'delivered'
|
||
message.delivered_at = datetime.utcnow()
|
||
db.session.commit()
|
||
|
||
emit('message.acked', {'message_id': message_id, 'status': status})
|