refactor: update module docstrings for clarity and consistency
This commit is contained in:
@@ -0,0 +1 @@
|
||||
"""Business logic services for backend domain operations."""
|
||||
|
||||
@@ -1,8 +1,13 @@
|
||||
"""Utilities for recording normalized activity events."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from uuid import UUID
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
if TYPE_CHECKING:
|
||||
from uuid import UUID
|
||||
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.models.activity_events import ActivityEvent
|
||||
|
||||
@@ -15,6 +20,7 @@ def record_activity(
|
||||
agent_id: UUID | None = None,
|
||||
task_id: UUID | None = None,
|
||||
) -> ActivityEvent:
|
||||
"""Create and attach an activity event row to the current DB session."""
|
||||
event = ActivityEvent(
|
||||
event_type=event_type,
|
||||
message=message,
|
||||
|
||||
@@ -1,10 +1,16 @@
|
||||
"""Access control helpers for admin-only operations."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
from app.core.auth import AuthContext
|
||||
if TYPE_CHECKING:
|
||||
from app.core.auth import AuthContext
|
||||
|
||||
|
||||
def require_admin(auth: AuthContext) -> None:
|
||||
"""Raise HTTP 403 unless the authenticated actor is a user admin."""
|
||||
if auth.actor_type != "user" or auth.user is None:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
|
||||
|
||||
@@ -1,21 +1,31 @@
|
||||
"""Gateway-facing agent provisioning and cleanup helpers."""
|
||||
# ruff: noqa: EM101, TRY003
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import re
|
||||
from contextlib import suppress
|
||||
from pathlib import Path
|
||||
from typing import Any, cast
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
from uuid import uuid4
|
||||
|
||||
from jinja2 import Environment, FileSystemLoader, StrictUndefined, select_autoescape
|
||||
|
||||
from app.core.config import settings
|
||||
from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig
|
||||
from app.integrations.openclaw_gateway import OpenClawGatewayError, ensure_session, openclaw_call
|
||||
from app.models.agents import Agent
|
||||
from app.models.boards import Board
|
||||
from app.models.gateways import Gateway
|
||||
from app.models.users import User
|
||||
from app.integrations.openclaw_gateway import (
|
||||
OpenClawGatewayError,
|
||||
ensure_session,
|
||||
openclaw_call,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.models.agents import Agent
|
||||
from app.models.boards import Board
|
||||
from app.models.gateways import Gateway
|
||||
from app.models.users import User
|
||||
|
||||
DEFAULT_HEARTBEAT_CONFIG = {"every": "10m", "target": "none"}
|
||||
DEFAULT_IDENTITY_PROFILE = {
|
||||
@@ -35,7 +45,8 @@ EXTRA_IDENTITY_PROFILE_FIELDS = {
|
||||
"verbosity": "identity_verbosity",
|
||||
"output_format": "identity_output_format",
|
||||
"update_cadence": "identity_update_cadence",
|
||||
# Per-agent charter (optional). Used to give agents a "purpose in life" and a distinct vibe.
|
||||
# Per-agent charter (optional).
|
||||
# Used to give agents a "purpose in life" and a distinct vibe.
|
||||
"purpose": "identity_purpose",
|
||||
"personality": "identity_personality",
|
||||
"custom_instructions": "identity_custom_instructions",
|
||||
@@ -54,11 +65,11 @@ DEFAULT_GATEWAY_FILES = frozenset(
|
||||
"BOOT.md",
|
||||
"BOOTSTRAP.md",
|
||||
"MEMORY.md",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
# These files are intended to evolve within the agent workspace. Provision them if missing,
|
||||
# but avoid overwriting existing content during updates.
|
||||
# These files are intended to evolve within the agent workspace.
|
||||
# Provision them if missing, but avoid overwriting existing content during updates.
|
||||
#
|
||||
# Examples:
|
||||
# - SELF.md: evolving identity/preferences
|
||||
@@ -68,6 +79,7 @@ PRESERVE_AGENT_EDITABLE_FILES = frozenset({"SELF.md", "USER.md", "MEMORY.md"})
|
||||
|
||||
HEARTBEAT_LEAD_TEMPLATE = "HEARTBEAT_LEAD.md"
|
||||
HEARTBEAT_AGENT_TEMPLATE = "HEARTBEAT_AGENT.md"
|
||||
_SESSION_KEY_PARTS_MIN = 2
|
||||
MAIN_TEMPLATE_MAP = {
|
||||
"AGENTS.md": "MAIN_AGENTS.md",
|
||||
"HEARTBEAT.md": "MAIN_HEARTBEAT.md",
|
||||
@@ -97,13 +109,13 @@ def _agent_id_from_session_key(session_key: str | None) -> str | None:
|
||||
if not value.startswith("agent:"):
|
||||
return None
|
||||
parts = value.split(":")
|
||||
if len(parts) < 2:
|
||||
if len(parts) < _SESSION_KEY_PARTS_MIN:
|
||||
return None
|
||||
agent_id = parts[1].strip()
|
||||
return agent_id or None
|
||||
|
||||
|
||||
def _extract_agent_id(payload: object) -> str | None:
|
||||
def _extract_agent_id(payload: object) -> str | None: # noqa: C901
|
||||
def _from_list(items: object) -> str | None:
|
||||
if not isinstance(items, list):
|
||||
return None
|
||||
@@ -137,7 +149,7 @@ def _agent_key(agent: Agent) -> str:
|
||||
session_key = agent.openclaw_session_id or ""
|
||||
if session_key.startswith("agent:"):
|
||||
parts = session_key.split(":")
|
||||
if len(parts) >= 2 and parts[1]:
|
||||
if len(parts) >= _SESSION_KEY_PARTS_MIN and parts[1]:
|
||||
return parts[1]
|
||||
return _slugify(agent.name)
|
||||
|
||||
@@ -183,14 +195,14 @@ def _ensure_workspace_file(
|
||||
if not workspace_path or not name:
|
||||
return
|
||||
# Only write to a dedicated, explicitly-configured local directory.
|
||||
# Using `gateway.workspace_root` directly here is unsafe (and CodeQL correctly flags it)
|
||||
# because it is a DB-backed config value.
|
||||
# Using `gateway.workspace_root` directly here is unsafe.
|
||||
# CodeQL correctly flags that value because it is DB-backed config.
|
||||
base_root = (settings.local_agent_workspace_root or "").strip()
|
||||
if not base_root:
|
||||
return
|
||||
base = Path(base_root).expanduser()
|
||||
|
||||
# Derive a stable, safe directory name from the (potentially untrusted) workspace path.
|
||||
# Derive a stable, safe directory name from the untrusted workspace path.
|
||||
# This prevents path traversal and avoids writing to arbitrary locations.
|
||||
digest = hashlib.sha256(workspace_path.encode("utf-8")).hexdigest()[:16]
|
||||
root = base / f"gateway-workspace-{digest}"
|
||||
@@ -345,12 +357,14 @@ async def _supported_gateway_files(config: GatewayClientConfig) -> set[str]:
|
||||
default_id = None
|
||||
if isinstance(agents_payload, dict):
|
||||
agents = list(agents_payload.get("agents") or [])
|
||||
default_id = agents_payload.get("defaultId") or agents_payload.get("default_id")
|
||||
default_id = agents_payload.get("defaultId") or agents_payload.get(
|
||||
"default_id",
|
||||
)
|
||||
agent_id = default_id or (agents[0].get("id") if agents else None)
|
||||
if not agent_id:
|
||||
return set(DEFAULT_GATEWAY_FILES)
|
||||
files_payload = await openclaw_call(
|
||||
"agents.files.list", {"agentId": agent_id}, config=config
|
||||
"agents.files.list", {"agentId": agent_id}, config=config,
|
||||
)
|
||||
if isinstance(files_payload, dict):
|
||||
files = files_payload.get("files") or []
|
||||
@@ -374,10 +388,12 @@ async def _reset_session(session_key: str, config: GatewayClientConfig) -> None:
|
||||
|
||||
|
||||
async def _gateway_agent_files_index(
|
||||
agent_id: str, config: GatewayClientConfig
|
||||
agent_id: str, config: GatewayClientConfig,
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
try:
|
||||
payload = await openclaw_call("agents.files.list", {"agentId": agent_id}, config=config)
|
||||
payload = await openclaw_call(
|
||||
"agents.files.list", {"agentId": agent_id}, config=config,
|
||||
)
|
||||
if isinstance(payload, dict):
|
||||
files = payload.get("files") or []
|
||||
index: dict[str, dict[str, Any]] = {}
|
||||
@@ -420,21 +436,25 @@ def _render_agent_files(
|
||||
)
|
||||
heartbeat_path = _templates_root() / heartbeat_template
|
||||
if heartbeat_path.exists():
|
||||
rendered[name] = env.get_template(heartbeat_template).render(**context).strip()
|
||||
rendered[name] = (
|
||||
env.get_template(heartbeat_template).render(**context).strip()
|
||||
)
|
||||
continue
|
||||
override = overrides.get(name)
|
||||
if override:
|
||||
rendered[name] = env.from_string(override).render(**context).strip()
|
||||
continue
|
||||
template_name = (
|
||||
template_overrides[name] if template_overrides and name in template_overrides else name
|
||||
template_overrides[name]
|
||||
if template_overrides and name in template_overrides
|
||||
else name
|
||||
)
|
||||
path = _templates_root() / template_name
|
||||
if path.exists():
|
||||
rendered[name] = env.get_template(template_name).render(**context).strip()
|
||||
continue
|
||||
if name == "MEMORY.md":
|
||||
# Back-compat fallback for existing gateways that don't ship a MEMORY.md template.
|
||||
# Back-compat fallback for gateways that do not ship MEMORY.md.
|
||||
rendered[name] = "# MEMORY\n\nBootstrap pending.\n"
|
||||
continue
|
||||
rendered[name] = ""
|
||||
@@ -487,7 +507,9 @@ async def _patch_gateway_agent_list(
|
||||
else:
|
||||
new_list.append(entry)
|
||||
if not updated:
|
||||
new_list.append({"id": agent_id, "workspace": workspace_path, "heartbeat": heartbeat})
|
||||
new_list.append(
|
||||
{"id": agent_id, "workspace": workspace_path, "heartbeat": heartbeat},
|
||||
)
|
||||
|
||||
patch = {"agents": {"list": new_list}}
|
||||
params = {"raw": json.dumps(patch)}
|
||||
@@ -496,7 +518,7 @@ async def _patch_gateway_agent_list(
|
||||
await openclaw_call("config.patch", params, config=config)
|
||||
|
||||
|
||||
async def patch_gateway_agent_heartbeats(
|
||||
async def patch_gateway_agent_heartbeats( # noqa: C901
|
||||
gateway: Gateway,
|
||||
*,
|
||||
entries: list[tuple[str, str, dict[str, Any]]],
|
||||
@@ -521,7 +543,8 @@ async def patch_gateway_agent_heartbeats(
|
||||
raise OpenClawGatewayError("config agents.list is not a list")
|
||||
|
||||
entry_by_id: dict[str, tuple[str, dict[str, Any]]] = {
|
||||
agent_id: (workspace_path, heartbeat) for agent_id, workspace_path, heartbeat in entries
|
||||
agent_id: (workspace_path, heartbeat)
|
||||
for agent_id, workspace_path, heartbeat in entries
|
||||
}
|
||||
|
||||
updated_ids: set[str] = set()
|
||||
@@ -544,7 +567,9 @@ async def patch_gateway_agent_heartbeats(
|
||||
for agent_id, (workspace_path, heartbeat) in entry_by_id.items():
|
||||
if agent_id in updated_ids:
|
||||
continue
|
||||
new_list.append({"id": agent_id, "workspace": workspace_path, "heartbeat": heartbeat})
|
||||
new_list.append(
|
||||
{"id": agent_id, "workspace": workspace_path, "heartbeat": heartbeat},
|
||||
)
|
||||
|
||||
patch = {"agents": {"list": new_list}}
|
||||
params = {"raw": json.dumps(patch)}
|
||||
@@ -585,7 +610,9 @@ async def _remove_gateway_agent_list(
|
||||
raise OpenClawGatewayError("config agents.list is not a list")
|
||||
|
||||
new_list = [
|
||||
entry for entry in lst if not (isinstance(entry, dict) and entry.get("id") == agent_id)
|
||||
entry
|
||||
for entry in lst
|
||||
if not (isinstance(entry, dict) and entry.get("id") == agent_id)
|
||||
]
|
||||
if len(new_list) == len(lst):
|
||||
return
|
||||
@@ -616,7 +643,7 @@ async def _get_gateway_agent_entry(
|
||||
return None
|
||||
|
||||
|
||||
async def provision_agent(
|
||||
async def provision_agent( # noqa: C901, PLR0912, PLR0913
|
||||
agent: Agent,
|
||||
board: Board,
|
||||
gateway: Gateway,
|
||||
@@ -627,6 +654,7 @@ async def provision_agent(
|
||||
force_bootstrap: bool = False,
|
||||
reset_session: bool = False,
|
||||
) -> None:
|
||||
"""Provision or update a regular board agent workspace."""
|
||||
if not gateway.url:
|
||||
return
|
||||
if not gateway.workspace_root:
|
||||
@@ -665,11 +693,9 @@ async def provision_agent(
|
||||
content = rendered.get(name)
|
||||
if not content:
|
||||
continue
|
||||
try:
|
||||
_ensure_workspace_file(workspace_path, name, content, overwrite=False)
|
||||
except OSError:
|
||||
with suppress(OSError):
|
||||
# Local workspace may not be writable/available; fall back to gateway API.
|
||||
pass
|
||||
_ensure_workspace_file(workspace_path, name, content, overwrite=False)
|
||||
for name, content in rendered.items():
|
||||
if content == "":
|
||||
continue
|
||||
@@ -694,7 +720,7 @@ async def provision_agent(
|
||||
await _reset_session(session_key, client_config)
|
||||
|
||||
|
||||
async def provision_main_agent(
|
||||
async def provision_main_agent( # noqa: C901, PLR0912, PLR0913
|
||||
agent: Agent,
|
||||
gateway: Gateway,
|
||||
auth_token: str,
|
||||
@@ -704,12 +730,15 @@ async def provision_main_agent(
|
||||
force_bootstrap: bool = False,
|
||||
reset_session: bool = False,
|
||||
) -> None:
|
||||
"""Provision or update the gateway main agent workspace."""
|
||||
if not gateway.url:
|
||||
return
|
||||
if not gateway.main_session_key:
|
||||
raise ValueError("gateway main_session_key is required")
|
||||
client_config = GatewayClientConfig(url=gateway.url, token=gateway.token)
|
||||
await ensure_session(gateway.main_session_key, config=client_config, label="Main Agent")
|
||||
await ensure_session(
|
||||
gateway.main_session_key, config=client_config, label="Main Agent",
|
||||
)
|
||||
|
||||
agent_id = await _gateway_default_agent_id(
|
||||
client_config,
|
||||
@@ -763,6 +792,7 @@ async def cleanup_agent(
|
||||
agent: Agent,
|
||||
gateway: Gateway,
|
||||
) -> str | None:
|
||||
"""Remove an agent from gateway config and delete its session."""
|
||||
if not gateway.url:
|
||||
return None
|
||||
if not gateway.workspace_root:
|
||||
|
||||
@@ -1,30 +1,41 @@
|
||||
"""Helpers for ensuring each board has a provisioned lead agent."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from sqlmodel import col, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.core.agent_tokens import generate_agent_token, hash_agent_token
|
||||
from app.core.time import utcnow
|
||||
from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig
|
||||
from app.integrations.openclaw_gateway import OpenClawGatewayError, ensure_session, send_message
|
||||
from app.integrations.openclaw_gateway import (
|
||||
OpenClawGatewayError,
|
||||
ensure_session,
|
||||
send_message,
|
||||
)
|
||||
from app.models.agents import Agent
|
||||
from app.models.boards import Board
|
||||
from app.models.gateways import Gateway
|
||||
from app.models.users import User
|
||||
from app.services.agent_provisioning import DEFAULT_HEARTBEAT_CONFIG, provision_agent
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.models.boards import Board
|
||||
from app.models.gateways import Gateway
|
||||
from app.models.users import User
|
||||
|
||||
|
||||
def lead_session_key(board: Board) -> str:
|
||||
"""Return the deterministic main session key for a board lead agent."""
|
||||
return f"agent:lead-{board.id}:main"
|
||||
|
||||
|
||||
def lead_agent_name(_: Board) -> str:
|
||||
"""Return the default display name for board lead agents."""
|
||||
return "Lead Agent"
|
||||
|
||||
|
||||
async def ensure_board_lead_agent(
|
||||
async def ensure_board_lead_agent( # noqa: PLR0913
|
||||
session: AsyncSession,
|
||||
*,
|
||||
board: Board,
|
||||
@@ -35,11 +46,12 @@ async def ensure_board_lead_agent(
|
||||
identity_profile: dict[str, str] | None = None,
|
||||
action: str = "provision",
|
||||
) -> tuple[Agent, bool]:
|
||||
"""Ensure a board has a lead agent; return `(agent, created)`."""
|
||||
existing = (
|
||||
await session.exec(
|
||||
select(Agent)
|
||||
.where(Agent.board_id == board.id)
|
||||
.where(col(Agent.is_board_lead).is_(True))
|
||||
.where(col(Agent.is_board_lead).is_(True)),
|
||||
)
|
||||
).first()
|
||||
if existing:
|
||||
@@ -66,7 +78,11 @@ async def ensure_board_lead_agent(
|
||||
}
|
||||
if identity_profile:
|
||||
merged_identity_profile.update(
|
||||
{key: value.strip() for key, value in identity_profile.items() if value.strip()}
|
||||
{
|
||||
key: value.strip()
|
||||
for key, value in identity_profile.items()
|
||||
if value.strip()
|
||||
},
|
||||
)
|
||||
|
||||
agent = Agent(
|
||||
@@ -89,11 +105,16 @@ async def ensure_board_lead_agent(
|
||||
try:
|
||||
await provision_agent(agent, board, gateway, raw_token, user, action=action)
|
||||
if agent.openclaw_session_id:
|
||||
await ensure_session(agent.openclaw_session_id, config=config, label=agent.name)
|
||||
await ensure_session(
|
||||
agent.openclaw_session_id,
|
||||
config=config,
|
||||
label=agent.name,
|
||||
)
|
||||
await send_message(
|
||||
(
|
||||
f"Hello {agent.name}. Your workspace has been provisioned.\n\n"
|
||||
"Start the agent, run BOOT.md, and if BOOTSTRAP.md exists run it once "
|
||||
"Start the agent, run BOOT.md, and if BOOTSTRAP.md exists run "
|
||||
"it once "
|
||||
"then delete it. Begin heartbeats after startup."
|
||||
),
|
||||
session_key=agent.openclaw_session_id,
|
||||
|
||||
@@ -1,17 +1,17 @@
|
||||
"""Helpers for assembling denormalized board snapshot response payloads."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import timedelta
|
||||
from uuid import UUID
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import case, func
|
||||
from sqlmodel import col, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.core.time import utcnow
|
||||
from app.models.agents import Agent
|
||||
from app.models.approvals import Approval
|
||||
from app.models.board_memory import BoardMemory
|
||||
from app.models.boards import Board
|
||||
from app.models.gateways import Gateway
|
||||
from app.models.tasks import Task
|
||||
from app.schemas.agents import AgentRead
|
||||
@@ -25,6 +25,13 @@ from app.services.task_dependencies import (
|
||||
dependency_status_by_id,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from uuid import UUID
|
||||
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.models.boards import Board
|
||||
|
||||
OFFLINE_AFTER = timedelta(minutes=10)
|
||||
|
||||
|
||||
@@ -48,9 +55,15 @@ def _agent_to_read(agent: Agent, main_session_keys: set[str]) -> AgentRead:
|
||||
model = AgentRead.model_validate(agent, from_attributes=True)
|
||||
computed_status = _computed_agent_status(agent)
|
||||
is_gateway_main = bool(
|
||||
agent.openclaw_session_id and agent.openclaw_session_id in main_session_keys
|
||||
agent.openclaw_session_id
|
||||
and agent.openclaw_session_id in main_session_keys,
|
||||
)
|
||||
return model.model_copy(
|
||||
update={
|
||||
"status": computed_status,
|
||||
"is_gateway_main": is_gateway_main,
|
||||
},
|
||||
)
|
||||
return model.model_copy(update={"status": computed_status, "is_gateway_main": is_gateway_main})
|
||||
|
||||
|
||||
def _memory_to_read(memory: BoardMemory) -> BoardMemoryRead:
|
||||
@@ -72,7 +85,9 @@ def _task_to_card(
|
||||
card = TaskCardRead.model_validate(task, from_attributes=True)
|
||||
approvals_count, approvals_pending_count = counts_by_task_id.get(task.id, (0, 0))
|
||||
assignee = (
|
||||
agent_name_by_id.get(task.assigned_agent_id) if task.assigned_agent_id is not None else None
|
||||
agent_name_by_id.get(task.assigned_agent_id)
|
||||
if task.assigned_agent_id
|
||||
else None
|
||||
)
|
||||
depends_on_task_ids = deps_by_task_id.get(task.id, [])
|
||||
blocked_by_task_ids = blocked_by_dependency_ids(
|
||||
@@ -89,21 +104,26 @@ def _task_to_card(
|
||||
"depends_on_task_ids": depends_on_task_ids,
|
||||
"blocked_by_task_ids": blocked_by_task_ids,
|
||||
"is_blocked": bool(blocked_by_task_ids),
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def build_board_snapshot(session: AsyncSession, board: Board) -> BoardSnapshot:
|
||||
"""Build a board snapshot with tasks, agents, approvals, and chat history."""
|
||||
board_read = BoardRead.model_validate(board, from_attributes=True)
|
||||
|
||||
tasks = list(
|
||||
await Task.objects.filter_by(board_id=board.id)
|
||||
.order_by(col(Task.created_at).desc())
|
||||
.all(session)
|
||||
.all(session),
|
||||
)
|
||||
task_ids = [task.id for task in tasks]
|
||||
|
||||
deps_by_task_id = await dependency_ids_by_task_id(session, board_id=board.id, task_ids=task_ids)
|
||||
deps_by_task_id = await dependency_ids_by_task_id(
|
||||
session,
|
||||
board_id=board.id,
|
||||
task_ids=task_ids,
|
||||
)
|
||||
all_dependency_ids: list[UUID] = []
|
||||
for values in deps_by_task_id.values():
|
||||
all_dependency_ids.extend(values)
|
||||
@@ -127,9 +147,9 @@ async def build_board_snapshot(session: AsyncSession, board: Board) -> BoardSnap
|
||||
await session.exec(
|
||||
select(func.count(col(Approval.id)))
|
||||
.where(col(Approval.board_id) == board.id)
|
||||
.where(col(Approval.status) == "pending")
|
||||
)
|
||||
).one()
|
||||
.where(col(Approval.status) == "pending"),
|
||||
),
|
||||
).one(),
|
||||
)
|
||||
|
||||
approvals = (
|
||||
@@ -146,12 +166,14 @@ async def build_board_snapshot(session: AsyncSession, board: Board) -> BoardSnap
|
||||
select(
|
||||
col(Approval.task_id),
|
||||
func.count(col(Approval.id)).label("total"),
|
||||
func.sum(case((col(Approval.status) == "pending", 1), else_=0)).label("pending"),
|
||||
func.sum(
|
||||
case((col(Approval.status) == "pending", 1), else_=0),
|
||||
).label("pending"),
|
||||
)
|
||||
.where(col(Approval.board_id) == board.id)
|
||||
.where(col(Approval.task_id).is_not(None))
|
||||
.group_by(col(Approval.task_id))
|
||||
)
|
||||
.group_by(col(Approval.task_id)),
|
||||
),
|
||||
)
|
||||
for task_id, total, pending in rows:
|
||||
if task_id is None:
|
||||
|
||||
@@ -1,26 +1,33 @@
|
||||
"""Policy helpers for lead-agent approval and planning decisions."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
from typing import Mapping
|
||||
|
||||
CONFIDENCE_THRESHOLD = 80
|
||||
MIN_PLANNING_SIGNALS = 2
|
||||
|
||||
|
||||
def compute_confidence(rubric_scores: Mapping[str, int]) -> int:
|
||||
"""Compute aggregate confidence from rubric score components."""
|
||||
return int(sum(rubric_scores.values()))
|
||||
|
||||
|
||||
def approval_required(*, confidence: int, is_external: bool, is_risky: bool) -> bool:
|
||||
"""Return whether an action must go through explicit approval."""
|
||||
return is_external or is_risky or confidence < CONFIDENCE_THRESHOLD
|
||||
|
||||
|
||||
def infer_planning(signals: Mapping[str, bool]) -> bool:
|
||||
"""Infer planning intent from boolean heuristic signals."""
|
||||
# Require at least two planning signals to avoid spam on general boards.
|
||||
truthy = [key for key, value in signals.items() if value]
|
||||
return len(truthy) >= 2
|
||||
return len(truthy) >= MIN_PLANNING_SIGNALS
|
||||
|
||||
|
||||
def task_fingerprint(title: str, description: str | None, board_id: str) -> str:
|
||||
"""Build a stable hash key for deduplicating similar board tasks."""
|
||||
normalized_title = title.strip().lower()
|
||||
normalized_desc = (description or "").strip().lower()
|
||||
seed = f"{board_id}::{normalized_title}::{normalized_desc}"
|
||||
|
||||
@@ -1,18 +1,24 @@
|
||||
"""Helpers for extracting and matching `@mention` tokens in text."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from app.models.agents import Agent
|
||||
if TYPE_CHECKING:
|
||||
from app.models.agents import Agent
|
||||
|
||||
# Mention tokens are single, space-free words (e.g. "@alex", "@lead").
|
||||
MENTION_PATTERN = re.compile(r"@([A-Za-z][\w-]{0,31})")
|
||||
|
||||
|
||||
def extract_mentions(message: str) -> set[str]:
|
||||
"""Extract normalized mention handles from a message body."""
|
||||
return {match.group(1).lower() for match in MENTION_PATTERN.finditer(message)}
|
||||
|
||||
|
||||
def matches_agent_mention(agent: Agent, mentions: set[str]) -> bool:
|
||||
"""Return whether a mention set targets the provided agent."""
|
||||
if not mentions:
|
||||
return False
|
||||
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
"""Organization membership and board-access service helpers."""
|
||||
# ruff: noqa: D101, D103
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterable
|
||||
from uuid import UUID
|
||||
from typing import TYPE_CHECKING, Iterable
|
||||
|
||||
from fastapi import HTTPException, status
|
||||
from sqlalchemy import func, or_
|
||||
from sqlalchemy.sql.elements import ColumnElement
|
||||
from sqlmodel import col, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.core.time import utcnow
|
||||
from app.db import crud
|
||||
@@ -19,7 +19,17 @@ from app.models.organization_invites import OrganizationInvite
|
||||
from app.models.organization_members import OrganizationMember
|
||||
from app.models.organizations import Organization
|
||||
from app.models.users import User
|
||||
from app.schemas.organizations import OrganizationBoardAccessSpec, OrganizationMemberAccessUpdate
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.sql.elements import ColumnElement
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.schemas.organizations import (
|
||||
OrganizationBoardAccessSpec,
|
||||
OrganizationMemberAccessUpdate,
|
||||
)
|
||||
|
||||
DEFAULT_ORG_NAME = "Personal"
|
||||
ADMIN_ROLES = {"owner", "admin"}
|
||||
@@ -63,7 +73,9 @@ async def get_member(
|
||||
).first(session)
|
||||
|
||||
|
||||
async def get_first_membership(session: AsyncSession, user_id: UUID) -> OrganizationMember | None:
|
||||
async def get_first_membership(
|
||||
session: AsyncSession, user_id: UUID,
|
||||
) -> OrganizationMember | None:
|
||||
return (
|
||||
await OrganizationMember.objects.filter_by(user_id=user_id)
|
||||
.order_by(col(OrganizationMember.created_at).asc())
|
||||
@@ -79,7 +91,9 @@ async def set_active_organization(
|
||||
) -> OrganizationMember:
|
||||
member = await get_member(session, user_id=user.id, organization_id=organization_id)
|
||||
if member is None:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="No org access")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail="No org access",
|
||||
)
|
||||
if user.active_organization_id != organization_id:
|
||||
user.active_organization_id = organization_id
|
||||
session.add(user)
|
||||
@@ -154,9 +168,10 @@ async def accept_invite(
|
||||
access_rows = list(
|
||||
await session.exec(
|
||||
select(OrganizationInviteBoardAccess).where(
|
||||
col(OrganizationInviteBoardAccess.organization_invite_id) == invite.id
|
||||
)
|
||||
)
|
||||
col(OrganizationInviteBoardAccess.organization_invite_id)
|
||||
== invite.id,
|
||||
),
|
||||
),
|
||||
)
|
||||
for row in access_rows:
|
||||
session.add(
|
||||
@@ -167,7 +182,7 @@ async def accept_invite(
|
||||
can_write=row.can_write,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
invite.accepted_by_user_id = user.id
|
||||
@@ -182,7 +197,9 @@ async def accept_invite(
|
||||
return member
|
||||
|
||||
|
||||
async def ensure_member_for_user(session: AsyncSession, user: User) -> OrganizationMember:
|
||||
async def ensure_member_for_user(
|
||||
session: AsyncSession, user: User,
|
||||
) -> OrganizationMember:
|
||||
existing = await get_active_membership(session, user)
|
||||
if existing is not None:
|
||||
return existing
|
||||
@@ -196,7 +213,9 @@ async def ensure_member_for_user(session: AsyncSession, user: User) -> Organizat
|
||||
now = utcnow()
|
||||
member_count = (
|
||||
await session.exec(
|
||||
select(func.count()).where(col(OrganizationMember.organization_id) == org.id)
|
||||
select(func.count()).where(
|
||||
col(OrganizationMember.organization_id) == org.id,
|
||||
),
|
||||
)
|
||||
).one()
|
||||
is_first = int(member_count or 0) == 0
|
||||
@@ -257,30 +276,40 @@ async def require_board_access(
|
||||
board: Board,
|
||||
write: bool,
|
||||
) -> OrganizationMember:
|
||||
member = await get_member(session, user_id=user.id, organization_id=board.organization_id)
|
||||
member = await get_member(
|
||||
session, user_id=user.id, organization_id=board.organization_id,
|
||||
)
|
||||
if member is None:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="No org access")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail="No org access",
|
||||
)
|
||||
if not await has_board_access(session, member=member, board=board, write=write):
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Board access denied")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail="Board access denied",
|
||||
)
|
||||
return member
|
||||
|
||||
|
||||
def board_access_filter(member: OrganizationMember, *, write: bool) -> ColumnElement[bool]:
|
||||
def board_access_filter(
|
||||
member: OrganizationMember, *, write: bool,
|
||||
) -> ColumnElement[bool]:
|
||||
if write and member_all_boards_write(member):
|
||||
return col(Board.organization_id) == member.organization_id
|
||||
if not write and member_all_boards_read(member):
|
||||
return col(Board.organization_id) == member.organization_id
|
||||
access_stmt = select(OrganizationBoardAccess.board_id).where(
|
||||
col(OrganizationBoardAccess.organization_member_id) == member.id
|
||||
col(OrganizationBoardAccess.organization_member_id) == member.id,
|
||||
)
|
||||
if write:
|
||||
access_stmt = access_stmt.where(col(OrganizationBoardAccess.can_write).is_(True))
|
||||
access_stmt = access_stmt.where(
|
||||
col(OrganizationBoardAccess.can_write).is_(True),
|
||||
)
|
||||
else:
|
||||
access_stmt = access_stmt.where(
|
||||
or_(
|
||||
col(OrganizationBoardAccess.can_read).is_(True),
|
||||
col(OrganizationBoardAccess.can_write).is_(True),
|
||||
)
|
||||
),
|
||||
)
|
||||
return col(Board.id).in_(access_stmt)
|
||||
|
||||
@@ -295,21 +324,25 @@ async def list_accessible_board_ids(
|
||||
not write and member_all_boards_read(member)
|
||||
):
|
||||
ids = await session.exec(
|
||||
select(Board.id).where(col(Board.organization_id) == member.organization_id)
|
||||
select(Board.id).where(
|
||||
col(Board.organization_id) == member.organization_id,
|
||||
),
|
||||
)
|
||||
return list(ids)
|
||||
|
||||
access_stmt = select(OrganizationBoardAccess.board_id).where(
|
||||
col(OrganizationBoardAccess.organization_member_id) == member.id
|
||||
col(OrganizationBoardAccess.organization_member_id) == member.id,
|
||||
)
|
||||
if write:
|
||||
access_stmt = access_stmt.where(col(OrganizationBoardAccess.can_write).is_(True))
|
||||
access_stmt = access_stmt.where(
|
||||
col(OrganizationBoardAccess.can_write).is_(True),
|
||||
)
|
||||
else:
|
||||
access_stmt = access_stmt.where(
|
||||
or_(
|
||||
col(OrganizationBoardAccess.can_read).is_(True),
|
||||
col(OrganizationBoardAccess.can_write).is_(True),
|
||||
)
|
||||
),
|
||||
)
|
||||
board_ids = await session.exec(access_stmt)
|
||||
return list(board_ids)
|
||||
@@ -337,18 +370,17 @@ async def apply_member_access_update(
|
||||
if update.all_boards_read or update.all_boards_write:
|
||||
return
|
||||
|
||||
rows: list[OrganizationBoardAccess] = []
|
||||
for entry in update.board_access:
|
||||
rows.append(
|
||||
OrganizationBoardAccess(
|
||||
organization_member_id=member.id,
|
||||
board_id=entry.board_id,
|
||||
can_read=entry.can_read,
|
||||
can_write=entry.can_write,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
rows = [
|
||||
OrganizationBoardAccess(
|
||||
organization_member_id=member.id,
|
||||
board_id=entry.board_id,
|
||||
can_read=entry.can_read,
|
||||
can_write=entry.can_write,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
for entry in update.board_access
|
||||
]
|
||||
session.add_all(rows)
|
||||
|
||||
|
||||
@@ -367,18 +399,17 @@ async def apply_invite_board_access(
|
||||
if invite.all_boards_read or invite.all_boards_write:
|
||||
return
|
||||
now = utcnow()
|
||||
rows: list[OrganizationInviteBoardAccess] = []
|
||||
for entry in entries:
|
||||
rows.append(
|
||||
OrganizationInviteBoardAccess(
|
||||
organization_invite_id=invite.id,
|
||||
board_id=entry.board_id,
|
||||
can_read=entry.can_read,
|
||||
can_write=entry.can_write,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
rows = [
|
||||
OrganizationInviteBoardAccess(
|
||||
organization_invite_id=invite.id,
|
||||
board_id=entry.board_id,
|
||||
can_read=entry.can_read,
|
||||
can_write=entry.can_write,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
for entry in entries
|
||||
]
|
||||
session.add_all(rows)
|
||||
|
||||
|
||||
@@ -423,9 +454,9 @@ async def apply_invite_to_member(
|
||||
access_rows = list(
|
||||
await session.exec(
|
||||
select(OrganizationInviteBoardAccess).where(
|
||||
col(OrganizationInviteBoardAccess.organization_invite_id) == invite.id
|
||||
)
|
||||
)
|
||||
col(OrganizationInviteBoardAccess.organization_invite_id) == invite.id,
|
||||
),
|
||||
),
|
||||
)
|
||||
for row in access_rows:
|
||||
existing = (
|
||||
@@ -433,7 +464,7 @@ async def apply_invite_to_member(
|
||||
select(OrganizationBoardAccess).where(
|
||||
col(OrganizationBoardAccess.organization_member_id) == member.id,
|
||||
col(OrganizationBoardAccess.board_id) == row.board_id,
|
||||
)
|
||||
),
|
||||
)
|
||||
).first()
|
||||
can_write = bool(row.can_write)
|
||||
@@ -447,7 +478,7 @@ async def apply_invite_to_member(
|
||||
can_write=can_write,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
),
|
||||
)
|
||||
else:
|
||||
existing.can_read = bool(existing.can_read or can_read)
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
"""Service helpers for querying and caching souls.directory content."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
@@ -11,33 +13,41 @@ SOULS_DIRECTORY_BASE_URL: Final[str] = "https://souls.directory"
|
||||
SOULS_DIRECTORY_SITEMAP_URL: Final[str] = f"{SOULS_DIRECTORY_BASE_URL}/sitemap.xml"
|
||||
|
||||
_SITEMAP_TTL_SECONDS: Final[int] = 60 * 60
|
||||
_SOUL_URL_MIN_PARTS: Final[int] = 6
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class SoulRef:
|
||||
"""Handle/slug reference pair for a soul entry."""
|
||||
|
||||
handle: str
|
||||
slug: str
|
||||
|
||||
@property
|
||||
def page_url(self) -> str:
|
||||
"""Return the canonical page URL for this soul."""
|
||||
return f"{SOULS_DIRECTORY_BASE_URL}/souls/{self.handle}/{self.slug}"
|
||||
|
||||
@property
|
||||
def raw_md_url(self) -> str:
|
||||
"""Return the raw markdown URL for this soul."""
|
||||
return f"{SOULS_DIRECTORY_BASE_URL}/api/souls/{self.handle}/{self.slug}.md"
|
||||
|
||||
|
||||
def _parse_sitemap_soul_refs(sitemap_xml: str) -> list[SoulRef]:
|
||||
"""Parse sitemap XML and extract valid souls.directory handle/slug refs."""
|
||||
try:
|
||||
root = ET.fromstring(sitemap_xml)
|
||||
# Souls sitemap is fetched from a known trusted host in this service flow.
|
||||
root = ET.fromstring(sitemap_xml) # noqa: S314
|
||||
except ET.ParseError:
|
||||
return []
|
||||
|
||||
# Handle both namespaced and non-namespaced sitemap XML.
|
||||
urls: list[str] = []
|
||||
for loc in root.iter():
|
||||
if loc.tag.endswith("loc") and loc.text:
|
||||
urls.append(loc.text.strip())
|
||||
urls = [
|
||||
loc.text.strip()
|
||||
for loc in root.iter()
|
||||
if loc.tag.endswith("loc") and loc.text
|
||||
]
|
||||
|
||||
refs: list[SoulRef] = []
|
||||
for url in urls:
|
||||
@@ -45,7 +55,7 @@ def _parse_sitemap_soul_refs(sitemap_xml: str) -> list[SoulRef]:
|
||||
continue
|
||||
# Expected: https://souls.directory/souls/{handle}/{slug}
|
||||
parts = url.split("/")
|
||||
if len(parts) < 6:
|
||||
if len(parts) < _SOUL_URL_MIN_PARTS:
|
||||
continue
|
||||
handle = parts[4].strip()
|
||||
slug = parts[5].strip()
|
||||
@@ -61,7 +71,11 @@ _sitemap_cache: dict[str, object] = {
|
||||
}
|
||||
|
||||
|
||||
async def list_souls_directory_refs(*, client: httpx.AsyncClient | None = None) -> list[SoulRef]:
|
||||
async def list_souls_directory_refs(
|
||||
*,
|
||||
client: httpx.AsyncClient | None = None,
|
||||
) -> list[SoulRef]:
|
||||
"""Return cached sitemap-derived soul refs, refreshing when TTL expires."""
|
||||
now = time.time()
|
||||
loaded_raw = _sitemap_cache.get("loaded_at")
|
||||
loaded_at = loaded_raw if isinstance(loaded_raw, (int, float)) else 0.0
|
||||
@@ -93,11 +107,15 @@ async def fetch_soul_markdown(
|
||||
slug: str,
|
||||
client: httpx.AsyncClient | None = None,
|
||||
) -> str:
|
||||
"""Fetch raw markdown content for a specific handle/slug pair."""
|
||||
normalized_handle = handle.strip().strip("/")
|
||||
normalized_slug = slug.strip().strip("/")
|
||||
if normalized_slug.endswith(".md"):
|
||||
normalized_slug = normalized_slug[: -len(".md")]
|
||||
url = f"{SOULS_DIRECTORY_BASE_URL}/api/souls/{normalized_handle}/{normalized_slug}.md"
|
||||
url = (
|
||||
f"{SOULS_DIRECTORY_BASE_URL}/api/souls/"
|
||||
f"{normalized_handle}/{normalized_slug}.md"
|
||||
)
|
||||
|
||||
owns_client = client is None
|
||||
if client is None:
|
||||
@@ -115,6 +133,7 @@ async def fetch_soul_markdown(
|
||||
|
||||
|
||||
def search_souls(refs: list[SoulRef], *, query: str, limit: int = 20) -> list[SoulRef]:
|
||||
"""Search refs by case-insensitive handle/slug substring with a hard limit."""
|
||||
q = query.strip().lower()
|
||||
if not q:
|
||||
return refs[: max(0, min(limit, len(refs)))]
|
||||
|
||||
Reference in New Issue
Block a user