mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
Merge pull request #1618 from BerriAI/litellm_sagemaker_cost_tracking_fixes
fix(utils.py): fix sagemaker cost tracking for streaming
This commit is contained in:
commit
612f74a426
11 changed files with 282 additions and 79 deletions
104
litellm/utils.py
104
litellm/utils.py
|
@ -1418,7 +1418,9 @@ class Logging:
|
|||
"""
|
||||
Implementing async callbacks, to handle asyncio event loop issues when custom integrations need to use async functions.
|
||||
"""
|
||||
print_verbose(f"Async success callbacks: {litellm._async_success_callback}")
|
||||
verbose_logger.debug(
|
||||
f"Async success callbacks: {litellm._async_success_callback}"
|
||||
)
|
||||
start_time, end_time, result = self._success_handler_helper_fn(
|
||||
start_time=start_time, end_time=end_time, result=result, cache_hit=cache_hit
|
||||
)
|
||||
|
@ -1427,7 +1429,7 @@ class Logging:
|
|||
if self.stream:
|
||||
if result.choices[0].finish_reason is not None: # if it's the last chunk
|
||||
self.streaming_chunks.append(result)
|
||||
# print_verbose(f"final set of received chunks: {self.streaming_chunks}")
|
||||
# verbose_logger.debug(f"final set of received chunks: {self.streaming_chunks}")
|
||||
try:
|
||||
complete_streaming_response = litellm.stream_chunk_builder(
|
||||
self.streaming_chunks,
|
||||
|
@ -1436,14 +1438,16 @@ class Logging:
|
|||
end_time=end_time,
|
||||
)
|
||||
except Exception as e:
|
||||
print_verbose(
|
||||
verbose_logger.debug(
|
||||
f"Error occurred building stream chunk: {traceback.format_exc()}"
|
||||
)
|
||||
complete_streaming_response = None
|
||||
else:
|
||||
self.streaming_chunks.append(result)
|
||||
if complete_streaming_response is not None:
|
||||
print_verbose("Async success callbacks: Got a complete streaming response")
|
||||
verbose_logger.debug(
|
||||
"Async success callbacks: Got a complete streaming response"
|
||||
)
|
||||
self.model_call_details[
|
||||
"complete_streaming_response"
|
||||
] = complete_streaming_response
|
||||
|
@ -7152,6 +7156,7 @@ class CustomStreamWrapper:
|
|||
"model_id": (_model_info.get("id", None))
|
||||
} # returned as x-litellm-model-id response header in proxy
|
||||
self.response_id = None
|
||||
self.logging_loop = None
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
@ -7722,6 +7727,27 @@ class CustomStreamWrapper:
|
|||
}
|
||||
return ""
|
||||
|
||||
def handle_sagemaker_stream(self, chunk):
|
||||
if "data: [DONE]" in chunk:
|
||||
text = ""
|
||||
is_finished = True
|
||||
finish_reason = "stop"
|
||||
return {
|
||||
"text": text,
|
||||
"is_finished": is_finished,
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
elif isinstance(chunk, dict):
|
||||
if chunk["is_finished"] == True:
|
||||
finish_reason = "stop"
|
||||
else:
|
||||
finish_reason = ""
|
||||
return {
|
||||
"text": chunk["text"],
|
||||
"is_finished": chunk["is_finished"],
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
|
||||
def chunk_creator(self, chunk):
|
||||
model_response = ModelResponse(stream=True, model=self.model)
|
||||
if self.response_id is not None:
|
||||
|
@ -7729,6 +7755,7 @@ class CustomStreamWrapper:
|
|||
else:
|
||||
self.response_id = model_response.id
|
||||
model_response._hidden_params["custom_llm_provider"] = self.custom_llm_provider
|
||||
model_response._hidden_params["created_at"] = time.time()
|
||||
model_response.choices = [StreamingChoices()]
|
||||
model_response.choices[0].finish_reason = None
|
||||
response_obj = {}
|
||||
|
@ -7847,8 +7874,14 @@ class CustomStreamWrapper:
|
|||
]
|
||||
self.sent_last_chunk = True
|
||||
elif self.custom_llm_provider == "sagemaker":
|
||||
print_verbose(f"ENTERS SAGEMAKER STREAMING for chunk {chunk}")
|
||||
completion_obj["content"] = chunk
|
||||
verbose_logger.debug(f"ENTERS SAGEMAKER STREAMING for chunk {chunk}")
|
||||
response_obj = self.handle_sagemaker_stream(chunk)
|
||||
completion_obj["content"] = response_obj["text"]
|
||||
if response_obj["is_finished"]:
|
||||
model_response.choices[0].finish_reason = response_obj[
|
||||
"finish_reason"
|
||||
]
|
||||
self.sent_last_chunk = True
|
||||
elif self.custom_llm_provider == "petals":
|
||||
if len(self.completion_stream) == 0:
|
||||
if self.sent_last_chunk:
|
||||
|
@ -8024,6 +8057,27 @@ class CustomStreamWrapper:
|
|||
original_exception=e,
|
||||
)
|
||||
|
||||
def set_logging_event_loop(self, loop):
|
||||
self.logging_loop = loop
|
||||
|
||||
async def your_async_function(self):
|
||||
# Your asynchronous code here
|
||||
return "Your asynchronous code is running"
|
||||
|
||||
def run_success_logging_in_thread(self, processed_chunk):
|
||||
# Create an event loop for the new thread
|
||||
## ASYNC LOGGING
|
||||
if self.logging_loop is not None:
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
self.logging_obj.async_success_handler(processed_chunk),
|
||||
loop=self.logging_loop,
|
||||
)
|
||||
result = future.result()
|
||||
else:
|
||||
asyncio.run(self.logging_obj.async_success_handler(processed_chunk))
|
||||
## SYNC LOGGING
|
||||
self.logging_obj.success_handler(processed_chunk)
|
||||
|
||||
## needs to handle the empty string case (even starting chunk can be an empty string)
|
||||
def __next__(self):
|
||||
try:
|
||||
|
@ -8042,8 +8096,9 @@ class CustomStreamWrapper:
|
|||
continue
|
||||
## LOGGING
|
||||
threading.Thread(
|
||||
target=self.logging_obj.success_handler, args=(response,)
|
||||
target=self.run_success_logging_in_thread, args=(response,)
|
||||
).start() # log response
|
||||
|
||||
# RETURN RESULT
|
||||
return response
|
||||
except StopIteration:
|
||||
|
@ -8099,13 +8154,34 @@ class CustomStreamWrapper:
|
|||
raise StopAsyncIteration
|
||||
else: # temporary patch for non-aiohttp async calls
|
||||
# example - boto3 bedrock llms
|
||||
processed_chunk = next(self)
|
||||
asyncio.create_task(
|
||||
self.logging_obj.async_success_handler(
|
||||
processed_chunk,
|
||||
)
|
||||
)
|
||||
return processed_chunk
|
||||
while True:
|
||||
if isinstance(self.completion_stream, str) or isinstance(
|
||||
self.completion_stream, bytes
|
||||
):
|
||||
chunk = self.completion_stream
|
||||
else:
|
||||
chunk = next(self.completion_stream)
|
||||
if chunk is not None and chunk != b"":
|
||||
print_verbose(f"PROCESSED CHUNK PRE CHUNK CREATOR: {chunk}")
|
||||
processed_chunk = self.chunk_creator(chunk=chunk)
|
||||
print_verbose(
|
||||
f"PROCESSED CHUNK POST CHUNK CREATOR: {processed_chunk}"
|
||||
)
|
||||
if processed_chunk is None:
|
||||
continue
|
||||
## LOGGING
|
||||
threading.Thread(
|
||||
target=self.logging_obj.success_handler,
|
||||
args=(processed_chunk,),
|
||||
).start() # log processed_chunk
|
||||
asyncio.create_task(
|
||||
self.logging_obj.async_success_handler(
|
||||
processed_chunk,
|
||||
)
|
||||
)
|
||||
|
||||
# RETURN RESULT
|
||||
return processed_chunk
|
||||
except StopAsyncIteration:
|
||||
raise
|
||||
except StopIteration:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue