fix(utils.py): fix get_llm_provider to support dynamic params for openai-compatible providers

This commit is contained in:
Krrish Dholakia 2024-07-19 19:36:31 -07:00
parent e2d275f1b7
commit e45956d77e
2 changed files with 22 additions and 16 deletions

View file

@ -1348,7 +1348,10 @@ def test_completion_fireworks_ai():
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
def test_completion_fireworks_ai_bad_api_base(): @pytest.mark.parametrize(
"api_key, api_base", [(None, "my-bad-api-base"), ("my-bad-api-key", None)]
)
def test_completion_fireworks_ai_dynamic_params(api_key, api_base):
try: try:
litellm.set_verbose = True litellm.set_verbose = True
messages = [ messages = [
@ -1361,7 +1364,8 @@ def test_completion_fireworks_ai_bad_api_base():
response = completion( response = completion(
model="fireworks_ai/accounts/fireworks/models/mixtral-8x7b-instruct", model="fireworks_ai/accounts/fireworks/models/mixtral-8x7b-instruct",
messages=messages, messages=messages,
api_base="my-bad-api-base", api_base=api_base,
api_key=api_key,
) )
pytest.fail(f"This call should have failed!") pytest.fail(f"This call should have failed!")
except Exception as e: except Exception as e:

View file

@ -4398,44 +4398,44 @@ def get_llm_provider(
if custom_llm_provider == "perplexity": if custom_llm_provider == "perplexity":
# perplexity is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.perplexity.ai # perplexity is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.perplexity.ai
api_base = api_base or "https://api.perplexity.ai" api_base = api_base or "https://api.perplexity.ai"
dynamic_api_key = get_secret("PERPLEXITYAI_API_KEY") dynamic_api_key = api_key or get_secret("PERPLEXITYAI_API_KEY")
elif custom_llm_provider == "anyscale": elif custom_llm_provider == "anyscale":
# anyscale is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.endpoints.anyscale.com/v1 # anyscale is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.endpoints.anyscale.com/v1
api_base = api_base or "https://api.endpoints.anyscale.com/v1" api_base = api_base or "https://api.endpoints.anyscale.com/v1"
dynamic_api_key = get_secret("ANYSCALE_API_KEY") dynamic_api_key = api_key or get_secret("ANYSCALE_API_KEY")
elif custom_llm_provider == "deepinfra": elif custom_llm_provider == "deepinfra":
# deepinfra is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.endpoints.anyscale.com/v1 # deepinfra is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.endpoints.anyscale.com/v1
api_base = api_base or "https://api.deepinfra.com/v1/openai" api_base = api_base or "https://api.deepinfra.com/v1/openai"
dynamic_api_key = get_secret("DEEPINFRA_API_KEY") dynamic_api_key = api_key or get_secret("DEEPINFRA_API_KEY")
elif custom_llm_provider == "empower": elif custom_llm_provider == "empower":
api_base = api_base or "https://app.empower.dev/api/v1" api_base = api_base or "https://app.empower.dev/api/v1"
dynamic_api_key = get_secret("EMPOWER_API_KEY") dynamic_api_key = api_key or get_secret("EMPOWER_API_KEY")
elif custom_llm_provider == "groq": elif custom_llm_provider == "groq":
# groq is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.groq.com/openai/v1 # groq is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.groq.com/openai/v1
api_base = api_base or "https://api.groq.com/openai/v1" api_base = api_base or "https://api.groq.com/openai/v1"
dynamic_api_key = get_secret("GROQ_API_KEY") dynamic_api_key = api_key or get_secret("GROQ_API_KEY")
elif custom_llm_provider == "nvidia_nim": elif custom_llm_provider == "nvidia_nim":
# nvidia_nim is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.endpoints.anyscale.com/v1 # nvidia_nim is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.endpoints.anyscale.com/v1
api_base = api_base or "https://integrate.api.nvidia.com/v1" api_base = api_base or "https://integrate.api.nvidia.com/v1"
dynamic_api_key = get_secret("NVIDIA_NIM_API_KEY") dynamic_api_key = api_key or get_secret("NVIDIA_NIM_API_KEY")
elif custom_llm_provider == "volcengine": elif custom_llm_provider == "volcengine":
# volcengine is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.endpoints.anyscale.com/v1 # volcengine is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.endpoints.anyscale.com/v1
api_base = api_base or "https://ark.cn-beijing.volces.com/api/v3" api_base = api_base or "https://ark.cn-beijing.volces.com/api/v3"
dynamic_api_key = get_secret("VOLCENGINE_API_KEY") dynamic_api_key = api_key or get_secret("VOLCENGINE_API_KEY")
elif custom_llm_provider == "codestral": elif custom_llm_provider == "codestral":
# codestral is openai compatible, we just need to set this to custom_openai and have the api_base be https://codestral.mistral.ai/v1 # codestral is openai compatible, we just need to set this to custom_openai and have the api_base be https://codestral.mistral.ai/v1
api_base = api_base or "https://codestral.mistral.ai/v1" api_base = api_base or "https://codestral.mistral.ai/v1"
dynamic_api_key = get_secret("CODESTRAL_API_KEY") dynamic_api_key = api_key or get_secret("CODESTRAL_API_KEY")
elif custom_llm_provider == "deepseek": elif custom_llm_provider == "deepseek":
# deepseek is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.deepseek.com/v1 # deepseek is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.deepseek.com/v1
api_base = api_base or "https://api.deepseek.com/v1" api_base = api_base or "https://api.deepseek.com/v1"
dynamic_api_key = get_secret("DEEPSEEK_API_KEY") dynamic_api_key = api_key or get_secret("DEEPSEEK_API_KEY")
elif custom_llm_provider == "fireworks_ai": elif custom_llm_provider == "fireworks_ai":
# fireworks is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.fireworks.ai/inference/v1 # fireworks is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.fireworks.ai/inference/v1
if not model.startswith("accounts/fireworks/models"): if not model.startswith("accounts/fireworks/models"):
model = f"accounts/fireworks/models/{model}" model = f"accounts/fireworks/models/{model}"
api_base = api_base or "https://api.fireworks.ai/inference/v1" api_base = api_base or "https://api.fireworks.ai/inference/v1"
dynamic_api_key = ( dynamic_api_key = api_key or (
get_secret("FIREWORKS_API_KEY") get_secret("FIREWORKS_API_KEY")
or get_secret("FIREWORKS_AI_API_KEY") or get_secret("FIREWORKS_AI_API_KEY")
or get_secret("FIREWORKSAI_API_KEY") or get_secret("FIREWORKSAI_API_KEY")
@ -4465,10 +4465,10 @@ def get_llm_provider(
elif custom_llm_provider == "voyage": elif custom_llm_provider == "voyage":
# voyage is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.voyageai.com/v1 # voyage is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.voyageai.com/v1
api_base = "https://api.voyageai.com/v1" api_base = "https://api.voyageai.com/v1"
dynamic_api_key = get_secret("VOYAGE_API_KEY") dynamic_api_key = api_key or get_secret("VOYAGE_API_KEY")
elif custom_llm_provider == "together_ai": elif custom_llm_provider == "together_ai":
api_base = "https://api.together.xyz/v1" api_base = "https://api.together.xyz/v1"
dynamic_api_key = ( dynamic_api_key = api_key or (
get_secret("TOGETHER_API_KEY") get_secret("TOGETHER_API_KEY")
or get_secret("TOGETHER_AI_API_KEY") or get_secret("TOGETHER_AI_API_KEY")
or get_secret("TOGETHERAI_API_KEY") or get_secret("TOGETHERAI_API_KEY")
@ -4476,8 +4476,10 @@ def get_llm_provider(
) )
elif custom_llm_provider == "friendliai": elif custom_llm_provider == "friendliai":
api_base = "https://inference.friendli.ai/v1" api_base = "https://inference.friendli.ai/v1"
dynamic_api_key = get_secret("FRIENDLIAI_API_KEY") or get_secret( dynamic_api_key = (
"FRIENDLI_TOKEN" api_key
or get_secret("FRIENDLIAI_API_KEY")
or get_secret("FRIENDLI_TOKEN")
) )
if api_base is not None and not isinstance(api_base, str): if api_base is not None and not isinstance(api_base, str):
raise Exception( raise Exception(