fix(bedrock.py): support ai21 / bedrock streaming

This commit is contained in:
Krrish Dholakia 2023-11-29 16:34:48 -08:00
parent 3b89cff65e
commit ab76daa90b
4 changed files with 97 additions and 26 deletions

View file

@ -393,29 +393,55 @@ def completion(
accept = 'application/json'
contentType = 'application/json'
if stream == True:
## LOGGING
request_str = f"""
response = client.invoke_model_with_response_stream(
body={data},
modelId={model},
accept=accept,
contentType=contentType
)
"""
logging_obj.pre_call(
if provider == "ai21":
## LOGGING
request_str = f"""
response = client.invoke_model(
body={data},
modelId={model},
accept=accept,
contentType=contentType
)
"""
logging_obj.pre_call(
input=prompt,
api_key="",
additional_args={"complete_input_dict": data, "request_str": request_str},
)
response = client.invoke_model_with_response_stream(
body=data,
modelId=model,
accept=accept,
contentType=contentType
)
response = response.get('body')
return response
)
response = client.invoke_model(
body=data,
modelId=model,
accept=accept,
contentType=contentType
)
response = response.get('body').read()
return response
else:
## LOGGING
request_str = f"""
response = client.invoke_model_with_response_stream(
body={data},
modelId={model},
accept=accept,
contentType=contentType
)
"""
logging_obj.pre_call(
input=prompt,
api_key="",
additional_args={"complete_input_dict": data, "request_str": request_str},
)
response = client.invoke_model_with_response_stream(
body=data,
modelId=model,
accept=accept,
contentType=contentType
)
response = response.get('body')
return response
try:
## LOGGING
request_str = f"""

View file

@ -1186,9 +1186,14 @@ def completion(
if "stream" in optional_params and optional_params["stream"] == True:
# don't try to access stream object,
response = CustomStreamWrapper(
iter(model_response), model, custom_llm_provider="bedrock", logging_obj=logging
)
if "ai21" in model:
response = CustomStreamWrapper(
model_response, model, custom_llm_provider="bedrock", logging_obj=logging
)
else:
response = CustomStreamWrapper(
iter(model_response), model, custom_llm_provider="bedrock", logging_obj=logging
)
return response
## RESPONSE OBJECT

View file

@ -600,6 +600,39 @@ def test_completion_bedrock_claude_stream():
# test_completion_bedrock_claude_stream()
def test_completion_bedrock_ai21_stream():
try:
litellm.set_verbose=False
response = completion(
model="bedrock/ai21.j2-mid-v1",
messages=[{"role": "user", "content": "Be as verbose as possible and give as many details as possible, how does a court case get to the Supreme Court?"}],
temperature=1,
max_tokens=20,
stream=True,
)
print(response)
complete_response = ""
has_finish_reason = False
# Add any assertions here to check the response
for idx, chunk in enumerate(response):
# print
chunk, finished = streaming_format_tests(idx, chunk)
has_finish_reason = finished
complete_response += chunk
if finished:
break
if has_finish_reason is False:
raise Exception("finish reason not set for last chunk")
if complete_response.strip() == "":
raise Exception("Empty response received")
print(f"completion_response: {complete_response}")
except RateLimitError:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")
test_completion_bedrock_ai21_stream()
# def test_completion_sagemaker_stream():
# try:
# response = completion(

View file

@ -5065,14 +5065,22 @@ class CustomStreamWrapper:
return ""
def handle_bedrock_stream(self, chunk):
chunk = chunk.get('chunk')
if chunk:
if hasattr(chunk, "get"):
chunk = chunk.get('chunk')
chunk_data = json.loads(chunk.get('bytes').decode())
else:
chunk_data = json.loads(chunk.decode())
if chunk_data:
text = ""
is_finished = False
finish_reason = ""
if "outputText" in chunk_data:
text = chunk_data['outputText']
# ai21 mapping
if "ai21" in self.model: # fake ai21 streaming
text = chunk_data.get('completions')[0].get('data').get('text')
is_finished = True
finish_reason = "stop"
# anthropic mapping
elif "completion" in chunk_data:
text = chunk_data['completion'] # bedrock.anthropic
@ -5295,11 +5303,10 @@ class CustomStreamWrapper:
def __next__(self):
try:
while True:
if isinstance(self.completion_stream, str):
if isinstance(self.completion_stream, str) or isinstance(self.completion_stream, bytes):
chunk = self.completion_stream
else:
chunk = next(self.completion_stream)
if chunk is not None and chunk != b'':
response = self.chunk_creator(chunk=chunk)
if response is not None: