mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
refactor use 1 util for llm routing
This commit is contained in:
parent
d50f26d73d
commit
58828403ea
1 changed files with 42 additions and 256 deletions
|
@ -3236,58 +3236,15 @@ async def completion(
|
|||
)
|
||||
|
||||
### ROUTE THE REQUESTs ###
|
||||
router_model_names = llm_router.model_names if llm_router is not None else []
|
||||
# skip router if user passed their key
|
||||
if "api_key" in data:
|
||||
llm_response = asyncio.create_task(litellm.atext_completion(**data))
|
||||
elif (
|
||||
llm_router is not None and data["model"] in router_model_names
|
||||
): # model in router model list
|
||||
llm_response = asyncio.create_task(llm_router.atext_completion(**data))
|
||||
elif (
|
||||
llm_router is not None
|
||||
and llm_router.model_group_alias is not None
|
||||
and data["model"] in llm_router.model_group_alias
|
||||
): # model set in model_group_alias
|
||||
llm_response = asyncio.create_task(llm_router.atext_completion(**data))
|
||||
elif (
|
||||
llm_router is not None and data["model"] in llm_router.deployment_names
|
||||
): # model in router deployments, calling a specific deployment on the router
|
||||
llm_response = asyncio.create_task(
|
||||
llm_router.atext_completion(**data, specific_deployment=True)
|
||||
)
|
||||
elif (
|
||||
llm_router is not None and data["model"] in llm_router.get_model_ids()
|
||||
): # model in router model list
|
||||
llm_response = asyncio.create_task(llm_router.atext_completion(**data))
|
||||
elif (
|
||||
llm_router is not None
|
||||
and data["model"] not in router_model_names
|
||||
and llm_router.router_general_settings.pass_through_all_models is True
|
||||
):
|
||||
llm_response = asyncio.create_task(litellm.atext_completion(**data))
|
||||
elif (
|
||||
llm_router is not None
|
||||
and data["model"] not in router_model_names
|
||||
and (
|
||||
llm_router.default_deployment is not None
|
||||
or len(llm_router.provider_default_deployments) > 0
|
||||
)
|
||||
): # model in router deployments, calling a specific deployment on the router
|
||||
llm_response = asyncio.create_task(llm_router.atext_completion(**data))
|
||||
elif user_model is not None: # `litellm --model <your-model-name>`
|
||||
llm_response = asyncio.create_task(litellm.atext_completion(**data))
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={
|
||||
"error": "completion: Invalid model name passed in model="
|
||||
+ data.get("model", "")
|
||||
},
|
||||
)
|
||||
llm_call = await route_request(
|
||||
data=data,
|
||||
route_type="atext_completion",
|
||||
llm_router=llm_router,
|
||||
user_model=user_model,
|
||||
)
|
||||
|
||||
# Await the llm_response task
|
||||
response = await llm_response
|
||||
response = await llm_call
|
||||
|
||||
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
||||
model_id = hidden_params.get("model_id", None) or ""
|
||||
|
@ -3501,59 +3458,13 @@ async def embeddings(
|
|||
)
|
||||
|
||||
## ROUTE TO CORRECT ENDPOINT ##
|
||||
# skip router if user passed their key
|
||||
if "api_key" in data:
|
||||
tasks.append(litellm.aembedding(**data))
|
||||
elif "user_config" in data:
|
||||
# initialize a new router instance. make request using this Router
|
||||
router_config = data.pop("user_config")
|
||||
user_router = litellm.Router(**router_config)
|
||||
tasks.append(user_router.aembedding(**data))
|
||||
elif (
|
||||
llm_router is not None and data["model"] in router_model_names
|
||||
): # model in router model list
|
||||
tasks.append(llm_router.aembedding(**data))
|
||||
elif (
|
||||
llm_router is not None
|
||||
and llm_router.model_group_alias is not None
|
||||
and data["model"] in llm_router.model_group_alias
|
||||
): # model set in model_group_alias
|
||||
tasks.append(
|
||||
llm_router.aembedding(**data)
|
||||
) # ensure this goes the llm_router, router will do the correct alias mapping
|
||||
elif (
|
||||
llm_router is not None and data["model"] in llm_router.deployment_names
|
||||
): # model in router deployments, calling a specific deployment on the router
|
||||
tasks.append(llm_router.aembedding(**data, specific_deployment=True))
|
||||
elif (
|
||||
llm_router is not None and data["model"] in llm_router.get_model_ids()
|
||||
): # model in router deployments, calling a specific deployment on the router
|
||||
tasks.append(llm_router.aembedding(**data))
|
||||
elif (
|
||||
llm_router is not None
|
||||
and data["model"] not in router_model_names
|
||||
and llm_router.router_general_settings.pass_through_all_models is True
|
||||
):
|
||||
tasks.append(litellm.aembedding(**data))
|
||||
elif (
|
||||
llm_router is not None
|
||||
and data["model"] not in router_model_names
|
||||
and (
|
||||
llm_router.default_deployment is not None
|
||||
or len(llm_router.provider_default_deployments) > 0
|
||||
)
|
||||
): # model in router deployments, calling a specific deployment on the router
|
||||
tasks.append(llm_router.aembedding(**data))
|
||||
elif user_model is not None: # `litellm --model <your-model-name>`
|
||||
tasks.append(litellm.aembedding(**data))
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={
|
||||
"error": "embeddings: Invalid model name passed in model="
|
||||
+ data.get("model", "")
|
||||
},
|
||||
)
|
||||
llm_call = await route_request(
|
||||
data=data,
|
||||
route_type="aembedding",
|
||||
llm_router=llm_router,
|
||||
user_model=user_model,
|
||||
)
|
||||
tasks.append(llm_call)
|
||||
|
||||
# wait for call to end
|
||||
llm_responses = asyncio.gather(
|
||||
|
@ -3684,46 +3595,13 @@ async def image_generation(
|
|||
)
|
||||
|
||||
## ROUTE TO CORRECT ENDPOINT ##
|
||||
# skip router if user passed their key
|
||||
if "api_key" in data:
|
||||
response = await litellm.aimage_generation(**data)
|
||||
elif (
|
||||
llm_router is not None and data["model"] in router_model_names
|
||||
): # model in router model list
|
||||
response = await llm_router.aimage_generation(**data)
|
||||
elif (
|
||||
llm_router is not None and data["model"] in llm_router.deployment_names
|
||||
): # model in router deployments, calling a specific deployment on the router
|
||||
response = await llm_router.aimage_generation(
|
||||
**data, specific_deployment=True
|
||||
)
|
||||
elif (
|
||||
llm_router is not None
|
||||
and llm_router.model_group_alias is not None
|
||||
and data["model"] in llm_router.model_group_alias
|
||||
): # model set in model_group_alias
|
||||
response = await llm_router.aimage_generation(
|
||||
**data
|
||||
) # ensure this goes the llm_router, router will do the correct alias mapping
|
||||
elif (
|
||||
llm_router is not None
|
||||
and data["model"] not in router_model_names
|
||||
and (
|
||||
llm_router.default_deployment is not None
|
||||
or len(llm_router.provider_default_deployments) > 0
|
||||
)
|
||||
): # model in router deployments, calling a specific deployment on the router
|
||||
response = await llm_router.aimage_generation(**data)
|
||||
elif user_model is not None: # `litellm --model <your-model-name>`
|
||||
response = await litellm.aimage_generation(**data)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={
|
||||
"error": "image_generation: Invalid model name passed in model="
|
||||
+ data.get("model", "")
|
||||
},
|
||||
)
|
||||
llm_call = await route_request(
|
||||
data=data,
|
||||
route_type="aimage_generation",
|
||||
llm_router=llm_router,
|
||||
user_model=user_model,
|
||||
)
|
||||
response = await llm_call
|
||||
|
||||
### ALERTING ###
|
||||
asyncio.create_task(
|
||||
|
@ -3831,44 +3709,13 @@ async def audio_speech(
|
|||
)
|
||||
|
||||
## ROUTE TO CORRECT ENDPOINT ##
|
||||
# skip router if user passed their key
|
||||
if "api_key" in data:
|
||||
response = await litellm.aspeech(**data)
|
||||
elif (
|
||||
llm_router is not None and data["model"] in router_model_names
|
||||
): # model in router model list
|
||||
response = await llm_router.aspeech(**data)
|
||||
elif (
|
||||
llm_router is not None and data["model"] in llm_router.deployment_names
|
||||
): # model in router deployments, calling a specific deployment on the router
|
||||
response = await llm_router.aspeech(**data, specific_deployment=True)
|
||||
elif (
|
||||
llm_router is not None
|
||||
and llm_router.model_group_alias is not None
|
||||
and data["model"] in llm_router.model_group_alias
|
||||
): # model set in model_group_alias
|
||||
response = await llm_router.aspeech(
|
||||
**data
|
||||
) # ensure this goes the llm_router, router will do the correct alias mapping
|
||||
elif (
|
||||
llm_router is not None
|
||||
and data["model"] not in router_model_names
|
||||
and (
|
||||
llm_router.default_deployment is not None
|
||||
or len(llm_router.provider_default_deployments) > 0
|
||||
)
|
||||
): # model in router deployments, calling a specific deployment on the router
|
||||
response = await llm_router.aspeech(**data)
|
||||
elif user_model is not None: # `litellm --model <your-model-name>`
|
||||
response = await litellm.aspeech(**data)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={
|
||||
"error": "audio_speech: Invalid model name passed in model="
|
||||
+ data.get("model", "")
|
||||
},
|
||||
)
|
||||
llm_call = await route_request(
|
||||
data=data,
|
||||
route_type="aspeech",
|
||||
llm_router=llm_router,
|
||||
user_model=user_model,
|
||||
)
|
||||
response = await llm_call
|
||||
|
||||
### ALERTING ###
|
||||
asyncio.create_task(
|
||||
|
@ -4001,47 +3848,13 @@ async def audio_transcriptions(
|
|||
)
|
||||
|
||||
## ROUTE TO CORRECT ENDPOINT ##
|
||||
# skip router if user passed their key
|
||||
if "api_key" in data:
|
||||
response = await litellm.atranscription(**data)
|
||||
elif (
|
||||
llm_router is not None and data["model"] in router_model_names
|
||||
): # model in router model list
|
||||
response = await llm_router.atranscription(**data)
|
||||
|
||||
elif (
|
||||
llm_router is not None and data["model"] in llm_router.deployment_names
|
||||
): # model in router deployments, calling a specific deployment on the router
|
||||
response = await llm_router.atranscription(
|
||||
**data, specific_deployment=True
|
||||
)
|
||||
elif (
|
||||
llm_router is not None
|
||||
and llm_router.model_group_alias is not None
|
||||
and data["model"] in llm_router.model_group_alias
|
||||
): # model set in model_group_alias
|
||||
response = await llm_router.atranscription(
|
||||
**data
|
||||
) # ensure this goes the llm_router, router will do the correct alias mapping
|
||||
elif (
|
||||
llm_router is not None
|
||||
and data["model"] not in router_model_names
|
||||
and (
|
||||
llm_router.default_deployment is not None
|
||||
or len(llm_router.provider_default_deployments) > 0
|
||||
)
|
||||
): # model in router deployments, calling a specific deployment on the router
|
||||
response = await llm_router.atranscription(**data)
|
||||
elif user_model is not None: # `litellm --model <your-model-name>`
|
||||
response = await litellm.atranscription(**data)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={
|
||||
"error": "audio_transcriptions: Invalid model name passed in model="
|
||||
+ data.get("model", "")
|
||||
},
|
||||
)
|
||||
llm_call = await route_request(
|
||||
data=data,
|
||||
route_type="atranscription",
|
||||
llm_router=llm_router,
|
||||
user_model=user_model,
|
||||
)
|
||||
response = await llm_call
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
finally:
|
||||
|
@ -5257,40 +5070,13 @@ async def moderations(
|
|||
start_time = time.time()
|
||||
|
||||
## ROUTE TO CORRECT ENDPOINT ##
|
||||
# skip router if user passed their key
|
||||
if "api_key" in data:
|
||||
response = await litellm.amoderation(**data)
|
||||
elif (
|
||||
llm_router is not None and data.get("model") in router_model_names
|
||||
): # model in router model list
|
||||
response = await llm_router.amoderation(**data)
|
||||
elif (
|
||||
llm_router is not None and data.get("model") in llm_router.deployment_names
|
||||
): # model in router deployments, calling a specific deployment on the router
|
||||
response = await llm_router.amoderation(**data, specific_deployment=True)
|
||||
elif (
|
||||
llm_router is not None
|
||||
and llm_router.model_group_alias is not None
|
||||
and data.get("model") in llm_router.model_group_alias
|
||||
): # model set in model_group_alias
|
||||
response = await llm_router.amoderation(
|
||||
**data
|
||||
) # ensure this goes the llm_router, router will do the correct alias mapping
|
||||
elif (
|
||||
llm_router is not None
|
||||
and data.get("model") not in router_model_names
|
||||
and (
|
||||
llm_router.default_deployment is not None
|
||||
or len(llm_router.provider_default_deployments) > 0
|
||||
)
|
||||
): # model in router deployments, calling a specific deployment on the router
|
||||
response = await llm_router.amoderation(**data)
|
||||
elif user_model is not None: # `litellm --model <your-model-name>`
|
||||
response = await litellm.amoderation(**data)
|
||||
else:
|
||||
# /moderations does not need a "model" passed
|
||||
# see https://platform.openai.com/docs/api-reference/moderations
|
||||
response = await litellm.amoderation(**data)
|
||||
llm_call = await route_request(
|
||||
data=data,
|
||||
route_type="amoderation",
|
||||
llm_router=llm_router,
|
||||
user_model=user_model,
|
||||
)
|
||||
response = await llm_call
|
||||
|
||||
### ALERTING ###
|
||||
asyncio.create_task(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue