mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 19:04:19 +00:00
Support sys_prompt behavior in inference (#937)
# What does this PR do? The current default system prompt for llama3.2 tends to overindex on tool calling and doesn't work well when the prompt does not require tool calling. This PR adds an option to override the default system prompt, and organizes tool-related configs into a new config object. - [ ] Addresses issue (#issue) ## Test Plan python -m unittest llama_stack.providers.tests.inference.test_prompt_adapter ## Sources Please link relevant resources if necessary. ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Ran pre-commit to handle lint / formatting issues. - [ ] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [ ] Updated relevant documentation. - [ ] Wrote necessary unit or integration tests. --- [//]: # (BEGIN SAPLING FOOTER) Stack created with [Sapling](https://sapling-scm.com). Best reviewed with [ReviewStack](https://reviewstack.dev/meta-llama/llama-stack/pull/937). * #938 * __->__ #937
This commit is contained in:
parent
62cd3c391e
commit
c9ab72fa82
25 changed files with 308 additions and 48 deletions
|
@ -3161,6 +3161,43 @@
|
||||||
"job_uuid"
|
"job_uuid"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
"ToolConfig": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"tool_choice": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": [
|
||||||
|
"auto",
|
||||||
|
"required"
|
||||||
|
],
|
||||||
|
"default": "auto",
|
||||||
|
"description": "(Optional) Whether tool use is required or automatic. Defaults to ToolChoice.auto."
|
||||||
|
},
|
||||||
|
"tool_prompt_format": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": [
|
||||||
|
"json",
|
||||||
|
"function_tag",
|
||||||
|
"python_list"
|
||||||
|
],
|
||||||
|
"description": "(Optional) Instructs the model how to format tool calls. By default, Llama Stack will attempt to use a format that is best adapted to the model. - `ToolPromptFormat.json`: The tool calls are formatted as a JSON object. - `ToolPromptFormat.function_tag`: The tool calls are enclosed in a <function=function_name> tag. - `ToolPromptFormat.python_list`: The tool calls are output as Python syntax -- a list of function calls."
|
||||||
|
},
|
||||||
|
"system_message_behavior": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": [
|
||||||
|
"append",
|
||||||
|
"replace"
|
||||||
|
],
|
||||||
|
"default": "append",
|
||||||
|
"description": "(Optional) Config for how to override the default system prompt. - `SystemMessageBehavior.append`: Appends the provided system message to the default system prompt. - `SystemMessageBehavior.replace`: Replaces the default system prompt with the provided system message. The system message can include the string '{{function_definitions}}' to indicate where the function definitions should be inserted."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"system_message_behavior"
|
||||||
|
],
|
||||||
|
"title": "Configuration for tool use."
|
||||||
|
},
|
||||||
"ChatCompletionRequest": {
|
"ChatCompletionRequest": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
|
@ -3192,7 +3229,7 @@
|
||||||
"auto",
|
"auto",
|
||||||
"required"
|
"required"
|
||||||
],
|
],
|
||||||
"description": "(Optional) Whether tool use is required or automatic. Defaults to ToolChoice.auto."
|
"description": "(Optional) Whether tool use is required or automatic. Defaults to ToolChoice.auto. .. deprecated:: Use tool_config instead."
|
||||||
},
|
},
|
||||||
"tool_prompt_format": {
|
"tool_prompt_format": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
|
@ -3201,7 +3238,7 @@
|
||||||
"function_tag",
|
"function_tag",
|
||||||
"python_list"
|
"python_list"
|
||||||
],
|
],
|
||||||
"description": "(Optional) Instructs the model how to format tool calls. By default, Llama Stack will attempt to use a format that is best adapted to the model. - `ToolPromptFormat.json`: The tool calls are formatted as a JSON object. - `ToolPromptFormat.function_tag`: The tool calls are enclosed in a <function=function_name> tag. - `ToolPromptFormat.python_list`: The tool calls are output as Python syntax -- a list of function calls."
|
"description": "(Optional) Instructs the model how to format tool calls. By default, Llama Stack will attempt to use a format that is best adapted to the model. - `ToolPromptFormat.json`: The tool calls are formatted as a JSON object. - `ToolPromptFormat.function_tag`: The tool calls are enclosed in a <function=function_name> tag. - `ToolPromptFormat.python_list`: The tool calls are output as Python syntax -- a list of function calls. .. deprecated:: Use tool_config instead."
|
||||||
},
|
},
|
||||||
"response_format": {
|
"response_format": {
|
||||||
"$ref": "#/components/schemas/ResponseFormat",
|
"$ref": "#/components/schemas/ResponseFormat",
|
||||||
|
@ -3222,6 +3259,10 @@
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
"description": "(Optional) If specified, log probabilities for each token position will be returned."
|
"description": "(Optional) If specified, log probabilities for each token position will be returned."
|
||||||
|
},
|
||||||
|
"tool_config": {
|
||||||
|
"$ref": "#/components/schemas/ToolConfig",
|
||||||
|
"description": "(Optional) Configuration for tool use."
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
|
|
|
@ -1956,6 +1956,46 @@ components:
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
required:
|
required:
|
||||||
- job_uuid
|
- job_uuid
|
||||||
|
ToolConfig:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
tool_choice:
|
||||||
|
type: string
|
||||||
|
enum:
|
||||||
|
- auto
|
||||||
|
- required
|
||||||
|
default: auto
|
||||||
|
description: >-
|
||||||
|
(Optional) Whether tool use is required or automatic. Defaults to ToolChoice.auto.
|
||||||
|
tool_prompt_format:
|
||||||
|
type: string
|
||||||
|
enum:
|
||||||
|
- json
|
||||||
|
- function_tag
|
||||||
|
- python_list
|
||||||
|
description: >-
|
||||||
|
(Optional) Instructs the model how to format tool calls. By default, Llama
|
||||||
|
Stack will attempt to use a format that is best adapted to the model.
|
||||||
|
- `ToolPromptFormat.json`: The tool calls are formatted as a JSON object.
|
||||||
|
- `ToolPromptFormat.function_tag`: The tool calls are enclosed in a <function=function_name>
|
||||||
|
tag. - `ToolPromptFormat.python_list`: The tool calls are output as Python
|
||||||
|
syntax -- a list of function calls.
|
||||||
|
system_message_behavior:
|
||||||
|
type: string
|
||||||
|
enum:
|
||||||
|
- append
|
||||||
|
- replace
|
||||||
|
default: append
|
||||||
|
description: >-
|
||||||
|
(Optional) Config for how to override the default system prompt. - `SystemMessageBehavior.append`:
|
||||||
|
Appends the provided system message to the default system prompt. - `SystemMessageBehavior.replace`:
|
||||||
|
Replaces the default system prompt with the provided system message. The
|
||||||
|
system message can include the string '{{function_definitions}}' to indicate
|
||||||
|
where the function definitions should be inserted.
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- system_message_behavior
|
||||||
|
title: Configuration for tool use.
|
||||||
ChatCompletionRequest:
|
ChatCompletionRequest:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
|
@ -1986,6 +2026,7 @@ components:
|
||||||
- required
|
- required
|
||||||
description: >-
|
description: >-
|
||||||
(Optional) Whether tool use is required or automatic. Defaults to ToolChoice.auto.
|
(Optional) Whether tool use is required or automatic. Defaults to ToolChoice.auto.
|
||||||
|
.. deprecated:: Use tool_config instead.
|
||||||
tool_prompt_format:
|
tool_prompt_format:
|
||||||
type: string
|
type: string
|
||||||
enum:
|
enum:
|
||||||
|
@ -1998,7 +2039,7 @@ components:
|
||||||
- `ToolPromptFormat.json`: The tool calls are formatted as a JSON object.
|
- `ToolPromptFormat.json`: The tool calls are formatted as a JSON object.
|
||||||
- `ToolPromptFormat.function_tag`: The tool calls are enclosed in a <function=function_name>
|
- `ToolPromptFormat.function_tag`: The tool calls are enclosed in a <function=function_name>
|
||||||
tag. - `ToolPromptFormat.python_list`: The tool calls are output as Python
|
tag. - `ToolPromptFormat.python_list`: The tool calls are output as Python
|
||||||
syntax -- a list of function calls.
|
syntax -- a list of function calls. .. deprecated:: Use tool_config instead.
|
||||||
response_format:
|
response_format:
|
||||||
$ref: '#/components/schemas/ResponseFormat'
|
$ref: '#/components/schemas/ResponseFormat'
|
||||||
description: >-
|
description: >-
|
||||||
|
@ -2024,6 +2065,9 @@ components:
|
||||||
description: >-
|
description: >-
|
||||||
(Optional) If specified, log probabilities for each token position will
|
(Optional) If specified, log probabilities for each token position will
|
||||||
be returned.
|
be returned.
|
||||||
|
tool_config:
|
||||||
|
$ref: '#/components/schemas/ToolConfig'
|
||||||
|
description: (Optional) Configuration for tool use.
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
required:
|
required:
|
||||||
- model_id
|
- model_id
|
||||||
|
|
|
@ -308,14 +308,49 @@ class CompletionResponseStreamChunk(BaseModel):
|
||||||
logprobs: Optional[List[TokenLogProbs]] = None
|
logprobs: Optional[List[TokenLogProbs]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class SystemMessageBehavior(Enum):
|
||||||
|
"""Config for how to override the default system prompt.
|
||||||
|
|
||||||
|
:cvar append: Appends the provided system message to the default system prompt:
|
||||||
|
https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_2/#-function-definitions-in-the-system-prompt-
|
||||||
|
:cvar replace: Replaces the default system prompt with the provided system message. The system message can include the string
|
||||||
|
'{{function_definitions}}' to indicate where the function definitions should be inserted.
|
||||||
|
"""
|
||||||
|
|
||||||
|
append = "append"
|
||||||
|
replace = "replace"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ToolConfig(BaseModel):
|
||||||
|
"""Configuration for tool use.
|
||||||
|
|
||||||
|
:param tool_choice: (Optional) Whether tool use is required or automatic. Defaults to ToolChoice.auto.
|
||||||
|
:param tool_prompt_format: (Optional) Instructs the model how to format tool calls. By default, Llama Stack will attempt to use a format that is best adapted to the model.
|
||||||
|
- `ToolPromptFormat.json`: The tool calls are formatted as a JSON object.
|
||||||
|
- `ToolPromptFormat.function_tag`: The tool calls are enclosed in a <function=function_name> tag.
|
||||||
|
- `ToolPromptFormat.python_list`: The tool calls are output as Python syntax -- a list of function calls.
|
||||||
|
:param system_message_behavior: (Optional) Config for how to override the default system prompt.
|
||||||
|
- `SystemMessageBehavior.append`: Appends the provided system message to the default system prompt.
|
||||||
|
- `SystemMessageBehavior.replace`: Replaces the default system prompt with the provided system message. The system message can include the string
|
||||||
|
'{{function_definitions}}' to indicate where the function definitions should be inserted.
|
||||||
|
"""
|
||||||
|
|
||||||
|
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
|
||||||
|
tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None)
|
||||||
|
system_message_behavior: SystemMessageBehavior = Field(default=SystemMessageBehavior.append)
|
||||||
|
|
||||||
|
|
||||||
# This is an internally used class
|
# This is an internally used class
|
||||||
|
@json_schema_type
|
||||||
class ChatCompletionRequest(BaseModel):
|
class ChatCompletionRequest(BaseModel):
|
||||||
model: str
|
model: str
|
||||||
messages: List[Message]
|
messages: List[Message]
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams()
|
sampling_params: Optional[SamplingParams] = SamplingParams()
|
||||||
|
|
||||||
tools: Optional[List[ToolDefinition]] = Field(default_factory=list)
|
tools: Optional[List[ToolDefinition]] = Field(default_factory=list)
|
||||||
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
|
tool_config: Optional[ToolConfig] = Field(default_factory=ToolConfig)
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None)
|
|
||||||
response_format: Optional[ResponseFormat] = None
|
response_format: Optional[ResponseFormat] = None
|
||||||
stream: Optional[bool] = False
|
stream: Optional[bool] = False
|
||||||
logprobs: Optional[LogProbConfig] = None
|
logprobs: Optional[LogProbConfig] = None
|
||||||
|
@ -404,6 +439,7 @@ class Inference(Protocol):
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: Optional[ResponseFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
tool_config: Optional[ToolConfig] = None,
|
||||||
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
|
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
|
||||||
"""Generate a chat completion for the given messages using the specified model.
|
"""Generate a chat completion for the given messages using the specified model.
|
||||||
|
|
||||||
|
@ -412,15 +448,20 @@ class Inference(Protocol):
|
||||||
:param sampling_params: Parameters to control the sampling strategy
|
:param sampling_params: Parameters to control the sampling strategy
|
||||||
:param tools: (Optional) List of tool definitions available to the model
|
:param tools: (Optional) List of tool definitions available to the model
|
||||||
:param tool_choice: (Optional) Whether tool use is required or automatic. Defaults to ToolChoice.auto.
|
:param tool_choice: (Optional) Whether tool use is required or automatic. Defaults to ToolChoice.auto.
|
||||||
|
.. deprecated::
|
||||||
|
Use tool_config instead.
|
||||||
:param tool_prompt_format: (Optional) Instructs the model how to format tool calls. By default, Llama Stack will attempt to use a format that is best adapted to the model.
|
:param tool_prompt_format: (Optional) Instructs the model how to format tool calls. By default, Llama Stack will attempt to use a format that is best adapted to the model.
|
||||||
- `ToolPromptFormat.json`: The tool calls are formatted as a JSON object.
|
- `ToolPromptFormat.json`: The tool calls are formatted as a JSON object.
|
||||||
- `ToolPromptFormat.function_tag`: The tool calls are enclosed in a <function=function_name> tag.
|
- `ToolPromptFormat.function_tag`: The tool calls are enclosed in a <function=function_name> tag.
|
||||||
- `ToolPromptFormat.python_list`: The tool calls are output as Python syntax -- a list of function calls.
|
- `ToolPromptFormat.python_list`: The tool calls are output as Python syntax -- a list of function calls.
|
||||||
|
.. deprecated::
|
||||||
|
Use tool_config instead.
|
||||||
:param response_format: (Optional) Grammar specification for guided (structured) decoding. There are two options:
|
:param response_format: (Optional) Grammar specification for guided (structured) decoding. There are two options:
|
||||||
- `ResponseFormat.json_schema`: The grammar is a JSON schema. Most providers support this format.
|
- `ResponseFormat.json_schema`: The grammar is a JSON schema. Most providers support this format.
|
||||||
- `ResponseFormat.grammar`: The grammar is a BNF grammar. This format is more flexible, but not all providers support it.
|
- `ResponseFormat.grammar`: The grammar is a BNF grammar. This format is more flexible, but not all providers support it.
|
||||||
:param stream: (Optional) If True, generate an SSE event stream of the response. Defaults to False.
|
:param stream: (Optional) If True, generate an SSE event stream of the response. Defaults to False.
|
||||||
:param logprobs: (Optional) If specified, log probabilities for each token position will be returned.
|
:param logprobs: (Optional) If specified, log probabilities for each token position will be returned.
|
||||||
|
:param tool_config: (Optional) Configuration for tool use.
|
||||||
:returns: If stream=False, returns a ChatCompletionResponse with the full completion.
|
:returns: If stream=False, returns a ChatCompletionResponse with the full completion.
|
||||||
If stream=True, returns an SSE event stream of ChatCompletionResponseStreamChunk
|
If stream=True, returns an SSE event stream of ChatCompletionResponseStreamChunk
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -24,6 +24,7 @@ from llama_stack.apis.inference import (
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
|
ToolConfig,
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
|
@ -132,12 +133,23 @@ class InferenceRouter(Inference):
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
tool_config: Optional[ToolConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
model = await self.routing_table.get_model(model_id)
|
model = await self.routing_table.get_model(model_id)
|
||||||
if model is None:
|
if model is None:
|
||||||
raise ValueError(f"Model '{model_id}' not found")
|
raise ValueError(f"Model '{model_id}' not found")
|
||||||
if model.model_type == ModelType.embedding:
|
if model.model_type == ModelType.embedding:
|
||||||
raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions")
|
raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions")
|
||||||
|
if tool_config:
|
||||||
|
if tool_choice != tool_config.tool_choice:
|
||||||
|
raise ValueError("tool_choice and tool_config.tool_choice must match")
|
||||||
|
if tool_prompt_format != tool_config.tool_prompt_format:
|
||||||
|
raise ValueError("tool_prompt_format and tool_config.tool_prompt_format must match")
|
||||||
|
else:
|
||||||
|
tool_config = ToolConfig(
|
||||||
|
tool_choice=tool_choice,
|
||||||
|
tool_prompt_format=tool_prompt_format,
|
||||||
|
)
|
||||||
params = dict(
|
params = dict(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
@ -148,6 +160,7 @@ class InferenceRouter(Inference):
|
||||||
response_format=response_format,
|
response_format=response_format,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
|
tool_config=tool_config,
|
||||||
)
|
)
|
||||||
provider = self.routing_table.get_provider_impl(model_id)
|
provider = self.routing_table.get_provider_impl(model_id)
|
||||||
if stream:
|
if stream:
|
||||||
|
|
|
@ -400,7 +400,7 @@ class Llama:
|
||||||
yield from self.generate(
|
yield from self.generate(
|
||||||
model_input=self.formatter.encode_dialog_prompt(
|
model_input=self.formatter.encode_dialog_prompt(
|
||||||
request.messages,
|
request.messages,
|
||||||
request.tool_prompt_format,
|
request.tool_config.tool_prompt_format,
|
||||||
),
|
),
|
||||||
max_gen_len=max_gen_len,
|
max_gen_len=max_gen_len,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
|
|
|
@ -38,6 +38,7 @@ from llama_stack.apis.inference import (
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
TokenLogProbs,
|
TokenLogProbs,
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
|
ToolConfig,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.models import Model, ModelType
|
from llama_stack.apis.models import Model, ModelType
|
||||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||||
|
@ -252,6 +253,7 @@ class MetaReferenceInferenceImpl(
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
tool_config: Optional[ToolConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
if logprobs:
|
if logprobs:
|
||||||
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
|
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
|
||||||
|
@ -262,11 +264,10 @@ class MetaReferenceInferenceImpl(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
tools=tools or [],
|
tools=tools or [],
|
||||||
tool_choice=tool_choice,
|
|
||||||
tool_prompt_format=tool_prompt_format,
|
|
||||||
response_format=response_format,
|
response_format=response_format,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
|
tool_config=tool_config,
|
||||||
)
|
)
|
||||||
self.check_model(request)
|
self.check_model(request)
|
||||||
|
|
||||||
|
|
|
@ -17,6 +17,7 @@ from llama_stack.apis.inference import (
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
|
ToolConfig,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
|
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
|
||||||
from llama_stack.providers.utils.inference.embedding_mixin import (
|
from llama_stack.providers.utils.inference.embedding_mixin import (
|
||||||
|
@ -71,5 +72,6 @@ class SentenceTransformersInferenceImpl(
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
tool_config: Optional[ToolConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
raise ValueError("Sentence transformers don't support chat completion")
|
raise ValueError("Sentence transformers don't support chat completion")
|
||||||
|
|
|
@ -30,6 +30,7 @@ from llama_stack.apis.inference import (
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
|
ToolConfig,
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
|
@ -159,6 +160,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: Optional[ResponseFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
tool_config: Optional[ToolConfig] = None,
|
||||||
) -> ChatCompletionResponse | ChatCompletionResponseStreamChunk:
|
) -> ChatCompletionResponse | ChatCompletionResponseStreamChunk:
|
||||||
assert self.engine is not None
|
assert self.engine is not None
|
||||||
|
|
||||||
|
@ -167,10 +169,9 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
messages=messages,
|
messages=messages,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
tools=tools or [],
|
tools=tools or [],
|
||||||
tool_choice=tool_choice,
|
|
||||||
tool_prompt_format=tool_prompt_format,
|
|
||||||
stream=stream,
|
stream=stream,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
|
tool_config=tool_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
log.info("Sampling params: %s", sampling_params)
|
log.info("Sampling params: %s", sampling_params)
|
||||||
|
|
|
@ -24,6 +24,7 @@ from llama_stack.apis.inference import (
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
|
ToolConfig,
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
|
@ -102,6 +103,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
tool_config: Optional[ToolConfig] = None,
|
||||||
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
|
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
|
||||||
model = await self.model_store.get_model(model_id)
|
model = await self.model_store.get_model(model_id)
|
||||||
request = ChatCompletionRequest(
|
request = ChatCompletionRequest(
|
||||||
|
@ -109,11 +111,10 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
messages=messages,
|
messages=messages,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
tools=tools or [],
|
tools=tools or [],
|
||||||
tool_choice=tool_choice,
|
|
||||||
tool_prompt_format=tool_prompt_format,
|
|
||||||
response_format=response_format,
|
response_format=response_format,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
|
tool_config=tool_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
|
|
|
@ -24,6 +24,7 @@ from llama_stack.apis.inference import (
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
|
ToolConfig,
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
|
@ -128,6 +129,7 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: Optional[ResponseFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
tool_config: Optional[ToolConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
model = await self.model_store.get_model(model_id)
|
model = await self.model_store.get_model(model_id)
|
||||||
request = ChatCompletionRequest(
|
request = ChatCompletionRequest(
|
||||||
|
@ -140,6 +142,7 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
response_format=response_format,
|
response_format=response_format,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
|
tool_config=tool_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
|
|
|
@ -89,16 +89,16 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
tool_config: Optional[ToolConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
request = ChatCompletionRequest(
|
request = ChatCompletionRequest(
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
tools=tools or [],
|
tools=tools or [],
|
||||||
tool_choice=tool_choice,
|
|
||||||
tool_prompt_format=tool_prompt_format,
|
|
||||||
stream=stream,
|
stream=stream,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
|
tool_config=tool_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
client = OpenAI(base_url=self.config.url, api_key=self.config.api_token)
|
client = OpenAI(base_url=self.config.url, api_key=self.config.api_token)
|
||||||
|
|
|
@ -25,6 +25,7 @@ from llama_stack.apis.inference import (
|
||||||
ResponseFormatType,
|
ResponseFormatType,
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
|
ToolConfig,
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
|
@ -204,6 +205,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: Optional[ResponseFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
tool_config: Optional[ToolConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
model = await self.model_store.get_model(model_id)
|
model = await self.model_store.get_model(model_id)
|
||||||
request = ChatCompletionRequest(
|
request = ChatCompletionRequest(
|
||||||
|
@ -211,11 +213,10 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
||||||
messages=messages,
|
messages=messages,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
tools=tools or [],
|
tools=tools or [],
|
||||||
tool_choice=tool_choice,
|
|
||||||
tool_prompt_format=tool_prompt_format,
|
|
||||||
response_format=response_format,
|
response_format=response_format,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
|
tool_config=tool_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
|
|
|
@ -99,6 +99,7 @@ class GroqInferenceAdapter(Inference, ModelRegistryHelper, NeedsRequestProviderD
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
tool_config: Optional[ToolConfig] = None,
|
||||||
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
|
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
|
||||||
model_id = self.get_provider_model_id(model_id)
|
model_id = self.get_provider_model_id(model_id)
|
||||||
if model_id == "llama-3.2-3b-preview":
|
if model_id == "llama-3.2-3b-preview":
|
||||||
|
@ -115,10 +116,9 @@ class GroqInferenceAdapter(Inference, ModelRegistryHelper, NeedsRequestProviderD
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
response_format=response_format,
|
response_format=response_format,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
tool_choice=tool_choice,
|
|
||||||
tool_prompt_format=tool_prompt_format,
|
|
||||||
stream=stream,
|
stream=stream,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
|
tool_config=tool_config,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -79,7 +79,7 @@ def convert_chat_completion_request(
|
||||||
# so we exclude it for now
|
# so we exclude it for now
|
||||||
warnings.warn("repetition_penalty is not supported")
|
warnings.warn("repetition_penalty is not supported")
|
||||||
|
|
||||||
if request.tool_prompt_format != ToolPromptFormat.json:
|
if request.tool_config.tool_prompt_format != ToolPromptFormat.json:
|
||||||
warnings.warn("tool_prompt_format is not used by Groq. Ignoring.")
|
warnings.warn("tool_prompt_format is not used by Groq. Ignoring.")
|
||||||
|
|
||||||
sampling_options = get_sampling_strategy_options(request.sampling_params)
|
sampling_options = get_sampling_strategy_options(request.sampling_params)
|
||||||
|
@ -93,7 +93,7 @@ def convert_chat_completion_request(
|
||||||
temperature=sampling_options.get("temperature", 1.0),
|
temperature=sampling_options.get("temperature", 1.0),
|
||||||
top_p=sampling_options.get("top_p", 1.0),
|
top_p=sampling_options.get("top_p", 1.0),
|
||||||
tools=[_convert_groq_tool_definition(tool) for tool in request.tools or []],
|
tools=[_convert_groq_tool_definition(tool) for tool in request.tools or []],
|
||||||
tool_choice=request.tool_choice.value if request.tool_choice else None,
|
tool_choice=(request.tool_config.tool_choice.value if request.tool_config.tool_choice else None),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -171,6 +171,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
tool_config: Optional[ToolConfig] = None,
|
||||||
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
|
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
|
||||||
if tool_prompt_format:
|
if tool_prompt_format:
|
||||||
warnings.warn("tool_prompt_format is not supported by NVIDIA NIM, ignoring")
|
warnings.warn("tool_prompt_format is not supported by NVIDIA NIM, ignoring")
|
||||||
|
@ -184,10 +185,9 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
response_format=response_format,
|
response_format=response_format,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
tool_choice=tool_choice,
|
|
||||||
tool_prompt_format=tool_prompt_format,
|
|
||||||
stream=stream,
|
stream=stream,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
|
tool_config=tool_config,
|
||||||
),
|
),
|
||||||
n=1,
|
n=1,
|
||||||
)
|
)
|
||||||
|
|
|
@ -282,9 +282,9 @@ async def convert_chat_completion_request(
|
||||||
|
|
||||||
if request.tools:
|
if request.tools:
|
||||||
payload.update(tools=[_convert_tooldef_to_openai_tool(tool) for tool in request.tools])
|
payload.update(tools=[_convert_tooldef_to_openai_tool(tool) for tool in request.tools])
|
||||||
if request.tool_choice:
|
if request.tool_config.tool_choice:
|
||||||
payload.update(
|
payload.update(
|
||||||
tool_choice=request.tool_choice.value
|
tool_choice=request.tool_config.tool_choice.value
|
||||||
) # we cannot include tool_choice w/o tools, server will complain
|
) # we cannot include tool_choice w/o tools, server will complain
|
||||||
|
|
||||||
if request.logprobs:
|
if request.logprobs:
|
||||||
|
|
|
@ -29,6 +29,7 @@ from llama_stack.apis.inference import (
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
|
ToolConfig,
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
|
@ -224,6 +225,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
tool_config: Optional[ToolConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
model = await self.model_store.get_model(model_id)
|
model = await self.model_store.get_model(model_id)
|
||||||
request = ChatCompletionRequest(
|
request = ChatCompletionRequest(
|
||||||
|
@ -231,11 +233,10 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
messages=messages,
|
messages=messages,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
tools=tools or [],
|
tools=tools or [],
|
||||||
tool_choice=tool_choice,
|
|
||||||
tool_prompt_format=tool_prompt_format,
|
|
||||||
stream=stream,
|
stream=stream,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
response_format=response_format,
|
response_format=response_format,
|
||||||
|
tool_config=tool_config,
|
||||||
)
|
)
|
||||||
if stream:
|
if stream:
|
||||||
return self._stream_chat_completion(request)
|
return self._stream_chat_completion(request)
|
||||||
|
|
|
@ -83,10 +83,9 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
messages=messages,
|
messages=messages,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
tools=tools or [],
|
tools=tools or [],
|
||||||
tool_choice=tool_choice,
|
|
||||||
tool_prompt_format=tool_prompt_format,
|
|
||||||
stream=stream,
|
stream=stream,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
|
tool_config=tool_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
client = OpenAI(base_url=self.config.url, api_key=self.config.api_token)
|
client = OpenAI(base_url=self.config.url, api_key=self.config.api_token)
|
||||||
|
|
|
@ -125,10 +125,9 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
messages=messages,
|
messages=messages,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
tools=tools or [],
|
tools=tools or [],
|
||||||
tool_choice=tool_choice,
|
|
||||||
tool_prompt_format=tool_prompt_format,
|
|
||||||
stream=stream,
|
stream=stream,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
|
tool_config=tool_config,
|
||||||
)
|
)
|
||||||
request_sambanova = await self.convert_chat_completion_request(request)
|
request_sambanova = await self.convert_chat_completion_request(request)
|
||||||
|
|
||||||
|
|
|
@ -26,6 +26,7 @@ from llama_stack.apis.inference import (
|
||||||
ResponseFormatType,
|
ResponseFormatType,
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
|
ToolConfig,
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
|
@ -205,6 +206,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: Optional[ResponseFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
tool_config: Optional[ToolConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
model = await self.model_store.get_model(model_id)
|
model = await self.model_store.get_model(model_id)
|
||||||
request = ChatCompletionRequest(
|
request = ChatCompletionRequest(
|
||||||
|
@ -212,11 +214,10 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
||||||
messages=messages,
|
messages=messages,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
tools=tools or [],
|
tools=tools or [],
|
||||||
tool_choice=tool_choice,
|
|
||||||
tool_prompt_format=tool_prompt_format,
|
|
||||||
response_format=response_format,
|
response_format=response_format,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
|
tool_config=tool_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
|
|
|
@ -24,6 +24,7 @@ from llama_stack.apis.inference import (
|
||||||
ResponseFormatType,
|
ResponseFormatType,
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
|
ToolConfig,
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
|
@ -194,6 +195,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: Optional[ResponseFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
tool_config: Optional[ToolConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
model = await self.model_store.get_model(model_id)
|
model = await self.model_store.get_model(model_id)
|
||||||
request = ChatCompletionRequest(
|
request = ChatCompletionRequest(
|
||||||
|
@ -201,11 +203,10 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
||||||
messages=messages,
|
messages=messages,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
tools=tools or [],
|
tools=tools or [],
|
||||||
tool_choice=tool_choice,
|
|
||||||
tool_prompt_format=tool_prompt_format,
|
|
||||||
response_format=response_format,
|
response_format=response_format,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
|
tool_config=tool_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
|
|
|
@ -27,6 +27,7 @@ from llama_stack.apis.inference import (
|
||||||
ResponseFormatType,
|
ResponseFormatType,
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
|
ToolConfig,
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
|
@ -119,6 +120,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
tool_config: Optional[ToolConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
model = await self.model_store.get_model(model_id)
|
model = await self.model_store.get_model(model_id)
|
||||||
request = ChatCompletionRequest(
|
request = ChatCompletionRequest(
|
||||||
|
@ -126,11 +128,10 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
messages=messages,
|
messages=messages,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
tools=tools or [],
|
tools=tools or [],
|
||||||
tool_choice=tool_choice,
|
|
||||||
tool_prompt_format=tool_prompt_format,
|
|
||||||
stream=stream,
|
stream=stream,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
response_format=response_format,
|
response_format=response_format,
|
||||||
|
tool_config=tool_config,
|
||||||
)
|
)
|
||||||
if stream:
|
if stream:
|
||||||
return self._stream_chat_completion(request, self.client)
|
return self._stream_chat_completion(request, self.client)
|
||||||
|
|
|
@ -179,7 +179,7 @@ class TestConvertChatCompletionRequest:
|
||||||
|
|
||||||
def test_includes_tool_choice(self):
|
def test_includes_tool_choice(self):
|
||||||
request = self._dummy_chat_completion_request()
|
request = self._dummy_chat_completion_request()
|
||||||
request.tool_choice = ToolChoice.required
|
request.tool_config.tool_choice = ToolChoice.required
|
||||||
|
|
||||||
converted = convert_chat_completion_request(request)
|
converted = convert_chat_completion_request(request)
|
||||||
|
|
||||||
|
|
|
@ -13,12 +13,18 @@ from llama_models.llama3.api.datatypes import (
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
|
|
||||||
from llama_stack.apis.inference import ChatCompletionRequest, SystemMessage, UserMessage
|
from llama_stack.apis.inference import (
|
||||||
|
ChatCompletionRequest,
|
||||||
|
SystemMessage,
|
||||||
|
ToolConfig,
|
||||||
|
UserMessage,
|
||||||
|
)
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
chat_completion_request_to_messages,
|
chat_completion_request_to_messages,
|
||||||
)
|
)
|
||||||
|
|
||||||
MODEL = "Llama3.1-8B-Instruct"
|
MODEL = "Llama3.1-8B-Instruct"
|
||||||
|
MODEL3_2 = "Llama3.2-3B-Instruct"
|
||||||
|
|
||||||
|
|
||||||
class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
|
class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
|
||||||
|
@ -73,7 +79,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
tool_prompt_format=ToolPromptFormat.json,
|
tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.json),
|
||||||
)
|
)
|
||||||
messages = chat_completion_request_to_messages(request, MODEL)
|
messages = chat_completion_request_to_messages(request, MODEL)
|
||||||
self.assertEqual(len(messages), 3)
|
self.assertEqual(len(messages), 3)
|
||||||
|
@ -132,3 +138,101 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
|
||||||
self.assertTrue(messages[0].content.endswith(system_prompt))
|
self.assertTrue(messages[0].content.endswith(system_prompt))
|
||||||
|
|
||||||
self.assertEqual(messages[-1].content, content)
|
self.assertEqual(messages[-1].content, content)
|
||||||
|
|
||||||
|
async def test_repalce_system_message_behavior_builtin_tools(self):
|
||||||
|
content = "Hello !"
|
||||||
|
system_prompt = "You are a pirate"
|
||||||
|
request = ChatCompletionRequest(
|
||||||
|
model=MODEL,
|
||||||
|
messages=[
|
||||||
|
SystemMessage(content=system_prompt),
|
||||||
|
UserMessage(content=content),
|
||||||
|
],
|
||||||
|
tools=[
|
||||||
|
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
||||||
|
],
|
||||||
|
tool_config=ToolConfig(
|
||||||
|
tool_choice="auto",
|
||||||
|
tool_prompt_format="python_list",
|
||||||
|
system_message_behavior="replace",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
messages = chat_completion_request_to_messages(request, MODEL3_2)
|
||||||
|
self.assertEqual(len(messages), 2, messages)
|
||||||
|
self.assertTrue(messages[0].content.endswith(system_prompt))
|
||||||
|
self.assertIn("Environment: ipython", messages[0].content)
|
||||||
|
self.assertEqual(messages[-1].content, content)
|
||||||
|
|
||||||
|
async def test_repalce_system_message_behavior_custom_tools(self):
|
||||||
|
content = "Hello !"
|
||||||
|
system_prompt = "You are a pirate"
|
||||||
|
request = ChatCompletionRequest(
|
||||||
|
model=MODEL,
|
||||||
|
messages=[
|
||||||
|
SystemMessage(content=system_prompt),
|
||||||
|
UserMessage(content=content),
|
||||||
|
],
|
||||||
|
tools=[
|
||||||
|
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
||||||
|
ToolDefinition(
|
||||||
|
tool_name="custom1",
|
||||||
|
description="custom1 tool",
|
||||||
|
parameters={
|
||||||
|
"param1": ToolParamDefinition(
|
||||||
|
param_type="str",
|
||||||
|
description="param1 description",
|
||||||
|
required=True,
|
||||||
|
),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
],
|
||||||
|
tool_config=ToolConfig(
|
||||||
|
tool_choice="auto",
|
||||||
|
tool_prompt_format="python_list",
|
||||||
|
system_message_behavior="replace",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
messages = chat_completion_request_to_messages(request, MODEL3_2)
|
||||||
|
|
||||||
|
self.assertEqual(len(messages), 2, messages)
|
||||||
|
self.assertTrue(messages[0].content.endswith(system_prompt))
|
||||||
|
self.assertIn("Environment: ipython", messages[0].content)
|
||||||
|
self.assertEqual(messages[-1].content, content)
|
||||||
|
|
||||||
|
async def test_replace_system_message_behavior_custom_tools_with_template(self):
|
||||||
|
content = "Hello !"
|
||||||
|
system_prompt = "You are a pirate {{ function_description }}"
|
||||||
|
request = ChatCompletionRequest(
|
||||||
|
model=MODEL,
|
||||||
|
messages=[
|
||||||
|
SystemMessage(content=system_prompt),
|
||||||
|
UserMessage(content=content),
|
||||||
|
],
|
||||||
|
tools=[
|
||||||
|
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
||||||
|
ToolDefinition(
|
||||||
|
tool_name="custom1",
|
||||||
|
description="custom1 tool",
|
||||||
|
parameters={
|
||||||
|
"param1": ToolParamDefinition(
|
||||||
|
param_type="str",
|
||||||
|
description="param1 description",
|
||||||
|
required=True,
|
||||||
|
),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
],
|
||||||
|
tool_config=ToolConfig(
|
||||||
|
tool_choice="auto",
|
||||||
|
tool_prompt_format="python_list",
|
||||||
|
system_message_behavior="replace",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
messages = chat_completion_request_to_messages(request, MODEL3_2)
|
||||||
|
|
||||||
|
self.assertEqual(len(messages), 2, messages)
|
||||||
|
self.assertIn("Environment: ipython", messages[0].content)
|
||||||
|
self.assertIn("You are a pirate", messages[0].content)
|
||||||
|
# function description is present in the system prompt
|
||||||
|
self.assertIn('"name": "custom1"', messages[0].content)
|
||||||
|
self.assertEqual(messages[-1].content, content)
|
||||||
|
|
|
@ -49,6 +49,7 @@ from llama_stack.apis.inference import (
|
||||||
SystemMessage,
|
SystemMessage,
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
UserMessage,
|
UserMessage,
|
||||||
|
SystemMessageBehavior,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.utils.inference import supported_inference_models
|
from llama_stack.providers.utils.inference import supported_inference_models
|
||||||
|
|
||||||
|
@ -309,7 +310,7 @@ def response_format_prompt(fmt: Optional[ResponseFormat]):
|
||||||
def augment_messages_for_tools_llama_3_1(
|
def augment_messages_for_tools_llama_3_1(
|
||||||
request: ChatCompletionRequest,
|
request: ChatCompletionRequest,
|
||||||
) -> List[Message]:
|
) -> List[Message]:
|
||||||
assert request.tool_choice == ToolChoice.auto, "Only `ToolChoice.auto` supported"
|
assert request.tool_config.tool_choice == ToolChoice.auto, "Only `ToolChoice.auto` supported"
|
||||||
|
|
||||||
existing_messages = request.messages
|
existing_messages = request.messages
|
||||||
existing_system_message = None
|
existing_system_message = None
|
||||||
|
@ -354,7 +355,7 @@ def augment_messages_for_tools_llama_3_1(
|
||||||
|
|
||||||
has_custom_tools = any(isinstance(dfn.tool_name, str) for dfn in request.tools)
|
has_custom_tools = any(isinstance(dfn.tool_name, str) for dfn in request.tools)
|
||||||
if has_custom_tools:
|
if has_custom_tools:
|
||||||
fmt = request.tool_prompt_format or ToolPromptFormat.json
|
fmt = request.tool_config.tool_prompt_format or ToolPromptFormat.json
|
||||||
if fmt == ToolPromptFormat.json:
|
if fmt == ToolPromptFormat.json:
|
||||||
tool_gen = JsonCustomToolGenerator()
|
tool_gen = JsonCustomToolGenerator()
|
||||||
elif fmt == ToolPromptFormat.function_tag:
|
elif fmt == ToolPromptFormat.function_tag:
|
||||||
|
@ -375,7 +376,7 @@ def augment_messages_for_tools_llama_3_1(
|
||||||
def augment_messages_for_tools_llama_3_2(
|
def augment_messages_for_tools_llama_3_2(
|
||||||
request: ChatCompletionRequest,
|
request: ChatCompletionRequest,
|
||||||
) -> List[Message]:
|
) -> List[Message]:
|
||||||
assert request.tool_choice == ToolChoice.auto, "Only `ToolChoice.auto` supported"
|
assert request.tool_config.tool_choice == ToolChoice.auto, "Only `ToolChoice.auto` supported"
|
||||||
|
|
||||||
existing_messages = request.messages
|
existing_messages = request.messages
|
||||||
existing_system_message = None
|
existing_system_message = None
|
||||||
|
@ -403,20 +404,25 @@ def augment_messages_for_tools_llama_3_2(
|
||||||
|
|
||||||
custom_tools = [dfn for dfn in request.tools if isinstance(dfn.tool_name, str)]
|
custom_tools = [dfn for dfn in request.tools if isinstance(dfn.tool_name, str)]
|
||||||
if custom_tools:
|
if custom_tools:
|
||||||
fmt = request.tool_prompt_format or ToolPromptFormat.python_list
|
fmt = request.tool_config.tool_prompt_format or ToolPromptFormat.python_list
|
||||||
if fmt != ToolPromptFormat.python_list:
|
if fmt != ToolPromptFormat.python_list:
|
||||||
raise ValueError(f"Non supported ToolPromptFormat {request.tool_prompt_format}")
|
raise ValueError(f"Non supported ToolPromptFormat {request.tool_config.tool_prompt_format}")
|
||||||
|
|
||||||
tool_gen = PythonListCustomToolGenerator()
|
system_prompt = None
|
||||||
tool_template = tool_gen.gen(custom_tools)
|
if existing_system_message and request.tool_config.system_message_behavior == SystemMessageBehavior.replace:
|
||||||
|
system_prompt = existing_system_message.content
|
||||||
|
|
||||||
|
tool_template = PythonListCustomToolGenerator().gen(custom_tools, system_prompt)
|
||||||
|
|
||||||
sys_content += tool_template.render()
|
sys_content += tool_template.render()
|
||||||
sys_content += "\n"
|
sys_content += "\n"
|
||||||
|
|
||||||
if existing_system_message:
|
if existing_system_message and (
|
||||||
|
request.tool_config.system_message_behavior == SystemMessageBehavior.append or not custom_tools
|
||||||
|
):
|
||||||
sys_content += interleaved_content_as_str(existing_system_message.content, sep="\n")
|
sys_content += interleaved_content_as_str(existing_system_message.content, sep="\n")
|
||||||
|
|
||||||
messages.append(SystemMessage(content=sys_content))
|
messages.append(SystemMessage(content=sys_content.strip("\n")))
|
||||||
|
|
||||||
# Add back existing messages from the request
|
# Add back existing messages from the request
|
||||||
messages += existing_messages
|
messages += existing_messages
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue