fix(openai.py): fix openai response for /audio/speech endpoint

This commit is contained in:
Krrish Dholakia 2024-05-30 16:41:06 -07:00
parent 1e89a1f56e
commit eb159b64e1
7 changed files with 311 additions and 127 deletions

View file

@ -1195,7 +1195,7 @@ class OpenAIChatCompletion(BaseLLM):
timeout: Union[float, httpx.Timeout], timeout: Union[float, httpx.Timeout],
aspeech: Optional[bool] = None, aspeech: Optional[bool] = None,
client=None, client=None,
) -> ResponseContextManager[StreamedBinaryAPIResponse]: ) -> HttpxBinaryResponseContent:
if aspeech is not None and aspeech == True: if aspeech is not None and aspeech == True:
return self.async_audio_speech( return self.async_audio_speech(
@ -1225,15 +1225,15 @@ class OpenAIChatCompletion(BaseLLM):
else: else:
openai_client = client openai_client = client
response = openai_client.audio.speech.with_streaming_response.create( response = openai_client.audio.speech.create(
model="tts-1", model=model,
voice="alloy", voice=voice, # type: ignore
input="the quick brown fox jumped over the lazy dogs", input=input,
**optional_params, **optional_params,
) )
return response return response
def async_audio_speech( async def async_audio_speech(
self, self,
model: str, model: str,
input: str, input: str,
@ -1246,7 +1246,7 @@ class OpenAIChatCompletion(BaseLLM):
max_retries: int, max_retries: int,
timeout: Union[float, httpx.Timeout], timeout: Union[float, httpx.Timeout],
client=None, client=None,
) -> AsyncResponseContextManager[AsyncStreamedBinaryAPIResponse]: ) -> HttpxBinaryResponseContent:
if client is None: if client is None:
openai_client = AsyncOpenAI( openai_client = AsyncOpenAI(
@ -1261,12 +1261,13 @@ class OpenAIChatCompletion(BaseLLM):
else: else:
openai_client = client openai_client = client
response = openai_client.audio.speech.with_streaming_response.create( response = await openai_client.audio.speech.create(
model="tts-1", model=model,
voice="alloy", voice=voice, # type: ignore
input="the quick brown fox jumped over the lazy dogs", input=input,
**optional_params, **optional_params,
) )
return response return response
async def ahealth_check( async def ahealth_check(

View file

@ -91,12 +91,7 @@ import tiktoken
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from typing import Callable, List, Optional, Dict, Union, Mapping from typing import Callable, List, Optional, Dict, Union, Mapping
from .caching import enable_cache, disable_cache, update_cache from .caching import enable_cache, disable_cache, update_cache
from .types.llms.openai import ( from .types.llms.openai import HttpxBinaryResponseContent
StreamedBinaryAPIResponse,
ResponseContextManager,
AsyncResponseContextManager,
AsyncStreamedBinaryAPIResponse,
)
encoding = tiktoken.get_encoding("cl100k_base") encoding = tiktoken.get_encoding("cl100k_base")
from litellm.utils import ( from litellm.utils import (
@ -4169,9 +4164,7 @@ def transcription(
return response return response
def aspeech( async def aspeech(*args, **kwargs) -> HttpxBinaryResponseContent:
*args, **kwargs
) -> AsyncResponseContextManager[AsyncStreamedBinaryAPIResponse]:
""" """
Calls openai tts endpoints. Calls openai tts endpoints.
""" """
@ -4181,25 +4174,25 @@ def aspeech(
kwargs["aspeech"] = True kwargs["aspeech"] = True
custom_llm_provider = kwargs.get("custom_llm_provider", None) custom_llm_provider = kwargs.get("custom_llm_provider", None)
try: try:
# # Use a partial function to pass your keyword arguments # Use a partial function to pass your keyword arguments
# func = partial(speech, *args, **kwargs) func = partial(speech, *args, **kwargs)
# # Add the context to the function # Add the context to the function
# ctx = contextvars.copy_context() ctx = contextvars.copy_context()
# func_with_context = partial(ctx.run, func) func_with_context = partial(ctx.run, func)
# _, custom_llm_provider, _, _ = get_llm_provider( _, custom_llm_provider, _, _ = get_llm_provider(
# model=model, api_base=kwargs.get("api_base", None) model=model, api_base=kwargs.get("api_base", None)
# ) )
# # Await normally # Await normally
# init_response = await loop.run_in_executor(None, func_with_context) init_response = await loop.run_in_executor(None, func_with_context)
# if asyncio.iscoroutine(init_response): if asyncio.iscoroutine(init_response):
# response = await init_response response = await init_response
# else: else:
# # Call the synchronous function using run_in_executor # Call the synchronous function using run_in_executor
# response = await loop.run_in_executor(None, func_with_context) response = await loop.run_in_executor(None, func_with_context)
return speech(*args, **kwargs) # type: ignore return response # type: ignore
except Exception as e: except Exception as e:
custom_llm_provider = custom_llm_provider or "openai" custom_llm_provider = custom_llm_provider or "openai"
raise exception_type( raise exception_type(
@ -4215,12 +4208,12 @@ def speech(
model: str, model: str,
input: str, input: str,
voice: str, voice: str,
optional_params: dict, api_key: Optional[str] = None,
api_key: Optional[str], api_base: Optional[str] = None,
api_base: Optional[str], organization: Optional[str] = None,
organization: Optional[str], project: Optional[str] = None,
project: Optional[str], max_retries: Optional[int] = None,
max_retries: int, metadata: Optional[dict] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None, timeout: Optional[Union[float, httpx.Timeout]] = None,
response_format: Optional[str] = None, response_format: Optional[str] = None,
speed: Optional[int] = None, speed: Optional[int] = None,
@ -4228,7 +4221,8 @@ def speech(
headers: Optional[dict] = None, headers: Optional[dict] = None,
custom_llm_provider: Optional[str] = None, custom_llm_provider: Optional[str] = None,
aspeech: Optional[bool] = None, aspeech: Optional[bool] = None,
) -> ResponseContextManager[StreamedBinaryAPIResponse]: **kwargs,
) -> HttpxBinaryResponseContent:
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore
@ -4236,12 +4230,14 @@ def speech(
if response_format is not None: if response_format is not None:
optional_params["response_format"] = response_format optional_params["response_format"] = response_format
if speed is not None: if speed is not None:
optional_params["speed"] = speed optional_params["speed"] = speed # type: ignore
if timeout is None: if timeout is None:
timeout = litellm.request_timeout timeout = litellm.request_timeout
response: Optional[ResponseContextManager[StreamedBinaryAPIResponse]] = None if max_retries is None:
max_retries = litellm.num_retries or openai.DEFAULT_MAX_RETRIES
response: Optional[HttpxBinaryResponseContent] = None
if custom_llm_provider == "openai": if custom_llm_provider == "openai":
api_base = ( api_base = (
api_base # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there api_base # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there

View file

@ -1,31 +1,3 @@
general_settings:
alert_to_webhook_url:
budget_alerts: https://hooks.slack.com/services/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
daily_reports: https://hooks.slack.com/services/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
db_exceptions: https://hooks.slack.com/services/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
llm_exceptions: https://hooks.slack.com/services/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
llm_requests_hanging: https://hooks.slack.com/services/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
llm_too_slow: https://hooks.slack.com/services/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
outage_alerts: https://hooks.slack.com/services/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
alert_types:
- llm_exceptions
- llm_too_slow
- llm_requests_hanging
- budget_alerts
- db_exceptions
- daily_reports
- spend_reports
- cooldown_deployment
- new_model_added
- outage_alerts
alerting:
- slack
database_connection_pool_limit: 100
database_connection_timeout: 60
health_check_interval: 300
ui_access_mode: all
# litellm_settings:
# json_logs: true
model_list: model_list:
- litellm_params: - litellm_params:
api_base: http://0.0.0.0:8080 api_base: http://0.0.0.0:8080
@ -52,10 +24,8 @@ model_list:
api_version: '2023-05-15' api_version: '2023-05-15'
model: azure/chatgpt-v-2 model: azure/chatgpt-v-2
model_name: gpt-3.5-turbo model_name: gpt-3.5-turbo
- model_name: mistral - model_name: tts
litellm_params: litellm_params:
model: azure/mistral-large-latest model: openai/tts-1
api_base: https://Mistral-large-nmefg-serverless.eastus2.inference.ai.azure.com/v1/
api_key: zEJhgmw1FAKk0XzPWoLEg7WU1cXbWYYn
router_settings: router_settings:
enable_pre_call_checks: true enable_pre_call_checks: true

View file

@ -79,6 +79,9 @@ def generate_feedback_box():
import litellm import litellm
from litellm.types.llms.openai import (
HttpxBinaryResponseContent,
)
from litellm.proxy.utils import ( from litellm.proxy.utils import (
PrismaClient, PrismaClient,
DBClient, DBClient,
@ -4875,6 +4878,143 @@ async def image_generation(
) )
@router.post(
"/v1/audio/speech",
dependencies=[Depends(user_api_key_auth)],
tags=["audio"],
)
@router.post(
"/audio/speech",
dependencies=[Depends(user_api_key_auth)],
tags=["audio"],
)
async def audio_speech(
request: Request,
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Same params as:
https://platform.openai.com/docs/api-reference/audio/createSpeech
"""
global proxy_logging_obj
data: Dict = {}
try:
# Use orjson to parse JSON data, orjson speeds up requests significantly
body = await request.body()
data = orjson.loads(body)
# Include original request and headers in the data
data["proxy_server_request"] = { # type: ignore
"url": str(request.url),
"method": request.method,
"headers": dict(request.headers),
"body": copy.copy(data), # use copy instead of deepcopy
}
if data.get("user", None) is None and user_api_key_dict.user_id is not None:
data["user"] = user_api_key_dict.user_id
if user_model:
data["model"] = user_model
if "metadata" not in data:
data["metadata"] = {}
data["metadata"]["user_api_key"] = user_api_key_dict.api_key
data["metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata
_headers = dict(request.headers)
_headers.pop(
"authorization", None
) # do not store the original `sk-..` api key in the db
data["metadata"]["headers"] = _headers
data["metadata"]["user_api_key_alias"] = getattr(
user_api_key_dict, "key_alias", None
)
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
data["metadata"]["user_api_key_team_id"] = getattr(
user_api_key_dict, "team_id", None
)
data["metadata"]["global_max_parallel_requests"] = general_settings.get(
"global_max_parallel_requests", None
)
data["metadata"]["user_api_key_team_alias"] = getattr(
user_api_key_dict, "team_alias", None
)
data["metadata"]["endpoint"] = str(request.url)
### TEAM-SPECIFIC PARAMS ###
if user_api_key_dict.team_id is not None:
team_config = await proxy_config.load_team_config(
team_id=user_api_key_dict.team_id
)
if len(team_config) == 0:
pass
else:
team_id = team_config.pop("team_id", None)
data["metadata"]["team_id"] = team_id
data = {
**team_config,
**data,
} # add the team-specific configs to the completion call
router_model_names = llm_router.model_names if llm_router is not None else []
### CALL HOOKS ### - modify incoming data / reject request before calling the model
data = await proxy_logging_obj.pre_call_hook(
user_api_key_dict=user_api_key_dict, data=data, call_type="image_generation"
)
## ROUTE TO CORRECT ENDPOINT ##
# skip router if user passed their key
if "api_key" in data:
response = await litellm.aspeech(**data)
elif (
llm_router is not None and data["model"] in router_model_names
): # model in router model list
response = await llm_router.aspeech(**data)
elif (
llm_router is not None and data["model"] in llm_router.deployment_names
): # model in router deployments, calling a specific deployment on the router
response = await llm_router.aspeech(**data, specific_deployment=True)
elif (
llm_router is not None
and llm_router.model_group_alias is not None
and data["model"] in llm_router.model_group_alias
): # model set in model_group_alias
response = await llm_router.aspeech(
**data
) # ensure this goes the llm_router, router will do the correct alias mapping
elif (
llm_router is not None
and data["model"] not in router_model_names
and llm_router.default_deployment is not None
): # model in router deployments, calling a specific deployment on the router
response = await llm_router.aspeech(**data)
elif user_model is not None: # `litellm --model <your-model-name>`
response = await litellm.aspeech(**data)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"error": "audio_speech: Invalid model name passed in model="
+ data.get("model", "")
},
)
# Printing each chunk size
async def generate(_response: HttpxBinaryResponseContent):
_generator = await _response.aiter_bytes(chunk_size=1024)
async for chunk in _generator:
yield chunk
return StreamingResponse(generate(response), media_type="audio/mpeg")
except Exception as e:
traceback.print_exc()
raise e
@router.post( @router.post(
"/v1/audio/transcriptions", "/v1/audio/transcriptions",
dependencies=[Depends(user_api_key_auth)], dependencies=[Depends(user_api_key_auth)],

View file

@ -1202,6 +1202,84 @@ class Router:
self.fail_calls[model_name] += 1 self.fail_calls[model_name] += 1
raise e raise e
async def aspeech(self, model: str, input: str, voice: str, **kwargs):
"""
Example Usage:
```
from litellm import Router
client = Router(model_list = [
{
"model_name": "tts",
"litellm_params": {
"model": "tts-1",
},
},
])
async with client.aspeech(
model="tts",
voice="alloy",
input="the quick brown fox jumped over the lazy dogs",
api_base=None,
api_key=None,
organization=None,
project=None,
max_retries=1,
timeout=600,
client=None,
optional_params={},
) as response:
response.stream_to_file(speech_file_path)
```
"""
try:
kwargs["input"] = input
kwargs["voice"] = voice
deployment = await self.async_get_available_deployment(
model=model,
messages=[{"role": "user", "content": "prompt"}],
specific_deployment=kwargs.pop("specific_deployment", None),
)
kwargs.setdefault("metadata", {}).update(
{
"deployment": deployment["litellm_params"]["model"],
"model_info": deployment.get("model_info", {}),
}
)
kwargs["model_info"] = deployment.get("model_info", {})
data = deployment["litellm_params"].copy()
model_name = data["model"]
for k, v in self.default_litellm_params.items():
if (
k not in kwargs
): # prioritize model-specific params > default router params
kwargs[k] = v
elif k == "metadata":
kwargs[k].update(v)
potential_model_client = self._get_client(
deployment=deployment, kwargs=kwargs, client_type="async"
)
# check if provided keys == client keys #
dynamic_api_key = kwargs.get("api_key", None)
if (
dynamic_api_key is not None
and potential_model_client is not None
and dynamic_api_key != potential_model_client.api_key
):
model_client = None
else:
model_client = potential_model_client
response = await litellm.aspeech(**data, **kwargs)
return response
except Exception as e:
raise e
async def amoderation(self, model: str, input: str, **kwargs): async def amoderation(self, model: str, input: str, **kwargs):
try: try:
kwargs["model"] = model kwargs["model"] = model

View file

@ -16,51 +16,13 @@ import litellm, openai
from pathlib import Path from pathlib import Path
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_audio_speech_openai(sync_mode):
speech_file_path = Path(__file__).parent / "speech.mp3"
openai_chat_completions = litellm.OpenAIChatCompletion()
if sync_mode:
with openai_chat_completions.audio_speech(
model="tts-1",
voice="alloy",
input="the quick brown fox jumped over the lazy dogs",
api_base=None,
api_key=None,
organization=None,
project=None,
max_retries=1,
timeout=600,
client=None,
optional_params={},
) as response:
response.stream_to_file(speech_file_path)
else:
async with openai_chat_completions.async_audio_speech(
model="tts-1",
voice="alloy",
input="the quick brown fox jumped over the lazy dogs",
api_base=None,
api_key=None,
organization=None,
project=None,
max_retries=1,
timeout=600,
client=None,
optional_params={},
) as response:
speech = await response.parse()
@pytest.mark.parametrize("sync_mode", [True, False]) @pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_audio_speech_litellm(sync_mode): async def test_audio_speech_litellm(sync_mode):
speech_file_path = Path(__file__).parent / "speech.mp3" speech_file_path = Path(__file__).parent / "speech.mp3"
if sync_mode: if sync_mode:
with litellm.speech( response = litellm.speech(
model="openai/tts-1", model="openai/tts-1",
voice="alloy", voice="alloy",
input="the quick brown fox jumped over the lazy dogs", input="the quick brown fox jumped over the lazy dogs",
@ -72,10 +34,13 @@ async def test_audio_speech_litellm(sync_mode):
timeout=600, timeout=600,
client=None, client=None,
optional_params={}, optional_params={},
) as response: )
response.stream_to_file(speech_file_path)
from litellm.llms.openai import HttpxBinaryResponseContent
assert isinstance(response, HttpxBinaryResponseContent)
else: else:
async with litellm.aspeech( response = await litellm.aspeech(
model="openai/tts-1", model="openai/tts-1",
voice="alloy", voice="alloy",
input="the quick brown fox jumped over the lazy dogs", input="the quick brown fox jumped over the lazy dogs",
@ -87,5 +52,45 @@ async def test_audio_speech_litellm(sync_mode):
timeout=600, timeout=600,
client=None, client=None,
optional_params={}, optional_params={},
) as response: )
await response.stream_to_file(speech_file_path)
from litellm.llms.openai import HttpxBinaryResponseContent
assert isinstance(response, HttpxBinaryResponseContent)
@pytest.mark.parametrize("mode", ["iterator"]) # "file",
@pytest.mark.asyncio
async def test_audio_speech_router(mode):
speech_file_path = Path(__file__).parent / "speech.mp3"
from litellm import Router
client = Router(
model_list=[
{
"model_name": "tts",
"litellm_params": {
"model": "openai/tts-1",
},
},
]
)
response = await client.aspeech(
model="tts",
voice="alloy",
input="the quick brown fox jumped over the lazy dogs",
api_base=None,
api_key=None,
organization=None,
project=None,
max_retries=1,
timeout=600,
client=None,
optional_params={},
)
from litellm.llms.openai import HttpxBinaryResponseContent
assert isinstance(response, HttpxBinaryResponseContent)

View file

@ -20,12 +20,6 @@ from openai.pagination import SyncCursorPage
from os import PathLike from os import PathLike
from openai.types import FileObject, Batch from openai.types import FileObject, Batch
from openai._legacy_response import HttpxBinaryResponseContent from openai._legacy_response import HttpxBinaryResponseContent
from openai._response import (
StreamedBinaryAPIResponse,
ResponseContextManager,
AsyncStreamedBinaryAPIResponse,
AsyncResponseContextManager,
)
from typing import TypedDict, List, Optional, Tuple, Mapping, IO from typing import TypedDict, List, Optional, Tuple, Mapping, IO
FileContent = Union[IO[bytes], bytes, PathLike] FileContent = Union[IO[bytes], bytes, PathLike]