feat: support tool_choice = {required, none, <function>} (#1059)

Summary:

titled


Test Plan:

added tests and

LLAMA_STACK_CONFIG=fireworks pytest -s -v tests/client-sdk/
--safety-shield meta-llama/Llama-Guard-3-8B
This commit is contained in:
ehhuang 2025-02-18 20:25:15 -08:00 committed by GitHub
parent 37cf60b732
commit 8de7cf103b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 164 additions and 41 deletions

View file

@ -247,6 +247,42 @@ def test_text_chat_completion_with_tool_calling_and_streaming(
assert tool_invocation_content == "[get_weather, {'location': 'San Francisco, CA'}]"
def test_text_chat_completion_with_tool_choice_required(
llama_stack_client, text_model_id, get_weather_tool_definition, provider_tool_format, inference_provider_type
):
if inference_provider_type == "remote::vllm":
pytest.xfail("vllm-project/vllm#13002")
response = llama_stack_client.inference.chat_completion(
model_id=text_model_id,
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What's the weather like in San Francisco?"},
],
tools=[get_weather_tool_definition],
tool_config={"tool_choice": "required", "tool_prompt_format": provider_tool_format},
stream=True,
)
tool_invocation_content = extract_tool_invocation_content(response)
assert tool_invocation_content == "[get_weather, {'location': 'San Francisco, CA'}]"
def test_text_chat_completion_with_tool_choice_none(
llama_stack_client, text_model_id, get_weather_tool_definition, provider_tool_format
):
response = llama_stack_client.inference.chat_completion(
model_id=text_model_id,
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What's the weather like in San Francisco?"},
],
tools=[get_weather_tool_definition],
tool_config={"tool_choice": "none", "tool_prompt_format": provider_tool_format},
stream=True,
)
tool_invocation_content = extract_tool_invocation_content(response)
assert tool_invocation_content == ""
def test_text_chat_completion_structured_output(llama_stack_client, text_model_id, inference_provider_type):
class AnswerFormat(BaseModel):
first_name: str