litellm-mirror/tests/llm_translation/test_azure_openai.py
Krish Dholakia 0c0498dd60
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 11s
Litellm dev 12 07 2024 (#7086)
* fix(main.py): support passing max retries to azure/openai embedding integrations

Fixes https://github.com/BerriAI/litellm/issues/7003

* feat(team_endpoints.py): allow updating team model aliases

Closes https://github.com/BerriAI/litellm/issues/6956

* feat(router.py): allow specifying model id as fallback - skips any cooldown check

Allows a default model to be checked if all models in cooldown

s/o @micahjsmith

* docs(reliability.md): add fallback to specific model to docs

* fix(utils.py): new 'is_prompt_caching_valid_prompt' helper util

Allows user to identify if messages/tools have prompt caching

Related issue: https://github.com/BerriAI/litellm/issues/6784

* feat(router.py): store model id for prompt caching valid prompt

Allows routing to that model id on subsequent requests

* fix(router.py): only cache if prompt is valid prompt caching prompt

prevents storing unnecessary items in cache

* feat(router.py): support routing prompt caching enabled models to previous deployments

Closes https://github.com/BerriAI/litellm/issues/6784

* test: fix linting errors

* feat(databricks/): convert basemodel to dict and exclude none values

allow passing pydantic message to databricks

* fix(utils.py): ensure all chat completion messages are dict

* (feat) Track `custom_llm_provider` in LiteLLMSpendLogs (#7081)

* add custom_llm_provider to SpendLogsPayload

* add custom_llm_provider to SpendLogs

* add custom llm provider to SpendLogs payload

* test_spend_logs_payload

* Add MLflow to the side bar (#7031)

Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>

* (bug fix) SpendLogs update DB catch all possible DB errors for retrying  (#7082)

* catch DB_CONNECTION_ERROR_TYPES

* fix DB retry mechanism for SpendLog updates

* use DB_CONNECTION_ERROR_TYPES in auth checks

* fix exp back off for writing SpendLogs

* use _raise_failed_update_spend_exception to ensure errors print as NON blocking

* test_update_spend_logs_multiple_batches_with_failure

* (Feat) Add StructuredOutputs support for Fireworks.AI (#7085)

* fix model cost map fireworks ai "supports_response_schema": true,

* fix supports_response_schema

* fix map openai params fireworks ai

* test_map_response_format

* test_map_response_format

* added deepinfra/Meta-Llama-3.1-405B-Instruct (#7084)

* bump: version 1.53.9 → 1.54.0

* fix deepinfra

* litellm db fixes LiteLLM_UserTable (#7089)

* ci/cd queue new release

* fix llama-3.3-70b-versatile

* refactor - use consistent file naming convention `AI21/` -> `ai21`  (#7090)

* fix refactor - use consistent file naming convention

* ci/cd run again

* fix naming structure

* fix use consistent naming (#7092)

---------

Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>
Co-authored-by: Ishaan Jaff <ishaanjaffer0324@gmail.com>
Co-authored-by: Yuki Watanabe <31463517+B-Step62@users.noreply.github.com>
Co-authored-by: ali sayyah <ali.sayyah2@gmail.com>
2024-12-08 00:30:33 -08:00

203 lines
6.8 KiB
Python

import sys
import os
sys.path.insert(
0, os.path.abspath("../../")
) # Adds the parent directory to the system path
import pytest
from litellm.llms.azure.common_utils import process_azure_headers
from httpx import Headers
from base_embedding_unit_tests import BaseLLMEmbeddingTest
def test_process_azure_headers_empty():
result = process_azure_headers({})
assert result == {}, "Expected empty dictionary for no input"
def test_process_azure_headers_with_all_headers():
input_headers = Headers(
{
"x-ratelimit-limit-requests": "100",
"x-ratelimit-remaining-requests": "90",
"x-ratelimit-limit-tokens": "10000",
"x-ratelimit-remaining-tokens": "9000",
"other-header": "value",
}
)
expected_output = {
"x-ratelimit-limit-requests": "100",
"x-ratelimit-remaining-requests": "90",
"x-ratelimit-limit-tokens": "10000",
"x-ratelimit-remaining-tokens": "9000",
"llm_provider-x-ratelimit-limit-requests": "100",
"llm_provider-x-ratelimit-remaining-requests": "90",
"llm_provider-x-ratelimit-limit-tokens": "10000",
"llm_provider-x-ratelimit-remaining-tokens": "9000",
"llm_provider-other-header": "value",
}
result = process_azure_headers(input_headers)
assert result == expected_output, "Unexpected output for all Azure headers"
def test_process_azure_headers_with_partial_headers():
input_headers = Headers(
{
"x-ratelimit-limit-requests": "100",
"x-ratelimit-remaining-tokens": "9000",
"other-header": "value",
}
)
expected_output = {
"x-ratelimit-limit-requests": "100",
"x-ratelimit-remaining-tokens": "9000",
"llm_provider-x-ratelimit-limit-requests": "100",
"llm_provider-x-ratelimit-remaining-tokens": "9000",
"llm_provider-other-header": "value",
}
result = process_azure_headers(input_headers)
assert result == expected_output, "Unexpected output for partial Azure headers"
def test_process_azure_headers_with_no_matching_headers():
input_headers = Headers(
{"unrelated-header-1": "value1", "unrelated-header-2": "value2"}
)
expected_output = {
"llm_provider-unrelated-header-1": "value1",
"llm_provider-unrelated-header-2": "value2",
}
result = process_azure_headers(input_headers)
assert result == expected_output, "Unexpected output for non-matching headers"
def test_process_azure_headers_with_dict_input():
input_headers = {
"x-ratelimit-limit-requests": "100",
"x-ratelimit-remaining-requests": "90",
"other-header": "value",
}
expected_output = {
"x-ratelimit-limit-requests": "100",
"x-ratelimit-remaining-requests": "90",
"llm_provider-x-ratelimit-limit-requests": "100",
"llm_provider-x-ratelimit-remaining-requests": "90",
"llm_provider-other-header": "value",
}
result = process_azure_headers(input_headers)
assert result == expected_output, "Unexpected output for dict input"
from httpx import Client
from unittest.mock import MagicMock, patch
from openai import AzureOpenAI
import litellm
from litellm import completion
import os
@pytest.mark.parametrize(
"input, call_type",
[
({"messages": [{"role": "user", "content": "Hello world"}]}, "completion"),
({"input": "Hello world"}, "embedding"),
({"prompt": "Hello world"}, "image_generation"),
],
)
def test_azure_extra_headers(input, call_type):
from litellm import embedding, image_generation
http_client = Client()
messages = [{"role": "user", "content": "Hello world"}]
with patch.object(http_client, "send", new=MagicMock()) as mock_client:
litellm.client_session = http_client
try:
if call_type == "completion":
func = completion
elif call_type == "embedding":
func = embedding
elif call_type == "image_generation":
func = image_generation
response = func(
model="azure/chatgpt-v-2",
api_base="https://openai-gpt-4-test-v-1.openai.azure.com",
api_version="2023-07-01-preview",
api_key="my-azure-api-key",
extra_headers={
"Authorization": "my-bad-key",
"Ocp-Apim-Subscription-Key": "hello-world-testing",
},
**input,
)
print(response)
except Exception as e:
print(e)
mock_client.assert_called()
print(f"mock_client.call_args: {mock_client.call_args}")
request = mock_client.call_args[0][0]
print(request.method) # This will print 'POST'
print(request.url) # This will print the full URL
print(request.headers) # This will print the full URL
auth_header = request.headers.get("Authorization")
apim_key = request.headers.get("Ocp-Apim-Subscription-Key")
print(auth_header)
assert auth_header == "my-bad-key"
assert apim_key == "hello-world-testing"
@pytest.mark.parametrize(
"api_base, model, expected_endpoint",
[
(
"https://my-endpoint-sweden-berri992.openai.azure.com",
"dall-e-3-test",
"https://my-endpoint-sweden-berri992.openai.azure.com/openai/deployments/dall-e-3-test/images/generations?api-version=2023-12-01-preview",
),
(
"https://my-endpoint-sweden-berri992.openai.azure.com/openai/deployments/my-custom-deployment",
"dall-e-3",
"https://my-endpoint-sweden-berri992.openai.azure.com/openai/deployments/my-custom-deployment/images/generations?api-version=2023-12-01-preview",
),
],
)
def test_process_azure_endpoint_url(api_base, model, expected_endpoint):
from litellm.llms.azure.azure import AzureChatCompletion
azure_chat_completion = AzureChatCompletion()
input_args = {
"azure_client_params": {
"api_version": "2023-12-01-preview",
"azure_endpoint": api_base,
"azure_deployment": model,
"max_retries": 2,
"timeout": 600,
"api_key": "f28ab7b695af4154bc53498e5bdccb07",
},
"model": model,
}
result = azure_chat_completion.create_azure_base_url(**input_args)
assert result == expected_endpoint, "Unexpected endpoint"
class TestAzureEmbedding(BaseLLMEmbeddingTest):
def get_base_embedding_call_args(self) -> dict:
return {
"model": "azure/azure-embedding-model",
"api_key": os.getenv("AZURE_API_KEY"),
"api_base": os.getenv("AZURE_API_BASE"),
}
def get_custom_llm_provider(self) -> litellm.LlmProviders:
return litellm.LlmProviders.AZURE