refactor: improve type coercion functions and enhance type hints across multiple files

This commit is contained in:
Abhimanyu Saharan
2026-02-09 17:43:42 +05:30
parent f5d592f61a
commit dddd1e9a7a
13 changed files with 217 additions and 64 deletions

View File

@@ -5,7 +5,6 @@ from __future__ import annotations
import asyncio
import json
from collections import deque
from collections.abc import AsyncIterator, Sequence
from contextlib import suppress
from dataclasses import dataclass
from datetime import datetime, timezone
@@ -15,7 +14,6 @@ from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
from sqlalchemy import asc, desc, or_
from sqlmodel import col, select
from sqlmodel.sql.expression import Select
from sse_starlette.sse import EventSourceResponse
from app.api.deps import (
@@ -67,7 +65,10 @@ from app.services.task_dependencies import (
)
if TYPE_CHECKING:
from collections.abc import AsyncIterator, Sequence
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodel.sql.expression import Select, SelectOfScalar
from app.core.auth import AuthContext
from app.models.users import User
@@ -157,6 +158,16 @@ def _parse_since(value: str | None) -> datetime | None:
return parsed
def _coerce_task_items(items: Sequence[object]) -> list[Task]:
tasks: list[Task] = []
for item in items:
if not isinstance(item, Task):
msg = "Expected Task items from paginated query"
raise TypeError(msg)
tasks.append(item)
return tasks
async def _lead_was_mentioned(
session: AsyncSession,
task: Task,
@@ -266,7 +277,7 @@ async def _fetch_task_events(
if not task_ids:
return []
statement = cast(
Select[tuple[ActivityEvent, Task | None]],
"Select[tuple[ActivityEvent, Task | None]]",
select(ActivityEvent, Task)
.outerjoin(Task, col(ActivityEvent.task_id) == col(Task.id))
.where(col(ActivityEvent.task_id).in_(task_ids))
@@ -512,7 +523,7 @@ def _task_list_statement(
status_filter: str | None,
assigned_agent_id: UUID | None,
unassigned: bool | None,
) -> object:
) -> SelectOfScalar[Task]:
statement = select(Task).where(Task.board_id == board_id)
statuses = _status_values(status_filter)
if statuses:
@@ -717,7 +728,7 @@ async def list_tasks(
)
async def _transform(items: Sequence[object]) -> Sequence[object]:
tasks = cast(Sequence[Task], items)
tasks = _coerce_task_items(items)
return await _task_read_page(
session=session,
board_id=board.id,
@@ -735,8 +746,8 @@ async def create_task(
auth: AuthContext = ADMIN_AUTH_DEP,
) -> TaskRead:
"""Create a task and initialize dependency rows."""
data = payload.model_dump()
depends_on_task_ids = cast(list[UUID], data.pop("depends_on_task_ids", []) or [])
data = payload.model_dump(exclude={"depends_on_task_ids"})
depends_on_task_ids = list(payload.depends_on_task_ids)
task = Task.model_validate(data)
task.board_id = board.id
@@ -828,10 +839,14 @@ async def update_task(
previous_status = task.status
previous_assigned = task.assigned_agent_id
updates = payload.model_dump(exclude_unset=True)
comment = cast(str | None, updates.pop("comment", None))
depends_on_task_ids = cast(
list[UUID] | None, updates.pop("depends_on_task_ids", None),
comment = payload.comment if "comment" in payload.model_fields_set else None
depends_on_task_ids = (
payload.depends_on_task_ids
if "depends_on_task_ids" in payload.model_fields_set
else None
)
updates.pop("comment", None)
updates.pop("depends_on_task_ids", None)
update = _TaskUpdateInput(
task=task,
actor=actor,
@@ -960,7 +975,7 @@ async def _comment_targets(
task: Task,
message: str,
actor: ActorContext,
) -> tuple[dict[UUID, Agent], list[str]]:
) -> tuple[dict[UUID, Agent], set[str]]:
mention_names = extract_mentions(message)
targets: dict[UUID, Agent] = {}
if mention_names and task.board_id:
@@ -985,7 +1000,7 @@ class _TaskCommentNotifyRequest:
actor: ActorContext
message: str
targets: dict[UUID, Agent]
mention_names: list[str]
mention_names: set[str]
async def _notify_task_comment_targets(
@@ -1048,6 +1063,18 @@ class _TaskUpdateInput:
depends_on_task_ids: list[UUID] | None
def _required_status_value(value: object) -> str:
if isinstance(value, str):
return value
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY)
def _optional_assigned_agent_id(value: object) -> UUID | None:
if value is None or isinstance(value, UUID):
return value
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY)
async def _task_dep_ids(
session: AsyncSession,
*,
@@ -1182,7 +1209,7 @@ async def _lead_apply_assignment(
) -> None:
if "assigned_agent_id" not in update.updates:
return
assigned_id = cast(UUID | None, update.updates["assigned_agent_id"])
assigned_id = _optional_assigned_agent_id(update.updates["assigned_agent_id"])
if not assigned_id:
update.task.assigned_agent_id = None
return
@@ -1214,7 +1241,7 @@ def _lead_apply_status(update: _TaskUpdateInput) -> None:
"in review."
),
)
target_status = cast(str, update.updates["status"])
target_status = _required_status_value(update.updates["status"])
if target_status not in {"done", "inbox"}:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
@@ -1332,7 +1359,7 @@ async def _apply_non_lead_agent_task_rules(
):
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
if "status" in update.updates:
status_value = cast(str, update.updates["status"])
status_value = _required_status_value(update.updates["status"])
if status_value != "inbox":
dep_ids = await _task_dep_ids(
session,
@@ -1390,7 +1417,9 @@ async def _apply_admin_task_rules(
board_id=update.board_id,
dep_ids=effective_deps,
)
target_status = cast(str, update.updates.get("status", update.task.status))
target_status = _required_status_value(
update.updates.get("status", update.task.status),
)
if blocked_ids and not (update.task.status == "done" and target_status == "done"):
update.task.status = "inbox"
update.task.assigned_agent_id = None
@@ -1399,14 +1428,16 @@ async def _apply_admin_task_rules(
update.updates["assigned_agent_id"] = None
if "status" in update.updates:
status_value = cast(str, update.updates["status"])
status_value = _required_status_value(update.updates["status"])
if status_value == "inbox":
update.task.assigned_agent_id = None
update.task.in_progress_at = None
elif status_value == "in_progress":
update.task.in_progress_at = utcnow()
assigned_agent_id = cast(UUID | None, update.updates.get("assigned_agent_id"))
assigned_agent_id = _optional_assigned_agent_id(
update.updates.get("assigned_agent_id"),
)
if assigned_agent_id:
agent = await Agent.objects.by_id(assigned_agent_id).first(session)
if agent is None:
@@ -1530,7 +1561,8 @@ async def _finalize_updated_task(
setattr(update.task, key, value)
update.task.updated_at = utcnow()
if "status" in update.updates and cast(str, update.updates["status"]) == "review":
status_raw = update.updates.get("status")
if status_raw == "review":
comment_text = (update.comment or "").strip()
if not comment_text and not await has_valid_recent_comment(
session,