mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
198 lines
7.1 KiB
Python
198 lines
7.1 KiB
Python
import uuid
|
|
from typing import Any, Coroutine, Optional, Union
|
|
|
|
from openai import AsyncAzureOpenAI, AzureOpenAI
|
|
from pydantic import BaseModel
|
|
|
|
from litellm.litellm_core_utils.audio_utils.utils import get_audio_file_name
|
|
from litellm.types.utils import FileTypes
|
|
from litellm.utils import (
|
|
TranscriptionResponse,
|
|
convert_to_model_response_object,
|
|
extract_duration_from_srt_or_vtt,
|
|
)
|
|
|
|
from .azure import AzureChatCompletion
|
|
from .common_utils import AzureOpenAIError
|
|
|
|
|
|
class AzureAudioTranscription(AzureChatCompletion):
|
|
def audio_transcriptions(
|
|
self,
|
|
model: str,
|
|
audio_file: FileTypes,
|
|
optional_params: dict,
|
|
logging_obj: Any,
|
|
model_response: TranscriptionResponse,
|
|
timeout: float,
|
|
max_retries: int,
|
|
api_key: Optional[str] = None,
|
|
api_base: Optional[str] = None,
|
|
api_version: Optional[str] = None,
|
|
client=None,
|
|
azure_ad_token: Optional[str] = None,
|
|
atranscription: bool = False,
|
|
litellm_params: Optional[dict] = None,
|
|
) -> Union[TranscriptionResponse, Coroutine[Any, Any, TranscriptionResponse]]:
|
|
data = {"model": model, "file": audio_file, **optional_params}
|
|
|
|
if atranscription is True:
|
|
return self.async_audio_transcriptions(
|
|
audio_file=audio_file,
|
|
data=data,
|
|
model_response=model_response,
|
|
timeout=timeout,
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
client=client,
|
|
max_retries=max_retries,
|
|
logging_obj=logging_obj,
|
|
model=model,
|
|
litellm_params=litellm_params,
|
|
)
|
|
|
|
azure_client = self.get_azure_openai_client(
|
|
api_version=api_version,
|
|
api_base=api_base,
|
|
api_key=api_key,
|
|
model=model,
|
|
_is_async=False,
|
|
client=client,
|
|
litellm_params=litellm_params,
|
|
)
|
|
if not isinstance(azure_client, AzureOpenAI):
|
|
raise AzureOpenAIError(
|
|
status_code=500,
|
|
message="azure_client is not an instance of AzureOpenAI",
|
|
)
|
|
|
|
## LOGGING
|
|
logging_obj.pre_call(
|
|
input=f"audio_file_{uuid.uuid4()}",
|
|
api_key=azure_client.api_key,
|
|
additional_args={
|
|
"headers": {"Authorization": f"Bearer {azure_client.api_key}"},
|
|
"api_base": azure_client._base_url._uri_reference,
|
|
"atranscription": True,
|
|
"complete_input_dict": data,
|
|
},
|
|
)
|
|
|
|
response = azure_client.audio.transcriptions.create(
|
|
**data, timeout=timeout # type: ignore
|
|
)
|
|
|
|
if isinstance(response, BaseModel):
|
|
stringified_response = response.model_dump()
|
|
else:
|
|
stringified_response = TranscriptionResponse(text=response).model_dump()
|
|
|
|
## LOGGING
|
|
logging_obj.post_call(
|
|
input=get_audio_file_name(audio_file),
|
|
api_key=api_key,
|
|
additional_args={"complete_input_dict": data},
|
|
original_response=stringified_response,
|
|
)
|
|
hidden_params = {"model": "whisper-1", "custom_llm_provider": "azure"}
|
|
final_response: TranscriptionResponse = convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, hidden_params=hidden_params, response_type="audio_transcription") # type: ignore
|
|
return final_response
|
|
|
|
async def async_audio_transcriptions(
|
|
self,
|
|
audio_file: FileTypes,
|
|
model: str,
|
|
data: dict,
|
|
model_response: TranscriptionResponse,
|
|
timeout: float,
|
|
logging_obj: Any,
|
|
api_version: Optional[str] = None,
|
|
api_key: Optional[str] = None,
|
|
api_base: Optional[str] = None,
|
|
client=None,
|
|
max_retries=None,
|
|
litellm_params: Optional[dict] = None,
|
|
) -> TranscriptionResponse:
|
|
response = None
|
|
try:
|
|
async_azure_client = self.get_azure_openai_client(
|
|
api_version=api_version,
|
|
api_base=api_base,
|
|
api_key=api_key,
|
|
model=model,
|
|
_is_async=True,
|
|
client=client,
|
|
litellm_params=litellm_params,
|
|
)
|
|
if not isinstance(async_azure_client, AsyncAzureOpenAI):
|
|
raise AzureOpenAIError(
|
|
status_code=500,
|
|
message="async_azure_client is not an instance of AsyncAzureOpenAI",
|
|
)
|
|
|
|
## LOGGING
|
|
logging_obj.pre_call(
|
|
input=f"audio_file_{uuid.uuid4()}",
|
|
api_key=async_azure_client.api_key,
|
|
additional_args={
|
|
"headers": {
|
|
"Authorization": f"Bearer {async_azure_client.api_key}"
|
|
},
|
|
"api_base": async_azure_client._base_url._uri_reference,
|
|
"atranscription": True,
|
|
"complete_input_dict": data,
|
|
},
|
|
)
|
|
|
|
raw_response = (
|
|
await async_azure_client.audio.transcriptions.with_raw_response.create(
|
|
**data, timeout=timeout
|
|
)
|
|
) # type: ignore
|
|
|
|
headers = dict(raw_response.headers)
|
|
response = raw_response.parse()
|
|
|
|
if isinstance(response, BaseModel):
|
|
stringified_response = response.model_dump()
|
|
else:
|
|
stringified_response = TranscriptionResponse(text=response).model_dump()
|
|
duration = extract_duration_from_srt_or_vtt(response)
|
|
stringified_response["duration"] = duration
|
|
|
|
## LOGGING
|
|
logging_obj.post_call(
|
|
input=get_audio_file_name(audio_file),
|
|
api_key=api_key,
|
|
additional_args={
|
|
"headers": {
|
|
"Authorization": f"Bearer {async_azure_client.api_key}"
|
|
},
|
|
"api_base": async_azure_client._base_url._uri_reference,
|
|
"atranscription": True,
|
|
"complete_input_dict": data,
|
|
},
|
|
original_response=stringified_response,
|
|
)
|
|
hidden_params = {"model": "whisper-1", "custom_llm_provider": "azure"}
|
|
response = convert_to_model_response_object(
|
|
_response_headers=headers,
|
|
response_object=stringified_response,
|
|
model_response_object=model_response,
|
|
hidden_params=hidden_params,
|
|
response_type="audio_transcription",
|
|
)
|
|
if not isinstance(response, TranscriptionResponse):
|
|
raise AzureOpenAIError(
|
|
status_code=500,
|
|
message="response is not an instance of TranscriptionResponse",
|
|
)
|
|
return response
|
|
except Exception as e:
|
|
## LOGGING
|
|
logging_obj.post_call(
|
|
input=input,
|
|
api_key=api_key,
|
|
original_response=str(e),
|
|
)
|
|
raise e
|