re-work tool definitions, fix FastAPI issues, fix tool regressions

This commit is contained in:
Ashwin Bharambe 2024-08-24 22:07:06 -07:00
parent 8d14d4228b
commit 8efe614719
11 changed files with 144 additions and 104 deletions

View file

@ -12,7 +12,6 @@ from typing import AsyncGenerator, Dict
from llama_toolchain.distribution.datatypes import Api, ProviderSpec
from llama_toolchain.inference.api import Inference
from llama_toolchain.inference.api.datatypes import BuiltinTool
from llama_toolchain.memory.api import Memory
from llama_toolchain.safety.api import Safety
from llama_toolchain.agentic_system.api import * # noqa: F403
@ -42,6 +41,7 @@ async def get_provider_impl(
impl = MetaReferenceAgenticSystemImpl(
config,
deps[Api.inference],
deps[Api.memory],
deps[Api.safety],
)
await impl.initialize()
@ -56,54 +56,55 @@ class MetaReferenceAgenticSystemImpl(AgenticSystem):
self,
config: MetaReferenceImplConfig,
inference_api: Inference,
safety_api: Safety,
memory_api: Memory,
safety_api: Safety,
):
self.config = config
self.inference_api = inference_api
self.safety_api = safety_api
self.memory_api = memory_api
self.safety_api = safety_api
async def initialize(self) -> None:
pass
async def create_agentic_system(
self,
request: AgenticSystemCreateRequest,
agent_config: AgentConfig,
) -> AgenticSystemCreateResponse:
agent_id = str(uuid.uuid4())
builtin_tools = []
cfg = request.agent_config
for dfn in cfg.tools:
if isinstance(dfn.tool_name, BuiltinTool):
if dfn.tool_name == BuiltinTool.wolfram_alpha:
key = self.config.wolfram_api_key
if not key:
raise ValueError("Wolfram API key not defined in config")
tool = WolframAlphaTool(key)
elif dfn.tool_name == BuiltinTool.brave_search:
key = self.config.brave_search_api_key
if not key:
raise ValueError("Brave API key not defined in config")
tool = BraveSearchTool(key)
elif dfn.tool_name == BuiltinTool.code_interpreter:
tool = CodeInterpreterTool()
elif dfn.tool_name == BuiltinTool.photogen:
tool = PhotogenTool(
dump_dir="/tmp/photogen_dump_" + os.environ["USER"],
)
else:
raise ValueError(f"Unknown builtin tool: {dfn.tool_name}")
builtin_tools.append(
with_safety(
tool, self.safety_api, dfn.input_shields, dfn.output_shields
)
for tool_defn in agent_config.tools:
if isinstance(tool_defn, WolframAlphaToolDefinition):
key = self.config.wolfram_api_key
if not key:
raise ValueError("Wolfram API key not defined in config")
tool = WolframAlphaTool(key)
elif isinstance(tool_defn, BraveSearchToolDefinition):
key = self.config.brave_search_api_key
if not key:
raise ValueError("Brave API key not defined in config")
tool = BraveSearchTool(key)
elif isinstance(tool_defn, CodeInterpreterToolDefinition):
tool = CodeInterpreterTool()
elif isinstance(tool_defn, PhotogenToolDefinition):
tool = PhotogenTool(
dump_dir="/tmp/photogen_dump_" + os.environ["USER"],
)
else:
continue
builtin_tools.append(
with_safety(
tool,
self.safety_api,
tool_defn.input_shields,
tool_defn.output_shields,
)
)
AGENT_INSTANCES_BY_ID[agent_id] = ChatAgent(
agent_config=cfg,
agent_config=agent_config,
inference_api=self.inference_api,
safety_api=self.safety_api,
memory_api=self.memory_api,
@ -116,13 +117,13 @@ class MetaReferenceAgenticSystemImpl(AgenticSystem):
async def create_agentic_system_session(
self,
request: AgenticSystemSessionCreateRequest,
agent_id: str,
session_name: str,
) -> AgenticSystemSessionCreateResponse:
agent_id = request.agent_id
assert agent_id in AGENT_INSTANCES_BY_ID, f"System {agent_id} not found"
agent = AGENT_INSTANCES_BY_ID[agent_id]
session = agent.create_session(request.session_name)
session = agent.create_session(session_name)
return AgenticSystemSessionCreateResponse(
session_id=session.session_id,
)