fix(skills): validate pack source URLs + git clone timeouts
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import ipaddress
|
||||
import json
|
||||
import subprocess
|
||||
from dataclasses import dataclass
|
||||
@@ -45,6 +46,10 @@ SESSION_DEP = Depends(get_session)
|
||||
ORG_ADMIN_DEP = Depends(require_org_admin)
|
||||
GATEWAY_ID_QUERY = Query(...)
|
||||
|
||||
ALLOWED_PACK_SOURCE_SCHEMES = {"https"}
|
||||
GIT_CLONE_TIMEOUT_SECONDS = 30
|
||||
GIT_REV_PARSE_TIMEOUT_SECONDS = 10
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PackSkillCandidate:
|
||||
@@ -137,6 +142,38 @@ def _normalize_repo_source_url(source_url: str) -> str:
|
||||
return normalized
|
||||
|
||||
|
||||
def _validate_pack_source_url(source_url: str) -> None:
|
||||
"""Validate that a skill pack source URL is safe to clone.
|
||||
|
||||
The current implementation is intentionally conservative:
|
||||
- allow only https URLs
|
||||
- block localhost
|
||||
- block literal private/loopback/link-local IPs
|
||||
|
||||
Note: DNS-based private resolution is not checked here.
|
||||
"""
|
||||
|
||||
parsed = urlparse(source_url)
|
||||
scheme = (parsed.scheme or "").lower()
|
||||
if scheme not in ALLOWED_PACK_SOURCE_SCHEMES:
|
||||
raise ValueError(f"Unsupported pack source URL scheme: {parsed.scheme!r}")
|
||||
|
||||
host = (parsed.hostname or "").strip().lower()
|
||||
if not host:
|
||||
raise ValueError("Pack source URL must include a hostname")
|
||||
|
||||
if host in {"localhost"}:
|
||||
raise ValueError("Pack source URL hostname is not allowed")
|
||||
|
||||
try:
|
||||
ip = ipaddress.ip_address(host)
|
||||
except ValueError:
|
||||
return
|
||||
|
||||
if ip.is_private or ip.is_loopback or ip.is_link_local or ip.is_reserved or ip.is_multicast:
|
||||
raise ValueError("Pack source URL hostname is not allowed")
|
||||
|
||||
|
||||
def _to_tree_source_url(repo_source_url: str, branch: str, rel_path: str) -> str:
|
||||
repo_url = _normalize_repo_source_url(repo_source_url)
|
||||
safe_branch = branch.strip() or "main"
|
||||
@@ -337,9 +374,12 @@ def _collect_pack_skills(source_url: str) -> list[PackSkillCandidate]:
|
||||
check=True,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=GIT_CLONE_TIMEOUT_SECONDS,
|
||||
)
|
||||
except FileNotFoundError as exc:
|
||||
raise RuntimeError("git binary not available on the server") from exc
|
||||
except subprocess.TimeoutExpired as exc:
|
||||
raise RuntimeError("timed out cloning pack repository") from exc
|
||||
except subprocess.CalledProcessError as exc:
|
||||
stderr = (exc.stderr or "").strip()
|
||||
detail = stderr or "unable to clone pack repository"
|
||||
@@ -351,8 +391,9 @@ def _collect_pack_skills(source_url: str) -> list[PackSkillCandidate]:
|
||||
check=True,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=GIT_REV_PARSE_TIMEOUT_SECONDS,
|
||||
).stdout.strip()
|
||||
except (FileNotFoundError, subprocess.CalledProcessError):
|
||||
except (FileNotFoundError, subprocess.TimeoutExpired, subprocess.CalledProcessError):
|
||||
branch = "main"
|
||||
|
||||
return _collect_pack_skills_from_repo(
|
||||
@@ -767,6 +808,11 @@ async def create_skill_pack(
|
||||
) -> SkillPackRead:
|
||||
"""Register a new skill pack source URL."""
|
||||
source_url = str(payload.source_url).strip()
|
||||
try:
|
||||
_validate_pack_source_url(source_url)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)) from exc
|
||||
|
||||
existing = await SkillPack.objects.filter_by(
|
||||
organization_id=ctx.organization.id,
|
||||
source_url=source_url,
|
||||
@@ -816,6 +862,10 @@ async def update_skill_pack(
|
||||
"""Update a skill pack URL and metadata."""
|
||||
pack = await _require_skill_pack_for_org(pack_id=pack_id, session=session, ctx=ctx)
|
||||
source_url = str(payload.source_url).strip()
|
||||
try:
|
||||
_validate_pack_source_url(source_url)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)) from exc
|
||||
|
||||
duplicate = await SkillPack.objects.filter_by(
|
||||
organization_id=ctx.organization.id,
|
||||
@@ -863,6 +913,11 @@ async def sync_skill_pack(
|
||||
"""Clone a pack repository and upsert discovered skills from `skills/**/SKILL.md`."""
|
||||
pack = await _require_skill_pack_for_org(pack_id=pack_id, session=session, ctx=ctx)
|
||||
|
||||
try:
|
||||
_validate_pack_source_url(pack.source_url)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)) from exc
|
||||
|
||||
try:
|
||||
discovered = _collect_pack_skills(pack.source_url)
|
||||
except RuntimeError as exc:
|
||||
|
||||
@@ -397,6 +397,66 @@ async def test_sync_pack_clones_and_upserts_skills(monkeypatch: pytest.MonkeyPat
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_skill_pack_rejects_non_https_source_url() -> None:
|
||||
engine = await _make_engine()
|
||||
session_maker = async_sessionmaker(
|
||||
engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
)
|
||||
try:
|
||||
async with session_maker() as session:
|
||||
organization, _gateway = await _seed_base(session)
|
||||
await session.commit()
|
||||
|
||||
app = _build_test_app(session_maker, organization=organization)
|
||||
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=app),
|
||||
base_url="http://testserver",
|
||||
) as client:
|
||||
response = await client.post(
|
||||
"/api/v1/skills/packs",
|
||||
json={"source_url": "http://github.com/sickn33/antigravity-awesome-skills"},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "scheme" in response.json()["detail"].lower() or "https" in response.json()["detail"].lower()
|
||||
finally:
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_skill_pack_rejects_localhost_source_url() -> None:
|
||||
engine = await _make_engine()
|
||||
session_maker = async_sessionmaker(
|
||||
engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
)
|
||||
try:
|
||||
async with session_maker() as session:
|
||||
organization, _gateway = await _seed_base(session)
|
||||
await session.commit()
|
||||
|
||||
app = _build_test_app(session_maker, organization=organization)
|
||||
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=app),
|
||||
base_url="http://testserver",
|
||||
) as client:
|
||||
response = await client.post(
|
||||
"/api/v1/skills/packs",
|
||||
json={"source_url": "https://localhost/skills-pack"},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "hostname" in response.json()["detail"].lower() or "not allowed" in response.json()["detail"].lower()
|
||||
finally:
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_skill_packs_includes_skill_count() -> None:
|
||||
engine = await _make_engine()
|
||||
|
||||
Reference in New Issue
Block a user