test: refactor testing to handle routing correctly

This commit is contained in:
Krrish Dholakia 2025-03-18 12:24:12 -07:00
parent 267084a1af
commit 00b338cb9c
3 changed files with 45 additions and 74 deletions

View file

@ -20,7 +20,8 @@ from litellm.caching.redis_cache import RedisCache
@pytest.mark.parametrize("namespace", [None, "test"]) @pytest.mark.parametrize("namespace", [None, "test"])
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_redis_cache_async_increment(namespace): async def test_redis_cache_async_increment(namespace, monkeypatch):
monkeypatch.setenv("REDIS_HOST", "https://my-test-host")
redis_cache = RedisCache(namespace=namespace) redis_cache = RedisCache(namespace=namespace)
# Create an AsyncMock for the Redis client # Create an AsyncMock for the Redis client
mock_redis_instance = AsyncMock() mock_redis_instance = AsyncMock()
@ -46,7 +47,8 @@ async def test_redis_cache_async_increment(namespace):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_redis_client_init_with_socket_timeout(): async def test_redis_client_init_with_socket_timeout(monkeypatch):
monkeypatch.setenv("REDIS_HOST", "my-fake-host")
redis_cache = RedisCache(socket_timeout=1.0) redis_cache = RedisCache(socket_timeout=1.0)
assert redis_cache.redis_kwargs["socket_timeout"] == 1.0 assert redis_cache.redis_kwargs["socket_timeout"] == 1.0
client = redis_cache.init_async_client() client = redis_cache.init_async_client()

View file

@ -8,84 +8,24 @@ import pytest
sys.path.insert( sys.path.insert(
0, os.path.abspath("../../../../..") 0, os.path.abspath("../../../../..")
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
import litellm from litellm.llms.azure_ai.chat.transformation import AzureAIStudioConfig
@pytest.mark.parametrize("is_async", [True, False])
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_azure_ai_request_format(is_async): async def test_get_openai_compatible_provider_info():
""" """
Test that Azure AI requests are formatted correctly with the proper endpoint and parameters Test that Azure AI requests are formatted correctly with the proper endpoint and parameters
for both synchronous and asynchronous calls for both synchronous and asynchronous calls
""" """
litellm._turn_on_debug() config = AzureAIStudioConfig()
# Set up the test parameters api_base, dynamic_api_key, custom_llm_provider = (
api_key = "00xxx" config._get_openai_compatible_provider_info(
api_base = "https://my-endpoint-europe-berri-992.openai.azure.com/openai/deployments/gpt-4o-mini/chat/completions?api-version=2024-08-01-preview" model="azure_ai/gpt-4o-mini",
model = "azure_ai/gpt-4o-mini" api_base="https://my-base",
messages = [ api_key="my-key",
{"role": "user", "content": "hi"},
{"role": "assistant", "content": "Hello! How can I assist you today?"},
{"role": "user", "content": "hi"},
]
if is_async:
# Mock AsyncHTTPHandler.post method for async test
with patch(
"litellm.llms.custom_httpx.llm_http_handler.AsyncHTTPHandler.post"
) as mock_post:
# Set up mock response
mock_post.return_value = AsyncMock()
# Call the acompletion function
try:
await litellm.acompletion(
custom_llm_provider="azure_ai", custom_llm_provider="azure_ai",
api_key=api_key,
api_base=api_base,
model=model,
messages=messages,
) )
except Exception as e:
# We expect an exception since we're mocking the response
pass
else:
# Mock HTTPHandler.post method for sync test
with patch(
"litellm.llms.custom_httpx.llm_http_handler.HTTPHandler.post"
) as mock_post:
# Set up mock response
mock_post.return_value = MagicMock()
# Call the completion function
try:
litellm.completion(
custom_llm_provider="azure_ai",
api_key=api_key,
api_base=api_base,
model=model,
messages=messages,
) )
except Exception as e:
# We expect an exception since we're mocking the response
pass
# Verify the request was made with the correct parameters assert custom_llm_provider == "azure"
mock_post.assert_called_once()
call_args = mock_post.call_args
print("sync request call=", json.dumps(call_args.kwargs, indent=4, default=str))
# Check URL
assert call_args.kwargs["url"] == api_base
# Check headers
assert call_args.kwargs["headers"]["api-key"] == api_key
# Check request body
request_body = json.loads(call_args.kwargs["data"])
assert (
request_body["model"] == "gpt-4o-mini"
) # Model name should be stripped of provider prefix
assert request_body["messages"] == messages

View file

@ -265,3 +265,32 @@ class TestAzureAIRerank(BaseLLMRerankTest):
"api_base": os.getenv("AZURE_AI_COHERE_API_BASE"), "api_base": os.getenv("AZURE_AI_COHERE_API_BASE"),
"api_key": os.getenv("AZURE_AI_COHERE_API_KEY"), "api_key": os.getenv("AZURE_AI_COHERE_API_KEY"),
} }
@pytest.mark.asyncio
async def test_azure_ai_request_format():
"""
Test that Azure AI requests are formatted correctly with the proper endpoint and parameters
for both synchronous and asynchronous calls
"""
from openai import AsyncAzureOpenAI, AzureOpenAI
litellm._turn_on_debug()
# Set up the test parameters
api_key = os.getenv("AZURE_API_KEY")
api_base = f"{os.getenv('AZURE_API_BASE')}/openai/deployments/gpt-4o/chat/completions?api-version=2024-08-01-preview"
model = "azure_ai/gpt-4o"
messages = [
{"role": "user", "content": "hi"},
{"role": "assistant", "content": "Hello! How can I assist you today?"},
{"role": "user", "content": "hi"},
]
await litellm.acompletion(
custom_llm_provider="azure_ai",
api_key=api_key,
api_base=api_base,
model=model,
messages=messages,
)