refactor: improve type coercion functions and enhance type hints across multiple files

This commit is contained in:
Abhimanyu Saharan
2026-02-09 17:43:42 +05:30
parent f5d592f61a
commit dddd1e9a7a
13 changed files with 217 additions and 64 deletions

View File

@@ -5,9 +5,8 @@ from __future__ import annotations
import asyncio import asyncio
import json import json
from collections import deque from collections import deque
from collections.abc import Sequence
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any, cast from typing import TYPE_CHECKING, Any
from uuid import UUID from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
@@ -35,7 +34,7 @@ from app.services.organizations import (
) )
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import AsyncIterator from collections.abc import AsyncIterator, Sequence
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
@@ -43,6 +42,7 @@ router = APIRouter(prefix="/activity", tags=["activity"])
SSE_SEEN_MAX = 2000 SSE_SEEN_MAX = 2000
STREAM_POLL_SECONDS = 2 STREAM_POLL_SECONDS = 2
TASK_COMMENT_ROW_LEN = 4
SESSION_DEP = Depends(get_session) SESSION_DEP = Depends(get_session)
ACTOR_DEP = Depends(require_admin_or_agent) ACTOR_DEP = Depends(require_admin_or_agent)
ORG_MEMBER_DEP = Depends(require_org_member) ORG_MEMBER_DEP = Depends(require_org_member)
@@ -100,6 +100,26 @@ def _feed_item(
) )
def _coerce_task_comment_rows(
items: Sequence[Any],
) -> list[tuple[ActivityEvent, Task, Board, Agent | None]]:
rows: list[tuple[ActivityEvent, Task, Board, Agent | None]] = []
for item in items:
if (
isinstance(item, tuple)
and len(item) == TASK_COMMENT_ROW_LEN
and isinstance(item[0], ActivityEvent)
and isinstance(item[1], Task)
and isinstance(item[2], Board)
and (isinstance(item[3], Agent) or item[3] is None)
):
rows.append((item[0], item[1], item[2], item[3]))
continue
msg = "Expected (ActivityEvent, Task, Board, Agent | None) rows"
raise TypeError(msg)
return rows
async def _fetch_task_comment_events( async def _fetch_task_comment_events(
session: AsyncSession, session: AsyncSession,
since: datetime, since: datetime,
@@ -118,10 +138,7 @@ async def _fetch_task_comment_events(
) )
if board_id is not None: if board_id is not None:
statement = statement.where(col(Task.board_id) == board_id) statement = statement.where(col(Task.board_id) == board_id)
return cast( return _coerce_task_comment_rows(list(await session.exec(statement)))
Sequence[tuple[ActivityEvent, Task, Board, Agent | None]],
list(await session.exec(statement)),
)
@router.get("", response_model=DefaultLimitOffsetPage[ActivityEventRead]) @router.get("", response_model=DefaultLimitOffsetPage[ActivityEventRead])
@@ -179,7 +196,7 @@ async def list_task_comment_feed(
statement = statement.where(col(Task.id).is_(None)) statement = statement.where(col(Task.id).is_(None))
def _transform(items: Sequence[Any]) -> Sequence[Any]: def _transform(items: Sequence[Any]) -> Sequence[Any]:
rows = cast(Sequence[tuple[ActivityEvent, Task, Board, Agent | None]], items) rows = _coerce_task_comment_rows(items)
return [ return [
_feed_item(event, task, board, agent) _feed_item(event, task, board, agent)
for event, task, board, agent in rows for event, task, board, agent in rows

View File

@@ -3,9 +3,7 @@
from __future__ import annotations from __future__ import annotations
import re import re
from collections.abc import Sequence from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, cast
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, Query, status from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlmodel import SQLModel, col, select from sqlmodel import SQLModel, col, select
@@ -75,6 +73,9 @@ from app.services.task_dependencies import (
) )
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import Sequence
from uuid import UUID
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
from app.models.activity_events import ActivityEvent from app.models.activity_events import ActivityEvent
@@ -97,6 +98,16 @@ IS_CHAT_QUERY = Query(default=None)
APPROVAL_STATUS_QUERY = Query(default=None, alias="status") APPROVAL_STATUS_QUERY = Query(default=None, alias="status")
def _coerce_agent_items(items: Sequence[Any]) -> list[Agent]:
agents: list[Agent] = []
for item in items:
if not isinstance(item, Agent):
msg = "Expected Agent items from paginated query"
raise TypeError(msg)
agents.append(item)
return agents
def _gateway_agent_id(agent: Agent) -> str: def _gateway_agent_id(agent: Agent) -> str:
session_key = agent.openclaw_session_id or "" session_key = agent.openclaw_session_id or ""
if session_key.startswith(_AGENT_SESSION_PREFIX): if session_key.startswith(_AGENT_SESSION_PREFIX):
@@ -248,7 +259,7 @@ async def list_agents(
statement = statement.order_by(col(Agent.created_at).desc()) statement = statement.order_by(col(Agent.created_at).desc())
def _transform(items: Sequence[Any]) -> Sequence[Any]: def _transform(items: Sequence[Any]) -> Sequence[Any]:
agents = cast(Sequence[Agent], items) agents = _coerce_agent_items(items)
return [ return [
agents_api.to_agent_read( agents_api.to_agent_read(
agents_api.with_computed_status(agent), agents_api.with_computed_status(agent),
@@ -275,7 +286,7 @@ async def list_tasks(
unassigned=filters.unassigned, unassigned=filters.unassigned,
board=board, board=board,
session=session, session=session,
actor=_actor(agent_ctx), _actor=_actor(agent_ctx),
) )
@@ -290,8 +301,8 @@ async def create_task(
_guard_board_access(agent_ctx, board) _guard_board_access(agent_ctx, board)
if not agent_ctx.agent.is_board_lead: if not agent_ctx.agent.is_board_lead:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
data = payload.model_dump() data = payload.model_dump(exclude={"depends_on_task_ids"})
depends_on_task_ids = cast(list[UUID], data.pop("depends_on_task_ids", []) or []) depends_on_task_ids = list(payload.depends_on_task_ids)
task = Task.model_validate(data) task = Task.model_validate(data)
task.board_id = board.id task.board_id = board.id
@@ -456,7 +467,7 @@ async def list_board_memory(
is_chat=is_chat, is_chat=is_chat,
board=board, board=board,
session=session, session=session,
actor=_actor(agent_ctx), _actor=_actor(agent_ctx),
) )
@@ -493,7 +504,7 @@ async def list_approvals(
status_filter=status_filter, status_filter=status_filter,
board=board, board=board,
session=session, session=session,
actor=_actor(agent_ctx), _actor=_actor(agent_ctx),
) )
@@ -510,7 +521,7 @@ async def create_approval(
payload=payload, payload=payload,
board=board, board=board,
session=session, session=session,
actor=_actor(agent_ctx), _actor=_actor(agent_ctx),
) )

View File

@@ -5,10 +5,9 @@ from __future__ import annotations
import asyncio import asyncio
import json import json
import re import re
from collections.abc import AsyncIterator, Sequence
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, Any, cast from typing import TYPE_CHECKING, Any
from uuid import UUID, uuid4 from uuid import UUID, uuid4
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
@@ -64,8 +63,11 @@ from app.services.organizations import (
) )
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import AsyncIterator, Sequence
from sqlalchemy.sql.elements import ColumnElement from sqlalchemy.sql.elements import ColumnElement
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodel.sql.expression import SelectOfScalar
from app.models.users import User from app.models.users import User
@@ -232,6 +234,16 @@ def to_agent_read(agent: Agent, main_session_keys: set[str]) -> AgentRead:
return _to_agent_read(agent, main_session_keys) return _to_agent_read(agent, main_session_keys)
def _coerce_agent_items(items: Sequence[Any]) -> list[Agent]:
agents: list[Agent] = []
for item in items:
if not isinstance(item, Agent):
msg = "Expected Agent items from paginated query"
raise TypeError(msg)
agents.append(item)
return agents
async def _find_gateway_for_main_session( async def _find_gateway_for_main_session(
session: AsyncSession, session_key: str | None, session: AsyncSession, session_key: str | None,
) -> Gateway | None: ) -> Gateway | None:
@@ -777,7 +789,7 @@ async def _provision_updated_agent(
) from exc ) from exc
def _heartbeat_lookup_statement(payload: AgentHeartbeatCreate) -> object: def _heartbeat_lookup_statement(payload: AgentHeartbeatCreate) -> SelectOfScalar[Agent]:
statement = Agent.objects.filter_by(name=payload.name).statement statement = Agent.objects.filter_by(name=payload.name).statement
if payload.board_id is not None: if payload.board_id is not None:
statement = statement.where(Agent.board_id == payload.board_id) statement = statement.where(Agent.board_id == payload.board_id)
@@ -973,7 +985,7 @@ async def list_agents(
statement = statement.order_by(col(Agent.created_at).desc()) statement = statement.order_by(col(Agent.created_at).desc())
def _transform(items: Sequence[Any]) -> Sequence[Any]: def _transform(items: Sequence[Any]) -> Sequence[Any]:
agents = cast(Sequence[Agent], items) agents = _coerce_agent_items(items)
return [ return [
_to_agent_read(_with_computed_status(agent), main_session_keys) _to_agent_read(_with_computed_status(agent), main_session_keys)
for agent in agents for agent in agents

View File

@@ -3,7 +3,7 @@
from __future__ import annotations from __future__ import annotations
import re import re
from typing import TYPE_CHECKING, Any, cast from typing import TYPE_CHECKING, Any
from uuid import UUID, uuid4 from uuid import UUID, uuid4
from fastapi import APIRouter, Depends, HTTPException, status from fastapi import APIRouter, Depends, HTTPException, status
@@ -235,9 +235,7 @@ def _update_agent_heartbeat(
) -> None: ) -> None:
raw = agent.heartbeat_config raw = agent.heartbeat_config
heartbeat: dict[str, Any] = ( heartbeat: dict[str, Any] = (
cast(dict[str, Any], dict(raw)) dict(raw) if isinstance(raw, dict) else DEFAULT_HEARTBEAT_CONFIG.copy()
if isinstance(raw, dict)
else cast(dict[str, Any], DEFAULT_HEARTBEAT_CONFIG.copy())
) )
heartbeat["every"] = payload.every heartbeat["every"] = payload.every
if payload.target is not None: if payload.target is not None:

View File

@@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Iterable
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from fastapi import APIRouter, Depends, HTTPException, Query, status from fastapi import APIRouter, Depends, HTTPException, Query, status
@@ -56,6 +57,20 @@ def _query_to_resolve_input(params: GatewayResolveQuery) -> GatewayResolveQuery:
RESOLVE_INPUT_DEP = Depends(_query_to_resolve_input) RESOLVE_INPUT_DEP = Depends(_query_to_resolve_input)
def _as_object_list(value: object) -> list[object]:
if value is None:
return []
if isinstance(value, list):
return value
if isinstance(value, (tuple, set)):
return list(value)
if isinstance(value, (str, bytes, dict)):
return []
if isinstance(value, Iterable):
return list(value)
return []
async def _resolve_gateway( async def _resolve_gateway(
session: AsyncSession, session: AsyncSession,
params: GatewayResolveQuery, params: GatewayResolveQuery,
@@ -138,9 +153,9 @@ async def gateways_status(
try: try:
sessions = await openclaw_call("sessions.list", config=config) sessions = await openclaw_call("sessions.list", config=config)
if isinstance(sessions, dict): if isinstance(sessions, dict):
sessions_list = list(sessions.get("sessions") or []) sessions_list = _as_object_list(sessions.get("sessions"))
else: else:
sessions_list = list(sessions or []) sessions_list = _as_object_list(sessions)
main_session_entry: object | None = None main_session_entry: object | None = None
main_session_error: str | None = None main_session_error: str | None = None
if main_session: if main_session:
@@ -190,9 +205,9 @@ async def list_gateway_sessions(
status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc), status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc),
) from exc ) from exc
if isinstance(sessions, dict): if isinstance(sessions, dict):
sessions_list = list(sessions.get("sessions") or []) sessions_list = _as_object_list(sessions.get("sessions"))
else: else:
sessions_list = list(sessions or []) sessions_list = _as_object_list(sessions)
main_session_entry: object | None = None main_session_entry: object | None = None
if main_session: if main_session:
@@ -215,9 +230,9 @@ async def list_gateway_sessions(
async def _list_sessions(config: GatewayClientConfig) -> list[dict[str, object]]: async def _list_sessions(config: GatewayClientConfig) -> list[dict[str, object]]:
sessions = await openclaw_call("sessions.list", config=config) sessions = await openclaw_call("sessions.list", config=config)
if isinstance(sessions, dict): if isinstance(sessions, dict):
raw_items = sessions.get("sessions") or [] raw_items = _as_object_list(sessions.get("sessions"))
else: else:
raw_items = sessions or [] raw_items = _as_object_list(sessions)
return [ return [
item item
for item in raw_items for item in raw_items
@@ -311,7 +326,7 @@ async def get_session_history(
) from exc ) from exc
if isinstance(history, dict) and isinstance(history.get("messages"), list): if isinstance(history, dict) and isinstance(history.get("messages"), list):
return GatewaySessionHistoryResponse(history=history["messages"]) return GatewaySessionHistoryResponse(history=history["messages"])
return GatewaySessionHistoryResponse(history=list(history or [])) return GatewaySessionHistoryResponse(history=_as_object_list(history))
@router.post("/sessions/{session_id}/message", response_model=OkResponse) @router.post("/sessions/{session_id}/message", response_model=OkResponse)

View File

@@ -5,7 +5,6 @@ from __future__ import annotations
import asyncio import asyncio
import json import json
from collections import deque from collections import deque
from collections.abc import AsyncIterator, Sequence
from contextlib import suppress from contextlib import suppress
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime, timezone from datetime import datetime, timezone
@@ -15,7 +14,6 @@ from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
from sqlalchemy import asc, desc, or_ from sqlalchemy import asc, desc, or_
from sqlmodel import col, select from sqlmodel import col, select
from sqlmodel.sql.expression import Select
from sse_starlette.sse import EventSourceResponse from sse_starlette.sse import EventSourceResponse
from app.api.deps import ( from app.api.deps import (
@@ -67,7 +65,10 @@ from app.services.task_dependencies import (
) )
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import AsyncIterator, Sequence
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodel.sql.expression import Select, SelectOfScalar
from app.core.auth import AuthContext from app.core.auth import AuthContext
from app.models.users import User from app.models.users import User
@@ -157,6 +158,16 @@ def _parse_since(value: str | None) -> datetime | None:
return parsed return parsed
def _coerce_task_items(items: Sequence[object]) -> list[Task]:
tasks: list[Task] = []
for item in items:
if not isinstance(item, Task):
msg = "Expected Task items from paginated query"
raise TypeError(msg)
tasks.append(item)
return tasks
async def _lead_was_mentioned( async def _lead_was_mentioned(
session: AsyncSession, session: AsyncSession,
task: Task, task: Task,
@@ -266,7 +277,7 @@ async def _fetch_task_events(
if not task_ids: if not task_ids:
return [] return []
statement = cast( statement = cast(
Select[tuple[ActivityEvent, Task | None]], "Select[tuple[ActivityEvent, Task | None]]",
select(ActivityEvent, Task) select(ActivityEvent, Task)
.outerjoin(Task, col(ActivityEvent.task_id) == col(Task.id)) .outerjoin(Task, col(ActivityEvent.task_id) == col(Task.id))
.where(col(ActivityEvent.task_id).in_(task_ids)) .where(col(ActivityEvent.task_id).in_(task_ids))
@@ -512,7 +523,7 @@ def _task_list_statement(
status_filter: str | None, status_filter: str | None,
assigned_agent_id: UUID | None, assigned_agent_id: UUID | None,
unassigned: bool | None, unassigned: bool | None,
) -> object: ) -> SelectOfScalar[Task]:
statement = select(Task).where(Task.board_id == board_id) statement = select(Task).where(Task.board_id == board_id)
statuses = _status_values(status_filter) statuses = _status_values(status_filter)
if statuses: if statuses:
@@ -717,7 +728,7 @@ async def list_tasks(
) )
async def _transform(items: Sequence[object]) -> Sequence[object]: async def _transform(items: Sequence[object]) -> Sequence[object]:
tasks = cast(Sequence[Task], items) tasks = _coerce_task_items(items)
return await _task_read_page( return await _task_read_page(
session=session, session=session,
board_id=board.id, board_id=board.id,
@@ -735,8 +746,8 @@ async def create_task(
auth: AuthContext = ADMIN_AUTH_DEP, auth: AuthContext = ADMIN_AUTH_DEP,
) -> TaskRead: ) -> TaskRead:
"""Create a task and initialize dependency rows.""" """Create a task and initialize dependency rows."""
data = payload.model_dump() data = payload.model_dump(exclude={"depends_on_task_ids"})
depends_on_task_ids = cast(list[UUID], data.pop("depends_on_task_ids", []) or []) depends_on_task_ids = list(payload.depends_on_task_ids)
task = Task.model_validate(data) task = Task.model_validate(data)
task.board_id = board.id task.board_id = board.id
@@ -828,10 +839,14 @@ async def update_task(
previous_status = task.status previous_status = task.status
previous_assigned = task.assigned_agent_id previous_assigned = task.assigned_agent_id
updates = payload.model_dump(exclude_unset=True) updates = payload.model_dump(exclude_unset=True)
comment = cast(str | None, updates.pop("comment", None)) comment = payload.comment if "comment" in payload.model_fields_set else None
depends_on_task_ids = cast( depends_on_task_ids = (
list[UUID] | None, updates.pop("depends_on_task_ids", None), payload.depends_on_task_ids
if "depends_on_task_ids" in payload.model_fields_set
else None
) )
updates.pop("comment", None)
updates.pop("depends_on_task_ids", None)
update = _TaskUpdateInput( update = _TaskUpdateInput(
task=task, task=task,
actor=actor, actor=actor,
@@ -960,7 +975,7 @@ async def _comment_targets(
task: Task, task: Task,
message: str, message: str,
actor: ActorContext, actor: ActorContext,
) -> tuple[dict[UUID, Agent], list[str]]: ) -> tuple[dict[UUID, Agent], set[str]]:
mention_names = extract_mentions(message) mention_names = extract_mentions(message)
targets: dict[UUID, Agent] = {} targets: dict[UUID, Agent] = {}
if mention_names and task.board_id: if mention_names and task.board_id:
@@ -985,7 +1000,7 @@ class _TaskCommentNotifyRequest:
actor: ActorContext actor: ActorContext
message: str message: str
targets: dict[UUID, Agent] targets: dict[UUID, Agent]
mention_names: list[str] mention_names: set[str]
async def _notify_task_comment_targets( async def _notify_task_comment_targets(
@@ -1048,6 +1063,18 @@ class _TaskUpdateInput:
depends_on_task_ids: list[UUID] | None depends_on_task_ids: list[UUID] | None
def _required_status_value(value: object) -> str:
if isinstance(value, str):
return value
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY)
def _optional_assigned_agent_id(value: object) -> UUID | None:
if value is None or isinstance(value, UUID):
return value
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY)
async def _task_dep_ids( async def _task_dep_ids(
session: AsyncSession, session: AsyncSession,
*, *,
@@ -1182,7 +1209,7 @@ async def _lead_apply_assignment(
) -> None: ) -> None:
if "assigned_agent_id" not in update.updates: if "assigned_agent_id" not in update.updates:
return return
assigned_id = cast(UUID | None, update.updates["assigned_agent_id"]) assigned_id = _optional_assigned_agent_id(update.updates["assigned_agent_id"])
if not assigned_id: if not assigned_id:
update.task.assigned_agent_id = None update.task.assigned_agent_id = None
return return
@@ -1214,7 +1241,7 @@ def _lead_apply_status(update: _TaskUpdateInput) -> None:
"in review." "in review."
), ),
) )
target_status = cast(str, update.updates["status"]) target_status = _required_status_value(update.updates["status"])
if target_status not in {"done", "inbox"}: if target_status not in {"done", "inbox"}:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
@@ -1332,7 +1359,7 @@ async def _apply_non_lead_agent_task_rules(
): ):
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
if "status" in update.updates: if "status" in update.updates:
status_value = cast(str, update.updates["status"]) status_value = _required_status_value(update.updates["status"])
if status_value != "inbox": if status_value != "inbox":
dep_ids = await _task_dep_ids( dep_ids = await _task_dep_ids(
session, session,
@@ -1390,7 +1417,9 @@ async def _apply_admin_task_rules(
board_id=update.board_id, board_id=update.board_id,
dep_ids=effective_deps, dep_ids=effective_deps,
) )
target_status = cast(str, update.updates.get("status", update.task.status)) target_status = _required_status_value(
update.updates.get("status", update.task.status),
)
if blocked_ids and not (update.task.status == "done" and target_status == "done"): if blocked_ids and not (update.task.status == "done" and target_status == "done"):
update.task.status = "inbox" update.task.status = "inbox"
update.task.assigned_agent_id = None update.task.assigned_agent_id = None
@@ -1399,14 +1428,16 @@ async def _apply_admin_task_rules(
update.updates["assigned_agent_id"] = None update.updates["assigned_agent_id"] = None
if "status" in update.updates: if "status" in update.updates:
status_value = cast(str, update.updates["status"]) status_value = _required_status_value(update.updates["status"])
if status_value == "inbox": if status_value == "inbox":
update.task.assigned_agent_id = None update.task.assigned_agent_id = None
update.task.in_progress_at = None update.task.in_progress_at = None
elif status_value == "in_progress": elif status_value == "in_progress":
update.task.in_progress_at = utcnow() update.task.in_progress_at = utcnow()
assigned_agent_id = cast(UUID | None, update.updates.get("assigned_agent_id")) assigned_agent_id = _optional_assigned_agent_id(
update.updates.get("assigned_agent_id"),
)
if assigned_agent_id: if assigned_agent_id:
agent = await Agent.objects.by_id(assigned_agent_id).first(session) agent = await Agent.objects.by_id(assigned_agent_id).first(session)
if agent is None: if agent is None:
@@ -1530,7 +1561,8 @@ async def _finalize_updated_task(
setattr(update.task, key, value) setattr(update.task, key, value)
update.task.updated_at = utcnow() update.task.updated_at = utcnow()
if "status" in update.updates and cast(str, update.updates["status"]) == "review": status_raw = update.updates.get("status")
if status_raw == "review":
comment_text = (update.comment or "").strip() comment_text = (update.comment or "").strip()
if not comment_text and not await has_valid_recent_comment( if not comment_text and not await has_valid_recent_comment(
session, session,

View File

@@ -8,19 +8,69 @@ import os
import sys import sys
import time import time
from datetime import datetime, timezone from datetime import datetime, timezone
from types import TracebackType
from typing import Any from typing import Any
from app.core.config import settings from app.core.config import settings
from app.core.version import APP_NAME, APP_VERSION from app.core.version import APP_NAME, APP_VERSION
TRACE_LEVEL = 5 TRACE_LEVEL = 5
EXC_INFO_TUPLE_SIZE = 3
logging.addLevelName(TRACE_LEVEL, "TRACE") logging.addLevelName(TRACE_LEVEL, "TRACE")
def _coerce_exc_info(
value: object,
) -> (
bool
| tuple[type[BaseException], BaseException, TracebackType | None]
| tuple[None, None, None]
| BaseException
| None
):
if value is None:
return None
if isinstance(value, bool | BaseException):
return value
if not isinstance(value, tuple) or len(value) != EXC_INFO_TUPLE_SIZE:
return None
first, second, third = value
if first is None and second is None and third is None:
return (None, None, None)
if (
isinstance(first, type)
and issubclass(first, BaseException)
and isinstance(second, BaseException)
and (isinstance(third, TracebackType) or third is None)
):
return (first, second, third)
return None
def _coerce_extra(value: object) -> dict[str, object] | None:
if not isinstance(value, dict):
return None
return {str(key): item for key, item in value.items()}
def _trace(self: logging.Logger, message: str, *args: object, **kwargs: object) -> None: def _trace(self: logging.Logger, message: str, *args: object, **kwargs: object) -> None:
"""Log a TRACE-level message when the logger is TRACE-enabled.""" """Log a TRACE-level message when the logger is TRACE-enabled."""
if self.isEnabledFor(TRACE_LEVEL): if self.isEnabledFor(TRACE_LEVEL):
self._log(TRACE_LEVEL, message, args, **kwargs) exc_info = _coerce_exc_info(kwargs.get("exc_info"))
stack_info_raw = kwargs.get("stack_info")
stack_info = stack_info_raw if isinstance(stack_info_raw, bool) else False
stacklevel_raw = kwargs.get("stacklevel")
stacklevel = stacklevel_raw if isinstance(stacklevel_raw, int) else 1
extra = _coerce_extra(kwargs.get("extra"))
self.log(
TRACE_LEVEL,
message,
*args,
exc_info=exc_info,
stack_info=stack_info,
stacklevel=stacklevel,
extra=extra,
)
logging.Logger.trace = _trace # type: ignore[attr-defined] logging.Logger.trace = _trace # type: ignore[attr-defined]

View File

@@ -3,13 +3,16 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from typing import Generic, TypeVar from typing import TYPE_CHECKING, Generic, TypeVar
from sqlalchemy import false from sqlalchemy import false
from sqlmodel import SQLModel, col from sqlmodel import SQLModel, col
from app.db.queryset import QuerySet, qs from app.db.queryset import QuerySet, qs
if TYPE_CHECKING:
from collections.abc import Iterable
ModelT = TypeVar("ModelT", bound=SQLModel) ModelT = TypeVar("ModelT", bound=SQLModel)
@@ -49,7 +52,7 @@ class ModelManager(Generic[ModelT]):
def by_ids( def by_ids(
self, self,
obj_ids: list[object] | tuple[object, ...] | set[object], obj_ids: Iterable[object],
) -> QuerySet[ModelT]: ) -> QuerySet[ModelT]:
"""Return queryset filtered by a set/list/tuple of identifiers.""" """Return queryset filtered by a set/list/tuple of identifiers."""
return self.by_field_in(self.id_field, obj_ids) return self.by_field_in(self.id_field, obj_ids)
@@ -61,7 +64,7 @@ class ModelManager(Generic[ModelT]):
def by_field_in( def by_field_in(
self, self,
field_name: str, field_name: str,
values: list[object] | tuple[object, ...] | set[object], values: Iterable[object],
) -> QuerySet[ModelT]: ) -> QuerySet[ModelT]:
"""Return queryset filtered by `field IN values` semantics.""" """Return queryset filtered by `field IN values` semantics."""
seq = tuple(values) seq = tuple(values)

View File

@@ -3,7 +3,7 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass, replace from dataclasses import dataclass, replace
from typing import TYPE_CHECKING, Generic, TypeVar from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast
from sqlmodel import select from sqlmodel import select
@@ -22,7 +22,11 @@ class QuerySet(Generic[ModelT]):
def filter(self, *criteria: object) -> QuerySet[ModelT]: def filter(self, *criteria: object) -> QuerySet[ModelT]:
"""Return a new queryset with additional SQL criteria.""" """Return a new queryset with additional SQL criteria."""
return replace(self, statement=self.statement.where(*criteria)) statement = cast(
"SelectOfScalar[ModelT]",
cast(Any, self.statement).where(*criteria),
)
return replace(self, statement=statement)
def where(self, *criteria: object) -> QuerySet[ModelT]: def where(self, *criteria: object) -> QuerySet[ModelT]:
"""Alias for `filter` to mirror SQLAlchemy naming.""" """Alias for `filter` to mirror SQLAlchemy naming."""
@@ -35,7 +39,11 @@ class QuerySet(Generic[ModelT]):
def order_by(self, *ordering: object) -> QuerySet[ModelT]: def order_by(self, *ordering: object) -> QuerySet[ModelT]:
"""Return a new queryset with ordering clauses applied.""" """Return a new queryset with ordering clauses applied."""
return replace(self, statement=self.statement.order_by(*ordering)) statement = cast(
"SelectOfScalar[ModelT]",
cast(Any, self.statement).order_by(*ordering),
)
return replace(self, statement=statement)
def limit(self, value: int) -> QuerySet[ModelT]: def limit(self, value: int) -> QuerySet[ModelT]:
"""Return a new queryset with a SQL row limit.""" """Return a new queryset with a SQL row limit."""

View File

@@ -8,7 +8,7 @@ import re
from contextlib import suppress from contextlib import suppress
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Any, cast from typing import TYPE_CHECKING, Any
from uuid import uuid4 from uuid import uuid4
from jinja2 import Environment, FileSystemLoader, StrictUndefined, select_autoescape from jinja2 import Environment, FileSystemLoader, StrictUndefined, select_autoescape
@@ -452,7 +452,7 @@ async def _gateway_agent_files_index(
continue continue
name = item.get("name") name = item.get("name")
if isinstance(name, str) and name: if isinstance(name, str) and name:
index[name] = cast(dict[str, Any], item) index[name] = dict(item)
return index return index
except OpenClawGatewayError: except OpenClawGatewayError:
pass pass

View File

@@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
from collections import defaultdict from collections import defaultdict
from typing import TYPE_CHECKING
from uuid import UUID from uuid import UUID
from sqlalchemy import case, func from sqlalchemy import case, func
@@ -21,18 +22,21 @@ from app.schemas.view_models import (
BoardGroupTaskSummary, BoardGroupTaskSummary,
) )
if TYPE_CHECKING:
from sqlalchemy.sql.elements import ColumnElement
_STATUS_ORDER = {"in_progress": 0, "review": 1, "inbox": 2, "done": 3} _STATUS_ORDER = {"in_progress": 0, "review": 1, "inbox": 2, "done": 3}
_PRIORITY_ORDER = {"high": 0, "medium": 1, "low": 2} _PRIORITY_ORDER = {"high": 0, "medium": 1, "low": 2}
_RUNTIME_TYPE_REFERENCES = (UUID, AsyncSession) _RUNTIME_TYPE_REFERENCES = (UUID, AsyncSession)
def _status_weight_expr() -> object: def _status_weight_expr() -> ColumnElement[int]:
"""Return a SQL expression that sorts task statuses by configured order.""" """Return a SQL expression that sorts task statuses by configured order."""
whens = [(col(Task.status) == key, weight) for key, weight in _STATUS_ORDER.items()] whens = [(col(Task.status) == key, weight) for key, weight in _STATUS_ORDER.items()]
return case(*whens, else_=99) return case(*whens, else_=99)
def _priority_weight_expr() -> object: def _priority_weight_expr() -> ColumnElement[int]:
"""Return a SQL expression that sorts task priorities by configured order.""" """Return a SQL expression that sorts task priorities by configured order."""
whens = [ whens = [
(col(Task.priority) == key, weight) (col(Task.priority) == key, weight)

View File

@@ -148,7 +148,7 @@ async def build_board_snapshot(session: AsyncSession, board: Board) -> BoardSnap
select(func.count(col(Approval.id))) select(func.count(col(Approval.id)))
.where(col(Approval.board_id) == board.id) .where(col(Approval.board_id) == board.id)
.where(col(Approval.status) == "pending"), .where(col(Approval.status) == "pending"),
), )
).one(), ).one(),
) )

View File

@@ -167,6 +167,9 @@ class _GatewayBackoff:
self._delay_s = min(self._delay_s * 2.0, self._max_delay_s) self._delay_s = min(self._delay_s * 2.0, self._max_delay_s)
continue continue
self.reset() self.reset()
if value is None:
msg = "Gateway retry produced no value without an error"
raise RuntimeError(msg)
return value return value