forked from phoenix-oss/llama-stack-mirror
feat: support tool_choice = {required, none, <function>} (#1059)
Summary: titled Test Plan: added tests and LLAMA_STACK_CONFIG=fireworks pytest -s -v tests/client-sdk/ --safety-shield meta-llama/Llama-Guard-3-8B
This commit is contained in:
parent
37cf60b732
commit
8de7cf103b
7 changed files with 164 additions and 41 deletions
|
@ -98,7 +98,6 @@ def agent_config(llama_stack_client, text_model_id):
|
|||
},
|
||||
},
|
||||
toolgroups=[],
|
||||
tool_choice="auto",
|
||||
input_shields=available_shields,
|
||||
output_shields=available_shields,
|
||||
enable_session_persistence=False,
|
||||
|
@ -322,6 +321,38 @@ def test_custom_tool(llama_stack_client, agent_config):
|
|||
assert "get_boiling_point" in logs_str
|
||||
|
||||
|
||||
def test_tool_choice(llama_stack_client, agent_config):
|
||||
data = [
|
||||
("required", '{"type": "function"'),
|
||||
("none", None),
|
||||
("get_boiling_point", '{"type": "function", "name": "get_boiling_point"'),
|
||||
]
|
||||
client_tool = TestClientTool()
|
||||
for tool_choice, expected_tool in data:
|
||||
agent_config["tool_config"] = {"tool_choice": tool_choice}
|
||||
agent_config["client_tools"] = [client_tool.get_tool_definition()]
|
||||
|
||||
agent = Agent(llama_stack_client, agent_config, client_tools=(client_tool,))
|
||||
session_id = agent.create_session(f"test-session-{uuid4()}")
|
||||
|
||||
response = agent.create_turn(
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is the boiling point of polyjuice?",
|
||||
},
|
||||
],
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
logs = [str(log) for log in EventLogger().log(response) if log is not None]
|
||||
logs_str = "".join(logs)
|
||||
if expected_tool:
|
||||
assert expected_tool in logs_str
|
||||
else:
|
||||
assert '{"type": "function"' not in logs_str
|
||||
|
||||
|
||||
# TODO: fix this flaky test
|
||||
def xtest_override_system_message_behavior(llama_stack_client, agent_config):
|
||||
client_tool = TestClientTool()
|
||||
|
|
|
@ -247,6 +247,42 @@ def test_text_chat_completion_with_tool_calling_and_streaming(
|
|||
assert tool_invocation_content == "[get_weather, {'location': 'San Francisco, CA'}]"
|
||||
|
||||
|
||||
def test_text_chat_completion_with_tool_choice_required(
|
||||
llama_stack_client, text_model_id, get_weather_tool_definition, provider_tool_format, inference_provider_type
|
||||
):
|
||||
if inference_provider_type == "remote::vllm":
|
||||
pytest.xfail("vllm-project/vllm#13002")
|
||||
response = llama_stack_client.inference.chat_completion(
|
||||
model_id=text_model_id,
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "What's the weather like in San Francisco?"},
|
||||
],
|
||||
tools=[get_weather_tool_definition],
|
||||
tool_config={"tool_choice": "required", "tool_prompt_format": provider_tool_format},
|
||||
stream=True,
|
||||
)
|
||||
tool_invocation_content = extract_tool_invocation_content(response)
|
||||
assert tool_invocation_content == "[get_weather, {'location': 'San Francisco, CA'}]"
|
||||
|
||||
|
||||
def test_text_chat_completion_with_tool_choice_none(
|
||||
llama_stack_client, text_model_id, get_weather_tool_definition, provider_tool_format
|
||||
):
|
||||
response = llama_stack_client.inference.chat_completion(
|
||||
model_id=text_model_id,
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "What's the weather like in San Francisco?"},
|
||||
],
|
||||
tools=[get_weather_tool_definition],
|
||||
tool_config={"tool_choice": "none", "tool_prompt_format": provider_tool_format},
|
||||
stream=True,
|
||||
)
|
||||
tool_invocation_content = extract_tool_invocation_content(response)
|
||||
assert tool_invocation_content == ""
|
||||
|
||||
|
||||
def test_text_chat_completion_structured_output(llama_stack_client, text_model_id, inference_provider_type):
|
||||
class AnswerFormat(BaseModel):
|
||||
first_name: str
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue