forked from phoenix-oss/llama-stack-mirror
fix: return 4xx for non-existent resources in GET requests (#1635)
# What does this PR do? - Removed Optional return types for GET methods - Raised ValueError when requested resource is not found - Ensures proper 4xx response for missing resources - Updated the API generator to check for wrong signatures ``` $ uv run --with ".[dev]" ./docs/openapi_generator/run_openapi_generator.sh Validating API method return types... API Method Return Type Validation Errors: Method ScoringFunctions.get_scoring_function returns Optional type ``` Closes: https://github.com/meta-llama/llama-stack/issues/1630 ## Test Plan Run the server then: ``` curl http://127.0.0.1:8321/v1/models/foo {"detail":"Invalid value: Model 'foo' not found"}% ``` Server log: ``` INFO: 127.0.0.1:52307 - "GET /v1/models/foo HTTP/1.1" 400 Bad Request 09:51:42.654 [END] /v1/models/foo [StatusCode.OK] (134.65ms) 09:51:42.651 [ERROR] Error executing endpoint route='/v1/models/{model_id:path}' method='get' Traceback (most recent call last): File "/Users/leseb/Documents/AI/llama-stack/llama_stack/distribution/server/server.py", line 193, in endpoint return await maybe_await(value) File "/Users/leseb/Documents/AI/llama-stack/llama_stack/distribution/server/server.py", line 156, in maybe_await return await value File "/Users/leseb/Documents/AI/llama-stack/llama_stack/providers/utils/telemetry/trace_protocol.py", line 102, in async_wrapper result = await method(self, *args, **kwargs) File "/Users/leseb/Documents/AI/llama-stack/llama_stack/distribution/routers/routing_tables.py", line 217, in get_model raise ValueError(f"Model '{model_id}' not found") ValueError: Model 'foo' not found ``` Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
parent
cca9bd6cc3
commit
c029fbcd13
14 changed files with 112 additions and 136 deletions
|
@ -12,7 +12,7 @@
|
|||
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import sys
|
||||
import fire
|
||||
import ruamel.yaml as yaml
|
||||
|
||||
|
@ -21,7 +21,7 @@ from llama_stack.distribution.stack import LlamaStack # noqa: E402
|
|||
|
||||
from .pyopenapi.options import Options # noqa: E402
|
||||
from .pyopenapi.specification import Info, Server # noqa: E402
|
||||
from .pyopenapi.utility import Specification # noqa: E402
|
||||
from .pyopenapi.utility import Specification, validate_api_method_return_types # noqa: E402
|
||||
|
||||
|
||||
def str_presenter(dumper, data):
|
||||
|
@ -39,6 +39,14 @@ def main(output_dir: str):
|
|||
if not output_dir.exists():
|
||||
raise ValueError(f"Directory {output_dir} does not exist")
|
||||
|
||||
# Validate API protocols before generating spec
|
||||
print("Validating API method return types...")
|
||||
return_type_errors = validate_api_method_return_types()
|
||||
if return_type_errors:
|
||||
print("\nAPI Method Return Type Validation Errors:\n")
|
||||
for error in return_type_errors:
|
||||
print(error)
|
||||
sys.exit(1)
|
||||
now = str(datetime.now())
|
||||
print(
|
||||
"Converting the spec to YAML (openapi.yaml) and HTML (openapi.html) at " + now
|
||||
|
|
|
@ -6,16 +6,19 @@
|
|||
|
||||
import json
|
||||
import typing
|
||||
import inspect
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import TextIO
|
||||
from typing import Any, Dict, List, Optional, Protocol, Type, Union, get_type_hints, get_origin, get_args
|
||||
|
||||
from llama_stack.strong_typing.schema import object_to_json, StrictJsonType
|
||||
from llama_stack.distribution.resolver import api_protocol_map
|
||||
|
||||
from .generator import Generator
|
||||
from .options import Options
|
||||
from .specification import Document
|
||||
|
||||
|
||||
THIS_DIR = Path(__file__).parent
|
||||
|
||||
|
||||
|
@ -114,3 +117,37 @@ class Specification:
|
|||
)
|
||||
|
||||
f.write(html)
|
||||
|
||||
def is_optional_type(type_: Any) -> bool:
|
||||
"""Check if a type is Optional."""
|
||||
origin = get_origin(type_)
|
||||
args = get_args(type_)
|
||||
return origin is Optional or (origin is Union and type(None) in args)
|
||||
|
||||
|
||||
def validate_api_method_return_types() -> List[str]:
|
||||
"""Validate that all API methods have proper return types."""
|
||||
errors = []
|
||||
protocols = api_protocol_map()
|
||||
|
||||
for protocol_name, protocol in protocols.items():
|
||||
methods = inspect.getmembers(protocol, predicate=inspect.isfunction)
|
||||
|
||||
for method_name, method in methods:
|
||||
if not hasattr(method, '__webmethod__'):
|
||||
continue
|
||||
|
||||
# Only check GET methods
|
||||
if method.__webmethod__.method != "GET":
|
||||
continue
|
||||
|
||||
hints = get_type_hints(method)
|
||||
|
||||
if 'return' not in hints:
|
||||
errors.append(f"Method {protocol_name}.{method_name} has no return type annotation")
|
||||
else:
|
||||
return_type = hints['return']
|
||||
if is_optional_type(return_type):
|
||||
errors.append(f"Method {protocol_name}.{method_name} returns Optional type")
|
||||
|
||||
return errors
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue