forked from phoenix-oss/llama-stack-mirror
feat: support tool_choice = {required, none, <function>} (#1059)
Summary: titled Test Plan: added tests and LLAMA_STACK_CONFIG=fireworks pytest -s -v tests/client-sdk/ --safety-shield meta-llama/Llama-Guard-3-8B
This commit is contained in:
parent
37cf60b732
commit
8de7cf103b
7 changed files with 164 additions and 41 deletions
|
@ -128,7 +128,7 @@ class InferenceRouter(Inference):
|
|||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_choice: Optional[ToolChoice] = None,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
|
@ -140,20 +140,36 @@ class InferenceRouter(Inference):
|
|||
if model.model_type == ModelType.embedding:
|
||||
raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions")
|
||||
if tool_config:
|
||||
if tool_choice != tool_config.tool_choice:
|
||||
if tool_choice and tool_choice != tool_config.tool_choice:
|
||||
raise ValueError("tool_choice and tool_config.tool_choice must match")
|
||||
if tool_prompt_format != tool_config.tool_prompt_format:
|
||||
if tool_prompt_format and tool_prompt_format != tool_config.tool_prompt_format:
|
||||
raise ValueError("tool_prompt_format and tool_config.tool_prompt_format must match")
|
||||
else:
|
||||
tool_config = ToolConfig(
|
||||
tool_choice=tool_choice,
|
||||
tool_prompt_format=tool_prompt_format,
|
||||
)
|
||||
params = {}
|
||||
if tool_choice:
|
||||
params["tool_choice"] = tool_choice
|
||||
if tool_prompt_format:
|
||||
params["tool_prompt_format"] = tool_prompt_format
|
||||
tool_config = ToolConfig(**params)
|
||||
|
||||
tools = tools or []
|
||||
if tool_config.tool_choice == ToolChoice.none:
|
||||
tools = []
|
||||
elif tool_config.tool_choice == ToolChoice.auto:
|
||||
pass
|
||||
elif tool_config.tool_choice == ToolChoice.required:
|
||||
pass
|
||||
else:
|
||||
# verify tool_choice is one of the tools
|
||||
tool_names = [t.tool_name if isinstance(t.tool_name, str) else t.tool_name.value for t in tools]
|
||||
if tool_config.tool_choice not in tool_names:
|
||||
raise ValueError(f"Tool choice {tool_config.tool_choice} is not one of the tools: {tool_names}")
|
||||
|
||||
params = dict(
|
||||
model_id=model_id,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
tools=tools or [],
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
tool_prompt_format=tool_prompt_format,
|
||||
response_format=response_format,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue