forked from phoenix/litellm-mirror
fix(utils.py): return openai streaming prompt caching tokens (#6051)
* fix(utils.py): return openai streaming prompt caching tokens Closes https://github.com/BerriAI/litellm/issues/6038 * fix(main.py): fix error in finish_reason updates
This commit is contained in:
parent
04ae095860
commit
09f0c09ba4
5 changed files with 91 additions and 10 deletions
|
@ -5,6 +5,7 @@ import time
|
|||
import traceback
|
||||
|
||||
import pytest
|
||||
from typing import List
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
|
@ -12,7 +13,6 @@ sys.path.insert(
|
|||
import os
|
||||
|
||||
import dotenv
|
||||
import pytest
|
||||
from openai import OpenAI
|
||||
|
||||
import litellm
|
||||
|
@ -622,3 +622,46 @@ def test_stream_chunk_builder_multiple_tool_calls():
|
|||
assert (
|
||||
expected_response.choices == response.choices
|
||||
), "\nGot={}\n, Expected={}\n".format(response.choices, expected_response.choices)
|
||||
|
||||
|
||||
def test_stream_chunk_builder_openai_prompt_caching():
|
||||
from openai import OpenAI
|
||||
from pydantic import BaseModel
|
||||
|
||||
client = OpenAI(
|
||||
# This is the default and can be omitted
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
)
|
||||
|
||||
chat_completion = client.chat.completions.create(
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Say this is a test",
|
||||
}
|
||||
],
|
||||
model="gpt-3.5-turbo",
|
||||
stream=True,
|
||||
stream_options={"include_usage": True},
|
||||
)
|
||||
chunks: List[litellm.ModelResponse] = []
|
||||
usage_obj = None
|
||||
for chunk in chat_completion:
|
||||
chunks.append(litellm.ModelResponse(**chunk.model_dump(), stream=True))
|
||||
|
||||
print(f"chunks: {chunks}")
|
||||
|
||||
usage_obj: litellm.Usage = chunks[-1].usage # type: ignore
|
||||
|
||||
response = stream_chunk_builder(chunks=chunks)
|
||||
print(f"response: {response}")
|
||||
print(f"response usage: {response.usage}")
|
||||
for k, v in usage_obj.model_dump().items():
|
||||
print(k, v)
|
||||
response_usage_value = getattr(response.usage, k) # type: ignore
|
||||
print(f"response_usage_value: {response_usage_value}")
|
||||
print(f"type: {type(response_usage_value)}")
|
||||
if isinstance(response_usage_value, BaseModel):
|
||||
assert response_usage_value.model_dump() == v
|
||||
else:
|
||||
assert response_usage_value == v
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue