Litellm openai audio streaming (#6325)

* refactor(main.py): streaming_chunk_builder

use <100 lines of code

refactor each component into a separate function - easier to maintain + test

* fix(utils.py): handle choices being None

openai pydantic schema updated

* fix(main.py): fix linting error

* feat(streaming_chunk_builder_utils.py): update stream chunk builder to support rebuilding audio chunks from openai

* test(test_custom_callback_input.py): test message redaction works for audio output

* fix(streaming_chunk_builder_utils.py): return anthropic token usage info directly

* fix(stream_chunk_builder_utils.py): run validation check before entering chunk processor

* fix(main.py): fix import
This commit is contained in:
Krish Dholakia 2024-10-19 16:16:51 -07:00 committed by GitHub
parent 979e8ea526
commit c58d542282
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 638 additions and 282 deletions

View file

@ -0,0 +1,355 @@
import base64
import time
from typing import Any, Dict, List, Optional, Union
from litellm.exceptions import APIError
from litellm.types.llms.openai import (
ChatCompletionAssistantContentValue,
ChatCompletionAudioDelta,
ChatCompletionToolCallChunk,
ChatCompletionToolCallFunctionChunk,
)
from litellm.types.utils import (
ChatCompletionAudioResponse,
ChatCompletionMessageToolCall,
CompletionTokensDetails,
Function,
FunctionCall,
ModelResponse,
PromptTokensDetails,
Usage,
)
from litellm.utils import print_verbose, token_counter
class ChunkProcessor:
def __init__(self, chunks: List, messages: Optional[list] = None):
self.chunks = self._sort_chunks(chunks)
self.messages = messages
self.first_chunk = chunks[0]
def _sort_chunks(self, chunks: list) -> list:
if not chunks:
return []
if chunks[0]._hidden_params.get("created_at"):
return sorted(
chunks, key=lambda x: x._hidden_params.get("created_at", float("inf"))
)
return chunks
def update_model_response_with_hidden_params(
self, model_response: ModelResponse, chunk: Optional[Dict[str, Any]] = None
) -> ModelResponse:
if chunk is None:
return model_response
# set hidden params from chunk to model_response
if model_response is not None and hasattr(model_response, "_hidden_params"):
model_response._hidden_params = chunk.get("_hidden_params", {})
return model_response
def build_base_response(self, chunks: List[Dict[str, Any]]) -> ModelResponse:
chunk = self.first_chunk
id = chunk["id"]
object = chunk["object"]
created = chunk["created"]
model = chunk["model"]
system_fingerprint = chunk.get("system_fingerprint", None)
role = chunk["choices"][0]["delta"]["role"]
finish_reason = "stop"
for chunk in chunks:
if "choices" in chunk and len(chunk["choices"]) > 0:
if hasattr(chunk["choices"][0], "finish_reason"):
finish_reason = chunk["choices"][0].finish_reason
elif "finish_reason" in chunk["choices"][0]:
finish_reason = chunk["choices"][0]["finish_reason"]
# Initialize the response dictionary
response = ModelResponse(
**{
"id": id,
"object": object,
"created": created,
"model": model,
"system_fingerprint": system_fingerprint,
"choices": [
{
"index": 0,
"message": {"role": role, "content": ""},
"finish_reason": finish_reason,
}
],
"usage": {
"prompt_tokens": 0, # Modify as needed
"completion_tokens": 0, # Modify as needed
"total_tokens": 0, # Modify as needed
},
}
)
response = self.update_model_response_with_hidden_params(
model_response=response, chunk=chunk
)
return response
def get_combined_tool_content(
self, tool_call_chunks: List[Dict[str, Any]]
) -> List[ChatCompletionMessageToolCall]:
argument_list: List = []
delta = tool_call_chunks[0]["choices"][0]["delta"]
id = None
name = None
type = None
tool_calls_list: List[ChatCompletionMessageToolCall] = []
prev_index = None
prev_name = None
prev_id = None
curr_id = None
curr_index = 0
for chunk in tool_call_chunks:
choices = chunk["choices"]
for choice in choices:
delta = choice.get("delta", {})
tool_calls = delta.get("tool_calls", "")
# Check if a tool call is present
if tool_calls and tool_calls[0].function is not None:
if tool_calls[0].id:
id = tool_calls[0].id
curr_id = id
if prev_id is None:
prev_id = curr_id
if tool_calls[0].index:
curr_index = tool_calls[0].index
if tool_calls[0].function.arguments:
# Now, tool_calls is expected to be a dictionary
arguments = tool_calls[0].function.arguments
argument_list.append(arguments)
if tool_calls[0].function.name:
name = tool_calls[0].function.name
if tool_calls[0].type:
type = tool_calls[0].type
if prev_index is None:
prev_index = curr_index
if prev_name is None:
prev_name = name
if curr_index != prev_index: # new tool call
combined_arguments = "".join(argument_list)
tool_calls_list.append(
ChatCompletionMessageToolCall(
id=prev_id,
function=Function(
arguments=combined_arguments,
name=prev_name,
),
type=type,
)
)
argument_list = [] # reset
prev_index = curr_index
prev_id = curr_id
prev_name = name
combined_arguments = (
"".join(argument_list) or "{}"
) # base case, return empty dict
tool_calls_list.append(
ChatCompletionMessageToolCall(
id=id,
type="function",
function=Function(
arguments=combined_arguments,
name=name,
),
)
)
return tool_calls_list
def get_combined_function_call_content(
self, function_call_chunks: List[Dict[str, Any]]
) -> FunctionCall:
argument_list = []
delta = function_call_chunks[0]["choices"][0]["delta"]
function_call = delta.get("function_call", "")
function_call_name = function_call.name
for chunk in function_call_chunks:
choices = chunk["choices"]
for choice in choices:
delta = choice.get("delta", {})
function_call = delta.get("function_call", "")
# Check if a function call is present
if function_call:
# Now, function_call is expected to be a dictionary
arguments = function_call.arguments
argument_list.append(arguments)
combined_arguments = "".join(argument_list)
return FunctionCall(
name=function_call_name,
arguments=combined_arguments,
)
def get_combined_content(
self, chunks: List[Dict[str, Any]]
) -> ChatCompletionAssistantContentValue:
content_list: List[str] = []
for chunk in chunks:
choices = chunk["choices"]
for choice in choices:
delta = choice.get("delta", {})
content = delta.get("content", "")
if content is None:
continue # openai v1.0.0 sets content = None for chunks
content_list.append(content)
# Combine the "content" strings into a single string || combine the 'function' strings into a single string
combined_content = "".join(content_list)
# Update the "content" field within the response dictionary
return combined_content
def get_combined_audio_content(
self, chunks: List[Dict[str, Any]]
) -> ChatCompletionAudioResponse:
base64_data_list: List[str] = []
transcript_list: List[str] = []
expires_at: Optional[int] = None
id: Optional[str] = None
for chunk in chunks:
choices = chunk["choices"]
for choice in choices:
delta = choice.get("delta") or {}
audio: Optional[ChatCompletionAudioDelta] = delta.get("audio")
if audio is not None:
for k, v in audio.items():
if k == "data" and v is not None and isinstance(v, str):
base64_data_list.append(v)
elif k == "transcript" and v is not None and isinstance(v, str):
transcript_list.append(v)
elif k == "expires_at" and v is not None and isinstance(v, int):
expires_at = v
elif k == "id" and v is not None and isinstance(v, str):
id = v
concatenated_audio = concatenate_base64_list(base64_data_list)
return ChatCompletionAudioResponse(
data=concatenated_audio,
expires_at=expires_at or int(time.time() + 3600),
transcript="".join(transcript_list),
id=id,
)
def calculate_usage(
self,
chunks: List[Union[Dict[str, Any], ModelResponse]],
model: str,
completion_output: str,
messages: Optional[List] = None,
) -> Usage:
"""
Calculate usage for the given chunks.
"""
returned_usage = Usage()
# # Update usage information if needed
prompt_tokens = 0
completion_tokens = 0
## anthropic prompt caching information ##
cache_creation_input_tokens: Optional[int] = None
cache_read_input_tokens: Optional[int] = None
completion_tokens_details: Optional[CompletionTokensDetails] = None
prompt_tokens_details: Optional[PromptTokensDetails] = None
for chunk in chunks:
usage_chunk: Optional[Usage] = None
if "usage" in chunk:
usage_chunk = chunk["usage"]
elif isinstance(chunk, ModelResponse) and hasattr(chunk, "_hidden_params"):
usage_chunk = chunk._hidden_params.get("usage", None)
if usage_chunk is not None:
if "prompt_tokens" in usage_chunk:
prompt_tokens = usage_chunk.get("prompt_tokens", 0) or 0
if "completion_tokens" in usage_chunk:
completion_tokens = usage_chunk.get("completion_tokens", 0) or 0
if "cache_creation_input_tokens" in usage_chunk:
cache_creation_input_tokens = usage_chunk.get(
"cache_creation_input_tokens"
)
if "cache_read_input_tokens" in usage_chunk:
cache_read_input_tokens = usage_chunk.get("cache_read_input_tokens")
if hasattr(usage_chunk, "completion_tokens_details"):
if isinstance(usage_chunk.completion_tokens_details, dict):
completion_tokens_details = CompletionTokensDetails(
**usage_chunk.completion_tokens_details
)
elif isinstance(
usage_chunk.completion_tokens_details, CompletionTokensDetails
):
completion_tokens_details = (
usage_chunk.completion_tokens_details
)
if hasattr(usage_chunk, "prompt_tokens_details"):
if isinstance(usage_chunk.prompt_tokens_details, dict):
prompt_tokens_details = PromptTokensDetails(
**usage_chunk.prompt_tokens_details
)
elif isinstance(
usage_chunk.prompt_tokens_details, PromptTokensDetails
):
prompt_tokens_details = usage_chunk.prompt_tokens_details
try:
returned_usage.prompt_tokens = prompt_tokens or token_counter(
model=model, messages=messages
)
except (
Exception
): # don't allow this failing to block a complete streaming response from being returned
print_verbose("token_counter failed, assuming prompt tokens is 0")
returned_usage.prompt_tokens = 0
returned_usage.completion_tokens = completion_tokens or token_counter(
model=model,
text=completion_output,
count_response_tokens=True, # count_response_tokens is a Flag to tell token counter this is a response, No need to add extra tokens we do for input messages
)
returned_usage.total_tokens = (
returned_usage.prompt_tokens + returned_usage.completion_tokens
)
if cache_creation_input_tokens is not None:
returned_usage._cache_creation_input_tokens = cache_creation_input_tokens
setattr(
returned_usage,
"cache_creation_input_tokens",
cache_creation_input_tokens,
) # for anthropic
if cache_read_input_tokens is not None:
returned_usage._cache_read_input_tokens = cache_read_input_tokens
setattr(
returned_usage, "cache_read_input_tokens", cache_read_input_tokens
) # for anthropic
if completion_tokens_details is not None:
returned_usage.completion_tokens_details = completion_tokens_details
if prompt_tokens_details is not None:
returned_usage.prompt_tokens_details = prompt_tokens_details
return returned_usage
def concatenate_base64_list(base64_strings: List[str]) -> str:
"""
Concatenates a list of base64-encoded strings.
Args:
base64_strings (List[str]): A list of base64 strings to concatenate.
Returns:
str: The concatenated result as a base64-encoded string.
"""
# Decode each base64 string and collect the resulting bytes
combined_bytes = b"".join(base64.b64decode(b64_str) for b64_str in base64_strings)
# Encode the concatenated bytes back to base64
return base64.b64encode(combined_bytes).decode("utf-8")

View file

@ -4,12 +4,13 @@ Common utility functions used for translating messages across providers
import json
from copy import deepcopy
from typing import Dict, List, Literal, Optional
from typing import Dict, List, Literal, Optional, Union
import litellm
from litellm.types.llms.openai import (
AllMessageValues,
ChatCompletionAssistantMessage,
ChatCompletionResponseMessage,
ChatCompletionUserMessage,
)
from litellm.types.utils import Choices, ModelResponse, StreamingChoices
@ -67,12 +68,18 @@ def convert_openai_message_to_only_content_messages(
return converted_messages
def get_content_from_model_response(response: ModelResponse) -> str:
def get_content_from_model_response(response: Union[ModelResponse, dict]) -> str:
"""
Gets content from model response
"""
if isinstance(response, dict):
new_response = ModelResponse(**response)
else:
new_response = response
content = ""
for choice in response.choices:
for choice in new_response.choices:
if isinstance(choice, Choices):
content += choice.message.content if choice.message.content else ""
if choice.message.function_call:

View file

@ -23,7 +23,18 @@ from concurrent import futures
from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait
from copy import deepcopy
from functools import partial
from typing import Any, Callable, Dict, List, Literal, Mapping, Optional, Type, Union
from typing import (
Any,
Callable,
Dict,
List,
Literal,
Mapping,
Optional,
Type,
Union,
cast,
)
import dotenv
import httpx
@ -47,6 +58,7 @@ from litellm.litellm_core_utils.mock_functions import (
mock_image_generation,
)
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.llms.prompt_templates.common_utils import get_content_from_model_response
from litellm.secret_managers.main import get_secret_str
from litellm.utils import (
CustomStreamWrapper,
@ -70,6 +82,7 @@ from litellm.utils import (
from ._logging import verbose_logger
from .caching.caching import disable_cache, enable_cache, update_cache
from .litellm_core_utils.streaming_chunk_builder_utils import ChunkProcessor
from .llms import (
aleph_alpha,
baseten,
@ -5390,73 +5403,37 @@ def stream_chunk_builder_text_completion(
return TextCompletionResponse(**response)
def stream_chunk_builder( # noqa: PLR0915
def stream_chunk_builder(
chunks: list, messages: Optional[list] = None, start_time=None, end_time=None
) -> Optional[Union[ModelResponse, TextCompletionResponse]]:
try:
model_response = litellm.ModelResponse()
if chunks is None:
raise litellm.APIError(
status_code=500,
message="Error building chunks for logging/streaming usage calculation",
llm_provider="",
model="",
)
if not chunks:
return None
processor = ChunkProcessor(chunks, messages)
chunks = processor.chunks
### BASE-CASE ###
if len(chunks) == 0:
return None
### SORT CHUNKS BASED ON CREATED ORDER ##
print_verbose("Goes into checking if chunk has hiddden created at param")
if chunks[0]._hidden_params.get("created_at", None):
print_verbose("Chunks have a created at hidden param")
# Sort chunks based on created_at in ascending order
chunks = sorted(
chunks, key=lambda x: x._hidden_params.get("created_at", float("inf"))
)
print_verbose("Chunks sorted")
# set hidden params from chunk to model_response
if model_response is not None and hasattr(model_response, "_hidden_params"):
model_response._hidden_params = chunks[0].get("_hidden_params", {})
id = chunks[0]["id"]
object = chunks[0]["object"]
created = chunks[0]["created"]
model = chunks[0]["model"]
system_fingerprint = chunks[0].get("system_fingerprint", None)
## Route to the text completion logic
if isinstance(
chunks[0]["choices"][0], litellm.utils.TextChoices
): # route to the text completion logic
return stream_chunk_builder_text_completion(
chunks=chunks, messages=messages
)
role = chunks[0]["choices"][0]["delta"]["role"]
finish_reason = "stop"
for chunk in chunks:
if "choices" in chunk and len(chunk["choices"]) > 0:
if hasattr(chunk["choices"][0], "finish_reason"):
finish_reason = chunk["choices"][0].finish_reason
elif "finish_reason" in chunk["choices"][0]:
finish_reason = chunk["choices"][0]["finish_reason"]
model = chunks[0]["model"]
# Initialize the response dictionary
response = {
"id": id,
"object": object,
"created": created,
"model": model,
"system_fingerprint": system_fingerprint,
"choices": [
{
"index": 0,
"message": {"role": role, "content": ""},
"finish_reason": finish_reason,
}
],
"usage": {
"prompt_tokens": 0, # Modify as needed
"completion_tokens": 0, # Modify as needed
"total_tokens": 0, # Modify as needed
},
}
# Extract the "content" strings from the nested dictionaries within "choices"
content_list = []
combined_content = ""
combined_arguments = ""
response = processor.build_base_response(chunks)
tool_call_chunks = [
chunk
@ -5467,75 +5444,10 @@ def stream_chunk_builder( # noqa: PLR0915
]
if len(tool_call_chunks) > 0:
argument_list: List = []
delta = tool_call_chunks[0]["choices"][0]["delta"]
message = response["choices"][0]["message"]
message["tool_calls"] = []
id = None
name = None
type = None
tool_calls_list = []
prev_index = None
prev_name = None
prev_id = None
curr_id = None
curr_index = 0
for chunk in tool_call_chunks:
choices = chunk["choices"]
for choice in choices:
delta = choice.get("delta", {})
tool_calls = delta.get("tool_calls", "")
# Check if a tool call is present
if tool_calls and tool_calls[0].function is not None:
if tool_calls[0].id:
id = tool_calls[0].id
curr_id = id
if prev_id is None:
prev_id = curr_id
if tool_calls[0].index:
curr_index = tool_calls[0].index
if tool_calls[0].function.arguments:
# Now, tool_calls is expected to be a dictionary
arguments = tool_calls[0].function.arguments
argument_list.append(arguments)
if tool_calls[0].function.name:
name = tool_calls[0].function.name
if tool_calls[0].type:
type = tool_calls[0].type
if prev_index is None:
prev_index = curr_index
if prev_name is None:
prev_name = name
if curr_index != prev_index: # new tool call
combined_arguments = "".join(argument_list)
tool_calls_list.append(
{
"id": prev_id,
"function": {
"arguments": combined_arguments,
"name": prev_name,
},
"type": type,
}
)
argument_list = [] # reset
prev_index = curr_index
prev_id = curr_id
prev_name = name
combined_arguments = (
"".join(argument_list) or "{}"
) # base case, return empty dict
tool_calls_list.append(
{
"id": id,
"function": {"arguments": combined_arguments, "name": name},
"type": type,
}
)
response["choices"][0]["message"]["content"] = None
response["choices"][0]["message"]["tool_calls"] = tool_calls_list
tool_calls_list = processor.get_combined_tool_content(tool_call_chunks)
_choice = cast(Choices, response.choices[0])
_choice.message.content = None
_choice.message.tool_calls = tool_calls_list
function_call_chunks = [
chunk
@ -5546,32 +5458,11 @@ def stream_chunk_builder( # noqa: PLR0915
]
if len(function_call_chunks) > 0:
argument_list = []
delta = function_call_chunks[0]["choices"][0]["delta"]
function_call = delta.get("function_call", "")
function_call_name = function_call.name
message = response["choices"][0]["message"]
message["function_call"] = {}
message["function_call"]["name"] = function_call_name
for chunk in function_call_chunks:
choices = chunk["choices"]
for choice in choices:
delta = choice.get("delta", {})
function_call = delta.get("function_call", "")
# Check if a function call is present
if function_call:
# Now, function_call is expected to be a dictionary
arguments = function_call.arguments
argument_list.append(arguments)
combined_arguments = "".join(argument_list)
response["choices"][0]["message"]["content"] = None
response["choices"][0]["message"]["function_call"][
"arguments"
] = combined_arguments
_choice = cast(Choices, response.choices[0])
_choice.message.content = None
_choice.message.function_call = (
processor.get_combined_function_call_content(function_call_chunks)
)
content_chunks = [
chunk
@ -5582,109 +5473,34 @@ def stream_chunk_builder( # noqa: PLR0915
]
if len(content_chunks) > 0:
for chunk in chunks:
choices = chunk["choices"]
for choice in choices:
delta = choice.get("delta", {})
content = delta.get("content", "")
if content is None:
continue # openai v1.0.0 sets content = None for chunks
content_list.append(content)
# Combine the "content" strings into a single string || combine the 'function' strings into a single string
combined_content = "".join(content_list)
# Update the "content" field within the response dictionary
response["choices"][0]["message"]["content"] = combined_content
completion_output = ""
if len(combined_content) > 0:
completion_output += combined_content
if len(combined_arguments) > 0:
completion_output += combined_arguments
# # Update usage information if needed
prompt_tokens = 0
completion_tokens = 0
## anthropic prompt caching information ##
cache_creation_input_tokens: Optional[int] = None
cache_read_input_tokens: Optional[int] = None
completion_tokens_details: Optional[CompletionTokensDetails] = None
prompt_tokens_details: Optional[PromptTokensDetails] = None
for chunk in chunks:
usage_chunk: Optional[Usage] = None
if "usage" in chunk:
usage_chunk = chunk.usage
elif hasattr(chunk, "_hidden_params") and "usage" in chunk._hidden_params:
usage_chunk = chunk._hidden_params["usage"]
if usage_chunk is not None:
if "prompt_tokens" in usage_chunk:
prompt_tokens = usage_chunk.get("prompt_tokens", 0) or 0
if "completion_tokens" in usage_chunk:
completion_tokens = usage_chunk.get("completion_tokens", 0) or 0
if "cache_creation_input_tokens" in usage_chunk:
cache_creation_input_tokens = usage_chunk.get(
"cache_creation_input_tokens"
response["choices"][0]["message"]["content"] = (
processor.get_combined_content(content_chunks)
)
if "cache_read_input_tokens" in usage_chunk:
cache_read_input_tokens = usage_chunk.get("cache_read_input_tokens")
if hasattr(usage_chunk, "completion_tokens_details"):
if isinstance(usage_chunk.completion_tokens_details, dict):
completion_tokens_details = CompletionTokensDetails(
**usage_chunk.completion_tokens_details
)
elif isinstance(
usage_chunk.completion_tokens_details, CompletionTokensDetails
):
completion_tokens_details = (
usage_chunk.completion_tokens_details
)
if hasattr(usage_chunk, "prompt_tokens_details"):
if isinstance(usage_chunk.prompt_tokens_details, dict):
prompt_tokens_details = PromptTokensDetails(
**usage_chunk.prompt_tokens_details
)
elif isinstance(
usage_chunk.prompt_tokens_details, PromptTokensDetails
):
prompt_tokens_details = usage_chunk.prompt_tokens_details
try:
response["usage"]["prompt_tokens"] = prompt_tokens or token_counter(
model=model, messages=messages
)
except (
Exception
): # don't allow this failing to block a complete streaming response from being returned
print_verbose("token_counter failed, assuming prompt tokens is 0")
response["usage"]["prompt_tokens"] = 0
response["usage"]["completion_tokens"] = completion_tokens or token_counter(
audio_chunks = [
chunk
for chunk in chunks
if len(chunk["choices"]) > 0
and "audio" in chunk["choices"][0]["delta"]
and chunk["choices"][0]["delta"]["audio"] is not None
]
if len(audio_chunks) > 0:
_choice = cast(Choices, response.choices[0])
_choice.message.audio = processor.get_combined_audio_content(audio_chunks)
completion_output = get_content_from_model_response(response)
usage = processor.calculate_usage(
chunks=chunks,
model=model,
text=completion_output,
count_response_tokens=True, # count_response_tokens is a Flag to tell token counter this is a response, No need to add extra tokens we do for input messages
)
response["usage"]["total_tokens"] = (
response["usage"]["prompt_tokens"] + response["usage"]["completion_tokens"]
completion_output=completion_output,
messages=messages,
)
if cache_creation_input_tokens is not None:
response["usage"][
"cache_creation_input_tokens"
] = cache_creation_input_tokens
if cache_read_input_tokens is not None:
response["usage"]["cache_read_input_tokens"] = cache_read_input_tokens
setattr(response, "usage", usage)
if completion_tokens_details is not None:
response["usage"]["completion_tokens_details"] = completion_tokens_details
if prompt_tokens_details is not None:
response["usage"]["prompt_tokens_details"] = prompt_tokens_details
return convert_to_model_response_object(
response_object=response,
model_response_object=model_response,
start_time=start_time,
end_time=end_time,
) # type: ignore
return response
except Exception as e:
verbose_logger.exception(
"litellm.main.py::stream_chunk_builder() - Exception occurred - {}".format(

View file

@ -1,12 +1,6 @@
model_list:
- model_name: "gpt-3.5-turbo"
- model_name: "gpt-4o-audio-preview"
litellm_params:
model: gpt-3.5-turbo
model: gpt-4o-audio-preview
api_key: os.environ/OPENAI_API_KEY
litellm_settings:
callbacks: ["argilla"]
argilla_transformation_object:
user_input: "messages"
llm_output: "response"

View file

@ -295,6 +295,13 @@ class ListBatchRequest(TypedDict, total=False):
timeout: Optional[float]
class ChatCompletionAudioDelta(TypedDict, total=False):
data: str
transcript: str
expires_at: int
id: str
class ChatCompletionToolCallFunctionChunk(TypedDict, total=False):
name: Optional[str]
arguments: str
@ -482,8 +489,13 @@ class ChatCompletionDeltaChunk(TypedDict, total=False):
role: str
ChatCompletionAssistantContentValue = (
str # keep as var, used in stream_chunk_builder as well
)
class ChatCompletionResponseMessage(TypedDict, total=False):
content: Optional[str]
content: Optional[ChatCompletionAssistantContentValue]
tool_calls: List[ChatCompletionToolCallChunk]
role: Literal["assistant"]
function_call: ChatCompletionToolCallFunctionChunk

View file

@ -321,7 +321,11 @@ class ChatCompletionMessageToolCall(OpenAIObject):
setattr(self, key, value)
class ChatCompletionAudioResponse(OpenAIObject):
from openai.types.chat.chat_completion_audio import ChatCompletionAudio
class ChatCompletionAudioResponse(ChatCompletionAudio):
def __init__(
self,
data: str,
@ -330,27 +334,13 @@ class ChatCompletionAudioResponse(OpenAIObject):
id: Optional[str] = None,
**params,
):
super(ChatCompletionAudioResponse, self).__init__(**params)
if id is not None:
self.id = id
id = id
else:
self.id = f"{uuid.uuid4()}"
"""Unique identifier for this audio response."""
self.data = data
"""
Base64 encoded audio bytes generated by the model, in the format specified in
the request.
"""
self.expires_at = expires_at
"""
The Unix timestamp (in seconds) for when this audio response will no longer be
accessible on the server for use in multi-turn conversations.
"""
self.transcript = transcript
"""Transcript of the audio generated by the model."""
id = f"{uuid.uuid4()}"
super(ChatCompletionAudioResponse, self).__init__(
data=data, expires_at=expires_at, transcript=transcript, id=id, **params
)
def __contains__(self, key):
# Define custom behavior for the 'in' operator

View file

@ -7573,7 +7573,7 @@ class CustomStreamWrapper:
original_chunk = response_obj.get("original_chunk", None)
model_response.id = original_chunk.id
self.response_id = original_chunk.id
if len(original_chunk.choices) > 0:
if original_chunk.choices and len(original_chunk.choices) > 0:
delta = original_chunk.choices[0].delta
if delta is not None and (
delta.function_call is not None or delta.tool_calls is not None

View file

@ -2365,3 +2365,32 @@ async def test_caching_kwargs_input(sync_mode):
else:
input["original_function"] = acompletion
await llm_caching_handler.async_set_cache(**input)
@pytest.mark.skip(reason="audio caching not supported yet")
@pytest.mark.parametrize("stream", [False]) # True,
@pytest.mark.asyncio()
async def test_audio_caching(stream):
litellm.cache = Cache(type="local")
## CALL 1 - no cache hit
completion = await litellm.acompletion(
model="gpt-4o-audio-preview",
modalities=["text", "audio"],
audio={"voice": "alloy", "format": "pcm16"},
messages=[{"role": "user", "content": "response in 1 word - yes or no"}],
stream=stream,
)
assert "cache_hit" not in completion._hidden_params
## CALL 2 - cache hit
completion = await litellm.acompletion(
model="gpt-4o-audio-preview",
modalities=["text", "audio"],
audio={"voice": "alloy", "format": "pcm16"},
messages=[{"role": "user", "content": "response in 1 word - yes or no"}],
stream=stream,
)
assert "cache_hit" in completion._hidden_params

View file

@ -1267,6 +1267,100 @@ def test_standard_logging_payload(model, turn_off_message_logging):
assert "redacted-by-litellm" == slobject["response"]
@pytest.mark.parametrize(
"stream",
[True, False],
)
@pytest.mark.parametrize(
"turn_off_message_logging",
[
True,
],
) # False
def test_standard_logging_payload_audio(turn_off_message_logging, stream):
"""
Ensure valid standard_logging_payload is passed for logging calls to s3
Motivation: provide a standard set of things that are logged to s3/gcs/future integrations across all llm calls
"""
from litellm.types.utils import StandardLoggingPayload
# sync completion
customHandler = CompletionCustomHandler()
litellm.callbacks = [customHandler]
litellm.turn_off_message_logging = turn_off_message_logging
with patch.object(
customHandler, "log_success_event", new=MagicMock()
) as mock_client:
response = litellm.completion(
model="gpt-4o-audio-preview",
modalities=["text", "audio"],
audio={"voice": "alloy", "format": "pcm16"},
messages=[{"role": "user", "content": "response in 1 word - yes or no"}],
stream=stream,
)
if stream:
for chunk in response:
continue
time.sleep(2)
mock_client.assert_called_once()
print(
f"mock_client_post.call_args: {mock_client.call_args.kwargs['kwargs'].keys()}"
)
assert "standard_logging_object" in mock_client.call_args.kwargs["kwargs"]
assert (
mock_client.call_args.kwargs["kwargs"]["standard_logging_object"]
is not None
)
print(
"Standard Logging Object - {}".format(
mock_client.call_args.kwargs["kwargs"]["standard_logging_object"]
)
)
keys_list = list(StandardLoggingPayload.__annotations__.keys())
for k in keys_list:
assert (
k in mock_client.call_args.kwargs["kwargs"]["standard_logging_object"]
)
## json serializable
json_str_payload = json.dumps(
mock_client.call_args.kwargs["kwargs"]["standard_logging_object"]
)
json.loads(json_str_payload)
## response cost
assert (
mock_client.call_args.kwargs["kwargs"]["standard_logging_object"][
"response_cost"
]
> 0
)
assert (
mock_client.call_args.kwargs["kwargs"]["standard_logging_object"][
"model_map_information"
]["model_map_value"]
is not None
)
## turn off message logging
slobject: StandardLoggingPayload = mock_client.call_args.kwargs["kwargs"][
"standard_logging_object"
]
if turn_off_message_logging:
print("checks redacted-by-litellm")
assert "redacted-by-litellm" == slobject["messages"][0]["content"]
assert "redacted-by-litellm" == slobject["response"]
@pytest.mark.skip(reason="Works locally. Flaky on ci/cd")
def test_aaastandard_logging_payload_cache_hit():
from litellm.types.utils import StandardLoggingPayload

View file

@ -6,6 +6,17 @@ import traceback
import pytest
from typing import List
from litellm.types.utils import StreamingChoices, ChatCompletionAudioResponse
def check_non_streaming_response(completion):
assert completion.choices[0].message.audio is not None, "Audio response is missing"
print("audio", completion.choices[0].message.audio)
assert isinstance(
completion.choices[0].message.audio, ChatCompletionAudioResponse
), "Invalid audio response type"
assert len(completion.choices[0].message.audio.data) > 0, "Audio data is empty"
sys.path.insert(
0, os.path.abspath("../..")
@ -656,12 +667,60 @@ def test_stream_chunk_builder_openai_prompt_caching():
response = stream_chunk_builder(chunks=chunks)
print(f"response: {response}")
print(f"response usage: {response.usage}")
for k, v in usage_obj.model_dump().items():
for k, v in usage_obj.model_dump(exclude_none=True).items():
print(k, v)
response_usage_value = getattr(response.usage, k) # type: ignore
print(f"response_usage_value: {response_usage_value}")
print(f"type: {type(response_usage_value)}")
if isinstance(response_usage_value, BaseModel):
assert response_usage_value.model_dump() == v
assert response_usage_value.model_dump(exclude_none=True) == v
else:
assert response_usage_value == v
def test_stream_chunk_builder_openai_audio_output_usage():
from pydantic import BaseModel
from openai import OpenAI
from typing import Optional
client = OpenAI(
# This is the default and can be omitted
api_key=os.getenv("OPENAI_API_KEY"),
)
completion = client.chat.completions.create(
model="gpt-4o-audio-preview",
modalities=["text", "audio"],
audio={"voice": "alloy", "format": "pcm16"},
messages=[{"role": "user", "content": "response in 1 word - yes or no"}],
stream=True,
stream_options={"include_usage": True},
)
chunks = []
for chunk in completion:
chunks.append(litellm.ModelResponse(**chunk.model_dump(), stream=True))
usage_obj: Optional[litellm.Usage] = None
for index, chunk in enumerate(chunks):
if hasattr(chunk, "usage"):
usage_obj = chunk.usage
print(f"chunk usage: {chunk.usage}")
print(f"index: {index}")
print(f"len chunks: {len(chunks)}")
print(f"usage_obj: {usage_obj}")
response = stream_chunk_builder(chunks=chunks)
print(f"response usage: {response.usage}")
check_non_streaming_response(response)
print(f"response: {response}")
for k, v in usage_obj.model_dump(exclude_none=True).items():
print(k, v)
response_usage_value = getattr(response.usage, k) # type: ignore
print(f"response_usage_value: {response_usage_value}")
print(f"type: {type(response_usage_value)}")
if isinstance(response_usage_value, BaseModel):
assert response_usage_value.model_dump(exclude_none=True) == v
else:
assert response_usage_value == v