From 80ecf0829c8211271072f25a05cc4b8ad14cedae Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Sat, 12 Oct 2024 16:01:21 +0530 Subject: [PATCH] (fix) provider wildcard routing - when models specificed without provider prefix (#6173) * fix wildcard routing scenario * fix pattern matching hits --- litellm/router.py | 22 +++++++++++++++++++++- tests/local_testing/test_router.py | 5 +++++ 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/litellm/router.py b/litellm/router.py index 50db754b6..537c14ddc 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -5169,8 +5169,28 @@ class Router: if model not in self.model_names: # check if provider/ specific wildcard routing use pattern matching - _pattern_router_response = self.pattern_router.route(model) + custom_llm_provider: Optional[str] = None + try: + ( + _, + custom_llm_provider, + _, + _, + ) = litellm.get_llm_provider(model=model) + except Exception: + # get_llm_provider raises exception when provider is unknown + pass + """ + self.pattern_router.route(model): + does exact pattern matching. Example openai/gpt-3.5-turbo gets routed to pattern openai/* + + self.pattern_router.route(f"{custom_llm_provider}/{model}"): + does pattern matching using litellm.get_llm_provider(), example claude-3-5-sonnet-20240620 gets routed to anthropic/* since 'claude-3-5-sonnet-20240620' is an Anthropic Model + """ + _pattern_router_response = self.pattern_router.route( + model + ) or self.pattern_router.route(f"{custom_llm_provider}/{model}") if _pattern_router_response is not None: provider_deployments = [] for deployment in _pattern_router_response: diff --git a/tests/local_testing/test_router.py b/tests/local_testing/test_router.py index 57ef196ff..42148d9ab 100644 --- a/tests/local_testing/test_router.py +++ b/tests/local_testing/test_router.py @@ -124,6 +124,11 @@ async def test_router_provider_wildcard_routing(): print("response 3 = ", response3) + response4 = await router.acompletion( + model="claude-3-5-sonnet-20240620", + messages=[{"role": "user", "content": "hello"}], + ) + @pytest.mark.asyncio() async def test_router_provider_wildcard_routing_regex():