From ce7fade15e6af7517dc3dfc4b287de13d7ce1843 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Tue, 2 Jul 2024 12:35:15 -0700 Subject: [PATCH] test whisper re-using openai/azure clients --- tests/test_whisper.py | 44 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/tests/test_whisper.py b/tests/test_whisper.py index 1debbbc1d..09819f796 100644 --- a/tests/test_whisper.py +++ b/tests/test_whisper.py @@ -8,6 +8,9 @@ from openai import AsyncOpenAI import sys, os, dotenv from typing import Optional from dotenv import load_dotenv +from litellm.integrations.custom_logger import CustomLogger +import litellm +import logging # Get the current directory of the file being run pwd = os.path.dirname(os.path.realpath(__file__)) @@ -84,9 +87,32 @@ async def test_transcription_async_openai(): assert isinstance(transcript.text, str) +# This file includes the custom callbacks for LiteLLM Proxy +# Once defined, these can be passed in proxy_config.yaml +class MyCustomHandler(CustomLogger): + def __init__(self): + self.openai_client = None + + async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): + try: + # init logging config + print("logging a transcript kwargs: ", kwargs) + print("openai client=", kwargs.get("client")) + self.openai_client = kwargs.get("client") + + except: + pass + + +proxy_handler_instance = MyCustomHandler() + + +# Set litellm.callbacks = [proxy_handler_instance] on the proxy +# need to set litellm.callbacks = [proxy_handler_instance] # on the proxy @pytest.mark.asyncio async def test_transcription_on_router(): litellm.set_verbose = True + litellm.callbacks = [proxy_handler_instance] print("\n Testing async transcription on router\n") try: model_list = [ @@ -108,11 +134,29 @@ async def test_transcription_on_router(): ] router = Router(model_list=model_list) + + router_level_clients = [] + for deployment in router.model_list: + _deployment_openai_client = router._get_client( + deployment=deployment, + kwargs={"model": "whisper-1"}, + client_type="async", + ) + + router_level_clients.append(str(_deployment_openai_client)) + response = await router.atranscription( model="whisper", file=audio_file, ) print(response) + + # PROD Test + # Ensure we ONLY use OpenAI/Azure client initialized on the router level + await asyncio.sleep(5) + print("OpenAI Client used= ", proxy_handler_instance.openai_client) + print("all router level clients= ", router_level_clients) + assert proxy_handler_instance.openai_client in router_level_clients except Exception as e: traceback.print_exc() pytest.fail(f"Error occurred: {e}")