fix(api): make /tasks + /task-comments atomic and return full JSON

This commit is contained in:
Abhimanyu Saharan
2026-02-02 14:00:46 +05:30
parent d5f527f311
commit f1fe2127da

View File

@@ -4,6 +4,7 @@ from datetime import datetime
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException
from sqlmodel import Session, select from sqlmodel import Session, select
from sqlalchemy.exc import IntegrityError
from app.api.utils import log_activity, get_actor_employee_id from app.api.utils import log_activity, get_actor_employee_id
from app.db.session import get_session from app.db.session import get_session
@@ -27,13 +28,15 @@ def list_tasks(project_id: int | None = None, session: Session = Depends(get_ses
def create_task(payload: TaskCreate, session: Session = Depends(get_session), actor_employee_id: int = Depends(get_actor_employee_id)): def create_task(payload: TaskCreate, session: Session = Depends(get_session), actor_employee_id: int = Depends(get_actor_employee_id)):
if payload.created_by_employee_id is None: if payload.created_by_employee_id is None:
payload = TaskCreate(**{**payload.model_dump(), "created_by_employee_id": actor_employee_id}) payload = TaskCreate(**{**payload.model_dump(), "created_by_employee_id": actor_employee_id})
task = Task(**payload.model_dump()) task = Task(**payload.model_dump())
if task.status not in ALLOWED_STATUSES: if task.status not in ALLOWED_STATUSES:
raise HTTPException(status_code=400, detail="Invalid status") raise HTTPException(status_code=400, detail="Invalid status")
task.updated_at = datetime.utcnow() task.updated_at = datetime.utcnow()
session.add(task) session.add(task)
session.commit()
session.refresh(task) try:
session.flush()
log_activity( log_activity(
session, session,
actor_employee_id=actor_employee_id, actor_employee_id=actor_employee_id,
@@ -43,7 +46,13 @@ def create_task(payload: TaskCreate, session: Session = Depends(get_session), ac
payload={"project_id": task.project_id, "title": task.title}, payload={"project_id": task.project_id, "title": task.title},
) )
session.commit() session.commit()
return task except IntegrityError:
session.rollback()
raise HTTPException(status_code=409, detail="Task create violates constraints")
session.refresh(task)
# Explicitly return a serializable payload (guards against empty {} responses)
return Task.model_validate(task)
@router.patch("/tasks/{task_id}", response_model=Task) @router.patch("/tasks/{task_id}", response_model=Task)
@@ -55,16 +64,22 @@ def update_task(task_id: int, payload: TaskUpdate, session: Session = Depends(ge
data = payload.model_dump(exclude_unset=True) data = payload.model_dump(exclude_unset=True)
if "status" in data and data["status"] not in ALLOWED_STATUSES: if "status" in data and data["status"] not in ALLOWED_STATUSES:
raise HTTPException(status_code=400, detail="Invalid status") raise HTTPException(status_code=400, detail="Invalid status")
for k, v in data.items(): for k, v in data.items():
setattr(task, k, v) setattr(task, k, v)
task.updated_at = datetime.utcnow() task.updated_at = datetime.utcnow()
session.add(task) session.add(task)
session.commit()
session.refresh(task) try:
session.flush()
log_activity(session, actor_employee_id=actor_employee_id, entity_type="task", entity_id=task.id, verb="updated", payload=data) log_activity(session, actor_employee_id=actor_employee_id, entity_type="task", entity_id=task.id, verb="updated", payload=data)
session.commit() session.commit()
return task except IntegrityError:
session.rollback()
raise HTTPException(status_code=409, detail="Task update violates constraints")
session.refresh(task)
return Task.model_validate(task)
@router.delete("/tasks/{task_id}") @router.delete("/tasks/{task_id}")
@@ -72,10 +87,16 @@ def delete_task(task_id: int, session: Session = Depends(get_session), actor_emp
task = session.get(Task, task_id) task = session.get(Task, task_id)
if not task: if not task:
raise HTTPException(status_code=404, detail="Task not found") raise HTTPException(status_code=404, detail="Task not found")
session.delete(task) session.delete(task)
session.commit() try:
session.flush()
log_activity(session, actor_employee_id=actor_employee_id, entity_type="task", entity_id=task_id, verb="deleted") log_activity(session, actor_employee_id=actor_employee_id, entity_type="task", entity_id=task_id, verb="deleted")
session.commit() session.commit()
except IntegrityError:
session.rollback()
raise HTTPException(status_code=409, detail="Task delete violates constraints")
return {"ok": True} return {"ok": True}
@@ -88,16 +109,17 @@ def list_task_comments(task_id: int, session: Session = Depends(get_session)):
def create_task_comment(payload: TaskCommentCreate, session: Session = Depends(get_session), actor_employee_id: int = Depends(get_actor_employee_id)): def create_task_comment(payload: TaskCommentCreate, session: Session = Depends(get_session), actor_employee_id: int = Depends(get_actor_employee_id)):
if payload.author_employee_id is None: if payload.author_employee_id is None:
payload = TaskCommentCreate(**{**payload.model_dump(), "author_employee_id": actor_employee_id}) payload = TaskCommentCreate(**{**payload.model_dump(), "author_employee_id": actor_employee_id})
c = TaskComment(**payload.model_dump())
# Validate reply target (must exist + belong to same task) c = TaskComment(**payload.model_dump())
if c.reply_to_comment_id is not None:
parent = session.get(TaskComment, c.reply_to_comment_id)
if parent is None or parent.task_id != c.task_id:
raise HTTPException(status_code=400, detail="Invalid reply_to_comment_id")
session.add(c) session.add(c)
session.commit()
session.refresh(c) try:
session.flush()
log_activity(session, actor_employee_id=actor_employee_id, entity_type="task", entity_id=c.task_id, verb="commented") log_activity(session, actor_employee_id=actor_employee_id, entity_type="task", entity_id=c.task_id, verb="commented")
session.commit() session.commit()
return c except IntegrityError:
session.rollback()
raise HTTPException(status_code=409, detail="Comment create violates constraints")
session.refresh(c)
return TaskComment.model_validate(c)