diff --git a/docs/resources/llama-stack-spec.html b/docs/resources/llama-stack-spec.html
index 85ae516f5..b22ce5914 100644
--- a/docs/resources/llama-stack-spec.html
+++ b/docs/resources/llama-stack-spec.html
@@ -3161,6 +3161,43 @@
"job_uuid"
]
},
+ "ToolConfig": {
+ "type": "object",
+ "properties": {
+ "tool_choice": {
+ "type": "string",
+ "enum": [
+ "auto",
+ "required"
+ ],
+ "default": "auto",
+ "description": "(Optional) Whether tool use is required or automatic. Defaults to ToolChoice.auto."
+ },
+ "tool_prompt_format": {
+ "type": "string",
+ "enum": [
+ "json",
+ "function_tag",
+ "python_list"
+ ],
+ "description": "(Optional) Instructs the model how to format tool calls. By default, Llama Stack will attempt to use a format that is best adapted to the model. - `ToolPromptFormat.json`: The tool calls are formatted as a JSON object. - `ToolPromptFormat.function_tag`: The tool calls are enclosed in a tag. - `ToolPromptFormat.python_list`: The tool calls are output as Python syntax -- a list of function calls."
+ },
+ "system_message_behavior": {
+ "type": "string",
+ "enum": [
+ "append",
+ "replace"
+ ],
+ "default": "append",
+ "description": "(Optional) Config for how to override the default system prompt. - `SystemMessageBehavior.append`: Appends the provided system message to the default system prompt. - `SystemMessageBehavior.replace`: Replaces the default system prompt with the provided system message. The system message can include the string '{{function_definitions}}' to indicate where the function definitions should be inserted."
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "system_message_behavior"
+ ],
+ "title": "Configuration for tool use."
+ },
"ChatCompletionRequest": {
"type": "object",
"properties": {
@@ -3192,7 +3229,7 @@
"auto",
"required"
],
- "description": "(Optional) Whether tool use is required or automatic. Defaults to ToolChoice.auto."
+ "description": "(Optional) Whether tool use is required or automatic. Defaults to ToolChoice.auto. .. deprecated:: Use tool_config instead."
},
"tool_prompt_format": {
"type": "string",
@@ -3201,7 +3238,7 @@
"function_tag",
"python_list"
],
- "description": "(Optional) Instructs the model how to format tool calls. By default, Llama Stack will attempt to use a format that is best adapted to the model. - `ToolPromptFormat.json`: The tool calls are formatted as a JSON object. - `ToolPromptFormat.function_tag`: The tool calls are enclosed in a tag. - `ToolPromptFormat.python_list`: The tool calls are output as Python syntax -- a list of function calls."
+ "description": "(Optional) Instructs the model how to format tool calls. By default, Llama Stack will attempt to use a format that is best adapted to the model. - `ToolPromptFormat.json`: The tool calls are formatted as a JSON object. - `ToolPromptFormat.function_tag`: The tool calls are enclosed in a tag. - `ToolPromptFormat.python_list`: The tool calls are output as Python syntax -- a list of function calls. .. deprecated:: Use tool_config instead."
},
"response_format": {
"$ref": "#/components/schemas/ResponseFormat",
@@ -3222,6 +3259,10 @@
},
"additionalProperties": false,
"description": "(Optional) If specified, log probabilities for each token position will be returned."
+ },
+ "tool_config": {
+ "$ref": "#/components/schemas/ToolConfig",
+ "description": "(Optional) Configuration for tool use."
}
},
"additionalProperties": false,
diff --git a/docs/resources/llama-stack-spec.yaml b/docs/resources/llama-stack-spec.yaml
index 2a95acf38..c49c948de 100644
--- a/docs/resources/llama-stack-spec.yaml
+++ b/docs/resources/llama-stack-spec.yaml
@@ -1956,6 +1956,46 @@ components:
additionalProperties: false
required:
- job_uuid
+ ToolConfig:
+ type: object
+ properties:
+ tool_choice:
+ type: string
+ enum:
+ - auto
+ - required
+ default: auto
+ description: >-
+ (Optional) Whether tool use is required or automatic. Defaults to ToolChoice.auto.
+ tool_prompt_format:
+ type: string
+ enum:
+ - json
+ - function_tag
+ - python_list
+ description: >-
+ (Optional) Instructs the model how to format tool calls. By default, Llama
+ Stack will attempt to use a format that is best adapted to the model.
+ - `ToolPromptFormat.json`: The tool calls are formatted as a JSON object.
+ - `ToolPromptFormat.function_tag`: The tool calls are enclosed in a
+ tag. - `ToolPromptFormat.python_list`: The tool calls are output as Python
+ syntax -- a list of function calls.
+ system_message_behavior:
+ type: string
+ enum:
+ - append
+ - replace
+ default: append
+ description: >-
+ (Optional) Config for how to override the default system prompt. - `SystemMessageBehavior.append`:
+ Appends the provided system message to the default system prompt. - `SystemMessageBehavior.replace`:
+ Replaces the default system prompt with the provided system message. The
+ system message can include the string '{{function_definitions}}' to indicate
+ where the function definitions should be inserted.
+ additionalProperties: false
+ required:
+ - system_message_behavior
+ title: Configuration for tool use.
ChatCompletionRequest:
type: object
properties:
@@ -1986,6 +2026,7 @@ components:
- required
description: >-
(Optional) Whether tool use is required or automatic. Defaults to ToolChoice.auto.
+ .. deprecated:: Use tool_config instead.
tool_prompt_format:
type: string
enum:
@@ -1998,7 +2039,7 @@ components:
- `ToolPromptFormat.json`: The tool calls are formatted as a JSON object.
- `ToolPromptFormat.function_tag`: The tool calls are enclosed in a
tag. - `ToolPromptFormat.python_list`: The tool calls are output as Python
- syntax -- a list of function calls.
+ syntax -- a list of function calls. .. deprecated:: Use tool_config instead.
response_format:
$ref: '#/components/schemas/ResponseFormat'
description: >-
@@ -2024,6 +2065,9 @@ components:
description: >-
(Optional) If specified, log probabilities for each token position will
be returned.
+ tool_config:
+ $ref: '#/components/schemas/ToolConfig'
+ description: (Optional) Configuration for tool use.
additionalProperties: false
required:
- model_id
diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py
index 6398f74e8..d4f13d65c 100644
--- a/llama_stack/apis/inference/inference.py
+++ b/llama_stack/apis/inference/inference.py
@@ -308,14 +308,49 @@ class CompletionResponseStreamChunk(BaseModel):
logprobs: Optional[List[TokenLogProbs]] = None
+class SystemMessageBehavior(Enum):
+ """Config for how to override the default system prompt.
+
+ :cvar append: Appends the provided system message to the default system prompt:
+ https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_2/#-function-definitions-in-the-system-prompt-
+ :cvar replace: Replaces the default system prompt with the provided system message. The system message can include the string
+ '{{function_definitions}}' to indicate where the function definitions should be inserted.
+ """
+
+ append = "append"
+ replace = "replace"
+
+
+@json_schema_type
+class ToolConfig(BaseModel):
+ """Configuration for tool use.
+
+ :param tool_choice: (Optional) Whether tool use is required or automatic. Defaults to ToolChoice.auto.
+ :param tool_prompt_format: (Optional) Instructs the model how to format tool calls. By default, Llama Stack will attempt to use a format that is best adapted to the model.
+ - `ToolPromptFormat.json`: The tool calls are formatted as a JSON object.
+ - `ToolPromptFormat.function_tag`: The tool calls are enclosed in a tag.
+ - `ToolPromptFormat.python_list`: The tool calls are output as Python syntax -- a list of function calls.
+ :param system_message_behavior: (Optional) Config for how to override the default system prompt.
+ - `SystemMessageBehavior.append`: Appends the provided system message to the default system prompt.
+ - `SystemMessageBehavior.replace`: Replaces the default system prompt with the provided system message. The system message can include the string
+ '{{function_definitions}}' to indicate where the function definitions should be inserted.
+ """
+
+ tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
+ tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None)
+ system_message_behavior: SystemMessageBehavior = Field(default=SystemMessageBehavior.append)
+
+
# This is an internally used class
+@json_schema_type
class ChatCompletionRequest(BaseModel):
model: str
messages: List[Message]
sampling_params: Optional[SamplingParams] = SamplingParams()
+
tools: Optional[List[ToolDefinition]] = Field(default_factory=list)
- tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
- tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None)
+ tool_config: Optional[ToolConfig] = Field(default_factory=ToolConfig)
+
response_format: Optional[ResponseFormat] = None
stream: Optional[bool] = False
logprobs: Optional[LogProbConfig] = None
@@ -404,6 +439,7 @@ class Inference(Protocol):
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
+ tool_config: Optional[ToolConfig] = None,
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
"""Generate a chat completion for the given messages using the specified model.
@@ -412,15 +448,20 @@ class Inference(Protocol):
:param sampling_params: Parameters to control the sampling strategy
:param tools: (Optional) List of tool definitions available to the model
:param tool_choice: (Optional) Whether tool use is required or automatic. Defaults to ToolChoice.auto.
+ .. deprecated::
+ Use tool_config instead.
:param tool_prompt_format: (Optional) Instructs the model how to format tool calls. By default, Llama Stack will attempt to use a format that is best adapted to the model.
- `ToolPromptFormat.json`: The tool calls are formatted as a JSON object.
- `ToolPromptFormat.function_tag`: The tool calls are enclosed in a tag.
- `ToolPromptFormat.python_list`: The tool calls are output as Python syntax -- a list of function calls.
+ .. deprecated::
+ Use tool_config instead.
:param response_format: (Optional) Grammar specification for guided (structured) decoding. There are two options:
- `ResponseFormat.json_schema`: The grammar is a JSON schema. Most providers support this format.
- `ResponseFormat.grammar`: The grammar is a BNF grammar. This format is more flexible, but not all providers support it.
:param stream: (Optional) If True, generate an SSE event stream of the response. Defaults to False.
:param logprobs: (Optional) If specified, log probabilities for each token position will be returned.
+ :param tool_config: (Optional) Configuration for tool use.
:returns: If stream=False, returns a ChatCompletionResponse with the full completion.
If stream=True, returns an SSE event stream of ChatCompletionResponseStreamChunk
"""
diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py
index c5a7e3af6..6cddcf73c 100644
--- a/llama_stack/distribution/routers/routers.py
+++ b/llama_stack/distribution/routers/routers.py
@@ -24,6 +24,7 @@ from llama_stack.apis.inference import (
ResponseFormat,
SamplingParams,
ToolChoice,
+ ToolConfig,
ToolDefinition,
ToolPromptFormat,
)
@@ -132,12 +133,23 @@ class InferenceRouter(Inference):
tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
+ tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator:
model = await self.routing_table.get_model(model_id)
if model is None:
raise ValueError(f"Model '{model_id}' not found")
if model.model_type == ModelType.embedding:
raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions")
+ if tool_config:
+ if tool_choice != tool_config.tool_choice:
+ raise ValueError("tool_choice and tool_config.tool_choice must match")
+ if tool_prompt_format != tool_config.tool_prompt_format:
+ raise ValueError("tool_prompt_format and tool_config.tool_prompt_format must match")
+ else:
+ tool_config = ToolConfig(
+ tool_choice=tool_choice,
+ tool_prompt_format=tool_prompt_format,
+ )
params = dict(
model_id=model_id,
messages=messages,
@@ -148,6 +160,7 @@ class InferenceRouter(Inference):
response_format=response_format,
stream=stream,
logprobs=logprobs,
+ tool_config=tool_config,
)
provider = self.routing_table.get_provider_impl(model_id)
if stream:
diff --git a/llama_stack/providers/inline/inference/meta_reference/generation.py b/llama_stack/providers/inline/inference/meta_reference/generation.py
index 4048972df..51c10b0a8 100644
--- a/llama_stack/providers/inline/inference/meta_reference/generation.py
+++ b/llama_stack/providers/inline/inference/meta_reference/generation.py
@@ -400,7 +400,7 @@ class Llama:
yield from self.generate(
model_input=self.formatter.encode_dialog_prompt(
request.messages,
- request.tool_prompt_format,
+ request.tool_config.tool_prompt_format,
),
max_gen_len=max_gen_len,
temperature=temperature,
diff --git a/llama_stack/providers/inline/inference/meta_reference/inference.py b/llama_stack/providers/inline/inference/meta_reference/inference.py
index 7e3508148..3caf4e2a5 100644
--- a/llama_stack/providers/inline/inference/meta_reference/inference.py
+++ b/llama_stack/providers/inline/inference/meta_reference/inference.py
@@ -38,6 +38,7 @@ from llama_stack.apis.inference import (
ResponseFormat,
TokenLogProbs,
ToolChoice,
+ ToolConfig,
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.providers.datatypes import ModelsProtocolPrivate
@@ -252,6 +253,7 @@ class MetaReferenceInferenceImpl(
tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
+ tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator:
if logprobs:
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
@@ -262,11 +264,10 @@ class MetaReferenceInferenceImpl(
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
- tool_choice=tool_choice,
- tool_prompt_format=tool_prompt_format,
response_format=response_format,
stream=stream,
logprobs=logprobs,
+ tool_config=tool_config,
)
self.check_model(request)
diff --git a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py
index 3920ee1ad..d34befbd9 100644
--- a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py
+++ b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py
@@ -17,6 +17,7 @@ from llama_stack.apis.inference import (
ToolChoice,
ToolDefinition,
ToolPromptFormat,
+ ToolConfig,
)
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
from llama_stack.providers.utils.inference.embedding_mixin import (
@@ -71,5 +72,6 @@ class SentenceTransformersInferenceImpl(
tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
+ tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator:
raise ValueError("Sentence transformers don't support chat completion")
diff --git a/llama_stack/providers/inline/inference/vllm/vllm.py b/llama_stack/providers/inline/inference/vllm/vllm.py
index 6f35d0c59..691737c15 100644
--- a/llama_stack/providers/inline/inference/vllm/vllm.py
+++ b/llama_stack/providers/inline/inference/vllm/vllm.py
@@ -30,6 +30,7 @@ from llama_stack.apis.inference import (
ResponseFormat,
SamplingParams,
ToolChoice,
+ ToolConfig,
ToolDefinition,
ToolPromptFormat,
)
@@ -159,6 +160,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
+ tool_config: Optional[ToolConfig] = None,
) -> ChatCompletionResponse | ChatCompletionResponseStreamChunk:
assert self.engine is not None
@@ -167,10 +169,9 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
- tool_choice=tool_choice,
- tool_prompt_format=tool_prompt_format,
stream=stream,
logprobs=logprobs,
+ tool_config=tool_config,
)
log.info("Sampling params: %s", sampling_params)
diff --git a/llama_stack/providers/remote/inference/bedrock/bedrock.py b/llama_stack/providers/remote/inference/bedrock/bedrock.py
index c1297d022..03a0a40c3 100644
--- a/llama_stack/providers/remote/inference/bedrock/bedrock.py
+++ b/llama_stack/providers/remote/inference/bedrock/bedrock.py
@@ -24,6 +24,7 @@ from llama_stack.apis.inference import (
ResponseFormat,
SamplingParams,
ToolChoice,
+ ToolConfig,
ToolDefinition,
ToolPromptFormat,
)
@@ -102,6 +103,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
+ tool_config: Optional[ToolConfig] = None,
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest(
@@ -109,11 +111,10 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
- tool_choice=tool_choice,
- tool_prompt_format=tool_prompt_format,
response_format=response_format,
stream=stream,
logprobs=logprobs,
+ tool_config=tool_config,
)
if stream:
diff --git a/llama_stack/providers/remote/inference/cerebras/cerebras.py b/llama_stack/providers/remote/inference/cerebras/cerebras.py
index eb77741e0..bd12c56c8 100644
--- a/llama_stack/providers/remote/inference/cerebras/cerebras.py
+++ b/llama_stack/providers/remote/inference/cerebras/cerebras.py
@@ -24,6 +24,7 @@ from llama_stack.apis.inference import (
ResponseFormat,
SamplingParams,
ToolChoice,
+ ToolConfig,
ToolDefinition,
ToolPromptFormat,
)
@@ -128,6 +129,7 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
+ tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator:
model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest(
@@ -140,6 +142,7 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
response_format=response_format,
stream=stream,
logprobs=logprobs,
+ tool_config=tool_config,
)
if stream:
diff --git a/llama_stack/providers/remote/inference/databricks/databricks.py b/llama_stack/providers/remote/inference/databricks/databricks.py
index 2ed3618c5..37070b4ce 100644
--- a/llama_stack/providers/remote/inference/databricks/databricks.py
+++ b/llama_stack/providers/remote/inference/databricks/databricks.py
@@ -89,16 +89,16 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
+ tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator:
request = ChatCompletionRequest(
model=model,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
- tool_choice=tool_choice,
- tool_prompt_format=tool_prompt_format,
stream=stream,
logprobs=logprobs,
+ tool_config=tool_config,
)
client = OpenAI(base_url=self.config.url, api_key=self.config.api_token)
diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py
index af3a7fce5..d47c035b8 100644
--- a/llama_stack/providers/remote/inference/fireworks/fireworks.py
+++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py
@@ -25,6 +25,7 @@ from llama_stack.apis.inference import (
ResponseFormatType,
SamplingParams,
ToolChoice,
+ ToolConfig,
ToolDefinition,
ToolPromptFormat,
)
@@ -204,6 +205,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
+ tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator:
model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest(
@@ -211,11 +213,10 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
- tool_choice=tool_choice,
- tool_prompt_format=tool_prompt_format,
response_format=response_format,
stream=stream,
logprobs=logprobs,
+ tool_config=tool_config,
)
if stream:
diff --git a/llama_stack/providers/remote/inference/groq/groq.py b/llama_stack/providers/remote/inference/groq/groq.py
index f0220f1c1..461b3ee61 100644
--- a/llama_stack/providers/remote/inference/groq/groq.py
+++ b/llama_stack/providers/remote/inference/groq/groq.py
@@ -99,6 +99,7 @@ class GroqInferenceAdapter(Inference, ModelRegistryHelper, NeedsRequestProviderD
tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
+ tool_config: Optional[ToolConfig] = None,
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
model_id = self.get_provider_model_id(model_id)
if model_id == "llama-3.2-3b-preview":
@@ -115,10 +116,9 @@ class GroqInferenceAdapter(Inference, ModelRegistryHelper, NeedsRequestProviderD
sampling_params=sampling_params,
response_format=response_format,
tools=tools,
- tool_choice=tool_choice,
- tool_prompt_format=tool_prompt_format,
stream=stream,
logprobs=logprobs,
+ tool_config=tool_config,
)
)
diff --git a/llama_stack/providers/remote/inference/groq/groq_utils.py b/llama_stack/providers/remote/inference/groq/groq_utils.py
index acb359a1c..537043d69 100644
--- a/llama_stack/providers/remote/inference/groq/groq_utils.py
+++ b/llama_stack/providers/remote/inference/groq/groq_utils.py
@@ -79,7 +79,7 @@ def convert_chat_completion_request(
# so we exclude it for now
warnings.warn("repetition_penalty is not supported")
- if request.tool_prompt_format != ToolPromptFormat.json:
+ if request.tool_config.tool_prompt_format != ToolPromptFormat.json:
warnings.warn("tool_prompt_format is not used by Groq. Ignoring.")
sampling_options = get_sampling_strategy_options(request.sampling_params)
@@ -93,7 +93,7 @@ def convert_chat_completion_request(
temperature=sampling_options.get("temperature", 1.0),
top_p=sampling_options.get("top_p", 1.0),
tools=[_convert_groq_tool_definition(tool) for tool in request.tools or []],
- tool_choice=request.tool_choice.value if request.tool_choice else None,
+ tool_choice=(request.tool_config.tool_choice.value if request.tool_config.tool_choice else None),
)
diff --git a/llama_stack/providers/remote/inference/nvidia/nvidia.py b/llama_stack/providers/remote/inference/nvidia/nvidia.py
index 0bbfe58b4..69533491e 100644
--- a/llama_stack/providers/remote/inference/nvidia/nvidia.py
+++ b/llama_stack/providers/remote/inference/nvidia/nvidia.py
@@ -171,6 +171,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
+ tool_config: Optional[ToolConfig] = None,
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
if tool_prompt_format:
warnings.warn("tool_prompt_format is not supported by NVIDIA NIM, ignoring")
@@ -184,10 +185,9 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
sampling_params=sampling_params,
response_format=response_format,
tools=tools,
- tool_choice=tool_choice,
- tool_prompt_format=tool_prompt_format,
stream=stream,
logprobs=logprobs,
+ tool_config=tool_config,
),
n=1,
)
diff --git a/llama_stack/providers/remote/inference/nvidia/openai_utils.py b/llama_stack/providers/remote/inference/nvidia/openai_utils.py
index 623d36aa0..0a62a2ab4 100644
--- a/llama_stack/providers/remote/inference/nvidia/openai_utils.py
+++ b/llama_stack/providers/remote/inference/nvidia/openai_utils.py
@@ -282,9 +282,9 @@ async def convert_chat_completion_request(
if request.tools:
payload.update(tools=[_convert_tooldef_to_openai_tool(tool) for tool in request.tools])
- if request.tool_choice:
+ if request.tool_config.tool_choice:
payload.update(
- tool_choice=request.tool_choice.value
+ tool_choice=request.tool_config.tool_choice.value
) # we cannot include tool_choice w/o tools, server will complain
if request.logprobs:
diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py
index d6380cd6f..cff8aa742 100644
--- a/llama_stack/providers/remote/inference/ollama/ollama.py
+++ b/llama_stack/providers/remote/inference/ollama/ollama.py
@@ -29,6 +29,7 @@ from llama_stack.apis.inference import (
ResponseFormat,
SamplingParams,
ToolChoice,
+ ToolConfig,
ToolDefinition,
ToolPromptFormat,
)
@@ -224,6 +225,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
+ tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator:
model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest(
@@ -231,11 +233,10 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
- tool_choice=tool_choice,
- tool_prompt_format=tool_prompt_format,
stream=stream,
logprobs=logprobs,
response_format=response_format,
+ tool_config=tool_config,
)
if stream:
return self._stream_chat_completion(request)
diff --git a/llama_stack/providers/remote/inference/runpod/runpod.py b/llama_stack/providers/remote/inference/runpod/runpod.py
index f6209258d..a62b0c97f 100644
--- a/llama_stack/providers/remote/inference/runpod/runpod.py
+++ b/llama_stack/providers/remote/inference/runpod/runpod.py
@@ -83,10 +83,9 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference):
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
- tool_choice=tool_choice,
- tool_prompt_format=tool_prompt_format,
stream=stream,
logprobs=logprobs,
+ tool_config=tool_config,
)
client = OpenAI(base_url=self.config.url, api_key=self.config.api_token)
diff --git a/llama_stack/providers/remote/inference/sambanova/sambanova.py b/llama_stack/providers/remote/inference/sambanova/sambanova.py
index 6ffbff384..dd697cd62 100644
--- a/llama_stack/providers/remote/inference/sambanova/sambanova.py
+++ b/llama_stack/providers/remote/inference/sambanova/sambanova.py
@@ -125,10 +125,9 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
- tool_choice=tool_choice,
- tool_prompt_format=tool_prompt_format,
stream=stream,
logprobs=logprobs,
+ tool_config=tool_config,
)
request_sambanova = await self.convert_chat_completion_request(request)
diff --git a/llama_stack/providers/remote/inference/tgi/tgi.py b/llama_stack/providers/remote/inference/tgi/tgi.py
index 1ce7ab5eb..2281319b3 100644
--- a/llama_stack/providers/remote/inference/tgi/tgi.py
+++ b/llama_stack/providers/remote/inference/tgi/tgi.py
@@ -26,6 +26,7 @@ from llama_stack.apis.inference import (
ResponseFormatType,
SamplingParams,
ToolChoice,
+ ToolConfig,
ToolDefinition,
ToolPromptFormat,
)
@@ -205,6 +206,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
+ tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator:
model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest(
@@ -212,11 +214,10 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
- tool_choice=tool_choice,
- tool_prompt_format=tool_prompt_format,
response_format=response_format,
stream=stream,
logprobs=logprobs,
+ tool_config=tool_config,
)
if stream:
diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py
index 0b965c861..cf24daf60 100644
--- a/llama_stack/providers/remote/inference/together/together.py
+++ b/llama_stack/providers/remote/inference/together/together.py
@@ -24,6 +24,7 @@ from llama_stack.apis.inference import (
ResponseFormatType,
SamplingParams,
ToolChoice,
+ ToolConfig,
ToolDefinition,
ToolPromptFormat,
)
@@ -194,6 +195,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
+ tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator:
model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest(
@@ -201,11 +203,10 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
- tool_choice=tool_choice,
- tool_prompt_format=tool_prompt_format,
response_format=response_format,
stream=stream,
logprobs=logprobs,
+ tool_config=tool_config,
)
if stream:
diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py
index 9d2d92279..bd3375baf 100644
--- a/llama_stack/providers/remote/inference/vllm/vllm.py
+++ b/llama_stack/providers/remote/inference/vllm/vllm.py
@@ -27,6 +27,7 @@ from llama_stack.apis.inference import (
ResponseFormatType,
SamplingParams,
ToolChoice,
+ ToolConfig,
ToolDefinition,
ToolPromptFormat,
)
@@ -119,6 +120,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
+ tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator:
model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest(
@@ -126,11 +128,10 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
- tool_choice=tool_choice,
- tool_prompt_format=tool_prompt_format,
stream=stream,
logprobs=logprobs,
response_format=response_format,
+ tool_config=tool_config,
)
if stream:
return self._stream_chat_completion(request, self.client)
diff --git a/llama_stack/providers/tests/inference/groq/test_groq_utils.py b/llama_stack/providers/tests/inference/groq/test_groq_utils.py
index 8a8a63b30..a28dd308e 100644
--- a/llama_stack/providers/tests/inference/groq/test_groq_utils.py
+++ b/llama_stack/providers/tests/inference/groq/test_groq_utils.py
@@ -179,7 +179,7 @@ class TestConvertChatCompletionRequest:
def test_includes_tool_choice(self):
request = self._dummy_chat_completion_request()
- request.tool_choice = ToolChoice.required
+ request.tool_config.tool_choice = ToolChoice.required
converted = convert_chat_completion_request(request)
diff --git a/llama_stack/providers/tests/inference/test_prompt_adapter.py b/llama_stack/providers/tests/inference/test_prompt_adapter.py
index 4826e89d5..c087c5df2 100644
--- a/llama_stack/providers/tests/inference/test_prompt_adapter.py
+++ b/llama_stack/providers/tests/inference/test_prompt_adapter.py
@@ -13,12 +13,18 @@ from llama_models.llama3.api.datatypes import (
ToolPromptFormat,
)
-from llama_stack.apis.inference import ChatCompletionRequest, SystemMessage, UserMessage
+from llama_stack.apis.inference import (
+ ChatCompletionRequest,
+ SystemMessage,
+ ToolConfig,
+ UserMessage,
+)
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_messages,
)
MODEL = "Llama3.1-8B-Instruct"
+MODEL3_2 = "Llama3.2-3B-Instruct"
class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
@@ -73,7 +79,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
},
)
],
- tool_prompt_format=ToolPromptFormat.json,
+ tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.json),
)
messages = chat_completion_request_to_messages(request, MODEL)
self.assertEqual(len(messages), 3)
@@ -132,3 +138,101 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
self.assertTrue(messages[0].content.endswith(system_prompt))
self.assertEqual(messages[-1].content, content)
+
+ async def test_repalce_system_message_behavior_builtin_tools(self):
+ content = "Hello !"
+ system_prompt = "You are a pirate"
+ request = ChatCompletionRequest(
+ model=MODEL,
+ messages=[
+ SystemMessage(content=system_prompt),
+ UserMessage(content=content),
+ ],
+ tools=[
+ ToolDefinition(tool_name=BuiltinTool.code_interpreter),
+ ],
+ tool_config=ToolConfig(
+ tool_choice="auto",
+ tool_prompt_format="python_list",
+ system_message_behavior="replace",
+ ),
+ )
+ messages = chat_completion_request_to_messages(request, MODEL3_2)
+ self.assertEqual(len(messages), 2, messages)
+ self.assertTrue(messages[0].content.endswith(system_prompt))
+ self.assertIn("Environment: ipython", messages[0].content)
+ self.assertEqual(messages[-1].content, content)
+
+ async def test_repalce_system_message_behavior_custom_tools(self):
+ content = "Hello !"
+ system_prompt = "You are a pirate"
+ request = ChatCompletionRequest(
+ model=MODEL,
+ messages=[
+ SystemMessage(content=system_prompt),
+ UserMessage(content=content),
+ ],
+ tools=[
+ ToolDefinition(tool_name=BuiltinTool.code_interpreter),
+ ToolDefinition(
+ tool_name="custom1",
+ description="custom1 tool",
+ parameters={
+ "param1": ToolParamDefinition(
+ param_type="str",
+ description="param1 description",
+ required=True,
+ ),
+ },
+ ),
+ ],
+ tool_config=ToolConfig(
+ tool_choice="auto",
+ tool_prompt_format="python_list",
+ system_message_behavior="replace",
+ ),
+ )
+ messages = chat_completion_request_to_messages(request, MODEL3_2)
+
+ self.assertEqual(len(messages), 2, messages)
+ self.assertTrue(messages[0].content.endswith(system_prompt))
+ self.assertIn("Environment: ipython", messages[0].content)
+ self.assertEqual(messages[-1].content, content)
+
+ async def test_replace_system_message_behavior_custom_tools_with_template(self):
+ content = "Hello !"
+ system_prompt = "You are a pirate {{ function_description }}"
+ request = ChatCompletionRequest(
+ model=MODEL,
+ messages=[
+ SystemMessage(content=system_prompt),
+ UserMessage(content=content),
+ ],
+ tools=[
+ ToolDefinition(tool_name=BuiltinTool.code_interpreter),
+ ToolDefinition(
+ tool_name="custom1",
+ description="custom1 tool",
+ parameters={
+ "param1": ToolParamDefinition(
+ param_type="str",
+ description="param1 description",
+ required=True,
+ ),
+ },
+ ),
+ ],
+ tool_config=ToolConfig(
+ tool_choice="auto",
+ tool_prompt_format="python_list",
+ system_message_behavior="replace",
+ ),
+ )
+ messages = chat_completion_request_to_messages(request, MODEL3_2)
+
+ self.assertEqual(len(messages), 2, messages)
+ self.assertIn("Environment: ipython", messages[0].content)
+ self.assertIn("You are a pirate", messages[0].content)
+ # function description is present in the system prompt
+ self.assertIn('"name": "custom1"', messages[0].content)
+ self.assertEqual(messages[-1].content, content)
diff --git a/llama_stack/providers/utils/inference/prompt_adapter.py b/llama_stack/providers/utils/inference/prompt_adapter.py
index 49c6ac7a9..57875e64b 100644
--- a/llama_stack/providers/utils/inference/prompt_adapter.py
+++ b/llama_stack/providers/utils/inference/prompt_adapter.py
@@ -49,6 +49,7 @@ from llama_stack.apis.inference import (
SystemMessage,
ToolChoice,
UserMessage,
+ SystemMessageBehavior,
)
from llama_stack.providers.utils.inference import supported_inference_models
@@ -309,7 +310,7 @@ def response_format_prompt(fmt: Optional[ResponseFormat]):
def augment_messages_for_tools_llama_3_1(
request: ChatCompletionRequest,
) -> List[Message]:
- assert request.tool_choice == ToolChoice.auto, "Only `ToolChoice.auto` supported"
+ assert request.tool_config.tool_choice == ToolChoice.auto, "Only `ToolChoice.auto` supported"
existing_messages = request.messages
existing_system_message = None
@@ -354,7 +355,7 @@ def augment_messages_for_tools_llama_3_1(
has_custom_tools = any(isinstance(dfn.tool_name, str) for dfn in request.tools)
if has_custom_tools:
- fmt = request.tool_prompt_format or ToolPromptFormat.json
+ fmt = request.tool_config.tool_prompt_format or ToolPromptFormat.json
if fmt == ToolPromptFormat.json:
tool_gen = JsonCustomToolGenerator()
elif fmt == ToolPromptFormat.function_tag:
@@ -375,7 +376,7 @@ def augment_messages_for_tools_llama_3_1(
def augment_messages_for_tools_llama_3_2(
request: ChatCompletionRequest,
) -> List[Message]:
- assert request.tool_choice == ToolChoice.auto, "Only `ToolChoice.auto` supported"
+ assert request.tool_config.tool_choice == ToolChoice.auto, "Only `ToolChoice.auto` supported"
existing_messages = request.messages
existing_system_message = None
@@ -403,20 +404,25 @@ def augment_messages_for_tools_llama_3_2(
custom_tools = [dfn for dfn in request.tools if isinstance(dfn.tool_name, str)]
if custom_tools:
- fmt = request.tool_prompt_format or ToolPromptFormat.python_list
+ fmt = request.tool_config.tool_prompt_format or ToolPromptFormat.python_list
if fmt != ToolPromptFormat.python_list:
- raise ValueError(f"Non supported ToolPromptFormat {request.tool_prompt_format}")
+ raise ValueError(f"Non supported ToolPromptFormat {request.tool_config.tool_prompt_format}")
- tool_gen = PythonListCustomToolGenerator()
- tool_template = tool_gen.gen(custom_tools)
+ system_prompt = None
+ if existing_system_message and request.tool_config.system_message_behavior == SystemMessageBehavior.replace:
+ system_prompt = existing_system_message.content
+
+ tool_template = PythonListCustomToolGenerator().gen(custom_tools, system_prompt)
sys_content += tool_template.render()
sys_content += "\n"
- if existing_system_message:
+ if existing_system_message and (
+ request.tool_config.system_message_behavior == SystemMessageBehavior.append or not custom_tools
+ ):
sys_content += interleaved_content_as_str(existing_system_message.content, sep="\n")
- messages.append(SystemMessage(content=sys_content))
+ messages.append(SystemMessage(content=sys_content.strip("\n")))
# Add back existing messages from the request
messages += existing_messages