Merge pull request #2401 from BerriAI/litellm_transcription_endpoints

feat(main.py): support openai transcription endpoints
This commit is contained in:
Krish Dholakia 2024-03-08 23:07:48 -08:00 committed by GitHub
commit e245b1c98a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 516 additions and 12 deletions

View file

@ -8,7 +8,7 @@
# Thank you ! We ❤️ you! - Krrish & Ishaan
import os, openai, sys, json, inspect, uuid, datetime, threading
from typing import Any, Literal, Union
from typing import Any, Literal, Union, BinaryIO
from functools import partial
import dotenv, traceback, random, asyncio, time, contextvars
from copy import deepcopy
@ -88,6 +88,7 @@ from litellm.utils import (
read_config_args,
Choices,
Message,
TranscriptionResponse,
)
####### ENVIRONMENT VARIABLES ###################
@ -3048,7 +3049,6 @@ def moderation(
return response
##### Moderation #######################
@client
async def amoderation(input: str, model: str, api_key: Optional[str] = None, **kwargs):
# only supports open ai for now
@ -3071,11 +3071,11 @@ async def aimage_generation(*args, **kwargs):
Asynchronously calls the `image_generation` function with the given arguments and keyword arguments.
Parameters:
- `args` (tuple): Positional arguments to be passed to the `embedding` function.
- `kwargs` (dict): Keyword arguments to be passed to the `embedding` function.
- `args` (tuple): Positional arguments to be passed to the `image_generation` function.
- `kwargs` (dict): Keyword arguments to be passed to the `image_generation` function.
Returns:
- `response` (Any): The response returned by the `embedding` function.
- `response` (Any): The response returned by the `image_generation` function.
"""
loop = asyncio.get_event_loop()
model = args[0] if len(args) > 0 else kwargs["model"]
@ -3097,7 +3097,7 @@ async def aimage_generation(*args, **kwargs):
# Await normally
init_response = await loop.run_in_executor(None, func_with_context)
if isinstance(init_response, dict) or isinstance(
init_response, ModelResponse
init_response, ImageResponse
): ## CACHING SCENARIO
response = init_response
elif asyncio.iscoroutine(init_response):
@ -3315,6 +3315,142 @@ def image_generation(
)
##### Transcription #######################
async def atranscription(*args, **kwargs):
"""
Calls openai + azure whisper endpoints.
Allows router to load balance between them
"""
loop = asyncio.get_event_loop()
model = args[0] if len(args) > 0 else kwargs["model"]
### PASS ARGS TO Image Generation ###
kwargs["atranscription"] = True
custom_llm_provider = None
try:
# Use a partial function to pass your keyword arguments
func = partial(transcription, *args, **kwargs)
# Add the context to the function
ctx = contextvars.copy_context()
func_with_context = partial(ctx.run, func)
_, custom_llm_provider, _, _ = get_llm_provider(
model=model, api_base=kwargs.get("api_base", None)
)
# Await normally
init_response = await loop.run_in_executor(None, func_with_context)
if isinstance(init_response, dict) or isinstance(
init_response, TranscriptionResponse
): ## CACHING SCENARIO
response = init_response
elif asyncio.iscoroutine(init_response):
response = await init_response
else:
# Call the synchronous function using run_in_executor
response = await loop.run_in_executor(None, func_with_context)
return response
except Exception as e:
custom_llm_provider = custom_llm_provider or "openai"
raise exception_type(
model=model,
custom_llm_provider=custom_llm_provider,
original_exception=e,
completion_kwargs=args,
)
@client
def transcription(
model: str,
file: BinaryIO,
## OPTIONAL OPENAI PARAMS ##
language: Optional[str] = None,
prompt: Optional[str] = None,
response_format: Optional[
Literal["json", "text", "srt", "verbose_json", "vtt"]
] = None,
temperature: Optional[int] = None, # openai defaults this to 0
## LITELLM PARAMS ##
user: Optional[str] = None,
timeout=600, # default to 10 minutes
api_key: Optional[str] = None,
api_base: Optional[str] = None,
api_version: Optional[str] = None,
litellm_logging_obj=None,
custom_llm_provider=None,
**kwargs,
):
"""
Calls openai + azure whisper endpoints.
Allows router to load balance between them
"""
atranscriptions = kwargs.get("atranscriptions", False)
litellm_call_id = kwargs.get("litellm_call_id", None)
logger_fn = kwargs.get("logger_fn", None)
proxy_server_request = kwargs.get("proxy_server_request", None)
model_info = kwargs.get("model_info", None)
metadata = kwargs.get("metadata", {})
model_response = litellm.utils.TranscriptionResponse()
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore
optional_params = {
"language": language,
"prompt": prompt,
"response_format": response_format,
"temperature": None, # openai defaults this to 0
}
if custom_llm_provider == "azure":
# azure configs
api_base = api_base or litellm.api_base or get_secret("AZURE_API_BASE")
api_version = (
api_version or litellm.api_version or get_secret("AZURE_API_VERSION")
)
azure_ad_token = kwargs.pop("azure_ad_token", None) or get_secret(
"AZURE_AD_TOKEN"
)
api_key = (
api_key
or litellm.api_key
or litellm.azure_key
or get_secret("AZURE_API_KEY")
)
response = azure_chat_completions.audio_transcriptions(
model=model,
audio_file=file,
optional_params=optional_params,
model_response=model_response,
atranscriptions=atranscriptions,
timeout=timeout,
logging_obj=litellm_logging_obj,
api_base=api_base,
api_key=api_key,
api_version=api_version,
azure_ad_token=azure_ad_token,
)
elif custom_llm_provider == "openai":
response = openai_chat_completions.audio_transcriptions(
model=model,
audio_file=file,
optional_params=optional_params,
model_response=model_response,
atranscriptions=atranscriptions,
timeout=timeout,
logging_obj=litellm_logging_obj,
)
return response
##### Health Endpoints #######################