Merge pull request #2868 from BerriAI/litellm_add_command_r_on_proxy

Add Azure Command-r-plus on litellm proxy
This commit is contained in:
Ishaan Jaff 2024-04-05 15:13:47 -07:00 committed by GitHub
commit faa0d38087
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 147 additions and 23 deletions

View file

@ -447,3 +447,46 @@ def test_openai_with_organization():
except Exception as e:
pytest.fail(f"Error occurred: {e}")
def test_init_clients_azure_command_r_plus():
# This tests that the router uses the OpenAI client for Azure/Command-R+
# For azure/command-r-plus we need to use openai.OpenAI because of how the Azure provider requires requests being sent
litellm.set_verbose = True
import logging
from litellm._logging import verbose_router_logger
verbose_router_logger.setLevel(logging.DEBUG)
try:
print("testing init 4 clients with diff timeouts")
model_list = [
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "azure/command-r-plus",
"api_key": os.getenv("AZURE_COHERE_API_KEY"),
"api_base": os.getenv("AZURE_COHERE_API_BASE"),
"timeout": 0.01,
"stream_timeout": 0.000_001,
"max_retries": 7,
},
},
]
router = Router(model_list=model_list, set_verbose=True)
for elem in router.model_list:
model_id = elem["model_info"]["id"]
async_client = router.cache.get_cache(f"{model_id}_async_client")
stream_async_client = router.cache.get_cache(
f"{model_id}_stream_async_client"
)
# Assert the Async Clients used are OpenAI clients and not Azure
# For using Azure/Command-R-Plus and Azure/Mistral the clients NEED to be OpenAI clients used
# this is weirdness introduced on Azure's side
assert "openai.AsyncOpenAI" in str(async_client)
assert "openai.AsyncOpenAI" in str(stream_async_client)
print("PASSED !")
except Exception as e:
traceback.print_exc()
pytest.fail(f"Error occurred: {e}")