mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
Convert SamplingParams.strategy
to a union (#767)
# What does this PR do? Cleans up how we provide sampling params. Earlier, strategy was an enum and all params (top_p, temperature, top_k) across all strategies were grouped. We now have a strategy union object with each strategy (greedy, top_p, top_k) having its corresponding params. Earlier, ``` class SamplingParams: strategy: enum () top_p, temperature, top_k and other params ``` However, the `strategy` field was not being used in any providers making it confusing to know the exact sampling behavior purely based on the params since you could pass temperature, top_p, top_k and how the provider would interpret those would not be clear. Hence we introduced -- a union where the strategy and relevant params are all clubbed together to avoid this confusion. Have updated all providers, tests, notebooks, readme and otehr places where sampling params was being used to use the new format. ## Test Plan `pytest llama_stack/providers/tests/inference/groq/test_groq_utils.py` // inference on ollama, fireworks and together `with-proxy pytest -v -s -k "ollama" --inference-model="meta-llama/Llama-3.1-8B-Instruct" llama_stack/providers/tests/inference/test_text_inference.py ` // agents on fireworks `pytest -v -s -k 'fireworks and create_agent' --inference-model="meta-llama/Llama-3.1-8B-Instruct" llama_stack/providers/tests/agents/test_agents.py --safety-shield="meta-llama/Llama-Guard-3-8B"` ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [X] Ran pre-commit to handle lint / formatting issues. - [X] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [X] Updated relevant documentation. - [X] Wrote necessary unit or integration tests. --------- Co-authored-by: Hardik Shah <hjshah@fb.com>
This commit is contained in:
parent
300e6e2702
commit
a51c8b4efc
29 changed files with 611 additions and 388 deletions
|
@ -26,27 +26,28 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"import requests\n",
|
||||
"import json\n",
|
||||
"import asyncio\n",
|
||||
"import nest_asyncio\n",
|
||||
"import json\n",
|
||||
"import os\n",
|
||||
"from typing import Dict, List\n",
|
||||
"\n",
|
||||
"import nest_asyncio\n",
|
||||
"import requests\n",
|
||||
"from dotenv import load_dotenv\n",
|
||||
"from llama_stack_client import LlamaStackClient\n",
|
||||
"from llama_stack_client.lib.agents.custom_tool import CustomTool\n",
|
||||
"from llama_stack_client.types.shared.tool_response_message import ToolResponseMessage\n",
|
||||
"from llama_stack_client.types import CompletionMessage\n",
|
||||
"from llama_stack_client.lib.agents.agent import Agent\n",
|
||||
"from llama_stack_client.lib.agents.custom_tool import CustomTool\n",
|
||||
"from llama_stack_client.lib.agents.event_logger import EventLogger\n",
|
||||
"from llama_stack_client.types import CompletionMessage\n",
|
||||
"from llama_stack_client.types.agent_create_params import AgentConfig\n",
|
||||
"from llama_stack_client.types.shared.tool_response_message import ToolResponseMessage\n",
|
||||
"\n",
|
||||
"# Allow asyncio to run in Jupyter Notebook\n",
|
||||
"nest_asyncio.apply()\n",
|
||||
"\n",
|
||||
"HOST='localhost'\n",
|
||||
"PORT=5001\n",
|
||||
"MODEL_NAME='meta-llama/Llama-3.2-3B-Instruct'"
|
||||
"HOST = \"localhost\"\n",
|
||||
"PORT = 5001\n",
|
||||
"MODEL_NAME = \"meta-llama/Llama-3.2-3B-Instruct\"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -69,7 +70,7 @@
|
|||
"outputs": [],
|
||||
"source": [
|
||||
"load_dotenv()\n",
|
||||
"BRAVE_SEARCH_API_KEY = os.environ['BRAVE_SEARCH_API_KEY']"
|
||||
"BRAVE_SEARCH_API_KEY = os.environ[\"BRAVE_SEARCH_API_KEY\"]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -118,7 +119,7 @@
|
|||
" cleaned = {k: v for k, v in results[idx].items() if k in selected_keys}\n",
|
||||
" clean_response.append(cleaned)\n",
|
||||
"\n",
|
||||
" return {\"query\": query, \"top_k\": clean_response}"
|
||||
" return {\"query\": query, \"top_k\": clean_response}\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -157,25 +158,29 @@
|
|||
" for message in messages:\n",
|
||||
" if isinstance(message, CompletionMessage) and message.tool_calls:\n",
|
||||
" for tool_call in message.tool_calls:\n",
|
||||
" if 'query' in tool_call.arguments:\n",
|
||||
" query = tool_call.arguments['query']\n",
|
||||
" if \"query\" in tool_call.arguments:\n",
|
||||
" query = tool_call.arguments[\"query\"]\n",
|
||||
" call_id = tool_call.call_id\n",
|
||||
"\n",
|
||||
" if query:\n",
|
||||
" search_result = await self.run_impl(query)\n",
|
||||
" return [ToolResponseMessage(\n",
|
||||
" call_id=call_id,\n",
|
||||
" role=\"ipython\",\n",
|
||||
" content=self._format_response_for_agent(search_result),\n",
|
||||
" tool_name=\"brave_search\"\n",
|
||||
" )]\n",
|
||||
" return [\n",
|
||||
" ToolResponseMessage(\n",
|
||||
" call_id=call_id,\n",
|
||||
" role=\"ipython\",\n",
|
||||
" content=self._format_response_for_agent(search_result),\n",
|
||||
" tool_name=\"brave_search\",\n",
|
||||
" )\n",
|
||||
" ]\n",
|
||||
"\n",
|
||||
" return [ToolResponseMessage(\n",
|
||||
" call_id=\"no_call_id\",\n",
|
||||
" role=\"ipython\",\n",
|
||||
" content=\"No query provided.\",\n",
|
||||
" tool_name=\"brave_search\"\n",
|
||||
" )]\n",
|
||||
" return [\n",
|
||||
" ToolResponseMessage(\n",
|
||||
" call_id=\"no_call_id\",\n",
|
||||
" role=\"ipython\",\n",
|
||||
" content=\"No query provided.\",\n",
|
||||
" tool_name=\"brave_search\",\n",
|
||||
" )\n",
|
||||
" ]\n",
|
||||
"\n",
|
||||
" def _format_response_for_agent(self, search_result):\n",
|
||||
" parsed_result = json.loads(search_result)\n",
|
||||
|
@ -186,7 +191,7 @@
|
|||
" f\" URL: {result.get('url', 'No URL')}\\n\"\n",
|
||||
" f\" Description: {result.get('description', 'No Description')}\\n\\n\"\n",
|
||||
" )\n",
|
||||
" return formatted_result"
|
||||
" return formatted_result\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -209,7 +214,7 @@
|
|||
"async def execute_search(query: str):\n",
|
||||
" web_search_tool = WebSearchTool(api_key=BRAVE_SEARCH_API_KEY)\n",
|
||||
" result = await web_search_tool.run_impl(query)\n",
|
||||
" print(\"Search Results:\", result)"
|
||||
" print(\"Search Results:\", result)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -236,7 +241,7 @@
|
|||
],
|
||||
"source": [
|
||||
"query = \"Latest developments in quantum computing\"\n",
|
||||
"asyncio.run(execute_search(query))"
|
||||
"asyncio.run(execute_search(query))\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -288,19 +293,17 @@
|
|||
"\n",
|
||||
" # Initialize custom tool (ensure `WebSearchTool` is defined earlier in the notebook)\n",
|
||||
" webSearchTool = WebSearchTool(api_key=BRAVE_SEARCH_API_KEY)\n",
|
||||
" \n",
|
||||
"\n",
|
||||
" # Define the agent configuration, including the model and tool setup\n",
|
||||
" agent_config = AgentConfig(\n",
|
||||
" model=MODEL_NAME,\n",
|
||||
" instructions=\"\"\"You are a helpful assistant that responds to user queries with relevant information and cites sources when available.\"\"\",\n",
|
||||
" sampling_params={\n",
|
||||
" \"strategy\": \"greedy\",\n",
|
||||
" \"temperature\": 1.0,\n",
|
||||
" \"top_p\": 0.9,\n",
|
||||
" \"strategy\": {\n",
|
||||
" \"type\": \"greedy\",\n",
|
||||
" },\n",
|
||||
" },\n",
|
||||
" tools=[\n",
|
||||
" webSearchTool.get_tool_definition()\n",
|
||||
" ],\n",
|
||||
" tools=[webSearchTool.get_tool_definition()],\n",
|
||||
" tool_choice=\"auto\",\n",
|
||||
" tool_prompt_format=\"python_list\",\n",
|
||||
" input_shields=input_shields,\n",
|
||||
|
@ -329,8 +332,9 @@
|
|||
" async for log in EventLogger().log(response):\n",
|
||||
" log.print()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Run the function asynchronously in a Jupyter Notebook cell\n",
|
||||
"await run_main(disable_safety=True)"
|
||||
"await run_main(disable_safety=True)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue