diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index 2debce1a7..fec23adc7 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -310,18 +310,49 @@ class CompletionResponseStreamChunk(BaseModel): logprobs: Optional[List[TokenLogProbs]] = None -# This is an internally used class +@json_schema_type +class OverrideSystemMessage(Enum): + """Config for how to override the default system prompt. + + :cvar append: Appends the provided system message to the default system prompt: + https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_2/#-function-definitions-in-the-system-prompt- + :cvar replace: Replaces the default system prompt with the provided system message. The system message can include the string + '{{function_definitions}}' to indicate where the function definitions should be inserted. + """ + + append = "append" + replace = "replace" + + +@json_schema_type +class ToolConfig(BaseModel): + # zero-shot tool definitions as input to the model + tools: Optional[List[ToolDefinition]] = Field(default_factory=list) + tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto) + tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None) + override_system_message: OverrideSystemMessage = Field( + default=OverrideSystemMessage.append + ) + + +@json_schema_type class ChatCompletionRequest(BaseModel): model: str messages: List[Message] sampling_params: Optional[SamplingParams] = SamplingParams() - tools: Optional[List[ToolDefinition]] = Field(default_factory=list) - tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto) - tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None) + + tool_config: Optional[ToolConfig] = None + response_format: Optional[ResponseFormat] = None stream: Optional[bool] = False logprobs: Optional[LogProbConfig] = None + # DEPRECATED: use tool_config instead + # zero-shot tool definitions as input to the model + tools: Optional[List[ToolDefinition]] = Field(default_factory=list) + tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto) + tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None) + @json_schema_type class ChatCompletionResponseStreamChunk(BaseModel): @@ -406,6 +437,7 @@ class Inference(Protocol): response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, + tool_config: Optional[ToolConfig] = None, ) -> Union[ ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk] ]: