diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html
index 210a84b03..72b2e6b17 100644
--- a/docs/_static/llama-stack-spec.html
+++ b/docs/_static/llama-stack-spec.html
@@ -1101,14 +1101,7 @@
"content": {
"application/json": {
"schema": {
- "oneOf": [
- {
- "$ref": "#/components/schemas/Benchmark"
- },
- {
- "type": "null"
- }
- ]
+ "$ref": "#/components/schemas/Benchmark"
}
}
}
@@ -1150,14 +1143,7 @@
"content": {
"application/json": {
"schema": {
- "oneOf": [
- {
- "$ref": "#/components/schemas/Dataset"
- },
- {
- "type": "null"
- }
- ]
+ "$ref": "#/components/schemas/Dataset"
}
}
}
@@ -1232,14 +1218,7 @@
"content": {
"application/json": {
"schema": {
- "oneOf": [
- {
- "$ref": "#/components/schemas/Model"
- },
- {
- "type": "null"
- }
- ]
+ "$ref": "#/components/schemas/Model"
}
}
}
@@ -1314,14 +1293,7 @@
"content": {
"application/json": {
"schema": {
- "oneOf": [
- {
- "$ref": "#/components/schemas/ScoringFn"
- },
- {
- "type": "null"
- }
- ]
+ "$ref": "#/components/schemas/ScoringFn"
}
}
}
@@ -1363,14 +1335,7 @@
"content": {
"application/json": {
"schema": {
- "oneOf": [
- {
- "$ref": "#/components/schemas/Shield"
- },
- {
- "type": "null"
- }
- ]
+ "$ref": "#/components/schemas/Shield"
}
}
}
@@ -1673,14 +1638,7 @@
"content": {
"application/json": {
"schema": {
- "oneOf": [
- {
- "$ref": "#/components/schemas/PostTrainingJobArtifactsResponse"
- },
- {
- "type": "null"
- }
- ]
+ "$ref": "#/components/schemas/PostTrainingJobArtifactsResponse"
}
}
}
@@ -1722,14 +1680,7 @@
"content": {
"application/json": {
"schema": {
- "oneOf": [
- {
- "$ref": "#/components/schemas/PostTrainingJobStatusResponse"
- },
- {
- "type": "null"
- }
- ]
+ "$ref": "#/components/schemas/PostTrainingJobStatusResponse"
}
}
}
@@ -1804,14 +1755,7 @@
"content": {
"application/json": {
"schema": {
- "oneOf": [
- {
- "$ref": "#/components/schemas/FileUploadResponse"
- },
- {
- "type": "null"
- }
- ]
+ "$ref": "#/components/schemas/FileUploadResponse"
}
}
}
@@ -1913,14 +1857,7 @@
"content": {
"application/json": {
"schema": {
- "oneOf": [
- {
- "$ref": "#/components/schemas/VectorDB"
- },
- {
- "type": "null"
- }
- ]
+ "$ref": "#/components/schemas/VectorDB"
}
}
}
@@ -2246,14 +2183,7 @@
"content": {
"application/json": {
"schema": {
- "oneOf": [
- {
- "$ref": "#/components/schemas/JobStatus"
- },
- {
- "type": "null"
- }
- ]
+ "$ref": "#/components/schemas/JobStatus"
}
}
}
diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml
index a1eb07444..6f4a9528b 100644
--- a/docs/_static/llama-stack-spec.yaml
+++ b/docs/_static/llama-stack-spec.yaml
@@ -757,9 +757,7 @@ paths:
content:
application/json:
schema:
- oneOf:
- - $ref: '#/components/schemas/Benchmark'
- - type: 'null'
+ $ref: '#/components/schemas/Benchmark'
'400':
$ref: '#/components/responses/BadRequest400'
'429':
@@ -787,9 +785,7 @@ paths:
content:
application/json:
schema:
- oneOf:
- - $ref: '#/components/schemas/Dataset'
- - type: 'null'
+ $ref: '#/components/schemas/Dataset'
'400':
$ref: '#/components/responses/BadRequest400'
'429':
@@ -840,9 +836,7 @@ paths:
content:
application/json:
schema:
- oneOf:
- - $ref: '#/components/schemas/Model'
- - type: 'null'
+ $ref: '#/components/schemas/Model'
'400':
$ref: '#/components/responses/BadRequest400'
'429':
@@ -893,9 +887,7 @@ paths:
content:
application/json:
schema:
- oneOf:
- - $ref: '#/components/schemas/ScoringFn'
- - type: 'null'
+ $ref: '#/components/schemas/ScoringFn'
'400':
$ref: '#/components/responses/BadRequest400'
'429':
@@ -923,9 +915,7 @@ paths:
content:
application/json:
schema:
- oneOf:
- - $ref: '#/components/schemas/Shield'
- - type: 'null'
+ $ref: '#/components/schemas/Shield'
'400':
$ref: '#/components/responses/BadRequest400'
'429':
@@ -1127,9 +1117,7 @@ paths:
content:
application/json:
schema:
- oneOf:
- - $ref: '#/components/schemas/PostTrainingJobArtifactsResponse'
- - type: 'null'
+ $ref: '#/components/schemas/PostTrainingJobArtifactsResponse'
'400':
$ref: '#/components/responses/BadRequest400'
'429':
@@ -1157,9 +1145,7 @@ paths:
content:
application/json:
schema:
- oneOf:
- - $ref: '#/components/schemas/PostTrainingJobStatusResponse'
- - type: 'null'
+ $ref: '#/components/schemas/PostTrainingJobStatusResponse'
'400':
$ref: '#/components/responses/BadRequest400'
'429':
@@ -1210,9 +1196,7 @@ paths:
content:
application/json:
schema:
- oneOf:
- - $ref: '#/components/schemas/FileUploadResponse'
- - type: 'null'
+ $ref: '#/components/schemas/FileUploadResponse'
'400':
$ref: '#/components/responses/BadRequest400'
'429':
@@ -1281,9 +1265,7 @@ paths:
content:
application/json:
schema:
- oneOf:
- - $ref: '#/components/schemas/VectorDB'
- - type: 'null'
+ $ref: '#/components/schemas/VectorDB'
'400':
$ref: '#/components/responses/BadRequest400'
'429':
@@ -1509,9 +1491,7 @@ paths:
content:
application/json:
schema:
- oneOf:
- - $ref: '#/components/schemas/JobStatus'
- - type: 'null'
+ $ref: '#/components/schemas/JobStatus'
'400':
$ref: '#/components/responses/BadRequest400'
'429':
diff --git a/docs/openapi_generator/generate.py b/docs/openapi_generator/generate.py
index a2553f905..879ac95e2 100644
--- a/docs/openapi_generator/generate.py
+++ b/docs/openapi_generator/generate.py
@@ -12,7 +12,7 @@
from datetime import datetime
from pathlib import Path
-
+import sys
import fire
import ruamel.yaml as yaml
@@ -21,7 +21,7 @@ from llama_stack.distribution.stack import LlamaStack # noqa: E402
from .pyopenapi.options import Options # noqa: E402
from .pyopenapi.specification import Info, Server # noqa: E402
-from .pyopenapi.utility import Specification # noqa: E402
+from .pyopenapi.utility import Specification, validate_api_method_return_types # noqa: E402
def str_presenter(dumper, data):
@@ -39,6 +39,14 @@ def main(output_dir: str):
if not output_dir.exists():
raise ValueError(f"Directory {output_dir} does not exist")
+ # Validate API protocols before generating spec
+ print("Validating API method return types...")
+ return_type_errors = validate_api_method_return_types()
+ if return_type_errors:
+ print("\nAPI Method Return Type Validation Errors:\n")
+ for error in return_type_errors:
+ print(error)
+ sys.exit(1)
now = str(datetime.now())
print(
"Converting the spec to YAML (openapi.yaml) and HTML (openapi.html) at " + now
diff --git a/docs/openapi_generator/pyopenapi/utility.py b/docs/openapi_generator/pyopenapi/utility.py
index f134aab4b..f60a33bb7 100644
--- a/docs/openapi_generator/pyopenapi/utility.py
+++ b/docs/openapi_generator/pyopenapi/utility.py
@@ -6,16 +6,19 @@
import json
import typing
+import inspect
+import os
from pathlib import Path
from typing import TextIO
+from typing import Any, Dict, List, Optional, Protocol, Type, Union, get_type_hints, get_origin, get_args
from llama_stack.strong_typing.schema import object_to_json, StrictJsonType
+from llama_stack.distribution.resolver import api_protocol_map
from .generator import Generator
from .options import Options
from .specification import Document
-
THIS_DIR = Path(__file__).parent
@@ -114,3 +117,37 @@ class Specification:
)
f.write(html)
+
+def is_optional_type(type_: Any) -> bool:
+ """Check if a type is Optional."""
+ origin = get_origin(type_)
+ args = get_args(type_)
+ return origin is Optional or (origin is Union and type(None) in args)
+
+
+def validate_api_method_return_types() -> List[str]:
+ """Validate that all API methods have proper return types."""
+ errors = []
+ protocols = api_protocol_map()
+
+ for protocol_name, protocol in protocols.items():
+ methods = inspect.getmembers(protocol, predicate=inspect.isfunction)
+
+ for method_name, method in methods:
+ if not hasattr(method, '__webmethod__'):
+ continue
+
+ # Only check GET methods
+ if method.__webmethod__.method != "GET":
+ continue
+
+ hints = get_type_hints(method)
+
+ if 'return' not in hints:
+ errors.append(f"Method {protocol_name}.{method_name} has no return type annotation")
+ else:
+ return_type = hints['return']
+ if is_optional_type(return_type):
+ errors.append(f"Method {protocol_name}.{method_name} returns Optional type")
+
+ return errors
diff --git a/llama_stack/apis/benchmarks/benchmarks.py b/llama_stack/apis/benchmarks/benchmarks.py
index 39ba355e9..809af8868 100644
--- a/llama_stack/apis/benchmarks/benchmarks.py
+++ b/llama_stack/apis/benchmarks/benchmarks.py
@@ -52,7 +52,7 @@ class Benchmarks(Protocol):
async def get_benchmark(
self,
benchmark_id: str,
- ) -> Optional[Benchmark]: ...
+ ) -> Benchmark: ...
@webmethod(route="/eval/benchmarks", method="POST")
async def register_benchmark(
diff --git a/llama_stack/apis/datasets/datasets.py b/llama_stack/apis/datasets/datasets.py
index d033d0b70..616371c7d 100644
--- a/llama_stack/apis/datasets/datasets.py
+++ b/llama_stack/apis/datasets/datasets.py
@@ -201,7 +201,7 @@ class Datasets(Protocol):
async def get_dataset(
self,
dataset_id: str,
- ) -> Optional[Dataset]: ...
+ ) -> Dataset: ...
@webmethod(route="/datasets", method="GET")
async def list_datasets(self) -> ListDatasetsResponse: ...
diff --git a/llama_stack/apis/eval/eval.py b/llama_stack/apis/eval/eval.py
index dec018d83..51c38b16a 100644
--- a/llama_stack/apis/eval/eval.py
+++ b/llama_stack/apis/eval/eval.py
@@ -117,7 +117,7 @@ class Eval(Protocol):
"""
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="GET")
- async def job_status(self, benchmark_id: str, job_id: str) -> Optional[JobStatus]:
+ async def job_status(self, benchmark_id: str, job_id: str) -> JobStatus:
"""Get the status of a job.
:param benchmark_id: The ID of the benchmark to run the evaluation on.
diff --git a/llama_stack/apis/files/files.py b/llama_stack/apis/files/files.py
index f17fadc8c..65c1ead6a 100644
--- a/llama_stack/apis/files/files.py
+++ b/llama_stack/apis/files/files.py
@@ -115,7 +115,7 @@ class Files(Protocol):
async def get_upload_session_info(
self,
upload_id: str,
- ) -> Optional[FileUploadResponse]:
+ ) -> FileUploadResponse:
"""
Returns information about an existsing upload session
diff --git a/llama_stack/apis/models/models.py b/llama_stack/apis/models/models.py
index 64b9510ea..893ebc179 100644
--- a/llama_stack/apis/models/models.py
+++ b/llama_stack/apis/models/models.py
@@ -66,7 +66,7 @@ class Models(Protocol):
async def get_model(
self,
model_id: str,
- ) -> Optional[Model]: ...
+ ) -> Model: ...
@webmethod(route="/models", method="POST")
async def register_model(
diff --git a/llama_stack/apis/post_training/post_training.py b/llama_stack/apis/post_training/post_training.py
index ed15c6de4..636eb7e7b 100644
--- a/llama_stack/apis/post_training/post_training.py
+++ b/llama_stack/apis/post_training/post_training.py
@@ -202,10 +202,10 @@ class PostTraining(Protocol):
async def get_training_jobs(self) -> ListPostTrainingJobsResponse: ...
@webmethod(route="/post-training/job/status", method="GET")
- async def get_training_job_status(self, job_uuid: str) -> Optional[PostTrainingJobStatusResponse]: ...
+ async def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse: ...
@webmethod(route="/post-training/job/cancel", method="POST")
async def cancel_training_job(self, job_uuid: str) -> None: ...
@webmethod(route="/post-training/job/artifacts", method="GET")
- async def get_training_job_artifacts(self, job_uuid: str) -> Optional[PostTrainingJobArtifactsResponse]: ...
+ async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse: ...
diff --git a/llama_stack/apis/scoring_functions/scoring_functions.py b/llama_stack/apis/scoring_functions/scoring_functions.py
index 52508d2ec..b02a7a0c4 100644
--- a/llama_stack/apis/scoring_functions/scoring_functions.py
+++ b/llama_stack/apis/scoring_functions/scoring_functions.py
@@ -135,7 +135,7 @@ class ScoringFunctions(Protocol):
async def list_scoring_functions(self) -> ListScoringFunctionsResponse: ...
@webmethod(route="/scoring-functions/{scoring_fn_id:path}", method="GET")
- async def get_scoring_function(self, scoring_fn_id: str, /) -> Optional[ScoringFn]: ...
+ async def get_scoring_function(self, scoring_fn_id: str, /) -> ScoringFn: ...
@webmethod(route="/scoring-functions", method="POST")
async def register_scoring_function(
diff --git a/llama_stack/apis/shields/shields.py b/llama_stack/apis/shields/shields.py
index ec1179ac4..67f3bd27b 100644
--- a/llama_stack/apis/shields/shields.py
+++ b/llama_stack/apis/shields/shields.py
@@ -49,7 +49,7 @@ class Shields(Protocol):
async def list_shields(self) -> ListShieldsResponse: ...
@webmethod(route="/shields/{identifier:path}", method="GET")
- async def get_shield(self, identifier: str) -> Optional[Shield]: ...
+ async def get_shield(self, identifier: str) -> Shield: ...
@webmethod(route="/shields", method="POST")
async def register_shield(
diff --git a/llama_stack/apis/vector_dbs/vector_dbs.py b/llama_stack/apis/vector_dbs/vector_dbs.py
index 9a4aa322f..fe6c33919 100644
--- a/llama_stack/apis/vector_dbs/vector_dbs.py
+++ b/llama_stack/apis/vector_dbs/vector_dbs.py
@@ -50,7 +50,7 @@ class VectorDBs(Protocol):
async def get_vector_db(
self,
vector_db_id: str,
- ) -> Optional[VectorDB]: ...
+ ) -> VectorDB: ...
@webmethod(route="/vector-dbs", method="POST")
async def register_vector_db(
diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py
index 533993421..5dea942f7 100644
--- a/llama_stack/distribution/routers/routing_tables.py
+++ b/llama_stack/distribution/routers/routing_tables.py
@@ -219,8 +219,11 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
async def list_models(self) -> ListModelsResponse:
return ListModelsResponse(data=await self.get_all_with_type("model"))
- async def get_model(self, model_id: str) -> Optional[Model]:
- return await self.get_object_by_identifier("model", model_id)
+ async def get_model(self, model_id: str) -> Model:
+ model = await self.get_object_by_identifier("model", model_id)
+ if model is None:
+ raise ValueError(f"Model '{model_id}' not found")
+ return model
async def register_model(
self,
@@ -267,8 +270,11 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
async def list_shields(self) -> ListShieldsResponse:
return ListShieldsResponse(data=await self.get_all_with_type(ResourceType.shield.value))
- async def get_shield(self, identifier: str) -> Optional[Shield]:
- return await self.get_object_by_identifier("shield", identifier)
+ async def get_shield(self, identifier: str) -> Shield:
+ shield = await self.get_object_by_identifier("shield", identifier)
+ if shield is None:
+ raise ValueError(f"Shield '{identifier}' not found")
+ return shield
async def register_shield(
self,
@@ -303,8 +309,11 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
async def list_vector_dbs(self) -> ListVectorDBsResponse:
return ListVectorDBsResponse(data=await self.get_all_with_type("vector_db"))
- async def get_vector_db(self, vector_db_id: str) -> Optional[VectorDB]:
- return await self.get_object_by_identifier("vector_db", vector_db_id)
+ async def get_vector_db(self, vector_db_id: str) -> VectorDB:
+ vector_db = await self.get_object_by_identifier("vector_db", vector_db_id)
+ if vector_db is None:
+ raise ValueError(f"Vector DB '{vector_db_id}' not found")
+ return vector_db
async def register_vector_db(
self,
@@ -355,8 +364,11 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
async def list_datasets(self) -> ListDatasetsResponse:
return ListDatasetsResponse(data=await self.get_all_with_type(ResourceType.dataset.value))
- async def get_dataset(self, dataset_id: str) -> Optional[Dataset]:
- return await self.get_object_by_identifier("dataset", dataset_id)
+ async def get_dataset(self, dataset_id: str) -> Dataset:
+ dataset = await self.get_object_by_identifier("dataset", dataset_id)
+ if dataset is None:
+ raise ValueError(f"Dataset '{dataset_id}' not found")
+ return dataset
async def register_dataset(
self,
@@ -408,8 +420,11 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
async def list_scoring_functions(self) -> ListScoringFunctionsResponse:
return ListScoringFunctionsResponse(data=await self.get_all_with_type(ResourceType.scoring_function.value))
- async def get_scoring_function(self, scoring_fn_id: str) -> Optional[ScoringFn]:
- return await self.get_object_by_identifier("scoring_function", scoring_fn_id)
+ async def get_scoring_function(self, scoring_fn_id: str) -> ScoringFn:
+ scoring_fn = await self.get_object_by_identifier("scoring_function", scoring_fn_id)
+ if scoring_fn is None:
+ raise ValueError(f"Scoring function '{scoring_fn_id}' not found")
+ return scoring_fn
async def register_scoring_function(
self,
@@ -445,8 +460,11 @@ class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
async def list_benchmarks(self) -> ListBenchmarksResponse:
return ListBenchmarksResponse(data=await self.get_all_with_type("benchmark"))
- async def get_benchmark(self, benchmark_id: str) -> Optional[Benchmark]:
- return await self.get_object_by_identifier("benchmark", benchmark_id)
+ async def get_benchmark(self, benchmark_id: str) -> Benchmark:
+ benchmark = await self.get_object_by_identifier("benchmark", benchmark_id)
+ if benchmark is None:
+ raise ValueError(f"Benchmark '{benchmark_id}' not found")
+ return benchmark
async def register_benchmark(
self,
@@ -490,7 +508,10 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
return ListToolGroupsResponse(data=await self.get_all_with_type("tool_group"))
async def get_tool_group(self, toolgroup_id: str) -> ToolGroup:
- return await self.get_object_by_identifier("tool_group", toolgroup_id)
+ tool_group = await self.get_object_by_identifier("tool_group", toolgroup_id)
+ if tool_group is None:
+ raise ValueError(f"Tool group '{toolgroup_id}' not found")
+ return tool_group
async def get_tool(self, tool_name: str) -> Tool:
return await self.get_object_by_identifier("tool", tool_name)