From c714372b9d064fd3f07e1df2868e7ebfda9a406e Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Sat, 16 Sep 2023 09:57:16 -0700 Subject: [PATCH] streaming for amazon titan bedrock --- litellm/llms/bedrock.py | 90 +++++++++++++++++--------------- litellm/main.py | 6 ++- litellm/tests/test_completion.py | 19 ++++++- litellm/utils.py | 11 ++++ 4 files changed, 82 insertions(+), 44 deletions(-) diff --git a/litellm/llms/bedrock.py b/litellm/llms/bedrock.py index d435d224d5..7c885a1872 100644 --- a/litellm/llms/bedrock.py +++ b/litellm/llms/bedrock.py @@ -59,6 +59,7 @@ def completion( encoding, logging_obj, optional_params=None, + stream=False, litellm_params=None, logger_fn=None, ): @@ -106,6 +107,15 @@ def completion( ## COMPLETION CALL accept = 'application/json' contentType = 'application/json' + if stream == True: + response = client.invoke_model_with_response_stream( + body=data, + modelId=model, + accept=accept, + contentType=contentType + ) + response = response.get('body') + return response response = client.invoke_model( body=data, @@ -114,50 +124,48 @@ def completion( contentType=contentType ) response_body = json.loads(response.get('body').read()) - if "stream" in optional_params and optional_params["stream"] == True: - return response.iter_lines() - else: - ## LOGGING - logging_obj.post_call( - input=prompt, - api_key="", - original_response=response, - additional_args={"complete_input_dict": data}, - ) - print_verbose(f"raw model_response: {response}") - ## RESPONSE OBJECT - outputText = "default" - if provider == "ai21": - outputText = response_body.get('completions')[0].get('data').get('text') - else: # amazon titan - outputText = response_body.get('results')[0].get('outputText') - if "error" in outputText: - raise BedrockError( - message=outputText, - status_code=response.status_code, - ) - else: - try: - model_response["choices"][0]["message"]["content"] = outputText - except: - raise BedrockError(message=json.dumps(outputText), status_code=response.status_code) - ## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here. - prompt_tokens = len( - encoding.encode(prompt) - ) - completion_tokens = len( - encoding.encode(model_response["choices"][0]["message"]["content"]) + ## LOGGING + logging_obj.post_call( + input=prompt, + api_key="", + original_response=response, + additional_args={"complete_input_dict": data}, ) + print_verbose(f"raw model_response: {response}") + ## RESPONSE OBJECT + outputText = "default" + if provider == "ai21": + outputText = response_body.get('completions')[0].get('data').get('text') + else: # amazon titan + outputText = response_body.get('results')[0].get('outputText') + if "error" in outputText: + raise BedrockError( + message=outputText, + status_code=response.status_code, + ) + else: + try: + model_response["choices"][0]["message"]["content"] = outputText + except: + raise BedrockError(message=json.dumps(outputText), status_code=response.status_code) - model_response["created"] = time.time() - model_response["model"] = model - model_response["usage"] = { - "prompt_tokens": prompt_tokens, - "completion_tokens": completion_tokens, - "total_tokens": prompt_tokens + completion_tokens, - } - return model_response + ## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here. + prompt_tokens = len( + encoding.encode(prompt) + ) + completion_tokens = len( + encoding.encode(model_response["choices"][0]["message"]["content"]) + ) + + model_response["created"] = time.time() + model_response["model"] = model + model_response["usage"] = { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": prompt_tokens + completion_tokens, + } + return model_response def embedding(): # logic for parsing in - calling - parsing out model embedding calls diff --git a/litellm/main.py b/litellm/main.py index d993e83d79..765c6ff834 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -781,10 +781,12 @@ def completion( litellm_params=litellm_params, logger_fn=logger_fn, encoding=encoding, - logging_obj=logging + logging_obj=logging, + stream=stream, ) - if "stream" in optional_params and optional_params["stream"] == True: ## [BETA] + + if stream == True: # don't try to access stream object, response = CustomStreamWrapper( iter(model_response), model, custom_llm_provider="bedrock", logging_obj=logging diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index ed3872cf89..5a9e545b0d 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -676,7 +676,24 @@ def test_completion_bedrock_ai21(): print(response) except Exception as e: pytest.fail(f"Error occurred: {e}") -# test_completion_bedrock_ai21() + +def test_completion_bedrock_ai21_stream(): + try: + litellm.set_verbose = False + response = completion( + model="bedrock/amazon.titan-tg1-large", + messages=[{"role": "user", "content": "Be as verbose as possible and give as many details as possible, how does a court case get to the Supreme Court?"}], + temperature=1, + max_tokens=4096, + stream=True, + ) + # Add any assertions here to check the response + print(response) + for chunk in response: + print(chunk) + except Exception as e: + pytest.fail(f"Error occurred: {e}") +# test_completion_bedrock_ai21_stream() # test_completion_sagemaker() diff --git a/litellm/utils.py b/litellm/utils.py index a5b2196a9c..dcf41cbce4 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -2475,6 +2475,15 @@ class CustomStreamWrapper: traceback.print_exc() return "" + def handle_bedrock_stream(self): + if self.completion_stream: + event = next(self.completion_stream) + chunk = event.get('chunk') + if chunk: + chunk_data = json.loads(chunk.get('bytes').decode()) + return chunk_data['outputText'] + return "" + def __next__(self): try: # return this for all models @@ -2520,6 +2529,8 @@ class CustomStreamWrapper: elif self.model in litellm.cohere_models or self.custom_llm_provider == "cohere": chunk = next(self.completion_stream) completion_obj["content"] = self.handle_cohere_chunk(chunk) + elif self.custom_llm_provider == "bedrock": + completion_obj["content"] = self.handle_bedrock_stream() else: # openai chat/azure models chunk = next(self.completion_stream) model_response = chunk