fix agentic_system utils

This commit is contained in:
Ashwin Bharambe 2024-08-24 22:56:43 -07:00
parent 8efe614719
commit 830252257b
2 changed files with 11 additions and 24 deletions

View file

@ -7,20 +7,15 @@
import uuid import uuid
from typing import Any, List, Optional from typing import Any, List, Optional
from llama_models.llama3.api.datatypes import ( from llama_models.llama3.api.datatypes import * # noqa: F403
BuiltinTool, from llama_toolchain.agentic_system.api import * # noqa: F403
Message, from llama_toolchain.memory.api import * # noqa: F403
SamplingParams, from llama_toolchain.safety.api import * # noqa: F403
ToolPromptFormat,
)
from llama_toolchain.agentic_system.api import AgentConfig, AgenticSystemToolDefinition
from llama_toolchain.agentic_system.client import AgenticSystemClient from llama_toolchain.agentic_system.client import AgenticSystemClient
from llama_toolchain.agentic_system.meta_reference.execute_with_custom_tools import ( from llama_toolchain.agentic_system.meta_reference.execute_with_custom_tools import (
execute_with_custom_tools, execute_with_custom_tools,
) )
from llama_toolchain.safety.api.datatypes import BuiltinShield, ShieldDefinition
# TODO: this should move back to the llama-agentic-system repo # TODO: this should move back to the llama-agentic-system repo
@ -69,18 +64,10 @@ async def get_agent_system_instance(
api = AgenticSystemClient(base_url=f"http://{host}:{port}") api = AgenticSystemClient(base_url=f"http://{host}:{port}")
tool_definitions = [ tool_definitions = [
AgenticSystemToolDefinition( BraveSearchToolDefinition(),
tool_name=BuiltinTool.brave_search, WolframAlphaToolDefinition(),
), PhotogenToolDefinition(),
AgenticSystemToolDefinition( CodeInterpreterToolDefinition(),
tool_name=BuiltinTool.wolfram_alpha,
),
AgenticSystemToolDefinition(
tool_name=BuiltinTool.photogen,
),
AgenticSystemToolDefinition(
tool_name=BuiltinTool.code_interpreter,
),
] + [t.get_tool_definition() for t in custom_tools] ] + [t.get_tool_definition() for t in custom_tools]
if not disable_safety: if not disable_safety:

View file

@ -54,9 +54,9 @@ class CustomTool:
} }
) )
def get_tool_definition(self) -> AgenticSystemToolDefinition: def get_tool_definition(self) -> FunctionCallToolDefinition:
return AgenticSystemToolDefinition( return FunctionCallToolDefinition(
tool_name=self.get_name(), function_name=self.get_name(),
description=self.get_description(), description=self.get_description(),
parameters=self.get_params_definition(), parameters=self.get_params_definition(),
) )