mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 18:24:20 +00:00
fix(bedrock.py): support ai21 / bedrock streaming
This commit is contained in:
parent
3b89cff65e
commit
ab76daa90b
4 changed files with 97 additions and 26 deletions
|
@ -393,29 +393,55 @@ def completion(
|
||||||
accept = 'application/json'
|
accept = 'application/json'
|
||||||
contentType = 'application/json'
|
contentType = 'application/json'
|
||||||
if stream == True:
|
if stream == True:
|
||||||
## LOGGING
|
if provider == "ai21":
|
||||||
request_str = f"""
|
## LOGGING
|
||||||
response = client.invoke_model_with_response_stream(
|
request_str = f"""
|
||||||
body={data},
|
response = client.invoke_model(
|
||||||
modelId={model},
|
body={data},
|
||||||
accept=accept,
|
modelId={model},
|
||||||
contentType=contentType
|
accept=accept,
|
||||||
)
|
contentType=contentType
|
||||||
"""
|
)
|
||||||
logging_obj.pre_call(
|
"""
|
||||||
|
logging_obj.pre_call(
|
||||||
input=prompt,
|
input=prompt,
|
||||||
api_key="",
|
api_key="",
|
||||||
additional_args={"complete_input_dict": data, "request_str": request_str},
|
additional_args={"complete_input_dict": data, "request_str": request_str},
|
||||||
)
|
)
|
||||||
response = client.invoke_model_with_response_stream(
|
|
||||||
body=data,
|
|
||||||
modelId=model,
|
|
||||||
accept=accept,
|
|
||||||
contentType=contentType
|
|
||||||
)
|
|
||||||
response = response.get('body')
|
|
||||||
return response
|
|
||||||
|
|
||||||
|
response = client.invoke_model(
|
||||||
|
body=data,
|
||||||
|
modelId=model,
|
||||||
|
accept=accept,
|
||||||
|
contentType=contentType
|
||||||
|
)
|
||||||
|
|
||||||
|
response = response.get('body').read()
|
||||||
|
return response
|
||||||
|
else:
|
||||||
|
## LOGGING
|
||||||
|
request_str = f"""
|
||||||
|
response = client.invoke_model_with_response_stream(
|
||||||
|
body={data},
|
||||||
|
modelId={model},
|
||||||
|
accept=accept,
|
||||||
|
contentType=contentType
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
logging_obj.pre_call(
|
||||||
|
input=prompt,
|
||||||
|
api_key="",
|
||||||
|
additional_args={"complete_input_dict": data, "request_str": request_str},
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.invoke_model_with_response_stream(
|
||||||
|
body=data,
|
||||||
|
modelId=model,
|
||||||
|
accept=accept,
|
||||||
|
contentType=contentType
|
||||||
|
)
|
||||||
|
response = response.get('body')
|
||||||
|
return response
|
||||||
try:
|
try:
|
||||||
## LOGGING
|
## LOGGING
|
||||||
request_str = f"""
|
request_str = f"""
|
||||||
|
|
|
@ -1186,9 +1186,14 @@ def completion(
|
||||||
|
|
||||||
if "stream" in optional_params and optional_params["stream"] == True:
|
if "stream" in optional_params and optional_params["stream"] == True:
|
||||||
# don't try to access stream object,
|
# don't try to access stream object,
|
||||||
response = CustomStreamWrapper(
|
if "ai21" in model:
|
||||||
iter(model_response), model, custom_llm_provider="bedrock", logging_obj=logging
|
response = CustomStreamWrapper(
|
||||||
)
|
model_response, model, custom_llm_provider="bedrock", logging_obj=logging
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
response = CustomStreamWrapper(
|
||||||
|
iter(model_response), model, custom_llm_provider="bedrock", logging_obj=logging
|
||||||
|
)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
## RESPONSE OBJECT
|
## RESPONSE OBJECT
|
||||||
|
|
|
@ -600,6 +600,39 @@ def test_completion_bedrock_claude_stream():
|
||||||
|
|
||||||
# test_completion_bedrock_claude_stream()
|
# test_completion_bedrock_claude_stream()
|
||||||
|
|
||||||
|
def test_completion_bedrock_ai21_stream():
|
||||||
|
try:
|
||||||
|
litellm.set_verbose=False
|
||||||
|
response = completion(
|
||||||
|
model="bedrock/ai21.j2-mid-v1",
|
||||||
|
messages=[{"role": "user", "content": "Be as verbose as possible and give as many details as possible, how does a court case get to the Supreme Court?"}],
|
||||||
|
temperature=1,
|
||||||
|
max_tokens=20,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
print(response)
|
||||||
|
complete_response = ""
|
||||||
|
has_finish_reason = False
|
||||||
|
# Add any assertions here to check the response
|
||||||
|
for idx, chunk in enumerate(response):
|
||||||
|
# print
|
||||||
|
chunk, finished = streaming_format_tests(idx, chunk)
|
||||||
|
has_finish_reason = finished
|
||||||
|
complete_response += chunk
|
||||||
|
if finished:
|
||||||
|
break
|
||||||
|
if has_finish_reason is False:
|
||||||
|
raise Exception("finish reason not set for last chunk")
|
||||||
|
if complete_response.strip() == "":
|
||||||
|
raise Exception("Empty response received")
|
||||||
|
print(f"completion_response: {complete_response}")
|
||||||
|
except RateLimitError:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
test_completion_bedrock_ai21_stream()
|
||||||
|
|
||||||
# def test_completion_sagemaker_stream():
|
# def test_completion_sagemaker_stream():
|
||||||
# try:
|
# try:
|
||||||
# response = completion(
|
# response = completion(
|
||||||
|
|
|
@ -5065,14 +5065,22 @@ class CustomStreamWrapper:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
def handle_bedrock_stream(self, chunk):
|
def handle_bedrock_stream(self, chunk):
|
||||||
chunk = chunk.get('chunk')
|
if hasattr(chunk, "get"):
|
||||||
if chunk:
|
chunk = chunk.get('chunk')
|
||||||
chunk_data = json.loads(chunk.get('bytes').decode())
|
chunk_data = json.loads(chunk.get('bytes').decode())
|
||||||
|
else:
|
||||||
|
chunk_data = json.loads(chunk.decode())
|
||||||
|
if chunk_data:
|
||||||
text = ""
|
text = ""
|
||||||
is_finished = False
|
is_finished = False
|
||||||
finish_reason = ""
|
finish_reason = ""
|
||||||
if "outputText" in chunk_data:
|
if "outputText" in chunk_data:
|
||||||
text = chunk_data['outputText']
|
text = chunk_data['outputText']
|
||||||
|
# ai21 mapping
|
||||||
|
if "ai21" in self.model: # fake ai21 streaming
|
||||||
|
text = chunk_data.get('completions')[0].get('data').get('text')
|
||||||
|
is_finished = True
|
||||||
|
finish_reason = "stop"
|
||||||
# anthropic mapping
|
# anthropic mapping
|
||||||
elif "completion" in chunk_data:
|
elif "completion" in chunk_data:
|
||||||
text = chunk_data['completion'] # bedrock.anthropic
|
text = chunk_data['completion'] # bedrock.anthropic
|
||||||
|
@ -5295,11 +5303,10 @@ class CustomStreamWrapper:
|
||||||
def __next__(self):
|
def __next__(self):
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
if isinstance(self.completion_stream, str):
|
if isinstance(self.completion_stream, str) or isinstance(self.completion_stream, bytes):
|
||||||
chunk = self.completion_stream
|
chunk = self.completion_stream
|
||||||
else:
|
else:
|
||||||
chunk = next(self.completion_stream)
|
chunk = next(self.completion_stream)
|
||||||
|
|
||||||
if chunk is not None and chunk != b'':
|
if chunk is not None and chunk != b'':
|
||||||
response = self.chunk_creator(chunk=chunk)
|
response = self.chunk_creator(chunk=chunk)
|
||||||
if response is not None:
|
if response is not None:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue