use customtool's get_tool_definition to remove duplication in agentconfig

This commit is contained in:
Jeff Tang 2024-12-08 10:54:32 -08:00
parent a29013112f
commit ee65d95bb5
2 changed files with 12 additions and 37 deletions

View file

@ -286,6 +286,9 @@
" input_shields = [] if disable_safety else [\"llama_guard\"]\n", " input_shields = [] if disable_safety else [\"llama_guard\"]\n",
" output_shields = [] if disable_safety else [\"llama_guard\"]\n", " output_shields = [] if disable_safety else [\"llama_guard\"]\n",
"\n", "\n",
" # Initialize custom tool (ensure `WebSearchTool` is defined earlier in the notebook)\n",
" webSearchTool = WebSearchTool(api_key=BRAVE_SEARCH_API_KEY)\n",
" \n",
" # Define the agent configuration, including the model and tool setup\n", " # Define the agent configuration, including the model and tool setup\n",
" agent_config = AgentConfig(\n", " agent_config = AgentConfig(\n",
" model=MODEL_NAME,\n", " model=MODEL_NAME,\n",
@ -296,18 +299,7 @@
" \"top_p\": 0.9,\n", " \"top_p\": 0.9,\n",
" },\n", " },\n",
" tools=[\n", " tools=[\n",
" {\n", " webSearchTool.get_tool_definition()\n",
" \"function_name\": \"web_search\", # Name of the tool being integrated\n",
" \"description\": \"Search the web for a given query\",\n",
" \"parameters\": {\n",
" \"query\": {\n",
" \"param_type\": \"str\",\n",
" \"description\": \"The query to search for\",\n",
" \"required\": True,\n",
" }\n",
" },\n",
" \"type\": \"function_call\",\n",
" },\n",
" ],\n", " ],\n",
" tool_choice=\"auto\",\n", " tool_choice=\"auto\",\n",
" tool_prompt_format=\"python_list\",\n", " tool_prompt_format=\"python_list\",\n",
@ -316,11 +308,8 @@
" enable_session_persistence=False,\n", " enable_session_persistence=False,\n",
" )\n", " )\n",
"\n", "\n",
" # Initialize custom tools (ensure `WebSearchTool` is defined earlier in the notebook)\n",
" custom_tools = [WebSearchTool(api_key=BRAVE_SEARCH_API_KEY)]\n",
"\n",
" # Create an agent instance with the client and configuration\n", " # Create an agent instance with the client and configuration\n",
" agent = Agent(client, agent_config, custom_tools)\n", " agent = Agent(client, agent_config, [webSearchTool])\n",
"\n", "\n",
" # Create a session for interaction and print the session ID\n", " # Create a session for interaction and print the session ID\n",
" session_id = agent.create_session(\"test-session\")\n", " session_id = agent.create_session(\"test-session\")\n",

View file

@ -71,7 +71,8 @@
} }
], ],
"source": [ "source": [
"!pip install llama-stack-client==0.0.50" "!pip install llama-stack-client==0.0.50\n",
"!pip install -U httpx==0.27.2 # https://github.com/meta-llama/llama-stack-apps/issues/131"
] ]
}, },
{ {
@ -355,6 +356,9 @@
"async def create_weather_agent(client: LlamaStackClient) -> Agent:\n", "async def create_weather_agent(client: LlamaStackClient) -> Agent:\n",
" \"\"\"Create an agent with weather tool capability.\"\"\"\n", " \"\"\"Create an agent with weather tool capability.\"\"\"\n",
"\n", "\n",
" # Create the agent with the tool\n",
" weather_tool = WeatherTool()\n",
" \n",
" agent_config = AgentConfig(\n", " agent_config = AgentConfig(\n",
" model=LLAMA31_8B_INSTRUCT,\n", " model=LLAMA31_8B_INSTRUCT,\n",
" #model=model_name,\n", " #model=model_name,\n",
@ -369,23 +373,7 @@
" \"top_p\": 0.9,\n", " \"top_p\": 0.9,\n",
" },\n", " },\n",
" tools=[\n", " tools=[\n",
" {\n", " weather_tool.get_tool_definition()\n",
" \"function_name\": \"get_weather\",\n",
" \"description\": \"Get weather information for a location\",\n",
" \"parameters\": {\n",
" \"location\": {\n",
" \"param_type\": \"str\",\n",
" \"description\": \"City or location name\",\n",
" \"required\": True,\n",
" },\n",
" \"date\": {\n",
" \"param_type\": \"str\",\n",
" \"description\": \"Optional date (YYYY-MM-DD)\",\n",
" \"required\": False,\n",
" },\n",
" },\n",
" \"type\": \"function_call\",\n",
" }\n",
" ],\n", " ],\n",
" tool_choice=\"auto\",\n", " tool_choice=\"auto\",\n",
" tool_prompt_format=\"json\",\n", " tool_prompt_format=\"json\",\n",
@ -394,8 +382,6 @@
" enable_session_persistence=True\n", " enable_session_persistence=True\n",
" )\n", " )\n",
"\n", "\n",
" # Create the agent with the tool\n",
" weather_tool = WeatherTool()\n",
" agent = Agent(\n", " agent = Agent(\n",
" client=client,\n", " client=client,\n",
" agent_config=agent_config,\n", " agent_config=agent_config,\n",
@ -470,5 +456,5 @@
} }
}, },
"nbformat": 4, "nbformat": 4,
"nbformat_minor": 0 "nbformat_minor": 4
} }