standardized port and also included pre-req for all notebooks

This commit is contained in:
Justin Lee 2024-11-05 16:38:46 -08:00
parent d0baf24999
commit b556cd91fd
8 changed files with 177 additions and 42 deletions

View file

@ -4,7 +4,9 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## Tool Calling"
"## Tool Calling\n",
"\n",
"Before you begin, please ensure Llama Stack is installed and set up by following the [Getting Started Guide](https://llama-stack.readthedocs.io/en/latest/getting_started/index.html)."
]
},
{
@ -17,6 +19,23 @@
"3. Configuring tool prompts and safety settings"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Set up your connection parameters:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"HOST = \"localhost\" # Replace with your host\n",
"PORT = 5001 # Replace with your port"
]
},
{
"cell_type": "code",
"execution_count": 22,
@ -206,13 +225,13 @@
"from datetime import datetime\n",
"class WeatherTool:\n",
" \"\"\"Example custom tool for weather information.\"\"\"\n",
" \n",
"\n",
" def get_name(self) -> str:\n",
" return \"get_weather\"\n",
" \n",
"\n",
" def get_description(self) -> str:\n",
" return \"Get weather information for a location\"\n",
" \n",
"\n",
" def get_params_definition(self) -> Dict[str, ToolParamDefinitionParam]:\n",
" return {\n",
" \"location\": ToolParamDefinitionParam(\n",
@ -226,7 +245,7 @@
" required=False\n",
" )\n",
" }\n",
" \n",
"\n",
" async def run_impl(self, location: str, date: Optional[str] = None) -> Dict[str, Any]:\n",
" \"\"\"Simulate getting weather data (replace with actual API call).\"\"\"\n",
" # Mock implementation\n",
@ -275,7 +294,7 @@
" output_shields=[],\n",
" enable_session_persistence=True\n",
" )\n",
" \n",
"\n",
" # Create the agent with the tool\n",
" weather_tool = WeatherTool()\n",
" agent = Agent(\n",
@ -283,7 +302,7 @@
" agent_config=agent_config,\n",
" custom_tools=[weather_tool]\n",
" )\n",
" \n",
"\n",
" return agent\n",
"\n",
"# Example usage\n",
@ -291,21 +310,21 @@
" client = LlamaStackClient(base_url=\"http://localhost:5001\")\n",
" agent = await create_weather_agent(client)\n",
" session_id = agent.create_session(\"weather-session\")\n",
" \n",
"\n",
" queries = [\n",
" \"What's the weather like in San Francisco?\",\n",
" \"Tell me the weather in Tokyo tomorrow\",\n",
" ]\n",
" \n",
"\n",
" for query in queries:\n",
" print(f\"\\nQuery: {query}\")\n",
" print(\"-\" * 50)\n",
" \n",
"\n",
" response = agent.create_turn(\n",
" messages=[{\"role\": \"user\", \"content\": query}],\n",
" session_id=session_id,\n",
" )\n",
" \n",
"\n",
" async for log in EventLogger().log(response):\n",
" log.print()\n",
"\n",