forked from phoenix/litellm-mirror
proxy - add azure/command r
This commit is contained in:
parent
8df13306ea
commit
9055a071e6
3 changed files with 62 additions and 6 deletions
|
@ -1767,6 +1767,11 @@ class Router:
|
|||
or "ft:gpt-3.5-turbo" in model_name
|
||||
or model_name in litellm.open_ai_embedding_models
|
||||
):
|
||||
if custom_llm_provider == "azure":
|
||||
if litellm.utils._is_non_openai_azure_model(model_name):
|
||||
custom_llm_provider = "openai"
|
||||
# remove azure prefx from model_name
|
||||
model_name = model_name.replace("azure/", "")
|
||||
# glorified / complicated reading of configs
|
||||
# user can pass vars directly or they can pas os.environ/AZURE_API_KEY, in which case we will read the env
|
||||
# we do this here because we init clients for Azure, OpenAI and we need to set the right key
|
||||
|
|
|
@ -440,3 +440,46 @@ def test_openai_with_organization():
|
|||
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
def test_init_clients_azure_command_r_plus():
|
||||
# This tests that the router uses the OpenAI client for Azure/Command-R+
|
||||
# For azure/command-r-plus we need to use openai.OpenAI because of how the Azure provider requires requests being sent
|
||||
litellm.set_verbose = True
|
||||
import logging
|
||||
from litellm._logging import verbose_router_logger
|
||||
|
||||
verbose_router_logger.setLevel(logging.DEBUG)
|
||||
try:
|
||||
print("testing init 4 clients with diff timeouts")
|
||||
model_list = [
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": {
|
||||
"model": "azure/command-r-plus",
|
||||
"api_key": os.getenv("AZURE_COHERE_API_KEY"),
|
||||
"api_base": os.getenv("AZURE_COHERE_API_BASE"),
|
||||
"timeout": 0.01,
|
||||
"stream_timeout": 0.000_001,
|
||||
"max_retries": 7,
|
||||
},
|
||||
},
|
||||
]
|
||||
router = Router(model_list=model_list, set_verbose=True)
|
||||
for elem in router.model_list:
|
||||
model_id = elem["model_info"]["id"]
|
||||
async_client = router.cache.get_cache(f"{model_id}_async_client")
|
||||
stream_async_client = router.cache.get_cache(
|
||||
f"{model_id}_stream_async_client"
|
||||
)
|
||||
# Assert the Async Clients used are OpenAI clients and not Azure
|
||||
# For using Azure/Command-R-Plus and Azure/Mistral the clients NEED to be OpenAI clients used
|
||||
# this is weirdness introduced on Azure's side
|
||||
|
||||
assert "openai.AsyncOpenAI" in str(async_client)
|
||||
assert "openai.AsyncOpenAI" in str(stream_async_client)
|
||||
print("PASSED !")
|
||||
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
|
|
@ -5563,6 +5563,19 @@ def get_formatted_prompt(
|
|||
return prompt
|
||||
|
||||
|
||||
def _is_non_openai_azure_model(model: str) -> bool:
|
||||
try:
|
||||
model_name = model.split("/", 1)[1]
|
||||
if (
|
||||
model_name in litellm.cohere_chat_models
|
||||
or f"mistral/{model_name}" in litellm.mistral_chat_models
|
||||
):
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
return False
|
||||
|
||||
|
||||
def get_llm_provider(
|
||||
model: str,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
|
@ -5576,13 +5589,8 @@ def get_llm_provider(
|
|||
# AZURE AI-Studio Logic - Azure AI Studio supports AZURE/Cohere
|
||||
# If User passes azure/command-r-plus -> we should send it to cohere_chat/command-r-plus
|
||||
if model.split("/", 1)[0] == "azure":
|
||||
model_name = model.split("/", 1)[1]
|
||||
if (
|
||||
model_name in litellm.cohere_chat_models
|
||||
or f"mistral/{model_name}" in litellm.mistral_chat_models
|
||||
):
|
||||
if _is_non_openai_azure_model(model):
|
||||
custom_llm_provider = "openai"
|
||||
model = model_name
|
||||
return model, custom_llm_provider, dynamic_api_key, api_base
|
||||
|
||||
if custom_llm_provider:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue