fix(proxy_server.py): fix exception raising

This commit is contained in:
Krrish Dholakia 2024-02-12 11:13:02 -08:00
parent 8f989235ea
commit 1a452057af
2 changed files with 24 additions and 10 deletions

View file

@ -2376,8 +2376,8 @@ async def chat_completion(
response = await litellm.acompletion(**data) response = await litellm.acompletion(**data)
else: else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, status_code=status.HTTP_400_BAD_REQUEST,
detail={"message": "Invalid model name passed in"}, detail={"error": "Invalid model name passed in"},
) )
# Post Call Processing # Post Call Processing
@ -2439,7 +2439,12 @@ async def chat_completion(
traceback.print_exc() traceback.print_exc()
if isinstance(e, HTTPException): if isinstance(e, HTTPException):
raise e raise ProxyException(
message=getattr(e, "detail", str(e)),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
)
else: else:
error_traceback = traceback.format_exc() error_traceback = traceback.format_exc()
error_msg = f"{str(e)}\n\n{error_traceback}" error_msg = f"{str(e)}\n\n{error_traceback}"
@ -2593,8 +2598,8 @@ async def embeddings(
response = await litellm.aembedding(**data) response = await litellm.aembedding(**data)
else: else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, status_code=status.HTTP_400_BAD_REQUEST,
detail={"message": "Invalid model name passed in"}, detail={"error": "Invalid model name passed in"},
) )
### ALERTING ### ### ALERTING ###
@ -2613,7 +2618,12 @@ async def embeddings(
) )
traceback.print_exc() traceback.print_exc()
if isinstance(e, HTTPException): if isinstance(e, HTTPException):
raise e raise ProxyException(
message=getattr(e, "message", str(e)),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
)
else: else:
error_traceback = traceback.format_exc() error_traceback = traceback.format_exc()
error_msg = f"{str(e)}\n\n{error_traceback}" error_msg = f"{str(e)}\n\n{error_traceback}"
@ -2733,8 +2743,8 @@ async def image_generation(
response = await litellm.aimage_generation(**data) response = await litellm.aimage_generation(**data)
else: else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, status_code=status.HTTP_400_BAD_REQUEST,
detail={"message": "Invalid model name passed in"}, detail={"error": "Invalid model name passed in"},
) )
### ALERTING ### ### ALERTING ###
@ -2753,7 +2763,12 @@ async def image_generation(
) )
traceback.print_exc() traceback.print_exc()
if isinstance(e, HTTPException): if isinstance(e, HTTPException):
raise e raise ProxyException(
message=getattr(e, "message", str(e)),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
)
else: else:
error_traceback = traceback.format_exc() error_traceback = traceback.format_exc()
error_msg = f"{str(e)}\n\n{error_traceback}" error_msg = f"{str(e)}\n\n{error_traceback}"

View file

@ -160,7 +160,6 @@ def test_chat_completion_exception_any_model(client):
response = client.post("/chat/completions", json=test_data) response = client.post("/chat/completions", json=test_data)
json_response = response.json() json_response = response.json()
print("keys in json response", json_response.keys())
assert json_response.keys() == {"error"} assert json_response.keys() == {"error"}
# make an openai client to call _make_status_error_from_response # make an openai client to call _make_status_error_from_response