Merge pull request #3928 from BerriAI/litellm_audio_speech_endpoint

feat(main.py): support openai tts endpoint
This commit is contained in:
Krish Dholakia 2024-05-30 17:30:42 -07:00 committed by GitHub
commit 73e3dba2f6
11 changed files with 574 additions and 7 deletions

View file

@ -227,7 +227,7 @@ default_team_settings: Optional[List] = None
max_user_budget: Optional[float] = None max_user_budget: Optional[float] = None
max_end_user_budget: Optional[float] = None max_end_user_budget: Optional[float] = None
#### RELIABILITY #### #### RELIABILITY ####
request_timeout: Optional[float] = 6000 request_timeout: float = 6000
num_retries: Optional[int] = None # per model endpoint num_retries: Optional[int] = None # per model endpoint
default_fallbacks: Optional[List] = None default_fallbacks: Optional[List] = None
fallbacks: Optional[List] = None fallbacks: Optional[List] = None
@ -304,6 +304,7 @@ api_base = None
headers = None headers = None
api_version = None api_version = None
organization = None organization = None
project = None
config_path = None config_path = None
####### COMPLETION MODELS ################### ####### COMPLETION MODELS ###################
open_ai_chat_completion_models: List = [] open_ai_chat_completion_models: List = []

View file

@ -26,6 +26,7 @@ import litellm
from .prompt_templates.factory import prompt_factory, custom_prompt from .prompt_templates.factory import prompt_factory, custom_prompt
from openai import OpenAI, AsyncOpenAI from openai import OpenAI, AsyncOpenAI
from ..types.llms.openai import * from ..types.llms.openai import *
import openai
class OpenAIError(Exception): class OpenAIError(Exception):
@ -1180,6 +1181,95 @@ class OpenAIChatCompletion(BaseLLM):
) )
raise e raise e
def audio_speech(
self,
model: str,
input: str,
voice: str,
optional_params: dict,
api_key: Optional[str],
api_base: Optional[str],
organization: Optional[str],
project: Optional[str],
max_retries: int,
timeout: Union[float, httpx.Timeout],
aspeech: Optional[bool] = None,
client=None,
) -> HttpxBinaryResponseContent:
if aspeech is not None and aspeech == True:
return self.async_audio_speech(
model=model,
input=input,
voice=voice,
optional_params=optional_params,
api_key=api_key,
api_base=api_base,
organization=organization,
project=project,
max_retries=max_retries,
timeout=timeout,
client=client,
) # type: ignore
if client is None:
openai_client = OpenAI(
api_key=api_key,
base_url=api_base,
organization=organization,
project=project,
http_client=litellm.client_session,
timeout=timeout,
max_retries=max_retries,
)
else:
openai_client = client
response = openai_client.audio.speech.create(
model=model,
voice=voice, # type: ignore
input=input,
**optional_params,
)
return response
async def async_audio_speech(
self,
model: str,
input: str,
voice: str,
optional_params: dict,
api_key: Optional[str],
api_base: Optional[str],
organization: Optional[str],
project: Optional[str],
max_retries: int,
timeout: Union[float, httpx.Timeout],
client=None,
) -> HttpxBinaryResponseContent:
if client is None:
openai_client = AsyncOpenAI(
api_key=api_key,
base_url=api_base,
organization=organization,
project=project,
http_client=litellm.aclient_session,
timeout=timeout,
max_retries=max_retries,
)
else:
openai_client = client
response = await openai_client.audio.speech.create(
model=model,
voice=voice, # type: ignore
input=input,
**optional_params,
)
return response
async def ahealth_check( async def ahealth_check(
self, self,
model: Optional[str], model: Optional[str],

View file

@ -91,6 +91,7 @@ import tiktoken
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from typing import Callable, List, Optional, Dict, Union, Mapping from typing import Callable, List, Optional, Dict, Union, Mapping
from .caching import enable_cache, disable_cache, update_cache from .caching import enable_cache, disable_cache, update_cache
from .types.llms.openai import HttpxBinaryResponseContent
encoding = tiktoken.get_encoding("cl100k_base") encoding = tiktoken.get_encoding("cl100k_base")
from litellm.utils import ( from litellm.utils import (
@ -4163,6 +4164,137 @@ def transcription(
return response return response
@client
async def aspeech(*args, **kwargs) -> HttpxBinaryResponseContent:
"""
Calls openai tts endpoints.
"""
loop = asyncio.get_event_loop()
model = args[0] if len(args) > 0 else kwargs["model"]
### PASS ARGS TO Image Generation ###
kwargs["aspeech"] = True
custom_llm_provider = kwargs.get("custom_llm_provider", None)
try:
# Use a partial function to pass your keyword arguments
func = partial(speech, *args, **kwargs)
# Add the context to the function
ctx = contextvars.copy_context()
func_with_context = partial(ctx.run, func)
_, custom_llm_provider, _, _ = get_llm_provider(
model=model, api_base=kwargs.get("api_base", None)
)
# Await normally
init_response = await loop.run_in_executor(None, func_with_context)
if asyncio.iscoroutine(init_response):
response = await init_response
else:
# Call the synchronous function using run_in_executor
response = await loop.run_in_executor(None, func_with_context)
return response # type: ignore
except Exception as e:
custom_llm_provider = custom_llm_provider or "openai"
raise exception_type(
model=model,
custom_llm_provider=custom_llm_provider,
original_exception=e,
completion_kwargs=args,
extra_kwargs=kwargs,
)
@client
def speech(
model: str,
input: str,
voice: str,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
organization: Optional[str] = None,
project: Optional[str] = None,
max_retries: Optional[int] = None,
metadata: Optional[dict] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None,
response_format: Optional[str] = None,
speed: Optional[int] = None,
client=None,
headers: Optional[dict] = None,
custom_llm_provider: Optional[str] = None,
aspeech: Optional[bool] = None,
**kwargs,
) -> HttpxBinaryResponseContent:
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore
optional_params = {}
if response_format is not None:
optional_params["response_format"] = response_format
if speed is not None:
optional_params["speed"] = speed # type: ignore
if timeout is None:
timeout = litellm.request_timeout
if max_retries is None:
max_retries = litellm.num_retries or openai.DEFAULT_MAX_RETRIES
response: Optional[HttpxBinaryResponseContent] = None
if custom_llm_provider == "openai":
api_base = (
api_base # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
or litellm.api_base
or get_secret("OPENAI_API_BASE")
or "https://api.openai.com/v1"
) # type: ignore
# set API KEY
api_key = (
api_key
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
or litellm.openai_key
or get_secret("OPENAI_API_KEY")
) # type: ignore
organization = (
organization
or litellm.organization
or get_secret("OPENAI_ORGANIZATION")
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
) # type: ignore
project = (
project
or litellm.project
or get_secret("OPENAI_PROJECT")
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
) # type: ignore
headers = headers or litellm.headers
response = openai_chat_completions.audio_speech(
model=model,
input=input,
voice=voice,
optional_params=optional_params,
api_key=api_key,
api_base=api_base,
organization=organization,
project=project,
max_retries=max_retries,
timeout=timeout,
client=client, # pass AsyncOpenAI, OpenAI client
aspeech=aspeech,
)
if response is None:
raise Exception(
"Unable to map the custom llm provider={} to a known provider={}.".format(
custom_llm_provider, litellm.provider_list
)
)
return response
##### Health Endpoints ####################### ##### Health Endpoints #######################

View file

@ -26,10 +26,8 @@ model_list:
api_version: '2023-05-15' api_version: '2023-05-15'
model: azure/chatgpt-v-2 model: azure/chatgpt-v-2
model_name: gpt-3.5-turbo model_name: gpt-3.5-turbo
- model_name: mistral - model_name: tts
litellm_params: litellm_params:
model: azure/mistral-large-latest model: openai/tts-1
api_base: https://Mistral-large-nmefg-serverless.eastus2.inference.ai.azure.com/v1/
api_key: zEJhgmw1FAKk0XzPWoLEg7WU1cXbWYYn
router_settings: router_settings:
enable_pre_call_checks: true enable_pre_call_checks: true

View file

@ -79,6 +79,9 @@ def generate_feedback_box():
import litellm import litellm
from litellm.types.llms.openai import (
HttpxBinaryResponseContent,
)
from litellm.proxy.utils import ( from litellm.proxy.utils import (
PrismaClient, PrismaClient,
DBClient, DBClient,
@ -4883,6 +4886,169 @@ async def image_generation(
) )
@router.post(
"/v1/audio/speech",
dependencies=[Depends(user_api_key_auth)],
tags=["audio"],
)
@router.post(
"/audio/speech",
dependencies=[Depends(user_api_key_auth)],
tags=["audio"],
)
async def audio_speech(
request: Request,
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Same params as:
https://platform.openai.com/docs/api-reference/audio/createSpeech
"""
global proxy_logging_obj
data: Dict = {}
try:
# Use orjson to parse JSON data, orjson speeds up requests significantly
body = await request.body()
data = orjson.loads(body)
# Include original request and headers in the data
data["proxy_server_request"] = { # type: ignore
"url": str(request.url),
"method": request.method,
"headers": dict(request.headers),
"body": copy.copy(data), # use copy instead of deepcopy
}
if data.get("user", None) is None and user_api_key_dict.user_id is not None:
data["user"] = user_api_key_dict.user_id
if user_model:
data["model"] = user_model
if "metadata" not in data:
data["metadata"] = {}
data["metadata"]["user_api_key"] = user_api_key_dict.api_key
data["metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata
_headers = dict(request.headers)
_headers.pop(
"authorization", None
) # do not store the original `sk-..` api key in the db
data["metadata"]["headers"] = _headers
data["metadata"]["user_api_key_alias"] = getattr(
user_api_key_dict, "key_alias", None
)
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
data["metadata"]["user_api_key_team_id"] = getattr(
user_api_key_dict, "team_id", None
)
data["metadata"]["global_max_parallel_requests"] = general_settings.get(
"global_max_parallel_requests", None
)
data["metadata"]["user_api_key_team_alias"] = getattr(
user_api_key_dict, "team_alias", None
)
data["metadata"]["endpoint"] = str(request.url)
### TEAM-SPECIFIC PARAMS ###
if user_api_key_dict.team_id is not None:
team_config = await proxy_config.load_team_config(
team_id=user_api_key_dict.team_id
)
if len(team_config) == 0:
pass
else:
team_id = team_config.pop("team_id", None)
data["metadata"]["team_id"] = team_id
data = {
**team_config,
**data,
} # add the team-specific configs to the completion call
router_model_names = llm_router.model_names if llm_router is not None else []
### CALL HOOKS ### - modify incoming data / reject request before calling the model
data = await proxy_logging_obj.pre_call_hook(
user_api_key_dict=user_api_key_dict, data=data, call_type="image_generation"
)
## ROUTE TO CORRECT ENDPOINT ##
# skip router if user passed their key
if "api_key" in data:
response = await litellm.aspeech(**data)
elif (
llm_router is not None and data["model"] in router_model_names
): # model in router model list
response = await llm_router.aspeech(**data)
elif (
llm_router is not None and data["model"] in llm_router.deployment_names
): # model in router deployments, calling a specific deployment on the router
response = await llm_router.aspeech(**data, specific_deployment=True)
elif (
llm_router is not None
and llm_router.model_group_alias is not None
and data["model"] in llm_router.model_group_alias
): # model set in model_group_alias
response = await llm_router.aspeech(
**data
) # ensure this goes the llm_router, router will do the correct alias mapping
elif (
llm_router is not None
and data["model"] not in router_model_names
and llm_router.default_deployment is not None
): # model in router deployments, calling a specific deployment on the router
response = await llm_router.aspeech(**data)
elif user_model is not None: # `litellm --model <your-model-name>`
response = await litellm.aspeech(**data)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"error": "audio_speech: Invalid model name passed in model="
+ data.get("model", "")
},
)
### ALERTING ###
data["litellm_status"] = "success" # used for alerting
### RESPONSE HEADERS ###
hidden_params = getattr(response, "_hidden_params", {}) or {}
model_id = hidden_params.get("model_id", None) or ""
cache_key = hidden_params.get("cache_key", None) or ""
api_base = hidden_params.get("api_base", None) or ""
# Printing each chunk size
async def generate(_response: HttpxBinaryResponseContent):
_generator = await _response.aiter_bytes(chunk_size=1024)
async for chunk in _generator:
yield chunk
custom_headers = get_custom_headers(
user_api_key_dict=user_api_key_dict,
model_id=model_id,
cache_key=cache_key,
api_base=api_base,
version=version,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
fastest_response_batch_completion=None,
)
selected_data_generator = select_data_generator(
response=response,
user_api_key_dict=user_api_key_dict,
request_data=data,
)
return StreamingResponse(
generate(response), media_type="audio/mpeg", headers=custom_headers
)
except Exception as e:
traceback.print_exc()
raise e
@router.post( @router.post(
"/v1/audio/transcriptions", "/v1/audio/transcriptions",
dependencies=[Depends(user_api_key_auth)], dependencies=[Depends(user_api_key_auth)],

View file

@ -1204,6 +1204,84 @@ class Router:
self.fail_calls[model_name] += 1 self.fail_calls[model_name] += 1
raise e raise e
async def aspeech(self, model: str, input: str, voice: str, **kwargs):
"""
Example Usage:
```
from litellm import Router
client = Router(model_list = [
{
"model_name": "tts",
"litellm_params": {
"model": "tts-1",
},
},
])
async with client.aspeech(
model="tts",
voice="alloy",
input="the quick brown fox jumped over the lazy dogs",
api_base=None,
api_key=None,
organization=None,
project=None,
max_retries=1,
timeout=600,
client=None,
optional_params={},
) as response:
response.stream_to_file(speech_file_path)
```
"""
try:
kwargs["input"] = input
kwargs["voice"] = voice
deployment = await self.async_get_available_deployment(
model=model,
messages=[{"role": "user", "content": "prompt"}],
specific_deployment=kwargs.pop("specific_deployment", None),
)
kwargs.setdefault("metadata", {}).update(
{
"deployment": deployment["litellm_params"]["model"],
"model_info": deployment.get("model_info", {}),
}
)
kwargs["model_info"] = deployment.get("model_info", {})
data = deployment["litellm_params"].copy()
model_name = data["model"]
for k, v in self.default_litellm_params.items():
if (
k not in kwargs
): # prioritize model-specific params > default router params
kwargs[k] = v
elif k == "metadata":
kwargs[k].update(v)
potential_model_client = self._get_client(
deployment=deployment, kwargs=kwargs, client_type="async"
)
# check if provided keys == client keys #
dynamic_api_key = kwargs.get("api_key", None)
if (
dynamic_api_key is not None
and potential_model_client is not None
and dynamic_api_key != potential_model_client.api_key
):
model_client = None
else:
model_client = potential_model_client
response = await litellm.aspeech(**data, **kwargs)
return response
except Exception as e:
raise e
async def amoderation(self, model: str, input: str, **kwargs): async def amoderation(self, model: str, input: str, **kwargs):
try: try:
kwargs["model"] = model kwargs["model"] = model

View file

@ -0,0 +1,96 @@
# What is this?
## unit tests for openai tts endpoint
import sys, os, asyncio, time, random, uuid
import traceback
from dotenv import load_dotenv
load_dotenv()
import os
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import pytest
import litellm, openai
from pathlib import Path
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_audio_speech_litellm(sync_mode):
speech_file_path = Path(__file__).parent / "speech.mp3"
if sync_mode:
response = litellm.speech(
model="openai/tts-1",
voice="alloy",
input="the quick brown fox jumped over the lazy dogs",
api_base=None,
api_key=None,
organization=None,
project=None,
max_retries=1,
timeout=600,
client=None,
optional_params={},
)
from litellm.llms.openai import HttpxBinaryResponseContent
assert isinstance(response, HttpxBinaryResponseContent)
else:
response = await litellm.aspeech(
model="openai/tts-1",
voice="alloy",
input="the quick brown fox jumped over the lazy dogs",
api_base=None,
api_key=None,
organization=None,
project=None,
max_retries=1,
timeout=600,
client=None,
optional_params={},
)
from litellm.llms.openai import HttpxBinaryResponseContent
assert isinstance(response, HttpxBinaryResponseContent)
@pytest.mark.parametrize("mode", ["iterator"]) # "file",
@pytest.mark.asyncio
async def test_audio_speech_router(mode):
speech_file_path = Path(__file__).parent / "speech.mp3"
from litellm import Router
client = Router(
model_list=[
{
"model_name": "tts",
"litellm_params": {
"model": "openai/tts-1",
},
},
]
)
response = await client.aspeech(
model="tts",
voice="alloy",
input="the quick brown fox jumped over the lazy dogs",
api_base=None,
api_key=None,
organization=None,
project=None,
max_retries=1,
timeout=600,
client=None,
optional_params={},
)
from litellm.llms.openai import HttpxBinaryResponseContent
assert isinstance(response, HttpxBinaryResponseContent)

View file

@ -8,7 +8,6 @@ from typing import (
) )
from typing_extensions import override, Required, Dict from typing_extensions import override, Required, Dict
from pydantic import BaseModel from pydantic import BaseModel
from openai.types.beta.threads.message_content import MessageContent from openai.types.beta.threads.message_content import MessageContent
from openai.types.beta.threads.message import Message as OpenAIMessage from openai.types.beta.threads.message import Message as OpenAIMessage
from openai.types.beta.thread_create_params import ( from openai.types.beta.thread_create_params import (
@ -21,7 +20,6 @@ from openai.pagination import SyncCursorPage
from os import PathLike from os import PathLike
from openai.types import FileObject, Batch from openai.types import FileObject, Batch
from openai._legacy_response import HttpxBinaryResponseContent from openai._legacy_response import HttpxBinaryResponseContent
from typing import TypedDict, List, Optional, Tuple, Mapping, IO from typing import TypedDict, List, Optional, Tuple, Mapping, IO
FileContent = Union[IO[bytes], bytes, PathLike] FileContent = Union[IO[bytes], bytes, PathLike]

View file

@ -1136,6 +1136,8 @@ class CallTypes(Enum):
amoderation = "amoderation" amoderation = "amoderation"
atranscription = "atranscription" atranscription = "atranscription"
transcription = "transcription" transcription = "transcription"
aspeech = "aspeech"
speech = "speech"
# Logging function -> log the exact model details + what's being sent | Non-BlockingP # Logging function -> log the exact model details + what's being sent | Non-BlockingP
@ -3005,6 +3007,10 @@ def function_setup(
): ):
_file_name: BinaryIO = args[1] if len(args) > 1 else kwargs["file"] _file_name: BinaryIO = args[1] if len(args) > 1 else kwargs["file"]
messages = "audio_file" messages = "audio_file"
elif (
call_type == CallTypes.aspeech.value or call_type == CallTypes.speech.value
):
messages = kwargs.get("input", "speech")
stream = True if "stream" in kwargs and kwargs["stream"] == True else False stream = True if "stream" in kwargs and kwargs["stream"] == True else False
logging_obj = Logging( logging_obj = Logging(
model=model, model=model,
@ -3346,6 +3352,8 @@ def client(original_function):
return result return result
elif "atranscription" in kwargs and kwargs["atranscription"] == True: elif "atranscription" in kwargs and kwargs["atranscription"] == True:
return result return result
elif "aspeech" in kwargs and kwargs["aspeech"] == True:
return result
### POST-CALL RULES ### ### POST-CALL RULES ###
post_call_processing(original_response=result, model=model or None) post_call_processing(original_response=result, model=model or None)