feat(factory.py): enable 'user_continue_message' for interweaving user/assistant messages when provider requires it

allows bedrock to be used with autogen
This commit is contained in:
Krrish Dholakia 2024-08-22 11:03:33 -07:00
parent 11bfc1dca7
commit 70bf8bd4f4
6 changed files with 54 additions and 8 deletions

View file

@ -84,6 +84,7 @@ BEDROCK_CONVERSE_MODELS = [
"meta.llama3-1-8b-instruct-v1:0",
"meta.llama3-1-70b-instruct-v1:0",
"meta.llama3-1-405b-instruct-v1:0",
"meta.llama3-70b-instruct-v1:0",
"mistral.mistral-large-2407-v1:0",
]
@ -1480,7 +1481,7 @@ class BedrockConverseLLM(BaseAWSLLM):
optional_params: dict,
acompletion: bool,
timeout: Optional[Union[float, httpx.Timeout]],
litellm_params=None,
litellm_params: dict,
logger_fn=None,
extra_headers: Optional[dict] = None,
client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
@ -1596,6 +1597,14 @@ class BedrockConverseLLM(BaseAWSLLM):
supported_tool_call_params = ["tools", "tool_choice"]
supported_guardrail_params = ["guardrailConfig"]
## TRANSFORMATION ##
bedrock_messages: List[MessageBlock] = _bedrock_converse_messages_pt(
messages=messages,
model=model,
llm_provider="bedrock_converse",
user_continue_message=litellm_params.pop("user_continue_message", None),
)
# send all model-specific params in 'additional_request_params'
for k, v in inference_params.items():
if (
@ -1608,11 +1617,6 @@ class BedrockConverseLLM(BaseAWSLLM):
for key in additional_request_keys:
inference_params.pop(key, None)
bedrock_messages: List[MessageBlock] = _bedrock_converse_messages_pt(
messages=messages,
model=model,
llm_provider="bedrock_converse",
)
bedrock_tools: List[ToolBlock] = _bedrock_tools_pt(
inference_params.pop("tools", [])
)

View file

@ -38,6 +38,18 @@ def prompt_injection_detection_default_pt():
BAD_MESSAGE_ERROR_STR = "Invalid Message "
# used to interweave user messages, to ensure user/assistant alternating
DEFAULT_USER_CONTINUE_MESSAGE = {
"role": "user",
"content": "Please continue.",
} # similar to autogen. Only used if `litellm.modify_params=True`.
# used to interweave assistant messages, to ensure user/assistant alternating
DEFAULT_ASSISTANT_CONTINUE_MESSAGE = {
"role": "assistant",
"content": "Please continue.",
} # similar to autogen. Only used if `litellm.modify_params=True`.
def map_system_message_pt(messages: list) -> list:
"""
@ -2254,6 +2266,7 @@ def _bedrock_converse_messages_pt(
messages: List,
model: str,
llm_provider: str,
user_continue_message: Optional[dict] = None,
) -> List[BedrockMessageBlock]:
"""
Converts given messages from OpenAI format to Bedrock format
@ -2264,6 +2277,21 @@ def _bedrock_converse_messages_pt(
contents: List[BedrockMessageBlock] = []
msg_i = 0
# if initial message is assistant message
if messages[0].get("role") is not None and messages[0]["role"] == "assistant":
if user_continue_message is not None:
messages.insert(0, user_continue_message)
elif litellm.modify_params:
messages.insert(0, DEFAULT_USER_CONTINUE_MESSAGE)
# if final message is assistant message
if messages[-1].get("role") is not None and messages[-1]["role"] == "assistant":
if user_continue_message is not None:
messages.append(user_continue_message)
elif litellm.modify_params:
messages.append(DEFAULT_USER_CONTINUE_MESSAGE)
while msg_i < len(messages):
user_content: List[BedrockContentBlock] = []
init_msg_i = msg_i
@ -2344,6 +2372,7 @@ def _bedrock_converse_messages_pt(
model=model,
llm_provider=llm_provider,
)
return contents

View file

@ -943,6 +943,7 @@ def completion(
output_cost_per_token=output_cost_per_token,
cooldown_time=cooldown_time,
text_completion=kwargs.get("text_completion"),
user_continue_message=kwargs.get("user_continue_message"),
)
logging.update_environment_variables(
model=model,
@ -2304,7 +2305,7 @@ def completion(
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params,
litellm_params=litellm_params, # type: ignore
logger_fn=logger_fn,
encoding=encoding,
logging_obj=logging,

View file

@ -738,8 +738,9 @@ def test_bedrock_system_prompt(system, model):
"temperature": 0.3,
"messages": [
{"role": "system", "content": system},
{"role": "user", "content": "hey, how's it going?"},
{"role": "assistant", "content": "hey, how's it going?"},
],
"user_continue_message": {"role": "user", "content": "Be a good bot!"},
}
response: ModelResponse = completion(
model="bedrock/{}".format(model),

View file

@ -1116,6 +1116,7 @@ all_litellm_params = [
"cooldown_time",
"cache_key",
"max_retries",
"user_continue_message",
]

View file

@ -2323,6 +2323,7 @@ def get_litellm_params(
output_cost_per_second=None,
cooldown_time=None,
text_completion=None,
user_continue_message=None,
):
litellm_params = {
"acompletion": acompletion,
@ -2347,6 +2348,7 @@ def get_litellm_params(
"output_cost_per_second": output_cost_per_second,
"cooldown_time": cooldown_time,
"text_completion": text_completion,
"user_continue_message": user_continue_message,
}
return litellm_params
@ -7123,6 +7125,14 @@ def exception_type(
llm_provider="bedrock",
response=original_exception.response,
)
elif "A conversation must start with a user message." in error_str:
exception_mapping_worked = True
raise BadRequestError(
message=f"BedrockException - {error_str}\n. Pass in default user message via `completion(..,user_continue_message=)` or enable `litellm.modify_params=True`.",
model=model,
llm_provider="bedrock",
response=original_exception.response,
)
elif (
"Unable to locate credentials" in error_str
or "The security token included in the request is invalid"