mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
More idiomatic REST API (#765)
# What does this PR do? This PR changes our API to follow more idiomatic REST API approaches of having paths being resources and methods indicating the action being performed. Changes made to generator: 1) removed the prefix check of "get" as its not required and is actually needed for other method types too 2) removed _ check on path since variables can have "_" ## Test Plan LLAMA_STACK_BASE_URL=http://localhost:5000 pytest -v tests/client-sdk/agents/test_agents.py
This commit is contained in:
parent
6deef1ece0
commit
7fb2c1c48d
29 changed files with 2144 additions and 1917 deletions
|
@ -14,16 +14,13 @@ import signal
|
|||
import sys
|
||||
import traceback
|
||||
import warnings
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from importlib.metadata import version as parse_version
|
||||
from pathlib import Path
|
||||
from typing import Any, Union
|
||||
from typing import Any, List, Union
|
||||
|
||||
import yaml
|
||||
|
||||
from fastapi import Body, FastAPI, HTTPException, Request
|
||||
from fastapi import Body, FastAPI, HTTPException, Path as FastapiPath, Request
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
@ -31,7 +28,6 @@ from termcolor import cprint
|
|||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.distribution.datatypes import StackRunConfig
|
||||
|
||||
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
||||
from llama_stack.distribution.request_headers import set_request_provider_data
|
||||
from llama_stack.distribution.resolver import InvalidProviderError
|
||||
|
@ -41,13 +37,11 @@ from llama_stack.distribution.stack import (
|
|||
replace_env_vars,
|
||||
validate_env_pair,
|
||||
)
|
||||
|
||||
from llama_stack.providers.datatypes import Api
|
||||
from llama_stack.providers.inline.telemetry.meta_reference.config import TelemetryConfig
|
||||
from llama_stack.providers.inline.telemetry.meta_reference.telemetry import (
|
||||
TelemetryAdapter,
|
||||
)
|
||||
|
||||
from llama_stack.providers.utils.telemetry.tracing import (
|
||||
end_trace,
|
||||
setup_logger,
|
||||
|
@ -56,7 +50,6 @@ from llama_stack.providers.utils.telemetry.tracing import (
|
|||
|
||||
from .endpoints import get_all_api_endpoints
|
||||
|
||||
|
||||
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
||||
|
||||
|
||||
|
@ -178,7 +171,7 @@ async def sse_generator(event_gen):
|
|||
)
|
||||
|
||||
|
||||
def create_dynamic_typed_route(func: Any, method: str):
|
||||
def create_dynamic_typed_route(func: Any, method: str, route: str):
|
||||
async def endpoint(request: Request, **kwargs):
|
||||
set_request_provider_data(request.headers)
|
||||
|
||||
|
@ -196,6 +189,7 @@ def create_dynamic_typed_route(func: Any, method: str):
|
|||
raise translate_exception(e) from e
|
||||
|
||||
sig = inspect.signature(func)
|
||||
|
||||
new_params = [
|
||||
inspect.Parameter(
|
||||
"request", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=Request
|
||||
|
@ -203,12 +197,21 @@ def create_dynamic_typed_route(func: Any, method: str):
|
|||
]
|
||||
new_params.extend(sig.parameters.values())
|
||||
|
||||
path_params = extract_path_params(route)
|
||||
if method == "post":
|
||||
# make sure every parameter is annotated with Body() so FASTAPI doesn't
|
||||
# do anything too intelligent and ask for some parameters in the query
|
||||
# and some in the body
|
||||
# Annotate parameters that are in the path with Path(...) and others with Body(...)
|
||||
new_params = [new_params[0]] + [
|
||||
param.replace(annotation=Annotated[param.annotation, Body(..., embed=True)])
|
||||
(
|
||||
param.replace(
|
||||
annotation=Annotated[
|
||||
param.annotation, FastapiPath(..., title=param.name)
|
||||
]
|
||||
)
|
||||
if param.name in path_params
|
||||
else param.replace(
|
||||
annotation=Annotated[param.annotation, Body(..., embed=True)]
|
||||
)
|
||||
)
|
||||
for param in new_params[1:]
|
||||
]
|
||||
|
||||
|
@ -386,6 +389,7 @@ def main():
|
|||
create_dynamic_typed_route(
|
||||
impl_method,
|
||||
endpoint.method,
|
||||
endpoint.route,
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -409,5 +413,13 @@ def main():
|
|||
uvicorn.run(app, host=listen_host, port=args.port)
|
||||
|
||||
|
||||
def extract_path_params(route: str) -> List[str]:
|
||||
segments = route.split("/")
|
||||
params = [
|
||||
seg[1:-1] for seg in segments if seg.startswith("{") and seg.endswith("}")
|
||||
]
|
||||
return params
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue