mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 11:43:54 +00:00
132 lines
4 KiB
Python
132 lines
4 KiB
Python
import os
|
|
import sys
|
|
from unittest.mock import MagicMock, call, patch
|
|
|
|
import pytest
|
|
|
|
sys.path.insert(
|
|
0, os.path.abspath("../../..")
|
|
) # Adds the parent directory to the system path
|
|
|
|
import litellm
|
|
from litellm.llms.openai.common_utils import BaseOpenAILLM
|
|
|
|
# Test parameters for different API functions
|
|
API_FUNCTION_PARAMS = [
|
|
# (function_name, is_async, args)
|
|
(
|
|
"completion",
|
|
False,
|
|
{
|
|
"model": "gpt-4o-mini",
|
|
"messages": [{"role": "user", "content": "Hello"}],
|
|
"max_tokens": 10,
|
|
},
|
|
),
|
|
(
|
|
"completion",
|
|
True,
|
|
{
|
|
"model": "gpt-4o-mini",
|
|
"messages": [{"role": "user", "content": "Hello"}],
|
|
"max_tokens": 10,
|
|
},
|
|
),
|
|
(
|
|
"completion",
|
|
True,
|
|
{
|
|
"model": "gpt-4o-mini",
|
|
"messages": [{"role": "user", "content": "Hello"}],
|
|
"max_tokens": 10,
|
|
"stream": True,
|
|
},
|
|
),
|
|
("embedding", False, {"model": "text-embedding-ada-002", "input": "Hello world"}),
|
|
("embedding", True, {"model": "text-embedding-ada-002", "input": "Hello world"}),
|
|
(
|
|
"image_generation",
|
|
False,
|
|
{"model": "dall-e-3", "prompt": "A beautiful sunset over mountains"},
|
|
),
|
|
(
|
|
"image_generation",
|
|
True,
|
|
{"model": "dall-e-3", "prompt": "A beautiful sunset over mountains"},
|
|
),
|
|
(
|
|
"speech",
|
|
False,
|
|
{
|
|
"model": "tts-1",
|
|
"input": "Hello, this is a test of text to speech",
|
|
"voice": "alloy",
|
|
},
|
|
),
|
|
(
|
|
"speech",
|
|
True,
|
|
{
|
|
"model": "tts-1",
|
|
"input": "Hello, this is a test of text to speech",
|
|
"voice": "alloy",
|
|
},
|
|
),
|
|
("transcription", False, {"model": "whisper-1", "file": MagicMock()}),
|
|
("transcription", True, {"model": "whisper-1", "file": MagicMock()}),
|
|
]
|
|
|
|
|
|
@pytest.mark.parametrize("function_name,is_async,args", API_FUNCTION_PARAMS)
|
|
@pytest.mark.asyncio
|
|
async def test_openai_client_reuse(function_name, is_async, args):
|
|
"""
|
|
Test that multiple API calls reuse the same OpenAI client
|
|
"""
|
|
litellm.set_verbose = True
|
|
|
|
# Determine which client class to mock based on whether the test is async
|
|
client_path = (
|
|
"litellm.llms.openai.openai.AsyncOpenAI"
|
|
if is_async
|
|
else "litellm.llms.openai.openai.OpenAI"
|
|
)
|
|
|
|
# Create the appropriate patches
|
|
with patch(client_path) as mock_client_class, patch.object(
|
|
BaseOpenAILLM, "set_cached_openai_client"
|
|
) as mock_set_cache, patch.object(
|
|
BaseOpenAILLM, "get_cached_openai_client"
|
|
) as mock_get_cache:
|
|
|
|
# Setup the mock to return None first time (cache miss) then a client for subsequent calls
|
|
mock_client = MagicMock()
|
|
mock_get_cache.side_effect = [None] + [
|
|
mock_client
|
|
] * 9 # First call returns None, rest return the mock client
|
|
|
|
# Make 10 API calls
|
|
for _ in range(10):
|
|
try:
|
|
# Call the appropriate function based on parameters
|
|
if is_async:
|
|
# Add 'a' prefix for async functions
|
|
func = getattr(litellm, f"a{function_name}")
|
|
await func(**args)
|
|
else:
|
|
func = getattr(litellm, function_name)
|
|
func(**args)
|
|
except Exception:
|
|
# We expect exceptions since we're mocking the client
|
|
pass
|
|
|
|
# Verify client was created only once
|
|
assert (
|
|
mock_client_class.call_count == 1
|
|
), f"{'Async' if is_async else ''}OpenAI client should be created only once"
|
|
|
|
# Verify the client was cached
|
|
assert mock_set_cache.call_count == 1, "Client should be cached once"
|
|
|
|
# Verify we tried to get from cache 10 times (once per request)
|
|
assert mock_get_cache.call_count == 10, "Should check cache for each request"
|