forked from phoenix/litellm-mirror
fix(utils.py): fix anthropic streaming usage calculation
Fixes https://github.com/BerriAI/litellm/issues/4965
This commit is contained in:
parent
aad0bbb08c
commit
ca0a0bed46
3 changed files with 59 additions and 29 deletions
|
@ -5191,17 +5191,24 @@ def stream_chunk_builder(
|
||||||
prompt_tokens = 0
|
prompt_tokens = 0
|
||||||
completion_tokens = 0
|
completion_tokens = 0
|
||||||
for chunk in chunks:
|
for chunk in chunks:
|
||||||
|
usage_chunk: Optional[Usage] = None
|
||||||
if "usage" in chunk:
|
if "usage" in chunk:
|
||||||
if "prompt_tokens" in chunk["usage"]:
|
usage_chunk = chunk.usage
|
||||||
prompt_tokens = chunk["usage"].get("prompt_tokens", 0) or 0
|
elif hasattr(chunk, "_hidden_params") and "usage" in chunk._hidden_params:
|
||||||
if "completion_tokens" in chunk["usage"]:
|
usage_chunk = chunk._hidden_params["usage"]
|
||||||
completion_tokens = chunk["usage"].get("completion_tokens", 0) or 0
|
if usage_chunk is not None:
|
||||||
|
if "prompt_tokens" in usage_chunk:
|
||||||
|
prompt_tokens = usage_chunk.get("prompt_tokens", 0) or 0
|
||||||
|
if "completion_tokens" in usage_chunk:
|
||||||
|
completion_tokens = usage_chunk.get("completion_tokens", 0) or 0
|
||||||
try:
|
try:
|
||||||
response["usage"]["prompt_tokens"] = prompt_tokens or token_counter(
|
response["usage"]["prompt_tokens"] = prompt_tokens or token_counter(
|
||||||
model=model, messages=messages
|
model=model, messages=messages
|
||||||
)
|
)
|
||||||
except: # don't allow this failing to block a complete streaming response from being returned
|
except (
|
||||||
print_verbose(f"token_counter failed, assuming prompt tokens is 0")
|
Exception
|
||||||
|
): # don't allow this failing to block a complete streaming response from being returned
|
||||||
|
print_verbose("token_counter failed, assuming prompt tokens is 0")
|
||||||
response["usage"]["prompt_tokens"] = 0
|
response["usage"]["prompt_tokens"] = 0
|
||||||
response["usage"]["completion_tokens"] = completion_tokens or token_counter(
|
response["usage"]["completion_tokens"] = completion_tokens or token_counter(
|
||||||
model=model,
|
model=model,
|
||||||
|
|
|
@ -3096,6 +3096,7 @@ def test_completion_claude_3_function_call_with_streaming():
|
||||||
elif idx == 1 and chunk.choices[0].finish_reason is None:
|
elif idx == 1 and chunk.choices[0].finish_reason is None:
|
||||||
validate_second_streaming_function_calling_chunk(chunk=chunk)
|
validate_second_streaming_function_calling_chunk(chunk=chunk)
|
||||||
elif chunk.choices[0].finish_reason is not None: # last chunk
|
elif chunk.choices[0].finish_reason is not None: # last chunk
|
||||||
|
assert "usage" in chunk._hidden_params
|
||||||
validate_final_streaming_function_calling_chunk(chunk=chunk)
|
validate_final_streaming_function_calling_chunk(chunk=chunk)
|
||||||
idx += 1
|
idx += 1
|
||||||
# raise Exception("it worked!")
|
# raise Exception("it worked!")
|
||||||
|
|
|
@ -8378,6 +8378,28 @@ def get_secret(
|
||||||
######## Streaming Class ############################
|
######## Streaming Class ############################
|
||||||
# wraps the completion stream to return the correct format for the model
|
# wraps the completion stream to return the correct format for the model
|
||||||
# replicate/anthropic/cohere
|
# replicate/anthropic/cohere
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_total_usage(chunks: List[ModelResponse]) -> Usage:
|
||||||
|
"""Assume most recent usage chunk has total usage uptil then."""
|
||||||
|
prompt_tokens: int = 0
|
||||||
|
completion_tokens: int = 0
|
||||||
|
for chunk in chunks:
|
||||||
|
if "usage" in chunk:
|
||||||
|
if "prompt_tokens" in chunk["usage"]:
|
||||||
|
prompt_tokens = chunk["usage"].get("prompt_tokens", 0) or 0
|
||||||
|
if "completion_tokens" in chunk["usage"]:
|
||||||
|
completion_tokens = chunk["usage"].get("completion_tokens", 0) or 0
|
||||||
|
|
||||||
|
returned_usage_chunk = Usage(
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
total_tokens=prompt_tokens + completion_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
return returned_usage_chunk
|
||||||
|
|
||||||
|
|
||||||
class CustomStreamWrapper:
|
class CustomStreamWrapper:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -9267,7 +9289,9 @@ class CustomStreamWrapper:
|
||||||
verbose_logger.debug(traceback.format_exc())
|
verbose_logger.debug(traceback.format_exc())
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
def model_response_creator(self, chunk: Optional[dict] = None):
|
def model_response_creator(
|
||||||
|
self, chunk: Optional[dict] = None, hidden_params: Optional[dict] = None
|
||||||
|
):
|
||||||
_model = self.model
|
_model = self.model
|
||||||
_received_llm_provider = self.custom_llm_provider
|
_received_llm_provider = self.custom_llm_provider
|
||||||
_logging_obj_llm_provider = self.logging_obj.model_call_details.get("custom_llm_provider", None) # type: ignore
|
_logging_obj_llm_provider = self.logging_obj.model_call_details.get("custom_llm_provider", None) # type: ignore
|
||||||
|
@ -9281,6 +9305,7 @@ class CustomStreamWrapper:
|
||||||
else:
|
else:
|
||||||
# pop model keyword
|
# pop model keyword
|
||||||
chunk.pop("model", None)
|
chunk.pop("model", None)
|
||||||
|
|
||||||
model_response = ModelResponse(
|
model_response = ModelResponse(
|
||||||
stream=True, model=_model, stream_options=self.stream_options, **chunk
|
stream=True, model=_model, stream_options=self.stream_options, **chunk
|
||||||
)
|
)
|
||||||
|
@ -9290,6 +9315,8 @@ class CustomStreamWrapper:
|
||||||
self.response_id = model_response.id # type: ignore
|
self.response_id = model_response.id # type: ignore
|
||||||
if self.system_fingerprint is not None:
|
if self.system_fingerprint is not None:
|
||||||
model_response.system_fingerprint = self.system_fingerprint
|
model_response.system_fingerprint = self.system_fingerprint
|
||||||
|
if hidden_params is not None:
|
||||||
|
model_response._hidden_params = hidden_params
|
||||||
model_response._hidden_params["custom_llm_provider"] = _logging_obj_llm_provider
|
model_response._hidden_params["custom_llm_provider"] = _logging_obj_llm_provider
|
||||||
model_response._hidden_params["created_at"] = time.time()
|
model_response._hidden_params["created_at"] = time.time()
|
||||||
|
|
||||||
|
@ -9344,11 +9371,7 @@ class CustomStreamWrapper:
|
||||||
"finish_reason"
|
"finish_reason"
|
||||||
]
|
]
|
||||||
|
|
||||||
if (
|
if anthropic_response_obj["usage"] is not None:
|
||||||
self.stream_options
|
|
||||||
and self.stream_options.get("include_usage", False) is True
|
|
||||||
and anthropic_response_obj["usage"] is not None
|
|
||||||
):
|
|
||||||
model_response.usage = litellm.Usage(
|
model_response.usage = litellm.Usage(
|
||||||
prompt_tokens=anthropic_response_obj["usage"]["prompt_tokens"],
|
prompt_tokens=anthropic_response_obj["usage"]["prompt_tokens"],
|
||||||
completion_tokens=anthropic_response_obj["usage"][
|
completion_tokens=anthropic_response_obj["usage"][
|
||||||
|
@ -9884,19 +9907,6 @@ class CustomStreamWrapper:
|
||||||
|
|
||||||
## RETURN ARG
|
## RETURN ARG
|
||||||
if (
|
if (
|
||||||
"content" in completion_obj
|
|
||||||
and isinstance(completion_obj["content"], str)
|
|
||||||
and len(completion_obj["content"]) == 0
|
|
||||||
and hasattr(model_response, "usage")
|
|
||||||
and hasattr(model_response.usage, "prompt_tokens")
|
|
||||||
):
|
|
||||||
if self.sent_first_chunk is False:
|
|
||||||
completion_obj["role"] = "assistant"
|
|
||||||
self.sent_first_chunk = True
|
|
||||||
model_response.choices[0].delta = Delta(**completion_obj)
|
|
||||||
print_verbose(f"returning model_response: {model_response}")
|
|
||||||
return model_response
|
|
||||||
elif (
|
|
||||||
"content" in completion_obj
|
"content" in completion_obj
|
||||||
and (
|
and (
|
||||||
isinstance(completion_obj["content"], str)
|
isinstance(completion_obj["content"], str)
|
||||||
|
@ -9991,6 +10001,7 @@ class CustomStreamWrapper:
|
||||||
model_response.choices[0].finish_reason = map_finish_reason(
|
model_response.choices[0].finish_reason = map_finish_reason(
|
||||||
finish_reason=self.received_finish_reason
|
finish_reason=self.received_finish_reason
|
||||||
) # ensure consistent output to openai
|
) # ensure consistent output to openai
|
||||||
|
|
||||||
self.sent_last_chunk = True
|
self.sent_last_chunk = True
|
||||||
|
|
||||||
return model_response
|
return model_response
|
||||||
|
@ -10003,6 +10014,8 @@ class CustomStreamWrapper:
|
||||||
self.sent_first_chunk = True
|
self.sent_first_chunk = True
|
||||||
return model_response
|
return model_response
|
||||||
else:
|
else:
|
||||||
|
if hasattr(model_response, "usage"):
|
||||||
|
self.chunks.append(model_response)
|
||||||
return
|
return
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
raise StopIteration
|
raise StopIteration
|
||||||
|
@ -10119,17 +10132,22 @@ class CustomStreamWrapper:
|
||||||
del obj_dict["usage"]
|
del obj_dict["usage"]
|
||||||
|
|
||||||
# Create a new object without the removed attribute
|
# Create a new object without the removed attribute
|
||||||
response = self.model_response_creator(chunk=obj_dict)
|
response = self.model_response_creator(
|
||||||
|
chunk=obj_dict, hidden_params=response._hidden_params
|
||||||
|
)
|
||||||
|
# add usage as hidden param
|
||||||
|
if self.sent_last_chunk is True and self.stream_options is None:
|
||||||
|
usage = calculate_total_usage(chunks=self.chunks)
|
||||||
|
response._hidden_params["usage"] = usage
|
||||||
# RETURN RESULT
|
# RETURN RESULT
|
||||||
return response
|
return response
|
||||||
|
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
if self.sent_last_chunk is True:
|
if self.sent_last_chunk is True:
|
||||||
if (
|
if (
|
||||||
self.sent_stream_usage == False
|
self.sent_stream_usage is False
|
||||||
and self.stream_options is not None
|
and self.stream_options is not None
|
||||||
and self.stream_options.get("include_usage", False) == True
|
and self.stream_options.get("include_usage", False) is True
|
||||||
):
|
):
|
||||||
# send the final chunk with stream options
|
# send the final chunk with stream options
|
||||||
complete_streaming_response = litellm.stream_chunk_builder(
|
complete_streaming_response = litellm.stream_chunk_builder(
|
||||||
|
@ -10137,6 +10155,7 @@ class CustomStreamWrapper:
|
||||||
)
|
)
|
||||||
response = self.model_response_creator()
|
response = self.model_response_creator()
|
||||||
response.usage = complete_streaming_response.usage # type: ignore
|
response.usage = complete_streaming_response.usage # type: ignore
|
||||||
|
response._hidden_params["usage"] = complete_streaming_response.usage # type: ignore
|
||||||
## LOGGING
|
## LOGGING
|
||||||
threading.Thread(
|
threading.Thread(
|
||||||
target=self.logging_obj.success_handler,
|
target=self.logging_obj.success_handler,
|
||||||
|
@ -10148,6 +10167,9 @@ class CustomStreamWrapper:
|
||||||
else:
|
else:
|
||||||
self.sent_last_chunk = True
|
self.sent_last_chunk = True
|
||||||
processed_chunk = self.finish_reason_handler()
|
processed_chunk = self.finish_reason_handler()
|
||||||
|
if self.stream_options is None: # add usage as hidden param
|
||||||
|
usage = calculate_total_usage(chunks=self.chunks)
|
||||||
|
setattr(processed_chunk, "usage", usage)
|
||||||
## LOGGING
|
## LOGGING
|
||||||
threading.Thread(
|
threading.Thread(
|
||||||
target=self.logging_obj.success_handler,
|
target=self.logging_obj.success_handler,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue