Update Strategy in SamplingParams to be a union

This commit is contained in:
Hardik Shah 2025-01-14 15:56:02 -08:00 committed by Ashwin Bharambe
parent 300e6e2702
commit dea575c994
28 changed files with 600 additions and 377 deletions

View file

@ -26,27 +26,28 @@
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import requests\n",
"import json\n",
"import asyncio\n",
"import nest_asyncio\n",
"import json\n",
"import os\n",
"from typing import Dict, List\n",
"\n",
"import nest_asyncio\n",
"import requests\n",
"from dotenv import load_dotenv\n",
"from llama_stack_client import LlamaStackClient\n",
"from llama_stack_client.lib.agents.custom_tool import CustomTool\n",
"from llama_stack_client.types.shared.tool_response_message import ToolResponseMessage\n",
"from llama_stack_client.types import CompletionMessage\n",
"from llama_stack_client.lib.agents.agent import Agent\n",
"from llama_stack_client.lib.agents.custom_tool import CustomTool\n",
"from llama_stack_client.lib.agents.event_logger import EventLogger\n",
"from llama_stack_client.types import CompletionMessage\n",
"from llama_stack_client.types.agent_create_params import AgentConfig\n",
"from llama_stack_client.types.shared.tool_response_message import ToolResponseMessage\n",
"\n",
"# Allow asyncio to run in Jupyter Notebook\n",
"nest_asyncio.apply()\n",
"\n",
"HOST='localhost'\n",
"PORT=5001\n",
"MODEL_NAME='meta-llama/Llama-3.2-3B-Instruct'"
"HOST = \"localhost\"\n",
"PORT = 5001\n",
"MODEL_NAME = \"meta-llama/Llama-3.2-3B-Instruct\"\n"
]
},
{
@ -69,7 +70,7 @@
"outputs": [],
"source": [
"load_dotenv()\n",
"BRAVE_SEARCH_API_KEY = os.environ['BRAVE_SEARCH_API_KEY']"
"BRAVE_SEARCH_API_KEY = os.environ[\"BRAVE_SEARCH_API_KEY\"]\n"
]
},
{
@ -118,7 +119,7 @@
" cleaned = {k: v for k, v in results[idx].items() if k in selected_keys}\n",
" clean_response.append(cleaned)\n",
"\n",
" return {\"query\": query, \"top_k\": clean_response}"
" return {\"query\": query, \"top_k\": clean_response}\n"
]
},
{
@ -157,25 +158,29 @@
" for message in messages:\n",
" if isinstance(message, CompletionMessage) and message.tool_calls:\n",
" for tool_call in message.tool_calls:\n",
" if 'query' in tool_call.arguments:\n",
" query = tool_call.arguments['query']\n",
" if \"query\" in tool_call.arguments:\n",
" query = tool_call.arguments[\"query\"]\n",
" call_id = tool_call.call_id\n",
"\n",
" if query:\n",
" search_result = await self.run_impl(query)\n",
" return [ToolResponseMessage(\n",
" call_id=call_id,\n",
" role=\"ipython\",\n",
" content=self._format_response_for_agent(search_result),\n",
" tool_name=\"brave_search\"\n",
" )]\n",
" return [\n",
" ToolResponseMessage(\n",
" call_id=call_id,\n",
" role=\"ipython\",\n",
" content=self._format_response_for_agent(search_result),\n",
" tool_name=\"brave_search\",\n",
" )\n",
" ]\n",
"\n",
" return [ToolResponseMessage(\n",
" call_id=\"no_call_id\",\n",
" role=\"ipython\",\n",
" content=\"No query provided.\",\n",
" tool_name=\"brave_search\"\n",
" )]\n",
" return [\n",
" ToolResponseMessage(\n",
" call_id=\"no_call_id\",\n",
" role=\"ipython\",\n",
" content=\"No query provided.\",\n",
" tool_name=\"brave_search\",\n",
" )\n",
" ]\n",
"\n",
" def _format_response_for_agent(self, search_result):\n",
" parsed_result = json.loads(search_result)\n",
@ -186,7 +191,7 @@
" f\" URL: {result.get('url', 'No URL')}\\n\"\n",
" f\" Description: {result.get('description', 'No Description')}\\n\\n\"\n",
" )\n",
" return formatted_result"
" return formatted_result\n"
]
},
{
@ -209,7 +214,7 @@
"async def execute_search(query: str):\n",
" web_search_tool = WebSearchTool(api_key=BRAVE_SEARCH_API_KEY)\n",
" result = await web_search_tool.run_impl(query)\n",
" print(\"Search Results:\", result)"
" print(\"Search Results:\", result)\n"
]
},
{
@ -236,7 +241,7 @@
],
"source": [
"query = \"Latest developments in quantum computing\"\n",
"asyncio.run(execute_search(query))"
"asyncio.run(execute_search(query))\n"
]
},
{
@ -288,19 +293,17 @@
"\n",
" # Initialize custom tool (ensure `WebSearchTool` is defined earlier in the notebook)\n",
" webSearchTool = WebSearchTool(api_key=BRAVE_SEARCH_API_KEY)\n",
" \n",
"\n",
" # Define the agent configuration, including the model and tool setup\n",
" agent_config = AgentConfig(\n",
" model=MODEL_NAME,\n",
" instructions=\"\"\"You are a helpful assistant that responds to user queries with relevant information and cites sources when available.\"\"\",\n",
" sampling_params={\n",
" \"strategy\": \"greedy\",\n",
" \"temperature\": 1.0,\n",
" \"top_p\": 0.9,\n",
" \"strategy\": {\n",
" \"type\": \"greedy\",\n",
" },\n",
" },\n",
" tools=[\n",
" webSearchTool.get_tool_definition()\n",
" ],\n",
" tools=[webSearchTool.get_tool_definition()],\n",
" tool_choice=\"auto\",\n",
" tool_prompt_format=\"python_list\",\n",
" input_shields=input_shields,\n",
@ -329,8 +332,9 @@
" async for log in EventLogger().log(response):\n",
" log.print()\n",
"\n",
"\n",
"# Run the function asynchronously in a Jupyter Notebook cell\n",
"await run_main(disable_safety=True)"
"await run_main(disable_safety=True)\n"
]
}
],