diff --git a/litellm/llms/azure.py b/litellm/llms/azure.py index 01b54987b..e19023b03 100644 --- a/litellm/llms/azure.py +++ b/litellm/llms/azure.py @@ -7,8 +7,9 @@ from litellm.utils import ( Message, CustomStreamWrapper, convert_to_model_response_object, + TranscriptionResponse, ) -from typing import Callable, Optional +from typing import Callable, Optional, BinaryIO from litellm import OpenAIConfig import litellm, json import httpx @@ -757,6 +758,114 @@ class AzureChatCompletion(BaseLLM): else: raise AzureOpenAIError(status_code=500, message=str(e)) + def audio_transcriptions( + self, + model: str, + audio_file: BinaryIO, + optional_params: dict, + model_response: TranscriptionResponse, + timeout: float, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + api_version: Optional[str] = None, + client=None, + azure_ad_token: Optional[str] = None, + max_retries=None, + logging_obj=None, + atranscriptions: bool = False, + ): + data = {"model": model, "file": audio_file, **optional_params} + + # init AzureOpenAI Client + azure_client_params = { + "api_version": api_version, + "azure_endpoint": api_base, + "azure_deployment": model, + "max_retries": max_retries, + "timeout": timeout, + } + azure_client_params = select_azure_base_url_or_endpoint( + azure_client_params=azure_client_params + ) + if api_key is not None: + azure_client_params["api_key"] = api_key + elif azure_ad_token is not None: + azure_client_params["azure_ad_token"] = azure_ad_token + + if atranscriptions == True: + return self.async_audio_transcriptions( + audio_file=audio_file, + data=data, + model_response=model_response, + timeout=timeout, + api_key=api_key, + api_base=api_base, + client=client, + azure_client_params=azure_client_params, + max_retries=max_retries, + logging_obj=logging_obj, + ) + if client is None: + azure_client = AzureOpenAI(http_client=litellm.client_session, **azure_client_params) # type: ignore + else: + azure_client = client + response = azure_client.audio.transcriptions.create( + **data, timeout=timeout # type: ignore + ) + stringified_response = response.model_dump() + ## LOGGING + logging_obj.post_call( + input=audio_file.name, + api_key=api_key, + additional_args={"complete_input_dict": data}, + original_response=stringified_response, + ) + final_response = convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, response_type="audio_transcription") # type: ignore + return final_response + + async def async_audio_transcriptions( + self, + audio_file: BinaryIO, + data: dict, + model_response: TranscriptionResponse, + timeout: float, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + client=None, + azure_client_params=None, + max_retries=None, + logging_obj=None, + ): + response = None + try: + if client is None: + async_azure_client = AsyncAzureOpenAI( + **azure_client_params, + http_client=litellm.aclient_session, + ) + else: + async_azure_client = client + response = await async_azure_client.audio.transcriptions.create( + **data, timeout=timeout + ) # type: ignore + stringified_response = response.model_dump() + ## LOGGING + logging_obj.post_call( + input=audio_file.name, + api_key=api_key, + additional_args={"complete_input_dict": data}, + original_response=stringified_response, + ) + return convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, response_type="image_generation") # type: ignore + except Exception as e: + ## LOGGING + logging_obj.post_call( + input=input, + api_key=api_key, + original_response=str(e), + ) + raise e + async def ahealth_check( self, model: Optional[str], diff --git a/litellm/main.py b/litellm/main.py index 2df9686fe..f1a745fcc 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -88,6 +88,7 @@ from litellm.utils import ( read_config_args, Choices, Message, + TranscriptionResponse, ) ####### ENVIRONMENT VARIABLES ################### @@ -3065,11 +3066,11 @@ async def aimage_generation(*args, **kwargs): Asynchronously calls the `image_generation` function with the given arguments and keyword arguments. Parameters: - - `args` (tuple): Positional arguments to be passed to the `embedding` function. - - `kwargs` (dict): Keyword arguments to be passed to the `embedding` function. + - `args` (tuple): Positional arguments to be passed to the `image_generation` function. + - `kwargs` (dict): Keyword arguments to be passed to the `image_generation` function. Returns: - - `response` (Any): The response returned by the `embedding` function. + - `response` (Any): The response returned by the `image_generation` function. """ loop = asyncio.get_event_loop() model = args[0] if len(args) > 0 else kwargs["model"] @@ -3091,7 +3092,7 @@ async def aimage_generation(*args, **kwargs): # Await normally init_response = await loop.run_in_executor(None, func_with_context) if isinstance(init_response, dict) or isinstance( - init_response, ModelResponse + init_response, ImageResponse ): ## CACHING SCENARIO response = init_response elif asyncio.iscoroutine(init_response): @@ -3318,7 +3319,43 @@ async def atranscription(*args, **kwargs): Allows router to load balance between them """ - pass + loop = asyncio.get_event_loop() + model = args[0] if len(args) > 0 else kwargs["model"] + ### PASS ARGS TO Image Generation ### + kwargs["atranscription"] = True + custom_llm_provider = None + try: + # Use a partial function to pass your keyword arguments + func = partial(transcription, *args, **kwargs) + + # Add the context to the function + ctx = contextvars.copy_context() + func_with_context = partial(ctx.run, func) + + _, custom_llm_provider, _, _ = get_llm_provider( + model=model, api_base=kwargs.get("api_base", None) + ) + + # Await normally + init_response = await loop.run_in_executor(None, func_with_context) + if isinstance(init_response, dict) or isinstance( + init_response, TranscriptionResponse + ): ## CACHING SCENARIO + response = init_response + elif asyncio.iscoroutine(init_response): + response = await init_response + else: + # Call the synchronous function using run_in_executor + response = await loop.run_in_executor(None, func_with_context) + return response + except Exception as e: + custom_llm_provider = custom_llm_provider or "openai" + raise exception_type( + model=model, + custom_llm_provider=custom_llm_provider, + original_exception=e, + completion_kwargs=args, + ) @client @@ -3356,8 +3393,7 @@ def transcription( model_response = litellm.utils.TranscriptionResponse() - # model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore - custom_llm_provider = "openai" + model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore optional_params = { "language": language, @@ -3365,8 +3401,40 @@ def transcription( "response_format": response_format, "temperature": None, # openai defaults this to 0 } - if custom_llm_provider == "openai": - return openai_chat_completions.audio_transcriptions( + + if custom_llm_provider == "azure": + # azure configs + api_base = api_base or litellm.api_base or get_secret("AZURE_API_BASE") + + api_version = ( + api_version or litellm.api_version or get_secret("AZURE_API_VERSION") + ) + + azure_ad_token = kwargs.pop("azure_ad_token", None) or get_secret( + "AZURE_AD_TOKEN" + ) + + api_key = ( + api_key + or litellm.api_key + or litellm.azure_key + or get_secret("AZURE_API_KEY") + ) + response = azure_chat_completions.audio_transcriptions( + model=model, + audio_file=file, + optional_params=optional_params, + model_response=model_response, + atranscriptions=atranscriptions, + timeout=timeout, + logging_obj=litellm_logging_obj, + api_base=api_base, + api_key=api_key, + api_version=api_version, + azure_ad_token=azure_ad_token, + ) + elif custom_llm_provider == "openai": + response = openai_chat_completions.audio_transcriptions( model=model, audio_file=file, optional_params=optional_params, @@ -3375,7 +3443,7 @@ def transcription( timeout=timeout, logging_obj=litellm_logging_obj, ) - return + return response ##### Health Endpoints ####################### diff --git a/tests/test_whisper.py b/tests/test_whisper.py index 8ee3b428c..9d8f038c2 100644 --- a/tests/test_whisper.py +++ b/tests/test_whisper.py @@ -19,8 +19,55 @@ import litellm def test_transcription(): - transcript = litellm.transcription(model="whisper-1", file=audio_file) + transcript = litellm.transcription( + model="whisper-1", + file=audio_file, + ) print(f"transcript: {transcript}") -test_transcription() +# test_transcription() + + +def test_transcription_azure(): + transcript = litellm.transcription( + model="azure/azure-whisper", + file=audio_file, + api_base=os.getenv("AZURE_EUROPE_API_BASE"), + api_key=os.getenv("AZURE_EUROPE_API_KEY"), + api_version=os.getenv("2024-02-15-preview"), + ) + + assert transcript.text is not None + assert isinstance(transcript.text, str) + + +# test_transcription_azure() + + +@pytest.mark.asyncio +async def test_transcription_async_azure(): + transcript = await litellm.atranscription( + model="azure/azure-whisper", + file=audio_file, + api_base=os.getenv("AZURE_EUROPE_API_BASE"), + api_key=os.getenv("AZURE_EUROPE_API_KEY"), + api_version=os.getenv("2024-02-15-preview"), + ) + + assert transcript.text is not None + assert isinstance(transcript.text, str) + + +# asyncio.run(test_transcription_async_azure()) + + +@pytest.mark.asyncio +async def test_transcription_async_openai(): + transcript = await litellm.atranscription( + model="whisper-1", + file=audio_file, + ) + + assert transcript.text is not None + assert isinstance(transcript.text, str)