From a6e48db8b06ca0cbe0adebed43e2e46580414bea Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Fri, 19 Jul 2024 19:36:31 -0700 Subject: [PATCH] fix(utils.py): fix get_llm_provider to support dynamic params for openai-compatible providers --- litellm/tests/test_completion.py | 8 ++++++-- litellm/utils.py | 30 ++++++++++++++++-------------- 2 files changed, 22 insertions(+), 16 deletions(-) diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index 87efa86bef..34eebb7124 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -1348,7 +1348,10 @@ def test_completion_fireworks_ai(): pytest.fail(f"Error occurred: {e}") -def test_completion_fireworks_ai_bad_api_base(): +@pytest.mark.parametrize( + "api_key, api_base", [(None, "my-bad-api-base"), ("my-bad-api-key", None)] +) +def test_completion_fireworks_ai_dynamic_params(api_key, api_base): try: litellm.set_verbose = True messages = [ @@ -1361,7 +1364,8 @@ def test_completion_fireworks_ai_bad_api_base(): response = completion( model="fireworks_ai/accounts/fireworks/models/mixtral-8x7b-instruct", messages=messages, - api_base="my-bad-api-base", + api_base=api_base, + api_key=api_key, ) pytest.fail(f"This call should have failed!") except Exception as e: diff --git a/litellm/utils.py b/litellm/utils.py index 809613a091..d8afdbc757 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -4398,44 +4398,44 @@ def get_llm_provider( if custom_llm_provider == "perplexity": # perplexity is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.perplexity.ai api_base = api_base or "https://api.perplexity.ai" - dynamic_api_key = get_secret("PERPLEXITYAI_API_KEY") + dynamic_api_key = api_key or get_secret("PERPLEXITYAI_API_KEY") elif custom_llm_provider == "anyscale": # anyscale is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.endpoints.anyscale.com/v1 api_base = api_base or "https://api.endpoints.anyscale.com/v1" - dynamic_api_key = get_secret("ANYSCALE_API_KEY") + dynamic_api_key = api_key or get_secret("ANYSCALE_API_KEY") elif custom_llm_provider == "deepinfra": # deepinfra is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.endpoints.anyscale.com/v1 api_base = api_base or "https://api.deepinfra.com/v1/openai" - dynamic_api_key = get_secret("DEEPINFRA_API_KEY") + dynamic_api_key = api_key or get_secret("DEEPINFRA_API_KEY") elif custom_llm_provider == "empower": api_base = api_base or "https://app.empower.dev/api/v1" - dynamic_api_key = get_secret("EMPOWER_API_KEY") + dynamic_api_key = api_key or get_secret("EMPOWER_API_KEY") elif custom_llm_provider == "groq": # groq is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.groq.com/openai/v1 api_base = api_base or "https://api.groq.com/openai/v1" - dynamic_api_key = get_secret("GROQ_API_KEY") + dynamic_api_key = api_key or get_secret("GROQ_API_KEY") elif custom_llm_provider == "nvidia_nim": # nvidia_nim is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.endpoints.anyscale.com/v1 api_base = api_base or "https://integrate.api.nvidia.com/v1" - dynamic_api_key = get_secret("NVIDIA_NIM_API_KEY") + dynamic_api_key = api_key or get_secret("NVIDIA_NIM_API_KEY") elif custom_llm_provider == "volcengine": # volcengine is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.endpoints.anyscale.com/v1 api_base = api_base or "https://ark.cn-beijing.volces.com/api/v3" - dynamic_api_key = get_secret("VOLCENGINE_API_KEY") + dynamic_api_key = api_key or get_secret("VOLCENGINE_API_KEY") elif custom_llm_provider == "codestral": # codestral is openai compatible, we just need to set this to custom_openai and have the api_base be https://codestral.mistral.ai/v1 api_base = api_base or "https://codestral.mistral.ai/v1" - dynamic_api_key = get_secret("CODESTRAL_API_KEY") + dynamic_api_key = api_key or get_secret("CODESTRAL_API_KEY") elif custom_llm_provider == "deepseek": # deepseek is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.deepseek.com/v1 api_base = api_base or "https://api.deepseek.com/v1" - dynamic_api_key = get_secret("DEEPSEEK_API_KEY") + dynamic_api_key = api_key or get_secret("DEEPSEEK_API_KEY") elif custom_llm_provider == "fireworks_ai": # fireworks is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.fireworks.ai/inference/v1 if not model.startswith("accounts/fireworks/models"): model = f"accounts/fireworks/models/{model}" api_base = api_base or "https://api.fireworks.ai/inference/v1" - dynamic_api_key = ( + dynamic_api_key = api_key or ( get_secret("FIREWORKS_API_KEY") or get_secret("FIREWORKS_AI_API_KEY") or get_secret("FIREWORKSAI_API_KEY") @@ -4465,10 +4465,10 @@ def get_llm_provider( elif custom_llm_provider == "voyage": # voyage is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.voyageai.com/v1 api_base = "https://api.voyageai.com/v1" - dynamic_api_key = get_secret("VOYAGE_API_KEY") + dynamic_api_key = api_key or get_secret("VOYAGE_API_KEY") elif custom_llm_provider == "together_ai": api_base = "https://api.together.xyz/v1" - dynamic_api_key = ( + dynamic_api_key = api_key or ( get_secret("TOGETHER_API_KEY") or get_secret("TOGETHER_AI_API_KEY") or get_secret("TOGETHERAI_API_KEY") @@ -4476,8 +4476,10 @@ def get_llm_provider( ) elif custom_llm_provider == "friendliai": api_base = "https://inference.friendli.ai/v1" - dynamic_api_key = get_secret("FRIENDLIAI_API_KEY") or get_secret( - "FRIENDLI_TOKEN" + dynamic_api_key = ( + api_key + or get_secret("FRIENDLIAI_API_KEY") + or get_secret("FRIENDLI_TOKEN") ) if api_base is not None and not isinstance(api_base, str): raise Exception(