feat: Phase 1 - 核心功能实现

This commit is contained in:
2026-03-14 19:41:36 +08:00
parent 3e9f632501
commit bf245ee8cb
21 changed files with 1468 additions and 0 deletions

39
.env.example Normal file
View File

@@ -0,0 +1,39 @@
# PIT Router 配置文件
# Flask 配置
FLASK_ENV=development
FLASK_APP=run.py
SECRET_KEY=your-secret-key-here-change-in-production
JWT_SECRET_KEY=your-jwt-secret-key-here-change-in-production
# 数据库配置
DATABASE_URL=postgresql://user:password@localhost:5432/pit_router
# 开发环境可使用 SQLite
# DATABASE_URL=sqlite:///app.db
# Redis 配置
REDIS_URL=redis://localhost:6379/0
# 服务器配置
HOST=0.0.0.0
PORT=9000
DEBUG=True
# JWT 配置
JWT_ACCESS_TOKEN_EXPIRES=86400
JWT_REFRESH_TOKEN_EXPIRES=604800
# WebSocket 配置
SOCKETIO_PING_INTERVAL=25000
SOCKETIO_PING_TIMEOUT=10000
# Agent 调度配置
SCHEDULER_STRATEGY=weighted_round_robin
SCHEDULER_TIMEOUT=30
# 安全配置
RATE_LIMIT=100
CORS_ORIGINS=*
# 日志配置
LOG_LEVEL=INFO

117
app/__init__.py Normal file
View File

@@ -0,0 +1,117 @@
"""
PIT Router Flask 应用工厂
"""
from flask import Flask
from flask_socketio import SocketIO
import logging
from app.config import config
from app.extensions import (
db, migrate, jwt, login_manager, cors, limiter, init_redis
)
# Socket.IO 实例
socketio = SocketIO(cors_allowed_origins="*", async_mode='threading')
def create_app(config_name='default'):
"""创建 Flask 应用实例"""
app = Flask(__name__)
app.config.from_object(config[config_name])
# 初始化扩展
_init_extensions(app)
# 注册蓝图
_register_blueprints(app)
# 注册 Socket.IO 事件
_register_socketio_events()
# 配置日志
_configure_logging(app)
# 注册错误处理
_register_error_handlers(app)
# 健康检查端点
@app.route('/health')
def health_check():
return {'status': 'ok', 'service': 'pit-router'}
return app
def _init_extensions(app):
"""初始化 Flask 扩展"""
db.init_app(app)
migrate.init_app(app, db)
jwt.init_app(app)
login_manager.init_app(app)
cors.init_app(app)
limiter.init_app(app)
socketio.init_app(app)
# 初始化 Redis
init_redis(app)
# 创建数据库表
with app.app_context():
db.create_all()
def _register_blueprints(app):
"""注册蓝图"""
from app.routes.auth import auth_bp
from app.routes.sessions import sessions_bp
from app.routes.agents import agents_bp
from app.routes.gateways import gateways_bp
from app.routes.messages import messages_bp
from app.routes.stats import stats_bp
app.register_blueprint(auth_bp, url_prefix='/api/auth')
app.register_blueprint(sessions_bp, url_prefix='/api/sessions')
app.register_blueprint(agents_bp, url_prefix='/api/agents')
app.register_blueprint(gateways_bp, url_prefix='/api/gateways')
app.register_blueprint(messages_bp, url_prefix='/api/messages')
app.register_blueprint(stats_bp, url_prefix='/api/stats')
def _register_socketio_events():
"""注册 Socket.IO 事件处理器"""
from app.socketio.handlers import register_handlers
register_handlers(socketio)
def _configure_logging(app):
"""配置日志"""
log_level = app.config.get('LOG_LEVEL', 'INFO')
logging.basicConfig(
level=getattr(logging, log_level),
format='%(asctime)s [%(levelname)s] %(name)s: %(message)s'
)
def _register_error_handlers(app):
"""注册错误处理器"""
from flask import jsonify
@app.errorhandler(400)
def bad_request(error):
return jsonify({'error': 'Bad Request', 'message': str(error)}), 400
@app.errorhandler(401)
def unauthorized(error):
return jsonify({'error': 'Unauthorized', 'message': str(error)}), 401
@app.errorhandler(403)
def forbidden(error):
return jsonify({'error': 'Forbidden', 'message': str(error)}), 403
@app.errorhandler(404)
def not_found(error):
return jsonify({'error': 'Not Found', 'message': str(error)}), 404
@app.errorhandler(500)
def internal_error(error):
return jsonify({'error': 'Internal Server Error', 'message': str(error)}), 500

72
app/config.py Normal file
View File

@@ -0,0 +1,72 @@
"""
PIT Router 配置管理
"""
import os
from datetime import timedelta
from dotenv import load_dotenv
load_dotenv()
class Config:
"""基础配置"""
SECRET_KEY = os.environ.get('SECRET_KEY', 'dev-secret-key-change-in-production')
JWT_SECRET_KEY = os.environ.get('JWT_SECRET_KEY', 'dev-jwt-secret-key-change-in-production')
# 数据库
SQLALCHEMY_DATABASE_URI = os.environ.get('DATABASE_URL', 'sqlite:///pit_router.db')
SQLALCHEMY_TRACK_MODIFICATIONS = False
SQLALCHEMY_ECHO = False
# Redis
REDIS_URL = os.environ.get('REDIS_URL', 'redis://localhost:6379/0')
# JWT
JWT_ACCESS_TOKEN_EXPIRES = timedelta(
seconds=int(os.environ.get('JWT_ACCESS_TOKEN_EXPIRES', 86400))
)
JWT_REFRESH_TOKEN_EXPIRES = timedelta(
seconds=int(os.environ.get('JWT_REFRESH_TOKEN_EXPIRES', 604800))
)
# Socket.IO
SOCKETIO_PING_INTERVAL = int(os.environ.get('SOCKETIO_PING_INTERVAL', 25000))
SOCKETIO_PING_TIMEOUT = int(os.environ.get('SOCKETIO_PING_TIMEOUT', 10000))
# Agent 调度
SCHEDULER_STRATEGY = os.environ.get('SCHEDULER_STRATEGY', 'weighted_round_robin')
SCHEDULER_TIMEOUT = int(os.environ.get('SCHEDULER_TIMEOUT', 30))
# 安全
RATE_LIMIT = int(os.environ.get('RATE_LIMIT', 100))
CORS_ORIGINS = os.environ.get('CORS_ORIGINS', '*')
class DevelopmentConfig(Config):
"""开发环境配置"""
DEBUG = True
SQLALCHEMY_ECHO = True
class ProductionConfig(Config):
"""生产环境配置"""
DEBUG = False
# 生产环境强制使用 PostgreSQL
if not os.environ.get('DATABASE_URL'):
raise ValueError("DATABASE_URL must be set in production")
class TestingConfig(Config):
"""测试环境配置"""
TESTING = True
SQLALCHEMY_DATABASE_URI = 'sqlite:///:memory:'
WTF_CSRF_ENABLED = False
config = {
'development': DevelopmentConfig,
'production': ProductionConfig,
'testing': TestingConfig,
'default': DevelopmentConfig
}

32
app/extensions.py Normal file
View File

@@ -0,0 +1,32 @@
"""
Flask 扩展初始化
"""
from flask_sqlalchemy import SQLAlchemy
from flask_migrate import Migrate
from flask_jwt_extended import JWTManager
from flask_login import LoginManager
from flask_cors import CORS
from flask_limiter import Limiter
from flask_limiter.util import get_remote_address
import redis
db = SQLAlchemy()
migrate = Migrate()
jwt = JWTManager()
login_manager = LoginManager()
cors = CORS()
limiter = Limiter(
key_func=get_remote_address,
default_limits=["100 per minute"]
)
# Redis 客户端
redis_client = None
def init_redis(app):
"""初始化 Redis 客户端"""
global redis_client
redis_url = app.config.get('REDIS_URL', 'redis://localhost:6379/0')
redis_client = redis.from_url(redis_url)
return redis_client

251
app/models/__init__.py Normal file
View File

@@ -0,0 +1,251 @@
"""
PIT Router 数据模型
"""
from datetime import datetime
from typing import Optional, List
from flask_sqlalchemy import SQLAlchemy
from sqlalchemy import String, DateTime, Integer, Text, JSON, ForeignKey, Boolean
from sqlalchemy.orm import Mapped, mapped_column, relationship
import uuid
db = SQLAlchemy()
def generate_uuid() -> str:
"""生成 UUID"""
return str(uuid.uuid4())
class User(db.Model):
"""用户模型"""
__tablename__ = 'users'
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=generate_uuid)
username: Mapped[str] = mapped_column(String(80), unique=True, nullable=False)
password_hash: Mapped[str] = mapped_column(String(256), nullable=False)
email: Mapped[str] = mapped_column(String(120), unique=True, nullable=False)
nickname: Mapped[Optional[str]] = mapped_column(String(80), nullable=True)
role: Mapped[str] = mapped_column(String(20), default='user')
status: Mapped[str] = mapped_column(String(20), default='active')
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
last_login_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
# 关联
sessions = relationship('Session', back_populates='user', cascade='all, delete-orphan')
def __repr__(self):
return f'<User {self.username}>'
def to_dict(self):
return {
'id': self.id,
'username': self.username,
'email': self.email,
'nickname': self.nickname,
'role': self.role,
'status': self.status,
'created_at': self.created_at.isoformat() if self.created_at else None,
'last_login_at': self.last_login_at.isoformat() if self.last_login_at else None,
}
class Gateway(db.Model):
"""Gateway 模型"""
__tablename__ = 'gateways'
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=generate_uuid)
name: Mapped[str] = mapped_column(String(80), unique=True, nullable=False)
url: Mapped[str] = mapped_column(String(256), nullable=False)
token_hash: Mapped[Optional[str]] = mapped_column(String(256), nullable=True)
status: Mapped[str] = mapped_column(String(20), default='offline')
agent_count: Mapped[int] = mapped_column(Integer, default=0)
connection_limit: Mapped[int] = mapped_column(Integer, default=10)
heartbeat_interval: Mapped[int] = mapped_column(Integer, default=60)
allowed_ips: Mapped[Optional[str]] = mapped_column(JSON, nullable=True)
last_heartbeat: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
# 关联
agents = relationship('Agent', back_populates='gateway', cascade='all, delete-orphan')
def __repr__(self):
return f'<Gateway {self.name}>'
def to_dict(self):
return {
'id': self.id,
'name': self.name,
'url': self.url,
'status': self.status,
'agent_count': self.agent_count,
'connection_limit': self.connection_limit,
'heartbeat_interval': self.heartbeat_interval,
'last_heartbeat': self.last_heartbeat.isoformat() if self.last_heartbeat else None,
'created_at': self.created_at.isoformat() if self.created_at else None,
}
class Agent(db.Model):
"""Agent 模型"""
__tablename__ = 'agents'
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=generate_uuid)
name: Mapped[str] = mapped_column(String(80), nullable=False)
display_name: Mapped[Optional[str]] = mapped_column(String(80), nullable=True)
gateway_id: Mapped[Optional[str]] = mapped_column(String(36), ForeignKey('gateways.id'), nullable=True)
socket_id: Mapped[Optional[str]] = mapped_column(String(100), nullable=True)
model: Mapped[Optional[str]] = mapped_column(String(80), nullable=True)
capabilities: Mapped[Optional[str]] = mapped_column(JSON, nullable=True)
status: Mapped[str] = mapped_column(String(20), default='offline')
priority: Mapped[int] = mapped_column(Integer, default=5)
weight: Mapped[int] = mapped_column(Integer, default=10)
connection_limit: Mapped[int] = mapped_column(Integer, default=5)
current_sessions: Mapped[int] = mapped_column(Integer, default=0)
last_heartbeat: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
# 关联
gateway = relationship('Gateway', back_populates='agents')
sessions = relationship('Session', back_populates='agent')
def __repr__(self):
return f'<Agent {self.name}>'
def to_dict(self):
return {
'id': self.id,
'name': self.name,
'display_name': self.display_name,
'gateway_id': self.gateway_id,
'socket_id': self.socket_id,
'model': self.model,
'capabilities': self.capabilities,
'status': self.status,
'priority': self.priority,
'weight': self.weight,
'connection_limit': self.connection_limit,
'current_sessions': self.current_sessions,
'last_heartbeat': self.last_heartbeat.isoformat() if self.last_heartbeat else None,
'created_at': self.created_at.isoformat() if self.created_at else None,
}
class Session(db.Model):
"""会话模型"""
__tablename__ = 'sessions'
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=generate_uuid)
user_id: Mapped[str] = mapped_column(String(36), ForeignKey('users.id'), nullable=False)
primary_agent_id: Mapped[Optional[str]] = mapped_column(String(36), ForeignKey('agents.id'), nullable=True)
participating_agent_ids: Mapped[Optional[str]] = mapped_column(JSON, nullable=True)
user_socket_id: Mapped[Optional[str]] = mapped_column(String(100), nullable=True)
title: Mapped[Optional[str]] = mapped_column(String(200), nullable=True)
channel_type: Mapped[str] = mapped_column(String(20), default='web')
status: Mapped[str] = mapped_column(String(20), default='active')
message_count: Mapped[int] = mapped_column(Integer, default=0)
unread_count: Mapped[int] = mapped_column(Integer, default=0)
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
updated_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
last_active_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
# 关联
user = relationship('User', back_populates='sessions')
agent = relationship('Agent', back_populates='sessions')
messages = relationship('Message', back_populates='session', cascade='all, delete-orphan')
def __repr__(self):
return f'<Session {self.id}>'
def to_dict(self):
return {
'id': self.id,
'user_id': self.user_id,
'primary_agent_id': self.primary_agent_id,
'participating_agent_ids': self.participating_agent_ids,
'user_socket_id': self.user_socket_id,
'title': self.title,
'channel_type': self.channel_type,
'status': self.status,
'message_count': self.message_count,
'unread_count': self.unread_count,
'created_at': self.created_at.isoformat() if self.created_at else None,
'updated_at': self.updated_at.isoformat() if self.updated_at else None,
'last_active_at': self.last_active_at.isoformat() if self.last_active_at else None,
}
class Message(db.Model):
"""消息模型"""
__tablename__ = 'messages'
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=generate_uuid)
session_id: Mapped[str] = mapped_column(String(36), ForeignKey('sessions.id'), nullable=False)
sender_type: Mapped[str] = mapped_column(String(20), nullable=False)
sender_id: Mapped[str] = mapped_column(String(36), nullable=False)
message_type: Mapped[str] = mapped_column(String(20), default='text')
content: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
content_type: Mapped[str] = mapped_column(String(20), default='markdown')
reply_to: Mapped[Optional[str]] = mapped_column(String(36), nullable=True)
status: Mapped[str] = mapped_column(String(20), default='sent')
ack_status: Mapped[str] = mapped_column(String(20), default='pending')
retry_count: Mapped[int] = mapped_column(Integer, default=0)
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
delivered_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
# 关联
session = relationship('Session', back_populates='messages')
def __repr__(self):
return f'<Message {self.id}>'
def to_dict(self):
return {
'id': self.id,
'session_id': self.session_id,
'sender_type': self.sender_type,
'sender_id': self.sender_id,
'message_type': self.message_type,
'content': self.content,
'content_type': self.content_type,
'reply_to': self.reply_to,
'status': self.status,
'ack_status': self.ack_status,
'retry_count': self.retry_count,
'created_at': self.created_at.isoformat() if self.created_at else None,
'delivered_at': self.delivered_at.isoformat() if self.delivered_at else None,
}
class Connection(db.Model):
"""连接模型"""
__tablename__ = 'connections'
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=generate_uuid)
socket_id: Mapped[str] = mapped_column(String(100), unique=True, nullable=False)
connection_type: Mapped[str] = mapped_column(String(20), nullable=False)
entity_id: Mapped[str] = mapped_column(String(36), nullable=False)
entity_type: Mapped[str] = mapped_column(String(20), nullable=False)
ip_address: Mapped[Optional[str]] = mapped_column(String(45), nullable=True)
user_agent: Mapped[Optional[str]] = mapped_column(String(500), nullable=True)
status: Mapped[str] = mapped_column(String(20), default='connected')
auth_token: Mapped[Optional[str]] = mapped_column(String(500), nullable=True)
connected_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
last_activity: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
disconnected_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
def __repr__(self):
return f'<Connection {self.socket_id}>'
def to_dict(self):
return {
'id': self.id,
'socket_id': self.socket_id,
'connection_type': self.connection_type,
'entity_id': self.entity_id,
'entity_type': self.entity_type,
'ip_address': self.ip_address,
'user_agent': self.user_agent,
'status': self.status,
'connected_at': self.connected_at.isoformat() if self.connected_at else None,
'last_activity': self.last_activity.isoformat() if self.last_activity else None,
}

3
app/routes/__init__.py Normal file
View File

@@ -0,0 +1,3 @@
"""
路由模块
"""

91
app/routes/agents.py Normal file
View File

@@ -0,0 +1,91 @@
"""
Agent 路由
"""
from flask import Blueprint, request, jsonify
from flask_jwt_extended import jwt_required
from datetime import datetime
from app.models import db, Agent
agents_bp = Blueprint('agents', __name__)
@agents_bp.route('/', methods=['GET'])
@jwt_required()
def get_agents():
"""获取 Agent 列表"""
agents = Agent.query.all()
return jsonify({'agents': [a.to_dict() for a in agents]}), 200
@agents_bp.route('/<agent_id>', methods=['GET'])
@jwt_required()
def get_agent(agent_id):
"""获取 Agent 详情"""
agent = Agent.query.get(agent_id)
if not agent:
return jsonify({'error': 'Agent not found'}), 404
return jsonify({'agent': agent.to_dict()}), 200
@agents_bp.route('/<agent_id>/status', methods=['GET'])
@jwt_required()
def get_agent_status(agent_id):
"""获取 Agent 实时状态"""
agent = Agent.query.get(agent_id)
if not agent:
return jsonify({'error': 'Agent not found'}), 404
return jsonify({
'agent_id': agent.id,
'status': agent.status,
'current_sessions': agent.current_sessions,
'last_heartbeat': agent.last_heartbeat.isoformat() if agent.last_heartbeat else None
}), 200
@agents_bp.route('/<agent_id>/config', methods=['PUT'])
@jwt_required()
def update_agent_config(agent_id):
"""更新 Agent 配置"""
agent = Agent.query.get(agent_id)
if not agent:
return jsonify({'error': 'Agent not found'}), 404
data = request.get_json()
if 'name' in data:
agent.name = data['name']
if 'display_name' in data:
agent.display_name = data['display_name']
if 'capabilities' in data:
agent.capabilities = data['capabilities']
if 'priority' in data:
agent.priority = data['priority']
if 'weight' in data:
agent.weight = data['weight']
db.session.commit()
return jsonify({'agent': agent.to_dict()}), 200
@agents_bp.route('/<agent_id>/heartbeat', methods=['POST'])
def agent_heartbeat(agent_id):
"""Agent 心跳上报"""
agent = Agent.query.get(agent_id)
if not agent:
return jsonify({'error': 'Agent not found'}), 404
agent.last_heartbeat = datetime.utcnow()
agent.status = 'online'
db.session.commit()
return jsonify({'status': 'ok'}), 200
@agents_bp.route('/available', methods=['GET'])
@jwt_required()
def get_available_agents():
"""获取可用 Agent 列表"""
agents = Agent.query.filter_by(status='online').all()
return jsonify({'agents': [a.to_dict() for a in agents]}), 200

142
app/routes/auth.py Normal file
View File

@@ -0,0 +1,142 @@
"""
认证路由
"""
from flask import Blueprint, request, jsonify
from flask_jwt_extended import (
create_access_token, create_refresh_token,
jwt_required, get_jwt_identity
)
from datetime import datetime
import bcrypt
from app.models import db, User
auth_bp = Blueprint('auth', __name__)
@auth_bp.route('/register', methods=['POST'])
def register():
"""注册新用户"""
data = request.get_json()
username = data.get('username')
email = data.get('email')
password = data.get('password')
nickname = data.get('nickname')
if not all([username, email, password]):
return jsonify({'error': 'Missing required fields'}), 400
# 检查用户是否已存在
if User.query.filter_by(username=username).first():
return jsonify({'error': 'Username already exists'}), 400
if User.query.filter_by(email=email).first():
return jsonify({'error': 'Email already exists'}), 400
# 密码哈希
password_hash = bcrypt.hashpw(
password.encode('utf-8'),
bcrypt.gensalt()
).decode('utf-8')
# 创建用户
user = User(
username=username,
email=email,
password_hash=password_hash,
nickname=nickname or username
)
db.session.add(user)
db.session.commit()
return jsonify({
'message': 'User registered successfully',
'user': user.to_dict()
}), 201
@auth_bp.route('/login', methods=['POST'])
def login():
"""用户登录"""
data = request.get_json()
username = data.get('username')
password = data.get('password')
if not all([username, password]):
return jsonify({'error': 'Missing username or password'}), 400
# 查找用户
user = User.query.filter_by(username=username).first()
if not user:
return jsonify({'error': 'Invalid username or password'}), 401
# 验证密码
if not bcrypt.checkpw(password.encode('utf-8'), user.password_hash.encode('utf-8')):
return jsonify({'error': 'Invalid username or password'}), 401
# 检查用户状态
if user.status != 'active':
return jsonify({'error': 'Account is disabled'}), 403
# 更新最后登录时间
user.last_login_at = datetime.utcnow()
db.session.commit()
# 生成 Token
access_token = create_access_token(identity=user.id)
refresh_token = create_refresh_token(identity=user.id)
return jsonify({
'access_token': access_token,
'refresh_token': refresh_token,
'user': user.to_dict()
}), 200
@auth_bp.route('/refresh', methods=['POST'])
@jwt_required(refresh=True)
def refresh():
"""刷新 Token"""
identity = get_jwt_identity()
access_token = create_access_token(identity=identity)
return jsonify({
'access_token': access_token
}), 200
@auth_bp.route('/me', methods=['GET'])
@jwt_required()
def get_current_user():
"""获取当前用户信息"""
user_id = get_jwt_identity()
user = User.query.get(user_id)
if not user:
return jsonify({'error': 'User not found'}), 404
return jsonify({'user': user.to_dict()}), 200
@auth_bp.route('/logout', methods=['POST'])
@jwt_required()
def logout():
"""用户登出"""
return jsonify({'message': 'Logged out successfully'}), 200
@auth_bp.route('/verify', methods=['POST'])
@jwt_required()
def verify_token():
"""验证 Token 有效性"""
user_id = get_jwt_identity()
user = User.query.get(user_id)
if not user:
return jsonify({'valid': False}), 401
return jsonify({
'valid': True,
'user': user.to_dict()
}), 200

99
app/routes/gateways.py Normal file
View File

@@ -0,0 +1,99 @@
"""
Gateway 路由
"""
from flask import Blueprint, request, jsonify
from flask_jwt_extended import jwt_required
from datetime import datetime
import bcrypt
from app.models import db, Gateway
gateways_bp = Blueprint('gateways', __name__)
@gateway_bp.route('/', methods=['GET'])
@jwt_required()
def get_gateways():
"""获取 Gateway 列表"""
gateways = Gateway.query.all()
return jsonify({'gateways': [g.to_dict() for g in gateways]}), 200
@gateway_bp.route('/', methods=['POST'])
@jwt_required()
def register_gateway():
"""注册 Gateway"""
data = request.get_json()
name = data.get('name')
url = data.get('url')
token = data.get('token')
if not all([name, url]):
return jsonify({'error': 'Missing required fields'}), 400
if Gateway.query.filter_by(name=name).first():
return jsonify({'error': 'Gateway name already exists'}), 400
# Token 哈希
token_hash = None
if token:
token_hash = bcrypt.hashpw(
token.encode('utf-8'),
bcrypt.gensalt()
).decode('utf-8')
gateway = Gateway(
name=name,
url=url,
token_hash=token_hash,
status='offline'
)
db.session.add(gateway)
db.session.commit()
return jsonify({'gateway': gateway.to_dict()}), 201
@gateway_bp.route('/<gateway_id>', methods=['DELETE'])
@jwt_required()
def delete_gateway(gateway_id):
"""注销 Gateway"""
gateway = Gateway.query.get(gateway_id)
if not gateway:
return jsonify({'error': 'Gateway not found'}), 404
db.session.delete(gateway)
db.session.commit()
return jsonify({'message': 'Gateway deleted'}), 200
@gateway_bp.route('/<gateway_id>/status', methods=['GET'])
@jwt_required()
def get_gateway_status(gateway_id):
"""获取 Gateway 状态"""
gateway = Gateway.query.get(gateway_id)
if not gateway:
return jsonify({'error': 'Gateway not found'}), 404
return jsonify({
'gateway_id': gateway.id,
'status': gateway.status,
'agent_count': gateway.agent_count,
'last_heartbeat': gateway.last_heartbeat.isoformat() if gateway.last_heartbeat else None
}), 200
@gateway_bp.route('/<gateway_id>/heartbeat', methods=['POST'])
def gateway_heartbeat(gateway_id):
"""Gateway 心跳上报"""
gateway = Gateway.query.get(gateway_id)
if not gateway:
return jsonify({'error': 'Gateway not found'}), 404
gateway.last_heartbeat = datetime.utcnow()
gateway.status = 'online'
db.session.commit()
return jsonify({'status': 'ok'}), 200

91
app/routes/messages.py Normal file
View File

@@ -0,0 +1,91 @@
"""
消息路由
"""
from flask import Blueprint, request, jsonify
from flask_jwt_extended import jwt_required, get_jwt_identity
from datetime import datetime
from app.models import db, Message, Session
messages_bp = Blueprint('messages', __name__)
@messages_bp.route('/', methods=['POST'])
@jwt_required()
def send_message():
"""发送消息 (HTTP 方式)"""
user_id = get_jwt_identity()
data = request.get_json()
session_id = data.get('session_id')
content = data.get('content')
message_type = data.get('type', 'text')
reply_to = data.get('reply_to')
if not all([session_id, content]):
return jsonify({'error': 'Missing required fields'}), 400
# 验证会话
session = Session.query.filter_by(id=session_id, user_id=user_id).first()
if not session:
return jsonify({'error': 'Session not found'}), 404
# 创建消息
message = Message(
session_id=session_id,
sender_type='user',
sender_id=user_id,
message_type=message_type,
content=content,
reply_to=reply_to,
status='sent',
ack_status='pending'
)
db.session.add(message)
# 更新会话
session.message_count += 1
session.last_active_at = datetime.utcnow()
db.session.commit()
return jsonify({'message': message.to_dict()}), 201
@messages_bp.route('/<message_id>', methods=['GET'])
@jwt_required()
def get_message(message_id):
"""获取单条消息"""
message = Message.query.get(message_id)
if not message:
return jsonify({'error': 'Message not found'}), 404
return jsonify({'message': message.to_dict()}), 200
@messages_bp.route('/<message_id>/ack', methods=['PUT'])
def acknowledge_message(message_id):
"""确认消息已送达"""
message = Message.query.get(message_id)
if not message:
return jsonify({'error': 'Message not found'}), 404
message.ack_status = 'acknowledged'
message.delivered_at = datetime.utcnow()
db.session.commit()
return jsonify({'message': message.to_dict()}), 200
@messages_bp.route('/<message_id>/read', methods=['PUT'])
@jwt_required()
def mark_message_read(message_id):
"""标记消息已读"""
message = Message.query.get(message_id)
if not message:
return jsonify({'error': 'Message not found'}), 404
message.status = 'read'
db.session.commit()
return jsonify({'message': message.to_dict()}), 200

102
app/routes/sessions.py Normal file
View File

@@ -0,0 +1,102 @@
"""
会话路由
"""
from flask import Blueprint, request, jsonify
from flask_jwt_extended import jwt_required, get_jwt_identity
from datetime import datetime
from app.models import db, Session, Agent
sessions_bp = Blueprint('sessions', __name__)
@sessions_bp.route('/', methods=['GET'])
@jwt_required()
def get_sessions():
"""获取会话列表"""
user_id = get_jwt_identity()
sessions = Session.query.filter_by(user_id=user_id).all()
return jsonify({
'sessions': [s.to_dict() for s in sessions]
}), 200
@sessions_bp.route('/', methods=['POST'])
@jwt_required()
def create_session():
"""创建会话"""
user_id = get_jwt_identity()
data = request.get_json()
title = data.get('title', 'New Session')
agent_id = data.get('agent_id')
priority = data.get('priority', 5)
# 如果没有指定 Agent分配一个
if not agent_id:
agent = Agent.query.filter_by(status='online').order_by(
Agent.current_sessions.asc()
).first()
if agent:
agent_id = agent.id
session = Session(
user_id=user_id,
primary_agent_id=agent_id,
title=title,
status='active'
)
db.session.add(session)
db.session.commit()
return jsonify({
'session': session.to_dict()
}), 201
@sessions_bp.route('/<session_id>', methods=['GET'])
@jwt_required()
def get_session(session_id):
"""获取会话详情"""
user_id = get_jwt_identity()
session = Session.query.filter_by(id=session_id, user_id=user_id).first()
if not session:
return jsonify({'error': 'Session not found'}), 404
return jsonify({'session': session.to_dict()}), 200
@sessions_bp.route('/<session_id>/close', methods=['PUT'])
@jwt_required()
def close_session(session_id):
"""关闭会话"""
user_id = get_jwt_identity()
session = Session.query.filter_by(id=session_id, user_id=user_id).first()
if not session:
return jsonify({'error': 'Session not found'}), 404
session.status = 'closed'
session.updated_at = datetime.utcnow()
db.session.commit()
return jsonify({'message': 'Session closed', 'session': session.to_dict()}), 200
@sessions_bp.route('/<session_id>/messages', methods=['GET'])
@jwt_required()
def get_session_messages(session_id):
"""获取会话消息"""
user_id = get_jwt_identity()
session = Session.query.filter_by(id=session_id, user_id=user_id).first()
if not session:
return jsonify({'error': 'Session not found'}), 404
messages = session.messages.order_by('created_at').all()
return jsonify({
'messages': [m.to_dict() for m in messages]
}), 200

67
app/routes/stats.py Normal file
View File

@@ -0,0 +1,67 @@
"""
统计路由
"""
from flask import Blueprint, jsonify
from flask_jwt_extended import jwt_required
from app.models import db, User, Session, Agent, Gateway, Message
from sqlalchemy import func
stats_bp = Blueprint('stats', __name__)
@stats_bp.route('/', methods=['GET'])
@jwt_required()
def get_stats():
"""获取系统统计信息"""
stats = {
'users': User.query.count(),
'sessions': Session.query.filter_by(status='active').count(),
'agents': Agent.query.filter_by(status='online').count(),
'gateways': Gateway.query.filter_by(status='online').count(),
'messages': Message.query.count(),
}
return jsonify({'stats': stats}), 200
@stats_bp.route('/sessions', methods=['GET'])
@jwt_required()
def get_session_stats():
"""获取会话统计"""
stats = {
'total': Session.query.count(),
'active': Session.query.filter_by(status='active').count(),
'closed': Session.query.filter_by(status='closed').count(),
'paused': Session.query.filter_by(status='paused').count(),
}
return jsonify({'stats': stats}), 200
@stats_bp.route('/messages', methods=['GET'])
@jwt_required()
def get_message_stats():
"""获取消息统计"""
stats = {
'total': Message.query.count(),
'sent': Message.query.filter_by(status='sent').count(),
'delivered': Message.query.filter_by(status='delivered').count(),
'read': Message.query.filter_by(status='read').count(),
'failed': Message.query.filter_by(status='failed').count(),
}
return jsonify({'stats': stats}), 200
@stats_bp.route('/agents', methods=['GET'])
@jwt_required()
def get_agent_stats():
"""获取 Agent 统计"""
stats = {
'total': Agent.query.count(),
'online': Agent.query.filter_by(status='online').count(),
'offline': Agent.query.filter_by(status='offline').count(),
'busy': Agent.query.filter_by(status='busy').count(),
}
return jsonify({'stats': stats}), 200

3
app/services/__init__.py Normal file
View File

@@ -0,0 +1,3 @@
"""
服务模块
"""

3
app/socketio/__init__.py Normal file
View File

@@ -0,0 +1,3 @@
"""
Socket.IO 模块
"""

54
app/socketio/events.py Normal file
View File

@@ -0,0 +1,54 @@
"""
Socket.IO 事件定义
"""
# 认证事件
AUTH_EVENTS = [
'auth', # C→S: 认证请求
'authenticated', # S→C: 认证成功
'auth_error', # S→C: 认证失败
]
# 心跳事件
HEARTBEAT_EVENTS = [
'ping', # C→S: 心跳请求
'pong', # S→C: 心跳响应
'heartbeat_timeout', # S→C: 心跳超时
]
# 会话事件
SESSION_EVENTS = [
'session.create', # C→S: 创建会话
'session.created', # S→C: 会话已创建
'session.join', # C→S: 加入会话
'session.joined', # S→C: 已加入会话
'session.leave', # C→S: 离开会话
'session.left', # S→C: 已离开会话
'session.closed', # S→C: 会话被关闭
'session.assigned', # S→C: Agent 分配通知
]
# 消息事件
MESSAGE_EVENTS = [
'message.send', # C→S: 发送消息
'message', # S→C: 收到消息
'message.ack', # C→S: 消息确认
'message.acked', # S→C: 确认已收到
'message.read', # C→S: 消息已读
'message.stream', # S→C: 流式消息
'typing', # C→S: 正在输入
]
# 错误事件
ERROR_EVENTS = [
'error', # S→C: 通用错误
'session_error', # S→C: 会话错误
'message_error', # S→C: 消息错误
]
ALL_EVENTS = (
AUTH_EVENTS +
HEARTBEAT_EVENTS +
SESSION_EVENTS +
MESSAGE_EVENTS +
ERROR_EVENTS
)

182
app/socketio/handlers.py Normal file
View File

@@ -0,0 +1,182 @@
"""
Socket.IO 事件处理
"""
from flask_socketio import emit, join_room, leave_room
from flask_jwt_extended import decode_token
from datetime import datetime
import json
from app.extensions import redis_client
# 连接管理器
class ConnectionManager:
def __init__(self):
self.user_sockets = {} # user_id -> socket_id
self.agent_sockets = {} # agent_id -> socket_id
self.socket_sessions = {} # socket_id -> session_id
connection_manager = ConnectionManager()
def register_handlers(socketio):
"""注册 Socket.IO 事件处理器"""
@socketio.on('connect')
def handle_connect():
"""客户端连接"""
print('Client connected')
# 发送认证请求
emit('auth', {'message': 'Please authenticate'})
@socketio.on('disconnect')
def handle_disconnect():
"""客户端断开连接"""
from flask import request
sid = request.sid
print(f'Client disconnected: {sid}')
# 清理连接
if sid in connection_manager.socket_sessions:
del connection_manager.socket_sessions[sid]
@socketio.on('auth')
def handle_auth(data):
"""处理认证"""
from flask import request
sid = request.sid
token = data.get('token')
if not token:
emit('auth_error', {'code': 'MISSING_TOKEN', 'message': 'Token is required'})
return
try:
# 验证 JWT Token
decoded = decode_token(token)
user_id = decoded['sub']
# 保存连接信息
connection_manager.user_sockets[user_id] = sid
# 存储到 Redis
redis_client.hset(f'socket:{sid}', mapping={
'user_id': user_id,
'connected_at': datetime.utcnow().isoformat()
})
redis_client.expire(f'socket:{sid}', 86400)
emit('authenticated', {
'user_id': user_id,
'socket_id': sid
})
except Exception as e:
emit('auth_error', {'code': 'INVALID_TOKEN', 'message': str(e)})
@socketio.on('ping')
def handle_ping(data):
"""心跳响应"""
emit('pong', {'timestamp': datetime.utcnow().timestamp()})
@socketio.on('session.create')
def handle_session_create(data):
"""创建会话"""
from app.models import db, Session, Agent
user_id = data.get('user_id')
title = data.get('title', 'New Session')
agent_id = data.get('agent_id')
# 如果没有指定 Agent分配一个
if not agent_id:
agent = Agent.query.filter_by(status='online').first()
if agent:
agent_id = agent.id
session = Session(
user_id=user_id,
primary_agent_id=agent_id,
title=title
)
db.session.add(session)
db.session.commit()
emit('session.created', {
'session_id': session.id,
'agent_id': agent_id
})
@socketio.on('session.join')
def handle_session_join(data):
"""加入会话"""
from flask import request
from app.models import Session
sid = request.sid
session_id = data.get('session_id')
# 加入房间
join_room(session_id)
connection_manager.socket_sessions[sid] = session_id
session = Session.query.get(session_id)
if session:
emit('session.joined', {
'session_id': session_id,
'participants': [session.user_id, session.primary_agent_id]
})
@socketio.on('message.send')
def handle_message_send(data):
"""发送消息"""
from flask import request
from app.models import db, Message, Session
sid = request.sid
session_id = data.get('session_id')
content = data.get('content')
sender_type = data.get('sender_type', 'user')
sender_id = data.get('sender_id')
# 创建消息
message = Message(
session_id=session_id,
sender_type=sender_type,
sender_id=sender_id,
content=content,
status='sent'
)
db.session.add(message)
# 更新会话
session = Session.query.get(session_id)
if session:
session.message_count += 1
session.last_active_at = datetime.utcnow()
db.session.commit()
# 广播消息到房间
emit('message', {
'message_id': message.id,
'session_id': session_id,
'sender': {'type': sender_type, 'id': sender_id},
'content': content,
'timestamp': datetime.utcnow().isoformat()
}, room=session_id)
@socketio.on('message.ack')
def handle_message_ack(data):
"""消息确认"""
from app.models import db, Message
message_id = data.get('message_id')
status = data.get('status')
message = Message.query.get(message_id)
if message:
message.ack_status = status
if status == 'acknowledged':
message.status = 'delivered'
message.delivered_at = datetime.utcnow()
db.session.commit()
emit('message.acked', {'message_id': message_id, 'status': status})

3
app/utils/__init__.py Normal file
View File

@@ -0,0 +1,3 @@
"""
工具模块
"""

80
docker-compose.yaml Normal file
View File

@@ -0,0 +1,80 @@
version: '3.8'
services:
pit-router:
build: .
ports:
- "9000:9000"
environment:
- FLASK_ENV=production
- SECRET_KEY=${SECRET_KEY}
- JWT_SECRET_KEY=${JWT_SECRET_KEY}
- DATABASE_URL=postgresql://postgres:${DB_PASSWORD}@postgres:5432/pit_router
- REDIS_URL=redis://redis:6379/0
volumes:
- pit-data:/app/data
- pit-logs:/app/logs
depends_on:
postgres:
condition: service_healthy
redis:
condition: service_healthy
restart: unless-stopped
healthcheck:
test: ["CMD", "python", "-c", "import requests; requests.get('http://localhost:9000/health')"]
interval: 30s
timeout: 10s
retries: 3
start_period: 40s
deploy:
resources:
limits:
cpus: '2.0'
memory: 2G
reservations:
cpus: '0.5'
memory: 512M
postgres:
image: postgres:15-alpine
environment:
- POSTGRES_USER=postgres
- POSTGRES_PASSWORD=${DB_PASSWORD}
- POSTGRES_DB=pit_router
volumes:
- postgres-data:/var/lib/postgresql/data
restart: unless-stopped
healthcheck:
test: ["CMD-SHELL", "pg_isready -U postgres -d pit_router"]
interval: 10s
timeout: 5s
retries: 5
redis:
image: redis:7-alpine
volumes:
- redis-data:/data
restart: unless-stopped
healthcheck:
test: ["CMD", "redis-cli", "ping"]
interval: 10s
timeout: 5s
retries: 5
nginx:
image: nginx:alpine
ports:
- "80:80"
- "443:443"
volumes:
- ./nginx.conf:/etc/nginx/nginx.conf:ro
- ./ssl:/etc/nginx/ssl:ro
depends_on:
- pit-router
restart: unless-stopped
volumes:
pit-data:
pit-logs:
postgres-data:
redis-data:

7
requirements-dev.txt Normal file
View File

@@ -0,0 +1,7 @@
pytest==8.0.2
pytest-flask==1.3.0
pytest-cov==4.1.0
black==24.2.0
flake8==7.0.0
mypy==1.8.0
ipython==8.22.1

20
requirements.txt Normal file
View File

@@ -0,0 +1,20 @@
alembic==1.13.1
Flask==3.0.2
Flask-CORS==4.0.0
Flask-JWT-Extended==4.6.0
Flask-Limiter==3.5.0
Flask-Login==0.6.3
Flask-Migrate==4.0.5
Flask-SocketIO==5.3.6
Flask-SQLAlchemy==3.1.1
gunicorn==21.2.0
psycopg2-binary==2.9.9
PyJWT==2.8.0
python-dotenv==1.0.0
redis==5.0.1
SQLAlchemy==2.0.27
Werkzeug==3.0.1
bcrypt==4.1.2
marshmallow==3.21.0
marshmallow-sqlalchemy==1.0.0
python-dateutil==2.9.0

10
run.py Normal file
View File

@@ -0,0 +1,10 @@
"""
PIT Router 启动入口
"""
from app import create_app
from app.extensions import socketio
app = create_app('development')
if __name__ == '__main__':
socketio.run(app, host='0.0.0.0', port=9000, debug=True)