Update Strategy in SamplingParams to be a union

This commit is contained in:
Hardik Shah 2025-01-14 15:56:02 -08:00 committed by Ashwin Bharambe
parent 300e6e2702
commit dea575c994
28 changed files with 600 additions and 377 deletions

View file

@ -618,11 +618,13 @@
],
"source": [
"import os\n",
"\n",
"from google.colab import userdata\n",
"\n",
"os.environ['TOGETHER_API_KEY'] = userdata.get('TOGETHER_API_KEY')\n",
"os.environ[\"TOGETHER_API_KEY\"] = userdata.get(\"TOGETHER_API_KEY\")\n",
"\n",
"from llama_stack.distribution.library_client import LlamaStackAsLibraryClient\n",
"\n",
"client = LlamaStackAsLibraryClient(\"together\")\n",
"_ = client.initialize()\n",
"\n",
@ -631,7 +633,7 @@
" model_id=\"meta-llama/Llama-3.1-405B-Instruct\",\n",
" provider_model_id=\"meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo\",\n",
" provider_id=\"together\",\n",
")"
")\n"
]
},
{
@ -668,7 +670,7 @@
"source": [
"name = \"llamastack/mmmu\"\n",
"subset = \"Agriculture\"\n",
"split = \"dev\""
"split = \"dev\"\n"
]
},
{
@ -914,9 +916,10 @@
],
"source": [
"import datasets\n",
"\n",
"ds = datasets.load_dataset(path=name, name=subset, split=split)\n",
"ds = ds.select_columns([\"chat_completion_input\", \"input_query\", \"expected_answer\"])\n",
"eval_rows = ds.to_pandas().to_dict(orient=\"records\")"
"eval_rows = ds.to_pandas().to_dict(orient=\"records\")\n"
]
},
{
@ -1014,8 +1017,8 @@
}
],
"source": [
"from tqdm import tqdm\n",
"from rich.pretty import pprint\n",
"from tqdm import tqdm\n",
"\n",
"SYSTEM_PROMPT_TEMPLATE = \"\"\"\n",
"You are an expert in {subject} whose job is to answer questions from the user using images.\n",
@ -1039,7 +1042,7 @@
"client.eval_tasks.register(\n",
" eval_task_id=\"meta-reference::mmmu\",\n",
" dataset_id=f\"mmmu-{subset}-{split}\",\n",
" scoring_functions=[\"basic::regex_parser_multiple_choice_answer\"]\n",
" scoring_functions=[\"basic::regex_parser_multiple_choice_answer\"],\n",
")\n",
"\n",
"response = client.eval.evaluate_rows(\n",
@ -1052,16 +1055,17 @@
" \"type\": \"model\",\n",
" \"model\": \"meta-llama/Llama-3.2-90B-Vision-Instruct\",\n",
" \"sampling_params\": {\n",
" \"temperature\": 0.0,\n",
" \"strategy\": {\n",
" \"type\": \"greedy\",\n",
" },\n",
" \"max_tokens\": 4096,\n",
" \"top_p\": 0.9,\n",
" \"repeat_penalty\": 1.0,\n",
" },\n",
" \"system_message\": system_message\n",
" }\n",
" }\n",
" \"system_message\": system_message,\n",
" },\n",
" },\n",
")\n",
"pprint(response)"
"pprint(response)\n"
]
},
{
@ -1098,8 +1102,8 @@
" \"input_query\": {\"type\": \"string\"},\n",
" \"expected_answer\": {\"type\": \"string\"},\n",
" \"chat_completion_input\": {\"type\": \"chat_completion_input\"},\n",
" }\n",
")"
" },\n",
")\n"
]
},
{
@ -1113,7 +1117,7 @@
"eval_rows = client.datasetio.get_rows_paginated(\n",
" dataset_id=simpleqa_dataset_id,\n",
" rows_in_page=5,\n",
")"
")\n"
]
},
{
@ -1209,7 +1213,7 @@
"client.eval_tasks.register(\n",
" eval_task_id=\"meta-reference::simpleqa\",\n",
" dataset_id=simpleqa_dataset_id,\n",
" scoring_functions=[\"llm-as-judge::405b-simpleqa\"]\n",
" scoring_functions=[\"llm-as-judge::405b-simpleqa\"],\n",
")\n",
"\n",
"response = client.eval.evaluate_rows(\n",
@ -1222,15 +1226,16 @@
" \"type\": \"model\",\n",
" \"model\": \"meta-llama/Llama-3.2-90B-Vision-Instruct\",\n",
" \"sampling_params\": {\n",
" \"temperature\": 0.0,\n",
" \"strategy\": {\n",
" \"type\": \"greedy\",\n",
" },\n",
" \"max_tokens\": 4096,\n",
" \"top_p\": 0.9,\n",
" \"repeat_penalty\": 1.0,\n",
" },\n",
" }\n",
" }\n",
" },\n",
" },\n",
")\n",
"pprint(response)"
"pprint(response)\n"
]
},
{
@ -1347,23 +1352,19 @@
"agent_config = {\n",
" \"model\": \"meta-llama/Llama-3.1-405B-Instruct\",\n",
" \"instructions\": \"You are a helpful assistant\",\n",
" \"sampling_params\": {\n",
" \"strategy\": \"greedy\",\n",
" \"temperature\": 0.0,\n",
" \"top_p\": 0.95,\n",
" },\n",
" \"sampling_params\": {\"strategy\": {\"type\": \"greedy\"}},\n",
" \"tools\": [\n",
" {\n",
" \"type\": \"brave_search\",\n",
" \"engine\": \"tavily\",\n",
" \"api_key\": userdata.get(\"TAVILY_SEARCH_API_KEY\")\n",
" \"api_key\": userdata.get(\"TAVILY_SEARCH_API_KEY\"),\n",
" }\n",
" ],\n",
" \"tool_choice\": \"auto\",\n",
" \"tool_prompt_format\": \"json\",\n",
" \"input_shields\": [],\n",
" \"output_shields\": [],\n",
" \"enable_session_persistence\": False\n",
" \"enable_session_persistence\": False,\n",
"}\n",
"\n",
"response = client.eval.evaluate_rows(\n",
@ -1375,10 +1376,10 @@
" \"eval_candidate\": {\n",
" \"type\": \"agent\",\n",
" \"config\": agent_config,\n",
" }\n",
" }\n",
" },\n",
" },\n",
")\n",
"pprint(response)"
"pprint(response)\n"
]
}
],

View file

@ -1336,6 +1336,7 @@
],
"source": [
"from rich.pretty import pprint\n",
"\n",
"print(\"Available models:\")\n",
"for m in client.models.list():\n",
" print(f\"{m.identifier} (provider's alias: {m.provider_resource_id}) \")\n",
@ -1344,7 +1345,7 @@
"print(\"Available shields (safety models):\")\n",
"for s in client.shields.list():\n",
" print(s.identifier)\n",
"print(\"----\")"
"print(\"----\")\n"
]
},
{
@ -1389,7 +1390,7 @@
"source": [
"model_id = \"meta-llama/Llama-3.1-70B-Instruct\"\n",
"\n",
"model_id"
"model_id\n"
]
},
{
@ -1432,11 +1433,11 @@
" model_id=model_id,\n",
" messages=[\n",
" {\"role\": \"system\", \"content\": \"You are a friendly assistant.\"},\n",
" {\"role\": \"user\", \"content\": \"Write a two-sentence poem about llama.\"}\n",
" {\"role\": \"user\", \"content\": \"Write a two-sentence poem about llama.\"},\n",
" ],\n",
")\n",
"\n",
"print(response.completion_message.content)"
"print(response.completion_message.content)\n"
]
},
{
@ -1489,12 +1490,13 @@
"source": [
"from termcolor import cprint\n",
"\n",
"\n",
"def chat_loop():\n",
" conversation_history = []\n",
" while True:\n",
" user_input = input('User> ')\n",
" if user_input.lower() in ['exit', 'quit', 'bye']:\n",
" cprint('Ending conversation. Goodbye!', 'yellow')\n",
" user_input = input(\"User> \")\n",
" if user_input.lower() in [\"exit\", \"quit\", \"bye\"]:\n",
" cprint(\"Ending conversation. Goodbye!\", \"yellow\")\n",
" break\n",
"\n",
" user_message = {\"role\": \"user\", \"content\": user_input}\n",
@ -1504,15 +1506,16 @@
" messages=conversation_history,\n",
" model_id=model_id,\n",
" )\n",
" cprint(f'> Response: {response.completion_message.content}', 'cyan')\n",
" cprint(f\"> Response: {response.completion_message.content}\", \"cyan\")\n",
"\n",
" assistant_message = {\n",
" \"role\": \"assistant\", # was user\n",
" \"role\": \"assistant\", # was user\n",
" \"content\": response.completion_message.content,\n",
" \"stop_reason\": response.completion_message.stop_reason,\n",
" }\n",
" conversation_history.append(assistant_message)\n",
"\n",
"\n",
"chat_loop()\n"
]
},
@ -1568,21 +1571,18 @@
"source": [
"from llama_stack_client.lib.inference.event_logger import EventLogger\n",
"\n",
"message = {\n",
" \"role\": \"user\",\n",
" \"content\": 'Write me a sonnet about llama'\n",
"}\n",
"print(f'User> {message[\"content\"]}', 'green')\n",
"message = {\"role\": \"user\", \"content\": \"Write me a sonnet about llama\"}\n",
"print(f'User> {message[\"content\"]}', \"green\")\n",
"\n",
"response = client.inference.chat_completion(\n",
" messages=[message],\n",
" model_id=model_id,\n",
" stream=True, # <-----------\n",
" stream=True, # <-----------\n",
")\n",
"\n",
"# Print the tokens while they are received\n",
"for log in EventLogger().log(response):\n",
" log.print()"
" log.print()\n"
]
},
{
@ -1648,17 +1648,22 @@
"source": [
"from pydantic import BaseModel\n",
"\n",
"\n",
"class Output(BaseModel):\n",
" name: str\n",
" year_born: str\n",
" year_retired: str\n",
"\n",
"\n",
"user_input = \"Michael Jordan was born in 1963. He played basketball for the Chicago Bulls. He retired in 2003. Extract this information into JSON for me. \"\n",
"response = client.inference.completion(\n",
" model_id=model_id,\n",
" content=user_input,\n",
" stream=False,\n",
" sampling_params={\n",
" \"strategy\": {\n",
" \"type\": \"greedy\",\n",
" },\n",
" \"max_tokens\": 50,\n",
" },\n",
" response_format={\n",
@ -1667,7 +1672,7 @@
" },\n",
")\n",
"\n",
"pprint(response)"
"pprint(response)\n"
]
},
{
@ -1823,7 +1828,7 @@
" shield_id=available_shields[0],\n",
" params={},\n",
" )\n",
" pprint(response)"
" pprint(response)\n"
]
},
{
@ -2025,7 +2030,7 @@
"\n",
"session_id = agent.create_session(\"test-session\")\n",
"for prompt in user_prompts:\n",
" cprint(f'User> {prompt}', 'green')\n",
" cprint(f\"User> {prompt}\", \"green\")\n",
" response = agent.create_turn(\n",
" messages=[\n",
" {\n",
@ -2451,8 +2456,8 @@
}
],
"source": [
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"import pandas as pd\n",
"\n",
"# Load data\n",
"df = pd.read_csv(\"/tmp/tmpvzjigv7g/n2OzlTWhinflation.csv\")\n",
@ -2536,10 +2541,10 @@
}
],
"source": [
"from google.colab import userdata\n",
"from llama_stack_client.lib.agents.agent import Agent\n",
"from llama_stack_client.lib.agents.event_logger import EventLogger\n",
"from llama_stack_client.types.agent_create_params import AgentConfig\n",
"from google.colab import userdata\n",
"\n",
"agent_config = AgentConfig(\n",
" model=\"meta-llama/Llama-3.1-405B-Instruct-FP8\",\n",
@ -2570,7 +2575,7 @@
" )\n",
"\n",
" for log in EventLogger().log(response):\n",
" log.print()"
" log.print()\n"
]
},
{
@ -2790,20 +2795,21 @@
"source": [
"print(f\"Getting traces for session_id={session_id}\")\n",
"import json\n",
"\n",
"from rich.pretty import pprint\n",
"\n",
"agent_logs = []\n",
"\n",
"for span in client.telemetry.query_spans(\n",
" attribute_filters=[\n",
" {\"key\": \"session_id\", \"op\": \"eq\", \"value\": session_id},\n",
" {\"key\": \"session_id\", \"op\": \"eq\", \"value\": session_id},\n",
" ],\n",
" attributes_to_return=[\"input\", \"output\"]\n",
" ):\n",
" if span.attributes[\"output\"] != \"no shields\":\n",
" agent_logs.append(span.attributes)\n",
" attributes_to_return=[\"input\", \"output\"],\n",
"):\n",
" if span.attributes[\"output\"] != \"no shields\":\n",
" agent_logs.append(span.attributes)\n",
"\n",
"pprint(agent_logs)"
"pprint(agent_logs)\n"
]
},
{
@ -2914,23 +2920,25 @@
"eval_rows = []\n",
"\n",
"for log in agent_logs:\n",
" last_msg = log['input'][-1]\n",
" if \"\\\"role\\\":\\\"user\\\"\" in last_msg:\n",
" eval_rows.append(\n",
" {\n",
" \"input_query\": last_msg,\n",
" \"generated_answer\": log[\"output\"],\n",
" # check if generated_answer uses tools brave_search\n",
" \"expected_answer\": \"brave_search\",\n",
" },\n",
" )\n",
" last_msg = log[\"input\"][-1]\n",
" if '\"role\":\"user\"' in last_msg:\n",
" eval_rows.append(\n",
" {\n",
" \"input_query\": last_msg,\n",
" \"generated_answer\": log[\"output\"],\n",
" # check if generated_answer uses tools brave_search\n",
" \"expected_answer\": \"brave_search\",\n",
" },\n",
" )\n",
"\n",
"pprint(eval_rows)\n",
"scoring_params = {\n",
" \"basic::subset_of\": None,\n",
"}\n",
"scoring_response = client.scoring.score(input_rows=eval_rows, scoring_functions=scoring_params)\n",
"pprint(scoring_response)"
"scoring_response = client.scoring.score(\n",
" input_rows=eval_rows, scoring_functions=scoring_params\n",
")\n",
"pprint(scoring_response)\n"
]
},
{
@ -3031,7 +3039,9 @@
"EXPECTED_RESPONSE: {expected_answer}\n",
"\"\"\"\n",
"\n",
"input_query = \"What are the top 5 topics that were explained? Only list succinct bullet points.\"\n",
"input_query = (\n",
" \"What are the top 5 topics that were explained? Only list succinct bullet points.\"\n",
")\n",
"generated_answer = \"\"\"\n",
"Here are the top 5 topics that were explained in the documentation for Torchtune:\n",
"\n",
@ -3062,7 +3072,7 @@
"}\n",
"\n",
"response = client.scoring.score(input_rows=rows, scoring_functions=scoring_params)\n",
"pprint(response)"
"pprint(response)\n"
]
}
],