diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml index 36b191c90..d4bddd9a0 100644 --- a/litellm/proxy/proxy_config.yaml +++ b/litellm/proxy/proxy_config.yaml @@ -8,9 +8,15 @@ model_list: litellm_params: model: fireworks_ai/accounts/fireworks/models/llama-v3-70b-instruct api_key: "os.environ/FIREWORKS" - - model_name: "*" + # provider specific wildcard routing + - model_name: "anthropic/*" litellm_params: - model: "*" + model: "anthropic/*" + api_key: os.environ/ANTHROPIC_API_KEY + - model_name: "groq/*" + litellm_params: + model: "groq/*" + api_key: os.environ/GROQ_API_KEY - model_name: "*" litellm_params: model: openai/* diff --git a/litellm/router.py b/litellm/router.py index 9afd78322..dc030d369 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -4469,13 +4469,7 @@ class Router: ) model = self.model_group_alias[model] - if model not in self.model_names and self.default_deployment is not None: - updated_deployment = copy.deepcopy( - self.default_deployment - ) # self.default_deployment - updated_deployment["litellm_params"]["model"] = model - return model, updated_deployment - elif model not in self.model_names: + if model not in self.model_names: # check if provider/ specific wildcard routing try: ( @@ -4499,6 +4493,14 @@ class Router: # get_llm_provider raises exception when provider is unknown pass + # check if default deployment is set + if self.default_deployment is not None: + updated_deployment = copy.deepcopy( + self.default_deployment + ) # self.default_deployment + updated_deployment["litellm_params"]["model"] = model + return model, updated_deployment + ## get healthy deployments ### get all deployments healthy_deployments = [m for m in self.model_list if m["model_name"] == model] diff --git a/proxy_server_config.yaml b/proxy_server_config.yaml index 4912ebbbf..57113d350 100644 --- a/proxy_server_config.yaml +++ b/proxy_server_config.yaml @@ -86,12 +86,16 @@ model_list: model: openai/* api_key: os.environ/OPENAI_API_KEY - # Pass through all llm requests to litellm.completion/litellm.embedding - # if user passes model="anthropic/claude-3-opus-20240229" proxy will make requests to anthropic claude-3-opus-20240229 using ANTHROPIC_API_KEY - - model_name: "*" + + # provider specific wildcard routing + - model_name: "anthropic/*" litellm_params: - model: "*" - + model: "anthropic/*" + api_key: os.environ/ANTHROPIC_API_KEY + - model_name: "groq/*" + litellm_params: + model: "groq/*" + api_key: os.environ/GROQ_API_KEY - model_name: mistral-embed litellm_params: model: mistral/mistral-embed diff --git a/tests/test_openai_endpoints.py b/tests/test_openai_endpoints.py index a77da8d52..932b32551 100644 --- a/tests/test_openai_endpoints.py +++ b/tests/test_openai_endpoints.py @@ -119,7 +119,9 @@ async def chat_completion(session, key, model: Union[str, List] = "gpt-4"): print() if status != 200: - raise Exception(f"Request did not return a 200 status code: {status}") + raise Exception( + f"Request did not return a 200 status code: {status}, response text={response_text}" + ) response_header_check( response @@ -485,6 +487,12 @@ async def test_proxy_all_models(): session=session, key=LITELLM_MASTER_KEY, model="groq/llama3-8b-8192" ) + await chat_completion( + session=session, + key=LITELLM_MASTER_KEY, + model="anthropic/claude-3-sonnet-20240229", + ) + @pytest.mark.asyncio async def test_batch_chat_completions():