Merge pull request #1946 from BerriAI/litellm_proxy_routing_fix

Litellm proxy routing fix
This commit is contained in:
Krish Dholakia 2024-02-12 18:12:52 -08:00 committed by GitHub
commit c4c1703c44
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 37 additions and 8 deletions

View file

@ -2373,8 +2373,13 @@ async def chat_completion(
llm_router is not None and data["model"] in llm_router.deployment_names
): # model in router deployments, calling a specific deployment on the router
response = await llm_router.acompletion(**data, specific_deployment=True)
else: # router is not set
elif user_model is not None: # `litellm --model <your-model-name>`
response = await litellm.acompletion(**data)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"error": "Invalid model name passed in"},
)
# Post Call Processing
data["litellm_status"] = "success" # used for alerting
@ -2435,7 +2440,12 @@ async def chat_completion(
traceback.print_exc()
if isinstance(e, HTTPException):
raise e
raise ProxyException(
message=getattr(e, "detail", str(e)),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
)
else:
error_traceback = traceback.format_exc()
error_msg = f"{str(e)}\n\n{error_traceback}"
@ -2585,8 +2595,13 @@ async def embeddings(
llm_router is not None and data["model"] in llm_router.deployment_names
): # model in router deployments, calling a specific deployment on the router
response = await llm_router.aembedding(**data, specific_deployment=True)
else:
elif user_model is not None: # `litellm --model <your-model-name>`
response = await litellm.aembedding(**data)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"error": "Invalid model name passed in"},
)
### ALERTING ###
data["litellm_status"] = "success" # used for alerting
@ -2604,7 +2619,12 @@ async def embeddings(
)
traceback.print_exc()
if isinstance(e, HTTPException):
raise e
raise ProxyException(
message=getattr(e, "message", str(e)),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
)
else:
error_traceback = traceback.format_exc()
error_msg = f"{str(e)}\n\n{error_traceback}"
@ -2720,8 +2740,13 @@ async def image_generation(
response = await llm_router.aimage_generation(
**data
) # ensure this goes the llm_router, router will do the correct alias mapping
else:
elif user_model is not None: # `litellm --model <your-model-name>`
response = await litellm.aimage_generation(**data)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"error": "Invalid model name passed in"},
)
### ALERTING ###
data["litellm_status"] = "success" # used for alerting
@ -2739,7 +2764,12 @@ async def image_generation(
)
traceback.print_exc()
if isinstance(e, HTTPException):
raise e
raise ProxyException(
message=getattr(e, "message", str(e)),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
)
else:
error_traceback = traceback.format_exc()
error_msg = f"{str(e)}\n\n{error_traceback}"