fix(openai.py): fix openai response for /audio/speech endpoint

This commit is contained in:
Krrish Dholakia 2024-05-30 16:41:06 -07:00
parent 1e89a1f56e
commit eb159b64e1
7 changed files with 311 additions and 127 deletions

View file

@ -91,12 +91,7 @@ import tiktoken
from concurrent.futures import ThreadPoolExecutor
from typing import Callable, List, Optional, Dict, Union, Mapping
from .caching import enable_cache, disable_cache, update_cache
from .types.llms.openai import (
StreamedBinaryAPIResponse,
ResponseContextManager,
AsyncResponseContextManager,
AsyncStreamedBinaryAPIResponse,
)
from .types.llms.openai import HttpxBinaryResponseContent
encoding = tiktoken.get_encoding("cl100k_base")
from litellm.utils import (
@ -4169,9 +4164,7 @@ def transcription(
return response
def aspeech(
*args, **kwargs
) -> AsyncResponseContextManager[AsyncStreamedBinaryAPIResponse]:
async def aspeech(*args, **kwargs) -> HttpxBinaryResponseContent:
"""
Calls openai tts endpoints.
"""
@ -4181,25 +4174,25 @@ def aspeech(
kwargs["aspeech"] = True
custom_llm_provider = kwargs.get("custom_llm_provider", None)
try:
# # Use a partial function to pass your keyword arguments
# func = partial(speech, *args, **kwargs)
# Use a partial function to pass your keyword arguments
func = partial(speech, *args, **kwargs)
# # Add the context to the function
# ctx = contextvars.copy_context()
# func_with_context = partial(ctx.run, func)
# Add the context to the function
ctx = contextvars.copy_context()
func_with_context = partial(ctx.run, func)
# _, custom_llm_provider, _, _ = get_llm_provider(
# model=model, api_base=kwargs.get("api_base", None)
# )
_, custom_llm_provider, _, _ = get_llm_provider(
model=model, api_base=kwargs.get("api_base", None)
)
# # Await normally
# init_response = await loop.run_in_executor(None, func_with_context)
# if asyncio.iscoroutine(init_response):
# response = await init_response
# else:
# # Call the synchronous function using run_in_executor
# response = await loop.run_in_executor(None, func_with_context)
return speech(*args, **kwargs) # type: ignore
# Await normally
init_response = await loop.run_in_executor(None, func_with_context)
if asyncio.iscoroutine(init_response):
response = await init_response
else:
# Call the synchronous function using run_in_executor
response = await loop.run_in_executor(None, func_with_context)
return response # type: ignore
except Exception as e:
custom_llm_provider = custom_llm_provider or "openai"
raise exception_type(
@ -4215,12 +4208,12 @@ def speech(
model: str,
input: str,
voice: str,
optional_params: dict,
api_key: Optional[str],
api_base: Optional[str],
organization: Optional[str],
project: Optional[str],
max_retries: int,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
organization: Optional[str] = None,
project: Optional[str] = None,
max_retries: Optional[int] = None,
metadata: Optional[dict] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None,
response_format: Optional[str] = None,
speed: Optional[int] = None,
@ -4228,7 +4221,8 @@ def speech(
headers: Optional[dict] = None,
custom_llm_provider: Optional[str] = None,
aspeech: Optional[bool] = None,
) -> ResponseContextManager[StreamedBinaryAPIResponse]:
**kwargs,
) -> HttpxBinaryResponseContent:
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore
@ -4236,12 +4230,14 @@ def speech(
if response_format is not None:
optional_params["response_format"] = response_format
if speed is not None:
optional_params["speed"] = speed
optional_params["speed"] = speed # type: ignore
if timeout is None:
timeout = litellm.request_timeout
response: Optional[ResponseContextManager[StreamedBinaryAPIResponse]] = None
if max_retries is None:
max_retries = litellm.num_retries or openai.DEFAULT_MAX_RETRIES
response: Optional[HttpxBinaryResponseContent] = None
if custom_llm_provider == "openai":
api_base = (
api_base # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there