Improved exception handling

This commit is contained in:
Ashwin Bharambe 2024-08-02 14:54:06 -07:00
parent 493f0d99b2
commit af4710c959
2 changed files with 58 additions and 10 deletions

View file

@ -19,11 +19,12 @@ import httpx
import yaml
from dotenv import load_dotenv
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
from fastapi import FastAPI, HTTPException, Request
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.routing import APIRoute
from pydantic import BaseModel
from pydantic import BaseModel, ValidationError
from termcolor import cprint
from .datatypes import PassthroughApiAdapter
@ -58,6 +59,22 @@ def create_sse_event(data: Any) -> str:
return f"data: {data}\n\n"
async def global_exception_handler(request: Request, exc: Exception):
http_exc = translate_exception(exc)
return JSONResponse(
status_code=http_exc.status_code, content={"error": {"detail": http_exc.detail}}
)
def translate_exception(exc: Exception) -> HTTPException:
if isinstance(exc, ValidationError):
return RequestValidationError(exc.raw_errors)
# Add more custom exception translations here
return HTTPException(status_code=500, detail="Internal server error")
async def passthrough(
request: Request,
downstream_url: str,
@ -122,17 +139,34 @@ def create_dynamic_typed_route(func: Any):
if is_streaming:
async def endpoint(request: request_model):
async def event_generator():
async for item in func(request):
yield create_sse_event(item)
await asyncio.sleep(0.001)
async def sse_generator(event_gen):
try:
async for item in event_gen:
yield create_sse_event(item)
await asyncio.sleep(0.01)
except asyncio.CancelledError:
print("Generator cancelled")
await event_gen.aclose()
except Exception as e:
yield create_sse_event(
{
"error": {
"message": str(translate_exception(e)),
},
}
)
return StreamingResponse(event_generator(), media_type="text/event-stream")
return StreamingResponse(
sse_generator(func(request2)), media_type="text/event-stream"
)
else:
async def endpoint(request: request_model):
return func(request)
try:
return func(request)
except Exception as e:
raise translate_exception(e) from e
return endpoint
@ -188,6 +222,7 @@ def main(
attrs=["bold"],
)
app.exception_handler(Exception)(global_exception_handler)
signal.signal(signal.SIGINT, handle_sigint)
import uvicorn

View file

@ -46,12 +46,25 @@ class InferenceClient(Inference):
headers={"Content-Type": "application/json"},
timeout=20,
) as response:
if response.status_code != 200:
content = await response.aread()
cprint(
f"Error: HTTP {response.status_code} {content.decode()}", "red"
)
return
async for line in response.aiter_lines():
if line.startswith("data:"):
data = line[len("data: ") :]
try:
if request.stream:
yield ChatCompletionResponseStreamChunk(**json.loads(data))
if "error" in data:
cprint(data, "red")
continue
yield ChatCompletionResponseStreamChunk(
**json.loads(data)
)
else:
yield ChatCompletionResponse(**json.loads(data))
except Exception as e: