mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 18:24:20 +00:00
fix(bedrock.py): support ai21 / bedrock streaming
This commit is contained in:
parent
3b89cff65e
commit
ab76daa90b
4 changed files with 97 additions and 26 deletions
|
@ -393,6 +393,32 @@ def completion(
|
|||
accept = 'application/json'
|
||||
contentType = 'application/json'
|
||||
if stream == True:
|
||||
if provider == "ai21":
|
||||
## LOGGING
|
||||
request_str = f"""
|
||||
response = client.invoke_model(
|
||||
body={data},
|
||||
modelId={model},
|
||||
accept=accept,
|
||||
contentType=contentType
|
||||
)
|
||||
"""
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key="",
|
||||
additional_args={"complete_input_dict": data, "request_str": request_str},
|
||||
)
|
||||
|
||||
response = client.invoke_model(
|
||||
body=data,
|
||||
modelId=model,
|
||||
accept=accept,
|
||||
contentType=contentType
|
||||
)
|
||||
|
||||
response = response.get('body').read()
|
||||
return response
|
||||
else:
|
||||
## LOGGING
|
||||
request_str = f"""
|
||||
response = client.invoke_model_with_response_stream(
|
||||
|
@ -407,6 +433,7 @@ def completion(
|
|||
api_key="",
|
||||
additional_args={"complete_input_dict": data, "request_str": request_str},
|
||||
)
|
||||
|
||||
response = client.invoke_model_with_response_stream(
|
||||
body=data,
|
||||
modelId=model,
|
||||
|
@ -415,7 +442,6 @@ def completion(
|
|||
)
|
||||
response = response.get('body')
|
||||
return response
|
||||
|
||||
try:
|
||||
## LOGGING
|
||||
request_str = f"""
|
||||
|
|
|
@ -1186,6 +1186,11 @@ def completion(
|
|||
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
# don't try to access stream object,
|
||||
if "ai21" in model:
|
||||
response = CustomStreamWrapper(
|
||||
model_response, model, custom_llm_provider="bedrock", logging_obj=logging
|
||||
)
|
||||
else:
|
||||
response = CustomStreamWrapper(
|
||||
iter(model_response), model, custom_llm_provider="bedrock", logging_obj=logging
|
||||
)
|
||||
|
|
|
@ -600,6 +600,39 @@ def test_completion_bedrock_claude_stream():
|
|||
|
||||
# test_completion_bedrock_claude_stream()
|
||||
|
||||
def test_completion_bedrock_ai21_stream():
|
||||
try:
|
||||
litellm.set_verbose=False
|
||||
response = completion(
|
||||
model="bedrock/ai21.j2-mid-v1",
|
||||
messages=[{"role": "user", "content": "Be as verbose as possible and give as many details as possible, how does a court case get to the Supreme Court?"}],
|
||||
temperature=1,
|
||||
max_tokens=20,
|
||||
stream=True,
|
||||
)
|
||||
print(response)
|
||||
complete_response = ""
|
||||
has_finish_reason = False
|
||||
# Add any assertions here to check the response
|
||||
for idx, chunk in enumerate(response):
|
||||
# print
|
||||
chunk, finished = streaming_format_tests(idx, chunk)
|
||||
has_finish_reason = finished
|
||||
complete_response += chunk
|
||||
if finished:
|
||||
break
|
||||
if has_finish_reason is False:
|
||||
raise Exception("finish reason not set for last chunk")
|
||||
if complete_response.strip() == "":
|
||||
raise Exception("Empty response received")
|
||||
print(f"completion_response: {complete_response}")
|
||||
except RateLimitError:
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
test_completion_bedrock_ai21_stream()
|
||||
|
||||
# def test_completion_sagemaker_stream():
|
||||
# try:
|
||||
# response = completion(
|
||||
|
|
|
@ -5065,14 +5065,22 @@ class CustomStreamWrapper:
|
|||
return ""
|
||||
|
||||
def handle_bedrock_stream(self, chunk):
|
||||
if hasattr(chunk, "get"):
|
||||
chunk = chunk.get('chunk')
|
||||
if chunk:
|
||||
chunk_data = json.loads(chunk.get('bytes').decode())
|
||||
else:
|
||||
chunk_data = json.loads(chunk.decode())
|
||||
if chunk_data:
|
||||
text = ""
|
||||
is_finished = False
|
||||
finish_reason = ""
|
||||
if "outputText" in chunk_data:
|
||||
text = chunk_data['outputText']
|
||||
# ai21 mapping
|
||||
if "ai21" in self.model: # fake ai21 streaming
|
||||
text = chunk_data.get('completions')[0].get('data').get('text')
|
||||
is_finished = True
|
||||
finish_reason = "stop"
|
||||
# anthropic mapping
|
||||
elif "completion" in chunk_data:
|
||||
text = chunk_data['completion'] # bedrock.anthropic
|
||||
|
@ -5295,11 +5303,10 @@ class CustomStreamWrapper:
|
|||
def __next__(self):
|
||||
try:
|
||||
while True:
|
||||
if isinstance(self.completion_stream, str):
|
||||
if isinstance(self.completion_stream, str) or isinstance(self.completion_stream, bytes):
|
||||
chunk = self.completion_stream
|
||||
else:
|
||||
chunk = next(self.completion_stream)
|
||||
|
||||
if chunk is not None and chunk != b'':
|
||||
response = self.chunk_creator(chunk=chunk)
|
||||
if response is not None:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue