feat: add validation for minimum length on various fields and update type definitions
This commit is contained in:
@@ -5,7 +5,8 @@ from dataclasses import dataclass
|
||||
from typing import Literal
|
||||
|
||||
from fastapi import Depends, Header, HTTPException, Request, status
|
||||
from sqlmodel import Session, col, select
|
||||
from sqlmodel import col, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.core.agent_tokens import verify_agent_token
|
||||
from app.db.session import get_session
|
||||
@@ -20,8 +21,10 @@ class AgentAuthContext:
|
||||
agent: Agent
|
||||
|
||||
|
||||
def _find_agent_for_token(session: Session, token: str) -> Agent | None:
|
||||
agents = list(session.exec(select(Agent).where(col(Agent.agent_token_hash).is_not(None))))
|
||||
async def _find_agent_for_token(session: AsyncSession, token: str) -> Agent | None:
|
||||
agents = list(
|
||||
await session.exec(select(Agent).where(col(Agent.agent_token_hash).is_not(None)))
|
||||
)
|
||||
for agent in agents:
|
||||
if agent.agent_token_hash and verify_agent_token(token, agent.agent_token_hash):
|
||||
return agent
|
||||
@@ -48,11 +51,11 @@ def _resolve_agent_token(
|
||||
return None
|
||||
|
||||
|
||||
def get_agent_auth_context(
|
||||
async def get_agent_auth_context(
|
||||
request: Request,
|
||||
agent_token: str | None = Header(default=None, alias="X-Agent-Token"),
|
||||
authorization: str | None = Header(default=None, alias="Authorization"),
|
||||
session: Session = Depends(get_session),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> AgentAuthContext:
|
||||
resolved = _resolve_agent_token(agent_token, authorization, accept_authorization=True)
|
||||
if not resolved:
|
||||
@@ -63,7 +66,7 @@ def get_agent_auth_context(
|
||||
bool(authorization),
|
||||
)
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
|
||||
agent = _find_agent_for_token(session, resolved)
|
||||
agent = await _find_agent_for_token(session, resolved)
|
||||
if agent is None:
|
||||
logger.warning(
|
||||
"agent auth invalid token path=%s token_prefix=%s",
|
||||
@@ -74,11 +77,11 @@ def get_agent_auth_context(
|
||||
return AgentAuthContext(actor_type="agent", agent=agent)
|
||||
|
||||
|
||||
def get_agent_auth_context_optional(
|
||||
async def get_agent_auth_context_optional(
|
||||
request: Request,
|
||||
agent_token: str | None = Header(default=None, alias="X-Agent-Token"),
|
||||
authorization: str | None = Header(default=None, alias="Authorization"),
|
||||
session: Session = Depends(get_session),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> AgentAuthContext | None:
|
||||
resolved = _resolve_agent_token(
|
||||
agent_token,
|
||||
@@ -94,7 +97,7 @@ def get_agent_auth_context_optional(
|
||||
bool(authorization),
|
||||
)
|
||||
return None
|
||||
agent = _find_agent_for_token(session, resolved)
|
||||
agent = await _find_agent_for_token(session, resolved)
|
||||
if agent is None:
|
||||
logger.warning(
|
||||
"agent auth optional invalid token path=%s token_prefix=%s",
|
||||
|
||||
@@ -9,9 +9,10 @@ from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from fastapi_clerk_auth import ClerkConfig, ClerkHTTPBearer
|
||||
from fastapi_clerk_auth import HTTPAuthorizationCredentials as ClerkCredentials
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from sqlmodel import Session, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.core.config import settings
|
||||
from app.db import crud
|
||||
from app.db.session import get_session
|
||||
from app.models.users import User
|
||||
|
||||
@@ -44,7 +45,9 @@ def _resolve_clerk_auth(
|
||||
request: Request, fallback: ClerkCredentials | None
|
||||
) -> ClerkCredentials | None:
|
||||
auth_data = getattr(request.state, "clerk_auth", None)
|
||||
return auth_data or fallback
|
||||
if isinstance(auth_data, ClerkCredentials):
|
||||
return auth_data
|
||||
return fallback
|
||||
|
||||
|
||||
def _parse_subject(auth_data: ClerkCredentials | None) -> str | None:
|
||||
@@ -57,7 +60,7 @@ def _parse_subject(auth_data: ClerkCredentials | None) -> str | None:
|
||||
async def get_auth_context(
|
||||
request: Request,
|
||||
credentials: HTTPAuthorizationCredentials | None = Depends(security),
|
||||
session: Session = Depends(get_session),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> AuthContext:
|
||||
if credentials is None:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
|
||||
@@ -79,17 +82,21 @@ async def get_auth_context(
|
||||
if not clerk_user_id:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
|
||||
|
||||
user = session.exec(select(User).where(User.clerk_user_id == clerk_user_id)).first()
|
||||
if user is None:
|
||||
claims = auth_data.decoded if auth_data and auth_data.decoded else {}
|
||||
user = User(
|
||||
clerk_user_id=clerk_user_id,
|
||||
email=claims.get("email"),
|
||||
name=claims.get("name"),
|
||||
)
|
||||
session.add(user)
|
||||
session.commit()
|
||||
session.refresh(user)
|
||||
claims: dict[str, object] = {}
|
||||
if auth_data and auth_data.decoded:
|
||||
claims = auth_data.decoded
|
||||
email_obj = claims.get("email")
|
||||
name_obj = claims.get("name")
|
||||
defaults: dict[str, object | None] = {
|
||||
"email": email_obj if isinstance(email_obj, str) else None,
|
||||
"name": name_obj if isinstance(name_obj, str) else None,
|
||||
}
|
||||
user, _created = await crud.get_or_create(
|
||||
session,
|
||||
User,
|
||||
clerk_user_id=clerk_user_id,
|
||||
defaults=defaults,
|
||||
)
|
||||
|
||||
return AuthContext(
|
||||
actor_type="user",
|
||||
@@ -100,7 +107,7 @@ async def get_auth_context(
|
||||
async def get_auth_context_optional(
|
||||
request: Request,
|
||||
credentials: HTTPAuthorizationCredentials | None = Depends(security),
|
||||
session: Session = Depends(get_session),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> AuthContext | None:
|
||||
if request.headers.get("X-Agent-Token"):
|
||||
return None
|
||||
@@ -124,17 +131,21 @@ async def get_auth_context_optional(
|
||||
if not clerk_user_id:
|
||||
return None
|
||||
|
||||
user = session.exec(select(User).where(User.clerk_user_id == clerk_user_id)).first()
|
||||
if user is None:
|
||||
claims = auth_data.decoded if auth_data and auth_data.decoded else {}
|
||||
user = User(
|
||||
clerk_user_id=clerk_user_id,
|
||||
email=claims.get("email"),
|
||||
name=claims.get("name"),
|
||||
)
|
||||
session.add(user)
|
||||
session.commit()
|
||||
session.refresh(user)
|
||||
claims: dict[str, object] = {}
|
||||
if auth_data and auth_data.decoded:
|
||||
claims = auth_data.decoded
|
||||
email_obj = claims.get("email")
|
||||
name_obj = claims.get("name")
|
||||
defaults: dict[str, object | None] = {
|
||||
"email": email_obj if isinstance(email_obj, str) else None,
|
||||
"name": name_obj if isinstance(name_obj, str) else None,
|
||||
}
|
||||
user, _created = await crud.get_or_create(
|
||||
session,
|
||||
User,
|
||||
clerk_user_id=clerk_user_id,
|
||||
defaults=defaults,
|
||||
)
|
||||
|
||||
return AuthContext(
|
||||
actor_type="user",
|
||||
|
||||
@@ -20,7 +20,7 @@ def _trace(self: logging.Logger, message: str, *args: Any, **kwargs: Any) -> Non
|
||||
self._log(TRACE_LEVEL, message, args, **kwargs)
|
||||
|
||||
|
||||
logging.Logger.trace = _trace # type: ignore[attr-defined]
|
||||
setattr(logging.Logger, "trace", _trace)
|
||||
|
||||
_STANDARD_LOG_RECORD_ATTRS = {
|
||||
"args",
|
||||
|
||||
11
backend/app/core/time.py
Normal file
11
backend/app/core/time.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
|
||||
def utcnow() -> datetime:
|
||||
"""Return a naive UTC datetime without using deprecated datetime.utcnow()."""
|
||||
|
||||
# Keep naive UTC values for compatibility with existing DB schema/queries.
|
||||
return datetime.now(UTC).replace(tzinfo=None)
|
||||
|
||||
Reference in New Issue
Block a user