mirror of
https://github.com/meta-llama/llama-stack.git
synced 2026-01-05 10:22:23 +00:00
Update Strategy in SamplingParams to be a union
This commit is contained in:
parent
300e6e2702
commit
dea575c994
28 changed files with 600 additions and 377 deletions
|
|
@ -26,27 +26,28 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"import requests\n",
|
||||
"import json\n",
|
||||
"import asyncio\n",
|
||||
"import nest_asyncio\n",
|
||||
"import json\n",
|
||||
"import os\n",
|
||||
"from typing import Dict, List\n",
|
||||
"\n",
|
||||
"import nest_asyncio\n",
|
||||
"import requests\n",
|
||||
"from dotenv import load_dotenv\n",
|
||||
"from llama_stack_client import LlamaStackClient\n",
|
||||
"from llama_stack_client.lib.agents.custom_tool import CustomTool\n",
|
||||
"from llama_stack_client.types.shared.tool_response_message import ToolResponseMessage\n",
|
||||
"from llama_stack_client.types import CompletionMessage\n",
|
||||
"from llama_stack_client.lib.agents.agent import Agent\n",
|
||||
"from llama_stack_client.lib.agents.custom_tool import CustomTool\n",
|
||||
"from llama_stack_client.lib.agents.event_logger import EventLogger\n",
|
||||
"from llama_stack_client.types import CompletionMessage\n",
|
||||
"from llama_stack_client.types.agent_create_params import AgentConfig\n",
|
||||
"from llama_stack_client.types.shared.tool_response_message import ToolResponseMessage\n",
|
||||
"\n",
|
||||
"# Allow asyncio to run in Jupyter Notebook\n",
|
||||
"nest_asyncio.apply()\n",
|
||||
"\n",
|
||||
"HOST='localhost'\n",
|
||||
"PORT=5001\n",
|
||||
"MODEL_NAME='meta-llama/Llama-3.2-3B-Instruct'"
|
||||
"HOST = \"localhost\"\n",
|
||||
"PORT = 5001\n",
|
||||
"MODEL_NAME = \"meta-llama/Llama-3.2-3B-Instruct\"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
@ -69,7 +70,7 @@
|
|||
"outputs": [],
|
||||
"source": [
|
||||
"load_dotenv()\n",
|
||||
"BRAVE_SEARCH_API_KEY = os.environ['BRAVE_SEARCH_API_KEY']"
|
||||
"BRAVE_SEARCH_API_KEY = os.environ[\"BRAVE_SEARCH_API_KEY\"]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
@ -118,7 +119,7 @@
|
|||
" cleaned = {k: v for k, v in results[idx].items() if k in selected_keys}\n",
|
||||
" clean_response.append(cleaned)\n",
|
||||
"\n",
|
||||
" return {\"query\": query, \"top_k\": clean_response}"
|
||||
" return {\"query\": query, \"top_k\": clean_response}\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
@ -157,25 +158,29 @@
|
|||
" for message in messages:\n",
|
||||
" if isinstance(message, CompletionMessage) and message.tool_calls:\n",
|
||||
" for tool_call in message.tool_calls:\n",
|
||||
" if 'query' in tool_call.arguments:\n",
|
||||
" query = tool_call.arguments['query']\n",
|
||||
" if \"query\" in tool_call.arguments:\n",
|
||||
" query = tool_call.arguments[\"query\"]\n",
|
||||
" call_id = tool_call.call_id\n",
|
||||
"\n",
|
||||
" if query:\n",
|
||||
" search_result = await self.run_impl(query)\n",
|
||||
" return [ToolResponseMessage(\n",
|
||||
" call_id=call_id,\n",
|
||||
" role=\"ipython\",\n",
|
||||
" content=self._format_response_for_agent(search_result),\n",
|
||||
" tool_name=\"brave_search\"\n",
|
||||
" )]\n",
|
||||
" return [\n",
|
||||
" ToolResponseMessage(\n",
|
||||
" call_id=call_id,\n",
|
||||
" role=\"ipython\",\n",
|
||||
" content=self._format_response_for_agent(search_result),\n",
|
||||
" tool_name=\"brave_search\",\n",
|
||||
" )\n",
|
||||
" ]\n",
|
||||
"\n",
|
||||
" return [ToolResponseMessage(\n",
|
||||
" call_id=\"no_call_id\",\n",
|
||||
" role=\"ipython\",\n",
|
||||
" content=\"No query provided.\",\n",
|
||||
" tool_name=\"brave_search\"\n",
|
||||
" )]\n",
|
||||
" return [\n",
|
||||
" ToolResponseMessage(\n",
|
||||
" call_id=\"no_call_id\",\n",
|
||||
" role=\"ipython\",\n",
|
||||
" content=\"No query provided.\",\n",
|
||||
" tool_name=\"brave_search\",\n",
|
||||
" )\n",
|
||||
" ]\n",
|
||||
"\n",
|
||||
" def _format_response_for_agent(self, search_result):\n",
|
||||
" parsed_result = json.loads(search_result)\n",
|
||||
|
|
@ -186,7 +191,7 @@
|
|||
" f\" URL: {result.get('url', 'No URL')}\\n\"\n",
|
||||
" f\" Description: {result.get('description', 'No Description')}\\n\\n\"\n",
|
||||
" )\n",
|
||||
" return formatted_result"
|
||||
" return formatted_result\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
@ -209,7 +214,7 @@
|
|||
"async def execute_search(query: str):\n",
|
||||
" web_search_tool = WebSearchTool(api_key=BRAVE_SEARCH_API_KEY)\n",
|
||||
" result = await web_search_tool.run_impl(query)\n",
|
||||
" print(\"Search Results:\", result)"
|
||||
" print(\"Search Results:\", result)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
@ -236,7 +241,7 @@
|
|||
],
|
||||
"source": [
|
||||
"query = \"Latest developments in quantum computing\"\n",
|
||||
"asyncio.run(execute_search(query))"
|
||||
"asyncio.run(execute_search(query))\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
@ -288,19 +293,17 @@
|
|||
"\n",
|
||||
" # Initialize custom tool (ensure `WebSearchTool` is defined earlier in the notebook)\n",
|
||||
" webSearchTool = WebSearchTool(api_key=BRAVE_SEARCH_API_KEY)\n",
|
||||
" \n",
|
||||
"\n",
|
||||
" # Define the agent configuration, including the model and tool setup\n",
|
||||
" agent_config = AgentConfig(\n",
|
||||
" model=MODEL_NAME,\n",
|
||||
" instructions=\"\"\"You are a helpful assistant that responds to user queries with relevant information and cites sources when available.\"\"\",\n",
|
||||
" sampling_params={\n",
|
||||
" \"strategy\": \"greedy\",\n",
|
||||
" \"temperature\": 1.0,\n",
|
||||
" \"top_p\": 0.9,\n",
|
||||
" \"strategy\": {\n",
|
||||
" \"type\": \"greedy\",\n",
|
||||
" },\n",
|
||||
" },\n",
|
||||
" tools=[\n",
|
||||
" webSearchTool.get_tool_definition()\n",
|
||||
" ],\n",
|
||||
" tools=[webSearchTool.get_tool_definition()],\n",
|
||||
" tool_choice=\"auto\",\n",
|
||||
" tool_prompt_format=\"python_list\",\n",
|
||||
" input_shields=input_shields,\n",
|
||||
|
|
@ -329,8 +332,9 @@
|
|||
" async for log in EventLogger().log(response):\n",
|
||||
" log.print()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Run the function asynchronously in a Jupyter Notebook cell\n",
|
||||
"await run_main(disable_safety=True)"
|
||||
"await run_main(disable_safety=True)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue