fix(utils.py): fix dynamic api base

This commit is contained in:
Krrish Dholakia 2024-08-06 11:27:39 -07:00
parent 036a6821d5
commit 34213edb91
2 changed files with 36 additions and 27 deletions

View file

@ -4491,49 +4491,49 @@ def get_llm_provider(
elif custom_llm_provider == "empower":
api_base = (
api_base
or str(get_secret("EMPOWER_API_BASE"))
or get_secret("EMPOWER_API_BASE")
or "https://app.empower.dev/api/v1"
)
) # type: ignore
dynamic_api_key = api_key or get_secret("EMPOWER_API_KEY")
elif custom_llm_provider == "groq":
# groq is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.groq.com/openai/v1
api_base = (
api_base
or str(get_secret("GROQ_API_BASE"))
or get_secret("GROQ_API_BASE")
or "https://api.groq.com/openai/v1"
)
) # type: ignore
dynamic_api_key = api_key or get_secret("GROQ_API_KEY")
elif custom_llm_provider == "nvidia_nim":
# nvidia_nim is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.endpoints.anyscale.com/v1
api_base = (
api_base
or str(get_secret("NVIDIA_NIM_API_BASE"))
or get_secret("NVIDIA_NIM_API_BASE")
or "https://integrate.api.nvidia.com/v1"
)
) # type: ignore
dynamic_api_key = api_key or get_secret("NVIDIA_NIM_API_KEY")
elif custom_llm_provider == "volcengine":
# volcengine is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.endpoints.anyscale.com/v1
api_base = (
api_base
or str(get_secret("VOLCENGINE_API_BASE"))
or get_secret("VOLCENGINE_API_BASE")
or "https://ark.cn-beijing.volces.com/api/v3"
)
) # type: ignore
dynamic_api_key = api_key or get_secret("VOLCENGINE_API_KEY")
elif custom_llm_provider == "codestral":
# codestral is openai compatible, we just need to set this to custom_openai and have the api_base be https://codestral.mistral.ai/v1
api_base = (
api_base
or str(get_secret("CODESTRAL_API_BASE"))
or get_secret("CODESTRAL_API_BASE")
or "https://codestral.mistral.ai/v1"
)
) # type: ignore
dynamic_api_key = api_key or get_secret("CODESTRAL_API_KEY")
elif custom_llm_provider == "deepseek":
# deepseek is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.deepseek.com/v1
api_base = (
api_base
or str(get_secret("DEEPSEEK_API_BASE"))
or get_secret("DEEPSEEK_API_BASE")
or "https://api.deepseek.com/v1"
)
) # type: ignore
dynamic_api_key = api_key or get_secret("DEEPSEEK_API_KEY")
elif custom_llm_provider == "fireworks_ai":
# fireworks is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.fireworks.ai/inference/v1
@ -4541,9 +4541,9 @@ def get_llm_provider(
model = f"accounts/fireworks/models/{model}"
api_base = (
api_base
or str(get_secret("FIREWORKS_API_BASE"))
or get_secret("FIREWORKS_API_BASE")
or "https://api.fireworks.ai/inference/v1"
)
) # type: ignore
dynamic_api_key = api_key or (
get_secret("FIREWORKS_API_KEY")
or get_secret("FIREWORKS_AI_API_KEY")
@ -4551,7 +4551,7 @@ def get_llm_provider(
or get_secret("FIREWORKS_AI_TOKEN")
)
elif custom_llm_provider == "azure_ai":
api_base = api_base or str(get_secret("AZURE_AI_API_BASE"))
api_base = api_base or get_secret("AZURE_AI_API_BASE") # type: ignore
dynamic_api_key = api_key or get_secret("AZURE_AI_API_KEY")
elif custom_llm_provider == "github":
api_base = api_base or get_secret("GITHUB_API_BASE") or "https://models.inference.ai.azure.com" # type: ignore
@ -4579,16 +4579,16 @@ def get_llm_provider(
# voyage is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.voyageai.com/v1
api_base = (
api_base
or str(get_secret("VOYAGE_API_BASE"))
or get_secret("VOYAGE_API_BASE")
or "https://api.voyageai.com/v1"
)
) # type: ignore
dynamic_api_key = api_key or get_secret("VOYAGE_API_KEY")
elif custom_llm_provider == "together_ai":
api_base = (
api_base
or str(get_secret("TOGETHER_AI_API_BASE"))
or get_secret("TOGETHER_AI_API_BASE")
or "https://api.together.xyz/v1"
)
) # type: ignore
dynamic_api_key = api_key or (
get_secret("TOGETHER_API_KEY")
or get_secret("TOGETHER_AI_API_KEY")
@ -4598,9 +4598,9 @@ def get_llm_provider(
elif custom_llm_provider == "friendliai":
api_base = (
api_base
or str(get_secret("FRIENDLI_API_BASE"))
or get_secret("FRIENDLI_API_BASE")
or "https://inference.friendli.ai/v1"
)
) # type: ignore
dynamic_api_key = (
api_key
or get_secret("FRIENDLIAI_API_KEY")