(feat) use regex pattern matching for wildcard routing (#6150)

* use pattern matching for llm deployments

* code quality fix

* fix linting

* add types to PatternMatchRouter

* docs add example config for regex patterns
This commit is contained in:
Ishaan Jaff 2024-10-10 18:24:16 +05:30 committed by GitHub
parent 6005450c8f
commit 89506053a4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 169 additions and 40 deletions

View file

@ -124,6 +124,8 @@ from litellm.utils import (
get_utc_datetime,
)
from .router_utils.pattern_match_deployments import PatternMatchRouter
class RoutingArgs(enum.Enum):
ttl = 60 # 1min (RPM/TPM expire key)
@ -344,8 +346,8 @@ class Router:
self.default_priority = default_priority
self.default_deployment = None # use this to track the users default deployment, when they want to use model = *
self.default_max_parallel_requests = default_max_parallel_requests
self.provider_default_deployments: Dict[str, List] = {}
self.provider_default_deployment_ids: List[str] = []
self.pattern_router = PatternMatchRouter()
if model_list is not None:
model_list = copy.deepcopy(model_list)
@ -4147,10 +4149,6 @@ class Router:
),
)
provider_specific_deployment = re.match(
rf"{custom_llm_provider}/\*$", deployment.model_name
)
# Check if user is trying to use model_name == "*"
# this is a catch all model for their specific api key
if deployment.model_name == "*":
@ -4161,16 +4159,12 @@ class Router:
self.default_deployment = deployment.to_json(exclude_none=True)
# Check if user is using provider specific wildcard routing
# example model_name = "databricks/*" or model_name = "anthropic/*"
elif provider_specific_deployment:
if custom_llm_provider in self.provider_default_deployments:
self.provider_default_deployments[custom_llm_provider].append(
deployment.to_json(exclude_none=True)
)
else:
self.provider_default_deployments[custom_llm_provider] = [
deployment.to_json(exclude_none=True)
]
elif "*" in deployment.model_name:
# store this as a regex pattern - all deployments matching this pattern will be sent to this deployment
# Store deployment.model_name as a regex pattern
self.pattern_router.add_pattern(
deployment.model_name, deployment.to_json(exclude_none=True)
)
if deployment.model_info.id:
self.provider_default_deployment_ids.append(deployment.model_info.id)
@ -4433,7 +4427,7 @@ class Router:
is_match = True
elif (
"model_name" in model
and model_group in self.provider_default_deployments
and self.pattern_router.route(model_group) is not None
): # wildcard model
is_match = True
@ -5174,28 +5168,16 @@ class Router:
model = _item["model"]
if model not in self.model_names:
# check if provider/ specific wildcard routing
try:
(
_,
custom_llm_provider,
_,
_,
) = litellm.get_llm_provider(model=model)
# check if custom_llm_provider
if custom_llm_provider in self.provider_default_deployments:
_provider_deployments = self.provider_default_deployments[
custom_llm_provider
]
provider_deployments = []
for deployment in _provider_deployments:
dep = copy.deepcopy(deployment)
dep["litellm_params"]["model"] = model
provider_deployments.append(dep)
return model, provider_deployments
except Exception:
# get_llm_provider raises exception when provider is unknown
pass
# check if provider/ specific wildcard routing use pattern matching
_pattern_router_response = self.pattern_router.route(model)
if _pattern_router_response is not None:
provider_deployments = []
for deployment in _pattern_router_response:
dep = copy.deepcopy(deployment)
dep["litellm_params"]["model"] = model
provider_deployments.append(dep)
return model, provider_deployments
# check if default deployment is set
if self.default_deployment is not None: