From 009218158b9ad6232ee9815e8ebd2ceb89ac48e8 Mon Sep 17 00:00:00 2001 From: Abhimanyu Saharan Date: Sun, 15 Feb 2026 05:50:03 +0000 Subject: [PATCH] feat(github): add mission-control approval check-run gate --- backend/.env.example | 4 + backend/app/api/approvals.py | 39 ++ backend/app/api/tasks.py | 23 + backend/app/core/config.py | 9 +- backend/app/services/github/__init__.py | 1 + backend/app/services/github/client.py | 195 ++++++++ .../github/mission_control_approval_check.py | 376 ++++++++++++++++ .../test_mission_control_approval_check.py | 416 ++++++++++++++++++ 8 files changed, 1062 insertions(+), 1 deletion(-) create mode 100644 backend/app/services/github/__init__.py create mode 100644 backend/app/services/github/client.py create mode 100644 backend/app/services/github/mission_control_approval_check.py create mode 100644 backend/tests/test_mission_control_approval_check.py diff --git a/backend/.env.example b/backend/.env.example index 6f4b190f..c48f477c 100644 --- a/backend/.env.example +++ b/backend/.env.example @@ -19,6 +19,10 @@ CLERK_VERIFY_IAT=true CLERK_LEEWAY=10.0 # Database DB_AUTO_MIGRATE=false +# GitHub integration (for Check Runs / required-check enforcement) +# Used by mission-control/approval check updater. +GH_TOKEN= + # Webhook queue / worker WEBHOOK_REDIS_URL=redis://localhost:6379/0 WEBHOOK_QUEUE_NAME=webhook-dispatch diff --git a/backend/app/api/approvals.py b/backend/app/api/approvals.py index 723444c0..67559502 100644 --- a/backend/app/api/approvals.py +++ b/backend/app/api/approvals.py @@ -38,6 +38,10 @@ from app.services.approval_task_links import ( replace_approval_task_links, task_counts_for_board, ) +from app.services.github.mission_control_approval_check import ( + github_approval_check_enabled, + sync_github_approval_check_for_task_ids, +) from app.services.openclaw.gateway_dispatch import GatewayDispatchService if TYPE_CHECKING: @@ -426,6 +430,20 @@ async def create_approval( await session.commit() await session.refresh(approval) title_by_id = await _task_titles_by_id(session, task_ids=set(task_ids)) + + if github_approval_check_enabled() and task_ids: + try: + await sync_github_approval_check_for_task_ids( + session, + board_id=board.id, + task_ids=list(task_ids), + ) + except Exception: + logger.exception( + "approval.github_check_sync_failed", + extra={"board_id": str(board.id), "task_ids": [str(tid) for tid in task_ids]}, + ) + return _approval_to_read( approval, task_ids=task_ids, @@ -481,5 +499,26 @@ async def update_approval( approval.id, approval.status, ) + if github_approval_check_enabled(): + try: + task_ids_by_approval = await load_task_ids_by_approval( + session, + approval_ids=[approval.id], + ) + approval_task_ids = task_ids_by_approval.get(approval.id) or [] + if not approval_task_ids and approval.task_id is not None: + approval_task_ids = [approval.task_id] + if approval_task_ids: + await sync_github_approval_check_for_task_ids( + session, + board_id=board.id, + task_ids=list(approval_task_ids), + ) + except Exception: + logger.exception( + "approval.github_check_sync_failed", + extra={"board_id": str(board.id), "approval_id": str(approval.id)}, + ) + reads = await _approval_reads(session, [approval]) return reads[0] diff --git a/backend/app/api/tasks.py b/backend/app/api/tasks.py index 79c2d1f3..3d3a274d 100644 --- a/backend/app/api/tasks.py +++ b/backend/app/api/tasks.py @@ -23,6 +23,7 @@ from app.api.deps import ( require_admin_auth, require_admin_or_agent, ) +from app.core.logging import get_logger from app.core.time import utcnow from app.db import crud from app.db.pagination import paginate @@ -56,6 +57,10 @@ from app.services.approval_task_links import ( load_task_ids_by_approval, pending_approval_conflicts_by_task, ) +from app.services.github.mission_control_approval_check import ( + github_approval_check_enabled, + sync_github_approval_check_for_pr_url, +) from app.services.mentions import extract_mentions, matches_agent_mention from app.services.openclaw.gateway_dispatch import GatewayDispatchService from app.services.openclaw.gateway_rpc import GatewayConfig as GatewayClientConfig @@ -76,6 +81,8 @@ from app.services.task_dependencies import ( validate_dependency_update, ) +logger = get_logger(__name__) + if TYPE_CHECKING: from collections.abc import AsyncIterator, Sequence @@ -2402,6 +2409,22 @@ async def _finalize_updated_task( await _record_task_update_activity(session, update=update) await _notify_task_update_assignment_changes(session, update=update) + # Sync GitHub approval gate check when a task's PR link changes. + if github_approval_check_enabled() and update.custom_field_values_set: + pr_url = update.custom_field_values.get("github_pr_url") + if isinstance(pr_url, str) and pr_url.strip(): + try: + await sync_github_approval_check_for_pr_url( + session, + board_id=update.board_id, + pr_url=pr_url.strip(), + ) + except Exception: + logger.exception( + "task.github_check_sync_failed", + extra={"board_id": str(update.board_id), "task_id": str(update.task.id)}, + ) + return await _task_read_response( session, task=update.task, diff --git a/backend/app/core/config.py b/backend/app/core/config.py index 91a5603d..cd0fe2c3 100644 --- a/backend/app/core/config.py +++ b/backend/app/core/config.py @@ -5,7 +5,7 @@ from __future__ import annotations from pathlib import Path from typing import Self -from pydantic import Field, model_validator +from pydantic import AliasChoices, Field, model_validator from pydantic_settings import BaseSettings, SettingsConfigDict from app.core.auth_mode import AuthMode @@ -50,6 +50,13 @@ class Settings(BaseSettings): cors_origins: str = "" base_url: str = "" + # GitHub integration + # Token used for GitHub REST API calls (checks/status updates). Supports GH_TOKEN or GITHUB_TOKEN. + github_token: str = Field( + default="", + validation_alias=AliasChoices("GH_TOKEN", "GITHUB_TOKEN"), + ) + # Database lifecycle db_auto_migrate: bool = False diff --git a/backend/app/services/github/__init__.py b/backend/app/services/github/__init__.py new file mode 100644 index 00000000..b4130170 --- /dev/null +++ b/backend/app/services/github/__init__.py @@ -0,0 +1 @@ +"""GitHub integration services (checks, statuses, PR metadata).""" diff --git a/backend/app/services/github/client.py b/backend/app/services/github/client.py new file mode 100644 index 00000000..babfd9d3 --- /dev/null +++ b/backend/app/services/github/client.py @@ -0,0 +1,195 @@ +"""Minimal GitHub REST client used for merge-policy enforcement. + +This module is intentionally small and purpose-built for: +- PR metadata lookup (head SHA) +- Check Runs upsert (create or update by name) + +It uses a repo-scoped token (PAT or GitHub App token) provided via settings. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Literal + +import httpx + +from app.core.config import settings +from app.core.logging import get_logger + +logger = get_logger(__name__) + +GITHUB_API_BASE_URL = "https://api.github.com" +GITHUB_API_VERSION = "2022-11-28" + + +@dataclass(frozen=True) +class ParsedPullRequest: + owner: str + repo: str + number: int + url: str + + +def parse_pull_request_url(url: str) -> ParsedPullRequest | None: + """Parse a GitHub PR URL: https://github.com///pull/.""" + raw = (url or "").strip() + if not raw: + return None + if raw.startswith("http://"): + # normalize; we only accept github.com URLs + raw = "https://" + raw.removeprefix("http://") + if not raw.startswith("https://github.com/"): + return None + path = raw.removeprefix("https://github.com/") + parts = [p for p in path.split("/") if p] + if len(parts) < 4: + return None + owner, repo, kind, num = parts[0], parts[1], parts[2], parts[3] + if kind != "pull": + return None + try: + number = int(num) + except ValueError: + return None + if number <= 0: + return None + return ParsedPullRequest(owner=owner, repo=repo, number=number, url=url) + + +class GitHubClientError(RuntimeError): + pass + + +def _auth_headers() -> dict[str, str]: + token = (settings.github_token or "").strip() + if not token: + raise GitHubClientError("GitHub token is not configured (GH_TOKEN/GITHUB_TOKEN).") + return { + "Authorization": f"Bearer {token}", + "Accept": "application/vnd.github+json", + "X-GitHub-Api-Version": GITHUB_API_VERSION, + } + + +async def get_pull_request_head_sha(pr: ParsedPullRequest) -> str: + """Return head SHA for a PR.""" + url = f"{GITHUB_API_BASE_URL}/repos/{pr.owner}/{pr.repo}/pulls/{pr.number}" + async with httpx.AsyncClient(timeout=10) as client: + resp = await client.get(url, headers=_auth_headers()) + if resp.status_code >= 400: + raise GitHubClientError(f"GitHub PR lookup failed: {resp.status_code} {resp.text}") + data = resp.json() + head = data.get("head") + if not isinstance(head, dict) or not isinstance(head.get("sha"), str): + raise GitHubClientError("GitHub PR response missing head.sha") + return head["sha"] + + +async def _find_check_run_id(*, owner: str, repo: str, ref: str, check_name: str) -> int | None: + # Docs: GET /repos/{owner}/{repo}/commits/{ref}/check-runs + url = f"{GITHUB_API_BASE_URL}/repos/{owner}/{repo}/commits/{ref}/check-runs" + params = {"check_name": check_name, "per_page": 100} + async with httpx.AsyncClient(timeout=10) as client: + resp = await client.get(url, headers=_auth_headers(), params=params) + if resp.status_code >= 400: + raise GitHubClientError( + f"GitHub check-runs lookup failed: {resp.status_code} {resp.text}", + ) + payload = resp.json() + runs = payload.get("check_runs") + if not isinstance(runs, list): + return None + for run in runs: + if not isinstance(run, dict): + continue + if run.get("name") != check_name: + continue + run_id = run.get("id") + if isinstance(run_id, int): + return run_id + return None + + +CheckStatus = Literal["queued", "in_progress", "completed"] +CheckConclusion = Literal[ + "success", + "failure", + "neutral", + "cancelled", + "skipped", + "timed_out", + "action_required", +] + + +async def upsert_check_run( + *, + owner: str, + repo: str, + head_sha: str, + check_name: str, + status: CheckStatus, + conclusion: CheckConclusion | None, + title: str, + summary: str, + details_url: str | None = None, +) -> None: + """Create or update a check run on a commit SHA. + + If a check run with the same name exists on the ref, we patch it. + Otherwise, we create a new one. + """ + + payload: dict[str, Any] = { + "name": check_name, + "head_sha": head_sha, + "status": status, + "output": { + "title": title, + "summary": summary, + }, + } + if details_url: + payload["details_url"] = details_url + if status == "completed": + if conclusion is None: + raise ValueError("conclusion is required when status=completed") + payload["conclusion"] = conclusion + + run_id = await _find_check_run_id(owner=owner, repo=repo, ref=head_sha, check_name=check_name) + if run_id is None: + url = f"{GITHUB_API_BASE_URL}/repos/{owner}/{repo}/check-runs" + async with httpx.AsyncClient(timeout=10) as client: + resp = await client.post(url, headers={**_auth_headers(), "Accept": "application/vnd.github+json"}, json=payload) + if resp.status_code >= 400: + raise GitHubClientError( + f"GitHub check-run create failed: {resp.status_code} {resp.text}", + ) + logger.info( + "github.check_run.created", + extra={"owner": owner, "repo": repo, "sha": head_sha, "check": check_name}, + ) + return + + url = f"{GITHUB_API_BASE_URL}/repos/{owner}/{repo}/check-runs/{run_id}" + # PATCH payload should not include head_sha/name for updates? Safe to include minimal fields. + patch_payload = { + "status": status, + "output": payload["output"], + } + if details_url: + patch_payload["details_url"] = details_url + if status == "completed": + patch_payload["conclusion"] = conclusion + + async with httpx.AsyncClient(timeout=10) as client: + resp = await client.patch(url, headers=_auth_headers(), json=patch_payload) + if resp.status_code >= 400: + raise GitHubClientError( + f"GitHub check-run update failed: {resp.status_code} {resp.text}", + ) + logger.info( + "github.check_run.updated", + extra={"owner": owner, "repo": repo, "sha": head_sha, "check": check_name, "id": run_id}, + ) diff --git a/backend/app/services/github/mission_control_approval_check.py b/backend/app/services/github/mission_control_approval_check.py new file mode 100644 index 00000000..13a8c7ae --- /dev/null +++ b/backend/app/services/github/mission_control_approval_check.py @@ -0,0 +1,376 @@ +"""Mission Control approval gate → GitHub required check. + +This module maintains a GitHub Check Run (recommended) named: +- `mission-control/approval` + +The check is intended to be added to GitHub ruleset required checks so PRs +cannot merge unless the corresponding Mission Control task has an approved +in-app approval. + +Mapping: +- PR → Task: by `custom_field_values.github_pr_url` exact match. +- Task → Approval: any linked Approval rows with status in {pending, approved, rejected}. + +Triggers (implemented via API hooks): +- approval created / resolved +- task github_pr_url updated + +A periodic reconciliation job should call the sync functions as a safety net. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Literal +from uuid import UUID + +from sqlmodel import col, select + +from app.core.config import settings +from app.core.logging import get_logger +from app.models.approval_task_links import ApprovalTaskLink +from app.models.approvals import Approval +from app.models.boards import Board +from app.models.task_custom_fields import TaskCustomFieldDefinition, TaskCustomFieldValue +from app.models.tasks import Task +from app.services.github.client import ( + GitHubClientError, + get_pull_request_head_sha, + parse_pull_request_url, + upsert_check_run, +) + +if False: # pragma: no cover + from sqlmodel.ext.asyncio.session import AsyncSession + +logger = get_logger(__name__) + +CHECK_NAME = "mission-control/approval" + +# Default action types that qualify as a "merge gate" approval. +# (Action types are free-form today; keep this conservative but configurable later.) +REQUIRED_ACTION_TYPES = {"mark_done", "mark_task_done"} + + +CheckOutcome = Literal["success", "pending", "rejected", "missing", "error", "multiple"] + + +@dataclass(frozen=True) +class ApprovalGateEvaluation: + outcome: CheckOutcome + task_ids: tuple[UUID, ...] = () + summary: str = "" + + +async def _board_org_id(session: AsyncSession, *, board_id: UUID) -> UUID | None: + return ( + await session.exec( + select(col(Board.organization_id)).where(col(Board.id) == board_id), + ) + ).first() + + +async def _tasks_for_pr_url( + session: AsyncSession, + *, + board_id: UUID, + pr_url: str, +) -> list[Task]: + org_id = await _board_org_id(session, board_id=board_id) + if org_id is None: + return [] + + statement = ( + select(Task) + .join(TaskCustomFieldValue, col(TaskCustomFieldValue.task_id) == col(Task.id)) + .join( + TaskCustomFieldDefinition, + col(TaskCustomFieldDefinition.id) + == col(TaskCustomFieldValue.task_custom_field_definition_id), + ) + .where(col(Task.board_id) == board_id) + .where(col(TaskCustomFieldDefinition.organization_id) == org_id) + .where(col(TaskCustomFieldDefinition.field_key) == "github_pr_url") + .where(col(TaskCustomFieldValue.value) == pr_url) + .order_by(col(Task.created_at).asc()) + ) + rows = list(await session.exec(statement)) + return [row for row in rows if isinstance(row, Task)] + + +async def _approval_rows_for_task( + session: AsyncSession, + *, + board_id: UUID, + task_id: UUID, +) -> list[Approval]: + # Linked approvals (new style) + linked_stmt = ( + select(Approval) + .join(ApprovalTaskLink, col(ApprovalTaskLink.approval_id) == col(Approval.id)) + .where(col(Approval.board_id) == board_id) + .where(col(ApprovalTaskLink.task_id) == task_id) + .order_by(col(Approval.created_at).asc()) + ) + linked = list(await session.exec(linked_stmt)) + + # Legacy approvals (Approval.task_id) not linked via ApprovalTaskLink + legacy_stmt = ( + select(Approval) + .where(col(Approval.board_id) == board_id) + .where(col(Approval.task_id) == task_id) + .order_by(col(Approval.created_at).asc()) + ) + legacy = list(await session.exec(legacy_stmt)) + + # Merge unique by id + by_id: dict[UUID, Approval] = {} + for approval in [*linked, *legacy]: + if isinstance(approval, Approval): + by_id.setdefault(approval.id, approval) + return list(by_id.values()) + + +def _qualifies_for_gate(approval: Approval) -> bool: + # If action types evolve, we can broaden this; for now keep it anchored. + return approval.action_type in REQUIRED_ACTION_TYPES + + +async def evaluate_approval_gate_for_pr_url( + session: AsyncSession, + *, + board_id: UUID, + pr_url: str, +) -> ApprovalGateEvaluation: + tasks = await _tasks_for_pr_url(session, board_id=board_id, pr_url=pr_url) + if not tasks: + return ApprovalGateEvaluation( + outcome="missing", + task_ids=(), + summary=( + "No Mission Control task is linked to this PR. Set the task custom field " + "`github_pr_url` to this PR URL." + ), + ) + if len(tasks) > 1: + return ApprovalGateEvaluation( + outcome="multiple", + task_ids=tuple(task.id for task in tasks), + summary=( + "Multiple Mission Control tasks are linked to this PR URL. " + "Ensure exactly one task has `github_pr_url` set to this PR." + ), + ) + + task = tasks[0] + approvals = await _approval_rows_for_task(session, board_id=board_id, task_id=task.id) + gate_approvals = [a for a in approvals if _qualifies_for_gate(a)] + + if not gate_approvals: + return ApprovalGateEvaluation( + outcome="missing", + task_ids=(task.id,), + summary=( + "No qualifying approval found for this task. Create an approval request " + f"(action_type in {sorted(REQUIRED_ACTION_TYPES)})." + ), + ) + + statuses = [str(a.status) for a in gate_approvals] + if any(s == "approved" for s in statuses): + return ApprovalGateEvaluation( + outcome="success", + task_ids=(task.id,), + summary="Approval is approved. Merge is permitted.", + ) + if any(s == "rejected" for s in statuses): + return ApprovalGateEvaluation( + outcome="rejected", + task_ids=(task.id,), + summary="Approval was rejected. Merge is blocked until a new approval is granted.", + ) + if any(s == "pending" for s in statuses): + return ApprovalGateEvaluation( + outcome="pending", + task_ids=(task.id,), + summary="Approval is pending. Merge is blocked until approved.", + ) + + return ApprovalGateEvaluation( + outcome="error", + task_ids=(task.id,), + summary=f"Unexpected approval statuses: {sorted(set(statuses))}", + ) + + +async def sync_github_approval_check_for_pr_url( + session: AsyncSession, + *, + board_id: UUID, + pr_url: str, +) -> None: + """Upsert the GitHub check run for a PR URL based on Mission Control approval state.""" + + parsed = parse_pull_request_url(pr_url) + if parsed is None: + logger.warning( + "github.approval_check.invalid_pr_url", + extra={"board_id": str(board_id), "pr_url": pr_url}, + ) + return + + try: + evaluation = await evaluate_approval_gate_for_pr_url( + session, + board_id=board_id, + pr_url=pr_url, + ) + + head_sha = await get_pull_request_head_sha(parsed) + + title = "Mission Control approval gate" + summary_lines = [ + f"PR: {parsed.url}", + f"Board: {board_id}", + ] + if evaluation.task_ids: + summary_lines.append("Task(s): " + ", ".join(str(tid) for tid in evaluation.task_ids)) + summary_lines.append("") + summary_lines.append(evaluation.summary) + + if evaluation.outcome == "success": + await upsert_check_run( + owner=parsed.owner, + repo=parsed.repo, + head_sha=head_sha, + check_name=CHECK_NAME, + status="completed", + conclusion="success", + title=title, + summary="\n".join(summary_lines), + ) + return + + if evaluation.outcome == "pending": + # Keep as in_progress to clearly signal it's waiting. + await upsert_check_run( + owner=parsed.owner, + repo=parsed.repo, + head_sha=head_sha, + check_name=CHECK_NAME, + status="in_progress", + conclusion=None, + title=title, + summary="\n".join(summary_lines), + ) + return + + # failure-like outcomes + await upsert_check_run( + owner=parsed.owner, + repo=parsed.repo, + head_sha=head_sha, + check_name=CHECK_NAME, + status="completed", + conclusion="failure", + title=title, + summary="\n".join(summary_lines), + ) + + except GitHubClientError as exc: + logger.warning( + "github.approval_check.github_error", + extra={"board_id": str(board_id), "pr_url": pr_url, "error": str(exc)}, + ) + except Exception as exc: + logger.exception( + "github.approval_check.unexpected", + extra={"board_id": str(board_id), "pr_url": pr_url, "error": str(exc)}, + ) + + +async def sync_github_approval_check_for_task_ids( + session: AsyncSession, + *, + board_id: UUID, + task_ids: list[UUID], +) -> None: + """Sync approval checks for any tasks that have github_pr_url set. + + Used by approval hooks (one approval can link multiple tasks). + """ + + if not task_ids: + return + + # Load custom-field values for these tasks and find github_pr_url. + # We reuse the same join approach but filter by task ids. + org_id = await _board_org_id(session, board_id=board_id) + if org_id is None: + return + + stmt = ( + select(col(TaskCustomFieldValue.task_id), col(TaskCustomFieldValue.value)) + .join( + TaskCustomFieldDefinition, + col(TaskCustomFieldDefinition.id) + == col(TaskCustomFieldValue.task_custom_field_definition_id), + ) + .where(col(TaskCustomFieldDefinition.organization_id) == org_id) + .where(col(TaskCustomFieldDefinition.field_key) == "github_pr_url") + .where(col(TaskCustomFieldValue.task_id).in_(task_ids)) + ) + rows = list(await session.exec(stmt)) + + pr_urls: set[str] = set() + for _task_id, value in rows: + if isinstance(value, str) and value.strip(): + pr_urls.add(value.strip()) + + for pr_url in sorted(pr_urls): + await sync_github_approval_check_for_pr_url(session, board_id=board_id, pr_url=pr_url) + + +async def reconcile_github_approval_checks_for_board( + session: AsyncSession, + *, + board_id: UUID, +) -> int: + """Periodic reconciliation safety net. + + Returns number of distinct PR URLs processed. + + Intended to be run by a cron/worker periodically. + """ + + org_id = await _board_org_id(session, board_id=board_id) + if org_id is None: + return 0 + + stmt = ( + select(col(TaskCustomFieldValue.value)) + .join( + TaskCustomFieldDefinition, + col(TaskCustomFieldDefinition.id) + == col(TaskCustomFieldValue.task_custom_field_definition_id), + ) + .join(Task, col(Task.id) == col(TaskCustomFieldValue.task_id)) + .where(col(Task.board_id) == board_id) + .where(col(TaskCustomFieldDefinition.organization_id) == org_id) + .where(col(TaskCustomFieldDefinition.field_key) == "github_pr_url") + ) + rows = list(await session.exec(stmt)) + + pr_urls: set[str] = set() + for (value,) in rows: + if isinstance(value, str) and value.strip(): + pr_urls.add(value.strip()) + + for pr_url in sorted(pr_urls): + await sync_github_approval_check_for_pr_url(session, board_id=board_id, pr_url=pr_url) + + return len(pr_urls) + + +def github_approval_check_enabled() -> bool: + return bool((settings.github_token or "").strip()) diff --git a/backend/tests/test_mission_control_approval_check.py b/backend/tests/test_mission_control_approval_check.py new file mode 100644 index 00000000..649c5c3a --- /dev/null +++ b/backend/tests/test_mission_control_approval_check.py @@ -0,0 +1,416 @@ +from __future__ import annotations + +from uuid import uuid4 + +import pytest +from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine +from sqlmodel import SQLModel, col, select +from sqlmodel.ext.asyncio.session import AsyncSession + +from app.models.approval_task_links import ApprovalTaskLink +from app.models.approvals import Approval +from app.models.boards import Board +from app.models.gateways import Gateway +from app.models.organizations import Organization +from app.models.task_custom_fields import TaskCustomFieldDefinition, TaskCustomFieldValue +from app.models.tasks import Task +from app.services.github.mission_control_approval_check import ( + REQUIRED_ACTION_TYPES, + evaluate_approval_gate_for_pr_url, +) + + +async def _make_engine() -> AsyncEngine: + engine = create_async_engine("sqlite+aiosqlite:///:memory:") + async with engine.connect() as conn, conn.begin(): + await conn.run_sync(SQLModel.metadata.create_all) + return engine + + +async def _make_session(engine: AsyncEngine) -> AsyncSession: + return AsyncSession(engine, expire_on_commit=False) + + +@pytest.mark.asyncio +async def test_approval_gate_no_task_linked_is_missing() -> None: + engine = await _make_engine() + try: + async with await _make_session(engine) as session: + org_id = uuid4() + board_id = uuid4() + gateway_id = uuid4() + + session.add(Organization(id=org_id, name="org")) + session.add( + Gateway( + id=gateway_id, + organization_id=org_id, + name="gateway", + url="https://gateway.local", + workspace_root="/tmp/workspace", + ) + ) + session.add( + Board( + id=board_id, + organization_id=org_id, + name="board", + slug="board", + gateway_id=gateway_id, + ) + ) + session.add( + TaskCustomFieldDefinition( + organization_id=org_id, + field_key="github_pr_url", + label="GitHub PR URL", + field_type="url", + ) + ) + await session.commit() + + out = await evaluate_approval_gate_for_pr_url( + session, + board_id=board_id, + pr_url="https://github.com/acme/repo/pull/1", + ) + assert out.outcome == "missing" + finally: + await engine.dispose() + + +@pytest.mark.asyncio +async def test_approval_gate_multiple_tasks_is_multiple() -> None: + engine = await _make_engine() + try: + async with await _make_session(engine) as session: + org_id = uuid4() + board_id = uuid4() + gateway_id = uuid4() + task_id_1 = uuid4() + task_id_2 = uuid4() + + session.add(Organization(id=org_id, name="org")) + session.add( + Gateway( + id=gateway_id, + organization_id=org_id, + name="gateway", + url="https://gateway.local", + workspace_root="/tmp/workspace", + ) + ) + session.add( + Board( + id=board_id, + organization_id=org_id, + name="board", + slug="board", + gateway_id=gateway_id, + ) + ) + field = TaskCustomFieldDefinition( + organization_id=org_id, + field_key="github_pr_url", + label="GitHub PR URL", + field_type="url", + ) + session.add(field) + session.add(Task(id=task_id_1, board_id=board_id, title="t1", description="", status="inbox")) + session.add(Task(id=task_id_2, board_id=board_id, title="t2", description="", status="inbox")) + await session.commit() + + session.add( + TaskCustomFieldValue( + task_id=task_id_1, + task_custom_field_definition_id=field.id, + value="https://github.com/acme/repo/pull/2", + ) + ) + session.add( + TaskCustomFieldValue( + task_id=task_id_2, + task_custom_field_definition_id=field.id, + value="https://github.com/acme/repo/pull/2", + ) + ) + await session.commit() + + out = await evaluate_approval_gate_for_pr_url( + session, + board_id=board_id, + pr_url="https://github.com/acme/repo/pull/2", + ) + assert out.outcome == "multiple" + assert len(out.task_ids) == 2 + finally: + await engine.dispose() + + +@pytest.mark.asyncio +async def test_approval_gate_pending_is_pending() -> None: + engine = await _make_engine() + try: + async with await _make_session(engine) as session: + org_id = uuid4() + board_id = uuid4() + gateway_id = uuid4() + task_id = uuid4() + + session.add(Organization(id=org_id, name="org")) + session.add( + Gateway( + id=gateway_id, + organization_id=org_id, + name="gateway", + url="https://gateway.local", + workspace_root="/tmp/workspace", + ) + ) + session.add( + Board( + id=board_id, + organization_id=org_id, + name="board", + slug="board", + gateway_id=gateway_id, + ) + ) + field = TaskCustomFieldDefinition( + organization_id=org_id, + field_key="github_pr_url", + label="GitHub PR URL", + field_type="url", + ) + session.add(field) + session.add(Task(id=task_id, board_id=board_id, title="t", description="", status="inbox")) + await session.commit() + + session.add( + TaskCustomFieldValue( + task_id=task_id, + task_custom_field_definition_id=field.id, + value="https://github.com/acme/repo/pull/3", + ) + ) + approval = Approval( + board_id=board_id, + task_id=task_id, + action_type=sorted(REQUIRED_ACTION_TYPES)[0], + confidence=90, + status="pending", + ) + session.add(approval) + await session.commit() + + out = await evaluate_approval_gate_for_pr_url( + session, + board_id=board_id, + pr_url="https://github.com/acme/repo/pull/3", + ) + assert out.outcome == "pending" + finally: + await engine.dispose() + + +@pytest.mark.asyncio +async def test_approval_gate_approved_is_success() -> None: + engine = await _make_engine() + try: + async with await _make_session(engine) as session: + org_id = uuid4() + board_id = uuid4() + gateway_id = uuid4() + task_id = uuid4() + + session.add(Organization(id=org_id, name="org")) + session.add( + Gateway( + id=gateway_id, + organization_id=org_id, + name="gateway", + url="https://gateway.local", + workspace_root="/tmp/workspace", + ) + ) + session.add( + Board( + id=board_id, + organization_id=org_id, + name="board", + slug="board", + gateway_id=gateway_id, + ) + ) + field = TaskCustomFieldDefinition( + organization_id=org_id, + field_key="github_pr_url", + label="GitHub PR URL", + field_type="url", + ) + session.add(field) + session.add(Task(id=task_id, board_id=board_id, title="t", description="", status="review")) + await session.commit() + + session.add( + TaskCustomFieldValue( + task_id=task_id, + task_custom_field_definition_id=field.id, + value="https://github.com/acme/repo/pull/4", + ) + ) + approval = Approval( + board_id=board_id, + task_id=None, + action_type=sorted(REQUIRED_ACTION_TYPES)[0], + confidence=90, + status="approved", + ) + session.add(approval) + await session.commit() + + session.add(ApprovalTaskLink(approval_id=approval.id, task_id=task_id)) + await session.commit() + + out = await evaluate_approval_gate_for_pr_url( + session, + board_id=board_id, + pr_url="https://github.com/acme/repo/pull/4", + ) + assert out.outcome == "success" + finally: + await engine.dispose() + + +@pytest.mark.asyncio +async def test_approval_gate_rejected_is_rejected() -> None: + engine = await _make_engine() + try: + async with await _make_session(engine) as session: + org_id = uuid4() + board_id = uuid4() + gateway_id = uuid4() + task_id = uuid4() + + session.add(Organization(id=org_id, name="org")) + session.add( + Gateway( + id=gateway_id, + organization_id=org_id, + name="gateway", + url="https://gateway.local", + workspace_root="/tmp/workspace", + ) + ) + session.add( + Board( + id=board_id, + organization_id=org_id, + name="board", + slug="board", + gateway_id=gateway_id, + ) + ) + field = TaskCustomFieldDefinition( + organization_id=org_id, + field_key="github_pr_url", + label="GitHub PR URL", + field_type="url", + ) + session.add(field) + session.add(Task(id=task_id, board_id=board_id, title="t", description="", status="review")) + await session.commit() + + session.add( + TaskCustomFieldValue( + task_id=task_id, + task_custom_field_definition_id=field.id, + value="https://github.com/acme/repo/pull/5", + ) + ) + approval = Approval( + board_id=board_id, + task_id=task_id, + action_type=sorted(REQUIRED_ACTION_TYPES)[0], + confidence=90, + status="rejected", + ) + session.add(approval) + await session.commit() + + out = await evaluate_approval_gate_for_pr_url( + session, + board_id=board_id, + pr_url="https://github.com/acme/repo/pull/5", + ) + assert out.outcome == "rejected" + finally: + await engine.dispose() + + +@pytest.mark.asyncio +async def test_approval_gate_non_qualifying_action_type_is_missing() -> None: + engine = await _make_engine() + try: + async with await _make_session(engine) as session: + org_id = uuid4() + board_id = uuid4() + gateway_id = uuid4() + task_id = uuid4() + + session.add(Organization(id=org_id, name="org")) + session.add( + Gateway( + id=gateway_id, + organization_id=org_id, + name="gateway", + url="https://gateway.local", + workspace_root="/tmp/workspace", + ) + ) + session.add( + Board( + id=board_id, + organization_id=org_id, + name="board", + slug="board", + gateway_id=gateway_id, + ) + ) + field = TaskCustomFieldDefinition( + organization_id=org_id, + field_key="github_pr_url", + label="GitHub PR URL", + field_type="url", + ) + session.add(field) + session.add(Task(id=task_id, board_id=board_id, title="t", description="", status="review")) + await session.commit() + + session.add( + TaskCustomFieldValue( + task_id=task_id, + task_custom_field_definition_id=field.id, + value="https://github.com/acme/repo/pull/6", + ) + ) + # approval exists but wrong action_type + session.add( + Approval( + board_id=board_id, + task_id=task_id, + action_type="some_other_action", + confidence=50, + status="approved", + ) + ) + await session.commit() + + out = await evaluate_approval_gate_for_pr_url( + session, + board_id=board_id, + pr_url="https://github.com/acme/repo/pull/6", + ) + assert out.outcome == "missing" + finally: + await engine.dispose()