refactor: improve type coercion functions and enhance type hints across multiple files
This commit is contained in:
@@ -5,9 +5,8 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import json
|
||||
from collections import deque
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime, timezone
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
|
||||
@@ -35,7 +34,7 @@ from app.services.organizations import (
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import AsyncIterator
|
||||
from collections.abc import AsyncIterator, Sequence
|
||||
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
@@ -43,6 +42,7 @@ router = APIRouter(prefix="/activity", tags=["activity"])
|
||||
|
||||
SSE_SEEN_MAX = 2000
|
||||
STREAM_POLL_SECONDS = 2
|
||||
TASK_COMMENT_ROW_LEN = 4
|
||||
SESSION_DEP = Depends(get_session)
|
||||
ACTOR_DEP = Depends(require_admin_or_agent)
|
||||
ORG_MEMBER_DEP = Depends(require_org_member)
|
||||
@@ -100,6 +100,26 @@ def _feed_item(
|
||||
)
|
||||
|
||||
|
||||
def _coerce_task_comment_rows(
|
||||
items: Sequence[Any],
|
||||
) -> list[tuple[ActivityEvent, Task, Board, Agent | None]]:
|
||||
rows: list[tuple[ActivityEvent, Task, Board, Agent | None]] = []
|
||||
for item in items:
|
||||
if (
|
||||
isinstance(item, tuple)
|
||||
and len(item) == TASK_COMMENT_ROW_LEN
|
||||
and isinstance(item[0], ActivityEvent)
|
||||
and isinstance(item[1], Task)
|
||||
and isinstance(item[2], Board)
|
||||
and (isinstance(item[3], Agent) or item[3] is None)
|
||||
):
|
||||
rows.append((item[0], item[1], item[2], item[3]))
|
||||
continue
|
||||
msg = "Expected (ActivityEvent, Task, Board, Agent | None) rows"
|
||||
raise TypeError(msg)
|
||||
return rows
|
||||
|
||||
|
||||
async def _fetch_task_comment_events(
|
||||
session: AsyncSession,
|
||||
since: datetime,
|
||||
@@ -118,10 +138,7 @@ async def _fetch_task_comment_events(
|
||||
)
|
||||
if board_id is not None:
|
||||
statement = statement.where(col(Task.board_id) == board_id)
|
||||
return cast(
|
||||
Sequence[tuple[ActivityEvent, Task, Board, Agent | None]],
|
||||
list(await session.exec(statement)),
|
||||
)
|
||||
return _coerce_task_comment_rows(list(await session.exec(statement)))
|
||||
|
||||
|
||||
@router.get("", response_model=DefaultLimitOffsetPage[ActivityEventRead])
|
||||
@@ -179,7 +196,7 @@ async def list_task_comment_feed(
|
||||
statement = statement.where(col(Task.id).is_(None))
|
||||
|
||||
def _transform(items: Sequence[Any]) -> Sequence[Any]:
|
||||
rows = cast(Sequence[tuple[ActivityEvent, Task, Board, Agent | None]], items)
|
||||
rows = _coerce_task_comment_rows(items)
|
||||
return [
|
||||
_feed_item(event, task, board, agent)
|
||||
for event, task, board, agent in rows
|
||||
|
||||
@@ -3,9 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
from uuid import UUID
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from sqlmodel import SQLModel, col, select
|
||||
@@ -75,6 +73,9 @@ from app.services.task_dependencies import (
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
from uuid import UUID
|
||||
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.models.activity_events import ActivityEvent
|
||||
@@ -97,6 +98,16 @@ IS_CHAT_QUERY = Query(default=None)
|
||||
APPROVAL_STATUS_QUERY = Query(default=None, alias="status")
|
||||
|
||||
|
||||
def _coerce_agent_items(items: Sequence[Any]) -> list[Agent]:
|
||||
agents: list[Agent] = []
|
||||
for item in items:
|
||||
if not isinstance(item, Agent):
|
||||
msg = "Expected Agent items from paginated query"
|
||||
raise TypeError(msg)
|
||||
agents.append(item)
|
||||
return agents
|
||||
|
||||
|
||||
def _gateway_agent_id(agent: Agent) -> str:
|
||||
session_key = agent.openclaw_session_id or ""
|
||||
if session_key.startswith(_AGENT_SESSION_PREFIX):
|
||||
@@ -248,7 +259,7 @@ async def list_agents(
|
||||
statement = statement.order_by(col(Agent.created_at).desc())
|
||||
|
||||
def _transform(items: Sequence[Any]) -> Sequence[Any]:
|
||||
agents = cast(Sequence[Agent], items)
|
||||
agents = _coerce_agent_items(items)
|
||||
return [
|
||||
agents_api.to_agent_read(
|
||||
agents_api.with_computed_status(agent),
|
||||
@@ -275,7 +286,7 @@ async def list_tasks(
|
||||
unassigned=filters.unassigned,
|
||||
board=board,
|
||||
session=session,
|
||||
actor=_actor(agent_ctx),
|
||||
_actor=_actor(agent_ctx),
|
||||
)
|
||||
|
||||
|
||||
@@ -290,8 +301,8 @@ async def create_task(
|
||||
_guard_board_access(agent_ctx, board)
|
||||
if not agent_ctx.agent.is_board_lead:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
|
||||
data = payload.model_dump()
|
||||
depends_on_task_ids = cast(list[UUID], data.pop("depends_on_task_ids", []) or [])
|
||||
data = payload.model_dump(exclude={"depends_on_task_ids"})
|
||||
depends_on_task_ids = list(payload.depends_on_task_ids)
|
||||
|
||||
task = Task.model_validate(data)
|
||||
task.board_id = board.id
|
||||
@@ -456,7 +467,7 @@ async def list_board_memory(
|
||||
is_chat=is_chat,
|
||||
board=board,
|
||||
session=session,
|
||||
actor=_actor(agent_ctx),
|
||||
_actor=_actor(agent_ctx),
|
||||
)
|
||||
|
||||
|
||||
@@ -493,7 +504,7 @@ async def list_approvals(
|
||||
status_filter=status_filter,
|
||||
board=board,
|
||||
session=session,
|
||||
actor=_actor(agent_ctx),
|
||||
_actor=_actor(agent_ctx),
|
||||
)
|
||||
|
||||
|
||||
@@ -510,7 +521,7 @@ async def create_approval(
|
||||
payload=payload,
|
||||
board=board,
|
||||
session=session,
|
||||
actor=_actor(agent_ctx),
|
||||
_actor=_actor(agent_ctx),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -5,10 +5,9 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
from collections.abc import AsyncIterator, Sequence
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
|
||||
@@ -64,8 +63,11 @@ from app.services.organizations import (
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import AsyncIterator, Sequence
|
||||
|
||||
from sqlalchemy.sql.elements import ColumnElement
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from sqlmodel.sql.expression import SelectOfScalar
|
||||
|
||||
from app.models.users import User
|
||||
|
||||
@@ -232,6 +234,16 @@ def to_agent_read(agent: Agent, main_session_keys: set[str]) -> AgentRead:
|
||||
return _to_agent_read(agent, main_session_keys)
|
||||
|
||||
|
||||
def _coerce_agent_items(items: Sequence[Any]) -> list[Agent]:
|
||||
agents: list[Agent] = []
|
||||
for item in items:
|
||||
if not isinstance(item, Agent):
|
||||
msg = "Expected Agent items from paginated query"
|
||||
raise TypeError(msg)
|
||||
agents.append(item)
|
||||
return agents
|
||||
|
||||
|
||||
async def _find_gateway_for_main_session(
|
||||
session: AsyncSession, session_key: str | None,
|
||||
) -> Gateway | None:
|
||||
@@ -777,7 +789,7 @@ async def _provision_updated_agent(
|
||||
) from exc
|
||||
|
||||
|
||||
def _heartbeat_lookup_statement(payload: AgentHeartbeatCreate) -> object:
|
||||
def _heartbeat_lookup_statement(payload: AgentHeartbeatCreate) -> SelectOfScalar[Agent]:
|
||||
statement = Agent.objects.filter_by(name=payload.name).statement
|
||||
if payload.board_id is not None:
|
||||
statement = statement.where(Agent.board_id == payload.board_id)
|
||||
@@ -973,7 +985,7 @@ async def list_agents(
|
||||
statement = statement.order_by(col(Agent.created_at).desc())
|
||||
|
||||
def _transform(items: Sequence[Any]) -> Sequence[Any]:
|
||||
agents = cast(Sequence[Agent], items)
|
||||
agents = _coerce_agent_items(items)
|
||||
return [
|
||||
_to_agent_read(_with_computed_status(agent), main_session_keys)
|
||||
for agent in agents
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
@@ -235,9 +235,7 @@ def _update_agent_heartbeat(
|
||||
) -> None:
|
||||
raw = agent.heartbeat_config
|
||||
heartbeat: dict[str, Any] = (
|
||||
cast(dict[str, Any], dict(raw))
|
||||
if isinstance(raw, dict)
|
||||
else cast(dict[str, Any], DEFAULT_HEARTBEAT_CONFIG.copy())
|
||||
dict(raw) if isinstance(raw, dict) else DEFAULT_HEARTBEAT_CONFIG.copy()
|
||||
)
|
||||
heartbeat["every"] = payload.every
|
||||
if payload.target is not None:
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterable
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
@@ -56,6 +57,20 @@ def _query_to_resolve_input(params: GatewayResolveQuery) -> GatewayResolveQuery:
|
||||
RESOLVE_INPUT_DEP = Depends(_query_to_resolve_input)
|
||||
|
||||
|
||||
def _as_object_list(value: object) -> list[object]:
|
||||
if value is None:
|
||||
return []
|
||||
if isinstance(value, list):
|
||||
return value
|
||||
if isinstance(value, (tuple, set)):
|
||||
return list(value)
|
||||
if isinstance(value, (str, bytes, dict)):
|
||||
return []
|
||||
if isinstance(value, Iterable):
|
||||
return list(value)
|
||||
return []
|
||||
|
||||
|
||||
async def _resolve_gateway(
|
||||
session: AsyncSession,
|
||||
params: GatewayResolveQuery,
|
||||
@@ -138,9 +153,9 @@ async def gateways_status(
|
||||
try:
|
||||
sessions = await openclaw_call("sessions.list", config=config)
|
||||
if isinstance(sessions, dict):
|
||||
sessions_list = list(sessions.get("sessions") or [])
|
||||
sessions_list = _as_object_list(sessions.get("sessions"))
|
||||
else:
|
||||
sessions_list = list(sessions or [])
|
||||
sessions_list = _as_object_list(sessions)
|
||||
main_session_entry: object | None = None
|
||||
main_session_error: str | None = None
|
||||
if main_session:
|
||||
@@ -190,9 +205,9 @@ async def list_gateway_sessions(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc),
|
||||
) from exc
|
||||
if isinstance(sessions, dict):
|
||||
sessions_list = list(sessions.get("sessions") or [])
|
||||
sessions_list = _as_object_list(sessions.get("sessions"))
|
||||
else:
|
||||
sessions_list = list(sessions or [])
|
||||
sessions_list = _as_object_list(sessions)
|
||||
|
||||
main_session_entry: object | None = None
|
||||
if main_session:
|
||||
@@ -215,9 +230,9 @@ async def list_gateway_sessions(
|
||||
async def _list_sessions(config: GatewayClientConfig) -> list[dict[str, object]]:
|
||||
sessions = await openclaw_call("sessions.list", config=config)
|
||||
if isinstance(sessions, dict):
|
||||
raw_items = sessions.get("sessions") or []
|
||||
raw_items = _as_object_list(sessions.get("sessions"))
|
||||
else:
|
||||
raw_items = sessions or []
|
||||
raw_items = _as_object_list(sessions)
|
||||
return [
|
||||
item
|
||||
for item in raw_items
|
||||
@@ -311,7 +326,7 @@ async def get_session_history(
|
||||
) from exc
|
||||
if isinstance(history, dict) and isinstance(history.get("messages"), list):
|
||||
return GatewaySessionHistoryResponse(history=history["messages"])
|
||||
return GatewaySessionHistoryResponse(history=list(history or []))
|
||||
return GatewaySessionHistoryResponse(history=_as_object_list(history))
|
||||
|
||||
|
||||
@router.post("/sessions/{session_id}/message", response_model=OkResponse)
|
||||
|
||||
@@ -5,7 +5,6 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import json
|
||||
from collections import deque
|
||||
from collections.abc import AsyncIterator, Sequence
|
||||
from contextlib import suppress
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
@@ -15,7 +14,6 @@ from uuid import UUID
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
|
||||
from sqlalchemy import asc, desc, or_
|
||||
from sqlmodel import col, select
|
||||
from sqlmodel.sql.expression import Select
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
|
||||
from app.api.deps import (
|
||||
@@ -67,7 +65,10 @@ from app.services.task_dependencies import (
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import AsyncIterator, Sequence
|
||||
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from sqlmodel.sql.expression import Select, SelectOfScalar
|
||||
|
||||
from app.core.auth import AuthContext
|
||||
from app.models.users import User
|
||||
@@ -157,6 +158,16 @@ def _parse_since(value: str | None) -> datetime | None:
|
||||
return parsed
|
||||
|
||||
|
||||
def _coerce_task_items(items: Sequence[object]) -> list[Task]:
|
||||
tasks: list[Task] = []
|
||||
for item in items:
|
||||
if not isinstance(item, Task):
|
||||
msg = "Expected Task items from paginated query"
|
||||
raise TypeError(msg)
|
||||
tasks.append(item)
|
||||
return tasks
|
||||
|
||||
|
||||
async def _lead_was_mentioned(
|
||||
session: AsyncSession,
|
||||
task: Task,
|
||||
@@ -266,7 +277,7 @@ async def _fetch_task_events(
|
||||
if not task_ids:
|
||||
return []
|
||||
statement = cast(
|
||||
Select[tuple[ActivityEvent, Task | None]],
|
||||
"Select[tuple[ActivityEvent, Task | None]]",
|
||||
select(ActivityEvent, Task)
|
||||
.outerjoin(Task, col(ActivityEvent.task_id) == col(Task.id))
|
||||
.where(col(ActivityEvent.task_id).in_(task_ids))
|
||||
@@ -512,7 +523,7 @@ def _task_list_statement(
|
||||
status_filter: str | None,
|
||||
assigned_agent_id: UUID | None,
|
||||
unassigned: bool | None,
|
||||
) -> object:
|
||||
) -> SelectOfScalar[Task]:
|
||||
statement = select(Task).where(Task.board_id == board_id)
|
||||
statuses = _status_values(status_filter)
|
||||
if statuses:
|
||||
@@ -717,7 +728,7 @@ async def list_tasks(
|
||||
)
|
||||
|
||||
async def _transform(items: Sequence[object]) -> Sequence[object]:
|
||||
tasks = cast(Sequence[Task], items)
|
||||
tasks = _coerce_task_items(items)
|
||||
return await _task_read_page(
|
||||
session=session,
|
||||
board_id=board.id,
|
||||
@@ -735,8 +746,8 @@ async def create_task(
|
||||
auth: AuthContext = ADMIN_AUTH_DEP,
|
||||
) -> TaskRead:
|
||||
"""Create a task and initialize dependency rows."""
|
||||
data = payload.model_dump()
|
||||
depends_on_task_ids = cast(list[UUID], data.pop("depends_on_task_ids", []) or [])
|
||||
data = payload.model_dump(exclude={"depends_on_task_ids"})
|
||||
depends_on_task_ids = list(payload.depends_on_task_ids)
|
||||
|
||||
task = Task.model_validate(data)
|
||||
task.board_id = board.id
|
||||
@@ -828,10 +839,14 @@ async def update_task(
|
||||
previous_status = task.status
|
||||
previous_assigned = task.assigned_agent_id
|
||||
updates = payload.model_dump(exclude_unset=True)
|
||||
comment = cast(str | None, updates.pop("comment", None))
|
||||
depends_on_task_ids = cast(
|
||||
list[UUID] | None, updates.pop("depends_on_task_ids", None),
|
||||
comment = payload.comment if "comment" in payload.model_fields_set else None
|
||||
depends_on_task_ids = (
|
||||
payload.depends_on_task_ids
|
||||
if "depends_on_task_ids" in payload.model_fields_set
|
||||
else None
|
||||
)
|
||||
updates.pop("comment", None)
|
||||
updates.pop("depends_on_task_ids", None)
|
||||
update = _TaskUpdateInput(
|
||||
task=task,
|
||||
actor=actor,
|
||||
@@ -960,7 +975,7 @@ async def _comment_targets(
|
||||
task: Task,
|
||||
message: str,
|
||||
actor: ActorContext,
|
||||
) -> tuple[dict[UUID, Agent], list[str]]:
|
||||
) -> tuple[dict[UUID, Agent], set[str]]:
|
||||
mention_names = extract_mentions(message)
|
||||
targets: dict[UUID, Agent] = {}
|
||||
if mention_names and task.board_id:
|
||||
@@ -985,7 +1000,7 @@ class _TaskCommentNotifyRequest:
|
||||
actor: ActorContext
|
||||
message: str
|
||||
targets: dict[UUID, Agent]
|
||||
mention_names: list[str]
|
||||
mention_names: set[str]
|
||||
|
||||
|
||||
async def _notify_task_comment_targets(
|
||||
@@ -1048,6 +1063,18 @@ class _TaskUpdateInput:
|
||||
depends_on_task_ids: list[UUID] | None
|
||||
|
||||
|
||||
def _required_status_value(value: object) -> str:
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY)
|
||||
|
||||
|
||||
def _optional_assigned_agent_id(value: object) -> UUID | None:
|
||||
if value is None or isinstance(value, UUID):
|
||||
return value
|
||||
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY)
|
||||
|
||||
|
||||
async def _task_dep_ids(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
@@ -1182,7 +1209,7 @@ async def _lead_apply_assignment(
|
||||
) -> None:
|
||||
if "assigned_agent_id" not in update.updates:
|
||||
return
|
||||
assigned_id = cast(UUID | None, update.updates["assigned_agent_id"])
|
||||
assigned_id = _optional_assigned_agent_id(update.updates["assigned_agent_id"])
|
||||
if not assigned_id:
|
||||
update.task.assigned_agent_id = None
|
||||
return
|
||||
@@ -1214,7 +1241,7 @@ def _lead_apply_status(update: _TaskUpdateInput) -> None:
|
||||
"in review."
|
||||
),
|
||||
)
|
||||
target_status = cast(str, update.updates["status"])
|
||||
target_status = _required_status_value(update.updates["status"])
|
||||
if target_status not in {"done", "inbox"}:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
@@ -1332,7 +1359,7 @@ async def _apply_non_lead_agent_task_rules(
|
||||
):
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
|
||||
if "status" in update.updates:
|
||||
status_value = cast(str, update.updates["status"])
|
||||
status_value = _required_status_value(update.updates["status"])
|
||||
if status_value != "inbox":
|
||||
dep_ids = await _task_dep_ids(
|
||||
session,
|
||||
@@ -1390,7 +1417,9 @@ async def _apply_admin_task_rules(
|
||||
board_id=update.board_id,
|
||||
dep_ids=effective_deps,
|
||||
)
|
||||
target_status = cast(str, update.updates.get("status", update.task.status))
|
||||
target_status = _required_status_value(
|
||||
update.updates.get("status", update.task.status),
|
||||
)
|
||||
if blocked_ids and not (update.task.status == "done" and target_status == "done"):
|
||||
update.task.status = "inbox"
|
||||
update.task.assigned_agent_id = None
|
||||
@@ -1399,14 +1428,16 @@ async def _apply_admin_task_rules(
|
||||
update.updates["assigned_agent_id"] = None
|
||||
|
||||
if "status" in update.updates:
|
||||
status_value = cast(str, update.updates["status"])
|
||||
status_value = _required_status_value(update.updates["status"])
|
||||
if status_value == "inbox":
|
||||
update.task.assigned_agent_id = None
|
||||
update.task.in_progress_at = None
|
||||
elif status_value == "in_progress":
|
||||
update.task.in_progress_at = utcnow()
|
||||
|
||||
assigned_agent_id = cast(UUID | None, update.updates.get("assigned_agent_id"))
|
||||
assigned_agent_id = _optional_assigned_agent_id(
|
||||
update.updates.get("assigned_agent_id"),
|
||||
)
|
||||
if assigned_agent_id:
|
||||
agent = await Agent.objects.by_id(assigned_agent_id).first(session)
|
||||
if agent is None:
|
||||
@@ -1530,7 +1561,8 @@ async def _finalize_updated_task(
|
||||
setattr(update.task, key, value)
|
||||
update.task.updated_at = utcnow()
|
||||
|
||||
if "status" in update.updates and cast(str, update.updates["status"]) == "review":
|
||||
status_raw = update.updates.get("status")
|
||||
if status_raw == "review":
|
||||
comment_text = (update.comment or "").strip()
|
||||
if not comment_text and not await has_valid_recent_comment(
|
||||
session,
|
||||
|
||||
@@ -8,19 +8,69 @@ import os
|
||||
import sys
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from types import TracebackType
|
||||
from typing import Any
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.version import APP_NAME, APP_VERSION
|
||||
|
||||
TRACE_LEVEL = 5
|
||||
EXC_INFO_TUPLE_SIZE = 3
|
||||
logging.addLevelName(TRACE_LEVEL, "TRACE")
|
||||
|
||||
|
||||
def _coerce_exc_info(
|
||||
value: object,
|
||||
) -> (
|
||||
bool
|
||||
| tuple[type[BaseException], BaseException, TracebackType | None]
|
||||
| tuple[None, None, None]
|
||||
| BaseException
|
||||
| None
|
||||
):
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, bool | BaseException):
|
||||
return value
|
||||
if not isinstance(value, tuple) or len(value) != EXC_INFO_TUPLE_SIZE:
|
||||
return None
|
||||
first, second, third = value
|
||||
if first is None and second is None and third is None:
|
||||
return (None, None, None)
|
||||
if (
|
||||
isinstance(first, type)
|
||||
and issubclass(first, BaseException)
|
||||
and isinstance(second, BaseException)
|
||||
and (isinstance(third, TracebackType) or third is None)
|
||||
):
|
||||
return (first, second, third)
|
||||
return None
|
||||
|
||||
|
||||
def _coerce_extra(value: object) -> dict[str, object] | None:
|
||||
if not isinstance(value, dict):
|
||||
return None
|
||||
return {str(key): item for key, item in value.items()}
|
||||
|
||||
|
||||
def _trace(self: logging.Logger, message: str, *args: object, **kwargs: object) -> None:
|
||||
"""Log a TRACE-level message when the logger is TRACE-enabled."""
|
||||
if self.isEnabledFor(TRACE_LEVEL):
|
||||
self._log(TRACE_LEVEL, message, args, **kwargs)
|
||||
exc_info = _coerce_exc_info(kwargs.get("exc_info"))
|
||||
stack_info_raw = kwargs.get("stack_info")
|
||||
stack_info = stack_info_raw if isinstance(stack_info_raw, bool) else False
|
||||
stacklevel_raw = kwargs.get("stacklevel")
|
||||
stacklevel = stacklevel_raw if isinstance(stacklevel_raw, int) else 1
|
||||
extra = _coerce_extra(kwargs.get("extra"))
|
||||
self.log(
|
||||
TRACE_LEVEL,
|
||||
message,
|
||||
*args,
|
||||
exc_info=exc_info,
|
||||
stack_info=stack_info,
|
||||
stacklevel=stacklevel,
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
|
||||
logging.Logger.trace = _trace # type: ignore[attr-defined]
|
||||
|
||||
@@ -3,13 +3,16 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Generic, TypeVar
|
||||
from typing import TYPE_CHECKING, Generic, TypeVar
|
||||
|
||||
from sqlalchemy import false
|
||||
from sqlmodel import SQLModel, col
|
||||
|
||||
from app.db.queryset import QuerySet, qs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterable
|
||||
|
||||
ModelT = TypeVar("ModelT", bound=SQLModel)
|
||||
|
||||
|
||||
@@ -49,7 +52,7 @@ class ModelManager(Generic[ModelT]):
|
||||
|
||||
def by_ids(
|
||||
self,
|
||||
obj_ids: list[object] | tuple[object, ...] | set[object],
|
||||
obj_ids: Iterable[object],
|
||||
) -> QuerySet[ModelT]:
|
||||
"""Return queryset filtered by a set/list/tuple of identifiers."""
|
||||
return self.by_field_in(self.id_field, obj_ids)
|
||||
@@ -61,7 +64,7 @@ class ModelManager(Generic[ModelT]):
|
||||
def by_field_in(
|
||||
self,
|
||||
field_name: str,
|
||||
values: list[object] | tuple[object, ...] | set[object],
|
||||
values: Iterable[object],
|
||||
) -> QuerySet[ModelT]:
|
||||
"""Return queryset filtered by `field IN values` semantics."""
|
||||
seq = tuple(values)
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, replace
|
||||
from typing import TYPE_CHECKING, Generic, TypeVar
|
||||
from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast
|
||||
|
||||
from sqlmodel import select
|
||||
|
||||
@@ -22,7 +22,11 @@ class QuerySet(Generic[ModelT]):
|
||||
|
||||
def filter(self, *criteria: object) -> QuerySet[ModelT]:
|
||||
"""Return a new queryset with additional SQL criteria."""
|
||||
return replace(self, statement=self.statement.where(*criteria))
|
||||
statement = cast(
|
||||
"SelectOfScalar[ModelT]",
|
||||
cast(Any, self.statement).where(*criteria),
|
||||
)
|
||||
return replace(self, statement=statement)
|
||||
|
||||
def where(self, *criteria: object) -> QuerySet[ModelT]:
|
||||
"""Alias for `filter` to mirror SQLAlchemy naming."""
|
||||
@@ -35,7 +39,11 @@ class QuerySet(Generic[ModelT]):
|
||||
|
||||
def order_by(self, *ordering: object) -> QuerySet[ModelT]:
|
||||
"""Return a new queryset with ordering clauses applied."""
|
||||
return replace(self, statement=self.statement.order_by(*ordering))
|
||||
statement = cast(
|
||||
"SelectOfScalar[ModelT]",
|
||||
cast(Any, self.statement).order_by(*ordering),
|
||||
)
|
||||
return replace(self, statement=statement)
|
||||
|
||||
def limit(self, value: int) -> QuerySet[ModelT]:
|
||||
"""Return a new queryset with a SQL row limit."""
|
||||
|
||||
@@ -8,7 +8,7 @@ import re
|
||||
from contextlib import suppress
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from uuid import uuid4
|
||||
|
||||
from jinja2 import Environment, FileSystemLoader, StrictUndefined, select_autoescape
|
||||
@@ -452,7 +452,7 @@ async def _gateway_agent_files_index(
|
||||
continue
|
||||
name = item.get("name")
|
||||
if isinstance(name, str) and name:
|
||||
index[name] = cast(dict[str, Any], item)
|
||||
index[name] = dict(item)
|
||||
return index
|
||||
except OpenClawGatewayError:
|
||||
pass
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import case, func
|
||||
@@ -21,18 +22,21 @@ from app.schemas.view_models import (
|
||||
BoardGroupTaskSummary,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlalchemy.sql.elements import ColumnElement
|
||||
|
||||
_STATUS_ORDER = {"in_progress": 0, "review": 1, "inbox": 2, "done": 3}
|
||||
_PRIORITY_ORDER = {"high": 0, "medium": 1, "low": 2}
|
||||
_RUNTIME_TYPE_REFERENCES = (UUID, AsyncSession)
|
||||
|
||||
|
||||
def _status_weight_expr() -> object:
|
||||
def _status_weight_expr() -> ColumnElement[int]:
|
||||
"""Return a SQL expression that sorts task statuses by configured order."""
|
||||
whens = [(col(Task.status) == key, weight) for key, weight in _STATUS_ORDER.items()]
|
||||
return case(*whens, else_=99)
|
||||
|
||||
|
||||
def _priority_weight_expr() -> object:
|
||||
def _priority_weight_expr() -> ColumnElement[int]:
|
||||
"""Return a SQL expression that sorts task priorities by configured order."""
|
||||
whens = [
|
||||
(col(Task.priority) == key, weight)
|
||||
|
||||
@@ -148,7 +148,7 @@ async def build_board_snapshot(session: AsyncSession, board: Board) -> BoardSnap
|
||||
select(func.count(col(Approval.id)))
|
||||
.where(col(Approval.board_id) == board.id)
|
||||
.where(col(Approval.status) == "pending"),
|
||||
),
|
||||
)
|
||||
).one(),
|
||||
)
|
||||
|
||||
|
||||
@@ -167,6 +167,9 @@ class _GatewayBackoff:
|
||||
self._delay_s = min(self._delay_s * 2.0, self._max_delay_s)
|
||||
continue
|
||||
self.reset()
|
||||
if value is None:
|
||||
msg = "Gateway retry produced no value without an error"
|
||||
raise RuntimeError(msg)
|
||||
return value
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user