refactor: improve type coercion functions and enhance type hints across multiple files
This commit is contained in:
@@ -5,9 +5,8 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import json
|
||||
from collections import deque
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime, timezone
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
|
||||
@@ -35,7 +34,7 @@ from app.services.organizations import (
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import AsyncIterator
|
||||
from collections.abc import AsyncIterator, Sequence
|
||||
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
@@ -43,6 +42,7 @@ router = APIRouter(prefix="/activity", tags=["activity"])
|
||||
|
||||
SSE_SEEN_MAX = 2000
|
||||
STREAM_POLL_SECONDS = 2
|
||||
TASK_COMMENT_ROW_LEN = 4
|
||||
SESSION_DEP = Depends(get_session)
|
||||
ACTOR_DEP = Depends(require_admin_or_agent)
|
||||
ORG_MEMBER_DEP = Depends(require_org_member)
|
||||
@@ -100,6 +100,26 @@ def _feed_item(
|
||||
)
|
||||
|
||||
|
||||
def _coerce_task_comment_rows(
|
||||
items: Sequence[Any],
|
||||
) -> list[tuple[ActivityEvent, Task, Board, Agent | None]]:
|
||||
rows: list[tuple[ActivityEvent, Task, Board, Agent | None]] = []
|
||||
for item in items:
|
||||
if (
|
||||
isinstance(item, tuple)
|
||||
and len(item) == TASK_COMMENT_ROW_LEN
|
||||
and isinstance(item[0], ActivityEvent)
|
||||
and isinstance(item[1], Task)
|
||||
and isinstance(item[2], Board)
|
||||
and (isinstance(item[3], Agent) or item[3] is None)
|
||||
):
|
||||
rows.append((item[0], item[1], item[2], item[3]))
|
||||
continue
|
||||
msg = "Expected (ActivityEvent, Task, Board, Agent | None) rows"
|
||||
raise TypeError(msg)
|
||||
return rows
|
||||
|
||||
|
||||
async def _fetch_task_comment_events(
|
||||
session: AsyncSession,
|
||||
since: datetime,
|
||||
@@ -118,10 +138,7 @@ async def _fetch_task_comment_events(
|
||||
)
|
||||
if board_id is not None:
|
||||
statement = statement.where(col(Task.board_id) == board_id)
|
||||
return cast(
|
||||
Sequence[tuple[ActivityEvent, Task, Board, Agent | None]],
|
||||
list(await session.exec(statement)),
|
||||
)
|
||||
return _coerce_task_comment_rows(list(await session.exec(statement)))
|
||||
|
||||
|
||||
@router.get("", response_model=DefaultLimitOffsetPage[ActivityEventRead])
|
||||
@@ -179,7 +196,7 @@ async def list_task_comment_feed(
|
||||
statement = statement.where(col(Task.id).is_(None))
|
||||
|
||||
def _transform(items: Sequence[Any]) -> Sequence[Any]:
|
||||
rows = cast(Sequence[tuple[ActivityEvent, Task, Board, Agent | None]], items)
|
||||
rows = _coerce_task_comment_rows(items)
|
||||
return [
|
||||
_feed_item(event, task, board, agent)
|
||||
for event, task, board, agent in rows
|
||||
|
||||
@@ -3,9 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
from uuid import UUID
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from sqlmodel import SQLModel, col, select
|
||||
@@ -75,6 +73,9 @@ from app.services.task_dependencies import (
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
from uuid import UUID
|
||||
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.models.activity_events import ActivityEvent
|
||||
@@ -97,6 +98,16 @@ IS_CHAT_QUERY = Query(default=None)
|
||||
APPROVAL_STATUS_QUERY = Query(default=None, alias="status")
|
||||
|
||||
|
||||
def _coerce_agent_items(items: Sequence[Any]) -> list[Agent]:
|
||||
agents: list[Agent] = []
|
||||
for item in items:
|
||||
if not isinstance(item, Agent):
|
||||
msg = "Expected Agent items from paginated query"
|
||||
raise TypeError(msg)
|
||||
agents.append(item)
|
||||
return agents
|
||||
|
||||
|
||||
def _gateway_agent_id(agent: Agent) -> str:
|
||||
session_key = agent.openclaw_session_id or ""
|
||||
if session_key.startswith(_AGENT_SESSION_PREFIX):
|
||||
@@ -248,7 +259,7 @@ async def list_agents(
|
||||
statement = statement.order_by(col(Agent.created_at).desc())
|
||||
|
||||
def _transform(items: Sequence[Any]) -> Sequence[Any]:
|
||||
agents = cast(Sequence[Agent], items)
|
||||
agents = _coerce_agent_items(items)
|
||||
return [
|
||||
agents_api.to_agent_read(
|
||||
agents_api.with_computed_status(agent),
|
||||
@@ -275,7 +286,7 @@ async def list_tasks(
|
||||
unassigned=filters.unassigned,
|
||||
board=board,
|
||||
session=session,
|
||||
actor=_actor(agent_ctx),
|
||||
_actor=_actor(agent_ctx),
|
||||
)
|
||||
|
||||
|
||||
@@ -290,8 +301,8 @@ async def create_task(
|
||||
_guard_board_access(agent_ctx, board)
|
||||
if not agent_ctx.agent.is_board_lead:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
|
||||
data = payload.model_dump()
|
||||
depends_on_task_ids = cast(list[UUID], data.pop("depends_on_task_ids", []) or [])
|
||||
data = payload.model_dump(exclude={"depends_on_task_ids"})
|
||||
depends_on_task_ids = list(payload.depends_on_task_ids)
|
||||
|
||||
task = Task.model_validate(data)
|
||||
task.board_id = board.id
|
||||
@@ -456,7 +467,7 @@ async def list_board_memory(
|
||||
is_chat=is_chat,
|
||||
board=board,
|
||||
session=session,
|
||||
actor=_actor(agent_ctx),
|
||||
_actor=_actor(agent_ctx),
|
||||
)
|
||||
|
||||
|
||||
@@ -493,7 +504,7 @@ async def list_approvals(
|
||||
status_filter=status_filter,
|
||||
board=board,
|
||||
session=session,
|
||||
actor=_actor(agent_ctx),
|
||||
_actor=_actor(agent_ctx),
|
||||
)
|
||||
|
||||
|
||||
@@ -510,7 +521,7 @@ async def create_approval(
|
||||
payload=payload,
|
||||
board=board,
|
||||
session=session,
|
||||
actor=_actor(agent_ctx),
|
||||
_actor=_actor(agent_ctx),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -5,10 +5,9 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
from collections.abc import AsyncIterator, Sequence
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
|
||||
@@ -64,8 +63,11 @@ from app.services.organizations import (
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import AsyncIterator, Sequence
|
||||
|
||||
from sqlalchemy.sql.elements import ColumnElement
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from sqlmodel.sql.expression import SelectOfScalar
|
||||
|
||||
from app.models.users import User
|
||||
|
||||
@@ -232,6 +234,16 @@ def to_agent_read(agent: Agent, main_session_keys: set[str]) -> AgentRead:
|
||||
return _to_agent_read(agent, main_session_keys)
|
||||
|
||||
|
||||
def _coerce_agent_items(items: Sequence[Any]) -> list[Agent]:
|
||||
agents: list[Agent] = []
|
||||
for item in items:
|
||||
if not isinstance(item, Agent):
|
||||
msg = "Expected Agent items from paginated query"
|
||||
raise TypeError(msg)
|
||||
agents.append(item)
|
||||
return agents
|
||||
|
||||
|
||||
async def _find_gateway_for_main_session(
|
||||
session: AsyncSession, session_key: str | None,
|
||||
) -> Gateway | None:
|
||||
@@ -777,7 +789,7 @@ async def _provision_updated_agent(
|
||||
) from exc
|
||||
|
||||
|
||||
def _heartbeat_lookup_statement(payload: AgentHeartbeatCreate) -> object:
|
||||
def _heartbeat_lookup_statement(payload: AgentHeartbeatCreate) -> SelectOfScalar[Agent]:
|
||||
statement = Agent.objects.filter_by(name=payload.name).statement
|
||||
if payload.board_id is not None:
|
||||
statement = statement.where(Agent.board_id == payload.board_id)
|
||||
@@ -973,7 +985,7 @@ async def list_agents(
|
||||
statement = statement.order_by(col(Agent.created_at).desc())
|
||||
|
||||
def _transform(items: Sequence[Any]) -> Sequence[Any]:
|
||||
agents = cast(Sequence[Agent], items)
|
||||
agents = _coerce_agent_items(items)
|
||||
return [
|
||||
_to_agent_read(_with_computed_status(agent), main_session_keys)
|
||||
for agent in agents
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
@@ -235,9 +235,7 @@ def _update_agent_heartbeat(
|
||||
) -> None:
|
||||
raw = agent.heartbeat_config
|
||||
heartbeat: dict[str, Any] = (
|
||||
cast(dict[str, Any], dict(raw))
|
||||
if isinstance(raw, dict)
|
||||
else cast(dict[str, Any], DEFAULT_HEARTBEAT_CONFIG.copy())
|
||||
dict(raw) if isinstance(raw, dict) else DEFAULT_HEARTBEAT_CONFIG.copy()
|
||||
)
|
||||
heartbeat["every"] = payload.every
|
||||
if payload.target is not None:
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterable
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
@@ -56,6 +57,20 @@ def _query_to_resolve_input(params: GatewayResolveQuery) -> GatewayResolveQuery:
|
||||
RESOLVE_INPUT_DEP = Depends(_query_to_resolve_input)
|
||||
|
||||
|
||||
def _as_object_list(value: object) -> list[object]:
|
||||
if value is None:
|
||||
return []
|
||||
if isinstance(value, list):
|
||||
return value
|
||||
if isinstance(value, (tuple, set)):
|
||||
return list(value)
|
||||
if isinstance(value, (str, bytes, dict)):
|
||||
return []
|
||||
if isinstance(value, Iterable):
|
||||
return list(value)
|
||||
return []
|
||||
|
||||
|
||||
async def _resolve_gateway(
|
||||
session: AsyncSession,
|
||||
params: GatewayResolveQuery,
|
||||
@@ -138,9 +153,9 @@ async def gateways_status(
|
||||
try:
|
||||
sessions = await openclaw_call("sessions.list", config=config)
|
||||
if isinstance(sessions, dict):
|
||||
sessions_list = list(sessions.get("sessions") or [])
|
||||
sessions_list = _as_object_list(sessions.get("sessions"))
|
||||
else:
|
||||
sessions_list = list(sessions or [])
|
||||
sessions_list = _as_object_list(sessions)
|
||||
main_session_entry: object | None = None
|
||||
main_session_error: str | None = None
|
||||
if main_session:
|
||||
@@ -190,9 +205,9 @@ async def list_gateway_sessions(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc),
|
||||
) from exc
|
||||
if isinstance(sessions, dict):
|
||||
sessions_list = list(sessions.get("sessions") or [])
|
||||
sessions_list = _as_object_list(sessions.get("sessions"))
|
||||
else:
|
||||
sessions_list = list(sessions or [])
|
||||
sessions_list = _as_object_list(sessions)
|
||||
|
||||
main_session_entry: object | None = None
|
||||
if main_session:
|
||||
@@ -215,9 +230,9 @@ async def list_gateway_sessions(
|
||||
async def _list_sessions(config: GatewayClientConfig) -> list[dict[str, object]]:
|
||||
sessions = await openclaw_call("sessions.list", config=config)
|
||||
if isinstance(sessions, dict):
|
||||
raw_items = sessions.get("sessions") or []
|
||||
raw_items = _as_object_list(sessions.get("sessions"))
|
||||
else:
|
||||
raw_items = sessions or []
|
||||
raw_items = _as_object_list(sessions)
|
||||
return [
|
||||
item
|
||||
for item in raw_items
|
||||
@@ -311,7 +326,7 @@ async def get_session_history(
|
||||
) from exc
|
||||
if isinstance(history, dict) and isinstance(history.get("messages"), list):
|
||||
return GatewaySessionHistoryResponse(history=history["messages"])
|
||||
return GatewaySessionHistoryResponse(history=list(history or []))
|
||||
return GatewaySessionHistoryResponse(history=_as_object_list(history))
|
||||
|
||||
|
||||
@router.post("/sessions/{session_id}/message", response_model=OkResponse)
|
||||
|
||||
@@ -5,7 +5,6 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import json
|
||||
from collections import deque
|
||||
from collections.abc import AsyncIterator, Sequence
|
||||
from contextlib import suppress
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
@@ -15,7 +14,6 @@ from uuid import UUID
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
|
||||
from sqlalchemy import asc, desc, or_
|
||||
from sqlmodel import col, select
|
||||
from sqlmodel.sql.expression import Select
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
|
||||
from app.api.deps import (
|
||||
@@ -67,7 +65,10 @@ from app.services.task_dependencies import (
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import AsyncIterator, Sequence
|
||||
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from sqlmodel.sql.expression import Select, SelectOfScalar
|
||||
|
||||
from app.core.auth import AuthContext
|
||||
from app.models.users import User
|
||||
@@ -157,6 +158,16 @@ def _parse_since(value: str | None) -> datetime | None:
|
||||
return parsed
|
||||
|
||||
|
||||
def _coerce_task_items(items: Sequence[object]) -> list[Task]:
|
||||
tasks: list[Task] = []
|
||||
for item in items:
|
||||
if not isinstance(item, Task):
|
||||
msg = "Expected Task items from paginated query"
|
||||
raise TypeError(msg)
|
||||
tasks.append(item)
|
||||
return tasks
|
||||
|
||||
|
||||
async def _lead_was_mentioned(
|
||||
session: AsyncSession,
|
||||
task: Task,
|
||||
@@ -266,7 +277,7 @@ async def _fetch_task_events(
|
||||
if not task_ids:
|
||||
return []
|
||||
statement = cast(
|
||||
Select[tuple[ActivityEvent, Task | None]],
|
||||
"Select[tuple[ActivityEvent, Task | None]]",
|
||||
select(ActivityEvent, Task)
|
||||
.outerjoin(Task, col(ActivityEvent.task_id) == col(Task.id))
|
||||
.where(col(ActivityEvent.task_id).in_(task_ids))
|
||||
@@ -512,7 +523,7 @@ def _task_list_statement(
|
||||
status_filter: str | None,
|
||||
assigned_agent_id: UUID | None,
|
||||
unassigned: bool | None,
|
||||
) -> object:
|
||||
) -> SelectOfScalar[Task]:
|
||||
statement = select(Task).where(Task.board_id == board_id)
|
||||
statuses = _status_values(status_filter)
|
||||
if statuses:
|
||||
@@ -717,7 +728,7 @@ async def list_tasks(
|
||||
)
|
||||
|
||||
async def _transform(items: Sequence[object]) -> Sequence[object]:
|
||||
tasks = cast(Sequence[Task], items)
|
||||
tasks = _coerce_task_items(items)
|
||||
return await _task_read_page(
|
||||
session=session,
|
||||
board_id=board.id,
|
||||
@@ -735,8 +746,8 @@ async def create_task(
|
||||
auth: AuthContext = ADMIN_AUTH_DEP,
|
||||
) -> TaskRead:
|
||||
"""Create a task and initialize dependency rows."""
|
||||
data = payload.model_dump()
|
||||
depends_on_task_ids = cast(list[UUID], data.pop("depends_on_task_ids", []) or [])
|
||||
data = payload.model_dump(exclude={"depends_on_task_ids"})
|
||||
depends_on_task_ids = list(payload.depends_on_task_ids)
|
||||
|
||||
task = Task.model_validate(data)
|
||||
task.board_id = board.id
|
||||
@@ -828,10 +839,14 @@ async def update_task(
|
||||
previous_status = task.status
|
||||
previous_assigned = task.assigned_agent_id
|
||||
updates = payload.model_dump(exclude_unset=True)
|
||||
comment = cast(str | None, updates.pop("comment", None))
|
||||
depends_on_task_ids = cast(
|
||||
list[UUID] | None, updates.pop("depends_on_task_ids", None),
|
||||
comment = payload.comment if "comment" in payload.model_fields_set else None
|
||||
depends_on_task_ids = (
|
||||
payload.depends_on_task_ids
|
||||
if "depends_on_task_ids" in payload.model_fields_set
|
||||
else None
|
||||
)
|
||||
updates.pop("comment", None)
|
||||
updates.pop("depends_on_task_ids", None)
|
||||
update = _TaskUpdateInput(
|
||||
task=task,
|
||||
actor=actor,
|
||||
@@ -960,7 +975,7 @@ async def _comment_targets(
|
||||
task: Task,
|
||||
message: str,
|
||||
actor: ActorContext,
|
||||
) -> tuple[dict[UUID, Agent], list[str]]:
|
||||
) -> tuple[dict[UUID, Agent], set[str]]:
|
||||
mention_names = extract_mentions(message)
|
||||
targets: dict[UUID, Agent] = {}
|
||||
if mention_names and task.board_id:
|
||||
@@ -985,7 +1000,7 @@ class _TaskCommentNotifyRequest:
|
||||
actor: ActorContext
|
||||
message: str
|
||||
targets: dict[UUID, Agent]
|
||||
mention_names: list[str]
|
||||
mention_names: set[str]
|
||||
|
||||
|
||||
async def _notify_task_comment_targets(
|
||||
@@ -1048,6 +1063,18 @@ class _TaskUpdateInput:
|
||||
depends_on_task_ids: list[UUID] | None
|
||||
|
||||
|
||||
def _required_status_value(value: object) -> str:
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY)
|
||||
|
||||
|
||||
def _optional_assigned_agent_id(value: object) -> UUID | None:
|
||||
if value is None or isinstance(value, UUID):
|
||||
return value
|
||||
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY)
|
||||
|
||||
|
||||
async def _task_dep_ids(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
@@ -1182,7 +1209,7 @@ async def _lead_apply_assignment(
|
||||
) -> None:
|
||||
if "assigned_agent_id" not in update.updates:
|
||||
return
|
||||
assigned_id = cast(UUID | None, update.updates["assigned_agent_id"])
|
||||
assigned_id = _optional_assigned_agent_id(update.updates["assigned_agent_id"])
|
||||
if not assigned_id:
|
||||
update.task.assigned_agent_id = None
|
||||
return
|
||||
@@ -1214,7 +1241,7 @@ def _lead_apply_status(update: _TaskUpdateInput) -> None:
|
||||
"in review."
|
||||
),
|
||||
)
|
||||
target_status = cast(str, update.updates["status"])
|
||||
target_status = _required_status_value(update.updates["status"])
|
||||
if target_status not in {"done", "inbox"}:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
@@ -1332,7 +1359,7 @@ async def _apply_non_lead_agent_task_rules(
|
||||
):
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
|
||||
if "status" in update.updates:
|
||||
status_value = cast(str, update.updates["status"])
|
||||
status_value = _required_status_value(update.updates["status"])
|
||||
if status_value != "inbox":
|
||||
dep_ids = await _task_dep_ids(
|
||||
session,
|
||||
@@ -1390,7 +1417,9 @@ async def _apply_admin_task_rules(
|
||||
board_id=update.board_id,
|
||||
dep_ids=effective_deps,
|
||||
)
|
||||
target_status = cast(str, update.updates.get("status", update.task.status))
|
||||
target_status = _required_status_value(
|
||||
update.updates.get("status", update.task.status),
|
||||
)
|
||||
if blocked_ids and not (update.task.status == "done" and target_status == "done"):
|
||||
update.task.status = "inbox"
|
||||
update.task.assigned_agent_id = None
|
||||
@@ -1399,14 +1428,16 @@ async def _apply_admin_task_rules(
|
||||
update.updates["assigned_agent_id"] = None
|
||||
|
||||
if "status" in update.updates:
|
||||
status_value = cast(str, update.updates["status"])
|
||||
status_value = _required_status_value(update.updates["status"])
|
||||
if status_value == "inbox":
|
||||
update.task.assigned_agent_id = None
|
||||
update.task.in_progress_at = None
|
||||
elif status_value == "in_progress":
|
||||
update.task.in_progress_at = utcnow()
|
||||
|
||||
assigned_agent_id = cast(UUID | None, update.updates.get("assigned_agent_id"))
|
||||
assigned_agent_id = _optional_assigned_agent_id(
|
||||
update.updates.get("assigned_agent_id"),
|
||||
)
|
||||
if assigned_agent_id:
|
||||
agent = await Agent.objects.by_id(assigned_agent_id).first(session)
|
||||
if agent is None:
|
||||
@@ -1530,7 +1561,8 @@ async def _finalize_updated_task(
|
||||
setattr(update.task, key, value)
|
||||
update.task.updated_at = utcnow()
|
||||
|
||||
if "status" in update.updates and cast(str, update.updates["status"]) == "review":
|
||||
status_raw = update.updates.get("status")
|
||||
if status_raw == "review":
|
||||
comment_text = (update.comment or "").strip()
|
||||
if not comment_text and not await has_valid_recent_comment(
|
||||
session,
|
||||
|
||||
Reference in New Issue
Block a user