litellm-mirror/tests/litellm/llms/openai/test_openai_common_utils.py
2025-03-18 18:13:36 -07:00

132 lines
4 KiB
Python

import os
import sys
from unittest.mock import MagicMock, call, patch
import pytest
sys.path.insert(
0, os.path.abspath("../../..")
) # Adds the parent directory to the system path
import litellm
from litellm.llms.openai.common_utils import BaseOpenAILLM
# Test parameters for different API functions
API_FUNCTION_PARAMS = [
# (function_name, is_async, args)
(
"completion",
False,
{
"model": "gpt-4o-mini",
"messages": [{"role": "user", "content": "Hello"}],
"max_tokens": 10,
},
),
(
"completion",
True,
{
"model": "gpt-4o-mini",
"messages": [{"role": "user", "content": "Hello"}],
"max_tokens": 10,
},
),
(
"completion",
True,
{
"model": "gpt-4o-mini",
"messages": [{"role": "user", "content": "Hello"}],
"max_tokens": 10,
"stream": True,
},
),
("embedding", False, {"model": "text-embedding-ada-002", "input": "Hello world"}),
("embedding", True, {"model": "text-embedding-ada-002", "input": "Hello world"}),
(
"image_generation",
False,
{"model": "dall-e-3", "prompt": "A beautiful sunset over mountains"},
),
(
"image_generation",
True,
{"model": "dall-e-3", "prompt": "A beautiful sunset over mountains"},
),
(
"speech",
False,
{
"model": "tts-1",
"input": "Hello, this is a test of text to speech",
"voice": "alloy",
},
),
(
"speech",
True,
{
"model": "tts-1",
"input": "Hello, this is a test of text to speech",
"voice": "alloy",
},
),
("transcription", False, {"model": "whisper-1", "file": MagicMock()}),
("transcription", True, {"model": "whisper-1", "file": MagicMock()}),
]
@pytest.mark.parametrize("function_name,is_async,args", API_FUNCTION_PARAMS)
@pytest.mark.asyncio
async def test_openai_client_reuse(function_name, is_async, args):
"""
Test that multiple API calls reuse the same OpenAI client
"""
litellm.set_verbose = True
# Determine which client class to mock based on whether the test is async
client_path = (
"litellm.llms.openai.openai.AsyncOpenAI"
if is_async
else "litellm.llms.openai.openai.OpenAI"
)
# Create the appropriate patches
with patch(client_path) as mock_client_class, patch.object(
BaseOpenAILLM, "set_cached_openai_client"
) as mock_set_cache, patch.object(
BaseOpenAILLM, "get_cached_openai_client"
) as mock_get_cache:
# Setup the mock to return None first time (cache miss) then a client for subsequent calls
mock_client = MagicMock()
mock_get_cache.side_effect = [None] + [
mock_client
] * 9 # First call returns None, rest return the mock client
# Make 10 API calls
for _ in range(10):
try:
# Call the appropriate function based on parameters
if is_async:
# Add 'a' prefix for async functions
func = getattr(litellm, f"a{function_name}")
await func(**args)
else:
func = getattr(litellm, function_name)
func(**args)
except Exception:
# We expect exceptions since we're mocking the client
pass
# Verify client was created only once
assert (
mock_client_class.call_count == 1
), f"{'Async' if is_async else ''}OpenAI client should be created only once"
# Verify the client was cached
assert mock_set_cache.call_count == 1, "Client should be cached once"
# Verify we tried to get from cache 10 times (once per request)
assert mock_get_cache.call_count == 10, "Should check cache for each request"