fix(utils.py): stream_options working across all providers

This commit is contained in:
Krrish Dholakia 2024-07-03 20:40:46 -07:00
parent 8dbe0559dd
commit 2e5a81f280
5 changed files with 98 additions and 35 deletions

View file

@ -1279,7 +1279,9 @@ def anthropic_messages_pt(messages: list):
) )
else: else:
raise Exception( raise Exception(
"Invalid first message. Should always start with 'role'='user' for Anthropic. System prompt is sent separately for Anthropic. set 'litellm.modify_params = True' or 'litellm_settings:modify_params = True' on proxy, to insert a placeholder user message - '.' as the first message, " "Invalid first message={}. Should always start with 'role'='user' for Anthropic. System prompt is sent separately for Anthropic. set 'litellm.modify_params = True' or 'litellm_settings:modify_params = True' on proxy, to insert a placeholder user message - '.' as the first message, ".format(
new_messages
)
) )
if new_messages[-1]["role"] == "assistant": if new_messages[-1]["role"] == "assistant":

View file

@ -4946,14 +4946,23 @@ def stream_chunk_builder(
else: else:
completion_output = "" completion_output = ""
# # Update usage information if needed # # Update usage information if needed
prompt_tokens = 0
completion_tokens = 0
for chunk in chunks:
if "usage" in chunk:
if "prompt_tokens" in chunk["usage"]:
prompt_tokens += chunk["usage"].get("prompt_tokens", 0) or 0
if "completion_tokens" in chunk["usage"]:
completion_tokens += chunk["usage"].get("completion_tokens", 0) or 0
try: try:
response["usage"]["prompt_tokens"] = token_counter( response["usage"]["prompt_tokens"] = prompt_tokens or token_counter(
model=model, messages=messages model=model, messages=messages
) )
except: # don't allow this failing to block a complete streaming response from being returned except: # don't allow this failing to block a complete streaming response from being returned
print_verbose(f"token_counter failed, assuming prompt tokens is 0") print_verbose(f"token_counter failed, assuming prompt tokens is 0")
response["usage"]["prompt_tokens"] = 0 response["usage"]["prompt_tokens"] = 0
response["usage"]["completion_tokens"] = token_counter( response["usage"]["completion_tokens"] = completion_tokens or token_counter(
model=model, model=model,
text=completion_output, text=completion_output,
count_response_tokens=True, # count_response_tokens is a Flag to tell token counter this is a response, No need to add extra tokens we do for input messages count_response_tokens=True, # count_response_tokens is a Flag to tell token counter this is a response, No need to add extra tokens we do for input messages

View file

@ -18,7 +18,6 @@ model_list:
- model_name: fake-openai-endpoint - model_name: fake-openai-endpoint
litellm_params: litellm_params:
model: predibase/llama-3-8b-instruct model: predibase/llama-3-8b-instruct
api_base: "http://0.0.0.0:8081"
api_key: os.environ/PREDIBASE_API_KEY api_key: os.environ/PREDIBASE_API_KEY
tenant_id: os.environ/PREDIBASE_TENANT_ID tenant_id: os.environ/PREDIBASE_TENANT_ID
max_new_tokens: 256 max_new_tokens: 256

View file

@ -2020,7 +2020,7 @@ def test_openai_chat_completion_complete_response_call():
# test_openai_chat_completion_complete_response_call() # test_openai_chat_completion_complete_response_call()
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model", "model",
["gpt-3.5-turbo", "azure/chatgpt-v-2"], ["gpt-3.5-turbo", "azure/chatgpt-v-2", "claude-3-haiku-20240307"], #
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
"sync", "sync",
@ -2028,14 +2028,14 @@ def test_openai_chat_completion_complete_response_call():
) )
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_openai_stream_options_call(model, sync): async def test_openai_stream_options_call(model, sync):
litellm.set_verbose = False litellm.set_verbose = True
usage = None usage = None
chunks = [] chunks = []
if sync: if sync:
response = litellm.completion( response = litellm.completion(
model=model, model=model,
messages=[ messages=[
{"role": "system", "content": "say GM - we're going to make it "} {"role": "user", "content": "say GM - we're going to make it "},
], ],
stream=True, stream=True,
stream_options={"include_usage": True}, stream_options={"include_usage": True},
@ -2047,9 +2047,7 @@ async def test_openai_stream_options_call(model, sync):
else: else:
response = await litellm.acompletion( response = await litellm.acompletion(
model=model, model=model,
messages=[ messages=[{"role": "user", "content": "say GM - we're going to make it "}],
{"role": "system", "content": "say GM - we're going to make it "}
],
stream=True, stream=True,
stream_options={"include_usage": True}, stream_options={"include_usage": True},
max_tokens=10, max_tokens=10,

View file

@ -8746,7 +8746,7 @@ class CustomStreamWrapper:
verbose_logger.debug(traceback.format_exc()) verbose_logger.debug(traceback.format_exc())
return "" return ""
def model_response_creator(self): def model_response_creator(self, chunk: Optional[dict] = None):
_model = self.model _model = self.model
_received_llm_provider = self.custom_llm_provider _received_llm_provider = self.custom_llm_provider
_logging_obj_llm_provider = self.logging_obj.model_call_details.get("custom_llm_provider", None) # type: ignore _logging_obj_llm_provider = self.logging_obj.model_call_details.get("custom_llm_provider", None) # type: ignore
@ -8755,13 +8755,18 @@ class CustomStreamWrapper:
and _received_llm_provider != _logging_obj_llm_provider and _received_llm_provider != _logging_obj_llm_provider
): ):
_model = "{}/{}".format(_logging_obj_llm_provider, _model) _model = "{}/{}".format(_logging_obj_llm_provider, _model)
if chunk is None:
chunk = {}
else:
# pop model keyword
chunk.pop("model", None)
model_response = ModelResponse( model_response = ModelResponse(
stream=True, model=_model, stream_options=self.stream_options stream=True, model=_model, stream_options=self.stream_options, **chunk
) )
if self.response_id is not None: if self.response_id is not None:
model_response.id = self.response_id model_response.id = self.response_id
else: else:
self.response_id = model_response.id self.response_id = model_response.id # type: ignore
if self.system_fingerprint is not None: if self.system_fingerprint is not None:
model_response.system_fingerprint = self.system_fingerprint model_response.system_fingerprint = self.system_fingerprint
model_response._hidden_params["custom_llm_provider"] = _logging_obj_llm_provider model_response._hidden_params["custom_llm_provider"] = _logging_obj_llm_provider
@ -8790,26 +8795,33 @@ class CustomStreamWrapper:
if self.received_finish_reason is not None: if self.received_finish_reason is not None:
raise StopIteration raise StopIteration
response_obj: GChunk = chunk anthropic_response_obj: GChunk = chunk
completion_obj["content"] = response_obj["text"] completion_obj["content"] = anthropic_response_obj["text"]
if response_obj["is_finished"]: if anthropic_response_obj["is_finished"]:
self.received_finish_reason = response_obj["finish_reason"] self.received_finish_reason = anthropic_response_obj[
"finish_reason"
]
if ( if (
self.stream_options self.stream_options
and self.stream_options.get("include_usage", False) is True and self.stream_options.get("include_usage", False) is True
and response_obj["usage"] is not None and anthropic_response_obj["usage"] is not None
): ):
self.sent_stream_usage = True
model_response.usage = litellm.Usage( model_response.usage = litellm.Usage(
prompt_tokens=response_obj["usage"]["prompt_tokens"], prompt_tokens=anthropic_response_obj["usage"]["prompt_tokens"],
completion_tokens=response_obj["usage"]["completion_tokens"], completion_tokens=anthropic_response_obj["usage"][
total_tokens=response_obj["usage"]["total_tokens"], "completion_tokens"
],
total_tokens=anthropic_response_obj["usage"]["total_tokens"],
) )
if "tool_use" in response_obj and response_obj["tool_use"] is not None: if (
completion_obj["tool_calls"] = [response_obj["tool_use"]] "tool_use" in anthropic_response_obj
and anthropic_response_obj["tool_use"] is not None
):
completion_obj["tool_calls"] = [anthropic_response_obj["tool_use"]]
response_obj = anthropic_response_obj
elif ( elif (
self.custom_llm_provider self.custom_llm_provider
and self.custom_llm_provider == "anthropic_text" and self.custom_llm_provider == "anthropic_text"
@ -8918,7 +8930,6 @@ class CustomStreamWrapper:
and self.stream_options.get("include_usage", False) is True and self.stream_options.get("include_usage", False) is True
and response_obj["usage"] is not None and response_obj["usage"] is not None
): ):
self.sent_stream_usage = True
model_response.usage = litellm.Usage( model_response.usage = litellm.Usage(
prompt_tokens=response_obj["usage"]["prompt_tokens"], prompt_tokens=response_obj["usage"]["prompt_tokens"],
completion_tokens=response_obj["usage"]["completion_tokens"], completion_tokens=response_obj["usage"]["completion_tokens"],
@ -9046,7 +9057,6 @@ class CustomStreamWrapper:
and self.stream_options.get("include_usage", False) is True and self.stream_options.get("include_usage", False) is True
and response_obj["usage"] is not None and response_obj["usage"] is not None
): ):
self.sent_stream_usage = True
model_response.usage = litellm.Usage( model_response.usage = litellm.Usage(
prompt_tokens=response_obj["usage"]["inputTokens"], prompt_tokens=response_obj["usage"]["inputTokens"],
completion_tokens=response_obj["usage"]["outputTokens"], completion_tokens=response_obj["usage"]["outputTokens"],
@ -9118,7 +9128,6 @@ class CustomStreamWrapper:
and self.stream_options.get("include_usage", False) == True and self.stream_options.get("include_usage", False) == True
and response_obj["usage"] is not None and response_obj["usage"] is not None
): ):
self.sent_stream_usage = True
model_response.usage = litellm.Usage( model_response.usage = litellm.Usage(
prompt_tokens=response_obj["usage"].prompt_tokens, prompt_tokens=response_obj["usage"].prompt_tokens,
completion_tokens=response_obj["usage"].completion_tokens, completion_tokens=response_obj["usage"].completion_tokens,
@ -9137,7 +9146,6 @@ class CustomStreamWrapper:
and self.stream_options.get("include_usage", False) == True and self.stream_options.get("include_usage", False) == True
and response_obj["usage"] is not None and response_obj["usage"] is not None
): ):
self.sent_stream_usage = True
model_response.usage = litellm.Usage( model_response.usage = litellm.Usage(
prompt_tokens=response_obj["usage"].prompt_tokens, prompt_tokens=response_obj["usage"].prompt_tokens,
completion_tokens=response_obj["usage"].completion_tokens, completion_tokens=response_obj["usage"].completion_tokens,
@ -9154,7 +9162,6 @@ class CustomStreamWrapper:
and self.stream_options.get("include_usage", False) == True and self.stream_options.get("include_usage", False) == True
and response_obj["usage"] is not None and response_obj["usage"] is not None
): ):
self.sent_stream_usage = True
model_response.usage = litellm.Usage( model_response.usage = litellm.Usage(
prompt_tokens=response_obj["usage"].prompt_tokens, prompt_tokens=response_obj["usage"].prompt_tokens,
completion_tokens=response_obj["usage"].completion_tokens, completion_tokens=response_obj["usage"].completion_tokens,
@ -9218,7 +9225,6 @@ class CustomStreamWrapper:
and self.stream_options["include_usage"] == True and self.stream_options["include_usage"] == True
and response_obj["usage"] is not None and response_obj["usage"] is not None
): ):
self.sent_stream_usage = True
model_response.usage = litellm.Usage( model_response.usage = litellm.Usage(
prompt_tokens=response_obj["usage"].prompt_tokens, prompt_tokens=response_obj["usage"].prompt_tokens,
completion_tokens=response_obj["usage"].completion_tokens, completion_tokens=response_obj["usage"].completion_tokens,
@ -9543,9 +9549,24 @@ class CustomStreamWrapper:
self.rules.post_call_rules( self.rules.post_call_rules(
input=self.response_uptil_now, model=self.model input=self.response_uptil_now, model=self.model
) )
# RETURN RESULT # HANDLE STREAM OPTIONS
self.chunks.append(response) self.chunks.append(response)
if hasattr(
response, "usage"
): # remove usage from chunk, only send on final chunk
# Convert the object to a dictionary
obj_dict = response.dict()
# Remove an attribute (e.g., 'attr2')
if "usage" in obj_dict:
del obj_dict["usage"]
# Create a new object without the removed attribute
response = self.model_response_creator(chunk=obj_dict)
# RETURN RESULT
return response return response
except StopIteration: except StopIteration:
if self.sent_last_chunk == True: if self.sent_last_chunk == True:
if ( if (
@ -9673,6 +9694,18 @@ class CustomStreamWrapper:
) )
print_verbose(f"final returned processed chunk: {processed_chunk}") print_verbose(f"final returned processed chunk: {processed_chunk}")
self.chunks.append(processed_chunk) self.chunks.append(processed_chunk)
if hasattr(
processed_chunk, "usage"
): # remove usage from chunk, only send on final chunk
# Convert the object to a dictionary
obj_dict = processed_chunk.dict()
# Remove an attribute (e.g., 'attr2')
if "usage" in obj_dict:
del obj_dict["usage"]
# Create a new object without the removed attribute
processed_chunk = self.model_response_creator(chunk=obj_dict)
return processed_chunk return processed_chunk
raise StopAsyncIteration raise StopAsyncIteration
else: # temporary patch for non-aiohttp async calls else: # temporary patch for non-aiohttp async calls
@ -9715,11 +9748,11 @@ class CustomStreamWrapper:
self.chunks.append(processed_chunk) self.chunks.append(processed_chunk)
return processed_chunk return processed_chunk
except StopAsyncIteration: except StopAsyncIteration:
if self.sent_last_chunk == True: if self.sent_last_chunk is True:
if ( if (
self.sent_stream_usage == False self.sent_stream_usage is False
and self.stream_options is not None and self.stream_options is not None
and self.stream_options.get("include_usage", False) == True and self.stream_options.get("include_usage", False) is True
): ):
# send the final chunk with stream options # send the final chunk with stream options
complete_streaming_response = litellm.stream_chunk_builder( complete_streaming_response = litellm.stream_chunk_builder(
@ -9753,7 +9786,29 @@ class CustomStreamWrapper:
) )
return processed_chunk return processed_chunk
except StopIteration: except StopIteration:
if self.sent_last_chunk == True: if self.sent_last_chunk is True:
if (
self.sent_stream_usage is False
and self.stream_options is not None
and self.stream_options.get("include_usage", False) is True
):
# send the final chunk with stream options
complete_streaming_response = litellm.stream_chunk_builder(
chunks=self.chunks, messages=self.messages
)
response = self.model_response_creator()
response.usage = complete_streaming_response.usage
## LOGGING
threading.Thread(
target=self.logging_obj.success_handler, args=(response,)
).start() # log response
asyncio.create_task(
self.logging_obj.async_success_handler(
response,
)
)
self.sent_stream_usage = True
return response
raise StopAsyncIteration raise StopAsyncIteration
else: else:
self.sent_last_chunk = True self.sent_last_chunk = True