From 54dacfdf61bcb5ec35bc08d319b906bcc85b2e5e Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Tue, 4 Jun 2024 19:03:26 -0700 Subject: [PATCH] feat(utils.py): support 'stream_options' param across all providers Closes https://github.com/BerriAI/litellm/issues/3553 --- litellm/tests/test_streaming.py | 8 +++-- litellm/utils.py | 63 +++++++++++++++++++++++++++++++-- 2 files changed, 66 insertions(+), 5 deletions(-) diff --git a/litellm/tests/test_streaming.py b/litellm/tests/test_streaming.py index b939d6299..53a7278bb 100644 --- a/litellm/tests/test_streaming.py +++ b/litellm/tests/test_streaming.py @@ -1993,10 +1993,14 @@ def test_openai_chat_completion_complete_response_call(): # test_openai_chat_completion_complete_response_call() -def test_openai_stream_options_call(): +@pytest.mark.parametrize( + "model", + ["gpt-3.5-turbo", "azure/chatgpt-v-2"], +) +def test_openai_stream_options_call(model): litellm.set_verbose = False response = litellm.completion( - model="gpt-3.5-turbo", + model=model, messages=[{"role": "system", "content": "say GM - we're going to make it "}], stream=True, stream_options={"include_usage": True}, diff --git a/litellm/utils.py b/litellm/utils.py index 820e22afc..76aee1218 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -1137,6 +1137,7 @@ class Logging: global supabaseClient, liteDebuggerClient, promptLayerLogger, weightsBiasesLogger, langsmithLogger, logfireLogger, capture_exception, add_breadcrumb, lunaryLogger custom_pricing: bool = False + stream_options = None def __init__( self, @@ -1205,6 +1206,7 @@ class Logging: self.litellm_params = litellm_params self.logger_fn = litellm_params.get("logger_fn", None) print_verbose(f"self.optional_params: {self.optional_params}") + self.model_call_details = { "model": self.model, "messages": self.messages, @@ -1220,6 +1222,9 @@ class Logging: **additional_params, } + ## check if stream options is set ## - used by CustomStreamWrapper for easy instrumentation + if "stream_options" in additional_params: + self.stream_options = additional_params["stream_options"] ## check if custom pricing set ## if ( litellm_params.get("input_cost_per_token") is not None @@ -3035,6 +3040,7 @@ def function_setup( user="", optional_params={}, litellm_params=litellm_params, + stream_options=kwargs.get("stream_options", None), ) return logging_obj, kwargs except Exception as e: @@ -5345,7 +5351,7 @@ def get_optional_params( unsupported_params = {} for k in non_default_params.keys(): if k not in supported_params: - if k == "user": + if k == "user" or k == "stream_options": continue if k == "n" and n == 1: # langchain sends n=1 as a default value continue # skip this param @@ -10274,7 +10280,14 @@ class CustomStreamWrapper: self.response_id = None self.logging_loop = None self.rules = Rules() - self.stream_options = stream_options + self.stream_options = stream_options or getattr( + logging_obj, "stream_options", None + ) + self.messages = getattr(logging_obj, "messages", None) + self.sent_stream_usage = False + self.chunks: List = ( + [] + ) # keep track of the returned chunks - used for calculating the input/output tokens for stream options def __iter__(self): return self @@ -11389,6 +11402,7 @@ class CustomStreamWrapper: and self.stream_options.get("include_usage", False) == True and response_obj["usage"] is not None ): + self.sent_stream_usage = True model_response.usage = litellm.Usage( prompt_tokens=response_obj["usage"].prompt_tokens, completion_tokens=response_obj["usage"].completion_tokens, @@ -11405,6 +11419,7 @@ class CustomStreamWrapper: and self.stream_options.get("include_usage", False) == True and response_obj["usage"] is not None ): + self.sent_stream_usage = True model_response.usage = litellm.Usage( prompt_tokens=response_obj["usage"].prompt_tokens, completion_tokens=response_obj["usage"].completion_tokens, @@ -11468,6 +11483,7 @@ class CustomStreamWrapper: and self.stream_options["include_usage"] == True and response_obj["usage"] is not None ): + self.sent_stream_usage = True model_response.usage = litellm.Usage( prompt_tokens=response_obj["usage"].prompt_tokens, completion_tokens=response_obj["usage"].completion_tokens, @@ -11749,7 +11765,6 @@ class CustomStreamWrapper: model_response.choices[0].finish_reason = "stop" return model_response - ## needs to handle the empty string case (even starting chunk can be an empty string) def __next__(self): try: while True: @@ -11781,9 +11796,27 @@ class CustomStreamWrapper: input=self.response_uptil_now, model=self.model ) # RETURN RESULT + self.chunks.append(response) return response except StopIteration: if self.sent_last_chunk == True: + if ( + self.sent_stream_usage == False + and self.stream_options is not None + and self.stream_options.get("include_usage", False) == True + ): + # send the final chunk with stream options + complete_streaming_response = litellm.stream_chunk_builder( + chunks=self.chunks, messages=self.messages + ) + response = self.model_response_creator() + response.usage = complete_streaming_response.usage # type: ignore + ## LOGGING + threading.Thread( + target=self.logging_obj.success_handler, args=(response,) + ).start() # log response + self.sent_stream_usage = True + return response raise # Re-raise StopIteration else: self.sent_last_chunk = True @@ -11881,6 +11914,7 @@ class CustomStreamWrapper: input=self.response_uptil_now, model=self.model ) print_verbose(f"final returned processed chunk: {processed_chunk}") + self.chunks.append(response) return processed_chunk raise StopAsyncIteration else: # temporary patch for non-aiohttp async calls @@ -11920,9 +11954,32 @@ class CustomStreamWrapper: input=self.response_uptil_now, model=self.model ) # RETURN RESULT + self.chunks.append(response) return processed_chunk except StopAsyncIteration: if self.sent_last_chunk == True: + if ( + self.sent_stream_usage == False + and self.stream_options is not None + and self.stream_options.get("include_usage", False) == True + ): + # send the final chunk with stream options + complete_streaming_response = litellm.stream_chunk_builder( + chunks=self.chunks + ) + response = self.model_response_creator() + response.usage = complete_streaming_response.usage + ## LOGGING + threading.Thread( + target=self.logging_obj.success_handler, args=(processed_chunk,) + ).start() # log response + asyncio.create_task( + self.logging_obj.async_success_handler( + processed_chunk, + ) + ) + self.sent_stream_usage = True + return response raise # Re-raise StopIteration else: self.sent_last_chunk = True