(feat) use regex pattern matching for wildcard routing (#6150)

* use pattern matching for llm deployments

* code quality fix

* fix linting

* add types to PatternMatchRouter

* docs add example config for regex patterns
This commit is contained in:
Ishaan Jaff 2024-10-10 18:24:16 +05:30 committed by GitHub
parent 6005450c8f
commit 89506053a4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 169 additions and 40 deletions

View file

@ -375,6 +375,10 @@ model_list:
litellm_params: litellm_params:
model: "groq/*" model: "groq/*"
api_key: os.environ/GROQ_API_KEY api_key: os.environ/GROQ_API_KEY
- model_name: "fo::*:static::*" # all requests matching this pattern will be routed to this deployment, example: model="fo::hi::static::hi" will be routed to deployment: "openai/fo::*:static::*"
litellm_params:
model: "openai/fo::*:static::*"
api_key: os.environ/OPENAI_API_KEY
``` ```
Step 2 - Run litellm proxy Step 2 - Run litellm proxy
@ -411,6 +415,19 @@ curl http://localhost:4000/v1/chat/completions \
}' }'
``` ```
Test with `fo::*::static::*` - all requests matching this pattern will be routed to `openai/fo::*:static::*`
```shell
curl http://localhost:4000/v1/chat/completions \
-H "Content-Type: application/json" \
-H "Authorization: Bearer sk-1234" \
-d '{
"model": "fo::hi::static::hi",
"messages": [
{"role": "user", "content": "Hello, Claude!"}
]
}'
```
### Load Balancing ### Load Balancing
:::info :::info

View file

@ -5553,7 +5553,7 @@ async def anthropic_response(
and data["model"] not in router_model_names and data["model"] not in router_model_names
and ( and (
llm_router.default_deployment is not None llm_router.default_deployment is not None
or len(llm_router.provider_default_deployments) > 0 or len(llm_router.pattern_router.patterns) > 0
) )
): # model in router deployments, calling a specific deployment on the router ): # model in router deployments, calling a specific deployment on the router
llm_response = asyncio.create_task(llm_router.aadapter_completion(**data)) llm_response = asyncio.create_task(llm_router.aadapter_completion(**data))

View file

@ -109,7 +109,7 @@ async def route_request(
return getattr(litellm, f"{route_type}")(**data) return getattr(litellm, f"{route_type}")(**data)
elif ( elif (
llm_router.default_deployment is not None llm_router.default_deployment is not None
or len(llm_router.provider_default_deployments) > 0 or len(llm_router.pattern_router.patterns) > 0
): ):
return getattr(llm_router, f"{route_type}")(**data) return getattr(llm_router, f"{route_type}")(**data)

View file

@ -124,6 +124,8 @@ from litellm.utils import (
get_utc_datetime, get_utc_datetime,
) )
from .router_utils.pattern_match_deployments import PatternMatchRouter
class RoutingArgs(enum.Enum): class RoutingArgs(enum.Enum):
ttl = 60 # 1min (RPM/TPM expire key) ttl = 60 # 1min (RPM/TPM expire key)
@ -344,8 +346,8 @@ class Router:
self.default_priority = default_priority self.default_priority = default_priority
self.default_deployment = None # use this to track the users default deployment, when they want to use model = * self.default_deployment = None # use this to track the users default deployment, when they want to use model = *
self.default_max_parallel_requests = default_max_parallel_requests self.default_max_parallel_requests = default_max_parallel_requests
self.provider_default_deployments: Dict[str, List] = {}
self.provider_default_deployment_ids: List[str] = [] self.provider_default_deployment_ids: List[str] = []
self.pattern_router = PatternMatchRouter()
if model_list is not None: if model_list is not None:
model_list = copy.deepcopy(model_list) model_list = copy.deepcopy(model_list)
@ -4147,10 +4149,6 @@ class Router:
), ),
) )
provider_specific_deployment = re.match(
rf"{custom_llm_provider}/\*$", deployment.model_name
)
# Check if user is trying to use model_name == "*" # Check if user is trying to use model_name == "*"
# this is a catch all model for their specific api key # this is a catch all model for their specific api key
if deployment.model_name == "*": if deployment.model_name == "*":
@ -4161,16 +4159,12 @@ class Router:
self.default_deployment = deployment.to_json(exclude_none=True) self.default_deployment = deployment.to_json(exclude_none=True)
# Check if user is using provider specific wildcard routing # Check if user is using provider specific wildcard routing
# example model_name = "databricks/*" or model_name = "anthropic/*" # example model_name = "databricks/*" or model_name = "anthropic/*"
elif provider_specific_deployment: elif "*" in deployment.model_name:
if custom_llm_provider in self.provider_default_deployments: # store this as a regex pattern - all deployments matching this pattern will be sent to this deployment
self.provider_default_deployments[custom_llm_provider].append( # Store deployment.model_name as a regex pattern
deployment.to_json(exclude_none=True) self.pattern_router.add_pattern(
) deployment.model_name, deployment.to_json(exclude_none=True)
else: )
self.provider_default_deployments[custom_llm_provider] = [
deployment.to_json(exclude_none=True)
]
if deployment.model_info.id: if deployment.model_info.id:
self.provider_default_deployment_ids.append(deployment.model_info.id) self.provider_default_deployment_ids.append(deployment.model_info.id)
@ -4433,7 +4427,7 @@ class Router:
is_match = True is_match = True
elif ( elif (
"model_name" in model "model_name" in model
and model_group in self.provider_default_deployments and self.pattern_router.route(model_group) is not None
): # wildcard model ): # wildcard model
is_match = True is_match = True
@ -5174,28 +5168,16 @@ class Router:
model = _item["model"] model = _item["model"]
if model not in self.model_names: if model not in self.model_names:
# check if provider/ specific wildcard routing # check if provider/ specific wildcard routing use pattern matching
try: _pattern_router_response = self.pattern_router.route(model)
(
_, if _pattern_router_response is not None:
custom_llm_provider, provider_deployments = []
_, for deployment in _pattern_router_response:
_, dep = copy.deepcopy(deployment)
) = litellm.get_llm_provider(model=model) dep["litellm_params"]["model"] = model
# check if custom_llm_provider provider_deployments.append(dep)
if custom_llm_provider in self.provider_default_deployments: return model, provider_deployments
_provider_deployments = self.provider_default_deployments[
custom_llm_provider
]
provider_deployments = []
for deployment in _provider_deployments:
dep = copy.deepcopy(deployment)
dep["litellm_params"]["model"] = model
provider_deployments.append(dep)
return model, provider_deployments
except Exception:
# get_llm_provider raises exception when provider is unknown
pass
# check if default deployment is set # check if default deployment is set
if self.default_deployment is not None: if self.default_deployment is not None:

View file

@ -0,0 +1,87 @@
"""
Class to handle llm wildcard routing and regex pattern matching
"""
import re
from typing import Dict, List, Optional
class PatternMatchRouter:
"""
Class to handle llm wildcard routing and regex pattern matching
doc: https://docs.litellm.ai/docs/proxy/configs#provider-specific-wildcard-routing
This class will store a mapping for regex pattern: List[Deployments]
"""
def __init__(self):
self.patterns: Dict[str, List] = {}
def add_pattern(self, pattern: str, llm_deployment: Dict):
"""
Add a regex pattern and the corresponding llm deployments to the patterns
Args:
pattern: str
llm_deployment: str or List[str]
"""
# Convert the pattern to a regex
regex = self._pattern_to_regex(pattern)
if regex not in self.patterns:
self.patterns[regex] = []
if isinstance(llm_deployment, list):
self.patterns[regex].extend(llm_deployment)
else:
self.patterns[regex].append(llm_deployment)
def _pattern_to_regex(self, pattern: str) -> str:
"""
Convert a wildcard pattern to a regex pattern
example:
pattern: openai/*
regex: openai/.*
pattern: openai/fo::*::static::*
regex: openai/fo::.*::static::.*
Args:
pattern: str
Returns:
str: regex pattern
"""
# Replace '*' with '.*' for regex matching
regex = pattern.replace("*", ".*")
# Escape other special characters
regex = re.escape(regex).replace(r"\.\*", ".*")
return f"^{regex}$"
def route(self, request: str) -> Optional[List[Dict]]:
"""
Route a requested model to the corresponding llm deployments based on the regex pattern
loop through all the patterns and find the matching pattern
if a pattern is found, return the corresponding llm deployments
if no pattern is found, return None
Args:
request: str
Returns:
Optional[List[Deployment]]: llm deployments
"""
for pattern, llm_deployments in self.patterns.items():
if re.match(pattern, request):
return llm_deployments
return None # No matching pattern found
# Example usage:
# router = PatternRouter()
# router.add_pattern('openai/*', [Deployment(), Deployment()])
# router.add_pattern('openai/fo::*::static::*', Deployment())
# print(router.route('openai/gpt-4')) # Output: [Deployment(), Deployment()]
# print(router.route('openai/fo::hi::static::hi')) # Output: [Deployment()]
# print(router.route('something/else')) # Output: None

View file

@ -73,6 +73,7 @@ async def test_router_provider_wildcard_routing():
Pass list of orgs in 1 model definition, Pass list of orgs in 1 model definition,
expect a unique deployment for each to be created expect a unique deployment for each to be created
""" """
litellm.set_verbose = True
router = litellm.Router( router = litellm.Router(
model_list=[ model_list=[
{ {
@ -124,6 +125,48 @@ async def test_router_provider_wildcard_routing():
print("response 3 = ", response3) print("response 3 = ", response3)
@pytest.mark.asyncio()
async def test_router_provider_wildcard_routing_regex():
"""
Pass list of orgs in 1 model definition,
expect a unique deployment for each to be created
"""
router = litellm.Router(
model_list=[
{
"model_name": "openai/fo::*:static::*",
"litellm_params": {
"model": "openai/fo::*:static::*",
"api_base": "https://exampleopenaiendpoint-production.up.railway.app/",
},
},
{
"model_name": "openai/foo3::hello::*",
"litellm_params": {
"model": "openai/foo3::hello::*",
"api_base": "https://exampleopenaiendpoint-production.up.railway.app/",
},
},
]
)
print("router model list = ", router.get_model_list())
response1 = await router.acompletion(
model="openai/fo::anything-can-be-here::static::anything-can-be-here",
messages=[{"role": "user", "content": "hello"}],
)
print("response 1 = ", response1)
response2 = await router.acompletion(
model="openai/foo3::hello::static::anything-can-be-here",
messages=[{"role": "user", "content": "hello"}],
)
print("response 2 = ", response2)
def test_router_specific_model_via_id(): def test_router_specific_model_via_id():
""" """
Call a specific deployment by it's id Call a specific deployment by it's id