feat: Phase 1 - 核心功能实现
This commit is contained in:
39
.env.example
Normal file
39
.env.example
Normal file
@@ -0,0 +1,39 @@
|
||||
# PIT Router 配置文件
|
||||
|
||||
# Flask 配置
|
||||
FLASK_ENV=development
|
||||
FLASK_APP=run.py
|
||||
SECRET_KEY=your-secret-key-here-change-in-production
|
||||
JWT_SECRET_KEY=your-jwt-secret-key-here-change-in-production
|
||||
|
||||
# 数据库配置
|
||||
DATABASE_URL=postgresql://user:password@localhost:5432/pit_router
|
||||
# 开发环境可使用 SQLite
|
||||
# DATABASE_URL=sqlite:///app.db
|
||||
|
||||
# Redis 配置
|
||||
REDIS_URL=redis://localhost:6379/0
|
||||
|
||||
# 服务器配置
|
||||
HOST=0.0.0.0
|
||||
PORT=9000
|
||||
DEBUG=True
|
||||
|
||||
# JWT 配置
|
||||
JWT_ACCESS_TOKEN_EXPIRES=86400
|
||||
JWT_REFRESH_TOKEN_EXPIRES=604800
|
||||
|
||||
# WebSocket 配置
|
||||
SOCKETIO_PING_INTERVAL=25000
|
||||
SOCKETIO_PING_TIMEOUT=10000
|
||||
|
||||
# Agent 调度配置
|
||||
SCHEDULER_STRATEGY=weighted_round_robin
|
||||
SCHEDULER_TIMEOUT=30
|
||||
|
||||
# 安全配置
|
||||
RATE_LIMIT=100
|
||||
CORS_ORIGINS=*
|
||||
|
||||
# 日志配置
|
||||
LOG_LEVEL=INFO
|
||||
117
app/__init__.py
Normal file
117
app/__init__.py
Normal file
@@ -0,0 +1,117 @@
|
||||
"""
|
||||
PIT Router Flask 应用工厂
|
||||
"""
|
||||
from flask import Flask
|
||||
from flask_socketio import SocketIO
|
||||
import logging
|
||||
|
||||
from app.config import config
|
||||
from app.extensions import (
|
||||
db, migrate, jwt, login_manager, cors, limiter, init_redis
|
||||
)
|
||||
|
||||
# Socket.IO 实例
|
||||
socketio = SocketIO(cors_allowed_origins="*", async_mode='threading')
|
||||
|
||||
|
||||
def create_app(config_name='default'):
|
||||
"""创建 Flask 应用实例"""
|
||||
app = Flask(__name__)
|
||||
app.config.from_object(config[config_name])
|
||||
|
||||
# 初始化扩展
|
||||
_init_extensions(app)
|
||||
|
||||
# 注册蓝图
|
||||
_register_blueprints(app)
|
||||
|
||||
# 注册 Socket.IO 事件
|
||||
_register_socketio_events()
|
||||
|
||||
# 配置日志
|
||||
_configure_logging(app)
|
||||
|
||||
# 注册错误处理
|
||||
_register_error_handlers(app)
|
||||
|
||||
# 健康检查端点
|
||||
@app.route('/health')
|
||||
def health_check():
|
||||
return {'status': 'ok', 'service': 'pit-router'}
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def _init_extensions(app):
|
||||
"""初始化 Flask 扩展"""
|
||||
db.init_app(app)
|
||||
migrate.init_app(app, db)
|
||||
jwt.init_app(app)
|
||||
login_manager.init_app(app)
|
||||
cors.init_app(app)
|
||||
limiter.init_app(app)
|
||||
socketio.init_app(app)
|
||||
|
||||
# 初始化 Redis
|
||||
init_redis(app)
|
||||
|
||||
# 创建数据库表
|
||||
with app.app_context():
|
||||
db.create_all()
|
||||
|
||||
|
||||
def _register_blueprints(app):
|
||||
"""注册蓝图"""
|
||||
from app.routes.auth import auth_bp
|
||||
from app.routes.sessions import sessions_bp
|
||||
from app.routes.agents import agents_bp
|
||||
from app.routes.gateways import gateways_bp
|
||||
from app.routes.messages import messages_bp
|
||||
from app.routes.stats import stats_bp
|
||||
|
||||
app.register_blueprint(auth_bp, url_prefix='/api/auth')
|
||||
app.register_blueprint(sessions_bp, url_prefix='/api/sessions')
|
||||
app.register_blueprint(agents_bp, url_prefix='/api/agents')
|
||||
app.register_blueprint(gateways_bp, url_prefix='/api/gateways')
|
||||
app.register_blueprint(messages_bp, url_prefix='/api/messages')
|
||||
app.register_blueprint(stats_bp, url_prefix='/api/stats')
|
||||
|
||||
|
||||
def _register_socketio_events():
|
||||
"""注册 Socket.IO 事件处理器"""
|
||||
from app.socketio.handlers import register_handlers
|
||||
register_handlers(socketio)
|
||||
|
||||
|
||||
def _configure_logging(app):
|
||||
"""配置日志"""
|
||||
log_level = app.config.get('LOG_LEVEL', 'INFO')
|
||||
logging.basicConfig(
|
||||
level=getattr(logging, log_level),
|
||||
format='%(asctime)s [%(levelname)s] %(name)s: %(message)s'
|
||||
)
|
||||
|
||||
|
||||
def _register_error_handlers(app):
|
||||
"""注册错误处理器"""
|
||||
from flask import jsonify
|
||||
|
||||
@app.errorhandler(400)
|
||||
def bad_request(error):
|
||||
return jsonify({'error': 'Bad Request', 'message': str(error)}), 400
|
||||
|
||||
@app.errorhandler(401)
|
||||
def unauthorized(error):
|
||||
return jsonify({'error': 'Unauthorized', 'message': str(error)}), 401
|
||||
|
||||
@app.errorhandler(403)
|
||||
def forbidden(error):
|
||||
return jsonify({'error': 'Forbidden', 'message': str(error)}), 403
|
||||
|
||||
@app.errorhandler(404)
|
||||
def not_found(error):
|
||||
return jsonify({'error': 'Not Found', 'message': str(error)}), 404
|
||||
|
||||
@app.errorhandler(500)
|
||||
def internal_error(error):
|
||||
return jsonify({'error': 'Internal Server Error', 'message': str(error)}), 500
|
||||
72
app/config.py
Normal file
72
app/config.py
Normal file
@@ -0,0 +1,72 @@
|
||||
"""
|
||||
PIT Router 配置管理
|
||||
"""
|
||||
import os
|
||||
from datetime import timedelta
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
class Config:
|
||||
"""基础配置"""
|
||||
SECRET_KEY = os.environ.get('SECRET_KEY', 'dev-secret-key-change-in-production')
|
||||
JWT_SECRET_KEY = os.environ.get('JWT_SECRET_KEY', 'dev-jwt-secret-key-change-in-production')
|
||||
|
||||
# 数据库
|
||||
SQLALCHEMY_DATABASE_URI = os.environ.get('DATABASE_URL', 'sqlite:///pit_router.db')
|
||||
SQLALCHEMY_TRACK_MODIFICATIONS = False
|
||||
SQLALCHEMY_ECHO = False
|
||||
|
||||
# Redis
|
||||
REDIS_URL = os.environ.get('REDIS_URL', 'redis://localhost:6379/0')
|
||||
|
||||
# JWT
|
||||
JWT_ACCESS_TOKEN_EXPIRES = timedelta(
|
||||
seconds=int(os.environ.get('JWT_ACCESS_TOKEN_EXPIRES', 86400))
|
||||
)
|
||||
JWT_REFRESH_TOKEN_EXPIRES = timedelta(
|
||||
seconds=int(os.environ.get('JWT_REFRESH_TOKEN_EXPIRES', 604800))
|
||||
)
|
||||
|
||||
# Socket.IO
|
||||
SOCKETIO_PING_INTERVAL = int(os.environ.get('SOCKETIO_PING_INTERVAL', 25000))
|
||||
SOCKETIO_PING_TIMEOUT = int(os.environ.get('SOCKETIO_PING_TIMEOUT', 10000))
|
||||
|
||||
# Agent 调度
|
||||
SCHEDULER_STRATEGY = os.environ.get('SCHEDULER_STRATEGY', 'weighted_round_robin')
|
||||
SCHEDULER_TIMEOUT = int(os.environ.get('SCHEDULER_TIMEOUT', 30))
|
||||
|
||||
# 安全
|
||||
RATE_LIMIT = int(os.environ.get('RATE_LIMIT', 100))
|
||||
CORS_ORIGINS = os.environ.get('CORS_ORIGINS', '*')
|
||||
|
||||
|
||||
class DevelopmentConfig(Config):
|
||||
"""开发环境配置"""
|
||||
DEBUG = True
|
||||
SQLALCHEMY_ECHO = True
|
||||
|
||||
|
||||
class ProductionConfig(Config):
|
||||
"""生产环境配置"""
|
||||
DEBUG = False
|
||||
|
||||
# 生产环境强制使用 PostgreSQL
|
||||
if not os.environ.get('DATABASE_URL'):
|
||||
raise ValueError("DATABASE_URL must be set in production")
|
||||
|
||||
|
||||
class TestingConfig(Config):
|
||||
"""测试环境配置"""
|
||||
TESTING = True
|
||||
SQLALCHEMY_DATABASE_URI = 'sqlite:///:memory:'
|
||||
WTF_CSRF_ENABLED = False
|
||||
|
||||
|
||||
config = {
|
||||
'development': DevelopmentConfig,
|
||||
'production': ProductionConfig,
|
||||
'testing': TestingConfig,
|
||||
'default': DevelopmentConfig
|
||||
}
|
||||
32
app/extensions.py
Normal file
32
app/extensions.py
Normal file
@@ -0,0 +1,32 @@
|
||||
"""
|
||||
Flask 扩展初始化
|
||||
"""
|
||||
from flask_sqlalchemy import SQLAlchemy
|
||||
from flask_migrate import Migrate
|
||||
from flask_jwt_extended import JWTManager
|
||||
from flask_login import LoginManager
|
||||
from flask_cors import CORS
|
||||
from flask_limiter import Limiter
|
||||
from flask_limiter.util import get_remote_address
|
||||
import redis
|
||||
|
||||
db = SQLAlchemy()
|
||||
migrate = Migrate()
|
||||
jwt = JWTManager()
|
||||
login_manager = LoginManager()
|
||||
cors = CORS()
|
||||
limiter = Limiter(
|
||||
key_func=get_remote_address,
|
||||
default_limits=["100 per minute"]
|
||||
)
|
||||
|
||||
# Redis 客户端
|
||||
redis_client = None
|
||||
|
||||
|
||||
def init_redis(app):
|
||||
"""初始化 Redis 客户端"""
|
||||
global redis_client
|
||||
redis_url = app.config.get('REDIS_URL', 'redis://localhost:6379/0')
|
||||
redis_client = redis.from_url(redis_url)
|
||||
return redis_client
|
||||
251
app/models/__init__.py
Normal file
251
app/models/__init__.py
Normal file
@@ -0,0 +1,251 @@
|
||||
"""
|
||||
PIT Router 数据模型
|
||||
"""
|
||||
from datetime import datetime
|
||||
from typing import Optional, List
|
||||
from flask_sqlalchemy import SQLAlchemy
|
||||
from sqlalchemy import String, DateTime, Integer, Text, JSON, ForeignKey, Boolean
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
import uuid
|
||||
|
||||
db = SQLAlchemy()
|
||||
|
||||
|
||||
def generate_uuid() -> str:
|
||||
"""生成 UUID"""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
class User(db.Model):
|
||||
"""用户模型"""
|
||||
__tablename__ = 'users'
|
||||
|
||||
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=generate_uuid)
|
||||
username: Mapped[str] = mapped_column(String(80), unique=True, nullable=False)
|
||||
password_hash: Mapped[str] = mapped_column(String(256), nullable=False)
|
||||
email: Mapped[str] = mapped_column(String(120), unique=True, nullable=False)
|
||||
nickname: Mapped[Optional[str]] = mapped_column(String(80), nullable=True)
|
||||
role: Mapped[str] = mapped_column(String(20), default='user')
|
||||
status: Mapped[str] = mapped_column(String(20), default='active')
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
|
||||
last_login_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
|
||||
|
||||
# 关联
|
||||
sessions = relationship('Session', back_populates='user', cascade='all, delete-orphan')
|
||||
|
||||
def __repr__(self):
|
||||
return f'<User {self.username}>'
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
'id': self.id,
|
||||
'username': self.username,
|
||||
'email': self.email,
|
||||
'nickname': self.nickname,
|
||||
'role': self.role,
|
||||
'status': self.status,
|
||||
'created_at': self.created_at.isoformat() if self.created_at else None,
|
||||
'last_login_at': self.last_login_at.isoformat() if self.last_login_at else None,
|
||||
}
|
||||
|
||||
|
||||
class Gateway(db.Model):
|
||||
"""Gateway 模型"""
|
||||
__tablename__ = 'gateways'
|
||||
|
||||
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=generate_uuid)
|
||||
name: Mapped[str] = mapped_column(String(80), unique=True, nullable=False)
|
||||
url: Mapped[str] = mapped_column(String(256), nullable=False)
|
||||
token_hash: Mapped[Optional[str]] = mapped_column(String(256), nullable=True)
|
||||
status: Mapped[str] = mapped_column(String(20), default='offline')
|
||||
agent_count: Mapped[int] = mapped_column(Integer, default=0)
|
||||
connection_limit: Mapped[int] = mapped_column(Integer, default=10)
|
||||
heartbeat_interval: Mapped[int] = mapped_column(Integer, default=60)
|
||||
allowed_ips: Mapped[Optional[str]] = mapped_column(JSON, nullable=True)
|
||||
last_heartbeat: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
|
||||
|
||||
# 关联
|
||||
agents = relationship('Agent', back_populates='gateway', cascade='all, delete-orphan')
|
||||
|
||||
def __repr__(self):
|
||||
return f'<Gateway {self.name}>'
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
'id': self.id,
|
||||
'name': self.name,
|
||||
'url': self.url,
|
||||
'status': self.status,
|
||||
'agent_count': self.agent_count,
|
||||
'connection_limit': self.connection_limit,
|
||||
'heartbeat_interval': self.heartbeat_interval,
|
||||
'last_heartbeat': self.last_heartbeat.isoformat() if self.last_heartbeat else None,
|
||||
'created_at': self.created_at.isoformat() if self.created_at else None,
|
||||
}
|
||||
|
||||
|
||||
class Agent(db.Model):
|
||||
"""Agent 模型"""
|
||||
__tablename__ = 'agents'
|
||||
|
||||
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=generate_uuid)
|
||||
name: Mapped[str] = mapped_column(String(80), nullable=False)
|
||||
display_name: Mapped[Optional[str]] = mapped_column(String(80), nullable=True)
|
||||
gateway_id: Mapped[Optional[str]] = mapped_column(String(36), ForeignKey('gateways.id'), nullable=True)
|
||||
socket_id: Mapped[Optional[str]] = mapped_column(String(100), nullable=True)
|
||||
model: Mapped[Optional[str]] = mapped_column(String(80), nullable=True)
|
||||
capabilities: Mapped[Optional[str]] = mapped_column(JSON, nullable=True)
|
||||
status: Mapped[str] = mapped_column(String(20), default='offline')
|
||||
priority: Mapped[int] = mapped_column(Integer, default=5)
|
||||
weight: Mapped[int] = mapped_column(Integer, default=10)
|
||||
connection_limit: Mapped[int] = mapped_column(Integer, default=5)
|
||||
current_sessions: Mapped[int] = mapped_column(Integer, default=0)
|
||||
last_heartbeat: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
|
||||
|
||||
# 关联
|
||||
gateway = relationship('Gateway', back_populates='agents')
|
||||
sessions = relationship('Session', back_populates='agent')
|
||||
|
||||
def __repr__(self):
|
||||
return f'<Agent {self.name}>'
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
'id': self.id,
|
||||
'name': self.name,
|
||||
'display_name': self.display_name,
|
||||
'gateway_id': self.gateway_id,
|
||||
'socket_id': self.socket_id,
|
||||
'model': self.model,
|
||||
'capabilities': self.capabilities,
|
||||
'status': self.status,
|
||||
'priority': self.priority,
|
||||
'weight': self.weight,
|
||||
'connection_limit': self.connection_limit,
|
||||
'current_sessions': self.current_sessions,
|
||||
'last_heartbeat': self.last_heartbeat.isoformat() if self.last_heartbeat else None,
|
||||
'created_at': self.created_at.isoformat() if self.created_at else None,
|
||||
}
|
||||
|
||||
|
||||
class Session(db.Model):
|
||||
"""会话模型"""
|
||||
__tablename__ = 'sessions'
|
||||
|
||||
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=generate_uuid)
|
||||
user_id: Mapped[str] = mapped_column(String(36), ForeignKey('users.id'), nullable=False)
|
||||
primary_agent_id: Mapped[Optional[str]] = mapped_column(String(36), ForeignKey('agents.id'), nullable=True)
|
||||
participating_agent_ids: Mapped[Optional[str]] = mapped_column(JSON, nullable=True)
|
||||
user_socket_id: Mapped[Optional[str]] = mapped_column(String(100), nullable=True)
|
||||
title: Mapped[Optional[str]] = mapped_column(String(200), nullable=True)
|
||||
channel_type: Mapped[str] = mapped_column(String(20), default='web')
|
||||
status: Mapped[str] = mapped_column(String(20), default='active')
|
||||
message_count: Mapped[int] = mapped_column(Integer, default=0)
|
||||
unread_count: Mapped[int] = mapped_column(Integer, default=0)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
|
||||
updated_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
|
||||
last_active_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
|
||||
|
||||
# 关联
|
||||
user = relationship('User', back_populates='sessions')
|
||||
agent = relationship('Agent', back_populates='sessions')
|
||||
messages = relationship('Message', back_populates='session', cascade='all, delete-orphan')
|
||||
|
||||
def __repr__(self):
|
||||
return f'<Session {self.id}>'
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
'id': self.id,
|
||||
'user_id': self.user_id,
|
||||
'primary_agent_id': self.primary_agent_id,
|
||||
'participating_agent_ids': self.participating_agent_ids,
|
||||
'user_socket_id': self.user_socket_id,
|
||||
'title': self.title,
|
||||
'channel_type': self.channel_type,
|
||||
'status': self.status,
|
||||
'message_count': self.message_count,
|
||||
'unread_count': self.unread_count,
|
||||
'created_at': self.created_at.isoformat() if self.created_at else None,
|
||||
'updated_at': self.updated_at.isoformat() if self.updated_at else None,
|
||||
'last_active_at': self.last_active_at.isoformat() if self.last_active_at else None,
|
||||
}
|
||||
|
||||
|
||||
class Message(db.Model):
|
||||
"""消息模型"""
|
||||
__tablename__ = 'messages'
|
||||
|
||||
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=generate_uuid)
|
||||
session_id: Mapped[str] = mapped_column(String(36), ForeignKey('sessions.id'), nullable=False)
|
||||
sender_type: Mapped[str] = mapped_column(String(20), nullable=False)
|
||||
sender_id: Mapped[str] = mapped_column(String(36), nullable=False)
|
||||
message_type: Mapped[str] = mapped_column(String(20), default='text')
|
||||
content: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
content_type: Mapped[str] = mapped_column(String(20), default='markdown')
|
||||
reply_to: Mapped[Optional[str]] = mapped_column(String(36), nullable=True)
|
||||
status: Mapped[str] = mapped_column(String(20), default='sent')
|
||||
ack_status: Mapped[str] = mapped_column(String(20), default='pending')
|
||||
retry_count: Mapped[int] = mapped_column(Integer, default=0)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
|
||||
delivered_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
|
||||
|
||||
# 关联
|
||||
session = relationship('Session', back_populates='messages')
|
||||
|
||||
def __repr__(self):
|
||||
return f'<Message {self.id}>'
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
'id': self.id,
|
||||
'session_id': self.session_id,
|
||||
'sender_type': self.sender_type,
|
||||
'sender_id': self.sender_id,
|
||||
'message_type': self.message_type,
|
||||
'content': self.content,
|
||||
'content_type': self.content_type,
|
||||
'reply_to': self.reply_to,
|
||||
'status': self.status,
|
||||
'ack_status': self.ack_status,
|
||||
'retry_count': self.retry_count,
|
||||
'created_at': self.created_at.isoformat() if self.created_at else None,
|
||||
'delivered_at': self.delivered_at.isoformat() if self.delivered_at else None,
|
||||
}
|
||||
|
||||
|
||||
class Connection(db.Model):
|
||||
"""连接模型"""
|
||||
__tablename__ = 'connections'
|
||||
|
||||
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=generate_uuid)
|
||||
socket_id: Mapped[str] = mapped_column(String(100), unique=True, nullable=False)
|
||||
connection_type: Mapped[str] = mapped_column(String(20), nullable=False)
|
||||
entity_id: Mapped[str] = mapped_column(String(36), nullable=False)
|
||||
entity_type: Mapped[str] = mapped_column(String(20), nullable=False)
|
||||
ip_address: Mapped[Optional[str]] = mapped_column(String(45), nullable=True)
|
||||
user_agent: Mapped[Optional[str]] = mapped_column(String(500), nullable=True)
|
||||
status: Mapped[str] = mapped_column(String(20), default='connected')
|
||||
auth_token: Mapped[Optional[str]] = mapped_column(String(500), nullable=True)
|
||||
connected_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
|
||||
last_activity: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
|
||||
disconnected_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
|
||||
|
||||
def __repr__(self):
|
||||
return f'<Connection {self.socket_id}>'
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
'id': self.id,
|
||||
'socket_id': self.socket_id,
|
||||
'connection_type': self.connection_type,
|
||||
'entity_id': self.entity_id,
|
||||
'entity_type': self.entity_type,
|
||||
'ip_address': self.ip_address,
|
||||
'user_agent': self.user_agent,
|
||||
'status': self.status,
|
||||
'connected_at': self.connected_at.isoformat() if self.connected_at else None,
|
||||
'last_activity': self.last_activity.isoformat() if self.last_activity else None,
|
||||
}
|
||||
3
app/routes/__init__.py
Normal file
3
app/routes/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
路由模块
|
||||
"""
|
||||
91
app/routes/agents.py
Normal file
91
app/routes/agents.py
Normal file
@@ -0,0 +1,91 @@
|
||||
"""
|
||||
Agent 路由
|
||||
"""
|
||||
from flask import Blueprint, request, jsonify
|
||||
from flask_jwt_extended import jwt_required
|
||||
from datetime import datetime
|
||||
from app.models import db, Agent
|
||||
|
||||
agents_bp = Blueprint('agents', __name__)
|
||||
|
||||
|
||||
@agents_bp.route('/', methods=['GET'])
|
||||
@jwt_required()
|
||||
def get_agents():
|
||||
"""获取 Agent 列表"""
|
||||
agents = Agent.query.all()
|
||||
return jsonify({'agents': [a.to_dict() for a in agents]}), 200
|
||||
|
||||
|
||||
@agents_bp.route('/<agent_id>', methods=['GET'])
|
||||
@jwt_required()
|
||||
def get_agent(agent_id):
|
||||
"""获取 Agent 详情"""
|
||||
agent = Agent.query.get(agent_id)
|
||||
if not agent:
|
||||
return jsonify({'error': 'Agent not found'}), 404
|
||||
return jsonify({'agent': agent.to_dict()}), 200
|
||||
|
||||
|
||||
@agents_bp.route('/<agent_id>/status', methods=['GET'])
|
||||
@jwt_required()
|
||||
def get_agent_status(agent_id):
|
||||
"""获取 Agent 实时状态"""
|
||||
agent = Agent.query.get(agent_id)
|
||||
if not agent:
|
||||
return jsonify({'error': 'Agent not found'}), 404
|
||||
|
||||
return jsonify({
|
||||
'agent_id': agent.id,
|
||||
'status': agent.status,
|
||||
'current_sessions': agent.current_sessions,
|
||||
'last_heartbeat': agent.last_heartbeat.isoformat() if agent.last_heartbeat else None
|
||||
}), 200
|
||||
|
||||
|
||||
@agents_bp.route('/<agent_id>/config', methods=['PUT'])
|
||||
@jwt_required()
|
||||
def update_agent_config(agent_id):
|
||||
"""更新 Agent 配置"""
|
||||
agent = Agent.query.get(agent_id)
|
||||
if not agent:
|
||||
return jsonify({'error': 'Agent not found'}), 404
|
||||
|
||||
data = request.get_json()
|
||||
|
||||
if 'name' in data:
|
||||
agent.name = data['name']
|
||||
if 'display_name' in data:
|
||||
agent.display_name = data['display_name']
|
||||
if 'capabilities' in data:
|
||||
agent.capabilities = data['capabilities']
|
||||
if 'priority' in data:
|
||||
agent.priority = data['priority']
|
||||
if 'weight' in data:
|
||||
agent.weight = data['weight']
|
||||
|
||||
db.session.commit()
|
||||
|
||||
return jsonify({'agent': agent.to_dict()}), 200
|
||||
|
||||
|
||||
@agents_bp.route('/<agent_id>/heartbeat', methods=['POST'])
|
||||
def agent_heartbeat(agent_id):
|
||||
"""Agent 心跳上报"""
|
||||
agent = Agent.query.get(agent_id)
|
||||
if not agent:
|
||||
return jsonify({'error': 'Agent not found'}), 404
|
||||
|
||||
agent.last_heartbeat = datetime.utcnow()
|
||||
agent.status = 'online'
|
||||
db.session.commit()
|
||||
|
||||
return jsonify({'status': 'ok'}), 200
|
||||
|
||||
|
||||
@agents_bp.route('/available', methods=['GET'])
|
||||
@jwt_required()
|
||||
def get_available_agents():
|
||||
"""获取可用 Agent 列表"""
|
||||
agents = Agent.query.filter_by(status='online').all()
|
||||
return jsonify({'agents': [a.to_dict() for a in agents]}), 200
|
||||
142
app/routes/auth.py
Normal file
142
app/routes/auth.py
Normal file
@@ -0,0 +1,142 @@
|
||||
"""
|
||||
认证路由
|
||||
"""
|
||||
from flask import Blueprint, request, jsonify
|
||||
from flask_jwt_extended import (
|
||||
create_access_token, create_refresh_token,
|
||||
jwt_required, get_jwt_identity
|
||||
)
|
||||
from datetime import datetime
|
||||
import bcrypt
|
||||
from app.models import db, User
|
||||
|
||||
auth_bp = Blueprint('auth', __name__)
|
||||
|
||||
|
||||
@auth_bp.route('/register', methods=['POST'])
|
||||
def register():
|
||||
"""注册新用户"""
|
||||
data = request.get_json()
|
||||
|
||||
username = data.get('username')
|
||||
email = data.get('email')
|
||||
password = data.get('password')
|
||||
nickname = data.get('nickname')
|
||||
|
||||
if not all([username, email, password]):
|
||||
return jsonify({'error': 'Missing required fields'}), 400
|
||||
|
||||
# 检查用户是否已存在
|
||||
if User.query.filter_by(username=username).first():
|
||||
return jsonify({'error': 'Username already exists'}), 400
|
||||
|
||||
if User.query.filter_by(email=email).first():
|
||||
return jsonify({'error': 'Email already exists'}), 400
|
||||
|
||||
# 密码哈希
|
||||
password_hash = bcrypt.hashpw(
|
||||
password.encode('utf-8'),
|
||||
bcrypt.gensalt()
|
||||
).decode('utf-8')
|
||||
|
||||
# 创建用户
|
||||
user = User(
|
||||
username=username,
|
||||
email=email,
|
||||
password_hash=password_hash,
|
||||
nickname=nickname or username
|
||||
)
|
||||
db.session.add(user)
|
||||
db.session.commit()
|
||||
|
||||
return jsonify({
|
||||
'message': 'User registered successfully',
|
||||
'user': user.to_dict()
|
||||
}), 201
|
||||
|
||||
|
||||
@auth_bp.route('/login', methods=['POST'])
|
||||
def login():
|
||||
"""用户登录"""
|
||||
data = request.get_json()
|
||||
|
||||
username = data.get('username')
|
||||
password = data.get('password')
|
||||
|
||||
if not all([username, password]):
|
||||
return jsonify({'error': 'Missing username or password'}), 400
|
||||
|
||||
# 查找用户
|
||||
user = User.query.filter_by(username=username).first()
|
||||
if not user:
|
||||
return jsonify({'error': 'Invalid username or password'}), 401
|
||||
|
||||
# 验证密码
|
||||
if not bcrypt.checkpw(password.encode('utf-8'), user.password_hash.encode('utf-8')):
|
||||
return jsonify({'error': 'Invalid username or password'}), 401
|
||||
|
||||
# 检查用户状态
|
||||
if user.status != 'active':
|
||||
return jsonify({'error': 'Account is disabled'}), 403
|
||||
|
||||
# 更新最后登录时间
|
||||
user.last_login_at = datetime.utcnow()
|
||||
db.session.commit()
|
||||
|
||||
# 生成 Token
|
||||
access_token = create_access_token(identity=user.id)
|
||||
refresh_token = create_refresh_token(identity=user.id)
|
||||
|
||||
return jsonify({
|
||||
'access_token': access_token,
|
||||
'refresh_token': refresh_token,
|
||||
'user': user.to_dict()
|
||||
}), 200
|
||||
|
||||
|
||||
@auth_bp.route('/refresh', methods=['POST'])
|
||||
@jwt_required(refresh=True)
|
||||
def refresh():
|
||||
"""刷新 Token"""
|
||||
identity = get_jwt_identity()
|
||||
access_token = create_access_token(identity=identity)
|
||||
|
||||
return jsonify({
|
||||
'access_token': access_token
|
||||
}), 200
|
||||
|
||||
|
||||
@auth_bp.route('/me', methods=['GET'])
|
||||
@jwt_required()
|
||||
def get_current_user():
|
||||
"""获取当前用户信息"""
|
||||
user_id = get_jwt_identity()
|
||||
user = User.query.get(user_id)
|
||||
|
||||
if not user:
|
||||
return jsonify({'error': 'User not found'}), 404
|
||||
|
||||
return jsonify({'user': user.to_dict()}), 200
|
||||
|
||||
|
||||
@auth_bp.route('/logout', methods=['POST'])
|
||||
@jwt_required()
|
||||
def logout():
|
||||
"""用户登出"""
|
||||
return jsonify({'message': 'Logged out successfully'}), 200
|
||||
|
||||
|
||||
@auth_bp.route('/verify', methods=['POST'])
|
||||
@jwt_required()
|
||||
def verify_token():
|
||||
"""验证 Token 有效性"""
|
||||
user_id = get_jwt_identity()
|
||||
user = User.query.get(user_id)
|
||||
|
||||
if not user:
|
||||
return jsonify({'valid': False}), 401
|
||||
|
||||
return jsonify({
|
||||
'valid': True,
|
||||
'user': user.to_dict()
|
||||
}), 200
|
||||
99
app/routes/gateways.py
Normal file
99
app/routes/gateways.py
Normal file
@@ -0,0 +1,99 @@
|
||||
"""
|
||||
Gateway 路由
|
||||
"""
|
||||
from flask import Blueprint, request, jsonify
|
||||
from flask_jwt_extended import jwt_required
|
||||
from datetime import datetime
|
||||
import bcrypt
|
||||
from app.models import db, Gateway
|
||||
|
||||
gateways_bp = Blueprint('gateways', __name__)
|
||||
|
||||
|
||||
@gateway_bp.route('/', methods=['GET'])
|
||||
@jwt_required()
|
||||
def get_gateways():
|
||||
"""获取 Gateway 列表"""
|
||||
gateways = Gateway.query.all()
|
||||
return jsonify({'gateways': [g.to_dict() for g in gateways]}), 200
|
||||
|
||||
|
||||
@gateway_bp.route('/', methods=['POST'])
|
||||
@jwt_required()
|
||||
def register_gateway():
|
||||
"""注册 Gateway"""
|
||||
data = request.get_json()
|
||||
|
||||
name = data.get('name')
|
||||
url = data.get('url')
|
||||
token = data.get('token')
|
||||
|
||||
if not all([name, url]):
|
||||
return jsonify({'error': 'Missing required fields'}), 400
|
||||
|
||||
if Gateway.query.filter_by(name=name).first():
|
||||
return jsonify({'error': 'Gateway name already exists'}), 400
|
||||
|
||||
# Token 哈希
|
||||
token_hash = None
|
||||
if token:
|
||||
token_hash = bcrypt.hashpw(
|
||||
token.encode('utf-8'),
|
||||
bcrypt.gensalt()
|
||||
).decode('utf-8')
|
||||
|
||||
gateway = Gateway(
|
||||
name=name,
|
||||
url=url,
|
||||
token_hash=token_hash,
|
||||
status='offline'
|
||||
)
|
||||
|
||||
db.session.add(gateway)
|
||||
db.session.commit()
|
||||
|
||||
return jsonify({'gateway': gateway.to_dict()}), 201
|
||||
|
||||
|
||||
@gateway_bp.route('/<gateway_id>', methods=['DELETE'])
|
||||
@jwt_required()
|
||||
def delete_gateway(gateway_id):
|
||||
"""注销 Gateway"""
|
||||
gateway = Gateway.query.get(gateway_id)
|
||||
if not gateway:
|
||||
return jsonify({'error': 'Gateway not found'}), 404
|
||||
|
||||
db.session.delete(gateway)
|
||||
db.session.commit()
|
||||
|
||||
return jsonify({'message': 'Gateway deleted'}), 200
|
||||
|
||||
|
||||
@gateway_bp.route('/<gateway_id>/status', methods=['GET'])
|
||||
@jwt_required()
|
||||
def get_gateway_status(gateway_id):
|
||||
"""获取 Gateway 状态"""
|
||||
gateway = Gateway.query.get(gateway_id)
|
||||
if not gateway:
|
||||
return jsonify({'error': 'Gateway not found'}), 404
|
||||
|
||||
return jsonify({
|
||||
'gateway_id': gateway.id,
|
||||
'status': gateway.status,
|
||||
'agent_count': gateway.agent_count,
|
||||
'last_heartbeat': gateway.last_heartbeat.isoformat() if gateway.last_heartbeat else None
|
||||
}), 200
|
||||
|
||||
|
||||
@gateway_bp.route('/<gateway_id>/heartbeat', methods=['POST'])
|
||||
def gateway_heartbeat(gateway_id):
|
||||
"""Gateway 心跳上报"""
|
||||
gateway = Gateway.query.get(gateway_id)
|
||||
if not gateway:
|
||||
return jsonify({'error': 'Gateway not found'}), 404
|
||||
|
||||
gateway.last_heartbeat = datetime.utcnow()
|
||||
gateway.status = 'online'
|
||||
db.session.commit()
|
||||
|
||||
return jsonify({'status': 'ok'}), 200
|
||||
91
app/routes/messages.py
Normal file
91
app/routes/messages.py
Normal file
@@ -0,0 +1,91 @@
|
||||
"""
|
||||
消息路由
|
||||
"""
|
||||
from flask import Blueprint, request, jsonify
|
||||
from flask_jwt_extended import jwt_required, get_jwt_identity
|
||||
from datetime import datetime
|
||||
from app.models import db, Message, Session
|
||||
|
||||
messages_bp = Blueprint('messages', __name__)
|
||||
|
||||
|
||||
@messages_bp.route('/', methods=['POST'])
|
||||
@jwt_required()
|
||||
def send_message():
|
||||
"""发送消息 (HTTP 方式)"""
|
||||
user_id = get_jwt_identity()
|
||||
data = request.get_json()
|
||||
|
||||
session_id = data.get('session_id')
|
||||
content = data.get('content')
|
||||
message_type = data.get('type', 'text')
|
||||
reply_to = data.get('reply_to')
|
||||
|
||||
if not all([session_id, content]):
|
||||
return jsonify({'error': 'Missing required fields'}), 400
|
||||
|
||||
# 验证会话
|
||||
session = Session.query.filter_by(id=session_id, user_id=user_id).first()
|
||||
if not session:
|
||||
return jsonify({'error': 'Session not found'}), 404
|
||||
|
||||
# 创建消息
|
||||
message = Message(
|
||||
session_id=session_id,
|
||||
sender_type='user',
|
||||
sender_id=user_id,
|
||||
message_type=message_type,
|
||||
content=content,
|
||||
reply_to=reply_to,
|
||||
status='sent',
|
||||
ack_status='pending'
|
||||
)
|
||||
|
||||
db.session.add(message)
|
||||
|
||||
# 更新会话
|
||||
session.message_count += 1
|
||||
session.last_active_at = datetime.utcnow()
|
||||
|
||||
db.session.commit()
|
||||
|
||||
return jsonify({'message': message.to_dict()}), 201
|
||||
|
||||
|
||||
@messages_bp.route('/<message_id>', methods=['GET'])
|
||||
@jwt_required()
|
||||
def get_message(message_id):
|
||||
"""获取单条消息"""
|
||||
message = Message.query.get(message_id)
|
||||
if not message:
|
||||
return jsonify({'error': 'Message not found'}), 404
|
||||
|
||||
return jsonify({'message': message.to_dict()}), 200
|
||||
|
||||
|
||||
@messages_bp.route('/<message_id>/ack', methods=['PUT'])
|
||||
def acknowledge_message(message_id):
|
||||
"""确认消息已送达"""
|
||||
message = Message.query.get(message_id)
|
||||
if not message:
|
||||
return jsonify({'error': 'Message not found'}), 404
|
||||
|
||||
message.ack_status = 'acknowledged'
|
||||
message.delivered_at = datetime.utcnow()
|
||||
db.session.commit()
|
||||
|
||||
return jsonify({'message': message.to_dict()}), 200
|
||||
|
||||
|
||||
@messages_bp.route('/<message_id>/read', methods=['PUT'])
|
||||
@jwt_required()
|
||||
def mark_message_read(message_id):
|
||||
"""标记消息已读"""
|
||||
message = Message.query.get(message_id)
|
||||
if not message:
|
||||
return jsonify({'error': 'Message not found'}), 404
|
||||
|
||||
message.status = 'read'
|
||||
db.session.commit()
|
||||
|
||||
return jsonify({'message': message.to_dict()}), 200
|
||||
102
app/routes/sessions.py
Normal file
102
app/routes/sessions.py
Normal file
@@ -0,0 +1,102 @@
|
||||
"""
|
||||
会话路由
|
||||
"""
|
||||
from flask import Blueprint, request, jsonify
|
||||
from flask_jwt_extended import jwt_required, get_jwt_identity
|
||||
from datetime import datetime
|
||||
from app.models import db, Session, Agent
|
||||
|
||||
sessions_bp = Blueprint('sessions', __name__)
|
||||
|
||||
|
||||
@sessions_bp.route('/', methods=['GET'])
|
||||
@jwt_required()
|
||||
def get_sessions():
|
||||
"""获取会话列表"""
|
||||
user_id = get_jwt_identity()
|
||||
sessions = Session.query.filter_by(user_id=user_id).all()
|
||||
|
||||
return jsonify({
|
||||
'sessions': [s.to_dict() for s in sessions]
|
||||
}), 200
|
||||
|
||||
|
||||
@sessions_bp.route('/', methods=['POST'])
|
||||
@jwt_required()
|
||||
def create_session():
|
||||
"""创建会话"""
|
||||
user_id = get_jwt_identity()
|
||||
data = request.get_json()
|
||||
|
||||
title = data.get('title', 'New Session')
|
||||
agent_id = data.get('agent_id')
|
||||
priority = data.get('priority', 5)
|
||||
|
||||
# 如果没有指定 Agent,分配一个
|
||||
if not agent_id:
|
||||
agent = Agent.query.filter_by(status='online').order_by(
|
||||
Agent.current_sessions.asc()
|
||||
).first()
|
||||
if agent:
|
||||
agent_id = agent.id
|
||||
|
||||
session = Session(
|
||||
user_id=user_id,
|
||||
primary_agent_id=agent_id,
|
||||
title=title,
|
||||
status='active'
|
||||
)
|
||||
|
||||
db.session.add(session)
|
||||
db.session.commit()
|
||||
|
||||
return jsonify({
|
||||
'session': session.to_dict()
|
||||
}), 201
|
||||
|
||||
|
||||
@sessions_bp.route('/<session_id>', methods=['GET'])
|
||||
@jwt_required()
|
||||
def get_session(session_id):
|
||||
"""获取会话详情"""
|
||||
user_id = get_jwt_identity()
|
||||
session = Session.query.filter_by(id=session_id, user_id=user_id).first()
|
||||
|
||||
if not session:
|
||||
return jsonify({'error': 'Session not found'}), 404
|
||||
|
||||
return jsonify({'session': session.to_dict()}), 200
|
||||
|
||||
|
||||
@sessions_bp.route('/<session_id>/close', methods=['PUT'])
|
||||
@jwt_required()
|
||||
def close_session(session_id):
|
||||
"""关闭会话"""
|
||||
user_id = get_jwt_identity()
|
||||
session = Session.query.filter_by(id=session_id, user_id=user_id).first()
|
||||
|
||||
if not session:
|
||||
return jsonify({'error': 'Session not found'}), 404
|
||||
|
||||
session.status = 'closed'
|
||||
session.updated_at = datetime.utcnow()
|
||||
db.session.commit()
|
||||
|
||||
return jsonify({'message': 'Session closed', 'session': session.to_dict()}), 200
|
||||
|
||||
|
||||
@sessions_bp.route('/<session_id>/messages', methods=['GET'])
|
||||
@jwt_required()
|
||||
def get_session_messages(session_id):
|
||||
"""获取会话消息"""
|
||||
user_id = get_jwt_identity()
|
||||
session = Session.query.filter_by(id=session_id, user_id=user_id).first()
|
||||
|
||||
if not session:
|
||||
return jsonify({'error': 'Session not found'}), 404
|
||||
|
||||
messages = session.messages.order_by('created_at').all()
|
||||
|
||||
return jsonify({
|
||||
'messages': [m.to_dict() for m in messages]
|
||||
}), 200
|
||||
67
app/routes/stats.py
Normal file
67
app/routes/stats.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""
|
||||
统计路由
|
||||
"""
|
||||
from flask import Blueprint, jsonify
|
||||
from flask_jwt_extended import jwt_required
|
||||
from app.models import db, User, Session, Agent, Gateway, Message
|
||||
from sqlalchemy import func
|
||||
|
||||
stats_bp = Blueprint('stats', __name__)
|
||||
|
||||
|
||||
@stats_bp.route('/', methods=['GET'])
|
||||
@jwt_required()
|
||||
def get_stats():
|
||||
"""获取系统统计信息"""
|
||||
stats = {
|
||||
'users': User.query.count(),
|
||||
'sessions': Session.query.filter_by(status='active').count(),
|
||||
'agents': Agent.query.filter_by(status='online').count(),
|
||||
'gateways': Gateway.query.filter_by(status='online').count(),
|
||||
'messages': Message.query.count(),
|
||||
}
|
||||
|
||||
return jsonify({'stats': stats}), 200
|
||||
|
||||
|
||||
@stats_bp.route('/sessions', methods=['GET'])
|
||||
@jwt_required()
|
||||
def get_session_stats():
|
||||
"""获取会话统计"""
|
||||
stats = {
|
||||
'total': Session.query.count(),
|
||||
'active': Session.query.filter_by(status='active').count(),
|
||||
'closed': Session.query.filter_by(status='closed').count(),
|
||||
'paused': Session.query.filter_by(status='paused').count(),
|
||||
}
|
||||
|
||||
return jsonify({'stats': stats}), 200
|
||||
|
||||
|
||||
@stats_bp.route('/messages', methods=['GET'])
|
||||
@jwt_required()
|
||||
def get_message_stats():
|
||||
"""获取消息统计"""
|
||||
stats = {
|
||||
'total': Message.query.count(),
|
||||
'sent': Message.query.filter_by(status='sent').count(),
|
||||
'delivered': Message.query.filter_by(status='delivered').count(),
|
||||
'read': Message.query.filter_by(status='read').count(),
|
||||
'failed': Message.query.filter_by(status='failed').count(),
|
||||
}
|
||||
|
||||
return jsonify({'stats': stats}), 200
|
||||
|
||||
|
||||
@stats_bp.route('/agents', methods=['GET'])
|
||||
@jwt_required()
|
||||
def get_agent_stats():
|
||||
"""获取 Agent 统计"""
|
||||
stats = {
|
||||
'total': Agent.query.count(),
|
||||
'online': Agent.query.filter_by(status='online').count(),
|
||||
'offline': Agent.query.filter_by(status='offline').count(),
|
||||
'busy': Agent.query.filter_by(status='busy').count(),
|
||||
}
|
||||
|
||||
return jsonify({'stats': stats}), 200
|
||||
3
app/services/__init__.py
Normal file
3
app/services/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
服务模块
|
||||
"""
|
||||
3
app/socketio/__init__.py
Normal file
3
app/socketio/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
Socket.IO 模块
|
||||
"""
|
||||
54
app/socketio/events.py
Normal file
54
app/socketio/events.py
Normal file
@@ -0,0 +1,54 @@
|
||||
"""
|
||||
Socket.IO 事件定义
|
||||
"""
|
||||
# 认证事件
|
||||
AUTH_EVENTS = [
|
||||
'auth', # C→S: 认证请求
|
||||
'authenticated', # S→C: 认证成功
|
||||
'auth_error', # S→C: 认证失败
|
||||
]
|
||||
|
||||
# 心跳事件
|
||||
HEARTBEAT_EVENTS = [
|
||||
'ping', # C→S: 心跳请求
|
||||
'pong', # S→C: 心跳响应
|
||||
'heartbeat_timeout', # S→C: 心跳超时
|
||||
]
|
||||
|
||||
# 会话事件
|
||||
SESSION_EVENTS = [
|
||||
'session.create', # C→S: 创建会话
|
||||
'session.created', # S→C: 会话已创建
|
||||
'session.join', # C→S: 加入会话
|
||||
'session.joined', # S→C: 已加入会话
|
||||
'session.leave', # C→S: 离开会话
|
||||
'session.left', # S→C: 已离开会话
|
||||
'session.closed', # S→C: 会话被关闭
|
||||
'session.assigned', # S→C: Agent 分配通知
|
||||
]
|
||||
|
||||
# 消息事件
|
||||
MESSAGE_EVENTS = [
|
||||
'message.send', # C→S: 发送消息
|
||||
'message', # S→C: 收到消息
|
||||
'message.ack', # C→S: 消息确认
|
||||
'message.acked', # S→C: 确认已收到
|
||||
'message.read', # C→S: 消息已读
|
||||
'message.stream', # S→C: 流式消息
|
||||
'typing', # C→S: 正在输入
|
||||
]
|
||||
|
||||
# 错误事件
|
||||
ERROR_EVENTS = [
|
||||
'error', # S→C: 通用错误
|
||||
'session_error', # S→C: 会话错误
|
||||
'message_error', # S→C: 消息错误
|
||||
]
|
||||
|
||||
ALL_EVENTS = (
|
||||
AUTH_EVENTS +
|
||||
HEARTBEAT_EVENTS +
|
||||
SESSION_EVENTS +
|
||||
MESSAGE_EVENTS +
|
||||
ERROR_EVENTS
|
||||
)
|
||||
182
app/socketio/handlers.py
Normal file
182
app/socketio/handlers.py
Normal file
@@ -0,0 +1,182 @@
|
||||
"""
|
||||
Socket.IO 事件处理
|
||||
"""
|
||||
from flask_socketio import emit, join_room, leave_room
|
||||
from flask_jwt_extended import decode_token
|
||||
from datetime import datetime
|
||||
import json
|
||||
from app.extensions import redis_client
|
||||
|
||||
# 连接管理器
|
||||
class ConnectionManager:
|
||||
def __init__(self):
|
||||
self.user_sockets = {} # user_id -> socket_id
|
||||
self.agent_sockets = {} # agent_id -> socket_id
|
||||
self.socket_sessions = {} # socket_id -> session_id
|
||||
|
||||
connection_manager = ConnectionManager()
|
||||
|
||||
|
||||
def register_handlers(socketio):
|
||||
"""注册 Socket.IO 事件处理器"""
|
||||
|
||||
@socketio.on('connect')
|
||||
def handle_connect():
|
||||
"""客户端连接"""
|
||||
print('Client connected')
|
||||
# 发送认证请求
|
||||
emit('auth', {'message': 'Please authenticate'})
|
||||
|
||||
@socketio.on('disconnect')
|
||||
def handle_disconnect():
|
||||
"""客户端断开连接"""
|
||||
from flask import request
|
||||
sid = request.sid
|
||||
print(f'Client disconnected: {sid}')
|
||||
|
||||
# 清理连接
|
||||
if sid in connection_manager.socket_sessions:
|
||||
del connection_manager.socket_sessions[sid]
|
||||
|
||||
@socketio.on('auth')
|
||||
def handle_auth(data):
|
||||
"""处理认证"""
|
||||
from flask import request
|
||||
sid = request.sid
|
||||
|
||||
token = data.get('token')
|
||||
if not token:
|
||||
emit('auth_error', {'code': 'MISSING_TOKEN', 'message': 'Token is required'})
|
||||
return
|
||||
|
||||
try:
|
||||
# 验证 JWT Token
|
||||
decoded = decode_token(token)
|
||||
user_id = decoded['sub']
|
||||
|
||||
# 保存连接信息
|
||||
connection_manager.user_sockets[user_id] = sid
|
||||
|
||||
# 存储到 Redis
|
||||
redis_client.hset(f'socket:{sid}', mapping={
|
||||
'user_id': user_id,
|
||||
'connected_at': datetime.utcnow().isoformat()
|
||||
})
|
||||
redis_client.expire(f'socket:{sid}', 86400)
|
||||
|
||||
emit('authenticated', {
|
||||
'user_id': user_id,
|
||||
'socket_id': sid
|
||||
})
|
||||
except Exception as e:
|
||||
emit('auth_error', {'code': 'INVALID_TOKEN', 'message': str(e)})
|
||||
|
||||
@socketio.on('ping')
|
||||
def handle_ping(data):
|
||||
"""心跳响应"""
|
||||
emit('pong', {'timestamp': datetime.utcnow().timestamp()})
|
||||
|
||||
@socketio.on('session.create')
|
||||
def handle_session_create(data):
|
||||
"""创建会话"""
|
||||
from app.models import db, Session, Agent
|
||||
|
||||
user_id = data.get('user_id')
|
||||
title = data.get('title', 'New Session')
|
||||
agent_id = data.get('agent_id')
|
||||
|
||||
# 如果没有指定 Agent,分配一个
|
||||
if not agent_id:
|
||||
agent = Agent.query.filter_by(status='online').first()
|
||||
if agent:
|
||||
agent_id = agent.id
|
||||
|
||||
session = Session(
|
||||
user_id=user_id,
|
||||
primary_agent_id=agent_id,
|
||||
title=title
|
||||
)
|
||||
db.session.add(session)
|
||||
db.session.commit()
|
||||
|
||||
emit('session.created', {
|
||||
'session_id': session.id,
|
||||
'agent_id': agent_id
|
||||
})
|
||||
|
||||
@socketio.on('session.join')
|
||||
def handle_session_join(data):
|
||||
"""加入会话"""
|
||||
from flask import request
|
||||
from app.models import Session
|
||||
|
||||
sid = request.sid
|
||||
session_id = data.get('session_id')
|
||||
|
||||
# 加入房间
|
||||
join_room(session_id)
|
||||
connection_manager.socket_sessions[sid] = session_id
|
||||
|
||||
session = Session.query.get(session_id)
|
||||
if session:
|
||||
emit('session.joined', {
|
||||
'session_id': session_id,
|
||||
'participants': [session.user_id, session.primary_agent_id]
|
||||
})
|
||||
|
||||
@socketio.on('message.send')
|
||||
def handle_message_send(data):
|
||||
"""发送消息"""
|
||||
from flask import request
|
||||
from app.models import db, Message, Session
|
||||
|
||||
sid = request.sid
|
||||
session_id = data.get('session_id')
|
||||
content = data.get('content')
|
||||
sender_type = data.get('sender_type', 'user')
|
||||
sender_id = data.get('sender_id')
|
||||
|
||||
# 创建消息
|
||||
message = Message(
|
||||
session_id=session_id,
|
||||
sender_type=sender_type,
|
||||
sender_id=sender_id,
|
||||
content=content,
|
||||
status='sent'
|
||||
)
|
||||
db.session.add(message)
|
||||
|
||||
# 更新会话
|
||||
session = Session.query.get(session_id)
|
||||
if session:
|
||||
session.message_count += 1
|
||||
session.last_active_at = datetime.utcnow()
|
||||
|
||||
db.session.commit()
|
||||
|
||||
# 广播消息到房间
|
||||
emit('message', {
|
||||
'message_id': message.id,
|
||||
'session_id': session_id,
|
||||
'sender': {'type': sender_type, 'id': sender_id},
|
||||
'content': content,
|
||||
'timestamp': datetime.utcnow().isoformat()
|
||||
}, room=session_id)
|
||||
|
||||
@socketio.on('message.ack')
|
||||
def handle_message_ack(data):
|
||||
"""消息确认"""
|
||||
from app.models import db, Message
|
||||
|
||||
message_id = data.get('message_id')
|
||||
status = data.get('status')
|
||||
|
||||
message = Message.query.get(message_id)
|
||||
if message:
|
||||
message.ack_status = status
|
||||
if status == 'acknowledged':
|
||||
message.status = 'delivered'
|
||||
message.delivered_at = datetime.utcnow()
|
||||
db.session.commit()
|
||||
|
||||
emit('message.acked', {'message_id': message_id, 'status': status})
|
||||
3
app/utils/__init__.py
Normal file
3
app/utils/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
工具模块
|
||||
"""
|
||||
80
docker-compose.yaml
Normal file
80
docker-compose.yaml
Normal file
@@ -0,0 +1,80 @@
|
||||
version: '3.8'
|
||||
|
||||
services:
|
||||
pit-router:
|
||||
build: .
|
||||
ports:
|
||||
- "9000:9000"
|
||||
environment:
|
||||
- FLASK_ENV=production
|
||||
- SECRET_KEY=${SECRET_KEY}
|
||||
- JWT_SECRET_KEY=${JWT_SECRET_KEY}
|
||||
- DATABASE_URL=postgresql://postgres:${DB_PASSWORD}@postgres:5432/pit_router
|
||||
- REDIS_URL=redis://redis:6379/0
|
||||
volumes:
|
||||
- pit-data:/app/data
|
||||
- pit-logs:/app/logs
|
||||
depends_on:
|
||||
postgres:
|
||||
condition: service_healthy
|
||||
redis:
|
||||
condition: service_healthy
|
||||
restart: unless-stopped
|
||||
healthcheck:
|
||||
test: ["CMD", "python", "-c", "import requests; requests.get('http://localhost:9000/health')"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
start_period: 40s
|
||||
deploy:
|
||||
resources:
|
||||
limits:
|
||||
cpus: '2.0'
|
||||
memory: 2G
|
||||
reservations:
|
||||
cpus: '0.5'
|
||||
memory: 512M
|
||||
|
||||
postgres:
|
||||
image: postgres:15-alpine
|
||||
environment:
|
||||
- POSTGRES_USER=postgres
|
||||
- POSTGRES_PASSWORD=${DB_PASSWORD}
|
||||
- POSTGRES_DB=pit_router
|
||||
volumes:
|
||||
- postgres-data:/var/lib/postgresql/data
|
||||
restart: unless-stopped
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "pg_isready -U postgres -d pit_router"]
|
||||
interval: 10s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
|
||||
redis:
|
||||
image: redis:7-alpine
|
||||
volumes:
|
||||
- redis-data:/data
|
||||
restart: unless-stopped
|
||||
healthcheck:
|
||||
test: ["CMD", "redis-cli", "ping"]
|
||||
interval: 10s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
|
||||
nginx:
|
||||
image: nginx:alpine
|
||||
ports:
|
||||
- "80:80"
|
||||
- "443:443"
|
||||
volumes:
|
||||
- ./nginx.conf:/etc/nginx/nginx.conf:ro
|
||||
- ./ssl:/etc/nginx/ssl:ro
|
||||
depends_on:
|
||||
- pit-router
|
||||
restart: unless-stopped
|
||||
|
||||
volumes:
|
||||
pit-data:
|
||||
pit-logs:
|
||||
postgres-data:
|
||||
redis-data:
|
||||
7
requirements-dev.txt
Normal file
7
requirements-dev.txt
Normal file
@@ -0,0 +1,7 @@
|
||||
pytest==8.0.2
|
||||
pytest-flask==1.3.0
|
||||
pytest-cov==4.1.0
|
||||
black==24.2.0
|
||||
flake8==7.0.0
|
||||
mypy==1.8.0
|
||||
ipython==8.22.1
|
||||
20
requirements.txt
Normal file
20
requirements.txt
Normal file
@@ -0,0 +1,20 @@
|
||||
alembic==1.13.1
|
||||
Flask==3.0.2
|
||||
Flask-CORS==4.0.0
|
||||
Flask-JWT-Extended==4.6.0
|
||||
Flask-Limiter==3.5.0
|
||||
Flask-Login==0.6.3
|
||||
Flask-Migrate==4.0.5
|
||||
Flask-SocketIO==5.3.6
|
||||
Flask-SQLAlchemy==3.1.1
|
||||
gunicorn==21.2.0
|
||||
psycopg2-binary==2.9.9
|
||||
PyJWT==2.8.0
|
||||
python-dotenv==1.0.0
|
||||
redis==5.0.1
|
||||
SQLAlchemy==2.0.27
|
||||
Werkzeug==3.0.1
|
||||
bcrypt==4.1.2
|
||||
marshmallow==3.21.0
|
||||
marshmallow-sqlalchemy==1.0.0
|
||||
python-dateutil==2.9.0
|
||||
Reference in New Issue
Block a user