Merge pull request #2861 from BerriAI/litellm_add_azure_command_r_plust

[FEAT] add azure command-r-plus
This commit is contained in:
Ishaan Jaff 2024-04-05 15:13:35 -07:00 committed by GitHub
commit 2174b240d8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 76 additions and 35 deletions

View file

@ -1,46 +1,24 @@
# Azure AI Studio
## Using Mistral models deployed on Azure AI Studio
## Sample Usage
The `azure/` prefix sends this to Azure
### Sample Usage - setting env vars
Set `MISTRAL_AZURE_API_KEY` and `MISTRAL_AZURE_API_BASE` in your env
```shell
MISTRAL_AZURE_API_KEY = "zE************""
MISTRAL_AZURE_API_BASE = "https://Mistral-large-nmefg-serverless.eastus2.inference.ai.azure.com/v1"
Ensure you add `/v1` to your api_base. Your Azure AI studio `api_base` passed to litellm should look something like this
```python
api_base = "https://Mistral-large-dfgfj-serverless.eastus2.inference.ai.azure.com/v1/"
```
```python
from litellm import completion
import os
response = completion(
model="mistral/Mistral-large-dfgfj",
messages=[
{"role": "user", "content": "hello from litellm"}
],
import litellm
response = litellm.completion(
model="azure/command-r-plus",
api_base="<your-deployment-base>/v1/"
api_key="eskk******"
messages=[{"role": "user", "content": "What is the meaning of life?"}],
)
print(response)
```
### Sample Usage - passing `api_base` and `api_key` to `litellm.completion`
```python
from litellm import completion
import os
response = completion(
model="mistral/Mistral-large-dfgfj",
api_base="https://Mistral-large-dfgfj-serverless.eastus2.inference.ai.azure.com",
api_key = "JGbKodRcTp****"
messages=[
{"role": "user", "content": "hello from litellm"}
],
)
print(response)
```
### [LiteLLM Proxy] Using Mistral Models
## Sample Usage - LiteLLM Proxy
Set this on your litellm proxy config.yaml
```yaml
@ -48,8 +26,17 @@ model_list:
- model_name: mistral
litellm_params:
model: mistral/Mistral-large-dfgfj
api_base: https://Mistral-large-dfgfj-serverless.eastus2.inference.ai.azure.com
api_base: https://Mistral-large-dfgfj-serverless.eastus2.inference.ai.azure.com/v1/
api_key: JGbKodRcTp****
```
## Supported Models
| Model Name | Function Call |
|--------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| command-r-plus | `completion(model="azure/command-r-plus", messages)` |
| command-r | `completion(model="azure/command-r", messages)` |
| mistral-large-latest | `completion(model="azure/mistral-large-latest", messages)` |

View file

@ -260,6 +260,7 @@ open_ai_chat_completion_models: List = []
open_ai_text_completion_models: List = []
cohere_models: List = []
cohere_chat_models: List = []
mistral_chat_models: List = []
anthropic_models: List = []
openrouter_models: List = []
vertex_language_models: List = []
@ -285,6 +286,8 @@ for key, value in model_cost.items():
cohere_models.append(key)
elif value.get("litellm_provider") == "cohere_chat":
cohere_chat_models.append(key)
elif value.get("litellm_provider") == "mistral":
mistral_chat_models.append(key)
elif value.get("litellm_provider") == "anthropic":
anthropic_models.append(key)
elif value.get("litellm_provider") == "openrouter":

View file

@ -474,6 +474,16 @@
"mode": "chat",
"supports_function_calling": true
},
"azure/command-r-plus": {
"max_tokens": 4096,
"max_input_tokens": 128000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.000003,
"output_cost_per_token": 0.000015,
"litellm_provider": "azure",
"mode": "chat",
"supports_function_calling": true
},
"azure/ada": {
"max_tokens": 8191,
"max_input_tokens": 8191,

View file

@ -53,6 +53,24 @@ def test_completion_custom_provider_model_name():
# test_completion_custom_provider_model_name()
def test_completion_azure_command_r():
try:
litellm.set_verbose = True
response = completion(
model="azure/command-r-plus",
api_base=os.getenv("AZURE_COHERE_API_BASE"),
api_key=os.getenv("AZURE_COHERE_API_KEY"),
messages=[{"role": "user", "content": "What is the meaning of life?"}],
)
print(response)
except litellm.Timeout as e:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")
def test_completion_claude():
litellm.set_verbose = True
litellm.cache = None

View file

@ -652,6 +652,7 @@ def load_vertex_ai_credentials():
@pytest.mark.asyncio
@pytest.mark.skip(reason="Skipping on this PR to test other stuff")
async def test_async_chat_vertex_ai_stream():
try:
load_vertex_ai_credentials()

View file

@ -5588,6 +5588,18 @@ def get_llm_provider(
dynamic_api_key = None
# check if llm provider provided
# AZURE AI-Studio Logic - Azure AI Studio supports AZURE/Cohere
# If User passes azure/command-r-plus -> we should send it to cohere_chat/command-r-plus
if model.split("/", 1)[0] == "azure":
model_name = model.split("/", 1)[1]
if (
model_name in litellm.cohere_chat_models
or f"mistral/{model_name}" in litellm.mistral_chat_models
):
custom_llm_provider = "openai"
model = model_name
return model, custom_llm_provider, dynamic_api_key, api_base
if custom_llm_provider:
return model, custom_llm_provider, dynamic_api_key, api_base

View file

@ -474,6 +474,16 @@
"mode": "chat",
"supports_function_calling": true
},
"azure/command-r-plus": {
"max_tokens": 4096,
"max_input_tokens": 128000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.000003,
"output_cost_per_token": 0.000015,
"litellm_provider": "azure",
"mode": "chat",
"supports_function_calling": true
},
"azure/ada": {
"max_tokens": 8191,
"max_input_tokens": 8191,