mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
memory client works
This commit is contained in:
parent
a08958c000
commit
8d14d4228b
8 changed files with 164 additions and 86 deletions
|
@ -5,8 +5,10 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
import signal
|
||||
import traceback
|
||||
from collections.abc import (
|
||||
AsyncGenerator as AsyncGeneratorABC,
|
||||
AsyncIterator as AsyncIteratorABC,
|
||||
|
@ -28,12 +30,13 @@ import fire
|
|||
import httpx
|
||||
import yaml
|
||||
|
||||
from fastapi import FastAPI, HTTPException, Request, Response
|
||||
from fastapi import Body, FastAPI, HTTPException, Request, Response
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from fastapi.routing import APIRoute
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from termcolor import cprint
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from .datatypes import Api, DistributionSpec, ProviderSpec, RemoteProviderSpec
|
||||
from .distribution import api_endpoints
|
||||
|
@ -66,6 +69,7 @@ def create_sse_event(data: Any) -> str:
|
|||
|
||||
|
||||
async def global_exception_handler(request: Request, exc: Exception):
|
||||
traceback.print_exception(exc)
|
||||
http_exc = translate_exception(exc)
|
||||
|
||||
return JSONResponse(
|
||||
|
@ -155,9 +159,8 @@ def create_dynamic_passthrough(
|
|||
return endpoint
|
||||
|
||||
|
||||
def create_dynamic_typed_route(func: Any):
|
||||
def create_dynamic_typed_route(func: Any, method: str):
|
||||
hints = get_type_hints(func)
|
||||
request_model = next(iter(hints.values()))
|
||||
response_model = hints["return"]
|
||||
|
||||
# NOTE: I think it is better to just add a method within each Api
|
||||
|
@ -168,7 +171,7 @@ def create_dynamic_typed_route(func: Any):
|
|||
|
||||
if is_streaming:
|
||||
|
||||
async def endpoint(request: request_model):
|
||||
async def endpoint(**kwargs):
|
||||
async def sse_generator(event_gen):
|
||||
try:
|
||||
async for item in event_gen:
|
||||
|
@ -178,10 +181,7 @@ def create_dynamic_typed_route(func: Any):
|
|||
print("Generator cancelled")
|
||||
await event_gen.aclose()
|
||||
except Exception as e:
|
||||
print(e)
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
traceback.print_exception(e)
|
||||
yield create_sse_event(
|
||||
{
|
||||
"error": {
|
||||
|
@ -191,25 +191,36 @@ def create_dynamic_typed_route(func: Any):
|
|||
)
|
||||
|
||||
return StreamingResponse(
|
||||
sse_generator(func(request)), media_type="text/event-stream"
|
||||
sse_generator(func(**kwargs)), media_type="text/event-stream"
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
async def endpoint(request: request_model):
|
||||
async def endpoint(**kwargs):
|
||||
try:
|
||||
return (
|
||||
await func(request)
|
||||
await func(**kwargs)
|
||||
if asyncio.iscoroutinefunction(func)
|
||||
else func(request)
|
||||
else func(**kwargs)
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
traceback.print_exception(e)
|
||||
raise translate_exception(e) from e
|
||||
|
||||
sig = inspect.signature(func)
|
||||
if method == "post":
|
||||
# make sure every parameter is annotated with Body() so FASTAPI doesn't
|
||||
# do anything too intelligent and ask for some parameters in the query
|
||||
# and some in the body
|
||||
endpoint.__signature__ = sig.replace(
|
||||
parameters=[
|
||||
param.replace(annotation=Annotated[param.annotation, Body()])
|
||||
for param in sig.parameters.values()
|
||||
]
|
||||
)
|
||||
else:
|
||||
endpoint.__signature__ = sig
|
||||
|
||||
return endpoint
|
||||
|
||||
|
||||
|
@ -296,7 +307,7 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
|
|||
|
||||
impl_method = getattr(impl, endpoint.name)
|
||||
getattr(app, endpoint.method)(endpoint.route, response_model=None)(
|
||||
create_dynamic_typed_route(impl_method)
|
||||
create_dynamic_typed_route(impl_method, endpoint.method)
|
||||
)
|
||||
|
||||
for route in app.routes:
|
||||
|
@ -307,6 +318,7 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
|
|||
attrs=["bold"],
|
||||
)
|
||||
|
||||
app.exception_handler(RequestValidationError)(global_exception_handler)
|
||||
app.exception_handler(Exception)(global_exception_handler)
|
||||
signal.signal(signal.SIGINT, handle_sigint)
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue