mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
test: refactor testing to handle routing correctly
This commit is contained in:
parent
267084a1af
commit
00b338cb9c
3 changed files with 45 additions and 74 deletions
|
@ -20,7 +20,8 @@ from litellm.caching.redis_cache import RedisCache
|
|||
|
||||
@pytest.mark.parametrize("namespace", [None, "test"])
|
||||
@pytest.mark.asyncio
|
||||
async def test_redis_cache_async_increment(namespace):
|
||||
async def test_redis_cache_async_increment(namespace, monkeypatch):
|
||||
monkeypatch.setenv("REDIS_HOST", "https://my-test-host")
|
||||
redis_cache = RedisCache(namespace=namespace)
|
||||
# Create an AsyncMock for the Redis client
|
||||
mock_redis_instance = AsyncMock()
|
||||
|
@ -46,7 +47,8 @@ async def test_redis_cache_async_increment(namespace):
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_redis_client_init_with_socket_timeout():
|
||||
async def test_redis_client_init_with_socket_timeout(monkeypatch):
|
||||
monkeypatch.setenv("REDIS_HOST", "my-fake-host")
|
||||
redis_cache = RedisCache(socket_timeout=1.0)
|
||||
assert redis_cache.redis_kwargs["socket_timeout"] == 1.0
|
||||
client = redis_cache.init_async_client()
|
||||
|
|
|
@ -8,84 +8,24 @@ import pytest
|
|||
sys.path.insert(
|
||||
0, os.path.abspath("../../../../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import litellm
|
||||
from litellm.llms.azure_ai.chat.transformation import AzureAIStudioConfig
|
||||
|
||||
|
||||
@pytest.mark.parametrize("is_async", [True, False])
|
||||
@pytest.mark.asyncio
|
||||
async def test_azure_ai_request_format(is_async):
|
||||
async def test_get_openai_compatible_provider_info():
|
||||
"""
|
||||
Test that Azure AI requests are formatted correctly with the proper endpoint and parameters
|
||||
for both synchronous and asynchronous calls
|
||||
"""
|
||||
litellm._turn_on_debug()
|
||||
config = AzureAIStudioConfig()
|
||||
|
||||
# Set up the test parameters
|
||||
api_key = "00xxx"
|
||||
api_base = "https://my-endpoint-europe-berri-992.openai.azure.com/openai/deployments/gpt-4o-mini/chat/completions?api-version=2024-08-01-preview"
|
||||
model = "azure_ai/gpt-4o-mini"
|
||||
messages = [
|
||||
{"role": "user", "content": "hi"},
|
||||
{"role": "assistant", "content": "Hello! How can I assist you today?"},
|
||||
{"role": "user", "content": "hi"},
|
||||
]
|
||||
api_base, dynamic_api_key, custom_llm_provider = (
|
||||
config._get_openai_compatible_provider_info(
|
||||
model="azure_ai/gpt-4o-mini",
|
||||
api_base="https://my-base",
|
||||
api_key="my-key",
|
||||
custom_llm_provider="azure_ai",
|
||||
)
|
||||
)
|
||||
|
||||
if is_async:
|
||||
# Mock AsyncHTTPHandler.post method for async test
|
||||
with patch(
|
||||
"litellm.llms.custom_httpx.llm_http_handler.AsyncHTTPHandler.post"
|
||||
) as mock_post:
|
||||
# Set up mock response
|
||||
mock_post.return_value = AsyncMock()
|
||||
|
||||
# Call the acompletion function
|
||||
try:
|
||||
await litellm.acompletion(
|
||||
custom_llm_provider="azure_ai",
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
model=model,
|
||||
messages=messages,
|
||||
)
|
||||
except Exception as e:
|
||||
# We expect an exception since we're mocking the response
|
||||
pass
|
||||
|
||||
else:
|
||||
# Mock HTTPHandler.post method for sync test
|
||||
with patch(
|
||||
"litellm.llms.custom_httpx.llm_http_handler.HTTPHandler.post"
|
||||
) as mock_post:
|
||||
# Set up mock response
|
||||
mock_post.return_value = MagicMock()
|
||||
|
||||
# Call the completion function
|
||||
try:
|
||||
litellm.completion(
|
||||
custom_llm_provider="azure_ai",
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
model=model,
|
||||
messages=messages,
|
||||
)
|
||||
except Exception as e:
|
||||
# We expect an exception since we're mocking the response
|
||||
pass
|
||||
|
||||
# Verify the request was made with the correct parameters
|
||||
mock_post.assert_called_once()
|
||||
call_args = mock_post.call_args
|
||||
print("sync request call=", json.dumps(call_args.kwargs, indent=4, default=str))
|
||||
|
||||
# Check URL
|
||||
assert call_args.kwargs["url"] == api_base
|
||||
|
||||
# Check headers
|
||||
assert call_args.kwargs["headers"]["api-key"] == api_key
|
||||
|
||||
# Check request body
|
||||
request_body = json.loads(call_args.kwargs["data"])
|
||||
assert (
|
||||
request_body["model"] == "gpt-4o-mini"
|
||||
) # Model name should be stripped of provider prefix
|
||||
assert request_body["messages"] == messages
|
||||
assert custom_llm_provider == "azure"
|
||||
|
|
|
@ -265,3 +265,32 @@ class TestAzureAIRerank(BaseLLMRerankTest):
|
|||
"api_base": os.getenv("AZURE_AI_COHERE_API_BASE"),
|
||||
"api_key": os.getenv("AZURE_AI_COHERE_API_KEY"),
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_azure_ai_request_format():
|
||||
"""
|
||||
Test that Azure AI requests are formatted correctly with the proper endpoint and parameters
|
||||
for both synchronous and asynchronous calls
|
||||
"""
|
||||
from openai import AsyncAzureOpenAI, AzureOpenAI
|
||||
|
||||
litellm._turn_on_debug()
|
||||
|
||||
# Set up the test parameters
|
||||
api_key = os.getenv("AZURE_API_KEY")
|
||||
api_base = f"{os.getenv('AZURE_API_BASE')}/openai/deployments/gpt-4o/chat/completions?api-version=2024-08-01-preview"
|
||||
model = "azure_ai/gpt-4o"
|
||||
messages = [
|
||||
{"role": "user", "content": "hi"},
|
||||
{"role": "assistant", "content": "Hello! How can I assist you today?"},
|
||||
{"role": "user", "content": "hi"},
|
||||
]
|
||||
|
||||
await litellm.acompletion(
|
||||
custom_llm_provider="azure_ai",
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
model=model,
|
||||
messages=messages,
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue