feat(main.py): support openai tts endpoint

Closes https://github.com/BerriAI/litellm/issues/3094
This commit is contained in:
Krrish Dholakia 2024-05-30 14:28:28 -07:00
parent 3167bee25a
commit a67cbf47f6
5 changed files with 322 additions and 3 deletions

View file

@ -227,7 +227,7 @@ default_team_settings: Optional[List] = None
max_user_budget: Optional[float] = None max_user_budget: Optional[float] = None
max_end_user_budget: Optional[float] = None max_end_user_budget: Optional[float] = None
#### RELIABILITY #### #### RELIABILITY ####
request_timeout: Optional[float] = 6000 request_timeout: float = 6000
num_retries: Optional[int] = None # per model endpoint num_retries: Optional[int] = None # per model endpoint
default_fallbacks: Optional[List] = None default_fallbacks: Optional[List] = None
fallbacks: Optional[List] = None fallbacks: Optional[List] = None
@ -304,6 +304,7 @@ api_base = None
headers = None headers = None
api_version = None api_version = None
organization = None organization = None
project = None
config_path = None config_path = None
####### COMPLETION MODELS ################### ####### COMPLETION MODELS ###################
open_ai_chat_completion_models: List = [] open_ai_chat_completion_models: List = []

View file

@ -26,6 +26,7 @@ import litellm
from .prompt_templates.factory import prompt_factory, custom_prompt from .prompt_templates.factory import prompt_factory, custom_prompt
from openai import OpenAI, AsyncOpenAI from openai import OpenAI, AsyncOpenAI
from ..types.llms.openai import * from ..types.llms.openai import *
import openai
class OpenAIError(Exception): class OpenAIError(Exception):
@ -1180,6 +1181,94 @@ class OpenAIChatCompletion(BaseLLM):
) )
raise e raise e
def audio_speech(
self,
model: str,
input: str,
voice: str,
optional_params: dict,
api_key: Optional[str],
api_base: Optional[str],
organization: Optional[str],
project: Optional[str],
max_retries: int,
timeout: Union[float, httpx.Timeout],
aspeech: Optional[bool] = None,
client=None,
) -> ResponseContextManager[StreamedBinaryAPIResponse]:
if aspeech is not None and aspeech == True:
return self.async_audio_speech(
model=model,
input=input,
voice=voice,
optional_params=optional_params,
api_key=api_key,
api_base=api_base,
organization=organization,
project=project,
max_retries=max_retries,
timeout=timeout,
client=client,
) # type: ignore
if client is None:
openai_client = OpenAI(
api_key=api_key,
base_url=api_base,
organization=organization,
project=project,
http_client=litellm.client_session,
timeout=timeout,
max_retries=max_retries,
)
else:
openai_client = client
response = openai_client.audio.speech.with_streaming_response.create(
model="tts-1",
voice="alloy",
input="the quick brown fox jumped over the lazy dogs",
**optional_params,
)
return response
def async_audio_speech(
self,
model: str,
input: str,
voice: str,
optional_params: dict,
api_key: Optional[str],
api_base: Optional[str],
organization: Optional[str],
project: Optional[str],
max_retries: int,
timeout: Union[float, httpx.Timeout],
client=None,
) -> AsyncResponseContextManager[AsyncStreamedBinaryAPIResponse]:
if client is None:
openai_client = AsyncOpenAI(
api_key=api_key,
base_url=api_base,
organization=organization,
project=project,
http_client=litellm.aclient_session,
timeout=timeout,
max_retries=max_retries,
)
else:
openai_client = client
response = openai_client.audio.speech.with_streaming_response.create(
model="tts-1",
voice="alloy",
input="the quick brown fox jumped over the lazy dogs",
**optional_params,
)
return response
async def ahealth_check( async def ahealth_check(
self, self,
model: Optional[str], model: Optional[str],

View file

@ -91,6 +91,12 @@ import tiktoken
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from typing import Callable, List, Optional, Dict, Union, Mapping from typing import Callable, List, Optional, Dict, Union, Mapping
from .caching import enable_cache, disable_cache, update_cache from .caching import enable_cache, disable_cache, update_cache
from .types.llms.openai import (
StreamedBinaryAPIResponse,
ResponseContextManager,
AsyncResponseContextManager,
AsyncStreamedBinaryAPIResponse,
)
encoding = tiktoken.get_encoding("cl100k_base") encoding = tiktoken.get_encoding("cl100k_base")
from litellm.utils import ( from litellm.utils import (
@ -4163,6 +4169,134 @@ def transcription(
return response return response
def aspeech(
*args, **kwargs
) -> AsyncResponseContextManager[AsyncStreamedBinaryAPIResponse]:
"""
Calls openai tts endpoints.
"""
loop = asyncio.get_event_loop()
model = args[0] if len(args) > 0 else kwargs["model"]
### PASS ARGS TO Image Generation ###
kwargs["aspeech"] = True
custom_llm_provider = kwargs.get("custom_llm_provider", None)
try:
# # Use a partial function to pass your keyword arguments
# func = partial(speech, *args, **kwargs)
# # Add the context to the function
# ctx = contextvars.copy_context()
# func_with_context = partial(ctx.run, func)
# _, custom_llm_provider, _, _ = get_llm_provider(
# model=model, api_base=kwargs.get("api_base", None)
# )
# # Await normally
# init_response = await loop.run_in_executor(None, func_with_context)
# if asyncio.iscoroutine(init_response):
# response = await init_response
# else:
# # Call the synchronous function using run_in_executor
# response = await loop.run_in_executor(None, func_with_context)
return speech(*args, **kwargs) # type: ignore
except Exception as e:
custom_llm_provider = custom_llm_provider or "openai"
raise exception_type(
model=model,
custom_llm_provider=custom_llm_provider,
original_exception=e,
completion_kwargs=args,
extra_kwargs=kwargs,
)
def speech(
model: str,
input: str,
voice: str,
optional_params: dict,
api_key: Optional[str],
api_base: Optional[str],
organization: Optional[str],
project: Optional[str],
max_retries: int,
timeout: Optional[Union[float, httpx.Timeout]] = None,
response_format: Optional[str] = None,
speed: Optional[int] = None,
client=None,
headers: Optional[dict] = None,
custom_llm_provider: Optional[str] = None,
aspeech: Optional[bool] = None,
) -> ResponseContextManager[StreamedBinaryAPIResponse]:
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore
optional_params = {}
if response_format is not None:
optional_params["response_format"] = response_format
if speed is not None:
optional_params["speed"] = speed
if timeout is None:
timeout = litellm.request_timeout
response: Optional[ResponseContextManager[StreamedBinaryAPIResponse]] = None
if custom_llm_provider == "openai":
api_base = (
api_base # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
or litellm.api_base
or get_secret("OPENAI_API_BASE")
or "https://api.openai.com/v1"
) # type: ignore
# set API KEY
api_key = (
api_key
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
or litellm.openai_key
or get_secret("OPENAI_API_KEY")
) # type: ignore
organization = (
organization
or litellm.organization
or get_secret("OPENAI_ORGANIZATION")
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
) # type: ignore
project = (
project
or litellm.project
or get_secret("OPENAI_PROJECT")
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
) # type: ignore
headers = headers or litellm.headers
response = openai_chat_completions.audio_speech(
model=model,
input=input,
voice=voice,
optional_params=optional_params,
api_key=api_key,
api_base=api_base,
organization=organization,
project=project,
max_retries=max_retries,
timeout=timeout,
client=client, # pass AsyncOpenAI, OpenAI client
aspeech=aspeech,
)
if response is None:
raise Exception(
"Unable to map the custom llm provider={} to a known provider={}.".format(
custom_llm_provider, litellm.provider_list
)
)
return response
##### Health Endpoints ####################### ##### Health Endpoints #######################

View file

@ -0,0 +1,91 @@
# What is this?
## unit tests for openai tts endpoint
import sys, os, asyncio, time, random, uuid
import traceback
from dotenv import load_dotenv
load_dotenv()
import os
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import pytest
import litellm, openai
from pathlib import Path
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_audio_speech_openai(sync_mode):
speech_file_path = Path(__file__).parent / "speech.mp3"
openai_chat_completions = litellm.OpenAIChatCompletion()
if sync_mode:
with openai_chat_completions.audio_speech(
model="tts-1",
voice="alloy",
input="the quick brown fox jumped over the lazy dogs",
api_base=None,
api_key=None,
organization=None,
project=None,
max_retries=1,
timeout=600,
client=None,
optional_params={},
) as response:
response.stream_to_file(speech_file_path)
else:
async with openai_chat_completions.async_audio_speech(
model="tts-1",
voice="alloy",
input="the quick brown fox jumped over the lazy dogs",
api_base=None,
api_key=None,
organization=None,
project=None,
max_retries=1,
timeout=600,
client=None,
optional_params={},
) as response:
speech = await response.parse()
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_audio_speech_litellm(sync_mode):
speech_file_path = Path(__file__).parent / "speech.mp3"
if sync_mode:
with litellm.speech(
model="openai/tts-1",
voice="alloy",
input="the quick brown fox jumped over the lazy dogs",
api_base=None,
api_key=None,
organization=None,
project=None,
max_retries=1,
timeout=600,
client=None,
optional_params={},
) as response:
response.stream_to_file(speech_file_path)
else:
async with litellm.aspeech(
model="openai/tts-1",
voice="alloy",
input="the quick brown fox jumped over the lazy dogs",
api_base=None,
api_key=None,
organization=None,
project=None,
max_retries=1,
timeout=600,
client=None,
optional_params={},
) as response:
await response.stream_to_file(speech_file_path)

View file

@ -8,7 +8,6 @@ from typing import (
) )
from typing_extensions import override, Required, Dict from typing_extensions import override, Required, Dict
from pydantic import BaseModel from pydantic import BaseModel
from openai.types.beta.threads.message_content import MessageContent from openai.types.beta.threads.message_content import MessageContent
from openai.types.beta.threads.message import Message as OpenAIMessage from openai.types.beta.threads.message import Message as OpenAIMessage
from openai.types.beta.thread_create_params import ( from openai.types.beta.thread_create_params import (
@ -21,7 +20,12 @@ from openai.pagination import SyncCursorPage
from os import PathLike from os import PathLike
from openai.types import FileObject, Batch from openai.types import FileObject, Batch
from openai._legacy_response import HttpxBinaryResponseContent from openai._legacy_response import HttpxBinaryResponseContent
from openai._response import (
StreamedBinaryAPIResponse,
ResponseContextManager,
AsyncStreamedBinaryAPIResponse,
AsyncResponseContextManager,
)
from typing import TypedDict, List, Optional, Tuple, Mapping, IO from typing import TypedDict, List, Optional, Tuple, Mapping, IO
FileContent = Union[IO[bytes], bytes, PathLike] FileContent = Union[IO[bytes], bytes, PathLike]