mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-10 05:24:39 +00:00
memory client works
This commit is contained in:
parent
a08958c000
commit
8d14d4228b
8 changed files with 164 additions and 86 deletions
|
@ -11,6 +11,8 @@ from llama_toolchain.agentic_system.api.endpoints import AgenticSystem
|
|||
from llama_toolchain.agentic_system.providers import available_agentic_system_providers
|
||||
from llama_toolchain.inference.api.endpoints import Inference
|
||||
from llama_toolchain.inference.providers import available_inference_providers
|
||||
from llama_toolchain.memory.api.endpoints import Memory
|
||||
from llama_toolchain.memory.providers import available_memory_providers
|
||||
from llama_toolchain.safety.api.endpoints import Safety
|
||||
from llama_toolchain.safety.providers import available_safety_providers
|
||||
|
||||
|
@ -47,6 +49,7 @@ def api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
|
|||
Api.inference: Inference,
|
||||
Api.safety: Safety,
|
||||
Api.agentic_system: AgenticSystem,
|
||||
Api.memory: Memory,
|
||||
}
|
||||
|
||||
for api, protocol in protocols.items():
|
||||
|
@ -60,9 +63,13 @@ def api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
|
|||
webmethod = method.__webmethod__
|
||||
route = webmethod.route
|
||||
|
||||
# use `post` for all methods right now until we fix up the `webmethod` openapi
|
||||
# annotation and write our own openapi generator
|
||||
endpoints.append(ApiEndpoint(route=route, method="post", name=name))
|
||||
if webmethod.method == "GET":
|
||||
method = "get"
|
||||
elif webmethod.method == "DELETE":
|
||||
method = "delete"
|
||||
else:
|
||||
method = "post"
|
||||
endpoints.append(ApiEndpoint(route=route, method=method, name=name))
|
||||
|
||||
apis[api] = endpoints
|
||||
|
||||
|
@ -82,4 +89,5 @@ def api_providers() -> Dict[Api, Dict[str, ProviderSpec]]:
|
|||
Api.inference: inference_providers_by_id,
|
||||
Api.safety: safety_providers_by_id,
|
||||
Api.agentic_system: agentic_system_providers_by_id,
|
||||
Api.memory: {a.provider_id: a for a in available_memory_providers()},
|
||||
}
|
||||
|
|
|
@ -53,6 +53,7 @@ def available_distribution_specs() -> List[DistributionSpec]:
|
|||
),
|
||||
DistributionSpec(
|
||||
spec_id="test-memory",
|
||||
description="Just a test distribution spec for testing memory bank APIs",
|
||||
provider_specs={
|
||||
Api.memory: providers[Api.memory]["meta-reference-faiss"],
|
||||
},
|
||||
|
|
|
@ -5,8 +5,10 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
import signal
|
||||
import traceback
|
||||
from collections.abc import (
|
||||
AsyncGenerator as AsyncGeneratorABC,
|
||||
AsyncIterator as AsyncIteratorABC,
|
||||
|
@ -28,12 +30,13 @@ import fire
|
|||
import httpx
|
||||
import yaml
|
||||
|
||||
from fastapi import FastAPI, HTTPException, Request, Response
|
||||
from fastapi import Body, FastAPI, HTTPException, Request, Response
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from fastapi.routing import APIRoute
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from termcolor import cprint
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from .datatypes import Api, DistributionSpec, ProviderSpec, RemoteProviderSpec
|
||||
from .distribution import api_endpoints
|
||||
|
@ -66,6 +69,7 @@ def create_sse_event(data: Any) -> str:
|
|||
|
||||
|
||||
async def global_exception_handler(request: Request, exc: Exception):
|
||||
traceback.print_exception(exc)
|
||||
http_exc = translate_exception(exc)
|
||||
|
||||
return JSONResponse(
|
||||
|
@ -155,9 +159,8 @@ def create_dynamic_passthrough(
|
|||
return endpoint
|
||||
|
||||
|
||||
def create_dynamic_typed_route(func: Any):
|
||||
def create_dynamic_typed_route(func: Any, method: str):
|
||||
hints = get_type_hints(func)
|
||||
request_model = next(iter(hints.values()))
|
||||
response_model = hints["return"]
|
||||
|
||||
# NOTE: I think it is better to just add a method within each Api
|
||||
|
@ -168,7 +171,7 @@ def create_dynamic_typed_route(func: Any):
|
|||
|
||||
if is_streaming:
|
||||
|
||||
async def endpoint(request: request_model):
|
||||
async def endpoint(**kwargs):
|
||||
async def sse_generator(event_gen):
|
||||
try:
|
||||
async for item in event_gen:
|
||||
|
@ -178,10 +181,7 @@ def create_dynamic_typed_route(func: Any):
|
|||
print("Generator cancelled")
|
||||
await event_gen.aclose()
|
||||
except Exception as e:
|
||||
print(e)
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
traceback.print_exception(e)
|
||||
yield create_sse_event(
|
||||
{
|
||||
"error": {
|
||||
|
@ -191,25 +191,36 @@ def create_dynamic_typed_route(func: Any):
|
|||
)
|
||||
|
||||
return StreamingResponse(
|
||||
sse_generator(func(request)), media_type="text/event-stream"
|
||||
sse_generator(func(**kwargs)), media_type="text/event-stream"
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
async def endpoint(request: request_model):
|
||||
async def endpoint(**kwargs):
|
||||
try:
|
||||
return (
|
||||
await func(request)
|
||||
await func(**kwargs)
|
||||
if asyncio.iscoroutinefunction(func)
|
||||
else func(request)
|
||||
else func(**kwargs)
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
traceback.print_exception(e)
|
||||
raise translate_exception(e) from e
|
||||
|
||||
sig = inspect.signature(func)
|
||||
if method == "post":
|
||||
# make sure every parameter is annotated with Body() so FASTAPI doesn't
|
||||
# do anything too intelligent and ask for some parameters in the query
|
||||
# and some in the body
|
||||
endpoint.__signature__ = sig.replace(
|
||||
parameters=[
|
||||
param.replace(annotation=Annotated[param.annotation, Body()])
|
||||
for param in sig.parameters.values()
|
||||
]
|
||||
)
|
||||
else:
|
||||
endpoint.__signature__ = sig
|
||||
|
||||
return endpoint
|
||||
|
||||
|
||||
|
@ -296,7 +307,7 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
|
|||
|
||||
impl_method = getattr(impl, endpoint.name)
|
||||
getattr(app, endpoint.method)(endpoint.route, response_model=None)(
|
||||
create_dynamic_typed_route(impl_method)
|
||||
create_dynamic_typed_route(impl_method, endpoint.method)
|
||||
)
|
||||
|
||||
for route in app.routes:
|
||||
|
@ -307,6 +318,7 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
|
|||
attrs=["bold"],
|
||||
)
|
||||
|
||||
app.exception_handler(RequestValidationError)(global_exception_handler)
|
||||
app.exception_handler(Exception)(global_exception_handler)
|
||||
signal.signal(signal.SIGINT, handle_sigint)
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue