mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
Improved exception handling
This commit is contained in:
parent
493f0d99b2
commit
af4710c959
2 changed files with 58 additions and 10 deletions
|
@ -19,11 +19,12 @@ import httpx
|
|||
import yaml
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from fastapi.routing import APIRoute
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from termcolor import cprint
|
||||
|
||||
from .datatypes import PassthroughApiAdapter
|
||||
|
@ -58,6 +59,22 @@ def create_sse_event(data: Any) -> str:
|
|||
return f"data: {data}\n\n"
|
||||
|
||||
|
||||
async def global_exception_handler(request: Request, exc: Exception):
|
||||
http_exc = translate_exception(exc)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=http_exc.status_code, content={"error": {"detail": http_exc.detail}}
|
||||
)
|
||||
|
||||
|
||||
def translate_exception(exc: Exception) -> HTTPException:
|
||||
if isinstance(exc, ValidationError):
|
||||
return RequestValidationError(exc.raw_errors)
|
||||
|
||||
# Add more custom exception translations here
|
||||
return HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
async def passthrough(
|
||||
request: Request,
|
||||
downstream_url: str,
|
||||
|
@ -122,17 +139,34 @@ def create_dynamic_typed_route(func: Any):
|
|||
if is_streaming:
|
||||
|
||||
async def endpoint(request: request_model):
|
||||
async def event_generator():
|
||||
async for item in func(request):
|
||||
async def sse_generator(event_gen):
|
||||
try:
|
||||
async for item in event_gen:
|
||||
yield create_sse_event(item)
|
||||
await asyncio.sleep(0.001)
|
||||
await asyncio.sleep(0.01)
|
||||
except asyncio.CancelledError:
|
||||
print("Generator cancelled")
|
||||
await event_gen.aclose()
|
||||
except Exception as e:
|
||||
yield create_sse_event(
|
||||
{
|
||||
"error": {
|
||||
"message": str(translate_exception(e)),
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
return StreamingResponse(event_generator(), media_type="text/event-stream")
|
||||
return StreamingResponse(
|
||||
sse_generator(func(request2)), media_type="text/event-stream"
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
async def endpoint(request: request_model):
|
||||
try:
|
||||
return func(request)
|
||||
except Exception as e:
|
||||
raise translate_exception(e) from e
|
||||
|
||||
return endpoint
|
||||
|
||||
|
@ -188,6 +222,7 @@ def main(
|
|||
attrs=["bold"],
|
||||
)
|
||||
|
||||
app.exception_handler(Exception)(global_exception_handler)
|
||||
signal.signal(signal.SIGINT, handle_sigint)
|
||||
|
||||
import uvicorn
|
||||
|
|
|
@ -46,12 +46,25 @@ class InferenceClient(Inference):
|
|||
headers={"Content-Type": "application/json"},
|
||||
timeout=20,
|
||||
) as response:
|
||||
if response.status_code != 200:
|
||||
content = await response.aread()
|
||||
cprint(
|
||||
f"Error: HTTP {response.status_code} {content.decode()}", "red"
|
||||
)
|
||||
return
|
||||
|
||||
async for line in response.aiter_lines():
|
||||
if line.startswith("data:"):
|
||||
data = line[len("data: ") :]
|
||||
try:
|
||||
if request.stream:
|
||||
yield ChatCompletionResponseStreamChunk(**json.loads(data))
|
||||
if "error" in data:
|
||||
cprint(data, "red")
|
||||
continue
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
**json.loads(data)
|
||||
)
|
||||
else:
|
||||
yield ChatCompletionResponse(**json.loads(data))
|
||||
except Exception as e:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue