forked from phoenix/litellm-mirror
(feat) use regex pattern matching for wildcard routing (#6150)
* use pattern matching for llm deployments * code quality fix * fix linting * add types to PatternMatchRouter * docs add example config for regex patterns
This commit is contained in:
parent
6005450c8f
commit
89506053a4
6 changed files with 169 additions and 40 deletions
|
@ -375,6 +375,10 @@ model_list:
|
|||
litellm_params:
|
||||
model: "groq/*"
|
||||
api_key: os.environ/GROQ_API_KEY
|
||||
- model_name: "fo::*:static::*" # all requests matching this pattern will be routed to this deployment, example: model="fo::hi::static::hi" will be routed to deployment: "openai/fo::*:static::*"
|
||||
litellm_params:
|
||||
model: "openai/fo::*:static::*"
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
```
|
||||
|
||||
Step 2 - Run litellm proxy
|
||||
|
@ -411,6 +415,19 @@ curl http://localhost:4000/v1/chat/completions \
|
|||
}'
|
||||
```
|
||||
|
||||
Test with `fo::*::static::*` - all requests matching this pattern will be routed to `openai/fo::*:static::*`
|
||||
```shell
|
||||
curl http://localhost:4000/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer sk-1234" \
|
||||
-d '{
|
||||
"model": "fo::hi::static::hi",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello, Claude!"}
|
||||
]
|
||||
}'
|
||||
```
|
||||
|
||||
### Load Balancing
|
||||
|
||||
:::info
|
||||
|
|
|
@ -5553,7 +5553,7 @@ async def anthropic_response(
|
|||
and data["model"] not in router_model_names
|
||||
and (
|
||||
llm_router.default_deployment is not None
|
||||
or len(llm_router.provider_default_deployments) > 0
|
||||
or len(llm_router.pattern_router.patterns) > 0
|
||||
)
|
||||
): # model in router deployments, calling a specific deployment on the router
|
||||
llm_response = asyncio.create_task(llm_router.aadapter_completion(**data))
|
||||
|
|
|
@ -109,7 +109,7 @@ async def route_request(
|
|||
return getattr(litellm, f"{route_type}")(**data)
|
||||
elif (
|
||||
llm_router.default_deployment is not None
|
||||
or len(llm_router.provider_default_deployments) > 0
|
||||
or len(llm_router.pattern_router.patterns) > 0
|
||||
):
|
||||
return getattr(llm_router, f"{route_type}")(**data)
|
||||
|
||||
|
|
|
@ -124,6 +124,8 @@ from litellm.utils import (
|
|||
get_utc_datetime,
|
||||
)
|
||||
|
||||
from .router_utils.pattern_match_deployments import PatternMatchRouter
|
||||
|
||||
|
||||
class RoutingArgs(enum.Enum):
|
||||
ttl = 60 # 1min (RPM/TPM expire key)
|
||||
|
@ -344,8 +346,8 @@ class Router:
|
|||
self.default_priority = default_priority
|
||||
self.default_deployment = None # use this to track the users default deployment, when they want to use model = *
|
||||
self.default_max_parallel_requests = default_max_parallel_requests
|
||||
self.provider_default_deployments: Dict[str, List] = {}
|
||||
self.provider_default_deployment_ids: List[str] = []
|
||||
self.pattern_router = PatternMatchRouter()
|
||||
|
||||
if model_list is not None:
|
||||
model_list = copy.deepcopy(model_list)
|
||||
|
@ -4147,10 +4149,6 @@ class Router:
|
|||
),
|
||||
)
|
||||
|
||||
provider_specific_deployment = re.match(
|
||||
rf"{custom_llm_provider}/\*$", deployment.model_name
|
||||
)
|
||||
|
||||
# Check if user is trying to use model_name == "*"
|
||||
# this is a catch all model for their specific api key
|
||||
if deployment.model_name == "*":
|
||||
|
@ -4161,16 +4159,12 @@ class Router:
|
|||
self.default_deployment = deployment.to_json(exclude_none=True)
|
||||
# Check if user is using provider specific wildcard routing
|
||||
# example model_name = "databricks/*" or model_name = "anthropic/*"
|
||||
elif provider_specific_deployment:
|
||||
if custom_llm_provider in self.provider_default_deployments:
|
||||
self.provider_default_deployments[custom_llm_provider].append(
|
||||
deployment.to_json(exclude_none=True)
|
||||
elif "*" in deployment.model_name:
|
||||
# store this as a regex pattern - all deployments matching this pattern will be sent to this deployment
|
||||
# Store deployment.model_name as a regex pattern
|
||||
self.pattern_router.add_pattern(
|
||||
deployment.model_name, deployment.to_json(exclude_none=True)
|
||||
)
|
||||
else:
|
||||
self.provider_default_deployments[custom_llm_provider] = [
|
||||
deployment.to_json(exclude_none=True)
|
||||
]
|
||||
|
||||
if deployment.model_info.id:
|
||||
self.provider_default_deployment_ids.append(deployment.model_info.id)
|
||||
|
||||
|
@ -4433,7 +4427,7 @@ class Router:
|
|||
is_match = True
|
||||
elif (
|
||||
"model_name" in model
|
||||
and model_group in self.provider_default_deployments
|
||||
and self.pattern_router.route(model_group) is not None
|
||||
): # wildcard model
|
||||
is_match = True
|
||||
|
||||
|
@ -5174,28 +5168,16 @@ class Router:
|
|||
model = _item["model"]
|
||||
|
||||
if model not in self.model_names:
|
||||
# check if provider/ specific wildcard routing
|
||||
try:
|
||||
(
|
||||
_,
|
||||
custom_llm_provider,
|
||||
_,
|
||||
_,
|
||||
) = litellm.get_llm_provider(model=model)
|
||||
# check if custom_llm_provider
|
||||
if custom_llm_provider in self.provider_default_deployments:
|
||||
_provider_deployments = self.provider_default_deployments[
|
||||
custom_llm_provider
|
||||
]
|
||||
# check if provider/ specific wildcard routing use pattern matching
|
||||
_pattern_router_response = self.pattern_router.route(model)
|
||||
|
||||
if _pattern_router_response is not None:
|
||||
provider_deployments = []
|
||||
for deployment in _provider_deployments:
|
||||
for deployment in _pattern_router_response:
|
||||
dep = copy.deepcopy(deployment)
|
||||
dep["litellm_params"]["model"] = model
|
||||
provider_deployments.append(dep)
|
||||
return model, provider_deployments
|
||||
except Exception:
|
||||
# get_llm_provider raises exception when provider is unknown
|
||||
pass
|
||||
|
||||
# check if default deployment is set
|
||||
if self.default_deployment is not None:
|
||||
|
|
87
litellm/router_utils/pattern_match_deployments.py
Normal file
87
litellm/router_utils/pattern_match_deployments.py
Normal file
|
@ -0,0 +1,87 @@
|
|||
"""
|
||||
Class to handle llm wildcard routing and regex pattern matching
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
|
||||
class PatternMatchRouter:
|
||||
"""
|
||||
Class to handle llm wildcard routing and regex pattern matching
|
||||
|
||||
doc: https://docs.litellm.ai/docs/proxy/configs#provider-specific-wildcard-routing
|
||||
|
||||
This class will store a mapping for regex pattern: List[Deployments]
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.patterns: Dict[str, List] = {}
|
||||
|
||||
def add_pattern(self, pattern: str, llm_deployment: Dict):
|
||||
"""
|
||||
Add a regex pattern and the corresponding llm deployments to the patterns
|
||||
|
||||
Args:
|
||||
pattern: str
|
||||
llm_deployment: str or List[str]
|
||||
"""
|
||||
# Convert the pattern to a regex
|
||||
regex = self._pattern_to_regex(pattern)
|
||||
if regex not in self.patterns:
|
||||
self.patterns[regex] = []
|
||||
if isinstance(llm_deployment, list):
|
||||
self.patterns[regex].extend(llm_deployment)
|
||||
else:
|
||||
self.patterns[regex].append(llm_deployment)
|
||||
|
||||
def _pattern_to_regex(self, pattern: str) -> str:
|
||||
"""
|
||||
Convert a wildcard pattern to a regex pattern
|
||||
|
||||
example:
|
||||
pattern: openai/*
|
||||
regex: openai/.*
|
||||
|
||||
pattern: openai/fo::*::static::*
|
||||
regex: openai/fo::.*::static::.*
|
||||
|
||||
Args:
|
||||
pattern: str
|
||||
|
||||
Returns:
|
||||
str: regex pattern
|
||||
"""
|
||||
# Replace '*' with '.*' for regex matching
|
||||
regex = pattern.replace("*", ".*")
|
||||
# Escape other special characters
|
||||
regex = re.escape(regex).replace(r"\.\*", ".*")
|
||||
return f"^{regex}$"
|
||||
|
||||
def route(self, request: str) -> Optional[List[Dict]]:
|
||||
"""
|
||||
Route a requested model to the corresponding llm deployments based on the regex pattern
|
||||
|
||||
loop through all the patterns and find the matching pattern
|
||||
if a pattern is found, return the corresponding llm deployments
|
||||
if no pattern is found, return None
|
||||
|
||||
Args:
|
||||
request: str
|
||||
|
||||
Returns:
|
||||
Optional[List[Deployment]]: llm deployments
|
||||
"""
|
||||
for pattern, llm_deployments in self.patterns.items():
|
||||
if re.match(pattern, request):
|
||||
return llm_deployments
|
||||
return None # No matching pattern found
|
||||
|
||||
|
||||
# Example usage:
|
||||
# router = PatternRouter()
|
||||
# router.add_pattern('openai/*', [Deployment(), Deployment()])
|
||||
# router.add_pattern('openai/fo::*::static::*', Deployment())
|
||||
# print(router.route('openai/gpt-4')) # Output: [Deployment(), Deployment()]
|
||||
# print(router.route('openai/fo::hi::static::hi')) # Output: [Deployment()]
|
||||
# print(router.route('something/else')) # Output: None
|
|
@ -73,6 +73,7 @@ async def test_router_provider_wildcard_routing():
|
|||
Pass list of orgs in 1 model definition,
|
||||
expect a unique deployment for each to be created
|
||||
"""
|
||||
litellm.set_verbose = True
|
||||
router = litellm.Router(
|
||||
model_list=[
|
||||
{
|
||||
|
@ -124,6 +125,48 @@ async def test_router_provider_wildcard_routing():
|
|||
print("response 3 = ", response3)
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_router_provider_wildcard_routing_regex():
|
||||
"""
|
||||
Pass list of orgs in 1 model definition,
|
||||
expect a unique deployment for each to be created
|
||||
"""
|
||||
router = litellm.Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "openai/fo::*:static::*",
|
||||
"litellm_params": {
|
||||
"model": "openai/fo::*:static::*",
|
||||
"api_base": "https://exampleopenaiendpoint-production.up.railway.app/",
|
||||
},
|
||||
},
|
||||
{
|
||||
"model_name": "openai/foo3::hello::*",
|
||||
"litellm_params": {
|
||||
"model": "openai/foo3::hello::*",
|
||||
"api_base": "https://exampleopenaiendpoint-production.up.railway.app/",
|
||||
},
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
print("router model list = ", router.get_model_list())
|
||||
|
||||
response1 = await router.acompletion(
|
||||
model="openai/fo::anything-can-be-here::static::anything-can-be-here",
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
)
|
||||
|
||||
print("response 1 = ", response1)
|
||||
|
||||
response2 = await router.acompletion(
|
||||
model="openai/foo3::hello::static::anything-can-be-here",
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
)
|
||||
|
||||
print("response 2 = ", response2)
|
||||
|
||||
|
||||
def test_router_specific_model_via_id():
|
||||
"""
|
||||
Call a specific deployment by it's id
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue