diff --git a/docs/my-website/docs/providers/azure_ai.md b/docs/my-website/docs/providers/azure_ai.md
index ed13c5664..87b8041ef 100644
--- a/docs/my-website/docs/providers/azure_ai.md
+++ b/docs/my-website/docs/providers/azure_ai.md
@@ -3,53 +3,155 @@ import TabItem from '@theme/TabItem';
# Azure AI Studio
-**Ensure the following:**
-1. The API Base passed ends in the `/v1/` prefix
- example:
- ```python
- api_base = "https://Mistral-large-dfgfj-serverless.eastus2.inference.ai.azure.com/v1/"
- ```
+LiteLLM supports all models on Azure AI Studio
-2. The `model` passed is listed in [supported models](#supported-models). You **DO NOT** Need to pass your deployment name to litellm. Example `model=azure/Mistral-large-nmefg`
## Usage
+### ENV VAR
```python
-import litellm
-response = litellm.completion(
- model="azure/command-r-plus",
- api_base="/v1/"
- api_key="eskk******"
- messages=[{"role": "user", "content": "What is the meaning of life?"}],
+import os
+os.environ["AZURE_API_API_KEY"] = ""
+os.environ["AZURE_AI_API_BASE"] = ""
+```
+
+### Example Call
+
+```python
+from litellm import completion
+import os
+## set ENV variables
+os.environ["AZURE_API_API_KEY"] = "azure ai key"
+os.environ["AZURE_AI_API_BASE"] = "azure ai base url" # e.g.: https://Mistral-large-dfgfj-serverless.eastus2.inference.ai.azure.com/
+
+# predibase llama-3 call
+response = completion(
+ model="azure_ai/command-r-plus",
+ messages = [{ "content": "Hello, how are you?","role": "user"}]
)
```
-## Sample Usage - LiteLLM Proxy
-
1. Add models to your config.yaml
```yaml
model_list:
- - model_name: mistral
- litellm_params:
- model: azure/mistral-large-latest
- api_base: https://Mistral-large-dfgfj-serverless.eastus2.inference.ai.azure.com/v1/
- api_key: JGbKodRcTp****
- model_name: command-r-plus
litellm_params:
- model: azure/command-r-plus
- api_key: os.environ/AZURE_COHERE_API_KEY
- api_base: os.environ/AZURE_COHERE_API_BASE
+ model: azure_ai/command-r-plus
+ api_key: os.environ/AZURE_AI_API_KEY
+ api_base: os.environ/AZURE_AI_API_BASE
```
+2. Start the proxy
+
+ ```bash
+ $ litellm --config /path/to/config.yaml --debug
+ ```
+
+3. Send Request to LiteLLM Proxy Server
+
+
+
+
+
+ ```python
+ import openai
+ client = openai.OpenAI(
+ api_key="sk-1234", # pass litellm proxy key, if you're using virtual keys
+ base_url="http://0.0.0.0:4000" # litellm-proxy-base url
+ )
+
+ response = client.chat.completions.create(
+ model="command-r-plus",
+ messages = [
+ {
+ "role": "system",
+ "content": "Be a good human!"
+ },
+ {
+ "role": "user",
+ "content": "What do you know about earth?"
+ }
+ ]
+ )
+
+ print(response)
+ ```
+
+
+
+
+
+ ```shell
+ curl --location 'http://0.0.0.0:4000/chat/completions' \
+ --header 'Authorization: Bearer sk-1234' \
+ --header 'Content-Type: application/json' \
+ --data '{
+ "model": "command-r-plus",
+ "messages": [
+ {
+ "role": "system",
+ "content": "Be a good human!"
+ },
+ {
+ "role": "user",
+ "content": "What do you know about earth?"
+ }
+ ],
+ }'
+ ```
+
+
+
+
+
+
+
+
+
+## Passing additional params - max_tokens, temperature
+See all litellm.completion supported params [here](../completion/input.md#translated-openai-params)
+
+```python
+# !pip install litellm
+from litellm import completion
+import os
+## set ENV variables
+os.environ["AZURE_AI_API_KEY"] = "azure ai api key"
+os.environ["AZURE_AI_API_BASE"] = "azure ai api base"
+
+# command r plus call
+response = completion(
+ model="azure_ai/command-r-plus",
+ messages = [{ "content": "Hello, how are you?","role": "user"}],
+ max_tokens=20,
+ temperature=0.5
+)
+```
+
+**proxy**
+
+```yaml
+ model_list:
+ - model_name: command-r-plus
+ litellm_params:
+ model: azure_ai/command-r-plus
+ api_key: os.environ/AZURE_AI_API_KEY
+ api_base: os.environ/AZURE_AI_API_BASE
+ max_tokens: 20
+ temperature: 0.5
+```
+
+
+
2. Start the proxy
```bash
@@ -103,9 +205,6 @@ response = litellm.completion(
-
-
-
## Function Calling
@@ -115,8 +214,8 @@ response = litellm.completion(
from litellm import completion
# set env
-os.environ["AZURE_MISTRAL_API_KEY"] = "your-api-key"
-os.environ["AZURE_MISTRAL_API_BASE"] = "your-api-base"
+os.environ["AZURE_AI_API_KEY"] = "your-api-key"
+os.environ["AZURE_AI_API_BASE"] = "your-api-base"
tools = [
{
@@ -141,9 +240,7 @@ tools = [
messages = [{"role": "user", "content": "What's the weather like in Boston today?"}]
response = completion(
- model="azure/mistral-large-latest",
- api_base=os.getenv("AZURE_MISTRAL_API_BASE")
- api_key=os.getenv("AZURE_MISTRAL_API_KEY")
+ model="azure_ai/mistral-large-latest",
messages=messages,
tools=tools,
tool_choice="auto",
@@ -206,10 +303,12 @@ curl http://0.0.0.0:4000/v1/chat/completions \
## Supported Models
+LiteLLM supports **ALL** azure ai models. Here's a few examples:
+
| Model Name | Function Call |
|--------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| Cohere command-r-plus | `completion(model="azure/command-r-plus", messages)` |
-| Cohere ommand-r | `completion(model="azure/command-r", messages)` |
+| Cohere command-r | `completion(model="azure/command-r", messages)` |
| mistral-large-latest | `completion(model="azure/mistral-large-latest", messages)` |
diff --git a/docs/my-website/docs/providers/clarifai.md b/docs/my-website/docs/providers/clarifai.md
index 85ee8fa26..085ab8ed9 100644
--- a/docs/my-website/docs/providers/clarifai.md
+++ b/docs/my-website/docs/providers/clarifai.md
@@ -1,4 +1,4 @@
-# 🆕 Clarifai
+# Clarifai
Anthropic, OpenAI, Mistral, Llama and Gemini LLMs are Supported on Clarifai.
## Pre-Requisites
diff --git a/docs/my-website/docs/providers/databricks.md b/docs/my-website/docs/providers/databricks.md
index 08a3e4f76..24c7c40cf 100644
--- a/docs/my-website/docs/providers/databricks.md
+++ b/docs/my-website/docs/providers/databricks.md
@@ -125,11 +125,12 @@ See all litellm.completion supported params [here](../completion/input.md#transl
from litellm import completion
import os
## set ENV variables
-os.environ["PREDIBASE_API_KEY"] = "predibase key"
+os.environ["DATABRICKS_API_KEY"] = "databricks key"
+os.environ["DATABRICKS_API_BASE"] = "databricks api base"
-# predibae llama-3 call
+# databricks dbrx call
response = completion(
- model="predibase/llama3-8b-instruct",
+ model="databricks/databricks-dbrx-instruct",
messages = [{ "content": "Hello, how are you?","role": "user"}],
max_tokens=20,
temperature=0.5
diff --git a/litellm/__init__.py b/litellm/__init__.py
index 4ddc4552c..82bf85e3f 100644
--- a/litellm/__init__.py
+++ b/litellm/__init__.py
@@ -62,6 +62,7 @@ post_call_rules: List[Callable] = []
turn_off_message_logging: Optional[bool] = False
log_raw_request_response: bool = False
redact_messages_in_exceptions: Optional[bool] = False
+log_raw_request_response: bool = False
store_audit_logs = False # Enterprise feature, allow users to see audit logs
## end of callbacks #############
@@ -406,6 +407,7 @@ openai_compatible_providers: List = [
"xinference",
"together_ai",
"fireworks_ai",
+ "azure_ai",
]
@@ -610,6 +612,7 @@ provider_list: List = [
"baseten",
"azure",
"azure_text",
+ "azure_ai",
"sagemaker",
"bedrock",
"vllm",
diff --git a/litellm/main.py b/litellm/main.py
index 2c906e990..bf4931168 100644
--- a/litellm/main.py
+++ b/litellm/main.py
@@ -11,6 +11,7 @@ import os, openai, sys, json, inspect, uuid, datetime, threading
from typing import Any, Literal, Union, BinaryIO
from typing_extensions import overload
from functools import partial
+
import dotenv, traceback, random, asyncio, time, contextvars
from copy import deepcopy
import httpx
diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py
index 2428cbf48..8a7557c35 100644
--- a/litellm/tests/test_completion.py
+++ b/litellm/tests/test_completion.py
@@ -114,6 +114,27 @@ def test_null_role_response():
assert response.choices[0].message.role == "assistant"
+def test_completion_azure_ai_command_r():
+ try:
+ import os
+
+ litellm.set_verbose = True
+
+ os.environ["AZURE_AI_API_BASE"] = os.getenv("AZURE_COHERE_API_BASE", "")
+ os.environ["AZURE_AI_API_KEY"] = os.getenv("AZURE_COHERE_API_KEY", "")
+
+ response: litellm.ModelResponse = completion(
+ model="azure_ai/command-r-plus",
+ messages=[{"role": "user", "content": "What is the meaning of life?"}],
+ ) # type: ignore
+
+ assert "azure_ai" in response.model
+ except litellm.Timeout as e:
+ pass
+ except Exception as e:
+ pytest.fail(f"Error occurred: {e}")
+
+
def test_completion_azure_command_r():
try:
litellm.set_verbose = True
@@ -721,7 +742,11 @@ def test_completion_claude_3_function_plus_image():
print(response)
-def test_completion_azure_mistral_large_function_calling():
+@pytest.mark.parametrize(
+ "provider",
+ ["azure", "azure_ai"],
+)
+def test_completion_azure_mistral_large_function_calling(provider):
"""
This primarily tests if the 'Function()' pydantic object correctly handles argument param passed in as a dict vs. string
"""
@@ -752,8 +777,9 @@ def test_completion_azure_mistral_large_function_calling():
"content": "What's the weather like in Boston today in Fahrenheit?",
}
]
+
response = completion(
- model="azure/mistral-large-latest",
+ model="{}/mistral-large-latest".format(provider),
api_base=os.getenv("AZURE_MISTRAL_API_BASE"),
api_key=os.getenv("AZURE_MISTRAL_API_KEY"),
messages=messages,
diff --git a/litellm/tests/test_custom_logger.py b/litellm/tests/test_custom_logger.py
index c7df31214..e3407c9e1 100644
--- a/litellm/tests/test_custom_logger.py
+++ b/litellm/tests/test_custom_logger.py
@@ -34,14 +34,15 @@ class MyCustomHandler(CustomLogger):
self.response_cost = 0
def log_pre_api_call(self, model, messages, kwargs):
- print(f"Pre-API Call")
+ print("Pre-API Call")
+ traceback.print_stack()
self.data_sent_to_api = kwargs["additional_args"].get("complete_input_dict", {})
def log_post_api_call(self, kwargs, response_obj, start_time, end_time):
- print(f"Post-API Call")
+ print("Post-API Call")
def log_stream_event(self, kwargs, response_obj, start_time, end_time):
- print(f"On Stream")
+ print("On Stream")
def log_success_event(self, kwargs, response_obj, start_time, end_time):
print(f"On Success")
@@ -372,6 +373,7 @@ async def test_async_custom_handler_embedding_optional_param():
Tests if the openai optional params for embedding - user + encoding_format,
are logged
"""
+ litellm.set_verbose = True
customHandler_optional_params = MyCustomHandler()
litellm.callbacks = [customHandler_optional_params]
response = await litellm.aembedding(
diff --git a/litellm/utils.py b/litellm/utils.py
index 98461d58b..329f3185d 100644
--- a/litellm/utils.py
+++ b/litellm/utils.py
@@ -938,7 +938,6 @@ class TextCompletionResponse(OpenAIObject):
object=None,
**params,
):
-
if stream:
object = "text_completion.chunk"
choices = [TextChoices()]
@@ -947,7 +946,6 @@ class TextCompletionResponse(OpenAIObject):
if choices is not None and isinstance(choices, list):
new_choices = []
for choice in choices:
-
if isinstance(choice, TextChoices):
_new_choice = choice
elif isinstance(choice, dict):
@@ -1023,7 +1021,6 @@ class ImageObject(OpenAIObject):
revised_prompt: Optional[str] = None
def __init__(self, b64_json=None, url=None, revised_prompt=None):
-
super().__init__(b64_json=b64_json, url=url, revised_prompt=revised_prompt)
def __contains__(self, key):
@@ -1627,7 +1624,6 @@ class Logging:
end_time=end_time,
)
except Exception as e:
-
complete_streaming_response = None
else:
self.sync_streaming_chunks.append(result)
@@ -2397,7 +2393,6 @@ class Logging:
"async_complete_streaming_response"
in self.model_call_details
):
-
await customLogger.async_log_event(
kwargs=self.model_call_details,
response_obj=self.model_call_details[
@@ -6172,13 +6167,16 @@ def get_api_base(
if litellm.model_alias_map and model in litellm.model_alias_map:
model = litellm.model_alias_map[model]
try:
- model, custom_llm_provider, dynamic_api_key, dynamic_api_base = (
- get_llm_provider(
- model=model,
- custom_llm_provider=_optional_params.custom_llm_provider,
- api_base=_optional_params.api_base,
- api_key=_optional_params.api_key,
- )
+ (
+ model,
+ custom_llm_provider,
+ dynamic_api_key,
+ dynamic_api_base,
+ ) = get_llm_provider(
+ model=model,
+ custom_llm_provider=_optional_params.custom_llm_provider,
+ api_base=_optional_params.api_base,
+ api_key=_optional_params.api_key,
)
except Exception as e:
verbose_logger.debug("Error occurred in getting api base - {}".format(str(e)))
@@ -6583,6 +6581,9 @@ def get_llm_provider(
or get_secret("FIREWORKSAI_API_KEY")
or get_secret("FIREWORKS_AI_TOKEN")
)
+ elif custom_llm_provider == "azure_ai":
+ api_base = api_base or get_secret("AZURE_AI_API_BASE") # type: ignore
+ dynamic_api_key = api_key or get_secret("AZURE_AI_API_KEY")
elif custom_llm_provider == "mistral":
# mistral is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.mistral.ai
api_base = (
@@ -7454,7 +7455,6 @@ def validate_environment(model: Optional[str] = None) -> dict:
def set_callbacks(callback_list, function_id=None):
-
global sentry_sdk_instance, capture_exception, add_breadcrumb, posthog, slack_app, alerts_channel, traceloopLogger, athinaLogger, heliconeLogger, aispendLogger, berrispendLogger, supabaseClient, liteDebuggerClient, lunaryLogger, promptLayerLogger, langFuseLogger, customLogger, weightsBiasesLogger, langsmithLogger, logfireLogger, dynamoLogger, s3Logger, dataDogLogger, prometheusLogger, greenscaleLogger, openMeterLogger
try: