memory client works

This commit is contained in:
Ashwin Bharambe 2024-08-24 18:43:49 -07:00
parent a08958c000
commit 8d14d4228b
8 changed files with 164 additions and 86 deletions

View file

@ -11,6 +11,8 @@ from llama_toolchain.agentic_system.api.endpoints import AgenticSystem
from llama_toolchain.agentic_system.providers import available_agentic_system_providers
from llama_toolchain.inference.api.endpoints import Inference
from llama_toolchain.inference.providers import available_inference_providers
from llama_toolchain.memory.api.endpoints import Memory
from llama_toolchain.memory.providers import available_memory_providers
from llama_toolchain.safety.api.endpoints import Safety
from llama_toolchain.safety.providers import available_safety_providers
@ -47,6 +49,7 @@ def api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
Api.inference: Inference,
Api.safety: Safety,
Api.agentic_system: AgenticSystem,
Api.memory: Memory,
}
for api, protocol in protocols.items():
@ -60,9 +63,13 @@ def api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
webmethod = method.__webmethod__
route = webmethod.route
# use `post` for all methods right now until we fix up the `webmethod` openapi
# annotation and write our own openapi generator
endpoints.append(ApiEndpoint(route=route, method="post", name=name))
if webmethod.method == "GET":
method = "get"
elif webmethod.method == "DELETE":
method = "delete"
else:
method = "post"
endpoints.append(ApiEndpoint(route=route, method=method, name=name))
apis[api] = endpoints
@ -82,4 +89,5 @@ def api_providers() -> Dict[Api, Dict[str, ProviderSpec]]:
Api.inference: inference_providers_by_id,
Api.safety: safety_providers_by_id,
Api.agentic_system: agentic_system_providers_by_id,
Api.memory: {a.provider_id: a for a in available_memory_providers()},
}

View file

@ -53,6 +53,7 @@ def available_distribution_specs() -> List[DistributionSpec]:
),
DistributionSpec(
spec_id="test-memory",
description="Just a test distribution spec for testing memory bank APIs",
provider_specs={
Api.memory: providers[Api.memory]["meta-reference-faiss"],
},

View file

@ -5,8 +5,10 @@
# the root directory of this source tree.
import asyncio
import inspect
import json
import signal
import traceback
from collections.abc import (
AsyncGenerator as AsyncGeneratorABC,
AsyncIterator as AsyncIteratorABC,
@ -28,12 +30,13 @@ import fire
import httpx
import yaml
from fastapi import FastAPI, HTTPException, Request, Response
from fastapi import Body, FastAPI, HTTPException, Request, Response
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.routing import APIRoute
from pydantic import BaseModel, ValidationError
from termcolor import cprint
from typing_extensions import Annotated
from .datatypes import Api, DistributionSpec, ProviderSpec, RemoteProviderSpec
from .distribution import api_endpoints
@ -66,6 +69,7 @@ def create_sse_event(data: Any) -> str:
async def global_exception_handler(request: Request, exc: Exception):
traceback.print_exception(exc)
http_exc = translate_exception(exc)
return JSONResponse(
@ -155,9 +159,8 @@ def create_dynamic_passthrough(
return endpoint
def create_dynamic_typed_route(func: Any):
def create_dynamic_typed_route(func: Any, method: str):
hints = get_type_hints(func)
request_model = next(iter(hints.values()))
response_model = hints["return"]
# NOTE: I think it is better to just add a method within each Api
@ -168,7 +171,7 @@ def create_dynamic_typed_route(func: Any):
if is_streaming:
async def endpoint(request: request_model):
async def endpoint(**kwargs):
async def sse_generator(event_gen):
try:
async for item in event_gen:
@ -178,10 +181,7 @@ def create_dynamic_typed_route(func: Any):
print("Generator cancelled")
await event_gen.aclose()
except Exception as e:
print(e)
import traceback
traceback.print_exc()
traceback.print_exception(e)
yield create_sse_event(
{
"error": {
@ -191,25 +191,36 @@ def create_dynamic_typed_route(func: Any):
)
return StreamingResponse(
sse_generator(func(request)), media_type="text/event-stream"
sse_generator(func(**kwargs)), media_type="text/event-stream"
)
else:
async def endpoint(request: request_model):
async def endpoint(**kwargs):
try:
return (
await func(request)
await func(**kwargs)
if asyncio.iscoroutinefunction(func)
else func(request)
else func(**kwargs)
)
except Exception as e:
print(e)
import traceback
traceback.print_exc()
traceback.print_exception(e)
raise translate_exception(e) from e
sig = inspect.signature(func)
if method == "post":
# make sure every parameter is annotated with Body() so FASTAPI doesn't
# do anything too intelligent and ask for some parameters in the query
# and some in the body
endpoint.__signature__ = sig.replace(
parameters=[
param.replace(annotation=Annotated[param.annotation, Body()])
for param in sig.parameters.values()
]
)
else:
endpoint.__signature__ = sig
return endpoint
@ -296,7 +307,7 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
impl_method = getattr(impl, endpoint.name)
getattr(app, endpoint.method)(endpoint.route, response_model=None)(
create_dynamic_typed_route(impl_method)
create_dynamic_typed_route(impl_method, endpoint.method)
)
for route in app.routes:
@ -307,6 +318,7 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
attrs=["bold"],
)
app.exception_handler(RequestValidationError)(global_exception_handler)
app.exception_handler(Exception)(global_exception_handler)
signal.signal(signal.SIGINT, handle_sigint)