mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
Improved exception handling
This commit is contained in:
parent
493f0d99b2
commit
af4710c959
2 changed files with 58 additions and 10 deletions
|
@ -19,11 +19,12 @@ import httpx
|
||||||
import yaml
|
import yaml
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
from fastapi import FastAPI, Request
|
from fastapi import FastAPI, HTTPException, Request
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.exceptions import RequestValidationError
|
||||||
|
from fastapi.responses import JSONResponse, StreamingResponse
|
||||||
from fastapi.routing import APIRoute
|
from fastapi.routing import APIRoute
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, ValidationError
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
from .datatypes import PassthroughApiAdapter
|
from .datatypes import PassthroughApiAdapter
|
||||||
|
@ -58,6 +59,22 @@ def create_sse_event(data: Any) -> str:
|
||||||
return f"data: {data}\n\n"
|
return f"data: {data}\n\n"
|
||||||
|
|
||||||
|
|
||||||
|
async def global_exception_handler(request: Request, exc: Exception):
|
||||||
|
http_exc = translate_exception(exc)
|
||||||
|
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=http_exc.status_code, content={"error": {"detail": http_exc.detail}}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def translate_exception(exc: Exception) -> HTTPException:
|
||||||
|
if isinstance(exc, ValidationError):
|
||||||
|
return RequestValidationError(exc.raw_errors)
|
||||||
|
|
||||||
|
# Add more custom exception translations here
|
||||||
|
return HTTPException(status_code=500, detail="Internal server error")
|
||||||
|
|
||||||
|
|
||||||
async def passthrough(
|
async def passthrough(
|
||||||
request: Request,
|
request: Request,
|
||||||
downstream_url: str,
|
downstream_url: str,
|
||||||
|
@ -122,17 +139,34 @@ def create_dynamic_typed_route(func: Any):
|
||||||
if is_streaming:
|
if is_streaming:
|
||||||
|
|
||||||
async def endpoint(request: request_model):
|
async def endpoint(request: request_model):
|
||||||
async def event_generator():
|
async def sse_generator(event_gen):
|
||||||
async for item in func(request):
|
try:
|
||||||
yield create_sse_event(item)
|
async for item in event_gen:
|
||||||
await asyncio.sleep(0.001)
|
yield create_sse_event(item)
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
print("Generator cancelled")
|
||||||
|
await event_gen.aclose()
|
||||||
|
except Exception as e:
|
||||||
|
yield create_sse_event(
|
||||||
|
{
|
||||||
|
"error": {
|
||||||
|
"message": str(translate_exception(e)),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
return StreamingResponse(event_generator(), media_type="text/event-stream")
|
return StreamingResponse(
|
||||||
|
sse_generator(func(request2)), media_type="text/event-stream"
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|
||||||
async def endpoint(request: request_model):
|
async def endpoint(request: request_model):
|
||||||
return func(request)
|
try:
|
||||||
|
return func(request)
|
||||||
|
except Exception as e:
|
||||||
|
raise translate_exception(e) from e
|
||||||
|
|
||||||
return endpoint
|
return endpoint
|
||||||
|
|
||||||
|
@ -188,6 +222,7 @@ def main(
|
||||||
attrs=["bold"],
|
attrs=["bold"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
app.exception_handler(Exception)(global_exception_handler)
|
||||||
signal.signal(signal.SIGINT, handle_sigint)
|
signal.signal(signal.SIGINT, handle_sigint)
|
||||||
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|
|
@ -46,12 +46,25 @@ class InferenceClient(Inference):
|
||||||
headers={"Content-Type": "application/json"},
|
headers={"Content-Type": "application/json"},
|
||||||
timeout=20,
|
timeout=20,
|
||||||
) as response:
|
) as response:
|
||||||
|
if response.status_code != 200:
|
||||||
|
content = await response.aread()
|
||||||
|
cprint(
|
||||||
|
f"Error: HTTP {response.status_code} {content.decode()}", "red"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
async for line in response.aiter_lines():
|
async for line in response.aiter_lines():
|
||||||
if line.startswith("data:"):
|
if line.startswith("data:"):
|
||||||
data = line[len("data: ") :]
|
data = line[len("data: ") :]
|
||||||
try:
|
try:
|
||||||
if request.stream:
|
if request.stream:
|
||||||
yield ChatCompletionResponseStreamChunk(**json.loads(data))
|
if "error" in data:
|
||||||
|
cprint(data, "red")
|
||||||
|
continue
|
||||||
|
|
||||||
|
yield ChatCompletionResponseStreamChunk(
|
||||||
|
**json.loads(data)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
yield ChatCompletionResponse(**json.loads(data))
|
yield ChatCompletionResponse(**json.loads(data))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue