sys_prompt

This commit is contained in:
Eric Huang 2025-01-31 11:03:58 -08:00
parent 15dcc4ea5e
commit 6d035b3152
26 changed files with 147 additions and 48 deletions

View file

@ -33,6 +33,7 @@ from llama_stack.apis.inference import (
ToolResponse,
ToolResponseMessage,
UserMessage,
ToolConfig,
)
from llama_stack.apis.safety import SafetyViolation
from llama_stack.apis.tools import ToolDef
@ -155,8 +156,13 @@ class AgentConfigCommon(BaseModel):
output_shields: Optional[List[str]] = Field(default_factory=list)
toolgroups: Optional[List[AgentToolGroup]] = Field(default_factory=list)
client_tools: Optional[List[ToolDef]] = Field(default_factory=list)
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None)
tool_choice: Optional[ToolChoice] = Field(
default=ToolChoice.auto, deprecated="use tool_config instead"
)
tool_prompt_format: Optional[ToolPromptFormat] = Field(
default=None, deprecated="use tool_config instead"
)
tool_config: Optional[ToolConfig] = Field(default=None)
max_infer_iters: int = 10
@ -280,7 +286,7 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
toolgroups: Optional[List[AgentToolGroup]] = None
stream: Optional[bool] = False
tool_config: Optional[ToolConfig] = None
@json_schema_type
class AgentTurnResponseStreamChunk(BaseModel):
@ -327,6 +333,7 @@ class Agents(Protocol):
stream: Optional[bool] = False,
documents: Optional[List[Document]] = None,
toolgroups: Optional[List[AgentToolGroup]] = None,
tool_config: Optional[ToolConfig] = None,
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...
@webmethod(

View file

@ -310,14 +310,48 @@ class CompletionResponseStreamChunk(BaseModel):
logprobs: Optional[List[TokenLogProbs]] = None
@json_schema_type
class SystemMessageBehavior(Enum):
"""Config for how to override the default system prompt.
:cvar append: Appends the provided system message to the default system prompt:
https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_2/#-function-definitions-in-the-system-prompt-
:cvar replace: Replaces the default system prompt with the provided system message. The system message can include the string
'{{function_definitions}}' to indicate where the function definitions should be inserted.
"""
append = "append"
replace = "replace"
@json_schema_type
class ToolConfig(BaseModel):
"""Configuration for tool use.
:param tool_choice: (Optional) Whether tool use is required or automatic. Defaults to ToolChoice.auto.
:param tool_prompt_format: (Optional) Instructs the model how to format tool calls. By default, Llama Stack will attempt to use a format that is best adapted to the model.
- `ToolPromptFormat.json`: The tool calls are formatted as a JSON object.
- `ToolPromptFormat.function_tag`: The tool calls are enclosed in a <function=function_name> tag.
- `ToolPromptFormat.python_list`: The tool calls are output as Python syntax -- a list of function calls.
"""
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None)
system_message_behavior: SystemMessageBehavior = Field(
default=SystemMessageBehavior.append
)
# This is an internally used class
@json_schema_type
class ChatCompletionRequest(BaseModel):
model: str
messages: List[Message]
sampling_params: Optional[SamplingParams] = SamplingParams()
tools: Optional[List[ToolDefinition]] = Field(default_factory=list)
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None)
tool_config: Optional[ToolConfig] = None
response_format: Optional[ResponseFormat] = None
stream: Optional[bool] = False
logprobs: Optional[LogProbConfig] = None
@ -406,6 +440,7 @@ class Inference(Protocol):
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> Union[
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
]:
@ -416,15 +451,20 @@ class Inference(Protocol):
:param sampling_params: Parameters to control the sampling strategy
:param tools: (Optional) List of tool definitions available to the model
:param tool_choice: (Optional) Whether tool use is required or automatic. Defaults to ToolChoice.auto.
.. deprecated::
Use tool_config instead.
:param tool_prompt_format: (Optional) Instructs the model how to format tool calls. By default, Llama Stack will attempt to use a format that is best adapted to the model.
- `ToolPromptFormat.json`: The tool calls are formatted as a JSON object.
- `ToolPromptFormat.function_tag`: The tool calls are enclosed in a <function=function_name> tag.
- `ToolPromptFormat.python_list`: The tool calls are output as Python syntax -- a list of function calls.
.. deprecated::
Use tool_config instead.
:param response_format: (Optional) Grammar specification for guided (structured) decoding. There are two options:
- `ResponseFormat.json_schema`: The grammar is a JSON schema. Most providers support this format.
- `ResponseFormat.grammar`: The grammar is a BNF grammar. This format is more flexible, but not all providers support it.
:param stream: (Optional) If True, generate an SSE event stream of the response. Defaults to False.
:param logprobs: (Optional) If specified, log probabilities for each token position will be returned.
:param tool_config: (Optional) Configuration for tool use.
:returns: If stream=False, returns a ChatCompletionResponse with the full completion.
If stream=True, returns an SSE event stream of ChatCompletionResponseStreamChunk
"""

View file

@ -24,6 +24,7 @@ from llama_stack.apis.inference import (
ResponseFormat,
SamplingParams,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
)
@ -138,6 +139,7 @@ class InferenceRouter(Inference):
tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator:
model = await self.routing_table.get_model(model_id)
if model is None:
@ -146,6 +148,20 @@ class InferenceRouter(Inference):
raise ValueError(
f"Model '{model_id}' is an embedding model and does not support chat completions"
)
if tool_config:
if tool_choice != tool_config.tool_choice:
raise ValueError(
"tool_choice and tool_config.tool_choice must match"
)
if tool_prompt_format != tool_config.tool_prompt_format:
raise ValueError(
"tool_prompt_format and tool_config.tool_prompt_format must match"
)
else:
tool_config = ToolConfig(
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
)
params = dict(
model_id=model_id,
messages=messages,
@ -156,6 +172,7 @@ class InferenceRouter(Inference):
response_format=response_format,
stream=stream,
logprobs=logprobs,
tool_config=tool_config,
)
provider = self.routing_table.get_provider_impl(model_id)
if stream:

View file

@ -515,10 +515,11 @@ class ChatAgent(ShieldRunnerMixin):
for tool in tool_defs.values()
if tool_to_group.get(tool.tool_name, None) != RAG_TOOL_GROUP
],
tool_prompt_format=self.agent_config.tool_prompt_format,
tool_prompt_format=self.agent_config.tool_config.tool_prompt_format,
response_format=self.agent_config.response_format,
stream=True,
sampling_params=sampling_params,
tool_config=self.agent_config.tool_config,
):
event = chunk.event
if event.event_type == ChatCompletionResponseEventType.start:

View file

@ -25,7 +25,12 @@ from llama_stack.apis.agents import (
Session,
Turn,
)
from llama_stack.apis.inference import Inference, ToolResponseMessage, UserMessage
from llama_stack.apis.inference import (
Inference,
ToolConfig,
ToolResponseMessage,
UserMessage,
)
from llama_stack.apis.safety import Safety
from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
@ -146,6 +151,7 @@ class MetaReferenceAgentsImpl(Agents):
toolgroups: Optional[List[AgentToolGroup]] = None,
documents: Optional[List[Document]] = None,
stream: Optional[bool] = False,
tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator:
request = AgentTurnCreateRequest(
agent_id=agent_id,
@ -154,6 +160,7 @@ class MetaReferenceAgentsImpl(Agents):
stream=True,
toolgroups=toolgroups,
documents=documents,
tool_config=tool_config,
)
if stream:
return self._create_agent_turn_streaming(request)

View file

@ -400,7 +400,7 @@ class Llama:
yield from self.generate(
model_input=self.formatter.encode_dialog_prompt(
request.messages,
request.tool_prompt_format,
request.tool_config.tool_prompt_format,
),
max_gen_len=max_gen_len,
temperature=temperature,

View file

@ -38,6 +38,7 @@ from llama_stack.apis.inference import (
ResponseFormat,
TokenLogProbs,
ToolChoice,
ToolConfig,
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.providers.datatypes import ModelsProtocolPrivate
@ -270,6 +271,7 @@ class MetaReferenceInferenceImpl(
tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator:
if logprobs:
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
@ -280,11 +282,10 @@ class MetaReferenceInferenceImpl(
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
response_format=response_format,
stream=stream,
logprobs=logprobs,
tool_config=tool_config,
)
self.check_model(request)

View file

@ -17,6 +17,7 @@ from llama_stack.apis.inference import (
ToolChoice,
ToolDefinition,
ToolPromptFormat,
ToolConfig,
)
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
from llama_stack.providers.utils.inference.embedding_mixin import (
@ -71,5 +72,6 @@ class SentenceTransformersInferenceImpl(
tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator:
raise ValueError("Sentence transformers don't support chat completion")

View file

@ -30,6 +30,7 @@ from llama_stack.apis.inference import (
ResponseFormat,
SamplingParams,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
)
@ -159,6 +160,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> ChatCompletionResponse | ChatCompletionResponseStreamChunk:
assert self.engine is not None
@ -167,10 +169,9 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream,
logprobs=logprobs,
tool_config=tool_config,
)
log.info("Sampling params: %s", sampling_params)

View file

@ -24,6 +24,7 @@ from llama_stack.apis.inference import (
ResponseFormat,
SamplingParams,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
)
@ -102,6 +103,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> Union[
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
]:
@ -111,11 +113,10 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
response_format=response_format,
stream=stream,
logprobs=logprobs,
tool_config=tool_config,
)
if stream:

View file

@ -24,6 +24,7 @@ from llama_stack.apis.inference import (
ResponseFormat,
SamplingParams,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
)
@ -130,6 +131,7 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator:
model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest(
@ -142,6 +144,7 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
response_format=response_format,
stream=stream,
logprobs=logprobs,
tool_config=tool_config,
)
if stream:

View file

@ -89,16 +89,16 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator:
request = ChatCompletionRequest(
model=model,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream,
logprobs=logprobs,
tool_config=tool_config,
)
client = OpenAI(base_url=self.config.url, api_key=self.config.api_token)

View file

@ -25,6 +25,7 @@ from llama_stack.apis.inference import (
ResponseFormatType,
SamplingParams,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
)
@ -208,6 +209,7 @@ class FireworksInferenceAdapter(
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator:
model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest(
@ -215,11 +217,10 @@ class FireworksInferenceAdapter(
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
response_format=response_format,
stream=stream,
logprobs=logprobs,
tool_config=tool_config,
)
if stream:

View file

@ -99,6 +99,7 @@ class GroqInferenceAdapter(Inference, ModelRegistryHelper, NeedsRequestProviderD
tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> Union[
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
]:
@ -117,10 +118,9 @@ class GroqInferenceAdapter(Inference, ModelRegistryHelper, NeedsRequestProviderD
sampling_params=sampling_params,
response_format=response_format,
tools=tools,
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream,
logprobs=logprobs,
tool_config=tool_config,
)
)

View file

@ -79,7 +79,7 @@ def convert_chat_completion_request(
# so we exclude it for now
warnings.warn("repetition_penalty is not supported")
if request.tool_prompt_format != ToolPromptFormat.json:
if request.tool_config.tool_prompt_format != ToolPromptFormat.json:
warnings.warn("tool_prompt_format is not used by Groq. Ignoring.")
sampling_options = get_sampling_strategy_options(request.sampling_params)
@ -93,7 +93,11 @@ def convert_chat_completion_request(
temperature=sampling_options.get("temperature", 1.0),
top_p=sampling_options.get("top_p", 1.0),
tools=[_convert_groq_tool_definition(tool) for tool in request.tools or []],
tool_choice=request.tool_choice.value if request.tool_choice else None,
tool_choice=(
request.tool_config.tool_choice.value
if request.tool_config.tool_choice
else None
),
)

View file

@ -178,6 +178,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> Union[
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
]:
@ -193,10 +194,9 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
sampling_params=sampling_params,
response_format=response_format,
tools=tools,
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream,
logprobs=logprobs,
tool_config=tool_config,
),
n=1,
)

View file

@ -253,9 +253,9 @@ def convert_chat_completion_request(
payload.update(
tools=[_convert_tooldef_to_openai_tool(tool) for tool in request.tools]
)
if request.tool_choice:
if request.tool_config.tool_choice:
payload.update(
tool_choice=request.tool_choice.value
tool_choice=request.tool_config.tool_choice.value
) # we cannot include tool_choice w/o tools, server will complain
if request.logprobs:

View file

@ -29,6 +29,7 @@ from llama_stack.apis.inference import (
ResponseFormat,
SamplingParams,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
)
@ -224,6 +225,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator:
model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest(
@ -231,11 +233,10 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream,
logprobs=logprobs,
response_format=response_format,
tool_config=tool_config,
)
if stream:
return self._stream_chat_completion(request)
@ -322,6 +323,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
self, request: ChatCompletionRequest
) -> AsyncGenerator:
params = await self._get_params(request)
print(params)
async def _generate_and_convert_to_openai_compat():
if "messages" in params:

View file

@ -85,10 +85,9 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference):
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream,
logprobs=logprobs,
tool_config=tool_config,
)
client = OpenAI(base_url=self.config.url, api_key=self.config.api_token)

View file

@ -125,10 +125,9 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream,
logprobs=logprobs,
tool_config=tool_config,
)
request_sambanova = await self.convert_chat_completion_request(request)

View file

@ -26,6 +26,7 @@ from llama_stack.apis.inference import (
ResponseFormatType,
SamplingParams,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
)
@ -213,6 +214,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator:
model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest(
@ -220,11 +222,10 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
response_format=response_format,
stream=stream,
logprobs=logprobs,
tool_config=tool_config,
)
if stream:

View file

@ -24,6 +24,7 @@ from llama_stack.apis.inference import (
ResponseFormatType,
SamplingParams,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
)
@ -198,6 +199,7 @@ class TogetherInferenceAdapter(
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator:
model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest(
@ -205,11 +207,10 @@ class TogetherInferenceAdapter(
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
response_format=response_format,
stream=stream,
logprobs=logprobs,
tool_config=tool_config,
)
if stream:

View file

@ -27,6 +27,7 @@ from llama_stack.apis.inference import (
ResponseFormatType,
SamplingParams,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
)
@ -119,6 +120,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator:
model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest(
@ -126,11 +128,10 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream,
logprobs=logprobs,
response_format=response_format,
tool_config=tool_config,
)
if stream:
return self._stream_chat_completion(request, self.client)

View file

@ -181,7 +181,7 @@ class TestConvertChatCompletionRequest:
def test_includes_tool_choice(self):
request = self._dummy_chat_completion_request()
request.tool_choice = ToolChoice.required
request.tool_config.tool_choice = ToolChoice.required
converted = convert_chat_completion_request(request)

View file

@ -13,7 +13,12 @@ from llama_models.llama3.api.datatypes import (
ToolPromptFormat,
)
from llama_stack.apis.inference import ChatCompletionRequest, SystemMessage, UserMessage
from llama_stack.apis.inference import (
ChatCompletionRequest,
SystemMessage,
ToolConfig,
UserMessage,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_messages,
)
@ -73,7 +78,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
},
)
],
tool_prompt_format=ToolPromptFormat.json,
tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.json),
)
messages = chat_completion_request_to_messages(request, MODEL)
self.assertEqual(len(messages), 3)

View file

@ -49,6 +49,7 @@ from llama_stack.apis.inference import (
SystemMessage,
ToolChoice,
UserMessage,
SystemMessageBehavior,
)
from llama_stack.providers.utils.inference import supported_inference_models
@ -319,7 +320,7 @@ def response_format_prompt(fmt: Optional[ResponseFormat]):
def augment_messages_for_tools_llama_3_1(
request: ChatCompletionRequest,
) -> List[Message]:
assert request.tool_choice == ToolChoice.auto, "Only `ToolChoice.auto` supported"
assert request.tool_config.tool_choice == ToolChoice.auto, "Only `ToolChoice.auto` supported"
existing_messages = request.messages
existing_system_message = None
@ -368,7 +369,7 @@ def augment_messages_for_tools_llama_3_1(
has_custom_tools = any(isinstance(dfn.tool_name, str) for dfn in request.tools)
if has_custom_tools:
fmt = request.tool_prompt_format or ToolPromptFormat.json
fmt = request.tool_config.tool_prompt_format or ToolPromptFormat.json
if fmt == ToolPromptFormat.json:
tool_gen = JsonCustomToolGenerator()
elif fmt == ToolPromptFormat.function_tag:
@ -389,7 +390,7 @@ def augment_messages_for_tools_llama_3_1(
def augment_messages_for_tools_llama_3_2(
request: ChatCompletionRequest,
) -> List[Message]:
assert request.tool_choice == ToolChoice.auto, "Only `ToolChoice.auto` supported"
assert request.tool_config.tool_choice == ToolChoice.auto, "Only `ToolChoice.auto` supported"
existing_messages = request.messages
existing_system_message = None
@ -419,19 +420,24 @@ def augment_messages_for_tools_llama_3_2(
custom_tools = [dfn for dfn in request.tools if isinstance(dfn.tool_name, str)]
if custom_tools:
fmt = request.tool_prompt_format or ToolPromptFormat.python_list
fmt = request.tool_config.tool_prompt_format or ToolPromptFormat.python_list
if fmt != ToolPromptFormat.python_list:
raise ValueError(
f"Non supported ToolPromptFormat {request.tool_prompt_format}"
f"Non supported ToolPromptFormat {request.tool_config.tool_prompt_format}"
)
tool_gen = PythonListCustomToolGenerator()
tool_template = tool_gen.gen(custom_tools)
system_prompt = None
if existing_system_message and request.tool_config.system_message_behavior == SystemMessageBehavior.replace:
system_prompt = existing_system_message.content
tool_template = tool_gen.gen(custom_tools, system_prompt)
sys_content += tool_template.render()
sys_content += "\n"
if existing_system_message:
if existing_system_message and request.tool_config.system_message_behavior == SystemMessageBehavior.append:
sys_content += interleaved_content_as_str(
existing_system_message.content, sep="\n"
)