Merge pull request #4134 from BerriAI/litellm_azure_ai_route

Azure AI support all models
This commit is contained in:
Krish Dholakia 2024-06-11 18:24:05 -07:00 committed by GitHub
commit a53ba9b2fb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 186 additions and 54 deletions

View file

@ -3,49 +3,151 @@ import TabItem from '@theme/TabItem';
# Azure AI Studio # Azure AI Studio
**Ensure the following:** LiteLLM supports all models on Azure AI Studio
1. The API Base passed ends in the `/v1/` prefix
example:
```python
api_base = "https://Mistral-large-dfgfj-serverless.eastus2.inference.ai.azure.com/v1/"
```
2. The `model` passed is listed in [supported models](#supported-models). You **DO NOT** Need to pass your deployment name to litellm. Example `model=azure/Mistral-large-nmefg`
## Usage ## Usage
<Tabs> <Tabs>
<TabItem value="sdk" label="SDK"> <TabItem value="sdk" label="SDK">
### ENV VAR
```python ```python
import litellm import os
response = litellm.completion( os.environ["AZURE_API_API_KEY"] = ""
model="azure/command-r-plus", os.environ["AZURE_AI_API_BASE"] = ""
api_base="<your-deployment-base>/v1/" ```
api_key="eskk******"
messages=[{"role": "user", "content": "What is the meaning of life?"}], ### Example Call
```python
from litellm import completion
import os
## set ENV variables
os.environ["AZURE_API_API_KEY"] = "azure ai key"
os.environ["AZURE_AI_API_BASE"] = "azure ai base url" # e.g.: https://Mistral-large-dfgfj-serverless.eastus2.inference.ai.azure.com/
# predibase llama-3 call
response = completion(
model="azure_ai/command-r-plus",
messages = [{ "content": "Hello, how are you?","role": "user"}]
) )
``` ```
</TabItem> </TabItem>
<TabItem value="proxy" label="PROXY"> <TabItem value="proxy" label="PROXY">
## Sample Usage - LiteLLM Proxy
1. Add models to your config.yaml 1. Add models to your config.yaml
```yaml ```yaml
model_list: model_list:
- model_name: mistral
litellm_params:
model: azure/mistral-large-latest
api_base: https://Mistral-large-dfgfj-serverless.eastus2.inference.ai.azure.com/v1/
api_key: JGbKodRcTp****
- model_name: command-r-plus - model_name: command-r-plus
litellm_params: litellm_params:
model: azure/command-r-plus model: azure_ai/command-r-plus
api_key: os.environ/AZURE_COHERE_API_KEY api_key: os.environ/AZURE_AI_API_KEY
api_base: os.environ/AZURE_COHERE_API_BASE api_base: os.environ/AZURE_AI_API_BASE
```
2. Start the proxy
```bash
$ litellm --config /path/to/config.yaml --debug
```
3. Send Request to LiteLLM Proxy Server
<Tabs>
<TabItem value="openai" label="OpenAI Python v1.0.0+">
```python
import openai
client = openai.OpenAI(
api_key="sk-1234", # pass litellm proxy key, if you're using virtual keys
base_url="http://0.0.0.0:4000" # litellm-proxy-base url
)
response = client.chat.completions.create(
model="command-r-plus",
messages = [
{
"role": "system",
"content": "Be a good human!"
},
{
"role": "user",
"content": "What do you know about earth?"
}
]
)
print(response)
```
</TabItem>
<TabItem value="curl" label="curl">
```shell
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data '{
"model": "command-r-plus",
"messages": [
{
"role": "system",
"content": "Be a good human!"
},
{
"role": "user",
"content": "What do you know about earth?"
}
],
}'
```
</TabItem>
</Tabs>
</TabItem>
</Tabs>
## Passing additional params - max_tokens, temperature
See all litellm.completion supported params [here](../completion/input.md#translated-openai-params)
```python
# !pip install litellm
from litellm import completion
import os
## set ENV variables
os.environ["AZURE_AI_API_KEY"] = "azure ai api key"
os.environ["AZURE_AI_API_BASE"] = "azure ai api base"
# command r plus call
response = completion(
model="azure_ai/command-r-plus",
messages = [{ "content": "Hello, how are you?","role": "user"}],
max_tokens=20,
temperature=0.5
)
```
**proxy**
```yaml
model_list:
- model_name: command-r-plus
litellm_params:
model: azure_ai/command-r-plus
api_key: os.environ/AZURE_AI_API_KEY
api_base: os.environ/AZURE_AI_API_BASE
max_tokens: 20
temperature: 0.5
``` ```
@ -103,9 +205,6 @@ response = litellm.completion(
</Tabs> </Tabs>
</TabItem>
</Tabs>
## Function Calling ## Function Calling
<Tabs> <Tabs>
@ -115,8 +214,8 @@ response = litellm.completion(
from litellm import completion from litellm import completion
# set env # set env
os.environ["AZURE_MISTRAL_API_KEY"] = "your-api-key" os.environ["AZURE_AI_API_KEY"] = "your-api-key"
os.environ["AZURE_MISTRAL_API_BASE"] = "your-api-base" os.environ["AZURE_AI_API_BASE"] = "your-api-base"
tools = [ tools = [
{ {
@ -141,9 +240,7 @@ tools = [
messages = [{"role": "user", "content": "What's the weather like in Boston today?"}] messages = [{"role": "user", "content": "What's the weather like in Boston today?"}]
response = completion( response = completion(
model="azure/mistral-large-latest", model="azure_ai/mistral-large-latest",
api_base=os.getenv("AZURE_MISTRAL_API_BASE")
api_key=os.getenv("AZURE_MISTRAL_API_KEY")
messages=messages, messages=messages,
tools=tools, tools=tools,
tool_choice="auto", tool_choice="auto",
@ -206,10 +303,12 @@ curl http://0.0.0.0:4000/v1/chat/completions \
## Supported Models ## Supported Models
LiteLLM supports **ALL** azure ai models. Here's a few examples:
| Model Name | Function Call | | Model Name | Function Call |
|--------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------| |--------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| Cohere command-r-plus | `completion(model="azure/command-r-plus", messages)` | | Cohere command-r-plus | `completion(model="azure/command-r-plus", messages)` |
| Cohere ommand-r | `completion(model="azure/command-r", messages)` | | Cohere command-r | `completion(model="azure/command-r", messages)` |
| mistral-large-latest | `completion(model="azure/mistral-large-latest", messages)` | | mistral-large-latest | `completion(model="azure/mistral-large-latest", messages)` |

View file

@ -1,4 +1,4 @@
# 🆕 Clarifai # Clarifai
Anthropic, OpenAI, Mistral, Llama and Gemini LLMs are Supported on Clarifai. Anthropic, OpenAI, Mistral, Llama and Gemini LLMs are Supported on Clarifai.
## Pre-Requisites ## Pre-Requisites

View file

@ -125,11 +125,12 @@ See all litellm.completion supported params [here](../completion/input.md#transl
from litellm import completion from litellm import completion
import os import os
## set ENV variables ## set ENV variables
os.environ["PREDIBASE_API_KEY"] = "predibase key" os.environ["DATABRICKS_API_KEY"] = "databricks key"
os.environ["DATABRICKS_API_BASE"] = "databricks api base"
# predibae llama-3 call # databricks dbrx call
response = completion( response = completion(
model="predibase/llama3-8b-instruct", model="databricks/databricks-dbrx-instruct",
messages = [{ "content": "Hello, how are you?","role": "user"}], messages = [{ "content": "Hello, how are you?","role": "user"}],
max_tokens=20, max_tokens=20,
temperature=0.5 temperature=0.5

View file

@ -62,6 +62,7 @@ post_call_rules: List[Callable] = []
turn_off_message_logging: Optional[bool] = False turn_off_message_logging: Optional[bool] = False
log_raw_request_response: bool = False log_raw_request_response: bool = False
redact_messages_in_exceptions: Optional[bool] = False redact_messages_in_exceptions: Optional[bool] = False
log_raw_request_response: bool = False
store_audit_logs = False # Enterprise feature, allow users to see audit logs store_audit_logs = False # Enterprise feature, allow users to see audit logs
## end of callbacks ############# ## end of callbacks #############
@ -406,6 +407,7 @@ openai_compatible_providers: List = [
"xinference", "xinference",
"together_ai", "together_ai",
"fireworks_ai", "fireworks_ai",
"azure_ai",
] ]
@ -610,6 +612,7 @@ provider_list: List = [
"baseten", "baseten",
"azure", "azure",
"azure_text", "azure_text",
"azure_ai",
"sagemaker", "sagemaker",
"bedrock", "bedrock",
"vllm", "vllm",

View file

@ -11,6 +11,7 @@ import os, openai, sys, json, inspect, uuid, datetime, threading
from typing import Any, Literal, Union, BinaryIO from typing import Any, Literal, Union, BinaryIO
from typing_extensions import overload from typing_extensions import overload
from functools import partial from functools import partial
import dotenv, traceback, random, asyncio, time, contextvars import dotenv, traceback, random, asyncio, time, contextvars
from copy import deepcopy from copy import deepcopy
import httpx import httpx

View file

@ -114,6 +114,27 @@ def test_null_role_response():
assert response.choices[0].message.role == "assistant" assert response.choices[0].message.role == "assistant"
def test_completion_azure_ai_command_r():
try:
import os
litellm.set_verbose = True
os.environ["AZURE_AI_API_BASE"] = os.getenv("AZURE_COHERE_API_BASE", "")
os.environ["AZURE_AI_API_KEY"] = os.getenv("AZURE_COHERE_API_KEY", "")
response: litellm.ModelResponse = completion(
model="azure_ai/command-r-plus",
messages=[{"role": "user", "content": "What is the meaning of life?"}],
) # type: ignore
assert "azure_ai" in response.model
except litellm.Timeout as e:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")
def test_completion_azure_command_r(): def test_completion_azure_command_r():
try: try:
litellm.set_verbose = True litellm.set_verbose = True
@ -721,7 +742,11 @@ def test_completion_claude_3_function_plus_image():
print(response) print(response)
def test_completion_azure_mistral_large_function_calling(): @pytest.mark.parametrize(
"provider",
["azure", "azure_ai"],
)
def test_completion_azure_mistral_large_function_calling(provider):
""" """
This primarily tests if the 'Function()' pydantic object correctly handles argument param passed in as a dict vs. string This primarily tests if the 'Function()' pydantic object correctly handles argument param passed in as a dict vs. string
""" """
@ -752,8 +777,9 @@ def test_completion_azure_mistral_large_function_calling():
"content": "What's the weather like in Boston today in Fahrenheit?", "content": "What's the weather like in Boston today in Fahrenheit?",
} }
] ]
response = completion( response = completion(
model="azure/mistral-large-latest", model="{}/mistral-large-latest".format(provider),
api_base=os.getenv("AZURE_MISTRAL_API_BASE"), api_base=os.getenv("AZURE_MISTRAL_API_BASE"),
api_key=os.getenv("AZURE_MISTRAL_API_KEY"), api_key=os.getenv("AZURE_MISTRAL_API_KEY"),
messages=messages, messages=messages,

View file

@ -34,14 +34,15 @@ class MyCustomHandler(CustomLogger):
self.response_cost = 0 self.response_cost = 0
def log_pre_api_call(self, model, messages, kwargs): def log_pre_api_call(self, model, messages, kwargs):
print(f"Pre-API Call") print("Pre-API Call")
traceback.print_stack()
self.data_sent_to_api = kwargs["additional_args"].get("complete_input_dict", {}) self.data_sent_to_api = kwargs["additional_args"].get("complete_input_dict", {})
def log_post_api_call(self, kwargs, response_obj, start_time, end_time): def log_post_api_call(self, kwargs, response_obj, start_time, end_time):
print(f"Post-API Call") print("Post-API Call")
def log_stream_event(self, kwargs, response_obj, start_time, end_time): def log_stream_event(self, kwargs, response_obj, start_time, end_time):
print(f"On Stream") print("On Stream")
def log_success_event(self, kwargs, response_obj, start_time, end_time): def log_success_event(self, kwargs, response_obj, start_time, end_time):
print(f"On Success") print(f"On Success")
@ -372,6 +373,7 @@ async def test_async_custom_handler_embedding_optional_param():
Tests if the openai optional params for embedding - user + encoding_format, Tests if the openai optional params for embedding - user + encoding_format,
are logged are logged
""" """
litellm.set_verbose = True
customHandler_optional_params = MyCustomHandler() customHandler_optional_params = MyCustomHandler()
litellm.callbacks = [customHandler_optional_params] litellm.callbacks = [customHandler_optional_params]
response = await litellm.aembedding( response = await litellm.aembedding(

View file

@ -938,7 +938,6 @@ class TextCompletionResponse(OpenAIObject):
object=None, object=None,
**params, **params,
): ):
if stream: if stream:
object = "text_completion.chunk" object = "text_completion.chunk"
choices = [TextChoices()] choices = [TextChoices()]
@ -947,7 +946,6 @@ class TextCompletionResponse(OpenAIObject):
if choices is not None and isinstance(choices, list): if choices is not None and isinstance(choices, list):
new_choices = [] new_choices = []
for choice in choices: for choice in choices:
if isinstance(choice, TextChoices): if isinstance(choice, TextChoices):
_new_choice = choice _new_choice = choice
elif isinstance(choice, dict): elif isinstance(choice, dict):
@ -1023,7 +1021,6 @@ class ImageObject(OpenAIObject):
revised_prompt: Optional[str] = None revised_prompt: Optional[str] = None
def __init__(self, b64_json=None, url=None, revised_prompt=None): def __init__(self, b64_json=None, url=None, revised_prompt=None):
super().__init__(b64_json=b64_json, url=url, revised_prompt=revised_prompt) super().__init__(b64_json=b64_json, url=url, revised_prompt=revised_prompt)
def __contains__(self, key): def __contains__(self, key):
@ -1627,7 +1624,6 @@ class Logging:
end_time=end_time, end_time=end_time,
) )
except Exception as e: except Exception as e:
complete_streaming_response = None complete_streaming_response = None
else: else:
self.sync_streaming_chunks.append(result) self.sync_streaming_chunks.append(result)
@ -2397,7 +2393,6 @@ class Logging:
"async_complete_streaming_response" "async_complete_streaming_response"
in self.model_call_details in self.model_call_details
): ):
await customLogger.async_log_event( await customLogger.async_log_event(
kwargs=self.model_call_details, kwargs=self.model_call_details,
response_obj=self.model_call_details[ response_obj=self.model_call_details[
@ -6172,14 +6167,17 @@ def get_api_base(
if litellm.model_alias_map and model in litellm.model_alias_map: if litellm.model_alias_map and model in litellm.model_alias_map:
model = litellm.model_alias_map[model] model = litellm.model_alias_map[model]
try: try:
model, custom_llm_provider, dynamic_api_key, dynamic_api_base = ( (
get_llm_provider( model,
custom_llm_provider,
dynamic_api_key,
dynamic_api_base,
) = get_llm_provider(
model=model, model=model,
custom_llm_provider=_optional_params.custom_llm_provider, custom_llm_provider=_optional_params.custom_llm_provider,
api_base=_optional_params.api_base, api_base=_optional_params.api_base,
api_key=_optional_params.api_key, api_key=_optional_params.api_key,
) )
)
except Exception as e: except Exception as e:
verbose_logger.debug("Error occurred in getting api base - {}".format(str(e))) verbose_logger.debug("Error occurred in getting api base - {}".format(str(e)))
custom_llm_provider = None custom_llm_provider = None
@ -6583,6 +6581,9 @@ def get_llm_provider(
or get_secret("FIREWORKSAI_API_KEY") or get_secret("FIREWORKSAI_API_KEY")
or get_secret("FIREWORKS_AI_TOKEN") or get_secret("FIREWORKS_AI_TOKEN")
) )
elif custom_llm_provider == "azure_ai":
api_base = api_base or get_secret("AZURE_AI_API_BASE") # type: ignore
dynamic_api_key = api_key or get_secret("AZURE_AI_API_KEY")
elif custom_llm_provider == "mistral": elif custom_llm_provider == "mistral":
# mistral is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.mistral.ai # mistral is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.mistral.ai
api_base = ( api_base = (
@ -7454,7 +7455,6 @@ def validate_environment(model: Optional[str] = None) -> dict:
def set_callbacks(callback_list, function_id=None): def set_callbacks(callback_list, function_id=None):
global sentry_sdk_instance, capture_exception, add_breadcrumb, posthog, slack_app, alerts_channel, traceloopLogger, athinaLogger, heliconeLogger, aispendLogger, berrispendLogger, supabaseClient, liteDebuggerClient, lunaryLogger, promptLayerLogger, langFuseLogger, customLogger, weightsBiasesLogger, langsmithLogger, logfireLogger, dynamoLogger, s3Logger, dataDogLogger, prometheusLogger, greenscaleLogger, openMeterLogger global sentry_sdk_instance, capture_exception, add_breadcrumb, posthog, slack_app, alerts_channel, traceloopLogger, athinaLogger, heliconeLogger, aispendLogger, berrispendLogger, supabaseClient, liteDebuggerClient, lunaryLogger, promptLayerLogger, langFuseLogger, customLogger, weightsBiasesLogger, langsmithLogger, logfireLogger, dynamoLogger, s3Logger, dataDogLogger, prometheusLogger, greenscaleLogger, openMeterLogger
try: try: