diff --git a/docs/my-website/docs/proxy/configs.md b/docs/my-website/docs/proxy/configs.md index e1f929c5a..745da30d0 100644 --- a/docs/my-website/docs/proxy/configs.md +++ b/docs/my-website/docs/proxy/configs.md @@ -375,6 +375,10 @@ model_list: litellm_params: model: "groq/*" api_key: os.environ/GROQ_API_KEY + - model_name: "fo::*:static::*" # all requests matching this pattern will be routed to this deployment, example: model="fo::hi::static::hi" will be routed to deployment: "openai/fo::*:static::*" + litellm_params: + model: "openai/fo::*:static::*" + api_key: os.environ/OPENAI_API_KEY ``` Step 2 - Run litellm proxy @@ -411,6 +415,19 @@ curl http://localhost:4000/v1/chat/completions \ }' ``` +Test with `fo::*::static::*` - all requests matching this pattern will be routed to `openai/fo::*:static::*` +```shell +curl http://localhost:4000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer sk-1234" \ + -d '{ + "model": "fo::hi::static::hi", + "messages": [ + {"role": "user", "content": "Hello, Claude!"} + ] + }' +``` + ### Load Balancing :::info diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 1383e6794..d75a087da 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -5553,7 +5553,7 @@ async def anthropic_response( and data["model"] not in router_model_names and ( llm_router.default_deployment is not None - or len(llm_router.provider_default_deployments) > 0 + or len(llm_router.pattern_router.patterns) > 0 ) ): # model in router deployments, calling a specific deployment on the router llm_response = asyncio.create_task(llm_router.aadapter_completion(**data)) diff --git a/litellm/proxy/route_llm_request.py b/litellm/proxy/route_llm_request.py index 63e41f64f..fcf95f6ab 100644 --- a/litellm/proxy/route_llm_request.py +++ b/litellm/proxy/route_llm_request.py @@ -109,7 +109,7 @@ async def route_request( return getattr(litellm, f"{route_type}")(**data) elif ( llm_router.default_deployment is not None - or len(llm_router.provider_default_deployments) > 0 + or len(llm_router.pattern_router.patterns) > 0 ): return getattr(llm_router, f"{route_type}")(**data) diff --git a/litellm/router.py b/litellm/router.py index d73f5d4b3..50db754b6 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -124,6 +124,8 @@ from litellm.utils import ( get_utc_datetime, ) +from .router_utils.pattern_match_deployments import PatternMatchRouter + class RoutingArgs(enum.Enum): ttl = 60 # 1min (RPM/TPM expire key) @@ -344,8 +346,8 @@ class Router: self.default_priority = default_priority self.default_deployment = None # use this to track the users default deployment, when they want to use model = * self.default_max_parallel_requests = default_max_parallel_requests - self.provider_default_deployments: Dict[str, List] = {} self.provider_default_deployment_ids: List[str] = [] + self.pattern_router = PatternMatchRouter() if model_list is not None: model_list = copy.deepcopy(model_list) @@ -4147,10 +4149,6 @@ class Router: ), ) - provider_specific_deployment = re.match( - rf"{custom_llm_provider}/\*$", deployment.model_name - ) - # Check if user is trying to use model_name == "*" # this is a catch all model for their specific api key if deployment.model_name == "*": @@ -4161,16 +4159,12 @@ class Router: self.default_deployment = deployment.to_json(exclude_none=True) # Check if user is using provider specific wildcard routing # example model_name = "databricks/*" or model_name = "anthropic/*" - elif provider_specific_deployment: - if custom_llm_provider in self.provider_default_deployments: - self.provider_default_deployments[custom_llm_provider].append( - deployment.to_json(exclude_none=True) - ) - else: - self.provider_default_deployments[custom_llm_provider] = [ - deployment.to_json(exclude_none=True) - ] - + elif "*" in deployment.model_name: + # store this as a regex pattern - all deployments matching this pattern will be sent to this deployment + # Store deployment.model_name as a regex pattern + self.pattern_router.add_pattern( + deployment.model_name, deployment.to_json(exclude_none=True) + ) if deployment.model_info.id: self.provider_default_deployment_ids.append(deployment.model_info.id) @@ -4433,7 +4427,7 @@ class Router: is_match = True elif ( "model_name" in model - and model_group in self.provider_default_deployments + and self.pattern_router.route(model_group) is not None ): # wildcard model is_match = True @@ -5174,28 +5168,16 @@ class Router: model = _item["model"] if model not in self.model_names: - # check if provider/ specific wildcard routing - try: - ( - _, - custom_llm_provider, - _, - _, - ) = litellm.get_llm_provider(model=model) - # check if custom_llm_provider - if custom_llm_provider in self.provider_default_deployments: - _provider_deployments = self.provider_default_deployments[ - custom_llm_provider - ] - provider_deployments = [] - for deployment in _provider_deployments: - dep = copy.deepcopy(deployment) - dep["litellm_params"]["model"] = model - provider_deployments.append(dep) - return model, provider_deployments - except Exception: - # get_llm_provider raises exception when provider is unknown - pass + # check if provider/ specific wildcard routing use pattern matching + _pattern_router_response = self.pattern_router.route(model) + + if _pattern_router_response is not None: + provider_deployments = [] + for deployment in _pattern_router_response: + dep = copy.deepcopy(deployment) + dep["litellm_params"]["model"] = model + provider_deployments.append(dep) + return model, provider_deployments # check if default deployment is set if self.default_deployment is not None: diff --git a/litellm/router_utils/pattern_match_deployments.py b/litellm/router_utils/pattern_match_deployments.py new file mode 100644 index 000000000..c9ea8c1a7 --- /dev/null +++ b/litellm/router_utils/pattern_match_deployments.py @@ -0,0 +1,87 @@ +""" +Class to handle llm wildcard routing and regex pattern matching +""" + +import re +from typing import Dict, List, Optional + + +class PatternMatchRouter: + """ + Class to handle llm wildcard routing and regex pattern matching + + doc: https://docs.litellm.ai/docs/proxy/configs#provider-specific-wildcard-routing + + This class will store a mapping for regex pattern: List[Deployments] + """ + + def __init__(self): + self.patterns: Dict[str, List] = {} + + def add_pattern(self, pattern: str, llm_deployment: Dict): + """ + Add a regex pattern and the corresponding llm deployments to the patterns + + Args: + pattern: str + llm_deployment: str or List[str] + """ + # Convert the pattern to a regex + regex = self._pattern_to_regex(pattern) + if regex not in self.patterns: + self.patterns[regex] = [] + if isinstance(llm_deployment, list): + self.patterns[regex].extend(llm_deployment) + else: + self.patterns[regex].append(llm_deployment) + + def _pattern_to_regex(self, pattern: str) -> str: + """ + Convert a wildcard pattern to a regex pattern + + example: + pattern: openai/* + regex: openai/.* + + pattern: openai/fo::*::static::* + regex: openai/fo::.*::static::.* + + Args: + pattern: str + + Returns: + str: regex pattern + """ + # Replace '*' with '.*' for regex matching + regex = pattern.replace("*", ".*") + # Escape other special characters + regex = re.escape(regex).replace(r"\.\*", ".*") + return f"^{regex}$" + + def route(self, request: str) -> Optional[List[Dict]]: + """ + Route a requested model to the corresponding llm deployments based on the regex pattern + + loop through all the patterns and find the matching pattern + if a pattern is found, return the corresponding llm deployments + if no pattern is found, return None + + Args: + request: str + + Returns: + Optional[List[Deployment]]: llm deployments + """ + for pattern, llm_deployments in self.patterns.items(): + if re.match(pattern, request): + return llm_deployments + return None # No matching pattern found + + +# Example usage: +# router = PatternRouter() +# router.add_pattern('openai/*', [Deployment(), Deployment()]) +# router.add_pattern('openai/fo::*::static::*', Deployment()) +# print(router.route('openai/gpt-4')) # Output: [Deployment(), Deployment()] +# print(router.route('openai/fo::hi::static::hi')) # Output: [Deployment()] +# print(router.route('something/else')) # Output: None diff --git a/tests/local_testing/test_router.py b/tests/local_testing/test_router.py index 760c14d68..57ef196ff 100644 --- a/tests/local_testing/test_router.py +++ b/tests/local_testing/test_router.py @@ -73,6 +73,7 @@ async def test_router_provider_wildcard_routing(): Pass list of orgs in 1 model definition, expect a unique deployment for each to be created """ + litellm.set_verbose = True router = litellm.Router( model_list=[ { @@ -124,6 +125,48 @@ async def test_router_provider_wildcard_routing(): print("response 3 = ", response3) +@pytest.mark.asyncio() +async def test_router_provider_wildcard_routing_regex(): + """ + Pass list of orgs in 1 model definition, + expect a unique deployment for each to be created + """ + router = litellm.Router( + model_list=[ + { + "model_name": "openai/fo::*:static::*", + "litellm_params": { + "model": "openai/fo::*:static::*", + "api_base": "https://exampleopenaiendpoint-production.up.railway.app/", + }, + }, + { + "model_name": "openai/foo3::hello::*", + "litellm_params": { + "model": "openai/foo3::hello::*", + "api_base": "https://exampleopenaiendpoint-production.up.railway.app/", + }, + }, + ] + ) + + print("router model list = ", router.get_model_list()) + + response1 = await router.acompletion( + model="openai/fo::anything-can-be-here::static::anything-can-be-here", + messages=[{"role": "user", "content": "hello"}], + ) + + print("response 1 = ", response1) + + response2 = await router.acompletion( + model="openai/foo3::hello::static::anything-can-be-here", + messages=[{"role": "user", "content": "hello"}], + ) + + print("response 2 = ", response2) + + def test_router_specific_model_via_id(): """ Call a specific deployment by it's id