diff --git a/litellm/router.py b/litellm/router.py index 5a4d83885..51fb12ea8 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -17,6 +17,7 @@ import inspect import json import logging import random +import re import threading import time import traceback @@ -310,6 +311,7 @@ class Router: ) self.default_deployment = None # use this to track the users default deployment, when they want to use model = * self.default_max_parallel_requests = default_max_parallel_requests + self.provider_default_deployments: Dict[str, List] = {} if model_list is not None: model_list = copy.deepcopy(model_list) @@ -3607,6 +3609,10 @@ class Router: ), ) + provider_specific_deployment = re.match( + f"{custom_llm_provider}/*", deployment.model_name + ) + # Check if user is trying to use model_name == "*" # this is a catch all model for their specific api key if deployment.model_name == "*": @@ -3615,6 +3621,17 @@ class Router: self.router_general_settings.pass_through_all_models = True else: self.default_deployment = deployment.to_json(exclude_none=True) + # Check if user is using provider specific wildcard routing + # example model_name = "databricks/*" or model_name = "anthropic/*" + elif provider_specific_deployment: + if custom_llm_provider in self.provider_default_deployments: + self.provider_default_deployments[custom_llm_provider].append( + deployment.to_json(exclude_none=True) + ) + else: + self.provider_default_deployments[custom_llm_provider] = [ + deployment.to_json(exclude_none=True) + ] # Azure GPT-Vision Enhancements, users can pass os.environ/ data_sources = deployment.litellm_params.get("dataSources", []) or []