RFC: config to override system prompt

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
This commit is contained in:
Eric Huang (AI Platform) 2025-01-30 20:31:37 -08:00
parent 15dcc4ea5e
commit 02c3ffcdbe

View file

@ -310,18 +310,49 @@ class CompletionResponseStreamChunk(BaseModel):
logprobs: Optional[List[TokenLogProbs]] = None
# This is an internally used class
@json_schema_type
class OverrideSystemMessage(Enum):
"""Config for how to override the default system prompt.
:cvar append: Appends the provided system message to the default system prompt:
https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_2/#-function-definitions-in-the-system-prompt-
:cvar replace: Replaces the default system prompt with the provided system message. The system message can include the string
'{{function_definitions}}' to indicate where the function definitions should be inserted.
"""
append = "append"
replace = "replace"
@json_schema_type
class ToolConfig(BaseModel):
# zero-shot tool definitions as input to the model
tools: Optional[List[ToolDefinition]] = Field(default_factory=list)
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None)
override_system_message: OverrideSystemMessage = Field(
default=OverrideSystemMessage.append
)
@json_schema_type
class ChatCompletionRequest(BaseModel):
model: str
messages: List[Message]
sampling_params: Optional[SamplingParams] = SamplingParams()
tools: Optional[List[ToolDefinition]] = Field(default_factory=list)
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None)
tool_config: Optional[ToolConfig] = None
response_format: Optional[ResponseFormat] = None
stream: Optional[bool] = False
logprobs: Optional[LogProbConfig] = None
# DEPRECATED: use tool_config instead
# zero-shot tool definitions as input to the model
tools: Optional[List[ToolDefinition]] = Field(default_factory=list)
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None)
@json_schema_type
class ChatCompletionResponseStreamChunk(BaseModel):
@ -406,6 +437,7 @@ class Inference(Protocol):
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> Union[
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
]: