forked from phoenix/litellm-mirror
Merge pull request #2401 from BerriAI/litellm_transcription_endpoints
feat(main.py): support openai transcription endpoints
This commit is contained in:
commit
e245b1c98a
6 changed files with 516 additions and 12 deletions
|
@ -10,7 +10,6 @@
|
|||
import sys, re, binascii, struct
|
||||
import litellm
|
||||
import dotenv, json, traceback, threading, base64, ast
|
||||
|
||||
import subprocess, os
|
||||
from os.path import abspath, join, dirname
|
||||
import litellm, openai
|
||||
|
@ -98,7 +97,7 @@ try:
|
|||
except Exception as e:
|
||||
verbose_logger.debug(f"Exception import enterprise features {str(e)}")
|
||||
|
||||
from typing import cast, List, Dict, Union, Optional, Literal, Any
|
||||
from typing import cast, List, Dict, Union, Optional, Literal, Any, BinaryIO
|
||||
from .caching import Cache
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
|
@ -790,6 +789,38 @@ class ImageResponse(OpenAIObject):
|
|||
return self.dict()
|
||||
|
||||
|
||||
class TranscriptionResponse(OpenAIObject):
|
||||
text: Optional[str] = None
|
||||
|
||||
_hidden_params: dict = {}
|
||||
|
||||
def __init__(self, text=None):
|
||||
super().__init__(text=text)
|
||||
|
||||
def __contains__(self, key):
|
||||
# Define custom behavior for the 'in' operator
|
||||
return hasattr(self, key)
|
||||
|
||||
def get(self, key, default=None):
|
||||
# Custom .get() method to access attributes with a default value if the attribute doesn't exist
|
||||
return getattr(self, key, default)
|
||||
|
||||
def __getitem__(self, key):
|
||||
# Allow dictionary-style access to attributes
|
||||
return getattr(self, key)
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
# Allow dictionary-style assignment of attributes
|
||||
setattr(self, key, value)
|
||||
|
||||
def json(self, **kwargs):
|
||||
try:
|
||||
return self.model_dump() # noqa
|
||||
except:
|
||||
# if using pydantic v1
|
||||
return self.dict()
|
||||
|
||||
|
||||
############################################################
|
||||
def print_verbose(print_statement, logger_only: bool = False):
|
||||
try:
|
||||
|
@ -815,6 +846,8 @@ class CallTypes(Enum):
|
|||
aimage_generation = "aimage_generation"
|
||||
moderation = "moderation"
|
||||
amoderation = "amoderation"
|
||||
atranscription = "atranscription"
|
||||
transcription = "transcription"
|
||||
|
||||
|
||||
# Logging function -> log the exact model details + what's being sent | Non-BlockingP
|
||||
|
@ -948,6 +981,7 @@ class Logging:
|
|||
curl_command = self.model_call_details
|
||||
|
||||
# only print verbose if verbose logger is not set
|
||||
|
||||
if verbose_logger.level == 0:
|
||||
# this means verbose logger was not switched on - user is in litellm.set_verbose=True
|
||||
print_verbose(f"\033[92m{curl_command}\033[0m\n")
|
||||
|
@ -2293,6 +2327,12 @@ def client(original_function):
|
|||
or call_type == CallTypes.text_completion.value
|
||||
):
|
||||
messages = args[0] if len(args) > 0 else kwargs["prompt"]
|
||||
elif (
|
||||
call_type == CallTypes.atranscription.value
|
||||
or call_type == CallTypes.transcription.value
|
||||
):
|
||||
_file_name: BinaryIO = args[1] if len(args) > 1 else kwargs["file"]
|
||||
messages = _file_name.name
|
||||
stream = True if "stream" in kwargs and kwargs["stream"] == True else False
|
||||
logging_obj = Logging(
|
||||
model=model,
|
||||
|
@ -6264,10 +6304,10 @@ def convert_to_streaming_response(response_object: Optional[dict] = None):
|
|||
def convert_to_model_response_object(
|
||||
response_object: Optional[dict] = None,
|
||||
model_response_object: Optional[
|
||||
Union[ModelResponse, EmbeddingResponse, ImageResponse]
|
||||
Union[ModelResponse, EmbeddingResponse, ImageResponse, TranscriptionResponse]
|
||||
] = None,
|
||||
response_type: Literal[
|
||||
"completion", "embedding", "image_generation"
|
||||
"completion", "embedding", "image_generation", "audio_transcription"
|
||||
] = "completion",
|
||||
stream=False,
|
||||
start_time=None,
|
||||
|
@ -6378,6 +6418,19 @@ def convert_to_model_response_object(
|
|||
model_response_object.data = response_object["data"]
|
||||
|
||||
return model_response_object
|
||||
elif response_type == "audio_transcription" and (
|
||||
model_response_object is None
|
||||
or isinstance(model_response_object, TranscriptionResponse)
|
||||
):
|
||||
if response_object is None:
|
||||
raise Exception("Error in response object format")
|
||||
|
||||
if model_response_object is None:
|
||||
model_response_object = TranscriptionResponse()
|
||||
|
||||
if "text" in response_object:
|
||||
model_response_object.text = response_object["text"]
|
||||
return model_response_object
|
||||
except Exception as e:
|
||||
raise Exception(f"Invalid response object {traceback.format_exc()}")
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue