use route_request for making llm call

This commit is contained in:
Ishaan Jaff 2024-08-15 08:16:44 -07:00
parent eb6a0a32f1
commit fdd6664420
3 changed files with 117 additions and 62 deletions

View file

@ -19,6 +19,9 @@ model_list:
litellm_params: litellm_params:
model: mistral/mistral-small-latest model: mistral/mistral-small-latest
api_key: "os.environ/MISTRAL_API_KEY" api_key: "os.environ/MISTRAL_API_KEY"
- model_name: bedrock-anthropic
litellm_params:
model: bedrock/anthropic.claude-3-sonnet-20240229-v1:0
- model_name: gemini-1.5-pro-001 - model_name: gemini-1.5-pro-001
litellm_params: litellm_params:
model: vertex_ai_beta/gemini-1.5-pro-001 model: vertex_ai_beta/gemini-1.5-pro-001
@ -40,3 +43,6 @@ general_settings:
litellm_settings: litellm_settings:
fallbacks: [{"gemini-1.5-pro-001": ["gpt-4o"]}] fallbacks: [{"gemini-1.5-pro-001": ["gpt-4o"]}]
callbacks: ["gcs_bucket"] callbacks: ["gcs_bucket"]
success_callback: ["langfuse"]
langfuse_default_tags: ["cache_hit", "cache_key", "user_api_key_alias", "user_api_key_team_alias"]
cache: True

View file

@ -187,6 +187,7 @@ from litellm.proxy.openai_files_endpoints.files_endpoints import set_files_confi
from litellm.proxy.pass_through_endpoints.pass_through_endpoints import ( from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
initialize_pass_through_endpoints, initialize_pass_through_endpoints,
) )
from litellm.proxy.route_llm_request import route_request
from litellm.proxy.secret_managers.aws_secret_manager import ( from litellm.proxy.secret_managers.aws_secret_manager import (
load_aws_kms, load_aws_kms,
load_aws_secret_manager, load_aws_secret_manager,
@ -3006,68 +3007,13 @@ async def chat_completion(
### ROUTE THE REQUEST ### ### ROUTE THE REQUEST ###
# Do not change this - it should be a constant time fetch - ALWAYS # Do not change this - it should be a constant time fetch - ALWAYS
router_model_names = llm_router.model_names if llm_router is not None else [] llm_call = await route_request(
# skip router if user passed their key data=data,
if "api_key" in data: route_type="acompletion",
tasks.append(litellm.acompletion(**data)) llm_router=llm_router,
elif "," in data["model"] and llm_router is not None: user_model=user_model,
if ( )
data.get("fastest_response", None) is not None tasks.append(llm_call)
and data["fastest_response"] == True
):
tasks.append(llm_router.abatch_completion_fastest_response(**data))
else:
_models_csv_string = data.pop("model")
_models = [model.strip() for model in _models_csv_string.split(",")]
tasks.append(llm_router.abatch_completion(models=_models, **data))
elif "user_config" in data:
# initialize a new router instance. make request using this Router
router_config = data.pop("user_config")
user_router = litellm.Router(**router_config)
tasks.append(user_router.acompletion(**data))
elif (
llm_router is not None and data["model"] in router_model_names
): # model in router model list
tasks.append(llm_router.acompletion(**data))
elif (
llm_router is not None and data["model"] in llm_router.get_model_ids()
): # model in router model list
tasks.append(llm_router.acompletion(**data))
elif (
llm_router is not None
and llm_router.model_group_alias is not None
and data["model"] in llm_router.model_group_alias
): # model set in model_group_alias
tasks.append(llm_router.acompletion(**data))
elif (
llm_router is not None and data["model"] in llm_router.deployment_names
): # model in router deployments, calling a specific deployment on the router
tasks.append(llm_router.acompletion(**data, specific_deployment=True))
elif (
llm_router is not None
and data["model"] not in router_model_names
and llm_router.router_general_settings.pass_through_all_models is True
):
tasks.append(litellm.acompletion(**data))
elif (
llm_router is not None
and data["model"] not in router_model_names
and (
llm_router.default_deployment is not None
or len(llm_router.provider_default_deployments) > 0
)
): # model in router deployments, calling a specific deployment on the router
tasks.append(llm_router.acompletion(**data))
elif user_model is not None: # `litellm --model <your-model-name>`
tasks.append(litellm.acompletion(**data))
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"error": "chat_completion: Invalid model name passed in model="
+ data.get("model", "")
},
)
# wait for call to end # wait for call to end
llm_responses = asyncio.gather( llm_responses = asyncio.gather(

View file

@ -0,0 +1,103 @@
from typing import TYPE_CHECKING, Any, Literal, Optional, Union
from fastapi import (
Depends,
FastAPI,
File,
Form,
Header,
HTTPException,
Path,
Request,
Response,
UploadFile,
status,
)
import litellm
from litellm._logging import verbose_logger
if TYPE_CHECKING:
from litellm.router import Router as _Router
LitellmRouter = _Router
else:
LitellmRouter = Any
async def route_request(
data: dict,
llm_router: Optional[LitellmRouter],
user_model: Optional[str],
route_type: Literal[
"acompletion",
"atext_completion",
"aembedding",
"aimage_generation",
"aspeech",
"atranscription",
"amoderation",
],
):
"""
Common helper to route the request
"""
router_model_names = llm_router.model_names if llm_router is not None else []
if "api_key" in data:
return await getattr(litellm, f"{route_type}")(**data)
elif "user_config" in data:
router_config = data.pop("user_config")
user_router = litellm.Router(**router_config)
return await getattr(user_router, f"{route_type}")(**data)
elif (
"," in data.get("model", "")
and llm_router is not None
and route_type == "acompletion"
):
if data.get("fastest_response", False):
return await llm_router.abatch_completion_fastest_response(**data)
else:
models = [model.strip() for model in data.pop("model").split(",")]
return await llm_router.abatch_completion(models=models, **data)
elif llm_router is not None:
if (
data["model"] in router_model_names
or data["model"] in llm_router.get_model_ids()
):
return await getattr(llm_router, f"{route_type}")(**data)
elif (
llm_router.model_group_alias is not None
and data["model"] in llm_router.model_group_alias
):
return await getattr(llm_router, f"{route_type}")(**data)
elif data["model"] in llm_router.deployment_names:
return await getattr(llm_router, f"{route_type}")(
**data, specific_deployment=True
)
elif data["model"] not in router_model_names:
if llm_router.router_general_settings.pass_through_all_models:
return await getattr(litellm, f"{route_type}")(**data)
elif (
llm_router.default_deployment is not None
or len(llm_router.provider_default_deployments) > 0
):
return await getattr(llm_router, f"{route_type}")(**data)
elif user_model is not None:
return await getattr(litellm, f"{route_type}")(**data)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"error": f"{route_type}: Invalid model name passed in model="
+ data.get("model", "")
},
)