fix(utils.py): return openai streaming prompt caching tokens (#6051)

* fix(utils.py): return openai streaming prompt caching tokens

Closes https://github.com/BerriAI/litellm/issues/6038

* fix(main.py): fix error in finish_reason updates
This commit is contained in:
Krish Dholakia 2024-10-03 22:20:13 -04:00 committed by GitHub
parent 04ae095860
commit 09f0c09ba4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 91 additions and 10 deletions

View file

@ -5,6 +5,7 @@ import time
import traceback
import pytest
from typing import List
sys.path.insert(
0, os.path.abspath("../..")
@ -12,7 +13,6 @@ sys.path.insert(
import os
import dotenv
import pytest
from openai import OpenAI
import litellm
@ -622,3 +622,46 @@ def test_stream_chunk_builder_multiple_tool_calls():
assert (
expected_response.choices == response.choices
), "\nGot={}\n, Expected={}\n".format(response.choices, expected_response.choices)
def test_stream_chunk_builder_openai_prompt_caching():
from openai import OpenAI
from pydantic import BaseModel
client = OpenAI(
# This is the default and can be omitted
api_key=os.getenv("OPENAI_API_KEY"),
)
chat_completion = client.chat.completions.create(
messages=[
{
"role": "user",
"content": "Say this is a test",
}
],
model="gpt-3.5-turbo",
stream=True,
stream_options={"include_usage": True},
)
chunks: List[litellm.ModelResponse] = []
usage_obj = None
for chunk in chat_completion:
chunks.append(litellm.ModelResponse(**chunk.model_dump(), stream=True))
print(f"chunks: {chunks}")
usage_obj: litellm.Usage = chunks[-1].usage # type: ignore
response = stream_chunk_builder(chunks=chunks)
print(f"response: {response}")
print(f"response usage: {response.usage}")
for k, v in usage_obj.model_dump().items():
print(k, v)
response_usage_value = getattr(response.usage, k) # type: ignore
print(f"response_usage_value: {response_usage_value}")
print(f"type: {type(response_usage_value)}")
if isinstance(response_usage_value, BaseModel):
assert response_usage_value.model_dump() == v
else:
assert response_usage_value == v