mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 19:04:19 +00:00
feat: support tool_choice = {required, none, <function>} (#1059)
Summary: titled Test Plan: added tests and LLAMA_STACK_CONFIG=fireworks pytest -s -v tests/client-sdk/ --safety-shield meta-llama/Llama-Guard-3-8B
This commit is contained in:
parent
37cf60b732
commit
8de7cf103b
7 changed files with 164 additions and 41 deletions
|
@ -247,6 +247,42 @@ def test_text_chat_completion_with_tool_calling_and_streaming(
|
|||
assert tool_invocation_content == "[get_weather, {'location': 'San Francisco, CA'}]"
|
||||
|
||||
|
||||
def test_text_chat_completion_with_tool_choice_required(
|
||||
llama_stack_client, text_model_id, get_weather_tool_definition, provider_tool_format, inference_provider_type
|
||||
):
|
||||
if inference_provider_type == "remote::vllm":
|
||||
pytest.xfail("vllm-project/vllm#13002")
|
||||
response = llama_stack_client.inference.chat_completion(
|
||||
model_id=text_model_id,
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "What's the weather like in San Francisco?"},
|
||||
],
|
||||
tools=[get_weather_tool_definition],
|
||||
tool_config={"tool_choice": "required", "tool_prompt_format": provider_tool_format},
|
||||
stream=True,
|
||||
)
|
||||
tool_invocation_content = extract_tool_invocation_content(response)
|
||||
assert tool_invocation_content == "[get_weather, {'location': 'San Francisco, CA'}]"
|
||||
|
||||
|
||||
def test_text_chat_completion_with_tool_choice_none(
|
||||
llama_stack_client, text_model_id, get_weather_tool_definition, provider_tool_format
|
||||
):
|
||||
response = llama_stack_client.inference.chat_completion(
|
||||
model_id=text_model_id,
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "What's the weather like in San Francisco?"},
|
||||
],
|
||||
tools=[get_weather_tool_definition],
|
||||
tool_config={"tool_choice": "none", "tool_prompt_format": provider_tool_format},
|
||||
stream=True,
|
||||
)
|
||||
tool_invocation_content = extract_tool_invocation_content(response)
|
||||
assert tool_invocation_content == ""
|
||||
|
||||
|
||||
def test_text_chat_completion_structured_output(llama_stack_client, text_model_id, inference_provider_type):
|
||||
class AnswerFormat(BaseModel):
|
||||
first_name: str
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue