diff --git a/litellm/main.py b/litellm/main.py index 429efb6c0..bb8a1305e 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -5191,17 +5191,24 @@ def stream_chunk_builder( prompt_tokens = 0 completion_tokens = 0 for chunk in chunks: + usage_chunk: Optional[Usage] = None if "usage" in chunk: - if "prompt_tokens" in chunk["usage"]: - prompt_tokens = chunk["usage"].get("prompt_tokens", 0) or 0 - if "completion_tokens" in chunk["usage"]: - completion_tokens = chunk["usage"].get("completion_tokens", 0) or 0 + usage_chunk = chunk.usage + elif hasattr(chunk, "_hidden_params") and "usage" in chunk._hidden_params: + usage_chunk = chunk._hidden_params["usage"] + if usage_chunk is not None: + if "prompt_tokens" in usage_chunk: + prompt_tokens = usage_chunk.get("prompt_tokens", 0) or 0 + if "completion_tokens" in usage_chunk: + completion_tokens = usage_chunk.get("completion_tokens", 0) or 0 try: response["usage"]["prompt_tokens"] = prompt_tokens or token_counter( model=model, messages=messages ) - except: # don't allow this failing to block a complete streaming response from being returned - print_verbose(f"token_counter failed, assuming prompt tokens is 0") + except ( + Exception + ): # don't allow this failing to block a complete streaming response from being returned + print_verbose("token_counter failed, assuming prompt tokens is 0") response["usage"]["prompt_tokens"] = 0 response["usage"]["completion_tokens"] = completion_tokens or token_counter( model=model, diff --git a/litellm/tests/test_streaming.py b/litellm/tests/test_streaming.py index bd2d889e3..9c53d5cfb 100644 --- a/litellm/tests/test_streaming.py +++ b/litellm/tests/test_streaming.py @@ -3096,6 +3096,7 @@ def test_completion_claude_3_function_call_with_streaming(): elif idx == 1 and chunk.choices[0].finish_reason is None: validate_second_streaming_function_calling_chunk(chunk=chunk) elif chunk.choices[0].finish_reason is not None: # last chunk + assert "usage" in chunk._hidden_params validate_final_streaming_function_calling_chunk(chunk=chunk) idx += 1 # raise Exception("it worked!") diff --git a/litellm/utils.py b/litellm/utils.py index 0e1573784..b6011cb94 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -8378,6 +8378,28 @@ def get_secret( ######## Streaming Class ############################ # wraps the completion stream to return the correct format for the model # replicate/anthropic/cohere + + +def calculate_total_usage(chunks: List[ModelResponse]) -> Usage: + """Assume most recent usage chunk has total usage uptil then.""" + prompt_tokens: int = 0 + completion_tokens: int = 0 + for chunk in chunks: + if "usage" in chunk: + if "prompt_tokens" in chunk["usage"]: + prompt_tokens = chunk["usage"].get("prompt_tokens", 0) or 0 + if "completion_tokens" in chunk["usage"]: + completion_tokens = chunk["usage"].get("completion_tokens", 0) or 0 + + returned_usage_chunk = Usage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ) + + return returned_usage_chunk + + class CustomStreamWrapper: def __init__( self, @@ -9267,7 +9289,9 @@ class CustomStreamWrapper: verbose_logger.debug(traceback.format_exc()) return "" - def model_response_creator(self, chunk: Optional[dict] = None): + def model_response_creator( + self, chunk: Optional[dict] = None, hidden_params: Optional[dict] = None + ): _model = self.model _received_llm_provider = self.custom_llm_provider _logging_obj_llm_provider = self.logging_obj.model_call_details.get("custom_llm_provider", None) # type: ignore @@ -9281,6 +9305,7 @@ class CustomStreamWrapper: else: # pop model keyword chunk.pop("model", None) + model_response = ModelResponse( stream=True, model=_model, stream_options=self.stream_options, **chunk ) @@ -9290,6 +9315,8 @@ class CustomStreamWrapper: self.response_id = model_response.id # type: ignore if self.system_fingerprint is not None: model_response.system_fingerprint = self.system_fingerprint + if hidden_params is not None: + model_response._hidden_params = hidden_params model_response._hidden_params["custom_llm_provider"] = _logging_obj_llm_provider model_response._hidden_params["created_at"] = time.time() @@ -9344,11 +9371,7 @@ class CustomStreamWrapper: "finish_reason" ] - if ( - self.stream_options - and self.stream_options.get("include_usage", False) is True - and anthropic_response_obj["usage"] is not None - ): + if anthropic_response_obj["usage"] is not None: model_response.usage = litellm.Usage( prompt_tokens=anthropic_response_obj["usage"]["prompt_tokens"], completion_tokens=anthropic_response_obj["usage"][ @@ -9884,19 +9907,6 @@ class CustomStreamWrapper: ## RETURN ARG if ( - "content" in completion_obj - and isinstance(completion_obj["content"], str) - and len(completion_obj["content"]) == 0 - and hasattr(model_response, "usage") - and hasattr(model_response.usage, "prompt_tokens") - ): - if self.sent_first_chunk is False: - completion_obj["role"] = "assistant" - self.sent_first_chunk = True - model_response.choices[0].delta = Delta(**completion_obj) - print_verbose(f"returning model_response: {model_response}") - return model_response - elif ( "content" in completion_obj and ( isinstance(completion_obj["content"], str) @@ -9991,6 +10001,7 @@ class CustomStreamWrapper: model_response.choices[0].finish_reason = map_finish_reason( finish_reason=self.received_finish_reason ) # ensure consistent output to openai + self.sent_last_chunk = True return model_response @@ -10003,6 +10014,8 @@ class CustomStreamWrapper: self.sent_first_chunk = True return model_response else: + if hasattr(model_response, "usage"): + self.chunks.append(model_response) return except StopIteration: raise StopIteration @@ -10119,17 +10132,22 @@ class CustomStreamWrapper: del obj_dict["usage"] # Create a new object without the removed attribute - response = self.model_response_creator(chunk=obj_dict) - + response = self.model_response_creator( + chunk=obj_dict, hidden_params=response._hidden_params + ) + # add usage as hidden param + if self.sent_last_chunk is True and self.stream_options is None: + usage = calculate_total_usage(chunks=self.chunks) + response._hidden_params["usage"] = usage # RETURN RESULT return response except StopIteration: if self.sent_last_chunk is True: if ( - self.sent_stream_usage == False + self.sent_stream_usage is False and self.stream_options is not None - and self.stream_options.get("include_usage", False) == True + and self.stream_options.get("include_usage", False) is True ): # send the final chunk with stream options complete_streaming_response = litellm.stream_chunk_builder( @@ -10137,6 +10155,7 @@ class CustomStreamWrapper: ) response = self.model_response_creator() response.usage = complete_streaming_response.usage # type: ignore + response._hidden_params["usage"] = complete_streaming_response.usage # type: ignore ## LOGGING threading.Thread( target=self.logging_obj.success_handler, @@ -10148,6 +10167,9 @@ class CustomStreamWrapper: else: self.sent_last_chunk = True processed_chunk = self.finish_reason_handler() + if self.stream_options is None: # add usage as hidden param + usage = calculate_total_usage(chunks=self.chunks) + setattr(processed_chunk, "usage", usage) ## LOGGING threading.Thread( target=self.logging_obj.success_handler,