forked from phoenix/litellm-mirror
Merge pull request #2868 from BerriAI/litellm_add_command_r_on_proxy
Add Azure Command-r-plus on litellm proxy
This commit is contained in:
commit
faa0d38087
4 changed files with 147 additions and 23 deletions
|
@ -1,13 +1,21 @@
|
||||||
|
import Tabs from '@theme/Tabs';
|
||||||
|
import TabItem from '@theme/TabItem';
|
||||||
|
|
||||||
# Azure AI Studio
|
# Azure AI Studio
|
||||||
|
|
||||||
## Sample Usage
|
## Sample Usage
|
||||||
The `azure/` prefix sends this to Azure
|
|
||||||
|
|
||||||
Ensure you add `/v1` to your api_base. Your Azure AI studio `api_base` passed to litellm should look something like this
|
**Ensure the following:**
|
||||||
```python
|
1. The API Base passed ends in the `/v1/` prefix
|
||||||
api_base = "https://Mistral-large-dfgfj-serverless.eastus2.inference.ai.azure.com/v1/"
|
example:
|
||||||
```
|
```python
|
||||||
|
api_base = "https://Mistral-large-dfgfj-serverless.eastus2.inference.ai.azure.com/v1/"
|
||||||
|
```
|
||||||
|
|
||||||
|
2. The `model` passed is listed in [supported models](#supported-models). You **DO NOT** Need to pass your deployment name to litellm. Example `model=azure/Mistral-large-nmefg`
|
||||||
|
|
||||||
|
|
||||||
|
**Quick Start**
|
||||||
```python
|
```python
|
||||||
import litellm
|
import litellm
|
||||||
response = litellm.completion(
|
response = litellm.completion(
|
||||||
|
@ -20,23 +28,83 @@ response = litellm.completion(
|
||||||
|
|
||||||
## Sample Usage - LiteLLM Proxy
|
## Sample Usage - LiteLLM Proxy
|
||||||
|
|
||||||
Set this on your litellm proxy config.yaml
|
1. Add models to your config.yaml
|
||||||
```yaml
|
|
||||||
model_list:
|
```yaml
|
||||||
|
model_list:
|
||||||
- model_name: mistral
|
- model_name: mistral
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: mistral/Mistral-large-dfgfj
|
model: azure/mistral-large-latest
|
||||||
api_base: https://Mistral-large-dfgfj-serverless.eastus2.inference.ai.azure.com/v1/
|
api_base: https://Mistral-large-dfgfj-serverless.eastus2.inference.ai.azure.com/v1/
|
||||||
api_key: JGbKodRcTp****
|
api_key: JGbKodRcTp****
|
||||||
```
|
- model_name: command-r-plus
|
||||||
|
litellm_params:
|
||||||
|
model: azure/command-r-plus
|
||||||
|
api_key: os.environ/AZURE_COHERE_API_KEY
|
||||||
|
api_base: os.environ/AZURE_COHERE_API_BASE
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
2. Start the proxy
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ litellm --config /path/to/config.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Send Request to LiteLLM Proxy Server
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
|
||||||
|
<TabItem value="openai" label="OpenAI Python v1.0.0+">
|
||||||
|
|
||||||
|
```python
|
||||||
|
import openai
|
||||||
|
client = openai.OpenAI(
|
||||||
|
api_key="sk-1234", # pass litellm proxy key, if you're using virtual keys
|
||||||
|
base_url="http://0.0.0.0:4000" # litellm-proxy-base url
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model="mistral",
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "what llm are you"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
print(response)
|
||||||
|
```
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
|
<TabItem value="curl" label="curl">
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||||
|
--header 'Authorization: Bearer sk-1234' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--data '{
|
||||||
|
"model": "mistral",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "what llm are you"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
## Supported Models
|
## Supported Models
|
||||||
|
|
||||||
| Model Name | Function Call |
|
| Model Name | Function Call |
|
||||||
|--------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
|--------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||||
| command-r-plus | `completion(model="azure/command-r-plus", messages)` |
|
| Cohere command-r-plus | `completion(model="azure/command-r-plus", messages)` |
|
||||||
| command-r | `completion(model="azure/command-r", messages)` |
|
| Cohere ommand-r | `completion(model="azure/command-r", messages)` |
|
||||||
| mistral-large-latest | `completion(model="azure/mistral-large-latest", messages)` |
|
| mistral-large-latest | `completion(model="azure/mistral-large-latest", messages)` |
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1772,6 +1772,11 @@ class Router:
|
||||||
or "ft:gpt-3.5-turbo" in model_name
|
or "ft:gpt-3.5-turbo" in model_name
|
||||||
or model_name in litellm.open_ai_embedding_models
|
or model_name in litellm.open_ai_embedding_models
|
||||||
):
|
):
|
||||||
|
if custom_llm_provider == "azure":
|
||||||
|
if litellm.utils._is_non_openai_azure_model(model_name):
|
||||||
|
custom_llm_provider = "openai"
|
||||||
|
# remove azure prefx from model_name
|
||||||
|
model_name = model_name.replace("azure/", "")
|
||||||
# glorified / complicated reading of configs
|
# glorified / complicated reading of configs
|
||||||
# user can pass vars directly or they can pas os.environ/AZURE_API_KEY, in which case we will read the env
|
# user can pass vars directly or they can pas os.environ/AZURE_API_KEY, in which case we will read the env
|
||||||
# we do this here because we init clients for Azure, OpenAI and we need to set the right key
|
# we do this here because we init clients for Azure, OpenAI and we need to set the right key
|
||||||
|
|
|
@ -447,3 +447,46 @@ def test_openai_with_organization():
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_clients_azure_command_r_plus():
|
||||||
|
# This tests that the router uses the OpenAI client for Azure/Command-R+
|
||||||
|
# For azure/command-r-plus we need to use openai.OpenAI because of how the Azure provider requires requests being sent
|
||||||
|
litellm.set_verbose = True
|
||||||
|
import logging
|
||||||
|
from litellm._logging import verbose_router_logger
|
||||||
|
|
||||||
|
verbose_router_logger.setLevel(logging.DEBUG)
|
||||||
|
try:
|
||||||
|
print("testing init 4 clients with diff timeouts")
|
||||||
|
model_list = [
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "azure/command-r-plus",
|
||||||
|
"api_key": os.getenv("AZURE_COHERE_API_KEY"),
|
||||||
|
"api_base": os.getenv("AZURE_COHERE_API_BASE"),
|
||||||
|
"timeout": 0.01,
|
||||||
|
"stream_timeout": 0.000_001,
|
||||||
|
"max_retries": 7,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
router = Router(model_list=model_list, set_verbose=True)
|
||||||
|
for elem in router.model_list:
|
||||||
|
model_id = elem["model_info"]["id"]
|
||||||
|
async_client = router.cache.get_cache(f"{model_id}_async_client")
|
||||||
|
stream_async_client = router.cache.get_cache(
|
||||||
|
f"{model_id}_stream_async_client"
|
||||||
|
)
|
||||||
|
# Assert the Async Clients used are OpenAI clients and not Azure
|
||||||
|
# For using Azure/Command-R-Plus and Azure/Mistral the clients NEED to be OpenAI clients used
|
||||||
|
# this is weirdness introduced on Azure's side
|
||||||
|
|
||||||
|
assert "openai.AsyncOpenAI" in str(async_client)
|
||||||
|
assert "openai.AsyncOpenAI" in str(stream_async_client)
|
||||||
|
print("PASSED !")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
traceback.print_exc()
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
|
@ -5578,6 +5578,19 @@ def get_formatted_prompt(
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
|
def _is_non_openai_azure_model(model: str) -> bool:
|
||||||
|
try:
|
||||||
|
model_name = model.split("/", 1)[1]
|
||||||
|
if (
|
||||||
|
model_name in litellm.cohere_chat_models
|
||||||
|
or f"mistral/{model_name}" in litellm.mistral_chat_models
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
except:
|
||||||
|
return False
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def get_llm_provider(
|
def get_llm_provider(
|
||||||
model: str,
|
model: str,
|
||||||
custom_llm_provider: Optional[str] = None,
|
custom_llm_provider: Optional[str] = None,
|
||||||
|
@ -5591,13 +5604,8 @@ def get_llm_provider(
|
||||||
# AZURE AI-Studio Logic - Azure AI Studio supports AZURE/Cohere
|
# AZURE AI-Studio Logic - Azure AI Studio supports AZURE/Cohere
|
||||||
# If User passes azure/command-r-plus -> we should send it to cohere_chat/command-r-plus
|
# If User passes azure/command-r-plus -> we should send it to cohere_chat/command-r-plus
|
||||||
if model.split("/", 1)[0] == "azure":
|
if model.split("/", 1)[0] == "azure":
|
||||||
model_name = model.split("/", 1)[1]
|
if _is_non_openai_azure_model(model):
|
||||||
if (
|
|
||||||
model_name in litellm.cohere_chat_models
|
|
||||||
or f"mistral/{model_name}" in litellm.mistral_chat_models
|
|
||||||
):
|
|
||||||
custom_llm_provider = "openai"
|
custom_llm_provider = "openai"
|
||||||
model = model_name
|
|
||||||
return model, custom_llm_provider, dynamic_api_key, api_base
|
return model, custom_llm_provider, dynamic_api_key, api_base
|
||||||
|
|
||||||
if custom_llm_provider:
|
if custom_llm_provider:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue