feat: support tool_choice = {required, none, <function>} (#1059)

Summary:

titled


Test Plan:

added tests and

LLAMA_STACK_CONFIG=fireworks pytest -s -v tests/client-sdk/
--safety-shield meta-llama/Llama-Guard-3-8B
This commit is contained in:
ehhuang 2025-02-18 20:25:15 -08:00 committed by GitHub
parent 37cf60b732
commit 8de7cf103b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 164 additions and 41 deletions

View file

@ -128,7 +128,7 @@ class InferenceRouter(Inference):
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_choice: Optional[ToolChoice] = None,
tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
@ -140,20 +140,36 @@ class InferenceRouter(Inference):
if model.model_type == ModelType.embedding:
raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions")
if tool_config:
if tool_choice != tool_config.tool_choice:
if tool_choice and tool_choice != tool_config.tool_choice:
raise ValueError("tool_choice and tool_config.tool_choice must match")
if tool_prompt_format != tool_config.tool_prompt_format:
if tool_prompt_format and tool_prompt_format != tool_config.tool_prompt_format:
raise ValueError("tool_prompt_format and tool_config.tool_prompt_format must match")
else:
tool_config = ToolConfig(
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
)
params = {}
if tool_choice:
params["tool_choice"] = tool_choice
if tool_prompt_format:
params["tool_prompt_format"] = tool_prompt_format
tool_config = ToolConfig(**params)
tools = tools or []
if tool_config.tool_choice == ToolChoice.none:
tools = []
elif tool_config.tool_choice == ToolChoice.auto:
pass
elif tool_config.tool_choice == ToolChoice.required:
pass
else:
# verify tool_choice is one of the tools
tool_names = [t.tool_name if isinstance(t.tool_name, str) else t.tool_name.value for t in tools]
if tool_config.tool_choice not in tool_names:
raise ValueError(f"Tool choice {tool_config.tool_choice} is not one of the tools: {tool_names}")
params = dict(
model_id=model_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
tools=tools,
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
response_format=response_format,