From ab76daa90bc2b2ee1ef39f41df1e4d693cb0accd Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 29 Nov 2023 16:34:48 -0800 Subject: [PATCH] fix(bedrock.py): support ai21 / bedrock streaming --- litellm/llms/bedrock.py | 64 +++++++++++++++++++++++---------- litellm/main.py | 11 ++++-- litellm/tests/test_streaming.py | 33 +++++++++++++++++ litellm/utils.py | 15 +++++--- 4 files changed, 97 insertions(+), 26 deletions(-) diff --git a/litellm/llms/bedrock.py b/litellm/llms/bedrock.py index 86fba8366..30aa3e6ce 100644 --- a/litellm/llms/bedrock.py +++ b/litellm/llms/bedrock.py @@ -393,29 +393,55 @@ def completion( accept = 'application/json' contentType = 'application/json' if stream == True: - ## LOGGING - request_str = f""" - response = client.invoke_model_with_response_stream( - body={data}, - modelId={model}, - accept=accept, - contentType=contentType - ) - """ - logging_obj.pre_call( + if provider == "ai21": + ## LOGGING + request_str = f""" + response = client.invoke_model( + body={data}, + modelId={model}, + accept=accept, + contentType=contentType + ) + """ + logging_obj.pre_call( input=prompt, api_key="", additional_args={"complete_input_dict": data, "request_str": request_str}, - ) - response = client.invoke_model_with_response_stream( - body=data, - modelId=model, - accept=accept, - contentType=contentType - ) - response = response.get('body') - return response + ) + response = client.invoke_model( + body=data, + modelId=model, + accept=accept, + contentType=contentType + ) + + response = response.get('body').read() + return response + else: + ## LOGGING + request_str = f""" + response = client.invoke_model_with_response_stream( + body={data}, + modelId={model}, + accept=accept, + contentType=contentType + ) + """ + logging_obj.pre_call( + input=prompt, + api_key="", + additional_args={"complete_input_dict": data, "request_str": request_str}, + ) + + response = client.invoke_model_with_response_stream( + body=data, + modelId=model, + accept=accept, + contentType=contentType + ) + response = response.get('body') + return response try: ## LOGGING request_str = f""" diff --git a/litellm/main.py b/litellm/main.py index ce807efc5..baec460d5 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -1186,9 +1186,14 @@ def completion( if "stream" in optional_params and optional_params["stream"] == True: # don't try to access stream object, - response = CustomStreamWrapper( - iter(model_response), model, custom_llm_provider="bedrock", logging_obj=logging - ) + if "ai21" in model: + response = CustomStreamWrapper( + model_response, model, custom_llm_provider="bedrock", logging_obj=logging + ) + else: + response = CustomStreamWrapper( + iter(model_response), model, custom_llm_provider="bedrock", logging_obj=logging + ) return response ## RESPONSE OBJECT diff --git a/litellm/tests/test_streaming.py b/litellm/tests/test_streaming.py index 417bd64d9..597020942 100644 --- a/litellm/tests/test_streaming.py +++ b/litellm/tests/test_streaming.py @@ -600,6 +600,39 @@ def test_completion_bedrock_claude_stream(): # test_completion_bedrock_claude_stream() +def test_completion_bedrock_ai21_stream(): + try: + litellm.set_verbose=False + response = completion( + model="bedrock/ai21.j2-mid-v1", + messages=[{"role": "user", "content": "Be as verbose as possible and give as many details as possible, how does a court case get to the Supreme Court?"}], + temperature=1, + max_tokens=20, + stream=True, + ) + print(response) + complete_response = "" + has_finish_reason = False + # Add any assertions here to check the response + for idx, chunk in enumerate(response): + # print + chunk, finished = streaming_format_tests(idx, chunk) + has_finish_reason = finished + complete_response += chunk + if finished: + break + if has_finish_reason is False: + raise Exception("finish reason not set for last chunk") + if complete_response.strip() == "": + raise Exception("Empty response received") + print(f"completion_response: {complete_response}") + except RateLimitError: + pass + except Exception as e: + pytest.fail(f"Error occurred: {e}") + +test_completion_bedrock_ai21_stream() + # def test_completion_sagemaker_stream(): # try: # response = completion( diff --git a/litellm/utils.py b/litellm/utils.py index 1fe9c12c8..dd90a9764 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -5065,14 +5065,22 @@ class CustomStreamWrapper: return "" def handle_bedrock_stream(self, chunk): - chunk = chunk.get('chunk') - if chunk: + if hasattr(chunk, "get"): + chunk = chunk.get('chunk') chunk_data = json.loads(chunk.get('bytes').decode()) + else: + chunk_data = json.loads(chunk.decode()) + if chunk_data: text = "" is_finished = False finish_reason = "" if "outputText" in chunk_data: text = chunk_data['outputText'] + # ai21 mapping + if "ai21" in self.model: # fake ai21 streaming + text = chunk_data.get('completions')[0].get('data').get('text') + is_finished = True + finish_reason = "stop" # anthropic mapping elif "completion" in chunk_data: text = chunk_data['completion'] # bedrock.anthropic @@ -5295,11 +5303,10 @@ class CustomStreamWrapper: def __next__(self): try: while True: - if isinstance(self.completion_stream, str): + if isinstance(self.completion_stream, str) or isinstance(self.completion_stream, bytes): chunk = self.completion_stream else: chunk = next(self.completion_stream) - if chunk is not None and chunk != b'': response = self.chunk_creator(chunk=chunk) if response is not None: