diff --git a/llama_toolchain/distribution/server.py b/llama_toolchain/distribution/server.py index d45e3b041..9b96d31fc 100644 --- a/llama_toolchain/distribution/server.py +++ b/llama_toolchain/distribution/server.py @@ -19,11 +19,12 @@ import httpx import yaml from dotenv import load_dotenv -from fastapi import FastAPI, Request -from fastapi.responses import StreamingResponse +from fastapi import FastAPI, HTTPException, Request +from fastapi.exceptions import RequestValidationError +from fastapi.responses import JSONResponse, StreamingResponse from fastapi.routing import APIRoute -from pydantic import BaseModel +from pydantic import BaseModel, ValidationError from termcolor import cprint from .datatypes import PassthroughApiAdapter @@ -58,6 +59,22 @@ def create_sse_event(data: Any) -> str: return f"data: {data}\n\n" +async def global_exception_handler(request: Request, exc: Exception): + http_exc = translate_exception(exc) + + return JSONResponse( + status_code=http_exc.status_code, content={"error": {"detail": http_exc.detail}} + ) + + +def translate_exception(exc: Exception) -> HTTPException: + if isinstance(exc, ValidationError): + return RequestValidationError(exc.raw_errors) + + # Add more custom exception translations here + return HTTPException(status_code=500, detail="Internal server error") + + async def passthrough( request: Request, downstream_url: str, @@ -122,17 +139,34 @@ def create_dynamic_typed_route(func: Any): if is_streaming: async def endpoint(request: request_model): - async def event_generator(): - async for item in func(request): - yield create_sse_event(item) - await asyncio.sleep(0.001) + async def sse_generator(event_gen): + try: + async for item in event_gen: + yield create_sse_event(item) + await asyncio.sleep(0.01) + except asyncio.CancelledError: + print("Generator cancelled") + await event_gen.aclose() + except Exception as e: + yield create_sse_event( + { + "error": { + "message": str(translate_exception(e)), + }, + } + ) - return StreamingResponse(event_generator(), media_type="text/event-stream") + return StreamingResponse( + sse_generator(func(request2)), media_type="text/event-stream" + ) else: async def endpoint(request: request_model): - return func(request) + try: + return func(request) + except Exception as e: + raise translate_exception(e) from e return endpoint @@ -188,6 +222,7 @@ def main( attrs=["bold"], ) + app.exception_handler(Exception)(global_exception_handler) signal.signal(signal.SIGINT, handle_sigint) import uvicorn diff --git a/llama_toolchain/inference/client.py b/llama_toolchain/inference/client.py index 331580190..1dfe47b24 100644 --- a/llama_toolchain/inference/client.py +++ b/llama_toolchain/inference/client.py @@ -46,12 +46,25 @@ class InferenceClient(Inference): headers={"Content-Type": "application/json"}, timeout=20, ) as response: + if response.status_code != 200: + content = await response.aread() + cprint( + f"Error: HTTP {response.status_code} {content.decode()}", "red" + ) + return + async for line in response.aiter_lines(): if line.startswith("data:"): data = line[len("data: ") :] try: if request.stream: - yield ChatCompletionResponseStreamChunk(**json.loads(data)) + if "error" in data: + cprint(data, "red") + continue + + yield ChatCompletionResponseStreamChunk( + **json.loads(data) + ) else: yield ChatCompletionResponse(**json.loads(data)) except Exception as e: