feat: support tool_choice = {required, none, <function>} (#1059)

Summary:

titled


Test Plan:

added tests and

LLAMA_STACK_CONFIG=fireworks pytest -s -v tests/client-sdk/
--safety-shield meta-llama/Llama-Guard-3-8B
This commit is contained in:
ehhuang 2025-02-18 20:25:15 -08:00 committed by GitHub
parent 37cf60b732
commit 8de7cf103b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 164 additions and 41 deletions

View file

@ -98,7 +98,6 @@ def agent_config(llama_stack_client, text_model_id):
},
},
toolgroups=[],
tool_choice="auto",
input_shields=available_shields,
output_shields=available_shields,
enable_session_persistence=False,
@ -322,6 +321,38 @@ def test_custom_tool(llama_stack_client, agent_config):
assert "get_boiling_point" in logs_str
def test_tool_choice(llama_stack_client, agent_config):
data = [
("required", '{"type": "function"'),
("none", None),
("get_boiling_point", '{"type": "function", "name": "get_boiling_point"'),
]
client_tool = TestClientTool()
for tool_choice, expected_tool in data:
agent_config["tool_config"] = {"tool_choice": tool_choice}
agent_config["client_tools"] = [client_tool.get_tool_definition()]
agent = Agent(llama_stack_client, agent_config, client_tools=(client_tool,))
session_id = agent.create_session(f"test-session-{uuid4()}")
response = agent.create_turn(
messages=[
{
"role": "user",
"content": "What is the boiling point of polyjuice?",
},
],
session_id=session_id,
)
logs = [str(log) for log in EventLogger().log(response) if log is not None]
logs_str = "".join(logs)
if expected_tool:
assert expected_tool in logs_str
else:
assert '{"type": "function"' not in logs_str
# TODO: fix this flaky test
def xtest_override_system_message_behavior(llama_stack_client, agent_config):
client_tool = TestClientTool()

View file

@ -247,6 +247,42 @@ def test_text_chat_completion_with_tool_calling_and_streaming(
assert tool_invocation_content == "[get_weather, {'location': 'San Francisco, CA'}]"
def test_text_chat_completion_with_tool_choice_required(
llama_stack_client, text_model_id, get_weather_tool_definition, provider_tool_format, inference_provider_type
):
if inference_provider_type == "remote::vllm":
pytest.xfail("vllm-project/vllm#13002")
response = llama_stack_client.inference.chat_completion(
model_id=text_model_id,
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What's the weather like in San Francisco?"},
],
tools=[get_weather_tool_definition],
tool_config={"tool_choice": "required", "tool_prompt_format": provider_tool_format},
stream=True,
)
tool_invocation_content = extract_tool_invocation_content(response)
assert tool_invocation_content == "[get_weather, {'location': 'San Francisco, CA'}]"
def test_text_chat_completion_with_tool_choice_none(
llama_stack_client, text_model_id, get_weather_tool_definition, provider_tool_format
):
response = llama_stack_client.inference.chat_completion(
model_id=text_model_id,
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What's the weather like in San Francisco?"},
],
tools=[get_weather_tool_definition],
tool_config={"tool_choice": "none", "tool_prompt_format": provider_tool_format},
stream=True,
)
tool_invocation_content = extract_tool_invocation_content(response)
assert tool_invocation_content == ""
def test_text_chat_completion_structured_output(llama_stack_client, text_model_id, inference_provider_type):
class AnswerFormat(BaseModel):
first_name: str