add tools to chat completion request

This commit is contained in:
Hardik Shah 2024-08-21 17:48:48 -07:00
parent 863bb915e1
commit f3f7af7b8a
26 changed files with 558 additions and 226 deletions

View file

@ -56,10 +56,10 @@ from llama_toolchain.safety.api.datatypes import (
)
from llama_toolchain.agentic_system.api.endpoints import * # noqa
from llama_toolchain.tools.base import BaseTool
from llama_toolchain.tools.builtin import SingleMessageBuiltinTool
from .safety import SafetyException, ShieldRunnerMixin
from .system_prompt import get_agentic_prefix_messages
from .tools.base import BaseTool
from .tools.builtin import SingleMessageBuiltinTool
class AgentInstance(ShieldRunnerMixin):
@ -85,18 +85,6 @@ class AgentInstance(ShieldRunnerMixin):
self.inference_api = inference_api
self.safety_api = safety_api
if prefix_messages is not None and len(prefix_messages) > 0:
self.prefix_messages = prefix_messages
else:
self.prefix_messages = get_agentic_prefix_messages(
builtin_tools,
custom_tool_definitions,
tool_prompt_format,
)
for m in self.prefix_messages:
print(m.content)
self.max_infer_iters = max_infer_iters
self.tools_dict = {t.get_name(): t for t in builtin_tools}
@ -344,7 +332,7 @@ class AgentInstance(ShieldRunnerMixin):
stream: bool = False,
max_gen_len: Optional[int] = None,
) -> AsyncGenerator:
input_messages = preprocess_dialog(input_messages, self.prefix_messages)
input_messages = preprocess_dialog(input_messages)
attachments = []
@ -373,7 +361,8 @@ class AgentInstance(ShieldRunnerMixin):
req = ChatCompletionRequest(
model=self.model,
messages=input_messages,
available_tools=self.instance_config.available_tools,
tools=self.instance_config.available_tools,
tool_prompt_format=self.instance_config.tool_prompt_format,
stream=True,
sampling_params=SamplingParams(
temperature=temperature,
@ -601,14 +590,12 @@ def attachment_message(url: URL) -> ToolResponseMessage:
)
def preprocess_dialog(
messages: List[Message], prefix_messages: List[Message]
) -> List[Message]:
def preprocess_dialog(messages: List[Message]) -> List[Message]:
"""
Preprocesses the dialog by removing the system message and
adding the system message to the beginning of the dialog.
"""
ret = prefix_messages.copy()
ret = []
for m in messages:
if m.role == Role.system.value: