mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 20:14:13 +00:00
add tools to chat completion request
This commit is contained in:
parent
9777639a1c
commit
68855ed218
26 changed files with 558 additions and 226 deletions
|
@ -56,10 +56,10 @@ from llama_toolchain.safety.api.datatypes import (
|
|||
)
|
||||
from llama_toolchain.agentic_system.api.endpoints import * # noqa
|
||||
|
||||
from llama_toolchain.tools.base import BaseTool
|
||||
from llama_toolchain.tools.builtin import SingleMessageBuiltinTool
|
||||
|
||||
from .safety import SafetyException, ShieldRunnerMixin
|
||||
from .system_prompt import get_agentic_prefix_messages
|
||||
from .tools.base import BaseTool
|
||||
from .tools.builtin import SingleMessageBuiltinTool
|
||||
|
||||
|
||||
class AgentInstance(ShieldRunnerMixin):
|
||||
|
@ -85,18 +85,6 @@ class AgentInstance(ShieldRunnerMixin):
|
|||
self.inference_api = inference_api
|
||||
self.safety_api = safety_api
|
||||
|
||||
if prefix_messages is not None and len(prefix_messages) > 0:
|
||||
self.prefix_messages = prefix_messages
|
||||
else:
|
||||
self.prefix_messages = get_agentic_prefix_messages(
|
||||
builtin_tools,
|
||||
custom_tool_definitions,
|
||||
tool_prompt_format,
|
||||
)
|
||||
|
||||
for m in self.prefix_messages:
|
||||
print(m.content)
|
||||
|
||||
self.max_infer_iters = max_infer_iters
|
||||
self.tools_dict = {t.get_name(): t for t in builtin_tools}
|
||||
|
||||
|
@ -344,7 +332,7 @@ class AgentInstance(ShieldRunnerMixin):
|
|||
stream: bool = False,
|
||||
max_gen_len: Optional[int] = None,
|
||||
) -> AsyncGenerator:
|
||||
input_messages = preprocess_dialog(input_messages, self.prefix_messages)
|
||||
input_messages = preprocess_dialog(input_messages)
|
||||
|
||||
attachments = []
|
||||
|
||||
|
@ -373,7 +361,8 @@ class AgentInstance(ShieldRunnerMixin):
|
|||
req = ChatCompletionRequest(
|
||||
model=self.model,
|
||||
messages=input_messages,
|
||||
available_tools=self.instance_config.available_tools,
|
||||
tools=self.instance_config.available_tools,
|
||||
tool_prompt_format=self.instance_config.tool_prompt_format,
|
||||
stream=True,
|
||||
sampling_params=SamplingParams(
|
||||
temperature=temperature,
|
||||
|
@ -601,14 +590,12 @@ def attachment_message(url: URL) -> ToolResponseMessage:
|
|||
)
|
||||
|
||||
|
||||
def preprocess_dialog(
|
||||
messages: List[Message], prefix_messages: List[Message]
|
||||
) -> List[Message]:
|
||||
def preprocess_dialog(messages: List[Message]) -> List[Message]:
|
||||
"""
|
||||
Preprocesses the dialog by removing the system message and
|
||||
adding the system message to the beginning of the dialog.
|
||||
"""
|
||||
ret = prefix_messages.copy()
|
||||
ret = []
|
||||
|
||||
for m in messages:
|
||||
if m.role == Role.system.value:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue