From bf245ee8cb71977d6fd7c9dadb971d0a135663bc Mon Sep 17 00:00:00 2001 From: "feifei.xu" <307327147@qq.com> Date: Sat, 14 Mar 2026 19:41:36 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20Phase=201=20-=20=E6=A0=B8=E5=BF=83?= =?UTF-8?q?=E5=8A=9F=E8=83=BD=E5=AE=9E=E7=8E=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .env.example | 39 ++++++ app/__init__.py | 117 ++++++++++++++++++ app/config.py | 72 +++++++++++ app/extensions.py | 32 +++++ app/models/__init__.py | 251 +++++++++++++++++++++++++++++++++++++++ app/routes/__init__.py | 3 + app/routes/agents.py | 91 ++++++++++++++ app/routes/auth.py | 142 ++++++++++++++++++++++ app/routes/gateways.py | 99 +++++++++++++++ app/routes/messages.py | 91 ++++++++++++++ app/routes/sessions.py | 102 ++++++++++++++++ app/routes/stats.py | 67 +++++++++++ app/services/__init__.py | 3 + app/socketio/__init__.py | 3 + app/socketio/events.py | 54 +++++++++ app/socketio/handlers.py | 182 ++++++++++++++++++++++++++++ app/utils/__init__.py | 3 + docker-compose.yaml | 80 +++++++++++++ requirements-dev.txt | 7 ++ requirements.txt | 20 ++++ run.py | 10 ++ 21 files changed, 1468 insertions(+) create mode 100644 .env.example create mode 100644 app/__init__.py create mode 100644 app/config.py create mode 100644 app/extensions.py create mode 100644 app/models/__init__.py create mode 100644 app/routes/__init__.py create mode 100644 app/routes/agents.py create mode 100644 app/routes/auth.py create mode 100644 app/routes/gateways.py create mode 100644 app/routes/messages.py create mode 100644 app/routes/sessions.py create mode 100644 app/routes/stats.py create mode 100644 app/services/__init__.py create mode 100644 app/socketio/__init__.py create mode 100644 app/socketio/events.py create mode 100644 app/socketio/handlers.py create mode 100644 app/utils/__init__.py create mode 100644 docker-compose.yaml create mode 100644 requirements-dev.txt create mode 100644 requirements.txt create mode 100644 run.py diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..c4df737 --- /dev/null +++ b/.env.example @@ -0,0 +1,39 @@ +# PIT Router 配置文件 + +# Flask 配置 +FLASK_ENV=development +FLASK_APP=run.py +SECRET_KEY=your-secret-key-here-change-in-production +JWT_SECRET_KEY=your-jwt-secret-key-here-change-in-production + +# 数据库配置 +DATABASE_URL=postgresql://user:password@localhost:5432/pit_router +# 开发环境可使用 SQLite +# DATABASE_URL=sqlite:///app.db + +# Redis 配置 +REDIS_URL=redis://localhost:6379/0 + +# 服务器配置 +HOST=0.0.0.0 +PORT=9000 +DEBUG=True + +# JWT 配置 +JWT_ACCESS_TOKEN_EXPIRES=86400 +JWT_REFRESH_TOKEN_EXPIRES=604800 + +# WebSocket 配置 +SOCKETIO_PING_INTERVAL=25000 +SOCKETIO_PING_TIMEOUT=10000 + +# Agent 调度配置 +SCHEDULER_STRATEGY=weighted_round_robin +SCHEDULER_TIMEOUT=30 + +# 安全配置 +RATE_LIMIT=100 +CORS_ORIGINS=* + +# 日志配置 +LOG_LEVEL=INFO diff --git a/app/__init__.py b/app/__init__.py new file mode 100644 index 0000000..52dd6b3 --- /dev/null +++ b/app/__init__.py @@ -0,0 +1,117 @@ +""" +PIT Router Flask 应用工厂 +""" +from flask import Flask +from flask_socketio import SocketIO +import logging + +from app.config import config +from app.extensions import ( + db, migrate, jwt, login_manager, cors, limiter, init_redis +) + +# Socket.IO 实例 +socketio = SocketIO(cors_allowed_origins="*", async_mode='threading') + + +def create_app(config_name='default'): + """创建 Flask 应用实例""" + app = Flask(__name__) + app.config.from_object(config[config_name]) + + # 初始化扩展 + _init_extensions(app) + + # 注册蓝图 + _register_blueprints(app) + + # 注册 Socket.IO 事件 + _register_socketio_events() + + # 配置日志 + _configure_logging(app) + + # 注册错误处理 + _register_error_handlers(app) + + # 健康检查端点 + @app.route('/health') + def health_check(): + return {'status': 'ok', 'service': 'pit-router'} + + return app + + +def _init_extensions(app): + """初始化 Flask 扩展""" + db.init_app(app) + migrate.init_app(app, db) + jwt.init_app(app) + login_manager.init_app(app) + cors.init_app(app) + limiter.init_app(app) + socketio.init_app(app) + + # 初始化 Redis + init_redis(app) + + # 创建数据库表 + with app.app_context(): + db.create_all() + + +def _register_blueprints(app): + """注册蓝图""" + from app.routes.auth import auth_bp + from app.routes.sessions import sessions_bp + from app.routes.agents import agents_bp + from app.routes.gateways import gateways_bp + from app.routes.messages import messages_bp + from app.routes.stats import stats_bp + + app.register_blueprint(auth_bp, url_prefix='/api/auth') + app.register_blueprint(sessions_bp, url_prefix='/api/sessions') + app.register_blueprint(agents_bp, url_prefix='/api/agents') + app.register_blueprint(gateways_bp, url_prefix='/api/gateways') + app.register_blueprint(messages_bp, url_prefix='/api/messages') + app.register_blueprint(stats_bp, url_prefix='/api/stats') + + +def _register_socketio_events(): + """注册 Socket.IO 事件处理器""" + from app.socketio.handlers import register_handlers + register_handlers(socketio) + + +def _configure_logging(app): + """配置日志""" + log_level = app.config.get('LOG_LEVEL', 'INFO') + logging.basicConfig( + level=getattr(logging, log_level), + format='%(asctime)s [%(levelname)s] %(name)s: %(message)s' + ) + + +def _register_error_handlers(app): + """注册错误处理器""" + from flask import jsonify + + @app.errorhandler(400) + def bad_request(error): + return jsonify({'error': 'Bad Request', 'message': str(error)}), 400 + + @app.errorhandler(401) + def unauthorized(error): + return jsonify({'error': 'Unauthorized', 'message': str(error)}), 401 + + @app.errorhandler(403) + def forbidden(error): + return jsonify({'error': 'Forbidden', 'message': str(error)}), 403 + + @app.errorhandler(404) + def not_found(error): + return jsonify({'error': 'Not Found', 'message': str(error)}), 404 + + @app.errorhandler(500) + def internal_error(error): + return jsonify({'error': 'Internal Server Error', 'message': str(error)}), 500 diff --git a/app/config.py b/app/config.py new file mode 100644 index 0000000..4bcb92f --- /dev/null +++ b/app/config.py @@ -0,0 +1,72 @@ +""" +PIT Router 配置管理 +""" +import os +from datetime import timedelta +from dotenv import load_dotenv + +load_dotenv() + + +class Config: + """基础配置""" + SECRET_KEY = os.environ.get('SECRET_KEY', 'dev-secret-key-change-in-production') + JWT_SECRET_KEY = os.environ.get('JWT_SECRET_KEY', 'dev-jwt-secret-key-change-in-production') + + # 数据库 + SQLALCHEMY_DATABASE_URI = os.environ.get('DATABASE_URL', 'sqlite:///pit_router.db') + SQLALCHEMY_TRACK_MODIFICATIONS = False + SQLALCHEMY_ECHO = False + + # Redis + REDIS_URL = os.environ.get('REDIS_URL', 'redis://localhost:6379/0') + + # JWT + JWT_ACCESS_TOKEN_EXPIRES = timedelta( + seconds=int(os.environ.get('JWT_ACCESS_TOKEN_EXPIRES', 86400)) + ) + JWT_REFRESH_TOKEN_EXPIRES = timedelta( + seconds=int(os.environ.get('JWT_REFRESH_TOKEN_EXPIRES', 604800)) + ) + + # Socket.IO + SOCKETIO_PING_INTERVAL = int(os.environ.get('SOCKETIO_PING_INTERVAL', 25000)) + SOCKETIO_PING_TIMEOUT = int(os.environ.get('SOCKETIO_PING_TIMEOUT', 10000)) + + # Agent 调度 + SCHEDULER_STRATEGY = os.environ.get('SCHEDULER_STRATEGY', 'weighted_round_robin') + SCHEDULER_TIMEOUT = int(os.environ.get('SCHEDULER_TIMEOUT', 30)) + + # 安全 + RATE_LIMIT = int(os.environ.get('RATE_LIMIT', 100)) + CORS_ORIGINS = os.environ.get('CORS_ORIGINS', '*') + + +class DevelopmentConfig(Config): + """开发环境配置""" + DEBUG = True + SQLALCHEMY_ECHO = True + + +class ProductionConfig(Config): + """生产环境配置""" + DEBUG = False + + # 生产环境强制使用 PostgreSQL + if not os.environ.get('DATABASE_URL'): + raise ValueError("DATABASE_URL must be set in production") + + +class TestingConfig(Config): + """测试环境配置""" + TESTING = True + SQLALCHEMY_DATABASE_URI = 'sqlite:///:memory:' + WTF_CSRF_ENABLED = False + + +config = { + 'development': DevelopmentConfig, + 'production': ProductionConfig, + 'testing': TestingConfig, + 'default': DevelopmentConfig +} diff --git a/app/extensions.py b/app/extensions.py new file mode 100644 index 0000000..369f5dd --- /dev/null +++ b/app/extensions.py @@ -0,0 +1,32 @@ +""" +Flask 扩展初始化 +""" +from flask_sqlalchemy import SQLAlchemy +from flask_migrate import Migrate +from flask_jwt_extended import JWTManager +from flask_login import LoginManager +from flask_cors import CORS +from flask_limiter import Limiter +from flask_limiter.util import get_remote_address +import redis + +db = SQLAlchemy() +migrate = Migrate() +jwt = JWTManager() +login_manager = LoginManager() +cors = CORS() +limiter = Limiter( + key_func=get_remote_address, + default_limits=["100 per minute"] +) + +# Redis 客户端 +redis_client = None + + +def init_redis(app): + """初始化 Redis 客户端""" + global redis_client + redis_url = app.config.get('REDIS_URL', 'redis://localhost:6379/0') + redis_client = redis.from_url(redis_url) + return redis_client diff --git a/app/models/__init__.py b/app/models/__init__.py new file mode 100644 index 0000000..8dca7a6 --- /dev/null +++ b/app/models/__init__.py @@ -0,0 +1,251 @@ +""" +PIT Router 数据模型 +""" +from datetime import datetime +from typing import Optional, List +from flask_sqlalchemy import SQLAlchemy +from sqlalchemy import String, DateTime, Integer, Text, JSON, ForeignKey, Boolean +from sqlalchemy.orm import Mapped, mapped_column, relationship +import uuid + +db = SQLAlchemy() + + +def generate_uuid() -> str: + """生成 UUID""" + return str(uuid.uuid4()) + + +class User(db.Model): + """用户模型""" + __tablename__ = 'users' + + id: Mapped[str] = mapped_column(String(36), primary_key=True, default=generate_uuid) + username: Mapped[str] = mapped_column(String(80), unique=True, nullable=False) + password_hash: Mapped[str] = mapped_column(String(256), nullable=False) + email: Mapped[str] = mapped_column(String(120), unique=True, nullable=False) + nickname: Mapped[Optional[str]] = mapped_column(String(80), nullable=True) + role: Mapped[str] = mapped_column(String(20), default='user') + status: Mapped[str] = mapped_column(String(20), default='active') + created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow) + last_login_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + + # 关联 + sessions = relationship('Session', back_populates='user', cascade='all, delete-orphan') + + def __repr__(self): + return f'' + + def to_dict(self): + return { + 'id': self.id, + 'username': self.username, + 'email': self.email, + 'nickname': self.nickname, + 'role': self.role, + 'status': self.status, + 'created_at': self.created_at.isoformat() if self.created_at else None, + 'last_login_at': self.last_login_at.isoformat() if self.last_login_at else None, + } + + +class Gateway(db.Model): + """Gateway 模型""" + __tablename__ = 'gateways' + + id: Mapped[str] = mapped_column(String(36), primary_key=True, default=generate_uuid) + name: Mapped[str] = mapped_column(String(80), unique=True, nullable=False) + url: Mapped[str] = mapped_column(String(256), nullable=False) + token_hash: Mapped[Optional[str]] = mapped_column(String(256), nullable=True) + status: Mapped[str] = mapped_column(String(20), default='offline') + agent_count: Mapped[int] = mapped_column(Integer, default=0) + connection_limit: Mapped[int] = mapped_column(Integer, default=10) + heartbeat_interval: Mapped[int] = mapped_column(Integer, default=60) + allowed_ips: Mapped[Optional[str]] = mapped_column(JSON, nullable=True) + last_heartbeat: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow) + + # 关联 + agents = relationship('Agent', back_populates='gateway', cascade='all, delete-orphan') + + def __repr__(self): + return f'' + + def to_dict(self): + return { + 'id': self.id, + 'name': self.name, + 'url': self.url, + 'status': self.status, + 'agent_count': self.agent_count, + 'connection_limit': self.connection_limit, + 'heartbeat_interval': self.heartbeat_interval, + 'last_heartbeat': self.last_heartbeat.isoformat() if self.last_heartbeat else None, + 'created_at': self.created_at.isoformat() if self.created_at else None, + } + + +class Agent(db.Model): + """Agent 模型""" + __tablename__ = 'agents' + + id: Mapped[str] = mapped_column(String(36), primary_key=True, default=generate_uuid) + name: Mapped[str] = mapped_column(String(80), nullable=False) + display_name: Mapped[Optional[str]] = mapped_column(String(80), nullable=True) + gateway_id: Mapped[Optional[str]] = mapped_column(String(36), ForeignKey('gateways.id'), nullable=True) + socket_id: Mapped[Optional[str]] = mapped_column(String(100), nullable=True) + model: Mapped[Optional[str]] = mapped_column(String(80), nullable=True) + capabilities: Mapped[Optional[str]] = mapped_column(JSON, nullable=True) + status: Mapped[str] = mapped_column(String(20), default='offline') + priority: Mapped[int] = mapped_column(Integer, default=5) + weight: Mapped[int] = mapped_column(Integer, default=10) + connection_limit: Mapped[int] = mapped_column(Integer, default=5) + current_sessions: Mapped[int] = mapped_column(Integer, default=0) + last_heartbeat: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow) + + # 关联 + gateway = relationship('Gateway', back_populates='agents') + sessions = relationship('Session', back_populates='agent') + + def __repr__(self): + return f'' + + def to_dict(self): + return { + 'id': self.id, + 'name': self.name, + 'display_name': self.display_name, + 'gateway_id': self.gateway_id, + 'socket_id': self.socket_id, + 'model': self.model, + 'capabilities': self.capabilities, + 'status': self.status, + 'priority': self.priority, + 'weight': self.weight, + 'connection_limit': self.connection_limit, + 'current_sessions': self.current_sessions, + 'last_heartbeat': self.last_heartbeat.isoformat() if self.last_heartbeat else None, + 'created_at': self.created_at.isoformat() if self.created_at else None, + } + + +class Session(db.Model): + """会话模型""" + __tablename__ = 'sessions' + + id: Mapped[str] = mapped_column(String(36), primary_key=True, default=generate_uuid) + user_id: Mapped[str] = mapped_column(String(36), ForeignKey('users.id'), nullable=False) + primary_agent_id: Mapped[Optional[str]] = mapped_column(String(36), ForeignKey('agents.id'), nullable=True) + participating_agent_ids: Mapped[Optional[str]] = mapped_column(JSON, nullable=True) + user_socket_id: Mapped[Optional[str]] = mapped_column(String(100), nullable=True) + title: Mapped[Optional[str]] = mapped_column(String(200), nullable=True) + channel_type: Mapped[str] = mapped_column(String(20), default='web') + status: Mapped[str] = mapped_column(String(20), default='active') + message_count: Mapped[int] = mapped_column(Integer, default=0) + unread_count: Mapped[int] = mapped_column(Integer, default=0) + created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow) + updated_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + last_active_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + + # 关联 + user = relationship('User', back_populates='sessions') + agent = relationship('Agent', back_populates='sessions') + messages = relationship('Message', back_populates='session', cascade='all, delete-orphan') + + def __repr__(self): + return f'' + + def to_dict(self): + return { + 'id': self.id, + 'user_id': self.user_id, + 'primary_agent_id': self.primary_agent_id, + 'participating_agent_ids': self.participating_agent_ids, + 'user_socket_id': self.user_socket_id, + 'title': self.title, + 'channel_type': self.channel_type, + 'status': self.status, + 'message_count': self.message_count, + 'unread_count': self.unread_count, + 'created_at': self.created_at.isoformat() if self.created_at else None, + 'updated_at': self.updated_at.isoformat() if self.updated_at else None, + 'last_active_at': self.last_active_at.isoformat() if self.last_active_at else None, + } + + +class Message(db.Model): + """消息模型""" + __tablename__ = 'messages' + + id: Mapped[str] = mapped_column(String(36), primary_key=True, default=generate_uuid) + session_id: Mapped[str] = mapped_column(String(36), ForeignKey('sessions.id'), nullable=False) + sender_type: Mapped[str] = mapped_column(String(20), nullable=False) + sender_id: Mapped[str] = mapped_column(String(36), nullable=False) + message_type: Mapped[str] = mapped_column(String(20), default='text') + content: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + content_type: Mapped[str] = mapped_column(String(20), default='markdown') + reply_to: Mapped[Optional[str]] = mapped_column(String(36), nullable=True) + status: Mapped[str] = mapped_column(String(20), default='sent') + ack_status: Mapped[str] = mapped_column(String(20), default='pending') + retry_count: Mapped[int] = mapped_column(Integer, default=0) + created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow) + delivered_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + + # 关联 + session = relationship('Session', back_populates='messages') + + def __repr__(self): + return f'' + + def to_dict(self): + return { + 'id': self.id, + 'session_id': self.session_id, + 'sender_type': self.sender_type, + 'sender_id': self.sender_id, + 'message_type': self.message_type, + 'content': self.content, + 'content_type': self.content_type, + 'reply_to': self.reply_to, + 'status': self.status, + 'ack_status': self.ack_status, + 'retry_count': self.retry_count, + 'created_at': self.created_at.isoformat() if self.created_at else None, + 'delivered_at': self.delivered_at.isoformat() if self.delivered_at else None, + } + + +class Connection(db.Model): + """连接模型""" + __tablename__ = 'connections' + + id: Mapped[str] = mapped_column(String(36), primary_key=True, default=generate_uuid) + socket_id: Mapped[str] = mapped_column(String(100), unique=True, nullable=False) + connection_type: Mapped[str] = mapped_column(String(20), nullable=False) + entity_id: Mapped[str] = mapped_column(String(36), nullable=False) + entity_type: Mapped[str] = mapped_column(String(20), nullable=False) + ip_address: Mapped[Optional[str]] = mapped_column(String(45), nullable=True) + user_agent: Mapped[Optional[str]] = mapped_column(String(500), nullable=True) + status: Mapped[str] = mapped_column(String(20), default='connected') + auth_token: Mapped[Optional[str]] = mapped_column(String(500), nullable=True) + connected_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow) + last_activity: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + disconnected_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + + def __repr__(self): + return f'' + + def to_dict(self): + return { + 'id': self.id, + 'socket_id': self.socket_id, + 'connection_type': self.connection_type, + 'entity_id': self.entity_id, + 'entity_type': self.entity_type, + 'ip_address': self.ip_address, + 'user_agent': self.user_agent, + 'status': self.status, + 'connected_at': self.connected_at.isoformat() if self.connected_at else None, + 'last_activity': self.last_activity.isoformat() if self.last_activity else None, + } diff --git a/app/routes/__init__.py b/app/routes/__init__.py new file mode 100644 index 0000000..c19d652 --- /dev/null +++ b/app/routes/__init__.py @@ -0,0 +1,3 @@ +""" +路由模块 +""" diff --git a/app/routes/agents.py b/app/routes/agents.py new file mode 100644 index 0000000..ad511c1 --- /dev/null +++ b/app/routes/agents.py @@ -0,0 +1,91 @@ +""" +Agent 路由 +""" +from flask import Blueprint, request, jsonify +from flask_jwt_extended import jwt_required +from datetime import datetime +from app.models import db, Agent + +agents_bp = Blueprint('agents', __name__) + + +@agents_bp.route('/', methods=['GET']) +@jwt_required() +def get_agents(): + """获取 Agent 列表""" + agents = Agent.query.all() + return jsonify({'agents': [a.to_dict() for a in agents]}), 200 + + +@agents_bp.route('/', methods=['GET']) +@jwt_required() +def get_agent(agent_id): + """获取 Agent 详情""" + agent = Agent.query.get(agent_id) + if not agent: + return jsonify({'error': 'Agent not found'}), 404 + return jsonify({'agent': agent.to_dict()}), 200 + + +@agents_bp.route('//status', methods=['GET']) +@jwt_required() +def get_agent_status(agent_id): + """获取 Agent 实时状态""" + agent = Agent.query.get(agent_id) + if not agent: + return jsonify({'error': 'Agent not found'}), 404 + + return jsonify({ + 'agent_id': agent.id, + 'status': agent.status, + 'current_sessions': agent.current_sessions, + 'last_heartbeat': agent.last_heartbeat.isoformat() if agent.last_heartbeat else None + }), 200 + + +@agents_bp.route('//config', methods=['PUT']) +@jwt_required() +def update_agent_config(agent_id): + """更新 Agent 配置""" + agent = Agent.query.get(agent_id) + if not agent: + return jsonify({'error': 'Agent not found'}), 404 + + data = request.get_json() + + if 'name' in data: + agent.name = data['name'] + if 'display_name' in data: + agent.display_name = data['display_name'] + if 'capabilities' in data: + agent.capabilities = data['capabilities'] + if 'priority' in data: + agent.priority = data['priority'] + if 'weight' in data: + agent.weight = data['weight'] + + db.session.commit() + + return jsonify({'agent': agent.to_dict()}), 200 + + +@agents_bp.route('//heartbeat', methods=['POST']) +def agent_heartbeat(agent_id): + """Agent 心跳上报""" + agent = Agent.query.get(agent_id) + if not agent: + return jsonify({'error': 'Agent not found'}), 404 + + agent.last_heartbeat = datetime.utcnow() + agent.status = 'online' + db.session.commit() + + return jsonify({'status': 'ok'}), 200 + + +@agents_bp.route('/available', methods=['GET']) +@jwt_required() +def get_available_agents(): + """获取可用 Agent 列表""" + agents = Agent.query.filter_by(status='online').all() + return jsonify({'agents': [a.to_dict() for a in agents]}), 200 diff --git a/app/routes/auth.py b/app/routes/auth.py new file mode 100644 index 0000000..6c1a6c0 --- /dev/null +++ b/app/routes/auth.py @@ -0,0 +1,142 @@ +""" +认证路由 +""" +from flask import Blueprint, request, jsonify +from flask_jwt_extended import ( + create_access_token, create_refresh_token, + jwt_required, get_jwt_identity +) +from datetime import datetime +import bcrypt +from app.models import db, User + +auth_bp = Blueprint('auth', __name__) + + +@auth_bp.route('/register', methods=['POST']) +def register(): + """注册新用户""" + data = request.get_json() + + username = data.get('username') + email = data.get('email') + password = data.get('password') + nickname = data.get('nickname') + + if not all([username, email, password]): + return jsonify({'error': 'Missing required fields'}), 400 + + # 检查用户是否已存在 + if User.query.filter_by(username=username).first(): + return jsonify({'error': 'Username already exists'}), 400 + + if User.query.filter_by(email=email).first(): + return jsonify({'error': 'Email already exists'}), 400 + + # 密码哈希 + password_hash = bcrypt.hashpw( + password.encode('utf-8'), + bcrypt.gensalt() + ).decode('utf-8') + + # 创建用户 + user = User( + username=username, + email=email, + password_hash=password_hash, + nickname=nickname or username + ) + db.session.add(user) + db.session.commit() + + return jsonify({ + 'message': 'User registered successfully', + 'user': user.to_dict() + }), 201 + + +@auth_bp.route('/login', methods=['POST']) +def login(): + """用户登录""" + data = request.get_json() + + username = data.get('username') + password = data.get('password') + + if not all([username, password]): + return jsonify({'error': 'Missing username or password'}), 400 + + # 查找用户 + user = User.query.filter_by(username=username).first() + if not user: + return jsonify({'error': 'Invalid username or password'}), 401 + + # 验证密码 + if not bcrypt.checkpw(password.encode('utf-8'), user.password_hash.encode('utf-8')): + return jsonify({'error': 'Invalid username or password'}), 401 + + # 检查用户状态 + if user.status != 'active': + return jsonify({'error': 'Account is disabled'}), 403 + + # 更新最后登录时间 + user.last_login_at = datetime.utcnow() + db.session.commit() + + # 生成 Token + access_token = create_access_token(identity=user.id) + refresh_token = create_refresh_token(identity=user.id) + + return jsonify({ + 'access_token': access_token, + 'refresh_token': refresh_token, + 'user': user.to_dict() + }), 200 + + +@auth_bp.route('/refresh', methods=['POST']) +@jwt_required(refresh=True) +def refresh(): + """刷新 Token""" + identity = get_jwt_identity() + access_token = create_access_token(identity=identity) + + return jsonify({ + 'access_token': access_token + }), 200 + + +@auth_bp.route('/me', methods=['GET']) +@jwt_required() +def get_current_user(): + """获取当前用户信息""" + user_id = get_jwt_identity() + user = User.query.get(user_id) + + if not user: + return jsonify({'error': 'User not found'}), 404 + + return jsonify({'user': user.to_dict()}), 200 + + +@auth_bp.route('/logout', methods=['POST']) +@jwt_required() +def logout(): + """用户登出""" + return jsonify({'message': 'Logged out successfully'}), 200 + + +@auth_bp.route('/verify', methods=['POST']) +@jwt_required() +def verify_token(): + """验证 Token 有效性""" + user_id = get_jwt_identity() + user = User.query.get(user_id) + + if not user: + return jsonify({'valid': False}), 401 + + return jsonify({ + 'valid': True, + 'user': user.to_dict() + }), 200 diff --git a/app/routes/gateways.py b/app/routes/gateways.py new file mode 100644 index 0000000..478bae0 --- /dev/null +++ b/app/routes/gateways.py @@ -0,0 +1,99 @@ +""" +Gateway 路由 +""" +from flask import Blueprint, request, jsonify +from flask_jwt_extended import jwt_required +from datetime import datetime +import bcrypt +from app.models import db, Gateway + +gateways_bp = Blueprint('gateways', __name__) + + +@gateway_bp.route('/', methods=['GET']) +@jwt_required() +def get_gateways(): + """获取 Gateway 列表""" + gateways = Gateway.query.all() + return jsonify({'gateways': [g.to_dict() for g in gateways]}), 200 + + +@gateway_bp.route('/', methods=['POST']) +@jwt_required() +def register_gateway(): + """注册 Gateway""" + data = request.get_json() + + name = data.get('name') + url = data.get('url') + token = data.get('token') + + if not all([name, url]): + return jsonify({'error': 'Missing required fields'}), 400 + + if Gateway.query.filter_by(name=name).first(): + return jsonify({'error': 'Gateway name already exists'}), 400 + + # Token 哈希 + token_hash = None + if token: + token_hash = bcrypt.hashpw( + token.encode('utf-8'), + bcrypt.gensalt() + ).decode('utf-8') + + gateway = Gateway( + name=name, + url=url, + token_hash=token_hash, + status='offline' + ) + + db.session.add(gateway) + db.session.commit() + + return jsonify({'gateway': gateway.to_dict()}), 201 + + +@gateway_bp.route('/', methods=['DELETE']) +@jwt_required() +def delete_gateway(gateway_id): + """注销 Gateway""" + gateway = Gateway.query.get(gateway_id) + if not gateway: + return jsonify({'error': 'Gateway not found'}), 404 + + db.session.delete(gateway) + db.session.commit() + + return jsonify({'message': 'Gateway deleted'}), 200 + + +@gateway_bp.route('//status', methods=['GET']) +@jwt_required() +def get_gateway_status(gateway_id): + """获取 Gateway 状态""" + gateway = Gateway.query.get(gateway_id) + if not gateway: + return jsonify({'error': 'Gateway not found'}), 404 + + return jsonify({ + 'gateway_id': gateway.id, + 'status': gateway.status, + 'agent_count': gateway.agent_count, + 'last_heartbeat': gateway.last_heartbeat.isoformat() if gateway.last_heartbeat else None + }), 200 + + +@gateway_bp.route('//heartbeat', methods=['POST']) +def gateway_heartbeat(gateway_id): + """Gateway 心跳上报""" + gateway = Gateway.query.get(gateway_id) + if not gateway: + return jsonify({'error': 'Gateway not found'}), 404 + + gateway.last_heartbeat = datetime.utcnow() + gateway.status = 'online' + db.session.commit() + + return jsonify({'status': 'ok'}), 200 diff --git a/app/routes/messages.py b/app/routes/messages.py new file mode 100644 index 0000000..1475d2d --- /dev/null +++ b/app/routes/messages.py @@ -0,0 +1,91 @@ +""" +消息路由 +""" +from flask import Blueprint, request, jsonify +from flask_jwt_extended import jwt_required, get_jwt_identity +from datetime import datetime +from app.models import db, Message, Session + +messages_bp = Blueprint('messages', __name__) + + +@messages_bp.route('/', methods=['POST']) +@jwt_required() +def send_message(): + """发送消息 (HTTP 方式)""" + user_id = get_jwt_identity() + data = request.get_json() + + session_id = data.get('session_id') + content = data.get('content') + message_type = data.get('type', 'text') + reply_to = data.get('reply_to') + + if not all([session_id, content]): + return jsonify({'error': 'Missing required fields'}), 400 + + # 验证会话 + session = Session.query.filter_by(id=session_id, user_id=user_id).first() + if not session: + return jsonify({'error': 'Session not found'}), 404 + + # 创建消息 + message = Message( + session_id=session_id, + sender_type='user', + sender_id=user_id, + message_type=message_type, + content=content, + reply_to=reply_to, + status='sent', + ack_status='pending' + ) + + db.session.add(message) + + # 更新会话 + session.message_count += 1 + session.last_active_at = datetime.utcnow() + + db.session.commit() + + return jsonify({'message': message.to_dict()}), 201 + + +@messages_bp.route('/', methods=['GET']) +@jwt_required() +def get_message(message_id): + """获取单条消息""" + message = Message.query.get(message_id) + if not message: + return jsonify({'error': 'Message not found'}), 404 + + return jsonify({'message': message.to_dict()}), 200 + + +@messages_bp.route('//ack', methods=['PUT']) +def acknowledge_message(message_id): + """确认消息已送达""" + message = Message.query.get(message_id) + if not message: + return jsonify({'error': 'Message not found'}), 404 + + message.ack_status = 'acknowledged' + message.delivered_at = datetime.utcnow() + db.session.commit() + + return jsonify({'message': message.to_dict()}), 200 + + +@messages_bp.route('//read', methods=['PUT']) +@jwt_required() +def mark_message_read(message_id): + """标记消息已读""" + message = Message.query.get(message_id) + if not message: + return jsonify({'error': 'Message not found'}), 404 + + message.status = 'read' + db.session.commit() + + return jsonify({'message': message.to_dict()}), 200 diff --git a/app/routes/sessions.py b/app/routes/sessions.py new file mode 100644 index 0000000..6e45f37 --- /dev/null +++ b/app/routes/sessions.py @@ -0,0 +1,102 @@ +""" +会话路由 +""" +from flask import Blueprint, request, jsonify +from flask_jwt_extended import jwt_required, get_jwt_identity +from datetime import datetime +from app.models import db, Session, Agent + +sessions_bp = Blueprint('sessions', __name__) + + +@sessions_bp.route('/', methods=['GET']) +@jwt_required() +def get_sessions(): + """获取会话列表""" + user_id = get_jwt_identity() + sessions = Session.query.filter_by(user_id=user_id).all() + + return jsonify({ + 'sessions': [s.to_dict() for s in sessions] + }), 200 + + +@sessions_bp.route('/', methods=['POST']) +@jwt_required() +def create_session(): + """创建会话""" + user_id = get_jwt_identity() + data = request.get_json() + + title = data.get('title', 'New Session') + agent_id = data.get('agent_id') + priority = data.get('priority', 5) + + # 如果没有指定 Agent,分配一个 + if not agent_id: + agent = Agent.query.filter_by(status='online').order_by( + Agent.current_sessions.asc() + ).first() + if agent: + agent_id = agent.id + + session = Session( + user_id=user_id, + primary_agent_id=agent_id, + title=title, + status='active' + ) + + db.session.add(session) + db.session.commit() + + return jsonify({ + 'session': session.to_dict() + }), 201 + + +@sessions_bp.route('/', methods=['GET']) +@jwt_required() +def get_session(session_id): + """获取会话详情""" + user_id = get_jwt_identity() + session = Session.query.filter_by(id=session_id, user_id=user_id).first() + + if not session: + return jsonify({'error': 'Session not found'}), 404 + + return jsonify({'session': session.to_dict()}), 200 + + +@sessions_bp.route('//close', methods=['PUT']) +@jwt_required() +def close_session(session_id): + """关闭会话""" + user_id = get_jwt_identity() + session = Session.query.filter_by(id=session_id, user_id=user_id).first() + + if not session: + return jsonify({'error': 'Session not found'}), 404 + + session.status = 'closed' + session.updated_at = datetime.utcnow() + db.session.commit() + + return jsonify({'message': 'Session closed', 'session': session.to_dict()}), 200 + + +@sessions_bp.route('//messages', methods=['GET']) +@jwt_required() +def get_session_messages(session_id): + """获取会话消息""" + user_id = get_jwt_identity() + session = Session.query.filter_by(id=session_id, user_id=user_id).first() + + if not session: + return jsonify({'error': 'Session not found'}), 404 + + messages = session.messages.order_by('created_at').all() + + return jsonify({ + 'messages': [m.to_dict() for m in messages] + }), 200 diff --git a/app/routes/stats.py b/app/routes/stats.py new file mode 100644 index 0000000..387efc8 --- /dev/null +++ b/app/routes/stats.py @@ -0,0 +1,67 @@ +""" +统计路由 +""" +from flask import Blueprint, jsonify +from flask_jwt_extended import jwt_required +from app.models import db, User, Session, Agent, Gateway, Message +from sqlalchemy import func + +stats_bp = Blueprint('stats', __name__) + + +@stats_bp.route('/', methods=['GET']) +@jwt_required() +def get_stats(): + """获取系统统计信息""" + stats = { + 'users': User.query.count(), + 'sessions': Session.query.filter_by(status='active').count(), + 'agents': Agent.query.filter_by(status='online').count(), + 'gateways': Gateway.query.filter_by(status='online').count(), + 'messages': Message.query.count(), + } + + return jsonify({'stats': stats}), 200 + + +@stats_bp.route('/sessions', methods=['GET']) +@jwt_required() +def get_session_stats(): + """获取会话统计""" + stats = { + 'total': Session.query.count(), + 'active': Session.query.filter_by(status='active').count(), + 'closed': Session.query.filter_by(status='closed').count(), + 'paused': Session.query.filter_by(status='paused').count(), + } + + return jsonify({'stats': stats}), 200 + + +@stats_bp.route('/messages', methods=['GET']) +@jwt_required() +def get_message_stats(): + """获取消息统计""" + stats = { + 'total': Message.query.count(), + 'sent': Message.query.filter_by(status='sent').count(), + 'delivered': Message.query.filter_by(status='delivered').count(), + 'read': Message.query.filter_by(status='read').count(), + 'failed': Message.query.filter_by(status='failed').count(), + } + + return jsonify({'stats': stats}), 200 + + +@stats_bp.route('/agents', methods=['GET']) +@jwt_required() +def get_agent_stats(): + """获取 Agent 统计""" + stats = { + 'total': Agent.query.count(), + 'online': Agent.query.filter_by(status='online').count(), + 'offline': Agent.query.filter_by(status='offline').count(), + 'busy': Agent.query.filter_by(status='busy').count(), + } + + return jsonify({'stats': stats}), 200 diff --git a/app/services/__init__.py b/app/services/__init__.py new file mode 100644 index 0000000..b60bf1b --- /dev/null +++ b/app/services/__init__.py @@ -0,0 +1,3 @@ +""" +服务模块 +""" diff --git a/app/socketio/__init__.py b/app/socketio/__init__.py new file mode 100644 index 0000000..487f7d1 --- /dev/null +++ b/app/socketio/__init__.py @@ -0,0 +1,3 @@ +""" +Socket.IO 模块 +""" diff --git a/app/socketio/events.py b/app/socketio/events.py new file mode 100644 index 0000000..244cb80 --- /dev/null +++ b/app/socketio/events.py @@ -0,0 +1,54 @@ +""" +Socket.IO 事件定义 +""" +# 认证事件 +AUTH_EVENTS = [ + 'auth', # C→S: 认证请求 + 'authenticated', # S→C: 认证成功 + 'auth_error', # S→C: 认证失败 +] + +# 心跳事件 +HEARTBEAT_EVENTS = [ + 'ping', # C→S: 心跳请求 + 'pong', # S→C: 心跳响应 + 'heartbeat_timeout', # S→C: 心跳超时 +] + +# 会话事件 +SESSION_EVENTS = [ + 'session.create', # C→S: 创建会话 + 'session.created', # S→C: 会话已创建 + 'session.join', # C→S: 加入会话 + 'session.joined', # S→C: 已加入会话 + 'session.leave', # C→S: 离开会话 + 'session.left', # S→C: 已离开会话 + 'session.closed', # S→C: 会话被关闭 + 'session.assigned', # S→C: Agent 分配通知 +] + +# 消息事件 +MESSAGE_EVENTS = [ + 'message.send', # C→S: 发送消息 + 'message', # S→C: 收到消息 + 'message.ack', # C→S: 消息确认 + 'message.acked', # S→C: 确认已收到 + 'message.read', # C→S: 消息已读 + 'message.stream', # S→C: 流式消息 + 'typing', # C→S: 正在输入 +] + +# 错误事件 +ERROR_EVENTS = [ + 'error', # S→C: 通用错误 + 'session_error', # S→C: 会话错误 + 'message_error', # S→C: 消息错误 +] + +ALL_EVENTS = ( + AUTH_EVENTS + + HEARTBEAT_EVENTS + + SESSION_EVENTS + + MESSAGE_EVENTS + + ERROR_EVENTS +) diff --git a/app/socketio/handlers.py b/app/socketio/handlers.py new file mode 100644 index 0000000..29db808 --- /dev/null +++ b/app/socketio/handlers.py @@ -0,0 +1,182 @@ +""" +Socket.IO 事件处理 +""" +from flask_socketio import emit, join_room, leave_room +from flask_jwt_extended import decode_token +from datetime import datetime +import json +from app.extensions import redis_client + +# 连接管理器 +class ConnectionManager: + def __init__(self): + self.user_sockets = {} # user_id -> socket_id + self.agent_sockets = {} # agent_id -> socket_id + self.socket_sessions = {} # socket_id -> session_id + +connection_manager = ConnectionManager() + + +def register_handlers(socketio): + """注册 Socket.IO 事件处理器""" + + @socketio.on('connect') + def handle_connect(): + """客户端连接""" + print('Client connected') + # 发送认证请求 + emit('auth', {'message': 'Please authenticate'}) + + @socketio.on('disconnect') + def handle_disconnect(): + """客户端断开连接""" + from flask import request + sid = request.sid + print(f'Client disconnected: {sid}') + + # 清理连接 + if sid in connection_manager.socket_sessions: + del connection_manager.socket_sessions[sid] + + @socketio.on('auth') + def handle_auth(data): + """处理认证""" + from flask import request + sid = request.sid + + token = data.get('token') + if not token: + emit('auth_error', {'code': 'MISSING_TOKEN', 'message': 'Token is required'}) + return + + try: + # 验证 JWT Token + decoded = decode_token(token) + user_id = decoded['sub'] + + # 保存连接信息 + connection_manager.user_sockets[user_id] = sid + + # 存储到 Redis + redis_client.hset(f'socket:{sid}', mapping={ + 'user_id': user_id, + 'connected_at': datetime.utcnow().isoformat() + }) + redis_client.expire(f'socket:{sid}', 86400) + + emit('authenticated', { + 'user_id': user_id, + 'socket_id': sid + }) + except Exception as e: + emit('auth_error', {'code': 'INVALID_TOKEN', 'message': str(e)}) + + @socketio.on('ping') + def handle_ping(data): + """心跳响应""" + emit('pong', {'timestamp': datetime.utcnow().timestamp()}) + + @socketio.on('session.create') + def handle_session_create(data): + """创建会话""" + from app.models import db, Session, Agent + + user_id = data.get('user_id') + title = data.get('title', 'New Session') + agent_id = data.get('agent_id') + + # 如果没有指定 Agent,分配一个 + if not agent_id: + agent = Agent.query.filter_by(status='online').first() + if agent: + agent_id = agent.id + + session = Session( + user_id=user_id, + primary_agent_id=agent_id, + title=title + ) + db.session.add(session) + db.session.commit() + + emit('session.created', { + 'session_id': session.id, + 'agent_id': agent_id + }) + + @socketio.on('session.join') + def handle_session_join(data): + """加入会话""" + from flask import request + from app.models import Session + + sid = request.sid + session_id = data.get('session_id') + + # 加入房间 + join_room(session_id) + connection_manager.socket_sessions[sid] = session_id + + session = Session.query.get(session_id) + if session: + emit('session.joined', { + 'session_id': session_id, + 'participants': [session.user_id, session.primary_agent_id] + }) + + @socketio.on('message.send') + def handle_message_send(data): + """发送消息""" + from flask import request + from app.models import db, Message, Session + + sid = request.sid + session_id = data.get('session_id') + content = data.get('content') + sender_type = data.get('sender_type', 'user') + sender_id = data.get('sender_id') + + # 创建消息 + message = Message( + session_id=session_id, + sender_type=sender_type, + sender_id=sender_id, + content=content, + status='sent' + ) + db.session.add(message) + + # 更新会话 + session = Session.query.get(session_id) + if session: + session.message_count += 1 + session.last_active_at = datetime.utcnow() + + db.session.commit() + + # 广播消息到房间 + emit('message', { + 'message_id': message.id, + 'session_id': session_id, + 'sender': {'type': sender_type, 'id': sender_id}, + 'content': content, + 'timestamp': datetime.utcnow().isoformat() + }, room=session_id) + + @socketio.on('message.ack') + def handle_message_ack(data): + """消息确认""" + from app.models import db, Message + + message_id = data.get('message_id') + status = data.get('status') + + message = Message.query.get(message_id) + if message: + message.ack_status = status + if status == 'acknowledged': + message.status = 'delivered' + message.delivered_at = datetime.utcnow() + db.session.commit() + + emit('message.acked', {'message_id': message_id, 'status': status}) diff --git a/app/utils/__init__.py b/app/utils/__init__.py new file mode 100644 index 0000000..548e274 --- /dev/null +++ b/app/utils/__init__.py @@ -0,0 +1,3 @@ +""" +工具模块 +""" diff --git a/docker-compose.yaml b/docker-compose.yaml new file mode 100644 index 0000000..ca7e44e --- /dev/null +++ b/docker-compose.yaml @@ -0,0 +1,80 @@ +version: '3.8' + +services: + pit-router: + build: . + ports: + - "9000:9000" + environment: + - FLASK_ENV=production + - SECRET_KEY=${SECRET_KEY} + - JWT_SECRET_KEY=${JWT_SECRET_KEY} + - DATABASE_URL=postgresql://postgres:${DB_PASSWORD}@postgres:5432/pit_router + - REDIS_URL=redis://redis:6379/0 + volumes: + - pit-data:/app/data + - pit-logs:/app/logs + depends_on: + postgres: + condition: service_healthy + redis: + condition: service_healthy + restart: unless-stopped + healthcheck: + test: ["CMD", "python", "-c", "import requests; requests.get('http://localhost:9000/health')"] + interval: 30s + timeout: 10s + retries: 3 + start_period: 40s + deploy: + resources: + limits: + cpus: '2.0' + memory: 2G + reservations: + cpus: '0.5' + memory: 512M + + postgres: + image: postgres:15-alpine + environment: + - POSTGRES_USER=postgres + - POSTGRES_PASSWORD=${DB_PASSWORD} + - POSTGRES_DB=pit_router + volumes: + - postgres-data:/var/lib/postgresql/data + restart: unless-stopped + healthcheck: + test: ["CMD-SHELL", "pg_isready -U postgres -d pit_router"] + interval: 10s + timeout: 5s + retries: 5 + + redis: + image: redis:7-alpine + volumes: + - redis-data:/data + restart: unless-stopped + healthcheck: + test: ["CMD", "redis-cli", "ping"] + interval: 10s + timeout: 5s + retries: 5 + + nginx: + image: nginx:alpine + ports: + - "80:80" + - "443:443" + volumes: + - ./nginx.conf:/etc/nginx/nginx.conf:ro + - ./ssl:/etc/nginx/ssl:ro + depends_on: + - pit-router + restart: unless-stopped + +volumes: + pit-data: + pit-logs: + postgres-data: + redis-data: diff --git a/requirements-dev.txt b/requirements-dev.txt new file mode 100644 index 0000000..6ee2869 --- /dev/null +++ b/requirements-dev.txt @@ -0,0 +1,7 @@ +pytest==8.0.2 +pytest-flask==1.3.0 +pytest-cov==4.1.0 +black==24.2.0 +flake8==7.0.0 +mypy==1.8.0 +ipython==8.22.1 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..458967e --- /dev/null +++ b/requirements.txt @@ -0,0 +1,20 @@ +alembic==1.13.1 +Flask==3.0.2 +Flask-CORS==4.0.0 +Flask-JWT-Extended==4.6.0 +Flask-Limiter==3.5.0 +Flask-Login==0.6.3 +Flask-Migrate==4.0.5 +Flask-SocketIO==5.3.6 +Flask-SQLAlchemy==3.1.1 +gunicorn==21.2.0 +psycopg2-binary==2.9.9 +PyJWT==2.8.0 +python-dotenv==1.0.0 +redis==5.0.1 +SQLAlchemy==2.0.27 +Werkzeug==3.0.1 +bcrypt==4.1.2 +marshmallow==3.21.0 +marshmallow-sqlalchemy==1.0.0 +python-dateutil==2.9.0 diff --git a/run.py b/run.py new file mode 100644 index 0000000..983e47a --- /dev/null +++ b/run.py @@ -0,0 +1,10 @@ +""" +PIT Router 启动入口 +""" +from app import create_app +from app.extensions import socketio + +app = create_app('development') + +if __name__ == '__main__': + socketio.run(app, host='0.0.0.0', port=9000, debug=True)