mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
feat(factory.py): enable 'user_continue_message' for interweaving user/assistant messages when provider requires it
allows bedrock to be used with autogen
This commit is contained in:
parent
11bfc1dca7
commit
70bf8bd4f4
6 changed files with 54 additions and 8 deletions
|
@ -84,6 +84,7 @@ BEDROCK_CONVERSE_MODELS = [
|
||||||
"meta.llama3-1-8b-instruct-v1:0",
|
"meta.llama3-1-8b-instruct-v1:0",
|
||||||
"meta.llama3-1-70b-instruct-v1:0",
|
"meta.llama3-1-70b-instruct-v1:0",
|
||||||
"meta.llama3-1-405b-instruct-v1:0",
|
"meta.llama3-1-405b-instruct-v1:0",
|
||||||
|
"meta.llama3-70b-instruct-v1:0",
|
||||||
"mistral.mistral-large-2407-v1:0",
|
"mistral.mistral-large-2407-v1:0",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -1480,7 +1481,7 @@ class BedrockConverseLLM(BaseAWSLLM):
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
acompletion: bool,
|
acompletion: bool,
|
||||||
timeout: Optional[Union[float, httpx.Timeout]],
|
timeout: Optional[Union[float, httpx.Timeout]],
|
||||||
litellm_params=None,
|
litellm_params: dict,
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
extra_headers: Optional[dict] = None,
|
extra_headers: Optional[dict] = None,
|
||||||
client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
|
client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
|
||||||
|
@ -1596,6 +1597,14 @@ class BedrockConverseLLM(BaseAWSLLM):
|
||||||
supported_tool_call_params = ["tools", "tool_choice"]
|
supported_tool_call_params = ["tools", "tool_choice"]
|
||||||
supported_guardrail_params = ["guardrailConfig"]
|
supported_guardrail_params = ["guardrailConfig"]
|
||||||
## TRANSFORMATION ##
|
## TRANSFORMATION ##
|
||||||
|
|
||||||
|
bedrock_messages: List[MessageBlock] = _bedrock_converse_messages_pt(
|
||||||
|
messages=messages,
|
||||||
|
model=model,
|
||||||
|
llm_provider="bedrock_converse",
|
||||||
|
user_continue_message=litellm_params.pop("user_continue_message", None),
|
||||||
|
)
|
||||||
|
|
||||||
# send all model-specific params in 'additional_request_params'
|
# send all model-specific params in 'additional_request_params'
|
||||||
for k, v in inference_params.items():
|
for k, v in inference_params.items():
|
||||||
if (
|
if (
|
||||||
|
@ -1608,11 +1617,6 @@ class BedrockConverseLLM(BaseAWSLLM):
|
||||||
for key in additional_request_keys:
|
for key in additional_request_keys:
|
||||||
inference_params.pop(key, None)
|
inference_params.pop(key, None)
|
||||||
|
|
||||||
bedrock_messages: List[MessageBlock] = _bedrock_converse_messages_pt(
|
|
||||||
messages=messages,
|
|
||||||
model=model,
|
|
||||||
llm_provider="bedrock_converse",
|
|
||||||
)
|
|
||||||
bedrock_tools: List[ToolBlock] = _bedrock_tools_pt(
|
bedrock_tools: List[ToolBlock] = _bedrock_tools_pt(
|
||||||
inference_params.pop("tools", [])
|
inference_params.pop("tools", [])
|
||||||
)
|
)
|
||||||
|
|
|
@ -38,6 +38,18 @@ def prompt_injection_detection_default_pt():
|
||||||
|
|
||||||
BAD_MESSAGE_ERROR_STR = "Invalid Message "
|
BAD_MESSAGE_ERROR_STR = "Invalid Message "
|
||||||
|
|
||||||
|
# used to interweave user messages, to ensure user/assistant alternating
|
||||||
|
DEFAULT_USER_CONTINUE_MESSAGE = {
|
||||||
|
"role": "user",
|
||||||
|
"content": "Please continue.",
|
||||||
|
} # similar to autogen. Only used if `litellm.modify_params=True`.
|
||||||
|
|
||||||
|
# used to interweave assistant messages, to ensure user/assistant alternating
|
||||||
|
DEFAULT_ASSISTANT_CONTINUE_MESSAGE = {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "Please continue.",
|
||||||
|
} # similar to autogen. Only used if `litellm.modify_params=True`.
|
||||||
|
|
||||||
|
|
||||||
def map_system_message_pt(messages: list) -> list:
|
def map_system_message_pt(messages: list) -> list:
|
||||||
"""
|
"""
|
||||||
|
@ -2254,6 +2266,7 @@ def _bedrock_converse_messages_pt(
|
||||||
messages: List,
|
messages: List,
|
||||||
model: str,
|
model: str,
|
||||||
llm_provider: str,
|
llm_provider: str,
|
||||||
|
user_continue_message: Optional[dict] = None,
|
||||||
) -> List[BedrockMessageBlock]:
|
) -> List[BedrockMessageBlock]:
|
||||||
"""
|
"""
|
||||||
Converts given messages from OpenAI format to Bedrock format
|
Converts given messages from OpenAI format to Bedrock format
|
||||||
|
@ -2264,6 +2277,21 @@ def _bedrock_converse_messages_pt(
|
||||||
|
|
||||||
contents: List[BedrockMessageBlock] = []
|
contents: List[BedrockMessageBlock] = []
|
||||||
msg_i = 0
|
msg_i = 0
|
||||||
|
|
||||||
|
# if initial message is assistant message
|
||||||
|
if messages[0].get("role") is not None and messages[0]["role"] == "assistant":
|
||||||
|
if user_continue_message is not None:
|
||||||
|
messages.insert(0, user_continue_message)
|
||||||
|
elif litellm.modify_params:
|
||||||
|
messages.insert(0, DEFAULT_USER_CONTINUE_MESSAGE)
|
||||||
|
|
||||||
|
# if final message is assistant message
|
||||||
|
if messages[-1].get("role") is not None and messages[-1]["role"] == "assistant":
|
||||||
|
if user_continue_message is not None:
|
||||||
|
messages.append(user_continue_message)
|
||||||
|
elif litellm.modify_params:
|
||||||
|
messages.append(DEFAULT_USER_CONTINUE_MESSAGE)
|
||||||
|
|
||||||
while msg_i < len(messages):
|
while msg_i < len(messages):
|
||||||
user_content: List[BedrockContentBlock] = []
|
user_content: List[BedrockContentBlock] = []
|
||||||
init_msg_i = msg_i
|
init_msg_i = msg_i
|
||||||
|
@ -2344,6 +2372,7 @@ def _bedrock_converse_messages_pt(
|
||||||
model=model,
|
model=model,
|
||||||
llm_provider=llm_provider,
|
llm_provider=llm_provider,
|
||||||
)
|
)
|
||||||
|
|
||||||
return contents
|
return contents
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -943,6 +943,7 @@ def completion(
|
||||||
output_cost_per_token=output_cost_per_token,
|
output_cost_per_token=output_cost_per_token,
|
||||||
cooldown_time=cooldown_time,
|
cooldown_time=cooldown_time,
|
||||||
text_completion=kwargs.get("text_completion"),
|
text_completion=kwargs.get("text_completion"),
|
||||||
|
user_continue_message=kwargs.get("user_continue_message"),
|
||||||
)
|
)
|
||||||
logging.update_environment_variables(
|
logging.update_environment_variables(
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -2304,7 +2305,7 @@ def completion(
|
||||||
model_response=model_response,
|
model_response=model_response,
|
||||||
print_verbose=print_verbose,
|
print_verbose=print_verbose,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
litellm_params=litellm_params,
|
litellm_params=litellm_params, # type: ignore
|
||||||
logger_fn=logger_fn,
|
logger_fn=logger_fn,
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
logging_obj=logging,
|
logging_obj=logging,
|
||||||
|
|
|
@ -738,8 +738,9 @@ def test_bedrock_system_prompt(system, model):
|
||||||
"temperature": 0.3,
|
"temperature": 0.3,
|
||||||
"messages": [
|
"messages": [
|
||||||
{"role": "system", "content": system},
|
{"role": "system", "content": system},
|
||||||
{"role": "user", "content": "hey, how's it going?"},
|
{"role": "assistant", "content": "hey, how's it going?"},
|
||||||
],
|
],
|
||||||
|
"user_continue_message": {"role": "user", "content": "Be a good bot!"},
|
||||||
}
|
}
|
||||||
response: ModelResponse = completion(
|
response: ModelResponse = completion(
|
||||||
model="bedrock/{}".format(model),
|
model="bedrock/{}".format(model),
|
||||||
|
|
|
@ -1116,6 +1116,7 @@ all_litellm_params = [
|
||||||
"cooldown_time",
|
"cooldown_time",
|
||||||
"cache_key",
|
"cache_key",
|
||||||
"max_retries",
|
"max_retries",
|
||||||
|
"user_continue_message",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -2323,6 +2323,7 @@ def get_litellm_params(
|
||||||
output_cost_per_second=None,
|
output_cost_per_second=None,
|
||||||
cooldown_time=None,
|
cooldown_time=None,
|
||||||
text_completion=None,
|
text_completion=None,
|
||||||
|
user_continue_message=None,
|
||||||
):
|
):
|
||||||
litellm_params = {
|
litellm_params = {
|
||||||
"acompletion": acompletion,
|
"acompletion": acompletion,
|
||||||
|
@ -2347,6 +2348,7 @@ def get_litellm_params(
|
||||||
"output_cost_per_second": output_cost_per_second,
|
"output_cost_per_second": output_cost_per_second,
|
||||||
"cooldown_time": cooldown_time,
|
"cooldown_time": cooldown_time,
|
||||||
"text_completion": text_completion,
|
"text_completion": text_completion,
|
||||||
|
"user_continue_message": user_continue_message,
|
||||||
}
|
}
|
||||||
|
|
||||||
return litellm_params
|
return litellm_params
|
||||||
|
@ -7123,6 +7125,14 @@ def exception_type(
|
||||||
llm_provider="bedrock",
|
llm_provider="bedrock",
|
||||||
response=original_exception.response,
|
response=original_exception.response,
|
||||||
)
|
)
|
||||||
|
elif "A conversation must start with a user message." in error_str:
|
||||||
|
exception_mapping_worked = True
|
||||||
|
raise BadRequestError(
|
||||||
|
message=f"BedrockException - {error_str}\n. Pass in default user message via `completion(..,user_continue_message=)` or enable `litellm.modify_params=True`.",
|
||||||
|
model=model,
|
||||||
|
llm_provider="bedrock",
|
||||||
|
response=original_exception.response,
|
||||||
|
)
|
||||||
elif (
|
elif (
|
||||||
"Unable to locate credentials" in error_str
|
"Unable to locate credentials" in error_str
|
||||||
or "The security token included in the request is invalid"
|
or "The security token included in the request is invalid"
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue