router support setting pass_through_all_models

This commit is contained in:
Ishaan Jaff 2024-07-25 18:34:12 -07:00
parent e67daf79be
commit 8f4c5437b8
2 changed files with 25 additions and 4 deletions

View file

@ -2885,6 +2885,11 @@ async def chat_completion(
and llm_router.default_deployment is not None
): # model in router deployments, calling a specific deployment on the router
tasks.append(llm_router.acompletion(**data))
elif (
llm_router is not None
and llm_router.router_general_settings.pass_through_all_models is True
):
tasks.append(litellm.acompletion(**data))
elif user_model is not None: # `litellm --model <your-model-name>`
tasks.append(litellm.acompletion(**data))
else:
@ -3147,6 +3152,11 @@ async def completion(
llm_response = asyncio.create_task(llm_router.atext_completion(**data))
elif user_model is not None: # `litellm --model <your-model-name>`
llm_response = asyncio.create_task(litellm.atext_completion(**data))
elif (
llm_router is not None
and llm_router.router_general_settings.pass_through_all_models is True
):
llm_response = asyncio.create_task(litellm.atext_completion(**data))
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
@ -3405,6 +3415,11 @@ async def embeddings(
and llm_router.default_deployment is not None
): # model in router deployments, calling a specific deployment on the router
tasks.append(llm_router.aembedding(**data))
elif (
llm_router is not None
and llm_router.router_general_settings.pass_through_all_models is True
):
tasks.append(litellm.aembedding(**data))
elif user_model is not None: # `litellm --model <your-model-name>`
tasks.append(litellm.aembedding(**data))
else:

View file

@ -174,7 +174,9 @@ class Router:
routing_strategy_args: dict = {}, # just for latency-based routing
semaphore: Optional[asyncio.Semaphore] = None,
alerting_config: Optional[AlertingConfig] = None,
router_general_settings: Optional[RouterGeneralSettings] = None,
router_general_settings: Optional[
RouterGeneralSettings
] = RouterGeneralSettings(),
) -> None:
"""
Initialize the Router class with the given parameters for caching, reliability, and routing strategy.
@ -253,8 +255,8 @@ class Router:
verbose_router_logger.setLevel(logging.INFO)
elif debug_level == "DEBUG":
verbose_router_logger.setLevel(logging.DEBUG)
self.router_general_settings: Optional[RouterGeneralSettings] = (
router_general_settings
self.router_general_settings: RouterGeneralSettings = (
router_general_settings or RouterGeneralSettings()
)
self.assistants_config = assistants_config
@ -3554,7 +3556,11 @@ class Router:
# Check if user is trying to use model_name == "*"
# this is a catch all model for their specific api key
if deployment.model_name == "*":
self.default_deployment = deployment.to_json(exclude_none=True)
if deployment.litellm_params.model == "*":
# user wants to pass through all requests to litellm.acompletion for unknown deployments
self.router_general_settings.pass_through_all_models = True
else:
self.default_deployment = deployment.to_json(exclude_none=True)
# Azure GPT-Vision Enhancements, users can pass os.environ/
data_sources = deployment.litellm_params.get("dataSources", []) or []