Merge pull request #1618 from BerriAI/litellm_sagemaker_cost_tracking_fixes

fix(utils.py): fix sagemaker cost tracking for streaming
This commit is contained in:
Krish Dholakia 2024-01-25 19:01:57 -08:00 committed by GitHub
commit 612f74a426
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 282 additions and 79 deletions

View file

@ -1418,7 +1418,9 @@ class Logging:
"""
Implementing async callbacks, to handle asyncio event loop issues when custom integrations need to use async functions.
"""
print_verbose(f"Async success callbacks: {litellm._async_success_callback}")
verbose_logger.debug(
f"Async success callbacks: {litellm._async_success_callback}"
)
start_time, end_time, result = self._success_handler_helper_fn(
start_time=start_time, end_time=end_time, result=result, cache_hit=cache_hit
)
@ -1427,7 +1429,7 @@ class Logging:
if self.stream:
if result.choices[0].finish_reason is not None: # if it's the last chunk
self.streaming_chunks.append(result)
# print_verbose(f"final set of received chunks: {self.streaming_chunks}")
# verbose_logger.debug(f"final set of received chunks: {self.streaming_chunks}")
try:
complete_streaming_response = litellm.stream_chunk_builder(
self.streaming_chunks,
@ -1436,14 +1438,16 @@ class Logging:
end_time=end_time,
)
except Exception as e:
print_verbose(
verbose_logger.debug(
f"Error occurred building stream chunk: {traceback.format_exc()}"
)
complete_streaming_response = None
else:
self.streaming_chunks.append(result)
if complete_streaming_response is not None:
print_verbose("Async success callbacks: Got a complete streaming response")
verbose_logger.debug(
"Async success callbacks: Got a complete streaming response"
)
self.model_call_details[
"complete_streaming_response"
] = complete_streaming_response
@ -7152,6 +7156,7 @@ class CustomStreamWrapper:
"model_id": (_model_info.get("id", None))
} # returned as x-litellm-model-id response header in proxy
self.response_id = None
self.logging_loop = None
def __iter__(self):
return self
@ -7722,6 +7727,27 @@ class CustomStreamWrapper:
}
return ""
def handle_sagemaker_stream(self, chunk):
if "data: [DONE]" in chunk:
text = ""
is_finished = True
finish_reason = "stop"
return {
"text": text,
"is_finished": is_finished,
"finish_reason": finish_reason,
}
elif isinstance(chunk, dict):
if chunk["is_finished"] == True:
finish_reason = "stop"
else:
finish_reason = ""
return {
"text": chunk["text"],
"is_finished": chunk["is_finished"],
"finish_reason": finish_reason,
}
def chunk_creator(self, chunk):
model_response = ModelResponse(stream=True, model=self.model)
if self.response_id is not None:
@ -7729,6 +7755,7 @@ class CustomStreamWrapper:
else:
self.response_id = model_response.id
model_response._hidden_params["custom_llm_provider"] = self.custom_llm_provider
model_response._hidden_params["created_at"] = time.time()
model_response.choices = [StreamingChoices()]
model_response.choices[0].finish_reason = None
response_obj = {}
@ -7847,8 +7874,14 @@ class CustomStreamWrapper:
]
self.sent_last_chunk = True
elif self.custom_llm_provider == "sagemaker":
print_verbose(f"ENTERS SAGEMAKER STREAMING for chunk {chunk}")
completion_obj["content"] = chunk
verbose_logger.debug(f"ENTERS SAGEMAKER STREAMING for chunk {chunk}")
response_obj = self.handle_sagemaker_stream(chunk)
completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]:
model_response.choices[0].finish_reason = response_obj[
"finish_reason"
]
self.sent_last_chunk = True
elif self.custom_llm_provider == "petals":
if len(self.completion_stream) == 0:
if self.sent_last_chunk:
@ -8024,6 +8057,27 @@ class CustomStreamWrapper:
original_exception=e,
)
def set_logging_event_loop(self, loop):
self.logging_loop = loop
async def your_async_function(self):
# Your asynchronous code here
return "Your asynchronous code is running"
def run_success_logging_in_thread(self, processed_chunk):
# Create an event loop for the new thread
## ASYNC LOGGING
if self.logging_loop is not None:
future = asyncio.run_coroutine_threadsafe(
self.logging_obj.async_success_handler(processed_chunk),
loop=self.logging_loop,
)
result = future.result()
else:
asyncio.run(self.logging_obj.async_success_handler(processed_chunk))
## SYNC LOGGING
self.logging_obj.success_handler(processed_chunk)
## needs to handle the empty string case (even starting chunk can be an empty string)
def __next__(self):
try:
@ -8042,8 +8096,9 @@ class CustomStreamWrapper:
continue
## LOGGING
threading.Thread(
target=self.logging_obj.success_handler, args=(response,)
target=self.run_success_logging_in_thread, args=(response,)
).start() # log response
# RETURN RESULT
return response
except StopIteration:
@ -8099,13 +8154,34 @@ class CustomStreamWrapper:
raise StopAsyncIteration
else: # temporary patch for non-aiohttp async calls
# example - boto3 bedrock llms
processed_chunk = next(self)
asyncio.create_task(
self.logging_obj.async_success_handler(
processed_chunk,
)
)
return processed_chunk
while True:
if isinstance(self.completion_stream, str) or isinstance(
self.completion_stream, bytes
):
chunk = self.completion_stream
else:
chunk = next(self.completion_stream)
if chunk is not None and chunk != b"":
print_verbose(f"PROCESSED CHUNK PRE CHUNK CREATOR: {chunk}")
processed_chunk = self.chunk_creator(chunk=chunk)
print_verbose(
f"PROCESSED CHUNK POST CHUNK CREATOR: {processed_chunk}"
)
if processed_chunk is None:
continue
## LOGGING
threading.Thread(
target=self.logging_obj.success_handler,
args=(processed_chunk,),
).start() # log processed_chunk
asyncio.create_task(
self.logging_obj.async_success_handler(
processed_chunk,
)
)
# RETURN RESULT
return processed_chunk
except StopAsyncIteration:
raise
except StopIteration: