mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-06 10:42:39 +00:00
sys_prompt
This commit is contained in:
parent
15dcc4ea5e
commit
6d035b3152
26 changed files with 147 additions and 48 deletions
|
@ -33,6 +33,7 @@ from llama_stack.apis.inference import (
|
|||
ToolResponse,
|
||||
ToolResponseMessage,
|
||||
UserMessage,
|
||||
ToolConfig,
|
||||
)
|
||||
from llama_stack.apis.safety import SafetyViolation
|
||||
from llama_stack.apis.tools import ToolDef
|
||||
|
@ -155,8 +156,13 @@ class AgentConfigCommon(BaseModel):
|
|||
output_shields: Optional[List[str]] = Field(default_factory=list)
|
||||
toolgroups: Optional[List[AgentToolGroup]] = Field(default_factory=list)
|
||||
client_tools: Optional[List[ToolDef]] = Field(default_factory=list)
|
||||
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None)
|
||||
tool_choice: Optional[ToolChoice] = Field(
|
||||
default=ToolChoice.auto, deprecated="use tool_config instead"
|
||||
)
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = Field(
|
||||
default=None, deprecated="use tool_config instead"
|
||||
)
|
||||
tool_config: Optional[ToolConfig] = Field(default=None)
|
||||
|
||||
max_infer_iters: int = 10
|
||||
|
||||
|
@ -280,7 +286,7 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
|
|||
toolgroups: Optional[List[AgentToolGroup]] = None
|
||||
|
||||
stream: Optional[bool] = False
|
||||
|
||||
tool_config: Optional[ToolConfig] = None
|
||||
|
||||
@json_schema_type
|
||||
class AgentTurnResponseStreamChunk(BaseModel):
|
||||
|
@ -327,6 +333,7 @@ class Agents(Protocol):
|
|||
stream: Optional[bool] = False,
|
||||
documents: Optional[List[Document]] = None,
|
||||
toolgroups: Optional[List[AgentToolGroup]] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...
|
||||
|
||||
@webmethod(
|
||||
|
|
|
@ -310,14 +310,48 @@ class CompletionResponseStreamChunk(BaseModel):
|
|||
logprobs: Optional[List[TokenLogProbs]] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class SystemMessageBehavior(Enum):
|
||||
"""Config for how to override the default system prompt.
|
||||
|
||||
:cvar append: Appends the provided system message to the default system prompt:
|
||||
https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_2/#-function-definitions-in-the-system-prompt-
|
||||
:cvar replace: Replaces the default system prompt with the provided system message. The system message can include the string
|
||||
'{{function_definitions}}' to indicate where the function definitions should be inserted.
|
||||
"""
|
||||
|
||||
append = "append"
|
||||
replace = "replace"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ToolConfig(BaseModel):
|
||||
"""Configuration for tool use.
|
||||
|
||||
:param tool_choice: (Optional) Whether tool use is required or automatic. Defaults to ToolChoice.auto.
|
||||
:param tool_prompt_format: (Optional) Instructs the model how to format tool calls. By default, Llama Stack will attempt to use a format that is best adapted to the model.
|
||||
- `ToolPromptFormat.json`: The tool calls are formatted as a JSON object.
|
||||
- `ToolPromptFormat.function_tag`: The tool calls are enclosed in a <function=function_name> tag.
|
||||
- `ToolPromptFormat.python_list`: The tool calls are output as Python syntax -- a list of function calls.
|
||||
"""
|
||||
|
||||
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None)
|
||||
system_message_behavior: SystemMessageBehavior = Field(
|
||||
default=SystemMessageBehavior.append
|
||||
)
|
||||
|
||||
|
||||
# This is an internally used class
|
||||
@json_schema_type
|
||||
class ChatCompletionRequest(BaseModel):
|
||||
model: str
|
||||
messages: List[Message]
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams()
|
||||
|
||||
tools: Optional[List[ToolDefinition]] = Field(default_factory=list)
|
||||
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None)
|
||||
tool_config: Optional[ToolConfig] = None
|
||||
|
||||
response_format: Optional[ResponseFormat] = None
|
||||
stream: Optional[bool] = False
|
||||
logprobs: Optional[LogProbConfig] = None
|
||||
|
@ -406,6 +440,7 @@ class Inference(Protocol):
|
|||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> Union[
|
||||
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
|
||||
]:
|
||||
|
@ -416,15 +451,20 @@ class Inference(Protocol):
|
|||
:param sampling_params: Parameters to control the sampling strategy
|
||||
:param tools: (Optional) List of tool definitions available to the model
|
||||
:param tool_choice: (Optional) Whether tool use is required or automatic. Defaults to ToolChoice.auto.
|
||||
.. deprecated::
|
||||
Use tool_config instead.
|
||||
:param tool_prompt_format: (Optional) Instructs the model how to format tool calls. By default, Llama Stack will attempt to use a format that is best adapted to the model.
|
||||
- `ToolPromptFormat.json`: The tool calls are formatted as a JSON object.
|
||||
- `ToolPromptFormat.function_tag`: The tool calls are enclosed in a <function=function_name> tag.
|
||||
- `ToolPromptFormat.python_list`: The tool calls are output as Python syntax -- a list of function calls.
|
||||
.. deprecated::
|
||||
Use tool_config instead.
|
||||
:param response_format: (Optional) Grammar specification for guided (structured) decoding. There are two options:
|
||||
- `ResponseFormat.json_schema`: The grammar is a JSON schema. Most providers support this format.
|
||||
- `ResponseFormat.grammar`: The grammar is a BNF grammar. This format is more flexible, but not all providers support it.
|
||||
:param stream: (Optional) If True, generate an SSE event stream of the response. Defaults to False.
|
||||
:param logprobs: (Optional) If specified, log probabilities for each token position will be returned.
|
||||
:param tool_config: (Optional) Configuration for tool use.
|
||||
:returns: If stream=False, returns a ChatCompletionResponse with the full completion.
|
||||
If stream=True, returns an SSE event stream of ChatCompletionResponseStreamChunk
|
||||
"""
|
||||
|
|
|
@ -24,6 +24,7 @@ from llama_stack.apis.inference import (
|
|||
ResponseFormat,
|
||||
SamplingParams,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
|
@ -138,6 +139,7 @@ class InferenceRouter(Inference):
|
|||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
model = await self.routing_table.get_model(model_id)
|
||||
if model is None:
|
||||
|
@ -146,6 +148,20 @@ class InferenceRouter(Inference):
|
|||
raise ValueError(
|
||||
f"Model '{model_id}' is an embedding model and does not support chat completions"
|
||||
)
|
||||
if tool_config:
|
||||
if tool_choice != tool_config.tool_choice:
|
||||
raise ValueError(
|
||||
"tool_choice and tool_config.tool_choice must match"
|
||||
)
|
||||
if tool_prompt_format != tool_config.tool_prompt_format:
|
||||
raise ValueError(
|
||||
"tool_prompt_format and tool_config.tool_prompt_format must match"
|
||||
)
|
||||
else:
|
||||
tool_config = ToolConfig(
|
||||
tool_choice=tool_choice,
|
||||
tool_prompt_format=tool_prompt_format,
|
||||
)
|
||||
params = dict(
|
||||
model_id=model_id,
|
||||
messages=messages,
|
||||
|
@ -156,6 +172,7 @@ class InferenceRouter(Inference):
|
|||
response_format=response_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
tool_config=tool_config,
|
||||
)
|
||||
provider = self.routing_table.get_provider_impl(model_id)
|
||||
if stream:
|
||||
|
|
|
@ -515,10 +515,11 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
for tool in tool_defs.values()
|
||||
if tool_to_group.get(tool.tool_name, None) != RAG_TOOL_GROUP
|
||||
],
|
||||
tool_prompt_format=self.agent_config.tool_prompt_format,
|
||||
tool_prompt_format=self.agent_config.tool_config.tool_prompt_format,
|
||||
response_format=self.agent_config.response_format,
|
||||
stream=True,
|
||||
sampling_params=sampling_params,
|
||||
tool_config=self.agent_config.tool_config,
|
||||
):
|
||||
event = chunk.event
|
||||
if event.event_type == ChatCompletionResponseEventType.start:
|
||||
|
|
|
@ -25,7 +25,12 @@ from llama_stack.apis.agents import (
|
|||
Session,
|
||||
Turn,
|
||||
)
|
||||
from llama_stack.apis.inference import Inference, ToolResponseMessage, UserMessage
|
||||
from llama_stack.apis.inference import (
|
||||
Inference,
|
||||
ToolConfig,
|
||||
ToolResponseMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
|
@ -146,6 +151,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
toolgroups: Optional[List[AgentToolGroup]] = None,
|
||||
documents: Optional[List[Document]] = None,
|
||||
stream: Optional[bool] = False,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
request = AgentTurnCreateRequest(
|
||||
agent_id=agent_id,
|
||||
|
@ -154,6 +160,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
stream=True,
|
||||
toolgroups=toolgroups,
|
||||
documents=documents,
|
||||
tool_config=tool_config,
|
||||
)
|
||||
if stream:
|
||||
return self._create_agent_turn_streaming(request)
|
||||
|
|
|
@ -400,7 +400,7 @@ class Llama:
|
|||
yield from self.generate(
|
||||
model_input=self.formatter.encode_dialog_prompt(
|
||||
request.messages,
|
||||
request.tool_prompt_format,
|
||||
request.tool_config.tool_prompt_format,
|
||||
),
|
||||
max_gen_len=max_gen_len,
|
||||
temperature=temperature,
|
||||
|
|
|
@ -38,6 +38,7 @@ from llama_stack.apis.inference import (
|
|||
ResponseFormat,
|
||||
TokenLogProbs,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
)
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||
|
@ -270,6 +271,7 @@ class MetaReferenceInferenceImpl(
|
|||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
if logprobs:
|
||||
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
|
||||
|
@ -280,11 +282,10 @@ class MetaReferenceInferenceImpl(
|
|||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
tools=tools or [],
|
||||
tool_choice=tool_choice,
|
||||
tool_prompt_format=tool_prompt_format,
|
||||
response_format=response_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
tool_config=tool_config,
|
||||
)
|
||||
self.check_model(request)
|
||||
|
||||
|
|
|
@ -17,6 +17,7 @@ from llama_stack.apis.inference import (
|
|||
ToolChoice,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
ToolConfig,
|
||||
)
|
||||
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.embedding_mixin import (
|
||||
|
@ -71,5 +72,6 @@ class SentenceTransformersInferenceImpl(
|
|||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
raise ValueError("Sentence transformers don't support chat completion")
|
||||
|
|
|
@ -30,6 +30,7 @@ from llama_stack.apis.inference import (
|
|||
ResponseFormat,
|
||||
SamplingParams,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
|
@ -159,6 +160,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> ChatCompletionResponse | ChatCompletionResponseStreamChunk:
|
||||
assert self.engine is not None
|
||||
|
||||
|
@ -167,10 +169,9 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
tools=tools or [],
|
||||
tool_choice=tool_choice,
|
||||
tool_prompt_format=tool_prompt_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
tool_config=tool_config,
|
||||
)
|
||||
|
||||
log.info("Sampling params: %s", sampling_params)
|
||||
|
|
|
@ -24,6 +24,7 @@ from llama_stack.apis.inference import (
|
|||
ResponseFormat,
|
||||
SamplingParams,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
|
@ -102,6 +103,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> Union[
|
||||
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
|
||||
]:
|
||||
|
@ -111,11 +113,10 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
tools=tools or [],
|
||||
tool_choice=tool_choice,
|
||||
tool_prompt_format=tool_prompt_format,
|
||||
response_format=response_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
tool_config=tool_config,
|
||||
)
|
||||
|
||||
if stream:
|
||||
|
|
|
@ -24,6 +24,7 @@ from llama_stack.apis.inference import (
|
|||
ResponseFormat,
|
||||
SamplingParams,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
|
@ -130,6 +131,7 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
model = await self.model_store.get_model(model_id)
|
||||
request = ChatCompletionRequest(
|
||||
|
@ -142,6 +144,7 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
response_format=response_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
tool_config=tool_config,
|
||||
)
|
||||
|
||||
if stream:
|
||||
|
|
|
@ -89,16 +89,16 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
request = ChatCompletionRequest(
|
||||
model=model,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
tools=tools or [],
|
||||
tool_choice=tool_choice,
|
||||
tool_prompt_format=tool_prompt_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
tool_config=tool_config,
|
||||
)
|
||||
|
||||
client = OpenAI(base_url=self.config.url, api_key=self.config.api_token)
|
||||
|
|
|
@ -25,6 +25,7 @@ from llama_stack.apis.inference import (
|
|||
ResponseFormatType,
|
||||
SamplingParams,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
|
@ -208,6 +209,7 @@ class FireworksInferenceAdapter(
|
|||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
model = await self.model_store.get_model(model_id)
|
||||
request = ChatCompletionRequest(
|
||||
|
@ -215,11 +217,10 @@ class FireworksInferenceAdapter(
|
|||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
tools=tools or [],
|
||||
tool_choice=tool_choice,
|
||||
tool_prompt_format=tool_prompt_format,
|
||||
response_format=response_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
tool_config=tool_config,
|
||||
)
|
||||
|
||||
if stream:
|
||||
|
|
|
@ -99,6 +99,7 @@ class GroqInferenceAdapter(Inference, ModelRegistryHelper, NeedsRequestProviderD
|
|||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> Union[
|
||||
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
|
||||
]:
|
||||
|
@ -117,10 +118,9 @@ class GroqInferenceAdapter(Inference, ModelRegistryHelper, NeedsRequestProviderD
|
|||
sampling_params=sampling_params,
|
||||
response_format=response_format,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
tool_prompt_format=tool_prompt_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
tool_config=tool_config,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
@ -79,7 +79,7 @@ def convert_chat_completion_request(
|
|||
# so we exclude it for now
|
||||
warnings.warn("repetition_penalty is not supported")
|
||||
|
||||
if request.tool_prompt_format != ToolPromptFormat.json:
|
||||
if request.tool_config.tool_prompt_format != ToolPromptFormat.json:
|
||||
warnings.warn("tool_prompt_format is not used by Groq. Ignoring.")
|
||||
|
||||
sampling_options = get_sampling_strategy_options(request.sampling_params)
|
||||
|
@ -93,7 +93,11 @@ def convert_chat_completion_request(
|
|||
temperature=sampling_options.get("temperature", 1.0),
|
||||
top_p=sampling_options.get("top_p", 1.0),
|
||||
tools=[_convert_groq_tool_definition(tool) for tool in request.tools or []],
|
||||
tool_choice=request.tool_choice.value if request.tool_choice else None,
|
||||
tool_choice=(
|
||||
request.tool_config.tool_choice.value
|
||||
if request.tool_config.tool_choice
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -178,6 +178,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> Union[
|
||||
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
|
||||
]:
|
||||
|
@ -193,10 +194,9 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
sampling_params=sampling_params,
|
||||
response_format=response_format,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
tool_prompt_format=tool_prompt_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
tool_config=tool_config,
|
||||
),
|
||||
n=1,
|
||||
)
|
||||
|
|
|
@ -253,9 +253,9 @@ def convert_chat_completion_request(
|
|||
payload.update(
|
||||
tools=[_convert_tooldef_to_openai_tool(tool) for tool in request.tools]
|
||||
)
|
||||
if request.tool_choice:
|
||||
if request.tool_config.tool_choice:
|
||||
payload.update(
|
||||
tool_choice=request.tool_choice.value
|
||||
tool_choice=request.tool_config.tool_choice.value
|
||||
) # we cannot include tool_choice w/o tools, server will complain
|
||||
|
||||
if request.logprobs:
|
||||
|
|
|
@ -29,6 +29,7 @@ from llama_stack.apis.inference import (
|
|||
ResponseFormat,
|
||||
SamplingParams,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
|
@ -224,6 +225,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
model = await self.model_store.get_model(model_id)
|
||||
request = ChatCompletionRequest(
|
||||
|
@ -231,11 +233,10 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
tools=tools or [],
|
||||
tool_choice=tool_choice,
|
||||
tool_prompt_format=tool_prompt_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
response_format=response_format,
|
||||
tool_config=tool_config,
|
||||
)
|
||||
if stream:
|
||||
return self._stream_chat_completion(request)
|
||||
|
@ -322,6 +323,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
self, request: ChatCompletionRequest
|
||||
) -> AsyncGenerator:
|
||||
params = await self._get_params(request)
|
||||
print(params)
|
||||
|
||||
async def _generate_and_convert_to_openai_compat():
|
||||
if "messages" in params:
|
||||
|
|
|
@ -85,10 +85,9 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
tools=tools or [],
|
||||
tool_choice=tool_choice,
|
||||
tool_prompt_format=tool_prompt_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
tool_config=tool_config,
|
||||
)
|
||||
|
||||
client = OpenAI(base_url=self.config.url, api_key=self.config.api_token)
|
||||
|
|
|
@ -125,10 +125,9 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
tools=tools or [],
|
||||
tool_choice=tool_choice,
|
||||
tool_prompt_format=tool_prompt_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
tool_config=tool_config,
|
||||
)
|
||||
request_sambanova = await self.convert_chat_completion_request(request)
|
||||
|
||||
|
|
|
@ -26,6 +26,7 @@ from llama_stack.apis.inference import (
|
|||
ResponseFormatType,
|
||||
SamplingParams,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
|
@ -213,6 +214,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
|||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
model = await self.model_store.get_model(model_id)
|
||||
request = ChatCompletionRequest(
|
||||
|
@ -220,11 +222,10 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
|||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
tools=tools or [],
|
||||
tool_choice=tool_choice,
|
||||
tool_prompt_format=tool_prompt_format,
|
||||
response_format=response_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
tool_config=tool_config,
|
||||
)
|
||||
|
||||
if stream:
|
||||
|
|
|
@ -24,6 +24,7 @@ from llama_stack.apis.inference import (
|
|||
ResponseFormatType,
|
||||
SamplingParams,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
|
@ -198,6 +199,7 @@ class TogetherInferenceAdapter(
|
|||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
model = await self.model_store.get_model(model_id)
|
||||
request = ChatCompletionRequest(
|
||||
|
@ -205,11 +207,10 @@ class TogetherInferenceAdapter(
|
|||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
tools=tools or [],
|
||||
tool_choice=tool_choice,
|
||||
tool_prompt_format=tool_prompt_format,
|
||||
response_format=response_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
tool_config=tool_config,
|
||||
)
|
||||
|
||||
if stream:
|
||||
|
|
|
@ -27,6 +27,7 @@ from llama_stack.apis.inference import (
|
|||
ResponseFormatType,
|
||||
SamplingParams,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
|
@ -119,6 +120,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
model = await self.model_store.get_model(model_id)
|
||||
request = ChatCompletionRequest(
|
||||
|
@ -126,11 +128,10 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
tools=tools or [],
|
||||
tool_choice=tool_choice,
|
||||
tool_prompt_format=tool_prompt_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
response_format=response_format,
|
||||
tool_config=tool_config,
|
||||
)
|
||||
if stream:
|
||||
return self._stream_chat_completion(request, self.client)
|
||||
|
|
|
@ -181,7 +181,7 @@ class TestConvertChatCompletionRequest:
|
|||
|
||||
def test_includes_tool_choice(self):
|
||||
request = self._dummy_chat_completion_request()
|
||||
request.tool_choice = ToolChoice.required
|
||||
request.tool_config.tool_choice = ToolChoice.required
|
||||
|
||||
converted = convert_chat_completion_request(request)
|
||||
|
||||
|
|
|
@ -13,7 +13,12 @@ from llama_models.llama3.api.datatypes import (
|
|||
ToolPromptFormat,
|
||||
)
|
||||
|
||||
from llama_stack.apis.inference import ChatCompletionRequest, SystemMessage, UserMessage
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
SystemMessage,
|
||||
ToolConfig,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_messages,
|
||||
)
|
||||
|
@ -73,7 +78,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
|
|||
},
|
||||
)
|
||||
],
|
||||
tool_prompt_format=ToolPromptFormat.json,
|
||||
tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.json),
|
||||
)
|
||||
messages = chat_completion_request_to_messages(request, MODEL)
|
||||
self.assertEqual(len(messages), 3)
|
||||
|
|
|
@ -49,6 +49,7 @@ from llama_stack.apis.inference import (
|
|||
SystemMessage,
|
||||
ToolChoice,
|
||||
UserMessage,
|
||||
SystemMessageBehavior,
|
||||
)
|
||||
from llama_stack.providers.utils.inference import supported_inference_models
|
||||
|
||||
|
@ -319,7 +320,7 @@ def response_format_prompt(fmt: Optional[ResponseFormat]):
|
|||
def augment_messages_for_tools_llama_3_1(
|
||||
request: ChatCompletionRequest,
|
||||
) -> List[Message]:
|
||||
assert request.tool_choice == ToolChoice.auto, "Only `ToolChoice.auto` supported"
|
||||
assert request.tool_config.tool_choice == ToolChoice.auto, "Only `ToolChoice.auto` supported"
|
||||
|
||||
existing_messages = request.messages
|
||||
existing_system_message = None
|
||||
|
@ -368,7 +369,7 @@ def augment_messages_for_tools_llama_3_1(
|
|||
|
||||
has_custom_tools = any(isinstance(dfn.tool_name, str) for dfn in request.tools)
|
||||
if has_custom_tools:
|
||||
fmt = request.tool_prompt_format or ToolPromptFormat.json
|
||||
fmt = request.tool_config.tool_prompt_format or ToolPromptFormat.json
|
||||
if fmt == ToolPromptFormat.json:
|
||||
tool_gen = JsonCustomToolGenerator()
|
||||
elif fmt == ToolPromptFormat.function_tag:
|
||||
|
@ -389,7 +390,7 @@ def augment_messages_for_tools_llama_3_1(
|
|||
def augment_messages_for_tools_llama_3_2(
|
||||
request: ChatCompletionRequest,
|
||||
) -> List[Message]:
|
||||
assert request.tool_choice == ToolChoice.auto, "Only `ToolChoice.auto` supported"
|
||||
assert request.tool_config.tool_choice == ToolChoice.auto, "Only `ToolChoice.auto` supported"
|
||||
|
||||
existing_messages = request.messages
|
||||
existing_system_message = None
|
||||
|
@ -419,19 +420,24 @@ def augment_messages_for_tools_llama_3_2(
|
|||
|
||||
custom_tools = [dfn for dfn in request.tools if isinstance(dfn.tool_name, str)]
|
||||
if custom_tools:
|
||||
fmt = request.tool_prompt_format or ToolPromptFormat.python_list
|
||||
fmt = request.tool_config.tool_prompt_format or ToolPromptFormat.python_list
|
||||
if fmt != ToolPromptFormat.python_list:
|
||||
raise ValueError(
|
||||
f"Non supported ToolPromptFormat {request.tool_prompt_format}"
|
||||
f"Non supported ToolPromptFormat {request.tool_config.tool_prompt_format}"
|
||||
)
|
||||
|
||||
tool_gen = PythonListCustomToolGenerator()
|
||||
tool_template = tool_gen.gen(custom_tools)
|
||||
|
||||
system_prompt = None
|
||||
if existing_system_message and request.tool_config.system_message_behavior == SystemMessageBehavior.replace:
|
||||
system_prompt = existing_system_message.content
|
||||
|
||||
tool_template = tool_gen.gen(custom_tools, system_prompt)
|
||||
|
||||
sys_content += tool_template.render()
|
||||
sys_content += "\n"
|
||||
|
||||
if existing_system_message:
|
||||
if existing_system_message and request.tool_config.system_message_behavior == SystemMessageBehavior.append:
|
||||
sys_content += interleaved_content_as_str(
|
||||
existing_system_message.content, sep="\n"
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue