more idiomatic REST API

This commit is contained in:
Dinesh Yeduguru 2025-01-14 14:52:32 -08:00
parent d0a25dd453
commit b438dad8d2
29 changed files with 2144 additions and 1917 deletions

View file

@ -14,16 +14,13 @@ import signal
import sys
import traceback
import warnings
from contextlib import asynccontextmanager
from importlib.metadata import version as parse_version
from pathlib import Path
from typing import Any, Union
from typing import Any, List, Union
import yaml
from fastapi import Body, FastAPI, HTTPException, Request
from fastapi import Body, FastAPI, HTTPException, Path as FastapiPath, Request
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import BaseModel, ValidationError
@ -31,7 +28,6 @@ from termcolor import cprint
from typing_extensions import Annotated
from llama_stack.distribution.datatypes import StackRunConfig
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
from llama_stack.distribution.request_headers import set_request_provider_data
from llama_stack.distribution.resolver import InvalidProviderError
@ -41,13 +37,11 @@ from llama_stack.distribution.stack import (
replace_env_vars,
validate_env_pair,
)
from llama_stack.providers.datatypes import Api
from llama_stack.providers.inline.telemetry.meta_reference.config import TelemetryConfig
from llama_stack.providers.inline.telemetry.meta_reference.telemetry import (
TelemetryAdapter,
)
from llama_stack.providers.utils.telemetry.tracing import (
end_trace,
setup_logger,
@ -56,7 +50,6 @@ from llama_stack.providers.utils.telemetry.tracing import (
from .endpoints import get_all_api_endpoints
REPO_ROOT = Path(__file__).parent.parent.parent.parent
@ -178,7 +171,7 @@ async def sse_generator(event_gen):
)
def create_dynamic_typed_route(func: Any, method: str):
def create_dynamic_typed_route(func: Any, method: str, route: str):
async def endpoint(request: Request, **kwargs):
set_request_provider_data(request.headers)
@ -196,6 +189,7 @@ def create_dynamic_typed_route(func: Any, method: str):
raise translate_exception(e) from e
sig = inspect.signature(func)
new_params = [
inspect.Parameter(
"request", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=Request
@ -203,12 +197,21 @@ def create_dynamic_typed_route(func: Any, method: str):
]
new_params.extend(sig.parameters.values())
path_params = extract_path_params(route)
if method == "post":
# make sure every parameter is annotated with Body() so FASTAPI doesn't
# do anything too intelligent and ask for some parameters in the query
# and some in the body
# Annotate parameters that are in the path with Path(...) and others with Body(...)
new_params = [new_params[0]] + [
param.replace(annotation=Annotated[param.annotation, Body(..., embed=True)])
(
param.replace(
annotation=Annotated[
param.annotation, FastapiPath(..., title=param.name)
]
)
if param.name in path_params
else param.replace(
annotation=Annotated[param.annotation, Body(..., embed=True)]
)
)
for param in new_params[1:]
]
@ -386,6 +389,7 @@ def main():
create_dynamic_typed_route(
impl_method,
endpoint.method,
endpoint.route,
)
)
@ -409,5 +413,13 @@ def main():
uvicorn.run(app, host=listen_host, port=args.port)
def extract_path_params(route: str) -> List[str]:
segments = route.split("/")
params = [
seg[1:-1] for seg in segments if seg.startswith("{") and seg.endswith("}")
]
return params
if __name__ == "__main__":
main()