feat(generator): tighten stainless config validation

This commit is contained in:
Ashwin Bharambe 2025-11-14 20:04:49 -08:00
parent 38ba5bfb94
commit b40a0c5151

View file

@ -5,7 +5,7 @@ from __future__ import annotations
import argparse
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
from typing import Any, Iterator
import yaml
@ -592,6 +592,9 @@ ALL_RESOURCES = {
}
HTTP_METHODS = {"get", "post", "put", "patch", "delete", "options", "head"}
@dataclass
class Endpoint:
method: str
@ -599,16 +602,28 @@ class Endpoint:
extra: dict[str, Any] = field(default_factory=dict)
@classmethod
def from_config(cls, value: Any) -> Endpoint:
def from_config(cls, value: Any) -> "Endpoint":
if isinstance(value, str):
method, _, path = value.partition(" ")
return cls(method, path)
return cls._from_parts(method, path)
if isinstance(value, dict) and "endpoint" in value:
method, _, path = value["endpoint"].partition(" ")
extra = {k: v for k, v in value.items() if k != "endpoint"}
return cls(method, path, extra)
endpoint = cls._from_parts(method, path)
endpoint.extra.update(extra)
return endpoint
raise ValueError(f"Unsupported endpoint value: {value!r}")
@classmethod
def _from_parts(cls, method: str, path: str) -> "Endpoint":
method = method.strip().lower()
path = path.strip()
if method not in HTTP_METHODS:
raise ValueError(f"Unsupported HTTP method for Stainless config: {method!r}")
if not path.startswith("/"):
raise ValueError(f"Endpoint path must start with '/': {path!r}")
return cls(method=method, path=path)
def to_config(self) -> Any:
if not self.extra:
return f"{self.method} {self.path}"
@ -617,7 +632,7 @@ class Endpoint:
return data
def route_key(self) -> str:
return f"{self.method.lower()} {self.path}"
return f"{self.method} {self.path}"
@dataclass
@ -649,6 +664,14 @@ class Resource:
paths.update(subresource.collect_endpoint_paths())
return paths
def iter_endpoints(self, prefix: str) -> Iterator[tuple[str, str]]:
for method_name, endpoint in self.methods.items():
label = f"{prefix}.{method_name}" if prefix else method_name
yield endpoint.route_key(), label
for sub_name, subresource in self.subresources.items():
sub_prefix = f"{prefix}.{sub_name}" if prefix else sub_name
yield from subresource.iter_endpoints(sub_prefix)
_RESOURCES = {name: Resource.from_dict(data) for name, data in ALL_RESOURCES.items()}
@ -700,8 +723,56 @@ class StainlessConfig:
paths: set[str] = set()
for resource in self.resources.values():
paths.update(resource.collect_endpoint_paths())
paths.update(self.readme_endpoint_paths())
return paths
def readme_endpoint_paths(self) -> set[str]:
example_requests = self.readme.get("example_requests", {}) if self.readme else {}
paths: set[str] = set()
for entry in example_requests.values():
endpoint = entry.get("endpoint") if isinstance(entry, dict) else None
if isinstance(endpoint, str):
method, _, route = endpoint.partition(" ")
method = method.strip().lower()
route = route.strip()
if method and route:
paths.add(f"{method} {route}")
return paths
def endpoint_map(self) -> dict[str, list[str]]:
mapping: dict[str, list[str]] = {}
for resource_name, resource in self.resources.items():
for route, label in resource.iter_endpoints(resource_name):
mapping.setdefault(route, []).append(label)
return mapping
def validate_unique_endpoints(self) -> None:
duplicates: dict[str, list[str]] = {}
for route, labels in self.endpoint_map().items():
top_levels = {label.split(".", 1)[0] for label in labels}
if len(top_levels) > 1:
duplicates[route] = labels
if duplicates:
formatted = "\n".join(
f" - {route} defined in: {', '.join(sorted(labels))}"
for route, labels in sorted(duplicates.items())
)
raise ValueError("Duplicate endpoints found across resources:\n" + formatted)
def validate_readme_endpoints(self) -> None:
resource_paths: set[str] = set()
for resource in self.resources.values():
resource_paths.update(resource.collect_endpoint_paths())
missing = sorted(
path for path in self.readme_endpoint_paths() if path not in resource_paths
)
if missing:
formatted = "\n".join(f" - {path}" for path in missing)
raise ValueError(
"README example endpoints are not present in Stainless resources:\n"
+ formatted
)
def to_dict(self) -> dict[str, Any]:
cfg: dict[str, Any] = {}
for section in SECTION_ORDER:
@ -721,6 +792,12 @@ class StainlessConfig:
formatted = "\n".join(f" - {path}" for path in missing)
raise ValueError("Stainless config references missing endpoints:\n" + formatted)
def validate(self, openapi_path: Path | None = None) -> None:
self.validate_unique_endpoints()
self.validate_readme_endpoints()
if openapi_path is not None:
self.validate_against_openapi(openapi_path)
def build_config() -> dict[str, Any]:
return StainlessConfig.make().to_dict()
@ -729,7 +806,7 @@ def build_config() -> dict[str, Any]:
def write_config(repo_root: Path, openapi_path: Path | None = None) -> Path:
stainless_config = StainlessConfig.make()
spec_path = (openapi_path or (repo_root / "client-sdks" / "stainless" / "openapi.yml")).resolve()
stainless_config.validate_against_openapi(spec_path)
stainless_config.validate(spec_path)
yaml_text = yaml.safe_dump(stainless_config.to_dict(), sort_keys=False)
output = repo_root / "client-sdks" / "stainless" / "config.yml"
output.write_text(HEADER + yaml_text)