(feat) Support audio param in responses streaming (#6312)

* add audio, modalities param

* add test for gpt audio models

* add get_supported_openai_params for GPT audio models

* add supported params for audio

* test_audio_output_from_model

* bump openai to openai==1.52.0

* bump openai on pyproject

* fix audio test

* fix test mock_chat_response

* handle audio for Message

* fix handling audio for OAI compatible API endpoints

* fix linting

* fix mock dbrx test

* add audio to Delta

* handle model_response.choices.delta.audio

* fix linting
This commit is contained in:
Ishaan Jaff 2024-10-18 19:16:14 +05:30 committed by GitHub
parent 13e0b3f626
commit a0d45ba516
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 81 additions and 20 deletions

View file

@ -451,12 +451,18 @@ class Delta(OpenAIObject):
role=None, role=None,
function_call=None, function_call=None,
tool_calls=None, tool_calls=None,
audio: Optional[ChatCompletionAudioResponse] = None,
**params, **params,
): ):
super(Delta, self).__init__(**params) super(Delta, self).__init__(**params)
self.content = content self.content = content
self.role = role self.role = role
# Set default values and correct types
self.function_call: Optional[Union[FunctionCall, Any]] = None
self.tool_calls: Optional[List[Union[ChatCompletionDeltaToolCall, Any]]] = None
self.audio: Optional[ChatCompletionAudioResponse] = None
if function_call is not None and isinstance(function_call, dict): if function_call is not None and isinstance(function_call, dict):
self.function_call = FunctionCall(**function_call) self.function_call = FunctionCall(**function_call)
else: else:
@ -473,6 +479,8 @@ class Delta(OpenAIObject):
else: else:
self.tool_calls = tool_calls self.tool_calls = tool_calls
self.audio = audio
def __contains__(self, key): def __contains__(self, key):
# Define custom behavior for the 'in' operator # Define custom behavior for the 'in' operator
return hasattr(self, key) return hasattr(self, key)

View file

@ -7639,6 +7639,10 @@ class CustomStreamWrapper:
) )
) )
model_response.choices[0].delta = Delta() model_response.choices[0].delta = Delta()
elif (
delta is not None and getattr(delta, "audio", None) is not None
):
model_response.choices[0].delta.audio = delta.audio
else: else:
try: try:
delta = ( delta = (
@ -7805,6 +7809,12 @@ class CustomStreamWrapper:
model_response.choices[0].delta["role"] = "assistant" model_response.choices[0].delta["role"] = "assistant"
self.sent_first_chunk = True self.sent_first_chunk = True
return model_response return model_response
elif (
len(model_response.choices) > 0
and hasattr(model_response.choices[0].delta, "audio")
and model_response.choices[0].delta.audio is not None
):
return model_response
else: else:
if hasattr(model_response, "usage"): if hasattr(model_response, "usage"):
self.chunks.append(model_response) self.chunks.append(model_response)

Binary file not shown.

View file

@ -15,40 +15,74 @@ from respx import MockRouter
import litellm import litellm
from litellm import Choices, Message, ModelResponse from litellm import Choices, Message, ModelResponse
from litellm.types.utils import StreamingChoices, ChatCompletionAudioResponse
import base64 import base64
import requests import requests
def check_non_streaming_response(completion):
assert completion.choices[0].message.audio is not None, "Audio response is missing"
assert isinstance(
completion.choices[0].message.audio, ChatCompletionAudioResponse
), "Invalid audio response type"
assert len(completion.choices[0].message.audio.data) > 0, "Audio data is empty"
async def check_streaming_response(completion):
_audio_bytes = None
_audio_transcript = None
_audio_id = None
async for chunk in completion:
print(chunk)
_choice: StreamingChoices = chunk.choices[0]
if _choice.delta.audio is not None:
if _choice.delta.audio.get("data") is not None:
_audio_bytes = _choice.delta.audio["data"]
if _choice.delta.audio.get("transcript") is not None:
_audio_transcript = _choice.delta.audio["transcript"]
if _choice.delta.audio.get("id") is not None:
_audio_id = _choice.delta.audio["id"]
# Atleast one chunk should have set _audio_bytes, _audio_transcript, _audio_id
assert _audio_bytes is not None
assert _audio_transcript is not None
assert _audio_id is not None
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.flaky(retries=3, delay=1) # @pytest.mark.flaky(retries=3, delay=1)
async def test_audio_output_from_model(): @pytest.mark.parametrize("stream", [True, False])
litellm.set_verbose = True async def test_audio_output_from_model(stream):
audio_format = "pcm16"
if stream is False:
audio_format = "wav"
litellm.set_verbose = False
completion = await litellm.acompletion( completion = await litellm.acompletion(
model="gpt-4o-audio-preview", model="gpt-4o-audio-preview",
modalities=["text", "audio"], modalities=["text", "audio"],
audio={"voice": "alloy", "format": "wav"}, audio={"voice": "alloy", "format": "pcm16"},
messages=[{"role": "user", "content": "response in 1 word - yes or no"}], messages=[{"role": "user", "content": "response in 1 word - yes or no"}],
stream=stream,
) )
print("response= ", completion) if stream is True:
await check_streaming_response(completion)
print(completion.choices[0]) else:
print("response= ", completion)
assert completion.choices[0].message.audio is not None check_non_streaming_response(completion)
assert isinstance( wav_bytes = base64.b64decode(completion.choices[0].message.audio.data)
completion.choices[0].message.audio, with open("dog.wav", "wb") as f:
litellm.types.utils.ChatCompletionAudioResponse, f.write(wav_bytes)
)
assert len(completion.choices[0].message.audio.data) > 0
wav_bytes = base64.b64decode(completion.choices[0].message.audio.data)
with open("dog.wav", "wb") as f:
f.write(wav_bytes)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_audio_input_to_model(): @pytest.mark.parametrize("stream", [True, False])
async def test_audio_input_to_model(stream):
# Fetch the audio file and convert it to a base64 encoded string # Fetch the audio file and convert it to a base64 encoded string
audio_format = "pcm16"
if stream is False:
audio_format = "wav"
litellm.set_verbose = True
url = "https://openaiassets.blob.core.windows.net/$web/API/docs/audio/alloy.wav" url = "https://openaiassets.blob.core.windows.net/$web/API/docs/audio/alloy.wav"
response = requests.get(url) response = requests.get(url)
response.raise_for_status() response.raise_for_status()
@ -58,7 +92,8 @@ async def test_audio_input_to_model():
completion = await litellm.acompletion( completion = await litellm.acompletion(
model="gpt-4o-audio-preview", model="gpt-4o-audio-preview",
modalities=["text", "audio"], modalities=["text", "audio"],
audio={"voice": "alloy", "format": "wav"}, audio={"voice": "alloy", "format": audio_format},
stream=stream,
messages=[ messages=[
{ {
"role": "user", "role": "user",
@ -73,4 +108,12 @@ async def test_audio_input_to_model():
], ],
) )
print(completion.choices[0].message) if stream is True:
await check_streaming_response(completion)
else:
print("response= ", completion)
check_non_streaming_response(completion)
wav_bytes = base64.b64decode(completion.choices[0].message.audio.data)
with open("dog.wav", "wb") as f:
f.write(wav_bytes)