Merge pull request #4134 from BerriAI/litellm_azure_ai_route

Azure AI support all models
This commit is contained in:
Krish Dholakia 2024-06-11 18:24:05 -07:00 committed by GitHub
commit a53ba9b2fb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 186 additions and 54 deletions

View file

@ -3,53 +3,155 @@ import TabItem from '@theme/TabItem';
# Azure AI Studio
**Ensure the following:**
1. The API Base passed ends in the `/v1/` prefix
example:
```python
api_base = "https://Mistral-large-dfgfj-serverless.eastus2.inference.ai.azure.com/v1/"
```
LiteLLM supports all models on Azure AI Studio
2. The `model` passed is listed in [supported models](#supported-models). You **DO NOT** Need to pass your deployment name to litellm. Example `model=azure/Mistral-large-nmefg`
## Usage
<Tabs>
<TabItem value="sdk" label="SDK">
### ENV VAR
```python
import litellm
response = litellm.completion(
model="azure/command-r-plus",
api_base="<your-deployment-base>/v1/"
api_key="eskk******"
messages=[{"role": "user", "content": "What is the meaning of life?"}],
import os
os.environ["AZURE_API_API_KEY"] = ""
os.environ["AZURE_AI_API_BASE"] = ""
```
### Example Call
```python
from litellm import completion
import os
## set ENV variables
os.environ["AZURE_API_API_KEY"] = "azure ai key"
os.environ["AZURE_AI_API_BASE"] = "azure ai base url" # e.g.: https://Mistral-large-dfgfj-serverless.eastus2.inference.ai.azure.com/
# predibase llama-3 call
response = completion(
model="azure_ai/command-r-plus",
messages = [{ "content": "Hello, how are you?","role": "user"}]
)
```
</TabItem>
<TabItem value="proxy" label="PROXY">
## Sample Usage - LiteLLM Proxy
1. Add models to your config.yaml
```yaml
model_list:
- model_name: mistral
litellm_params:
model: azure/mistral-large-latest
api_base: https://Mistral-large-dfgfj-serverless.eastus2.inference.ai.azure.com/v1/
api_key: JGbKodRcTp****
- model_name: command-r-plus
litellm_params:
model: azure/command-r-plus
api_key: os.environ/AZURE_COHERE_API_KEY
api_base: os.environ/AZURE_COHERE_API_BASE
model: azure_ai/command-r-plus
api_key: os.environ/AZURE_AI_API_KEY
api_base: os.environ/AZURE_AI_API_BASE
```
2. Start the proxy
```bash
$ litellm --config /path/to/config.yaml --debug
```
3. Send Request to LiteLLM Proxy Server
<Tabs>
<TabItem value="openai" label="OpenAI Python v1.0.0+">
```python
import openai
client = openai.OpenAI(
api_key="sk-1234", # pass litellm proxy key, if you're using virtual keys
base_url="http://0.0.0.0:4000" # litellm-proxy-base url
)
response = client.chat.completions.create(
model="command-r-plus",
messages = [
{
"role": "system",
"content": "Be a good human!"
},
{
"role": "user",
"content": "What do you know about earth?"
}
]
)
print(response)
```
</TabItem>
<TabItem value="curl" label="curl">
```shell
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data '{
"model": "command-r-plus",
"messages": [
{
"role": "system",
"content": "Be a good human!"
},
{
"role": "user",
"content": "What do you know about earth?"
}
],
}'
```
</TabItem>
</Tabs>
</TabItem>
</Tabs>
## Passing additional params - max_tokens, temperature
See all litellm.completion supported params [here](../completion/input.md#translated-openai-params)
```python
# !pip install litellm
from litellm import completion
import os
## set ENV variables
os.environ["AZURE_AI_API_KEY"] = "azure ai api key"
os.environ["AZURE_AI_API_BASE"] = "azure ai api base"
# command r plus call
response = completion(
model="azure_ai/command-r-plus",
messages = [{ "content": "Hello, how are you?","role": "user"}],
max_tokens=20,
temperature=0.5
)
```
**proxy**
```yaml
model_list:
- model_name: command-r-plus
litellm_params:
model: azure_ai/command-r-plus
api_key: os.environ/AZURE_AI_API_KEY
api_base: os.environ/AZURE_AI_API_BASE
max_tokens: 20
temperature: 0.5
```
2. Start the proxy
```bash
@ -103,9 +205,6 @@ response = litellm.completion(
</Tabs>
</TabItem>
</Tabs>
## Function Calling
<Tabs>
@ -115,8 +214,8 @@ response = litellm.completion(
from litellm import completion
# set env
os.environ["AZURE_MISTRAL_API_KEY"] = "your-api-key"
os.environ["AZURE_MISTRAL_API_BASE"] = "your-api-base"
os.environ["AZURE_AI_API_KEY"] = "your-api-key"
os.environ["AZURE_AI_API_BASE"] = "your-api-base"
tools = [
{
@ -141,9 +240,7 @@ tools = [
messages = [{"role": "user", "content": "What's the weather like in Boston today?"}]
response = completion(
model="azure/mistral-large-latest",
api_base=os.getenv("AZURE_MISTRAL_API_BASE")
api_key=os.getenv("AZURE_MISTRAL_API_KEY")
model="azure_ai/mistral-large-latest",
messages=messages,
tools=tools,
tool_choice="auto",
@ -206,10 +303,12 @@ curl http://0.0.0.0:4000/v1/chat/completions \
## Supported Models
LiteLLM supports **ALL** azure ai models. Here's a few examples:
| Model Name | Function Call |
|--------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| Cohere command-r-plus | `completion(model="azure/command-r-plus", messages)` |
| Cohere ommand-r | `completion(model="azure/command-r", messages)` |
| Cohere command-r | `completion(model="azure/command-r", messages)` |
| mistral-large-latest | `completion(model="azure/mistral-large-latest", messages)` |

View file

@ -1,4 +1,4 @@
# 🆕 Clarifai
# Clarifai
Anthropic, OpenAI, Mistral, Llama and Gemini LLMs are Supported on Clarifai.
## Pre-Requisites

View file

@ -125,11 +125,12 @@ See all litellm.completion supported params [here](../completion/input.md#transl
from litellm import completion
import os
## set ENV variables
os.environ["PREDIBASE_API_KEY"] = "predibase key"
os.environ["DATABRICKS_API_KEY"] = "databricks key"
os.environ["DATABRICKS_API_BASE"] = "databricks api base"
# predibae llama-3 call
# databricks dbrx call
response = completion(
model="predibase/llama3-8b-instruct",
model="databricks/databricks-dbrx-instruct",
messages = [{ "content": "Hello, how are you?","role": "user"}],
max_tokens=20,
temperature=0.5

View file

@ -62,6 +62,7 @@ post_call_rules: List[Callable] = []
turn_off_message_logging: Optional[bool] = False
log_raw_request_response: bool = False
redact_messages_in_exceptions: Optional[bool] = False
log_raw_request_response: bool = False
store_audit_logs = False # Enterprise feature, allow users to see audit logs
## end of callbacks #############
@ -406,6 +407,7 @@ openai_compatible_providers: List = [
"xinference",
"together_ai",
"fireworks_ai",
"azure_ai",
]
@ -610,6 +612,7 @@ provider_list: List = [
"baseten",
"azure",
"azure_text",
"azure_ai",
"sagemaker",
"bedrock",
"vllm",

View file

@ -11,6 +11,7 @@ import os, openai, sys, json, inspect, uuid, datetime, threading
from typing import Any, Literal, Union, BinaryIO
from typing_extensions import overload
from functools import partial
import dotenv, traceback, random, asyncio, time, contextvars
from copy import deepcopy
import httpx

View file

@ -114,6 +114,27 @@ def test_null_role_response():
assert response.choices[0].message.role == "assistant"
def test_completion_azure_ai_command_r():
try:
import os
litellm.set_verbose = True
os.environ["AZURE_AI_API_BASE"] = os.getenv("AZURE_COHERE_API_BASE", "")
os.environ["AZURE_AI_API_KEY"] = os.getenv("AZURE_COHERE_API_KEY", "")
response: litellm.ModelResponse = completion(
model="azure_ai/command-r-plus",
messages=[{"role": "user", "content": "What is the meaning of life?"}],
) # type: ignore
assert "azure_ai" in response.model
except litellm.Timeout as e:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")
def test_completion_azure_command_r():
try:
litellm.set_verbose = True
@ -721,7 +742,11 @@ def test_completion_claude_3_function_plus_image():
print(response)
def test_completion_azure_mistral_large_function_calling():
@pytest.mark.parametrize(
"provider",
["azure", "azure_ai"],
)
def test_completion_azure_mistral_large_function_calling(provider):
"""
This primarily tests if the 'Function()' pydantic object correctly handles argument param passed in as a dict vs. string
"""
@ -752,8 +777,9 @@ def test_completion_azure_mistral_large_function_calling():
"content": "What's the weather like in Boston today in Fahrenheit?",
}
]
response = completion(
model="azure/mistral-large-latest",
model="{}/mistral-large-latest".format(provider),
api_base=os.getenv("AZURE_MISTRAL_API_BASE"),
api_key=os.getenv("AZURE_MISTRAL_API_KEY"),
messages=messages,

View file

@ -34,14 +34,15 @@ class MyCustomHandler(CustomLogger):
self.response_cost = 0
def log_pre_api_call(self, model, messages, kwargs):
print(f"Pre-API Call")
print("Pre-API Call")
traceback.print_stack()
self.data_sent_to_api = kwargs["additional_args"].get("complete_input_dict", {})
def log_post_api_call(self, kwargs, response_obj, start_time, end_time):
print(f"Post-API Call")
print("Post-API Call")
def log_stream_event(self, kwargs, response_obj, start_time, end_time):
print(f"On Stream")
print("On Stream")
def log_success_event(self, kwargs, response_obj, start_time, end_time):
print(f"On Success")
@ -372,6 +373,7 @@ async def test_async_custom_handler_embedding_optional_param():
Tests if the openai optional params for embedding - user + encoding_format,
are logged
"""
litellm.set_verbose = True
customHandler_optional_params = MyCustomHandler()
litellm.callbacks = [customHandler_optional_params]
response = await litellm.aembedding(

View file

@ -938,7 +938,6 @@ class TextCompletionResponse(OpenAIObject):
object=None,
**params,
):
if stream:
object = "text_completion.chunk"
choices = [TextChoices()]
@ -947,7 +946,6 @@ class TextCompletionResponse(OpenAIObject):
if choices is not None and isinstance(choices, list):
new_choices = []
for choice in choices:
if isinstance(choice, TextChoices):
_new_choice = choice
elif isinstance(choice, dict):
@ -1023,7 +1021,6 @@ class ImageObject(OpenAIObject):
revised_prompt: Optional[str] = None
def __init__(self, b64_json=None, url=None, revised_prompt=None):
super().__init__(b64_json=b64_json, url=url, revised_prompt=revised_prompt)
def __contains__(self, key):
@ -1627,7 +1624,6 @@ class Logging:
end_time=end_time,
)
except Exception as e:
complete_streaming_response = None
else:
self.sync_streaming_chunks.append(result)
@ -2397,7 +2393,6 @@ class Logging:
"async_complete_streaming_response"
in self.model_call_details
):
await customLogger.async_log_event(
kwargs=self.model_call_details,
response_obj=self.model_call_details[
@ -6172,13 +6167,16 @@ def get_api_base(
if litellm.model_alias_map and model in litellm.model_alias_map:
model = litellm.model_alias_map[model]
try:
model, custom_llm_provider, dynamic_api_key, dynamic_api_base = (
get_llm_provider(
model=model,
custom_llm_provider=_optional_params.custom_llm_provider,
api_base=_optional_params.api_base,
api_key=_optional_params.api_key,
)
(
model,
custom_llm_provider,
dynamic_api_key,
dynamic_api_base,
) = get_llm_provider(
model=model,
custom_llm_provider=_optional_params.custom_llm_provider,
api_base=_optional_params.api_base,
api_key=_optional_params.api_key,
)
except Exception as e:
verbose_logger.debug("Error occurred in getting api base - {}".format(str(e)))
@ -6583,6 +6581,9 @@ def get_llm_provider(
or get_secret("FIREWORKSAI_API_KEY")
or get_secret("FIREWORKS_AI_TOKEN")
)
elif custom_llm_provider == "azure_ai":
api_base = api_base or get_secret("AZURE_AI_API_BASE") # type: ignore
dynamic_api_key = api_key or get_secret("AZURE_AI_API_KEY")
elif custom_llm_provider == "mistral":
# mistral is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.mistral.ai
api_base = (
@ -7454,7 +7455,6 @@ def validate_environment(model: Optional[str] = None) -> dict:
def set_callbacks(callback_list, function_id=None):
global sentry_sdk_instance, capture_exception, add_breadcrumb, posthog, slack_app, alerts_channel, traceloopLogger, athinaLogger, heliconeLogger, aispendLogger, berrispendLogger, supabaseClient, liteDebuggerClient, lunaryLogger, promptLayerLogger, langFuseLogger, customLogger, weightsBiasesLogger, langsmithLogger, logfireLogger, dynamoLogger, s3Logger, dataDogLogger, prometheusLogger, greenscaleLogger, openMeterLogger
try: