diff --git a/litellm/llms/azure.py b/litellm/llms/azure.py index 0f6786902..1d49f9a0e 100644 --- a/litellm/llms/azure.py +++ b/litellm/llms/azure.py @@ -7,13 +7,15 @@ from litellm.utils import ( Message, CustomStreamWrapper, convert_to_model_response_object, + TranscriptionResponse, ) -from typing import Callable, Optional +from typing import Callable, Optional, BinaryIO from litellm import OpenAIConfig import litellm, json import httpx from .custom_httpx.azure_dall_e_2 import CustomHTTPTransport, AsyncCustomHTTPTransport from openai import AzureOpenAI, AsyncAzureOpenAI +import uuid class AzureOpenAIError(Exception): @@ -780,6 +782,142 @@ class AzureChatCompletion(BaseLLM): else: raise AzureOpenAIError(status_code=500, message=str(e)) + def audio_transcriptions( + self, + model: str, + audio_file: BinaryIO, + optional_params: dict, + model_response: TranscriptionResponse, + timeout: float, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + api_version: Optional[str] = None, + client=None, + azure_ad_token: Optional[str] = None, + max_retries=None, + logging_obj=None, + atranscriptions: bool = False, + ): + data = {"model": model, "file": audio_file, **optional_params} + + # init AzureOpenAI Client + azure_client_params = { + "api_version": api_version, + "azure_endpoint": api_base, + "azure_deployment": model, + "max_retries": max_retries, + "timeout": timeout, + } + azure_client_params = select_azure_base_url_or_endpoint( + azure_client_params=azure_client_params + ) + if api_key is not None: + azure_client_params["api_key"] = api_key + elif azure_ad_token is not None: + azure_client_params["azure_ad_token"] = azure_ad_token + + if atranscriptions == True: + return self.async_audio_transcriptions( + audio_file=audio_file, + data=data, + model_response=model_response, + timeout=timeout, + api_key=api_key, + api_base=api_base, + client=client, + azure_client_params=azure_client_params, + max_retries=max_retries, + logging_obj=logging_obj, + ) + if client is None: + azure_client = AzureOpenAI(http_client=litellm.client_session, **azure_client_params) # type: ignore + else: + azure_client = client + + ## LOGGING + logging_obj.pre_call( + input=f"audio_file_{uuid.uuid4()}", + api_key=azure_client.api_key, + additional_args={ + "headers": {"Authorization": f"Bearer {azure_client.api_key}"}, + "api_base": azure_client._base_url._uri_reference, + "atranscription": True, + "complete_input_dict": data, + }, + ) + + response = azure_client.audio.transcriptions.create( + **data, timeout=timeout # type: ignore + ) + stringified_response = response.model_dump() + ## LOGGING + logging_obj.post_call( + input=audio_file.name, + api_key=api_key, + additional_args={"complete_input_dict": data}, + original_response=stringified_response, + ) + final_response = convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, response_type="audio_transcription") # type: ignore + return final_response + + async def async_audio_transcriptions( + self, + audio_file: BinaryIO, + data: dict, + model_response: TranscriptionResponse, + timeout: float, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + client=None, + azure_client_params=None, + max_retries=None, + logging_obj=None, + ): + response = None + try: + if client is None: + async_azure_client = AsyncAzureOpenAI( + **azure_client_params, + http_client=litellm.aclient_session, + ) + else: + async_azure_client = client + + ## LOGGING + logging_obj.pre_call( + input=f"audio_file_{uuid.uuid4()}", + api_key=async_azure_client.api_key, + additional_args={ + "headers": { + "Authorization": f"Bearer {async_azure_client.api_key}" + }, + "api_base": async_azure_client._base_url._uri_reference, + "atranscription": True, + "complete_input_dict": data, + }, + ) + + response = await async_azure_client.audio.transcriptions.create( + **data, timeout=timeout + ) # type: ignore + stringified_response = response.model_dump() + ## LOGGING + logging_obj.post_call( + input=audio_file.name, + api_key=api_key, + additional_args={"complete_input_dict": data}, + original_response=stringified_response, + ) + return convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, response_type="image_generation") # type: ignore + except Exception as e: + ## LOGGING + logging_obj.post_call( + input=input, + api_key=api_key, + original_response=str(e), + ) + raise e + async def ahealth_check( self, model: Optional[str], diff --git a/litellm/llms/openai.py b/litellm/llms/openai.py index de605edff..4357063e8 100644 --- a/litellm/llms/openai.py +++ b/litellm/llms/openai.py @@ -1,4 +1,4 @@ -from typing import Optional, Union, Any +from typing import Optional, Union, Any, BinaryIO import types, time, json, traceback import httpx from .base import BaseLLM @@ -9,6 +9,7 @@ from litellm.utils import ( CustomStreamWrapper, convert_to_model_response_object, Usage, + TranscriptionResponse, ) from typing import Callable, Optional import aiohttp, requests @@ -774,6 +775,103 @@ class OpenAIChatCompletion(BaseLLM): else: raise OpenAIError(status_code=500, message=str(e)) + def audio_transcriptions( + self, + model: str, + audio_file: BinaryIO, + optional_params: dict, + model_response: TranscriptionResponse, + timeout: float, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + client=None, + max_retries=None, + logging_obj=None, + atranscriptions: bool = False, + ): + data = {"model": model, "file": audio_file, **optional_params} + if atranscriptions == True: + return self.async_audio_transcriptions( + audio_file=audio_file, + data=data, + model_response=model_response, + timeout=timeout, + api_key=api_key, + api_base=api_base, + client=client, + max_retries=max_retries, + logging_obj=logging_obj, + ) + if client is None: + openai_client = OpenAI( + api_key=api_key, + base_url=api_base, + http_client=litellm.client_session, + timeout=timeout, + max_retries=max_retries, + ) + else: + openai_client = client + response = openai_client.audio.transcriptions.create( + **data, timeout=timeout # type: ignore + ) + + stringified_response = response.model_dump() + ## LOGGING + logging_obj.post_call( + input=audio_file.name, + api_key=api_key, + additional_args={"complete_input_dict": data}, + original_response=stringified_response, + ) + final_response = convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, response_type="audio_transcription") # type: ignore + return final_response + + async def async_audio_transcriptions( + self, + audio_file: BinaryIO, + data: dict, + model_response: TranscriptionResponse, + timeout: float, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + client=None, + max_retries=None, + logging_obj=None, + ): + response = None + try: + if client is None: + openai_aclient = AsyncOpenAI( + api_key=api_key, + base_url=api_base, + http_client=litellm.aclient_session, + timeout=timeout, + max_retries=max_retries, + ) + else: + openai_aclient = client + response = await openai_aclient.audio.transcriptions.create( + **data, timeout=timeout + ) # type: ignore + stringified_response = response.model_dump() + ## LOGGING + logging_obj.post_call( + input=audio_file.name, + api_key=api_key, + additional_args={"complete_input_dict": data}, + original_response=stringified_response, + ) + return convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, response_type="image_generation") # type: ignore + except Exception as e: + ## LOGGING + logging_obj.post_call( + input=input, + api_key=api_key, + original_response=str(e), + ) + raise e + async def ahealth_check( self, model: Optional[str], diff --git a/litellm/main.py b/litellm/main.py index c8a8dbc1c..0447370fc 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -8,7 +8,7 @@ # Thank you ! We ❤️ you! - Krrish & Ishaan import os, openai, sys, json, inspect, uuid, datetime, threading -from typing import Any, Literal, Union +from typing import Any, Literal, Union, BinaryIO from functools import partial import dotenv, traceback, random, asyncio, time, contextvars from copy import deepcopy @@ -88,6 +88,7 @@ from litellm.utils import ( read_config_args, Choices, Message, + TranscriptionResponse, ) ####### ENVIRONMENT VARIABLES ################### @@ -3048,7 +3049,6 @@ def moderation( return response -##### Moderation ####################### @client async def amoderation(input: str, model: str, api_key: Optional[str] = None, **kwargs): # only supports open ai for now @@ -3071,11 +3071,11 @@ async def aimage_generation(*args, **kwargs): Asynchronously calls the `image_generation` function with the given arguments and keyword arguments. Parameters: - - `args` (tuple): Positional arguments to be passed to the `embedding` function. - - `kwargs` (dict): Keyword arguments to be passed to the `embedding` function. + - `args` (tuple): Positional arguments to be passed to the `image_generation` function. + - `kwargs` (dict): Keyword arguments to be passed to the `image_generation` function. Returns: - - `response` (Any): The response returned by the `embedding` function. + - `response` (Any): The response returned by the `image_generation` function. """ loop = asyncio.get_event_loop() model = args[0] if len(args) > 0 else kwargs["model"] @@ -3097,7 +3097,7 @@ async def aimage_generation(*args, **kwargs): # Await normally init_response = await loop.run_in_executor(None, func_with_context) if isinstance(init_response, dict) or isinstance( - init_response, ModelResponse + init_response, ImageResponse ): ## CACHING SCENARIO response = init_response elif asyncio.iscoroutine(init_response): @@ -3315,6 +3315,142 @@ def image_generation( ) +##### Transcription ####################### + + +async def atranscription(*args, **kwargs): + """ + Calls openai + azure whisper endpoints. + + Allows router to load balance between them + """ + loop = asyncio.get_event_loop() + model = args[0] if len(args) > 0 else kwargs["model"] + ### PASS ARGS TO Image Generation ### + kwargs["atranscription"] = True + custom_llm_provider = None + try: + # Use a partial function to pass your keyword arguments + func = partial(transcription, *args, **kwargs) + + # Add the context to the function + ctx = contextvars.copy_context() + func_with_context = partial(ctx.run, func) + + _, custom_llm_provider, _, _ = get_llm_provider( + model=model, api_base=kwargs.get("api_base", None) + ) + + # Await normally + init_response = await loop.run_in_executor(None, func_with_context) + if isinstance(init_response, dict) or isinstance( + init_response, TranscriptionResponse + ): ## CACHING SCENARIO + response = init_response + elif asyncio.iscoroutine(init_response): + response = await init_response + else: + # Call the synchronous function using run_in_executor + response = await loop.run_in_executor(None, func_with_context) + return response + except Exception as e: + custom_llm_provider = custom_llm_provider or "openai" + raise exception_type( + model=model, + custom_llm_provider=custom_llm_provider, + original_exception=e, + completion_kwargs=args, + ) + + +@client +def transcription( + model: str, + file: BinaryIO, + ## OPTIONAL OPENAI PARAMS ## + language: Optional[str] = None, + prompt: Optional[str] = None, + response_format: Optional[ + Literal["json", "text", "srt", "verbose_json", "vtt"] + ] = None, + temperature: Optional[int] = None, # openai defaults this to 0 + ## LITELLM PARAMS ## + user: Optional[str] = None, + timeout=600, # default to 10 minutes + api_key: Optional[str] = None, + api_base: Optional[str] = None, + api_version: Optional[str] = None, + litellm_logging_obj=None, + custom_llm_provider=None, + **kwargs, +): + """ + Calls openai + azure whisper endpoints. + + Allows router to load balance between them + """ + atranscriptions = kwargs.get("atranscriptions", False) + litellm_call_id = kwargs.get("litellm_call_id", None) + logger_fn = kwargs.get("logger_fn", None) + proxy_server_request = kwargs.get("proxy_server_request", None) + model_info = kwargs.get("model_info", None) + metadata = kwargs.get("metadata", {}) + + model_response = litellm.utils.TranscriptionResponse() + + model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore + + optional_params = { + "language": language, + "prompt": prompt, + "response_format": response_format, + "temperature": None, # openai defaults this to 0 + } + + if custom_llm_provider == "azure": + # azure configs + api_base = api_base or litellm.api_base or get_secret("AZURE_API_BASE") + + api_version = ( + api_version or litellm.api_version or get_secret("AZURE_API_VERSION") + ) + + azure_ad_token = kwargs.pop("azure_ad_token", None) or get_secret( + "AZURE_AD_TOKEN" + ) + + api_key = ( + api_key + or litellm.api_key + or litellm.azure_key + or get_secret("AZURE_API_KEY") + ) + response = azure_chat_completions.audio_transcriptions( + model=model, + audio_file=file, + optional_params=optional_params, + model_response=model_response, + atranscriptions=atranscriptions, + timeout=timeout, + logging_obj=litellm_logging_obj, + api_base=api_base, + api_key=api_key, + api_version=api_version, + azure_ad_token=azure_ad_token, + ) + elif custom_llm_provider == "openai": + response = openai_chat_completions.audio_transcriptions( + model=model, + audio_file=file, + optional_params=optional_params, + model_response=model_response, + atranscriptions=atranscriptions, + timeout=timeout, + logging_obj=litellm_logging_obj, + ) + return response + + ##### Health Endpoints ####################### diff --git a/litellm/utils.py b/litellm/utils.py index 44921f72e..263730313 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -10,7 +10,6 @@ import sys, re, binascii, struct import litellm import dotenv, json, traceback, threading, base64, ast - import subprocess, os from os.path import abspath, join, dirname import litellm, openai @@ -98,7 +97,7 @@ try: except Exception as e: verbose_logger.debug(f"Exception import enterprise features {str(e)}") -from typing import cast, List, Dict, Union, Optional, Literal, Any +from typing import cast, List, Dict, Union, Optional, Literal, Any, BinaryIO from .caching import Cache from concurrent.futures import ThreadPoolExecutor @@ -790,6 +789,38 @@ class ImageResponse(OpenAIObject): return self.dict() +class TranscriptionResponse(OpenAIObject): + text: Optional[str] = None + + _hidden_params: dict = {} + + def __init__(self, text=None): + super().__init__(text=text) + + def __contains__(self, key): + # Define custom behavior for the 'in' operator + return hasattr(self, key) + + def get(self, key, default=None): + # Custom .get() method to access attributes with a default value if the attribute doesn't exist + return getattr(self, key, default) + + def __getitem__(self, key): + # Allow dictionary-style access to attributes + return getattr(self, key) + + def __setitem__(self, key, value): + # Allow dictionary-style assignment of attributes + setattr(self, key, value) + + def json(self, **kwargs): + try: + return self.model_dump() # noqa + except: + # if using pydantic v1 + return self.dict() + + ############################################################ def print_verbose(print_statement, logger_only: bool = False): try: @@ -815,6 +846,8 @@ class CallTypes(Enum): aimage_generation = "aimage_generation" moderation = "moderation" amoderation = "amoderation" + atranscription = "atranscription" + transcription = "transcription" # Logging function -> log the exact model details + what's being sent | Non-BlockingP @@ -948,6 +981,7 @@ class Logging: curl_command = self.model_call_details # only print verbose if verbose logger is not set + if verbose_logger.level == 0: # this means verbose logger was not switched on - user is in litellm.set_verbose=True print_verbose(f"\033[92m{curl_command}\033[0m\n") @@ -2293,6 +2327,12 @@ def client(original_function): or call_type == CallTypes.text_completion.value ): messages = args[0] if len(args) > 0 else kwargs["prompt"] + elif ( + call_type == CallTypes.atranscription.value + or call_type == CallTypes.transcription.value + ): + _file_name: BinaryIO = args[1] if len(args) > 1 else kwargs["file"] + messages = _file_name.name stream = True if "stream" in kwargs and kwargs["stream"] == True else False logging_obj = Logging( model=model, @@ -6264,10 +6304,10 @@ def convert_to_streaming_response(response_object: Optional[dict] = None): def convert_to_model_response_object( response_object: Optional[dict] = None, model_response_object: Optional[ - Union[ModelResponse, EmbeddingResponse, ImageResponse] + Union[ModelResponse, EmbeddingResponse, ImageResponse, TranscriptionResponse] ] = None, response_type: Literal[ - "completion", "embedding", "image_generation" + "completion", "embedding", "image_generation", "audio_transcription" ] = "completion", stream=False, start_time=None, @@ -6378,6 +6418,19 @@ def convert_to_model_response_object( model_response_object.data = response_object["data"] return model_response_object + elif response_type == "audio_transcription" and ( + model_response_object is None + or isinstance(model_response_object, TranscriptionResponse) + ): + if response_object is None: + raise Exception("Error in response object format") + + if model_response_object is None: + model_response_object = TranscriptionResponse() + + if "text" in response_object: + model_response_object.text = response_object["text"] + return model_response_object except Exception as e: raise Exception(f"Invalid response object {traceback.format_exc()}") diff --git a/tests/gettysburg.wav b/tests/gettysburg.wav new file mode 100644 index 000000000..9690f521e Binary files /dev/null and b/tests/gettysburg.wav differ diff --git a/tests/test_whisper.py b/tests/test_whisper.py new file mode 100644 index 000000000..dfeebb161 --- /dev/null +++ b/tests/test_whisper.py @@ -0,0 +1,79 @@ +# What is this? +## Tests `litellm.transcription` endpoint +import pytest +import asyncio, time +import aiohttp +from openai import AsyncOpenAI +import sys, os, dotenv +from typing import Optional +from dotenv import load_dotenv + +# Get the current directory of the file being run +pwd = os.path.dirname(os.path.realpath(__file__)) +print(pwd) + +file_path = os.path.join(pwd, "gettysburg.wav") +audio_file = open(file_path, "rb") + +load_dotenv() + +sys.path.insert( + 0, os.path.abspath("../") +) # Adds the parent directory to the system path +import litellm + + +def test_transcription(): + transcript = litellm.transcription( + model="whisper-1", + file=audio_file, + ) + print(f"transcript: {transcript}") + + +# test_transcription() + + +def test_transcription_azure(): + litellm.set_verbose = True + transcript = litellm.transcription( + model="azure/azure-whisper", + file=audio_file, + api_base="https://my-endpoint-europe-berri-992.openai.azure.com/", + api_key=os.getenv("AZURE_EUROPE_API_KEY"), + api_version="2024-02-15-preview", + ) + + assert transcript.text is not None + assert isinstance(transcript.text, str) + + +# test_transcription_azure() + + +@pytest.mark.asyncio +async def test_transcription_async_azure(): + transcript = await litellm.atranscription( + model="azure/azure-whisper", + file=audio_file, + api_base="https://my-endpoint-europe-berri-992.openai.azure.com/", + api_key=os.getenv("AZURE_EUROPE_API_KEY"), + api_version="2024-02-15-preview", + ) + + assert transcript.text is not None + assert isinstance(transcript.text, str) + + +# asyncio.run(test_transcription_async_azure()) + + +@pytest.mark.asyncio +async def test_transcription_async_openai(): + transcript = await litellm.atranscription( + model="whisper-1", + file=audio_file, + ) + + assert transcript.text is not None + assert isinstance(transcript.text, str)