mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 01:48:05 +00:00
feat(generator): tighten stainless config validation
This commit is contained in:
parent
38ba5bfb94
commit
b40a0c5151
1 changed files with 83 additions and 6 deletions
|
|
@ -5,7 +5,7 @@ from __future__ import annotations
|
||||||
import argparse
|
import argparse
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any, Iterator
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
|
|
@ -592,6 +592,9 @@ ALL_RESOURCES = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
HTTP_METHODS = {"get", "post", "put", "patch", "delete", "options", "head"}
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Endpoint:
|
class Endpoint:
|
||||||
method: str
|
method: str
|
||||||
|
|
@ -599,16 +602,28 @@ class Endpoint:
|
||||||
extra: dict[str, Any] = field(default_factory=dict)
|
extra: dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, value: Any) -> Endpoint:
|
def from_config(cls, value: Any) -> "Endpoint":
|
||||||
if isinstance(value, str):
|
if isinstance(value, str):
|
||||||
method, _, path = value.partition(" ")
|
method, _, path = value.partition(" ")
|
||||||
return cls(method, path)
|
return cls._from_parts(method, path)
|
||||||
if isinstance(value, dict) and "endpoint" in value:
|
if isinstance(value, dict) and "endpoint" in value:
|
||||||
method, _, path = value["endpoint"].partition(" ")
|
method, _, path = value["endpoint"].partition(" ")
|
||||||
extra = {k: v for k, v in value.items() if k != "endpoint"}
|
extra = {k: v for k, v in value.items() if k != "endpoint"}
|
||||||
return cls(method, path, extra)
|
endpoint = cls._from_parts(method, path)
|
||||||
|
endpoint.extra.update(extra)
|
||||||
|
return endpoint
|
||||||
raise ValueError(f"Unsupported endpoint value: {value!r}")
|
raise ValueError(f"Unsupported endpoint value: {value!r}")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _from_parts(cls, method: str, path: str) -> "Endpoint":
|
||||||
|
method = method.strip().lower()
|
||||||
|
path = path.strip()
|
||||||
|
if method not in HTTP_METHODS:
|
||||||
|
raise ValueError(f"Unsupported HTTP method for Stainless config: {method!r}")
|
||||||
|
if not path.startswith("/"):
|
||||||
|
raise ValueError(f"Endpoint path must start with '/': {path!r}")
|
||||||
|
return cls(method=method, path=path)
|
||||||
|
|
||||||
def to_config(self) -> Any:
|
def to_config(self) -> Any:
|
||||||
if not self.extra:
|
if not self.extra:
|
||||||
return f"{self.method} {self.path}"
|
return f"{self.method} {self.path}"
|
||||||
|
|
@ -617,7 +632,7 @@ class Endpoint:
|
||||||
return data
|
return data
|
||||||
|
|
||||||
def route_key(self) -> str:
|
def route_key(self) -> str:
|
||||||
return f"{self.method.lower()} {self.path}"
|
return f"{self.method} {self.path}"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
@ -649,6 +664,14 @@ class Resource:
|
||||||
paths.update(subresource.collect_endpoint_paths())
|
paths.update(subresource.collect_endpoint_paths())
|
||||||
return paths
|
return paths
|
||||||
|
|
||||||
|
def iter_endpoints(self, prefix: str) -> Iterator[tuple[str, str]]:
|
||||||
|
for method_name, endpoint in self.methods.items():
|
||||||
|
label = f"{prefix}.{method_name}" if prefix else method_name
|
||||||
|
yield endpoint.route_key(), label
|
||||||
|
for sub_name, subresource in self.subresources.items():
|
||||||
|
sub_prefix = f"{prefix}.{sub_name}" if prefix else sub_name
|
||||||
|
yield from subresource.iter_endpoints(sub_prefix)
|
||||||
|
|
||||||
|
|
||||||
_RESOURCES = {name: Resource.from_dict(data) for name, data in ALL_RESOURCES.items()}
|
_RESOURCES = {name: Resource.from_dict(data) for name, data in ALL_RESOURCES.items()}
|
||||||
|
|
||||||
|
|
@ -700,8 +723,56 @@ class StainlessConfig:
|
||||||
paths: set[str] = set()
|
paths: set[str] = set()
|
||||||
for resource in self.resources.values():
|
for resource in self.resources.values():
|
||||||
paths.update(resource.collect_endpoint_paths())
|
paths.update(resource.collect_endpoint_paths())
|
||||||
|
paths.update(self.readme_endpoint_paths())
|
||||||
return paths
|
return paths
|
||||||
|
|
||||||
|
def readme_endpoint_paths(self) -> set[str]:
|
||||||
|
example_requests = self.readme.get("example_requests", {}) if self.readme else {}
|
||||||
|
paths: set[str] = set()
|
||||||
|
for entry in example_requests.values():
|
||||||
|
endpoint = entry.get("endpoint") if isinstance(entry, dict) else None
|
||||||
|
if isinstance(endpoint, str):
|
||||||
|
method, _, route = endpoint.partition(" ")
|
||||||
|
method = method.strip().lower()
|
||||||
|
route = route.strip()
|
||||||
|
if method and route:
|
||||||
|
paths.add(f"{method} {route}")
|
||||||
|
return paths
|
||||||
|
|
||||||
|
def endpoint_map(self) -> dict[str, list[str]]:
|
||||||
|
mapping: dict[str, list[str]] = {}
|
||||||
|
for resource_name, resource in self.resources.items():
|
||||||
|
for route, label in resource.iter_endpoints(resource_name):
|
||||||
|
mapping.setdefault(route, []).append(label)
|
||||||
|
return mapping
|
||||||
|
|
||||||
|
def validate_unique_endpoints(self) -> None:
|
||||||
|
duplicates: dict[str, list[str]] = {}
|
||||||
|
for route, labels in self.endpoint_map().items():
|
||||||
|
top_levels = {label.split(".", 1)[0] for label in labels}
|
||||||
|
if len(top_levels) > 1:
|
||||||
|
duplicates[route] = labels
|
||||||
|
if duplicates:
|
||||||
|
formatted = "\n".join(
|
||||||
|
f" - {route} defined in: {', '.join(sorted(labels))}"
|
||||||
|
for route, labels in sorted(duplicates.items())
|
||||||
|
)
|
||||||
|
raise ValueError("Duplicate endpoints found across resources:\n" + formatted)
|
||||||
|
|
||||||
|
def validate_readme_endpoints(self) -> None:
|
||||||
|
resource_paths: set[str] = set()
|
||||||
|
for resource in self.resources.values():
|
||||||
|
resource_paths.update(resource.collect_endpoint_paths())
|
||||||
|
missing = sorted(
|
||||||
|
path for path in self.readme_endpoint_paths() if path not in resource_paths
|
||||||
|
)
|
||||||
|
if missing:
|
||||||
|
formatted = "\n".join(f" - {path}" for path in missing)
|
||||||
|
raise ValueError(
|
||||||
|
"README example endpoints are not present in Stainless resources:\n"
|
||||||
|
+ formatted
|
||||||
|
)
|
||||||
|
|
||||||
def to_dict(self) -> dict[str, Any]:
|
def to_dict(self) -> dict[str, Any]:
|
||||||
cfg: dict[str, Any] = {}
|
cfg: dict[str, Any] = {}
|
||||||
for section in SECTION_ORDER:
|
for section in SECTION_ORDER:
|
||||||
|
|
@ -721,6 +792,12 @@ class StainlessConfig:
|
||||||
formatted = "\n".join(f" - {path}" for path in missing)
|
formatted = "\n".join(f" - {path}" for path in missing)
|
||||||
raise ValueError("Stainless config references missing endpoints:\n" + formatted)
|
raise ValueError("Stainless config references missing endpoints:\n" + formatted)
|
||||||
|
|
||||||
|
def validate(self, openapi_path: Path | None = None) -> None:
|
||||||
|
self.validate_unique_endpoints()
|
||||||
|
self.validate_readme_endpoints()
|
||||||
|
if openapi_path is not None:
|
||||||
|
self.validate_against_openapi(openapi_path)
|
||||||
|
|
||||||
|
|
||||||
def build_config() -> dict[str, Any]:
|
def build_config() -> dict[str, Any]:
|
||||||
return StainlessConfig.make().to_dict()
|
return StainlessConfig.make().to_dict()
|
||||||
|
|
@ -729,7 +806,7 @@ def build_config() -> dict[str, Any]:
|
||||||
def write_config(repo_root: Path, openapi_path: Path | None = None) -> Path:
|
def write_config(repo_root: Path, openapi_path: Path | None = None) -> Path:
|
||||||
stainless_config = StainlessConfig.make()
|
stainless_config = StainlessConfig.make()
|
||||||
spec_path = (openapi_path or (repo_root / "client-sdks" / "stainless" / "openapi.yml")).resolve()
|
spec_path = (openapi_path or (repo_root / "client-sdks" / "stainless" / "openapi.yml")).resolve()
|
||||||
stainless_config.validate_against_openapi(spec_path)
|
stainless_config.validate(spec_path)
|
||||||
yaml_text = yaml.safe_dump(stainless_config.to_dict(), sort_keys=False)
|
yaml_text = yaml.safe_dump(stainless_config.to_dict(), sort_keys=False)
|
||||||
output = repo_root / "client-sdks" / "stainless" / "config.yml"
|
output = repo_root / "client-sdks" / "stainless" / "config.yml"
|
||||||
output.write_text(HEADER + yaml_text)
|
output.write_text(HEADER + yaml_text)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue