feat(utils.py): support 'stream_options' param across all providers

Closes https://github.com/BerriAI/litellm/issues/3553
This commit is contained in:
Krrish Dholakia 2024-06-04 19:03:26 -07:00
parent 34f31a1994
commit 54dacfdf61
2 changed files with 66 additions and 5 deletions

View file

@ -1137,6 +1137,7 @@ class Logging:
global supabaseClient, liteDebuggerClient, promptLayerLogger, weightsBiasesLogger, langsmithLogger, logfireLogger, capture_exception, add_breadcrumb, lunaryLogger
custom_pricing: bool = False
stream_options = None
def __init__(
self,
@ -1205,6 +1206,7 @@ class Logging:
self.litellm_params = litellm_params
self.logger_fn = litellm_params.get("logger_fn", None)
print_verbose(f"self.optional_params: {self.optional_params}")
self.model_call_details = {
"model": self.model,
"messages": self.messages,
@ -1220,6 +1222,9 @@ class Logging:
**additional_params,
}
## check if stream options is set ## - used by CustomStreamWrapper for easy instrumentation
if "stream_options" in additional_params:
self.stream_options = additional_params["stream_options"]
## check if custom pricing set ##
if (
litellm_params.get("input_cost_per_token") is not None
@ -3035,6 +3040,7 @@ def function_setup(
user="",
optional_params={},
litellm_params=litellm_params,
stream_options=kwargs.get("stream_options", None),
)
return logging_obj, kwargs
except Exception as e:
@ -5345,7 +5351,7 @@ def get_optional_params(
unsupported_params = {}
for k in non_default_params.keys():
if k not in supported_params:
if k == "user":
if k == "user" or k == "stream_options":
continue
if k == "n" and n == 1: # langchain sends n=1 as a default value
continue # skip this param
@ -10274,7 +10280,14 @@ class CustomStreamWrapper:
self.response_id = None
self.logging_loop = None
self.rules = Rules()
self.stream_options = stream_options
self.stream_options = stream_options or getattr(
logging_obj, "stream_options", None
)
self.messages = getattr(logging_obj, "messages", None)
self.sent_stream_usage = False
self.chunks: List = (
[]
) # keep track of the returned chunks - used for calculating the input/output tokens for stream options
def __iter__(self):
return self
@ -11389,6 +11402,7 @@ class CustomStreamWrapper:
and self.stream_options.get("include_usage", False) == True
and response_obj["usage"] is not None
):
self.sent_stream_usage = True
model_response.usage = litellm.Usage(
prompt_tokens=response_obj["usage"].prompt_tokens,
completion_tokens=response_obj["usage"].completion_tokens,
@ -11405,6 +11419,7 @@ class CustomStreamWrapper:
and self.stream_options.get("include_usage", False) == True
and response_obj["usage"] is not None
):
self.sent_stream_usage = True
model_response.usage = litellm.Usage(
prompt_tokens=response_obj["usage"].prompt_tokens,
completion_tokens=response_obj["usage"].completion_tokens,
@ -11468,6 +11483,7 @@ class CustomStreamWrapper:
and self.stream_options["include_usage"] == True
and response_obj["usage"] is not None
):
self.sent_stream_usage = True
model_response.usage = litellm.Usage(
prompt_tokens=response_obj["usage"].prompt_tokens,
completion_tokens=response_obj["usage"].completion_tokens,
@ -11749,7 +11765,6 @@ class CustomStreamWrapper:
model_response.choices[0].finish_reason = "stop"
return model_response
## needs to handle the empty string case (even starting chunk can be an empty string)
def __next__(self):
try:
while True:
@ -11781,9 +11796,27 @@ class CustomStreamWrapper:
input=self.response_uptil_now, model=self.model
)
# RETURN RESULT
self.chunks.append(response)
return response
except StopIteration:
if self.sent_last_chunk == True:
if (
self.sent_stream_usage == False
and self.stream_options is not None
and self.stream_options.get("include_usage", False) == True
):
# send the final chunk with stream options
complete_streaming_response = litellm.stream_chunk_builder(
chunks=self.chunks, messages=self.messages
)
response = self.model_response_creator()
response.usage = complete_streaming_response.usage # type: ignore
## LOGGING
threading.Thread(
target=self.logging_obj.success_handler, args=(response,)
).start() # log response
self.sent_stream_usage = True
return response
raise # Re-raise StopIteration
else:
self.sent_last_chunk = True
@ -11881,6 +11914,7 @@ class CustomStreamWrapper:
input=self.response_uptil_now, model=self.model
)
print_verbose(f"final returned processed chunk: {processed_chunk}")
self.chunks.append(response)
return processed_chunk
raise StopAsyncIteration
else: # temporary patch for non-aiohttp async calls
@ -11920,9 +11954,32 @@ class CustomStreamWrapper:
input=self.response_uptil_now, model=self.model
)
# RETURN RESULT
self.chunks.append(response)
return processed_chunk
except StopAsyncIteration:
if self.sent_last_chunk == True:
if (
self.sent_stream_usage == False
and self.stream_options is not None
and self.stream_options.get("include_usage", False) == True
):
# send the final chunk with stream options
complete_streaming_response = litellm.stream_chunk_builder(
chunks=self.chunks
)
response = self.model_response_creator()
response.usage = complete_streaming_response.usage
## LOGGING
threading.Thread(
target=self.logging_obj.success_handler, args=(processed_chunk,)
).start() # log response
asyncio.create_task(
self.logging_obj.async_success_handler(
processed_chunk,
)
)
self.sent_stream_usage = True
return response
raise # Re-raise StopIteration
else:
self.sent_last_chunk = True