From 09ec6d645851fcc62b2851eb4b421a2a77e89468 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 25 Jan 2024 12:49:45 -0800 Subject: [PATCH] fix(utils.py): fix sagemaker async logging for sync streaming https://github.com/BerriAI/litellm/issues/1592 --- .circleci/config.yml | 3 + litellm/llms/sagemaker.py | 35 +++++--- litellm/main.py | 11 +-- litellm/proxy/proxy_server.py | 3 + litellm/proxy/utils.py | 5 +- litellm/tests/test_custom_callback_input.py | 41 +++++++++ litellm/tests/test_streaming.py | 70 ++++++++------- litellm/utils.py | 94 ++++++++++++++++++--- proxy_server_config.yaml | 4 + tests/test_keys.py | 45 +++++++++- 10 files changed, 247 insertions(+), 64 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 1de72a156..e0e6f5743 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -147,6 +147,9 @@ jobs: -e AZURE_API_KEY=$AZURE_API_KEY \ -e AZURE_FRANCE_API_KEY=$AZURE_FRANCE_API_KEY \ -e AZURE_EUROPE_API_KEY=$AZURE_EUROPE_API_KEY \ + -e AWS_ACCESS_KEY_ID=$AWS_ACCESS_KEY_ID \ + -e AWS_SECRET_ACCESS_KEY=$AWS_SECRET_ACCESS_KEY \ + -e AWS_REGION_NAME=$AWS_REGION_NAME \ --name my-app \ -v $(pwd)/proxy_server_config.yaml:/app/config.yaml \ my-app:latest \ diff --git a/litellm/llms/sagemaker.py b/litellm/llms/sagemaker.py index 1608f7a0f..78aafe7f7 100644 --- a/litellm/llms/sagemaker.py +++ b/litellm/llms/sagemaker.py @@ -34,22 +34,35 @@ class TokenIterator: self.byte_iterator = iter(stream) self.buffer = io.BytesIO() self.read_pos = 0 + self.end_of_data = False def __iter__(self): return self def __next__(self): - while True: - self.buffer.seek(self.read_pos) - line = self.buffer.readline() - if line and line[-1] == ord("\n"): - self.read_pos += len(line) + 1 - full_line = line[:-1].decode("utf-8") - line_data = json.loads(full_line.lstrip("data:").rstrip("/n")) - return line_data["token"]["text"] - chunk = next(self.byte_iterator) - self.buffer.seek(0, io.SEEK_END) - self.buffer.write(chunk["PayloadPart"]["Bytes"]) + try: + while True: + self.buffer.seek(self.read_pos) + line = self.buffer.readline() + if line and line[-1] == ord("\n"): + response_obj = {"text": "", "is_finished": False} + self.read_pos += len(line) + 1 + full_line = line[:-1].decode("utf-8") + line_data = json.loads(full_line.lstrip("data:").rstrip("/n")) + if line_data.get("generated_text", None) is not None: + self.end_of_data = True + response_obj["is_finished"] = True + response_obj["text"] = line_data["token"]["text"] + return response_obj + chunk = next(self.byte_iterator) + self.buffer.seek(0, io.SEEK_END) + self.buffer.write(chunk["PayloadPart"]["Bytes"]) + except StopIteration as e: + if self.end_of_data == True: + raise e # Re-raise StopIteration + else: + self.end_of_data = True + return "data: [DONE]" class SagemakerConfig: diff --git a/litellm/main.py b/litellm/main.py index 6b9a0bb18..fca3bd2b2 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -1514,11 +1514,6 @@ def completion( if ( "stream" in optional_params and optional_params["stream"] == True ): ## [BETA] - # sagemaker does not support streaming as of now so we're faking streaming: - # https://discuss.huggingface.co/t/streaming-output-text-when-deploying-on-sagemaker/39611 - # "SageMaker is currently not supporting streaming responses." - - # fake streaming for sagemaker print_verbose(f"ENTERS SAGEMAKER CUSTOMSTREAMWRAPPER") from .llms.sagemaker import TokenIterator @@ -1529,6 +1524,12 @@ def completion( custom_llm_provider="sagemaker", logging_obj=logging, ) + ## LOGGING + logging.post_call( + input=messages, + api_key=None, + original_response=response, + ) return response ## RESPONSE OBJECT diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index b53484b86..493ad9731 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -690,6 +690,9 @@ async def update_database( existing_spend_obj = await custom_db_client.get_data( key=id, table_name="user" ) + verbose_proxy_logger.debug( + f"Updating existing_spend_obj: {existing_spend_obj}" + ) if existing_spend_obj is None: existing_spend = 0 existing_spend_obj = LiteLLM_UserTable( diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index faa73d70b..728716886 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -409,7 +409,9 @@ class PrismaClient: hashed_token = token if token.startswith("sk-"): hashed_token = self.hash_token(token=token) - print_verbose("PrismaClient: find_unique") + verbose_proxy_logger.debug( + f"PrismaClient: find_unique for token: {hashed_token}" + ) if query_type == "find_unique": response = await self.db.litellm_verificationtoken.find_unique( where={"token": hashed_token} @@ -716,7 +718,6 @@ class PrismaClient: Batch write update queries """ batcher = self.db.batch_() - verbose_proxy_logger.debug(f"data list for user table: {data_list}") for idx, user in enumerate(data_list): try: data_json = self.jsonify_object(data=user.model_dump()) diff --git a/litellm/tests/test_custom_callback_input.py b/litellm/tests/test_custom_callback_input.py index 556628d82..a61cc843e 100644 --- a/litellm/tests/test_custom_callback_input.py +++ b/litellm/tests/test_custom_callback_input.py @@ -556,6 +556,47 @@ async def test_async_chat_bedrock_stream(): # asyncio.run(test_async_chat_bedrock_stream()) + +## Test Sagemaker + Async +@pytest.mark.asyncio +async def test_async_chat_sagemaker_stream(): + try: + customHandler = CompletionCustomHandler() + litellm.callbacks = [customHandler] + response = await litellm.acompletion( + model="sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4", + messages=[{"role": "user", "content": "Hi 👋 - i'm async sagemaker"}], + ) + # test streaming + response = await litellm.acompletion( + model="sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4", + messages=[{"role": "user", "content": "Hi 👋 - i'm async sagemaker"}], + stream=True, + ) + print(f"response: {response}") + async for chunk in response: + print(f"chunk: {chunk}") + continue + ## test failure callback + try: + response = await litellm.acompletion( + model="sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4", + messages=[{"role": "user", "content": "Hi 👋 - i'm async sagemaker"}], + aws_region_name="my-bad-key", + stream=True, + ) + async for chunk in response: + continue + except: + pass + time.sleep(1) + print(f"customHandler.errors: {customHandler.errors}") + assert len(customHandler.errors) == 0 + litellm.callbacks = [] + except Exception as e: + pytest.fail(f"An exception occurred: {str(e)}") + + # Text Completion diff --git a/litellm/tests/test_streaming.py b/litellm/tests/test_streaming.py index 14b1a7210..d9f99bece 100644 --- a/litellm/tests/test_streaming.py +++ b/litellm/tests/test_streaming.py @@ -872,41 +872,53 @@ async def test_sagemaker_streaming_async(): ) # Add any assertions here to check the response + print(response) complete_response = "" + has_finish_reason = False + # Add any assertions here to check the response + idx = 0 async for chunk in response: - complete_response += chunk.choices[0].delta.content or "" - print(f"complete_response: {complete_response}") - assert len(complete_response) > 0 + # print + chunk, finished = streaming_format_tests(idx, chunk) + has_finish_reason = finished + complete_response += chunk + if finished: + break + idx += 1 + if has_finish_reason is False: + raise Exception("finish reason not set for last chunk") + if complete_response.strip() == "": + raise Exception("Empty response received") + print(f"completion_response: {complete_response}") except Exception as e: pytest.fail(f"An exception occurred - {str(e)}") -# def test_completion_sagemaker_stream(): -# try: -# response = completion( -# model="sagemaker/jumpstart-dft-meta-textgeneration-llama-2-7b", -# messages=messages, -# temperature=0.2, -# max_tokens=80, -# stream=True, -# ) -# complete_response = "" -# has_finish_reason = False -# # Add any assertions here to check the response -# for idx, chunk in enumerate(response): -# chunk, finished = streaming_format_tests(idx, chunk) -# has_finish_reason = finished -# if finished: -# break -# complete_response += chunk -# if has_finish_reason is False: -# raise Exception("finish reason not set for last chunk") -# if complete_response.strip() == "": -# raise Exception("Empty response received") -# except InvalidRequestError as e: -# pass -# except Exception as e: -# pytest.fail(f"Error occurred: {e}") +def test_completion_sagemaker_stream(): + try: + response = completion( + model="sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4", + messages=messages, + temperature=0.2, + max_tokens=80, + stream=True, + ) + complete_response = "" + has_finish_reason = False + # Add any assertions here to check the response + for idx, chunk in enumerate(response): + chunk, finished = streaming_format_tests(idx, chunk) + has_finish_reason = finished + if finished: + break + complete_response += chunk + if has_finish_reason is False: + raise Exception("finish reason not set for last chunk") + if complete_response.strip() == "": + raise Exception("Empty response received") + except Exception as e: + pytest.fail(f"Error occurred: {e}") + # test_completion_sagemaker_stream() diff --git a/litellm/utils.py b/litellm/utils.py index 0e12463b9..fb3210b1d 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -1417,7 +1417,9 @@ class Logging: """ Implementing async callbacks, to handle asyncio event loop issues when custom integrations need to use async functions. """ - print_verbose(f"Async success callbacks: {litellm._async_success_callback}") + verbose_logger.debug( + f"Async success callbacks: {litellm._async_success_callback}" + ) start_time, end_time, result = self._success_handler_helper_fn( start_time=start_time, end_time=end_time, result=result, cache_hit=cache_hit ) @@ -1426,7 +1428,7 @@ class Logging: if self.stream: if result.choices[0].finish_reason is not None: # if it's the last chunk self.streaming_chunks.append(result) - # print_verbose(f"final set of received chunks: {self.streaming_chunks}") + # verbose_logger.debug(f"final set of received chunks: {self.streaming_chunks}") try: complete_streaming_response = litellm.stream_chunk_builder( self.streaming_chunks, @@ -1435,14 +1437,16 @@ class Logging: end_time=end_time, ) except Exception as e: - print_verbose( + verbose_logger.debug( f"Error occurred building stream chunk: {traceback.format_exc()}" ) complete_streaming_response = None else: self.streaming_chunks.append(result) if complete_streaming_response is not None: - print_verbose("Async success callbacks: Got a complete streaming response") + verbose_logger.debug( + "Async success callbacks: Got a complete streaming response" + ) self.model_call_details[ "complete_streaming_response" ] = complete_streaming_response @@ -7682,6 +7686,27 @@ class CustomStreamWrapper: } return "" + def handle_sagemaker_stream(self, chunk): + if "data: [DONE]" in chunk: + text = "" + is_finished = True + finish_reason = "stop" + return { + "text": text, + "is_finished": is_finished, + "finish_reason": finish_reason, + } + elif isinstance(chunk, dict): + if chunk["is_finished"] == True: + finish_reason = "stop" + else: + finish_reason = "" + return { + "text": chunk["text"], + "is_finished": chunk["is_finished"], + "finish_reason": finish_reason, + } + def chunk_creator(self, chunk): model_response = ModelResponse(stream=True, model=self.model) if self.response_id is not None: @@ -7807,8 +7832,14 @@ class CustomStreamWrapper: ] self.sent_last_chunk = True elif self.custom_llm_provider == "sagemaker": - print_verbose(f"ENTERS SAGEMAKER STREAMING for chunk {chunk}") - completion_obj["content"] = chunk + verbose_logger.debug(f"ENTERS SAGEMAKER STREAMING for chunk {chunk}") + response_obj = self.handle_sagemaker_stream(chunk) + completion_obj["content"] = response_obj["text"] + if response_obj["is_finished"]: + model_response.choices[0].finish_reason = response_obj[ + "finish_reason" + ] + self.sent_last_chunk = True elif self.custom_llm_provider == "petals": if len(self.completion_stream) == 0: if self.sent_last_chunk: @@ -7984,6 +8015,19 @@ class CustomStreamWrapper: original_exception=e, ) + def run_success_logging_in_thread(self, processed_chunk): + # Create an event loop for the new thread + ## ASYNC LOGGING + # Run the asynchronous function in the new thread's event loop + asyncio.run( + self.logging_obj.async_success_handler( + processed_chunk, + ) + ) + + ## SYNC LOGGING + self.logging_obj.success_handler(processed_chunk) + ## needs to handle the empty string case (even starting chunk can be an empty string) def __next__(self): try: @@ -8002,8 +8046,9 @@ class CustomStreamWrapper: continue ## LOGGING threading.Thread( - target=self.logging_obj.success_handler, args=(response,) + target=self.run_success_logging_in_thread, args=(response,) ).start() # log response + # RETURN RESULT return response except StopIteration: @@ -8059,13 +8104,34 @@ class CustomStreamWrapper: raise StopAsyncIteration else: # temporary patch for non-aiohttp async calls # example - boto3 bedrock llms - processed_chunk = next(self) - asyncio.create_task( - self.logging_obj.async_success_handler( - processed_chunk, - ) - ) - return processed_chunk + while True: + if isinstance(self.completion_stream, str) or isinstance( + self.completion_stream, bytes + ): + chunk = self.completion_stream + else: + chunk = next(self.completion_stream) + if chunk is not None and chunk != b"": + print_verbose(f"PROCESSED CHUNK PRE CHUNK CREATOR: {chunk}") + processed_chunk = self.chunk_creator(chunk=chunk) + print_verbose( + f"PROCESSED CHUNK POST CHUNK CREATOR: {processed_chunk}" + ) + if processed_chunk is None: + continue + ## LOGGING + threading.Thread( + target=self.logging_obj.success_handler, + args=(processed_chunk,), + ).start() # log processed_chunk + asyncio.create_task( + self.logging_obj.async_success_handler( + processed_chunk, + ) + ) + + # RETURN RESULT + return processed_chunk except StopAsyncIteration: raise except StopIteration: diff --git a/proxy_server_config.yaml b/proxy_server_config.yaml index dfa8e1151..2c123d156 100644 --- a/proxy_server_config.yaml +++ b/proxy_server_config.yaml @@ -11,6 +11,10 @@ model_list: api_base: https://openai-gpt-4-test-v-1.openai.azure.com/ api_version: "2023-05-15" api_key: os.environ/AZURE_API_KEY # The `os.environ/` prefix tells litellm to read this from the env. See https://docs.litellm.ai/docs/simple_proxy#load-api-keys-from-vault + - model_name: sagemaker-completion-model + litellm_params: + model: sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4 + input_cost_per_second: 0.000420 - model_name: gpt-4 litellm_params: model: azure/gpt-turbo diff --git a/tests/test_keys.py b/tests/test_keys.py index f05204c03..cb06e1f7e 100644 --- a/tests/test_keys.py +++ b/tests/test_keys.py @@ -13,17 +13,21 @@ sys.path.insert( import litellm -async def generate_key(session, i, budget=None, budget_duration=None): +async def generate_key( + session, i, budget=None, budget_duration=None, models=["azure-models", "gpt-4"] +): url = "http://0.0.0.0:4000/key/generate" headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"} data = { - "models": ["azure-models", "gpt-4"], + "models": models, "aliases": {"mistral-7b": "gpt-3.5-turbo"}, "duration": None, "max_budget": budget, "budget_duration": budget_duration, } + print(f"data: {data}") + async with session.post(url, headers=headers, json=data) as response: status = response.status response_text = await response.text() @@ -293,7 +297,7 @@ async def test_key_info_spend_values(): rounded_response_cost = round(response_cost, 8) rounded_key_info_spend = round(key_info["info"]["spend"], 8) assert rounded_response_cost == rounded_key_info_spend - ## streaming + ## streaming - azure key_gen = await generate_key(session=session, i=0) new_key = key_gen["key"] prompt_tokens, completion_tokens = await chat_completion_streaming( @@ -318,6 +322,41 @@ async def test_key_info_spend_values(): assert rounded_response_cost == rounded_key_info_spend +@pytest.mark.asyncio +async def test_key_info_spend_values_sagemaker(): + """ + Tests the sync streaming loop to ensure spend is correctly calculated. + - create key + - make completion call + - assert cost is expected value + """ + async with aiohttp.ClientSession() as session: + ## streaming - sagemaker + key_gen = await generate_key(session=session, i=0, models=[]) + new_key = key_gen["key"] + prompt_tokens, completion_tokens = await chat_completion_streaming( + session=session, key=new_key, model="sagemaker-completion-model" + ) + # print(f"prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}") + # prompt_cost, completion_cost = litellm.cost_per_token( + # model="azure/gpt-35-turbo", + # prompt_tokens=prompt_tokens, + # completion_tokens=completion_tokens, + # ) + # response_cost = prompt_cost + completion_cost + await asyncio.sleep(5) # allow db log to be updated + key_info = await get_key_info( + session=session, get_key=new_key, call_key=new_key + ) + # print( + # f"response_cost: {response_cost}; key_info spend: {key_info['info']['spend']}" + # ) + # rounded_response_cost = round(response_cost, 8) + rounded_key_info_spend = round(key_info["info"]["spend"], 8) + assert rounded_key_info_spend > 0 + # assert rounded_response_cost == rounded_key_info_spend + + @pytest.mark.asyncio async def test_key_with_budgets(): """