mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
test: refactor testing to handle routing correctly
This commit is contained in:
parent
267084a1af
commit
00b338cb9c
3 changed files with 45 additions and 74 deletions
|
@ -20,7 +20,8 @@ from litellm.caching.redis_cache import RedisCache
|
||||||
|
|
||||||
@pytest.mark.parametrize("namespace", [None, "test"])
|
@pytest.mark.parametrize("namespace", [None, "test"])
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_redis_cache_async_increment(namespace):
|
async def test_redis_cache_async_increment(namespace, monkeypatch):
|
||||||
|
monkeypatch.setenv("REDIS_HOST", "https://my-test-host")
|
||||||
redis_cache = RedisCache(namespace=namespace)
|
redis_cache = RedisCache(namespace=namespace)
|
||||||
# Create an AsyncMock for the Redis client
|
# Create an AsyncMock for the Redis client
|
||||||
mock_redis_instance = AsyncMock()
|
mock_redis_instance = AsyncMock()
|
||||||
|
@ -46,7 +47,8 @@ async def test_redis_cache_async_increment(namespace):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_redis_client_init_with_socket_timeout():
|
async def test_redis_client_init_with_socket_timeout(monkeypatch):
|
||||||
|
monkeypatch.setenv("REDIS_HOST", "my-fake-host")
|
||||||
redis_cache = RedisCache(socket_timeout=1.0)
|
redis_cache = RedisCache(socket_timeout=1.0)
|
||||||
assert redis_cache.redis_kwargs["socket_timeout"] == 1.0
|
assert redis_cache.redis_kwargs["socket_timeout"] == 1.0
|
||||||
client = redis_cache.init_async_client()
|
client = redis_cache.init_async_client()
|
||||||
|
|
|
@ -8,84 +8,24 @@ import pytest
|
||||||
sys.path.insert(
|
sys.path.insert(
|
||||||
0, os.path.abspath("../../../../..")
|
0, os.path.abspath("../../../../..")
|
||||||
) # Adds the parent directory to the system path
|
) # Adds the parent directory to the system path
|
||||||
import litellm
|
from litellm.llms.azure_ai.chat.transformation import AzureAIStudioConfig
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("is_async", [True, False])
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_azure_ai_request_format(is_async):
|
async def test_get_openai_compatible_provider_info():
|
||||||
"""
|
"""
|
||||||
Test that Azure AI requests are formatted correctly with the proper endpoint and parameters
|
Test that Azure AI requests are formatted correctly with the proper endpoint and parameters
|
||||||
for both synchronous and asynchronous calls
|
for both synchronous and asynchronous calls
|
||||||
"""
|
"""
|
||||||
litellm._turn_on_debug()
|
config = AzureAIStudioConfig()
|
||||||
|
|
||||||
# Set up the test parameters
|
api_base, dynamic_api_key, custom_llm_provider = (
|
||||||
api_key = "00xxx"
|
config._get_openai_compatible_provider_info(
|
||||||
api_base = "https://my-endpoint-europe-berri-992.openai.azure.com/openai/deployments/gpt-4o-mini/chat/completions?api-version=2024-08-01-preview"
|
model="azure_ai/gpt-4o-mini",
|
||||||
model = "azure_ai/gpt-4o-mini"
|
api_base="https://my-base",
|
||||||
messages = [
|
api_key="my-key",
|
||||||
{"role": "user", "content": "hi"},
|
|
||||||
{"role": "assistant", "content": "Hello! How can I assist you today?"},
|
|
||||||
{"role": "user", "content": "hi"},
|
|
||||||
]
|
|
||||||
|
|
||||||
if is_async:
|
|
||||||
# Mock AsyncHTTPHandler.post method for async test
|
|
||||||
with patch(
|
|
||||||
"litellm.llms.custom_httpx.llm_http_handler.AsyncHTTPHandler.post"
|
|
||||||
) as mock_post:
|
|
||||||
# Set up mock response
|
|
||||||
mock_post.return_value = AsyncMock()
|
|
||||||
|
|
||||||
# Call the acompletion function
|
|
||||||
try:
|
|
||||||
await litellm.acompletion(
|
|
||||||
custom_llm_provider="azure_ai",
|
custom_llm_provider="azure_ai",
|
||||||
api_key=api_key,
|
|
||||||
api_base=api_base,
|
|
||||||
model=model,
|
|
||||||
messages=messages,
|
|
||||||
)
|
)
|
||||||
except Exception as e:
|
|
||||||
# We expect an exception since we're mocking the response
|
|
||||||
pass
|
|
||||||
|
|
||||||
else:
|
|
||||||
# Mock HTTPHandler.post method for sync test
|
|
||||||
with patch(
|
|
||||||
"litellm.llms.custom_httpx.llm_http_handler.HTTPHandler.post"
|
|
||||||
) as mock_post:
|
|
||||||
# Set up mock response
|
|
||||||
mock_post.return_value = MagicMock()
|
|
||||||
|
|
||||||
# Call the completion function
|
|
||||||
try:
|
|
||||||
litellm.completion(
|
|
||||||
custom_llm_provider="azure_ai",
|
|
||||||
api_key=api_key,
|
|
||||||
api_base=api_base,
|
|
||||||
model=model,
|
|
||||||
messages=messages,
|
|
||||||
)
|
)
|
||||||
except Exception as e:
|
|
||||||
# We expect an exception since we're mocking the response
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Verify the request was made with the correct parameters
|
assert custom_llm_provider == "azure"
|
||||||
mock_post.assert_called_once()
|
|
||||||
call_args = mock_post.call_args
|
|
||||||
print("sync request call=", json.dumps(call_args.kwargs, indent=4, default=str))
|
|
||||||
|
|
||||||
# Check URL
|
|
||||||
assert call_args.kwargs["url"] == api_base
|
|
||||||
|
|
||||||
# Check headers
|
|
||||||
assert call_args.kwargs["headers"]["api-key"] == api_key
|
|
||||||
|
|
||||||
# Check request body
|
|
||||||
request_body = json.loads(call_args.kwargs["data"])
|
|
||||||
assert (
|
|
||||||
request_body["model"] == "gpt-4o-mini"
|
|
||||||
) # Model name should be stripped of provider prefix
|
|
||||||
assert request_body["messages"] == messages
|
|
||||||
|
|
|
@ -265,3 +265,32 @@ class TestAzureAIRerank(BaseLLMRerankTest):
|
||||||
"api_base": os.getenv("AZURE_AI_COHERE_API_BASE"),
|
"api_base": os.getenv("AZURE_AI_COHERE_API_BASE"),
|
||||||
"api_key": os.getenv("AZURE_AI_COHERE_API_KEY"),
|
"api_key": os.getenv("AZURE_AI_COHERE_API_KEY"),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_azure_ai_request_format():
|
||||||
|
"""
|
||||||
|
Test that Azure AI requests are formatted correctly with the proper endpoint and parameters
|
||||||
|
for both synchronous and asynchronous calls
|
||||||
|
"""
|
||||||
|
from openai import AsyncAzureOpenAI, AzureOpenAI
|
||||||
|
|
||||||
|
litellm._turn_on_debug()
|
||||||
|
|
||||||
|
# Set up the test parameters
|
||||||
|
api_key = os.getenv("AZURE_API_KEY")
|
||||||
|
api_base = f"{os.getenv('AZURE_API_BASE')}/openai/deployments/gpt-4o/chat/completions?api-version=2024-08-01-preview"
|
||||||
|
model = "azure_ai/gpt-4o"
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": "hi"},
|
||||||
|
{"role": "assistant", "content": "Hello! How can I assist you today?"},
|
||||||
|
{"role": "user", "content": "hi"},
|
||||||
|
]
|
||||||
|
|
||||||
|
await litellm.acompletion(
|
||||||
|
custom_llm_provider="azure_ai",
|
||||||
|
api_key=api_key,
|
||||||
|
api_base=api_base,
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue