simplify logic for routing llm request

This commit is contained in:
Ishaan Jaff 2024-08-15 08:29:28 -07:00
parent 4cbde8af39
commit 4a30781919

View file

@ -46,12 +46,12 @@ async def route_request(
router_model_names = llm_router.model_names if llm_router is not None else [] router_model_names = llm_router.model_names if llm_router is not None else []
if "api_key" in data: if "api_key" in data:
return await getattr(litellm, f"{route_type}")(**data) return getattr(litellm, f"{route_type}")(**data)
elif "user_config" in data: elif "user_config" in data:
router_config = data.pop("user_config") router_config = data.pop("user_config")
user_router = litellm.Router(**router_config) user_router = litellm.Router(**router_config)
return await getattr(user_router, f"{route_type}")(**data) return getattr(user_router, f"{route_type}")(**data)
elif ( elif (
"," in data.get("model", "") "," in data.get("model", "")
@ -59,40 +59,40 @@ async def route_request(
and route_type == "acompletion" and route_type == "acompletion"
): ):
if data.get("fastest_response", False): if data.get("fastest_response", False):
return await llm_router.abatch_completion_fastest_response(**data) return llm_router.abatch_completion_fastest_response(**data)
else: else:
models = [model.strip() for model in data.pop("model").split(",")] models = [model.strip() for model in data.pop("model").split(",")]
return await llm_router.abatch_completion(models=models, **data) return llm_router.abatch_completion(models=models, **data)
elif llm_router is not None: elif llm_router is not None:
if ( if (
data["model"] in router_model_names data["model"] in router_model_names
or data["model"] in llm_router.get_model_ids() or data["model"] in llm_router.get_model_ids()
): ):
return await getattr(llm_router, f"{route_type}")(**data) return getattr(llm_router, f"{route_type}")(**data)
elif ( elif (
llm_router.model_group_alias is not None llm_router.model_group_alias is not None
and data["model"] in llm_router.model_group_alias and data["model"] in llm_router.model_group_alias
): ):
return await getattr(llm_router, f"{route_type}")(**data) return getattr(llm_router, f"{route_type}")(**data)
elif data["model"] in llm_router.deployment_names: elif data["model"] in llm_router.deployment_names:
return await getattr(llm_router, f"{route_type}")( return getattr(llm_router, f"{route_type}")(
**data, specific_deployment=True **data, specific_deployment=True
) )
elif data["model"] not in router_model_names: elif data["model"] not in router_model_names:
if llm_router.router_general_settings.pass_through_all_models: if llm_router.router_general_settings.pass_through_all_models:
return await getattr(litellm, f"{route_type}")(**data) return getattr(litellm, f"{route_type}")(**data)
elif ( elif (
llm_router.default_deployment is not None llm_router.default_deployment is not None
or len(llm_router.provider_default_deployments) > 0 or len(llm_router.provider_default_deployments) > 0
): ):
return await getattr(llm_router, f"{route_type}")(**data) return getattr(llm_router, f"{route_type}")(**data)
elif user_model is not None: elif user_model is not None:
return await getattr(litellm, f"{route_type}")(**data) return getattr(litellm, f"{route_type}")(**data)
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,