feat(azure.py): add support for calling whisper endpoints on azure

This commit is contained in:
Krrish Dholakia 2024-03-08 13:48:38 -08:00
parent 696eb54455
commit 6b1049217e
3 changed files with 237 additions and 13 deletions

View file

@ -7,8 +7,9 @@ from litellm.utils import (
Message, Message,
CustomStreamWrapper, CustomStreamWrapper,
convert_to_model_response_object, convert_to_model_response_object,
TranscriptionResponse,
) )
from typing import Callable, Optional from typing import Callable, Optional, BinaryIO
from litellm import OpenAIConfig from litellm import OpenAIConfig
import litellm, json import litellm, json
import httpx import httpx
@ -757,6 +758,114 @@ class AzureChatCompletion(BaseLLM):
else: else:
raise AzureOpenAIError(status_code=500, message=str(e)) raise AzureOpenAIError(status_code=500, message=str(e))
def audio_transcriptions(
self,
model: str,
audio_file: BinaryIO,
optional_params: dict,
model_response: TranscriptionResponse,
timeout: float,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
api_version: Optional[str] = None,
client=None,
azure_ad_token: Optional[str] = None,
max_retries=None,
logging_obj=None,
atranscriptions: bool = False,
):
data = {"model": model, "file": audio_file, **optional_params}
# init AzureOpenAI Client
azure_client_params = {
"api_version": api_version,
"azure_endpoint": api_base,
"azure_deployment": model,
"max_retries": max_retries,
"timeout": timeout,
}
azure_client_params = select_azure_base_url_or_endpoint(
azure_client_params=azure_client_params
)
if api_key is not None:
azure_client_params["api_key"] = api_key
elif azure_ad_token is not None:
azure_client_params["azure_ad_token"] = azure_ad_token
if atranscriptions == True:
return self.async_audio_transcriptions(
audio_file=audio_file,
data=data,
model_response=model_response,
timeout=timeout,
api_key=api_key,
api_base=api_base,
client=client,
azure_client_params=azure_client_params,
max_retries=max_retries,
logging_obj=logging_obj,
)
if client is None:
azure_client = AzureOpenAI(http_client=litellm.client_session, **azure_client_params) # type: ignore
else:
azure_client = client
response = azure_client.audio.transcriptions.create(
**data, timeout=timeout # type: ignore
)
stringified_response = response.model_dump()
## LOGGING
logging_obj.post_call(
input=audio_file.name,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=stringified_response,
)
final_response = convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, response_type="audio_transcription") # type: ignore
return final_response
async def async_audio_transcriptions(
self,
audio_file: BinaryIO,
data: dict,
model_response: TranscriptionResponse,
timeout: float,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
client=None,
azure_client_params=None,
max_retries=None,
logging_obj=None,
):
response = None
try:
if client is None:
async_azure_client = AsyncAzureOpenAI(
**azure_client_params,
http_client=litellm.aclient_session,
)
else:
async_azure_client = client
response = await async_azure_client.audio.transcriptions.create(
**data, timeout=timeout
) # type: ignore
stringified_response = response.model_dump()
## LOGGING
logging_obj.post_call(
input=audio_file.name,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=stringified_response,
)
return convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, response_type="image_generation") # type: ignore
except Exception as e:
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
original_response=str(e),
)
raise e
async def ahealth_check( async def ahealth_check(
self, self,
model: Optional[str], model: Optional[str],

View file

@ -88,6 +88,7 @@ from litellm.utils import (
read_config_args, read_config_args,
Choices, Choices,
Message, Message,
TranscriptionResponse,
) )
####### ENVIRONMENT VARIABLES ################### ####### ENVIRONMENT VARIABLES ###################
@ -3065,11 +3066,11 @@ async def aimage_generation(*args, **kwargs):
Asynchronously calls the `image_generation` function with the given arguments and keyword arguments. Asynchronously calls the `image_generation` function with the given arguments and keyword arguments.
Parameters: Parameters:
- `args` (tuple): Positional arguments to be passed to the `embedding` function. - `args` (tuple): Positional arguments to be passed to the `image_generation` function.
- `kwargs` (dict): Keyword arguments to be passed to the `embedding` function. - `kwargs` (dict): Keyword arguments to be passed to the `image_generation` function.
Returns: Returns:
- `response` (Any): The response returned by the `embedding` function. - `response` (Any): The response returned by the `image_generation` function.
""" """
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
model = args[0] if len(args) > 0 else kwargs["model"] model = args[0] if len(args) > 0 else kwargs["model"]
@ -3091,7 +3092,7 @@ async def aimage_generation(*args, **kwargs):
# Await normally # Await normally
init_response = await loop.run_in_executor(None, func_with_context) init_response = await loop.run_in_executor(None, func_with_context)
if isinstance(init_response, dict) or isinstance( if isinstance(init_response, dict) or isinstance(
init_response, ModelResponse init_response, ImageResponse
): ## CACHING SCENARIO ): ## CACHING SCENARIO
response = init_response response = init_response
elif asyncio.iscoroutine(init_response): elif asyncio.iscoroutine(init_response):
@ -3318,7 +3319,43 @@ async def atranscription(*args, **kwargs):
Allows router to load balance between them Allows router to load balance between them
""" """
pass loop = asyncio.get_event_loop()
model = args[0] if len(args) > 0 else kwargs["model"]
### PASS ARGS TO Image Generation ###
kwargs["atranscription"] = True
custom_llm_provider = None
try:
# Use a partial function to pass your keyword arguments
func = partial(transcription, *args, **kwargs)
# Add the context to the function
ctx = contextvars.copy_context()
func_with_context = partial(ctx.run, func)
_, custom_llm_provider, _, _ = get_llm_provider(
model=model, api_base=kwargs.get("api_base", None)
)
# Await normally
init_response = await loop.run_in_executor(None, func_with_context)
if isinstance(init_response, dict) or isinstance(
init_response, TranscriptionResponse
): ## CACHING SCENARIO
response = init_response
elif asyncio.iscoroutine(init_response):
response = await init_response
else:
# Call the synchronous function using run_in_executor
response = await loop.run_in_executor(None, func_with_context)
return response
except Exception as e:
custom_llm_provider = custom_llm_provider or "openai"
raise exception_type(
model=model,
custom_llm_provider=custom_llm_provider,
original_exception=e,
completion_kwargs=args,
)
@client @client
@ -3356,8 +3393,7 @@ def transcription(
model_response = litellm.utils.TranscriptionResponse() model_response = litellm.utils.TranscriptionResponse()
# model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore
custom_llm_provider = "openai"
optional_params = { optional_params = {
"language": language, "language": language,
@ -3365,8 +3401,40 @@ def transcription(
"response_format": response_format, "response_format": response_format,
"temperature": None, # openai defaults this to 0 "temperature": None, # openai defaults this to 0
} }
if custom_llm_provider == "openai":
return openai_chat_completions.audio_transcriptions( if custom_llm_provider == "azure":
# azure configs
api_base = api_base or litellm.api_base or get_secret("AZURE_API_BASE")
api_version = (
api_version or litellm.api_version or get_secret("AZURE_API_VERSION")
)
azure_ad_token = kwargs.pop("azure_ad_token", None) or get_secret(
"AZURE_AD_TOKEN"
)
api_key = (
api_key
or litellm.api_key
or litellm.azure_key
or get_secret("AZURE_API_KEY")
)
response = azure_chat_completions.audio_transcriptions(
model=model,
audio_file=file,
optional_params=optional_params,
model_response=model_response,
atranscriptions=atranscriptions,
timeout=timeout,
logging_obj=litellm_logging_obj,
api_base=api_base,
api_key=api_key,
api_version=api_version,
azure_ad_token=azure_ad_token,
)
elif custom_llm_provider == "openai":
response = openai_chat_completions.audio_transcriptions(
model=model, model=model,
audio_file=file, audio_file=file,
optional_params=optional_params, optional_params=optional_params,
@ -3375,7 +3443,7 @@ def transcription(
timeout=timeout, timeout=timeout,
logging_obj=litellm_logging_obj, logging_obj=litellm_logging_obj,
) )
return return response
##### Health Endpoints ####################### ##### Health Endpoints #######################

View file

@ -19,8 +19,55 @@ import litellm
def test_transcription(): def test_transcription():
transcript = litellm.transcription(model="whisper-1", file=audio_file) transcript = litellm.transcription(
model="whisper-1",
file=audio_file,
)
print(f"transcript: {transcript}") print(f"transcript: {transcript}")
test_transcription() # test_transcription()
def test_transcription_azure():
transcript = litellm.transcription(
model="azure/azure-whisper",
file=audio_file,
api_base=os.getenv("AZURE_EUROPE_API_BASE"),
api_key=os.getenv("AZURE_EUROPE_API_KEY"),
api_version=os.getenv("2024-02-15-preview"),
)
assert transcript.text is not None
assert isinstance(transcript.text, str)
# test_transcription_azure()
@pytest.mark.asyncio
async def test_transcription_async_azure():
transcript = await litellm.atranscription(
model="azure/azure-whisper",
file=audio_file,
api_base=os.getenv("AZURE_EUROPE_API_BASE"),
api_key=os.getenv("AZURE_EUROPE_API_KEY"),
api_version=os.getenv("2024-02-15-preview"),
)
assert transcript.text is not None
assert isinstance(transcript.text, str)
# asyncio.run(test_transcription_async_azure())
@pytest.mark.asyncio
async def test_transcription_async_openai():
transcript = await litellm.atranscription(
model="whisper-1",
file=audio_file,
)
assert transcript.text is not None
assert isinstance(transcript.text, str)