Merge pull request #4015 from BerriAI/litellm_stream_options_fix_2

feat(utils.py): Support `stream_options` param across all providers
This commit is contained in:
Krish Dholakia 2024-06-04 20:59:39 -07:00 committed by GitHub
commit d6f4233441
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 129 additions and 26 deletions

View file

@ -1993,20 +1993,46 @@ def test_openai_chat_completion_complete_response_call():
# test_openai_chat_completion_complete_response_call() # test_openai_chat_completion_complete_response_call()
def test_openai_stream_options_call(): @pytest.mark.parametrize(
"model",
["gpt-3.5-turbo", "azure/chatgpt-v-2"],
)
@pytest.mark.parametrize(
"sync",
[True, False],
)
@pytest.mark.asyncio
async def test_openai_stream_options_call(model, sync):
litellm.set_verbose = False litellm.set_verbose = False
response = litellm.completion(
model="gpt-3.5-turbo",
messages=[{"role": "system", "content": "say GM - we're going to make it "}],
stream=True,
stream_options={"include_usage": True},
max_tokens=10,
)
usage = None usage = None
chunks = [] chunks = []
for chunk in response: if sync:
print("chunk: ", chunk) response = litellm.completion(
chunks.append(chunk) model=model,
messages=[
{"role": "system", "content": "say GM - we're going to make it "}
],
stream=True,
stream_options={"include_usage": True},
max_tokens=10,
)
for chunk in response:
print("chunk: ", chunk)
chunks.append(chunk)
else:
response = await litellm.acompletion(
model=model,
messages=[
{"role": "system", "content": "say GM - we're going to make it "}
],
stream=True,
stream_options={"include_usage": True},
max_tokens=10,
)
async for chunk in response:
print("chunk: ", chunk)
chunks.append(chunk)
last_chunk = chunks[-1] last_chunk = chunks[-1]
print("last chunk: ", last_chunk) print("last chunk: ", last_chunk)
@ -2018,12 +2044,24 @@ def test_openai_stream_options_call():
""" """
assert last_chunk.usage is not None assert last_chunk.usage is not None
assert isinstance(last_chunk.usage, litellm.Usage)
assert last_chunk.usage.total_tokens > 0 assert last_chunk.usage.total_tokens > 0
assert last_chunk.usage.prompt_tokens > 0 assert last_chunk.usage.prompt_tokens > 0
assert last_chunk.usage.completion_tokens > 0 assert last_chunk.usage.completion_tokens > 0
# assert all non last chunks have usage=None # assert all non last chunks have usage=None
assert all(chunk.usage is None for chunk in chunks[:-1]) # Improved assertion with detailed error message
non_last_chunks_with_usage = [
chunk
for chunk in chunks[:-1]
if hasattr(chunk, "usage") and chunk.usage is not None
]
assert (
not non_last_chunks_with_usage
), f"Non-last chunks with usage not None:\n" + "\n".join(
f"Chunk ID: {chunk.id}, Usage: {chunk.usage}, Content: {chunk.choices[0].delta.content}"
for chunk in non_last_chunks_with_usage
)
def test_openai_stream_options_call_text_completion(): def test_openai_stream_options_call_text_completion():

View file

@ -680,12 +680,6 @@ class ModelResponse(OpenAIObject):
usage = usage usage = usage
elif stream is None or stream == False: elif stream is None or stream == False:
usage = Usage() usage = Usage()
elif (
stream == True
and stream_options is not None
and stream_options.get("include_usage") == True
):
usage = Usage()
if hidden_params: if hidden_params:
self._hidden_params = hidden_params self._hidden_params = hidden_params
@ -1143,6 +1137,7 @@ class Logging:
global supabaseClient, liteDebuggerClient, promptLayerLogger, weightsBiasesLogger, langsmithLogger, logfireLogger, capture_exception, add_breadcrumb, lunaryLogger global supabaseClient, liteDebuggerClient, promptLayerLogger, weightsBiasesLogger, langsmithLogger, logfireLogger, capture_exception, add_breadcrumb, lunaryLogger
custom_pricing: bool = False custom_pricing: bool = False
stream_options = None
def __init__( def __init__(
self, self,
@ -1211,6 +1206,7 @@ class Logging:
self.litellm_params = litellm_params self.litellm_params = litellm_params
self.logger_fn = litellm_params.get("logger_fn", None) self.logger_fn = litellm_params.get("logger_fn", None)
print_verbose(f"self.optional_params: {self.optional_params}") print_verbose(f"self.optional_params: {self.optional_params}")
self.model_call_details = { self.model_call_details = {
"model": self.model, "model": self.model,
"messages": self.messages, "messages": self.messages,
@ -1226,6 +1222,9 @@ class Logging:
**additional_params, **additional_params,
} }
## check if stream options is set ## - used by CustomStreamWrapper for easy instrumentation
if "stream_options" in additional_params:
self.stream_options = additional_params["stream_options"]
## check if custom pricing set ## ## check if custom pricing set ##
if ( if (
litellm_params.get("input_cost_per_token") is not None litellm_params.get("input_cost_per_token") is not None
@ -3044,6 +3043,7 @@ def function_setup(
user="", user="",
optional_params={}, optional_params={},
litellm_params=litellm_params, litellm_params=litellm_params,
stream_options=kwargs.get("stream_options", None),
) )
return logging_obj, kwargs return logging_obj, kwargs
except Exception as e: except Exception as e:
@ -5354,7 +5354,7 @@ def get_optional_params(
unsupported_params = {} unsupported_params = {}
for k in non_default_params.keys(): for k in non_default_params.keys():
if k not in supported_params: if k not in supported_params:
if k == "user": if k == "user" or k == "stream_options":
continue continue
if k == "n" and n == 1: # langchain sends n=1 as a default value if k == "n" and n == 1: # langchain sends n=1 as a default value
continue # skip this param continue # skip this param
@ -10283,7 +10283,14 @@ class CustomStreamWrapper:
self.response_id = None self.response_id = None
self.logging_loop = None self.logging_loop = None
self.rules = Rules() self.rules = Rules()
self.stream_options = stream_options self.stream_options = stream_options or getattr(
logging_obj, "stream_options", None
)
self.messages = getattr(logging_obj, "messages", None)
self.sent_stream_usage = False
self.chunks: List = (
[]
) # keep track of the returned chunks - used for calculating the input/output tokens for stream options
def __iter__(self): def __iter__(self):
return self return self
@ -11110,8 +11117,7 @@ class CustomStreamWrapper:
model_response.system_fingerprint = self.system_fingerprint model_response.system_fingerprint = self.system_fingerprint
model_response._hidden_params["custom_llm_provider"] = self.custom_llm_provider model_response._hidden_params["custom_llm_provider"] = self.custom_llm_provider
model_response._hidden_params["created_at"] = time.time() model_response._hidden_params["created_at"] = time.time()
model_response.choices = [StreamingChoices()] model_response.choices = [StreamingChoices(finish_reason=None)]
model_response.choices[0].finish_reason = None
return model_response return model_response
def is_delta_empty(self, delta: Delta) -> bool: def is_delta_empty(self, delta: Delta) -> bool:
@ -11397,8 +11403,14 @@ class CustomStreamWrapper:
if ( if (
self.stream_options self.stream_options
and self.stream_options.get("include_usage", False) == True and self.stream_options.get("include_usage", False) == True
and response_obj["usage"] is not None
): ):
model_response.usage = response_obj["usage"] self.sent_stream_usage = True
model_response.usage = litellm.Usage(
prompt_tokens=response_obj["usage"].prompt_tokens,
completion_tokens=response_obj["usage"].completion_tokens,
total_tokens=response_obj["usage"].total_tokens,
)
elif self.custom_llm_provider == "databricks": elif self.custom_llm_provider == "databricks":
response_obj = litellm.DatabricksConfig()._chunk_parser(chunk) response_obj = litellm.DatabricksConfig()._chunk_parser(chunk)
completion_obj["content"] = response_obj["text"] completion_obj["content"] = response_obj["text"]
@ -11408,8 +11420,14 @@ class CustomStreamWrapper:
if ( if (
self.stream_options self.stream_options
and self.stream_options.get("include_usage", False) == True and self.stream_options.get("include_usage", False) == True
and response_obj["usage"] is not None
): ):
model_response.usage = response_obj["usage"] self.sent_stream_usage = True
model_response.usage = litellm.Usage(
prompt_tokens=response_obj["usage"].prompt_tokens,
completion_tokens=response_obj["usage"].completion_tokens,
total_tokens=response_obj["usage"].total_tokens,
)
elif self.custom_llm_provider == "azure_text": elif self.custom_llm_provider == "azure_text":
response_obj = self.handle_azure_text_completion_chunk(chunk) response_obj = self.handle_azure_text_completion_chunk(chunk)
completion_obj["content"] = response_obj["text"] completion_obj["content"] = response_obj["text"]
@ -11466,8 +11484,14 @@ class CustomStreamWrapper:
if ( if (
self.stream_options is not None self.stream_options is not None
and self.stream_options["include_usage"] == True and self.stream_options["include_usage"] == True
and response_obj["usage"] is not None
): ):
model_response.usage = response_obj["usage"] self.sent_stream_usage = True
model_response.usage = litellm.Usage(
prompt_tokens=response_obj["usage"].prompt_tokens,
completion_tokens=response_obj["usage"].completion_tokens,
total_tokens=response_obj["usage"].total_tokens,
)
model_response.model = self.model model_response.model = self.model
print_verbose( print_verbose(
@ -11744,7 +11768,6 @@ class CustomStreamWrapper:
model_response.choices[0].finish_reason = "stop" model_response.choices[0].finish_reason = "stop"
return model_response return model_response
## needs to handle the empty string case (even starting chunk can be an empty string)
def __next__(self): def __next__(self):
try: try:
while True: while True:
@ -11776,9 +11799,27 @@ class CustomStreamWrapper:
input=self.response_uptil_now, model=self.model input=self.response_uptil_now, model=self.model
) )
# RETURN RESULT # RETURN RESULT
self.chunks.append(response)
return response return response
except StopIteration: except StopIteration:
if self.sent_last_chunk == True: if self.sent_last_chunk == True:
if (
self.sent_stream_usage == False
and self.stream_options is not None
and self.stream_options.get("include_usage", False) == True
):
# send the final chunk with stream options
complete_streaming_response = litellm.stream_chunk_builder(
chunks=self.chunks, messages=self.messages
)
response = self.model_response_creator()
response.usage = complete_streaming_response.usage # type: ignore
## LOGGING
threading.Thread(
target=self.logging_obj.success_handler, args=(response,)
).start() # log response
self.sent_stream_usage = True
return response
raise # Re-raise StopIteration raise # Re-raise StopIteration
else: else:
self.sent_last_chunk = True self.sent_last_chunk = True
@ -11876,6 +11917,7 @@ class CustomStreamWrapper:
input=self.response_uptil_now, model=self.model input=self.response_uptil_now, model=self.model
) )
print_verbose(f"final returned processed chunk: {processed_chunk}") print_verbose(f"final returned processed chunk: {processed_chunk}")
self.chunks.append(processed_chunk)
return processed_chunk return processed_chunk
raise StopAsyncIteration raise StopAsyncIteration
else: # temporary patch for non-aiohttp async calls else: # temporary patch for non-aiohttp async calls
@ -11915,9 +11957,32 @@ class CustomStreamWrapper:
input=self.response_uptil_now, model=self.model input=self.response_uptil_now, model=self.model
) )
# RETURN RESULT # RETURN RESULT
self.chunks.append(processed_chunk)
return processed_chunk return processed_chunk
except StopAsyncIteration: except StopAsyncIteration:
if self.sent_last_chunk == True: if self.sent_last_chunk == True:
if (
self.sent_stream_usage == False
and self.stream_options is not None
and self.stream_options.get("include_usage", False) == True
):
# send the final chunk with stream options
complete_streaming_response = litellm.stream_chunk_builder(
chunks=self.chunks, messages=self.messages
)
response = self.model_response_creator()
response.usage = complete_streaming_response.usage
## LOGGING
threading.Thread(
target=self.logging_obj.success_handler, args=(response,)
).start() # log response
asyncio.create_task(
self.logging_obj.async_success_handler(
response,
)
)
self.sent_stream_usage = True
return response
raise # Re-raise StopIteration raise # Re-raise StopIteration
else: else:
self.sent_last_chunk = True self.sent_last_chunk = True