diff --git a/litellm/llms/openai.py b/litellm/llms/openai.py index 90846b627..fca950d31 100644 --- a/litellm/llms/openai.py +++ b/litellm/llms/openai.py @@ -1,4 +1,4 @@ -from typing import Optional, Union, Any +from typing import Optional, Union, Any, BinaryIO import types, time, json, traceback import httpx from .base import BaseLLM @@ -9,6 +9,7 @@ from litellm.utils import ( CustomStreamWrapper, convert_to_model_response_object, Usage, + TranscriptionResponse, ) from typing import Callable, Optional import aiohttp, requests @@ -766,6 +767,103 @@ class OpenAIChatCompletion(BaseLLM): else: raise OpenAIError(status_code=500, message=str(e)) + def audio_transcriptions( + self, + model: str, + audio_file: BinaryIO, + optional_params: dict, + model_response: TranscriptionResponse, + timeout: float, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + client=None, + max_retries=None, + logging_obj=None, + atranscriptions: bool = False, + ): + data = {"model": model, "file": audio_file, **optional_params} + if atranscriptions == True: + return self.async_audio_transcriptions( + audio_file=audio_file, + data=data, + model_response=model_response, + timeout=timeout, + api_key=api_key, + api_base=api_base, + client=client, + max_retries=max_retries, + logging_obj=logging_obj, + ) + if client is None: + openai_client = OpenAI( + api_key=api_key, + base_url=api_base, + http_client=litellm.client_session, + timeout=timeout, + max_retries=max_retries, + ) + else: + openai_client = client + response = openai_client.audio.transcriptions.create( + **data, timeout=timeout # type: ignore + ) + + stringified_response = response.model_dump() + ## LOGGING + logging_obj.post_call( + input=audio_file.name, + api_key=api_key, + additional_args={"complete_input_dict": data}, + original_response=stringified_response, + ) + final_response = convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, response_type="audio_transcription") # type: ignore + return final_response + + async def async_audio_transcriptions( + self, + audio_file: BinaryIO, + data: dict, + model_response: TranscriptionResponse, + timeout: float, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + client=None, + max_retries=None, + logging_obj=None, + ): + response = None + try: + if client is None: + openai_aclient = AsyncOpenAI( + api_key=api_key, + base_url=api_base, + http_client=litellm.aclient_session, + timeout=timeout, + max_retries=max_retries, + ) + else: + openai_aclient = client + response = await openai_aclient.audio.transcriptions.create( + **data, timeout=timeout + ) # type: ignore + stringified_response = response.model_dump() + ## LOGGING + logging_obj.post_call( + input=audio_file.name, + api_key=api_key, + additional_args={"complete_input_dict": data}, + original_response=stringified_response, + ) + return convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, response_type="image_generation") # type: ignore + except Exception as e: + ## LOGGING + logging_obj.post_call( + input=input, + api_key=api_key, + original_response=str(e), + ) + raise e + async def ahealth_check( self, model: Optional[str], diff --git a/litellm/main.py b/litellm/main.py index 63649844a..2df9686fe 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -8,7 +8,7 @@ # Thank you ! We ❤️ you! - Krrish & Ishaan import os, openai, sys, json, inspect, uuid, datetime, threading -from typing import Any, Literal, Union +from typing import Any, Literal, Union, BinaryIO from functools import partial import dotenv, traceback, random, asyncio, time, contextvars from copy import deepcopy @@ -3043,7 +3043,6 @@ def moderation( return response -##### Moderation ####################### @client async def amoderation(input: str, model: str, api_key: Optional[str] = None, **kwargs): # only supports open ai for now @@ -3310,6 +3309,75 @@ def image_generation( ) +##### Transcription ####################### + + +async def atranscription(*args, **kwargs): + """ + Calls openai + azure whisper endpoints. + + Allows router to load balance between them + """ + pass + + +@client +def transcription( + model: str, + file: BinaryIO, + ## OPTIONAL OPENAI PARAMS ## + language: Optional[str] = None, + prompt: Optional[str] = None, + response_format: Optional[ + Literal["json", "text", "srt", "verbose_json", "vtt"] + ] = None, + temperature: Optional[int] = None, # openai defaults this to 0 + ## LITELLM PARAMS ## + user: Optional[str] = None, + timeout=600, # default to 10 minutes + api_key: Optional[str] = None, + api_base: Optional[str] = None, + api_version: Optional[str] = None, + litellm_logging_obj=None, + custom_llm_provider=None, + **kwargs, +): + """ + Calls openai + azure whisper endpoints. + + Allows router to load balance between them + """ + atranscriptions = kwargs.get("atranscriptions", False) + litellm_call_id = kwargs.get("litellm_call_id", None) + logger_fn = kwargs.get("logger_fn", None) + proxy_server_request = kwargs.get("proxy_server_request", None) + model_info = kwargs.get("model_info", None) + metadata = kwargs.get("metadata", {}) + + model_response = litellm.utils.TranscriptionResponse() + + # model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore + custom_llm_provider = "openai" + + optional_params = { + "language": language, + "prompt": prompt, + "response_format": response_format, + "temperature": None, # openai defaults this to 0 + } + if custom_llm_provider == "openai": + return openai_chat_completions.audio_transcriptions( + model=model, + audio_file=file, + optional_params=optional_params, + model_response=model_response, + atranscriptions=atranscriptions, + timeout=timeout, + logging_obj=litellm_logging_obj, + ) + return + + ##### Health Endpoints ####################### diff --git a/litellm/model_prices_and_context_window_backup.json b/litellm/model_prices_and_context_window_backup.json index d56450b23..18c4b0d9a 100644 --- a/litellm/model_prices_and_context_window_backup.json +++ b/litellm/model_prices_and_context_window_backup.json @@ -293,6 +293,18 @@ "output_cost_per_pixel": 0.0, "litellm_provider": "openai" }, + "whisper-1": { + "mode": "audio_transcription", + "input_cost_per_second": 0, + "output_cost_per_second": 0.0001, + "litellm_provider": "openai" + }, + "azure/whisper-1": { + "mode": "audio_transcription", + "input_cost_per_second": 0, + "output_cost_per_second": 0.0001, + "litellm_provider": "azure" + }, "azure/gpt-4-0125-preview": { "max_tokens": 128000, "max_input_tokens": 128000, diff --git a/litellm/utils.py b/litellm/utils.py index 38836a4bc..330903f5a 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -10,7 +10,6 @@ import sys, re, binascii, struct import litellm import dotenv, json, traceback, threading, base64, ast - import subprocess, os from os.path import abspath, join, dirname import litellm, openai @@ -98,7 +97,7 @@ try: except Exception as e: verbose_logger.debug(f"Exception import enterprise features {str(e)}") -from typing import cast, List, Dict, Union, Optional, Literal, Any +from typing import cast, List, Dict, Union, Optional, Literal, Any, BinaryIO from .caching import Cache from concurrent.futures import ThreadPoolExecutor @@ -790,6 +789,38 @@ class ImageResponse(OpenAIObject): return self.dict() +class TranscriptionResponse(OpenAIObject): + text: Optional[str] = None + + _hidden_params: dict = {} + + def __init__(self, text=None): + super().__init__(text=text) + + def __contains__(self, key): + # Define custom behavior for the 'in' operator + return hasattr(self, key) + + def get(self, key, default=None): + # Custom .get() method to access attributes with a default value if the attribute doesn't exist + return getattr(self, key, default) + + def __getitem__(self, key): + # Allow dictionary-style access to attributes + return getattr(self, key) + + def __setitem__(self, key, value): + # Allow dictionary-style assignment of attributes + setattr(self, key, value) + + def json(self, **kwargs): + try: + return self.model_dump() # noqa + except: + # if using pydantic v1 + return self.dict() + + ############################################################ def print_verbose(print_statement, logger_only: bool = False): try: @@ -815,6 +846,8 @@ class CallTypes(Enum): aimage_generation = "aimage_generation" moderation = "moderation" amoderation = "amoderation" + atranscription = "atranscription" + transcription = "transcription" # Logging function -> log the exact model details + what's being sent | Non-BlockingP @@ -2271,6 +2304,12 @@ def client(original_function): or call_type == CallTypes.text_completion.value ): messages = args[0] if len(args) > 0 else kwargs["prompt"] + elif ( + call_type == CallTypes.atranscription.value + or call_type == CallTypes.transcription.value + ): + _file_name: BinaryIO = args[1] if len(args) > 1 else kwargs["file"] + messages = _file_name.name stream = True if "stream" in kwargs and kwargs["stream"] == True else False logging_obj = Logging( model=model, @@ -6135,10 +6174,10 @@ def convert_to_streaming_response(response_object: Optional[dict] = None): def convert_to_model_response_object( response_object: Optional[dict] = None, model_response_object: Optional[ - Union[ModelResponse, EmbeddingResponse, ImageResponse] + Union[ModelResponse, EmbeddingResponse, ImageResponse, TranscriptionResponse] ] = None, response_type: Literal[ - "completion", "embedding", "image_generation" + "completion", "embedding", "image_generation", "audio_transcription" ] = "completion", stream=False, start_time=None, @@ -6249,6 +6288,19 @@ def convert_to_model_response_object( model_response_object.data = response_object["data"] return model_response_object + elif response_type == "audio_transcription" and ( + model_response_object is None + or isinstance(model_response_object, TranscriptionResponse) + ): + if response_object is None: + raise Exception("Error in response object format") + + if model_response_object is None: + model_response_object = TranscriptionResponse() + + if "text" in response_object: + model_response_object.text = response_object["text"] + return model_response_object except Exception as e: raise Exception(f"Invalid response object {traceback.format_exc()}") diff --git a/model_prices_and_context_window.json b/model_prices_and_context_window.json index d56450b23..18c4b0d9a 100644 --- a/model_prices_and_context_window.json +++ b/model_prices_and_context_window.json @@ -293,6 +293,18 @@ "output_cost_per_pixel": 0.0, "litellm_provider": "openai" }, + "whisper-1": { + "mode": "audio_transcription", + "input_cost_per_second": 0, + "output_cost_per_second": 0.0001, + "litellm_provider": "openai" + }, + "azure/whisper-1": { + "mode": "audio_transcription", + "input_cost_per_second": 0, + "output_cost_per_second": 0.0001, + "litellm_provider": "azure" + }, "azure/gpt-4-0125-preview": { "max_tokens": 128000, "max_input_tokens": 128000, diff --git a/tests/gettysburg.wav b/tests/gettysburg.wav new file mode 100644 index 000000000..9690f521e Binary files /dev/null and b/tests/gettysburg.wav differ diff --git a/tests/test_whisper.py b/tests/test_whisper.py new file mode 100644 index 000000000..8ee3b428c --- /dev/null +++ b/tests/test_whisper.py @@ -0,0 +1,26 @@ +# What is this? +## Tests `litellm.transcription` endpoint +import pytest +import asyncio, time +import aiohttp +from openai import AsyncOpenAI +import sys, os, dotenv +from typing import Optional +from dotenv import load_dotenv + +audio_file = open("./gettysburg.wav", "rb") + +load_dotenv() + +sys.path.insert( + 0, os.path.abspath("../") +) # Adds the parent directory to the system path +import litellm + + +def test_transcription(): + transcript = litellm.transcription(model="whisper-1", file=audio_file) + print(f"transcript: {transcript}") + + +test_transcription()