fix(utils.py): raise stop iteration exception on bedrock stream close

This commit is contained in:
Krrish Dholakia 2023-11-29 16:42:55 -08:00
parent 286ce586be
commit 5411d5a6fd

View file

@ -2176,7 +2176,7 @@ def get_optional_params( # use the openai defaults
_check_valid_arg(supported_params=supported_params)
elif custom_llm_provider == "bedrock":
if "ai21" in model:
supported_params = ["max_tokens", "temperature", "stop", "top_p", "stream"]
supported_params = ["max_tokens", "temperature", "top_p", "stream"]
_check_valid_arg(supported_params=supported_params)
# params "maxTokens":200,"temperature":0,"topP":250,"stop_sequences":[],
# https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=j2-ultra
@ -2184,8 +2184,6 @@ def get_optional_params( # use the openai defaults
optional_params["maxTokens"] = max_tokens
if temperature is not None:
optional_params["temperature"] = temperature
if stop is not None:
optional_params["stop_sequences"] = stop
if top_p is not None:
optional_params["topP"] = top_p
if stream:
@ -5184,10 +5182,13 @@ class CustomStreamWrapper:
if response_obj["is_finished"]:
model_response.choices[0].finish_reason = response_obj["finish_reason"]
elif self.custom_llm_provider == "bedrock":
if self.sent_last_chunk:
raise StopIteration
response_obj = self.handle_bedrock_stream(chunk)
completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]:
model_response.choices[0].finish_reason = response_obj["finish_reason"]
self.sent_last_chunk = True
elif self.custom_llm_provider == "sagemaker":
if len(self.completion_stream)==0:
if self.sent_last_chunk: