forked from phoenix/litellm-mirror
(feat) use regex pattern matching for wildcard routing (#6150)
* use pattern matching for llm deployments * code quality fix * fix linting * add types to PatternMatchRouter * docs add example config for regex patterns
This commit is contained in:
parent
6005450c8f
commit
89506053a4
6 changed files with 169 additions and 40 deletions
|
@ -124,6 +124,8 @@ from litellm.utils import (
|
|||
get_utc_datetime,
|
||||
)
|
||||
|
||||
from .router_utils.pattern_match_deployments import PatternMatchRouter
|
||||
|
||||
|
||||
class RoutingArgs(enum.Enum):
|
||||
ttl = 60 # 1min (RPM/TPM expire key)
|
||||
|
@ -344,8 +346,8 @@ class Router:
|
|||
self.default_priority = default_priority
|
||||
self.default_deployment = None # use this to track the users default deployment, when they want to use model = *
|
||||
self.default_max_parallel_requests = default_max_parallel_requests
|
||||
self.provider_default_deployments: Dict[str, List] = {}
|
||||
self.provider_default_deployment_ids: List[str] = []
|
||||
self.pattern_router = PatternMatchRouter()
|
||||
|
||||
if model_list is not None:
|
||||
model_list = copy.deepcopy(model_list)
|
||||
|
@ -4147,10 +4149,6 @@ class Router:
|
|||
),
|
||||
)
|
||||
|
||||
provider_specific_deployment = re.match(
|
||||
rf"{custom_llm_provider}/\*$", deployment.model_name
|
||||
)
|
||||
|
||||
# Check if user is trying to use model_name == "*"
|
||||
# this is a catch all model for their specific api key
|
||||
if deployment.model_name == "*":
|
||||
|
@ -4161,16 +4159,12 @@ class Router:
|
|||
self.default_deployment = deployment.to_json(exclude_none=True)
|
||||
# Check if user is using provider specific wildcard routing
|
||||
# example model_name = "databricks/*" or model_name = "anthropic/*"
|
||||
elif provider_specific_deployment:
|
||||
if custom_llm_provider in self.provider_default_deployments:
|
||||
self.provider_default_deployments[custom_llm_provider].append(
|
||||
deployment.to_json(exclude_none=True)
|
||||
)
|
||||
else:
|
||||
self.provider_default_deployments[custom_llm_provider] = [
|
||||
deployment.to_json(exclude_none=True)
|
||||
]
|
||||
|
||||
elif "*" in deployment.model_name:
|
||||
# store this as a regex pattern - all deployments matching this pattern will be sent to this deployment
|
||||
# Store deployment.model_name as a regex pattern
|
||||
self.pattern_router.add_pattern(
|
||||
deployment.model_name, deployment.to_json(exclude_none=True)
|
||||
)
|
||||
if deployment.model_info.id:
|
||||
self.provider_default_deployment_ids.append(deployment.model_info.id)
|
||||
|
||||
|
@ -4433,7 +4427,7 @@ class Router:
|
|||
is_match = True
|
||||
elif (
|
||||
"model_name" in model
|
||||
and model_group in self.provider_default_deployments
|
||||
and self.pattern_router.route(model_group) is not None
|
||||
): # wildcard model
|
||||
is_match = True
|
||||
|
||||
|
@ -5174,28 +5168,16 @@ class Router:
|
|||
model = _item["model"]
|
||||
|
||||
if model not in self.model_names:
|
||||
# check if provider/ specific wildcard routing
|
||||
try:
|
||||
(
|
||||
_,
|
||||
custom_llm_provider,
|
||||
_,
|
||||
_,
|
||||
) = litellm.get_llm_provider(model=model)
|
||||
# check if custom_llm_provider
|
||||
if custom_llm_provider in self.provider_default_deployments:
|
||||
_provider_deployments = self.provider_default_deployments[
|
||||
custom_llm_provider
|
||||
]
|
||||
provider_deployments = []
|
||||
for deployment in _provider_deployments:
|
||||
dep = copy.deepcopy(deployment)
|
||||
dep["litellm_params"]["model"] = model
|
||||
provider_deployments.append(dep)
|
||||
return model, provider_deployments
|
||||
except Exception:
|
||||
# get_llm_provider raises exception when provider is unknown
|
||||
pass
|
||||
# check if provider/ specific wildcard routing use pattern matching
|
||||
_pattern_router_response = self.pattern_router.route(model)
|
||||
|
||||
if _pattern_router_response is not None:
|
||||
provider_deployments = []
|
||||
for deployment in _pattern_router_response:
|
||||
dep = copy.deepcopy(deployment)
|
||||
dep["litellm_params"]["model"] = model
|
||||
provider_deployments.append(dep)
|
||||
return model, provider_deployments
|
||||
|
||||
# check if default deployment is set
|
||||
if self.default_deployment is not None:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue