streaming for amazon titan bedrock

This commit is contained in:
ishaan-jaff 2023-09-16 09:57:16 -07:00
parent 93fbe4a733
commit c714372b9d
4 changed files with 82 additions and 44 deletions

View file

@ -59,6 +59,7 @@ def completion(
encoding,
logging_obj,
optional_params=None,
stream=False,
litellm_params=None,
logger_fn=None,
):
@ -106,6 +107,15 @@ def completion(
## COMPLETION CALL
accept = 'application/json'
contentType = 'application/json'
if stream == True:
response = client.invoke_model_with_response_stream(
body=data,
modelId=model,
accept=accept,
contentType=contentType
)
response = response.get('body')
return response
response = client.invoke_model(
body=data,
@ -114,9 +124,7 @@ def completion(
contentType=contentType
)
response_body = json.loads(response.get('body').read())
if "stream" in optional_params and optional_params["stream"] == True:
return response.iter_lines()
else:
## LOGGING
logging_obj.post_call(
input=prompt,

View file

@ -781,10 +781,12 @@ def completion(
litellm_params=litellm_params,
logger_fn=logger_fn,
encoding=encoding,
logging_obj=logging
logging_obj=logging,
stream=stream,
)
if "stream" in optional_params and optional_params["stream"] == True: ## [BETA]
if stream == True:
# don't try to access stream object,
response = CustomStreamWrapper(
iter(model_response), model, custom_llm_provider="bedrock", logging_obj=logging

View file

@ -676,7 +676,24 @@ def test_completion_bedrock_ai21():
print(response)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_completion_bedrock_ai21()
def test_completion_bedrock_ai21_stream():
try:
litellm.set_verbose = False
response = completion(
model="bedrock/amazon.titan-tg1-large",
messages=[{"role": "user", "content": "Be as verbose as possible and give as many details as possible, how does a court case get to the Supreme Court?"}],
temperature=1,
max_tokens=4096,
stream=True,
)
# Add any assertions here to check the response
print(response)
for chunk in response:
print(chunk)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_completion_bedrock_ai21_stream()
# test_completion_sagemaker()

View file

@ -2475,6 +2475,15 @@ class CustomStreamWrapper:
traceback.print_exc()
return ""
def handle_bedrock_stream(self):
if self.completion_stream:
event = next(self.completion_stream)
chunk = event.get('chunk')
if chunk:
chunk_data = json.loads(chunk.get('bytes').decode())
return chunk_data['outputText']
return ""
def __next__(self):
try:
# return this for all models
@ -2520,6 +2529,8 @@ class CustomStreamWrapper:
elif self.model in litellm.cohere_models or self.custom_llm_provider == "cohere":
chunk = next(self.completion_stream)
completion_obj["content"] = self.handle_cohere_chunk(chunk)
elif self.custom_llm_provider == "bedrock":
completion_obj["content"] = self.handle_bedrock_stream()
else: # openai chat/azure models
chunk = next(self.completion_stream)
model_response = chunk