fix(utils.py): stream_options working across all providers

This commit is contained in:
Krrish Dholakia 2024-07-03 20:40:46 -07:00
parent 8dbe0559dd
commit 2e5a81f280
5 changed files with 98 additions and 35 deletions

View file

@ -8746,7 +8746,7 @@ class CustomStreamWrapper:
verbose_logger.debug(traceback.format_exc())
return ""
def model_response_creator(self):
def model_response_creator(self, chunk: Optional[dict] = None):
_model = self.model
_received_llm_provider = self.custom_llm_provider
_logging_obj_llm_provider = self.logging_obj.model_call_details.get("custom_llm_provider", None) # type: ignore
@ -8755,13 +8755,18 @@ class CustomStreamWrapper:
and _received_llm_provider != _logging_obj_llm_provider
):
_model = "{}/{}".format(_logging_obj_llm_provider, _model)
if chunk is None:
chunk = {}
else:
# pop model keyword
chunk.pop("model", None)
model_response = ModelResponse(
stream=True, model=_model, stream_options=self.stream_options
stream=True, model=_model, stream_options=self.stream_options, **chunk
)
if self.response_id is not None:
model_response.id = self.response_id
else:
self.response_id = model_response.id
self.response_id = model_response.id # type: ignore
if self.system_fingerprint is not None:
model_response.system_fingerprint = self.system_fingerprint
model_response._hidden_params["custom_llm_provider"] = _logging_obj_llm_provider
@ -8790,26 +8795,33 @@ class CustomStreamWrapper:
if self.received_finish_reason is not None:
raise StopIteration
response_obj: GChunk = chunk
completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]:
self.received_finish_reason = response_obj["finish_reason"]
anthropic_response_obj: GChunk = chunk
completion_obj["content"] = anthropic_response_obj["text"]
if anthropic_response_obj["is_finished"]:
self.received_finish_reason = anthropic_response_obj[
"finish_reason"
]
if (
self.stream_options
and self.stream_options.get("include_usage", False) is True
and response_obj["usage"] is not None
and anthropic_response_obj["usage"] is not None
):
self.sent_stream_usage = True
model_response.usage = litellm.Usage(
prompt_tokens=response_obj["usage"]["prompt_tokens"],
completion_tokens=response_obj["usage"]["completion_tokens"],
total_tokens=response_obj["usage"]["total_tokens"],
prompt_tokens=anthropic_response_obj["usage"]["prompt_tokens"],
completion_tokens=anthropic_response_obj["usage"][
"completion_tokens"
],
total_tokens=anthropic_response_obj["usage"]["total_tokens"],
)
if "tool_use" in response_obj and response_obj["tool_use"] is not None:
completion_obj["tool_calls"] = [response_obj["tool_use"]]
if (
"tool_use" in anthropic_response_obj
and anthropic_response_obj["tool_use"] is not None
):
completion_obj["tool_calls"] = [anthropic_response_obj["tool_use"]]
response_obj = anthropic_response_obj
elif (
self.custom_llm_provider
and self.custom_llm_provider == "anthropic_text"
@ -8918,7 +8930,6 @@ class CustomStreamWrapper:
and self.stream_options.get("include_usage", False) is True
and response_obj["usage"] is not None
):
self.sent_stream_usage = True
model_response.usage = litellm.Usage(
prompt_tokens=response_obj["usage"]["prompt_tokens"],
completion_tokens=response_obj["usage"]["completion_tokens"],
@ -9046,7 +9057,6 @@ class CustomStreamWrapper:
and self.stream_options.get("include_usage", False) is True
and response_obj["usage"] is not None
):
self.sent_stream_usage = True
model_response.usage = litellm.Usage(
prompt_tokens=response_obj["usage"]["inputTokens"],
completion_tokens=response_obj["usage"]["outputTokens"],
@ -9118,7 +9128,6 @@ class CustomStreamWrapper:
and self.stream_options.get("include_usage", False) == True
and response_obj["usage"] is not None
):
self.sent_stream_usage = True
model_response.usage = litellm.Usage(
prompt_tokens=response_obj["usage"].prompt_tokens,
completion_tokens=response_obj["usage"].completion_tokens,
@ -9137,7 +9146,6 @@ class CustomStreamWrapper:
and self.stream_options.get("include_usage", False) == True
and response_obj["usage"] is not None
):
self.sent_stream_usage = True
model_response.usage = litellm.Usage(
prompt_tokens=response_obj["usage"].prompt_tokens,
completion_tokens=response_obj["usage"].completion_tokens,
@ -9154,7 +9162,6 @@ class CustomStreamWrapper:
and self.stream_options.get("include_usage", False) == True
and response_obj["usage"] is not None
):
self.sent_stream_usage = True
model_response.usage = litellm.Usage(
prompt_tokens=response_obj["usage"].prompt_tokens,
completion_tokens=response_obj["usage"].completion_tokens,
@ -9218,7 +9225,6 @@ class CustomStreamWrapper:
and self.stream_options["include_usage"] == True
and response_obj["usage"] is not None
):
self.sent_stream_usage = True
model_response.usage = litellm.Usage(
prompt_tokens=response_obj["usage"].prompt_tokens,
completion_tokens=response_obj["usage"].completion_tokens,
@ -9543,9 +9549,24 @@ class CustomStreamWrapper:
self.rules.post_call_rules(
input=self.response_uptil_now, model=self.model
)
# RETURN RESULT
# HANDLE STREAM OPTIONS
self.chunks.append(response)
if hasattr(
response, "usage"
): # remove usage from chunk, only send on final chunk
# Convert the object to a dictionary
obj_dict = response.dict()
# Remove an attribute (e.g., 'attr2')
if "usage" in obj_dict:
del obj_dict["usage"]
# Create a new object without the removed attribute
response = self.model_response_creator(chunk=obj_dict)
# RETURN RESULT
return response
except StopIteration:
if self.sent_last_chunk == True:
if (
@ -9673,6 +9694,18 @@ class CustomStreamWrapper:
)
print_verbose(f"final returned processed chunk: {processed_chunk}")
self.chunks.append(processed_chunk)
if hasattr(
processed_chunk, "usage"
): # remove usage from chunk, only send on final chunk
# Convert the object to a dictionary
obj_dict = processed_chunk.dict()
# Remove an attribute (e.g., 'attr2')
if "usage" in obj_dict:
del obj_dict["usage"]
# Create a new object without the removed attribute
processed_chunk = self.model_response_creator(chunk=obj_dict)
return processed_chunk
raise StopAsyncIteration
else: # temporary patch for non-aiohttp async calls
@ -9715,11 +9748,11 @@ class CustomStreamWrapper:
self.chunks.append(processed_chunk)
return processed_chunk
except StopAsyncIteration:
if self.sent_last_chunk == True:
if self.sent_last_chunk is True:
if (
self.sent_stream_usage == False
self.sent_stream_usage is False
and self.stream_options is not None
and self.stream_options.get("include_usage", False) == True
and self.stream_options.get("include_usage", False) is True
):
# send the final chunk with stream options
complete_streaming_response = litellm.stream_chunk_builder(
@ -9753,7 +9786,29 @@ class CustomStreamWrapper:
)
return processed_chunk
except StopIteration:
if self.sent_last_chunk == True:
if self.sent_last_chunk is True:
if (
self.sent_stream_usage is False
and self.stream_options is not None
and self.stream_options.get("include_usage", False) is True
):
# send the final chunk with stream options
complete_streaming_response = litellm.stream_chunk_builder(
chunks=self.chunks, messages=self.messages
)
response = self.model_response_creator()
response.usage = complete_streaming_response.usage
## LOGGING
threading.Thread(
target=self.logging_obj.success_handler, args=(response,)
).start() # log response
asyncio.create_task(
self.logging_obj.async_success_handler(
response,
)
)
self.sent_stream_usage = True
return response
raise StopAsyncIteration
else:
self.sent_last_chunk = True