test(test_whisper.py): add testing for load balancing whisper endpoints on router

This commit is contained in:
Krrish Dholakia 2024-03-08 14:19:37 -08:00
parent 93e9781d37
commit aca37d3bc5
2 changed files with 66 additions and 3 deletions

View file

@ -634,6 +634,28 @@ class Router:
raise e raise e
async def atranscription(self, file: BinaryIO, model: str, **kwargs): async def atranscription(self, file: BinaryIO, model: str, **kwargs):
"""
Example Usage:
```
from litellm import Router
client = Router(model_list = [
{
"model_name": "whisper",
"litellm_params": {
"model": "whisper-1",
},
},
])
audio_file = open("speech.mp3", "rb")
transcript = await client.atranscription(
model="whisper",
file=audio_file
)
```
"""
try: try:
kwargs["model"] = model kwargs["model"] = model
kwargs["file"] = file kwargs["file"] = file

View file

@ -1,14 +1,20 @@
# What is this? # What is this?
## Tests `litellm.transcription` endpoint ## Tests `litellm.transcription` endpoint. Outside litellm module b/c of audio file used in testing (it's ~700kb).
import pytest import pytest
import asyncio, time import asyncio, time
import aiohttp import aiohttp, traceback
from openai import AsyncOpenAI from openai import AsyncOpenAI
import sys, os, dotenv import sys, os, dotenv
from typing import Optional from typing import Optional
from dotenv import load_dotenv from dotenv import load_dotenv
audio_file = open("./gettysburg.wav", "rb") pwd = os.path.dirname(os.path.realpath(__file__))
print(pwd)
file_path = os.path.join(pwd, "gettysburg.wav")
audio_file = open(file_path, "rb")
load_dotenv() load_dotenv()
@ -16,6 +22,7 @@ sys.path.insert(
0, os.path.abspath("../") 0, os.path.abspath("../")
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
import litellm import litellm
from litellm import Router
def test_transcription(): def test_transcription():
@ -71,3 +78,37 @@ async def test_transcription_async_openai():
assert transcript.text is not None assert transcript.text is not None
assert isinstance(transcript.text, str) assert isinstance(transcript.text, str)
@pytest.mark.asyncio
async def test_transcription_on_router():
litellm.set_verbose = True
print("\n Testing async transcription on router\n")
try:
model_list = [
{
"model_name": "whisper",
"litellm_params": {
"model": "whisper-1",
},
},
{
"model_name": "whisper",
"litellm_params": {
"model": "azure/azure-whisper",
"api_base": os.getenv("AZURE_EUROPE_API_BASE"),
"api_key": os.getenv("AZURE_EUROPE_API_KEY"),
"api_version": os.getenv("2024-02-15-preview"),
},
},
]
router = Router(model_list=model_list)
response = await router.atranscription(
model="whisper",
file=audio_file,
)
print(response)
except Exception as e:
traceback.print_exc()
pytest.fail(f"Error occurred: {e}")