From c029fbcd13ff270888f3e34e5369fb9d750821d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Han?= Date: Tue, 18 Mar 2025 22:06:53 +0100 Subject: [PATCH] fix: return 4xx for non-existent resources in GET requests (#1635) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # What does this PR do? - Removed Optional return types for GET methods - Raised ValueError when requested resource is not found - Ensures proper 4xx response for missing resources - Updated the API generator to check for wrong signatures ``` $ uv run --with ".[dev]" ./docs/openapi_generator/run_openapi_generator.sh Validating API method return types... API Method Return Type Validation Errors: Method ScoringFunctions.get_scoring_function returns Optional type ``` Closes: https://github.com/meta-llama/llama-stack/issues/1630 ## Test Plan Run the server then: ``` curl http://127.0.0.1:8321/v1/models/foo {"detail":"Invalid value: Model 'foo' not found"}% ``` Server log: ``` INFO: 127.0.0.1:52307 - "GET /v1/models/foo HTTP/1.1" 400 Bad Request 09:51:42.654 [END] /v1/models/foo [StatusCode.OK] (134.65ms) 09:51:42.651 [ERROR] Error executing endpoint route='/v1/models/{model_id:path}' method='get' Traceback (most recent call last): File "/Users/leseb/Documents/AI/llama-stack/llama_stack/distribution/server/server.py", line 193, in endpoint return await maybe_await(value) File "/Users/leseb/Documents/AI/llama-stack/llama_stack/distribution/server/server.py", line 156, in maybe_await return await value File "/Users/leseb/Documents/AI/llama-stack/llama_stack/providers/utils/telemetry/trace_protocol.py", line 102, in async_wrapper result = await method(self, *args, **kwargs) File "/Users/leseb/Documents/AI/llama-stack/llama_stack/distribution/routers/routing_tables.py", line 217, in get_model raise ValueError(f"Model '{model_id}' not found") ValueError: Model 'foo' not found ``` Signed-off-by: Sébastien Han --- docs/_static/llama-stack-spec.html | 90 +++---------------- docs/_static/llama-stack-spec.yaml | 40 +++------ docs/openapi_generator/generate.py | 12 ++- docs/openapi_generator/pyopenapi/utility.py | 39 +++++++- llama_stack/apis/benchmarks/benchmarks.py | 2 +- llama_stack/apis/datasets/datasets.py | 2 +- llama_stack/apis/eval/eval.py | 2 +- llama_stack/apis/files/files.py | 2 +- llama_stack/apis/models/models.py | 2 +- .../apis/post_training/post_training.py | 4 +- .../scoring_functions/scoring_functions.py | 2 +- llama_stack/apis/shields/shields.py | 2 +- llama_stack/apis/vector_dbs/vector_dbs.py | 2 +- .../distribution/routers/routing_tables.py | 47 +++++++--- 14 files changed, 112 insertions(+), 136 deletions(-) diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index 210a84b03..72b2e6b17 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -1101,14 +1101,7 @@ "content": { "application/json": { "schema": { - "oneOf": [ - { - "$ref": "#/components/schemas/Benchmark" - }, - { - "type": "null" - } - ] + "$ref": "#/components/schemas/Benchmark" } } } @@ -1150,14 +1143,7 @@ "content": { "application/json": { "schema": { - "oneOf": [ - { - "$ref": "#/components/schemas/Dataset" - }, - { - "type": "null" - } - ] + "$ref": "#/components/schemas/Dataset" } } } @@ -1232,14 +1218,7 @@ "content": { "application/json": { "schema": { - "oneOf": [ - { - "$ref": "#/components/schemas/Model" - }, - { - "type": "null" - } - ] + "$ref": "#/components/schemas/Model" } } } @@ -1314,14 +1293,7 @@ "content": { "application/json": { "schema": { - "oneOf": [ - { - "$ref": "#/components/schemas/ScoringFn" - }, - { - "type": "null" - } - ] + "$ref": "#/components/schemas/ScoringFn" } } } @@ -1363,14 +1335,7 @@ "content": { "application/json": { "schema": { - "oneOf": [ - { - "$ref": "#/components/schemas/Shield" - }, - { - "type": "null" - } - ] + "$ref": "#/components/schemas/Shield" } } } @@ -1673,14 +1638,7 @@ "content": { "application/json": { "schema": { - "oneOf": [ - { - "$ref": "#/components/schemas/PostTrainingJobArtifactsResponse" - }, - { - "type": "null" - } - ] + "$ref": "#/components/schemas/PostTrainingJobArtifactsResponse" } } } @@ -1722,14 +1680,7 @@ "content": { "application/json": { "schema": { - "oneOf": [ - { - "$ref": "#/components/schemas/PostTrainingJobStatusResponse" - }, - { - "type": "null" - } - ] + "$ref": "#/components/schemas/PostTrainingJobStatusResponse" } } } @@ -1804,14 +1755,7 @@ "content": { "application/json": { "schema": { - "oneOf": [ - { - "$ref": "#/components/schemas/FileUploadResponse" - }, - { - "type": "null" - } - ] + "$ref": "#/components/schemas/FileUploadResponse" } } } @@ -1913,14 +1857,7 @@ "content": { "application/json": { "schema": { - "oneOf": [ - { - "$ref": "#/components/schemas/VectorDB" - }, - { - "type": "null" - } - ] + "$ref": "#/components/schemas/VectorDB" } } } @@ -2246,14 +2183,7 @@ "content": { "application/json": { "schema": { - "oneOf": [ - { - "$ref": "#/components/schemas/JobStatus" - }, - { - "type": "null" - } - ] + "$ref": "#/components/schemas/JobStatus" } } } diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index a1eb07444..6f4a9528b 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -757,9 +757,7 @@ paths: content: application/json: schema: - oneOf: - - $ref: '#/components/schemas/Benchmark' - - type: 'null' + $ref: '#/components/schemas/Benchmark' '400': $ref: '#/components/responses/BadRequest400' '429': @@ -787,9 +785,7 @@ paths: content: application/json: schema: - oneOf: - - $ref: '#/components/schemas/Dataset' - - type: 'null' + $ref: '#/components/schemas/Dataset' '400': $ref: '#/components/responses/BadRequest400' '429': @@ -840,9 +836,7 @@ paths: content: application/json: schema: - oneOf: - - $ref: '#/components/schemas/Model' - - type: 'null' + $ref: '#/components/schemas/Model' '400': $ref: '#/components/responses/BadRequest400' '429': @@ -893,9 +887,7 @@ paths: content: application/json: schema: - oneOf: - - $ref: '#/components/schemas/ScoringFn' - - type: 'null' + $ref: '#/components/schemas/ScoringFn' '400': $ref: '#/components/responses/BadRequest400' '429': @@ -923,9 +915,7 @@ paths: content: application/json: schema: - oneOf: - - $ref: '#/components/schemas/Shield' - - type: 'null' + $ref: '#/components/schemas/Shield' '400': $ref: '#/components/responses/BadRequest400' '429': @@ -1127,9 +1117,7 @@ paths: content: application/json: schema: - oneOf: - - $ref: '#/components/schemas/PostTrainingJobArtifactsResponse' - - type: 'null' + $ref: '#/components/schemas/PostTrainingJobArtifactsResponse' '400': $ref: '#/components/responses/BadRequest400' '429': @@ -1157,9 +1145,7 @@ paths: content: application/json: schema: - oneOf: - - $ref: '#/components/schemas/PostTrainingJobStatusResponse' - - type: 'null' + $ref: '#/components/schemas/PostTrainingJobStatusResponse' '400': $ref: '#/components/responses/BadRequest400' '429': @@ -1210,9 +1196,7 @@ paths: content: application/json: schema: - oneOf: - - $ref: '#/components/schemas/FileUploadResponse' - - type: 'null' + $ref: '#/components/schemas/FileUploadResponse' '400': $ref: '#/components/responses/BadRequest400' '429': @@ -1281,9 +1265,7 @@ paths: content: application/json: schema: - oneOf: - - $ref: '#/components/schemas/VectorDB' - - type: 'null' + $ref: '#/components/schemas/VectorDB' '400': $ref: '#/components/responses/BadRequest400' '429': @@ -1509,9 +1491,7 @@ paths: content: application/json: schema: - oneOf: - - $ref: '#/components/schemas/JobStatus' - - type: 'null' + $ref: '#/components/schemas/JobStatus' '400': $ref: '#/components/responses/BadRequest400' '429': diff --git a/docs/openapi_generator/generate.py b/docs/openapi_generator/generate.py index a2553f905..879ac95e2 100644 --- a/docs/openapi_generator/generate.py +++ b/docs/openapi_generator/generate.py @@ -12,7 +12,7 @@ from datetime import datetime from pathlib import Path - +import sys import fire import ruamel.yaml as yaml @@ -21,7 +21,7 @@ from llama_stack.distribution.stack import LlamaStack # noqa: E402 from .pyopenapi.options import Options # noqa: E402 from .pyopenapi.specification import Info, Server # noqa: E402 -from .pyopenapi.utility import Specification # noqa: E402 +from .pyopenapi.utility import Specification, validate_api_method_return_types # noqa: E402 def str_presenter(dumper, data): @@ -39,6 +39,14 @@ def main(output_dir: str): if not output_dir.exists(): raise ValueError(f"Directory {output_dir} does not exist") + # Validate API protocols before generating spec + print("Validating API method return types...") + return_type_errors = validate_api_method_return_types() + if return_type_errors: + print("\nAPI Method Return Type Validation Errors:\n") + for error in return_type_errors: + print(error) + sys.exit(1) now = str(datetime.now()) print( "Converting the spec to YAML (openapi.yaml) and HTML (openapi.html) at " + now diff --git a/docs/openapi_generator/pyopenapi/utility.py b/docs/openapi_generator/pyopenapi/utility.py index f134aab4b..f60a33bb7 100644 --- a/docs/openapi_generator/pyopenapi/utility.py +++ b/docs/openapi_generator/pyopenapi/utility.py @@ -6,16 +6,19 @@ import json import typing +import inspect +import os from pathlib import Path from typing import TextIO +from typing import Any, Dict, List, Optional, Protocol, Type, Union, get_type_hints, get_origin, get_args from llama_stack.strong_typing.schema import object_to_json, StrictJsonType +from llama_stack.distribution.resolver import api_protocol_map from .generator import Generator from .options import Options from .specification import Document - THIS_DIR = Path(__file__).parent @@ -114,3 +117,37 @@ class Specification: ) f.write(html) + +def is_optional_type(type_: Any) -> bool: + """Check if a type is Optional.""" + origin = get_origin(type_) + args = get_args(type_) + return origin is Optional or (origin is Union and type(None) in args) + + +def validate_api_method_return_types() -> List[str]: + """Validate that all API methods have proper return types.""" + errors = [] + protocols = api_protocol_map() + + for protocol_name, protocol in protocols.items(): + methods = inspect.getmembers(protocol, predicate=inspect.isfunction) + + for method_name, method in methods: + if not hasattr(method, '__webmethod__'): + continue + + # Only check GET methods + if method.__webmethod__.method != "GET": + continue + + hints = get_type_hints(method) + + if 'return' not in hints: + errors.append(f"Method {protocol_name}.{method_name} has no return type annotation") + else: + return_type = hints['return'] + if is_optional_type(return_type): + errors.append(f"Method {protocol_name}.{method_name} returns Optional type") + + return errors diff --git a/llama_stack/apis/benchmarks/benchmarks.py b/llama_stack/apis/benchmarks/benchmarks.py index 39ba355e9..809af8868 100644 --- a/llama_stack/apis/benchmarks/benchmarks.py +++ b/llama_stack/apis/benchmarks/benchmarks.py @@ -52,7 +52,7 @@ class Benchmarks(Protocol): async def get_benchmark( self, benchmark_id: str, - ) -> Optional[Benchmark]: ... + ) -> Benchmark: ... @webmethod(route="/eval/benchmarks", method="POST") async def register_benchmark( diff --git a/llama_stack/apis/datasets/datasets.py b/llama_stack/apis/datasets/datasets.py index d033d0b70..616371c7d 100644 --- a/llama_stack/apis/datasets/datasets.py +++ b/llama_stack/apis/datasets/datasets.py @@ -201,7 +201,7 @@ class Datasets(Protocol): async def get_dataset( self, dataset_id: str, - ) -> Optional[Dataset]: ... + ) -> Dataset: ... @webmethod(route="/datasets", method="GET") async def list_datasets(self) -> ListDatasetsResponse: ... diff --git a/llama_stack/apis/eval/eval.py b/llama_stack/apis/eval/eval.py index dec018d83..51c38b16a 100644 --- a/llama_stack/apis/eval/eval.py +++ b/llama_stack/apis/eval/eval.py @@ -117,7 +117,7 @@ class Eval(Protocol): """ @webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="GET") - async def job_status(self, benchmark_id: str, job_id: str) -> Optional[JobStatus]: + async def job_status(self, benchmark_id: str, job_id: str) -> JobStatus: """Get the status of a job. :param benchmark_id: The ID of the benchmark to run the evaluation on. diff --git a/llama_stack/apis/files/files.py b/llama_stack/apis/files/files.py index f17fadc8c..65c1ead6a 100644 --- a/llama_stack/apis/files/files.py +++ b/llama_stack/apis/files/files.py @@ -115,7 +115,7 @@ class Files(Protocol): async def get_upload_session_info( self, upload_id: str, - ) -> Optional[FileUploadResponse]: + ) -> FileUploadResponse: """ Returns information about an existsing upload session diff --git a/llama_stack/apis/models/models.py b/llama_stack/apis/models/models.py index 64b9510ea..893ebc179 100644 --- a/llama_stack/apis/models/models.py +++ b/llama_stack/apis/models/models.py @@ -66,7 +66,7 @@ class Models(Protocol): async def get_model( self, model_id: str, - ) -> Optional[Model]: ... + ) -> Model: ... @webmethod(route="/models", method="POST") async def register_model( diff --git a/llama_stack/apis/post_training/post_training.py b/llama_stack/apis/post_training/post_training.py index ed15c6de4..636eb7e7b 100644 --- a/llama_stack/apis/post_training/post_training.py +++ b/llama_stack/apis/post_training/post_training.py @@ -202,10 +202,10 @@ class PostTraining(Protocol): async def get_training_jobs(self) -> ListPostTrainingJobsResponse: ... @webmethod(route="/post-training/job/status", method="GET") - async def get_training_job_status(self, job_uuid: str) -> Optional[PostTrainingJobStatusResponse]: ... + async def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse: ... @webmethod(route="/post-training/job/cancel", method="POST") async def cancel_training_job(self, job_uuid: str) -> None: ... @webmethod(route="/post-training/job/artifacts", method="GET") - async def get_training_job_artifacts(self, job_uuid: str) -> Optional[PostTrainingJobArtifactsResponse]: ... + async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse: ... diff --git a/llama_stack/apis/scoring_functions/scoring_functions.py b/llama_stack/apis/scoring_functions/scoring_functions.py index 52508d2ec..b02a7a0c4 100644 --- a/llama_stack/apis/scoring_functions/scoring_functions.py +++ b/llama_stack/apis/scoring_functions/scoring_functions.py @@ -135,7 +135,7 @@ class ScoringFunctions(Protocol): async def list_scoring_functions(self) -> ListScoringFunctionsResponse: ... @webmethod(route="/scoring-functions/{scoring_fn_id:path}", method="GET") - async def get_scoring_function(self, scoring_fn_id: str, /) -> Optional[ScoringFn]: ... + async def get_scoring_function(self, scoring_fn_id: str, /) -> ScoringFn: ... @webmethod(route="/scoring-functions", method="POST") async def register_scoring_function( diff --git a/llama_stack/apis/shields/shields.py b/llama_stack/apis/shields/shields.py index ec1179ac4..67f3bd27b 100644 --- a/llama_stack/apis/shields/shields.py +++ b/llama_stack/apis/shields/shields.py @@ -49,7 +49,7 @@ class Shields(Protocol): async def list_shields(self) -> ListShieldsResponse: ... @webmethod(route="/shields/{identifier:path}", method="GET") - async def get_shield(self, identifier: str) -> Optional[Shield]: ... + async def get_shield(self, identifier: str) -> Shield: ... @webmethod(route="/shields", method="POST") async def register_shield( diff --git a/llama_stack/apis/vector_dbs/vector_dbs.py b/llama_stack/apis/vector_dbs/vector_dbs.py index 9a4aa322f..fe6c33919 100644 --- a/llama_stack/apis/vector_dbs/vector_dbs.py +++ b/llama_stack/apis/vector_dbs/vector_dbs.py @@ -50,7 +50,7 @@ class VectorDBs(Protocol): async def get_vector_db( self, vector_db_id: str, - ) -> Optional[VectorDB]: ... + ) -> VectorDB: ... @webmethod(route="/vector-dbs", method="POST") async def register_vector_db( diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 533993421..5dea942f7 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -219,8 +219,11 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): async def list_models(self) -> ListModelsResponse: return ListModelsResponse(data=await self.get_all_with_type("model")) - async def get_model(self, model_id: str) -> Optional[Model]: - return await self.get_object_by_identifier("model", model_id) + async def get_model(self, model_id: str) -> Model: + model = await self.get_object_by_identifier("model", model_id) + if model is None: + raise ValueError(f"Model '{model_id}' not found") + return model async def register_model( self, @@ -267,8 +270,11 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): async def list_shields(self) -> ListShieldsResponse: return ListShieldsResponse(data=await self.get_all_with_type(ResourceType.shield.value)) - async def get_shield(self, identifier: str) -> Optional[Shield]: - return await self.get_object_by_identifier("shield", identifier) + async def get_shield(self, identifier: str) -> Shield: + shield = await self.get_object_by_identifier("shield", identifier) + if shield is None: + raise ValueError(f"Shield '{identifier}' not found") + return shield async def register_shield( self, @@ -303,8 +309,11 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): async def list_vector_dbs(self) -> ListVectorDBsResponse: return ListVectorDBsResponse(data=await self.get_all_with_type("vector_db")) - async def get_vector_db(self, vector_db_id: str) -> Optional[VectorDB]: - return await self.get_object_by_identifier("vector_db", vector_db_id) + async def get_vector_db(self, vector_db_id: str) -> VectorDB: + vector_db = await self.get_object_by_identifier("vector_db", vector_db_id) + if vector_db is None: + raise ValueError(f"Vector DB '{vector_db_id}' not found") + return vector_db async def register_vector_db( self, @@ -355,8 +364,11 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): async def list_datasets(self) -> ListDatasetsResponse: return ListDatasetsResponse(data=await self.get_all_with_type(ResourceType.dataset.value)) - async def get_dataset(self, dataset_id: str) -> Optional[Dataset]: - return await self.get_object_by_identifier("dataset", dataset_id) + async def get_dataset(self, dataset_id: str) -> Dataset: + dataset = await self.get_object_by_identifier("dataset", dataset_id) + if dataset is None: + raise ValueError(f"Dataset '{dataset_id}' not found") + return dataset async def register_dataset( self, @@ -408,8 +420,11 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions): async def list_scoring_functions(self) -> ListScoringFunctionsResponse: return ListScoringFunctionsResponse(data=await self.get_all_with_type(ResourceType.scoring_function.value)) - async def get_scoring_function(self, scoring_fn_id: str) -> Optional[ScoringFn]: - return await self.get_object_by_identifier("scoring_function", scoring_fn_id) + async def get_scoring_function(self, scoring_fn_id: str) -> ScoringFn: + scoring_fn = await self.get_object_by_identifier("scoring_function", scoring_fn_id) + if scoring_fn is None: + raise ValueError(f"Scoring function '{scoring_fn_id}' not found") + return scoring_fn async def register_scoring_function( self, @@ -445,8 +460,11 @@ class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks): async def list_benchmarks(self) -> ListBenchmarksResponse: return ListBenchmarksResponse(data=await self.get_all_with_type("benchmark")) - async def get_benchmark(self, benchmark_id: str) -> Optional[Benchmark]: - return await self.get_object_by_identifier("benchmark", benchmark_id) + async def get_benchmark(self, benchmark_id: str) -> Benchmark: + benchmark = await self.get_object_by_identifier("benchmark", benchmark_id) + if benchmark is None: + raise ValueError(f"Benchmark '{benchmark_id}' not found") + return benchmark async def register_benchmark( self, @@ -490,7 +508,10 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups): return ListToolGroupsResponse(data=await self.get_all_with_type("tool_group")) async def get_tool_group(self, toolgroup_id: str) -> ToolGroup: - return await self.get_object_by_identifier("tool_group", toolgroup_id) + tool_group = await self.get_object_by_identifier("tool_group", toolgroup_id) + if tool_group is None: + raise ValueError(f"Tool group '{toolgroup_id}' not found") + return tool_group async def get_tool(self, tool_name: str) -> Tool: return await self.get_object_by_identifier("tool", tool_name)