fix(utils.py): return openai streaming prompt caching tokens (#6051)

* fix(utils.py): return openai streaming prompt caching tokens

Closes https://github.com/BerriAI/litellm/issues/6038

* fix(main.py): fix error in finish_reason updates
This commit is contained in:
Krish Dholakia 2024-10-03 22:20:13 -04:00 committed by GitHub
parent 04ae095860
commit 09f0c09ba4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 91 additions and 10 deletions

View file

@ -144,8 +144,10 @@ from .types.llms.openai import HttpxBinaryResponseContent
from .types.utils import ( from .types.utils import (
AdapterCompletionStreamWrapper, AdapterCompletionStreamWrapper,
ChatCompletionMessageToolCall, ChatCompletionMessageToolCall,
CompletionTokensDetails,
FileTypes, FileTypes,
HiddenParams, HiddenParams,
PromptTokensDetails,
all_litellm_params, all_litellm_params,
) )
@ -5481,7 +5483,13 @@ def stream_chunk_builder(
chunks=chunks, messages=messages chunks=chunks, messages=messages
) )
role = chunks[0]["choices"][0]["delta"]["role"] role = chunks[0]["choices"][0]["delta"]["role"]
finish_reason = chunks[-1]["choices"][0]["finish_reason"] finish_reason = "stop"
for chunk in chunks:
if "choices" in chunk and len(chunk["choices"]) > 0:
if hasattr(chunk["choices"][0], "finish_reason"):
finish_reason = chunk["choices"][0].finish_reason
elif "finish_reason" in chunk["choices"][0]:
finish_reason = chunk["choices"][0]["finish_reason"]
# Initialize the response dictionary # Initialize the response dictionary
response = { response = {
@ -5512,7 +5520,8 @@ def stream_chunk_builder(
tool_call_chunks = [ tool_call_chunks = [
chunk chunk
for chunk in chunks for chunk in chunks
if "tool_calls" in chunk["choices"][0]["delta"] if len(chunk["choices"]) > 0
and "tool_calls" in chunk["choices"][0]["delta"]
and chunk["choices"][0]["delta"]["tool_calls"] is not None and chunk["choices"][0]["delta"]["tool_calls"] is not None
] ]
@ -5590,7 +5599,8 @@ def stream_chunk_builder(
function_call_chunks = [ function_call_chunks = [
chunk chunk
for chunk in chunks for chunk in chunks
if "function_call" in chunk["choices"][0]["delta"] if len(chunk["choices"]) > 0
and "function_call" in chunk["choices"][0]["delta"]
and chunk["choices"][0]["delta"]["function_call"] is not None and chunk["choices"][0]["delta"]["function_call"] is not None
] ]
@ -5625,7 +5635,8 @@ def stream_chunk_builder(
content_chunks = [ content_chunks = [
chunk chunk
for chunk in chunks for chunk in chunks
if "content" in chunk["choices"][0]["delta"] if len(chunk["choices"]) > 0
and "content" in chunk["choices"][0]["delta"]
and chunk["choices"][0]["delta"]["content"] is not None and chunk["choices"][0]["delta"]["content"] is not None
] ]
@ -5657,6 +5668,8 @@ def stream_chunk_builder(
## anthropic prompt caching information ## ## anthropic prompt caching information ##
cache_creation_input_tokens: Optional[int] = None cache_creation_input_tokens: Optional[int] = None
cache_read_input_tokens: Optional[int] = None cache_read_input_tokens: Optional[int] = None
completion_tokens_details: Optional[CompletionTokensDetails] = None
prompt_tokens_details: Optional[PromptTokensDetails] = None
for chunk in chunks: for chunk in chunks:
usage_chunk: Optional[Usage] = None usage_chunk: Optional[Usage] = None
if "usage" in chunk: if "usage" in chunk:
@ -5674,6 +5687,26 @@ def stream_chunk_builder(
) )
if "cache_read_input_tokens" in usage_chunk: if "cache_read_input_tokens" in usage_chunk:
cache_read_input_tokens = usage_chunk.get("cache_read_input_tokens") cache_read_input_tokens = usage_chunk.get("cache_read_input_tokens")
if hasattr(usage_chunk, "completion_tokens_details"):
if isinstance(usage_chunk.completion_tokens_details, dict):
completion_tokens_details = CompletionTokensDetails(
**usage_chunk.completion_tokens_details
)
elif isinstance(
usage_chunk.completion_tokens_details, CompletionTokensDetails
):
completion_tokens_details = (
usage_chunk.completion_tokens_details
)
if hasattr(usage_chunk, "prompt_tokens_details"):
if isinstance(usage_chunk.prompt_tokens_details, dict):
prompt_tokens_details = PromptTokensDetails(
**usage_chunk.prompt_tokens_details
)
elif isinstance(
usage_chunk.prompt_tokens_details, PromptTokensDetails
):
prompt_tokens_details = usage_chunk.prompt_tokens_details
try: try:
response["usage"]["prompt_tokens"] = prompt_tokens or token_counter( response["usage"]["prompt_tokens"] = prompt_tokens or token_counter(
@ -5700,6 +5733,11 @@ def stream_chunk_builder(
if cache_read_input_tokens is not None: if cache_read_input_tokens is not None:
response["usage"]["cache_read_input_tokens"] = cache_read_input_tokens response["usage"]["cache_read_input_tokens"] = cache_read_input_tokens
if completion_tokens_details is not None:
response["usage"]["completion_tokens_details"] = completion_tokens_details
if prompt_tokens_details is not None:
response["usage"]["prompt_tokens_details"] = prompt_tokens_details
return convert_to_model_response_object( return convert_to_model_response_object(
response_object=response, response_object=response,
model_response_object=model_response, model_response_object=model_response,

View file

@ -1,5 +1,5 @@
model_list: model_list:
- model_name: gpt-4o-realtime-audio - model_name: gpt-4o
litellm_params: litellm_params:
model: azure/gpt-4o-realtime-preview model: azure/gpt-4o-realtime-preview
api_key: os.environ/AZURE_SWEDEN_API_KEY api_key: os.environ/AZURE_SWEDEN_API_KEY

View file

@ -11,7 +11,7 @@ from openai.types.completion_usage import (
CompletionUsage, CompletionUsage,
PromptTokensDetails, PromptTokensDetails,
) )
from pydantic import ConfigDict, PrivateAttr from pydantic import BaseModel, ConfigDict, PrivateAttr
from typing_extensions import Callable, Dict, Required, TypedDict, override from typing_extensions import Callable, Dict, Required, TypedDict, override
from ..litellm_core_utils.core_helpers import map_finish_reason from ..litellm_core_utils.core_helpers import map_finish_reason
@ -677,6 +677,8 @@ class ModelResponse(OpenAIObject):
_new_choice = choice _new_choice = choice
elif isinstance(choice, dict): elif isinstance(choice, dict):
_new_choice = StreamingChoices(**choice) _new_choice = StreamingChoices(**choice)
elif isinstance(choice, BaseModel):
_new_choice = StreamingChoices(**choice.model_dump())
new_choices.append(_new_choice) new_choices.append(_new_choice)
choices = new_choices choices = new_choices
else: else:

View file

@ -7813,9 +7813,7 @@ class CustomStreamWrapper:
) )
elif isinstance(response_obj["usage"], BaseModel): elif isinstance(response_obj["usage"], BaseModel):
model_response.usage = litellm.Usage( model_response.usage = litellm.Usage(
prompt_tokens=response_obj["usage"].prompt_tokens, **response_obj["usage"].model_dump()
completion_tokens=response_obj["usage"].completion_tokens,
total_tokens=response_obj["usage"].total_tokens,
) )
model_response.model = self.model model_response.model = self.model

View file

@ -5,6 +5,7 @@ import time
import traceback import traceback
import pytest import pytest
from typing import List
sys.path.insert( sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
@ -12,7 +13,6 @@ sys.path.insert(
import os import os
import dotenv import dotenv
import pytest
from openai import OpenAI from openai import OpenAI
import litellm import litellm
@ -622,3 +622,46 @@ def test_stream_chunk_builder_multiple_tool_calls():
assert ( assert (
expected_response.choices == response.choices expected_response.choices == response.choices
), "\nGot={}\n, Expected={}\n".format(response.choices, expected_response.choices) ), "\nGot={}\n, Expected={}\n".format(response.choices, expected_response.choices)
def test_stream_chunk_builder_openai_prompt_caching():
from openai import OpenAI
from pydantic import BaseModel
client = OpenAI(
# This is the default and can be omitted
api_key=os.getenv("OPENAI_API_KEY"),
)
chat_completion = client.chat.completions.create(
messages=[
{
"role": "user",
"content": "Say this is a test",
}
],
model="gpt-3.5-turbo",
stream=True,
stream_options={"include_usage": True},
)
chunks: List[litellm.ModelResponse] = []
usage_obj = None
for chunk in chat_completion:
chunks.append(litellm.ModelResponse(**chunk.model_dump(), stream=True))
print(f"chunks: {chunks}")
usage_obj: litellm.Usage = chunks[-1].usage # type: ignore
response = stream_chunk_builder(chunks=chunks)
print(f"response: {response}")
print(f"response usage: {response.usage}")
for k, v in usage_obj.model_dump().items():
print(k, v)
response_usage_value = getattr(response.usage, k) # type: ignore
print(f"response_usage_value: {response_usage_value}")
print(f"type: {type(response_usage_value)}")
if isinstance(response_usage_value, BaseModel):
assert response_usage_value.model_dump() == v
else:
assert response_usage_value == v