fix(pattern_match_deployments.py): default to user input if unable to map based on wildcards (#6646)

This commit is contained in:
Krish Dholakia 2024-11-07 23:57:37 +05:30 committed by GitHub
parent 9cb02513b4
commit 44840d615d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 41 additions and 6 deletions

View file

@ -117,11 +117,26 @@ class PatternMatchRouter:
E.g.: E.g.:
Case 1:
model_name: llmengine/* (can be any regex pattern or wildcard pattern) model_name: llmengine/* (can be any regex pattern or wildcard pattern)
litellm_params: litellm_params:
model: openai/* model: openai/*
if model_name = "llmengine/foo" -> model = "openai/foo" if model_name = "llmengine/foo" -> model = "openai/foo"
Case 2:
model_name: llmengine/fo::*::static::*
litellm_params:
model: openai/fo::*::static::*
if model_name = "llmengine/foo::bar::static::baz" -> model = "openai/foo::bar::static::baz"
Case 3:
model_name: *meta.llama3*
litellm_params:
model: bedrock/meta.llama3*
if model_name = "hello-world-meta.llama3-70b" -> model = "bedrock/meta.llama3-70b"
""" """
## BASE CASE: if the deployment model name does not contain a wildcard, return the deployment model name ## BASE CASE: if the deployment model name does not contain a wildcard, return the deployment model name
@ -134,10 +149,9 @@ class PatternMatchRouter:
dynamic_segments = matched_pattern.groups() dynamic_segments = matched_pattern.groups()
if len(dynamic_segments) > wildcard_count: if len(dynamic_segments) > wildcard_count:
raise ValueError( return (
f"More wildcards in the deployment model name than the pattern. Wildcard count: {wildcard_count}, dynamic segments count: {len(dynamic_segments)}" matched_pattern.string
) ) # default to the user input, if unable to map based on wildcards.
# Replace the corresponding wildcards in the litellm model pattern with extracted segments # Replace the corresponding wildcards in the litellm model pattern with extracted segments
for segment in dynamic_segments: for segment in dynamic_segments:
litellm_deployment_litellm_model = litellm_deployment_litellm_model.replace( litellm_deployment_litellm_model = litellm_deployment_litellm_model.replace(

View file

@ -182,6 +182,14 @@ async def test_route_with_no_matching_pattern():
) )
assert result.choices[0].message.content == "Works" assert result.choices[0].message.content == "Works"
## WORKS
result = await router.acompletion(
model="meta.llama3-70b-instruct-v1:0",
messages=[{"role": "user", "content": "Hello, world!"}],
mock_response="Works",
)
assert result.choices[0].message.content == "Works"
## FAILS ## FAILS
with pytest.raises(litellm.BadRequestError) as e: with pytest.raises(litellm.BadRequestError) as e:
await router.acompletion( await router.acompletion(
@ -198,6 +206,7 @@ async def test_route_with_no_matching_pattern():
input="Hello, world!", input="Hello, world!",
) )
def test_router_pattern_match_e2e(): def test_router_pattern_match_e2e():
""" """
Tests the end to end flow of the router Tests the end to end flow of the router
@ -228,4 +237,3 @@ def test_router_pattern_match_e2e():
"model": "gpt-4o", "model": "gpt-4o",
"messages": [{"role": "user", "content": "Hello, how are you?"}], "messages": [{"role": "user", "content": "Hello, how are you?"}],
} }

View file

@ -960,6 +960,18 @@ def test_replace_model_in_jsonl(model_list):
"openai/gpt-3.5-turbo", "openai/gpt-3.5-turbo",
"openai/gpt-3.5-turbo", "openai/gpt-3.5-turbo",
), ),
(
"bedrock/meta.llama3-70b",
"*meta.llama3*",
"bedrock/meta.llama3-*",
"bedrock/meta.llama3-70b",
),
(
"meta.llama3-70b",
"*meta.llama3*",
"bedrock/meta.llama3-*",
"meta.llama3-70b",
),
], ],
) )
def test_pattern_match_deployment_set_model_name( def test_pattern_match_deployment_set_model_name(
@ -1000,6 +1012,7 @@ def test_pattern_match_deployment_set_model_name(
for model in updated_models: for model in updated_models:
assert model["litellm_params"]["model"] == expected_model assert model["litellm_params"]["model"] == expected_model
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_pass_through_moderation_endpoint_factory(model_list): async def test_pass_through_moderation_endpoint_factory(model_list):
router = Router(model_list=model_list) router = Router(model_list=model_list)