fix(bedrock.py): support ai21 / bedrock streaming

This commit is contained in:
Krrish Dholakia 2023-11-29 16:34:48 -08:00
parent 3b89cff65e
commit ab76daa90b
4 changed files with 97 additions and 26 deletions

View file

@ -393,29 +393,55 @@ def completion(
accept = 'application/json' accept = 'application/json'
contentType = 'application/json' contentType = 'application/json'
if stream == True: if stream == True:
## LOGGING if provider == "ai21":
request_str = f""" ## LOGGING
response = client.invoke_model_with_response_stream( request_str = f"""
body={data}, response = client.invoke_model(
modelId={model}, body={data},
accept=accept, modelId={model},
contentType=contentType accept=accept,
) contentType=contentType
""" )
logging_obj.pre_call( """
logging_obj.pre_call(
input=prompt, input=prompt,
api_key="", api_key="",
additional_args={"complete_input_dict": data, "request_str": request_str}, additional_args={"complete_input_dict": data, "request_str": request_str},
) )
response = client.invoke_model_with_response_stream(
body=data,
modelId=model,
accept=accept,
contentType=contentType
)
response = response.get('body')
return response
response = client.invoke_model(
body=data,
modelId=model,
accept=accept,
contentType=contentType
)
response = response.get('body').read()
return response
else:
## LOGGING
request_str = f"""
response = client.invoke_model_with_response_stream(
body={data},
modelId={model},
accept=accept,
contentType=contentType
)
"""
logging_obj.pre_call(
input=prompt,
api_key="",
additional_args={"complete_input_dict": data, "request_str": request_str},
)
response = client.invoke_model_with_response_stream(
body=data,
modelId=model,
accept=accept,
contentType=contentType
)
response = response.get('body')
return response
try: try:
## LOGGING ## LOGGING
request_str = f""" request_str = f"""

View file

@ -1186,9 +1186,14 @@ def completion(
if "stream" in optional_params and optional_params["stream"] == True: if "stream" in optional_params and optional_params["stream"] == True:
# don't try to access stream object, # don't try to access stream object,
response = CustomStreamWrapper( if "ai21" in model:
iter(model_response), model, custom_llm_provider="bedrock", logging_obj=logging response = CustomStreamWrapper(
) model_response, model, custom_llm_provider="bedrock", logging_obj=logging
)
else:
response = CustomStreamWrapper(
iter(model_response), model, custom_llm_provider="bedrock", logging_obj=logging
)
return response return response
## RESPONSE OBJECT ## RESPONSE OBJECT

View file

@ -600,6 +600,39 @@ def test_completion_bedrock_claude_stream():
# test_completion_bedrock_claude_stream() # test_completion_bedrock_claude_stream()
def test_completion_bedrock_ai21_stream():
try:
litellm.set_verbose=False
response = completion(
model="bedrock/ai21.j2-mid-v1",
messages=[{"role": "user", "content": "Be as verbose as possible and give as many details as possible, how does a court case get to the Supreme Court?"}],
temperature=1,
max_tokens=20,
stream=True,
)
print(response)
complete_response = ""
has_finish_reason = False
# Add any assertions here to check the response
for idx, chunk in enumerate(response):
# print
chunk, finished = streaming_format_tests(idx, chunk)
has_finish_reason = finished
complete_response += chunk
if finished:
break
if has_finish_reason is False:
raise Exception("finish reason not set for last chunk")
if complete_response.strip() == "":
raise Exception("Empty response received")
print(f"completion_response: {complete_response}")
except RateLimitError:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")
test_completion_bedrock_ai21_stream()
# def test_completion_sagemaker_stream(): # def test_completion_sagemaker_stream():
# try: # try:
# response = completion( # response = completion(

View file

@ -5065,14 +5065,22 @@ class CustomStreamWrapper:
return "" return ""
def handle_bedrock_stream(self, chunk): def handle_bedrock_stream(self, chunk):
chunk = chunk.get('chunk') if hasattr(chunk, "get"):
if chunk: chunk = chunk.get('chunk')
chunk_data = json.loads(chunk.get('bytes').decode()) chunk_data = json.loads(chunk.get('bytes').decode())
else:
chunk_data = json.loads(chunk.decode())
if chunk_data:
text = "" text = ""
is_finished = False is_finished = False
finish_reason = "" finish_reason = ""
if "outputText" in chunk_data: if "outputText" in chunk_data:
text = chunk_data['outputText'] text = chunk_data['outputText']
# ai21 mapping
if "ai21" in self.model: # fake ai21 streaming
text = chunk_data.get('completions')[0].get('data').get('text')
is_finished = True
finish_reason = "stop"
# anthropic mapping # anthropic mapping
elif "completion" in chunk_data: elif "completion" in chunk_data:
text = chunk_data['completion'] # bedrock.anthropic text = chunk_data['completion'] # bedrock.anthropic
@ -5295,11 +5303,10 @@ class CustomStreamWrapper:
def __next__(self): def __next__(self):
try: try:
while True: while True:
if isinstance(self.completion_stream, str): if isinstance(self.completion_stream, str) or isinstance(self.completion_stream, bytes):
chunk = self.completion_stream chunk = self.completion_stream
else: else:
chunk = next(self.completion_stream) chunk = next(self.completion_stream)
if chunk is not None and chunk != b'': if chunk is not None and chunk != b'':
response = self.chunk_creator(chunk=chunk) response = self.chunk_creator(chunk=chunk)
if response is not None: if response is not None: