fix(utils.py): fix anthropic streaming usage calculation

Fixes https://github.com/BerriAI/litellm/issues/4965
This commit is contained in:
Krrish Dholakia 2024-08-01 14:45:54 -07:00
parent aad0bbb08c
commit ca0a0bed46
3 changed files with 59 additions and 29 deletions

View file

@ -5191,17 +5191,24 @@ def stream_chunk_builder(
prompt_tokens = 0
completion_tokens = 0
for chunk in chunks:
usage_chunk: Optional[Usage] = None
if "usage" in chunk:
if "prompt_tokens" in chunk["usage"]:
prompt_tokens = chunk["usage"].get("prompt_tokens", 0) or 0
if "completion_tokens" in chunk["usage"]:
completion_tokens = chunk["usage"].get("completion_tokens", 0) or 0
usage_chunk = chunk.usage
elif hasattr(chunk, "_hidden_params") and "usage" in chunk._hidden_params:
usage_chunk = chunk._hidden_params["usage"]
if usage_chunk is not None:
if "prompt_tokens" in usage_chunk:
prompt_tokens = usage_chunk.get("prompt_tokens", 0) or 0
if "completion_tokens" in usage_chunk:
completion_tokens = usage_chunk.get("completion_tokens", 0) or 0
try:
response["usage"]["prompt_tokens"] = prompt_tokens or token_counter(
model=model, messages=messages
)
except: # don't allow this failing to block a complete streaming response from being returned
print_verbose(f"token_counter failed, assuming prompt tokens is 0")
except (
Exception
): # don't allow this failing to block a complete streaming response from being returned
print_verbose("token_counter failed, assuming prompt tokens is 0")
response["usage"]["prompt_tokens"] = 0
response["usage"]["completion_tokens"] = completion_tokens or token_counter(
model=model,

View file

@ -3096,6 +3096,7 @@ def test_completion_claude_3_function_call_with_streaming():
elif idx == 1 and chunk.choices[0].finish_reason is None:
validate_second_streaming_function_calling_chunk(chunk=chunk)
elif chunk.choices[0].finish_reason is not None: # last chunk
assert "usage" in chunk._hidden_params
validate_final_streaming_function_calling_chunk(chunk=chunk)
idx += 1
# raise Exception("it worked!")

View file

@ -8378,6 +8378,28 @@ def get_secret(
######## Streaming Class ############################
# wraps the completion stream to return the correct format for the model
# replicate/anthropic/cohere
def calculate_total_usage(chunks: List[ModelResponse]) -> Usage:
"""Assume most recent usage chunk has total usage uptil then."""
prompt_tokens: int = 0
completion_tokens: int = 0
for chunk in chunks:
if "usage" in chunk:
if "prompt_tokens" in chunk["usage"]:
prompt_tokens = chunk["usage"].get("prompt_tokens", 0) or 0
if "completion_tokens" in chunk["usage"]:
completion_tokens = chunk["usage"].get("completion_tokens", 0) or 0
returned_usage_chunk = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
return returned_usage_chunk
class CustomStreamWrapper:
def __init__(
self,
@ -9267,7 +9289,9 @@ class CustomStreamWrapper:
verbose_logger.debug(traceback.format_exc())
return ""
def model_response_creator(self, chunk: Optional[dict] = None):
def model_response_creator(
self, chunk: Optional[dict] = None, hidden_params: Optional[dict] = None
):
_model = self.model
_received_llm_provider = self.custom_llm_provider
_logging_obj_llm_provider = self.logging_obj.model_call_details.get("custom_llm_provider", None) # type: ignore
@ -9281,6 +9305,7 @@ class CustomStreamWrapper:
else:
# pop model keyword
chunk.pop("model", None)
model_response = ModelResponse(
stream=True, model=_model, stream_options=self.stream_options, **chunk
)
@ -9290,6 +9315,8 @@ class CustomStreamWrapper:
self.response_id = model_response.id # type: ignore
if self.system_fingerprint is not None:
model_response.system_fingerprint = self.system_fingerprint
if hidden_params is not None:
model_response._hidden_params = hidden_params
model_response._hidden_params["custom_llm_provider"] = _logging_obj_llm_provider
model_response._hidden_params["created_at"] = time.time()
@ -9344,11 +9371,7 @@ class CustomStreamWrapper:
"finish_reason"
]
if (
self.stream_options
and self.stream_options.get("include_usage", False) is True
and anthropic_response_obj["usage"] is not None
):
if anthropic_response_obj["usage"] is not None:
model_response.usage = litellm.Usage(
prompt_tokens=anthropic_response_obj["usage"]["prompt_tokens"],
completion_tokens=anthropic_response_obj["usage"][
@ -9884,19 +9907,6 @@ class CustomStreamWrapper:
## RETURN ARG
if (
"content" in completion_obj
and isinstance(completion_obj["content"], str)
and len(completion_obj["content"]) == 0
and hasattr(model_response, "usage")
and hasattr(model_response.usage, "prompt_tokens")
):
if self.sent_first_chunk is False:
completion_obj["role"] = "assistant"
self.sent_first_chunk = True
model_response.choices[0].delta = Delta(**completion_obj)
print_verbose(f"returning model_response: {model_response}")
return model_response
elif (
"content" in completion_obj
and (
isinstance(completion_obj["content"], str)
@ -9991,6 +10001,7 @@ class CustomStreamWrapper:
model_response.choices[0].finish_reason = map_finish_reason(
finish_reason=self.received_finish_reason
) # ensure consistent output to openai
self.sent_last_chunk = True
return model_response
@ -10003,6 +10014,8 @@ class CustomStreamWrapper:
self.sent_first_chunk = True
return model_response
else:
if hasattr(model_response, "usage"):
self.chunks.append(model_response)
return
except StopIteration:
raise StopIteration
@ -10119,17 +10132,22 @@ class CustomStreamWrapper:
del obj_dict["usage"]
# Create a new object without the removed attribute
response = self.model_response_creator(chunk=obj_dict)
response = self.model_response_creator(
chunk=obj_dict, hidden_params=response._hidden_params
)
# add usage as hidden param
if self.sent_last_chunk is True and self.stream_options is None:
usage = calculate_total_usage(chunks=self.chunks)
response._hidden_params["usage"] = usage
# RETURN RESULT
return response
except StopIteration:
if self.sent_last_chunk is True:
if (
self.sent_stream_usage == False
self.sent_stream_usage is False
and self.stream_options is not None
and self.stream_options.get("include_usage", False) == True
and self.stream_options.get("include_usage", False) is True
):
# send the final chunk with stream options
complete_streaming_response = litellm.stream_chunk_builder(
@ -10137,6 +10155,7 @@ class CustomStreamWrapper:
)
response = self.model_response_creator()
response.usage = complete_streaming_response.usage # type: ignore
response._hidden_params["usage"] = complete_streaming_response.usage # type: ignore
## LOGGING
threading.Thread(
target=self.logging_obj.success_handler,
@ -10148,6 +10167,9 @@ class CustomStreamWrapper:
else:
self.sent_last_chunk = True
processed_chunk = self.finish_reason_handler()
if self.stream_options is None: # add usage as hidden param
usage = calculate_total_usage(chunks=self.chunks)
setattr(processed_chunk, "usage", usage)
## LOGGING
threading.Thread(
target=self.logging_obj.success_handler,