feat: Phase 3 - 工具层 + 测试 + 数据库迁移
This commit is contained in:
121
app/utils/security.py
Normal file
121
app/utils/security.py
Normal file
@@ -0,0 +1,121 @@
|
||||
"""
|
||||
安全工具
|
||||
"""
|
||||
import hashlib
|
||||
import hmac
|
||||
import secrets
|
||||
import bcrypt
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def generate_token(length: int = 32) -> str:
|
||||
"""生成随机 Token"""
|
||||
return secrets.token_urlsafe(length)
|
||||
|
||||
|
||||
def hash_password(password: str) -> str:
|
||||
"""密码哈希"""
|
||||
return bcrypt.hashpw(
|
||||
password.encode('utf-8'),
|
||||
bcrypt.gensalt()
|
||||
).decode('utf-8')
|
||||
|
||||
|
||||
def verify_password(password: str, password_hash: str) -> bool:
|
||||
"""验证密码"""
|
||||
return bcrypt.checkpw(
|
||||
password.encode('utf-8'),
|
||||
password_hash.encode('utf-8')
|
||||
)
|
||||
|
||||
|
||||
def hash_token(token: str) -> str:
|
||||
"""Token 哈希(用于存储)"""
|
||||
return hashlib.sha256(token.encode()).hexdigest()
|
||||
|
||||
|
||||
def verify_token_hash(token: str, token_hash: str) -> bool:
|
||||
"""验证 Token 哈希"""
|
||||
return hash_token(token) == token_hash
|
||||
|
||||
|
||||
def generate_api_key() -> str:
|
||||
"""生成 API Key"""
|
||||
return f"pit_{secrets.token_urlsafe(32)}"
|
||||
|
||||
|
||||
def secure_compare(val1: str, val2: str) -> bool:
|
||||
"""安全字符串比较(防时序攻击)"""
|
||||
return hmac.compare_digest(val1.encode(), val2.encode())
|
||||
|
||||
|
||||
def mask_sensitive_data(data: str, visible_chars: int = 4) -> str:
|
||||
"""脱敏处理"""
|
||||
if len(data) <= visible_chars * 2:
|
||||
return '*' * len(data)
|
||||
|
||||
return data[:visible_chars] + '*' * (len(data) - visible_chars * 2) + data[-visible_chars:]
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""简单内存限流器"""
|
||||
|
||||
def __init__(self):
|
||||
self._storage = {}
|
||||
|
||||
def is_allowed(self, key: str, limit: int = 100, window: int = 60) -> bool:
|
||||
"""检查是否允许请求"""
|
||||
from time import time
|
||||
|
||||
now = time()
|
||||
window_start = now - window
|
||||
|
||||
# 获取该 key 的请求记录
|
||||
requests = self._storage.get(key, [])
|
||||
|
||||
# 清理过期记录
|
||||
requests = [t for t in requests if t > window_start]
|
||||
|
||||
# 检查是否超过限制
|
||||
if len(requests) >= limit:
|
||||
self._storage[key] = requests
|
||||
return False
|
||||
|
||||
# 记录新请求
|
||||
requests.append(now)
|
||||
self._storage[key] = requests
|
||||
|
||||
return True
|
||||
|
||||
|
||||
class IPWhitelist:
|
||||
"""IP 白名单"""
|
||||
|
||||
def __init__(self, allowed_ips: list = None):
|
||||
self.allowed_ips = set(allowed_ips or [])
|
||||
self.allow_all = '*' in self.allowed_ips
|
||||
|
||||
def is_allowed(self, ip: str) -> bool:
|
||||
"""检查 IP 是否允许"""
|
||||
if self.allow_all:
|
||||
return True
|
||||
|
||||
# 支持 CIDR 格式
|
||||
for allowed in self.allowed_ips:
|
||||
if '/' in allowed:
|
||||
if self._ip_in_cidr(ip, allowed):
|
||||
return True
|
||||
elif ip == allowed:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _ip_in_cidr(self, ip: str, cidr: str) -> bool:
|
||||
"""检查 IP 是否在 CIDR 范围内"""
|
||||
try:
|
||||
import ipaddress
|
||||
network = ipaddress.ip_network(cidr, strict=False)
|
||||
address = ipaddress.ip_address(ip)
|
||||
return address in network
|
||||
except:
|
||||
return False
|
||||
Reference in New Issue
Block a user