mirror of
https://github.com/meta-llama/llama-stack.git
synced 2026-01-03 10:42:16 +00:00
Update Strategy in SamplingParams to be a union
This commit is contained in:
parent
300e6e2702
commit
dea575c994
28 changed files with 600 additions and 377 deletions
|
|
@ -618,11 +618,13 @@
|
|||
],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"\n",
|
||||
"from google.colab import userdata\n",
|
||||
"\n",
|
||||
"os.environ['TOGETHER_API_KEY'] = userdata.get('TOGETHER_API_KEY')\n",
|
||||
"os.environ[\"TOGETHER_API_KEY\"] = userdata.get(\"TOGETHER_API_KEY\")\n",
|
||||
"\n",
|
||||
"from llama_stack.distribution.library_client import LlamaStackAsLibraryClient\n",
|
||||
"\n",
|
||||
"client = LlamaStackAsLibraryClient(\"together\")\n",
|
||||
"_ = client.initialize()\n",
|
||||
"\n",
|
||||
|
|
@ -631,7 +633,7 @@
|
|||
" model_id=\"meta-llama/Llama-3.1-405B-Instruct\",\n",
|
||||
" provider_model_id=\"meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo\",\n",
|
||||
" provider_id=\"together\",\n",
|
||||
")"
|
||||
")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
@ -668,7 +670,7 @@
|
|||
"source": [
|
||||
"name = \"llamastack/mmmu\"\n",
|
||||
"subset = \"Agriculture\"\n",
|
||||
"split = \"dev\""
|
||||
"split = \"dev\"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
@ -914,9 +916,10 @@
|
|||
],
|
||||
"source": [
|
||||
"import datasets\n",
|
||||
"\n",
|
||||
"ds = datasets.load_dataset(path=name, name=subset, split=split)\n",
|
||||
"ds = ds.select_columns([\"chat_completion_input\", \"input_query\", \"expected_answer\"])\n",
|
||||
"eval_rows = ds.to_pandas().to_dict(orient=\"records\")"
|
||||
"eval_rows = ds.to_pandas().to_dict(orient=\"records\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
@ -1014,8 +1017,8 @@
|
|||
}
|
||||
],
|
||||
"source": [
|
||||
"from tqdm import tqdm\n",
|
||||
"from rich.pretty import pprint\n",
|
||||
"from tqdm import tqdm\n",
|
||||
"\n",
|
||||
"SYSTEM_PROMPT_TEMPLATE = \"\"\"\n",
|
||||
"You are an expert in {subject} whose job is to answer questions from the user using images.\n",
|
||||
|
|
@ -1039,7 +1042,7 @@
|
|||
"client.eval_tasks.register(\n",
|
||||
" eval_task_id=\"meta-reference::mmmu\",\n",
|
||||
" dataset_id=f\"mmmu-{subset}-{split}\",\n",
|
||||
" scoring_functions=[\"basic::regex_parser_multiple_choice_answer\"]\n",
|
||||
" scoring_functions=[\"basic::regex_parser_multiple_choice_answer\"],\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"response = client.eval.evaluate_rows(\n",
|
||||
|
|
@ -1052,16 +1055,17 @@
|
|||
" \"type\": \"model\",\n",
|
||||
" \"model\": \"meta-llama/Llama-3.2-90B-Vision-Instruct\",\n",
|
||||
" \"sampling_params\": {\n",
|
||||
" \"temperature\": 0.0,\n",
|
||||
" \"strategy\": {\n",
|
||||
" \"type\": \"greedy\",\n",
|
||||
" },\n",
|
||||
" \"max_tokens\": 4096,\n",
|
||||
" \"top_p\": 0.9,\n",
|
||||
" \"repeat_penalty\": 1.0,\n",
|
||||
" },\n",
|
||||
" \"system_message\": system_message\n",
|
||||
" }\n",
|
||||
" }\n",
|
||||
" \"system_message\": system_message,\n",
|
||||
" },\n",
|
||||
" },\n",
|
||||
")\n",
|
||||
"pprint(response)"
|
||||
"pprint(response)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
@ -1098,8 +1102,8 @@
|
|||
" \"input_query\": {\"type\": \"string\"},\n",
|
||||
" \"expected_answer\": {\"type\": \"string\"},\n",
|
||||
" \"chat_completion_input\": {\"type\": \"chat_completion_input\"},\n",
|
||||
" }\n",
|
||||
")"
|
||||
" },\n",
|
||||
")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
@ -1113,7 +1117,7 @@
|
|||
"eval_rows = client.datasetio.get_rows_paginated(\n",
|
||||
" dataset_id=simpleqa_dataset_id,\n",
|
||||
" rows_in_page=5,\n",
|
||||
")"
|
||||
")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
@ -1209,7 +1213,7 @@
|
|||
"client.eval_tasks.register(\n",
|
||||
" eval_task_id=\"meta-reference::simpleqa\",\n",
|
||||
" dataset_id=simpleqa_dataset_id,\n",
|
||||
" scoring_functions=[\"llm-as-judge::405b-simpleqa\"]\n",
|
||||
" scoring_functions=[\"llm-as-judge::405b-simpleqa\"],\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"response = client.eval.evaluate_rows(\n",
|
||||
|
|
@ -1222,15 +1226,16 @@
|
|||
" \"type\": \"model\",\n",
|
||||
" \"model\": \"meta-llama/Llama-3.2-90B-Vision-Instruct\",\n",
|
||||
" \"sampling_params\": {\n",
|
||||
" \"temperature\": 0.0,\n",
|
||||
" \"strategy\": {\n",
|
||||
" \"type\": \"greedy\",\n",
|
||||
" },\n",
|
||||
" \"max_tokens\": 4096,\n",
|
||||
" \"top_p\": 0.9,\n",
|
||||
" \"repeat_penalty\": 1.0,\n",
|
||||
" },\n",
|
||||
" }\n",
|
||||
" }\n",
|
||||
" },\n",
|
||||
" },\n",
|
||||
")\n",
|
||||
"pprint(response)"
|
||||
"pprint(response)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
@ -1347,23 +1352,19 @@
|
|||
"agent_config = {\n",
|
||||
" \"model\": \"meta-llama/Llama-3.1-405B-Instruct\",\n",
|
||||
" \"instructions\": \"You are a helpful assistant\",\n",
|
||||
" \"sampling_params\": {\n",
|
||||
" \"strategy\": \"greedy\",\n",
|
||||
" \"temperature\": 0.0,\n",
|
||||
" \"top_p\": 0.95,\n",
|
||||
" },\n",
|
||||
" \"sampling_params\": {\"strategy\": {\"type\": \"greedy\"}},\n",
|
||||
" \"tools\": [\n",
|
||||
" {\n",
|
||||
" \"type\": \"brave_search\",\n",
|
||||
" \"engine\": \"tavily\",\n",
|
||||
" \"api_key\": userdata.get(\"TAVILY_SEARCH_API_KEY\")\n",
|
||||
" \"api_key\": userdata.get(\"TAVILY_SEARCH_API_KEY\"),\n",
|
||||
" }\n",
|
||||
" ],\n",
|
||||
" \"tool_choice\": \"auto\",\n",
|
||||
" \"tool_prompt_format\": \"json\",\n",
|
||||
" \"input_shields\": [],\n",
|
||||
" \"output_shields\": [],\n",
|
||||
" \"enable_session_persistence\": False\n",
|
||||
" \"enable_session_persistence\": False,\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"response = client.eval.evaluate_rows(\n",
|
||||
|
|
@ -1375,10 +1376,10 @@
|
|||
" \"eval_candidate\": {\n",
|
||||
" \"type\": \"agent\",\n",
|
||||
" \"config\": agent_config,\n",
|
||||
" }\n",
|
||||
" }\n",
|
||||
" },\n",
|
||||
" },\n",
|
||||
")\n",
|
||||
"pprint(response)"
|
||||
"pprint(response)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue