fix(main.py): fix stream_chunk_builder usage calc

Closes https://github.com/BerriAI/litellm/issues/4496
This commit is contained in:
Krrish Dholakia 2024-07-06 14:52:38 -07:00
parent 6cce966139
commit f89632f5ac
3 changed files with 57 additions and 12 deletions

View file

@ -5022,10 +5022,9 @@ def stream_chunk_builder(
for chunk in chunks: for chunk in chunks:
if "usage" in chunk: if "usage" in chunk:
if "prompt_tokens" in chunk["usage"]: if "prompt_tokens" in chunk["usage"]:
prompt_tokens += chunk["usage"].get("prompt_tokens", 0) or 0 prompt_tokens = chunk["usage"].get("prompt_tokens", 0) or 0
if "completion_tokens" in chunk["usage"]: if "completion_tokens" in chunk["usage"]:
completion_tokens += chunk["usage"].get("completion_tokens", 0) or 0 completion_tokens = chunk["usage"].get("completion_tokens", 0) or 0
try: try:
response["usage"]["prompt_tokens"] = prompt_tokens or token_counter( response["usage"]["prompt_tokens"] = prompt_tokens or token_counter(
model=model, messages=messages model=model, messages=messages

View file

@ -2,11 +2,8 @@ model_list:
- model_name: "*" - model_name: "*"
litellm_params: litellm_params:
model: "openai/*" model: "openai/*"
mock_response: "Hello world!"
litellm_settings: litellm_settings:
success_callback: ["langfuse"] success_callback: ["langfuse"]
failure_callback: ["langfuse"]
general_settings: general_settings:
alerting: ["slack"] alerting: ["slack"]

View file

@ -1,15 +1,22 @@
import sys, os, time import asyncio
import traceback, asyncio import os
import sys
import time
import traceback
import pytest import pytest
sys.path.insert( sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
from litellm import completion, stream_chunk_builder import os
import litellm
import os, dotenv import dotenv
from openai import OpenAI
import pytest import pytest
from openai import OpenAI
import litellm
from litellm import completion, stream_chunk_builder
dotenv.load_dotenv() dotenv.load_dotenv()
@ -147,3 +154,45 @@ def test_stream_chunk_builder_litellm_tool_call_regular_message():
# test_stream_chunk_builder_litellm_tool_call_regular_message() # test_stream_chunk_builder_litellm_tool_call_regular_message()
def test_stream_chunk_builder_litellm_usage_chunks():
"""
Checks if stream_chunk_builder is able to correctly rebuild with given metadata from streaming chunks
"""
messages = [
{"role": "user", "content": "Tell me the funniest joke you know."},
{
"role": "assistant",
"content": "Why did the chicken cross the road?\nYou will not guess this one I bet\n",
},
{"role": "user", "content": "I do not know, why?"},
{"role": "assistant", "content": "uhhhh\n\n\nhmmmm.....\nthinking....\n"},
{"role": "user", "content": "\nI am waiting...\n\n...\n"},
]
# make a regular gemini call
response = completion(
model="gemini/gemini-1.5-flash",
messages=messages,
)
usage: litellm.Usage = response.usage
gemini_pt = usage.prompt_tokens
# make a streaming gemini call
response = completion(
model="gemini/gemini-1.5-flash",
messages=messages,
stream=True,
complete_response=True,
stream_options={"include_usage": True},
)
usage: litellm.Usage = response.usage
stream_rebuilt_pt = usage.prompt_tokens
# assert prompt tokens are the same
assert gemini_pt == stream_rebuilt_pt