fix(utils.py): fix sagemaker async logging for sync streaming

https://github.com/BerriAI/litellm/issues/1592
This commit is contained in:
Krrish Dholakia 2024-01-25 12:49:45 -08:00
parent 39d5407e67
commit 09ec6d6458
10 changed files with 247 additions and 64 deletions

View file

@ -1417,7 +1417,9 @@ class Logging:
"""
Implementing async callbacks, to handle asyncio event loop issues when custom integrations need to use async functions.
"""
print_verbose(f"Async success callbacks: {litellm._async_success_callback}")
verbose_logger.debug(
f"Async success callbacks: {litellm._async_success_callback}"
)
start_time, end_time, result = self._success_handler_helper_fn(
start_time=start_time, end_time=end_time, result=result, cache_hit=cache_hit
)
@ -1426,7 +1428,7 @@ class Logging:
if self.stream:
if result.choices[0].finish_reason is not None: # if it's the last chunk
self.streaming_chunks.append(result)
# print_verbose(f"final set of received chunks: {self.streaming_chunks}")
# verbose_logger.debug(f"final set of received chunks: {self.streaming_chunks}")
try:
complete_streaming_response = litellm.stream_chunk_builder(
self.streaming_chunks,
@ -1435,14 +1437,16 @@ class Logging:
end_time=end_time,
)
except Exception as e:
print_verbose(
verbose_logger.debug(
f"Error occurred building stream chunk: {traceback.format_exc()}"
)
complete_streaming_response = None
else:
self.streaming_chunks.append(result)
if complete_streaming_response is not None:
print_verbose("Async success callbacks: Got a complete streaming response")
verbose_logger.debug(
"Async success callbacks: Got a complete streaming response"
)
self.model_call_details[
"complete_streaming_response"
] = complete_streaming_response
@ -7682,6 +7686,27 @@ class CustomStreamWrapper:
}
return ""
def handle_sagemaker_stream(self, chunk):
if "data: [DONE]" in chunk:
text = ""
is_finished = True
finish_reason = "stop"
return {
"text": text,
"is_finished": is_finished,
"finish_reason": finish_reason,
}
elif isinstance(chunk, dict):
if chunk["is_finished"] == True:
finish_reason = "stop"
else:
finish_reason = ""
return {
"text": chunk["text"],
"is_finished": chunk["is_finished"],
"finish_reason": finish_reason,
}
def chunk_creator(self, chunk):
model_response = ModelResponse(stream=True, model=self.model)
if self.response_id is not None:
@ -7807,8 +7832,14 @@ class CustomStreamWrapper:
]
self.sent_last_chunk = True
elif self.custom_llm_provider == "sagemaker":
print_verbose(f"ENTERS SAGEMAKER STREAMING for chunk {chunk}")
completion_obj["content"] = chunk
verbose_logger.debug(f"ENTERS SAGEMAKER STREAMING for chunk {chunk}")
response_obj = self.handle_sagemaker_stream(chunk)
completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]:
model_response.choices[0].finish_reason = response_obj[
"finish_reason"
]
self.sent_last_chunk = True
elif self.custom_llm_provider == "petals":
if len(self.completion_stream) == 0:
if self.sent_last_chunk:
@ -7984,6 +8015,19 @@ class CustomStreamWrapper:
original_exception=e,
)
def run_success_logging_in_thread(self, processed_chunk):
# Create an event loop for the new thread
## ASYNC LOGGING
# Run the asynchronous function in the new thread's event loop
asyncio.run(
self.logging_obj.async_success_handler(
processed_chunk,
)
)
## SYNC LOGGING
self.logging_obj.success_handler(processed_chunk)
## needs to handle the empty string case (even starting chunk can be an empty string)
def __next__(self):
try:
@ -8002,8 +8046,9 @@ class CustomStreamWrapper:
continue
## LOGGING
threading.Thread(
target=self.logging_obj.success_handler, args=(response,)
target=self.run_success_logging_in_thread, args=(response,)
).start() # log response
# RETURN RESULT
return response
except StopIteration:
@ -8059,13 +8104,34 @@ class CustomStreamWrapper:
raise StopAsyncIteration
else: # temporary patch for non-aiohttp async calls
# example - boto3 bedrock llms
processed_chunk = next(self)
asyncio.create_task(
self.logging_obj.async_success_handler(
processed_chunk,
)
)
return processed_chunk
while True:
if isinstance(self.completion_stream, str) or isinstance(
self.completion_stream, bytes
):
chunk = self.completion_stream
else:
chunk = next(self.completion_stream)
if chunk is not None and chunk != b"":
print_verbose(f"PROCESSED CHUNK PRE CHUNK CREATOR: {chunk}")
processed_chunk = self.chunk_creator(chunk=chunk)
print_verbose(
f"PROCESSED CHUNK POST CHUNK CREATOR: {processed_chunk}"
)
if processed_chunk is None:
continue
## LOGGING
threading.Thread(
target=self.logging_obj.success_handler,
args=(processed_chunk,),
).start() # log processed_chunk
asyncio.create_task(
self.logging_obj.async_success_handler(
processed_chunk,
)
)
# RETURN RESULT
return processed_chunk
except StopAsyncIteration:
raise
except StopIteration: