mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
Merge pull request #3928 from BerriAI/litellm_audio_speech_endpoint
feat(main.py): support openai tts endpoint
This commit is contained in:
commit
73e3dba2f6
11 changed files with 574 additions and 7 deletions
|
@ -227,7 +227,7 @@ default_team_settings: Optional[List] = None
|
|||
max_user_budget: Optional[float] = None
|
||||
max_end_user_budget: Optional[float] = None
|
||||
#### RELIABILITY ####
|
||||
request_timeout: Optional[float] = 6000
|
||||
request_timeout: float = 6000
|
||||
num_retries: Optional[int] = None # per model endpoint
|
||||
default_fallbacks: Optional[List] = None
|
||||
fallbacks: Optional[List] = None
|
||||
|
@ -304,6 +304,7 @@ api_base = None
|
|||
headers = None
|
||||
api_version = None
|
||||
organization = None
|
||||
project = None
|
||||
config_path = None
|
||||
####### COMPLETION MODELS ###################
|
||||
open_ai_chat_completion_models: List = []
|
||||
|
|
|
@ -26,6 +26,7 @@ import litellm
|
|||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||
from openai import OpenAI, AsyncOpenAI
|
||||
from ..types.llms.openai import *
|
||||
import openai
|
||||
|
||||
|
||||
class OpenAIError(Exception):
|
||||
|
@ -1180,6 +1181,95 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
)
|
||||
raise e
|
||||
|
||||
def audio_speech(
|
||||
self,
|
||||
model: str,
|
||||
input: str,
|
||||
voice: str,
|
||||
optional_params: dict,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
organization: Optional[str],
|
||||
project: Optional[str],
|
||||
max_retries: int,
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
aspeech: Optional[bool] = None,
|
||||
client=None,
|
||||
) -> HttpxBinaryResponseContent:
|
||||
|
||||
if aspeech is not None and aspeech == True:
|
||||
return self.async_audio_speech(
|
||||
model=model,
|
||||
input=input,
|
||||
voice=voice,
|
||||
optional_params=optional_params,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
organization=organization,
|
||||
project=project,
|
||||
max_retries=max_retries,
|
||||
timeout=timeout,
|
||||
client=client,
|
||||
) # type: ignore
|
||||
|
||||
if client is None:
|
||||
openai_client = OpenAI(
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
organization=organization,
|
||||
project=project,
|
||||
http_client=litellm.client_session,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
)
|
||||
else:
|
||||
openai_client = client
|
||||
|
||||
response = openai_client.audio.speech.create(
|
||||
model=model,
|
||||
voice=voice, # type: ignore
|
||||
input=input,
|
||||
**optional_params,
|
||||
)
|
||||
return response
|
||||
|
||||
async def async_audio_speech(
|
||||
self,
|
||||
model: str,
|
||||
input: str,
|
||||
voice: str,
|
||||
optional_params: dict,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
organization: Optional[str],
|
||||
project: Optional[str],
|
||||
max_retries: int,
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
client=None,
|
||||
) -> HttpxBinaryResponseContent:
|
||||
|
||||
if client is None:
|
||||
openai_client = AsyncOpenAI(
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
organization=organization,
|
||||
project=project,
|
||||
http_client=litellm.aclient_session,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
)
|
||||
else:
|
||||
openai_client = client
|
||||
|
||||
response = await openai_client.audio.speech.create(
|
||||
model=model,
|
||||
voice=voice, # type: ignore
|
||||
input=input,
|
||||
**optional_params,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
async def ahealth_check(
|
||||
self,
|
||||
model: Optional[str],
|
||||
|
|
132
litellm/main.py
132
litellm/main.py
|
@ -91,6 +91,7 @@ import tiktoken
|
|||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Callable, List, Optional, Dict, Union, Mapping
|
||||
from .caching import enable_cache, disable_cache, update_cache
|
||||
from .types.llms.openai import HttpxBinaryResponseContent
|
||||
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
from litellm.utils import (
|
||||
|
@ -4163,6 +4164,137 @@ def transcription(
|
|||
return response
|
||||
|
||||
|
||||
@client
|
||||
async def aspeech(*args, **kwargs) -> HttpxBinaryResponseContent:
|
||||
"""
|
||||
Calls openai tts endpoints.
|
||||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
model = args[0] if len(args) > 0 else kwargs["model"]
|
||||
### PASS ARGS TO Image Generation ###
|
||||
kwargs["aspeech"] = True
|
||||
custom_llm_provider = kwargs.get("custom_llm_provider", None)
|
||||
try:
|
||||
# Use a partial function to pass your keyword arguments
|
||||
func = partial(speech, *args, **kwargs)
|
||||
|
||||
# Add the context to the function
|
||||
ctx = contextvars.copy_context()
|
||||
func_with_context = partial(ctx.run, func)
|
||||
|
||||
_, custom_llm_provider, _, _ = get_llm_provider(
|
||||
model=model, api_base=kwargs.get("api_base", None)
|
||||
)
|
||||
|
||||
# Await normally
|
||||
init_response = await loop.run_in_executor(None, func_with_context)
|
||||
if asyncio.iscoroutine(init_response):
|
||||
response = await init_response
|
||||
else:
|
||||
# Call the synchronous function using run_in_executor
|
||||
response = await loop.run_in_executor(None, func_with_context)
|
||||
return response # type: ignore
|
||||
except Exception as e:
|
||||
custom_llm_provider = custom_llm_provider or "openai"
|
||||
raise exception_type(
|
||||
model=model,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
original_exception=e,
|
||||
completion_kwargs=args,
|
||||
extra_kwargs=kwargs,
|
||||
)
|
||||
|
||||
|
||||
@client
|
||||
def speech(
|
||||
model: str,
|
||||
input: str,
|
||||
voice: str,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
organization: Optional[str] = None,
|
||||
project: Optional[str] = None,
|
||||
max_retries: Optional[int] = None,
|
||||
metadata: Optional[dict] = None,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
response_format: Optional[str] = None,
|
||||
speed: Optional[int] = None,
|
||||
client=None,
|
||||
headers: Optional[dict] = None,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
aspeech: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> HttpxBinaryResponseContent:
|
||||
|
||||
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore
|
||||
|
||||
optional_params = {}
|
||||
if response_format is not None:
|
||||
optional_params["response_format"] = response_format
|
||||
if speed is not None:
|
||||
optional_params["speed"] = speed # type: ignore
|
||||
|
||||
if timeout is None:
|
||||
timeout = litellm.request_timeout
|
||||
|
||||
if max_retries is None:
|
||||
max_retries = litellm.num_retries or openai.DEFAULT_MAX_RETRIES
|
||||
response: Optional[HttpxBinaryResponseContent] = None
|
||||
if custom_llm_provider == "openai":
|
||||
api_base = (
|
||||
api_base # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
||||
or litellm.api_base
|
||||
or get_secret("OPENAI_API_BASE")
|
||||
or "https://api.openai.com/v1"
|
||||
) # type: ignore
|
||||
# set API KEY
|
||||
api_key = (
|
||||
api_key
|
||||
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
|
||||
or litellm.openai_key
|
||||
or get_secret("OPENAI_API_KEY")
|
||||
) # type: ignore
|
||||
|
||||
organization = (
|
||||
organization
|
||||
or litellm.organization
|
||||
or get_secret("OPENAI_ORGANIZATION")
|
||||
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
|
||||
) # type: ignore
|
||||
|
||||
project = (
|
||||
project
|
||||
or litellm.project
|
||||
or get_secret("OPENAI_PROJECT")
|
||||
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
|
||||
) # type: ignore
|
||||
|
||||
headers = headers or litellm.headers
|
||||
|
||||
response = openai_chat_completions.audio_speech(
|
||||
model=model,
|
||||
input=input,
|
||||
voice=voice,
|
||||
optional_params=optional_params,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
organization=organization,
|
||||
project=project,
|
||||
max_retries=max_retries,
|
||||
timeout=timeout,
|
||||
client=client, # pass AsyncOpenAI, OpenAI client
|
||||
aspeech=aspeech,
|
||||
)
|
||||
|
||||
if response is None:
|
||||
raise Exception(
|
||||
"Unable to map the custom llm provider={} to a known provider={}.".format(
|
||||
custom_llm_provider, litellm.provider_list
|
||||
)
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
##### Health Endpoints #######################
|
||||
|
||||
|
||||
|
|
|
@ -26,10 +26,8 @@ model_list:
|
|||
api_version: '2023-05-15'
|
||||
model: azure/chatgpt-v-2
|
||||
model_name: gpt-3.5-turbo
|
||||
- model_name: mistral
|
||||
- model_name: tts
|
||||
litellm_params:
|
||||
model: azure/mistral-large-latest
|
||||
api_base: https://Mistral-large-nmefg-serverless.eastus2.inference.ai.azure.com/v1/
|
||||
api_key: zEJhgmw1FAKk0XzPWoLEg7WU1cXbWYYn
|
||||
model: openai/tts-1
|
||||
router_settings:
|
||||
enable_pre_call_checks: true
|
||||
|
|
|
@ -79,6 +79,9 @@ def generate_feedback_box():
|
|||
|
||||
|
||||
import litellm
|
||||
from litellm.types.llms.openai import (
|
||||
HttpxBinaryResponseContent,
|
||||
)
|
||||
from litellm.proxy.utils import (
|
||||
PrismaClient,
|
||||
DBClient,
|
||||
|
@ -4883,6 +4886,169 @@ async def image_generation(
|
|||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/audio/speech",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
tags=["audio"],
|
||||
)
|
||||
@router.post(
|
||||
"/audio/speech",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
tags=["audio"],
|
||||
)
|
||||
async def audio_speech(
|
||||
request: Request,
|
||||
fastapi_response: Response,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Same params as:
|
||||
|
||||
https://platform.openai.com/docs/api-reference/audio/createSpeech
|
||||
"""
|
||||
global proxy_logging_obj
|
||||
data: Dict = {}
|
||||
try:
|
||||
# Use orjson to parse JSON data, orjson speeds up requests significantly
|
||||
body = await request.body()
|
||||
data = orjson.loads(body)
|
||||
|
||||
# Include original request and headers in the data
|
||||
data["proxy_server_request"] = { # type: ignore
|
||||
"url": str(request.url),
|
||||
"method": request.method,
|
||||
"headers": dict(request.headers),
|
||||
"body": copy.copy(data), # use copy instead of deepcopy
|
||||
}
|
||||
|
||||
if data.get("user", None) is None and user_api_key_dict.user_id is not None:
|
||||
data["user"] = user_api_key_dict.user_id
|
||||
|
||||
if user_model:
|
||||
data["model"] = user_model
|
||||
|
||||
if "metadata" not in data:
|
||||
data["metadata"] = {}
|
||||
data["metadata"]["user_api_key"] = user_api_key_dict.api_key
|
||||
data["metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata
|
||||
_headers = dict(request.headers)
|
||||
_headers.pop(
|
||||
"authorization", None
|
||||
) # do not store the original `sk-..` api key in the db
|
||||
data["metadata"]["headers"] = _headers
|
||||
data["metadata"]["user_api_key_alias"] = getattr(
|
||||
user_api_key_dict, "key_alias", None
|
||||
)
|
||||
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
|
||||
data["metadata"]["user_api_key_team_id"] = getattr(
|
||||
user_api_key_dict, "team_id", None
|
||||
)
|
||||
data["metadata"]["global_max_parallel_requests"] = general_settings.get(
|
||||
"global_max_parallel_requests", None
|
||||
)
|
||||
data["metadata"]["user_api_key_team_alias"] = getattr(
|
||||
user_api_key_dict, "team_alias", None
|
||||
)
|
||||
data["metadata"]["endpoint"] = str(request.url)
|
||||
|
||||
### TEAM-SPECIFIC PARAMS ###
|
||||
if user_api_key_dict.team_id is not None:
|
||||
team_config = await proxy_config.load_team_config(
|
||||
team_id=user_api_key_dict.team_id
|
||||
)
|
||||
if len(team_config) == 0:
|
||||
pass
|
||||
else:
|
||||
team_id = team_config.pop("team_id", None)
|
||||
data["metadata"]["team_id"] = team_id
|
||||
data = {
|
||||
**team_config,
|
||||
**data,
|
||||
} # add the team-specific configs to the completion call
|
||||
|
||||
router_model_names = llm_router.model_names if llm_router is not None else []
|
||||
|
||||
### CALL HOOKS ### - modify incoming data / reject request before calling the model
|
||||
data = await proxy_logging_obj.pre_call_hook(
|
||||
user_api_key_dict=user_api_key_dict, data=data, call_type="image_generation"
|
||||
)
|
||||
|
||||
## ROUTE TO CORRECT ENDPOINT ##
|
||||
# skip router if user passed their key
|
||||
if "api_key" in data:
|
||||
response = await litellm.aspeech(**data)
|
||||
elif (
|
||||
llm_router is not None and data["model"] in router_model_names
|
||||
): # model in router model list
|
||||
response = await llm_router.aspeech(**data)
|
||||
elif (
|
||||
llm_router is not None and data["model"] in llm_router.deployment_names
|
||||
): # model in router deployments, calling a specific deployment on the router
|
||||
response = await llm_router.aspeech(**data, specific_deployment=True)
|
||||
elif (
|
||||
llm_router is not None
|
||||
and llm_router.model_group_alias is not None
|
||||
and data["model"] in llm_router.model_group_alias
|
||||
): # model set in model_group_alias
|
||||
response = await llm_router.aspeech(
|
||||
**data
|
||||
) # ensure this goes the llm_router, router will do the correct alias mapping
|
||||
elif (
|
||||
llm_router is not None
|
||||
and data["model"] not in router_model_names
|
||||
and llm_router.default_deployment is not None
|
||||
): # model in router deployments, calling a specific deployment on the router
|
||||
response = await llm_router.aspeech(**data)
|
||||
elif user_model is not None: # `litellm --model <your-model-name>`
|
||||
response = await litellm.aspeech(**data)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={
|
||||
"error": "audio_speech: Invalid model name passed in model="
|
||||
+ data.get("model", "")
|
||||
},
|
||||
)
|
||||
|
||||
### ALERTING ###
|
||||
data["litellm_status"] = "success" # used for alerting
|
||||
|
||||
### RESPONSE HEADERS ###
|
||||
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
||||
model_id = hidden_params.get("model_id", None) or ""
|
||||
cache_key = hidden_params.get("cache_key", None) or ""
|
||||
api_base = hidden_params.get("api_base", None) or ""
|
||||
|
||||
# Printing each chunk size
|
||||
async def generate(_response: HttpxBinaryResponseContent):
|
||||
_generator = await _response.aiter_bytes(chunk_size=1024)
|
||||
async for chunk in _generator:
|
||||
yield chunk
|
||||
|
||||
custom_headers = get_custom_headers(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
model_id=model_id,
|
||||
cache_key=cache_key,
|
||||
api_base=api_base,
|
||||
version=version,
|
||||
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
||||
fastest_response_batch_completion=None,
|
||||
)
|
||||
|
||||
selected_data_generator = select_data_generator(
|
||||
response=response,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
request_data=data,
|
||||
)
|
||||
return StreamingResponse(
|
||||
generate(response), media_type="audio/mpeg", headers=custom_headers
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
raise e
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/audio/transcriptions",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
|
|
|
@ -1204,6 +1204,84 @@ class Router:
|
|||
self.fail_calls[model_name] += 1
|
||||
raise e
|
||||
|
||||
async def aspeech(self, model: str, input: str, voice: str, **kwargs):
|
||||
"""
|
||||
Example Usage:
|
||||
|
||||
```
|
||||
from litellm import Router
|
||||
client = Router(model_list = [
|
||||
{
|
||||
"model_name": "tts",
|
||||
"litellm_params": {
|
||||
"model": "tts-1",
|
||||
},
|
||||
},
|
||||
])
|
||||
|
||||
async with client.aspeech(
|
||||
model="tts",
|
||||
voice="alloy",
|
||||
input="the quick brown fox jumped over the lazy dogs",
|
||||
api_base=None,
|
||||
api_key=None,
|
||||
organization=None,
|
||||
project=None,
|
||||
max_retries=1,
|
||||
timeout=600,
|
||||
client=None,
|
||||
optional_params={},
|
||||
) as response:
|
||||
response.stream_to_file(speech_file_path)
|
||||
|
||||
```
|
||||
"""
|
||||
try:
|
||||
kwargs["input"] = input
|
||||
kwargs["voice"] = voice
|
||||
|
||||
deployment = await self.async_get_available_deployment(
|
||||
model=model,
|
||||
messages=[{"role": "user", "content": "prompt"}],
|
||||
specific_deployment=kwargs.pop("specific_deployment", None),
|
||||
)
|
||||
kwargs.setdefault("metadata", {}).update(
|
||||
{
|
||||
"deployment": deployment["litellm_params"]["model"],
|
||||
"model_info": deployment.get("model_info", {}),
|
||||
}
|
||||
)
|
||||
kwargs["model_info"] = deployment.get("model_info", {})
|
||||
data = deployment["litellm_params"].copy()
|
||||
model_name = data["model"]
|
||||
for k, v in self.default_litellm_params.items():
|
||||
if (
|
||||
k not in kwargs
|
||||
): # prioritize model-specific params > default router params
|
||||
kwargs[k] = v
|
||||
elif k == "metadata":
|
||||
kwargs[k].update(v)
|
||||
|
||||
potential_model_client = self._get_client(
|
||||
deployment=deployment, kwargs=kwargs, client_type="async"
|
||||
)
|
||||
# check if provided keys == client keys #
|
||||
dynamic_api_key = kwargs.get("api_key", None)
|
||||
if (
|
||||
dynamic_api_key is not None
|
||||
and potential_model_client is not None
|
||||
and dynamic_api_key != potential_model_client.api_key
|
||||
):
|
||||
model_client = None
|
||||
else:
|
||||
model_client = potential_model_client
|
||||
|
||||
response = await litellm.aspeech(**data, **kwargs)
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
async def amoderation(self, model: str, input: str, **kwargs):
|
||||
try:
|
||||
kwargs["model"] = model
|
||||
|
|
96
litellm/tests/test_audio_speech.py
Normal file
96
litellm/tests/test_audio_speech.py
Normal file
|
@ -0,0 +1,96 @@
|
|||
# What is this?
|
||||
## unit tests for openai tts endpoint
|
||||
|
||||
import sys, os, asyncio, time, random, uuid
|
||||
import traceback
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
import os
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import pytest
|
||||
import litellm, openai
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||
@pytest.mark.asyncio
|
||||
async def test_audio_speech_litellm(sync_mode):
|
||||
speech_file_path = Path(__file__).parent / "speech.mp3"
|
||||
|
||||
if sync_mode:
|
||||
response = litellm.speech(
|
||||
model="openai/tts-1",
|
||||
voice="alloy",
|
||||
input="the quick brown fox jumped over the lazy dogs",
|
||||
api_base=None,
|
||||
api_key=None,
|
||||
organization=None,
|
||||
project=None,
|
||||
max_retries=1,
|
||||
timeout=600,
|
||||
client=None,
|
||||
optional_params={},
|
||||
)
|
||||
|
||||
from litellm.llms.openai import HttpxBinaryResponseContent
|
||||
|
||||
assert isinstance(response, HttpxBinaryResponseContent)
|
||||
else:
|
||||
response = await litellm.aspeech(
|
||||
model="openai/tts-1",
|
||||
voice="alloy",
|
||||
input="the quick brown fox jumped over the lazy dogs",
|
||||
api_base=None,
|
||||
api_key=None,
|
||||
organization=None,
|
||||
project=None,
|
||||
max_retries=1,
|
||||
timeout=600,
|
||||
client=None,
|
||||
optional_params={},
|
||||
)
|
||||
|
||||
from litellm.llms.openai import HttpxBinaryResponseContent
|
||||
|
||||
assert isinstance(response, HttpxBinaryResponseContent)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("mode", ["iterator"]) # "file",
|
||||
@pytest.mark.asyncio
|
||||
async def test_audio_speech_router(mode):
|
||||
speech_file_path = Path(__file__).parent / "speech.mp3"
|
||||
|
||||
from litellm import Router
|
||||
|
||||
client = Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "tts",
|
||||
"litellm_params": {
|
||||
"model": "openai/tts-1",
|
||||
},
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
response = await client.aspeech(
|
||||
model="tts",
|
||||
voice="alloy",
|
||||
input="the quick brown fox jumped over the lazy dogs",
|
||||
api_base=None,
|
||||
api_key=None,
|
||||
organization=None,
|
||||
project=None,
|
||||
max_retries=1,
|
||||
timeout=600,
|
||||
client=None,
|
||||
optional_params={},
|
||||
)
|
||||
|
||||
from litellm.llms.openai import HttpxBinaryResponseContent
|
||||
|
||||
assert isinstance(response, HttpxBinaryResponseContent)
|
|
@ -8,7 +8,6 @@ from typing import (
|
|||
)
|
||||
from typing_extensions import override, Required, Dict
|
||||
from pydantic import BaseModel
|
||||
|
||||
from openai.types.beta.threads.message_content import MessageContent
|
||||
from openai.types.beta.threads.message import Message as OpenAIMessage
|
||||
from openai.types.beta.thread_create_params import (
|
||||
|
@ -21,7 +20,6 @@ from openai.pagination import SyncCursorPage
|
|||
from os import PathLike
|
||||
from openai.types import FileObject, Batch
|
||||
from openai._legacy_response import HttpxBinaryResponseContent
|
||||
|
||||
from typing import TypedDict, List, Optional, Tuple, Mapping, IO
|
||||
|
||||
FileContent = Union[IO[bytes], bytes, PathLike]
|
||||
|
|
|
@ -1136,6 +1136,8 @@ class CallTypes(Enum):
|
|||
amoderation = "amoderation"
|
||||
atranscription = "atranscription"
|
||||
transcription = "transcription"
|
||||
aspeech = "aspeech"
|
||||
speech = "speech"
|
||||
|
||||
|
||||
# Logging function -> log the exact model details + what's being sent | Non-BlockingP
|
||||
|
@ -3005,6 +3007,10 @@ def function_setup(
|
|||
):
|
||||
_file_name: BinaryIO = args[1] if len(args) > 1 else kwargs["file"]
|
||||
messages = "audio_file"
|
||||
elif (
|
||||
call_type == CallTypes.aspeech.value or call_type == CallTypes.speech.value
|
||||
):
|
||||
messages = kwargs.get("input", "speech")
|
||||
stream = True if "stream" in kwargs and kwargs["stream"] == True else False
|
||||
logging_obj = Logging(
|
||||
model=model,
|
||||
|
@ -3346,6 +3352,8 @@ def client(original_function):
|
|||
return result
|
||||
elif "atranscription" in kwargs and kwargs["atranscription"] == True:
|
||||
return result
|
||||
elif "aspeech" in kwargs and kwargs["aspeech"] == True:
|
||||
return result
|
||||
|
||||
### POST-CALL RULES ###
|
||||
post_call_processing(original_response=result, model=model or None)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue