From f136f802b1e1596e907559c3539aa344cd6d06bc Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Mon, 23 Sep 2024 21:39:47 -0700 Subject: [PATCH] Somewhat better error handling --- llama_stack/distribution/server/server.py | 37 +++++++++++++++++++---- 1 file changed, 31 insertions(+), 6 deletions(-) diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index f09e1c586..38218ab8b 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -35,6 +35,9 @@ from fastapi import Body, FastAPI, HTTPException, Request, Response from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse, StreamingResponse from fastapi.routing import APIRoute +from pydantic import BaseModel, ValidationError +from termcolor import cprint +from typing_extensions import Annotated from llama_stack.providers.utils.telemetry.tracing import ( end_trace, @@ -42,9 +45,6 @@ from llama_stack.providers.utils.telemetry.tracing import ( SpanStatus, start_trace, ) -from pydantic import BaseModel, ValidationError -from termcolor import cprint -from typing_extensions import Annotated from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.distribution.distribution import ( @@ -90,10 +90,35 @@ async def global_exception_handler(request: Request, exc: Exception): def translate_exception(exc: Exception) -> HTTPException: if isinstance(exc, ValidationError): - return RequestValidationError(exc.raw_errors) + exc = RequestValidationError(exc.raw_errors) - # Add more custom exception translations here - return HTTPException(status_code=500, detail="Internal server error") + if isinstance(exc, RequestValidationError): + return HTTPException( + status_code=400, + detail={ + "errors": [ + { + "loc": list(error["loc"]), + "msg": error["msg"], + "type": error["type"], + } + for error in exc.errors() + ] + }, + ) + elif isinstance(exc, ValueError): + return HTTPException(status_code=400, detail=f"Invalid value: {str(exc)}") + elif isinstance(exc, PermissionError): + return HTTPException(status_code=403, detail=f"Permission denied: {str(exc)}") + elif isinstance(exc, TimeoutError): + return HTTPException(status_code=504, detail=f"Operation timed out: {str(exc)}") + elif isinstance(exc, NotImplementedError): + return HTTPException(status_code=501, detail=f"Not implemented: {str(exc)}") + else: + return HTTPException( + status_code=500, + detail="Internal server error: An unexpected error occurred.", + ) async def passthrough(