forked from phoenix/litellm-mirror
Merge pull request #4015 from BerriAI/litellm_stream_options_fix_2
feat(utils.py): Support `stream_options` param across all providers
This commit is contained in:
commit
d6f4233441
2 changed files with 129 additions and 26 deletions
|
@ -1993,20 +1993,46 @@ def test_openai_chat_completion_complete_response_call():
|
||||||
|
|
||||||
|
|
||||||
# test_openai_chat_completion_complete_response_call()
|
# test_openai_chat_completion_complete_response_call()
|
||||||
def test_openai_stream_options_call():
|
@pytest.mark.parametrize(
|
||||||
|
"model",
|
||||||
|
["gpt-3.5-turbo", "azure/chatgpt-v-2"],
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"sync",
|
||||||
|
[True, False],
|
||||||
|
)
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_openai_stream_options_call(model, sync):
|
||||||
litellm.set_verbose = False
|
litellm.set_verbose = False
|
||||||
|
usage = None
|
||||||
|
chunks = []
|
||||||
|
if sync:
|
||||||
response = litellm.completion(
|
response = litellm.completion(
|
||||||
model="gpt-3.5-turbo",
|
model=model,
|
||||||
messages=[{"role": "system", "content": "say GM - we're going to make it "}],
|
messages=[
|
||||||
|
{"role": "system", "content": "say GM - we're going to make it "}
|
||||||
|
],
|
||||||
stream=True,
|
stream=True,
|
||||||
stream_options={"include_usage": True},
|
stream_options={"include_usage": True},
|
||||||
max_tokens=10,
|
max_tokens=10,
|
||||||
)
|
)
|
||||||
usage = None
|
|
||||||
chunks = []
|
|
||||||
for chunk in response:
|
for chunk in response:
|
||||||
print("chunk: ", chunk)
|
print("chunk: ", chunk)
|
||||||
chunks.append(chunk)
|
chunks.append(chunk)
|
||||||
|
else:
|
||||||
|
response = await litellm.acompletion(
|
||||||
|
model=model,
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": "say GM - we're going to make it "}
|
||||||
|
],
|
||||||
|
stream=True,
|
||||||
|
stream_options={"include_usage": True},
|
||||||
|
max_tokens=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
async for chunk in response:
|
||||||
|
print("chunk: ", chunk)
|
||||||
|
chunks.append(chunk)
|
||||||
|
|
||||||
last_chunk = chunks[-1]
|
last_chunk = chunks[-1]
|
||||||
print("last chunk: ", last_chunk)
|
print("last chunk: ", last_chunk)
|
||||||
|
@ -2018,12 +2044,24 @@ def test_openai_stream_options_call():
|
||||||
"""
|
"""
|
||||||
|
|
||||||
assert last_chunk.usage is not None
|
assert last_chunk.usage is not None
|
||||||
|
assert isinstance(last_chunk.usage, litellm.Usage)
|
||||||
assert last_chunk.usage.total_tokens > 0
|
assert last_chunk.usage.total_tokens > 0
|
||||||
assert last_chunk.usage.prompt_tokens > 0
|
assert last_chunk.usage.prompt_tokens > 0
|
||||||
assert last_chunk.usage.completion_tokens > 0
|
assert last_chunk.usage.completion_tokens > 0
|
||||||
|
|
||||||
# assert all non last chunks have usage=None
|
# assert all non last chunks have usage=None
|
||||||
assert all(chunk.usage is None for chunk in chunks[:-1])
|
# Improved assertion with detailed error message
|
||||||
|
non_last_chunks_with_usage = [
|
||||||
|
chunk
|
||||||
|
for chunk in chunks[:-1]
|
||||||
|
if hasattr(chunk, "usage") and chunk.usage is not None
|
||||||
|
]
|
||||||
|
assert (
|
||||||
|
not non_last_chunks_with_usage
|
||||||
|
), f"Non-last chunks with usage not None:\n" + "\n".join(
|
||||||
|
f"Chunk ID: {chunk.id}, Usage: {chunk.usage}, Content: {chunk.choices[0].delta.content}"
|
||||||
|
for chunk in non_last_chunks_with_usage
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_openai_stream_options_call_text_completion():
|
def test_openai_stream_options_call_text_completion():
|
||||||
|
|
|
@ -680,12 +680,6 @@ class ModelResponse(OpenAIObject):
|
||||||
usage = usage
|
usage = usage
|
||||||
elif stream is None or stream == False:
|
elif stream is None or stream == False:
|
||||||
usage = Usage()
|
usage = Usage()
|
||||||
elif (
|
|
||||||
stream == True
|
|
||||||
and stream_options is not None
|
|
||||||
and stream_options.get("include_usage") == True
|
|
||||||
):
|
|
||||||
usage = Usage()
|
|
||||||
if hidden_params:
|
if hidden_params:
|
||||||
self._hidden_params = hidden_params
|
self._hidden_params = hidden_params
|
||||||
|
|
||||||
|
@ -1143,6 +1137,7 @@ class Logging:
|
||||||
global supabaseClient, liteDebuggerClient, promptLayerLogger, weightsBiasesLogger, langsmithLogger, logfireLogger, capture_exception, add_breadcrumb, lunaryLogger
|
global supabaseClient, liteDebuggerClient, promptLayerLogger, weightsBiasesLogger, langsmithLogger, logfireLogger, capture_exception, add_breadcrumb, lunaryLogger
|
||||||
|
|
||||||
custom_pricing: bool = False
|
custom_pricing: bool = False
|
||||||
|
stream_options = None
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -1211,6 +1206,7 @@ class Logging:
|
||||||
self.litellm_params = litellm_params
|
self.litellm_params = litellm_params
|
||||||
self.logger_fn = litellm_params.get("logger_fn", None)
|
self.logger_fn = litellm_params.get("logger_fn", None)
|
||||||
print_verbose(f"self.optional_params: {self.optional_params}")
|
print_verbose(f"self.optional_params: {self.optional_params}")
|
||||||
|
|
||||||
self.model_call_details = {
|
self.model_call_details = {
|
||||||
"model": self.model,
|
"model": self.model,
|
||||||
"messages": self.messages,
|
"messages": self.messages,
|
||||||
|
@ -1226,6 +1222,9 @@ class Logging:
|
||||||
**additional_params,
|
**additional_params,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
## check if stream options is set ## - used by CustomStreamWrapper for easy instrumentation
|
||||||
|
if "stream_options" in additional_params:
|
||||||
|
self.stream_options = additional_params["stream_options"]
|
||||||
## check if custom pricing set ##
|
## check if custom pricing set ##
|
||||||
if (
|
if (
|
||||||
litellm_params.get("input_cost_per_token") is not None
|
litellm_params.get("input_cost_per_token") is not None
|
||||||
|
@ -3044,6 +3043,7 @@ def function_setup(
|
||||||
user="",
|
user="",
|
||||||
optional_params={},
|
optional_params={},
|
||||||
litellm_params=litellm_params,
|
litellm_params=litellm_params,
|
||||||
|
stream_options=kwargs.get("stream_options", None),
|
||||||
)
|
)
|
||||||
return logging_obj, kwargs
|
return logging_obj, kwargs
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -5354,7 +5354,7 @@ def get_optional_params(
|
||||||
unsupported_params = {}
|
unsupported_params = {}
|
||||||
for k in non_default_params.keys():
|
for k in non_default_params.keys():
|
||||||
if k not in supported_params:
|
if k not in supported_params:
|
||||||
if k == "user":
|
if k == "user" or k == "stream_options":
|
||||||
continue
|
continue
|
||||||
if k == "n" and n == 1: # langchain sends n=1 as a default value
|
if k == "n" and n == 1: # langchain sends n=1 as a default value
|
||||||
continue # skip this param
|
continue # skip this param
|
||||||
|
@ -10283,7 +10283,14 @@ class CustomStreamWrapper:
|
||||||
self.response_id = None
|
self.response_id = None
|
||||||
self.logging_loop = None
|
self.logging_loop = None
|
||||||
self.rules = Rules()
|
self.rules = Rules()
|
||||||
self.stream_options = stream_options
|
self.stream_options = stream_options or getattr(
|
||||||
|
logging_obj, "stream_options", None
|
||||||
|
)
|
||||||
|
self.messages = getattr(logging_obj, "messages", None)
|
||||||
|
self.sent_stream_usage = False
|
||||||
|
self.chunks: List = (
|
||||||
|
[]
|
||||||
|
) # keep track of the returned chunks - used for calculating the input/output tokens for stream options
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
return self
|
return self
|
||||||
|
@ -11110,8 +11117,7 @@ class CustomStreamWrapper:
|
||||||
model_response.system_fingerprint = self.system_fingerprint
|
model_response.system_fingerprint = self.system_fingerprint
|
||||||
model_response._hidden_params["custom_llm_provider"] = self.custom_llm_provider
|
model_response._hidden_params["custom_llm_provider"] = self.custom_llm_provider
|
||||||
model_response._hidden_params["created_at"] = time.time()
|
model_response._hidden_params["created_at"] = time.time()
|
||||||
model_response.choices = [StreamingChoices()]
|
model_response.choices = [StreamingChoices(finish_reason=None)]
|
||||||
model_response.choices[0].finish_reason = None
|
|
||||||
return model_response
|
return model_response
|
||||||
|
|
||||||
def is_delta_empty(self, delta: Delta) -> bool:
|
def is_delta_empty(self, delta: Delta) -> bool:
|
||||||
|
@ -11397,8 +11403,14 @@ class CustomStreamWrapper:
|
||||||
if (
|
if (
|
||||||
self.stream_options
|
self.stream_options
|
||||||
and self.stream_options.get("include_usage", False) == True
|
and self.stream_options.get("include_usage", False) == True
|
||||||
|
and response_obj["usage"] is not None
|
||||||
):
|
):
|
||||||
model_response.usage = response_obj["usage"]
|
self.sent_stream_usage = True
|
||||||
|
model_response.usage = litellm.Usage(
|
||||||
|
prompt_tokens=response_obj["usage"].prompt_tokens,
|
||||||
|
completion_tokens=response_obj["usage"].completion_tokens,
|
||||||
|
total_tokens=response_obj["usage"].total_tokens,
|
||||||
|
)
|
||||||
elif self.custom_llm_provider == "databricks":
|
elif self.custom_llm_provider == "databricks":
|
||||||
response_obj = litellm.DatabricksConfig()._chunk_parser(chunk)
|
response_obj = litellm.DatabricksConfig()._chunk_parser(chunk)
|
||||||
completion_obj["content"] = response_obj["text"]
|
completion_obj["content"] = response_obj["text"]
|
||||||
|
@ -11408,8 +11420,14 @@ class CustomStreamWrapper:
|
||||||
if (
|
if (
|
||||||
self.stream_options
|
self.stream_options
|
||||||
and self.stream_options.get("include_usage", False) == True
|
and self.stream_options.get("include_usage", False) == True
|
||||||
|
and response_obj["usage"] is not None
|
||||||
):
|
):
|
||||||
model_response.usage = response_obj["usage"]
|
self.sent_stream_usage = True
|
||||||
|
model_response.usage = litellm.Usage(
|
||||||
|
prompt_tokens=response_obj["usage"].prompt_tokens,
|
||||||
|
completion_tokens=response_obj["usage"].completion_tokens,
|
||||||
|
total_tokens=response_obj["usage"].total_tokens,
|
||||||
|
)
|
||||||
elif self.custom_llm_provider == "azure_text":
|
elif self.custom_llm_provider == "azure_text":
|
||||||
response_obj = self.handle_azure_text_completion_chunk(chunk)
|
response_obj = self.handle_azure_text_completion_chunk(chunk)
|
||||||
completion_obj["content"] = response_obj["text"]
|
completion_obj["content"] = response_obj["text"]
|
||||||
|
@ -11466,8 +11484,14 @@ class CustomStreamWrapper:
|
||||||
if (
|
if (
|
||||||
self.stream_options is not None
|
self.stream_options is not None
|
||||||
and self.stream_options["include_usage"] == True
|
and self.stream_options["include_usage"] == True
|
||||||
|
and response_obj["usage"] is not None
|
||||||
):
|
):
|
||||||
model_response.usage = response_obj["usage"]
|
self.sent_stream_usage = True
|
||||||
|
model_response.usage = litellm.Usage(
|
||||||
|
prompt_tokens=response_obj["usage"].prompt_tokens,
|
||||||
|
completion_tokens=response_obj["usage"].completion_tokens,
|
||||||
|
total_tokens=response_obj["usage"].total_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
model_response.model = self.model
|
model_response.model = self.model
|
||||||
print_verbose(
|
print_verbose(
|
||||||
|
@ -11744,7 +11768,6 @@ class CustomStreamWrapper:
|
||||||
model_response.choices[0].finish_reason = "stop"
|
model_response.choices[0].finish_reason = "stop"
|
||||||
return model_response
|
return model_response
|
||||||
|
|
||||||
## needs to handle the empty string case (even starting chunk can be an empty string)
|
|
||||||
def __next__(self):
|
def __next__(self):
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
|
@ -11776,9 +11799,27 @@ class CustomStreamWrapper:
|
||||||
input=self.response_uptil_now, model=self.model
|
input=self.response_uptil_now, model=self.model
|
||||||
)
|
)
|
||||||
# RETURN RESULT
|
# RETURN RESULT
|
||||||
|
self.chunks.append(response)
|
||||||
return response
|
return response
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
if self.sent_last_chunk == True:
|
if self.sent_last_chunk == True:
|
||||||
|
if (
|
||||||
|
self.sent_stream_usage == False
|
||||||
|
and self.stream_options is not None
|
||||||
|
and self.stream_options.get("include_usage", False) == True
|
||||||
|
):
|
||||||
|
# send the final chunk with stream options
|
||||||
|
complete_streaming_response = litellm.stream_chunk_builder(
|
||||||
|
chunks=self.chunks, messages=self.messages
|
||||||
|
)
|
||||||
|
response = self.model_response_creator()
|
||||||
|
response.usage = complete_streaming_response.usage # type: ignore
|
||||||
|
## LOGGING
|
||||||
|
threading.Thread(
|
||||||
|
target=self.logging_obj.success_handler, args=(response,)
|
||||||
|
).start() # log response
|
||||||
|
self.sent_stream_usage = True
|
||||||
|
return response
|
||||||
raise # Re-raise StopIteration
|
raise # Re-raise StopIteration
|
||||||
else:
|
else:
|
||||||
self.sent_last_chunk = True
|
self.sent_last_chunk = True
|
||||||
|
@ -11876,6 +11917,7 @@ class CustomStreamWrapper:
|
||||||
input=self.response_uptil_now, model=self.model
|
input=self.response_uptil_now, model=self.model
|
||||||
)
|
)
|
||||||
print_verbose(f"final returned processed chunk: {processed_chunk}")
|
print_verbose(f"final returned processed chunk: {processed_chunk}")
|
||||||
|
self.chunks.append(processed_chunk)
|
||||||
return processed_chunk
|
return processed_chunk
|
||||||
raise StopAsyncIteration
|
raise StopAsyncIteration
|
||||||
else: # temporary patch for non-aiohttp async calls
|
else: # temporary patch for non-aiohttp async calls
|
||||||
|
@ -11915,9 +11957,32 @@ class CustomStreamWrapper:
|
||||||
input=self.response_uptil_now, model=self.model
|
input=self.response_uptil_now, model=self.model
|
||||||
)
|
)
|
||||||
# RETURN RESULT
|
# RETURN RESULT
|
||||||
|
self.chunks.append(processed_chunk)
|
||||||
return processed_chunk
|
return processed_chunk
|
||||||
except StopAsyncIteration:
|
except StopAsyncIteration:
|
||||||
if self.sent_last_chunk == True:
|
if self.sent_last_chunk == True:
|
||||||
|
if (
|
||||||
|
self.sent_stream_usage == False
|
||||||
|
and self.stream_options is not None
|
||||||
|
and self.stream_options.get("include_usage", False) == True
|
||||||
|
):
|
||||||
|
# send the final chunk with stream options
|
||||||
|
complete_streaming_response = litellm.stream_chunk_builder(
|
||||||
|
chunks=self.chunks, messages=self.messages
|
||||||
|
)
|
||||||
|
response = self.model_response_creator()
|
||||||
|
response.usage = complete_streaming_response.usage
|
||||||
|
## LOGGING
|
||||||
|
threading.Thread(
|
||||||
|
target=self.logging_obj.success_handler, args=(response,)
|
||||||
|
).start() # log response
|
||||||
|
asyncio.create_task(
|
||||||
|
self.logging_obj.async_success_handler(
|
||||||
|
response,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.sent_stream_usage = True
|
||||||
|
return response
|
||||||
raise # Re-raise StopIteration
|
raise # Re-raise StopIteration
|
||||||
else:
|
else:
|
||||||
self.sent_last_chunk = True
|
self.sent_last_chunk = True
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue