diff --git a/litellm/router.py b/litellm/router.py index 10f7058b3e..71339aa36c 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -634,6 +634,28 @@ class Router: raise e async def atranscription(self, file: BinaryIO, model: str, **kwargs): + """ + Example Usage: + + ``` + from litellm import Router + client = Router(model_list = [ + { + "model_name": "whisper", + "litellm_params": { + "model": "whisper-1", + }, + }, + ]) + + audio_file = open("speech.mp3", "rb") + transcript = await client.atranscription( + model="whisper", + file=audio_file + ) + + ``` + """ try: kwargs["model"] = model kwargs["file"] = file diff --git a/tests/test_whisper.py b/tests/test_whisper.py index 9d8f038c29..bb09717325 100644 --- a/tests/test_whisper.py +++ b/tests/test_whisper.py @@ -1,14 +1,20 @@ # What is this? -## Tests `litellm.transcription` endpoint +## Tests `litellm.transcription` endpoint. Outside litellm module b/c of audio file used in testing (it's ~700kb). + import pytest import asyncio, time -import aiohttp +import aiohttp, traceback from openai import AsyncOpenAI import sys, os, dotenv from typing import Optional from dotenv import load_dotenv -audio_file = open("./gettysburg.wav", "rb") +pwd = os.path.dirname(os.path.realpath(__file__)) +print(pwd) + +file_path = os.path.join(pwd, "gettysburg.wav") + +audio_file = open(file_path, "rb") load_dotenv() @@ -16,6 +22,7 @@ sys.path.insert( 0, os.path.abspath("../") ) # Adds the parent directory to the system path import litellm +from litellm import Router def test_transcription(): @@ -71,3 +78,37 @@ async def test_transcription_async_openai(): assert transcript.text is not None assert isinstance(transcript.text, str) + + +@pytest.mark.asyncio +async def test_transcription_on_router(): + litellm.set_verbose = True + print("\n Testing async transcription on router\n") + try: + model_list = [ + { + "model_name": "whisper", + "litellm_params": { + "model": "whisper-1", + }, + }, + { + "model_name": "whisper", + "litellm_params": { + "model": "azure/azure-whisper", + "api_base": os.getenv("AZURE_EUROPE_API_BASE"), + "api_key": os.getenv("AZURE_EUROPE_API_KEY"), + "api_version": os.getenv("2024-02-15-preview"), + }, + }, + ] + + router = Router(model_list=model_list) + response = await router.atranscription( + model="whisper", + file=audio_file, + ) + print(response) + except Exception as e: + traceback.print_exc() + pytest.fail(f"Error occurred: {e}")