Litellm openai audio streaming (#6325)

* refactor(main.py): streaming_chunk_builder

use <100 lines of code

refactor each component into a separate function - easier to maintain + test

* fix(utils.py): handle choices being None

openai pydantic schema updated

* fix(main.py): fix linting error

* feat(streaming_chunk_builder_utils.py): update stream chunk builder to support rebuilding audio chunks from openai

* test(test_custom_callback_input.py): test message redaction works for audio output

* fix(streaming_chunk_builder_utils.py): return anthropic token usage info directly

* fix(stream_chunk_builder_utils.py): run validation check before entering chunk processor

* fix(main.py): fix import
This commit is contained in:
Krish Dholakia 2024-10-19 16:16:51 -07:00 committed by GitHub
parent 979e8ea526
commit c58d542282
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 638 additions and 282 deletions

View file

@ -23,7 +23,18 @@ from concurrent import futures
from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait
from copy import deepcopy
from functools import partial
from typing import Any, Callable, Dict, List, Literal, Mapping, Optional, Type, Union
from typing import (
Any,
Callable,
Dict,
List,
Literal,
Mapping,
Optional,
Type,
Union,
cast,
)
import dotenv
import httpx
@ -47,6 +58,7 @@ from litellm.litellm_core_utils.mock_functions import (
mock_image_generation,
)
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.llms.prompt_templates.common_utils import get_content_from_model_response
from litellm.secret_managers.main import get_secret_str
from litellm.utils import (
CustomStreamWrapper,
@ -70,6 +82,7 @@ from litellm.utils import (
from ._logging import verbose_logger
from .caching.caching import disable_cache, enable_cache, update_cache
from .litellm_core_utils.streaming_chunk_builder_utils import ChunkProcessor
from .llms import (
aleph_alpha,
baseten,
@ -5390,73 +5403,37 @@ def stream_chunk_builder_text_completion(
return TextCompletionResponse(**response)
def stream_chunk_builder( # noqa: PLR0915
def stream_chunk_builder(
chunks: list, messages: Optional[list] = None, start_time=None, end_time=None
) -> Optional[Union[ModelResponse, TextCompletionResponse]]:
try:
model_response = litellm.ModelResponse()
if chunks is None:
raise litellm.APIError(
status_code=500,
message="Error building chunks for logging/streaming usage calculation",
llm_provider="",
model="",
)
if not chunks:
return None
processor = ChunkProcessor(chunks, messages)
chunks = processor.chunks
### BASE-CASE ###
if len(chunks) == 0:
return None
### SORT CHUNKS BASED ON CREATED ORDER ##
print_verbose("Goes into checking if chunk has hiddden created at param")
if chunks[0]._hidden_params.get("created_at", None):
print_verbose("Chunks have a created at hidden param")
# Sort chunks based on created_at in ascending order
chunks = sorted(
chunks, key=lambda x: x._hidden_params.get("created_at", float("inf"))
)
print_verbose("Chunks sorted")
# set hidden params from chunk to model_response
if model_response is not None and hasattr(model_response, "_hidden_params"):
model_response._hidden_params = chunks[0].get("_hidden_params", {})
id = chunks[0]["id"]
object = chunks[0]["object"]
created = chunks[0]["created"]
model = chunks[0]["model"]
system_fingerprint = chunks[0].get("system_fingerprint", None)
## Route to the text completion logic
if isinstance(
chunks[0]["choices"][0], litellm.utils.TextChoices
): # route to the text completion logic
return stream_chunk_builder_text_completion(
chunks=chunks, messages=messages
)
role = chunks[0]["choices"][0]["delta"]["role"]
finish_reason = "stop"
for chunk in chunks:
if "choices" in chunk and len(chunk["choices"]) > 0:
if hasattr(chunk["choices"][0], "finish_reason"):
finish_reason = chunk["choices"][0].finish_reason
elif "finish_reason" in chunk["choices"][0]:
finish_reason = chunk["choices"][0]["finish_reason"]
model = chunks[0]["model"]
# Initialize the response dictionary
response = {
"id": id,
"object": object,
"created": created,
"model": model,
"system_fingerprint": system_fingerprint,
"choices": [
{
"index": 0,
"message": {"role": role, "content": ""},
"finish_reason": finish_reason,
}
],
"usage": {
"prompt_tokens": 0, # Modify as needed
"completion_tokens": 0, # Modify as needed
"total_tokens": 0, # Modify as needed
},
}
# Extract the "content" strings from the nested dictionaries within "choices"
content_list = []
combined_content = ""
combined_arguments = ""
response = processor.build_base_response(chunks)
tool_call_chunks = [
chunk
@ -5467,75 +5444,10 @@ def stream_chunk_builder( # noqa: PLR0915
]
if len(tool_call_chunks) > 0:
argument_list: List = []
delta = tool_call_chunks[0]["choices"][0]["delta"]
message = response["choices"][0]["message"]
message["tool_calls"] = []
id = None
name = None
type = None
tool_calls_list = []
prev_index = None
prev_name = None
prev_id = None
curr_id = None
curr_index = 0
for chunk in tool_call_chunks:
choices = chunk["choices"]
for choice in choices:
delta = choice.get("delta", {})
tool_calls = delta.get("tool_calls", "")
# Check if a tool call is present
if tool_calls and tool_calls[0].function is not None:
if tool_calls[0].id:
id = tool_calls[0].id
curr_id = id
if prev_id is None:
prev_id = curr_id
if tool_calls[0].index:
curr_index = tool_calls[0].index
if tool_calls[0].function.arguments:
# Now, tool_calls is expected to be a dictionary
arguments = tool_calls[0].function.arguments
argument_list.append(arguments)
if tool_calls[0].function.name:
name = tool_calls[0].function.name
if tool_calls[0].type:
type = tool_calls[0].type
if prev_index is None:
prev_index = curr_index
if prev_name is None:
prev_name = name
if curr_index != prev_index: # new tool call
combined_arguments = "".join(argument_list)
tool_calls_list.append(
{
"id": prev_id,
"function": {
"arguments": combined_arguments,
"name": prev_name,
},
"type": type,
}
)
argument_list = [] # reset
prev_index = curr_index
prev_id = curr_id
prev_name = name
combined_arguments = (
"".join(argument_list) or "{}"
) # base case, return empty dict
tool_calls_list.append(
{
"id": id,
"function": {"arguments": combined_arguments, "name": name},
"type": type,
}
)
response["choices"][0]["message"]["content"] = None
response["choices"][0]["message"]["tool_calls"] = tool_calls_list
tool_calls_list = processor.get_combined_tool_content(tool_call_chunks)
_choice = cast(Choices, response.choices[0])
_choice.message.content = None
_choice.message.tool_calls = tool_calls_list
function_call_chunks = [
chunk
@ -5546,32 +5458,11 @@ def stream_chunk_builder( # noqa: PLR0915
]
if len(function_call_chunks) > 0:
argument_list = []
delta = function_call_chunks[0]["choices"][0]["delta"]
function_call = delta.get("function_call", "")
function_call_name = function_call.name
message = response["choices"][0]["message"]
message["function_call"] = {}
message["function_call"]["name"] = function_call_name
for chunk in function_call_chunks:
choices = chunk["choices"]
for choice in choices:
delta = choice.get("delta", {})
function_call = delta.get("function_call", "")
# Check if a function call is present
if function_call:
# Now, function_call is expected to be a dictionary
arguments = function_call.arguments
argument_list.append(arguments)
combined_arguments = "".join(argument_list)
response["choices"][0]["message"]["content"] = None
response["choices"][0]["message"]["function_call"][
"arguments"
] = combined_arguments
_choice = cast(Choices, response.choices[0])
_choice.message.content = None
_choice.message.function_call = (
processor.get_combined_function_call_content(function_call_chunks)
)
content_chunks = [
chunk
@ -5582,109 +5473,34 @@ def stream_chunk_builder( # noqa: PLR0915
]
if len(content_chunks) > 0:
for chunk in chunks:
choices = chunk["choices"]
for choice in choices:
delta = choice.get("delta", {})
content = delta.get("content", "")
if content is None:
continue # openai v1.0.0 sets content = None for chunks
content_list.append(content)
# Combine the "content" strings into a single string || combine the 'function' strings into a single string
combined_content = "".join(content_list)
# Update the "content" field within the response dictionary
response["choices"][0]["message"]["content"] = combined_content
completion_output = ""
if len(combined_content) > 0:
completion_output += combined_content
if len(combined_arguments) > 0:
completion_output += combined_arguments
# # Update usage information if needed
prompt_tokens = 0
completion_tokens = 0
## anthropic prompt caching information ##
cache_creation_input_tokens: Optional[int] = None
cache_read_input_tokens: Optional[int] = None
completion_tokens_details: Optional[CompletionTokensDetails] = None
prompt_tokens_details: Optional[PromptTokensDetails] = None
for chunk in chunks:
usage_chunk: Optional[Usage] = None
if "usage" in chunk:
usage_chunk = chunk.usage
elif hasattr(chunk, "_hidden_params") and "usage" in chunk._hidden_params:
usage_chunk = chunk._hidden_params["usage"]
if usage_chunk is not None:
if "prompt_tokens" in usage_chunk:
prompt_tokens = usage_chunk.get("prompt_tokens", 0) or 0
if "completion_tokens" in usage_chunk:
completion_tokens = usage_chunk.get("completion_tokens", 0) or 0
if "cache_creation_input_tokens" in usage_chunk:
cache_creation_input_tokens = usage_chunk.get(
"cache_creation_input_tokens"
)
if "cache_read_input_tokens" in usage_chunk:
cache_read_input_tokens = usage_chunk.get("cache_read_input_tokens")
if hasattr(usage_chunk, "completion_tokens_details"):
if isinstance(usage_chunk.completion_tokens_details, dict):
completion_tokens_details = CompletionTokensDetails(
**usage_chunk.completion_tokens_details
)
elif isinstance(
usage_chunk.completion_tokens_details, CompletionTokensDetails
):
completion_tokens_details = (
usage_chunk.completion_tokens_details
)
if hasattr(usage_chunk, "prompt_tokens_details"):
if isinstance(usage_chunk.prompt_tokens_details, dict):
prompt_tokens_details = PromptTokensDetails(
**usage_chunk.prompt_tokens_details
)
elif isinstance(
usage_chunk.prompt_tokens_details, PromptTokensDetails
):
prompt_tokens_details = usage_chunk.prompt_tokens_details
try:
response["usage"]["prompt_tokens"] = prompt_tokens or token_counter(
model=model, messages=messages
response["choices"][0]["message"]["content"] = (
processor.get_combined_content(content_chunks)
)
except (
Exception
): # don't allow this failing to block a complete streaming response from being returned
print_verbose("token_counter failed, assuming prompt tokens is 0")
response["usage"]["prompt_tokens"] = 0
response["usage"]["completion_tokens"] = completion_tokens or token_counter(
audio_chunks = [
chunk
for chunk in chunks
if len(chunk["choices"]) > 0
and "audio" in chunk["choices"][0]["delta"]
and chunk["choices"][0]["delta"]["audio"] is not None
]
if len(audio_chunks) > 0:
_choice = cast(Choices, response.choices[0])
_choice.message.audio = processor.get_combined_audio_content(audio_chunks)
completion_output = get_content_from_model_response(response)
usage = processor.calculate_usage(
chunks=chunks,
model=model,
text=completion_output,
count_response_tokens=True, # count_response_tokens is a Flag to tell token counter this is a response, No need to add extra tokens we do for input messages
)
response["usage"]["total_tokens"] = (
response["usage"]["prompt_tokens"] + response["usage"]["completion_tokens"]
completion_output=completion_output,
messages=messages,
)
if cache_creation_input_tokens is not None:
response["usage"][
"cache_creation_input_tokens"
] = cache_creation_input_tokens
if cache_read_input_tokens is not None:
response["usage"]["cache_read_input_tokens"] = cache_read_input_tokens
setattr(response, "usage", usage)
if completion_tokens_details is not None:
response["usage"]["completion_tokens_details"] = completion_tokens_details
if prompt_tokens_details is not None:
response["usage"]["prompt_tokens_details"] = prompt_tokens_details
return convert_to_model_response_object(
response_object=response,
model_response_object=model_response,
start_time=start_time,
end_time=end_time,
) # type: ignore
return response
except Exception as e:
verbose_logger.exception(
"litellm.main.py::stream_chunk_builder() - Exception occurred - {}".format(