forked from phoenix/litellm-mirror
Litellm openai audio streaming (#6325)
* refactor(main.py): streaming_chunk_builder use <100 lines of code refactor each component into a separate function - easier to maintain + test * fix(utils.py): handle choices being None openai pydantic schema updated * fix(main.py): fix linting error * feat(streaming_chunk_builder_utils.py): update stream chunk builder to support rebuilding audio chunks from openai * test(test_custom_callback_input.py): test message redaction works for audio output * fix(streaming_chunk_builder_utils.py): return anthropic token usage info directly * fix(stream_chunk_builder_utils.py): run validation check before entering chunk processor * fix(main.py): fix import
This commit is contained in:
parent
979e8ea526
commit
c58d542282
10 changed files with 638 additions and 282 deletions
355
litellm/litellm_core_utils/streaming_chunk_builder_utils.py
Normal file
355
litellm/litellm_core_utils/streaming_chunk_builder_utils.py
Normal file
|
@ -0,0 +1,355 @@
|
||||||
|
import base64
|
||||||
|
import time
|
||||||
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
from litellm.exceptions import APIError
|
||||||
|
from litellm.types.llms.openai import (
|
||||||
|
ChatCompletionAssistantContentValue,
|
||||||
|
ChatCompletionAudioDelta,
|
||||||
|
ChatCompletionToolCallChunk,
|
||||||
|
ChatCompletionToolCallFunctionChunk,
|
||||||
|
)
|
||||||
|
from litellm.types.utils import (
|
||||||
|
ChatCompletionAudioResponse,
|
||||||
|
ChatCompletionMessageToolCall,
|
||||||
|
CompletionTokensDetails,
|
||||||
|
Function,
|
||||||
|
FunctionCall,
|
||||||
|
ModelResponse,
|
||||||
|
PromptTokensDetails,
|
||||||
|
Usage,
|
||||||
|
)
|
||||||
|
from litellm.utils import print_verbose, token_counter
|
||||||
|
|
||||||
|
|
||||||
|
class ChunkProcessor:
|
||||||
|
def __init__(self, chunks: List, messages: Optional[list] = None):
|
||||||
|
self.chunks = self._sort_chunks(chunks)
|
||||||
|
self.messages = messages
|
||||||
|
self.first_chunk = chunks[0]
|
||||||
|
|
||||||
|
def _sort_chunks(self, chunks: list) -> list:
|
||||||
|
if not chunks:
|
||||||
|
return []
|
||||||
|
if chunks[0]._hidden_params.get("created_at"):
|
||||||
|
return sorted(
|
||||||
|
chunks, key=lambda x: x._hidden_params.get("created_at", float("inf"))
|
||||||
|
)
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
def update_model_response_with_hidden_params(
|
||||||
|
self, model_response: ModelResponse, chunk: Optional[Dict[str, Any]] = None
|
||||||
|
) -> ModelResponse:
|
||||||
|
if chunk is None:
|
||||||
|
return model_response
|
||||||
|
# set hidden params from chunk to model_response
|
||||||
|
if model_response is not None and hasattr(model_response, "_hidden_params"):
|
||||||
|
model_response._hidden_params = chunk.get("_hidden_params", {})
|
||||||
|
return model_response
|
||||||
|
|
||||||
|
def build_base_response(self, chunks: List[Dict[str, Any]]) -> ModelResponse:
|
||||||
|
chunk = self.first_chunk
|
||||||
|
id = chunk["id"]
|
||||||
|
object = chunk["object"]
|
||||||
|
created = chunk["created"]
|
||||||
|
model = chunk["model"]
|
||||||
|
system_fingerprint = chunk.get("system_fingerprint", None)
|
||||||
|
|
||||||
|
role = chunk["choices"][0]["delta"]["role"]
|
||||||
|
finish_reason = "stop"
|
||||||
|
for chunk in chunks:
|
||||||
|
if "choices" in chunk and len(chunk["choices"]) > 0:
|
||||||
|
if hasattr(chunk["choices"][0], "finish_reason"):
|
||||||
|
finish_reason = chunk["choices"][0].finish_reason
|
||||||
|
elif "finish_reason" in chunk["choices"][0]:
|
||||||
|
finish_reason = chunk["choices"][0]["finish_reason"]
|
||||||
|
|
||||||
|
# Initialize the response dictionary
|
||||||
|
response = ModelResponse(
|
||||||
|
**{
|
||||||
|
"id": id,
|
||||||
|
"object": object,
|
||||||
|
"created": created,
|
||||||
|
"model": model,
|
||||||
|
"system_fingerprint": system_fingerprint,
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"message": {"role": role, "content": ""},
|
||||||
|
"finish_reason": finish_reason,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": 0, # Modify as needed
|
||||||
|
"completion_tokens": 0, # Modify as needed
|
||||||
|
"total_tokens": 0, # Modify as needed
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
response = self.update_model_response_with_hidden_params(
|
||||||
|
model_response=response, chunk=chunk
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
|
||||||
|
def get_combined_tool_content(
|
||||||
|
self, tool_call_chunks: List[Dict[str, Any]]
|
||||||
|
) -> List[ChatCompletionMessageToolCall]:
|
||||||
|
argument_list: List = []
|
||||||
|
delta = tool_call_chunks[0]["choices"][0]["delta"]
|
||||||
|
id = None
|
||||||
|
name = None
|
||||||
|
type = None
|
||||||
|
tool_calls_list: List[ChatCompletionMessageToolCall] = []
|
||||||
|
prev_index = None
|
||||||
|
prev_name = None
|
||||||
|
prev_id = None
|
||||||
|
curr_id = None
|
||||||
|
curr_index = 0
|
||||||
|
for chunk in tool_call_chunks:
|
||||||
|
choices = chunk["choices"]
|
||||||
|
for choice in choices:
|
||||||
|
delta = choice.get("delta", {})
|
||||||
|
tool_calls = delta.get("tool_calls", "")
|
||||||
|
# Check if a tool call is present
|
||||||
|
if tool_calls and tool_calls[0].function is not None:
|
||||||
|
if tool_calls[0].id:
|
||||||
|
id = tool_calls[0].id
|
||||||
|
curr_id = id
|
||||||
|
if prev_id is None:
|
||||||
|
prev_id = curr_id
|
||||||
|
if tool_calls[0].index:
|
||||||
|
curr_index = tool_calls[0].index
|
||||||
|
if tool_calls[0].function.arguments:
|
||||||
|
# Now, tool_calls is expected to be a dictionary
|
||||||
|
arguments = tool_calls[0].function.arguments
|
||||||
|
argument_list.append(arguments)
|
||||||
|
if tool_calls[0].function.name:
|
||||||
|
name = tool_calls[0].function.name
|
||||||
|
if tool_calls[0].type:
|
||||||
|
type = tool_calls[0].type
|
||||||
|
if prev_index is None:
|
||||||
|
prev_index = curr_index
|
||||||
|
if prev_name is None:
|
||||||
|
prev_name = name
|
||||||
|
if curr_index != prev_index: # new tool call
|
||||||
|
combined_arguments = "".join(argument_list)
|
||||||
|
tool_calls_list.append(
|
||||||
|
ChatCompletionMessageToolCall(
|
||||||
|
id=prev_id,
|
||||||
|
function=Function(
|
||||||
|
arguments=combined_arguments,
|
||||||
|
name=prev_name,
|
||||||
|
),
|
||||||
|
type=type,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
argument_list = [] # reset
|
||||||
|
prev_index = curr_index
|
||||||
|
prev_id = curr_id
|
||||||
|
prev_name = name
|
||||||
|
|
||||||
|
combined_arguments = (
|
||||||
|
"".join(argument_list) or "{}"
|
||||||
|
) # base case, return empty dict
|
||||||
|
|
||||||
|
tool_calls_list.append(
|
||||||
|
ChatCompletionMessageToolCall(
|
||||||
|
id=id,
|
||||||
|
type="function",
|
||||||
|
function=Function(
|
||||||
|
arguments=combined_arguments,
|
||||||
|
name=name,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return tool_calls_list
|
||||||
|
|
||||||
|
def get_combined_function_call_content(
|
||||||
|
self, function_call_chunks: List[Dict[str, Any]]
|
||||||
|
) -> FunctionCall:
|
||||||
|
argument_list = []
|
||||||
|
delta = function_call_chunks[0]["choices"][0]["delta"]
|
||||||
|
function_call = delta.get("function_call", "")
|
||||||
|
function_call_name = function_call.name
|
||||||
|
|
||||||
|
for chunk in function_call_chunks:
|
||||||
|
choices = chunk["choices"]
|
||||||
|
for choice in choices:
|
||||||
|
delta = choice.get("delta", {})
|
||||||
|
function_call = delta.get("function_call", "")
|
||||||
|
|
||||||
|
# Check if a function call is present
|
||||||
|
if function_call:
|
||||||
|
# Now, function_call is expected to be a dictionary
|
||||||
|
arguments = function_call.arguments
|
||||||
|
argument_list.append(arguments)
|
||||||
|
|
||||||
|
combined_arguments = "".join(argument_list)
|
||||||
|
|
||||||
|
return FunctionCall(
|
||||||
|
name=function_call_name,
|
||||||
|
arguments=combined_arguments,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_combined_content(
|
||||||
|
self, chunks: List[Dict[str, Any]]
|
||||||
|
) -> ChatCompletionAssistantContentValue:
|
||||||
|
content_list: List[str] = []
|
||||||
|
for chunk in chunks:
|
||||||
|
choices = chunk["choices"]
|
||||||
|
for choice in choices:
|
||||||
|
delta = choice.get("delta", {})
|
||||||
|
content = delta.get("content", "")
|
||||||
|
if content is None:
|
||||||
|
continue # openai v1.0.0 sets content = None for chunks
|
||||||
|
content_list.append(content)
|
||||||
|
|
||||||
|
# Combine the "content" strings into a single string || combine the 'function' strings into a single string
|
||||||
|
combined_content = "".join(content_list)
|
||||||
|
|
||||||
|
# Update the "content" field within the response dictionary
|
||||||
|
return combined_content
|
||||||
|
|
||||||
|
def get_combined_audio_content(
|
||||||
|
self, chunks: List[Dict[str, Any]]
|
||||||
|
) -> ChatCompletionAudioResponse:
|
||||||
|
base64_data_list: List[str] = []
|
||||||
|
transcript_list: List[str] = []
|
||||||
|
expires_at: Optional[int] = None
|
||||||
|
id: Optional[str] = None
|
||||||
|
|
||||||
|
for chunk in chunks:
|
||||||
|
choices = chunk["choices"]
|
||||||
|
for choice in choices:
|
||||||
|
delta = choice.get("delta") or {}
|
||||||
|
audio: Optional[ChatCompletionAudioDelta] = delta.get("audio")
|
||||||
|
if audio is not None:
|
||||||
|
for k, v in audio.items():
|
||||||
|
if k == "data" and v is not None and isinstance(v, str):
|
||||||
|
base64_data_list.append(v)
|
||||||
|
elif k == "transcript" and v is not None and isinstance(v, str):
|
||||||
|
transcript_list.append(v)
|
||||||
|
elif k == "expires_at" and v is not None and isinstance(v, int):
|
||||||
|
expires_at = v
|
||||||
|
elif k == "id" and v is not None and isinstance(v, str):
|
||||||
|
id = v
|
||||||
|
|
||||||
|
concatenated_audio = concatenate_base64_list(base64_data_list)
|
||||||
|
return ChatCompletionAudioResponse(
|
||||||
|
data=concatenated_audio,
|
||||||
|
expires_at=expires_at or int(time.time() + 3600),
|
||||||
|
transcript="".join(transcript_list),
|
||||||
|
id=id,
|
||||||
|
)
|
||||||
|
|
||||||
|
def calculate_usage(
|
||||||
|
self,
|
||||||
|
chunks: List[Union[Dict[str, Any], ModelResponse]],
|
||||||
|
model: str,
|
||||||
|
completion_output: str,
|
||||||
|
messages: Optional[List] = None,
|
||||||
|
) -> Usage:
|
||||||
|
"""
|
||||||
|
Calculate usage for the given chunks.
|
||||||
|
"""
|
||||||
|
returned_usage = Usage()
|
||||||
|
# # Update usage information if needed
|
||||||
|
prompt_tokens = 0
|
||||||
|
completion_tokens = 0
|
||||||
|
## anthropic prompt caching information ##
|
||||||
|
cache_creation_input_tokens: Optional[int] = None
|
||||||
|
cache_read_input_tokens: Optional[int] = None
|
||||||
|
completion_tokens_details: Optional[CompletionTokensDetails] = None
|
||||||
|
prompt_tokens_details: Optional[PromptTokensDetails] = None
|
||||||
|
for chunk in chunks:
|
||||||
|
usage_chunk: Optional[Usage] = None
|
||||||
|
if "usage" in chunk:
|
||||||
|
usage_chunk = chunk["usage"]
|
||||||
|
elif isinstance(chunk, ModelResponse) and hasattr(chunk, "_hidden_params"):
|
||||||
|
usage_chunk = chunk._hidden_params.get("usage", None)
|
||||||
|
if usage_chunk is not None:
|
||||||
|
if "prompt_tokens" in usage_chunk:
|
||||||
|
prompt_tokens = usage_chunk.get("prompt_tokens", 0) or 0
|
||||||
|
if "completion_tokens" in usage_chunk:
|
||||||
|
completion_tokens = usage_chunk.get("completion_tokens", 0) or 0
|
||||||
|
if "cache_creation_input_tokens" in usage_chunk:
|
||||||
|
cache_creation_input_tokens = usage_chunk.get(
|
||||||
|
"cache_creation_input_tokens"
|
||||||
|
)
|
||||||
|
if "cache_read_input_tokens" in usage_chunk:
|
||||||
|
cache_read_input_tokens = usage_chunk.get("cache_read_input_tokens")
|
||||||
|
if hasattr(usage_chunk, "completion_tokens_details"):
|
||||||
|
if isinstance(usage_chunk.completion_tokens_details, dict):
|
||||||
|
completion_tokens_details = CompletionTokensDetails(
|
||||||
|
**usage_chunk.completion_tokens_details
|
||||||
|
)
|
||||||
|
elif isinstance(
|
||||||
|
usage_chunk.completion_tokens_details, CompletionTokensDetails
|
||||||
|
):
|
||||||
|
completion_tokens_details = (
|
||||||
|
usage_chunk.completion_tokens_details
|
||||||
|
)
|
||||||
|
if hasattr(usage_chunk, "prompt_tokens_details"):
|
||||||
|
if isinstance(usage_chunk.prompt_tokens_details, dict):
|
||||||
|
prompt_tokens_details = PromptTokensDetails(
|
||||||
|
**usage_chunk.prompt_tokens_details
|
||||||
|
)
|
||||||
|
elif isinstance(
|
||||||
|
usage_chunk.prompt_tokens_details, PromptTokensDetails
|
||||||
|
):
|
||||||
|
prompt_tokens_details = usage_chunk.prompt_tokens_details
|
||||||
|
|
||||||
|
try:
|
||||||
|
returned_usage.prompt_tokens = prompt_tokens or token_counter(
|
||||||
|
model=model, messages=messages
|
||||||
|
)
|
||||||
|
except (
|
||||||
|
Exception
|
||||||
|
): # don't allow this failing to block a complete streaming response from being returned
|
||||||
|
print_verbose("token_counter failed, assuming prompt tokens is 0")
|
||||||
|
returned_usage.prompt_tokens = 0
|
||||||
|
returned_usage.completion_tokens = completion_tokens or token_counter(
|
||||||
|
model=model,
|
||||||
|
text=completion_output,
|
||||||
|
count_response_tokens=True, # count_response_tokens is a Flag to tell token counter this is a response, No need to add extra tokens we do for input messages
|
||||||
|
)
|
||||||
|
returned_usage.total_tokens = (
|
||||||
|
returned_usage.prompt_tokens + returned_usage.completion_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
if cache_creation_input_tokens is not None:
|
||||||
|
returned_usage._cache_creation_input_tokens = cache_creation_input_tokens
|
||||||
|
setattr(
|
||||||
|
returned_usage,
|
||||||
|
"cache_creation_input_tokens",
|
||||||
|
cache_creation_input_tokens,
|
||||||
|
) # for anthropic
|
||||||
|
if cache_read_input_tokens is not None:
|
||||||
|
returned_usage._cache_read_input_tokens = cache_read_input_tokens
|
||||||
|
setattr(
|
||||||
|
returned_usage, "cache_read_input_tokens", cache_read_input_tokens
|
||||||
|
) # for anthropic
|
||||||
|
if completion_tokens_details is not None:
|
||||||
|
returned_usage.completion_tokens_details = completion_tokens_details
|
||||||
|
if prompt_tokens_details is not None:
|
||||||
|
returned_usage.prompt_tokens_details = prompt_tokens_details
|
||||||
|
|
||||||
|
return returned_usage
|
||||||
|
|
||||||
|
|
||||||
|
def concatenate_base64_list(base64_strings: List[str]) -> str:
|
||||||
|
"""
|
||||||
|
Concatenates a list of base64-encoded strings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
base64_strings (List[str]): A list of base64 strings to concatenate.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The concatenated result as a base64-encoded string.
|
||||||
|
"""
|
||||||
|
# Decode each base64 string and collect the resulting bytes
|
||||||
|
combined_bytes = b"".join(base64.b64decode(b64_str) for b64_str in base64_strings)
|
||||||
|
|
||||||
|
# Encode the concatenated bytes back to base64
|
||||||
|
return base64.b64encode(combined_bytes).decode("utf-8")
|
|
@ -4,12 +4,13 @@ Common utility functions used for translating messages across providers
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import Dict, List, Literal, Optional
|
from typing import Dict, List, Literal, Optional, Union
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.types.llms.openai import (
|
from litellm.types.llms.openai import (
|
||||||
AllMessageValues,
|
AllMessageValues,
|
||||||
ChatCompletionAssistantMessage,
|
ChatCompletionAssistantMessage,
|
||||||
|
ChatCompletionResponseMessage,
|
||||||
ChatCompletionUserMessage,
|
ChatCompletionUserMessage,
|
||||||
)
|
)
|
||||||
from litellm.types.utils import Choices, ModelResponse, StreamingChoices
|
from litellm.types.utils import Choices, ModelResponse, StreamingChoices
|
||||||
|
@ -67,12 +68,18 @@ def convert_openai_message_to_only_content_messages(
|
||||||
return converted_messages
|
return converted_messages
|
||||||
|
|
||||||
|
|
||||||
def get_content_from_model_response(response: ModelResponse) -> str:
|
def get_content_from_model_response(response: Union[ModelResponse, dict]) -> str:
|
||||||
"""
|
"""
|
||||||
Gets content from model response
|
Gets content from model response
|
||||||
"""
|
"""
|
||||||
|
if isinstance(response, dict):
|
||||||
|
new_response = ModelResponse(**response)
|
||||||
|
else:
|
||||||
|
new_response = response
|
||||||
|
|
||||||
content = ""
|
content = ""
|
||||||
for choice in response.choices:
|
|
||||||
|
for choice in new_response.choices:
|
||||||
if isinstance(choice, Choices):
|
if isinstance(choice, Choices):
|
||||||
content += choice.message.content if choice.message.content else ""
|
content += choice.message.content if choice.message.content else ""
|
||||||
if choice.message.function_call:
|
if choice.message.function_call:
|
||||||
|
|
308
litellm/main.py
308
litellm/main.py
|
@ -23,7 +23,18 @@ from concurrent import futures
|
||||||
from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait
|
from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any, Callable, Dict, List, Literal, Mapping, Optional, Type, Union
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
List,
|
||||||
|
Literal,
|
||||||
|
Mapping,
|
||||||
|
Optional,
|
||||||
|
Type,
|
||||||
|
Union,
|
||||||
|
cast,
|
||||||
|
)
|
||||||
|
|
||||||
import dotenv
|
import dotenv
|
||||||
import httpx
|
import httpx
|
||||||
|
@ -47,6 +58,7 @@ from litellm.litellm_core_utils.mock_functions import (
|
||||||
mock_image_generation,
|
mock_image_generation,
|
||||||
)
|
)
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||||
|
from litellm.llms.prompt_templates.common_utils import get_content_from_model_response
|
||||||
from litellm.secret_managers.main import get_secret_str
|
from litellm.secret_managers.main import get_secret_str
|
||||||
from litellm.utils import (
|
from litellm.utils import (
|
||||||
CustomStreamWrapper,
|
CustomStreamWrapper,
|
||||||
|
@ -70,6 +82,7 @@ from litellm.utils import (
|
||||||
|
|
||||||
from ._logging import verbose_logger
|
from ._logging import verbose_logger
|
||||||
from .caching.caching import disable_cache, enable_cache, update_cache
|
from .caching.caching import disable_cache, enable_cache, update_cache
|
||||||
|
from .litellm_core_utils.streaming_chunk_builder_utils import ChunkProcessor
|
||||||
from .llms import (
|
from .llms import (
|
||||||
aleph_alpha,
|
aleph_alpha,
|
||||||
baseten,
|
baseten,
|
||||||
|
@ -5390,73 +5403,37 @@ def stream_chunk_builder_text_completion(
|
||||||
return TextCompletionResponse(**response)
|
return TextCompletionResponse(**response)
|
||||||
|
|
||||||
|
|
||||||
def stream_chunk_builder( # noqa: PLR0915
|
def stream_chunk_builder(
|
||||||
chunks: list, messages: Optional[list] = None, start_time=None, end_time=None
|
chunks: list, messages: Optional[list] = None, start_time=None, end_time=None
|
||||||
) -> Optional[Union[ModelResponse, TextCompletionResponse]]:
|
) -> Optional[Union[ModelResponse, TextCompletionResponse]]:
|
||||||
try:
|
try:
|
||||||
model_response = litellm.ModelResponse()
|
if chunks is None:
|
||||||
|
raise litellm.APIError(
|
||||||
|
status_code=500,
|
||||||
|
message="Error building chunks for logging/streaming usage calculation",
|
||||||
|
llm_provider="",
|
||||||
|
model="",
|
||||||
|
)
|
||||||
|
if not chunks:
|
||||||
|
return None
|
||||||
|
|
||||||
|
processor = ChunkProcessor(chunks, messages)
|
||||||
|
chunks = processor.chunks
|
||||||
|
|
||||||
### BASE-CASE ###
|
### BASE-CASE ###
|
||||||
if len(chunks) == 0:
|
if len(chunks) == 0:
|
||||||
return None
|
return None
|
||||||
### SORT CHUNKS BASED ON CREATED ORDER ##
|
## Route to the text completion logic
|
||||||
print_verbose("Goes into checking if chunk has hiddden created at param")
|
|
||||||
if chunks[0]._hidden_params.get("created_at", None):
|
|
||||||
print_verbose("Chunks have a created at hidden param")
|
|
||||||
# Sort chunks based on created_at in ascending order
|
|
||||||
chunks = sorted(
|
|
||||||
chunks, key=lambda x: x._hidden_params.get("created_at", float("inf"))
|
|
||||||
)
|
|
||||||
print_verbose("Chunks sorted")
|
|
||||||
|
|
||||||
# set hidden params from chunk to model_response
|
|
||||||
if model_response is not None and hasattr(model_response, "_hidden_params"):
|
|
||||||
model_response._hidden_params = chunks[0].get("_hidden_params", {})
|
|
||||||
id = chunks[0]["id"]
|
|
||||||
object = chunks[0]["object"]
|
|
||||||
created = chunks[0]["created"]
|
|
||||||
model = chunks[0]["model"]
|
|
||||||
system_fingerprint = chunks[0].get("system_fingerprint", None)
|
|
||||||
|
|
||||||
if isinstance(
|
if isinstance(
|
||||||
chunks[0]["choices"][0], litellm.utils.TextChoices
|
chunks[0]["choices"][0], litellm.utils.TextChoices
|
||||||
): # route to the text completion logic
|
): # route to the text completion logic
|
||||||
return stream_chunk_builder_text_completion(
|
return stream_chunk_builder_text_completion(
|
||||||
chunks=chunks, messages=messages
|
chunks=chunks, messages=messages
|
||||||
)
|
)
|
||||||
role = chunks[0]["choices"][0]["delta"]["role"]
|
|
||||||
finish_reason = "stop"
|
|
||||||
for chunk in chunks:
|
|
||||||
if "choices" in chunk and len(chunk["choices"]) > 0:
|
|
||||||
if hasattr(chunk["choices"][0], "finish_reason"):
|
|
||||||
finish_reason = chunk["choices"][0].finish_reason
|
|
||||||
elif "finish_reason" in chunk["choices"][0]:
|
|
||||||
finish_reason = chunk["choices"][0]["finish_reason"]
|
|
||||||
|
|
||||||
|
model = chunks[0]["model"]
|
||||||
# Initialize the response dictionary
|
# Initialize the response dictionary
|
||||||
response = {
|
response = processor.build_base_response(chunks)
|
||||||
"id": id,
|
|
||||||
"object": object,
|
|
||||||
"created": created,
|
|
||||||
"model": model,
|
|
||||||
"system_fingerprint": system_fingerprint,
|
|
||||||
"choices": [
|
|
||||||
{
|
|
||||||
"index": 0,
|
|
||||||
"message": {"role": role, "content": ""},
|
|
||||||
"finish_reason": finish_reason,
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"usage": {
|
|
||||||
"prompt_tokens": 0, # Modify as needed
|
|
||||||
"completion_tokens": 0, # Modify as needed
|
|
||||||
"total_tokens": 0, # Modify as needed
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
# Extract the "content" strings from the nested dictionaries within "choices"
|
|
||||||
content_list = []
|
|
||||||
combined_content = ""
|
|
||||||
combined_arguments = ""
|
|
||||||
|
|
||||||
tool_call_chunks = [
|
tool_call_chunks = [
|
||||||
chunk
|
chunk
|
||||||
|
@ -5467,75 +5444,10 @@ def stream_chunk_builder( # noqa: PLR0915
|
||||||
]
|
]
|
||||||
|
|
||||||
if len(tool_call_chunks) > 0:
|
if len(tool_call_chunks) > 0:
|
||||||
argument_list: List = []
|
tool_calls_list = processor.get_combined_tool_content(tool_call_chunks)
|
||||||
delta = tool_call_chunks[0]["choices"][0]["delta"]
|
_choice = cast(Choices, response.choices[0])
|
||||||
message = response["choices"][0]["message"]
|
_choice.message.content = None
|
||||||
message["tool_calls"] = []
|
_choice.message.tool_calls = tool_calls_list
|
||||||
id = None
|
|
||||||
name = None
|
|
||||||
type = None
|
|
||||||
tool_calls_list = []
|
|
||||||
prev_index = None
|
|
||||||
prev_name = None
|
|
||||||
prev_id = None
|
|
||||||
curr_id = None
|
|
||||||
curr_index = 0
|
|
||||||
for chunk in tool_call_chunks:
|
|
||||||
choices = chunk["choices"]
|
|
||||||
for choice in choices:
|
|
||||||
delta = choice.get("delta", {})
|
|
||||||
tool_calls = delta.get("tool_calls", "")
|
|
||||||
# Check if a tool call is present
|
|
||||||
if tool_calls and tool_calls[0].function is not None:
|
|
||||||
if tool_calls[0].id:
|
|
||||||
id = tool_calls[0].id
|
|
||||||
curr_id = id
|
|
||||||
if prev_id is None:
|
|
||||||
prev_id = curr_id
|
|
||||||
if tool_calls[0].index:
|
|
||||||
curr_index = tool_calls[0].index
|
|
||||||
if tool_calls[0].function.arguments:
|
|
||||||
# Now, tool_calls is expected to be a dictionary
|
|
||||||
arguments = tool_calls[0].function.arguments
|
|
||||||
argument_list.append(arguments)
|
|
||||||
if tool_calls[0].function.name:
|
|
||||||
name = tool_calls[0].function.name
|
|
||||||
if tool_calls[0].type:
|
|
||||||
type = tool_calls[0].type
|
|
||||||
if prev_index is None:
|
|
||||||
prev_index = curr_index
|
|
||||||
if prev_name is None:
|
|
||||||
prev_name = name
|
|
||||||
if curr_index != prev_index: # new tool call
|
|
||||||
combined_arguments = "".join(argument_list)
|
|
||||||
tool_calls_list.append(
|
|
||||||
{
|
|
||||||
"id": prev_id,
|
|
||||||
"function": {
|
|
||||||
"arguments": combined_arguments,
|
|
||||||
"name": prev_name,
|
|
||||||
},
|
|
||||||
"type": type,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
argument_list = [] # reset
|
|
||||||
prev_index = curr_index
|
|
||||||
prev_id = curr_id
|
|
||||||
prev_name = name
|
|
||||||
|
|
||||||
combined_arguments = (
|
|
||||||
"".join(argument_list) or "{}"
|
|
||||||
) # base case, return empty dict
|
|
||||||
|
|
||||||
tool_calls_list.append(
|
|
||||||
{
|
|
||||||
"id": id,
|
|
||||||
"function": {"arguments": combined_arguments, "name": name},
|
|
||||||
"type": type,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
response["choices"][0]["message"]["content"] = None
|
|
||||||
response["choices"][0]["message"]["tool_calls"] = tool_calls_list
|
|
||||||
|
|
||||||
function_call_chunks = [
|
function_call_chunks = [
|
||||||
chunk
|
chunk
|
||||||
|
@ -5546,32 +5458,11 @@ def stream_chunk_builder( # noqa: PLR0915
|
||||||
]
|
]
|
||||||
|
|
||||||
if len(function_call_chunks) > 0:
|
if len(function_call_chunks) > 0:
|
||||||
argument_list = []
|
_choice = cast(Choices, response.choices[0])
|
||||||
delta = function_call_chunks[0]["choices"][0]["delta"]
|
_choice.message.content = None
|
||||||
function_call = delta.get("function_call", "")
|
_choice.message.function_call = (
|
||||||
function_call_name = function_call.name
|
processor.get_combined_function_call_content(function_call_chunks)
|
||||||
|
)
|
||||||
message = response["choices"][0]["message"]
|
|
||||||
message["function_call"] = {}
|
|
||||||
message["function_call"]["name"] = function_call_name
|
|
||||||
|
|
||||||
for chunk in function_call_chunks:
|
|
||||||
choices = chunk["choices"]
|
|
||||||
for choice in choices:
|
|
||||||
delta = choice.get("delta", {})
|
|
||||||
function_call = delta.get("function_call", "")
|
|
||||||
|
|
||||||
# Check if a function call is present
|
|
||||||
if function_call:
|
|
||||||
# Now, function_call is expected to be a dictionary
|
|
||||||
arguments = function_call.arguments
|
|
||||||
argument_list.append(arguments)
|
|
||||||
|
|
||||||
combined_arguments = "".join(argument_list)
|
|
||||||
response["choices"][0]["message"]["content"] = None
|
|
||||||
response["choices"][0]["message"]["function_call"][
|
|
||||||
"arguments"
|
|
||||||
] = combined_arguments
|
|
||||||
|
|
||||||
content_chunks = [
|
content_chunks = [
|
||||||
chunk
|
chunk
|
||||||
|
@ -5582,109 +5473,34 @@ def stream_chunk_builder( # noqa: PLR0915
|
||||||
]
|
]
|
||||||
|
|
||||||
if len(content_chunks) > 0:
|
if len(content_chunks) > 0:
|
||||||
for chunk in chunks:
|
response["choices"][0]["message"]["content"] = (
|
||||||
choices = chunk["choices"]
|
processor.get_combined_content(content_chunks)
|
||||||
for choice in choices:
|
|
||||||
delta = choice.get("delta", {})
|
|
||||||
content = delta.get("content", "")
|
|
||||||
if content is None:
|
|
||||||
continue # openai v1.0.0 sets content = None for chunks
|
|
||||||
content_list.append(content)
|
|
||||||
|
|
||||||
# Combine the "content" strings into a single string || combine the 'function' strings into a single string
|
|
||||||
combined_content = "".join(content_list)
|
|
||||||
|
|
||||||
# Update the "content" field within the response dictionary
|
|
||||||
response["choices"][0]["message"]["content"] = combined_content
|
|
||||||
|
|
||||||
completion_output = ""
|
|
||||||
if len(combined_content) > 0:
|
|
||||||
completion_output += combined_content
|
|
||||||
if len(combined_arguments) > 0:
|
|
||||||
completion_output += combined_arguments
|
|
||||||
|
|
||||||
# # Update usage information if needed
|
|
||||||
prompt_tokens = 0
|
|
||||||
completion_tokens = 0
|
|
||||||
## anthropic prompt caching information ##
|
|
||||||
cache_creation_input_tokens: Optional[int] = None
|
|
||||||
cache_read_input_tokens: Optional[int] = None
|
|
||||||
completion_tokens_details: Optional[CompletionTokensDetails] = None
|
|
||||||
prompt_tokens_details: Optional[PromptTokensDetails] = None
|
|
||||||
for chunk in chunks:
|
|
||||||
usage_chunk: Optional[Usage] = None
|
|
||||||
if "usage" in chunk:
|
|
||||||
usage_chunk = chunk.usage
|
|
||||||
elif hasattr(chunk, "_hidden_params") and "usage" in chunk._hidden_params:
|
|
||||||
usage_chunk = chunk._hidden_params["usage"]
|
|
||||||
if usage_chunk is not None:
|
|
||||||
if "prompt_tokens" in usage_chunk:
|
|
||||||
prompt_tokens = usage_chunk.get("prompt_tokens", 0) or 0
|
|
||||||
if "completion_tokens" in usage_chunk:
|
|
||||||
completion_tokens = usage_chunk.get("completion_tokens", 0) or 0
|
|
||||||
if "cache_creation_input_tokens" in usage_chunk:
|
|
||||||
cache_creation_input_tokens = usage_chunk.get(
|
|
||||||
"cache_creation_input_tokens"
|
|
||||||
)
|
)
|
||||||
if "cache_read_input_tokens" in usage_chunk:
|
|
||||||
cache_read_input_tokens = usage_chunk.get("cache_read_input_tokens")
|
|
||||||
if hasattr(usage_chunk, "completion_tokens_details"):
|
|
||||||
if isinstance(usage_chunk.completion_tokens_details, dict):
|
|
||||||
completion_tokens_details = CompletionTokensDetails(
|
|
||||||
**usage_chunk.completion_tokens_details
|
|
||||||
)
|
|
||||||
elif isinstance(
|
|
||||||
usage_chunk.completion_tokens_details, CompletionTokensDetails
|
|
||||||
):
|
|
||||||
completion_tokens_details = (
|
|
||||||
usage_chunk.completion_tokens_details
|
|
||||||
)
|
|
||||||
if hasattr(usage_chunk, "prompt_tokens_details"):
|
|
||||||
if isinstance(usage_chunk.prompt_tokens_details, dict):
|
|
||||||
prompt_tokens_details = PromptTokensDetails(
|
|
||||||
**usage_chunk.prompt_tokens_details
|
|
||||||
)
|
|
||||||
elif isinstance(
|
|
||||||
usage_chunk.prompt_tokens_details, PromptTokensDetails
|
|
||||||
):
|
|
||||||
prompt_tokens_details = usage_chunk.prompt_tokens_details
|
|
||||||
|
|
||||||
try:
|
audio_chunks = [
|
||||||
response["usage"]["prompt_tokens"] = prompt_tokens or token_counter(
|
chunk
|
||||||
model=model, messages=messages
|
for chunk in chunks
|
||||||
)
|
if len(chunk["choices"]) > 0
|
||||||
except (
|
and "audio" in chunk["choices"][0]["delta"]
|
||||||
Exception
|
and chunk["choices"][0]["delta"]["audio"] is not None
|
||||||
): # don't allow this failing to block a complete streaming response from being returned
|
]
|
||||||
print_verbose("token_counter failed, assuming prompt tokens is 0")
|
|
||||||
response["usage"]["prompt_tokens"] = 0
|
if len(audio_chunks) > 0:
|
||||||
response["usage"]["completion_tokens"] = completion_tokens or token_counter(
|
_choice = cast(Choices, response.choices[0])
|
||||||
|
_choice.message.audio = processor.get_combined_audio_content(audio_chunks)
|
||||||
|
|
||||||
|
completion_output = get_content_from_model_response(response)
|
||||||
|
|
||||||
|
usage = processor.calculate_usage(
|
||||||
|
chunks=chunks,
|
||||||
model=model,
|
model=model,
|
||||||
text=completion_output,
|
completion_output=completion_output,
|
||||||
count_response_tokens=True, # count_response_tokens is a Flag to tell token counter this is a response, No need to add extra tokens we do for input messages
|
messages=messages,
|
||||||
)
|
|
||||||
response["usage"]["total_tokens"] = (
|
|
||||||
response["usage"]["prompt_tokens"] + response["usage"]["completion_tokens"]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if cache_creation_input_tokens is not None:
|
setattr(response, "usage", usage)
|
||||||
response["usage"][
|
|
||||||
"cache_creation_input_tokens"
|
|
||||||
] = cache_creation_input_tokens
|
|
||||||
if cache_read_input_tokens is not None:
|
|
||||||
response["usage"]["cache_read_input_tokens"] = cache_read_input_tokens
|
|
||||||
|
|
||||||
if completion_tokens_details is not None:
|
return response
|
||||||
response["usage"]["completion_tokens_details"] = completion_tokens_details
|
|
||||||
if prompt_tokens_details is not None:
|
|
||||||
response["usage"]["prompt_tokens_details"] = prompt_tokens_details
|
|
||||||
|
|
||||||
return convert_to_model_response_object(
|
|
||||||
response_object=response,
|
|
||||||
model_response_object=model_response,
|
|
||||||
start_time=start_time,
|
|
||||||
end_time=end_time,
|
|
||||||
) # type: ignore
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
verbose_logger.exception(
|
verbose_logger.exception(
|
||||||
"litellm.main.py::stream_chunk_builder() - Exception occurred - {}".format(
|
"litellm.main.py::stream_chunk_builder() - Exception occurred - {}".format(
|
||||||
|
|
|
@ -1,12 +1,6 @@
|
||||||
model_list:
|
model_list:
|
||||||
- model_name: "gpt-3.5-turbo"
|
- model_name: "gpt-4o-audio-preview"
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: gpt-3.5-turbo
|
model: gpt-4o-audio-preview
|
||||||
api_key: os.environ/OPENAI_API_KEY
|
api_key: os.environ/OPENAI_API_KEY
|
||||||
|
|
||||||
|
|
||||||
litellm_settings:
|
|
||||||
callbacks: ["argilla"]
|
|
||||||
argilla_transformation_object:
|
|
||||||
user_input: "messages"
|
|
||||||
llm_output: "response"
|
|
|
@ -295,6 +295,13 @@ class ListBatchRequest(TypedDict, total=False):
|
||||||
timeout: Optional[float]
|
timeout: Optional[float]
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionAudioDelta(TypedDict, total=False):
|
||||||
|
data: str
|
||||||
|
transcript: str
|
||||||
|
expires_at: int
|
||||||
|
id: str
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionToolCallFunctionChunk(TypedDict, total=False):
|
class ChatCompletionToolCallFunctionChunk(TypedDict, total=False):
|
||||||
name: Optional[str]
|
name: Optional[str]
|
||||||
arguments: str
|
arguments: str
|
||||||
|
@ -482,8 +489,13 @@ class ChatCompletionDeltaChunk(TypedDict, total=False):
|
||||||
role: str
|
role: str
|
||||||
|
|
||||||
|
|
||||||
|
ChatCompletionAssistantContentValue = (
|
||||||
|
str # keep as var, used in stream_chunk_builder as well
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionResponseMessage(TypedDict, total=False):
|
class ChatCompletionResponseMessage(TypedDict, total=False):
|
||||||
content: Optional[str]
|
content: Optional[ChatCompletionAssistantContentValue]
|
||||||
tool_calls: List[ChatCompletionToolCallChunk]
|
tool_calls: List[ChatCompletionToolCallChunk]
|
||||||
role: Literal["assistant"]
|
role: Literal["assistant"]
|
||||||
function_call: ChatCompletionToolCallFunctionChunk
|
function_call: ChatCompletionToolCallFunctionChunk
|
||||||
|
|
|
@ -321,7 +321,11 @@ class ChatCompletionMessageToolCall(OpenAIObject):
|
||||||
setattr(self, key, value)
|
setattr(self, key, value)
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionAudioResponse(OpenAIObject):
|
from openai.types.chat.chat_completion_audio import ChatCompletionAudio
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionAudioResponse(ChatCompletionAudio):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
data: str,
|
data: str,
|
||||||
|
@ -330,27 +334,13 @@ class ChatCompletionAudioResponse(OpenAIObject):
|
||||||
id: Optional[str] = None,
|
id: Optional[str] = None,
|
||||||
**params,
|
**params,
|
||||||
):
|
):
|
||||||
super(ChatCompletionAudioResponse, self).__init__(**params)
|
|
||||||
if id is not None:
|
if id is not None:
|
||||||
self.id = id
|
id = id
|
||||||
else:
|
else:
|
||||||
self.id = f"{uuid.uuid4()}"
|
id = f"{uuid.uuid4()}"
|
||||||
"""Unique identifier for this audio response."""
|
super(ChatCompletionAudioResponse, self).__init__(
|
||||||
|
data=data, expires_at=expires_at, transcript=transcript, id=id, **params
|
||||||
self.data = data
|
)
|
||||||
"""
|
|
||||||
Base64 encoded audio bytes generated by the model, in the format specified in
|
|
||||||
the request.
|
|
||||||
"""
|
|
||||||
|
|
||||||
self.expires_at = expires_at
|
|
||||||
"""
|
|
||||||
The Unix timestamp (in seconds) for when this audio response will no longer be
|
|
||||||
accessible on the server for use in multi-turn conversations.
|
|
||||||
"""
|
|
||||||
|
|
||||||
self.transcript = transcript
|
|
||||||
"""Transcript of the audio generated by the model."""
|
|
||||||
|
|
||||||
def __contains__(self, key):
|
def __contains__(self, key):
|
||||||
# Define custom behavior for the 'in' operator
|
# Define custom behavior for the 'in' operator
|
||||||
|
|
|
@ -7573,7 +7573,7 @@ class CustomStreamWrapper:
|
||||||
original_chunk = response_obj.get("original_chunk", None)
|
original_chunk = response_obj.get("original_chunk", None)
|
||||||
model_response.id = original_chunk.id
|
model_response.id = original_chunk.id
|
||||||
self.response_id = original_chunk.id
|
self.response_id = original_chunk.id
|
||||||
if len(original_chunk.choices) > 0:
|
if original_chunk.choices and len(original_chunk.choices) > 0:
|
||||||
delta = original_chunk.choices[0].delta
|
delta = original_chunk.choices[0].delta
|
||||||
if delta is not None and (
|
if delta is not None and (
|
||||||
delta.function_call is not None or delta.tool_calls is not None
|
delta.function_call is not None or delta.tool_calls is not None
|
||||||
|
|
|
@ -2365,3 +2365,32 @@ async def test_caching_kwargs_input(sync_mode):
|
||||||
else:
|
else:
|
||||||
input["original_function"] = acompletion
|
input["original_function"] = acompletion
|
||||||
await llm_caching_handler.async_set_cache(**input)
|
await llm_caching_handler.async_set_cache(**input)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="audio caching not supported yet")
|
||||||
|
@pytest.mark.parametrize("stream", [False]) # True,
|
||||||
|
@pytest.mark.asyncio()
|
||||||
|
async def test_audio_caching(stream):
|
||||||
|
litellm.cache = Cache(type="local")
|
||||||
|
|
||||||
|
## CALL 1 - no cache hit
|
||||||
|
completion = await litellm.acompletion(
|
||||||
|
model="gpt-4o-audio-preview",
|
||||||
|
modalities=["text", "audio"],
|
||||||
|
audio={"voice": "alloy", "format": "pcm16"},
|
||||||
|
messages=[{"role": "user", "content": "response in 1 word - yes or no"}],
|
||||||
|
stream=stream,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "cache_hit" not in completion._hidden_params
|
||||||
|
|
||||||
|
## CALL 2 - cache hit
|
||||||
|
completion = await litellm.acompletion(
|
||||||
|
model="gpt-4o-audio-preview",
|
||||||
|
modalities=["text", "audio"],
|
||||||
|
audio={"voice": "alloy", "format": "pcm16"},
|
||||||
|
messages=[{"role": "user", "content": "response in 1 word - yes or no"}],
|
||||||
|
stream=stream,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "cache_hit" in completion._hidden_params
|
||||||
|
|
|
@ -1267,6 +1267,100 @@ def test_standard_logging_payload(model, turn_off_message_logging):
|
||||||
assert "redacted-by-litellm" == slobject["response"]
|
assert "redacted-by-litellm" == slobject["response"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"stream",
|
||||||
|
[True, False],
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"turn_off_message_logging",
|
||||||
|
[
|
||||||
|
True,
|
||||||
|
],
|
||||||
|
) # False
|
||||||
|
def test_standard_logging_payload_audio(turn_off_message_logging, stream):
|
||||||
|
"""
|
||||||
|
Ensure valid standard_logging_payload is passed for logging calls to s3
|
||||||
|
|
||||||
|
Motivation: provide a standard set of things that are logged to s3/gcs/future integrations across all llm calls
|
||||||
|
"""
|
||||||
|
from litellm.types.utils import StandardLoggingPayload
|
||||||
|
|
||||||
|
# sync completion
|
||||||
|
customHandler = CompletionCustomHandler()
|
||||||
|
litellm.callbacks = [customHandler]
|
||||||
|
|
||||||
|
litellm.turn_off_message_logging = turn_off_message_logging
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
customHandler, "log_success_event", new=MagicMock()
|
||||||
|
) as mock_client:
|
||||||
|
response = litellm.completion(
|
||||||
|
model="gpt-4o-audio-preview",
|
||||||
|
modalities=["text", "audio"],
|
||||||
|
audio={"voice": "alloy", "format": "pcm16"},
|
||||||
|
messages=[{"role": "user", "content": "response in 1 word - yes or no"}],
|
||||||
|
stream=stream,
|
||||||
|
)
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
for chunk in response:
|
||||||
|
continue
|
||||||
|
|
||||||
|
time.sleep(2)
|
||||||
|
mock_client.assert_called_once()
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"mock_client_post.call_args: {mock_client.call_args.kwargs['kwargs'].keys()}"
|
||||||
|
)
|
||||||
|
assert "standard_logging_object" in mock_client.call_args.kwargs["kwargs"]
|
||||||
|
assert (
|
||||||
|
mock_client.call_args.kwargs["kwargs"]["standard_logging_object"]
|
||||||
|
is not None
|
||||||
|
)
|
||||||
|
|
||||||
|
print(
|
||||||
|
"Standard Logging Object - {}".format(
|
||||||
|
mock_client.call_args.kwargs["kwargs"]["standard_logging_object"]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
keys_list = list(StandardLoggingPayload.__annotations__.keys())
|
||||||
|
|
||||||
|
for k in keys_list:
|
||||||
|
assert (
|
||||||
|
k in mock_client.call_args.kwargs["kwargs"]["standard_logging_object"]
|
||||||
|
)
|
||||||
|
|
||||||
|
## json serializable
|
||||||
|
json_str_payload = json.dumps(
|
||||||
|
mock_client.call_args.kwargs["kwargs"]["standard_logging_object"]
|
||||||
|
)
|
||||||
|
json.loads(json_str_payload)
|
||||||
|
|
||||||
|
## response cost
|
||||||
|
assert (
|
||||||
|
mock_client.call_args.kwargs["kwargs"]["standard_logging_object"][
|
||||||
|
"response_cost"
|
||||||
|
]
|
||||||
|
> 0
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
mock_client.call_args.kwargs["kwargs"]["standard_logging_object"][
|
||||||
|
"model_map_information"
|
||||||
|
]["model_map_value"]
|
||||||
|
is not None
|
||||||
|
)
|
||||||
|
|
||||||
|
## turn off message logging
|
||||||
|
slobject: StandardLoggingPayload = mock_client.call_args.kwargs["kwargs"][
|
||||||
|
"standard_logging_object"
|
||||||
|
]
|
||||||
|
if turn_off_message_logging:
|
||||||
|
print("checks redacted-by-litellm")
|
||||||
|
assert "redacted-by-litellm" == slobject["messages"][0]["content"]
|
||||||
|
assert "redacted-by-litellm" == slobject["response"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="Works locally. Flaky on ci/cd")
|
@pytest.mark.skip(reason="Works locally. Flaky on ci/cd")
|
||||||
def test_aaastandard_logging_payload_cache_hit():
|
def test_aaastandard_logging_payload_cache_hit():
|
||||||
from litellm.types.utils import StandardLoggingPayload
|
from litellm.types.utils import StandardLoggingPayload
|
||||||
|
|
|
@ -6,6 +6,17 @@ import traceback
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from typing import List
|
from typing import List
|
||||||
|
from litellm.types.utils import StreamingChoices, ChatCompletionAudioResponse
|
||||||
|
|
||||||
|
|
||||||
|
def check_non_streaming_response(completion):
|
||||||
|
assert completion.choices[0].message.audio is not None, "Audio response is missing"
|
||||||
|
print("audio", completion.choices[0].message.audio)
|
||||||
|
assert isinstance(
|
||||||
|
completion.choices[0].message.audio, ChatCompletionAudioResponse
|
||||||
|
), "Invalid audio response type"
|
||||||
|
assert len(completion.choices[0].message.audio.data) > 0, "Audio data is empty"
|
||||||
|
|
||||||
|
|
||||||
sys.path.insert(
|
sys.path.insert(
|
||||||
0, os.path.abspath("../..")
|
0, os.path.abspath("../..")
|
||||||
|
@ -656,12 +667,60 @@ def test_stream_chunk_builder_openai_prompt_caching():
|
||||||
response = stream_chunk_builder(chunks=chunks)
|
response = stream_chunk_builder(chunks=chunks)
|
||||||
print(f"response: {response}")
|
print(f"response: {response}")
|
||||||
print(f"response usage: {response.usage}")
|
print(f"response usage: {response.usage}")
|
||||||
for k, v in usage_obj.model_dump().items():
|
for k, v in usage_obj.model_dump(exclude_none=True).items():
|
||||||
print(k, v)
|
print(k, v)
|
||||||
response_usage_value = getattr(response.usage, k) # type: ignore
|
response_usage_value = getattr(response.usage, k) # type: ignore
|
||||||
print(f"response_usage_value: {response_usage_value}")
|
print(f"response_usage_value: {response_usage_value}")
|
||||||
print(f"type: {type(response_usage_value)}")
|
print(f"type: {type(response_usage_value)}")
|
||||||
if isinstance(response_usage_value, BaseModel):
|
if isinstance(response_usage_value, BaseModel):
|
||||||
assert response_usage_value.model_dump() == v
|
assert response_usage_value.model_dump(exclude_none=True) == v
|
||||||
|
else:
|
||||||
|
assert response_usage_value == v
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_chunk_builder_openai_audio_output_usage():
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from openai import OpenAI
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
client = OpenAI(
|
||||||
|
# This is the default and can be omitted
|
||||||
|
api_key=os.getenv("OPENAI_API_KEY"),
|
||||||
|
)
|
||||||
|
|
||||||
|
completion = client.chat.completions.create(
|
||||||
|
model="gpt-4o-audio-preview",
|
||||||
|
modalities=["text", "audio"],
|
||||||
|
audio={"voice": "alloy", "format": "pcm16"},
|
||||||
|
messages=[{"role": "user", "content": "response in 1 word - yes or no"}],
|
||||||
|
stream=True,
|
||||||
|
stream_options={"include_usage": True},
|
||||||
|
)
|
||||||
|
|
||||||
|
chunks = []
|
||||||
|
for chunk in completion:
|
||||||
|
chunks.append(litellm.ModelResponse(**chunk.model_dump(), stream=True))
|
||||||
|
|
||||||
|
usage_obj: Optional[litellm.Usage] = None
|
||||||
|
|
||||||
|
for index, chunk in enumerate(chunks):
|
||||||
|
if hasattr(chunk, "usage"):
|
||||||
|
usage_obj = chunk.usage
|
||||||
|
print(f"chunk usage: {chunk.usage}")
|
||||||
|
print(f"index: {index}")
|
||||||
|
print(f"len chunks: {len(chunks)}")
|
||||||
|
|
||||||
|
print(f"usage_obj: {usage_obj}")
|
||||||
|
response = stream_chunk_builder(chunks=chunks)
|
||||||
|
print(f"response usage: {response.usage}")
|
||||||
|
check_non_streaming_response(response)
|
||||||
|
print(f"response: {response}")
|
||||||
|
for k, v in usage_obj.model_dump(exclude_none=True).items():
|
||||||
|
print(k, v)
|
||||||
|
response_usage_value = getattr(response.usage, k) # type: ignore
|
||||||
|
print(f"response_usage_value: {response_usage_value}")
|
||||||
|
print(f"type: {type(response_usage_value)}")
|
||||||
|
if isinstance(response_usage_value, BaseModel):
|
||||||
|
assert response_usage_value.model_dump(exclude_none=True) == v
|
||||||
else:
|
else:
|
||||||
assert response_usage_value == v
|
assert response_usage_value == v
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue