mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-11 05:38:38 +00:00
re-work tool definitions, fix FastAPI issues, fix tool regressions
This commit is contained in:
parent
8d14d4228b
commit
8efe614719
11 changed files with 144 additions and 104 deletions
|
@ -10,8 +10,6 @@ import uuid
|
|||
from datetime import datetime
|
||||
from typing import AsyncGenerator, List
|
||||
|
||||
from llama_models.llama3.api.datatypes import ToolPromptFormat
|
||||
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_toolchain.agentic_system.api import * # noqa: F403
|
||||
|
@ -20,7 +18,10 @@ from llama_toolchain.memory.api import * # noqa: F403
|
|||
from llama_toolchain.safety.api import * # noqa: F403
|
||||
|
||||
from llama_toolchain.tools.base import BaseTool
|
||||
from llama_toolchain.tools.builtin import SingleMessageBuiltinTool
|
||||
from llama_toolchain.tools.builtin import (
|
||||
interpret_content_as_attachment,
|
||||
SingleMessageBuiltinTool,
|
||||
)
|
||||
|
||||
from .safety import SafetyException, ShieldRunnerMixin
|
||||
|
||||
|
@ -192,7 +193,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
yield res
|
||||
|
||||
async for res in self._run(
|
||||
turn_id, session, input_messages, attachments, sampling_params, stream
|
||||
session, turn_id, input_messages, attachments, sampling_params, stream
|
||||
):
|
||||
if isinstance(res, bool):
|
||||
return
|
||||
|
@ -358,7 +359,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
req = ChatCompletionRequest(
|
||||
model=self.agent_config.model,
|
||||
messages=input_messages,
|
||||
tools=self.agent_config.tools,
|
||||
tools=self._get_tools(),
|
||||
tool_prompt_format=self.agent_config.tool_prompt_format,
|
||||
stream=True,
|
||||
sampling_params=sampling_params,
|
||||
|
@ -555,17 +556,13 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
yield False
|
||||
return
|
||||
|
||||
if isinstance(result_message.content, Attachment):
|
||||
if out_attachment := interpret_content_as_attachment(
|
||||
result_message.content
|
||||
):
|
||||
# NOTE: when we push this message back to the model, the model may ignore the
|
||||
# attached file path etc. since the model is trained to only provide a user message
|
||||
# with the summary. We keep all generated attachments and then attach them to final message
|
||||
output_attachments.append(result_message.content)
|
||||
elif isinstance(result_message.content, list) or isinstance(
|
||||
result_message.content, tuple
|
||||
):
|
||||
for c in result_message.content:
|
||||
if isinstance(c, Attachment):
|
||||
output_attachments.append(c)
|
||||
output_attachments.append(out_attachment)
|
||||
|
||||
input_messages = input_messages + [message, result_message]
|
||||
|
||||
|
@ -667,6 +664,27 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
"\n=== END-RETRIEVED-CONTEXT ===\n",
|
||||
]
|
||||
|
||||
def _get_tools(self) -> List[ToolDefinition]:
|
||||
ret = []
|
||||
for t in self.agent_config.tools:
|
||||
if isinstance(t, BraveSearchToolDefinition):
|
||||
ret.append(ToolDefinition(tool_name=BuiltinTool.brave_search))
|
||||
elif isinstance(t, WolframAlphaToolDefinition):
|
||||
ret.append(ToolDefinition(tool_name=BuiltinTool.wolfram_alpha))
|
||||
elif isinstance(t, PhotogenToolDefinition):
|
||||
ret.append(ToolDefinition(tool_name=BuiltinTool.photogen))
|
||||
elif isinstance(t, CodeInterpreterToolDefinition):
|
||||
ret.append(ToolDefinition(tool_name=BuiltinTool.code_interpreter))
|
||||
elif isinstance(t, FunctionCallToolDefinition):
|
||||
ret.append(
|
||||
ToolDefinition(
|
||||
tool_name=t.function_name,
|
||||
description=t.description,
|
||||
parameters=t.parameters,
|
||||
)
|
||||
)
|
||||
return ret
|
||||
|
||||
|
||||
def attachment_message(urls: List[URL]) -> ToolResponseMessage:
|
||||
content = []
|
||||
|
|
|
@ -12,7 +12,6 @@ from typing import AsyncGenerator, Dict
|
|||
|
||||
from llama_toolchain.distribution.datatypes import Api, ProviderSpec
|
||||
from llama_toolchain.inference.api import Inference
|
||||
from llama_toolchain.inference.api.datatypes import BuiltinTool
|
||||
from llama_toolchain.memory.api import Memory
|
||||
from llama_toolchain.safety.api import Safety
|
||||
from llama_toolchain.agentic_system.api import * # noqa: F403
|
||||
|
@ -42,6 +41,7 @@ async def get_provider_impl(
|
|||
impl = MetaReferenceAgenticSystemImpl(
|
||||
config,
|
||||
deps[Api.inference],
|
||||
deps[Api.memory],
|
||||
deps[Api.safety],
|
||||
)
|
||||
await impl.initialize()
|
||||
|
@ -56,54 +56,55 @@ class MetaReferenceAgenticSystemImpl(AgenticSystem):
|
|||
self,
|
||||
config: MetaReferenceImplConfig,
|
||||
inference_api: Inference,
|
||||
safety_api: Safety,
|
||||
memory_api: Memory,
|
||||
safety_api: Safety,
|
||||
):
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
self.safety_api = safety_api
|
||||
self.memory_api = memory_api
|
||||
self.safety_api = safety_api
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def create_agentic_system(
|
||||
self,
|
||||
request: AgenticSystemCreateRequest,
|
||||
agent_config: AgentConfig,
|
||||
) -> AgenticSystemCreateResponse:
|
||||
agent_id = str(uuid.uuid4())
|
||||
|
||||
builtin_tools = []
|
||||
cfg = request.agent_config
|
||||
for dfn in cfg.tools:
|
||||
if isinstance(dfn.tool_name, BuiltinTool):
|
||||
if dfn.tool_name == BuiltinTool.wolfram_alpha:
|
||||
key = self.config.wolfram_api_key
|
||||
if not key:
|
||||
raise ValueError("Wolfram API key not defined in config")
|
||||
tool = WolframAlphaTool(key)
|
||||
elif dfn.tool_name == BuiltinTool.brave_search:
|
||||
key = self.config.brave_search_api_key
|
||||
if not key:
|
||||
raise ValueError("Brave API key not defined in config")
|
||||
tool = BraveSearchTool(key)
|
||||
elif dfn.tool_name == BuiltinTool.code_interpreter:
|
||||
tool = CodeInterpreterTool()
|
||||
elif dfn.tool_name == BuiltinTool.photogen:
|
||||
tool = PhotogenTool(
|
||||
dump_dir="/tmp/photogen_dump_" + os.environ["USER"],
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown builtin tool: {dfn.tool_name}")
|
||||
|
||||
builtin_tools.append(
|
||||
with_safety(
|
||||
tool, self.safety_api, dfn.input_shields, dfn.output_shields
|
||||
)
|
||||
for tool_defn in agent_config.tools:
|
||||
if isinstance(tool_defn, WolframAlphaToolDefinition):
|
||||
key = self.config.wolfram_api_key
|
||||
if not key:
|
||||
raise ValueError("Wolfram API key not defined in config")
|
||||
tool = WolframAlphaTool(key)
|
||||
elif isinstance(tool_defn, BraveSearchToolDefinition):
|
||||
key = self.config.brave_search_api_key
|
||||
if not key:
|
||||
raise ValueError("Brave API key not defined in config")
|
||||
tool = BraveSearchTool(key)
|
||||
elif isinstance(tool_defn, CodeInterpreterToolDefinition):
|
||||
tool = CodeInterpreterTool()
|
||||
elif isinstance(tool_defn, PhotogenToolDefinition):
|
||||
tool = PhotogenTool(
|
||||
dump_dir="/tmp/photogen_dump_" + os.environ["USER"],
|
||||
)
|
||||
else:
|
||||
continue
|
||||
|
||||
builtin_tools.append(
|
||||
with_safety(
|
||||
tool,
|
||||
self.safety_api,
|
||||
tool_defn.input_shields,
|
||||
tool_defn.output_shields,
|
||||
)
|
||||
)
|
||||
|
||||
AGENT_INSTANCES_BY_ID[agent_id] = ChatAgent(
|
||||
agent_config=cfg,
|
||||
agent_config=agent_config,
|
||||
inference_api=self.inference_api,
|
||||
safety_api=self.safety_api,
|
||||
memory_api=self.memory_api,
|
||||
|
@ -116,13 +117,13 @@ class MetaReferenceAgenticSystemImpl(AgenticSystem):
|
|||
|
||||
async def create_agentic_system_session(
|
||||
self,
|
||||
request: AgenticSystemSessionCreateRequest,
|
||||
agent_id: str,
|
||||
session_name: str,
|
||||
) -> AgenticSystemSessionCreateResponse:
|
||||
agent_id = request.agent_id
|
||||
assert agent_id in AGENT_INSTANCES_BY_ID, f"System {agent_id} not found"
|
||||
agent = AGENT_INSTANCES_BY_ID[agent_id]
|
||||
|
||||
session = agent.create_session(request.session_name)
|
||||
session = agent.create_session(session_name)
|
||||
return AgenticSystemSessionCreateResponse(
|
||||
session_id=session.session_id,
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue