diff --git a/litellm/utils.py b/litellm/utils.py index df58db29c..64a644f15 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -612,6 +612,7 @@ class ModelResponse(OpenAIObject): system_fingerprint=None, usage=None, stream=None, + stream_options=None, response_ms=None, hidden_params=None, **params, @@ -658,6 +659,12 @@ class ModelResponse(OpenAIObject): usage = usage elif stream is None or stream == False: usage = Usage() + elif ( + stream == True + and stream_options is not None + and stream_options.get("include_usage") == True + ): + usage = Usage() if hidden_params: self._hidden_params = hidden_params @@ -4839,6 +4846,7 @@ def get_optional_params( top_p=None, n=None, stream=False, + stream_options=None, stop=None, max_tokens=None, presence_penalty=None, @@ -4908,6 +4916,7 @@ def get_optional_params( "top_p": None, "n": None, "stream": None, + "stream_options": None, "stop": None, "max_tokens": None, "presence_penalty": None, @@ -5779,6 +5788,8 @@ def get_optional_params( optional_params["n"] = n if stream is not None: optional_params["stream"] = stream + if stream_options is not None: + optional_params["stream_options"] = stream_options if stop is not None: optional_params["stop"] = stop if max_tokens is not None: @@ -6049,6 +6060,7 @@ def get_supported_openai_params(model: str, custom_llm_provider: str): "top_p", "n", "stream", + "stream_options", "stop", "max_tokens", "presence_penalty", @@ -9466,7 +9478,12 @@ def get_secret( # replicate/anthropic/cohere class CustomStreamWrapper: def __init__( - self, completion_stream, model, custom_llm_provider=None, logging_obj=None + self, + completion_stream, + model, + custom_llm_provider=None, + logging_obj=None, + stream_options=None, ): self.model = model self.custom_llm_provider = custom_llm_provider @@ -9492,6 +9509,7 @@ class CustomStreamWrapper: self.response_id = None self.logging_loop = None self.rules = Rules() + self.stream_options = stream_options def __iter__(self): return self @@ -9932,6 +9950,7 @@ class CustomStreamWrapper: is_finished = False finish_reason = None logprobs = None + usage = None original_chunk = None # this is used for function/tool calling if len(str_line.choices) > 0: if ( @@ -9966,12 +9985,15 @@ class CustomStreamWrapper: else: logprobs = None + usage = getattr(str_line, "usage", None) + return { "text": text, "is_finished": is_finished, "finish_reason": finish_reason, "logprobs": logprobs, "original_chunk": str_line, + "usage": usage, } except Exception as e: traceback.print_exc() @@ -10274,7 +10296,9 @@ class CustomStreamWrapper: raise e def model_response_creator(self): - model_response = ModelResponse(stream=True, model=self.model) + model_response = ModelResponse( + stream=True, model=self.model, stream_options=self.stream_options + ) if self.response_id is not None: model_response.id = self.response_id else: @@ -10594,6 +10618,12 @@ class CustomStreamWrapper: if response_obj["logprobs"] is not None: model_response.choices[0].logprobs = response_obj["logprobs"] + if ( + self.stream_options is not None + and self.stream_options["include_usage"] == True + ): + model_response.usage = response_obj["usage"] + model_response.model = self.model print_verbose( f"model_response finish reason 3: {self.received_finish_reason}; response_obj={response_obj}" @@ -10681,6 +10711,11 @@ class CustomStreamWrapper: except Exception as e: model_response.choices[0].delta = Delta() else: + if ( + self.stream_options is not None + and self.stream_options["include_usage"] == True + ): + return model_response return print_verbose( f"model_response.choices[0].delta: {model_response.choices[0].delta}; completion_obj: {completion_obj}"