mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
feat(factory.py): enable 'user_continue_message' for interweaving user/assistant messages when provider requires it
allows bedrock to be used with autogen
This commit is contained in:
parent
11bfc1dca7
commit
70bf8bd4f4
6 changed files with 54 additions and 8 deletions
|
@ -84,6 +84,7 @@ BEDROCK_CONVERSE_MODELS = [
|
|||
"meta.llama3-1-8b-instruct-v1:0",
|
||||
"meta.llama3-1-70b-instruct-v1:0",
|
||||
"meta.llama3-1-405b-instruct-v1:0",
|
||||
"meta.llama3-70b-instruct-v1:0",
|
||||
"mistral.mistral-large-2407-v1:0",
|
||||
]
|
||||
|
||||
|
@ -1480,7 +1481,7 @@ class BedrockConverseLLM(BaseAWSLLM):
|
|||
optional_params: dict,
|
||||
acompletion: bool,
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
litellm_params=None,
|
||||
litellm_params: dict,
|
||||
logger_fn=None,
|
||||
extra_headers: Optional[dict] = None,
|
||||
client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
|
||||
|
@ -1596,6 +1597,14 @@ class BedrockConverseLLM(BaseAWSLLM):
|
|||
supported_tool_call_params = ["tools", "tool_choice"]
|
||||
supported_guardrail_params = ["guardrailConfig"]
|
||||
## TRANSFORMATION ##
|
||||
|
||||
bedrock_messages: List[MessageBlock] = _bedrock_converse_messages_pt(
|
||||
messages=messages,
|
||||
model=model,
|
||||
llm_provider="bedrock_converse",
|
||||
user_continue_message=litellm_params.pop("user_continue_message", None),
|
||||
)
|
||||
|
||||
# send all model-specific params in 'additional_request_params'
|
||||
for k, v in inference_params.items():
|
||||
if (
|
||||
|
@ -1608,11 +1617,6 @@ class BedrockConverseLLM(BaseAWSLLM):
|
|||
for key in additional_request_keys:
|
||||
inference_params.pop(key, None)
|
||||
|
||||
bedrock_messages: List[MessageBlock] = _bedrock_converse_messages_pt(
|
||||
messages=messages,
|
||||
model=model,
|
||||
llm_provider="bedrock_converse",
|
||||
)
|
||||
bedrock_tools: List[ToolBlock] = _bedrock_tools_pt(
|
||||
inference_params.pop("tools", [])
|
||||
)
|
||||
|
|
|
@ -38,6 +38,18 @@ def prompt_injection_detection_default_pt():
|
|||
|
||||
BAD_MESSAGE_ERROR_STR = "Invalid Message "
|
||||
|
||||
# used to interweave user messages, to ensure user/assistant alternating
|
||||
DEFAULT_USER_CONTINUE_MESSAGE = {
|
||||
"role": "user",
|
||||
"content": "Please continue.",
|
||||
} # similar to autogen. Only used if `litellm.modify_params=True`.
|
||||
|
||||
# used to interweave assistant messages, to ensure user/assistant alternating
|
||||
DEFAULT_ASSISTANT_CONTINUE_MESSAGE = {
|
||||
"role": "assistant",
|
||||
"content": "Please continue.",
|
||||
} # similar to autogen. Only used if `litellm.modify_params=True`.
|
||||
|
||||
|
||||
def map_system_message_pt(messages: list) -> list:
|
||||
"""
|
||||
|
@ -2254,6 +2266,7 @@ def _bedrock_converse_messages_pt(
|
|||
messages: List,
|
||||
model: str,
|
||||
llm_provider: str,
|
||||
user_continue_message: Optional[dict] = None,
|
||||
) -> List[BedrockMessageBlock]:
|
||||
"""
|
||||
Converts given messages from OpenAI format to Bedrock format
|
||||
|
@ -2264,6 +2277,21 @@ def _bedrock_converse_messages_pt(
|
|||
|
||||
contents: List[BedrockMessageBlock] = []
|
||||
msg_i = 0
|
||||
|
||||
# if initial message is assistant message
|
||||
if messages[0].get("role") is not None and messages[0]["role"] == "assistant":
|
||||
if user_continue_message is not None:
|
||||
messages.insert(0, user_continue_message)
|
||||
elif litellm.modify_params:
|
||||
messages.insert(0, DEFAULT_USER_CONTINUE_MESSAGE)
|
||||
|
||||
# if final message is assistant message
|
||||
if messages[-1].get("role") is not None and messages[-1]["role"] == "assistant":
|
||||
if user_continue_message is not None:
|
||||
messages.append(user_continue_message)
|
||||
elif litellm.modify_params:
|
||||
messages.append(DEFAULT_USER_CONTINUE_MESSAGE)
|
||||
|
||||
while msg_i < len(messages):
|
||||
user_content: List[BedrockContentBlock] = []
|
||||
init_msg_i = msg_i
|
||||
|
@ -2344,6 +2372,7 @@ def _bedrock_converse_messages_pt(
|
|||
model=model,
|
||||
llm_provider=llm_provider,
|
||||
)
|
||||
|
||||
return contents
|
||||
|
||||
|
||||
|
|
|
@ -943,6 +943,7 @@ def completion(
|
|||
output_cost_per_token=output_cost_per_token,
|
||||
cooldown_time=cooldown_time,
|
||||
text_completion=kwargs.get("text_completion"),
|
||||
user_continue_message=kwargs.get("user_continue_message"),
|
||||
)
|
||||
logging.update_environment_variables(
|
||||
model=model,
|
||||
|
@ -2304,7 +2305,7 @@ def completion(
|
|||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
litellm_params=litellm_params, # type: ignore
|
||||
logger_fn=logger_fn,
|
||||
encoding=encoding,
|
||||
logging_obj=logging,
|
||||
|
|
|
@ -738,8 +738,9 @@ def test_bedrock_system_prompt(system, model):
|
|||
"temperature": 0.3,
|
||||
"messages": [
|
||||
{"role": "system", "content": system},
|
||||
{"role": "user", "content": "hey, how's it going?"},
|
||||
{"role": "assistant", "content": "hey, how's it going?"},
|
||||
],
|
||||
"user_continue_message": {"role": "user", "content": "Be a good bot!"},
|
||||
}
|
||||
response: ModelResponse = completion(
|
||||
model="bedrock/{}".format(model),
|
||||
|
|
|
@ -1116,6 +1116,7 @@ all_litellm_params = [
|
|||
"cooldown_time",
|
||||
"cache_key",
|
||||
"max_retries",
|
||||
"user_continue_message",
|
||||
]
|
||||
|
||||
|
||||
|
|
|
@ -2323,6 +2323,7 @@ def get_litellm_params(
|
|||
output_cost_per_second=None,
|
||||
cooldown_time=None,
|
||||
text_completion=None,
|
||||
user_continue_message=None,
|
||||
):
|
||||
litellm_params = {
|
||||
"acompletion": acompletion,
|
||||
|
@ -2347,6 +2348,7 @@ def get_litellm_params(
|
|||
"output_cost_per_second": output_cost_per_second,
|
||||
"cooldown_time": cooldown_time,
|
||||
"text_completion": text_completion,
|
||||
"user_continue_message": user_continue_message,
|
||||
}
|
||||
|
||||
return litellm_params
|
||||
|
@ -7123,6 +7125,14 @@ def exception_type(
|
|||
llm_provider="bedrock",
|
||||
response=original_exception.response,
|
||||
)
|
||||
elif "A conversation must start with a user message." in error_str:
|
||||
exception_mapping_worked = True
|
||||
raise BadRequestError(
|
||||
message=f"BedrockException - {error_str}\n. Pass in default user message via `completion(..,user_continue_message=)` or enable `litellm.modify_params=True`.",
|
||||
model=model,
|
||||
llm_provider="bedrock",
|
||||
response=original_exception.response,
|
||||
)
|
||||
elif (
|
||||
"Unable to locate credentials" in error_str
|
||||
or "The security token included in the request is invalid"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue