Merge pull request #2401 from BerriAI/litellm_transcription_endpoints

feat(main.py): support openai transcription endpoints
This commit is contained in:
Krish Dholakia 2024-03-08 23:07:48 -08:00 committed by GitHub
commit e245b1c98a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 516 additions and 12 deletions

View file

@ -7,13 +7,15 @@ from litellm.utils import (
Message,
CustomStreamWrapper,
convert_to_model_response_object,
TranscriptionResponse,
)
from typing import Callable, Optional
from typing import Callable, Optional, BinaryIO
from litellm import OpenAIConfig
import litellm, json
import httpx
from .custom_httpx.azure_dall_e_2 import CustomHTTPTransport, AsyncCustomHTTPTransport
from openai import AzureOpenAI, AsyncAzureOpenAI
import uuid
class AzureOpenAIError(Exception):
@ -780,6 +782,142 @@ class AzureChatCompletion(BaseLLM):
else:
raise AzureOpenAIError(status_code=500, message=str(e))
def audio_transcriptions(
self,
model: str,
audio_file: BinaryIO,
optional_params: dict,
model_response: TranscriptionResponse,
timeout: float,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
api_version: Optional[str] = None,
client=None,
azure_ad_token: Optional[str] = None,
max_retries=None,
logging_obj=None,
atranscriptions: bool = False,
):
data = {"model": model, "file": audio_file, **optional_params}
# init AzureOpenAI Client
azure_client_params = {
"api_version": api_version,
"azure_endpoint": api_base,
"azure_deployment": model,
"max_retries": max_retries,
"timeout": timeout,
}
azure_client_params = select_azure_base_url_or_endpoint(
azure_client_params=azure_client_params
)
if api_key is not None:
azure_client_params["api_key"] = api_key
elif azure_ad_token is not None:
azure_client_params["azure_ad_token"] = azure_ad_token
if atranscriptions == True:
return self.async_audio_transcriptions(
audio_file=audio_file,
data=data,
model_response=model_response,
timeout=timeout,
api_key=api_key,
api_base=api_base,
client=client,
azure_client_params=azure_client_params,
max_retries=max_retries,
logging_obj=logging_obj,
)
if client is None:
azure_client = AzureOpenAI(http_client=litellm.client_session, **azure_client_params) # type: ignore
else:
azure_client = client
## LOGGING
logging_obj.pre_call(
input=f"audio_file_{uuid.uuid4()}",
api_key=azure_client.api_key,
additional_args={
"headers": {"Authorization": f"Bearer {azure_client.api_key}"},
"api_base": azure_client._base_url._uri_reference,
"atranscription": True,
"complete_input_dict": data,
},
)
response = azure_client.audio.transcriptions.create(
**data, timeout=timeout # type: ignore
)
stringified_response = response.model_dump()
## LOGGING
logging_obj.post_call(
input=audio_file.name,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=stringified_response,
)
final_response = convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, response_type="audio_transcription") # type: ignore
return final_response
async def async_audio_transcriptions(
self,
audio_file: BinaryIO,
data: dict,
model_response: TranscriptionResponse,
timeout: float,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
client=None,
azure_client_params=None,
max_retries=None,
logging_obj=None,
):
response = None
try:
if client is None:
async_azure_client = AsyncAzureOpenAI(
**azure_client_params,
http_client=litellm.aclient_session,
)
else:
async_azure_client = client
## LOGGING
logging_obj.pre_call(
input=f"audio_file_{uuid.uuid4()}",
api_key=async_azure_client.api_key,
additional_args={
"headers": {
"Authorization": f"Bearer {async_azure_client.api_key}"
},
"api_base": async_azure_client._base_url._uri_reference,
"atranscription": True,
"complete_input_dict": data,
},
)
response = await async_azure_client.audio.transcriptions.create(
**data, timeout=timeout
) # type: ignore
stringified_response = response.model_dump()
## LOGGING
logging_obj.post_call(
input=audio_file.name,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=stringified_response,
)
return convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, response_type="image_generation") # type: ignore
except Exception as e:
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
original_response=str(e),
)
raise e
async def ahealth_check(
self,
model: Optional[str],

View file

@ -1,4 +1,4 @@
from typing import Optional, Union, Any
from typing import Optional, Union, Any, BinaryIO
import types, time, json, traceback
import httpx
from .base import BaseLLM
@ -9,6 +9,7 @@ from litellm.utils import (
CustomStreamWrapper,
convert_to_model_response_object,
Usage,
TranscriptionResponse,
)
from typing import Callable, Optional
import aiohttp, requests
@ -774,6 +775,103 @@ class OpenAIChatCompletion(BaseLLM):
else:
raise OpenAIError(status_code=500, message=str(e))
def audio_transcriptions(
self,
model: str,
audio_file: BinaryIO,
optional_params: dict,
model_response: TranscriptionResponse,
timeout: float,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
client=None,
max_retries=None,
logging_obj=None,
atranscriptions: bool = False,
):
data = {"model": model, "file": audio_file, **optional_params}
if atranscriptions == True:
return self.async_audio_transcriptions(
audio_file=audio_file,
data=data,
model_response=model_response,
timeout=timeout,
api_key=api_key,
api_base=api_base,
client=client,
max_retries=max_retries,
logging_obj=logging_obj,
)
if client is None:
openai_client = OpenAI(
api_key=api_key,
base_url=api_base,
http_client=litellm.client_session,
timeout=timeout,
max_retries=max_retries,
)
else:
openai_client = client
response = openai_client.audio.transcriptions.create(
**data, timeout=timeout # type: ignore
)
stringified_response = response.model_dump()
## LOGGING
logging_obj.post_call(
input=audio_file.name,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=stringified_response,
)
final_response = convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, response_type="audio_transcription") # type: ignore
return final_response
async def async_audio_transcriptions(
self,
audio_file: BinaryIO,
data: dict,
model_response: TranscriptionResponse,
timeout: float,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
client=None,
max_retries=None,
logging_obj=None,
):
response = None
try:
if client is None:
openai_aclient = AsyncOpenAI(
api_key=api_key,
base_url=api_base,
http_client=litellm.aclient_session,
timeout=timeout,
max_retries=max_retries,
)
else:
openai_aclient = client
response = await openai_aclient.audio.transcriptions.create(
**data, timeout=timeout
) # type: ignore
stringified_response = response.model_dump()
## LOGGING
logging_obj.post_call(
input=audio_file.name,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=stringified_response,
)
return convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, response_type="image_generation") # type: ignore
except Exception as e:
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
original_response=str(e),
)
raise e
async def ahealth_check(
self,
model: Optional[str],

View file

@ -8,7 +8,7 @@
# Thank you ! We ❤️ you! - Krrish & Ishaan
import os, openai, sys, json, inspect, uuid, datetime, threading
from typing import Any, Literal, Union
from typing import Any, Literal, Union, BinaryIO
from functools import partial
import dotenv, traceback, random, asyncio, time, contextvars
from copy import deepcopy
@ -88,6 +88,7 @@ from litellm.utils import (
read_config_args,
Choices,
Message,
TranscriptionResponse,
)
####### ENVIRONMENT VARIABLES ###################
@ -3048,7 +3049,6 @@ def moderation(
return response
##### Moderation #######################
@client
async def amoderation(input: str, model: str, api_key: Optional[str] = None, **kwargs):
# only supports open ai for now
@ -3071,11 +3071,11 @@ async def aimage_generation(*args, **kwargs):
Asynchronously calls the `image_generation` function with the given arguments and keyword arguments.
Parameters:
- `args` (tuple): Positional arguments to be passed to the `embedding` function.
- `kwargs` (dict): Keyword arguments to be passed to the `embedding` function.
- `args` (tuple): Positional arguments to be passed to the `image_generation` function.
- `kwargs` (dict): Keyword arguments to be passed to the `image_generation` function.
Returns:
- `response` (Any): The response returned by the `embedding` function.
- `response` (Any): The response returned by the `image_generation` function.
"""
loop = asyncio.get_event_loop()
model = args[0] if len(args) > 0 else kwargs["model"]
@ -3097,7 +3097,7 @@ async def aimage_generation(*args, **kwargs):
# Await normally
init_response = await loop.run_in_executor(None, func_with_context)
if isinstance(init_response, dict) or isinstance(
init_response, ModelResponse
init_response, ImageResponse
): ## CACHING SCENARIO
response = init_response
elif asyncio.iscoroutine(init_response):
@ -3315,6 +3315,142 @@ def image_generation(
)
##### Transcription #######################
async def atranscription(*args, **kwargs):
"""
Calls openai + azure whisper endpoints.
Allows router to load balance between them
"""
loop = asyncio.get_event_loop()
model = args[0] if len(args) > 0 else kwargs["model"]
### PASS ARGS TO Image Generation ###
kwargs["atranscription"] = True
custom_llm_provider = None
try:
# Use a partial function to pass your keyword arguments
func = partial(transcription, *args, **kwargs)
# Add the context to the function
ctx = contextvars.copy_context()
func_with_context = partial(ctx.run, func)
_, custom_llm_provider, _, _ = get_llm_provider(
model=model, api_base=kwargs.get("api_base", None)
)
# Await normally
init_response = await loop.run_in_executor(None, func_with_context)
if isinstance(init_response, dict) or isinstance(
init_response, TranscriptionResponse
): ## CACHING SCENARIO
response = init_response
elif asyncio.iscoroutine(init_response):
response = await init_response
else:
# Call the synchronous function using run_in_executor
response = await loop.run_in_executor(None, func_with_context)
return response
except Exception as e:
custom_llm_provider = custom_llm_provider or "openai"
raise exception_type(
model=model,
custom_llm_provider=custom_llm_provider,
original_exception=e,
completion_kwargs=args,
)
@client
def transcription(
model: str,
file: BinaryIO,
## OPTIONAL OPENAI PARAMS ##
language: Optional[str] = None,
prompt: Optional[str] = None,
response_format: Optional[
Literal["json", "text", "srt", "verbose_json", "vtt"]
] = None,
temperature: Optional[int] = None, # openai defaults this to 0
## LITELLM PARAMS ##
user: Optional[str] = None,
timeout=600, # default to 10 minutes
api_key: Optional[str] = None,
api_base: Optional[str] = None,
api_version: Optional[str] = None,
litellm_logging_obj=None,
custom_llm_provider=None,
**kwargs,
):
"""
Calls openai + azure whisper endpoints.
Allows router to load balance between them
"""
atranscriptions = kwargs.get("atranscriptions", False)
litellm_call_id = kwargs.get("litellm_call_id", None)
logger_fn = kwargs.get("logger_fn", None)
proxy_server_request = kwargs.get("proxy_server_request", None)
model_info = kwargs.get("model_info", None)
metadata = kwargs.get("metadata", {})
model_response = litellm.utils.TranscriptionResponse()
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore
optional_params = {
"language": language,
"prompt": prompt,
"response_format": response_format,
"temperature": None, # openai defaults this to 0
}
if custom_llm_provider == "azure":
# azure configs
api_base = api_base or litellm.api_base or get_secret("AZURE_API_BASE")
api_version = (
api_version or litellm.api_version or get_secret("AZURE_API_VERSION")
)
azure_ad_token = kwargs.pop("azure_ad_token", None) or get_secret(
"AZURE_AD_TOKEN"
)
api_key = (
api_key
or litellm.api_key
or litellm.azure_key
or get_secret("AZURE_API_KEY")
)
response = azure_chat_completions.audio_transcriptions(
model=model,
audio_file=file,
optional_params=optional_params,
model_response=model_response,
atranscriptions=atranscriptions,
timeout=timeout,
logging_obj=litellm_logging_obj,
api_base=api_base,
api_key=api_key,
api_version=api_version,
azure_ad_token=azure_ad_token,
)
elif custom_llm_provider == "openai":
response = openai_chat_completions.audio_transcriptions(
model=model,
audio_file=file,
optional_params=optional_params,
model_response=model_response,
atranscriptions=atranscriptions,
timeout=timeout,
logging_obj=litellm_logging_obj,
)
return response
##### Health Endpoints #######################

View file

@ -10,7 +10,6 @@
import sys, re, binascii, struct
import litellm
import dotenv, json, traceback, threading, base64, ast
import subprocess, os
from os.path import abspath, join, dirname
import litellm, openai
@ -98,7 +97,7 @@ try:
except Exception as e:
verbose_logger.debug(f"Exception import enterprise features {str(e)}")
from typing import cast, List, Dict, Union, Optional, Literal, Any
from typing import cast, List, Dict, Union, Optional, Literal, Any, BinaryIO
from .caching import Cache
from concurrent.futures import ThreadPoolExecutor
@ -790,6 +789,38 @@ class ImageResponse(OpenAIObject):
return self.dict()
class TranscriptionResponse(OpenAIObject):
text: Optional[str] = None
_hidden_params: dict = {}
def __init__(self, text=None):
super().__init__(text=text)
def __contains__(self, key):
# Define custom behavior for the 'in' operator
return hasattr(self, key)
def get(self, key, default=None):
# Custom .get() method to access attributes with a default value if the attribute doesn't exist
return getattr(self, key, default)
def __getitem__(self, key):
# Allow dictionary-style access to attributes
return getattr(self, key)
def __setitem__(self, key, value):
# Allow dictionary-style assignment of attributes
setattr(self, key, value)
def json(self, **kwargs):
try:
return self.model_dump() # noqa
except:
# if using pydantic v1
return self.dict()
############################################################
def print_verbose(print_statement, logger_only: bool = False):
try:
@ -815,6 +846,8 @@ class CallTypes(Enum):
aimage_generation = "aimage_generation"
moderation = "moderation"
amoderation = "amoderation"
atranscription = "atranscription"
transcription = "transcription"
# Logging function -> log the exact model details + what's being sent | Non-BlockingP
@ -948,6 +981,7 @@ class Logging:
curl_command = self.model_call_details
# only print verbose if verbose logger is not set
if verbose_logger.level == 0:
# this means verbose logger was not switched on - user is in litellm.set_verbose=True
print_verbose(f"\033[92m{curl_command}\033[0m\n")
@ -2293,6 +2327,12 @@ def client(original_function):
or call_type == CallTypes.text_completion.value
):
messages = args[0] if len(args) > 0 else kwargs["prompt"]
elif (
call_type == CallTypes.atranscription.value
or call_type == CallTypes.transcription.value
):
_file_name: BinaryIO = args[1] if len(args) > 1 else kwargs["file"]
messages = _file_name.name
stream = True if "stream" in kwargs and kwargs["stream"] == True else False
logging_obj = Logging(
model=model,
@ -6264,10 +6304,10 @@ def convert_to_streaming_response(response_object: Optional[dict] = None):
def convert_to_model_response_object(
response_object: Optional[dict] = None,
model_response_object: Optional[
Union[ModelResponse, EmbeddingResponse, ImageResponse]
Union[ModelResponse, EmbeddingResponse, ImageResponse, TranscriptionResponse]
] = None,
response_type: Literal[
"completion", "embedding", "image_generation"
"completion", "embedding", "image_generation", "audio_transcription"
] = "completion",
stream=False,
start_time=None,
@ -6378,6 +6418,19 @@ def convert_to_model_response_object(
model_response_object.data = response_object["data"]
return model_response_object
elif response_type == "audio_transcription" and (
model_response_object is None
or isinstance(model_response_object, TranscriptionResponse)
):
if response_object is None:
raise Exception("Error in response object format")
if model_response_object is None:
model_response_object = TranscriptionResponse()
if "text" in response_object:
model_response_object.text = response_object["text"]
return model_response_object
except Exception as e:
raise Exception(f"Invalid response object {traceback.format_exc()}")

BIN
tests/gettysburg.wav Normal file

Binary file not shown.

79
tests/test_whisper.py Normal file
View file

@ -0,0 +1,79 @@
# What is this?
## Tests `litellm.transcription` endpoint
import pytest
import asyncio, time
import aiohttp
from openai import AsyncOpenAI
import sys, os, dotenv
from typing import Optional
from dotenv import load_dotenv
# Get the current directory of the file being run
pwd = os.path.dirname(os.path.realpath(__file__))
print(pwd)
file_path = os.path.join(pwd, "gettysburg.wav")
audio_file = open(file_path, "rb")
load_dotenv()
sys.path.insert(
0, os.path.abspath("../")
) # Adds the parent directory to the system path
import litellm
def test_transcription():
transcript = litellm.transcription(
model="whisper-1",
file=audio_file,
)
print(f"transcript: {transcript}")
# test_transcription()
def test_transcription_azure():
litellm.set_verbose = True
transcript = litellm.transcription(
model="azure/azure-whisper",
file=audio_file,
api_base="https://my-endpoint-europe-berri-992.openai.azure.com/",
api_key=os.getenv("AZURE_EUROPE_API_KEY"),
api_version="2024-02-15-preview",
)
assert transcript.text is not None
assert isinstance(transcript.text, str)
# test_transcription_azure()
@pytest.mark.asyncio
async def test_transcription_async_azure():
transcript = await litellm.atranscription(
model="azure/azure-whisper",
file=audio_file,
api_base="https://my-endpoint-europe-berri-992.openai.azure.com/",
api_key=os.getenv("AZURE_EUROPE_API_KEY"),
api_version="2024-02-15-preview",
)
assert transcript.text is not None
assert isinstance(transcript.text, str)
# asyncio.run(test_transcription_async_azure())
@pytest.mark.asyncio
async def test_transcription_async_openai():
transcript = await litellm.atranscription(
model="whisper-1",
file=audio_file,
)
assert transcript.text is not None
assert isinstance(transcript.text, str)