From 44840d615d07c8b3f514d00d7cdf8b816b3fa0f1 Mon Sep 17 00:00:00 2001 From: Krish Dholakia Date: Thu, 7 Nov 2024 23:57:37 +0530 Subject: [PATCH] fix(pattern_match_deployments.py): default to user input if unable to map based on wildcards (#6646) --- .../router_utils/pattern_match_deployments.py | 22 +++++++++++++++---- .../test_router_pattern_matching.py | 10 ++++++++- .../test_router_helper_utils.py | 15 ++++++++++++- 3 files changed, 41 insertions(+), 6 deletions(-) diff --git a/litellm/router_utils/pattern_match_deployments.py b/litellm/router_utils/pattern_match_deployments.py index 039af635c..3896c3a95 100644 --- a/litellm/router_utils/pattern_match_deployments.py +++ b/litellm/router_utils/pattern_match_deployments.py @@ -117,11 +117,26 @@ class PatternMatchRouter: E.g.: + Case 1: model_name: llmengine/* (can be any regex pattern or wildcard pattern) litellm_params: model: openai/* if model_name = "llmengine/foo" -> model = "openai/foo" + + Case 2: + model_name: llmengine/fo::*::static::* + litellm_params: + model: openai/fo::*::static::* + + if model_name = "llmengine/foo::bar::static::baz" -> model = "openai/foo::bar::static::baz" + + Case 3: + model_name: *meta.llama3* + litellm_params: + model: bedrock/meta.llama3* + + if model_name = "hello-world-meta.llama3-70b" -> model = "bedrock/meta.llama3-70b" """ ## BASE CASE: if the deployment model name does not contain a wildcard, return the deployment model name @@ -134,10 +149,9 @@ class PatternMatchRouter: dynamic_segments = matched_pattern.groups() if len(dynamic_segments) > wildcard_count: - raise ValueError( - f"More wildcards in the deployment model name than the pattern. Wildcard count: {wildcard_count}, dynamic segments count: {len(dynamic_segments)}" - ) - + return ( + matched_pattern.string + ) # default to the user input, if unable to map based on wildcards. # Replace the corresponding wildcards in the litellm model pattern with extracted segments for segment in dynamic_segments: litellm_deployment_litellm_model = litellm_deployment_litellm_model.replace( diff --git a/tests/local_testing/test_router_pattern_matching.py b/tests/local_testing/test_router_pattern_matching.py index 2a6f66105..914e8ecfa 100644 --- a/tests/local_testing/test_router_pattern_matching.py +++ b/tests/local_testing/test_router_pattern_matching.py @@ -182,6 +182,14 @@ async def test_route_with_no_matching_pattern(): ) assert result.choices[0].message.content == "Works" + ## WORKS + result = await router.acompletion( + model="meta.llama3-70b-instruct-v1:0", + messages=[{"role": "user", "content": "Hello, world!"}], + mock_response="Works", + ) + assert result.choices[0].message.content == "Works" + ## FAILS with pytest.raises(litellm.BadRequestError) as e: await router.acompletion( @@ -198,6 +206,7 @@ async def test_route_with_no_matching_pattern(): input="Hello, world!", ) + def test_router_pattern_match_e2e(): """ Tests the end to end flow of the router @@ -228,4 +237,3 @@ def test_router_pattern_match_e2e(): "model": "gpt-4o", "messages": [{"role": "user", "content": "Hello, how are you?"}], } - diff --git a/tests/router_unit_tests/test_router_helper_utils.py b/tests/router_unit_tests/test_router_helper_utils.py index cabb4a899..7e2daa9b5 100644 --- a/tests/router_unit_tests/test_router_helper_utils.py +++ b/tests/router_unit_tests/test_router_helper_utils.py @@ -960,6 +960,18 @@ def test_replace_model_in_jsonl(model_list): "openai/gpt-3.5-turbo", "openai/gpt-3.5-turbo", ), + ( + "bedrock/meta.llama3-70b", + "*meta.llama3*", + "bedrock/meta.llama3-*", + "bedrock/meta.llama3-70b", + ), + ( + "meta.llama3-70b", + "*meta.llama3*", + "bedrock/meta.llama3-*", + "meta.llama3-70b", + ), ], ) def test_pattern_match_deployment_set_model_name( @@ -1000,9 +1012,10 @@ def test_pattern_match_deployment_set_model_name( for model in updated_models: assert model["litellm_params"]["model"] == expected_model + @pytest.mark.asyncio async def test_pass_through_moderation_endpoint_factory(model_list): router = Router(model_list=model_list) response = await router._pass_through_moderation_endpoint_factory( original_function=litellm.amoderation, input="this is valid good text" - ) \ No newline at end of file + )