fix support azure/mistral models

This commit is contained in:
Ishaan Jaff 2024-04-05 09:32:39 -07:00
parent ab60d7c8fb
commit 5ce80d82d3
2 changed files with 12 additions and 7 deletions

View file

@ -260,6 +260,7 @@ open_ai_chat_completion_models: List = []
open_ai_text_completion_models: List = [] open_ai_text_completion_models: List = []
cohere_models: List = [] cohere_models: List = []
cohere_chat_models: List = [] cohere_chat_models: List = []
mistral_chat_models: List = []
anthropic_models: List = [] anthropic_models: List = []
openrouter_models: List = [] openrouter_models: List = []
vertex_language_models: List = [] vertex_language_models: List = []
@ -284,6 +285,8 @@ for key, value in model_cost.items():
cohere_models.append(key) cohere_models.append(key)
elif value.get("litellm_provider") == "cohere_chat": elif value.get("litellm_provider") == "cohere_chat":
cohere_chat_models.append(key) cohere_chat_models.append(key)
elif value.get("litellm_provider") == "mistral":
mistral_chat_models.append(key)
elif value.get("litellm_provider") == "anthropic": elif value.get("litellm_provider") == "anthropic":
anthropic_models.append(key) anthropic_models.append(key)
elif value.get("litellm_provider") == "openrouter": elif value.get("litellm_provider") == "openrouter":

View file

@ -5575,12 +5575,14 @@ def get_llm_provider(
# AZURE AI-Studio Logic - Azure AI Studio supports AZURE/Cohere # AZURE AI-Studio Logic - Azure AI Studio supports AZURE/Cohere
# If User passes azure/command-r-plus -> we should send it to cohere_chat/command-r-plus # If User passes azure/command-r-plus -> we should send it to cohere_chat/command-r-plus
if model.split("/", 1)[0] == "azure":
model_name = model.split("/", 1)[1]
if ( if (
model.split("/", 1)[0] == "azure" model_name in litellm.cohere_chat_models
and model.split("/", 1)[1] in litellm.cohere_chat_models or model_name in litellm.mistral_chat_models
): ):
custom_llm_provider = "openai" custom_llm_provider = "openai"
model = model.split("/", 1)[1] model = model_name
return model, custom_llm_provider, dynamic_api_key, api_base return model, custom_llm_provider, dynamic_api_key, api_base
if custom_llm_provider: if custom_llm_provider: