forked from phoenix/litellm-mirror
(feat) use regex pattern matching for wildcard routing (#6150)
* use pattern matching for llm deployments * code quality fix * fix linting * add types to PatternMatchRouter * docs add example config for regex patterns
This commit is contained in:
parent
6005450c8f
commit
89506053a4
6 changed files with 169 additions and 40 deletions
|
@ -375,6 +375,10 @@ model_list:
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: "groq/*"
|
model: "groq/*"
|
||||||
api_key: os.environ/GROQ_API_KEY
|
api_key: os.environ/GROQ_API_KEY
|
||||||
|
- model_name: "fo::*:static::*" # all requests matching this pattern will be routed to this deployment, example: model="fo::hi::static::hi" will be routed to deployment: "openai/fo::*:static::*"
|
||||||
|
litellm_params:
|
||||||
|
model: "openai/fo::*:static::*"
|
||||||
|
api_key: os.environ/OPENAI_API_KEY
|
||||||
```
|
```
|
||||||
|
|
||||||
Step 2 - Run litellm proxy
|
Step 2 - Run litellm proxy
|
||||||
|
@ -411,6 +415,19 @@ curl http://localhost:4000/v1/chat/completions \
|
||||||
}'
|
}'
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Test with `fo::*::static::*` - all requests matching this pattern will be routed to `openai/fo::*:static::*`
|
||||||
|
```shell
|
||||||
|
curl http://localhost:4000/v1/chat/completions \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-H "Authorization: Bearer sk-1234" \
|
||||||
|
-d '{
|
||||||
|
"model": "fo::hi::static::hi",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "Hello, Claude!"}
|
||||||
|
]
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
### Load Balancing
|
### Load Balancing
|
||||||
|
|
||||||
:::info
|
:::info
|
||||||
|
|
|
@ -5553,7 +5553,7 @@ async def anthropic_response(
|
||||||
and data["model"] not in router_model_names
|
and data["model"] not in router_model_names
|
||||||
and (
|
and (
|
||||||
llm_router.default_deployment is not None
|
llm_router.default_deployment is not None
|
||||||
or len(llm_router.provider_default_deployments) > 0
|
or len(llm_router.pattern_router.patterns) > 0
|
||||||
)
|
)
|
||||||
): # model in router deployments, calling a specific deployment on the router
|
): # model in router deployments, calling a specific deployment on the router
|
||||||
llm_response = asyncio.create_task(llm_router.aadapter_completion(**data))
|
llm_response = asyncio.create_task(llm_router.aadapter_completion(**data))
|
||||||
|
|
|
@ -109,7 +109,7 @@ async def route_request(
|
||||||
return getattr(litellm, f"{route_type}")(**data)
|
return getattr(litellm, f"{route_type}")(**data)
|
||||||
elif (
|
elif (
|
||||||
llm_router.default_deployment is not None
|
llm_router.default_deployment is not None
|
||||||
or len(llm_router.provider_default_deployments) > 0
|
or len(llm_router.pattern_router.patterns) > 0
|
||||||
):
|
):
|
||||||
return getattr(llm_router, f"{route_type}")(**data)
|
return getattr(llm_router, f"{route_type}")(**data)
|
||||||
|
|
||||||
|
|
|
@ -124,6 +124,8 @@ from litellm.utils import (
|
||||||
get_utc_datetime,
|
get_utc_datetime,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from .router_utils.pattern_match_deployments import PatternMatchRouter
|
||||||
|
|
||||||
|
|
||||||
class RoutingArgs(enum.Enum):
|
class RoutingArgs(enum.Enum):
|
||||||
ttl = 60 # 1min (RPM/TPM expire key)
|
ttl = 60 # 1min (RPM/TPM expire key)
|
||||||
|
@ -344,8 +346,8 @@ class Router:
|
||||||
self.default_priority = default_priority
|
self.default_priority = default_priority
|
||||||
self.default_deployment = None # use this to track the users default deployment, when they want to use model = *
|
self.default_deployment = None # use this to track the users default deployment, when they want to use model = *
|
||||||
self.default_max_parallel_requests = default_max_parallel_requests
|
self.default_max_parallel_requests = default_max_parallel_requests
|
||||||
self.provider_default_deployments: Dict[str, List] = {}
|
|
||||||
self.provider_default_deployment_ids: List[str] = []
|
self.provider_default_deployment_ids: List[str] = []
|
||||||
|
self.pattern_router = PatternMatchRouter()
|
||||||
|
|
||||||
if model_list is not None:
|
if model_list is not None:
|
||||||
model_list = copy.deepcopy(model_list)
|
model_list = copy.deepcopy(model_list)
|
||||||
|
@ -4147,10 +4149,6 @@ class Router:
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
provider_specific_deployment = re.match(
|
|
||||||
rf"{custom_llm_provider}/\*$", deployment.model_name
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check if user is trying to use model_name == "*"
|
# Check if user is trying to use model_name == "*"
|
||||||
# this is a catch all model for their specific api key
|
# this is a catch all model for their specific api key
|
||||||
if deployment.model_name == "*":
|
if deployment.model_name == "*":
|
||||||
|
@ -4161,16 +4159,12 @@ class Router:
|
||||||
self.default_deployment = deployment.to_json(exclude_none=True)
|
self.default_deployment = deployment.to_json(exclude_none=True)
|
||||||
# Check if user is using provider specific wildcard routing
|
# Check if user is using provider specific wildcard routing
|
||||||
# example model_name = "databricks/*" or model_name = "anthropic/*"
|
# example model_name = "databricks/*" or model_name = "anthropic/*"
|
||||||
elif provider_specific_deployment:
|
elif "*" in deployment.model_name:
|
||||||
if custom_llm_provider in self.provider_default_deployments:
|
# store this as a regex pattern - all deployments matching this pattern will be sent to this deployment
|
||||||
self.provider_default_deployments[custom_llm_provider].append(
|
# Store deployment.model_name as a regex pattern
|
||||||
deployment.to_json(exclude_none=True)
|
self.pattern_router.add_pattern(
|
||||||
)
|
deployment.model_name, deployment.to_json(exclude_none=True)
|
||||||
else:
|
)
|
||||||
self.provider_default_deployments[custom_llm_provider] = [
|
|
||||||
deployment.to_json(exclude_none=True)
|
|
||||||
]
|
|
||||||
|
|
||||||
if deployment.model_info.id:
|
if deployment.model_info.id:
|
||||||
self.provider_default_deployment_ids.append(deployment.model_info.id)
|
self.provider_default_deployment_ids.append(deployment.model_info.id)
|
||||||
|
|
||||||
|
@ -4433,7 +4427,7 @@ class Router:
|
||||||
is_match = True
|
is_match = True
|
||||||
elif (
|
elif (
|
||||||
"model_name" in model
|
"model_name" in model
|
||||||
and model_group in self.provider_default_deployments
|
and self.pattern_router.route(model_group) is not None
|
||||||
): # wildcard model
|
): # wildcard model
|
||||||
is_match = True
|
is_match = True
|
||||||
|
|
||||||
|
@ -5174,28 +5168,16 @@ class Router:
|
||||||
model = _item["model"]
|
model = _item["model"]
|
||||||
|
|
||||||
if model not in self.model_names:
|
if model not in self.model_names:
|
||||||
# check if provider/ specific wildcard routing
|
# check if provider/ specific wildcard routing use pattern matching
|
||||||
try:
|
_pattern_router_response = self.pattern_router.route(model)
|
||||||
(
|
|
||||||
_,
|
if _pattern_router_response is not None:
|
||||||
custom_llm_provider,
|
provider_deployments = []
|
||||||
_,
|
for deployment in _pattern_router_response:
|
||||||
_,
|
dep = copy.deepcopy(deployment)
|
||||||
) = litellm.get_llm_provider(model=model)
|
dep["litellm_params"]["model"] = model
|
||||||
# check if custom_llm_provider
|
provider_deployments.append(dep)
|
||||||
if custom_llm_provider in self.provider_default_deployments:
|
return model, provider_deployments
|
||||||
_provider_deployments = self.provider_default_deployments[
|
|
||||||
custom_llm_provider
|
|
||||||
]
|
|
||||||
provider_deployments = []
|
|
||||||
for deployment in _provider_deployments:
|
|
||||||
dep = copy.deepcopy(deployment)
|
|
||||||
dep["litellm_params"]["model"] = model
|
|
||||||
provider_deployments.append(dep)
|
|
||||||
return model, provider_deployments
|
|
||||||
except Exception:
|
|
||||||
# get_llm_provider raises exception when provider is unknown
|
|
||||||
pass
|
|
||||||
|
|
||||||
# check if default deployment is set
|
# check if default deployment is set
|
||||||
if self.default_deployment is not None:
|
if self.default_deployment is not None:
|
||||||
|
|
87
litellm/router_utils/pattern_match_deployments.py
Normal file
87
litellm/router_utils/pattern_match_deployments.py
Normal file
|
@ -0,0 +1,87 @@
|
||||||
|
"""
|
||||||
|
Class to handle llm wildcard routing and regex pattern matching
|
||||||
|
"""
|
||||||
|
|
||||||
|
import re
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
|
||||||
|
class PatternMatchRouter:
|
||||||
|
"""
|
||||||
|
Class to handle llm wildcard routing and regex pattern matching
|
||||||
|
|
||||||
|
doc: https://docs.litellm.ai/docs/proxy/configs#provider-specific-wildcard-routing
|
||||||
|
|
||||||
|
This class will store a mapping for regex pattern: List[Deployments]
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.patterns: Dict[str, List] = {}
|
||||||
|
|
||||||
|
def add_pattern(self, pattern: str, llm_deployment: Dict):
|
||||||
|
"""
|
||||||
|
Add a regex pattern and the corresponding llm deployments to the patterns
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pattern: str
|
||||||
|
llm_deployment: str or List[str]
|
||||||
|
"""
|
||||||
|
# Convert the pattern to a regex
|
||||||
|
regex = self._pattern_to_regex(pattern)
|
||||||
|
if regex not in self.patterns:
|
||||||
|
self.patterns[regex] = []
|
||||||
|
if isinstance(llm_deployment, list):
|
||||||
|
self.patterns[regex].extend(llm_deployment)
|
||||||
|
else:
|
||||||
|
self.patterns[regex].append(llm_deployment)
|
||||||
|
|
||||||
|
def _pattern_to_regex(self, pattern: str) -> str:
|
||||||
|
"""
|
||||||
|
Convert a wildcard pattern to a regex pattern
|
||||||
|
|
||||||
|
example:
|
||||||
|
pattern: openai/*
|
||||||
|
regex: openai/.*
|
||||||
|
|
||||||
|
pattern: openai/fo::*::static::*
|
||||||
|
regex: openai/fo::.*::static::.*
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pattern: str
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: regex pattern
|
||||||
|
"""
|
||||||
|
# Replace '*' with '.*' for regex matching
|
||||||
|
regex = pattern.replace("*", ".*")
|
||||||
|
# Escape other special characters
|
||||||
|
regex = re.escape(regex).replace(r"\.\*", ".*")
|
||||||
|
return f"^{regex}$"
|
||||||
|
|
||||||
|
def route(self, request: str) -> Optional[List[Dict]]:
|
||||||
|
"""
|
||||||
|
Route a requested model to the corresponding llm deployments based on the regex pattern
|
||||||
|
|
||||||
|
loop through all the patterns and find the matching pattern
|
||||||
|
if a pattern is found, return the corresponding llm deployments
|
||||||
|
if no pattern is found, return None
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: str
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[List[Deployment]]: llm deployments
|
||||||
|
"""
|
||||||
|
for pattern, llm_deployments in self.patterns.items():
|
||||||
|
if re.match(pattern, request):
|
||||||
|
return llm_deployments
|
||||||
|
return None # No matching pattern found
|
||||||
|
|
||||||
|
|
||||||
|
# Example usage:
|
||||||
|
# router = PatternRouter()
|
||||||
|
# router.add_pattern('openai/*', [Deployment(), Deployment()])
|
||||||
|
# router.add_pattern('openai/fo::*::static::*', Deployment())
|
||||||
|
# print(router.route('openai/gpt-4')) # Output: [Deployment(), Deployment()]
|
||||||
|
# print(router.route('openai/fo::hi::static::hi')) # Output: [Deployment()]
|
||||||
|
# print(router.route('something/else')) # Output: None
|
|
@ -73,6 +73,7 @@ async def test_router_provider_wildcard_routing():
|
||||||
Pass list of orgs in 1 model definition,
|
Pass list of orgs in 1 model definition,
|
||||||
expect a unique deployment for each to be created
|
expect a unique deployment for each to be created
|
||||||
"""
|
"""
|
||||||
|
litellm.set_verbose = True
|
||||||
router = litellm.Router(
|
router = litellm.Router(
|
||||||
model_list=[
|
model_list=[
|
||||||
{
|
{
|
||||||
|
@ -124,6 +125,48 @@ async def test_router_provider_wildcard_routing():
|
||||||
print("response 3 = ", response3)
|
print("response 3 = ", response3)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio()
|
||||||
|
async def test_router_provider_wildcard_routing_regex():
|
||||||
|
"""
|
||||||
|
Pass list of orgs in 1 model definition,
|
||||||
|
expect a unique deployment for each to be created
|
||||||
|
"""
|
||||||
|
router = litellm.Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": "openai/fo::*:static::*",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "openai/fo::*:static::*",
|
||||||
|
"api_base": "https://exampleopenaiendpoint-production.up.railway.app/",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "openai/foo3::hello::*",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "openai/foo3::hello::*",
|
||||||
|
"api_base": "https://exampleopenaiendpoint-production.up.railway.app/",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
print("router model list = ", router.get_model_list())
|
||||||
|
|
||||||
|
response1 = await router.acompletion(
|
||||||
|
model="openai/fo::anything-can-be-here::static::anything-can-be-here",
|
||||||
|
messages=[{"role": "user", "content": "hello"}],
|
||||||
|
)
|
||||||
|
|
||||||
|
print("response 1 = ", response1)
|
||||||
|
|
||||||
|
response2 = await router.acompletion(
|
||||||
|
model="openai/foo3::hello::static::anything-can-be-here",
|
||||||
|
messages=[{"role": "user", "content": "hello"}],
|
||||||
|
)
|
||||||
|
|
||||||
|
print("response 2 = ", response2)
|
||||||
|
|
||||||
|
|
||||||
def test_router_specific_model_via_id():
|
def test_router_specific_model_via_id():
|
||||||
"""
|
"""
|
||||||
Call a specific deployment by it's id
|
Call a specific deployment by it's id
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue