Update Strategy in SamplingParams to be a union

This commit is contained in:
Hardik Shah 2025-01-14 15:56:02 -08:00 committed by Ashwin Bharambe
parent 300e6e2702
commit dea575c994
28 changed files with 600 additions and 377 deletions

View file

@ -618,11 +618,13 @@
],
"source": [
"import os\n",
"\n",
"from google.colab import userdata\n",
"\n",
"os.environ['TOGETHER_API_KEY'] = userdata.get('TOGETHER_API_KEY')\n",
"os.environ[\"TOGETHER_API_KEY\"] = userdata.get(\"TOGETHER_API_KEY\")\n",
"\n",
"from llama_stack.distribution.library_client import LlamaStackAsLibraryClient\n",
"\n",
"client = LlamaStackAsLibraryClient(\"together\")\n",
"_ = client.initialize()\n",
"\n",
@ -631,7 +633,7 @@
" model_id=\"meta-llama/Llama-3.1-405B-Instruct\",\n",
" provider_model_id=\"meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo\",\n",
" provider_id=\"together\",\n",
")"
")\n"
]
},
{
@ -668,7 +670,7 @@
"source": [
"name = \"llamastack/mmmu\"\n",
"subset = \"Agriculture\"\n",
"split = \"dev\""
"split = \"dev\"\n"
]
},
{
@ -914,9 +916,10 @@
],
"source": [
"import datasets\n",
"\n",
"ds = datasets.load_dataset(path=name, name=subset, split=split)\n",
"ds = ds.select_columns([\"chat_completion_input\", \"input_query\", \"expected_answer\"])\n",
"eval_rows = ds.to_pandas().to_dict(orient=\"records\")"
"eval_rows = ds.to_pandas().to_dict(orient=\"records\")\n"
]
},
{
@ -1014,8 +1017,8 @@
}
],
"source": [
"from tqdm import tqdm\n",
"from rich.pretty import pprint\n",
"from tqdm import tqdm\n",
"\n",
"SYSTEM_PROMPT_TEMPLATE = \"\"\"\n",
"You are an expert in {subject} whose job is to answer questions from the user using images.\n",
@ -1039,7 +1042,7 @@
"client.eval_tasks.register(\n",
" eval_task_id=\"meta-reference::mmmu\",\n",
" dataset_id=f\"mmmu-{subset}-{split}\",\n",
" scoring_functions=[\"basic::regex_parser_multiple_choice_answer\"]\n",
" scoring_functions=[\"basic::regex_parser_multiple_choice_answer\"],\n",
")\n",
"\n",
"response = client.eval.evaluate_rows(\n",
@ -1052,16 +1055,17 @@
" \"type\": \"model\",\n",
" \"model\": \"meta-llama/Llama-3.2-90B-Vision-Instruct\",\n",
" \"sampling_params\": {\n",
" \"temperature\": 0.0,\n",
" \"strategy\": {\n",
" \"type\": \"greedy\",\n",
" },\n",
" \"max_tokens\": 4096,\n",
" \"top_p\": 0.9,\n",
" \"repeat_penalty\": 1.0,\n",
" },\n",
" \"system_message\": system_message\n",
" }\n",
" }\n",
" \"system_message\": system_message,\n",
" },\n",
" },\n",
")\n",
"pprint(response)"
"pprint(response)\n"
]
},
{
@ -1098,8 +1102,8 @@
" \"input_query\": {\"type\": \"string\"},\n",
" \"expected_answer\": {\"type\": \"string\"},\n",
" \"chat_completion_input\": {\"type\": \"chat_completion_input\"},\n",
" }\n",
")"
" },\n",
")\n"
]
},
{
@ -1113,7 +1117,7 @@
"eval_rows = client.datasetio.get_rows_paginated(\n",
" dataset_id=simpleqa_dataset_id,\n",
" rows_in_page=5,\n",
")"
")\n"
]
},
{
@ -1209,7 +1213,7 @@
"client.eval_tasks.register(\n",
" eval_task_id=\"meta-reference::simpleqa\",\n",
" dataset_id=simpleqa_dataset_id,\n",
" scoring_functions=[\"llm-as-judge::405b-simpleqa\"]\n",
" scoring_functions=[\"llm-as-judge::405b-simpleqa\"],\n",
")\n",
"\n",
"response = client.eval.evaluate_rows(\n",
@ -1222,15 +1226,16 @@
" \"type\": \"model\",\n",
" \"model\": \"meta-llama/Llama-3.2-90B-Vision-Instruct\",\n",
" \"sampling_params\": {\n",
" \"temperature\": 0.0,\n",
" \"strategy\": {\n",
" \"type\": \"greedy\",\n",
" },\n",
" \"max_tokens\": 4096,\n",
" \"top_p\": 0.9,\n",
" \"repeat_penalty\": 1.0,\n",
" },\n",
" }\n",
" }\n",
" },\n",
" },\n",
")\n",
"pprint(response)"
"pprint(response)\n"
]
},
{
@ -1347,23 +1352,19 @@
"agent_config = {\n",
" \"model\": \"meta-llama/Llama-3.1-405B-Instruct\",\n",
" \"instructions\": \"You are a helpful assistant\",\n",
" \"sampling_params\": {\n",
" \"strategy\": \"greedy\",\n",
" \"temperature\": 0.0,\n",
" \"top_p\": 0.95,\n",
" },\n",
" \"sampling_params\": {\"strategy\": {\"type\": \"greedy\"}},\n",
" \"tools\": [\n",
" {\n",
" \"type\": \"brave_search\",\n",
" \"engine\": \"tavily\",\n",
" \"api_key\": userdata.get(\"TAVILY_SEARCH_API_KEY\")\n",
" \"api_key\": userdata.get(\"TAVILY_SEARCH_API_KEY\"),\n",
" }\n",
" ],\n",
" \"tool_choice\": \"auto\",\n",
" \"tool_prompt_format\": \"json\",\n",
" \"input_shields\": [],\n",
" \"output_shields\": [],\n",
" \"enable_session_persistence\": False\n",
" \"enable_session_persistence\": False,\n",
"}\n",
"\n",
"response = client.eval.evaluate_rows(\n",
@ -1375,10 +1376,10 @@
" \"eval_candidate\": {\n",
" \"type\": \"agent\",\n",
" \"config\": agent_config,\n",
" }\n",
" }\n",
" },\n",
" },\n",
")\n",
"pprint(response)"
"pprint(response)\n"
]
}
],