feat(azure.py): azure tts support

Closes https://github.com/BerriAI/litellm/issues/4002
This commit is contained in:
Krrish Dholakia 2024-06-27 16:59:25 -07:00
parent 80d8bf5d8f
commit c14cc35e52
4 changed files with 208 additions and 12 deletions

View file

@ -42,6 +42,7 @@ from ..types.llms.openai import (
AsyncAssistantEventHandler,
AsyncAssistantStreamManager,
AsyncCursorPage,
HttpxBinaryResponseContent,
MessageData,
OpenAICreateThreadParamsMessage,
OpenAIMessage,
@ -414,6 +415,49 @@ class AzureChatCompletion(BaseLLM):
headers["Authorization"] = f"Bearer {azure_ad_token}"
return headers
def _get_sync_azure_client(
self,
api_version: Optional[str],
api_base: Optional[str],
api_key: Optional[str],
azure_ad_token: Optional[str],
model: str,
max_retries: int,
timeout: Union[float, httpx.Timeout],
client: Optional[Any],
client_type: Literal["sync", "async"],
):
# init AzureOpenAI Client
azure_client_params = {
"api_version": api_version,
"azure_endpoint": api_base,
"azure_deployment": model,
"http_client": litellm.client_session,
"max_retries": max_retries,
"timeout": timeout,
}
azure_client_params = select_azure_base_url_or_endpoint(
azure_client_params=azure_client_params
)
if api_key is not None:
azure_client_params["api_key"] = api_key
elif azure_ad_token is not None:
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
azure_client_params["azure_ad_token"] = azure_ad_token
if client is None:
if client_type == "sync":
azure_client = AzureOpenAI(**azure_client_params) # type: ignore
elif client_type == "async":
azure_client = AsyncAzureOpenAI(**azure_client_params) # type: ignore
else:
azure_client = client
if api_version is not None and isinstance(azure_client._custom_query, dict):
# set api_version to version passed by user
azure_client._custom_query.setdefault("api-version", api_version)
return azure_client
def completion(
self,
model: str,
@ -1248,6 +1292,96 @@ class AzureChatCompletion(BaseLLM):
)
raise e
def audio_speech(
self,
model: str,
input: str,
voice: str,
optional_params: dict,
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str],
organization: Optional[str],
max_retries: int,
timeout: Union[float, httpx.Timeout],
azure_ad_token: Optional[str] = None,
aspeech: Optional[bool] = None,
client=None,
) -> HttpxBinaryResponseContent:
max_retries = optional_params.pop("max_retries", 2)
if aspeech is not None and aspeech is True:
return self.async_audio_speech(
model=model,
input=input,
voice=voice,
optional_params=optional_params,
api_key=api_key,
api_base=api_base,
api_version=api_version,
azure_ad_token=azure_ad_token,
max_retries=max_retries,
timeout=timeout,
client=client,
) # type: ignore
azure_client: AzureOpenAI = self._get_sync_azure_client(
api_base=api_base,
api_version=api_version,
api_key=api_key,
azure_ad_token=azure_ad_token,
model=model,
max_retries=max_retries,
timeout=timeout,
client=client,
client_type="sync",
) # type: ignore
response = azure_client.audio.speech.create(
model=model,
voice=voice, # type: ignore
input=input,
**optional_params,
)
return response
async def async_audio_speech(
self,
model: str,
input: str,
voice: str,
optional_params: dict,
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str],
azure_ad_token: Optional[str],
max_retries: int,
timeout: Union[float, httpx.Timeout],
client=None,
) -> HttpxBinaryResponseContent:
azure_client: AsyncAzureOpenAI = self._get_sync_azure_client(
api_base=api_base,
api_version=api_version,
api_key=api_key,
azure_ad_token=azure_ad_token,
model=model,
max_retries=max_retries,
timeout=timeout,
client=client,
client_type="async",
) # type: ignore
response = await azure_client.audio.speech.create(
model=model,
voice=voice, # type: ignore
input=input,
**optional_params,
)
return response
def get_headers(
self,
model: Optional[str],

View file

@ -4410,6 +4410,7 @@ def speech(
voice: str,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
api_version: Optional[str] = None,
organization: Optional[str] = None,
project: Optional[str] = None,
max_retries: Optional[int] = None,
@ -4483,6 +4484,45 @@ def speech(
client=client, # pass AsyncOpenAI, OpenAI client
aspeech=aspeech,
)
elif custom_llm_provider == "azure":
# azure configs
api_base = api_base or litellm.api_base or get_secret("AZURE_API_BASE") # type: ignore
api_version = (
api_version or litellm.api_version or get_secret("AZURE_API_VERSION")
) # type: ignore
api_key = (
api_key
or litellm.api_key
or litellm.azure_key
or get_secret("AZURE_OPENAI_API_KEY")
or get_secret("AZURE_API_KEY")
) # type: ignore
azure_ad_token: Optional[str] = optional_params.get("extra_body", {}).pop( # type: ignore
"azure_ad_token", None
) or get_secret(
"AZURE_AD_TOKEN"
)
headers = headers or litellm.headers
response = azure_chat_completions.audio_speech(
model=model,
input=input,
voice=voice,
optional_params=optional_params,
api_key=api_key,
api_base=api_base,
api_version=api_version,
azure_ad_token=azure_ad_token,
organization=organization,
max_retries=max_retries,
timeout=timeout,
client=client, # pass AsyncOpenAI, OpenAI client
aspeech=aspeech,
)
if response is None:
raise Exception(

View file

@ -18,7 +18,6 @@ model_list:
api_key: os.environ/PREDIBASE_API_KEY
tenant_id: os.environ/PREDIBASE_TENANT_ID
max_new_tokens: 256
# - litellm_params:
# api_base: https://my-endpoint-europe-berri-992.openai.azure.com/
# api_key: os.environ/AZURE_EUROPE_API_KEY

View file

@ -1,8 +1,14 @@
# What is this?
## unit tests for openai tts endpoint
import sys, os, asyncio, time, random, uuid
import asyncio
import os
import random
import sys
import time
import traceback
import uuid
from dotenv import load_dotenv
load_dotenv()
@ -11,23 +17,40 @@ import os
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import pytest
import litellm, openai
from pathlib import Path
import openai
import pytest
@pytest.mark.parametrize("sync_mode", [True, False])
import litellm
@pytest.mark.parametrize(
"sync_mode",
[True, False],
)
@pytest.mark.parametrize(
"model, api_key, api_base",
[
(
"azure/azure-tts",
os.getenv("AZURE_SWEDEN_API_KEY"),
os.getenv("AZURE_SWEDEN_API_BASE"),
),
("openai/tts-1", os.getenv("OPENAI_API_KEY"), None),
],
) # ,
@pytest.mark.asyncio
async def test_audio_speech_litellm(sync_mode):
async def test_audio_speech_litellm(sync_mode, model, api_base, api_key):
speech_file_path = Path(__file__).parent / "speech.mp3"
if sync_mode:
response = litellm.speech(
model="openai/tts-1",
model=model,
voice="alloy",
input="the quick brown fox jumped over the lazy dogs",
api_base=None,
api_key=None,
api_base=api_base,
api_key=api_key,
organization=None,
project=None,
max_retries=1,
@ -41,11 +64,11 @@ async def test_audio_speech_litellm(sync_mode):
assert isinstance(response, HttpxBinaryResponseContent)
else:
response = await litellm.aspeech(
model="openai/tts-1",
model=model,
voice="alloy",
input="the quick brown fox jumped over the lazy dogs",
api_base=None,
api_key=None,
api_base=api_base,
api_key=api_key,
organization=None,
project=None,
max_retries=1,