diff --git a/docs/resources/llama-stack-spec.html b/docs/resources/llama-stack-spec.html index 85ae516f5..b22ce5914 100644 --- a/docs/resources/llama-stack-spec.html +++ b/docs/resources/llama-stack-spec.html @@ -3161,6 +3161,43 @@ "job_uuid" ] }, + "ToolConfig": { + "type": "object", + "properties": { + "tool_choice": { + "type": "string", + "enum": [ + "auto", + "required" + ], + "default": "auto", + "description": "(Optional) Whether tool use is required or automatic. Defaults to ToolChoice.auto." + }, + "tool_prompt_format": { + "type": "string", + "enum": [ + "json", + "function_tag", + "python_list" + ], + "description": "(Optional) Instructs the model how to format tool calls. By default, Llama Stack will attempt to use a format that is best adapted to the model. - `ToolPromptFormat.json`: The tool calls are formatted as a JSON object. - `ToolPromptFormat.function_tag`: The tool calls are enclosed in a tag. - `ToolPromptFormat.python_list`: The tool calls are output as Python syntax -- a list of function calls." + }, + "system_message_behavior": { + "type": "string", + "enum": [ + "append", + "replace" + ], + "default": "append", + "description": "(Optional) Config for how to override the default system prompt. - `SystemMessageBehavior.append`: Appends the provided system message to the default system prompt. - `SystemMessageBehavior.replace`: Replaces the default system prompt with the provided system message. The system message can include the string '{{function_definitions}}' to indicate where the function definitions should be inserted." + } + }, + "additionalProperties": false, + "required": [ + "system_message_behavior" + ], + "title": "Configuration for tool use." + }, "ChatCompletionRequest": { "type": "object", "properties": { @@ -3192,7 +3229,7 @@ "auto", "required" ], - "description": "(Optional) Whether tool use is required or automatic. Defaults to ToolChoice.auto." + "description": "(Optional) Whether tool use is required or automatic. Defaults to ToolChoice.auto. .. deprecated:: Use tool_config instead." }, "tool_prompt_format": { "type": "string", @@ -3201,7 +3238,7 @@ "function_tag", "python_list" ], - "description": "(Optional) Instructs the model how to format tool calls. By default, Llama Stack will attempt to use a format that is best adapted to the model. - `ToolPromptFormat.json`: The tool calls are formatted as a JSON object. - `ToolPromptFormat.function_tag`: The tool calls are enclosed in a tag. - `ToolPromptFormat.python_list`: The tool calls are output as Python syntax -- a list of function calls." + "description": "(Optional) Instructs the model how to format tool calls. By default, Llama Stack will attempt to use a format that is best adapted to the model. - `ToolPromptFormat.json`: The tool calls are formatted as a JSON object. - `ToolPromptFormat.function_tag`: The tool calls are enclosed in a tag. - `ToolPromptFormat.python_list`: The tool calls are output as Python syntax -- a list of function calls. .. deprecated:: Use tool_config instead." }, "response_format": { "$ref": "#/components/schemas/ResponseFormat", @@ -3222,6 +3259,10 @@ }, "additionalProperties": false, "description": "(Optional) If specified, log probabilities for each token position will be returned." + }, + "tool_config": { + "$ref": "#/components/schemas/ToolConfig", + "description": "(Optional) Configuration for tool use." } }, "additionalProperties": false, diff --git a/docs/resources/llama-stack-spec.yaml b/docs/resources/llama-stack-spec.yaml index 2a95acf38..c49c948de 100644 --- a/docs/resources/llama-stack-spec.yaml +++ b/docs/resources/llama-stack-spec.yaml @@ -1956,6 +1956,46 @@ components: additionalProperties: false required: - job_uuid + ToolConfig: + type: object + properties: + tool_choice: + type: string + enum: + - auto + - required + default: auto + description: >- + (Optional) Whether tool use is required or automatic. Defaults to ToolChoice.auto. + tool_prompt_format: + type: string + enum: + - json + - function_tag + - python_list + description: >- + (Optional) Instructs the model how to format tool calls. By default, Llama + Stack will attempt to use a format that is best adapted to the model. + - `ToolPromptFormat.json`: The tool calls are formatted as a JSON object. + - `ToolPromptFormat.function_tag`: The tool calls are enclosed in a + tag. - `ToolPromptFormat.python_list`: The tool calls are output as Python + syntax -- a list of function calls. + system_message_behavior: + type: string + enum: + - append + - replace + default: append + description: >- + (Optional) Config for how to override the default system prompt. - `SystemMessageBehavior.append`: + Appends the provided system message to the default system prompt. - `SystemMessageBehavior.replace`: + Replaces the default system prompt with the provided system message. The + system message can include the string '{{function_definitions}}' to indicate + where the function definitions should be inserted. + additionalProperties: false + required: + - system_message_behavior + title: Configuration for tool use. ChatCompletionRequest: type: object properties: @@ -1986,6 +2026,7 @@ components: - required description: >- (Optional) Whether tool use is required or automatic. Defaults to ToolChoice.auto. + .. deprecated:: Use tool_config instead. tool_prompt_format: type: string enum: @@ -1998,7 +2039,7 @@ components: - `ToolPromptFormat.json`: The tool calls are formatted as a JSON object. - `ToolPromptFormat.function_tag`: The tool calls are enclosed in a tag. - `ToolPromptFormat.python_list`: The tool calls are output as Python - syntax -- a list of function calls. + syntax -- a list of function calls. .. deprecated:: Use tool_config instead. response_format: $ref: '#/components/schemas/ResponseFormat' description: >- @@ -2024,6 +2065,9 @@ components: description: >- (Optional) If specified, log probabilities for each token position will be returned. + tool_config: + $ref: '#/components/schemas/ToolConfig' + description: (Optional) Configuration for tool use. additionalProperties: false required: - model_id diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index 6398f74e8..d4f13d65c 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -308,14 +308,49 @@ class CompletionResponseStreamChunk(BaseModel): logprobs: Optional[List[TokenLogProbs]] = None +class SystemMessageBehavior(Enum): + """Config for how to override the default system prompt. + + :cvar append: Appends the provided system message to the default system prompt: + https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_2/#-function-definitions-in-the-system-prompt- + :cvar replace: Replaces the default system prompt with the provided system message. The system message can include the string + '{{function_definitions}}' to indicate where the function definitions should be inserted. + """ + + append = "append" + replace = "replace" + + +@json_schema_type +class ToolConfig(BaseModel): + """Configuration for tool use. + + :param tool_choice: (Optional) Whether tool use is required or automatic. Defaults to ToolChoice.auto. + :param tool_prompt_format: (Optional) Instructs the model how to format tool calls. By default, Llama Stack will attempt to use a format that is best adapted to the model. + - `ToolPromptFormat.json`: The tool calls are formatted as a JSON object. + - `ToolPromptFormat.function_tag`: The tool calls are enclosed in a tag. + - `ToolPromptFormat.python_list`: The tool calls are output as Python syntax -- a list of function calls. + :param system_message_behavior: (Optional) Config for how to override the default system prompt. + - `SystemMessageBehavior.append`: Appends the provided system message to the default system prompt. + - `SystemMessageBehavior.replace`: Replaces the default system prompt with the provided system message. The system message can include the string + '{{function_definitions}}' to indicate where the function definitions should be inserted. + """ + + tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto) + tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None) + system_message_behavior: SystemMessageBehavior = Field(default=SystemMessageBehavior.append) + + # This is an internally used class +@json_schema_type class ChatCompletionRequest(BaseModel): model: str messages: List[Message] sampling_params: Optional[SamplingParams] = SamplingParams() + tools: Optional[List[ToolDefinition]] = Field(default_factory=list) - tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto) - tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None) + tool_config: Optional[ToolConfig] = Field(default_factory=ToolConfig) + response_format: Optional[ResponseFormat] = None stream: Optional[bool] = False logprobs: Optional[LogProbConfig] = None @@ -404,6 +439,7 @@ class Inference(Protocol): response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, + tool_config: Optional[ToolConfig] = None, ) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]: """Generate a chat completion for the given messages using the specified model. @@ -412,15 +448,20 @@ class Inference(Protocol): :param sampling_params: Parameters to control the sampling strategy :param tools: (Optional) List of tool definitions available to the model :param tool_choice: (Optional) Whether tool use is required or automatic. Defaults to ToolChoice.auto. + .. deprecated:: + Use tool_config instead. :param tool_prompt_format: (Optional) Instructs the model how to format tool calls. By default, Llama Stack will attempt to use a format that is best adapted to the model. - `ToolPromptFormat.json`: The tool calls are formatted as a JSON object. - `ToolPromptFormat.function_tag`: The tool calls are enclosed in a tag. - `ToolPromptFormat.python_list`: The tool calls are output as Python syntax -- a list of function calls. + .. deprecated:: + Use tool_config instead. :param response_format: (Optional) Grammar specification for guided (structured) decoding. There are two options: - `ResponseFormat.json_schema`: The grammar is a JSON schema. Most providers support this format. - `ResponseFormat.grammar`: The grammar is a BNF grammar. This format is more flexible, but not all providers support it. :param stream: (Optional) If True, generate an SSE event stream of the response. Defaults to False. :param logprobs: (Optional) If specified, log probabilities for each token position will be returned. + :param tool_config: (Optional) Configuration for tool use. :returns: If stream=False, returns a ChatCompletionResponse with the full completion. If stream=True, returns an SSE event stream of ChatCompletionResponseStreamChunk """ diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index c5a7e3af6..6cddcf73c 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -24,6 +24,7 @@ from llama_stack.apis.inference import ( ResponseFormat, SamplingParams, ToolChoice, + ToolConfig, ToolDefinition, ToolPromptFormat, ) @@ -132,12 +133,23 @@ class InferenceRouter(Inference): tool_prompt_format: Optional[ToolPromptFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, + tool_config: Optional[ToolConfig] = None, ) -> AsyncGenerator: model = await self.routing_table.get_model(model_id) if model is None: raise ValueError(f"Model '{model_id}' not found") if model.model_type == ModelType.embedding: raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions") + if tool_config: + if tool_choice != tool_config.tool_choice: + raise ValueError("tool_choice and tool_config.tool_choice must match") + if tool_prompt_format != tool_config.tool_prompt_format: + raise ValueError("tool_prompt_format and tool_config.tool_prompt_format must match") + else: + tool_config = ToolConfig( + tool_choice=tool_choice, + tool_prompt_format=tool_prompt_format, + ) params = dict( model_id=model_id, messages=messages, @@ -148,6 +160,7 @@ class InferenceRouter(Inference): response_format=response_format, stream=stream, logprobs=logprobs, + tool_config=tool_config, ) provider = self.routing_table.get_provider_impl(model_id) if stream: diff --git a/llama_stack/providers/inline/inference/meta_reference/generation.py b/llama_stack/providers/inline/inference/meta_reference/generation.py index 4048972df..51c10b0a8 100644 --- a/llama_stack/providers/inline/inference/meta_reference/generation.py +++ b/llama_stack/providers/inline/inference/meta_reference/generation.py @@ -400,7 +400,7 @@ class Llama: yield from self.generate( model_input=self.formatter.encode_dialog_prompt( request.messages, - request.tool_prompt_format, + request.tool_config.tool_prompt_format, ), max_gen_len=max_gen_len, temperature=temperature, diff --git a/llama_stack/providers/inline/inference/meta_reference/inference.py b/llama_stack/providers/inline/inference/meta_reference/inference.py index 7e3508148..3caf4e2a5 100644 --- a/llama_stack/providers/inline/inference/meta_reference/inference.py +++ b/llama_stack/providers/inline/inference/meta_reference/inference.py @@ -38,6 +38,7 @@ from llama_stack.apis.inference import ( ResponseFormat, TokenLogProbs, ToolChoice, + ToolConfig, ) from llama_stack.apis.models import Model, ModelType from llama_stack.providers.datatypes import ModelsProtocolPrivate @@ -252,6 +253,7 @@ class MetaReferenceInferenceImpl( tool_prompt_format: Optional[ToolPromptFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, + tool_config: Optional[ToolConfig] = None, ) -> AsyncGenerator: if logprobs: assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}" @@ -262,11 +264,10 @@ class MetaReferenceInferenceImpl( messages=messages, sampling_params=sampling_params, tools=tools or [], - tool_choice=tool_choice, - tool_prompt_format=tool_prompt_format, response_format=response_format, stream=stream, logprobs=logprobs, + tool_config=tool_config, ) self.check_model(request) diff --git a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py index 3920ee1ad..d34befbd9 100644 --- a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py +++ b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py @@ -17,6 +17,7 @@ from llama_stack.apis.inference import ( ToolChoice, ToolDefinition, ToolPromptFormat, + ToolConfig, ) from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate from llama_stack.providers.utils.inference.embedding_mixin import ( @@ -71,5 +72,6 @@ class SentenceTransformersInferenceImpl( tool_prompt_format: Optional[ToolPromptFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, + tool_config: Optional[ToolConfig] = None, ) -> AsyncGenerator: raise ValueError("Sentence transformers don't support chat completion") diff --git a/llama_stack/providers/inline/inference/vllm/vllm.py b/llama_stack/providers/inline/inference/vllm/vllm.py index 6f35d0c59..691737c15 100644 --- a/llama_stack/providers/inline/inference/vllm/vllm.py +++ b/llama_stack/providers/inline/inference/vllm/vllm.py @@ -30,6 +30,7 @@ from llama_stack.apis.inference import ( ResponseFormat, SamplingParams, ToolChoice, + ToolConfig, ToolDefinition, ToolPromptFormat, ) @@ -159,6 +160,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, + tool_config: Optional[ToolConfig] = None, ) -> ChatCompletionResponse | ChatCompletionResponseStreamChunk: assert self.engine is not None @@ -167,10 +169,9 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): messages=messages, sampling_params=sampling_params, tools=tools or [], - tool_choice=tool_choice, - tool_prompt_format=tool_prompt_format, stream=stream, logprobs=logprobs, + tool_config=tool_config, ) log.info("Sampling params: %s", sampling_params) diff --git a/llama_stack/providers/remote/inference/bedrock/bedrock.py b/llama_stack/providers/remote/inference/bedrock/bedrock.py index c1297d022..03a0a40c3 100644 --- a/llama_stack/providers/remote/inference/bedrock/bedrock.py +++ b/llama_stack/providers/remote/inference/bedrock/bedrock.py @@ -24,6 +24,7 @@ from llama_stack.apis.inference import ( ResponseFormat, SamplingParams, ToolChoice, + ToolConfig, ToolDefinition, ToolPromptFormat, ) @@ -102,6 +103,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference): tool_prompt_format: Optional[ToolPromptFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, + tool_config: Optional[ToolConfig] = None, ) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]: model = await self.model_store.get_model(model_id) request = ChatCompletionRequest( @@ -109,11 +111,10 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference): messages=messages, sampling_params=sampling_params, tools=tools or [], - tool_choice=tool_choice, - tool_prompt_format=tool_prompt_format, response_format=response_format, stream=stream, logprobs=logprobs, + tool_config=tool_config, ) if stream: diff --git a/llama_stack/providers/remote/inference/cerebras/cerebras.py b/llama_stack/providers/remote/inference/cerebras/cerebras.py index eb77741e0..bd12c56c8 100644 --- a/llama_stack/providers/remote/inference/cerebras/cerebras.py +++ b/llama_stack/providers/remote/inference/cerebras/cerebras.py @@ -24,6 +24,7 @@ from llama_stack.apis.inference import ( ResponseFormat, SamplingParams, ToolChoice, + ToolConfig, ToolDefinition, ToolPromptFormat, ) @@ -128,6 +129,7 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference): response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, + tool_config: Optional[ToolConfig] = None, ) -> AsyncGenerator: model = await self.model_store.get_model(model_id) request = ChatCompletionRequest( @@ -140,6 +142,7 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference): response_format=response_format, stream=stream, logprobs=logprobs, + tool_config=tool_config, ) if stream: diff --git a/llama_stack/providers/remote/inference/databricks/databricks.py b/llama_stack/providers/remote/inference/databricks/databricks.py index 2ed3618c5..37070b4ce 100644 --- a/llama_stack/providers/remote/inference/databricks/databricks.py +++ b/llama_stack/providers/remote/inference/databricks/databricks.py @@ -89,16 +89,16 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference): tool_prompt_format: Optional[ToolPromptFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, + tool_config: Optional[ToolConfig] = None, ) -> AsyncGenerator: request = ChatCompletionRequest( model=model, messages=messages, sampling_params=sampling_params, tools=tools or [], - tool_choice=tool_choice, - tool_prompt_format=tool_prompt_format, stream=stream, logprobs=logprobs, + tool_config=tool_config, ) client = OpenAI(base_url=self.config.url, api_key=self.config.api_token) diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py index af3a7fce5..d47c035b8 100644 --- a/llama_stack/providers/remote/inference/fireworks/fireworks.py +++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py @@ -25,6 +25,7 @@ from llama_stack.apis.inference import ( ResponseFormatType, SamplingParams, ToolChoice, + ToolConfig, ToolDefinition, ToolPromptFormat, ) @@ -204,6 +205,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, + tool_config: Optional[ToolConfig] = None, ) -> AsyncGenerator: model = await self.model_store.get_model(model_id) request = ChatCompletionRequest( @@ -211,11 +213,10 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv messages=messages, sampling_params=sampling_params, tools=tools or [], - tool_choice=tool_choice, - tool_prompt_format=tool_prompt_format, response_format=response_format, stream=stream, logprobs=logprobs, + tool_config=tool_config, ) if stream: diff --git a/llama_stack/providers/remote/inference/groq/groq.py b/llama_stack/providers/remote/inference/groq/groq.py index f0220f1c1..461b3ee61 100644 --- a/llama_stack/providers/remote/inference/groq/groq.py +++ b/llama_stack/providers/remote/inference/groq/groq.py @@ -99,6 +99,7 @@ class GroqInferenceAdapter(Inference, ModelRegistryHelper, NeedsRequestProviderD tool_prompt_format: Optional[ToolPromptFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, + tool_config: Optional[ToolConfig] = None, ) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]: model_id = self.get_provider_model_id(model_id) if model_id == "llama-3.2-3b-preview": @@ -115,10 +116,9 @@ class GroqInferenceAdapter(Inference, ModelRegistryHelper, NeedsRequestProviderD sampling_params=sampling_params, response_format=response_format, tools=tools, - tool_choice=tool_choice, - tool_prompt_format=tool_prompt_format, stream=stream, logprobs=logprobs, + tool_config=tool_config, ) ) diff --git a/llama_stack/providers/remote/inference/groq/groq_utils.py b/llama_stack/providers/remote/inference/groq/groq_utils.py index acb359a1c..537043d69 100644 --- a/llama_stack/providers/remote/inference/groq/groq_utils.py +++ b/llama_stack/providers/remote/inference/groq/groq_utils.py @@ -79,7 +79,7 @@ def convert_chat_completion_request( # so we exclude it for now warnings.warn("repetition_penalty is not supported") - if request.tool_prompt_format != ToolPromptFormat.json: + if request.tool_config.tool_prompt_format != ToolPromptFormat.json: warnings.warn("tool_prompt_format is not used by Groq. Ignoring.") sampling_options = get_sampling_strategy_options(request.sampling_params) @@ -93,7 +93,7 @@ def convert_chat_completion_request( temperature=sampling_options.get("temperature", 1.0), top_p=sampling_options.get("top_p", 1.0), tools=[_convert_groq_tool_definition(tool) for tool in request.tools or []], - tool_choice=request.tool_choice.value if request.tool_choice else None, + tool_choice=(request.tool_config.tool_choice.value if request.tool_config.tool_choice else None), ) diff --git a/llama_stack/providers/remote/inference/nvidia/nvidia.py b/llama_stack/providers/remote/inference/nvidia/nvidia.py index 0bbfe58b4..69533491e 100644 --- a/llama_stack/providers/remote/inference/nvidia/nvidia.py +++ b/llama_stack/providers/remote/inference/nvidia/nvidia.py @@ -171,6 +171,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): tool_prompt_format: Optional[ToolPromptFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, + tool_config: Optional[ToolConfig] = None, ) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]: if tool_prompt_format: warnings.warn("tool_prompt_format is not supported by NVIDIA NIM, ignoring") @@ -184,10 +185,9 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): sampling_params=sampling_params, response_format=response_format, tools=tools, - tool_choice=tool_choice, - tool_prompt_format=tool_prompt_format, stream=stream, logprobs=logprobs, + tool_config=tool_config, ), n=1, ) diff --git a/llama_stack/providers/remote/inference/nvidia/openai_utils.py b/llama_stack/providers/remote/inference/nvidia/openai_utils.py index 623d36aa0..0a62a2ab4 100644 --- a/llama_stack/providers/remote/inference/nvidia/openai_utils.py +++ b/llama_stack/providers/remote/inference/nvidia/openai_utils.py @@ -282,9 +282,9 @@ async def convert_chat_completion_request( if request.tools: payload.update(tools=[_convert_tooldef_to_openai_tool(tool) for tool in request.tools]) - if request.tool_choice: + if request.tool_config.tool_choice: payload.update( - tool_choice=request.tool_choice.value + tool_choice=request.tool_config.tool_choice.value ) # we cannot include tool_choice w/o tools, server will complain if request.logprobs: diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index d6380cd6f..cff8aa742 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -29,6 +29,7 @@ from llama_stack.apis.inference import ( ResponseFormat, SamplingParams, ToolChoice, + ToolConfig, ToolDefinition, ToolPromptFormat, ) @@ -224,6 +225,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): tool_prompt_format: Optional[ToolPromptFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, + tool_config: Optional[ToolConfig] = None, ) -> AsyncGenerator: model = await self.model_store.get_model(model_id) request = ChatCompletionRequest( @@ -231,11 +233,10 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): messages=messages, sampling_params=sampling_params, tools=tools or [], - tool_choice=tool_choice, - tool_prompt_format=tool_prompt_format, stream=stream, logprobs=logprobs, response_format=response_format, + tool_config=tool_config, ) if stream: return self._stream_chat_completion(request) diff --git a/llama_stack/providers/remote/inference/runpod/runpod.py b/llama_stack/providers/remote/inference/runpod/runpod.py index f6209258d..a62b0c97f 100644 --- a/llama_stack/providers/remote/inference/runpod/runpod.py +++ b/llama_stack/providers/remote/inference/runpod/runpod.py @@ -83,10 +83,9 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference): messages=messages, sampling_params=sampling_params, tools=tools or [], - tool_choice=tool_choice, - tool_prompt_format=tool_prompt_format, stream=stream, logprobs=logprobs, + tool_config=tool_config, ) client = OpenAI(base_url=self.config.url, api_key=self.config.api_token) diff --git a/llama_stack/providers/remote/inference/sambanova/sambanova.py b/llama_stack/providers/remote/inference/sambanova/sambanova.py index 6ffbff384..dd697cd62 100644 --- a/llama_stack/providers/remote/inference/sambanova/sambanova.py +++ b/llama_stack/providers/remote/inference/sambanova/sambanova.py @@ -125,10 +125,9 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference): messages=messages, sampling_params=sampling_params, tools=tools or [], - tool_choice=tool_choice, - tool_prompt_format=tool_prompt_format, stream=stream, logprobs=logprobs, + tool_config=tool_config, ) request_sambanova = await self.convert_chat_completion_request(request) diff --git a/llama_stack/providers/remote/inference/tgi/tgi.py b/llama_stack/providers/remote/inference/tgi/tgi.py index 1ce7ab5eb..2281319b3 100644 --- a/llama_stack/providers/remote/inference/tgi/tgi.py +++ b/llama_stack/providers/remote/inference/tgi/tgi.py @@ -26,6 +26,7 @@ from llama_stack.apis.inference import ( ResponseFormatType, SamplingParams, ToolChoice, + ToolConfig, ToolDefinition, ToolPromptFormat, ) @@ -205,6 +206,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, + tool_config: Optional[ToolConfig] = None, ) -> AsyncGenerator: model = await self.model_store.get_model(model_id) request = ChatCompletionRequest( @@ -212,11 +214,10 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): messages=messages, sampling_params=sampling_params, tools=tools or [], - tool_choice=tool_choice, - tool_prompt_format=tool_prompt_format, response_format=response_format, stream=stream, logprobs=logprobs, + tool_config=tool_config, ) if stream: diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py index 0b965c861..cf24daf60 100644 --- a/llama_stack/providers/remote/inference/together/together.py +++ b/llama_stack/providers/remote/inference/together/together.py @@ -24,6 +24,7 @@ from llama_stack.apis.inference import ( ResponseFormatType, SamplingParams, ToolChoice, + ToolConfig, ToolDefinition, ToolPromptFormat, ) @@ -194,6 +195,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, + tool_config: Optional[ToolConfig] = None, ) -> AsyncGenerator: model = await self.model_store.get_model(model_id) request = ChatCompletionRequest( @@ -201,11 +203,10 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi messages=messages, sampling_params=sampling_params, tools=tools or [], - tool_choice=tool_choice, - tool_prompt_format=tool_prompt_format, response_format=response_format, stream=stream, logprobs=logprobs, + tool_config=tool_config, ) if stream: diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index 9d2d92279..bd3375baf 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -27,6 +27,7 @@ from llama_stack.apis.inference import ( ResponseFormatType, SamplingParams, ToolChoice, + ToolConfig, ToolDefinition, ToolPromptFormat, ) @@ -119,6 +120,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): tool_prompt_format: Optional[ToolPromptFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, + tool_config: Optional[ToolConfig] = None, ) -> AsyncGenerator: model = await self.model_store.get_model(model_id) request = ChatCompletionRequest( @@ -126,11 +128,10 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): messages=messages, sampling_params=sampling_params, tools=tools or [], - tool_choice=tool_choice, - tool_prompt_format=tool_prompt_format, stream=stream, logprobs=logprobs, response_format=response_format, + tool_config=tool_config, ) if stream: return self._stream_chat_completion(request, self.client) diff --git a/llama_stack/providers/tests/inference/groq/test_groq_utils.py b/llama_stack/providers/tests/inference/groq/test_groq_utils.py index 8a8a63b30..a28dd308e 100644 --- a/llama_stack/providers/tests/inference/groq/test_groq_utils.py +++ b/llama_stack/providers/tests/inference/groq/test_groq_utils.py @@ -179,7 +179,7 @@ class TestConvertChatCompletionRequest: def test_includes_tool_choice(self): request = self._dummy_chat_completion_request() - request.tool_choice = ToolChoice.required + request.tool_config.tool_choice = ToolChoice.required converted = convert_chat_completion_request(request) diff --git a/llama_stack/providers/tests/inference/test_prompt_adapter.py b/llama_stack/providers/tests/inference/test_prompt_adapter.py index 4826e89d5..c087c5df2 100644 --- a/llama_stack/providers/tests/inference/test_prompt_adapter.py +++ b/llama_stack/providers/tests/inference/test_prompt_adapter.py @@ -13,12 +13,18 @@ from llama_models.llama3.api.datatypes import ( ToolPromptFormat, ) -from llama_stack.apis.inference import ChatCompletionRequest, SystemMessage, UserMessage +from llama_stack.apis.inference import ( + ChatCompletionRequest, + SystemMessage, + ToolConfig, + UserMessage, +) from llama_stack.providers.utils.inference.prompt_adapter import ( chat_completion_request_to_messages, ) MODEL = "Llama3.1-8B-Instruct" +MODEL3_2 = "Llama3.2-3B-Instruct" class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase): @@ -73,7 +79,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase): }, ) ], - tool_prompt_format=ToolPromptFormat.json, + tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.json), ) messages = chat_completion_request_to_messages(request, MODEL) self.assertEqual(len(messages), 3) @@ -132,3 +138,101 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase): self.assertTrue(messages[0].content.endswith(system_prompt)) self.assertEqual(messages[-1].content, content) + + async def test_repalce_system_message_behavior_builtin_tools(self): + content = "Hello !" + system_prompt = "You are a pirate" + request = ChatCompletionRequest( + model=MODEL, + messages=[ + SystemMessage(content=system_prompt), + UserMessage(content=content), + ], + tools=[ + ToolDefinition(tool_name=BuiltinTool.code_interpreter), + ], + tool_config=ToolConfig( + tool_choice="auto", + tool_prompt_format="python_list", + system_message_behavior="replace", + ), + ) + messages = chat_completion_request_to_messages(request, MODEL3_2) + self.assertEqual(len(messages), 2, messages) + self.assertTrue(messages[0].content.endswith(system_prompt)) + self.assertIn("Environment: ipython", messages[0].content) + self.assertEqual(messages[-1].content, content) + + async def test_repalce_system_message_behavior_custom_tools(self): + content = "Hello !" + system_prompt = "You are a pirate" + request = ChatCompletionRequest( + model=MODEL, + messages=[ + SystemMessage(content=system_prompt), + UserMessage(content=content), + ], + tools=[ + ToolDefinition(tool_name=BuiltinTool.code_interpreter), + ToolDefinition( + tool_name="custom1", + description="custom1 tool", + parameters={ + "param1": ToolParamDefinition( + param_type="str", + description="param1 description", + required=True, + ), + }, + ), + ], + tool_config=ToolConfig( + tool_choice="auto", + tool_prompt_format="python_list", + system_message_behavior="replace", + ), + ) + messages = chat_completion_request_to_messages(request, MODEL3_2) + + self.assertEqual(len(messages), 2, messages) + self.assertTrue(messages[0].content.endswith(system_prompt)) + self.assertIn("Environment: ipython", messages[0].content) + self.assertEqual(messages[-1].content, content) + + async def test_replace_system_message_behavior_custom_tools_with_template(self): + content = "Hello !" + system_prompt = "You are a pirate {{ function_description }}" + request = ChatCompletionRequest( + model=MODEL, + messages=[ + SystemMessage(content=system_prompt), + UserMessage(content=content), + ], + tools=[ + ToolDefinition(tool_name=BuiltinTool.code_interpreter), + ToolDefinition( + tool_name="custom1", + description="custom1 tool", + parameters={ + "param1": ToolParamDefinition( + param_type="str", + description="param1 description", + required=True, + ), + }, + ), + ], + tool_config=ToolConfig( + tool_choice="auto", + tool_prompt_format="python_list", + system_message_behavior="replace", + ), + ) + messages = chat_completion_request_to_messages(request, MODEL3_2) + + self.assertEqual(len(messages), 2, messages) + self.assertIn("Environment: ipython", messages[0].content) + self.assertIn("You are a pirate", messages[0].content) + # function description is present in the system prompt + self.assertIn('"name": "custom1"', messages[0].content) + self.assertEqual(messages[-1].content, content) diff --git a/llama_stack/providers/utils/inference/prompt_adapter.py b/llama_stack/providers/utils/inference/prompt_adapter.py index 49c6ac7a9..57875e64b 100644 --- a/llama_stack/providers/utils/inference/prompt_adapter.py +++ b/llama_stack/providers/utils/inference/prompt_adapter.py @@ -49,6 +49,7 @@ from llama_stack.apis.inference import ( SystemMessage, ToolChoice, UserMessage, + SystemMessageBehavior, ) from llama_stack.providers.utils.inference import supported_inference_models @@ -309,7 +310,7 @@ def response_format_prompt(fmt: Optional[ResponseFormat]): def augment_messages_for_tools_llama_3_1( request: ChatCompletionRequest, ) -> List[Message]: - assert request.tool_choice == ToolChoice.auto, "Only `ToolChoice.auto` supported" + assert request.tool_config.tool_choice == ToolChoice.auto, "Only `ToolChoice.auto` supported" existing_messages = request.messages existing_system_message = None @@ -354,7 +355,7 @@ def augment_messages_for_tools_llama_3_1( has_custom_tools = any(isinstance(dfn.tool_name, str) for dfn in request.tools) if has_custom_tools: - fmt = request.tool_prompt_format or ToolPromptFormat.json + fmt = request.tool_config.tool_prompt_format or ToolPromptFormat.json if fmt == ToolPromptFormat.json: tool_gen = JsonCustomToolGenerator() elif fmt == ToolPromptFormat.function_tag: @@ -375,7 +376,7 @@ def augment_messages_for_tools_llama_3_1( def augment_messages_for_tools_llama_3_2( request: ChatCompletionRequest, ) -> List[Message]: - assert request.tool_choice == ToolChoice.auto, "Only `ToolChoice.auto` supported" + assert request.tool_config.tool_choice == ToolChoice.auto, "Only `ToolChoice.auto` supported" existing_messages = request.messages existing_system_message = None @@ -403,20 +404,25 @@ def augment_messages_for_tools_llama_3_2( custom_tools = [dfn for dfn in request.tools if isinstance(dfn.tool_name, str)] if custom_tools: - fmt = request.tool_prompt_format or ToolPromptFormat.python_list + fmt = request.tool_config.tool_prompt_format or ToolPromptFormat.python_list if fmt != ToolPromptFormat.python_list: - raise ValueError(f"Non supported ToolPromptFormat {request.tool_prompt_format}") + raise ValueError(f"Non supported ToolPromptFormat {request.tool_config.tool_prompt_format}") - tool_gen = PythonListCustomToolGenerator() - tool_template = tool_gen.gen(custom_tools) + system_prompt = None + if existing_system_message and request.tool_config.system_message_behavior == SystemMessageBehavior.replace: + system_prompt = existing_system_message.content + + tool_template = PythonListCustomToolGenerator().gen(custom_tools, system_prompt) sys_content += tool_template.render() sys_content += "\n" - if existing_system_message: + if existing_system_message and ( + request.tool_config.system_message_behavior == SystemMessageBehavior.append or not custom_tools + ): sys_content += interleaved_content_as_str(existing_system_message.content, sep="\n") - messages.append(SystemMessage(content=sys_content)) + messages.append(SystemMessage(content=sys_content.strip("\n"))) # Add back existing messages from the request messages += existing_messages