fix(utils.py): fix dynamic api base

This commit is contained in:
Krrish Dholakia 2024-08-06 11:27:39 -07:00
parent 036a6821d5
commit 34213edb91
2 changed files with 36 additions and 27 deletions

View file

@ -1,20 +1,28 @@
#### What this tests #### #### What this tests ####
# This tests litellm router with batch completion # This tests litellm router with batch completion
import sys, os, time, openai import asyncio
import traceback, asyncio import os
import sys
import time
import traceback
import openai
import pytest import pytest
sys.path.insert( sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
import os
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
import httpx
from dotenv import load_dotenv
import litellm import litellm
from litellm import Router from litellm import Router
from litellm.router import Deployment, LiteLLM_Params, ModelInfo from litellm.router import Deployment, LiteLLM_Params, ModelInfo
from concurrent.futures import ThreadPoolExecutor
from collections import defaultdict
from dotenv import load_dotenv
import os, httpx
load_dotenv() load_dotenv()
@ -54,6 +62,7 @@ async def test_batch_completion_multiple_models(mode):
assert len(response) == 2 assert len(response) == 2
models_in_responses = [] models_in_responses = []
print(f"response: {response}")
for individual_response in response: for individual_response in response:
_model = individual_response["model"] _model = individual_response["model"]
models_in_responses.append(_model) models_in_responses.append(_model)

View file

@ -4491,49 +4491,49 @@ def get_llm_provider(
elif custom_llm_provider == "empower": elif custom_llm_provider == "empower":
api_base = ( api_base = (
api_base api_base
or str(get_secret("EMPOWER_API_BASE")) or get_secret("EMPOWER_API_BASE")
or "https://app.empower.dev/api/v1" or "https://app.empower.dev/api/v1"
) ) # type: ignore
dynamic_api_key = api_key or get_secret("EMPOWER_API_KEY") dynamic_api_key = api_key or get_secret("EMPOWER_API_KEY")
elif custom_llm_provider == "groq": elif custom_llm_provider == "groq":
# groq is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.groq.com/openai/v1 # groq is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.groq.com/openai/v1
api_base = ( api_base = (
api_base api_base
or str(get_secret("GROQ_API_BASE")) or get_secret("GROQ_API_BASE")
or "https://api.groq.com/openai/v1" or "https://api.groq.com/openai/v1"
) ) # type: ignore
dynamic_api_key = api_key or get_secret("GROQ_API_KEY") dynamic_api_key = api_key or get_secret("GROQ_API_KEY")
elif custom_llm_provider == "nvidia_nim": elif custom_llm_provider == "nvidia_nim":
# nvidia_nim is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.endpoints.anyscale.com/v1 # nvidia_nim is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.endpoints.anyscale.com/v1
api_base = ( api_base = (
api_base api_base
or str(get_secret("NVIDIA_NIM_API_BASE")) or get_secret("NVIDIA_NIM_API_BASE")
or "https://integrate.api.nvidia.com/v1" or "https://integrate.api.nvidia.com/v1"
) ) # type: ignore
dynamic_api_key = api_key or get_secret("NVIDIA_NIM_API_KEY") dynamic_api_key = api_key or get_secret("NVIDIA_NIM_API_KEY")
elif custom_llm_provider == "volcengine": elif custom_llm_provider == "volcengine":
# volcengine is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.endpoints.anyscale.com/v1 # volcengine is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.endpoints.anyscale.com/v1
api_base = ( api_base = (
api_base api_base
or str(get_secret("VOLCENGINE_API_BASE")) or get_secret("VOLCENGINE_API_BASE")
or "https://ark.cn-beijing.volces.com/api/v3" or "https://ark.cn-beijing.volces.com/api/v3"
) ) # type: ignore
dynamic_api_key = api_key or get_secret("VOLCENGINE_API_KEY") dynamic_api_key = api_key or get_secret("VOLCENGINE_API_KEY")
elif custom_llm_provider == "codestral": elif custom_llm_provider == "codestral":
# codestral is openai compatible, we just need to set this to custom_openai and have the api_base be https://codestral.mistral.ai/v1 # codestral is openai compatible, we just need to set this to custom_openai and have the api_base be https://codestral.mistral.ai/v1
api_base = ( api_base = (
api_base api_base
or str(get_secret("CODESTRAL_API_BASE")) or get_secret("CODESTRAL_API_BASE")
or "https://codestral.mistral.ai/v1" or "https://codestral.mistral.ai/v1"
) ) # type: ignore
dynamic_api_key = api_key or get_secret("CODESTRAL_API_KEY") dynamic_api_key = api_key or get_secret("CODESTRAL_API_KEY")
elif custom_llm_provider == "deepseek": elif custom_llm_provider == "deepseek":
# deepseek is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.deepseek.com/v1 # deepseek is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.deepseek.com/v1
api_base = ( api_base = (
api_base api_base
or str(get_secret("DEEPSEEK_API_BASE")) or get_secret("DEEPSEEK_API_BASE")
or "https://api.deepseek.com/v1" or "https://api.deepseek.com/v1"
) ) # type: ignore
dynamic_api_key = api_key or get_secret("DEEPSEEK_API_KEY") dynamic_api_key = api_key or get_secret("DEEPSEEK_API_KEY")
elif custom_llm_provider == "fireworks_ai": elif custom_llm_provider == "fireworks_ai":
# fireworks is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.fireworks.ai/inference/v1 # fireworks is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.fireworks.ai/inference/v1
@ -4541,9 +4541,9 @@ def get_llm_provider(
model = f"accounts/fireworks/models/{model}" model = f"accounts/fireworks/models/{model}"
api_base = ( api_base = (
api_base api_base
or str(get_secret("FIREWORKS_API_BASE")) or get_secret("FIREWORKS_API_BASE")
or "https://api.fireworks.ai/inference/v1" or "https://api.fireworks.ai/inference/v1"
) ) # type: ignore
dynamic_api_key = api_key or ( dynamic_api_key = api_key or (
get_secret("FIREWORKS_API_KEY") get_secret("FIREWORKS_API_KEY")
or get_secret("FIREWORKS_AI_API_KEY") or get_secret("FIREWORKS_AI_API_KEY")
@ -4551,7 +4551,7 @@ def get_llm_provider(
or get_secret("FIREWORKS_AI_TOKEN") or get_secret("FIREWORKS_AI_TOKEN")
) )
elif custom_llm_provider == "azure_ai": elif custom_llm_provider == "azure_ai":
api_base = api_base or str(get_secret("AZURE_AI_API_BASE")) api_base = api_base or get_secret("AZURE_AI_API_BASE") # type: ignore
dynamic_api_key = api_key or get_secret("AZURE_AI_API_KEY") dynamic_api_key = api_key or get_secret("AZURE_AI_API_KEY")
elif custom_llm_provider == "github": elif custom_llm_provider == "github":
api_base = api_base or get_secret("GITHUB_API_BASE") or "https://models.inference.ai.azure.com" # type: ignore api_base = api_base or get_secret("GITHUB_API_BASE") or "https://models.inference.ai.azure.com" # type: ignore
@ -4579,16 +4579,16 @@ def get_llm_provider(
# voyage is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.voyageai.com/v1 # voyage is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.voyageai.com/v1
api_base = ( api_base = (
api_base api_base
or str(get_secret("VOYAGE_API_BASE")) or get_secret("VOYAGE_API_BASE")
or "https://api.voyageai.com/v1" or "https://api.voyageai.com/v1"
) ) # type: ignore
dynamic_api_key = api_key or get_secret("VOYAGE_API_KEY") dynamic_api_key = api_key or get_secret("VOYAGE_API_KEY")
elif custom_llm_provider == "together_ai": elif custom_llm_provider == "together_ai":
api_base = ( api_base = (
api_base api_base
or str(get_secret("TOGETHER_AI_API_BASE")) or get_secret("TOGETHER_AI_API_BASE")
or "https://api.together.xyz/v1" or "https://api.together.xyz/v1"
) ) # type: ignore
dynamic_api_key = api_key or ( dynamic_api_key = api_key or (
get_secret("TOGETHER_API_KEY") get_secret("TOGETHER_API_KEY")
or get_secret("TOGETHER_AI_API_KEY") or get_secret("TOGETHER_AI_API_KEY")
@ -4598,9 +4598,9 @@ def get_llm_provider(
elif custom_llm_provider == "friendliai": elif custom_llm_provider == "friendliai":
api_base = ( api_base = (
api_base api_base
or str(get_secret("FRIENDLI_API_BASE")) or get_secret("FRIENDLI_API_BASE")
or "https://inference.friendli.ai/v1" or "https://inference.friendli.ai/v1"
) ) # type: ignore
dynamic_api_key = ( dynamic_api_key = (
api_key api_key
or get_secret("FRIENDLIAI_API_KEY") or get_secret("FRIENDLIAI_API_KEY")