mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-07 02:58:21 +00:00
sys_prompt
This commit is contained in:
parent
15dcc4ea5e
commit
6d035b3152
26 changed files with 147 additions and 48 deletions
|
@ -33,6 +33,7 @@ from llama_stack.apis.inference import (
|
||||||
ToolResponse,
|
ToolResponse,
|
||||||
ToolResponseMessage,
|
ToolResponseMessage,
|
||||||
UserMessage,
|
UserMessage,
|
||||||
|
ToolConfig,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.safety import SafetyViolation
|
from llama_stack.apis.safety import SafetyViolation
|
||||||
from llama_stack.apis.tools import ToolDef
|
from llama_stack.apis.tools import ToolDef
|
||||||
|
@ -155,8 +156,13 @@ class AgentConfigCommon(BaseModel):
|
||||||
output_shields: Optional[List[str]] = Field(default_factory=list)
|
output_shields: Optional[List[str]] = Field(default_factory=list)
|
||||||
toolgroups: Optional[List[AgentToolGroup]] = Field(default_factory=list)
|
toolgroups: Optional[List[AgentToolGroup]] = Field(default_factory=list)
|
||||||
client_tools: Optional[List[ToolDef]] = Field(default_factory=list)
|
client_tools: Optional[List[ToolDef]] = Field(default_factory=list)
|
||||||
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
|
tool_choice: Optional[ToolChoice] = Field(
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None)
|
default=ToolChoice.auto, deprecated="use tool_config instead"
|
||||||
|
)
|
||||||
|
tool_prompt_format: Optional[ToolPromptFormat] = Field(
|
||||||
|
default=None, deprecated="use tool_config instead"
|
||||||
|
)
|
||||||
|
tool_config: Optional[ToolConfig] = Field(default=None)
|
||||||
|
|
||||||
max_infer_iters: int = 10
|
max_infer_iters: int = 10
|
||||||
|
|
||||||
|
@ -280,7 +286,7 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
|
||||||
toolgroups: Optional[List[AgentToolGroup]] = None
|
toolgroups: Optional[List[AgentToolGroup]] = None
|
||||||
|
|
||||||
stream: Optional[bool] = False
|
stream: Optional[bool] = False
|
||||||
|
tool_config: Optional[ToolConfig] = None
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class AgentTurnResponseStreamChunk(BaseModel):
|
class AgentTurnResponseStreamChunk(BaseModel):
|
||||||
|
@ -327,6 +333,7 @@ class Agents(Protocol):
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
documents: Optional[List[Document]] = None,
|
documents: Optional[List[Document]] = None,
|
||||||
toolgroups: Optional[List[AgentToolGroup]] = None,
|
toolgroups: Optional[List[AgentToolGroup]] = None,
|
||||||
|
tool_config: Optional[ToolConfig] = None,
|
||||||
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...
|
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...
|
||||||
|
|
||||||
@webmethod(
|
@webmethod(
|
||||||
|
|
|
@ -310,14 +310,48 @@ class CompletionResponseStreamChunk(BaseModel):
|
||||||
logprobs: Optional[List[TokenLogProbs]] = None
|
logprobs: Optional[List[TokenLogProbs]] = None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class SystemMessageBehavior(Enum):
|
||||||
|
"""Config for how to override the default system prompt.
|
||||||
|
|
||||||
|
:cvar append: Appends the provided system message to the default system prompt:
|
||||||
|
https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_2/#-function-definitions-in-the-system-prompt-
|
||||||
|
:cvar replace: Replaces the default system prompt with the provided system message. The system message can include the string
|
||||||
|
'{{function_definitions}}' to indicate where the function definitions should be inserted.
|
||||||
|
"""
|
||||||
|
|
||||||
|
append = "append"
|
||||||
|
replace = "replace"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ToolConfig(BaseModel):
|
||||||
|
"""Configuration for tool use.
|
||||||
|
|
||||||
|
:param tool_choice: (Optional) Whether tool use is required or automatic. Defaults to ToolChoice.auto.
|
||||||
|
:param tool_prompt_format: (Optional) Instructs the model how to format tool calls. By default, Llama Stack will attempt to use a format that is best adapted to the model.
|
||||||
|
- `ToolPromptFormat.json`: The tool calls are formatted as a JSON object.
|
||||||
|
- `ToolPromptFormat.function_tag`: The tool calls are enclosed in a <function=function_name> tag.
|
||||||
|
- `ToolPromptFormat.python_list`: The tool calls are output as Python syntax -- a list of function calls.
|
||||||
|
"""
|
||||||
|
|
||||||
|
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
|
||||||
|
tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None)
|
||||||
|
system_message_behavior: SystemMessageBehavior = Field(
|
||||||
|
default=SystemMessageBehavior.append
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# This is an internally used class
|
# This is an internally used class
|
||||||
|
@json_schema_type
|
||||||
class ChatCompletionRequest(BaseModel):
|
class ChatCompletionRequest(BaseModel):
|
||||||
model: str
|
model: str
|
||||||
messages: List[Message]
|
messages: List[Message]
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams()
|
sampling_params: Optional[SamplingParams] = SamplingParams()
|
||||||
|
|
||||||
tools: Optional[List[ToolDefinition]] = Field(default_factory=list)
|
tools: Optional[List[ToolDefinition]] = Field(default_factory=list)
|
||||||
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
|
tool_config: Optional[ToolConfig] = None
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None)
|
|
||||||
response_format: Optional[ResponseFormat] = None
|
response_format: Optional[ResponseFormat] = None
|
||||||
stream: Optional[bool] = False
|
stream: Optional[bool] = False
|
||||||
logprobs: Optional[LogProbConfig] = None
|
logprobs: Optional[LogProbConfig] = None
|
||||||
|
@ -406,6 +440,7 @@ class Inference(Protocol):
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: Optional[ResponseFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
tool_config: Optional[ToolConfig] = None,
|
||||||
) -> Union[
|
) -> Union[
|
||||||
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
|
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
|
||||||
]:
|
]:
|
||||||
|
@ -416,15 +451,20 @@ class Inference(Protocol):
|
||||||
:param sampling_params: Parameters to control the sampling strategy
|
:param sampling_params: Parameters to control the sampling strategy
|
||||||
:param tools: (Optional) List of tool definitions available to the model
|
:param tools: (Optional) List of tool definitions available to the model
|
||||||
:param tool_choice: (Optional) Whether tool use is required or automatic. Defaults to ToolChoice.auto.
|
:param tool_choice: (Optional) Whether tool use is required or automatic. Defaults to ToolChoice.auto.
|
||||||
|
.. deprecated::
|
||||||
|
Use tool_config instead.
|
||||||
:param tool_prompt_format: (Optional) Instructs the model how to format tool calls. By default, Llama Stack will attempt to use a format that is best adapted to the model.
|
:param tool_prompt_format: (Optional) Instructs the model how to format tool calls. By default, Llama Stack will attempt to use a format that is best adapted to the model.
|
||||||
- `ToolPromptFormat.json`: The tool calls are formatted as a JSON object.
|
- `ToolPromptFormat.json`: The tool calls are formatted as a JSON object.
|
||||||
- `ToolPromptFormat.function_tag`: The tool calls are enclosed in a <function=function_name> tag.
|
- `ToolPromptFormat.function_tag`: The tool calls are enclosed in a <function=function_name> tag.
|
||||||
- `ToolPromptFormat.python_list`: The tool calls are output as Python syntax -- a list of function calls.
|
- `ToolPromptFormat.python_list`: The tool calls are output as Python syntax -- a list of function calls.
|
||||||
|
.. deprecated::
|
||||||
|
Use tool_config instead.
|
||||||
:param response_format: (Optional) Grammar specification for guided (structured) decoding. There are two options:
|
:param response_format: (Optional) Grammar specification for guided (structured) decoding. There are two options:
|
||||||
- `ResponseFormat.json_schema`: The grammar is a JSON schema. Most providers support this format.
|
- `ResponseFormat.json_schema`: The grammar is a JSON schema. Most providers support this format.
|
||||||
- `ResponseFormat.grammar`: The grammar is a BNF grammar. This format is more flexible, but not all providers support it.
|
- `ResponseFormat.grammar`: The grammar is a BNF grammar. This format is more flexible, but not all providers support it.
|
||||||
:param stream: (Optional) If True, generate an SSE event stream of the response. Defaults to False.
|
:param stream: (Optional) If True, generate an SSE event stream of the response. Defaults to False.
|
||||||
:param logprobs: (Optional) If specified, log probabilities for each token position will be returned.
|
:param logprobs: (Optional) If specified, log probabilities for each token position will be returned.
|
||||||
|
:param tool_config: (Optional) Configuration for tool use.
|
||||||
:returns: If stream=False, returns a ChatCompletionResponse with the full completion.
|
:returns: If stream=False, returns a ChatCompletionResponse with the full completion.
|
||||||
If stream=True, returns an SSE event stream of ChatCompletionResponseStreamChunk
|
If stream=True, returns an SSE event stream of ChatCompletionResponseStreamChunk
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -24,6 +24,7 @@ from llama_stack.apis.inference import (
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
|
ToolConfig,
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
|
@ -138,6 +139,7 @@ class InferenceRouter(Inference):
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
tool_config: Optional[ToolConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
model = await self.routing_table.get_model(model_id)
|
model = await self.routing_table.get_model(model_id)
|
||||||
if model is None:
|
if model is None:
|
||||||
|
@ -146,6 +148,20 @@ class InferenceRouter(Inference):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Model '{model_id}' is an embedding model and does not support chat completions"
|
f"Model '{model_id}' is an embedding model and does not support chat completions"
|
||||||
)
|
)
|
||||||
|
if tool_config:
|
||||||
|
if tool_choice != tool_config.tool_choice:
|
||||||
|
raise ValueError(
|
||||||
|
"tool_choice and tool_config.tool_choice must match"
|
||||||
|
)
|
||||||
|
if tool_prompt_format != tool_config.tool_prompt_format:
|
||||||
|
raise ValueError(
|
||||||
|
"tool_prompt_format and tool_config.tool_prompt_format must match"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
tool_config = ToolConfig(
|
||||||
|
tool_choice=tool_choice,
|
||||||
|
tool_prompt_format=tool_prompt_format,
|
||||||
|
)
|
||||||
params = dict(
|
params = dict(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
@ -156,6 +172,7 @@ class InferenceRouter(Inference):
|
||||||
response_format=response_format,
|
response_format=response_format,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
|
tool_config=tool_config,
|
||||||
)
|
)
|
||||||
provider = self.routing_table.get_provider_impl(model_id)
|
provider = self.routing_table.get_provider_impl(model_id)
|
||||||
if stream:
|
if stream:
|
||||||
|
|
|
@ -515,10 +515,11 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
for tool in tool_defs.values()
|
for tool in tool_defs.values()
|
||||||
if tool_to_group.get(tool.tool_name, None) != RAG_TOOL_GROUP
|
if tool_to_group.get(tool.tool_name, None) != RAG_TOOL_GROUP
|
||||||
],
|
],
|
||||||
tool_prompt_format=self.agent_config.tool_prompt_format,
|
tool_prompt_format=self.agent_config.tool_config.tool_prompt_format,
|
||||||
response_format=self.agent_config.response_format,
|
response_format=self.agent_config.response_format,
|
||||||
stream=True,
|
stream=True,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
|
tool_config=self.agent_config.tool_config,
|
||||||
):
|
):
|
||||||
event = chunk.event
|
event = chunk.event
|
||||||
if event.event_type == ChatCompletionResponseEventType.start:
|
if event.event_type == ChatCompletionResponseEventType.start:
|
||||||
|
|
|
@ -25,7 +25,12 @@ from llama_stack.apis.agents import (
|
||||||
Session,
|
Session,
|
||||||
Turn,
|
Turn,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.inference import Inference, ToolResponseMessage, UserMessage
|
from llama_stack.apis.inference import (
|
||||||
|
Inference,
|
||||||
|
ToolConfig,
|
||||||
|
ToolResponseMessage,
|
||||||
|
UserMessage,
|
||||||
|
)
|
||||||
from llama_stack.apis.safety import Safety
|
from llama_stack.apis.safety import Safety
|
||||||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||||
from llama_stack.apis.vector_io import VectorIO
|
from llama_stack.apis.vector_io import VectorIO
|
||||||
|
@ -146,6 +151,7 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
toolgroups: Optional[List[AgentToolGroup]] = None,
|
toolgroups: Optional[List[AgentToolGroup]] = None,
|
||||||
documents: Optional[List[Document]] = None,
|
documents: Optional[List[Document]] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
|
tool_config: Optional[ToolConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
request = AgentTurnCreateRequest(
|
request = AgentTurnCreateRequest(
|
||||||
agent_id=agent_id,
|
agent_id=agent_id,
|
||||||
|
@ -154,6 +160,7 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
stream=True,
|
stream=True,
|
||||||
toolgroups=toolgroups,
|
toolgroups=toolgroups,
|
||||||
documents=documents,
|
documents=documents,
|
||||||
|
tool_config=tool_config,
|
||||||
)
|
)
|
||||||
if stream:
|
if stream:
|
||||||
return self._create_agent_turn_streaming(request)
|
return self._create_agent_turn_streaming(request)
|
||||||
|
|
|
@ -400,7 +400,7 @@ class Llama:
|
||||||
yield from self.generate(
|
yield from self.generate(
|
||||||
model_input=self.formatter.encode_dialog_prompt(
|
model_input=self.formatter.encode_dialog_prompt(
|
||||||
request.messages,
|
request.messages,
|
||||||
request.tool_prompt_format,
|
request.tool_config.tool_prompt_format,
|
||||||
),
|
),
|
||||||
max_gen_len=max_gen_len,
|
max_gen_len=max_gen_len,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
|
|
|
@ -38,6 +38,7 @@ from llama_stack.apis.inference import (
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
TokenLogProbs,
|
TokenLogProbs,
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
|
ToolConfig,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.models import Model, ModelType
|
from llama_stack.apis.models import Model, ModelType
|
||||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||||
|
@ -270,6 +271,7 @@ class MetaReferenceInferenceImpl(
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
tool_config: Optional[ToolConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
if logprobs:
|
if logprobs:
|
||||||
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
|
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
|
||||||
|
@ -280,11 +282,10 @@ class MetaReferenceInferenceImpl(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
tools=tools or [],
|
tools=tools or [],
|
||||||
tool_choice=tool_choice,
|
|
||||||
tool_prompt_format=tool_prompt_format,
|
|
||||||
response_format=response_format,
|
response_format=response_format,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
|
tool_config=tool_config,
|
||||||
)
|
)
|
||||||
self.check_model(request)
|
self.check_model(request)
|
||||||
|
|
||||||
|
|
|
@ -17,6 +17,7 @@ from llama_stack.apis.inference import (
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
|
ToolConfig,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
|
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
|
||||||
from llama_stack.providers.utils.inference.embedding_mixin import (
|
from llama_stack.providers.utils.inference.embedding_mixin import (
|
||||||
|
@ -71,5 +72,6 @@ class SentenceTransformersInferenceImpl(
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
tool_config: Optional[ToolConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
raise ValueError("Sentence transformers don't support chat completion")
|
raise ValueError("Sentence transformers don't support chat completion")
|
||||||
|
|
|
@ -30,6 +30,7 @@ from llama_stack.apis.inference import (
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
|
ToolConfig,
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
|
@ -159,6 +160,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: Optional[ResponseFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
tool_config: Optional[ToolConfig] = None,
|
||||||
) -> ChatCompletionResponse | ChatCompletionResponseStreamChunk:
|
) -> ChatCompletionResponse | ChatCompletionResponseStreamChunk:
|
||||||
assert self.engine is not None
|
assert self.engine is not None
|
||||||
|
|
||||||
|
@ -167,10 +169,9 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
messages=messages,
|
messages=messages,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
tools=tools or [],
|
tools=tools or [],
|
||||||
tool_choice=tool_choice,
|
|
||||||
tool_prompt_format=tool_prompt_format,
|
|
||||||
stream=stream,
|
stream=stream,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
|
tool_config=tool_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
log.info("Sampling params: %s", sampling_params)
|
log.info("Sampling params: %s", sampling_params)
|
||||||
|
|
|
@ -24,6 +24,7 @@ from llama_stack.apis.inference import (
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
|
ToolConfig,
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
|
@ -102,6 +103,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
tool_config: Optional[ToolConfig] = None,
|
||||||
) -> Union[
|
) -> Union[
|
||||||
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
|
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
|
||||||
]:
|
]:
|
||||||
|
@ -111,11 +113,10 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
messages=messages,
|
messages=messages,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
tools=tools or [],
|
tools=tools or [],
|
||||||
tool_choice=tool_choice,
|
|
||||||
tool_prompt_format=tool_prompt_format,
|
|
||||||
response_format=response_format,
|
response_format=response_format,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
|
tool_config=tool_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
|
|
|
@ -24,6 +24,7 @@ from llama_stack.apis.inference import (
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
|
ToolConfig,
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
|
@ -130,6 +131,7 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: Optional[ResponseFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
tool_config: Optional[ToolConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
model = await self.model_store.get_model(model_id)
|
model = await self.model_store.get_model(model_id)
|
||||||
request = ChatCompletionRequest(
|
request = ChatCompletionRequest(
|
||||||
|
@ -142,6 +144,7 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
response_format=response_format,
|
response_format=response_format,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
|
tool_config=tool_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
|
|
|
@ -89,16 +89,16 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
tool_config: Optional[ToolConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
request = ChatCompletionRequest(
|
request = ChatCompletionRequest(
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
tools=tools or [],
|
tools=tools or [],
|
||||||
tool_choice=tool_choice,
|
|
||||||
tool_prompt_format=tool_prompt_format,
|
|
||||||
stream=stream,
|
stream=stream,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
|
tool_config=tool_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
client = OpenAI(base_url=self.config.url, api_key=self.config.api_token)
|
client = OpenAI(base_url=self.config.url, api_key=self.config.api_token)
|
||||||
|
|
|
@ -25,6 +25,7 @@ from llama_stack.apis.inference import (
|
||||||
ResponseFormatType,
|
ResponseFormatType,
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
|
ToolConfig,
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
|
@ -208,6 +209,7 @@ class FireworksInferenceAdapter(
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: Optional[ResponseFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
tool_config: Optional[ToolConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
model = await self.model_store.get_model(model_id)
|
model = await self.model_store.get_model(model_id)
|
||||||
request = ChatCompletionRequest(
|
request = ChatCompletionRequest(
|
||||||
|
@ -215,11 +217,10 @@ class FireworksInferenceAdapter(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
tools=tools or [],
|
tools=tools or [],
|
||||||
tool_choice=tool_choice,
|
|
||||||
tool_prompt_format=tool_prompt_format,
|
|
||||||
response_format=response_format,
|
response_format=response_format,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
|
tool_config=tool_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
|
|
|
@ -99,6 +99,7 @@ class GroqInferenceAdapter(Inference, ModelRegistryHelper, NeedsRequestProviderD
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
tool_config: Optional[ToolConfig] = None,
|
||||||
) -> Union[
|
) -> Union[
|
||||||
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
|
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
|
||||||
]:
|
]:
|
||||||
|
@ -117,10 +118,9 @@ class GroqInferenceAdapter(Inference, ModelRegistryHelper, NeedsRequestProviderD
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
response_format=response_format,
|
response_format=response_format,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
tool_choice=tool_choice,
|
|
||||||
tool_prompt_format=tool_prompt_format,
|
|
||||||
stream=stream,
|
stream=stream,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
|
tool_config=tool_config,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -79,7 +79,7 @@ def convert_chat_completion_request(
|
||||||
# so we exclude it for now
|
# so we exclude it for now
|
||||||
warnings.warn("repetition_penalty is not supported")
|
warnings.warn("repetition_penalty is not supported")
|
||||||
|
|
||||||
if request.tool_prompt_format != ToolPromptFormat.json:
|
if request.tool_config.tool_prompt_format != ToolPromptFormat.json:
|
||||||
warnings.warn("tool_prompt_format is not used by Groq. Ignoring.")
|
warnings.warn("tool_prompt_format is not used by Groq. Ignoring.")
|
||||||
|
|
||||||
sampling_options = get_sampling_strategy_options(request.sampling_params)
|
sampling_options = get_sampling_strategy_options(request.sampling_params)
|
||||||
|
@ -93,7 +93,11 @@ def convert_chat_completion_request(
|
||||||
temperature=sampling_options.get("temperature", 1.0),
|
temperature=sampling_options.get("temperature", 1.0),
|
||||||
top_p=sampling_options.get("top_p", 1.0),
|
top_p=sampling_options.get("top_p", 1.0),
|
||||||
tools=[_convert_groq_tool_definition(tool) for tool in request.tools or []],
|
tools=[_convert_groq_tool_definition(tool) for tool in request.tools or []],
|
||||||
tool_choice=request.tool_choice.value if request.tool_choice else None,
|
tool_choice=(
|
||||||
|
request.tool_config.tool_choice.value
|
||||||
|
if request.tool_config.tool_choice
|
||||||
|
else None
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -178,6 +178,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
tool_config: Optional[ToolConfig] = None,
|
||||||
) -> Union[
|
) -> Union[
|
||||||
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
|
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
|
||||||
]:
|
]:
|
||||||
|
@ -193,10 +194,9 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
response_format=response_format,
|
response_format=response_format,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
tool_choice=tool_choice,
|
|
||||||
tool_prompt_format=tool_prompt_format,
|
|
||||||
stream=stream,
|
stream=stream,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
|
tool_config=tool_config,
|
||||||
),
|
),
|
||||||
n=1,
|
n=1,
|
||||||
)
|
)
|
||||||
|
|
|
@ -253,9 +253,9 @@ def convert_chat_completion_request(
|
||||||
payload.update(
|
payload.update(
|
||||||
tools=[_convert_tooldef_to_openai_tool(tool) for tool in request.tools]
|
tools=[_convert_tooldef_to_openai_tool(tool) for tool in request.tools]
|
||||||
)
|
)
|
||||||
if request.tool_choice:
|
if request.tool_config.tool_choice:
|
||||||
payload.update(
|
payload.update(
|
||||||
tool_choice=request.tool_choice.value
|
tool_choice=request.tool_config.tool_choice.value
|
||||||
) # we cannot include tool_choice w/o tools, server will complain
|
) # we cannot include tool_choice w/o tools, server will complain
|
||||||
|
|
||||||
if request.logprobs:
|
if request.logprobs:
|
||||||
|
|
|
@ -29,6 +29,7 @@ from llama_stack.apis.inference import (
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
|
ToolConfig,
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
|
@ -224,6 +225,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
tool_config: Optional[ToolConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
model = await self.model_store.get_model(model_id)
|
model = await self.model_store.get_model(model_id)
|
||||||
request = ChatCompletionRequest(
|
request = ChatCompletionRequest(
|
||||||
|
@ -231,11 +233,10 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
messages=messages,
|
messages=messages,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
tools=tools or [],
|
tools=tools or [],
|
||||||
tool_choice=tool_choice,
|
|
||||||
tool_prompt_format=tool_prompt_format,
|
|
||||||
stream=stream,
|
stream=stream,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
response_format=response_format,
|
response_format=response_format,
|
||||||
|
tool_config=tool_config,
|
||||||
)
|
)
|
||||||
if stream:
|
if stream:
|
||||||
return self._stream_chat_completion(request)
|
return self._stream_chat_completion(request)
|
||||||
|
@ -322,6 +323,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
self, request: ChatCompletionRequest
|
self, request: ChatCompletionRequest
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
params = await self._get_params(request)
|
params = await self._get_params(request)
|
||||||
|
print(params)
|
||||||
|
|
||||||
async def _generate_and_convert_to_openai_compat():
|
async def _generate_and_convert_to_openai_compat():
|
||||||
if "messages" in params:
|
if "messages" in params:
|
||||||
|
|
|
@ -85,10 +85,9 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
messages=messages,
|
messages=messages,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
tools=tools or [],
|
tools=tools or [],
|
||||||
tool_choice=tool_choice,
|
|
||||||
tool_prompt_format=tool_prompt_format,
|
|
||||||
stream=stream,
|
stream=stream,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
|
tool_config=tool_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
client = OpenAI(base_url=self.config.url, api_key=self.config.api_token)
|
client = OpenAI(base_url=self.config.url, api_key=self.config.api_token)
|
||||||
|
|
|
@ -125,10 +125,9 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
messages=messages,
|
messages=messages,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
tools=tools or [],
|
tools=tools or [],
|
||||||
tool_choice=tool_choice,
|
|
||||||
tool_prompt_format=tool_prompt_format,
|
|
||||||
stream=stream,
|
stream=stream,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
|
tool_config=tool_config,
|
||||||
)
|
)
|
||||||
request_sambanova = await self.convert_chat_completion_request(request)
|
request_sambanova = await self.convert_chat_completion_request(request)
|
||||||
|
|
||||||
|
|
|
@ -26,6 +26,7 @@ from llama_stack.apis.inference import (
|
||||||
ResponseFormatType,
|
ResponseFormatType,
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
|
ToolConfig,
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
|
@ -213,6 +214,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: Optional[ResponseFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
tool_config: Optional[ToolConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
model = await self.model_store.get_model(model_id)
|
model = await self.model_store.get_model(model_id)
|
||||||
request = ChatCompletionRequest(
|
request = ChatCompletionRequest(
|
||||||
|
@ -220,11 +222,10 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
||||||
messages=messages,
|
messages=messages,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
tools=tools or [],
|
tools=tools or [],
|
||||||
tool_choice=tool_choice,
|
|
||||||
tool_prompt_format=tool_prompt_format,
|
|
||||||
response_format=response_format,
|
response_format=response_format,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
|
tool_config=tool_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
|
|
|
@ -24,6 +24,7 @@ from llama_stack.apis.inference import (
|
||||||
ResponseFormatType,
|
ResponseFormatType,
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
|
ToolConfig,
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
|
@ -198,6 +199,7 @@ class TogetherInferenceAdapter(
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: Optional[ResponseFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
tool_config: Optional[ToolConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
model = await self.model_store.get_model(model_id)
|
model = await self.model_store.get_model(model_id)
|
||||||
request = ChatCompletionRequest(
|
request = ChatCompletionRequest(
|
||||||
|
@ -205,11 +207,10 @@ class TogetherInferenceAdapter(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
tools=tools or [],
|
tools=tools or [],
|
||||||
tool_choice=tool_choice,
|
|
||||||
tool_prompt_format=tool_prompt_format,
|
|
||||||
response_format=response_format,
|
response_format=response_format,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
|
tool_config=tool_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
|
|
|
@ -27,6 +27,7 @@ from llama_stack.apis.inference import (
|
||||||
ResponseFormatType,
|
ResponseFormatType,
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
|
ToolConfig,
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
|
@ -119,6 +120,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
tool_config: Optional[ToolConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
model = await self.model_store.get_model(model_id)
|
model = await self.model_store.get_model(model_id)
|
||||||
request = ChatCompletionRequest(
|
request = ChatCompletionRequest(
|
||||||
|
@ -126,11 +128,10 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
messages=messages,
|
messages=messages,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
tools=tools or [],
|
tools=tools or [],
|
||||||
tool_choice=tool_choice,
|
|
||||||
tool_prompt_format=tool_prompt_format,
|
|
||||||
stream=stream,
|
stream=stream,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
response_format=response_format,
|
response_format=response_format,
|
||||||
|
tool_config=tool_config,
|
||||||
)
|
)
|
||||||
if stream:
|
if stream:
|
||||||
return self._stream_chat_completion(request, self.client)
|
return self._stream_chat_completion(request, self.client)
|
||||||
|
|
|
@ -181,7 +181,7 @@ class TestConvertChatCompletionRequest:
|
||||||
|
|
||||||
def test_includes_tool_choice(self):
|
def test_includes_tool_choice(self):
|
||||||
request = self._dummy_chat_completion_request()
|
request = self._dummy_chat_completion_request()
|
||||||
request.tool_choice = ToolChoice.required
|
request.tool_config.tool_choice = ToolChoice.required
|
||||||
|
|
||||||
converted = convert_chat_completion_request(request)
|
converted = convert_chat_completion_request(request)
|
||||||
|
|
||||||
|
|
|
@ -13,7 +13,12 @@ from llama_models.llama3.api.datatypes import (
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
|
|
||||||
from llama_stack.apis.inference import ChatCompletionRequest, SystemMessage, UserMessage
|
from llama_stack.apis.inference import (
|
||||||
|
ChatCompletionRequest,
|
||||||
|
SystemMessage,
|
||||||
|
ToolConfig,
|
||||||
|
UserMessage,
|
||||||
|
)
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
chat_completion_request_to_messages,
|
chat_completion_request_to_messages,
|
||||||
)
|
)
|
||||||
|
@ -73,7 +78,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
tool_prompt_format=ToolPromptFormat.json,
|
tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.json),
|
||||||
)
|
)
|
||||||
messages = chat_completion_request_to_messages(request, MODEL)
|
messages = chat_completion_request_to_messages(request, MODEL)
|
||||||
self.assertEqual(len(messages), 3)
|
self.assertEqual(len(messages), 3)
|
||||||
|
|
|
@ -49,6 +49,7 @@ from llama_stack.apis.inference import (
|
||||||
SystemMessage,
|
SystemMessage,
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
UserMessage,
|
UserMessage,
|
||||||
|
SystemMessageBehavior,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.utils.inference import supported_inference_models
|
from llama_stack.providers.utils.inference import supported_inference_models
|
||||||
|
|
||||||
|
@ -319,7 +320,7 @@ def response_format_prompt(fmt: Optional[ResponseFormat]):
|
||||||
def augment_messages_for_tools_llama_3_1(
|
def augment_messages_for_tools_llama_3_1(
|
||||||
request: ChatCompletionRequest,
|
request: ChatCompletionRequest,
|
||||||
) -> List[Message]:
|
) -> List[Message]:
|
||||||
assert request.tool_choice == ToolChoice.auto, "Only `ToolChoice.auto` supported"
|
assert request.tool_config.tool_choice == ToolChoice.auto, "Only `ToolChoice.auto` supported"
|
||||||
|
|
||||||
existing_messages = request.messages
|
existing_messages = request.messages
|
||||||
existing_system_message = None
|
existing_system_message = None
|
||||||
|
@ -368,7 +369,7 @@ def augment_messages_for_tools_llama_3_1(
|
||||||
|
|
||||||
has_custom_tools = any(isinstance(dfn.tool_name, str) for dfn in request.tools)
|
has_custom_tools = any(isinstance(dfn.tool_name, str) for dfn in request.tools)
|
||||||
if has_custom_tools:
|
if has_custom_tools:
|
||||||
fmt = request.tool_prompt_format or ToolPromptFormat.json
|
fmt = request.tool_config.tool_prompt_format or ToolPromptFormat.json
|
||||||
if fmt == ToolPromptFormat.json:
|
if fmt == ToolPromptFormat.json:
|
||||||
tool_gen = JsonCustomToolGenerator()
|
tool_gen = JsonCustomToolGenerator()
|
||||||
elif fmt == ToolPromptFormat.function_tag:
|
elif fmt == ToolPromptFormat.function_tag:
|
||||||
|
@ -389,7 +390,7 @@ def augment_messages_for_tools_llama_3_1(
|
||||||
def augment_messages_for_tools_llama_3_2(
|
def augment_messages_for_tools_llama_3_2(
|
||||||
request: ChatCompletionRequest,
|
request: ChatCompletionRequest,
|
||||||
) -> List[Message]:
|
) -> List[Message]:
|
||||||
assert request.tool_choice == ToolChoice.auto, "Only `ToolChoice.auto` supported"
|
assert request.tool_config.tool_choice == ToolChoice.auto, "Only `ToolChoice.auto` supported"
|
||||||
|
|
||||||
existing_messages = request.messages
|
existing_messages = request.messages
|
||||||
existing_system_message = None
|
existing_system_message = None
|
||||||
|
@ -419,19 +420,24 @@ def augment_messages_for_tools_llama_3_2(
|
||||||
|
|
||||||
custom_tools = [dfn for dfn in request.tools if isinstance(dfn.tool_name, str)]
|
custom_tools = [dfn for dfn in request.tools if isinstance(dfn.tool_name, str)]
|
||||||
if custom_tools:
|
if custom_tools:
|
||||||
fmt = request.tool_prompt_format or ToolPromptFormat.python_list
|
fmt = request.tool_config.tool_prompt_format or ToolPromptFormat.python_list
|
||||||
if fmt != ToolPromptFormat.python_list:
|
if fmt != ToolPromptFormat.python_list:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Non supported ToolPromptFormat {request.tool_prompt_format}"
|
f"Non supported ToolPromptFormat {request.tool_config.tool_prompt_format}"
|
||||||
)
|
)
|
||||||
|
|
||||||
tool_gen = PythonListCustomToolGenerator()
|
tool_gen = PythonListCustomToolGenerator()
|
||||||
tool_template = tool_gen.gen(custom_tools)
|
|
||||||
|
system_prompt = None
|
||||||
|
if existing_system_message and request.tool_config.system_message_behavior == SystemMessageBehavior.replace:
|
||||||
|
system_prompt = existing_system_message.content
|
||||||
|
|
||||||
|
tool_template = tool_gen.gen(custom_tools, system_prompt)
|
||||||
|
|
||||||
sys_content += tool_template.render()
|
sys_content += tool_template.render()
|
||||||
sys_content += "\n"
|
sys_content += "\n"
|
||||||
|
|
||||||
if existing_system_message:
|
if existing_system_message and request.tool_config.system_message_behavior == SystemMessageBehavior.append:
|
||||||
sys_content += interleaved_content_as_str(
|
sys_content += interleaved_content_as_str(
|
||||||
existing_system_message.content, sep="\n"
|
existing_system_message.content, sep="\n"
|
||||||
)
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue