memory client works

This commit is contained in:
Ashwin Bharambe 2024-08-24 18:43:49 -07:00
parent a08958c000
commit 8d14d4228b
8 changed files with 164 additions and 86 deletions

View file

@ -5,8 +5,10 @@
# the root directory of this source tree.
import asyncio
import inspect
import json
import signal
import traceback
from collections.abc import (
AsyncGenerator as AsyncGeneratorABC,
AsyncIterator as AsyncIteratorABC,
@ -28,12 +30,13 @@ import fire
import httpx
import yaml
from fastapi import FastAPI, HTTPException, Request, Response
from fastapi import Body, FastAPI, HTTPException, Request, Response
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.routing import APIRoute
from pydantic import BaseModel, ValidationError
from termcolor import cprint
from typing_extensions import Annotated
from .datatypes import Api, DistributionSpec, ProviderSpec, RemoteProviderSpec
from .distribution import api_endpoints
@ -66,6 +69,7 @@ def create_sse_event(data: Any) -> str:
async def global_exception_handler(request: Request, exc: Exception):
traceback.print_exception(exc)
http_exc = translate_exception(exc)
return JSONResponse(
@ -155,9 +159,8 @@ def create_dynamic_passthrough(
return endpoint
def create_dynamic_typed_route(func: Any):
def create_dynamic_typed_route(func: Any, method: str):
hints = get_type_hints(func)
request_model = next(iter(hints.values()))
response_model = hints["return"]
# NOTE: I think it is better to just add a method within each Api
@ -168,7 +171,7 @@ def create_dynamic_typed_route(func: Any):
if is_streaming:
async def endpoint(request: request_model):
async def endpoint(**kwargs):
async def sse_generator(event_gen):
try:
async for item in event_gen:
@ -178,10 +181,7 @@ def create_dynamic_typed_route(func: Any):
print("Generator cancelled")
await event_gen.aclose()
except Exception as e:
print(e)
import traceback
traceback.print_exc()
traceback.print_exception(e)
yield create_sse_event(
{
"error": {
@ -191,25 +191,36 @@ def create_dynamic_typed_route(func: Any):
)
return StreamingResponse(
sse_generator(func(request)), media_type="text/event-stream"
sse_generator(func(**kwargs)), media_type="text/event-stream"
)
else:
async def endpoint(request: request_model):
async def endpoint(**kwargs):
try:
return (
await func(request)
await func(**kwargs)
if asyncio.iscoroutinefunction(func)
else func(request)
else func(**kwargs)
)
except Exception as e:
print(e)
import traceback
traceback.print_exc()
traceback.print_exception(e)
raise translate_exception(e) from e
sig = inspect.signature(func)
if method == "post":
# make sure every parameter is annotated with Body() so FASTAPI doesn't
# do anything too intelligent and ask for some parameters in the query
# and some in the body
endpoint.__signature__ = sig.replace(
parameters=[
param.replace(annotation=Annotated[param.annotation, Body()])
for param in sig.parameters.values()
]
)
else:
endpoint.__signature__ = sig
return endpoint
@ -296,7 +307,7 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
impl_method = getattr(impl, endpoint.name)
getattr(app, endpoint.method)(endpoint.route, response_model=None)(
create_dynamic_typed_route(impl_method)
create_dynamic_typed_route(impl_method, endpoint.method)
)
for route in app.routes:
@ -307,6 +318,7 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
attrs=["bold"],
)
app.exception_handler(RequestValidationError)(global_exception_handler)
app.exception_handler(Exception)(global_exception_handler)
signal.signal(signal.SIGINT, handle_sigint)