From 8c75bb5f0fad0690f863031385ed7016b12c1525 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Mon, 1 Apr 2024 19:00:24 -0700 Subject: [PATCH 1/3] feat router allow * models --- litellm/router.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/litellm/router.py b/litellm/router.py index 9ddf6e2298..210fb6b5c7 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -191,6 +191,8 @@ class Router: redis_cache=redis_cache, in_memory_cache=InMemoryCache() ) # use a dual cache (Redis+In-Memory) for tracking cooldowns, usage, etc. + self.default_deployment = None # use this to track the users default deployment, when they want to use model = * + if model_list: model_list = copy.deepcopy(model_list) self.set_model_list(model_list) @@ -252,7 +254,6 @@ class Router: } } """ - ### ROUTING SETUP ### if routing_strategy == "least-busy": self.leastbusy_logger = LeastBusyLoggingHandler( @@ -2078,6 +2079,11 @@ class Router: ), ) + # Check if user is trying to use model_name == "*" + # this is a catch all model for their specific api key + if model["model_name"] == "*": + self.default_deployment = model + # Azure GPT-Vision Enhancements, users can pass os.environ/ data_sources = model.get("litellm_params", {}).get("dataSources", []) @@ -2248,6 +2254,13 @@ class Router: ) model = self.model_group_alias[model] + if model not in self.model_names and self.default_deployment is not None: + updated_deployment = copy.deepcopy( + self.default_deployment + ) # self.default_deployment + updated_deployment["litellm_params"]["model"] = model + return updated_deployment + ## get healthy deployments ### get all deployments healthy_deployments = [m for m in self.model_list if m["model_name"] == model] From 716fcd3ec4203a8a77a8522326b73d419612164c Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Mon, 1 Apr 2024 19:07:05 -0700 Subject: [PATCH 2/3] (fix) allow wildcard models --- litellm/proxy/proxy_server.py | 36 +++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 8b6fae40f2..12d0428ba4 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -3107,6 +3107,12 @@ async def completion( response = await llm_router.atext_completion( **data, specific_deployment=True ) + elif ( + llm_router is not None + and data["model"] not in router_model_names + and llm_router.default_deployment is not None + ): # model in router deployments, calling a specific deployment on the router + response = await llm_router.atext_completion(**data) elif user_model is not None: # `litellm --model ` response = await litellm.atext_completion(**data) else: @@ -3324,6 +3330,12 @@ async def chat_completion( llm_router is not None and data["model"] in llm_router.deployment_names ): # model in router deployments, calling a specific deployment on the router tasks.append(llm_router.acompletion(**data, specific_deployment=True)) + elif ( + llm_router is not None + and data["model"] not in router_model_names + and llm_router.default_deployment is not None + ): # model in router deployments, calling a specific deployment on the router + tasks.append(llm_router.acompletion(**data)) elif user_model is not None: # `litellm --model ` tasks.append(litellm.acompletion(**data)) else: @@ -3538,6 +3550,12 @@ async def embeddings( llm_router is not None and data["model"] in llm_router.deployment_names ): # model in router deployments, calling a specific deployment on the router response = await llm_router.aembedding(**data, specific_deployment=True) + elif ( + llm_router is not None + and data["model"] not in router_model_names + and llm_router.default_deployment is not None + ): # model in router deployments, calling a specific deployment on the router + response = await llm_router.aembedding(**data) elif user_model is not None: # `litellm --model ` response = await litellm.aembedding(**data) else: @@ -3683,6 +3701,12 @@ async def image_generation( response = await llm_router.aimage_generation( **data ) # ensure this goes the llm_router, router will do the correct alias mapping + elif ( + llm_router is not None + and data["model"] not in router_model_names + and llm_router.default_deployment is not None + ): # model in router deployments, calling a specific deployment on the router + response = await llm_router.aimage_generation(**data) elif user_model is not None: # `litellm --model ` response = await litellm.aimage_generation(**data) else: @@ -3837,6 +3861,12 @@ async def audio_transcriptions( response = await llm_router.atranscription( **data ) # ensure this goes the llm_router, router will do the correct alias mapping + elif ( + llm_router is not None + and data["model"] not in router_model_names + and llm_router.default_deployment is not None + ): # model in router deployments, calling a specific deployment on the router + response = await llm_router.atranscription(**data) elif user_model is not None: # `litellm --model ` response = await litellm.atranscription(**data) else: @@ -3990,6 +4020,12 @@ async def moderations( response = await llm_router.amoderation( **data ) # ensure this goes the llm_router, router will do the correct alias mapping + elif ( + llm_router is not None + and data["model"] not in router_model_names + and llm_router.default_deployment is not None + ): # model in router deployments, calling a specific deployment on the router + response = await llm_router.amoderation(**data) elif user_model is not None: # `litellm --model ` response = await litellm.amoderation(**data) else: From 71537393d321ba745f33adfe7ec93023620a973b Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Mon, 1 Apr 2024 19:46:07 -0700 Subject: [PATCH 3/3] test test_wildcard_openai_routing --- litellm/tests/test_router_get_deployments.py | 64 ++++++++++++++++++++ 1 file changed, 64 insertions(+) diff --git a/litellm/tests/test_router_get_deployments.py b/litellm/tests/test_router_get_deployments.py index 8fa6b9761c..53299575a5 100644 --- a/litellm/tests/test_router_get_deployments.py +++ b/litellm/tests/test_router_get_deployments.py @@ -448,3 +448,67 @@ def test_usage_based_routing(): ), f"Assertion failed: 'chatgpt-high-tpm' does not have about 80% of the total requests in the weighted load balancer. Selection counts {selection_counts}" except Exception as e: pytest.fail(f"Error occurred: {e}") + + +@pytest.mark.asyncio +async def test_wildcard_openai_routing(): + """ + Initialize router with *, all models go through * and use OPENAI_API_KEY + """ + try: + model_list = [ + { + "model_name": "*", + "litellm_params": { + "model": "openai/*", + "api_key": os.getenv("OPENAI_API_KEY"), + }, + "tpm": 100, + }, + ] + + router = Router( + model_list=model_list, + ) + + messages = [ + {"content": "Tell me a joke.", "role": "user"}, + ] + + selection_counts = defaultdict(int) + for _ in range(25): + response = await router.acompletion( + model="gpt-4", + messages=messages, + mock_response="good morning", + ) + # print("response1", response) + + selection_counts[response["model"]] += 1 + + response = await router.acompletion( + model="gpt-3.5-turbo", + messages=messages, + mock_response="good morning", + ) + # print("response2", response) + + selection_counts[response["model"]] += 1 + + response = await router.acompletion( + model="gpt-4-turbo-preview", + messages=messages, + mock_response="good morning", + ) + # print("response3", response) + + # print("response", response) + + selection_counts[response["model"]] += 1 + + assert selection_counts["gpt-4"] == 25 + assert selection_counts["gpt-3.5-turbo"] == 25 + assert selection_counts["gpt-4-turbo-preview"] == 25 + + except Exception as e: + pytest.fail(f"Error occurred: {e}")