mirror of
https://github.com/meta-llama/llama-stack.git
synced 2026-01-01 13:04:32 +00:00
more idiomatic REST API
This commit is contained in:
parent
d0a25dd453
commit
b438dad8d2
29 changed files with 2144 additions and 1917 deletions
|
|
@ -14,16 +14,13 @@ import signal
|
|||
import sys
|
||||
import traceback
|
||||
import warnings
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from importlib.metadata import version as parse_version
|
||||
from pathlib import Path
|
||||
from typing import Any, Union
|
||||
from typing import Any, List, Union
|
||||
|
||||
import yaml
|
||||
|
||||
from fastapi import Body, FastAPI, HTTPException, Request
|
||||
from fastapi import Body, FastAPI, HTTPException, Path as FastapiPath, Request
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
|
@ -31,7 +28,6 @@ from termcolor import cprint
|
|||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.distribution.datatypes import StackRunConfig
|
||||
|
||||
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
||||
from llama_stack.distribution.request_headers import set_request_provider_data
|
||||
from llama_stack.distribution.resolver import InvalidProviderError
|
||||
|
|
@ -41,13 +37,11 @@ from llama_stack.distribution.stack import (
|
|||
replace_env_vars,
|
||||
validate_env_pair,
|
||||
)
|
||||
|
||||
from llama_stack.providers.datatypes import Api
|
||||
from llama_stack.providers.inline.telemetry.meta_reference.config import TelemetryConfig
|
||||
from llama_stack.providers.inline.telemetry.meta_reference.telemetry import (
|
||||
TelemetryAdapter,
|
||||
)
|
||||
|
||||
from llama_stack.providers.utils.telemetry.tracing import (
|
||||
end_trace,
|
||||
setup_logger,
|
||||
|
|
@ -56,7 +50,6 @@ from llama_stack.providers.utils.telemetry.tracing import (
|
|||
|
||||
from .endpoints import get_all_api_endpoints
|
||||
|
||||
|
||||
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
||||
|
||||
|
||||
|
|
@ -178,7 +171,7 @@ async def sse_generator(event_gen):
|
|||
)
|
||||
|
||||
|
||||
def create_dynamic_typed_route(func: Any, method: str):
|
||||
def create_dynamic_typed_route(func: Any, method: str, route: str):
|
||||
async def endpoint(request: Request, **kwargs):
|
||||
set_request_provider_data(request.headers)
|
||||
|
||||
|
|
@ -196,6 +189,7 @@ def create_dynamic_typed_route(func: Any, method: str):
|
|||
raise translate_exception(e) from e
|
||||
|
||||
sig = inspect.signature(func)
|
||||
|
||||
new_params = [
|
||||
inspect.Parameter(
|
||||
"request", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=Request
|
||||
|
|
@ -203,12 +197,21 @@ def create_dynamic_typed_route(func: Any, method: str):
|
|||
]
|
||||
new_params.extend(sig.parameters.values())
|
||||
|
||||
path_params = extract_path_params(route)
|
||||
if method == "post":
|
||||
# make sure every parameter is annotated with Body() so FASTAPI doesn't
|
||||
# do anything too intelligent and ask for some parameters in the query
|
||||
# and some in the body
|
||||
# Annotate parameters that are in the path with Path(...) and others with Body(...)
|
||||
new_params = [new_params[0]] + [
|
||||
param.replace(annotation=Annotated[param.annotation, Body(..., embed=True)])
|
||||
(
|
||||
param.replace(
|
||||
annotation=Annotated[
|
||||
param.annotation, FastapiPath(..., title=param.name)
|
||||
]
|
||||
)
|
||||
if param.name in path_params
|
||||
else param.replace(
|
||||
annotation=Annotated[param.annotation, Body(..., embed=True)]
|
||||
)
|
||||
)
|
||||
for param in new_params[1:]
|
||||
]
|
||||
|
||||
|
|
@ -386,6 +389,7 @@ def main():
|
|||
create_dynamic_typed_route(
|
||||
impl_method,
|
||||
endpoint.method,
|
||||
endpoint.route,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -409,5 +413,13 @@ def main():
|
|||
uvicorn.run(app, host=listen_host, port=args.port)
|
||||
|
||||
|
||||
def extract_path_params(route: str) -> List[str]:
|
||||
segments = route.split("/")
|
||||
params = [
|
||||
seg[1:-1] for seg in segments if seg.startswith("{") and seg.endswith("}")
|
||||
]
|
||||
return params
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue