mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
feat(utils.py): support 'stream_options' param across all providers
Closes https://github.com/BerriAI/litellm/issues/3553
This commit is contained in:
parent
34f31a1994
commit
54dacfdf61
2 changed files with 66 additions and 5 deletions
|
@ -1137,6 +1137,7 @@ class Logging:
|
|||
global supabaseClient, liteDebuggerClient, promptLayerLogger, weightsBiasesLogger, langsmithLogger, logfireLogger, capture_exception, add_breadcrumb, lunaryLogger
|
||||
|
||||
custom_pricing: bool = False
|
||||
stream_options = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -1205,6 +1206,7 @@ class Logging:
|
|||
self.litellm_params = litellm_params
|
||||
self.logger_fn = litellm_params.get("logger_fn", None)
|
||||
print_verbose(f"self.optional_params: {self.optional_params}")
|
||||
|
||||
self.model_call_details = {
|
||||
"model": self.model,
|
||||
"messages": self.messages,
|
||||
|
@ -1220,6 +1222,9 @@ class Logging:
|
|||
**additional_params,
|
||||
}
|
||||
|
||||
## check if stream options is set ## - used by CustomStreamWrapper for easy instrumentation
|
||||
if "stream_options" in additional_params:
|
||||
self.stream_options = additional_params["stream_options"]
|
||||
## check if custom pricing set ##
|
||||
if (
|
||||
litellm_params.get("input_cost_per_token") is not None
|
||||
|
@ -3035,6 +3040,7 @@ def function_setup(
|
|||
user="",
|
||||
optional_params={},
|
||||
litellm_params=litellm_params,
|
||||
stream_options=kwargs.get("stream_options", None),
|
||||
)
|
||||
return logging_obj, kwargs
|
||||
except Exception as e:
|
||||
|
@ -5345,7 +5351,7 @@ def get_optional_params(
|
|||
unsupported_params = {}
|
||||
for k in non_default_params.keys():
|
||||
if k not in supported_params:
|
||||
if k == "user":
|
||||
if k == "user" or k == "stream_options":
|
||||
continue
|
||||
if k == "n" and n == 1: # langchain sends n=1 as a default value
|
||||
continue # skip this param
|
||||
|
@ -10274,7 +10280,14 @@ class CustomStreamWrapper:
|
|||
self.response_id = None
|
||||
self.logging_loop = None
|
||||
self.rules = Rules()
|
||||
self.stream_options = stream_options
|
||||
self.stream_options = stream_options or getattr(
|
||||
logging_obj, "stream_options", None
|
||||
)
|
||||
self.messages = getattr(logging_obj, "messages", None)
|
||||
self.sent_stream_usage = False
|
||||
self.chunks: List = (
|
||||
[]
|
||||
) # keep track of the returned chunks - used for calculating the input/output tokens for stream options
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
@ -11389,6 +11402,7 @@ class CustomStreamWrapper:
|
|||
and self.stream_options.get("include_usage", False) == True
|
||||
and response_obj["usage"] is not None
|
||||
):
|
||||
self.sent_stream_usage = True
|
||||
model_response.usage = litellm.Usage(
|
||||
prompt_tokens=response_obj["usage"].prompt_tokens,
|
||||
completion_tokens=response_obj["usage"].completion_tokens,
|
||||
|
@ -11405,6 +11419,7 @@ class CustomStreamWrapper:
|
|||
and self.stream_options.get("include_usage", False) == True
|
||||
and response_obj["usage"] is not None
|
||||
):
|
||||
self.sent_stream_usage = True
|
||||
model_response.usage = litellm.Usage(
|
||||
prompt_tokens=response_obj["usage"].prompt_tokens,
|
||||
completion_tokens=response_obj["usage"].completion_tokens,
|
||||
|
@ -11468,6 +11483,7 @@ class CustomStreamWrapper:
|
|||
and self.stream_options["include_usage"] == True
|
||||
and response_obj["usage"] is not None
|
||||
):
|
||||
self.sent_stream_usage = True
|
||||
model_response.usage = litellm.Usage(
|
||||
prompt_tokens=response_obj["usage"].prompt_tokens,
|
||||
completion_tokens=response_obj["usage"].completion_tokens,
|
||||
|
@ -11749,7 +11765,6 @@ class CustomStreamWrapper:
|
|||
model_response.choices[0].finish_reason = "stop"
|
||||
return model_response
|
||||
|
||||
## needs to handle the empty string case (even starting chunk can be an empty string)
|
||||
def __next__(self):
|
||||
try:
|
||||
while True:
|
||||
|
@ -11781,9 +11796,27 @@ class CustomStreamWrapper:
|
|||
input=self.response_uptil_now, model=self.model
|
||||
)
|
||||
# RETURN RESULT
|
||||
self.chunks.append(response)
|
||||
return response
|
||||
except StopIteration:
|
||||
if self.sent_last_chunk == True:
|
||||
if (
|
||||
self.sent_stream_usage == False
|
||||
and self.stream_options is not None
|
||||
and self.stream_options.get("include_usage", False) == True
|
||||
):
|
||||
# send the final chunk with stream options
|
||||
complete_streaming_response = litellm.stream_chunk_builder(
|
||||
chunks=self.chunks, messages=self.messages
|
||||
)
|
||||
response = self.model_response_creator()
|
||||
response.usage = complete_streaming_response.usage # type: ignore
|
||||
## LOGGING
|
||||
threading.Thread(
|
||||
target=self.logging_obj.success_handler, args=(response,)
|
||||
).start() # log response
|
||||
self.sent_stream_usage = True
|
||||
return response
|
||||
raise # Re-raise StopIteration
|
||||
else:
|
||||
self.sent_last_chunk = True
|
||||
|
@ -11881,6 +11914,7 @@ class CustomStreamWrapper:
|
|||
input=self.response_uptil_now, model=self.model
|
||||
)
|
||||
print_verbose(f"final returned processed chunk: {processed_chunk}")
|
||||
self.chunks.append(response)
|
||||
return processed_chunk
|
||||
raise StopAsyncIteration
|
||||
else: # temporary patch for non-aiohttp async calls
|
||||
|
@ -11920,9 +11954,32 @@ class CustomStreamWrapper:
|
|||
input=self.response_uptil_now, model=self.model
|
||||
)
|
||||
# RETURN RESULT
|
||||
self.chunks.append(response)
|
||||
return processed_chunk
|
||||
except StopAsyncIteration:
|
||||
if self.sent_last_chunk == True:
|
||||
if (
|
||||
self.sent_stream_usage == False
|
||||
and self.stream_options is not None
|
||||
and self.stream_options.get("include_usage", False) == True
|
||||
):
|
||||
# send the final chunk with stream options
|
||||
complete_streaming_response = litellm.stream_chunk_builder(
|
||||
chunks=self.chunks
|
||||
)
|
||||
response = self.model_response_creator()
|
||||
response.usage = complete_streaming_response.usage
|
||||
## LOGGING
|
||||
threading.Thread(
|
||||
target=self.logging_obj.success_handler, args=(processed_chunk,)
|
||||
).start() # log response
|
||||
asyncio.create_task(
|
||||
self.logging_obj.async_success_handler(
|
||||
processed_chunk,
|
||||
)
|
||||
)
|
||||
self.sent_stream_usage = True
|
||||
return response
|
||||
raise # Re-raise StopIteration
|
||||
else:
|
||||
self.sent_last_chunk = True
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue