fix: return 4xx for non-existent resources in GET requests (#1635)

# What does this PR do?

- Removed Optional return types for GET methods
- Raised ValueError when requested resource is not found
- Ensures proper 4xx response for missing resources
- Updated the API generator to check for wrong signatures

```
$ uv run --with ".[dev]" ./docs/openapi_generator/run_openapi_generator.sh
Validating API method return types...

API Method Return Type Validation Errors:

Method ScoringFunctions.get_scoring_function returns Optional type
```

Closes: https://github.com/meta-llama/llama-stack/issues/1630

## Test Plan

Run the server then:

```
curl http://127.0.0.1:8321/v1/models/foo     
{"detail":"Invalid value: Model 'foo' not found"}%  
```

Server log:

```
INFO:     127.0.0.1:52307 - "GET /v1/models/foo HTTP/1.1" 400 Bad Request
09:51:42.654 [END] /v1/models/foo [StatusCode.OK] (134.65ms)
 09:51:42.651 [ERROR] Error executing endpoint route='/v1/models/{model_id:path}' method='get'
Traceback (most recent call last):
  File "/Users/leseb/Documents/AI/llama-stack/llama_stack/distribution/server/server.py", line 193, in endpoint
    return await maybe_await(value)
  File "/Users/leseb/Documents/AI/llama-stack/llama_stack/distribution/server/server.py", line 156, in maybe_await
    return await value
  File "/Users/leseb/Documents/AI/llama-stack/llama_stack/providers/utils/telemetry/trace_protocol.py", line 102, in async_wrapper
    result = await method(self, *args, **kwargs)
  File "/Users/leseb/Documents/AI/llama-stack/llama_stack/distribution/routers/routing_tables.py", line 217, in get_model
    raise ValueError(f"Model '{model_id}' not found")
ValueError: Model 'foo' not found
```

Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
Sébastien Han 2025-03-18 22:06:53 +01:00 committed by GitHub
parent cca9bd6cc3
commit c029fbcd13
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 112 additions and 136 deletions

View file

@ -1101,14 +1101,7 @@
"content": { "content": {
"application/json": { "application/json": {
"schema": { "schema": {
"oneOf": [
{
"$ref": "#/components/schemas/Benchmark" "$ref": "#/components/schemas/Benchmark"
},
{
"type": "null"
}
]
} }
} }
} }
@ -1150,14 +1143,7 @@
"content": { "content": {
"application/json": { "application/json": {
"schema": { "schema": {
"oneOf": [
{
"$ref": "#/components/schemas/Dataset" "$ref": "#/components/schemas/Dataset"
},
{
"type": "null"
}
]
} }
} }
} }
@ -1232,14 +1218,7 @@
"content": { "content": {
"application/json": { "application/json": {
"schema": { "schema": {
"oneOf": [
{
"$ref": "#/components/schemas/Model" "$ref": "#/components/schemas/Model"
},
{
"type": "null"
}
]
} }
} }
} }
@ -1314,14 +1293,7 @@
"content": { "content": {
"application/json": { "application/json": {
"schema": { "schema": {
"oneOf": [
{
"$ref": "#/components/schemas/ScoringFn" "$ref": "#/components/schemas/ScoringFn"
},
{
"type": "null"
}
]
} }
} }
} }
@ -1363,14 +1335,7 @@
"content": { "content": {
"application/json": { "application/json": {
"schema": { "schema": {
"oneOf": [
{
"$ref": "#/components/schemas/Shield" "$ref": "#/components/schemas/Shield"
},
{
"type": "null"
}
]
} }
} }
} }
@ -1673,14 +1638,7 @@
"content": { "content": {
"application/json": { "application/json": {
"schema": { "schema": {
"oneOf": [
{
"$ref": "#/components/schemas/PostTrainingJobArtifactsResponse" "$ref": "#/components/schemas/PostTrainingJobArtifactsResponse"
},
{
"type": "null"
}
]
} }
} }
} }
@ -1722,14 +1680,7 @@
"content": { "content": {
"application/json": { "application/json": {
"schema": { "schema": {
"oneOf": [
{
"$ref": "#/components/schemas/PostTrainingJobStatusResponse" "$ref": "#/components/schemas/PostTrainingJobStatusResponse"
},
{
"type": "null"
}
]
} }
} }
} }
@ -1804,14 +1755,7 @@
"content": { "content": {
"application/json": { "application/json": {
"schema": { "schema": {
"oneOf": [
{
"$ref": "#/components/schemas/FileUploadResponse" "$ref": "#/components/schemas/FileUploadResponse"
},
{
"type": "null"
}
]
} }
} }
} }
@ -1913,14 +1857,7 @@
"content": { "content": {
"application/json": { "application/json": {
"schema": { "schema": {
"oneOf": [
{
"$ref": "#/components/schemas/VectorDB" "$ref": "#/components/schemas/VectorDB"
},
{
"type": "null"
}
]
} }
} }
} }
@ -2246,14 +2183,7 @@
"content": { "content": {
"application/json": { "application/json": {
"schema": { "schema": {
"oneOf": [
{
"$ref": "#/components/schemas/JobStatus" "$ref": "#/components/schemas/JobStatus"
},
{
"type": "null"
}
]
} }
} }
} }

View file

@ -757,9 +757,7 @@ paths:
content: content:
application/json: application/json:
schema: schema:
oneOf: $ref: '#/components/schemas/Benchmark'
- $ref: '#/components/schemas/Benchmark'
- type: 'null'
'400': '400':
$ref: '#/components/responses/BadRequest400' $ref: '#/components/responses/BadRequest400'
'429': '429':
@ -787,9 +785,7 @@ paths:
content: content:
application/json: application/json:
schema: schema:
oneOf: $ref: '#/components/schemas/Dataset'
- $ref: '#/components/schemas/Dataset'
- type: 'null'
'400': '400':
$ref: '#/components/responses/BadRequest400' $ref: '#/components/responses/BadRequest400'
'429': '429':
@ -840,9 +836,7 @@ paths:
content: content:
application/json: application/json:
schema: schema:
oneOf: $ref: '#/components/schemas/Model'
- $ref: '#/components/schemas/Model'
- type: 'null'
'400': '400':
$ref: '#/components/responses/BadRequest400' $ref: '#/components/responses/BadRequest400'
'429': '429':
@ -893,9 +887,7 @@ paths:
content: content:
application/json: application/json:
schema: schema:
oneOf: $ref: '#/components/schemas/ScoringFn'
- $ref: '#/components/schemas/ScoringFn'
- type: 'null'
'400': '400':
$ref: '#/components/responses/BadRequest400' $ref: '#/components/responses/BadRequest400'
'429': '429':
@ -923,9 +915,7 @@ paths:
content: content:
application/json: application/json:
schema: schema:
oneOf: $ref: '#/components/schemas/Shield'
- $ref: '#/components/schemas/Shield'
- type: 'null'
'400': '400':
$ref: '#/components/responses/BadRequest400' $ref: '#/components/responses/BadRequest400'
'429': '429':
@ -1127,9 +1117,7 @@ paths:
content: content:
application/json: application/json:
schema: schema:
oneOf: $ref: '#/components/schemas/PostTrainingJobArtifactsResponse'
- $ref: '#/components/schemas/PostTrainingJobArtifactsResponse'
- type: 'null'
'400': '400':
$ref: '#/components/responses/BadRequest400' $ref: '#/components/responses/BadRequest400'
'429': '429':
@ -1157,9 +1145,7 @@ paths:
content: content:
application/json: application/json:
schema: schema:
oneOf: $ref: '#/components/schemas/PostTrainingJobStatusResponse'
- $ref: '#/components/schemas/PostTrainingJobStatusResponse'
- type: 'null'
'400': '400':
$ref: '#/components/responses/BadRequest400' $ref: '#/components/responses/BadRequest400'
'429': '429':
@ -1210,9 +1196,7 @@ paths:
content: content:
application/json: application/json:
schema: schema:
oneOf: $ref: '#/components/schemas/FileUploadResponse'
- $ref: '#/components/schemas/FileUploadResponse'
- type: 'null'
'400': '400':
$ref: '#/components/responses/BadRequest400' $ref: '#/components/responses/BadRequest400'
'429': '429':
@ -1281,9 +1265,7 @@ paths:
content: content:
application/json: application/json:
schema: schema:
oneOf: $ref: '#/components/schemas/VectorDB'
- $ref: '#/components/schemas/VectorDB'
- type: 'null'
'400': '400':
$ref: '#/components/responses/BadRequest400' $ref: '#/components/responses/BadRequest400'
'429': '429':
@ -1509,9 +1491,7 @@ paths:
content: content:
application/json: application/json:
schema: schema:
oneOf: $ref: '#/components/schemas/JobStatus'
- $ref: '#/components/schemas/JobStatus'
- type: 'null'
'400': '400':
$ref: '#/components/responses/BadRequest400' $ref: '#/components/responses/BadRequest400'
'429': '429':

View file

@ -12,7 +12,7 @@
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
import sys
import fire import fire
import ruamel.yaml as yaml import ruamel.yaml as yaml
@ -21,7 +21,7 @@ from llama_stack.distribution.stack import LlamaStack # noqa: E402
from .pyopenapi.options import Options # noqa: E402 from .pyopenapi.options import Options # noqa: E402
from .pyopenapi.specification import Info, Server # noqa: E402 from .pyopenapi.specification import Info, Server # noqa: E402
from .pyopenapi.utility import Specification # noqa: E402 from .pyopenapi.utility import Specification, validate_api_method_return_types # noqa: E402
def str_presenter(dumper, data): def str_presenter(dumper, data):
@ -39,6 +39,14 @@ def main(output_dir: str):
if not output_dir.exists(): if not output_dir.exists():
raise ValueError(f"Directory {output_dir} does not exist") raise ValueError(f"Directory {output_dir} does not exist")
# Validate API protocols before generating spec
print("Validating API method return types...")
return_type_errors = validate_api_method_return_types()
if return_type_errors:
print("\nAPI Method Return Type Validation Errors:\n")
for error in return_type_errors:
print(error)
sys.exit(1)
now = str(datetime.now()) now = str(datetime.now())
print( print(
"Converting the spec to YAML (openapi.yaml) and HTML (openapi.html) at " + now "Converting the spec to YAML (openapi.yaml) and HTML (openapi.html) at " + now

View file

@ -6,16 +6,19 @@
import json import json
import typing import typing
import inspect
import os
from pathlib import Path from pathlib import Path
from typing import TextIO from typing import TextIO
from typing import Any, Dict, List, Optional, Protocol, Type, Union, get_type_hints, get_origin, get_args
from llama_stack.strong_typing.schema import object_to_json, StrictJsonType from llama_stack.strong_typing.schema import object_to_json, StrictJsonType
from llama_stack.distribution.resolver import api_protocol_map
from .generator import Generator from .generator import Generator
from .options import Options from .options import Options
from .specification import Document from .specification import Document
THIS_DIR = Path(__file__).parent THIS_DIR = Path(__file__).parent
@ -114,3 +117,37 @@ class Specification:
) )
f.write(html) f.write(html)
def is_optional_type(type_: Any) -> bool:
"""Check if a type is Optional."""
origin = get_origin(type_)
args = get_args(type_)
return origin is Optional or (origin is Union and type(None) in args)
def validate_api_method_return_types() -> List[str]:
"""Validate that all API methods have proper return types."""
errors = []
protocols = api_protocol_map()
for protocol_name, protocol in protocols.items():
methods = inspect.getmembers(protocol, predicate=inspect.isfunction)
for method_name, method in methods:
if not hasattr(method, '__webmethod__'):
continue
# Only check GET methods
if method.__webmethod__.method != "GET":
continue
hints = get_type_hints(method)
if 'return' not in hints:
errors.append(f"Method {protocol_name}.{method_name} has no return type annotation")
else:
return_type = hints['return']
if is_optional_type(return_type):
errors.append(f"Method {protocol_name}.{method_name} returns Optional type")
return errors

View file

@ -52,7 +52,7 @@ class Benchmarks(Protocol):
async def get_benchmark( async def get_benchmark(
self, self,
benchmark_id: str, benchmark_id: str,
) -> Optional[Benchmark]: ... ) -> Benchmark: ...
@webmethod(route="/eval/benchmarks", method="POST") @webmethod(route="/eval/benchmarks", method="POST")
async def register_benchmark( async def register_benchmark(

View file

@ -201,7 +201,7 @@ class Datasets(Protocol):
async def get_dataset( async def get_dataset(
self, self,
dataset_id: str, dataset_id: str,
) -> Optional[Dataset]: ... ) -> Dataset: ...
@webmethod(route="/datasets", method="GET") @webmethod(route="/datasets", method="GET")
async def list_datasets(self) -> ListDatasetsResponse: ... async def list_datasets(self) -> ListDatasetsResponse: ...

View file

@ -117,7 +117,7 @@ class Eval(Protocol):
""" """
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="GET") @webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="GET")
async def job_status(self, benchmark_id: str, job_id: str) -> Optional[JobStatus]: async def job_status(self, benchmark_id: str, job_id: str) -> JobStatus:
"""Get the status of a job. """Get the status of a job.
:param benchmark_id: The ID of the benchmark to run the evaluation on. :param benchmark_id: The ID of the benchmark to run the evaluation on.

View file

@ -115,7 +115,7 @@ class Files(Protocol):
async def get_upload_session_info( async def get_upload_session_info(
self, self,
upload_id: str, upload_id: str,
) -> Optional[FileUploadResponse]: ) -> FileUploadResponse:
""" """
Returns information about an existsing upload session Returns information about an existsing upload session

View file

@ -66,7 +66,7 @@ class Models(Protocol):
async def get_model( async def get_model(
self, self,
model_id: str, model_id: str,
) -> Optional[Model]: ... ) -> Model: ...
@webmethod(route="/models", method="POST") @webmethod(route="/models", method="POST")
async def register_model( async def register_model(

View file

@ -202,10 +202,10 @@ class PostTraining(Protocol):
async def get_training_jobs(self) -> ListPostTrainingJobsResponse: ... async def get_training_jobs(self) -> ListPostTrainingJobsResponse: ...
@webmethod(route="/post-training/job/status", method="GET") @webmethod(route="/post-training/job/status", method="GET")
async def get_training_job_status(self, job_uuid: str) -> Optional[PostTrainingJobStatusResponse]: ... async def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse: ...
@webmethod(route="/post-training/job/cancel", method="POST") @webmethod(route="/post-training/job/cancel", method="POST")
async def cancel_training_job(self, job_uuid: str) -> None: ... async def cancel_training_job(self, job_uuid: str) -> None: ...
@webmethod(route="/post-training/job/artifacts", method="GET") @webmethod(route="/post-training/job/artifacts", method="GET")
async def get_training_job_artifacts(self, job_uuid: str) -> Optional[PostTrainingJobArtifactsResponse]: ... async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse: ...

View file

@ -135,7 +135,7 @@ class ScoringFunctions(Protocol):
async def list_scoring_functions(self) -> ListScoringFunctionsResponse: ... async def list_scoring_functions(self) -> ListScoringFunctionsResponse: ...
@webmethod(route="/scoring-functions/{scoring_fn_id:path}", method="GET") @webmethod(route="/scoring-functions/{scoring_fn_id:path}", method="GET")
async def get_scoring_function(self, scoring_fn_id: str, /) -> Optional[ScoringFn]: ... async def get_scoring_function(self, scoring_fn_id: str, /) -> ScoringFn: ...
@webmethod(route="/scoring-functions", method="POST") @webmethod(route="/scoring-functions", method="POST")
async def register_scoring_function( async def register_scoring_function(

View file

@ -49,7 +49,7 @@ class Shields(Protocol):
async def list_shields(self) -> ListShieldsResponse: ... async def list_shields(self) -> ListShieldsResponse: ...
@webmethod(route="/shields/{identifier:path}", method="GET") @webmethod(route="/shields/{identifier:path}", method="GET")
async def get_shield(self, identifier: str) -> Optional[Shield]: ... async def get_shield(self, identifier: str) -> Shield: ...
@webmethod(route="/shields", method="POST") @webmethod(route="/shields", method="POST")
async def register_shield( async def register_shield(

View file

@ -50,7 +50,7 @@ class VectorDBs(Protocol):
async def get_vector_db( async def get_vector_db(
self, self,
vector_db_id: str, vector_db_id: str,
) -> Optional[VectorDB]: ... ) -> VectorDB: ...
@webmethod(route="/vector-dbs", method="POST") @webmethod(route="/vector-dbs", method="POST")
async def register_vector_db( async def register_vector_db(

View file

@ -219,8 +219,11 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
async def list_models(self) -> ListModelsResponse: async def list_models(self) -> ListModelsResponse:
return ListModelsResponse(data=await self.get_all_with_type("model")) return ListModelsResponse(data=await self.get_all_with_type("model"))
async def get_model(self, model_id: str) -> Optional[Model]: async def get_model(self, model_id: str) -> Model:
return await self.get_object_by_identifier("model", model_id) model = await self.get_object_by_identifier("model", model_id)
if model is None:
raise ValueError(f"Model '{model_id}' not found")
return model
async def register_model( async def register_model(
self, self,
@ -267,8 +270,11 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
async def list_shields(self) -> ListShieldsResponse: async def list_shields(self) -> ListShieldsResponse:
return ListShieldsResponse(data=await self.get_all_with_type(ResourceType.shield.value)) return ListShieldsResponse(data=await self.get_all_with_type(ResourceType.shield.value))
async def get_shield(self, identifier: str) -> Optional[Shield]: async def get_shield(self, identifier: str) -> Shield:
return await self.get_object_by_identifier("shield", identifier) shield = await self.get_object_by_identifier("shield", identifier)
if shield is None:
raise ValueError(f"Shield '{identifier}' not found")
return shield
async def register_shield( async def register_shield(
self, self,
@ -303,8 +309,11 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
async def list_vector_dbs(self) -> ListVectorDBsResponse: async def list_vector_dbs(self) -> ListVectorDBsResponse:
return ListVectorDBsResponse(data=await self.get_all_with_type("vector_db")) return ListVectorDBsResponse(data=await self.get_all_with_type("vector_db"))
async def get_vector_db(self, vector_db_id: str) -> Optional[VectorDB]: async def get_vector_db(self, vector_db_id: str) -> VectorDB:
return await self.get_object_by_identifier("vector_db", vector_db_id) vector_db = await self.get_object_by_identifier("vector_db", vector_db_id)
if vector_db is None:
raise ValueError(f"Vector DB '{vector_db_id}' not found")
return vector_db
async def register_vector_db( async def register_vector_db(
self, self,
@ -355,8 +364,11 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
async def list_datasets(self) -> ListDatasetsResponse: async def list_datasets(self) -> ListDatasetsResponse:
return ListDatasetsResponse(data=await self.get_all_with_type(ResourceType.dataset.value)) return ListDatasetsResponse(data=await self.get_all_with_type(ResourceType.dataset.value))
async def get_dataset(self, dataset_id: str) -> Optional[Dataset]: async def get_dataset(self, dataset_id: str) -> Dataset:
return await self.get_object_by_identifier("dataset", dataset_id) dataset = await self.get_object_by_identifier("dataset", dataset_id)
if dataset is None:
raise ValueError(f"Dataset '{dataset_id}' not found")
return dataset
async def register_dataset( async def register_dataset(
self, self,
@ -408,8 +420,11 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
async def list_scoring_functions(self) -> ListScoringFunctionsResponse: async def list_scoring_functions(self) -> ListScoringFunctionsResponse:
return ListScoringFunctionsResponse(data=await self.get_all_with_type(ResourceType.scoring_function.value)) return ListScoringFunctionsResponse(data=await self.get_all_with_type(ResourceType.scoring_function.value))
async def get_scoring_function(self, scoring_fn_id: str) -> Optional[ScoringFn]: async def get_scoring_function(self, scoring_fn_id: str) -> ScoringFn:
return await self.get_object_by_identifier("scoring_function", scoring_fn_id) scoring_fn = await self.get_object_by_identifier("scoring_function", scoring_fn_id)
if scoring_fn is None:
raise ValueError(f"Scoring function '{scoring_fn_id}' not found")
return scoring_fn
async def register_scoring_function( async def register_scoring_function(
self, self,
@ -445,8 +460,11 @@ class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
async def list_benchmarks(self) -> ListBenchmarksResponse: async def list_benchmarks(self) -> ListBenchmarksResponse:
return ListBenchmarksResponse(data=await self.get_all_with_type("benchmark")) return ListBenchmarksResponse(data=await self.get_all_with_type("benchmark"))
async def get_benchmark(self, benchmark_id: str) -> Optional[Benchmark]: async def get_benchmark(self, benchmark_id: str) -> Benchmark:
return await self.get_object_by_identifier("benchmark", benchmark_id) benchmark = await self.get_object_by_identifier("benchmark", benchmark_id)
if benchmark is None:
raise ValueError(f"Benchmark '{benchmark_id}' not found")
return benchmark
async def register_benchmark( async def register_benchmark(
self, self,
@ -490,7 +508,10 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
return ListToolGroupsResponse(data=await self.get_all_with_type("tool_group")) return ListToolGroupsResponse(data=await self.get_all_with_type("tool_group"))
async def get_tool_group(self, toolgroup_id: str) -> ToolGroup: async def get_tool_group(self, toolgroup_id: str) -> ToolGroup:
return await self.get_object_by_identifier("tool_group", toolgroup_id) tool_group = await self.get_object_by_identifier("tool_group", toolgroup_id)
if tool_group is None:
raise ValueError(f"Tool group '{toolgroup_id}' not found")
return tool_group
async def get_tool(self, tool_name: str) -> Tool: async def get_tool(self, tool_name: str) -> Tool:
return await self.get_object_by_identifier("tool", tool_name) return await self.get_object_by_identifier("tool", tool_name)