forked from phoenix/litellm-mirror
fix(utils.py): fix anthropic streaming usage calculation
Fixes https://github.com/BerriAI/litellm/issues/4965
This commit is contained in:
parent
aad0bbb08c
commit
ca0a0bed46
3 changed files with 59 additions and 29 deletions
|
@ -5191,17 +5191,24 @@ def stream_chunk_builder(
|
|||
prompt_tokens = 0
|
||||
completion_tokens = 0
|
||||
for chunk in chunks:
|
||||
usage_chunk: Optional[Usage] = None
|
||||
if "usage" in chunk:
|
||||
if "prompt_tokens" in chunk["usage"]:
|
||||
prompt_tokens = chunk["usage"].get("prompt_tokens", 0) or 0
|
||||
if "completion_tokens" in chunk["usage"]:
|
||||
completion_tokens = chunk["usage"].get("completion_tokens", 0) or 0
|
||||
usage_chunk = chunk.usage
|
||||
elif hasattr(chunk, "_hidden_params") and "usage" in chunk._hidden_params:
|
||||
usage_chunk = chunk._hidden_params["usage"]
|
||||
if usage_chunk is not None:
|
||||
if "prompt_tokens" in usage_chunk:
|
||||
prompt_tokens = usage_chunk.get("prompt_tokens", 0) or 0
|
||||
if "completion_tokens" in usage_chunk:
|
||||
completion_tokens = usage_chunk.get("completion_tokens", 0) or 0
|
||||
try:
|
||||
response["usage"]["prompt_tokens"] = prompt_tokens or token_counter(
|
||||
model=model, messages=messages
|
||||
)
|
||||
except: # don't allow this failing to block a complete streaming response from being returned
|
||||
print_verbose(f"token_counter failed, assuming prompt tokens is 0")
|
||||
except (
|
||||
Exception
|
||||
): # don't allow this failing to block a complete streaming response from being returned
|
||||
print_verbose("token_counter failed, assuming prompt tokens is 0")
|
||||
response["usage"]["prompt_tokens"] = 0
|
||||
response["usage"]["completion_tokens"] = completion_tokens or token_counter(
|
||||
model=model,
|
||||
|
|
|
@ -3096,6 +3096,7 @@ def test_completion_claude_3_function_call_with_streaming():
|
|||
elif idx == 1 and chunk.choices[0].finish_reason is None:
|
||||
validate_second_streaming_function_calling_chunk(chunk=chunk)
|
||||
elif chunk.choices[0].finish_reason is not None: # last chunk
|
||||
assert "usage" in chunk._hidden_params
|
||||
validate_final_streaming_function_calling_chunk(chunk=chunk)
|
||||
idx += 1
|
||||
# raise Exception("it worked!")
|
||||
|
|
|
@ -8378,6 +8378,28 @@ def get_secret(
|
|||
######## Streaming Class ############################
|
||||
# wraps the completion stream to return the correct format for the model
|
||||
# replicate/anthropic/cohere
|
||||
|
||||
|
||||
def calculate_total_usage(chunks: List[ModelResponse]) -> Usage:
|
||||
"""Assume most recent usage chunk has total usage uptil then."""
|
||||
prompt_tokens: int = 0
|
||||
completion_tokens: int = 0
|
||||
for chunk in chunks:
|
||||
if "usage" in chunk:
|
||||
if "prompt_tokens" in chunk["usage"]:
|
||||
prompt_tokens = chunk["usage"].get("prompt_tokens", 0) or 0
|
||||
if "completion_tokens" in chunk["usage"]:
|
||||
completion_tokens = chunk["usage"].get("completion_tokens", 0) or 0
|
||||
|
||||
returned_usage_chunk = Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
|
||||
return returned_usage_chunk
|
||||
|
||||
|
||||
class CustomStreamWrapper:
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -9267,7 +9289,9 @@ class CustomStreamWrapper:
|
|||
verbose_logger.debug(traceback.format_exc())
|
||||
return ""
|
||||
|
||||
def model_response_creator(self, chunk: Optional[dict] = None):
|
||||
def model_response_creator(
|
||||
self, chunk: Optional[dict] = None, hidden_params: Optional[dict] = None
|
||||
):
|
||||
_model = self.model
|
||||
_received_llm_provider = self.custom_llm_provider
|
||||
_logging_obj_llm_provider = self.logging_obj.model_call_details.get("custom_llm_provider", None) # type: ignore
|
||||
|
@ -9281,6 +9305,7 @@ class CustomStreamWrapper:
|
|||
else:
|
||||
# pop model keyword
|
||||
chunk.pop("model", None)
|
||||
|
||||
model_response = ModelResponse(
|
||||
stream=True, model=_model, stream_options=self.stream_options, **chunk
|
||||
)
|
||||
|
@ -9290,6 +9315,8 @@ class CustomStreamWrapper:
|
|||
self.response_id = model_response.id # type: ignore
|
||||
if self.system_fingerprint is not None:
|
||||
model_response.system_fingerprint = self.system_fingerprint
|
||||
if hidden_params is not None:
|
||||
model_response._hidden_params = hidden_params
|
||||
model_response._hidden_params["custom_llm_provider"] = _logging_obj_llm_provider
|
||||
model_response._hidden_params["created_at"] = time.time()
|
||||
|
||||
|
@ -9344,11 +9371,7 @@ class CustomStreamWrapper:
|
|||
"finish_reason"
|
||||
]
|
||||
|
||||
if (
|
||||
self.stream_options
|
||||
and self.stream_options.get("include_usage", False) is True
|
||||
and anthropic_response_obj["usage"] is not None
|
||||
):
|
||||
if anthropic_response_obj["usage"] is not None:
|
||||
model_response.usage = litellm.Usage(
|
||||
prompt_tokens=anthropic_response_obj["usage"]["prompt_tokens"],
|
||||
completion_tokens=anthropic_response_obj["usage"][
|
||||
|
@ -9884,19 +9907,6 @@ class CustomStreamWrapper:
|
|||
|
||||
## RETURN ARG
|
||||
if (
|
||||
"content" in completion_obj
|
||||
and isinstance(completion_obj["content"], str)
|
||||
and len(completion_obj["content"]) == 0
|
||||
and hasattr(model_response, "usage")
|
||||
and hasattr(model_response.usage, "prompt_tokens")
|
||||
):
|
||||
if self.sent_first_chunk is False:
|
||||
completion_obj["role"] = "assistant"
|
||||
self.sent_first_chunk = True
|
||||
model_response.choices[0].delta = Delta(**completion_obj)
|
||||
print_verbose(f"returning model_response: {model_response}")
|
||||
return model_response
|
||||
elif (
|
||||
"content" in completion_obj
|
||||
and (
|
||||
isinstance(completion_obj["content"], str)
|
||||
|
@ -9991,6 +10001,7 @@ class CustomStreamWrapper:
|
|||
model_response.choices[0].finish_reason = map_finish_reason(
|
||||
finish_reason=self.received_finish_reason
|
||||
) # ensure consistent output to openai
|
||||
|
||||
self.sent_last_chunk = True
|
||||
|
||||
return model_response
|
||||
|
@ -10003,6 +10014,8 @@ class CustomStreamWrapper:
|
|||
self.sent_first_chunk = True
|
||||
return model_response
|
||||
else:
|
||||
if hasattr(model_response, "usage"):
|
||||
self.chunks.append(model_response)
|
||||
return
|
||||
except StopIteration:
|
||||
raise StopIteration
|
||||
|
@ -10119,17 +10132,22 @@ class CustomStreamWrapper:
|
|||
del obj_dict["usage"]
|
||||
|
||||
# Create a new object without the removed attribute
|
||||
response = self.model_response_creator(chunk=obj_dict)
|
||||
|
||||
response = self.model_response_creator(
|
||||
chunk=obj_dict, hidden_params=response._hidden_params
|
||||
)
|
||||
# add usage as hidden param
|
||||
if self.sent_last_chunk is True and self.stream_options is None:
|
||||
usage = calculate_total_usage(chunks=self.chunks)
|
||||
response._hidden_params["usage"] = usage
|
||||
# RETURN RESULT
|
||||
return response
|
||||
|
||||
except StopIteration:
|
||||
if self.sent_last_chunk is True:
|
||||
if (
|
||||
self.sent_stream_usage == False
|
||||
self.sent_stream_usage is False
|
||||
and self.stream_options is not None
|
||||
and self.stream_options.get("include_usage", False) == True
|
||||
and self.stream_options.get("include_usage", False) is True
|
||||
):
|
||||
# send the final chunk with stream options
|
||||
complete_streaming_response = litellm.stream_chunk_builder(
|
||||
|
@ -10137,6 +10155,7 @@ class CustomStreamWrapper:
|
|||
)
|
||||
response = self.model_response_creator()
|
||||
response.usage = complete_streaming_response.usage # type: ignore
|
||||
response._hidden_params["usage"] = complete_streaming_response.usage # type: ignore
|
||||
## LOGGING
|
||||
threading.Thread(
|
||||
target=self.logging_obj.success_handler,
|
||||
|
@ -10148,6 +10167,9 @@ class CustomStreamWrapper:
|
|||
else:
|
||||
self.sent_last_chunk = True
|
||||
processed_chunk = self.finish_reason_handler()
|
||||
if self.stream_options is None: # add usage as hidden param
|
||||
usage = calculate_total_usage(chunks=self.chunks)
|
||||
setattr(processed_chunk, "usage", usage)
|
||||
## LOGGING
|
||||
threading.Thread(
|
||||
target=self.logging_obj.success_handler,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue