Merge pull request #2401 from BerriAI/litellm_transcription_endpoints

feat(main.py): support openai transcription endpoints
This commit is contained in:
Krish Dholakia 2024-03-08 23:07:48 -08:00 committed by GitHub
commit e245b1c98a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 516 additions and 12 deletions

View file

@ -1,4 +1,4 @@
from typing import Optional, Union, Any
from typing import Optional, Union, Any, BinaryIO
import types, time, json, traceback
import httpx
from .base import BaseLLM
@ -9,6 +9,7 @@ from litellm.utils import (
CustomStreamWrapper,
convert_to_model_response_object,
Usage,
TranscriptionResponse,
)
from typing import Callable, Optional
import aiohttp, requests
@ -774,6 +775,103 @@ class OpenAIChatCompletion(BaseLLM):
else:
raise OpenAIError(status_code=500, message=str(e))
def audio_transcriptions(
self,
model: str,
audio_file: BinaryIO,
optional_params: dict,
model_response: TranscriptionResponse,
timeout: float,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
client=None,
max_retries=None,
logging_obj=None,
atranscriptions: bool = False,
):
data = {"model": model, "file": audio_file, **optional_params}
if atranscriptions == True:
return self.async_audio_transcriptions(
audio_file=audio_file,
data=data,
model_response=model_response,
timeout=timeout,
api_key=api_key,
api_base=api_base,
client=client,
max_retries=max_retries,
logging_obj=logging_obj,
)
if client is None:
openai_client = OpenAI(
api_key=api_key,
base_url=api_base,
http_client=litellm.client_session,
timeout=timeout,
max_retries=max_retries,
)
else:
openai_client = client
response = openai_client.audio.transcriptions.create(
**data, timeout=timeout # type: ignore
)
stringified_response = response.model_dump()
## LOGGING
logging_obj.post_call(
input=audio_file.name,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=stringified_response,
)
final_response = convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, response_type="audio_transcription") # type: ignore
return final_response
async def async_audio_transcriptions(
self,
audio_file: BinaryIO,
data: dict,
model_response: TranscriptionResponse,
timeout: float,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
client=None,
max_retries=None,
logging_obj=None,
):
response = None
try:
if client is None:
openai_aclient = AsyncOpenAI(
api_key=api_key,
base_url=api_base,
http_client=litellm.aclient_session,
timeout=timeout,
max_retries=max_retries,
)
else:
openai_aclient = client
response = await openai_aclient.audio.transcriptions.create(
**data, timeout=timeout
) # type: ignore
stringified_response = response.model_dump()
## LOGGING
logging_obj.post_call(
input=audio_file.name,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=stringified_response,
)
return convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, response_type="image_generation") # type: ignore
except Exception as e:
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
original_response=str(e),
)
raise e
async def ahealth_check(
self,
model: Optional[str],