mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
feat(main.py): support openai tts endpoint
Closes https://github.com/BerriAI/litellm/issues/3094
This commit is contained in:
parent
3167bee25a
commit
a67cbf47f6
5 changed files with 322 additions and 3 deletions
|
@ -227,7 +227,7 @@ default_team_settings: Optional[List] = None
|
||||||
max_user_budget: Optional[float] = None
|
max_user_budget: Optional[float] = None
|
||||||
max_end_user_budget: Optional[float] = None
|
max_end_user_budget: Optional[float] = None
|
||||||
#### RELIABILITY ####
|
#### RELIABILITY ####
|
||||||
request_timeout: Optional[float] = 6000
|
request_timeout: float = 6000
|
||||||
num_retries: Optional[int] = None # per model endpoint
|
num_retries: Optional[int] = None # per model endpoint
|
||||||
default_fallbacks: Optional[List] = None
|
default_fallbacks: Optional[List] = None
|
||||||
fallbacks: Optional[List] = None
|
fallbacks: Optional[List] = None
|
||||||
|
@ -304,6 +304,7 @@ api_base = None
|
||||||
headers = None
|
headers = None
|
||||||
api_version = None
|
api_version = None
|
||||||
organization = None
|
organization = None
|
||||||
|
project = None
|
||||||
config_path = None
|
config_path = None
|
||||||
####### COMPLETION MODELS ###################
|
####### COMPLETION MODELS ###################
|
||||||
open_ai_chat_completion_models: List = []
|
open_ai_chat_completion_models: List = []
|
||||||
|
|
|
@ -26,6 +26,7 @@ import litellm
|
||||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||||
from openai import OpenAI, AsyncOpenAI
|
from openai import OpenAI, AsyncOpenAI
|
||||||
from ..types.llms.openai import *
|
from ..types.llms.openai import *
|
||||||
|
import openai
|
||||||
|
|
||||||
|
|
||||||
class OpenAIError(Exception):
|
class OpenAIError(Exception):
|
||||||
|
@ -1180,6 +1181,94 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
)
|
)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
def audio_speech(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
input: str,
|
||||||
|
voice: str,
|
||||||
|
optional_params: dict,
|
||||||
|
api_key: Optional[str],
|
||||||
|
api_base: Optional[str],
|
||||||
|
organization: Optional[str],
|
||||||
|
project: Optional[str],
|
||||||
|
max_retries: int,
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
aspeech: Optional[bool] = None,
|
||||||
|
client=None,
|
||||||
|
) -> ResponseContextManager[StreamedBinaryAPIResponse]:
|
||||||
|
|
||||||
|
if aspeech is not None and aspeech == True:
|
||||||
|
return self.async_audio_speech(
|
||||||
|
model=model,
|
||||||
|
input=input,
|
||||||
|
voice=voice,
|
||||||
|
optional_params=optional_params,
|
||||||
|
api_key=api_key,
|
||||||
|
api_base=api_base,
|
||||||
|
organization=organization,
|
||||||
|
project=project,
|
||||||
|
max_retries=max_retries,
|
||||||
|
timeout=timeout,
|
||||||
|
client=client,
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
|
if client is None:
|
||||||
|
openai_client = OpenAI(
|
||||||
|
api_key=api_key,
|
||||||
|
base_url=api_base,
|
||||||
|
organization=organization,
|
||||||
|
project=project,
|
||||||
|
http_client=litellm.client_session,
|
||||||
|
timeout=timeout,
|
||||||
|
max_retries=max_retries,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
openai_client = client
|
||||||
|
|
||||||
|
response = openai_client.audio.speech.with_streaming_response.create(
|
||||||
|
model="tts-1",
|
||||||
|
voice="alloy",
|
||||||
|
input="the quick brown fox jumped over the lazy dogs",
|
||||||
|
**optional_params,
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
|
||||||
|
def async_audio_speech(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
input: str,
|
||||||
|
voice: str,
|
||||||
|
optional_params: dict,
|
||||||
|
api_key: Optional[str],
|
||||||
|
api_base: Optional[str],
|
||||||
|
organization: Optional[str],
|
||||||
|
project: Optional[str],
|
||||||
|
max_retries: int,
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
client=None,
|
||||||
|
) -> AsyncResponseContextManager[AsyncStreamedBinaryAPIResponse]:
|
||||||
|
|
||||||
|
if client is None:
|
||||||
|
openai_client = AsyncOpenAI(
|
||||||
|
api_key=api_key,
|
||||||
|
base_url=api_base,
|
||||||
|
organization=organization,
|
||||||
|
project=project,
|
||||||
|
http_client=litellm.aclient_session,
|
||||||
|
timeout=timeout,
|
||||||
|
max_retries=max_retries,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
openai_client = client
|
||||||
|
|
||||||
|
response = openai_client.audio.speech.with_streaming_response.create(
|
||||||
|
model="tts-1",
|
||||||
|
voice="alloy",
|
||||||
|
input="the quick brown fox jumped over the lazy dogs",
|
||||||
|
**optional_params,
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
|
||||||
async def ahealth_check(
|
async def ahealth_check(
|
||||||
self,
|
self,
|
||||||
model: Optional[str],
|
model: Optional[str],
|
||||||
|
|
134
litellm/main.py
134
litellm/main.py
|
@ -91,6 +91,12 @@ import tiktoken
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from typing import Callable, List, Optional, Dict, Union, Mapping
|
from typing import Callable, List, Optional, Dict, Union, Mapping
|
||||||
from .caching import enable_cache, disable_cache, update_cache
|
from .caching import enable_cache, disable_cache, update_cache
|
||||||
|
from .types.llms.openai import (
|
||||||
|
StreamedBinaryAPIResponse,
|
||||||
|
ResponseContextManager,
|
||||||
|
AsyncResponseContextManager,
|
||||||
|
AsyncStreamedBinaryAPIResponse,
|
||||||
|
)
|
||||||
|
|
||||||
encoding = tiktoken.get_encoding("cl100k_base")
|
encoding = tiktoken.get_encoding("cl100k_base")
|
||||||
from litellm.utils import (
|
from litellm.utils import (
|
||||||
|
@ -4163,6 +4169,134 @@ def transcription(
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
def aspeech(
|
||||||
|
*args, **kwargs
|
||||||
|
) -> AsyncResponseContextManager[AsyncStreamedBinaryAPIResponse]:
|
||||||
|
"""
|
||||||
|
Calls openai tts endpoints.
|
||||||
|
"""
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
model = args[0] if len(args) > 0 else kwargs["model"]
|
||||||
|
### PASS ARGS TO Image Generation ###
|
||||||
|
kwargs["aspeech"] = True
|
||||||
|
custom_llm_provider = kwargs.get("custom_llm_provider", None)
|
||||||
|
try:
|
||||||
|
# # Use a partial function to pass your keyword arguments
|
||||||
|
# func = partial(speech, *args, **kwargs)
|
||||||
|
|
||||||
|
# # Add the context to the function
|
||||||
|
# ctx = contextvars.copy_context()
|
||||||
|
# func_with_context = partial(ctx.run, func)
|
||||||
|
|
||||||
|
# _, custom_llm_provider, _, _ = get_llm_provider(
|
||||||
|
# model=model, api_base=kwargs.get("api_base", None)
|
||||||
|
# )
|
||||||
|
|
||||||
|
# # Await normally
|
||||||
|
# init_response = await loop.run_in_executor(None, func_with_context)
|
||||||
|
# if asyncio.iscoroutine(init_response):
|
||||||
|
# response = await init_response
|
||||||
|
# else:
|
||||||
|
# # Call the synchronous function using run_in_executor
|
||||||
|
# response = await loop.run_in_executor(None, func_with_context)
|
||||||
|
return speech(*args, **kwargs) # type: ignore
|
||||||
|
except Exception as e:
|
||||||
|
custom_llm_provider = custom_llm_provider or "openai"
|
||||||
|
raise exception_type(
|
||||||
|
model=model,
|
||||||
|
custom_llm_provider=custom_llm_provider,
|
||||||
|
original_exception=e,
|
||||||
|
completion_kwargs=args,
|
||||||
|
extra_kwargs=kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def speech(
|
||||||
|
model: str,
|
||||||
|
input: str,
|
||||||
|
voice: str,
|
||||||
|
optional_params: dict,
|
||||||
|
api_key: Optional[str],
|
||||||
|
api_base: Optional[str],
|
||||||
|
organization: Optional[str],
|
||||||
|
project: Optional[str],
|
||||||
|
max_retries: int,
|
||||||
|
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||||
|
response_format: Optional[str] = None,
|
||||||
|
speed: Optional[int] = None,
|
||||||
|
client=None,
|
||||||
|
headers: Optional[dict] = None,
|
||||||
|
custom_llm_provider: Optional[str] = None,
|
||||||
|
aspeech: Optional[bool] = None,
|
||||||
|
) -> ResponseContextManager[StreamedBinaryAPIResponse]:
|
||||||
|
|
||||||
|
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore
|
||||||
|
|
||||||
|
optional_params = {}
|
||||||
|
if response_format is not None:
|
||||||
|
optional_params["response_format"] = response_format
|
||||||
|
if speed is not None:
|
||||||
|
optional_params["speed"] = speed
|
||||||
|
|
||||||
|
if timeout is None:
|
||||||
|
timeout = litellm.request_timeout
|
||||||
|
|
||||||
|
response: Optional[ResponseContextManager[StreamedBinaryAPIResponse]] = None
|
||||||
|
if custom_llm_provider == "openai":
|
||||||
|
api_base = (
|
||||||
|
api_base # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
||||||
|
or litellm.api_base
|
||||||
|
or get_secret("OPENAI_API_BASE")
|
||||||
|
or "https://api.openai.com/v1"
|
||||||
|
) # type: ignore
|
||||||
|
# set API KEY
|
||||||
|
api_key = (
|
||||||
|
api_key
|
||||||
|
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
|
||||||
|
or litellm.openai_key
|
||||||
|
or get_secret("OPENAI_API_KEY")
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
|
organization = (
|
||||||
|
organization
|
||||||
|
or litellm.organization
|
||||||
|
or get_secret("OPENAI_ORGANIZATION")
|
||||||
|
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
|
project = (
|
||||||
|
project
|
||||||
|
or litellm.project
|
||||||
|
or get_secret("OPENAI_PROJECT")
|
||||||
|
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
|
headers = headers or litellm.headers
|
||||||
|
|
||||||
|
response = openai_chat_completions.audio_speech(
|
||||||
|
model=model,
|
||||||
|
input=input,
|
||||||
|
voice=voice,
|
||||||
|
optional_params=optional_params,
|
||||||
|
api_key=api_key,
|
||||||
|
api_base=api_base,
|
||||||
|
organization=organization,
|
||||||
|
project=project,
|
||||||
|
max_retries=max_retries,
|
||||||
|
timeout=timeout,
|
||||||
|
client=client, # pass AsyncOpenAI, OpenAI client
|
||||||
|
aspeech=aspeech,
|
||||||
|
)
|
||||||
|
|
||||||
|
if response is None:
|
||||||
|
raise Exception(
|
||||||
|
"Unable to map the custom llm provider={} to a known provider={}.".format(
|
||||||
|
custom_llm_provider, litellm.provider_list
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
##### Health Endpoints #######################
|
##### Health Endpoints #######################
|
||||||
|
|
||||||
|
|
||||||
|
|
91
litellm/tests/test_audio_speech.py
Normal file
91
litellm/tests/test_audio_speech.py
Normal file
|
@ -0,0 +1,91 @@
|
||||||
|
# What is this?
|
||||||
|
## unit tests for openai tts endpoint
|
||||||
|
|
||||||
|
import sys, os, asyncio, time, random, uuid
|
||||||
|
import traceback
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
import os
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
import pytest
|
||||||
|
import litellm, openai
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_audio_speech_openai(sync_mode):
|
||||||
|
|
||||||
|
speech_file_path = Path(__file__).parent / "speech.mp3"
|
||||||
|
openai_chat_completions = litellm.OpenAIChatCompletion()
|
||||||
|
if sync_mode:
|
||||||
|
with openai_chat_completions.audio_speech(
|
||||||
|
model="tts-1",
|
||||||
|
voice="alloy",
|
||||||
|
input="the quick brown fox jumped over the lazy dogs",
|
||||||
|
api_base=None,
|
||||||
|
api_key=None,
|
||||||
|
organization=None,
|
||||||
|
project=None,
|
||||||
|
max_retries=1,
|
||||||
|
timeout=600,
|
||||||
|
client=None,
|
||||||
|
optional_params={},
|
||||||
|
) as response:
|
||||||
|
response.stream_to_file(speech_file_path)
|
||||||
|
else:
|
||||||
|
async with openai_chat_completions.async_audio_speech(
|
||||||
|
model="tts-1",
|
||||||
|
voice="alloy",
|
||||||
|
input="the quick brown fox jumped over the lazy dogs",
|
||||||
|
api_base=None,
|
||||||
|
api_key=None,
|
||||||
|
organization=None,
|
||||||
|
project=None,
|
||||||
|
max_retries=1,
|
||||||
|
timeout=600,
|
||||||
|
client=None,
|
||||||
|
optional_params={},
|
||||||
|
) as response:
|
||||||
|
speech = await response.parse()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_audio_speech_litellm(sync_mode):
|
||||||
|
speech_file_path = Path(__file__).parent / "speech.mp3"
|
||||||
|
|
||||||
|
if sync_mode:
|
||||||
|
with litellm.speech(
|
||||||
|
model="openai/tts-1",
|
||||||
|
voice="alloy",
|
||||||
|
input="the quick brown fox jumped over the lazy dogs",
|
||||||
|
api_base=None,
|
||||||
|
api_key=None,
|
||||||
|
organization=None,
|
||||||
|
project=None,
|
||||||
|
max_retries=1,
|
||||||
|
timeout=600,
|
||||||
|
client=None,
|
||||||
|
optional_params={},
|
||||||
|
) as response:
|
||||||
|
response.stream_to_file(speech_file_path)
|
||||||
|
else:
|
||||||
|
async with litellm.aspeech(
|
||||||
|
model="openai/tts-1",
|
||||||
|
voice="alloy",
|
||||||
|
input="the quick brown fox jumped over the lazy dogs",
|
||||||
|
api_base=None,
|
||||||
|
api_key=None,
|
||||||
|
organization=None,
|
||||||
|
project=None,
|
||||||
|
max_retries=1,
|
||||||
|
timeout=600,
|
||||||
|
client=None,
|
||||||
|
optional_params={},
|
||||||
|
) as response:
|
||||||
|
await response.stream_to_file(speech_file_path)
|
|
@ -8,7 +8,6 @@ from typing import (
|
||||||
)
|
)
|
||||||
from typing_extensions import override, Required, Dict
|
from typing_extensions import override, Required, Dict
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from openai.types.beta.threads.message_content import MessageContent
|
from openai.types.beta.threads.message_content import MessageContent
|
||||||
from openai.types.beta.threads.message import Message as OpenAIMessage
|
from openai.types.beta.threads.message import Message as OpenAIMessage
|
||||||
from openai.types.beta.thread_create_params import (
|
from openai.types.beta.thread_create_params import (
|
||||||
|
@ -21,7 +20,12 @@ from openai.pagination import SyncCursorPage
|
||||||
from os import PathLike
|
from os import PathLike
|
||||||
from openai.types import FileObject, Batch
|
from openai.types import FileObject, Batch
|
||||||
from openai._legacy_response import HttpxBinaryResponseContent
|
from openai._legacy_response import HttpxBinaryResponseContent
|
||||||
|
from openai._response import (
|
||||||
|
StreamedBinaryAPIResponse,
|
||||||
|
ResponseContextManager,
|
||||||
|
AsyncStreamedBinaryAPIResponse,
|
||||||
|
AsyncResponseContextManager,
|
||||||
|
)
|
||||||
from typing import TypedDict, List, Optional, Tuple, Mapping, IO
|
from typing import TypedDict, List, Optional, Tuple, Mapping, IO
|
||||||
|
|
||||||
FileContent = Union[IO[bytes], bytes, PathLike]
|
FileContent = Union[IO[bytes], bytes, PathLike]
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue