diff --git a/.circleci/config.yml b/.circleci/config.yml index 1de72a156..e0e6f5743 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -147,6 +147,9 @@ jobs: -e AZURE_API_KEY=$AZURE_API_KEY \ -e AZURE_FRANCE_API_KEY=$AZURE_FRANCE_API_KEY \ -e AZURE_EUROPE_API_KEY=$AZURE_EUROPE_API_KEY \ + -e AWS_ACCESS_KEY_ID=$AWS_ACCESS_KEY_ID \ + -e AWS_SECRET_ACCESS_KEY=$AWS_SECRET_ACCESS_KEY \ + -e AWS_REGION_NAME=$AWS_REGION_NAME \ --name my-app \ -v $(pwd)/proxy_server_config.yaml:/app/config.yaml \ my-app:latest \ diff --git a/litellm/llms/sagemaker.py b/litellm/llms/sagemaker.py index 1608f7a0f..78aafe7f7 100644 --- a/litellm/llms/sagemaker.py +++ b/litellm/llms/sagemaker.py @@ -34,22 +34,35 @@ class TokenIterator: self.byte_iterator = iter(stream) self.buffer = io.BytesIO() self.read_pos = 0 + self.end_of_data = False def __iter__(self): return self def __next__(self): - while True: - self.buffer.seek(self.read_pos) - line = self.buffer.readline() - if line and line[-1] == ord("\n"): - self.read_pos += len(line) + 1 - full_line = line[:-1].decode("utf-8") - line_data = json.loads(full_line.lstrip("data:").rstrip("/n")) - return line_data["token"]["text"] - chunk = next(self.byte_iterator) - self.buffer.seek(0, io.SEEK_END) - self.buffer.write(chunk["PayloadPart"]["Bytes"]) + try: + while True: + self.buffer.seek(self.read_pos) + line = self.buffer.readline() + if line and line[-1] == ord("\n"): + response_obj = {"text": "", "is_finished": False} + self.read_pos += len(line) + 1 + full_line = line[:-1].decode("utf-8") + line_data = json.loads(full_line.lstrip("data:").rstrip("/n")) + if line_data.get("generated_text", None) is not None: + self.end_of_data = True + response_obj["is_finished"] = True + response_obj["text"] = line_data["token"]["text"] + return response_obj + chunk = next(self.byte_iterator) + self.buffer.seek(0, io.SEEK_END) + self.buffer.write(chunk["PayloadPart"]["Bytes"]) + except StopIteration as e: + if self.end_of_data == True: + raise e # Re-raise StopIteration + else: + self.end_of_data = True + return "data: [DONE]" class SagemakerConfig: diff --git a/litellm/main.py b/litellm/main.py index 36d764b73..f9f1139f6 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -15,7 +15,7 @@ import dotenv, traceback, random, asyncio, time, contextvars from copy import deepcopy import httpx import litellm - +from ._logging import verbose_logger from litellm import ( # type: ignore client, exception_type, @@ -274,14 +274,10 @@ async def acompletion( else: # Call the synchronous function using run_in_executor response = await loop.run_in_executor(None, func_with_context) # type: ignore - # if kwargs.get("stream", False): # return an async generator - # return _async_streaming( - # response=response, - # model=model, - # custom_llm_provider=custom_llm_provider, - # args=args, - # ) - # else: + if isinstance(response, CustomStreamWrapper): + response.set_logging_event_loop( + loop=loop + ) # sets the logging event loop if the user does sync streaming (e.g. on proxy for sagemaker calls) return response except Exception as e: custom_llm_provider = custom_llm_provider or "openai" @@ -1520,11 +1516,6 @@ def completion( if ( "stream" in optional_params and optional_params["stream"] == True ): ## [BETA] - # sagemaker does not support streaming as of now so we're faking streaming: - # https://discuss.huggingface.co/t/streaming-output-text-when-deploying-on-sagemaker/39611 - # "SageMaker is currently not supporting streaming responses." - - # fake streaming for sagemaker print_verbose(f"ENTERS SAGEMAKER CUSTOMSTREAMWRAPPER") from .llms.sagemaker import TokenIterator @@ -1535,6 +1526,12 @@ def completion( custom_llm_provider="sagemaker", logging_obj=logging, ) + ## LOGGING + logging.post_call( + input=messages, + api_key=None, + original_response=response, + ) return response ## RESPONSE OBJECT @@ -3348,6 +3345,16 @@ def stream_chunk_builder( chunks: list, messages: Optional[list] = None, start_time=None, end_time=None ): model_response = litellm.ModelResponse() + ### SORT CHUNKS BASED ON CREATED ORDER ## + print_verbose("Goes into checking if chunk has hiddden created at param") + if chunks[0]._hidden_params.get("created_at", None): + print_verbose("Chunks have a created at hidden param") + # Sort chunks based on created_at in ascending order + chunks = sorted( + chunks, key=lambda x: x._hidden_params.get("created_at", float("inf")) + ) + print_verbose("Chunks sorted") + # set hidden params from chunk to model_response if model_response is not None and hasattr(model_response, "_hidden_params"): model_response._hidden_params = chunks[0].get("_hidden_params", {}) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 5dc1d0da6..8aa7e79fa 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -718,6 +718,9 @@ async def update_database( existing_spend_obj = await custom_db_client.get_data( key=id, table_name="user" ) + verbose_proxy_logger.debug( + f"Updating existing_spend_obj: {existing_spend_obj}" + ) if existing_spend_obj is None: existing_spend = 0 existing_spend_obj = LiteLLM_UserTable( diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 43c61e397..12605cf40 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -467,7 +467,9 @@ class PrismaClient: hashed_token = token if token.startswith("sk-"): hashed_token = self.hash_token(token=token) - print_verbose("PrismaClient: find_unique") + verbose_proxy_logger.debug( + f"PrismaClient: find_unique for token: {hashed_token}" + ) if query_type == "find_unique": response = await self.db.litellm_verificationtoken.find_unique( where={"token": hashed_token} @@ -774,7 +776,6 @@ class PrismaClient: Batch write update queries """ batcher = self.db.batch_() - verbose_proxy_logger.debug(f"data list for user table: {data_list}") for idx, user in enumerate(data_list): try: data_json = self.jsonify_object(data=user.model_dump()) diff --git a/litellm/tests/test_custom_callback_input.py b/litellm/tests/test_custom_callback_input.py index 556628d82..a61cc843e 100644 --- a/litellm/tests/test_custom_callback_input.py +++ b/litellm/tests/test_custom_callback_input.py @@ -556,6 +556,47 @@ async def test_async_chat_bedrock_stream(): # asyncio.run(test_async_chat_bedrock_stream()) + +## Test Sagemaker + Async +@pytest.mark.asyncio +async def test_async_chat_sagemaker_stream(): + try: + customHandler = CompletionCustomHandler() + litellm.callbacks = [customHandler] + response = await litellm.acompletion( + model="sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4", + messages=[{"role": "user", "content": "Hi 👋 - i'm async sagemaker"}], + ) + # test streaming + response = await litellm.acompletion( + model="sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4", + messages=[{"role": "user", "content": "Hi 👋 - i'm async sagemaker"}], + stream=True, + ) + print(f"response: {response}") + async for chunk in response: + print(f"chunk: {chunk}") + continue + ## test failure callback + try: + response = await litellm.acompletion( + model="sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4", + messages=[{"role": "user", "content": "Hi 👋 - i'm async sagemaker"}], + aws_region_name="my-bad-key", + stream=True, + ) + async for chunk in response: + continue + except: + pass + time.sleep(1) + print(f"customHandler.errors: {customHandler.errors}") + assert len(customHandler.errors) == 0 + litellm.callbacks = [] + except Exception as e: + pytest.fail(f"An exception occurred: {str(e)}") + + # Text Completion diff --git a/litellm/tests/test_custom_logger.py b/litellm/tests/test_custom_logger.py index 565df5b25..e1c532a88 100644 --- a/litellm/tests/test_custom_logger.py +++ b/litellm/tests/test_custom_logger.py @@ -206,12 +206,12 @@ def test_azure_completion_stream(): # checks if the model response available in the async + stream callbacks is equal to the received response customHandler2 = MyCustomHandler() litellm.callbacks = [customHandler2] - litellm.set_verbose = False + litellm.set_verbose = True messages = [ {"role": "system", "content": "You are a helpful assistant."}, { "role": "user", - "content": "write 1 sentence about litellm being amazing", + "content": f"write 1 sentence about litellm being amazing {time.time()}", }, ] complete_streaming_response = "" diff --git a/litellm/tests/test_streaming.py b/litellm/tests/test_streaming.py index 14b1a7210..fda640c96 100644 --- a/litellm/tests/test_streaming.py +++ b/litellm/tests/test_streaming.py @@ -847,9 +847,13 @@ def test_sagemaker_weird_response(): logging_obj=logging_obj, ) complete_response = "" - for chunk in response: - print(chunk) - complete_response += chunk["choices"][0]["delta"]["content"] + for idx, chunk in enumerate(response): + # print + chunk, finished = streaming_format_tests(idx, chunk) + has_finish_reason = finished + complete_response += chunk + if finished: + break assert len(complete_response) > 0 except Exception as e: pytest.fail(f"An exception occurred - {str(e)}") @@ -872,41 +876,53 @@ async def test_sagemaker_streaming_async(): ) # Add any assertions here to check the response + print(response) complete_response = "" + has_finish_reason = False + # Add any assertions here to check the response + idx = 0 async for chunk in response: - complete_response += chunk.choices[0].delta.content or "" - print(f"complete_response: {complete_response}") - assert len(complete_response) > 0 + # print + chunk, finished = streaming_format_tests(idx, chunk) + has_finish_reason = finished + complete_response += chunk + if finished: + break + idx += 1 + if has_finish_reason is False: + raise Exception("finish reason not set for last chunk") + if complete_response.strip() == "": + raise Exception("Empty response received") + print(f"completion_response: {complete_response}") except Exception as e: pytest.fail(f"An exception occurred - {str(e)}") -# def test_completion_sagemaker_stream(): -# try: -# response = completion( -# model="sagemaker/jumpstart-dft-meta-textgeneration-llama-2-7b", -# messages=messages, -# temperature=0.2, -# max_tokens=80, -# stream=True, -# ) -# complete_response = "" -# has_finish_reason = False -# # Add any assertions here to check the response -# for idx, chunk in enumerate(response): -# chunk, finished = streaming_format_tests(idx, chunk) -# has_finish_reason = finished -# if finished: -# break -# complete_response += chunk -# if has_finish_reason is False: -# raise Exception("finish reason not set for last chunk") -# if complete_response.strip() == "": -# raise Exception("Empty response received") -# except InvalidRequestError as e: -# pass -# except Exception as e: -# pytest.fail(f"Error occurred: {e}") +def test_completion_sagemaker_stream(): + try: + response = completion( + model="sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4", + messages=messages, + temperature=0.2, + max_tokens=80, + stream=True, + ) + complete_response = "" + has_finish_reason = False + # Add any assertions here to check the response + for idx, chunk in enumerate(response): + chunk, finished = streaming_format_tests(idx, chunk) + has_finish_reason = finished + if finished: + break + complete_response += chunk + if has_finish_reason is False: + raise Exception("finish reason not set for last chunk") + if complete_response.strip() == "": + raise Exception("Empty response received") + except Exception as e: + pytest.fail(f"Error occurred: {e}") + # test_completion_sagemaker_stream() diff --git a/litellm/utils.py b/litellm/utils.py index c062f4a22..b0e48bbc6 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -1418,7 +1418,9 @@ class Logging: """ Implementing async callbacks, to handle asyncio event loop issues when custom integrations need to use async functions. """ - print_verbose(f"Async success callbacks: {litellm._async_success_callback}") + verbose_logger.debug( + f"Async success callbacks: {litellm._async_success_callback}" + ) start_time, end_time, result = self._success_handler_helper_fn( start_time=start_time, end_time=end_time, result=result, cache_hit=cache_hit ) @@ -1427,7 +1429,7 @@ class Logging: if self.stream: if result.choices[0].finish_reason is not None: # if it's the last chunk self.streaming_chunks.append(result) - # print_verbose(f"final set of received chunks: {self.streaming_chunks}") + # verbose_logger.debug(f"final set of received chunks: {self.streaming_chunks}") try: complete_streaming_response = litellm.stream_chunk_builder( self.streaming_chunks, @@ -1436,14 +1438,16 @@ class Logging: end_time=end_time, ) except Exception as e: - print_verbose( + verbose_logger.debug( f"Error occurred building stream chunk: {traceback.format_exc()}" ) complete_streaming_response = None else: self.streaming_chunks.append(result) if complete_streaming_response is not None: - print_verbose("Async success callbacks: Got a complete streaming response") + verbose_logger.debug( + "Async success callbacks: Got a complete streaming response" + ) self.model_call_details[ "complete_streaming_response" ] = complete_streaming_response @@ -7152,6 +7156,7 @@ class CustomStreamWrapper: "model_id": (_model_info.get("id", None)) } # returned as x-litellm-model-id response header in proxy self.response_id = None + self.logging_loop = None def __iter__(self): return self @@ -7722,6 +7727,27 @@ class CustomStreamWrapper: } return "" + def handle_sagemaker_stream(self, chunk): + if "data: [DONE]" in chunk: + text = "" + is_finished = True + finish_reason = "stop" + return { + "text": text, + "is_finished": is_finished, + "finish_reason": finish_reason, + } + elif isinstance(chunk, dict): + if chunk["is_finished"] == True: + finish_reason = "stop" + else: + finish_reason = "" + return { + "text": chunk["text"], + "is_finished": chunk["is_finished"], + "finish_reason": finish_reason, + } + def chunk_creator(self, chunk): model_response = ModelResponse(stream=True, model=self.model) if self.response_id is not None: @@ -7729,6 +7755,7 @@ class CustomStreamWrapper: else: self.response_id = model_response.id model_response._hidden_params["custom_llm_provider"] = self.custom_llm_provider + model_response._hidden_params["created_at"] = time.time() model_response.choices = [StreamingChoices()] model_response.choices[0].finish_reason = None response_obj = {} @@ -7847,8 +7874,14 @@ class CustomStreamWrapper: ] self.sent_last_chunk = True elif self.custom_llm_provider == "sagemaker": - print_verbose(f"ENTERS SAGEMAKER STREAMING for chunk {chunk}") - completion_obj["content"] = chunk + verbose_logger.debug(f"ENTERS SAGEMAKER STREAMING for chunk {chunk}") + response_obj = self.handle_sagemaker_stream(chunk) + completion_obj["content"] = response_obj["text"] + if response_obj["is_finished"]: + model_response.choices[0].finish_reason = response_obj[ + "finish_reason" + ] + self.sent_last_chunk = True elif self.custom_llm_provider == "petals": if len(self.completion_stream) == 0: if self.sent_last_chunk: @@ -8024,6 +8057,27 @@ class CustomStreamWrapper: original_exception=e, ) + def set_logging_event_loop(self, loop): + self.logging_loop = loop + + async def your_async_function(self): + # Your asynchronous code here + return "Your asynchronous code is running" + + def run_success_logging_in_thread(self, processed_chunk): + # Create an event loop for the new thread + ## ASYNC LOGGING + if self.logging_loop is not None: + future = asyncio.run_coroutine_threadsafe( + self.logging_obj.async_success_handler(processed_chunk), + loop=self.logging_loop, + ) + result = future.result() + else: + asyncio.run(self.logging_obj.async_success_handler(processed_chunk)) + ## SYNC LOGGING + self.logging_obj.success_handler(processed_chunk) + ## needs to handle the empty string case (even starting chunk can be an empty string) def __next__(self): try: @@ -8042,8 +8096,9 @@ class CustomStreamWrapper: continue ## LOGGING threading.Thread( - target=self.logging_obj.success_handler, args=(response,) + target=self.run_success_logging_in_thread, args=(response,) ).start() # log response + # RETURN RESULT return response except StopIteration: @@ -8099,13 +8154,34 @@ class CustomStreamWrapper: raise StopAsyncIteration else: # temporary patch for non-aiohttp async calls # example - boto3 bedrock llms - processed_chunk = next(self) - asyncio.create_task( - self.logging_obj.async_success_handler( - processed_chunk, - ) - ) - return processed_chunk + while True: + if isinstance(self.completion_stream, str) or isinstance( + self.completion_stream, bytes + ): + chunk = self.completion_stream + else: + chunk = next(self.completion_stream) + if chunk is not None and chunk != b"": + print_verbose(f"PROCESSED CHUNK PRE CHUNK CREATOR: {chunk}") + processed_chunk = self.chunk_creator(chunk=chunk) + print_verbose( + f"PROCESSED CHUNK POST CHUNK CREATOR: {processed_chunk}" + ) + if processed_chunk is None: + continue + ## LOGGING + threading.Thread( + target=self.logging_obj.success_handler, + args=(processed_chunk,), + ).start() # log processed_chunk + asyncio.create_task( + self.logging_obj.async_success_handler( + processed_chunk, + ) + ) + + # RETURN RESULT + return processed_chunk except StopAsyncIteration: raise except StopIteration: diff --git a/proxy_server_config.yaml b/proxy_server_config.yaml index dfa8e1151..2c123d156 100644 --- a/proxy_server_config.yaml +++ b/proxy_server_config.yaml @@ -11,6 +11,10 @@ model_list: api_base: https://openai-gpt-4-test-v-1.openai.azure.com/ api_version: "2023-05-15" api_key: os.environ/AZURE_API_KEY # The `os.environ/` prefix tells litellm to read this from the env. See https://docs.litellm.ai/docs/simple_proxy#load-api-keys-from-vault + - model_name: sagemaker-completion-model + litellm_params: + model: sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4 + input_cost_per_second: 0.000420 - model_name: gpt-4 litellm_params: model: azure/gpt-turbo diff --git a/tests/test_keys.py b/tests/test_keys.py index f05204c03..348be63af 100644 --- a/tests/test_keys.py +++ b/tests/test_keys.py @@ -13,17 +13,21 @@ sys.path.insert( import litellm -async def generate_key(session, i, budget=None, budget_duration=None): +async def generate_key( + session, i, budget=None, budget_duration=None, models=["azure-models", "gpt-4"] +): url = "http://0.0.0.0:4000/key/generate" headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"} data = { - "models": ["azure-models", "gpt-4"], + "models": models, "aliases": {"mistral-7b": "gpt-3.5-turbo"}, "duration": None, "max_budget": budget, "budget_duration": budget_duration, } + print(f"data: {data}") + async with session.post(url, headers=headers, json=data) as response: status = response.status response_text = await response.text() @@ -293,7 +297,7 @@ async def test_key_info_spend_values(): rounded_response_cost = round(response_cost, 8) rounded_key_info_spend = round(key_info["info"]["spend"], 8) assert rounded_response_cost == rounded_key_info_spend - ## streaming + ## streaming - azure key_gen = await generate_key(session=session, i=0) new_key = key_gen["key"] prompt_tokens, completion_tokens = await chat_completion_streaming( @@ -318,6 +322,41 @@ async def test_key_info_spend_values(): assert rounded_response_cost == rounded_key_info_spend +@pytest.mark.asyncio +async def test_key_info_spend_values_sagemaker(): + """ + Tests the sync streaming loop to ensure spend is correctly calculated. + - create key + - make completion call + - assert cost is expected value + """ + async with aiohttp.ClientSession() as session: + ## streaming - sagemaker + key_gen = await generate_key(session=session, i=0, models=[]) + new_key = key_gen["key"] + prompt_tokens, completion_tokens = await chat_completion_streaming( + session=session, key=new_key, model="sagemaker-completion-model" + ) + # print(f"prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}") + # prompt_cost, completion_cost = litellm.cost_per_token( + # model="azure/gpt-35-turbo", + # prompt_tokens=prompt_tokens, + # completion_tokens=completion_tokens, + # ) + # response_cost = prompt_cost + completion_cost + await asyncio.sleep(5) # allow db log to be updated + key_info = await get_key_info( + session=session, get_key=new_key, call_key=new_key + ) + # print( + # f"response_cost: {response_cost}; key_info spend: {key_info['info']['spend']}" + # ) + # rounded_response_cost = round(response_cost, 8) + rounded_key_info_spend = round(key_info["info"]["spend"], 8) + assert rounded_key_info_spend > 0 + # assert rounded_response_cost == rounded_key_info_spend + + @pytest.mark.asyncio async def test_key_with_budgets(): """ @@ -337,7 +376,7 @@ async def test_key_with_budgets(): print(f"hashed_token: {hashed_token}") key_info = await get_key_info(session=session, get_key=key, call_key=key) reset_at_init_value = key_info["info"]["budget_reset_at"] - await asyncio.sleep(15) + await asyncio.sleep(30) key_info = await get_key_info(session=session, get_key=key, call_key=key) reset_at_new_value = key_info["info"]["budget_reset_at"] assert reset_at_init_value != reset_at_new_value