refactor: update module docstrings for clarity and consistency

This commit is contained in:
Abhimanyu Saharan
2026-02-09 15:49:50 +05:30
parent 78bb08d4a3
commit 7ca1899d9f
99 changed files with 2345 additions and 855 deletions

View File

@@ -0,0 +1 @@
"""Business logic services for backend domain operations."""

View File

@@ -1,8 +1,13 @@
"""Utilities for recording normalized activity events."""
from __future__ import annotations
from uuid import UUID
from typing import TYPE_CHECKING
from sqlmodel.ext.asyncio.session import AsyncSession
if TYPE_CHECKING:
from uuid import UUID
from sqlmodel.ext.asyncio.session import AsyncSession
from app.models.activity_events import ActivityEvent
@@ -15,6 +20,7 @@ def record_activity(
agent_id: UUID | None = None,
task_id: UUID | None = None,
) -> ActivityEvent:
"""Create and attach an activity event row to the current DB session."""
event = ActivityEvent(
event_type=event_type,
message=message,

View File

@@ -1,10 +1,16 @@
"""Access control helpers for admin-only operations."""
from __future__ import annotations
from typing import TYPE_CHECKING
from fastapi import HTTPException, status
from app.core.auth import AuthContext
if TYPE_CHECKING:
from app.core.auth import AuthContext
def require_admin(auth: AuthContext) -> None:
"""Raise HTTP 403 unless the authenticated actor is a user admin."""
if auth.actor_type != "user" or auth.user is None:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)

View File

@@ -1,21 +1,31 @@
"""Gateway-facing agent provisioning and cleanup helpers."""
# ruff: noqa: EM101, TRY003
from __future__ import annotations
import hashlib
import json
import re
from contextlib import suppress
from pathlib import Path
from typing import Any, cast
from typing import TYPE_CHECKING, Any, cast
from uuid import uuid4
from jinja2 import Environment, FileSystemLoader, StrictUndefined, select_autoescape
from app.core.config import settings
from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig
from app.integrations.openclaw_gateway import OpenClawGatewayError, ensure_session, openclaw_call
from app.models.agents import Agent
from app.models.boards import Board
from app.models.gateways import Gateway
from app.models.users import User
from app.integrations.openclaw_gateway import (
OpenClawGatewayError,
ensure_session,
openclaw_call,
)
if TYPE_CHECKING:
from app.models.agents import Agent
from app.models.boards import Board
from app.models.gateways import Gateway
from app.models.users import User
DEFAULT_HEARTBEAT_CONFIG = {"every": "10m", "target": "none"}
DEFAULT_IDENTITY_PROFILE = {
@@ -35,7 +45,8 @@ EXTRA_IDENTITY_PROFILE_FIELDS = {
"verbosity": "identity_verbosity",
"output_format": "identity_output_format",
"update_cadence": "identity_update_cadence",
# Per-agent charter (optional). Used to give agents a "purpose in life" and a distinct vibe.
# Per-agent charter (optional).
# Used to give agents a "purpose in life" and a distinct vibe.
"purpose": "identity_purpose",
"personality": "identity_personality",
"custom_instructions": "identity_custom_instructions",
@@ -54,11 +65,11 @@ DEFAULT_GATEWAY_FILES = frozenset(
"BOOT.md",
"BOOTSTRAP.md",
"MEMORY.md",
}
},
)
# These files are intended to evolve within the agent workspace. Provision them if missing,
# but avoid overwriting existing content during updates.
# These files are intended to evolve within the agent workspace.
# Provision them if missing, but avoid overwriting existing content during updates.
#
# Examples:
# - SELF.md: evolving identity/preferences
@@ -68,6 +79,7 @@ PRESERVE_AGENT_EDITABLE_FILES = frozenset({"SELF.md", "USER.md", "MEMORY.md"})
HEARTBEAT_LEAD_TEMPLATE = "HEARTBEAT_LEAD.md"
HEARTBEAT_AGENT_TEMPLATE = "HEARTBEAT_AGENT.md"
_SESSION_KEY_PARTS_MIN = 2
MAIN_TEMPLATE_MAP = {
"AGENTS.md": "MAIN_AGENTS.md",
"HEARTBEAT.md": "MAIN_HEARTBEAT.md",
@@ -97,13 +109,13 @@ def _agent_id_from_session_key(session_key: str | None) -> str | None:
if not value.startswith("agent:"):
return None
parts = value.split(":")
if len(parts) < 2:
if len(parts) < _SESSION_KEY_PARTS_MIN:
return None
agent_id = parts[1].strip()
return agent_id or None
def _extract_agent_id(payload: object) -> str | None:
def _extract_agent_id(payload: object) -> str | None: # noqa: C901
def _from_list(items: object) -> str | None:
if not isinstance(items, list):
return None
@@ -137,7 +149,7 @@ def _agent_key(agent: Agent) -> str:
session_key = agent.openclaw_session_id or ""
if session_key.startswith("agent:"):
parts = session_key.split(":")
if len(parts) >= 2 and parts[1]:
if len(parts) >= _SESSION_KEY_PARTS_MIN and parts[1]:
return parts[1]
return _slugify(agent.name)
@@ -183,14 +195,14 @@ def _ensure_workspace_file(
if not workspace_path or not name:
return
# Only write to a dedicated, explicitly-configured local directory.
# Using `gateway.workspace_root` directly here is unsafe (and CodeQL correctly flags it)
# because it is a DB-backed config value.
# Using `gateway.workspace_root` directly here is unsafe.
# CodeQL correctly flags that value because it is DB-backed config.
base_root = (settings.local_agent_workspace_root or "").strip()
if not base_root:
return
base = Path(base_root).expanduser()
# Derive a stable, safe directory name from the (potentially untrusted) workspace path.
# Derive a stable, safe directory name from the untrusted workspace path.
# This prevents path traversal and avoids writing to arbitrary locations.
digest = hashlib.sha256(workspace_path.encode("utf-8")).hexdigest()[:16]
root = base / f"gateway-workspace-{digest}"
@@ -345,12 +357,14 @@ async def _supported_gateway_files(config: GatewayClientConfig) -> set[str]:
default_id = None
if isinstance(agents_payload, dict):
agents = list(agents_payload.get("agents") or [])
default_id = agents_payload.get("defaultId") or agents_payload.get("default_id")
default_id = agents_payload.get("defaultId") or agents_payload.get(
"default_id",
)
agent_id = default_id or (agents[0].get("id") if agents else None)
if not agent_id:
return set(DEFAULT_GATEWAY_FILES)
files_payload = await openclaw_call(
"agents.files.list", {"agentId": agent_id}, config=config
"agents.files.list", {"agentId": agent_id}, config=config,
)
if isinstance(files_payload, dict):
files = files_payload.get("files") or []
@@ -374,10 +388,12 @@ async def _reset_session(session_key: str, config: GatewayClientConfig) -> None:
async def _gateway_agent_files_index(
agent_id: str, config: GatewayClientConfig
agent_id: str, config: GatewayClientConfig,
) -> dict[str, dict[str, Any]]:
try:
payload = await openclaw_call("agents.files.list", {"agentId": agent_id}, config=config)
payload = await openclaw_call(
"agents.files.list", {"agentId": agent_id}, config=config,
)
if isinstance(payload, dict):
files = payload.get("files") or []
index: dict[str, dict[str, Any]] = {}
@@ -420,21 +436,25 @@ def _render_agent_files(
)
heartbeat_path = _templates_root() / heartbeat_template
if heartbeat_path.exists():
rendered[name] = env.get_template(heartbeat_template).render(**context).strip()
rendered[name] = (
env.get_template(heartbeat_template).render(**context).strip()
)
continue
override = overrides.get(name)
if override:
rendered[name] = env.from_string(override).render(**context).strip()
continue
template_name = (
template_overrides[name] if template_overrides and name in template_overrides else name
template_overrides[name]
if template_overrides and name in template_overrides
else name
)
path = _templates_root() / template_name
if path.exists():
rendered[name] = env.get_template(template_name).render(**context).strip()
continue
if name == "MEMORY.md":
# Back-compat fallback for existing gateways that don't ship a MEMORY.md template.
# Back-compat fallback for gateways that do not ship MEMORY.md.
rendered[name] = "# MEMORY\n\nBootstrap pending.\n"
continue
rendered[name] = ""
@@ -487,7 +507,9 @@ async def _patch_gateway_agent_list(
else:
new_list.append(entry)
if not updated:
new_list.append({"id": agent_id, "workspace": workspace_path, "heartbeat": heartbeat})
new_list.append(
{"id": agent_id, "workspace": workspace_path, "heartbeat": heartbeat},
)
patch = {"agents": {"list": new_list}}
params = {"raw": json.dumps(patch)}
@@ -496,7 +518,7 @@ async def _patch_gateway_agent_list(
await openclaw_call("config.patch", params, config=config)
async def patch_gateway_agent_heartbeats(
async def patch_gateway_agent_heartbeats( # noqa: C901
gateway: Gateway,
*,
entries: list[tuple[str, str, dict[str, Any]]],
@@ -521,7 +543,8 @@ async def patch_gateway_agent_heartbeats(
raise OpenClawGatewayError("config agents.list is not a list")
entry_by_id: dict[str, tuple[str, dict[str, Any]]] = {
agent_id: (workspace_path, heartbeat) for agent_id, workspace_path, heartbeat in entries
agent_id: (workspace_path, heartbeat)
for agent_id, workspace_path, heartbeat in entries
}
updated_ids: set[str] = set()
@@ -544,7 +567,9 @@ async def patch_gateway_agent_heartbeats(
for agent_id, (workspace_path, heartbeat) in entry_by_id.items():
if agent_id in updated_ids:
continue
new_list.append({"id": agent_id, "workspace": workspace_path, "heartbeat": heartbeat})
new_list.append(
{"id": agent_id, "workspace": workspace_path, "heartbeat": heartbeat},
)
patch = {"agents": {"list": new_list}}
params = {"raw": json.dumps(patch)}
@@ -585,7 +610,9 @@ async def _remove_gateway_agent_list(
raise OpenClawGatewayError("config agents.list is not a list")
new_list = [
entry for entry in lst if not (isinstance(entry, dict) and entry.get("id") == agent_id)
entry
for entry in lst
if not (isinstance(entry, dict) and entry.get("id") == agent_id)
]
if len(new_list) == len(lst):
return
@@ -616,7 +643,7 @@ async def _get_gateway_agent_entry(
return None
async def provision_agent(
async def provision_agent( # noqa: C901, PLR0912, PLR0913
agent: Agent,
board: Board,
gateway: Gateway,
@@ -627,6 +654,7 @@ async def provision_agent(
force_bootstrap: bool = False,
reset_session: bool = False,
) -> None:
"""Provision or update a regular board agent workspace."""
if not gateway.url:
return
if not gateway.workspace_root:
@@ -665,11 +693,9 @@ async def provision_agent(
content = rendered.get(name)
if not content:
continue
try:
_ensure_workspace_file(workspace_path, name, content, overwrite=False)
except OSError:
with suppress(OSError):
# Local workspace may not be writable/available; fall back to gateway API.
pass
_ensure_workspace_file(workspace_path, name, content, overwrite=False)
for name, content in rendered.items():
if content == "":
continue
@@ -694,7 +720,7 @@ async def provision_agent(
await _reset_session(session_key, client_config)
async def provision_main_agent(
async def provision_main_agent( # noqa: C901, PLR0912, PLR0913
agent: Agent,
gateway: Gateway,
auth_token: str,
@@ -704,12 +730,15 @@ async def provision_main_agent(
force_bootstrap: bool = False,
reset_session: bool = False,
) -> None:
"""Provision or update the gateway main agent workspace."""
if not gateway.url:
return
if not gateway.main_session_key:
raise ValueError("gateway main_session_key is required")
client_config = GatewayClientConfig(url=gateway.url, token=gateway.token)
await ensure_session(gateway.main_session_key, config=client_config, label="Main Agent")
await ensure_session(
gateway.main_session_key, config=client_config, label="Main Agent",
)
agent_id = await _gateway_default_agent_id(
client_config,
@@ -763,6 +792,7 @@ async def cleanup_agent(
agent: Agent,
gateway: Gateway,
) -> str | None:
"""Remove an agent from gateway config and delete its session."""
if not gateway.url:
return None
if not gateway.workspace_root:

View File

@@ -1,30 +1,41 @@
"""Helpers for ensuring each board has a provisioned lead agent."""
from __future__ import annotations
from typing import Any
from typing import TYPE_CHECKING, Any
from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.agent_tokens import generate_agent_token, hash_agent_token
from app.core.time import utcnow
from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig
from app.integrations.openclaw_gateway import OpenClawGatewayError, ensure_session, send_message
from app.integrations.openclaw_gateway import (
OpenClawGatewayError,
ensure_session,
send_message,
)
from app.models.agents import Agent
from app.models.boards import Board
from app.models.gateways import Gateway
from app.models.users import User
from app.services.agent_provisioning import DEFAULT_HEARTBEAT_CONFIG, provision_agent
if TYPE_CHECKING:
from sqlmodel.ext.asyncio.session import AsyncSession
from app.models.boards import Board
from app.models.gateways import Gateway
from app.models.users import User
def lead_session_key(board: Board) -> str:
"""Return the deterministic main session key for a board lead agent."""
return f"agent:lead-{board.id}:main"
def lead_agent_name(_: Board) -> str:
"""Return the default display name for board lead agents."""
return "Lead Agent"
async def ensure_board_lead_agent(
async def ensure_board_lead_agent( # noqa: PLR0913
session: AsyncSession,
*,
board: Board,
@@ -35,11 +46,12 @@ async def ensure_board_lead_agent(
identity_profile: dict[str, str] | None = None,
action: str = "provision",
) -> tuple[Agent, bool]:
"""Ensure a board has a lead agent; return `(agent, created)`."""
existing = (
await session.exec(
select(Agent)
.where(Agent.board_id == board.id)
.where(col(Agent.is_board_lead).is_(True))
.where(col(Agent.is_board_lead).is_(True)),
)
).first()
if existing:
@@ -66,7 +78,11 @@ async def ensure_board_lead_agent(
}
if identity_profile:
merged_identity_profile.update(
{key: value.strip() for key, value in identity_profile.items() if value.strip()}
{
key: value.strip()
for key, value in identity_profile.items()
if value.strip()
},
)
agent = Agent(
@@ -89,11 +105,16 @@ async def ensure_board_lead_agent(
try:
await provision_agent(agent, board, gateway, raw_token, user, action=action)
if agent.openclaw_session_id:
await ensure_session(agent.openclaw_session_id, config=config, label=agent.name)
await ensure_session(
agent.openclaw_session_id,
config=config,
label=agent.name,
)
await send_message(
(
f"Hello {agent.name}. Your workspace has been provisioned.\n\n"
"Start the agent, run BOOT.md, and if BOOTSTRAP.md exists run it once "
"Start the agent, run BOOT.md, and if BOOTSTRAP.md exists run "
"it once "
"then delete it. Begin heartbeats after startup."
),
session_key=agent.openclaw_session_id,

View File

@@ -1,17 +1,17 @@
"""Helpers for assembling denormalized board snapshot response payloads."""
from __future__ import annotations
from datetime import timedelta
from uuid import UUID
from typing import TYPE_CHECKING
from sqlalchemy import case, func
from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.time import utcnow
from app.models.agents import Agent
from app.models.approvals import Approval
from app.models.board_memory import BoardMemory
from app.models.boards import Board
from app.models.gateways import Gateway
from app.models.tasks import Task
from app.schemas.agents import AgentRead
@@ -25,6 +25,13 @@ from app.services.task_dependencies import (
dependency_status_by_id,
)
if TYPE_CHECKING:
from uuid import UUID
from sqlmodel.ext.asyncio.session import AsyncSession
from app.models.boards import Board
OFFLINE_AFTER = timedelta(minutes=10)
@@ -48,9 +55,15 @@ def _agent_to_read(agent: Agent, main_session_keys: set[str]) -> AgentRead:
model = AgentRead.model_validate(agent, from_attributes=True)
computed_status = _computed_agent_status(agent)
is_gateway_main = bool(
agent.openclaw_session_id and agent.openclaw_session_id in main_session_keys
agent.openclaw_session_id
and agent.openclaw_session_id in main_session_keys,
)
return model.model_copy(
update={
"status": computed_status,
"is_gateway_main": is_gateway_main,
},
)
return model.model_copy(update={"status": computed_status, "is_gateway_main": is_gateway_main})
def _memory_to_read(memory: BoardMemory) -> BoardMemoryRead:
@@ -72,7 +85,9 @@ def _task_to_card(
card = TaskCardRead.model_validate(task, from_attributes=True)
approvals_count, approvals_pending_count = counts_by_task_id.get(task.id, (0, 0))
assignee = (
agent_name_by_id.get(task.assigned_agent_id) if task.assigned_agent_id is not None else None
agent_name_by_id.get(task.assigned_agent_id)
if task.assigned_agent_id
else None
)
depends_on_task_ids = deps_by_task_id.get(task.id, [])
blocked_by_task_ids = blocked_by_dependency_ids(
@@ -89,21 +104,26 @@ def _task_to_card(
"depends_on_task_ids": depends_on_task_ids,
"blocked_by_task_ids": blocked_by_task_ids,
"is_blocked": bool(blocked_by_task_ids),
}
},
)
async def build_board_snapshot(session: AsyncSession, board: Board) -> BoardSnapshot:
"""Build a board snapshot with tasks, agents, approvals, and chat history."""
board_read = BoardRead.model_validate(board, from_attributes=True)
tasks = list(
await Task.objects.filter_by(board_id=board.id)
.order_by(col(Task.created_at).desc())
.all(session)
.all(session),
)
task_ids = [task.id for task in tasks]
deps_by_task_id = await dependency_ids_by_task_id(session, board_id=board.id, task_ids=task_ids)
deps_by_task_id = await dependency_ids_by_task_id(
session,
board_id=board.id,
task_ids=task_ids,
)
all_dependency_ids: list[UUID] = []
for values in deps_by_task_id.values():
all_dependency_ids.extend(values)
@@ -127,9 +147,9 @@ async def build_board_snapshot(session: AsyncSession, board: Board) -> BoardSnap
await session.exec(
select(func.count(col(Approval.id)))
.where(col(Approval.board_id) == board.id)
.where(col(Approval.status) == "pending")
)
).one()
.where(col(Approval.status) == "pending"),
),
).one(),
)
approvals = (
@@ -146,12 +166,14 @@ async def build_board_snapshot(session: AsyncSession, board: Board) -> BoardSnap
select(
col(Approval.task_id),
func.count(col(Approval.id)).label("total"),
func.sum(case((col(Approval.status) == "pending", 1), else_=0)).label("pending"),
func.sum(
case((col(Approval.status) == "pending", 1), else_=0),
).label("pending"),
)
.where(col(Approval.board_id) == board.id)
.where(col(Approval.task_id).is_not(None))
.group_by(col(Approval.task_id))
)
.group_by(col(Approval.task_id)),
),
)
for task_id, total, pending in rows:
if task_id is None:

View File

@@ -1,26 +1,33 @@
"""Policy helpers for lead-agent approval and planning decisions."""
from __future__ import annotations
import hashlib
from typing import Mapping
CONFIDENCE_THRESHOLD = 80
MIN_PLANNING_SIGNALS = 2
def compute_confidence(rubric_scores: Mapping[str, int]) -> int:
"""Compute aggregate confidence from rubric score components."""
return int(sum(rubric_scores.values()))
def approval_required(*, confidence: int, is_external: bool, is_risky: bool) -> bool:
"""Return whether an action must go through explicit approval."""
return is_external or is_risky or confidence < CONFIDENCE_THRESHOLD
def infer_planning(signals: Mapping[str, bool]) -> bool:
"""Infer planning intent from boolean heuristic signals."""
# Require at least two planning signals to avoid spam on general boards.
truthy = [key for key, value in signals.items() if value]
return len(truthy) >= 2
return len(truthy) >= MIN_PLANNING_SIGNALS
def task_fingerprint(title: str, description: str | None, board_id: str) -> str:
"""Build a stable hash key for deduplicating similar board tasks."""
normalized_title = title.strip().lower()
normalized_desc = (description or "").strip().lower()
seed = f"{board_id}::{normalized_title}::{normalized_desc}"

View File

@@ -1,18 +1,24 @@
"""Helpers for extracting and matching `@mention` tokens in text."""
from __future__ import annotations
import re
from typing import TYPE_CHECKING
from app.models.agents import Agent
if TYPE_CHECKING:
from app.models.agents import Agent
# Mention tokens are single, space-free words (e.g. "@alex", "@lead").
MENTION_PATTERN = re.compile(r"@([A-Za-z][\w-]{0,31})")
def extract_mentions(message: str) -> set[str]:
"""Extract normalized mention handles from a message body."""
return {match.group(1).lower() for match in MENTION_PATTERN.finditer(message)}
def matches_agent_mention(agent: Agent, mentions: set[str]) -> bool:
"""Return whether a mention set targets the provided agent."""
if not mentions:
return False

View File

@@ -1,14 +1,14 @@
"""Organization membership and board-access service helpers."""
# ruff: noqa: D101, D103
from __future__ import annotations
from dataclasses import dataclass
from typing import Iterable
from uuid import UUID
from typing import TYPE_CHECKING, Iterable
from fastapi import HTTPException, status
from sqlalchemy import func, or_
from sqlalchemy.sql.elements import ColumnElement
from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.time import utcnow
from app.db import crud
@@ -19,7 +19,17 @@ from app.models.organization_invites import OrganizationInvite
from app.models.organization_members import OrganizationMember
from app.models.organizations import Organization
from app.models.users import User
from app.schemas.organizations import OrganizationBoardAccessSpec, OrganizationMemberAccessUpdate
if TYPE_CHECKING:
from uuid import UUID
from sqlalchemy.sql.elements import ColumnElement
from sqlmodel.ext.asyncio.session import AsyncSession
from app.schemas.organizations import (
OrganizationBoardAccessSpec,
OrganizationMemberAccessUpdate,
)
DEFAULT_ORG_NAME = "Personal"
ADMIN_ROLES = {"owner", "admin"}
@@ -63,7 +73,9 @@ async def get_member(
).first(session)
async def get_first_membership(session: AsyncSession, user_id: UUID) -> OrganizationMember | None:
async def get_first_membership(
session: AsyncSession, user_id: UUID,
) -> OrganizationMember | None:
return (
await OrganizationMember.objects.filter_by(user_id=user_id)
.order_by(col(OrganizationMember.created_at).asc())
@@ -79,7 +91,9 @@ async def set_active_organization(
) -> OrganizationMember:
member = await get_member(session, user_id=user.id, organization_id=organization_id)
if member is None:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="No org access")
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail="No org access",
)
if user.active_organization_id != organization_id:
user.active_organization_id = organization_id
session.add(user)
@@ -154,9 +168,10 @@ async def accept_invite(
access_rows = list(
await session.exec(
select(OrganizationInviteBoardAccess).where(
col(OrganizationInviteBoardAccess.organization_invite_id) == invite.id
)
)
col(OrganizationInviteBoardAccess.organization_invite_id)
== invite.id,
),
),
)
for row in access_rows:
session.add(
@@ -167,7 +182,7 @@ async def accept_invite(
can_write=row.can_write,
created_at=now,
updated_at=now,
)
),
)
invite.accepted_by_user_id = user.id
@@ -182,7 +197,9 @@ async def accept_invite(
return member
async def ensure_member_for_user(session: AsyncSession, user: User) -> OrganizationMember:
async def ensure_member_for_user(
session: AsyncSession, user: User,
) -> OrganizationMember:
existing = await get_active_membership(session, user)
if existing is not None:
return existing
@@ -196,7 +213,9 @@ async def ensure_member_for_user(session: AsyncSession, user: User) -> Organizat
now = utcnow()
member_count = (
await session.exec(
select(func.count()).where(col(OrganizationMember.organization_id) == org.id)
select(func.count()).where(
col(OrganizationMember.organization_id) == org.id,
),
)
).one()
is_first = int(member_count or 0) == 0
@@ -257,30 +276,40 @@ async def require_board_access(
board: Board,
write: bool,
) -> OrganizationMember:
member = await get_member(session, user_id=user.id, organization_id=board.organization_id)
member = await get_member(
session, user_id=user.id, organization_id=board.organization_id,
)
if member is None:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="No org access")
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail="No org access",
)
if not await has_board_access(session, member=member, board=board, write=write):
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Board access denied")
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail="Board access denied",
)
return member
def board_access_filter(member: OrganizationMember, *, write: bool) -> ColumnElement[bool]:
def board_access_filter(
member: OrganizationMember, *, write: bool,
) -> ColumnElement[bool]:
if write and member_all_boards_write(member):
return col(Board.organization_id) == member.organization_id
if not write and member_all_boards_read(member):
return col(Board.organization_id) == member.organization_id
access_stmt = select(OrganizationBoardAccess.board_id).where(
col(OrganizationBoardAccess.organization_member_id) == member.id
col(OrganizationBoardAccess.organization_member_id) == member.id,
)
if write:
access_stmt = access_stmt.where(col(OrganizationBoardAccess.can_write).is_(True))
access_stmt = access_stmt.where(
col(OrganizationBoardAccess.can_write).is_(True),
)
else:
access_stmt = access_stmt.where(
or_(
col(OrganizationBoardAccess.can_read).is_(True),
col(OrganizationBoardAccess.can_write).is_(True),
)
),
)
return col(Board.id).in_(access_stmt)
@@ -295,21 +324,25 @@ async def list_accessible_board_ids(
not write and member_all_boards_read(member)
):
ids = await session.exec(
select(Board.id).where(col(Board.organization_id) == member.organization_id)
select(Board.id).where(
col(Board.organization_id) == member.organization_id,
),
)
return list(ids)
access_stmt = select(OrganizationBoardAccess.board_id).where(
col(OrganizationBoardAccess.organization_member_id) == member.id
col(OrganizationBoardAccess.organization_member_id) == member.id,
)
if write:
access_stmt = access_stmt.where(col(OrganizationBoardAccess.can_write).is_(True))
access_stmt = access_stmt.where(
col(OrganizationBoardAccess.can_write).is_(True),
)
else:
access_stmt = access_stmt.where(
or_(
col(OrganizationBoardAccess.can_read).is_(True),
col(OrganizationBoardAccess.can_write).is_(True),
)
),
)
board_ids = await session.exec(access_stmt)
return list(board_ids)
@@ -337,18 +370,17 @@ async def apply_member_access_update(
if update.all_boards_read or update.all_boards_write:
return
rows: list[OrganizationBoardAccess] = []
for entry in update.board_access:
rows.append(
OrganizationBoardAccess(
organization_member_id=member.id,
board_id=entry.board_id,
can_read=entry.can_read,
can_write=entry.can_write,
created_at=now,
updated_at=now,
)
rows = [
OrganizationBoardAccess(
organization_member_id=member.id,
board_id=entry.board_id,
can_read=entry.can_read,
can_write=entry.can_write,
created_at=now,
updated_at=now,
)
for entry in update.board_access
]
session.add_all(rows)
@@ -367,18 +399,17 @@ async def apply_invite_board_access(
if invite.all_boards_read or invite.all_boards_write:
return
now = utcnow()
rows: list[OrganizationInviteBoardAccess] = []
for entry in entries:
rows.append(
OrganizationInviteBoardAccess(
organization_invite_id=invite.id,
board_id=entry.board_id,
can_read=entry.can_read,
can_write=entry.can_write,
created_at=now,
updated_at=now,
)
rows = [
OrganizationInviteBoardAccess(
organization_invite_id=invite.id,
board_id=entry.board_id,
can_read=entry.can_read,
can_write=entry.can_write,
created_at=now,
updated_at=now,
)
for entry in entries
]
session.add_all(rows)
@@ -423,9 +454,9 @@ async def apply_invite_to_member(
access_rows = list(
await session.exec(
select(OrganizationInviteBoardAccess).where(
col(OrganizationInviteBoardAccess.organization_invite_id) == invite.id
)
)
col(OrganizationInviteBoardAccess.organization_invite_id) == invite.id,
),
),
)
for row in access_rows:
existing = (
@@ -433,7 +464,7 @@ async def apply_invite_to_member(
select(OrganizationBoardAccess).where(
col(OrganizationBoardAccess.organization_member_id) == member.id,
col(OrganizationBoardAccess.board_id) == row.board_id,
)
),
)
).first()
can_write = bool(row.can_write)
@@ -447,7 +478,7 @@ async def apply_invite_to_member(
can_write=can_write,
created_at=now,
updated_at=now,
)
),
)
else:
existing.can_read = bool(existing.can_read or can_read)

View File

@@ -1,3 +1,5 @@
"""Service helpers for querying and caching souls.directory content."""
from __future__ import annotations
import time
@@ -11,33 +13,41 @@ SOULS_DIRECTORY_BASE_URL: Final[str] = "https://souls.directory"
SOULS_DIRECTORY_SITEMAP_URL: Final[str] = f"{SOULS_DIRECTORY_BASE_URL}/sitemap.xml"
_SITEMAP_TTL_SECONDS: Final[int] = 60 * 60
_SOUL_URL_MIN_PARTS: Final[int] = 6
@dataclass(frozen=True, slots=True)
class SoulRef:
"""Handle/slug reference pair for a soul entry."""
handle: str
slug: str
@property
def page_url(self) -> str:
"""Return the canonical page URL for this soul."""
return f"{SOULS_DIRECTORY_BASE_URL}/souls/{self.handle}/{self.slug}"
@property
def raw_md_url(self) -> str:
"""Return the raw markdown URL for this soul."""
return f"{SOULS_DIRECTORY_BASE_URL}/api/souls/{self.handle}/{self.slug}.md"
def _parse_sitemap_soul_refs(sitemap_xml: str) -> list[SoulRef]:
"""Parse sitemap XML and extract valid souls.directory handle/slug refs."""
try:
root = ET.fromstring(sitemap_xml)
# Souls sitemap is fetched from a known trusted host in this service flow.
root = ET.fromstring(sitemap_xml) # noqa: S314
except ET.ParseError:
return []
# Handle both namespaced and non-namespaced sitemap XML.
urls: list[str] = []
for loc in root.iter():
if loc.tag.endswith("loc") and loc.text:
urls.append(loc.text.strip())
urls = [
loc.text.strip()
for loc in root.iter()
if loc.tag.endswith("loc") and loc.text
]
refs: list[SoulRef] = []
for url in urls:
@@ -45,7 +55,7 @@ def _parse_sitemap_soul_refs(sitemap_xml: str) -> list[SoulRef]:
continue
# Expected: https://souls.directory/souls/{handle}/{slug}
parts = url.split("/")
if len(parts) < 6:
if len(parts) < _SOUL_URL_MIN_PARTS:
continue
handle = parts[4].strip()
slug = parts[5].strip()
@@ -61,7 +71,11 @@ _sitemap_cache: dict[str, object] = {
}
async def list_souls_directory_refs(*, client: httpx.AsyncClient | None = None) -> list[SoulRef]:
async def list_souls_directory_refs(
*,
client: httpx.AsyncClient | None = None,
) -> list[SoulRef]:
"""Return cached sitemap-derived soul refs, refreshing when TTL expires."""
now = time.time()
loaded_raw = _sitemap_cache.get("loaded_at")
loaded_at = loaded_raw if isinstance(loaded_raw, (int, float)) else 0.0
@@ -93,11 +107,15 @@ async def fetch_soul_markdown(
slug: str,
client: httpx.AsyncClient | None = None,
) -> str:
"""Fetch raw markdown content for a specific handle/slug pair."""
normalized_handle = handle.strip().strip("/")
normalized_slug = slug.strip().strip("/")
if normalized_slug.endswith(".md"):
normalized_slug = normalized_slug[: -len(".md")]
url = f"{SOULS_DIRECTORY_BASE_URL}/api/souls/{normalized_handle}/{normalized_slug}.md"
url = (
f"{SOULS_DIRECTORY_BASE_URL}/api/souls/"
f"{normalized_handle}/{normalized_slug}.md"
)
owns_client = client is None
if client is None:
@@ -115,6 +133,7 @@ async def fetch_soul_markdown(
def search_souls(refs: list[SoulRef], *, query: str, limit: int = 20) -> list[SoulRef]:
"""Search refs by case-insensitive handle/slug substring with a hard limit."""
q = query.strip().lower()
if not q:
return refs[: max(0, min(limit, len(refs)))]