Improved exception handling

This commit is contained in:
Ashwin Bharambe 2024-08-02 14:54:06 -07:00
parent 493f0d99b2
commit af4710c959
2 changed files with 58 additions and 10 deletions

View file

@ -19,11 +19,12 @@ import httpx
import yaml import yaml
from dotenv import load_dotenv from dotenv import load_dotenv
from fastapi import FastAPI, Request from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import StreamingResponse from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.routing import APIRoute from fastapi.routing import APIRoute
from pydantic import BaseModel from pydantic import BaseModel, ValidationError
from termcolor import cprint from termcolor import cprint
from .datatypes import PassthroughApiAdapter from .datatypes import PassthroughApiAdapter
@ -58,6 +59,22 @@ def create_sse_event(data: Any) -> str:
return f"data: {data}\n\n" return f"data: {data}\n\n"
async def global_exception_handler(request: Request, exc: Exception):
http_exc = translate_exception(exc)
return JSONResponse(
status_code=http_exc.status_code, content={"error": {"detail": http_exc.detail}}
)
def translate_exception(exc: Exception) -> HTTPException:
if isinstance(exc, ValidationError):
return RequestValidationError(exc.raw_errors)
# Add more custom exception translations here
return HTTPException(status_code=500, detail="Internal server error")
async def passthrough( async def passthrough(
request: Request, request: Request,
downstream_url: str, downstream_url: str,
@ -122,17 +139,34 @@ def create_dynamic_typed_route(func: Any):
if is_streaming: if is_streaming:
async def endpoint(request: request_model): async def endpoint(request: request_model):
async def event_generator(): async def sse_generator(event_gen):
async for item in func(request): try:
yield create_sse_event(item) async for item in event_gen:
await asyncio.sleep(0.001) yield create_sse_event(item)
await asyncio.sleep(0.01)
except asyncio.CancelledError:
print("Generator cancelled")
await event_gen.aclose()
except Exception as e:
yield create_sse_event(
{
"error": {
"message": str(translate_exception(e)),
},
}
)
return StreamingResponse(event_generator(), media_type="text/event-stream") return StreamingResponse(
sse_generator(func(request2)), media_type="text/event-stream"
)
else: else:
async def endpoint(request: request_model): async def endpoint(request: request_model):
return func(request) try:
return func(request)
except Exception as e:
raise translate_exception(e) from e
return endpoint return endpoint
@ -188,6 +222,7 @@ def main(
attrs=["bold"], attrs=["bold"],
) )
app.exception_handler(Exception)(global_exception_handler)
signal.signal(signal.SIGINT, handle_sigint) signal.signal(signal.SIGINT, handle_sigint)
import uvicorn import uvicorn

View file

@ -46,12 +46,25 @@ class InferenceClient(Inference):
headers={"Content-Type": "application/json"}, headers={"Content-Type": "application/json"},
timeout=20, timeout=20,
) as response: ) as response:
if response.status_code != 200:
content = await response.aread()
cprint(
f"Error: HTTP {response.status_code} {content.decode()}", "red"
)
return
async for line in response.aiter_lines(): async for line in response.aiter_lines():
if line.startswith("data:"): if line.startswith("data:"):
data = line[len("data: ") :] data = line[len("data: ") :]
try: try:
if request.stream: if request.stream:
yield ChatCompletionResponseStreamChunk(**json.loads(data)) if "error" in data:
cprint(data, "red")
continue
yield ChatCompletionResponseStreamChunk(
**json.loads(data)
)
else: else:
yield ChatCompletionResponse(**json.loads(data)) yield ChatCompletionResponse(**json.loads(data))
except Exception as e: except Exception as e: