test(test_whisper.py): add testing for load balancing whisper endpoints on router

This commit is contained in:
Krrish Dholakia 2024-03-08 14:19:37 -08:00
parent ae54b398d2
commit fe125a5131
2 changed files with 66 additions and 3 deletions

View file

@ -634,6 +634,28 @@ class Router:
raise e
async def atranscription(self, file: BinaryIO, model: str, **kwargs):
"""
Example Usage:
```
from litellm import Router
client = Router(model_list = [
{
"model_name": "whisper",
"litellm_params": {
"model": "whisper-1",
},
},
])
audio_file = open("speech.mp3", "rb")
transcript = await client.atranscription(
model="whisper",
file=audio_file
)
```
"""
try:
kwargs["model"] = model
kwargs["file"] = file

View file

@ -1,14 +1,20 @@
# What is this?
## Tests `litellm.transcription` endpoint
## Tests `litellm.transcription` endpoint. Outside litellm module b/c of audio file used in testing (it's ~700kb).
import pytest
import asyncio, time
import aiohttp
import aiohttp, traceback
from openai import AsyncOpenAI
import sys, os, dotenv
from typing import Optional
from dotenv import load_dotenv
audio_file = open("./gettysburg.wav", "rb")
pwd = os.path.dirname(os.path.realpath(__file__))
print(pwd)
file_path = os.path.join(pwd, "gettysburg.wav")
audio_file = open(file_path, "rb")
load_dotenv()
@ -16,6 +22,7 @@ sys.path.insert(
0, os.path.abspath("../")
) # Adds the parent directory to the system path
import litellm
from litellm import Router
def test_transcription():
@ -71,3 +78,37 @@ async def test_transcription_async_openai():
assert transcript.text is not None
assert isinstance(transcript.text, str)
@pytest.mark.asyncio
async def test_transcription_on_router():
litellm.set_verbose = True
print("\n Testing async transcription on router\n")
try:
model_list = [
{
"model_name": "whisper",
"litellm_params": {
"model": "whisper-1",
},
},
{
"model_name": "whisper",
"litellm_params": {
"model": "azure/azure-whisper",
"api_base": os.getenv("AZURE_EUROPE_API_BASE"),
"api_key": os.getenv("AZURE_EUROPE_API_KEY"),
"api_version": os.getenv("2024-02-15-preview"),
},
},
]
router = Router(model_list=model_list)
response = await router.atranscription(
model="whisper",
file=audio_file,
)
print(response)
except Exception as e:
traceback.print_exc()
pytest.fail(f"Error occurred: {e}")