fix(utils.py): allow passing dynamic api base for openai-compatible endpoints

This commit is contained in:
Krrish Dholakia 2024-07-15 20:00:44 -07:00
parent 9e07af6c02
commit a15ba2592a
2 changed files with 30 additions and 10 deletions

View file

@ -4255,44 +4255,44 @@ def get_llm_provider(
model = model.split("/", 1)[1]
if custom_llm_provider == "perplexity":
# perplexity is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.perplexity.ai
api_base = "https://api.perplexity.ai"
api_base = api_base or "https://api.perplexity.ai"
dynamic_api_key = get_secret("PERPLEXITYAI_API_KEY")
elif custom_llm_provider == "anyscale":
# anyscale is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.endpoints.anyscale.com/v1
api_base = "https://api.endpoints.anyscale.com/v1"
api_base = api_base or "https://api.endpoints.anyscale.com/v1"
dynamic_api_key = get_secret("ANYSCALE_API_KEY")
elif custom_llm_provider == "deepinfra":
# deepinfra is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.endpoints.anyscale.com/v1
api_base = "https://api.deepinfra.com/v1/openai"
api_base = api_base or "https://api.deepinfra.com/v1/openai"
dynamic_api_key = get_secret("DEEPINFRA_API_KEY")
elif custom_llm_provider == "empower":
api_base = "https://app.empower.dev/api/v1"
api_base = api_base or "https://app.empower.dev/api/v1"
dynamic_api_key = get_secret("EMPOWER_API_KEY")
elif custom_llm_provider == "groq":
# groq is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.groq.com/openai/v1
api_base = "https://api.groq.com/openai/v1"
api_base = api_base or "https://api.groq.com/openai/v1"
dynamic_api_key = get_secret("GROQ_API_KEY")
elif custom_llm_provider == "nvidia_nim":
# nvidia_nim is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.endpoints.anyscale.com/v1
api_base = "https://integrate.api.nvidia.com/v1"
api_base = api_base or "https://integrate.api.nvidia.com/v1"
dynamic_api_key = get_secret("NVIDIA_NIM_API_KEY")
elif custom_llm_provider == "volcengine":
# volcengine is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.endpoints.anyscale.com/v1
api_base = "https://ark.cn-beijing.volces.com/api/v3"
api_base = api_base or "https://ark.cn-beijing.volces.com/api/v3"
dynamic_api_key = get_secret("VOLCENGINE_API_KEY")
elif custom_llm_provider == "codestral":
# codestral is openai compatible, we just need to set this to custom_openai and have the api_base be https://codestral.mistral.ai/v1
api_base = "https://codestral.mistral.ai/v1"
api_base = api_base or "https://codestral.mistral.ai/v1"
dynamic_api_key = get_secret("CODESTRAL_API_KEY")
elif custom_llm_provider == "deepseek":
# deepseek is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.deepseek.com/v1
api_base = "https://api.deepseek.com/v1"
api_base = api_base or "https://api.deepseek.com/v1"
dynamic_api_key = get_secret("DEEPSEEK_API_KEY")
elif custom_llm_provider == "fireworks_ai":
# fireworks is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.fireworks.ai/inference/v1
if not model.startswith("accounts/fireworks/models"):
model = f"accounts/fireworks/models/{model}"
api_base = "https://api.fireworks.ai/inference/v1"
api_base = api_base or "https://api.fireworks.ai/inference/v1"
dynamic_api_key = (
get_secret("FIREWORKS_API_KEY")
or get_secret("FIREWORKS_AI_API_KEY")