sys_prompt

This commit is contained in:
Eric Huang 2025-01-31 11:03:58 -08:00
parent 15dcc4ea5e
commit 6d035b3152
26 changed files with 147 additions and 48 deletions

View file

@ -33,6 +33,7 @@ from llama_stack.apis.inference import (
ToolResponse, ToolResponse,
ToolResponseMessage, ToolResponseMessage,
UserMessage, UserMessage,
ToolConfig,
) )
from llama_stack.apis.safety import SafetyViolation from llama_stack.apis.safety import SafetyViolation
from llama_stack.apis.tools import ToolDef from llama_stack.apis.tools import ToolDef
@ -155,8 +156,13 @@ class AgentConfigCommon(BaseModel):
output_shields: Optional[List[str]] = Field(default_factory=list) output_shields: Optional[List[str]] = Field(default_factory=list)
toolgroups: Optional[List[AgentToolGroup]] = Field(default_factory=list) toolgroups: Optional[List[AgentToolGroup]] = Field(default_factory=list)
client_tools: Optional[List[ToolDef]] = Field(default_factory=list) client_tools: Optional[List[ToolDef]] = Field(default_factory=list)
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto) tool_choice: Optional[ToolChoice] = Field(
tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None) default=ToolChoice.auto, deprecated="use tool_config instead"
)
tool_prompt_format: Optional[ToolPromptFormat] = Field(
default=None, deprecated="use tool_config instead"
)
tool_config: Optional[ToolConfig] = Field(default=None)
max_infer_iters: int = 10 max_infer_iters: int = 10
@ -280,7 +286,7 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
toolgroups: Optional[List[AgentToolGroup]] = None toolgroups: Optional[List[AgentToolGroup]] = None
stream: Optional[bool] = False stream: Optional[bool] = False
tool_config: Optional[ToolConfig] = None
@json_schema_type @json_schema_type
class AgentTurnResponseStreamChunk(BaseModel): class AgentTurnResponseStreamChunk(BaseModel):
@ -327,6 +333,7 @@ class Agents(Protocol):
stream: Optional[bool] = False, stream: Optional[bool] = False,
documents: Optional[List[Document]] = None, documents: Optional[List[Document]] = None,
toolgroups: Optional[List[AgentToolGroup]] = None, toolgroups: Optional[List[AgentToolGroup]] = None,
tool_config: Optional[ToolConfig] = None,
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ... ) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...
@webmethod( @webmethod(

View file

@ -310,14 +310,48 @@ class CompletionResponseStreamChunk(BaseModel):
logprobs: Optional[List[TokenLogProbs]] = None logprobs: Optional[List[TokenLogProbs]] = None
@json_schema_type
class SystemMessageBehavior(Enum):
"""Config for how to override the default system prompt.
:cvar append: Appends the provided system message to the default system prompt:
https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_2/#-function-definitions-in-the-system-prompt-
:cvar replace: Replaces the default system prompt with the provided system message. The system message can include the string
'{{function_definitions}}' to indicate where the function definitions should be inserted.
"""
append = "append"
replace = "replace"
@json_schema_type
class ToolConfig(BaseModel):
"""Configuration for tool use.
:param tool_choice: (Optional) Whether tool use is required or automatic. Defaults to ToolChoice.auto.
:param tool_prompt_format: (Optional) Instructs the model how to format tool calls. By default, Llama Stack will attempt to use a format that is best adapted to the model.
- `ToolPromptFormat.json`: The tool calls are formatted as a JSON object.
- `ToolPromptFormat.function_tag`: The tool calls are enclosed in a <function=function_name> tag.
- `ToolPromptFormat.python_list`: The tool calls are output as Python syntax -- a list of function calls.
"""
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None)
system_message_behavior: SystemMessageBehavior = Field(
default=SystemMessageBehavior.append
)
# This is an internally used class # This is an internally used class
@json_schema_type
class ChatCompletionRequest(BaseModel): class ChatCompletionRequest(BaseModel):
model: str model: str
messages: List[Message] messages: List[Message]
sampling_params: Optional[SamplingParams] = SamplingParams() sampling_params: Optional[SamplingParams] = SamplingParams()
tools: Optional[List[ToolDefinition]] = Field(default_factory=list) tools: Optional[List[ToolDefinition]] = Field(default_factory=list)
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto) tool_config: Optional[ToolConfig] = None
tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None)
response_format: Optional[ResponseFormat] = None response_format: Optional[ResponseFormat] = None
stream: Optional[bool] = False stream: Optional[bool] = False
logprobs: Optional[LogProbConfig] = None logprobs: Optional[LogProbConfig] = None
@ -406,6 +440,7 @@ class Inference(Protocol):
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> Union[ ) -> Union[
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk] ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
]: ]:
@ -416,15 +451,20 @@ class Inference(Protocol):
:param sampling_params: Parameters to control the sampling strategy :param sampling_params: Parameters to control the sampling strategy
:param tools: (Optional) List of tool definitions available to the model :param tools: (Optional) List of tool definitions available to the model
:param tool_choice: (Optional) Whether tool use is required or automatic. Defaults to ToolChoice.auto. :param tool_choice: (Optional) Whether tool use is required or automatic. Defaults to ToolChoice.auto.
.. deprecated::
Use tool_config instead.
:param tool_prompt_format: (Optional) Instructs the model how to format tool calls. By default, Llama Stack will attempt to use a format that is best adapted to the model. :param tool_prompt_format: (Optional) Instructs the model how to format tool calls. By default, Llama Stack will attempt to use a format that is best adapted to the model.
- `ToolPromptFormat.json`: The tool calls are formatted as a JSON object. - `ToolPromptFormat.json`: The tool calls are formatted as a JSON object.
- `ToolPromptFormat.function_tag`: The tool calls are enclosed in a <function=function_name> tag. - `ToolPromptFormat.function_tag`: The tool calls are enclosed in a <function=function_name> tag.
- `ToolPromptFormat.python_list`: The tool calls are output as Python syntax -- a list of function calls. - `ToolPromptFormat.python_list`: The tool calls are output as Python syntax -- a list of function calls.
.. deprecated::
Use tool_config instead.
:param response_format: (Optional) Grammar specification for guided (structured) decoding. There are two options: :param response_format: (Optional) Grammar specification for guided (structured) decoding. There are two options:
- `ResponseFormat.json_schema`: The grammar is a JSON schema. Most providers support this format. - `ResponseFormat.json_schema`: The grammar is a JSON schema. Most providers support this format.
- `ResponseFormat.grammar`: The grammar is a BNF grammar. This format is more flexible, but not all providers support it. - `ResponseFormat.grammar`: The grammar is a BNF grammar. This format is more flexible, but not all providers support it.
:param stream: (Optional) If True, generate an SSE event stream of the response. Defaults to False. :param stream: (Optional) If True, generate an SSE event stream of the response. Defaults to False.
:param logprobs: (Optional) If specified, log probabilities for each token position will be returned. :param logprobs: (Optional) If specified, log probabilities for each token position will be returned.
:param tool_config: (Optional) Configuration for tool use.
:returns: If stream=False, returns a ChatCompletionResponse with the full completion. :returns: If stream=False, returns a ChatCompletionResponse with the full completion.
If stream=True, returns an SSE event stream of ChatCompletionResponseStreamChunk If stream=True, returns an SSE event stream of ChatCompletionResponseStreamChunk
""" """

View file

@ -24,6 +24,7 @@ from llama_stack.apis.inference import (
ResponseFormat, ResponseFormat,
SamplingParams, SamplingParams,
ToolChoice, ToolChoice,
ToolConfig,
ToolDefinition, ToolDefinition,
ToolPromptFormat, ToolPromptFormat,
) )
@ -138,6 +139,7 @@ class InferenceRouter(Inference):
tool_prompt_format: Optional[ToolPromptFormat] = None, tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
model = await self.routing_table.get_model(model_id) model = await self.routing_table.get_model(model_id)
if model is None: if model is None:
@ -146,6 +148,20 @@ class InferenceRouter(Inference):
raise ValueError( raise ValueError(
f"Model '{model_id}' is an embedding model and does not support chat completions" f"Model '{model_id}' is an embedding model and does not support chat completions"
) )
if tool_config:
if tool_choice != tool_config.tool_choice:
raise ValueError(
"tool_choice and tool_config.tool_choice must match"
)
if tool_prompt_format != tool_config.tool_prompt_format:
raise ValueError(
"tool_prompt_format and tool_config.tool_prompt_format must match"
)
else:
tool_config = ToolConfig(
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
)
params = dict( params = dict(
model_id=model_id, model_id=model_id,
messages=messages, messages=messages,
@ -156,6 +172,7 @@ class InferenceRouter(Inference):
response_format=response_format, response_format=response_format,
stream=stream, stream=stream,
logprobs=logprobs, logprobs=logprobs,
tool_config=tool_config,
) )
provider = self.routing_table.get_provider_impl(model_id) provider = self.routing_table.get_provider_impl(model_id)
if stream: if stream:

View file

@ -515,10 +515,11 @@ class ChatAgent(ShieldRunnerMixin):
for tool in tool_defs.values() for tool in tool_defs.values()
if tool_to_group.get(tool.tool_name, None) != RAG_TOOL_GROUP if tool_to_group.get(tool.tool_name, None) != RAG_TOOL_GROUP
], ],
tool_prompt_format=self.agent_config.tool_prompt_format, tool_prompt_format=self.agent_config.tool_config.tool_prompt_format,
response_format=self.agent_config.response_format, response_format=self.agent_config.response_format,
stream=True, stream=True,
sampling_params=sampling_params, sampling_params=sampling_params,
tool_config=self.agent_config.tool_config,
): ):
event = chunk.event event = chunk.event
if event.event_type == ChatCompletionResponseEventType.start: if event.event_type == ChatCompletionResponseEventType.start:

View file

@ -25,7 +25,12 @@ from llama_stack.apis.agents import (
Session, Session,
Turn, Turn,
) )
from llama_stack.apis.inference import Inference, ToolResponseMessage, UserMessage from llama_stack.apis.inference import (
Inference,
ToolConfig,
ToolResponseMessage,
UserMessage,
)
from llama_stack.apis.safety import Safety from llama_stack.apis.safety import Safety
from llama_stack.apis.tools import ToolGroups, ToolRuntime from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.apis.vector_io import VectorIO from llama_stack.apis.vector_io import VectorIO
@ -146,6 +151,7 @@ class MetaReferenceAgentsImpl(Agents):
toolgroups: Optional[List[AgentToolGroup]] = None, toolgroups: Optional[List[AgentToolGroup]] = None,
documents: Optional[List[Document]] = None, documents: Optional[List[Document]] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
request = AgentTurnCreateRequest( request = AgentTurnCreateRequest(
agent_id=agent_id, agent_id=agent_id,
@ -154,6 +160,7 @@ class MetaReferenceAgentsImpl(Agents):
stream=True, stream=True,
toolgroups=toolgroups, toolgroups=toolgroups,
documents=documents, documents=documents,
tool_config=tool_config,
) )
if stream: if stream:
return self._create_agent_turn_streaming(request) return self._create_agent_turn_streaming(request)

View file

@ -400,7 +400,7 @@ class Llama:
yield from self.generate( yield from self.generate(
model_input=self.formatter.encode_dialog_prompt( model_input=self.formatter.encode_dialog_prompt(
request.messages, request.messages,
request.tool_prompt_format, request.tool_config.tool_prompt_format,
), ),
max_gen_len=max_gen_len, max_gen_len=max_gen_len,
temperature=temperature, temperature=temperature,

View file

@ -38,6 +38,7 @@ from llama_stack.apis.inference import (
ResponseFormat, ResponseFormat,
TokenLogProbs, TokenLogProbs,
ToolChoice, ToolChoice,
ToolConfig,
) )
from llama_stack.apis.models import Model, ModelType from llama_stack.apis.models import Model, ModelType
from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.datatypes import ModelsProtocolPrivate
@ -270,6 +271,7 @@ class MetaReferenceInferenceImpl(
tool_prompt_format: Optional[ToolPromptFormat] = None, tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
if logprobs: if logprobs:
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}" assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
@ -280,11 +282,10 @@ class MetaReferenceInferenceImpl(
messages=messages, messages=messages,
sampling_params=sampling_params, sampling_params=sampling_params,
tools=tools or [], tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
response_format=response_format, response_format=response_format,
stream=stream, stream=stream,
logprobs=logprobs, logprobs=logprobs,
tool_config=tool_config,
) )
self.check_model(request) self.check_model(request)

View file

@ -17,6 +17,7 @@ from llama_stack.apis.inference import (
ToolChoice, ToolChoice,
ToolDefinition, ToolDefinition,
ToolPromptFormat, ToolPromptFormat,
ToolConfig,
) )
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
from llama_stack.providers.utils.inference.embedding_mixin import ( from llama_stack.providers.utils.inference.embedding_mixin import (
@ -71,5 +72,6 @@ class SentenceTransformersInferenceImpl(
tool_prompt_format: Optional[ToolPromptFormat] = None, tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
raise ValueError("Sentence transformers don't support chat completion") raise ValueError("Sentence transformers don't support chat completion")

View file

@ -30,6 +30,7 @@ from llama_stack.apis.inference import (
ResponseFormat, ResponseFormat,
SamplingParams, SamplingParams,
ToolChoice, ToolChoice,
ToolConfig,
ToolDefinition, ToolDefinition,
ToolPromptFormat, ToolPromptFormat,
) )
@ -159,6 +160,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> ChatCompletionResponse | ChatCompletionResponseStreamChunk: ) -> ChatCompletionResponse | ChatCompletionResponseStreamChunk:
assert self.engine is not None assert self.engine is not None
@ -167,10 +169,9 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
messages=messages, messages=messages,
sampling_params=sampling_params, sampling_params=sampling_params,
tools=tools or [], tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream, stream=stream,
logprobs=logprobs, logprobs=logprobs,
tool_config=tool_config,
) )
log.info("Sampling params: %s", sampling_params) log.info("Sampling params: %s", sampling_params)

View file

@ -24,6 +24,7 @@ from llama_stack.apis.inference import (
ResponseFormat, ResponseFormat,
SamplingParams, SamplingParams,
ToolChoice, ToolChoice,
ToolConfig,
ToolDefinition, ToolDefinition,
ToolPromptFormat, ToolPromptFormat,
) )
@ -102,6 +103,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
tool_prompt_format: Optional[ToolPromptFormat] = None, tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> Union[ ) -> Union[
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk] ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
]: ]:
@ -111,11 +113,10 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
messages=messages, messages=messages,
sampling_params=sampling_params, sampling_params=sampling_params,
tools=tools or [], tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
response_format=response_format, response_format=response_format,
stream=stream, stream=stream,
logprobs=logprobs, logprobs=logprobs,
tool_config=tool_config,
) )
if stream: if stream:

View file

@ -24,6 +24,7 @@ from llama_stack.apis.inference import (
ResponseFormat, ResponseFormat,
SamplingParams, SamplingParams,
ToolChoice, ToolChoice,
ToolConfig,
ToolDefinition, ToolDefinition,
ToolPromptFormat, ToolPromptFormat,
) )
@ -130,6 +131,7 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
model = await self.model_store.get_model(model_id) model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest( request = ChatCompletionRequest(
@ -142,6 +144,7 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
response_format=response_format, response_format=response_format,
stream=stream, stream=stream,
logprobs=logprobs, logprobs=logprobs,
tool_config=tool_config,
) )
if stream: if stream:

View file

@ -89,16 +89,16 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
tool_prompt_format: Optional[ToolPromptFormat] = None, tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
request = ChatCompletionRequest( request = ChatCompletionRequest(
model=model, model=model,
messages=messages, messages=messages,
sampling_params=sampling_params, sampling_params=sampling_params,
tools=tools or [], tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream, stream=stream,
logprobs=logprobs, logprobs=logprobs,
tool_config=tool_config,
) )
client = OpenAI(base_url=self.config.url, api_key=self.config.api_token) client = OpenAI(base_url=self.config.url, api_key=self.config.api_token)

View file

@ -25,6 +25,7 @@ from llama_stack.apis.inference import (
ResponseFormatType, ResponseFormatType,
SamplingParams, SamplingParams,
ToolChoice, ToolChoice,
ToolConfig,
ToolDefinition, ToolDefinition,
ToolPromptFormat, ToolPromptFormat,
) )
@ -208,6 +209,7 @@ class FireworksInferenceAdapter(
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
model = await self.model_store.get_model(model_id) model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest( request = ChatCompletionRequest(
@ -215,11 +217,10 @@ class FireworksInferenceAdapter(
messages=messages, messages=messages,
sampling_params=sampling_params, sampling_params=sampling_params,
tools=tools or [], tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
response_format=response_format, response_format=response_format,
stream=stream, stream=stream,
logprobs=logprobs, logprobs=logprobs,
tool_config=tool_config,
) )
if stream: if stream:

View file

@ -99,6 +99,7 @@ class GroqInferenceAdapter(Inference, ModelRegistryHelper, NeedsRequestProviderD
tool_prompt_format: Optional[ToolPromptFormat] = None, tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> Union[ ) -> Union[
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk] ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
]: ]:
@ -117,10 +118,9 @@ class GroqInferenceAdapter(Inference, ModelRegistryHelper, NeedsRequestProviderD
sampling_params=sampling_params, sampling_params=sampling_params,
response_format=response_format, response_format=response_format,
tools=tools, tools=tools,
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream, stream=stream,
logprobs=logprobs, logprobs=logprobs,
tool_config=tool_config,
) )
) )

View file

@ -79,7 +79,7 @@ def convert_chat_completion_request(
# so we exclude it for now # so we exclude it for now
warnings.warn("repetition_penalty is not supported") warnings.warn("repetition_penalty is not supported")
if request.tool_prompt_format != ToolPromptFormat.json: if request.tool_config.tool_prompt_format != ToolPromptFormat.json:
warnings.warn("tool_prompt_format is not used by Groq. Ignoring.") warnings.warn("tool_prompt_format is not used by Groq. Ignoring.")
sampling_options = get_sampling_strategy_options(request.sampling_params) sampling_options = get_sampling_strategy_options(request.sampling_params)
@ -93,7 +93,11 @@ def convert_chat_completion_request(
temperature=sampling_options.get("temperature", 1.0), temperature=sampling_options.get("temperature", 1.0),
top_p=sampling_options.get("top_p", 1.0), top_p=sampling_options.get("top_p", 1.0),
tools=[_convert_groq_tool_definition(tool) for tool in request.tools or []], tools=[_convert_groq_tool_definition(tool) for tool in request.tools or []],
tool_choice=request.tool_choice.value if request.tool_choice else None, tool_choice=(
request.tool_config.tool_choice.value
if request.tool_config.tool_choice
else None
),
) )

View file

@ -178,6 +178,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
tool_prompt_format: Optional[ToolPromptFormat] = None, tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> Union[ ) -> Union[
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk] ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
]: ]:
@ -193,10 +194,9 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
sampling_params=sampling_params, sampling_params=sampling_params,
response_format=response_format, response_format=response_format,
tools=tools, tools=tools,
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream, stream=stream,
logprobs=logprobs, logprobs=logprobs,
tool_config=tool_config,
), ),
n=1, n=1,
) )

View file

@ -253,9 +253,9 @@ def convert_chat_completion_request(
payload.update( payload.update(
tools=[_convert_tooldef_to_openai_tool(tool) for tool in request.tools] tools=[_convert_tooldef_to_openai_tool(tool) for tool in request.tools]
) )
if request.tool_choice: if request.tool_config.tool_choice:
payload.update( payload.update(
tool_choice=request.tool_choice.value tool_choice=request.tool_config.tool_choice.value
) # we cannot include tool_choice w/o tools, server will complain ) # we cannot include tool_choice w/o tools, server will complain
if request.logprobs: if request.logprobs:

View file

@ -29,6 +29,7 @@ from llama_stack.apis.inference import (
ResponseFormat, ResponseFormat,
SamplingParams, SamplingParams,
ToolChoice, ToolChoice,
ToolConfig,
ToolDefinition, ToolDefinition,
ToolPromptFormat, ToolPromptFormat,
) )
@ -224,6 +225,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
tool_prompt_format: Optional[ToolPromptFormat] = None, tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
model = await self.model_store.get_model(model_id) model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest( request = ChatCompletionRequest(
@ -231,11 +233,10 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
messages=messages, messages=messages,
sampling_params=sampling_params, sampling_params=sampling_params,
tools=tools or [], tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream, stream=stream,
logprobs=logprobs, logprobs=logprobs,
response_format=response_format, response_format=response_format,
tool_config=tool_config,
) )
if stream: if stream:
return self._stream_chat_completion(request) return self._stream_chat_completion(request)
@ -322,6 +323,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
self, request: ChatCompletionRequest self, request: ChatCompletionRequest
) -> AsyncGenerator: ) -> AsyncGenerator:
params = await self._get_params(request) params = await self._get_params(request)
print(params)
async def _generate_and_convert_to_openai_compat(): async def _generate_and_convert_to_openai_compat():
if "messages" in params: if "messages" in params:

View file

@ -85,10 +85,9 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference):
messages=messages, messages=messages,
sampling_params=sampling_params, sampling_params=sampling_params,
tools=tools or [], tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream, stream=stream,
logprobs=logprobs, logprobs=logprobs,
tool_config=tool_config,
) )
client = OpenAI(base_url=self.config.url, api_key=self.config.api_token) client = OpenAI(base_url=self.config.url, api_key=self.config.api_token)

View file

@ -125,10 +125,9 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
messages=messages, messages=messages,
sampling_params=sampling_params, sampling_params=sampling_params,
tools=tools or [], tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream, stream=stream,
logprobs=logprobs, logprobs=logprobs,
tool_config=tool_config,
) )
request_sambanova = await self.convert_chat_completion_request(request) request_sambanova = await self.convert_chat_completion_request(request)

View file

@ -26,6 +26,7 @@ from llama_stack.apis.inference import (
ResponseFormatType, ResponseFormatType,
SamplingParams, SamplingParams,
ToolChoice, ToolChoice,
ToolConfig,
ToolDefinition, ToolDefinition,
ToolPromptFormat, ToolPromptFormat,
) )
@ -213,6 +214,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
model = await self.model_store.get_model(model_id) model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest( request = ChatCompletionRequest(
@ -220,11 +222,10 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
messages=messages, messages=messages,
sampling_params=sampling_params, sampling_params=sampling_params,
tools=tools or [], tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
response_format=response_format, response_format=response_format,
stream=stream, stream=stream,
logprobs=logprobs, logprobs=logprobs,
tool_config=tool_config,
) )
if stream: if stream:

View file

@ -24,6 +24,7 @@ from llama_stack.apis.inference import (
ResponseFormatType, ResponseFormatType,
SamplingParams, SamplingParams,
ToolChoice, ToolChoice,
ToolConfig,
ToolDefinition, ToolDefinition,
ToolPromptFormat, ToolPromptFormat,
) )
@ -198,6 +199,7 @@ class TogetherInferenceAdapter(
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
model = await self.model_store.get_model(model_id) model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest( request = ChatCompletionRequest(
@ -205,11 +207,10 @@ class TogetherInferenceAdapter(
messages=messages, messages=messages,
sampling_params=sampling_params, sampling_params=sampling_params,
tools=tools or [], tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
response_format=response_format, response_format=response_format,
stream=stream, stream=stream,
logprobs=logprobs, logprobs=logprobs,
tool_config=tool_config,
) )
if stream: if stream:

View file

@ -27,6 +27,7 @@ from llama_stack.apis.inference import (
ResponseFormatType, ResponseFormatType,
SamplingParams, SamplingParams,
ToolChoice, ToolChoice,
ToolConfig,
ToolDefinition, ToolDefinition,
ToolPromptFormat, ToolPromptFormat,
) )
@ -119,6 +120,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
tool_prompt_format: Optional[ToolPromptFormat] = None, tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
model = await self.model_store.get_model(model_id) model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest( request = ChatCompletionRequest(
@ -126,11 +128,10 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
messages=messages, messages=messages,
sampling_params=sampling_params, sampling_params=sampling_params,
tools=tools or [], tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream, stream=stream,
logprobs=logprobs, logprobs=logprobs,
response_format=response_format, response_format=response_format,
tool_config=tool_config,
) )
if stream: if stream:
return self._stream_chat_completion(request, self.client) return self._stream_chat_completion(request, self.client)

View file

@ -181,7 +181,7 @@ class TestConvertChatCompletionRequest:
def test_includes_tool_choice(self): def test_includes_tool_choice(self):
request = self._dummy_chat_completion_request() request = self._dummy_chat_completion_request()
request.tool_choice = ToolChoice.required request.tool_config.tool_choice = ToolChoice.required
converted = convert_chat_completion_request(request) converted = convert_chat_completion_request(request)

View file

@ -13,7 +13,12 @@ from llama_models.llama3.api.datatypes import (
ToolPromptFormat, ToolPromptFormat,
) )
from llama_stack.apis.inference import ChatCompletionRequest, SystemMessage, UserMessage from llama_stack.apis.inference import (
ChatCompletionRequest,
SystemMessage,
ToolConfig,
UserMessage,
)
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_messages, chat_completion_request_to_messages,
) )
@ -73,7 +78,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
}, },
) )
], ],
tool_prompt_format=ToolPromptFormat.json, tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.json),
) )
messages = chat_completion_request_to_messages(request, MODEL) messages = chat_completion_request_to_messages(request, MODEL)
self.assertEqual(len(messages), 3) self.assertEqual(len(messages), 3)

View file

@ -49,6 +49,7 @@ from llama_stack.apis.inference import (
SystemMessage, SystemMessage,
ToolChoice, ToolChoice,
UserMessage, UserMessage,
SystemMessageBehavior,
) )
from llama_stack.providers.utils.inference import supported_inference_models from llama_stack.providers.utils.inference import supported_inference_models
@ -319,7 +320,7 @@ def response_format_prompt(fmt: Optional[ResponseFormat]):
def augment_messages_for_tools_llama_3_1( def augment_messages_for_tools_llama_3_1(
request: ChatCompletionRequest, request: ChatCompletionRequest,
) -> List[Message]: ) -> List[Message]:
assert request.tool_choice == ToolChoice.auto, "Only `ToolChoice.auto` supported" assert request.tool_config.tool_choice == ToolChoice.auto, "Only `ToolChoice.auto` supported"
existing_messages = request.messages existing_messages = request.messages
existing_system_message = None existing_system_message = None
@ -368,7 +369,7 @@ def augment_messages_for_tools_llama_3_1(
has_custom_tools = any(isinstance(dfn.tool_name, str) for dfn in request.tools) has_custom_tools = any(isinstance(dfn.tool_name, str) for dfn in request.tools)
if has_custom_tools: if has_custom_tools:
fmt = request.tool_prompt_format or ToolPromptFormat.json fmt = request.tool_config.tool_prompt_format or ToolPromptFormat.json
if fmt == ToolPromptFormat.json: if fmt == ToolPromptFormat.json:
tool_gen = JsonCustomToolGenerator() tool_gen = JsonCustomToolGenerator()
elif fmt == ToolPromptFormat.function_tag: elif fmt == ToolPromptFormat.function_tag:
@ -389,7 +390,7 @@ def augment_messages_for_tools_llama_3_1(
def augment_messages_for_tools_llama_3_2( def augment_messages_for_tools_llama_3_2(
request: ChatCompletionRequest, request: ChatCompletionRequest,
) -> List[Message]: ) -> List[Message]:
assert request.tool_choice == ToolChoice.auto, "Only `ToolChoice.auto` supported" assert request.tool_config.tool_choice == ToolChoice.auto, "Only `ToolChoice.auto` supported"
existing_messages = request.messages existing_messages = request.messages
existing_system_message = None existing_system_message = None
@ -419,19 +420,24 @@ def augment_messages_for_tools_llama_3_2(
custom_tools = [dfn for dfn in request.tools if isinstance(dfn.tool_name, str)] custom_tools = [dfn for dfn in request.tools if isinstance(dfn.tool_name, str)]
if custom_tools: if custom_tools:
fmt = request.tool_prompt_format or ToolPromptFormat.python_list fmt = request.tool_config.tool_prompt_format or ToolPromptFormat.python_list
if fmt != ToolPromptFormat.python_list: if fmt != ToolPromptFormat.python_list:
raise ValueError( raise ValueError(
f"Non supported ToolPromptFormat {request.tool_prompt_format}" f"Non supported ToolPromptFormat {request.tool_config.tool_prompt_format}"
) )
tool_gen = PythonListCustomToolGenerator() tool_gen = PythonListCustomToolGenerator()
tool_template = tool_gen.gen(custom_tools)
system_prompt = None
if existing_system_message and request.tool_config.system_message_behavior == SystemMessageBehavior.replace:
system_prompt = existing_system_message.content
tool_template = tool_gen.gen(custom_tools, system_prompt)
sys_content += tool_template.render() sys_content += tool_template.render()
sys_content += "\n" sys_content += "\n"
if existing_system_message: if existing_system_message and request.tool_config.system_message_behavior == SystemMessageBehavior.append:
sys_content += interleaved_content_as_str( sys_content += interleaved_content_as_str(
existing_system_message.content, sep="\n" existing_system_message.content, sep="\n"
) )