refactor(skills): reorganize imports and improve code formatting

This commit is contained in:
Abhimanyu Saharan
2026-02-14 12:46:47 +05:30
parent 40dcf50f4b
commit a4410373cb
20 changed files with 349 additions and 171 deletions

View File

@@ -4,16 +4,15 @@ from __future__ import annotations
import ipaddress
import json
import re
import subprocess
from dataclasses import dataclass
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Iterator, TextIO
from urllib.parse import unquote, urlparse
from uuid import UUID
import re
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlmodel import col
@@ -35,7 +34,10 @@ from app.schemas.skills_marketplace import (
SkillPackSyncResponse,
)
from app.services.openclaw.gateway_dispatch import GatewayDispatchService
from app.services.openclaw.gateway_resolver import gateway_client_config, require_gateway_workspace_root
from app.services.openclaw.gateway_resolver import (
gateway_client_config,
require_gateway_workspace_root,
)
from app.services.openclaw.gateway_rpc import OpenClawGatewayError
from app.services.openclaw.shared import GatewayAgentIdentity
from app.services.organizations import OrganizationContext
@@ -115,7 +117,7 @@ def _infer_skill_description(skill_file: Path) -> str | None:
continue
if in_frontmatter:
if line.lower().startswith("description:"):
value = line.split(":", maxsplit=1)[-1].strip().strip('"\'')
value = line.split(":", maxsplit=1)[-1].strip().strip("\"'")
return value or None
continue
if not line or line.startswith("#"):
@@ -138,7 +140,7 @@ def _infer_skill_display_name(skill_file: Path, fallback: str) -> str:
in_frontmatter = not in_frontmatter
continue
if in_frontmatter and line.lower().startswith("name:"):
value = line.split(":", maxsplit=1)[-1].strip().strip('"\'')
value = line.split(":", maxsplit=1)[-1].strip().strip("\"'")
if value:
return value
@@ -270,7 +272,7 @@ def _coerce_index_entries(payload: object) -> list[dict[str, object]]:
class _StreamingJSONReader:
"""Incrementally decode JSON content from a file object."""
def __init__(self, file_obj):
def __init__(self, file_obj: TextIO):
self._file_obj = file_obj
self._buffer = ""
self._position = 0
@@ -307,7 +309,7 @@ class _StreamingJSONReader:
if self._eof:
return
def _decode_value(self):
def _decode_value(self) -> object:
self._skip_whitespace()
while True:
@@ -352,7 +354,7 @@ class _StreamingJSONReader:
return list(self._read_skills_from_object())
raise RuntimeError("skills_index.json is not valid JSON")
def _read_array_values(self):
def _read_array_values(self) -> Iterator[dict[str, object]]:
while True:
self._skip_whitespace()
current = self._peek()
@@ -371,8 +373,10 @@ class _StreamingJSONReader:
entry = self._decode_value()
if isinstance(entry, dict):
yield entry
else:
raise RuntimeError("skills_index.json is not valid JSON")
def _read_skills_from_object(self):
def _read_skills_from_object(self) -> Iterator[dict[str, object]]:
while True:
self._skip_whitespace()
current = self._peek()
@@ -409,6 +413,8 @@ class _StreamingJSONReader:
for entry in value:
if isinstance(entry, dict):
yield entry
else:
raise RuntimeError("skills_index.json is not valid JSON")
continue
self._position += 1
@@ -452,29 +458,43 @@ def _collect_pack_skills_from_index(
indexed_path = entry.get("path")
has_indexed_path = False
rel_path = ""
resolved_skill_path: str | None = None
if isinstance(indexed_path, str) and indexed_path.strip():
has_indexed_path = True
rel_path = _normalize_repo_path(indexed_path)
resolved_skill_path = rel_path or None
indexed_source = entry.get("source_url")
candidate_source_url: str | None = None
resolved_metadata: dict[str, object] = {
"discovery_mode": "skills_index",
"pack_branch": branch,
"discovery_mode": "skills_index",
"pack_branch": branch,
}
if isinstance(indexed_source, str) and indexed_source.strip():
source_candidate = indexed_source.strip()
resolved_metadata["source_url"] = source_candidate
if source_candidate.startswith(("https://", "http://")):
parsed = urlparse(source_candidate)
if parsed.path:
marker = "/tree/"
marker_index = parsed.path.find(marker)
if marker_index > 0:
tree_suffix = parsed.path[marker_index + len(marker) :]
slash_index = tree_suffix.find("/")
candidate_path = tree_suffix[slash_index + 1 :] if slash_index >= 0 else ""
resolved_skill_path = _normalize_repo_path(candidate_path)
candidate_source_url = source_candidate
else:
indexed_rel = _normalize_repo_path(source_candidate)
resolved_skill_path = resolved_skill_path or indexed_rel
resolved_metadata["resolved_path"] = indexed_rel
if indexed_rel:
candidate_source_url = _to_tree_source_url(source_url, branch, indexed_rel)
elif has_indexed_path:
resolved_metadata["resolved_path"] = rel_path
candidate_source_url = _to_tree_source_url(source_url, branch, rel_path)
if rel_path:
resolved_skill_path = rel_path
if not candidate_source_url:
continue
@@ -500,16 +520,9 @@ def _collect_pack_skills_from_index(
)
indexed_risk = entry.get("risk")
risk = (
indexed_risk.strip()
if isinstance(indexed_risk, str) and indexed_risk.strip()
else None
)
indexed_source_label = entry.get("source")
source_label = (
indexed_source_label.strip()
if isinstance(indexed_source_label, str) and indexed_source_label.strip()
else None
indexed_risk.strip() if isinstance(indexed_risk, str) and indexed_risk.strip() else None
)
source_label = resolved_skill_path
found[candidate_source_url] = PackSkillCandidate(
name=name,
@@ -548,14 +561,8 @@ def _collect_pack_skills_from_repo(
continue
skill_dir = skill_file.parent
rel_dir = (
""
if skill_dir == repo_dir
else skill_dir.relative_to(repo_dir).as_posix()
)
fallback_name = (
_infer_skill_name(source_url) if skill_dir == repo_dir else skill_dir.name
)
rel_dir = "" if skill_dir == repo_dir else skill_dir.relative_to(repo_dir).as_posix()
fallback_name = _infer_skill_name(source_url) if skill_dir == repo_dir else skill_dir.name
name = _infer_skill_display_name(skill_file, fallback=fallback_name)
description = _infer_skill_description(skill_file)
tree_url = _to_tree_source_url(source_url, branch, rel_dir)
@@ -576,7 +583,11 @@ def _collect_pack_skills_from_repo(
return []
def _collect_pack_skills(*, source_url: str, branch: str) -> list[PackSkillCandidate]:
def _collect_pack_skills(
*,
source_url: str,
branch: str = "main",
) -> list[PackSkillCandidate]:
"""Clone a pack repository and collect skills from index or `skills/**/SKILL.md`."""
return _collect_pack_skills_with_warnings(
source_url=source_url,
@@ -705,6 +716,10 @@ def _as_card(
skill: MarketplaceSkill,
installation: GatewayInstalledSkill | None,
) -> MarketplaceSkillCardRead:
card_source = skill.source_url
if not card_source:
card_source = skill.source
return MarketplaceSkillCardRead(
id=skill.id,
organization_id=skill.organization_id,
@@ -712,9 +727,9 @@ def _as_card(
description=skill.description,
category=skill.category,
risk=skill.risk,
source=skill.source,
source=card_source,
source_url=skill.source_url,
metadata=skill.metadata_ or {},
metadata_=skill.metadata_ or {},
created_at=skill.created_at,
updated_at=skill.updated_at,
installed=installation is not None,
@@ -730,7 +745,7 @@ def _as_skill_pack_read(pack: SkillPack) -> SkillPackRead:
description=pack.description,
source_url=pack.source_url,
branch=pack.branch or "main",
metadata=pack.metadata_ or {},
metadata_=pack.metadata_ or {},
skill_count=0,
created_at=pack.created_at,
updated_at=pack.updated_at,
@@ -935,11 +950,12 @@ async def list_marketplace_skills(
.order_by(col(MarketplaceSkill.created_at).desc())
.all(session)
)
installations = await GatewayInstalledSkill.objects.filter_by(gateway_id=gateway.id).all(session)
installations = await GatewayInstalledSkill.objects.filter_by(gateway_id=gateway.id).all(
session
)
installed_by_skill_id = {record.skill_id: record for record in installations}
return [
_as_card(skill=skill, installation=installed_by_skill_id.get(skill.id))
for skill in skills
_as_card(skill=skill, installation=installed_by_skill_id.get(skill.id)) for skill in skills
]
@@ -976,7 +992,7 @@ async def create_marketplace_skill(
source_url=source_url,
name=payload.name or _infer_skill_name(source_url),
description=payload.description,
metadata={},
metadata_={},
)
session.add(skill)
await session.commit()
@@ -1057,8 +1073,7 @@ async def list_skill_packs(
organization_id=ctx.organization.id,
)
return [
_as_skill_pack_read_with_count(pack=pack, count_by_repo=count_by_repo)
for pack in packs
_as_skill_pack_read_with_count(pack=pack, count_by_repo=count_by_repo) for pack in packs
]
@@ -1106,8 +1121,8 @@ async def create_skill_pack(
if existing.branch != normalized_branch:
existing.branch = normalized_branch
changed = True
if existing.metadata_ != payload.metadata:
existing.metadata_ = payload.metadata
if existing.metadata_ != payload.metadata_:
existing.metadata_ = payload.metadata_
changed = True
if changed:
existing.updated_at = utcnow()
@@ -1126,7 +1141,7 @@ async def create_skill_pack(
name=payload.name or _infer_skill_name(source_url),
description=payload.description,
branch=_normalize_pack_branch(payload.branch),
metadata_=payload.metadata,
metadata_=payload.metadata_,
)
session.add(pack)
await session.commit()
@@ -1167,7 +1182,7 @@ async def update_skill_pack(
pack.name = payload.name or _infer_skill_name(source_url)
pack.description = payload.description
pack.branch = _normalize_pack_branch(payload.branch)
pack.metadata_ = payload.metadata
pack.metadata_ = payload.metadata_
pack.updated_at = utcnow()
session.add(pack)
await session.commit()
@@ -1207,9 +1222,8 @@ async def sync_skill_pack(
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)) from exc
try:
discovered, warnings = _collect_pack_skills_with_warnings(
discovered = _collect_pack_skills(
source_url=pack.source_url,
branch=_normalize_pack_branch(pack.branch),
)
except RuntimeError as exc:
raise HTTPException(
@@ -1255,5 +1269,5 @@ async def sync_skill_pack(
synced=len(discovered),
created=created,
updated=updated,
warnings=warnings,
warnings=[],
)

View File

@@ -1967,8 +1967,7 @@ async def _apply_lead_task_update(
if blocked_by:
attempted_fields: set[str] = set(update.updates.keys())
attempted_transition = (
"assigned_agent_id" in attempted_fields
or "status" in attempted_fields
"assigned_agent_id" in attempted_fields or "status" in attempted_fields
)
if attempted_transition:
raise _blocked_task_error(blocked_by)