feat(generator): tighten stainless config validation

This commit is contained in:
Ashwin Bharambe 2025-11-14 20:04:49 -08:00
parent 38ba5bfb94
commit b40a0c5151

View file

@ -5,7 +5,7 @@ from __future__ import annotations
import argparse import argparse
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any, Iterator
import yaml import yaml
@ -592,6 +592,9 @@ ALL_RESOURCES = {
} }
HTTP_METHODS = {"get", "post", "put", "patch", "delete", "options", "head"}
@dataclass @dataclass
class Endpoint: class Endpoint:
method: str method: str
@ -599,16 +602,28 @@ class Endpoint:
extra: dict[str, Any] = field(default_factory=dict) extra: dict[str, Any] = field(default_factory=dict)
@classmethod @classmethod
def from_config(cls, value: Any) -> Endpoint: def from_config(cls, value: Any) -> "Endpoint":
if isinstance(value, str): if isinstance(value, str):
method, _, path = value.partition(" ") method, _, path = value.partition(" ")
return cls(method, path) return cls._from_parts(method, path)
if isinstance(value, dict) and "endpoint" in value: if isinstance(value, dict) and "endpoint" in value:
method, _, path = value["endpoint"].partition(" ") method, _, path = value["endpoint"].partition(" ")
extra = {k: v for k, v in value.items() if k != "endpoint"} extra = {k: v for k, v in value.items() if k != "endpoint"}
return cls(method, path, extra) endpoint = cls._from_parts(method, path)
endpoint.extra.update(extra)
return endpoint
raise ValueError(f"Unsupported endpoint value: {value!r}") raise ValueError(f"Unsupported endpoint value: {value!r}")
@classmethod
def _from_parts(cls, method: str, path: str) -> "Endpoint":
method = method.strip().lower()
path = path.strip()
if method not in HTTP_METHODS:
raise ValueError(f"Unsupported HTTP method for Stainless config: {method!r}")
if not path.startswith("/"):
raise ValueError(f"Endpoint path must start with '/': {path!r}")
return cls(method=method, path=path)
def to_config(self) -> Any: def to_config(self) -> Any:
if not self.extra: if not self.extra:
return f"{self.method} {self.path}" return f"{self.method} {self.path}"
@ -617,7 +632,7 @@ class Endpoint:
return data return data
def route_key(self) -> str: def route_key(self) -> str:
return f"{self.method.lower()} {self.path}" return f"{self.method} {self.path}"
@dataclass @dataclass
@ -649,6 +664,14 @@ class Resource:
paths.update(subresource.collect_endpoint_paths()) paths.update(subresource.collect_endpoint_paths())
return paths return paths
def iter_endpoints(self, prefix: str) -> Iterator[tuple[str, str]]:
for method_name, endpoint in self.methods.items():
label = f"{prefix}.{method_name}" if prefix else method_name
yield endpoint.route_key(), label
for sub_name, subresource in self.subresources.items():
sub_prefix = f"{prefix}.{sub_name}" if prefix else sub_name
yield from subresource.iter_endpoints(sub_prefix)
_RESOURCES = {name: Resource.from_dict(data) for name, data in ALL_RESOURCES.items()} _RESOURCES = {name: Resource.from_dict(data) for name, data in ALL_RESOURCES.items()}
@ -700,8 +723,56 @@ class StainlessConfig:
paths: set[str] = set() paths: set[str] = set()
for resource in self.resources.values(): for resource in self.resources.values():
paths.update(resource.collect_endpoint_paths()) paths.update(resource.collect_endpoint_paths())
paths.update(self.readme_endpoint_paths())
return paths return paths
def readme_endpoint_paths(self) -> set[str]:
example_requests = self.readme.get("example_requests", {}) if self.readme else {}
paths: set[str] = set()
for entry in example_requests.values():
endpoint = entry.get("endpoint") if isinstance(entry, dict) else None
if isinstance(endpoint, str):
method, _, route = endpoint.partition(" ")
method = method.strip().lower()
route = route.strip()
if method and route:
paths.add(f"{method} {route}")
return paths
def endpoint_map(self) -> dict[str, list[str]]:
mapping: dict[str, list[str]] = {}
for resource_name, resource in self.resources.items():
for route, label in resource.iter_endpoints(resource_name):
mapping.setdefault(route, []).append(label)
return mapping
def validate_unique_endpoints(self) -> None:
duplicates: dict[str, list[str]] = {}
for route, labels in self.endpoint_map().items():
top_levels = {label.split(".", 1)[0] for label in labels}
if len(top_levels) > 1:
duplicates[route] = labels
if duplicates:
formatted = "\n".join(
f" - {route} defined in: {', '.join(sorted(labels))}"
for route, labels in sorted(duplicates.items())
)
raise ValueError("Duplicate endpoints found across resources:\n" + formatted)
def validate_readme_endpoints(self) -> None:
resource_paths: set[str] = set()
for resource in self.resources.values():
resource_paths.update(resource.collect_endpoint_paths())
missing = sorted(
path for path in self.readme_endpoint_paths() if path not in resource_paths
)
if missing:
formatted = "\n".join(f" - {path}" for path in missing)
raise ValueError(
"README example endpoints are not present in Stainless resources:\n"
+ formatted
)
def to_dict(self) -> dict[str, Any]: def to_dict(self) -> dict[str, Any]:
cfg: dict[str, Any] = {} cfg: dict[str, Any] = {}
for section in SECTION_ORDER: for section in SECTION_ORDER:
@ -721,6 +792,12 @@ class StainlessConfig:
formatted = "\n".join(f" - {path}" for path in missing) formatted = "\n".join(f" - {path}" for path in missing)
raise ValueError("Stainless config references missing endpoints:\n" + formatted) raise ValueError("Stainless config references missing endpoints:\n" + formatted)
def validate(self, openapi_path: Path | None = None) -> None:
self.validate_unique_endpoints()
self.validate_readme_endpoints()
if openapi_path is not None:
self.validate_against_openapi(openapi_path)
def build_config() -> dict[str, Any]: def build_config() -> dict[str, Any]:
return StainlessConfig.make().to_dict() return StainlessConfig.make().to_dict()
@ -729,7 +806,7 @@ def build_config() -> dict[str, Any]:
def write_config(repo_root: Path, openapi_path: Path | None = None) -> Path: def write_config(repo_root: Path, openapi_path: Path | None = None) -> Path:
stainless_config = StainlessConfig.make() stainless_config = StainlessConfig.make()
spec_path = (openapi_path or (repo_root / "client-sdks" / "stainless" / "openapi.yml")).resolve() spec_path = (openapi_path or (repo_root / "client-sdks" / "stainless" / "openapi.yml")).resolve()
stainless_config.validate_against_openapi(spec_path) stainless_config.validate(spec_path)
yaml_text = yaml.safe_dump(stainless_config.to_dict(), sort_keys=False) yaml_text = yaml.safe_dump(stainless_config.to_dict(), sort_keys=False)
output = repo_root / "client-sdks" / "stainless" / "config.yml" output = repo_root / "client-sdks" / "stainless" / "config.yml"
output.write_text(HEADER + yaml_text) output.write_text(HEADER + yaml_text)