mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 19:24:27 +00:00
fix(openai.py): fix openai response for /audio/speech
endpoint
This commit is contained in:
parent
1e89a1f56e
commit
eb159b64e1
7 changed files with 311 additions and 127 deletions
|
@ -91,12 +91,7 @@ import tiktoken
|
|||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Callable, List, Optional, Dict, Union, Mapping
|
||||
from .caching import enable_cache, disable_cache, update_cache
|
||||
from .types.llms.openai import (
|
||||
StreamedBinaryAPIResponse,
|
||||
ResponseContextManager,
|
||||
AsyncResponseContextManager,
|
||||
AsyncStreamedBinaryAPIResponse,
|
||||
)
|
||||
from .types.llms.openai import HttpxBinaryResponseContent
|
||||
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
from litellm.utils import (
|
||||
|
@ -4169,9 +4164,7 @@ def transcription(
|
|||
return response
|
||||
|
||||
|
||||
def aspeech(
|
||||
*args, **kwargs
|
||||
) -> AsyncResponseContextManager[AsyncStreamedBinaryAPIResponse]:
|
||||
async def aspeech(*args, **kwargs) -> HttpxBinaryResponseContent:
|
||||
"""
|
||||
Calls openai tts endpoints.
|
||||
"""
|
||||
|
@ -4181,25 +4174,25 @@ def aspeech(
|
|||
kwargs["aspeech"] = True
|
||||
custom_llm_provider = kwargs.get("custom_llm_provider", None)
|
||||
try:
|
||||
# # Use a partial function to pass your keyword arguments
|
||||
# func = partial(speech, *args, **kwargs)
|
||||
# Use a partial function to pass your keyword arguments
|
||||
func = partial(speech, *args, **kwargs)
|
||||
|
||||
# # Add the context to the function
|
||||
# ctx = contextvars.copy_context()
|
||||
# func_with_context = partial(ctx.run, func)
|
||||
# Add the context to the function
|
||||
ctx = contextvars.copy_context()
|
||||
func_with_context = partial(ctx.run, func)
|
||||
|
||||
# _, custom_llm_provider, _, _ = get_llm_provider(
|
||||
# model=model, api_base=kwargs.get("api_base", None)
|
||||
# )
|
||||
_, custom_llm_provider, _, _ = get_llm_provider(
|
||||
model=model, api_base=kwargs.get("api_base", None)
|
||||
)
|
||||
|
||||
# # Await normally
|
||||
# init_response = await loop.run_in_executor(None, func_with_context)
|
||||
# if asyncio.iscoroutine(init_response):
|
||||
# response = await init_response
|
||||
# else:
|
||||
# # Call the synchronous function using run_in_executor
|
||||
# response = await loop.run_in_executor(None, func_with_context)
|
||||
return speech(*args, **kwargs) # type: ignore
|
||||
# Await normally
|
||||
init_response = await loop.run_in_executor(None, func_with_context)
|
||||
if asyncio.iscoroutine(init_response):
|
||||
response = await init_response
|
||||
else:
|
||||
# Call the synchronous function using run_in_executor
|
||||
response = await loop.run_in_executor(None, func_with_context)
|
||||
return response # type: ignore
|
||||
except Exception as e:
|
||||
custom_llm_provider = custom_llm_provider or "openai"
|
||||
raise exception_type(
|
||||
|
@ -4215,12 +4208,12 @@ def speech(
|
|||
model: str,
|
||||
input: str,
|
||||
voice: str,
|
||||
optional_params: dict,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
organization: Optional[str],
|
||||
project: Optional[str],
|
||||
max_retries: int,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
organization: Optional[str] = None,
|
||||
project: Optional[str] = None,
|
||||
max_retries: Optional[int] = None,
|
||||
metadata: Optional[dict] = None,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
response_format: Optional[str] = None,
|
||||
speed: Optional[int] = None,
|
||||
|
@ -4228,7 +4221,8 @@ def speech(
|
|||
headers: Optional[dict] = None,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
aspeech: Optional[bool] = None,
|
||||
) -> ResponseContextManager[StreamedBinaryAPIResponse]:
|
||||
**kwargs,
|
||||
) -> HttpxBinaryResponseContent:
|
||||
|
||||
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore
|
||||
|
||||
|
@ -4236,12 +4230,14 @@ def speech(
|
|||
if response_format is not None:
|
||||
optional_params["response_format"] = response_format
|
||||
if speed is not None:
|
||||
optional_params["speed"] = speed
|
||||
optional_params["speed"] = speed # type: ignore
|
||||
|
||||
if timeout is None:
|
||||
timeout = litellm.request_timeout
|
||||
|
||||
response: Optional[ResponseContextManager[StreamedBinaryAPIResponse]] = None
|
||||
if max_retries is None:
|
||||
max_retries = litellm.num_retries or openai.DEFAULT_MAX_RETRIES
|
||||
response: Optional[HttpxBinaryResponseContent] = None
|
||||
if custom_llm_provider == "openai":
|
||||
api_base = (
|
||||
api_base # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue