fix: return 4xx for non-existent resources in GET requests (#1635)

# What does this PR do?

- Removed Optional return types for GET methods
- Raised ValueError when requested resource is not found
- Ensures proper 4xx response for missing resources
- Updated the API generator to check for wrong signatures

```
$ uv run --with ".[dev]" ./docs/openapi_generator/run_openapi_generator.sh
Validating API method return types...

API Method Return Type Validation Errors:

Method ScoringFunctions.get_scoring_function returns Optional type
```

Closes: https://github.com/meta-llama/llama-stack/issues/1630

## Test Plan

Run the server then:

```
curl http://127.0.0.1:8321/v1/models/foo     
{"detail":"Invalid value: Model 'foo' not found"}%  
```

Server log:

```
INFO:     127.0.0.1:52307 - "GET /v1/models/foo HTTP/1.1" 400 Bad Request
09:51:42.654 [END] /v1/models/foo [StatusCode.OK] (134.65ms)
 09:51:42.651 [ERROR] Error executing endpoint route='/v1/models/{model_id:path}' method='get'
Traceback (most recent call last):
  File "/Users/leseb/Documents/AI/llama-stack/llama_stack/distribution/server/server.py", line 193, in endpoint
    return await maybe_await(value)
  File "/Users/leseb/Documents/AI/llama-stack/llama_stack/distribution/server/server.py", line 156, in maybe_await
    return await value
  File "/Users/leseb/Documents/AI/llama-stack/llama_stack/providers/utils/telemetry/trace_protocol.py", line 102, in async_wrapper
    result = await method(self, *args, **kwargs)
  File "/Users/leseb/Documents/AI/llama-stack/llama_stack/distribution/routers/routing_tables.py", line 217, in get_model
    raise ValueError(f"Model '{model_id}' not found")
ValueError: Model 'foo' not found
```

Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
Sébastien Han 2025-03-18 22:06:53 +01:00 committed by GitHub
parent cca9bd6cc3
commit c029fbcd13
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 112 additions and 136 deletions

View file

@ -1101,14 +1101,7 @@
"content": {
"application/json": {
"schema": {
"oneOf": [
{
"$ref": "#/components/schemas/Benchmark"
},
{
"type": "null"
}
]
"$ref": "#/components/schemas/Benchmark"
}
}
}
@ -1150,14 +1143,7 @@
"content": {
"application/json": {
"schema": {
"oneOf": [
{
"$ref": "#/components/schemas/Dataset"
},
{
"type": "null"
}
]
"$ref": "#/components/schemas/Dataset"
}
}
}
@ -1232,14 +1218,7 @@
"content": {
"application/json": {
"schema": {
"oneOf": [
{
"$ref": "#/components/schemas/Model"
},
{
"type": "null"
}
]
"$ref": "#/components/schemas/Model"
}
}
}
@ -1314,14 +1293,7 @@
"content": {
"application/json": {
"schema": {
"oneOf": [
{
"$ref": "#/components/schemas/ScoringFn"
},
{
"type": "null"
}
]
"$ref": "#/components/schemas/ScoringFn"
}
}
}
@ -1363,14 +1335,7 @@
"content": {
"application/json": {
"schema": {
"oneOf": [
{
"$ref": "#/components/schemas/Shield"
},
{
"type": "null"
}
]
"$ref": "#/components/schemas/Shield"
}
}
}
@ -1673,14 +1638,7 @@
"content": {
"application/json": {
"schema": {
"oneOf": [
{
"$ref": "#/components/schemas/PostTrainingJobArtifactsResponse"
},
{
"type": "null"
}
]
"$ref": "#/components/schemas/PostTrainingJobArtifactsResponse"
}
}
}
@ -1722,14 +1680,7 @@
"content": {
"application/json": {
"schema": {
"oneOf": [
{
"$ref": "#/components/schemas/PostTrainingJobStatusResponse"
},
{
"type": "null"
}
]
"$ref": "#/components/schemas/PostTrainingJobStatusResponse"
}
}
}
@ -1804,14 +1755,7 @@
"content": {
"application/json": {
"schema": {
"oneOf": [
{
"$ref": "#/components/schemas/FileUploadResponse"
},
{
"type": "null"
}
]
"$ref": "#/components/schemas/FileUploadResponse"
}
}
}
@ -1913,14 +1857,7 @@
"content": {
"application/json": {
"schema": {
"oneOf": [
{
"$ref": "#/components/schemas/VectorDB"
},
{
"type": "null"
}
]
"$ref": "#/components/schemas/VectorDB"
}
}
}
@ -2246,14 +2183,7 @@
"content": {
"application/json": {
"schema": {
"oneOf": [
{
"$ref": "#/components/schemas/JobStatus"
},
{
"type": "null"
}
]
"$ref": "#/components/schemas/JobStatus"
}
}
}

View file

@ -757,9 +757,7 @@ paths:
content:
application/json:
schema:
oneOf:
- $ref: '#/components/schemas/Benchmark'
- type: 'null'
$ref: '#/components/schemas/Benchmark'
'400':
$ref: '#/components/responses/BadRequest400'
'429':
@ -787,9 +785,7 @@ paths:
content:
application/json:
schema:
oneOf:
- $ref: '#/components/schemas/Dataset'
- type: 'null'
$ref: '#/components/schemas/Dataset'
'400':
$ref: '#/components/responses/BadRequest400'
'429':
@ -840,9 +836,7 @@ paths:
content:
application/json:
schema:
oneOf:
- $ref: '#/components/schemas/Model'
- type: 'null'
$ref: '#/components/schemas/Model'
'400':
$ref: '#/components/responses/BadRequest400'
'429':
@ -893,9 +887,7 @@ paths:
content:
application/json:
schema:
oneOf:
- $ref: '#/components/schemas/ScoringFn'
- type: 'null'
$ref: '#/components/schemas/ScoringFn'
'400':
$ref: '#/components/responses/BadRequest400'
'429':
@ -923,9 +915,7 @@ paths:
content:
application/json:
schema:
oneOf:
- $ref: '#/components/schemas/Shield'
- type: 'null'
$ref: '#/components/schemas/Shield'
'400':
$ref: '#/components/responses/BadRequest400'
'429':
@ -1127,9 +1117,7 @@ paths:
content:
application/json:
schema:
oneOf:
- $ref: '#/components/schemas/PostTrainingJobArtifactsResponse'
- type: 'null'
$ref: '#/components/schemas/PostTrainingJobArtifactsResponse'
'400':
$ref: '#/components/responses/BadRequest400'
'429':
@ -1157,9 +1145,7 @@ paths:
content:
application/json:
schema:
oneOf:
- $ref: '#/components/schemas/PostTrainingJobStatusResponse'
- type: 'null'
$ref: '#/components/schemas/PostTrainingJobStatusResponse'
'400':
$ref: '#/components/responses/BadRequest400'
'429':
@ -1210,9 +1196,7 @@ paths:
content:
application/json:
schema:
oneOf:
- $ref: '#/components/schemas/FileUploadResponse'
- type: 'null'
$ref: '#/components/schemas/FileUploadResponse'
'400':
$ref: '#/components/responses/BadRequest400'
'429':
@ -1281,9 +1265,7 @@ paths:
content:
application/json:
schema:
oneOf:
- $ref: '#/components/schemas/VectorDB'
- type: 'null'
$ref: '#/components/schemas/VectorDB'
'400':
$ref: '#/components/responses/BadRequest400'
'429':
@ -1509,9 +1491,7 @@ paths:
content:
application/json:
schema:
oneOf:
- $ref: '#/components/schemas/JobStatus'
- type: 'null'
$ref: '#/components/schemas/JobStatus'
'400':
$ref: '#/components/responses/BadRequest400'
'429':

View file

@ -12,7 +12,7 @@
from datetime import datetime
from pathlib import Path
import sys
import fire
import ruamel.yaml as yaml
@ -21,7 +21,7 @@ from llama_stack.distribution.stack import LlamaStack # noqa: E402
from .pyopenapi.options import Options # noqa: E402
from .pyopenapi.specification import Info, Server # noqa: E402
from .pyopenapi.utility import Specification # noqa: E402
from .pyopenapi.utility import Specification, validate_api_method_return_types # noqa: E402
def str_presenter(dumper, data):
@ -39,6 +39,14 @@ def main(output_dir: str):
if not output_dir.exists():
raise ValueError(f"Directory {output_dir} does not exist")
# Validate API protocols before generating spec
print("Validating API method return types...")
return_type_errors = validate_api_method_return_types()
if return_type_errors:
print("\nAPI Method Return Type Validation Errors:\n")
for error in return_type_errors:
print(error)
sys.exit(1)
now = str(datetime.now())
print(
"Converting the spec to YAML (openapi.yaml) and HTML (openapi.html) at " + now

View file

@ -6,16 +6,19 @@
import json
import typing
import inspect
import os
from pathlib import Path
from typing import TextIO
from typing import Any, Dict, List, Optional, Protocol, Type, Union, get_type_hints, get_origin, get_args
from llama_stack.strong_typing.schema import object_to_json, StrictJsonType
from llama_stack.distribution.resolver import api_protocol_map
from .generator import Generator
from .options import Options
from .specification import Document
THIS_DIR = Path(__file__).parent
@ -114,3 +117,37 @@ class Specification:
)
f.write(html)
def is_optional_type(type_: Any) -> bool:
"""Check if a type is Optional."""
origin = get_origin(type_)
args = get_args(type_)
return origin is Optional or (origin is Union and type(None) in args)
def validate_api_method_return_types() -> List[str]:
"""Validate that all API methods have proper return types."""
errors = []
protocols = api_protocol_map()
for protocol_name, protocol in protocols.items():
methods = inspect.getmembers(protocol, predicate=inspect.isfunction)
for method_name, method in methods:
if not hasattr(method, '__webmethod__'):
continue
# Only check GET methods
if method.__webmethod__.method != "GET":
continue
hints = get_type_hints(method)
if 'return' not in hints:
errors.append(f"Method {protocol_name}.{method_name} has no return type annotation")
else:
return_type = hints['return']
if is_optional_type(return_type):
errors.append(f"Method {protocol_name}.{method_name} returns Optional type")
return errors

View file

@ -52,7 +52,7 @@ class Benchmarks(Protocol):
async def get_benchmark(
self,
benchmark_id: str,
) -> Optional[Benchmark]: ...
) -> Benchmark: ...
@webmethod(route="/eval/benchmarks", method="POST")
async def register_benchmark(

View file

@ -201,7 +201,7 @@ class Datasets(Protocol):
async def get_dataset(
self,
dataset_id: str,
) -> Optional[Dataset]: ...
) -> Dataset: ...
@webmethod(route="/datasets", method="GET")
async def list_datasets(self) -> ListDatasetsResponse: ...

View file

@ -117,7 +117,7 @@ class Eval(Protocol):
"""
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="GET")
async def job_status(self, benchmark_id: str, job_id: str) -> Optional[JobStatus]:
async def job_status(self, benchmark_id: str, job_id: str) -> JobStatus:
"""Get the status of a job.
:param benchmark_id: The ID of the benchmark to run the evaluation on.

View file

@ -115,7 +115,7 @@ class Files(Protocol):
async def get_upload_session_info(
self,
upload_id: str,
) -> Optional[FileUploadResponse]:
) -> FileUploadResponse:
"""
Returns information about an existsing upload session

View file

@ -66,7 +66,7 @@ class Models(Protocol):
async def get_model(
self,
model_id: str,
) -> Optional[Model]: ...
) -> Model: ...
@webmethod(route="/models", method="POST")
async def register_model(

View file

@ -202,10 +202,10 @@ class PostTraining(Protocol):
async def get_training_jobs(self) -> ListPostTrainingJobsResponse: ...
@webmethod(route="/post-training/job/status", method="GET")
async def get_training_job_status(self, job_uuid: str) -> Optional[PostTrainingJobStatusResponse]: ...
async def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse: ...
@webmethod(route="/post-training/job/cancel", method="POST")
async def cancel_training_job(self, job_uuid: str) -> None: ...
@webmethod(route="/post-training/job/artifacts", method="GET")
async def get_training_job_artifacts(self, job_uuid: str) -> Optional[PostTrainingJobArtifactsResponse]: ...
async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse: ...

View file

@ -135,7 +135,7 @@ class ScoringFunctions(Protocol):
async def list_scoring_functions(self) -> ListScoringFunctionsResponse: ...
@webmethod(route="/scoring-functions/{scoring_fn_id:path}", method="GET")
async def get_scoring_function(self, scoring_fn_id: str, /) -> Optional[ScoringFn]: ...
async def get_scoring_function(self, scoring_fn_id: str, /) -> ScoringFn: ...
@webmethod(route="/scoring-functions", method="POST")
async def register_scoring_function(

View file

@ -49,7 +49,7 @@ class Shields(Protocol):
async def list_shields(self) -> ListShieldsResponse: ...
@webmethod(route="/shields/{identifier:path}", method="GET")
async def get_shield(self, identifier: str) -> Optional[Shield]: ...
async def get_shield(self, identifier: str) -> Shield: ...
@webmethod(route="/shields", method="POST")
async def register_shield(

View file

@ -50,7 +50,7 @@ class VectorDBs(Protocol):
async def get_vector_db(
self,
vector_db_id: str,
) -> Optional[VectorDB]: ...
) -> VectorDB: ...
@webmethod(route="/vector-dbs", method="POST")
async def register_vector_db(

View file

@ -219,8 +219,11 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
async def list_models(self) -> ListModelsResponse:
return ListModelsResponse(data=await self.get_all_with_type("model"))
async def get_model(self, model_id: str) -> Optional[Model]:
return await self.get_object_by_identifier("model", model_id)
async def get_model(self, model_id: str) -> Model:
model = await self.get_object_by_identifier("model", model_id)
if model is None:
raise ValueError(f"Model '{model_id}' not found")
return model
async def register_model(
self,
@ -267,8 +270,11 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
async def list_shields(self) -> ListShieldsResponse:
return ListShieldsResponse(data=await self.get_all_with_type(ResourceType.shield.value))
async def get_shield(self, identifier: str) -> Optional[Shield]:
return await self.get_object_by_identifier("shield", identifier)
async def get_shield(self, identifier: str) -> Shield:
shield = await self.get_object_by_identifier("shield", identifier)
if shield is None:
raise ValueError(f"Shield '{identifier}' not found")
return shield
async def register_shield(
self,
@ -303,8 +309,11 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
async def list_vector_dbs(self) -> ListVectorDBsResponse:
return ListVectorDBsResponse(data=await self.get_all_with_type("vector_db"))
async def get_vector_db(self, vector_db_id: str) -> Optional[VectorDB]:
return await self.get_object_by_identifier("vector_db", vector_db_id)
async def get_vector_db(self, vector_db_id: str) -> VectorDB:
vector_db = await self.get_object_by_identifier("vector_db", vector_db_id)
if vector_db is None:
raise ValueError(f"Vector DB '{vector_db_id}' not found")
return vector_db
async def register_vector_db(
self,
@ -355,8 +364,11 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
async def list_datasets(self) -> ListDatasetsResponse:
return ListDatasetsResponse(data=await self.get_all_with_type(ResourceType.dataset.value))
async def get_dataset(self, dataset_id: str) -> Optional[Dataset]:
return await self.get_object_by_identifier("dataset", dataset_id)
async def get_dataset(self, dataset_id: str) -> Dataset:
dataset = await self.get_object_by_identifier("dataset", dataset_id)
if dataset is None:
raise ValueError(f"Dataset '{dataset_id}' not found")
return dataset
async def register_dataset(
self,
@ -408,8 +420,11 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
async def list_scoring_functions(self) -> ListScoringFunctionsResponse:
return ListScoringFunctionsResponse(data=await self.get_all_with_type(ResourceType.scoring_function.value))
async def get_scoring_function(self, scoring_fn_id: str) -> Optional[ScoringFn]:
return await self.get_object_by_identifier("scoring_function", scoring_fn_id)
async def get_scoring_function(self, scoring_fn_id: str) -> ScoringFn:
scoring_fn = await self.get_object_by_identifier("scoring_function", scoring_fn_id)
if scoring_fn is None:
raise ValueError(f"Scoring function '{scoring_fn_id}' not found")
return scoring_fn
async def register_scoring_function(
self,
@ -445,8 +460,11 @@ class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
async def list_benchmarks(self) -> ListBenchmarksResponse:
return ListBenchmarksResponse(data=await self.get_all_with_type("benchmark"))
async def get_benchmark(self, benchmark_id: str) -> Optional[Benchmark]:
return await self.get_object_by_identifier("benchmark", benchmark_id)
async def get_benchmark(self, benchmark_id: str) -> Benchmark:
benchmark = await self.get_object_by_identifier("benchmark", benchmark_id)
if benchmark is None:
raise ValueError(f"Benchmark '{benchmark_id}' not found")
return benchmark
async def register_benchmark(
self,
@ -490,7 +508,10 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
return ListToolGroupsResponse(data=await self.get_all_with_type("tool_group"))
async def get_tool_group(self, toolgroup_id: str) -> ToolGroup:
return await self.get_object_by_identifier("tool_group", toolgroup_id)
tool_group = await self.get_object_by_identifier("tool_group", toolgroup_id)
if tool_group is None:
raise ValueError(f"Tool group '{toolgroup_id}' not found")
return tool_group
async def get_tool(self, tool_name: str) -> Tool:
return await self.get_object_by_identifier("tool", tool_name)