From b40a0c5151bc5d2b17aa143912edc34a37e1c3fd Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Fri, 14 Nov 2025 20:04:49 -0800 Subject: [PATCH] feat(generator): tighten stainless config validation --- .../stainless_config/generate_config.py | 89 +++++++++++++++++-- 1 file changed, 83 insertions(+), 6 deletions(-) diff --git a/scripts/openapi_generator/stainless_config/generate_config.py b/scripts/openapi_generator/stainless_config/generate_config.py index cf55536d0..b0d806f6f 100755 --- a/scripts/openapi_generator/stainless_config/generate_config.py +++ b/scripts/openapi_generator/stainless_config/generate_config.py @@ -5,7 +5,7 @@ from __future__ import annotations import argparse from dataclasses import dataclass, field from pathlib import Path -from typing import Any +from typing import Any, Iterator import yaml @@ -592,6 +592,9 @@ ALL_RESOURCES = { } +HTTP_METHODS = {"get", "post", "put", "patch", "delete", "options", "head"} + + @dataclass class Endpoint: method: str @@ -599,16 +602,28 @@ class Endpoint: extra: dict[str, Any] = field(default_factory=dict) @classmethod - def from_config(cls, value: Any) -> Endpoint: + def from_config(cls, value: Any) -> "Endpoint": if isinstance(value, str): method, _, path = value.partition(" ") - return cls(method, path) + return cls._from_parts(method, path) if isinstance(value, dict) and "endpoint" in value: method, _, path = value["endpoint"].partition(" ") extra = {k: v for k, v in value.items() if k != "endpoint"} - return cls(method, path, extra) + endpoint = cls._from_parts(method, path) + endpoint.extra.update(extra) + return endpoint raise ValueError(f"Unsupported endpoint value: {value!r}") + @classmethod + def _from_parts(cls, method: str, path: str) -> "Endpoint": + method = method.strip().lower() + path = path.strip() + if method not in HTTP_METHODS: + raise ValueError(f"Unsupported HTTP method for Stainless config: {method!r}") + if not path.startswith("/"): + raise ValueError(f"Endpoint path must start with '/': {path!r}") + return cls(method=method, path=path) + def to_config(self) -> Any: if not self.extra: return f"{self.method} {self.path}" @@ -617,7 +632,7 @@ class Endpoint: return data def route_key(self) -> str: - return f"{self.method.lower()} {self.path}" + return f"{self.method} {self.path}" @dataclass @@ -649,6 +664,14 @@ class Resource: paths.update(subresource.collect_endpoint_paths()) return paths + def iter_endpoints(self, prefix: str) -> Iterator[tuple[str, str]]: + for method_name, endpoint in self.methods.items(): + label = f"{prefix}.{method_name}" if prefix else method_name + yield endpoint.route_key(), label + for sub_name, subresource in self.subresources.items(): + sub_prefix = f"{prefix}.{sub_name}" if prefix else sub_name + yield from subresource.iter_endpoints(sub_prefix) + _RESOURCES = {name: Resource.from_dict(data) for name, data in ALL_RESOURCES.items()} @@ -700,8 +723,56 @@ class StainlessConfig: paths: set[str] = set() for resource in self.resources.values(): paths.update(resource.collect_endpoint_paths()) + paths.update(self.readme_endpoint_paths()) return paths + def readme_endpoint_paths(self) -> set[str]: + example_requests = self.readme.get("example_requests", {}) if self.readme else {} + paths: set[str] = set() + for entry in example_requests.values(): + endpoint = entry.get("endpoint") if isinstance(entry, dict) else None + if isinstance(endpoint, str): + method, _, route = endpoint.partition(" ") + method = method.strip().lower() + route = route.strip() + if method and route: + paths.add(f"{method} {route}") + return paths + + def endpoint_map(self) -> dict[str, list[str]]: + mapping: dict[str, list[str]] = {} + for resource_name, resource in self.resources.items(): + for route, label in resource.iter_endpoints(resource_name): + mapping.setdefault(route, []).append(label) + return mapping + + def validate_unique_endpoints(self) -> None: + duplicates: dict[str, list[str]] = {} + for route, labels in self.endpoint_map().items(): + top_levels = {label.split(".", 1)[0] for label in labels} + if len(top_levels) > 1: + duplicates[route] = labels + if duplicates: + formatted = "\n".join( + f" - {route} defined in: {', '.join(sorted(labels))}" + for route, labels in sorted(duplicates.items()) + ) + raise ValueError("Duplicate endpoints found across resources:\n" + formatted) + + def validate_readme_endpoints(self) -> None: + resource_paths: set[str] = set() + for resource in self.resources.values(): + resource_paths.update(resource.collect_endpoint_paths()) + missing = sorted( + path for path in self.readme_endpoint_paths() if path not in resource_paths + ) + if missing: + formatted = "\n".join(f" - {path}" for path in missing) + raise ValueError( + "README example endpoints are not present in Stainless resources:\n" + + formatted + ) + def to_dict(self) -> dict[str, Any]: cfg: dict[str, Any] = {} for section in SECTION_ORDER: @@ -721,6 +792,12 @@ class StainlessConfig: formatted = "\n".join(f" - {path}" for path in missing) raise ValueError("Stainless config references missing endpoints:\n" + formatted) + def validate(self, openapi_path: Path | None = None) -> None: + self.validate_unique_endpoints() + self.validate_readme_endpoints() + if openapi_path is not None: + self.validate_against_openapi(openapi_path) + def build_config() -> dict[str, Any]: return StainlessConfig.make().to_dict() @@ -729,7 +806,7 @@ def build_config() -> dict[str, Any]: def write_config(repo_root: Path, openapi_path: Path | None = None) -> Path: stainless_config = StainlessConfig.make() spec_path = (openapi_path or (repo_root / "client-sdks" / "stainless" / "openapi.yml")).resolve() - stainless_config.validate_against_openapi(spec_path) + stainless_config.validate(spec_path) yaml_text = yaml.safe_dump(stainless_config.to_dict(), sort_keys=False) output = repo_root / "client-sdks" / "stainless" / "config.yml" output.write_text(HEADER + yaml_text)