From 70bf8bd4f44e65e29cc11fe5da8fd141cd026410 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 22 Aug 2024 11:03:33 -0700 Subject: [PATCH] feat(factory.py): enable 'user_continue_message' for interweaving user/assistant messages when provider requires it allows bedrock to be used with autogen --- litellm/llms/bedrock_httpx.py | 16 ++++++++----- litellm/llms/prompt_templates/factory.py | 29 ++++++++++++++++++++++++ litellm/main.py | 3 ++- litellm/tests/test_bedrock_completion.py | 3 ++- litellm/types/utils.py | 1 + litellm/utils.py | 10 ++++++++ 6 files changed, 54 insertions(+), 8 deletions(-) diff --git a/litellm/llms/bedrock_httpx.py b/litellm/llms/bedrock_httpx.py index e45559752..23e7fdc3e 100644 --- a/litellm/llms/bedrock_httpx.py +++ b/litellm/llms/bedrock_httpx.py @@ -84,6 +84,7 @@ BEDROCK_CONVERSE_MODELS = [ "meta.llama3-1-8b-instruct-v1:0", "meta.llama3-1-70b-instruct-v1:0", "meta.llama3-1-405b-instruct-v1:0", + "meta.llama3-70b-instruct-v1:0", "mistral.mistral-large-2407-v1:0", ] @@ -1480,7 +1481,7 @@ class BedrockConverseLLM(BaseAWSLLM): optional_params: dict, acompletion: bool, timeout: Optional[Union[float, httpx.Timeout]], - litellm_params=None, + litellm_params: dict, logger_fn=None, extra_headers: Optional[dict] = None, client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None, @@ -1596,6 +1597,14 @@ class BedrockConverseLLM(BaseAWSLLM): supported_tool_call_params = ["tools", "tool_choice"] supported_guardrail_params = ["guardrailConfig"] ## TRANSFORMATION ## + + bedrock_messages: List[MessageBlock] = _bedrock_converse_messages_pt( + messages=messages, + model=model, + llm_provider="bedrock_converse", + user_continue_message=litellm_params.pop("user_continue_message", None), + ) + # send all model-specific params in 'additional_request_params' for k, v in inference_params.items(): if ( @@ -1608,11 +1617,6 @@ class BedrockConverseLLM(BaseAWSLLM): for key in additional_request_keys: inference_params.pop(key, None) - bedrock_messages: List[MessageBlock] = _bedrock_converse_messages_pt( - messages=messages, - model=model, - llm_provider="bedrock_converse", - ) bedrock_tools: List[ToolBlock] = _bedrock_tools_pt( inference_params.pop("tools", []) ) diff --git a/litellm/llms/prompt_templates/factory.py b/litellm/llms/prompt_templates/factory.py index c9e691c00..2b9a7fc24 100644 --- a/litellm/llms/prompt_templates/factory.py +++ b/litellm/llms/prompt_templates/factory.py @@ -38,6 +38,18 @@ def prompt_injection_detection_default_pt(): BAD_MESSAGE_ERROR_STR = "Invalid Message " +# used to interweave user messages, to ensure user/assistant alternating +DEFAULT_USER_CONTINUE_MESSAGE = { + "role": "user", + "content": "Please continue.", +} # similar to autogen. Only used if `litellm.modify_params=True`. + +# used to interweave assistant messages, to ensure user/assistant alternating +DEFAULT_ASSISTANT_CONTINUE_MESSAGE = { + "role": "assistant", + "content": "Please continue.", +} # similar to autogen. Only used if `litellm.modify_params=True`. + def map_system_message_pt(messages: list) -> list: """ @@ -2254,6 +2266,7 @@ def _bedrock_converse_messages_pt( messages: List, model: str, llm_provider: str, + user_continue_message: Optional[dict] = None, ) -> List[BedrockMessageBlock]: """ Converts given messages from OpenAI format to Bedrock format @@ -2264,6 +2277,21 @@ def _bedrock_converse_messages_pt( contents: List[BedrockMessageBlock] = [] msg_i = 0 + + # if initial message is assistant message + if messages[0].get("role") is not None and messages[0]["role"] == "assistant": + if user_continue_message is not None: + messages.insert(0, user_continue_message) + elif litellm.modify_params: + messages.insert(0, DEFAULT_USER_CONTINUE_MESSAGE) + + # if final message is assistant message + if messages[-1].get("role") is not None and messages[-1]["role"] == "assistant": + if user_continue_message is not None: + messages.append(user_continue_message) + elif litellm.modify_params: + messages.append(DEFAULT_USER_CONTINUE_MESSAGE) + while msg_i < len(messages): user_content: List[BedrockContentBlock] = [] init_msg_i = msg_i @@ -2344,6 +2372,7 @@ def _bedrock_converse_messages_pt( model=model, llm_provider=llm_provider, ) + return contents diff --git a/litellm/main.py b/litellm/main.py index 1beca0113..28054537c 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -943,6 +943,7 @@ def completion( output_cost_per_token=output_cost_per_token, cooldown_time=cooldown_time, text_completion=kwargs.get("text_completion"), + user_continue_message=kwargs.get("user_continue_message"), ) logging.update_environment_variables( model=model, @@ -2304,7 +2305,7 @@ def completion( model_response=model_response, print_verbose=print_verbose, optional_params=optional_params, - litellm_params=litellm_params, + litellm_params=litellm_params, # type: ignore logger_fn=logger_fn, encoding=encoding, logging_obj=logging, diff --git a/litellm/tests/test_bedrock_completion.py b/litellm/tests/test_bedrock_completion.py index 4892601b1..90592b499 100644 --- a/litellm/tests/test_bedrock_completion.py +++ b/litellm/tests/test_bedrock_completion.py @@ -738,8 +738,9 @@ def test_bedrock_system_prompt(system, model): "temperature": 0.3, "messages": [ {"role": "system", "content": system}, - {"role": "user", "content": "hey, how's it going?"}, + {"role": "assistant", "content": "hey, how's it going?"}, ], + "user_continue_message": {"role": "user", "content": "Be a good bot!"}, } response: ModelResponse = completion( model="bedrock/{}".format(model), diff --git a/litellm/types/utils.py b/litellm/types/utils.py index 8efbe5a11..6b278efa1 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -1116,6 +1116,7 @@ all_litellm_params = [ "cooldown_time", "cache_key", "max_retries", + "user_continue_message", ] diff --git a/litellm/utils.py b/litellm/utils.py index f3bb944a8..9c6f0b849 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -2323,6 +2323,7 @@ def get_litellm_params( output_cost_per_second=None, cooldown_time=None, text_completion=None, + user_continue_message=None, ): litellm_params = { "acompletion": acompletion, @@ -2347,6 +2348,7 @@ def get_litellm_params( "output_cost_per_second": output_cost_per_second, "cooldown_time": cooldown_time, "text_completion": text_completion, + "user_continue_message": user_continue_message, } return litellm_params @@ -7123,6 +7125,14 @@ def exception_type( llm_provider="bedrock", response=original_exception.response, ) + elif "A conversation must start with a user message." in error_str: + exception_mapping_worked = True + raise BadRequestError( + message=f"BedrockException - {error_str}\n. Pass in default user message via `completion(..,user_continue_message=)` or enable `litellm.modify_params=True`.", + model=model, + llm_provider="bedrock", + response=original_exception.response, + ) elif ( "Unable to locate credentials" in error_str or "The security token included in the request is invalid"