refactor: improve type coercion functions and enhance type hints across multiple files
This commit is contained in:
@@ -5,7 +5,6 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import json
|
||||
from collections import deque
|
||||
from collections.abc import AsyncIterator, Sequence
|
||||
from contextlib import suppress
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
@@ -15,7 +14,6 @@ from uuid import UUID
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
|
||||
from sqlalchemy import asc, desc, or_
|
||||
from sqlmodel import col, select
|
||||
from sqlmodel.sql.expression import Select
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
|
||||
from app.api.deps import (
|
||||
@@ -67,7 +65,10 @@ from app.services.task_dependencies import (
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import AsyncIterator, Sequence
|
||||
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from sqlmodel.sql.expression import Select, SelectOfScalar
|
||||
|
||||
from app.core.auth import AuthContext
|
||||
from app.models.users import User
|
||||
@@ -157,6 +158,16 @@ def _parse_since(value: str | None) -> datetime | None:
|
||||
return parsed
|
||||
|
||||
|
||||
def _coerce_task_items(items: Sequence[object]) -> list[Task]:
|
||||
tasks: list[Task] = []
|
||||
for item in items:
|
||||
if not isinstance(item, Task):
|
||||
msg = "Expected Task items from paginated query"
|
||||
raise TypeError(msg)
|
||||
tasks.append(item)
|
||||
return tasks
|
||||
|
||||
|
||||
async def _lead_was_mentioned(
|
||||
session: AsyncSession,
|
||||
task: Task,
|
||||
@@ -266,7 +277,7 @@ async def _fetch_task_events(
|
||||
if not task_ids:
|
||||
return []
|
||||
statement = cast(
|
||||
Select[tuple[ActivityEvent, Task | None]],
|
||||
"Select[tuple[ActivityEvent, Task | None]]",
|
||||
select(ActivityEvent, Task)
|
||||
.outerjoin(Task, col(ActivityEvent.task_id) == col(Task.id))
|
||||
.where(col(ActivityEvent.task_id).in_(task_ids))
|
||||
@@ -512,7 +523,7 @@ def _task_list_statement(
|
||||
status_filter: str | None,
|
||||
assigned_agent_id: UUID | None,
|
||||
unassigned: bool | None,
|
||||
) -> object:
|
||||
) -> SelectOfScalar[Task]:
|
||||
statement = select(Task).where(Task.board_id == board_id)
|
||||
statuses = _status_values(status_filter)
|
||||
if statuses:
|
||||
@@ -717,7 +728,7 @@ async def list_tasks(
|
||||
)
|
||||
|
||||
async def _transform(items: Sequence[object]) -> Sequence[object]:
|
||||
tasks = cast(Sequence[Task], items)
|
||||
tasks = _coerce_task_items(items)
|
||||
return await _task_read_page(
|
||||
session=session,
|
||||
board_id=board.id,
|
||||
@@ -735,8 +746,8 @@ async def create_task(
|
||||
auth: AuthContext = ADMIN_AUTH_DEP,
|
||||
) -> TaskRead:
|
||||
"""Create a task and initialize dependency rows."""
|
||||
data = payload.model_dump()
|
||||
depends_on_task_ids = cast(list[UUID], data.pop("depends_on_task_ids", []) or [])
|
||||
data = payload.model_dump(exclude={"depends_on_task_ids"})
|
||||
depends_on_task_ids = list(payload.depends_on_task_ids)
|
||||
|
||||
task = Task.model_validate(data)
|
||||
task.board_id = board.id
|
||||
@@ -828,10 +839,14 @@ async def update_task(
|
||||
previous_status = task.status
|
||||
previous_assigned = task.assigned_agent_id
|
||||
updates = payload.model_dump(exclude_unset=True)
|
||||
comment = cast(str | None, updates.pop("comment", None))
|
||||
depends_on_task_ids = cast(
|
||||
list[UUID] | None, updates.pop("depends_on_task_ids", None),
|
||||
comment = payload.comment if "comment" in payload.model_fields_set else None
|
||||
depends_on_task_ids = (
|
||||
payload.depends_on_task_ids
|
||||
if "depends_on_task_ids" in payload.model_fields_set
|
||||
else None
|
||||
)
|
||||
updates.pop("comment", None)
|
||||
updates.pop("depends_on_task_ids", None)
|
||||
update = _TaskUpdateInput(
|
||||
task=task,
|
||||
actor=actor,
|
||||
@@ -960,7 +975,7 @@ async def _comment_targets(
|
||||
task: Task,
|
||||
message: str,
|
||||
actor: ActorContext,
|
||||
) -> tuple[dict[UUID, Agent], list[str]]:
|
||||
) -> tuple[dict[UUID, Agent], set[str]]:
|
||||
mention_names = extract_mentions(message)
|
||||
targets: dict[UUID, Agent] = {}
|
||||
if mention_names and task.board_id:
|
||||
@@ -985,7 +1000,7 @@ class _TaskCommentNotifyRequest:
|
||||
actor: ActorContext
|
||||
message: str
|
||||
targets: dict[UUID, Agent]
|
||||
mention_names: list[str]
|
||||
mention_names: set[str]
|
||||
|
||||
|
||||
async def _notify_task_comment_targets(
|
||||
@@ -1048,6 +1063,18 @@ class _TaskUpdateInput:
|
||||
depends_on_task_ids: list[UUID] | None
|
||||
|
||||
|
||||
def _required_status_value(value: object) -> str:
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY)
|
||||
|
||||
|
||||
def _optional_assigned_agent_id(value: object) -> UUID | None:
|
||||
if value is None or isinstance(value, UUID):
|
||||
return value
|
||||
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY)
|
||||
|
||||
|
||||
async def _task_dep_ids(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
@@ -1182,7 +1209,7 @@ async def _lead_apply_assignment(
|
||||
) -> None:
|
||||
if "assigned_agent_id" not in update.updates:
|
||||
return
|
||||
assigned_id = cast(UUID | None, update.updates["assigned_agent_id"])
|
||||
assigned_id = _optional_assigned_agent_id(update.updates["assigned_agent_id"])
|
||||
if not assigned_id:
|
||||
update.task.assigned_agent_id = None
|
||||
return
|
||||
@@ -1214,7 +1241,7 @@ def _lead_apply_status(update: _TaskUpdateInput) -> None:
|
||||
"in review."
|
||||
),
|
||||
)
|
||||
target_status = cast(str, update.updates["status"])
|
||||
target_status = _required_status_value(update.updates["status"])
|
||||
if target_status not in {"done", "inbox"}:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
@@ -1332,7 +1359,7 @@ async def _apply_non_lead_agent_task_rules(
|
||||
):
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
|
||||
if "status" in update.updates:
|
||||
status_value = cast(str, update.updates["status"])
|
||||
status_value = _required_status_value(update.updates["status"])
|
||||
if status_value != "inbox":
|
||||
dep_ids = await _task_dep_ids(
|
||||
session,
|
||||
@@ -1390,7 +1417,9 @@ async def _apply_admin_task_rules(
|
||||
board_id=update.board_id,
|
||||
dep_ids=effective_deps,
|
||||
)
|
||||
target_status = cast(str, update.updates.get("status", update.task.status))
|
||||
target_status = _required_status_value(
|
||||
update.updates.get("status", update.task.status),
|
||||
)
|
||||
if blocked_ids and not (update.task.status == "done" and target_status == "done"):
|
||||
update.task.status = "inbox"
|
||||
update.task.assigned_agent_id = None
|
||||
@@ -1399,14 +1428,16 @@ async def _apply_admin_task_rules(
|
||||
update.updates["assigned_agent_id"] = None
|
||||
|
||||
if "status" in update.updates:
|
||||
status_value = cast(str, update.updates["status"])
|
||||
status_value = _required_status_value(update.updates["status"])
|
||||
if status_value == "inbox":
|
||||
update.task.assigned_agent_id = None
|
||||
update.task.in_progress_at = None
|
||||
elif status_value == "in_progress":
|
||||
update.task.in_progress_at = utcnow()
|
||||
|
||||
assigned_agent_id = cast(UUID | None, update.updates.get("assigned_agent_id"))
|
||||
assigned_agent_id = _optional_assigned_agent_id(
|
||||
update.updates.get("assigned_agent_id"),
|
||||
)
|
||||
if assigned_agent_id:
|
||||
agent = await Agent.objects.by_id(assigned_agent_id).first(session)
|
||||
if agent is None:
|
||||
@@ -1530,7 +1561,8 @@ async def _finalize_updated_task(
|
||||
setattr(update.task, key, value)
|
||||
update.task.updated_at = utcnow()
|
||||
|
||||
if "status" in update.updates and cast(str, update.updates["status"]) == "review":
|
||||
status_raw = update.updates.get("status")
|
||||
if status_raw == "review":
|
||||
comment_text = (update.comment or "").strip()
|
||||
if not comment_text and not await has_valid_recent_comment(
|
||||
session,
|
||||
|
||||
Reference in New Issue
Block a user