feat(main.py): support openai transcription endpoints

enable user to load balance between openai + azure transcription endpoints
This commit is contained in:
Krrish Dholakia 2024-03-08 10:25:19 -08:00
parent f8f01e5224
commit 696eb54455
7 changed files with 275 additions and 7 deletions

View file

@ -1,4 +1,4 @@
from typing import Optional, Union, Any
from typing import Optional, Union, Any, BinaryIO
import types, time, json, traceback
import httpx
from .base import BaseLLM
@ -9,6 +9,7 @@ from litellm.utils import (
CustomStreamWrapper,
convert_to_model_response_object,
Usage,
TranscriptionResponse,
)
from typing import Callable, Optional
import aiohttp, requests
@ -766,6 +767,103 @@ class OpenAIChatCompletion(BaseLLM):
else:
raise OpenAIError(status_code=500, message=str(e))
def audio_transcriptions(
self,
model: str,
audio_file: BinaryIO,
optional_params: dict,
model_response: TranscriptionResponse,
timeout: float,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
client=None,
max_retries=None,
logging_obj=None,
atranscriptions: bool = False,
):
data = {"model": model, "file": audio_file, **optional_params}
if atranscriptions == True:
return self.async_audio_transcriptions(
audio_file=audio_file,
data=data,
model_response=model_response,
timeout=timeout,
api_key=api_key,
api_base=api_base,
client=client,
max_retries=max_retries,
logging_obj=logging_obj,
)
if client is None:
openai_client = OpenAI(
api_key=api_key,
base_url=api_base,
http_client=litellm.client_session,
timeout=timeout,
max_retries=max_retries,
)
else:
openai_client = client
response = openai_client.audio.transcriptions.create(
**data, timeout=timeout # type: ignore
)
stringified_response = response.model_dump()
## LOGGING
logging_obj.post_call(
input=audio_file.name,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=stringified_response,
)
final_response = convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, response_type="audio_transcription") # type: ignore
return final_response
async def async_audio_transcriptions(
self,
audio_file: BinaryIO,
data: dict,
model_response: TranscriptionResponse,
timeout: float,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
client=None,
max_retries=None,
logging_obj=None,
):
response = None
try:
if client is None:
openai_aclient = AsyncOpenAI(
api_key=api_key,
base_url=api_base,
http_client=litellm.aclient_session,
timeout=timeout,
max_retries=max_retries,
)
else:
openai_aclient = client
response = await openai_aclient.audio.transcriptions.create(
**data, timeout=timeout
) # type: ignore
stringified_response = response.model_dump()
## LOGGING
logging_obj.post_call(
input=audio_file.name,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=stringified_response,
)
return convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, response_type="image_generation") # type: ignore
except Exception as e:
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
original_response=str(e),
)
raise e
async def ahealth_check(
self,
model: Optional[str],

View file

@ -8,7 +8,7 @@
# Thank you ! We ❤️ you! - Krrish & Ishaan
import os, openai, sys, json, inspect, uuid, datetime, threading
from typing import Any, Literal, Union
from typing import Any, Literal, Union, BinaryIO
from functools import partial
import dotenv, traceback, random, asyncio, time, contextvars
from copy import deepcopy
@ -3043,7 +3043,6 @@ def moderation(
return response
##### Moderation #######################
@client
async def amoderation(input: str, model: str, api_key: Optional[str] = None, **kwargs):
# only supports open ai for now
@ -3310,6 +3309,75 @@ def image_generation(
)
##### Transcription #######################
async def atranscription(*args, **kwargs):
"""
Calls openai + azure whisper endpoints.
Allows router to load balance between them
"""
pass
@client
def transcription(
model: str,
file: BinaryIO,
## OPTIONAL OPENAI PARAMS ##
language: Optional[str] = None,
prompt: Optional[str] = None,
response_format: Optional[
Literal["json", "text", "srt", "verbose_json", "vtt"]
] = None,
temperature: Optional[int] = None, # openai defaults this to 0
## LITELLM PARAMS ##
user: Optional[str] = None,
timeout=600, # default to 10 minutes
api_key: Optional[str] = None,
api_base: Optional[str] = None,
api_version: Optional[str] = None,
litellm_logging_obj=None,
custom_llm_provider=None,
**kwargs,
):
"""
Calls openai + azure whisper endpoints.
Allows router to load balance between them
"""
atranscriptions = kwargs.get("atranscriptions", False)
litellm_call_id = kwargs.get("litellm_call_id", None)
logger_fn = kwargs.get("logger_fn", None)
proxy_server_request = kwargs.get("proxy_server_request", None)
model_info = kwargs.get("model_info", None)
metadata = kwargs.get("metadata", {})
model_response = litellm.utils.TranscriptionResponse()
# model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore
custom_llm_provider = "openai"
optional_params = {
"language": language,
"prompt": prompt,
"response_format": response_format,
"temperature": None, # openai defaults this to 0
}
if custom_llm_provider == "openai":
return openai_chat_completions.audio_transcriptions(
model=model,
audio_file=file,
optional_params=optional_params,
model_response=model_response,
atranscriptions=atranscriptions,
timeout=timeout,
logging_obj=litellm_logging_obj,
)
return
##### Health Endpoints #######################

View file

@ -293,6 +293,18 @@
"output_cost_per_pixel": 0.0,
"litellm_provider": "openai"
},
"whisper-1": {
"mode": "audio_transcription",
"input_cost_per_second": 0,
"output_cost_per_second": 0.0001,
"litellm_provider": "openai"
},
"azure/whisper-1": {
"mode": "audio_transcription",
"input_cost_per_second": 0,
"output_cost_per_second": 0.0001,
"litellm_provider": "azure"
},
"azure/gpt-4-0125-preview": {
"max_tokens": 128000,
"max_input_tokens": 128000,

View file

@ -10,7 +10,6 @@
import sys, re, binascii, struct
import litellm
import dotenv, json, traceback, threading, base64, ast
import subprocess, os
from os.path import abspath, join, dirname
import litellm, openai
@ -98,7 +97,7 @@ try:
except Exception as e:
verbose_logger.debug(f"Exception import enterprise features {str(e)}")
from typing import cast, List, Dict, Union, Optional, Literal, Any
from typing import cast, List, Dict, Union, Optional, Literal, Any, BinaryIO
from .caching import Cache
from concurrent.futures import ThreadPoolExecutor
@ -790,6 +789,38 @@ class ImageResponse(OpenAIObject):
return self.dict()
class TranscriptionResponse(OpenAIObject):
text: Optional[str] = None
_hidden_params: dict = {}
def __init__(self, text=None):
super().__init__(text=text)
def __contains__(self, key):
# Define custom behavior for the 'in' operator
return hasattr(self, key)
def get(self, key, default=None):
# Custom .get() method to access attributes with a default value if the attribute doesn't exist
return getattr(self, key, default)
def __getitem__(self, key):
# Allow dictionary-style access to attributes
return getattr(self, key)
def __setitem__(self, key, value):
# Allow dictionary-style assignment of attributes
setattr(self, key, value)
def json(self, **kwargs):
try:
return self.model_dump() # noqa
except:
# if using pydantic v1
return self.dict()
############################################################
def print_verbose(print_statement, logger_only: bool = False):
try:
@ -815,6 +846,8 @@ class CallTypes(Enum):
aimage_generation = "aimage_generation"
moderation = "moderation"
amoderation = "amoderation"
atranscription = "atranscription"
transcription = "transcription"
# Logging function -> log the exact model details + what's being sent | Non-BlockingP
@ -2271,6 +2304,12 @@ def client(original_function):
or call_type == CallTypes.text_completion.value
):
messages = args[0] if len(args) > 0 else kwargs["prompt"]
elif (
call_type == CallTypes.atranscription.value
or call_type == CallTypes.transcription.value
):
_file_name: BinaryIO = args[1] if len(args) > 1 else kwargs["file"]
messages = _file_name.name
stream = True if "stream" in kwargs and kwargs["stream"] == True else False
logging_obj = Logging(
model=model,
@ -6135,10 +6174,10 @@ def convert_to_streaming_response(response_object: Optional[dict] = None):
def convert_to_model_response_object(
response_object: Optional[dict] = None,
model_response_object: Optional[
Union[ModelResponse, EmbeddingResponse, ImageResponse]
Union[ModelResponse, EmbeddingResponse, ImageResponse, TranscriptionResponse]
] = None,
response_type: Literal[
"completion", "embedding", "image_generation"
"completion", "embedding", "image_generation", "audio_transcription"
] = "completion",
stream=False,
start_time=None,
@ -6249,6 +6288,19 @@ def convert_to_model_response_object(
model_response_object.data = response_object["data"]
return model_response_object
elif response_type == "audio_transcription" and (
model_response_object is None
or isinstance(model_response_object, TranscriptionResponse)
):
if response_object is None:
raise Exception("Error in response object format")
if model_response_object is None:
model_response_object = TranscriptionResponse()
if "text" in response_object:
model_response_object.text = response_object["text"]
return model_response_object
except Exception as e:
raise Exception(f"Invalid response object {traceback.format_exc()}")

View file

@ -293,6 +293,18 @@
"output_cost_per_pixel": 0.0,
"litellm_provider": "openai"
},
"whisper-1": {
"mode": "audio_transcription",
"input_cost_per_second": 0,
"output_cost_per_second": 0.0001,
"litellm_provider": "openai"
},
"azure/whisper-1": {
"mode": "audio_transcription",
"input_cost_per_second": 0,
"output_cost_per_second": 0.0001,
"litellm_provider": "azure"
},
"azure/gpt-4-0125-preview": {
"max_tokens": 128000,
"max_input_tokens": 128000,

BIN
tests/gettysburg.wav Normal file

Binary file not shown.

26
tests/test_whisper.py Normal file
View file

@ -0,0 +1,26 @@
# What is this?
## Tests `litellm.transcription` endpoint
import pytest
import asyncio, time
import aiohttp
from openai import AsyncOpenAI
import sys, os, dotenv
from typing import Optional
from dotenv import load_dotenv
audio_file = open("./gettysburg.wav", "rb")
load_dotenv()
sys.path.insert(
0, os.path.abspath("../")
) # Adds the parent directory to the system path
import litellm
def test_transcription():
transcript = litellm.transcription(model="whisper-1", file=audio_file)
print(f"transcript: {transcript}")
test_transcription()