Merge pull request #2401 from BerriAI/litellm_transcription_endpoints

feat(main.py): support openai transcription endpoints
This commit is contained in:
Krish Dholakia 2024-03-08 23:07:48 -08:00 committed by GitHub
commit e245b1c98a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 516 additions and 12 deletions

View file

@ -10,7 +10,6 @@
import sys, re, binascii, struct
import litellm
import dotenv, json, traceback, threading, base64, ast
import subprocess, os
from os.path import abspath, join, dirname
import litellm, openai
@ -98,7 +97,7 @@ try:
except Exception as e:
verbose_logger.debug(f"Exception import enterprise features {str(e)}")
from typing import cast, List, Dict, Union, Optional, Literal, Any
from typing import cast, List, Dict, Union, Optional, Literal, Any, BinaryIO
from .caching import Cache
from concurrent.futures import ThreadPoolExecutor
@ -790,6 +789,38 @@ class ImageResponse(OpenAIObject):
return self.dict()
class TranscriptionResponse(OpenAIObject):
text: Optional[str] = None
_hidden_params: dict = {}
def __init__(self, text=None):
super().__init__(text=text)
def __contains__(self, key):
# Define custom behavior for the 'in' operator
return hasattr(self, key)
def get(self, key, default=None):
# Custom .get() method to access attributes with a default value if the attribute doesn't exist
return getattr(self, key, default)
def __getitem__(self, key):
# Allow dictionary-style access to attributes
return getattr(self, key)
def __setitem__(self, key, value):
# Allow dictionary-style assignment of attributes
setattr(self, key, value)
def json(self, **kwargs):
try:
return self.model_dump() # noqa
except:
# if using pydantic v1
return self.dict()
############################################################
def print_verbose(print_statement, logger_only: bool = False):
try:
@ -815,6 +846,8 @@ class CallTypes(Enum):
aimage_generation = "aimage_generation"
moderation = "moderation"
amoderation = "amoderation"
atranscription = "atranscription"
transcription = "transcription"
# Logging function -> log the exact model details + what's being sent | Non-BlockingP
@ -948,6 +981,7 @@ class Logging:
curl_command = self.model_call_details
# only print verbose if verbose logger is not set
if verbose_logger.level == 0:
# this means verbose logger was not switched on - user is in litellm.set_verbose=True
print_verbose(f"\033[92m{curl_command}\033[0m\n")
@ -2293,6 +2327,12 @@ def client(original_function):
or call_type == CallTypes.text_completion.value
):
messages = args[0] if len(args) > 0 else kwargs["prompt"]
elif (
call_type == CallTypes.atranscription.value
or call_type == CallTypes.transcription.value
):
_file_name: BinaryIO = args[1] if len(args) > 1 else kwargs["file"]
messages = _file_name.name
stream = True if "stream" in kwargs and kwargs["stream"] == True else False
logging_obj = Logging(
model=model,
@ -6264,10 +6304,10 @@ def convert_to_streaming_response(response_object: Optional[dict] = None):
def convert_to_model_response_object(
response_object: Optional[dict] = None,
model_response_object: Optional[
Union[ModelResponse, EmbeddingResponse, ImageResponse]
Union[ModelResponse, EmbeddingResponse, ImageResponse, TranscriptionResponse]
] = None,
response_type: Literal[
"completion", "embedding", "image_generation"
"completion", "embedding", "image_generation", "audio_transcription"
] = "completion",
stream=False,
start_time=None,
@ -6378,6 +6418,19 @@ def convert_to_model_response_object(
model_response_object.data = response_object["data"]
return model_response_object
elif response_type == "audio_transcription" and (
model_response_object is None
or isinstance(model_response_object, TranscriptionResponse)
):
if response_object is None:
raise Exception("Error in response object format")
if model_response_object is None:
model_response_object = TranscriptionResponse()
if "text" in response_object:
model_response_object.text = response_object["text"]
return model_response_object
except Exception as e:
raise Exception(f"Invalid response object {traceback.format_exc()}")