mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-02 20:40:36 +00:00
Convert SamplingParams.strategy
to a union (#767)
# What does this PR do? Cleans up how we provide sampling params. Earlier, strategy was an enum and all params (top_p, temperature, top_k) across all strategies were grouped. We now have a strategy union object with each strategy (greedy, top_p, top_k) having its corresponding params. Earlier, ``` class SamplingParams: strategy: enum () top_p, temperature, top_k and other params ``` However, the `strategy` field was not being used in any providers making it confusing to know the exact sampling behavior purely based on the params since you could pass temperature, top_p, top_k and how the provider would interpret those would not be clear. Hence we introduced -- a union where the strategy and relevant params are all clubbed together to avoid this confusion. Have updated all providers, tests, notebooks, readme and otehr places where sampling params was being used to use the new format. ## Test Plan `pytest llama_stack/providers/tests/inference/groq/test_groq_utils.py` // inference on ollama, fireworks and together `with-proxy pytest -v -s -k "ollama" --inference-model="meta-llama/Llama-3.1-8B-Instruct" llama_stack/providers/tests/inference/test_text_inference.py ` // agents on fireworks `pytest -v -s -k 'fireworks and create_agent' --inference-model="meta-llama/Llama-3.1-8B-Instruct" llama_stack/providers/tests/agents/test_agents.py --safety-shield="meta-llama/Llama-Guard-3-8B"` ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [X] Ran pre-commit to handle lint / formatting issues. - [X] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [X] Updated relevant documentation. - [X] Wrote necessary unit or integration tests. --------- Co-authored-by: Hardik Shah <hjshah@fb.com>
This commit is contained in:
parent
300e6e2702
commit
a51c8b4efc
29 changed files with 611 additions and 388 deletions
|
@ -713,13 +713,15 @@
|
|||
],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"\n",
|
||||
"from google.colab import userdata\n",
|
||||
"\n",
|
||||
"os.environ['TOGETHER_API_KEY'] = userdata.get('TOGETHER_API_KEY')\n",
|
||||
"os.environ[\"TOGETHER_API_KEY\"] = userdata.get(\"TOGETHER_API_KEY\")\n",
|
||||
"\n",
|
||||
"from llama_stack.distribution.library_client import LlamaStackAsLibraryClient\n",
|
||||
"\n",
|
||||
"client = LlamaStackAsLibraryClient(\"together\")\n",
|
||||
"_ = client.initialize()"
|
||||
"_ = client.initialize()\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -769,6 +771,7 @@
|
|||
],
|
||||
"source": [
|
||||
"from rich.pretty import pprint\n",
|
||||
"\n",
|
||||
"print(\"Available models:\")\n",
|
||||
"for m in client.models.list():\n",
|
||||
" print(f\"{m.identifier} (provider's alias: {m.provider_resource_id}) \")\n",
|
||||
|
@ -777,7 +780,7 @@
|
|||
"print(\"Available shields (safety models):\")\n",
|
||||
"for s in client.shields.list():\n",
|
||||
" print(s.identifier)\n",
|
||||
"print(\"----\")"
|
||||
"print(\"----\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -822,7 +825,7 @@
|
|||
"source": [
|
||||
"model_id = \"meta-llama/Llama-3.1-70B-Instruct\"\n",
|
||||
"\n",
|
||||
"model_id"
|
||||
"model_id\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -863,11 +866,11 @@
|
|||
" model_id=model_id,\n",
|
||||
" messages=[\n",
|
||||
" {\"role\": \"system\", \"content\": \"You are a friendly assistant.\"},\n",
|
||||
" {\"role\": \"user\", \"content\": \"Write a two-sentence poem about llama.\"}\n",
|
||||
" {\"role\": \"user\", \"content\": \"Write a two-sentence poem about llama.\"},\n",
|
||||
" ],\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"print(response.completion_message.content)"
|
||||
"print(response.completion_message.content)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -900,12 +903,13 @@
|
|||
"source": [
|
||||
"from termcolor import cprint\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def chat_loop():\n",
|
||||
" conversation_history = []\n",
|
||||
" while True:\n",
|
||||
" user_input = input('User> ')\n",
|
||||
" if user_input.lower() in ['exit', 'quit', 'bye']:\n",
|
||||
" cprint('Ending conversation. Goodbye!', 'yellow')\n",
|
||||
" user_input = input(\"User> \")\n",
|
||||
" if user_input.lower() in [\"exit\", \"quit\", \"bye\"]:\n",
|
||||
" cprint(\"Ending conversation. Goodbye!\", \"yellow\")\n",
|
||||
" break\n",
|
||||
"\n",
|
||||
" user_message = {\"role\": \"user\", \"content\": user_input}\n",
|
||||
|
@ -915,14 +919,15 @@
|
|||
" messages=conversation_history,\n",
|
||||
" model_id=model_id,\n",
|
||||
" )\n",
|
||||
" cprint(f'> Response: {response.completion_message.content}', 'cyan')\n",
|
||||
" cprint(f\"> Response: {response.completion_message.content}\", \"cyan\")\n",
|
||||
"\n",
|
||||
" assistant_message = {\n",
|
||||
" \"role\": \"assistant\", # was user\n",
|
||||
" \"role\": \"assistant\", # was user\n",
|
||||
" \"content\": response.completion_message.content,\n",
|
||||
" }\n",
|
||||
" conversation_history.append(assistant_message)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"chat_loop()\n"
|
||||
]
|
||||
},
|
||||
|
@ -978,21 +983,18 @@
|
|||
"source": [
|
||||
"from llama_stack_client.lib.inference.event_logger import EventLogger\n",
|
||||
"\n",
|
||||
"message = {\n",
|
||||
" \"role\": \"user\",\n",
|
||||
" \"content\": 'Write me a sonnet about llama'\n",
|
||||
"}\n",
|
||||
"print(f'User> {message[\"content\"]}', 'green')\n",
|
||||
"message = {\"role\": \"user\", \"content\": \"Write me a sonnet about llama\"}\n",
|
||||
"print(f'User> {message[\"content\"]}', \"green\")\n",
|
||||
"\n",
|
||||
"response = client.inference.chat_completion(\n",
|
||||
" messages=[message],\n",
|
||||
" model_id=model_id,\n",
|
||||
" stream=True, # <-----------\n",
|
||||
" stream=True, # <-----------\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Print the tokens while they are received\n",
|
||||
"for log in EventLogger().log(response):\n",
|
||||
" log.print()"
|
||||
" log.print()\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -1045,26 +1047,26 @@
|
|||
"source": [
|
||||
"from pydantic import BaseModel\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class Output(BaseModel):\n",
|
||||
" name: str\n",
|
||||
" year_born: str\n",
|
||||
" year_retired: str\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"user_input = \"Michael Jordan was born in 1963. He played basketball for the Chicago Bulls. He retired in 2003. Extract this information into JSON for me. \"\n",
|
||||
"response = client.inference.completion(\n",
|
||||
" model_id=model_id,\n",
|
||||
" content=user_input,\n",
|
||||
" stream=False,\n",
|
||||
" sampling_params={\n",
|
||||
" \"max_tokens\": 50,\n",
|
||||
" },\n",
|
||||
" sampling_params={\"strategy\": {\"type\": \"greedy\"}, \"max_tokens\": 50},\n",
|
||||
" response_format={\n",
|
||||
" \"type\": \"json_schema\",\n",
|
||||
" \"json_schema\": Output.model_json_schema(),\n",
|
||||
" },\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"pprint(response)"
|
||||
"pprint(response)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -1220,7 +1222,7 @@
|
|||
" shield_id=available_shields[0],\n",
|
||||
" params={},\n",
|
||||
" )\n",
|
||||
" pprint(response)"
|
||||
" pprint(response)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -1489,8 +1491,8 @@
|
|||
"source": [
|
||||
"from llama_stack_client.lib.agents.agent import Agent\n",
|
||||
"from llama_stack_client.lib.agents.event_logger import EventLogger\n",
|
||||
"from llama_stack_client.types.agent_create_params import AgentConfig\n",
|
||||
"from llama_stack_client.types import Attachment\n",
|
||||
"from llama_stack_client.types.agent_create_params import AgentConfig\n",
|
||||
"from termcolor import cprint\n",
|
||||
"\n",
|
||||
"urls = [\"chat.rst\", \"llama3.rst\", \"datasets.rst\", \"lora_finetune.rst\"]\n",
|
||||
|
@ -1522,14 +1524,14 @@
|
|||
" ),\n",
|
||||
"]\n",
|
||||
"for prompt, attachments in user_prompts:\n",
|
||||
" cprint(f'User> {prompt}', 'green')\n",
|
||||
" cprint(f\"User> {prompt}\", \"green\")\n",
|
||||
" response = rag_agent.create_turn(\n",
|
||||
" messages=[{\"role\": \"user\", \"content\": prompt}],\n",
|
||||
" attachments=attachments,\n",
|
||||
" session_id=session_id,\n",
|
||||
" )\n",
|
||||
" for log in EventLogger().log(response):\n",
|
||||
" log.print()"
|
||||
" log.print()\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -1560,8 +1562,8 @@
|
|||
"search_tool = {\n",
|
||||
" \"type\": \"brave_search\",\n",
|
||||
" \"engine\": \"tavily\",\n",
|
||||
" \"api_key\": userdata.get(\"TAVILY_SEARCH_API_KEY\")\n",
|
||||
"}"
|
||||
" \"api_key\": userdata.get(\"TAVILY_SEARCH_API_KEY\"),\n",
|
||||
"}\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -1608,7 +1610,7 @@
|
|||
"\n",
|
||||
"session_id = agent.create_session(\"test-session\")\n",
|
||||
"for prompt in user_prompts:\n",
|
||||
" cprint(f'User> {prompt}', 'green')\n",
|
||||
" cprint(f\"User> {prompt}\", \"green\")\n",
|
||||
" response = agent.create_turn(\n",
|
||||
" messages=[\n",
|
||||
" {\n",
|
||||
|
@ -1758,7 +1760,7 @@
|
|||
" search_tool,\n",
|
||||
" {\n",
|
||||
" \"type\": \"code_interpreter\",\n",
|
||||
" }\n",
|
||||
" },\n",
|
||||
" ],\n",
|
||||
" tool_choice=\"required\",\n",
|
||||
" input_shields=[],\n",
|
||||
|
@ -1788,7 +1790,7 @@
|
|||
"]\n",
|
||||
"\n",
|
||||
"for prompt in user_prompts:\n",
|
||||
" cprint(f'User> {prompt}', 'green')\n",
|
||||
" cprint(f\"User> {prompt}\", \"green\")\n",
|
||||
" response = codex_agent.create_turn(\n",
|
||||
" messages=[\n",
|
||||
" {\n",
|
||||
|
@ -1841,27 +1843,57 @@
|
|||
}
|
||||
],
|
||||
"source": [
|
||||
"import pandas as pd\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"import pandas as pd\n",
|
||||
"\n",
|
||||
"# Read the CSV file\n",
|
||||
"df = pd.read_csv('/tmp/tmpco0s0o4_/LOdZoVp1inflation.csv')\n",
|
||||
"df = pd.read_csv(\"/tmp/tmpco0s0o4_/LOdZoVp1inflation.csv\")\n",
|
||||
"\n",
|
||||
"# Extract the year and inflation rate from the CSV file\n",
|
||||
"df['Year'] = pd.to_datetime(df['Year'], format='%Y')\n",
|
||||
"df = df.rename(columns={'Jan': 'Jan Rate', 'Feb': 'Feb Rate', 'Mar': 'Mar Rate', 'Apr': 'Apr Rate', 'May': 'May Rate', 'Jun': 'Jun Rate', 'Jul': 'Jul Rate', 'Aug': 'Aug Rate', 'Sep': 'Sep Rate', 'Oct': 'Oct Rate', 'Nov': 'Nov Rate', 'Dec': 'Dec Rate'})\n",
|
||||
"df[\"Year\"] = pd.to_datetime(df[\"Year\"], format=\"%Y\")\n",
|
||||
"df = df.rename(\n",
|
||||
" columns={\n",
|
||||
" \"Jan\": \"Jan Rate\",\n",
|
||||
" \"Feb\": \"Feb Rate\",\n",
|
||||
" \"Mar\": \"Mar Rate\",\n",
|
||||
" \"Apr\": \"Apr Rate\",\n",
|
||||
" \"May\": \"May Rate\",\n",
|
||||
" \"Jun\": \"Jun Rate\",\n",
|
||||
" \"Jul\": \"Jul Rate\",\n",
|
||||
" \"Aug\": \"Aug Rate\",\n",
|
||||
" \"Sep\": \"Sep Rate\",\n",
|
||||
" \"Oct\": \"Oct Rate\",\n",
|
||||
" \"Nov\": \"Nov Rate\",\n",
|
||||
" \"Dec\": \"Dec Rate\",\n",
|
||||
" }\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Calculate the average yearly inflation rate\n",
|
||||
"df['Yearly Inflation'] = df[['Jan Rate', 'Feb Rate', 'Mar Rate', 'Apr Rate', 'May Rate', 'Jun Rate', 'Jul Rate', 'Aug Rate', 'Sep Rate', 'Oct Rate', 'Nov Rate', 'Dec Rate']].mean(axis=1)\n",
|
||||
"df[\"Yearly Inflation\"] = df[\n",
|
||||
" [\n",
|
||||
" \"Jan Rate\",\n",
|
||||
" \"Feb Rate\",\n",
|
||||
" \"Mar Rate\",\n",
|
||||
" \"Apr Rate\",\n",
|
||||
" \"May Rate\",\n",
|
||||
" \"Jun Rate\",\n",
|
||||
" \"Jul Rate\",\n",
|
||||
" \"Aug Rate\",\n",
|
||||
" \"Sep Rate\",\n",
|
||||
" \"Oct Rate\",\n",
|
||||
" \"Nov Rate\",\n",
|
||||
" \"Dec Rate\",\n",
|
||||
" ]\n",
|
||||
"].mean(axis=1)\n",
|
||||
"\n",
|
||||
"# Plot the average yearly inflation rate as a time series\n",
|
||||
"plt.figure(figsize=(10, 6))\n",
|
||||
"plt.plot(df['Year'], df['Yearly Inflation'], marker='o')\n",
|
||||
"plt.title('Average Yearly Inflation Rate')\n",
|
||||
"plt.xlabel('Year')\n",
|
||||
"plt.ylabel('Inflation Rate (%)')\n",
|
||||
"plt.plot(df[\"Year\"], df[\"Yearly Inflation\"], marker=\"o\")\n",
|
||||
"plt.title(\"Average Yearly Inflation Rate\")\n",
|
||||
"plt.xlabel(\"Year\")\n",
|
||||
"plt.ylabel(\"Inflation Rate (%)\")\n",
|
||||
"plt.grid(True)\n",
|
||||
"plt.show()"
|
||||
"plt.show()\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -2035,6 +2067,8 @@
|
|||
"source": [
|
||||
"# disable logging for clean server logs\n",
|
||||
"import logging\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def remove_root_handlers():\n",
|
||||
" root_logger = logging.getLogger()\n",
|
||||
" for handler in root_logger.handlers[:]:\n",
|
||||
|
@ -2042,7 +2076,7 @@
|
|||
" print(f\"Removed handler {handler.__class__.__name__} from root logger\")\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"remove_root_handlers()"
|
||||
"remove_root_handlers()\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -2083,10 +2117,10 @@
|
|||
}
|
||||
],
|
||||
"source": [
|
||||
"from google.colab import userdata\n",
|
||||
"from llama_stack_client.lib.agents.agent import Agent\n",
|
||||
"from llama_stack_client.lib.agents.event_logger import EventLogger\n",
|
||||
"from llama_stack_client.types.agent_create_params import AgentConfig\n",
|
||||
"from google.colab import userdata\n",
|
||||
"\n",
|
||||
"agent_config = AgentConfig(\n",
|
||||
" model=\"meta-llama/Llama-3.1-405B-Instruct\",\n",
|
||||
|
@ -2096,7 +2130,7 @@
|
|||
" {\n",
|
||||
" \"type\": \"brave_search\",\n",
|
||||
" \"engine\": \"tavily\",\n",
|
||||
" \"api_key\": userdata.get(\"TAVILY_SEARCH_API_KEY\")\n",
|
||||
" \"api_key\": userdata.get(\"TAVILY_SEARCH_API_KEY\"),\n",
|
||||
" }\n",
|
||||
" ]\n",
|
||||
" ),\n",
|
||||
|
@ -2125,7 +2159,7 @@
|
|||
" )\n",
|
||||
"\n",
|
||||
" for log in EventLogger().log(response):\n",
|
||||
" log.print()"
|
||||
" log.print()\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -2265,20 +2299,21 @@
|
|||
"source": [
|
||||
"print(f\"Getting traces for session_id={session_id}\")\n",
|
||||
"import json\n",
|
||||
"\n",
|
||||
"from rich.pretty import pprint\n",
|
||||
"\n",
|
||||
"agent_logs = []\n",
|
||||
"\n",
|
||||
"for span in client.telemetry.query_spans(\n",
|
||||
" attribute_filters=[\n",
|
||||
" {\"key\": \"session_id\", \"op\": \"eq\", \"value\": session_id},\n",
|
||||
" {\"key\": \"session_id\", \"op\": \"eq\", \"value\": session_id},\n",
|
||||
" ],\n",
|
||||
" attributes_to_return=[\"input\", \"output\"]\n",
|
||||
" ):\n",
|
||||
" if span.attributes[\"output\"] != \"no shields\":\n",
|
||||
" agent_logs.append(span.attributes)\n",
|
||||
" attributes_to_return=[\"input\", \"output\"],\n",
|
||||
"):\n",
|
||||
" if span.attributes[\"output\"] != \"no shields\":\n",
|
||||
" agent_logs.append(span.attributes)\n",
|
||||
"\n",
|
||||
"pprint(agent_logs)"
|
||||
"pprint(agent_logs)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -2389,23 +2424,25 @@
|
|||
"eval_rows = []\n",
|
||||
"\n",
|
||||
"for log in agent_logs:\n",
|
||||
" last_msg = log['input'][-1]\n",
|
||||
" if \"\\\"role\\\":\\\"user\\\"\" in last_msg:\n",
|
||||
" eval_rows.append(\n",
|
||||
" {\n",
|
||||
" \"input_query\": last_msg,\n",
|
||||
" \"generated_answer\": log[\"output\"],\n",
|
||||
" # check if generated_answer uses tools brave_search\n",
|
||||
" \"expected_answer\": \"brave_search\",\n",
|
||||
" },\n",
|
||||
" )\n",
|
||||
" last_msg = log[\"input\"][-1]\n",
|
||||
" if '\"role\":\"user\"' in last_msg:\n",
|
||||
" eval_rows.append(\n",
|
||||
" {\n",
|
||||
" \"input_query\": last_msg,\n",
|
||||
" \"generated_answer\": log[\"output\"],\n",
|
||||
" # check if generated_answer uses tools brave_search\n",
|
||||
" \"expected_answer\": \"brave_search\",\n",
|
||||
" },\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"pprint(eval_rows)\n",
|
||||
"scoring_params = {\n",
|
||||
" \"basic::subset_of\": None,\n",
|
||||
"}\n",
|
||||
"scoring_response = client.scoring.score(input_rows=eval_rows, scoring_functions=scoring_params)\n",
|
||||
"pprint(scoring_response)"
|
||||
"scoring_response = client.scoring.score(\n",
|
||||
" input_rows=eval_rows, scoring_functions=scoring_params\n",
|
||||
")\n",
|
||||
"pprint(scoring_response)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -2506,7 +2543,9 @@
|
|||
"EXPECTED_RESPONSE: {expected_answer}\n",
|
||||
"\"\"\"\n",
|
||||
"\n",
|
||||
"input_query = \"What are the top 5 topics that were explained? Only list succinct bullet points.\"\n",
|
||||
"input_query = (\n",
|
||||
" \"What are the top 5 topics that were explained? Only list succinct bullet points.\"\n",
|
||||
")\n",
|
||||
"generated_answer = \"\"\"\n",
|
||||
"Here are the top 5 topics that were explained in the documentation for Torchtune:\n",
|
||||
"\n",
|
||||
|
@ -2537,7 +2576,7 @@
|
|||
"}\n",
|
||||
"\n",
|
||||
"response = client.scoring.score(input_rows=rows, scoring_functions=scoring_params)\n",
|
||||
"pprint(response)"
|
||||
"pprint(response)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
|
@ -618,11 +618,13 @@
|
|||
],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"\n",
|
||||
"from google.colab import userdata\n",
|
||||
"\n",
|
||||
"os.environ['TOGETHER_API_KEY'] = userdata.get('TOGETHER_API_KEY')\n",
|
||||
"os.environ[\"TOGETHER_API_KEY\"] = userdata.get(\"TOGETHER_API_KEY\")\n",
|
||||
"\n",
|
||||
"from llama_stack.distribution.library_client import LlamaStackAsLibraryClient\n",
|
||||
"\n",
|
||||
"client = LlamaStackAsLibraryClient(\"together\")\n",
|
||||
"_ = client.initialize()\n",
|
||||
"\n",
|
||||
|
@ -631,7 +633,7 @@
|
|||
" model_id=\"meta-llama/Llama-3.1-405B-Instruct\",\n",
|
||||
" provider_model_id=\"meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo\",\n",
|
||||
" provider_id=\"together\",\n",
|
||||
")"
|
||||
")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -668,7 +670,7 @@
|
|||
"source": [
|
||||
"name = \"llamastack/mmmu\"\n",
|
||||
"subset = \"Agriculture\"\n",
|
||||
"split = \"dev\""
|
||||
"split = \"dev\"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -914,9 +916,10 @@
|
|||
],
|
||||
"source": [
|
||||
"import datasets\n",
|
||||
"\n",
|
||||
"ds = datasets.load_dataset(path=name, name=subset, split=split)\n",
|
||||
"ds = ds.select_columns([\"chat_completion_input\", \"input_query\", \"expected_answer\"])\n",
|
||||
"eval_rows = ds.to_pandas().to_dict(orient=\"records\")"
|
||||
"eval_rows = ds.to_pandas().to_dict(orient=\"records\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -1014,8 +1017,8 @@
|
|||
}
|
||||
],
|
||||
"source": [
|
||||
"from tqdm import tqdm\n",
|
||||
"from rich.pretty import pprint\n",
|
||||
"from tqdm import tqdm\n",
|
||||
"\n",
|
||||
"SYSTEM_PROMPT_TEMPLATE = \"\"\"\n",
|
||||
"You are an expert in {subject} whose job is to answer questions from the user using images.\n",
|
||||
|
@ -1039,7 +1042,7 @@
|
|||
"client.eval_tasks.register(\n",
|
||||
" eval_task_id=\"meta-reference::mmmu\",\n",
|
||||
" dataset_id=f\"mmmu-{subset}-{split}\",\n",
|
||||
" scoring_functions=[\"basic::regex_parser_multiple_choice_answer\"]\n",
|
||||
" scoring_functions=[\"basic::regex_parser_multiple_choice_answer\"],\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"response = client.eval.evaluate_rows(\n",
|
||||
|
@ -1052,16 +1055,17 @@
|
|||
" \"type\": \"model\",\n",
|
||||
" \"model\": \"meta-llama/Llama-3.2-90B-Vision-Instruct\",\n",
|
||||
" \"sampling_params\": {\n",
|
||||
" \"temperature\": 0.0,\n",
|
||||
" \"strategy\": {\n",
|
||||
" \"type\": \"greedy\",\n",
|
||||
" },\n",
|
||||
" \"max_tokens\": 4096,\n",
|
||||
" \"top_p\": 0.9,\n",
|
||||
" \"repeat_penalty\": 1.0,\n",
|
||||
" },\n",
|
||||
" \"system_message\": system_message\n",
|
||||
" }\n",
|
||||
" }\n",
|
||||
" \"system_message\": system_message,\n",
|
||||
" },\n",
|
||||
" },\n",
|
||||
")\n",
|
||||
"pprint(response)"
|
||||
"pprint(response)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -1098,8 +1102,8 @@
|
|||
" \"input_query\": {\"type\": \"string\"},\n",
|
||||
" \"expected_answer\": {\"type\": \"string\"},\n",
|
||||
" \"chat_completion_input\": {\"type\": \"chat_completion_input\"},\n",
|
||||
" }\n",
|
||||
")"
|
||||
" },\n",
|
||||
")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -1113,7 +1117,7 @@
|
|||
"eval_rows = client.datasetio.get_rows_paginated(\n",
|
||||
" dataset_id=simpleqa_dataset_id,\n",
|
||||
" rows_in_page=5,\n",
|
||||
")"
|
||||
")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -1209,7 +1213,7 @@
|
|||
"client.eval_tasks.register(\n",
|
||||
" eval_task_id=\"meta-reference::simpleqa\",\n",
|
||||
" dataset_id=simpleqa_dataset_id,\n",
|
||||
" scoring_functions=[\"llm-as-judge::405b-simpleqa\"]\n",
|
||||
" scoring_functions=[\"llm-as-judge::405b-simpleqa\"],\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"response = client.eval.evaluate_rows(\n",
|
||||
|
@ -1222,15 +1226,16 @@
|
|||
" \"type\": \"model\",\n",
|
||||
" \"model\": \"meta-llama/Llama-3.2-90B-Vision-Instruct\",\n",
|
||||
" \"sampling_params\": {\n",
|
||||
" \"temperature\": 0.0,\n",
|
||||
" \"strategy\": {\n",
|
||||
" \"type\": \"greedy\",\n",
|
||||
" },\n",
|
||||
" \"max_tokens\": 4096,\n",
|
||||
" \"top_p\": 0.9,\n",
|
||||
" \"repeat_penalty\": 1.0,\n",
|
||||
" },\n",
|
||||
" }\n",
|
||||
" }\n",
|
||||
" },\n",
|
||||
" },\n",
|
||||
")\n",
|
||||
"pprint(response)"
|
||||
"pprint(response)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -1347,23 +1352,19 @@
|
|||
"agent_config = {\n",
|
||||
" \"model\": \"meta-llama/Llama-3.1-405B-Instruct\",\n",
|
||||
" \"instructions\": \"You are a helpful assistant\",\n",
|
||||
" \"sampling_params\": {\n",
|
||||
" \"strategy\": \"greedy\",\n",
|
||||
" \"temperature\": 0.0,\n",
|
||||
" \"top_p\": 0.95,\n",
|
||||
" },\n",
|
||||
" \"sampling_params\": {\"strategy\": {\"type\": \"greedy\"}},\n",
|
||||
" \"tools\": [\n",
|
||||
" {\n",
|
||||
" \"type\": \"brave_search\",\n",
|
||||
" \"engine\": \"tavily\",\n",
|
||||
" \"api_key\": userdata.get(\"TAVILY_SEARCH_API_KEY\")\n",
|
||||
" \"api_key\": userdata.get(\"TAVILY_SEARCH_API_KEY\"),\n",
|
||||
" }\n",
|
||||
" ],\n",
|
||||
" \"tool_choice\": \"auto\",\n",
|
||||
" \"tool_prompt_format\": \"json\",\n",
|
||||
" \"input_shields\": [],\n",
|
||||
" \"output_shields\": [],\n",
|
||||
" \"enable_session_persistence\": False\n",
|
||||
" \"enable_session_persistence\": False,\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"response = client.eval.evaluate_rows(\n",
|
||||
|
@ -1375,10 +1376,10 @@
|
|||
" \"eval_candidate\": {\n",
|
||||
" \"type\": \"agent\",\n",
|
||||
" \"config\": agent_config,\n",
|
||||
" }\n",
|
||||
" }\n",
|
||||
" },\n",
|
||||
" },\n",
|
||||
")\n",
|
||||
"pprint(response)"
|
||||
"pprint(response)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
|
|
@ -1336,6 +1336,7 @@
|
|||
],
|
||||
"source": [
|
||||
"from rich.pretty import pprint\n",
|
||||
"\n",
|
||||
"print(\"Available models:\")\n",
|
||||
"for m in client.models.list():\n",
|
||||
" print(f\"{m.identifier} (provider's alias: {m.provider_resource_id}) \")\n",
|
||||
|
@ -1344,7 +1345,7 @@
|
|||
"print(\"Available shields (safety models):\")\n",
|
||||
"for s in client.shields.list():\n",
|
||||
" print(s.identifier)\n",
|
||||
"print(\"----\")"
|
||||
"print(\"----\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -1389,7 +1390,7 @@
|
|||
"source": [
|
||||
"model_id = \"meta-llama/Llama-3.1-70B-Instruct\"\n",
|
||||
"\n",
|
||||
"model_id"
|
||||
"model_id\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -1432,11 +1433,11 @@
|
|||
" model_id=model_id,\n",
|
||||
" messages=[\n",
|
||||
" {\"role\": \"system\", \"content\": \"You are a friendly assistant.\"},\n",
|
||||
" {\"role\": \"user\", \"content\": \"Write a two-sentence poem about llama.\"}\n",
|
||||
" {\"role\": \"user\", \"content\": \"Write a two-sentence poem about llama.\"},\n",
|
||||
" ],\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"print(response.completion_message.content)"
|
||||
"print(response.completion_message.content)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -1489,12 +1490,13 @@
|
|||
"source": [
|
||||
"from termcolor import cprint\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def chat_loop():\n",
|
||||
" conversation_history = []\n",
|
||||
" while True:\n",
|
||||
" user_input = input('User> ')\n",
|
||||
" if user_input.lower() in ['exit', 'quit', 'bye']:\n",
|
||||
" cprint('Ending conversation. Goodbye!', 'yellow')\n",
|
||||
" user_input = input(\"User> \")\n",
|
||||
" if user_input.lower() in [\"exit\", \"quit\", \"bye\"]:\n",
|
||||
" cprint(\"Ending conversation. Goodbye!\", \"yellow\")\n",
|
||||
" break\n",
|
||||
"\n",
|
||||
" user_message = {\"role\": \"user\", \"content\": user_input}\n",
|
||||
|
@ -1504,15 +1506,16 @@
|
|||
" messages=conversation_history,\n",
|
||||
" model_id=model_id,\n",
|
||||
" )\n",
|
||||
" cprint(f'> Response: {response.completion_message.content}', 'cyan')\n",
|
||||
" cprint(f\"> Response: {response.completion_message.content}\", \"cyan\")\n",
|
||||
"\n",
|
||||
" assistant_message = {\n",
|
||||
" \"role\": \"assistant\", # was user\n",
|
||||
" \"role\": \"assistant\", # was user\n",
|
||||
" \"content\": response.completion_message.content,\n",
|
||||
" \"stop_reason\": response.completion_message.stop_reason,\n",
|
||||
" }\n",
|
||||
" conversation_history.append(assistant_message)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"chat_loop()\n"
|
||||
]
|
||||
},
|
||||
|
@ -1568,21 +1571,18 @@
|
|||
"source": [
|
||||
"from llama_stack_client.lib.inference.event_logger import EventLogger\n",
|
||||
"\n",
|
||||
"message = {\n",
|
||||
" \"role\": \"user\",\n",
|
||||
" \"content\": 'Write me a sonnet about llama'\n",
|
||||
"}\n",
|
||||
"print(f'User> {message[\"content\"]}', 'green')\n",
|
||||
"message = {\"role\": \"user\", \"content\": \"Write me a sonnet about llama\"}\n",
|
||||
"print(f'User> {message[\"content\"]}', \"green\")\n",
|
||||
"\n",
|
||||
"response = client.inference.chat_completion(\n",
|
||||
" messages=[message],\n",
|
||||
" model_id=model_id,\n",
|
||||
" stream=True, # <-----------\n",
|
||||
" stream=True, # <-----------\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Print the tokens while they are received\n",
|
||||
"for log in EventLogger().log(response):\n",
|
||||
" log.print()"
|
||||
" log.print()\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -1648,17 +1648,22 @@
|
|||
"source": [
|
||||
"from pydantic import BaseModel\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class Output(BaseModel):\n",
|
||||
" name: str\n",
|
||||
" year_born: str\n",
|
||||
" year_retired: str\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"user_input = \"Michael Jordan was born in 1963. He played basketball for the Chicago Bulls. He retired in 2003. Extract this information into JSON for me. \"\n",
|
||||
"response = client.inference.completion(\n",
|
||||
" model_id=model_id,\n",
|
||||
" content=user_input,\n",
|
||||
" stream=False,\n",
|
||||
" sampling_params={\n",
|
||||
" \"strategy\": {\n",
|
||||
" \"type\": \"greedy\",\n",
|
||||
" },\n",
|
||||
" \"max_tokens\": 50,\n",
|
||||
" },\n",
|
||||
" response_format={\n",
|
||||
|
@ -1667,7 +1672,7 @@
|
|||
" },\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"pprint(response)"
|
||||
"pprint(response)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -1823,7 +1828,7 @@
|
|||
" shield_id=available_shields[0],\n",
|
||||
" params={},\n",
|
||||
" )\n",
|
||||
" pprint(response)"
|
||||
" pprint(response)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -2025,7 +2030,7 @@
|
|||
"\n",
|
||||
"session_id = agent.create_session(\"test-session\")\n",
|
||||
"for prompt in user_prompts:\n",
|
||||
" cprint(f'User> {prompt}', 'green')\n",
|
||||
" cprint(f\"User> {prompt}\", \"green\")\n",
|
||||
" response = agent.create_turn(\n",
|
||||
" messages=[\n",
|
||||
" {\n",
|
||||
|
@ -2451,8 +2456,8 @@
|
|||
}
|
||||
],
|
||||
"source": [
|
||||
"import pandas as pd\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"import pandas as pd\n",
|
||||
"\n",
|
||||
"# Load data\n",
|
||||
"df = pd.read_csv(\"/tmp/tmpvzjigv7g/n2OzlTWhinflation.csv\")\n",
|
||||
|
@ -2536,10 +2541,10 @@
|
|||
}
|
||||
],
|
||||
"source": [
|
||||
"from google.colab import userdata\n",
|
||||
"from llama_stack_client.lib.agents.agent import Agent\n",
|
||||
"from llama_stack_client.lib.agents.event_logger import EventLogger\n",
|
||||
"from llama_stack_client.types.agent_create_params import AgentConfig\n",
|
||||
"from google.colab import userdata\n",
|
||||
"\n",
|
||||
"agent_config = AgentConfig(\n",
|
||||
" model=\"meta-llama/Llama-3.1-405B-Instruct-FP8\",\n",
|
||||
|
@ -2570,7 +2575,7 @@
|
|||
" )\n",
|
||||
"\n",
|
||||
" for log in EventLogger().log(response):\n",
|
||||
" log.print()"
|
||||
" log.print()\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -2790,20 +2795,21 @@
|
|||
"source": [
|
||||
"print(f\"Getting traces for session_id={session_id}\")\n",
|
||||
"import json\n",
|
||||
"\n",
|
||||
"from rich.pretty import pprint\n",
|
||||
"\n",
|
||||
"agent_logs = []\n",
|
||||
"\n",
|
||||
"for span in client.telemetry.query_spans(\n",
|
||||
" attribute_filters=[\n",
|
||||
" {\"key\": \"session_id\", \"op\": \"eq\", \"value\": session_id},\n",
|
||||
" {\"key\": \"session_id\", \"op\": \"eq\", \"value\": session_id},\n",
|
||||
" ],\n",
|
||||
" attributes_to_return=[\"input\", \"output\"]\n",
|
||||
" ):\n",
|
||||
" if span.attributes[\"output\"] != \"no shields\":\n",
|
||||
" agent_logs.append(span.attributes)\n",
|
||||
" attributes_to_return=[\"input\", \"output\"],\n",
|
||||
"):\n",
|
||||
" if span.attributes[\"output\"] != \"no shields\":\n",
|
||||
" agent_logs.append(span.attributes)\n",
|
||||
"\n",
|
||||
"pprint(agent_logs)"
|
||||
"pprint(agent_logs)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -2914,23 +2920,25 @@
|
|||
"eval_rows = []\n",
|
||||
"\n",
|
||||
"for log in agent_logs:\n",
|
||||
" last_msg = log['input'][-1]\n",
|
||||
" if \"\\\"role\\\":\\\"user\\\"\" in last_msg:\n",
|
||||
" eval_rows.append(\n",
|
||||
" {\n",
|
||||
" \"input_query\": last_msg,\n",
|
||||
" \"generated_answer\": log[\"output\"],\n",
|
||||
" # check if generated_answer uses tools brave_search\n",
|
||||
" \"expected_answer\": \"brave_search\",\n",
|
||||
" },\n",
|
||||
" )\n",
|
||||
" last_msg = log[\"input\"][-1]\n",
|
||||
" if '\"role\":\"user\"' in last_msg:\n",
|
||||
" eval_rows.append(\n",
|
||||
" {\n",
|
||||
" \"input_query\": last_msg,\n",
|
||||
" \"generated_answer\": log[\"output\"],\n",
|
||||
" # check if generated_answer uses tools brave_search\n",
|
||||
" \"expected_answer\": \"brave_search\",\n",
|
||||
" },\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"pprint(eval_rows)\n",
|
||||
"scoring_params = {\n",
|
||||
" \"basic::subset_of\": None,\n",
|
||||
"}\n",
|
||||
"scoring_response = client.scoring.score(input_rows=eval_rows, scoring_functions=scoring_params)\n",
|
||||
"pprint(scoring_response)"
|
||||
"scoring_response = client.scoring.score(\n",
|
||||
" input_rows=eval_rows, scoring_functions=scoring_params\n",
|
||||
")\n",
|
||||
"pprint(scoring_response)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -3031,7 +3039,9 @@
|
|||
"EXPECTED_RESPONSE: {expected_answer}\n",
|
||||
"\"\"\"\n",
|
||||
"\n",
|
||||
"input_query = \"What are the top 5 topics that were explained? Only list succinct bullet points.\"\n",
|
||||
"input_query = (\n",
|
||||
" \"What are the top 5 topics that were explained? Only list succinct bullet points.\"\n",
|
||||
")\n",
|
||||
"generated_answer = \"\"\"\n",
|
||||
"Here are the top 5 topics that were explained in the documentation for Torchtune:\n",
|
||||
"\n",
|
||||
|
@ -3062,7 +3072,7 @@
|
|||
"}\n",
|
||||
"\n",
|
||||
"response = client.scoring.score(input_rows=rows, scoring_functions=scoring_params)\n",
|
||||
"pprint(response)"
|
||||
"pprint(response)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
|
|
@ -3514,6 +3514,20 @@
|
|||
"tool_calls"
|
||||
]
|
||||
},
|
||||
"GreedySamplingStrategy": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"type": {
|
||||
"type": "string",
|
||||
"const": "greedy",
|
||||
"default": "greedy"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"type"
|
||||
]
|
||||
},
|
||||
"ImageContentItem": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
@ -3581,20 +3595,17 @@
|
|||
"type": "object",
|
||||
"properties": {
|
||||
"strategy": {
|
||||
"$ref": "#/components/schemas/SamplingStrategy",
|
||||
"default": "greedy"
|
||||
},
|
||||
"temperature": {
|
||||
"type": "number",
|
||||
"default": 0.0
|
||||
},
|
||||
"top_p": {
|
||||
"type": "number",
|
||||
"default": 0.95
|
||||
},
|
||||
"top_k": {
|
||||
"type": "integer",
|
||||
"default": 0
|
||||
"oneOf": [
|
||||
{
|
||||
"$ref": "#/components/schemas/GreedySamplingStrategy"
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/TopPSamplingStrategy"
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/TopKSamplingStrategy"
|
||||
}
|
||||
]
|
||||
},
|
||||
"max_tokens": {
|
||||
"type": "integer",
|
||||
|
@ -3610,14 +3621,6 @@
|
|||
"strategy"
|
||||
]
|
||||
},
|
||||
"SamplingStrategy": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"greedy",
|
||||
"top_p",
|
||||
"top_k"
|
||||
]
|
||||
},
|
||||
"StopReason": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
|
@ -3871,6 +3874,45 @@
|
|||
"content"
|
||||
]
|
||||
},
|
||||
"TopKSamplingStrategy": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"type": {
|
||||
"type": "string",
|
||||
"const": "top_k",
|
||||
"default": "top_k"
|
||||
},
|
||||
"top_k": {
|
||||
"type": "integer"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"type",
|
||||
"top_k"
|
||||
]
|
||||
},
|
||||
"TopPSamplingStrategy": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"type": {
|
||||
"type": "string",
|
||||
"const": "top_p",
|
||||
"default": "top_p"
|
||||
},
|
||||
"temperature": {
|
||||
"type": "number"
|
||||
},
|
||||
"top_p": {
|
||||
"type": "number",
|
||||
"default": 0.95
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"type"
|
||||
]
|
||||
},
|
||||
"URL": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
@ -8887,6 +8929,10 @@
|
|||
"name": "GraphMemoryBankParams",
|
||||
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/GraphMemoryBankParams\" />"
|
||||
},
|
||||
{
|
||||
"name": "GreedySamplingStrategy",
|
||||
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/GreedySamplingStrategy\" />"
|
||||
},
|
||||
{
|
||||
"name": "HealthInfo",
|
||||
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/HealthInfo\" />"
|
||||
|
@ -9136,10 +9182,6 @@
|
|||
"name": "SamplingParams",
|
||||
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/SamplingParams\" />"
|
||||
},
|
||||
{
|
||||
"name": "SamplingStrategy",
|
||||
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/SamplingStrategy\" />"
|
||||
},
|
||||
{
|
||||
"name": "SaveSpansToDatasetRequest",
|
||||
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/SaveSpansToDatasetRequest\" />"
|
||||
|
@ -9317,6 +9359,14 @@
|
|||
{
|
||||
"name": "ToolRuntime"
|
||||
},
|
||||
{
|
||||
"name": "TopKSamplingStrategy",
|
||||
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/TopKSamplingStrategy\" />"
|
||||
},
|
||||
{
|
||||
"name": "TopPSamplingStrategy",
|
||||
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/TopPSamplingStrategy\" />"
|
||||
},
|
||||
{
|
||||
"name": "Trace",
|
||||
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/Trace\" />"
|
||||
|
@ -9456,6 +9506,7 @@
|
|||
"GetSpanTreeRequest",
|
||||
"GraphMemoryBank",
|
||||
"GraphMemoryBankParams",
|
||||
"GreedySamplingStrategy",
|
||||
"HealthInfo",
|
||||
"ImageContentItem",
|
||||
"InferenceStep",
|
||||
|
@ -9513,7 +9564,6 @@
|
|||
"RunShieldResponse",
|
||||
"SafetyViolation",
|
||||
"SamplingParams",
|
||||
"SamplingStrategy",
|
||||
"SaveSpansToDatasetRequest",
|
||||
"ScoreBatchRequest",
|
||||
"ScoreBatchResponse",
|
||||
|
@ -9553,6 +9603,8 @@
|
|||
"ToolPromptFormat",
|
||||
"ToolResponse",
|
||||
"ToolResponseMessage",
|
||||
"TopKSamplingStrategy",
|
||||
"TopPSamplingStrategy",
|
||||
"Trace",
|
||||
"TrainingConfig",
|
||||
"Turn",
|
||||
|
|
|
@ -937,6 +937,16 @@ components:
|
|||
required:
|
||||
- memory_bank_type
|
||||
type: object
|
||||
GreedySamplingStrategy:
|
||||
additionalProperties: false
|
||||
properties:
|
||||
type:
|
||||
const: greedy
|
||||
default: greedy
|
||||
type: string
|
||||
required:
|
||||
- type
|
||||
type: object
|
||||
HealthInfo:
|
||||
additionalProperties: false
|
||||
properties:
|
||||
|
@ -2064,26 +2074,13 @@ components:
|
|||
default: 1.0
|
||||
type: number
|
||||
strategy:
|
||||
$ref: '#/components/schemas/SamplingStrategy'
|
||||
default: greedy
|
||||
temperature:
|
||||
default: 0.0
|
||||
type: number
|
||||
top_k:
|
||||
default: 0
|
||||
type: integer
|
||||
top_p:
|
||||
default: 0.95
|
||||
type: number
|
||||
oneOf:
|
||||
- $ref: '#/components/schemas/GreedySamplingStrategy'
|
||||
- $ref: '#/components/schemas/TopPSamplingStrategy'
|
||||
- $ref: '#/components/schemas/TopKSamplingStrategy'
|
||||
required:
|
||||
- strategy
|
||||
type: object
|
||||
SamplingStrategy:
|
||||
enum:
|
||||
- greedy
|
||||
- top_p
|
||||
- top_k
|
||||
type: string
|
||||
SaveSpansToDatasetRequest:
|
||||
additionalProperties: false
|
||||
properties:
|
||||
|
@ -2931,6 +2928,34 @@ components:
|
|||
- tool_name
|
||||
- content
|
||||
type: object
|
||||
TopKSamplingStrategy:
|
||||
additionalProperties: false
|
||||
properties:
|
||||
top_k:
|
||||
type: integer
|
||||
type:
|
||||
const: top_k
|
||||
default: top_k
|
||||
type: string
|
||||
required:
|
||||
- type
|
||||
- top_k
|
||||
type: object
|
||||
TopPSamplingStrategy:
|
||||
additionalProperties: false
|
||||
properties:
|
||||
temperature:
|
||||
type: number
|
||||
top_p:
|
||||
default: 0.95
|
||||
type: number
|
||||
type:
|
||||
const: top_p
|
||||
default: top_p
|
||||
type: string
|
||||
required:
|
||||
- type
|
||||
type: object
|
||||
Trace:
|
||||
additionalProperties: false
|
||||
properties:
|
||||
|
@ -5587,6 +5612,9 @@ tags:
|
|||
- description: <SchemaDefinition schemaRef="#/components/schemas/GraphMemoryBankParams"
|
||||
/>
|
||||
name: GraphMemoryBankParams
|
||||
- description: <SchemaDefinition schemaRef="#/components/schemas/GreedySamplingStrategy"
|
||||
/>
|
||||
name: GreedySamplingStrategy
|
||||
- description: <SchemaDefinition schemaRef="#/components/schemas/HealthInfo" />
|
||||
name: HealthInfo
|
||||
- description: <SchemaDefinition schemaRef="#/components/schemas/ImageContentItem"
|
||||
|
@ -5753,9 +5781,6 @@ tags:
|
|||
name: SafetyViolation
|
||||
- description: <SchemaDefinition schemaRef="#/components/schemas/SamplingParams" />
|
||||
name: SamplingParams
|
||||
- description: <SchemaDefinition schemaRef="#/components/schemas/SamplingStrategy"
|
||||
/>
|
||||
name: SamplingStrategy
|
||||
- description: <SchemaDefinition schemaRef="#/components/schemas/SaveSpansToDatasetRequest"
|
||||
/>
|
||||
name: SaveSpansToDatasetRequest
|
||||
|
@ -5874,6 +5899,12 @@ tags:
|
|||
/>
|
||||
name: ToolResponseMessage
|
||||
- name: ToolRuntime
|
||||
- description: <SchemaDefinition schemaRef="#/components/schemas/TopKSamplingStrategy"
|
||||
/>
|
||||
name: TopKSamplingStrategy
|
||||
- description: <SchemaDefinition schemaRef="#/components/schemas/TopPSamplingStrategy"
|
||||
/>
|
||||
name: TopPSamplingStrategy
|
||||
- description: <SchemaDefinition schemaRef="#/components/schemas/Trace" />
|
||||
name: Trace
|
||||
- description: <SchemaDefinition schemaRef="#/components/schemas/TrainingConfig" />
|
||||
|
@ -5990,6 +6021,7 @@ x-tagGroups:
|
|||
- GetSpanTreeRequest
|
||||
- GraphMemoryBank
|
||||
- GraphMemoryBankParams
|
||||
- GreedySamplingStrategy
|
||||
- HealthInfo
|
||||
- ImageContentItem
|
||||
- InferenceStep
|
||||
|
@ -6047,7 +6079,6 @@ x-tagGroups:
|
|||
- RunShieldResponse
|
||||
- SafetyViolation
|
||||
- SamplingParams
|
||||
- SamplingStrategy
|
||||
- SaveSpansToDatasetRequest
|
||||
- ScoreBatchRequest
|
||||
- ScoreBatchResponse
|
||||
|
@ -6087,6 +6118,8 @@ x-tagGroups:
|
|||
- ToolPromptFormat
|
||||
- ToolResponse
|
||||
- ToolResponseMessage
|
||||
- TopKSamplingStrategy
|
||||
- TopPSamplingStrategy
|
||||
- Trace
|
||||
- TrainingConfig
|
||||
- Turn
|
||||
|
|
|
@ -56,9 +56,10 @@ response = client.eval.evaluate_rows(
|
|||
"type": "model",
|
||||
"model": "meta-llama/Llama-3.2-90B-Vision-Instruct",
|
||||
"sampling_params": {
|
||||
"temperature": 0.0,
|
||||
"strategy": {
|
||||
"type": "greedy",
|
||||
},
|
||||
"max_tokens": 4096,
|
||||
"top_p": 0.9,
|
||||
"repeat_penalty": 1.0,
|
||||
},
|
||||
"system_message": system_message
|
||||
|
@ -113,9 +114,10 @@ response = client.eval.evaluate_rows(
|
|||
"type": "model",
|
||||
"model": "meta-llama/Llama-3.2-90B-Vision-Instruct",
|
||||
"sampling_params": {
|
||||
"temperature": 0.0,
|
||||
"strategy": {
|
||||
"type": "greedy",
|
||||
},
|
||||
"max_tokens": 4096,
|
||||
"top_p": 0.9,
|
||||
"repeat_penalty": 1.0,
|
||||
},
|
||||
}
|
||||
|
@ -134,9 +136,9 @@ agent_config = {
|
|||
"model": "meta-llama/Llama-3.1-405B-Instruct",
|
||||
"instructions": "You are a helpful assistant",
|
||||
"sampling_params": {
|
||||
"strategy": "greedy",
|
||||
"temperature": 0.0,
|
||||
"top_p": 0.95,
|
||||
"strategy": {
|
||||
"type": "greedy",
|
||||
},
|
||||
},
|
||||
"tools": [
|
||||
{
|
||||
|
|
|
@ -189,7 +189,11 @@ agent_config = AgentConfig(
|
|||
# Control the inference loop
|
||||
max_infer_iters=5,
|
||||
sampling_params={
|
||||
"temperature": 0.7,
|
||||
"strategy": {
|
||||
"type": "top_p",
|
||||
"temperature": 0.7,
|
||||
"top_p": 0.95
|
||||
},
|
||||
"max_tokens": 2048
|
||||
}
|
||||
)
|
||||
|
|
|
@ -92,9 +92,10 @@ response = client.eval.evaluate_rows(
|
|||
"type": "model",
|
||||
"model": "meta-llama/Llama-3.2-90B-Vision-Instruct",
|
||||
"sampling_params": {
|
||||
"temperature": 0.0,
|
||||
"strategy": {
|
||||
"type": "greedy",
|
||||
},
|
||||
"max_tokens": 4096,
|
||||
"top_p": 0.9,
|
||||
"repeat_penalty": 1.0,
|
||||
},
|
||||
"system_message": system_message
|
||||
|
@ -149,9 +150,10 @@ response = client.eval.evaluate_rows(
|
|||
"type": "model",
|
||||
"model": "meta-llama/Llama-3.2-90B-Vision-Instruct",
|
||||
"sampling_params": {
|
||||
"temperature": 0.0,
|
||||
"strategy": {
|
||||
"type": "greedy",
|
||||
},
|
||||
"max_tokens": 4096,
|
||||
"top_p": 0.9,
|
||||
"repeat_penalty": 1.0,
|
||||
},
|
||||
}
|
||||
|
@ -170,9 +172,9 @@ agent_config = {
|
|||
"model": "meta-llama/Llama-3.1-405B-Instruct",
|
||||
"instructions": "You are a helpful assistant",
|
||||
"sampling_params": {
|
||||
"strategy": "greedy",
|
||||
"temperature": 0.0,
|
||||
"top_p": 0.95,
|
||||
"strategy": {
|
||||
"type": "greedy",
|
||||
},
|
||||
},
|
||||
"tools": [
|
||||
{
|
||||
|
@ -318,10 +320,9 @@ The `EvalTaskConfig` are user specified config to define:
|
|||
"type": "model",
|
||||
"model": "Llama3.2-3B-Instruct",
|
||||
"sampling_params": {
|
||||
"strategy": "greedy",
|
||||
"temperature": 0,
|
||||
"top_p": 0.95,
|
||||
"top_k": 0,
|
||||
"strategy": {
|
||||
"type": "greedy",
|
||||
},
|
||||
"max_tokens": 0,
|
||||
"repetition_penalty": 1.0
|
||||
}
|
||||
|
@ -337,10 +338,9 @@ The `EvalTaskConfig` are user specified config to define:
|
|||
"type": "model",
|
||||
"model": "Llama3.1-405B-Instruct",
|
||||
"sampling_params": {
|
||||
"strategy": "greedy",
|
||||
"temperature": 0,
|
||||
"top_p": 0.95,
|
||||
"top_k": 0,
|
||||
"strategy": {
|
||||
"type": "greedy",
|
||||
},
|
||||
"max_tokens": 0,
|
||||
"repetition_penalty": 1.0
|
||||
}
|
||||
|
|
|
@ -214,7 +214,6 @@ llama model describe -m Llama3.2-3B-Instruct
|
|||
| | } |
|
||||
+-----------------------------+----------------------------------+
|
||||
| Recommended sampling params | { |
|
||||
| | "strategy": "top_p", |
|
||||
| | "temperature": 1.0, |
|
||||
| | "top_p": 0.9, |
|
||||
| | "top_k": 0 |
|
||||
|
|
|
@ -200,10 +200,9 @@ Example eval_task_config.json:
|
|||
"type": "model",
|
||||
"model": "Llama3.1-405B-Instruct",
|
||||
"sampling_params": {
|
||||
"strategy": "greedy",
|
||||
"temperature": 0,
|
||||
"top_p": 0.95,
|
||||
"top_k": 0,
|
||||
"strategy": {
|
||||
"type": "greedy"
|
||||
},
|
||||
"max_tokens": 0,
|
||||
"repetition_penalty": 1.0
|
||||
}
|
||||
|
|
|
@ -26,27 +26,28 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"import requests\n",
|
||||
"import json\n",
|
||||
"import asyncio\n",
|
||||
"import nest_asyncio\n",
|
||||
"import json\n",
|
||||
"import os\n",
|
||||
"from typing import Dict, List\n",
|
||||
"\n",
|
||||
"import nest_asyncio\n",
|
||||
"import requests\n",
|
||||
"from dotenv import load_dotenv\n",
|
||||
"from llama_stack_client import LlamaStackClient\n",
|
||||
"from llama_stack_client.lib.agents.custom_tool import CustomTool\n",
|
||||
"from llama_stack_client.types.shared.tool_response_message import ToolResponseMessage\n",
|
||||
"from llama_stack_client.types import CompletionMessage\n",
|
||||
"from llama_stack_client.lib.agents.agent import Agent\n",
|
||||
"from llama_stack_client.lib.agents.custom_tool import CustomTool\n",
|
||||
"from llama_stack_client.lib.agents.event_logger import EventLogger\n",
|
||||
"from llama_stack_client.types import CompletionMessage\n",
|
||||
"from llama_stack_client.types.agent_create_params import AgentConfig\n",
|
||||
"from llama_stack_client.types.shared.tool_response_message import ToolResponseMessage\n",
|
||||
"\n",
|
||||
"# Allow asyncio to run in Jupyter Notebook\n",
|
||||
"nest_asyncio.apply()\n",
|
||||
"\n",
|
||||
"HOST='localhost'\n",
|
||||
"PORT=5001\n",
|
||||
"MODEL_NAME='meta-llama/Llama-3.2-3B-Instruct'"
|
||||
"HOST = \"localhost\"\n",
|
||||
"PORT = 5001\n",
|
||||
"MODEL_NAME = \"meta-llama/Llama-3.2-3B-Instruct\"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -69,7 +70,7 @@
|
|||
"outputs": [],
|
||||
"source": [
|
||||
"load_dotenv()\n",
|
||||
"BRAVE_SEARCH_API_KEY = os.environ['BRAVE_SEARCH_API_KEY']"
|
||||
"BRAVE_SEARCH_API_KEY = os.environ[\"BRAVE_SEARCH_API_KEY\"]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -118,7 +119,7 @@
|
|||
" cleaned = {k: v for k, v in results[idx].items() if k in selected_keys}\n",
|
||||
" clean_response.append(cleaned)\n",
|
||||
"\n",
|
||||
" return {\"query\": query, \"top_k\": clean_response}"
|
||||
" return {\"query\": query, \"top_k\": clean_response}\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -157,25 +158,29 @@
|
|||
" for message in messages:\n",
|
||||
" if isinstance(message, CompletionMessage) and message.tool_calls:\n",
|
||||
" for tool_call in message.tool_calls:\n",
|
||||
" if 'query' in tool_call.arguments:\n",
|
||||
" query = tool_call.arguments['query']\n",
|
||||
" if \"query\" in tool_call.arguments:\n",
|
||||
" query = tool_call.arguments[\"query\"]\n",
|
||||
" call_id = tool_call.call_id\n",
|
||||
"\n",
|
||||
" if query:\n",
|
||||
" search_result = await self.run_impl(query)\n",
|
||||
" return [ToolResponseMessage(\n",
|
||||
" call_id=call_id,\n",
|
||||
" role=\"ipython\",\n",
|
||||
" content=self._format_response_for_agent(search_result),\n",
|
||||
" tool_name=\"brave_search\"\n",
|
||||
" )]\n",
|
||||
" return [\n",
|
||||
" ToolResponseMessage(\n",
|
||||
" call_id=call_id,\n",
|
||||
" role=\"ipython\",\n",
|
||||
" content=self._format_response_for_agent(search_result),\n",
|
||||
" tool_name=\"brave_search\",\n",
|
||||
" )\n",
|
||||
" ]\n",
|
||||
"\n",
|
||||
" return [ToolResponseMessage(\n",
|
||||
" call_id=\"no_call_id\",\n",
|
||||
" role=\"ipython\",\n",
|
||||
" content=\"No query provided.\",\n",
|
||||
" tool_name=\"brave_search\"\n",
|
||||
" )]\n",
|
||||
" return [\n",
|
||||
" ToolResponseMessage(\n",
|
||||
" call_id=\"no_call_id\",\n",
|
||||
" role=\"ipython\",\n",
|
||||
" content=\"No query provided.\",\n",
|
||||
" tool_name=\"brave_search\",\n",
|
||||
" )\n",
|
||||
" ]\n",
|
||||
"\n",
|
||||
" def _format_response_for_agent(self, search_result):\n",
|
||||
" parsed_result = json.loads(search_result)\n",
|
||||
|
@ -186,7 +191,7 @@
|
|||
" f\" URL: {result.get('url', 'No URL')}\\n\"\n",
|
||||
" f\" Description: {result.get('description', 'No Description')}\\n\\n\"\n",
|
||||
" )\n",
|
||||
" return formatted_result"
|
||||
" return formatted_result\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -209,7 +214,7 @@
|
|||
"async def execute_search(query: str):\n",
|
||||
" web_search_tool = WebSearchTool(api_key=BRAVE_SEARCH_API_KEY)\n",
|
||||
" result = await web_search_tool.run_impl(query)\n",
|
||||
" print(\"Search Results:\", result)"
|
||||
" print(\"Search Results:\", result)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -236,7 +241,7 @@
|
|||
],
|
||||
"source": [
|
||||
"query = \"Latest developments in quantum computing\"\n",
|
||||
"asyncio.run(execute_search(query))"
|
||||
"asyncio.run(execute_search(query))\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -288,19 +293,17 @@
|
|||
"\n",
|
||||
" # Initialize custom tool (ensure `WebSearchTool` is defined earlier in the notebook)\n",
|
||||
" webSearchTool = WebSearchTool(api_key=BRAVE_SEARCH_API_KEY)\n",
|
||||
" \n",
|
||||
"\n",
|
||||
" # Define the agent configuration, including the model and tool setup\n",
|
||||
" agent_config = AgentConfig(\n",
|
||||
" model=MODEL_NAME,\n",
|
||||
" instructions=\"\"\"You are a helpful assistant that responds to user queries with relevant information and cites sources when available.\"\"\",\n",
|
||||
" sampling_params={\n",
|
||||
" \"strategy\": \"greedy\",\n",
|
||||
" \"temperature\": 1.0,\n",
|
||||
" \"top_p\": 0.9,\n",
|
||||
" \"strategy\": {\n",
|
||||
" \"type\": \"greedy\",\n",
|
||||
" },\n",
|
||||
" },\n",
|
||||
" tools=[\n",
|
||||
" webSearchTool.get_tool_definition()\n",
|
||||
" ],\n",
|
||||
" tools=[webSearchTool.get_tool_definition()],\n",
|
||||
" tool_choice=\"auto\",\n",
|
||||
" tool_prompt_format=\"python_list\",\n",
|
||||
" input_shields=input_shields,\n",
|
||||
|
@ -329,8 +332,9 @@
|
|||
" async for log in EventLogger().log(response):\n",
|
||||
" log.print()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Run the function asynchronously in a Jupyter Notebook cell\n",
|
||||
"await run_main(disable_safety=True)"
|
||||
"await run_main(disable_safety=True)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
|
|
@ -50,8 +50,8 @@
|
|||
"outputs": [],
|
||||
"source": [
|
||||
"HOST = \"localhost\" # Replace with your host\n",
|
||||
"PORT = 5001 # Replace with your port\n",
|
||||
"MODEL_NAME='meta-llama/Llama-3.2-3B-Instruct'"
|
||||
"PORT = 5001 # Replace with your port\n",
|
||||
"MODEL_NAME = \"meta-llama/Llama-3.2-3B-Instruct\"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -60,10 +60,12 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from dotenv import load_dotenv\n",
|
||||
"import os\n",
|
||||
"\n",
|
||||
"from dotenv import load_dotenv\n",
|
||||
"\n",
|
||||
"load_dotenv()\n",
|
||||
"BRAVE_SEARCH_API_KEY = os.environ['BRAVE_SEARCH_API_KEY']"
|
||||
"BRAVE_SEARCH_API_KEY = os.environ[\"BRAVE_SEARCH_API_KEY\"]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -104,20 +106,22 @@
|
|||
],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"\n",
|
||||
"from llama_stack_client import LlamaStackClient\n",
|
||||
"from llama_stack_client.lib.agents.agent import Agent\n",
|
||||
"from llama_stack_client.lib.agents.event_logger import EventLogger\n",
|
||||
"from llama_stack_client.types.agent_create_params import AgentConfig\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"async def agent_example():\n",
|
||||
" client = LlamaStackClient(base_url=f\"http://{HOST}:{PORT}\")\n",
|
||||
" agent_config = AgentConfig(\n",
|
||||
" model=MODEL_NAME,\n",
|
||||
" instructions=\"You are a helpful assistant! If you call builtin tools like brave search, follow the syntax brave_search.call(…)\",\n",
|
||||
" sampling_params={\n",
|
||||
" \"strategy\": \"greedy\",\n",
|
||||
" \"temperature\": 1.0,\n",
|
||||
" \"top_p\": 0.9,\n",
|
||||
" \"strategy\": {\n",
|
||||
" \"type\": \"greedy\",\n",
|
||||
" },\n",
|
||||
" },\n",
|
||||
" tools=[\n",
|
||||
" {\n",
|
||||
|
@ -157,7 +161,7 @@
|
|||
" log.print()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"await agent_example()"
|
||||
"await agent_example()\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
|
@ -157,7 +157,15 @@ curl http://localhost:$LLAMA_STACK_PORT/alpha/inference/chat-completion
|
|||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Write me a 2-sentence poem about the moon"}
|
||||
],
|
||||
"sampling_params": {"temperature": 0.7, "seed": 42, "max_tokens": 512}
|
||||
"sampling_params": {
|
||||
"strategy": {
|
||||
"type": "top_p",
|
||||
"temperatrue": 0.7,
|
||||
"top_p": 0.95,
|
||||
},
|
||||
"seed": 42,
|
||||
"max_tokens": 512
|
||||
}
|
||||
}
|
||||
EOF
|
||||
```
|
||||
|
|
|
@ -83,8 +83,8 @@
|
|||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"LLAMA_STACK_API_TOGETHER_URL=\"https://llama-stack.together.ai\"\n",
|
||||
"LLAMA31_8B_INSTRUCT = \"Llama3.1-8B-Instruct\""
|
||||
"LLAMA_STACK_API_TOGETHER_URL = \"https://llama-stack.together.ai\"\n",
|
||||
"LLAMA31_8B_INSTRUCT = \"Llama3.1-8B-Instruct\"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -107,12 +107,13 @@
|
|||
" AgentConfigToolSearchToolDefinition,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Helper function to create an agent with tools\n",
|
||||
"async def create_tool_agent(\n",
|
||||
" client: LlamaStackClient,\n",
|
||||
" tools: List[Dict],\n",
|
||||
" instructions: str = \"You are a helpful assistant\",\n",
|
||||
" model: str = LLAMA31_8B_INSTRUCT\n",
|
||||
" model: str = LLAMA31_8B_INSTRUCT,\n",
|
||||
") -> Agent:\n",
|
||||
" \"\"\"Create an agent with specified tools.\"\"\"\n",
|
||||
" print(\"Using the following model: \", model)\n",
|
||||
|
@ -120,9 +121,9 @@
|
|||
" model=model,\n",
|
||||
" instructions=instructions,\n",
|
||||
" sampling_params={\n",
|
||||
" \"strategy\": \"greedy\",\n",
|
||||
" \"temperature\": 1.0,\n",
|
||||
" \"top_p\": 0.9,\n",
|
||||
" \"strategy\": {\n",
|
||||
" \"type\": \"greedy\",\n",
|
||||
" },\n",
|
||||
" },\n",
|
||||
" tools=tools,\n",
|
||||
" tool_choice=\"auto\",\n",
|
||||
|
@ -130,7 +131,7 @@
|
|||
" enable_session_persistence=True,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" return Agent(client, agent_config)"
|
||||
" return Agent(client, agent_config)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -172,7 +173,8 @@
|
|||
],
|
||||
"source": [
|
||||
"# comment this if you don't have a BRAVE_SEARCH_API_KEY\n",
|
||||
"os.environ[\"BRAVE_SEARCH_API_KEY\"] = 'YOUR_BRAVE_SEARCH_API_KEY'\n",
|
||||
"os.environ[\"BRAVE_SEARCH_API_KEY\"] = \"YOUR_BRAVE_SEARCH_API_KEY\"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"async def create_search_agent(client: LlamaStackClient) -> Agent:\n",
|
||||
" \"\"\"Create an agent with Brave Search capability.\"\"\"\n",
|
||||
|
@ -186,8 +188,8 @@
|
|||
"\n",
|
||||
" return await create_tool_agent(\n",
|
||||
" client=client,\n",
|
||||
" tools=[search_tool], # set this to [] if you don't have a BRAVE_SEARCH_API_KEY\n",
|
||||
" model = LLAMA31_8B_INSTRUCT,\n",
|
||||
" tools=[search_tool], # set this to [] if you don't have a BRAVE_SEARCH_API_KEY\n",
|
||||
" model=LLAMA31_8B_INSTRUCT,\n",
|
||||
" instructions=\"\"\"\n",
|
||||
" You are a research assistant that can search the web.\n",
|
||||
" Always cite your sources with URLs when providing information.\n",
|
||||
|
@ -198,9 +200,10 @@
|
|||
"\n",
|
||||
" SOURCES:\n",
|
||||
" - [Source title](URL)\n",
|
||||
" \"\"\"\n",
|
||||
" \"\"\",\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Example usage\n",
|
||||
"async def search_example():\n",
|
||||
" client = LlamaStackClient(base_url=LLAMA_STACK_API_TOGETHER_URL)\n",
|
||||
|
@ -212,7 +215,7 @@
|
|||
" # Example queries\n",
|
||||
" queries = [\n",
|
||||
" \"What are the latest developments in quantum computing?\",\n",
|
||||
" #\"Who won the most recent Super Bowl?\",\n",
|
||||
" # \"Who won the most recent Super Bowl?\",\n",
|
||||
" ]\n",
|
||||
"\n",
|
||||
" for query in queries:\n",
|
||||
|
@ -227,8 +230,9 @@
|
|||
" async for log in EventLogger().log(response):\n",
|
||||
" log.print()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Run the example (in Jupyter, use asyncio.run())\n",
|
||||
"await search_example()"
|
||||
"await search_example()\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -286,12 +290,16 @@
|
|||
}
|
||||
],
|
||||
"source": [
|
||||
"from typing import TypedDict, Optional, Dict, Any\n",
|
||||
"from datetime import datetime\n",
|
||||
"import json\n",
|
||||
"from llama_stack_client.types.tool_param_definition_param import ToolParamDefinitionParam\n",
|
||||
"from llama_stack_client.types import CompletionMessage,ToolResponseMessage\n",
|
||||
"from datetime import datetime\n",
|
||||
"from typing import Any, Dict, Optional, TypedDict\n",
|
||||
"\n",
|
||||
"from llama_stack_client.lib.agents.custom_tool import CustomTool\n",
|
||||
"from llama_stack_client.types import CompletionMessage, ToolResponseMessage\n",
|
||||
"from llama_stack_client.types.tool_param_definition_param import (\n",
|
||||
" ToolParamDefinitionParam,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class WeatherTool(CustomTool):\n",
|
||||
" \"\"\"Example custom tool for weather information.\"\"\"\n",
|
||||
|
@ -305,16 +313,15 @@
|
|||
" def get_params_definition(self) -> Dict[str, ToolParamDefinitionParam]:\n",
|
||||
" return {\n",
|
||||
" \"location\": ToolParamDefinitionParam(\n",
|
||||
" param_type=\"str\",\n",
|
||||
" description=\"City or location name\",\n",
|
||||
" required=True\n",
|
||||
" param_type=\"str\", description=\"City or location name\", required=True\n",
|
||||
" ),\n",
|
||||
" \"date\": ToolParamDefinitionParam(\n",
|
||||
" param_type=\"str\",\n",
|
||||
" description=\"Optional date (YYYY-MM-DD)\",\n",
|
||||
" required=False\n",
|
||||
" )\n",
|
||||
" required=False,\n",
|
||||
" ),\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" async def run(self, messages: List[CompletionMessage]) -> List[ToolResponseMessage]:\n",
|
||||
" assert len(messages) == 1, \"Expected single message\"\n",
|
||||
"\n",
|
||||
|
@ -337,20 +344,14 @@
|
|||
" )\n",
|
||||
" return [message]\n",
|
||||
"\n",
|
||||
" async def run_impl(self, location: str, date: Optional[str] = None) -> Dict[str, Any]:\n",
|
||||
" async def run_impl(\n",
|
||||
" self, location: str, date: Optional[str] = None\n",
|
||||
" ) -> Dict[str, Any]:\n",
|
||||
" \"\"\"Simulate getting weather data (replace with actual API call).\"\"\"\n",
|
||||
" # Mock implementation\n",
|
||||
" if date:\n",
|
||||
" return {\n",
|
||||
" \"temperature\": 90.1,\n",
|
||||
" \"conditions\": \"sunny\",\n",
|
||||
" \"humidity\": 40.0\n",
|
||||
" }\n",
|
||||
" return {\n",
|
||||
" \"temperature\": 72.5,\n",
|
||||
" \"conditions\": \"partly cloudy\",\n",
|
||||
" \"humidity\": 65.0\n",
|
||||
" }\n",
|
||||
" return {\"temperature\": 90.1, \"conditions\": \"sunny\", \"humidity\": 40.0}\n",
|
||||
" return {\"temperature\": 72.5, \"conditions\": \"partly cloudy\", \"humidity\": 65.0}\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"async def create_weather_agent(client: LlamaStackClient) -> Agent:\n",
|
||||
|
@ -358,38 +359,33 @@
|
|||
"\n",
|
||||
" # Create the agent with the tool\n",
|
||||
" weather_tool = WeatherTool()\n",
|
||||
" \n",
|
||||
"\n",
|
||||
" agent_config = AgentConfig(\n",
|
||||
" model=LLAMA31_8B_INSTRUCT,\n",
|
||||
" #model=model_name,\n",
|
||||
" # model=model_name,\n",
|
||||
" instructions=\"\"\"\n",
|
||||
" You are a weather assistant that can provide weather information.\n",
|
||||
" Always specify the location clearly in your responses.\n",
|
||||
" Include both temperature and conditions in your summaries.\n",
|
||||
" \"\"\",\n",
|
||||
" sampling_params={\n",
|
||||
" \"strategy\": \"greedy\",\n",
|
||||
" \"temperature\": 1.0,\n",
|
||||
" \"top_p\": 0.9,\n",
|
||||
" \"strategy\": {\n",
|
||||
" \"type\": \"greedy\",\n",
|
||||
" },\n",
|
||||
" },\n",
|
||||
" tools=[\n",
|
||||
" weather_tool.get_tool_definition()\n",
|
||||
" ],\n",
|
||||
" tools=[weather_tool.get_tool_definition()],\n",
|
||||
" tool_choice=\"auto\",\n",
|
||||
" tool_prompt_format=\"json\",\n",
|
||||
" input_shields=[],\n",
|
||||
" output_shields=[],\n",
|
||||
" enable_session_persistence=True\n",
|
||||
" enable_session_persistence=True,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" agent = Agent(\n",
|
||||
" client=client,\n",
|
||||
" agent_config=agent_config,\n",
|
||||
" custom_tools=[weather_tool]\n",
|
||||
" )\n",
|
||||
" agent = Agent(client=client, agent_config=agent_config, custom_tools=[weather_tool])\n",
|
||||
"\n",
|
||||
" return agent\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Example usage\n",
|
||||
"async def weather_example():\n",
|
||||
" client = LlamaStackClient(base_url=LLAMA_STACK_API_TOGETHER_URL)\n",
|
||||
|
@ -413,12 +409,14 @@
|
|||
" async for log in EventLogger().log(response):\n",
|
||||
" log.print()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# For Jupyter notebooks\n",
|
||||
"import nest_asyncio\n",
|
||||
"\n",
|
||||
"nest_asyncio.apply()\n",
|
||||
"\n",
|
||||
"# Run the example\n",
|
||||
"await weather_example()"
|
||||
"await weather_example()\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue