Update Strategy in SamplingParams to be a union

This commit is contained in:
Hardik Shah 2025-01-14 15:56:02 -08:00 committed by Ashwin Bharambe
parent 300e6e2702
commit dea575c994
28 changed files with 600 additions and 377 deletions

View file

@ -713,13 +713,15 @@
],
"source": [
"import os\n",
"\n",
"from google.colab import userdata\n",
"\n",
"os.environ['TOGETHER_API_KEY'] = userdata.get('TOGETHER_API_KEY')\n",
"os.environ[\"TOGETHER_API_KEY\"] = userdata.get(\"TOGETHER_API_KEY\")\n",
"\n",
"from llama_stack.distribution.library_client import LlamaStackAsLibraryClient\n",
"\n",
"client = LlamaStackAsLibraryClient(\"together\")\n",
"_ = client.initialize()"
"_ = client.initialize()\n"
]
},
{
@ -769,6 +771,7 @@
],
"source": [
"from rich.pretty import pprint\n",
"\n",
"print(\"Available models:\")\n",
"for m in client.models.list():\n",
" print(f\"{m.identifier} (provider's alias: {m.provider_resource_id}) \")\n",
@ -777,7 +780,7 @@
"print(\"Available shields (safety models):\")\n",
"for s in client.shields.list():\n",
" print(s.identifier)\n",
"print(\"----\")"
"print(\"----\")\n"
]
},
{
@ -822,7 +825,7 @@
"source": [
"model_id = \"meta-llama/Llama-3.1-70B-Instruct\"\n",
"\n",
"model_id"
"model_id\n"
]
},
{
@ -863,11 +866,11 @@
" model_id=model_id,\n",
" messages=[\n",
" {\"role\": \"system\", \"content\": \"You are a friendly assistant.\"},\n",
" {\"role\": \"user\", \"content\": \"Write a two-sentence poem about llama.\"}\n",
" {\"role\": \"user\", \"content\": \"Write a two-sentence poem about llama.\"},\n",
" ],\n",
")\n",
"\n",
"print(response.completion_message.content)"
"print(response.completion_message.content)\n"
]
},
{
@ -900,12 +903,13 @@
"source": [
"from termcolor import cprint\n",
"\n",
"\n",
"def chat_loop():\n",
" conversation_history = []\n",
" while True:\n",
" user_input = input('User> ')\n",
" if user_input.lower() in ['exit', 'quit', 'bye']:\n",
" cprint('Ending conversation. Goodbye!', 'yellow')\n",
" user_input = input(\"User> \")\n",
" if user_input.lower() in [\"exit\", \"quit\", \"bye\"]:\n",
" cprint(\"Ending conversation. Goodbye!\", \"yellow\")\n",
" break\n",
"\n",
" user_message = {\"role\": \"user\", \"content\": user_input}\n",
@ -915,14 +919,15 @@
" messages=conversation_history,\n",
" model_id=model_id,\n",
" )\n",
" cprint(f'> Response: {response.completion_message.content}', 'cyan')\n",
" cprint(f\"> Response: {response.completion_message.content}\", \"cyan\")\n",
"\n",
" assistant_message = {\n",
" \"role\": \"assistant\", # was user\n",
" \"role\": \"assistant\", # was user\n",
" \"content\": response.completion_message.content,\n",
" }\n",
" conversation_history.append(assistant_message)\n",
"\n",
"\n",
"chat_loop()\n"
]
},
@ -978,21 +983,18 @@
"source": [
"from llama_stack_client.lib.inference.event_logger import EventLogger\n",
"\n",
"message = {\n",
" \"role\": \"user\",\n",
" \"content\": 'Write me a sonnet about llama'\n",
"}\n",
"print(f'User> {message[\"content\"]}', 'green')\n",
"message = {\"role\": \"user\", \"content\": \"Write me a sonnet about llama\"}\n",
"print(f'User> {message[\"content\"]}', \"green\")\n",
"\n",
"response = client.inference.chat_completion(\n",
" messages=[message],\n",
" model_id=model_id,\n",
" stream=True, # <-----------\n",
" stream=True, # <-----------\n",
")\n",
"\n",
"# Print the tokens while they are received\n",
"for log in EventLogger().log(response):\n",
" log.print()"
" log.print()\n"
]
},
{
@ -1045,26 +1047,26 @@
"source": [
"from pydantic import BaseModel\n",
"\n",
"\n",
"class Output(BaseModel):\n",
" name: str\n",
" year_born: str\n",
" year_retired: str\n",
"\n",
"\n",
"user_input = \"Michael Jordan was born in 1963. He played basketball for the Chicago Bulls. He retired in 2003. Extract this information into JSON for me. \"\n",
"response = client.inference.completion(\n",
" model_id=model_id,\n",
" content=user_input,\n",
" stream=False,\n",
" sampling_params={\n",
" \"max_tokens\": 50,\n",
" },\n",
" sampling_params={\"strategy\": {\"type\": \"greedy\"}, \"max_tokens\": 50},\n",
" response_format={\n",
" \"type\": \"json_schema\",\n",
" \"json_schema\": Output.model_json_schema(),\n",
" },\n",
")\n",
"\n",
"pprint(response)"
"pprint(response)\n"
]
},
{
@ -1220,7 +1222,7 @@
" shield_id=available_shields[0],\n",
" params={},\n",
" )\n",
" pprint(response)"
" pprint(response)\n"
]
},
{
@ -1489,8 +1491,8 @@
"source": [
"from llama_stack_client.lib.agents.agent import Agent\n",
"from llama_stack_client.lib.agents.event_logger import EventLogger\n",
"from llama_stack_client.types.agent_create_params import AgentConfig\n",
"from llama_stack_client.types import Attachment\n",
"from llama_stack_client.types.agent_create_params import AgentConfig\n",
"from termcolor import cprint\n",
"\n",
"urls = [\"chat.rst\", \"llama3.rst\", \"datasets.rst\", \"lora_finetune.rst\"]\n",
@ -1522,14 +1524,14 @@
" ),\n",
"]\n",
"for prompt, attachments in user_prompts:\n",
" cprint(f'User> {prompt}', 'green')\n",
" cprint(f\"User> {prompt}\", \"green\")\n",
" response = rag_agent.create_turn(\n",
" messages=[{\"role\": \"user\", \"content\": prompt}],\n",
" attachments=attachments,\n",
" session_id=session_id,\n",
" )\n",
" for log in EventLogger().log(response):\n",
" log.print()"
" log.print()\n"
]
},
{
@ -1560,8 +1562,8 @@
"search_tool = {\n",
" \"type\": \"brave_search\",\n",
" \"engine\": \"tavily\",\n",
" \"api_key\": userdata.get(\"TAVILY_SEARCH_API_KEY\")\n",
"}"
" \"api_key\": userdata.get(\"TAVILY_SEARCH_API_KEY\"),\n",
"}\n"
]
},
{
@ -1608,7 +1610,7 @@
"\n",
"session_id = agent.create_session(\"test-session\")\n",
"for prompt in user_prompts:\n",
" cprint(f'User> {prompt}', 'green')\n",
" cprint(f\"User> {prompt}\", \"green\")\n",
" response = agent.create_turn(\n",
" messages=[\n",
" {\n",
@ -1758,7 +1760,7 @@
" search_tool,\n",
" {\n",
" \"type\": \"code_interpreter\",\n",
" }\n",
" },\n",
" ],\n",
" tool_choice=\"required\",\n",
" input_shields=[],\n",
@ -1788,7 +1790,7 @@
"]\n",
"\n",
"for prompt in user_prompts:\n",
" cprint(f'User> {prompt}', 'green')\n",
" cprint(f\"User> {prompt}\", \"green\")\n",
" response = codex_agent.create_turn(\n",
" messages=[\n",
" {\n",
@ -1841,27 +1843,57 @@
}
],
"source": [
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"import pandas as pd\n",
"\n",
"# Read the CSV file\n",
"df = pd.read_csv('/tmp/tmpco0s0o4_/LOdZoVp1inflation.csv')\n",
"df = pd.read_csv(\"/tmp/tmpco0s0o4_/LOdZoVp1inflation.csv\")\n",
"\n",
"# Extract the year and inflation rate from the CSV file\n",
"df['Year'] = pd.to_datetime(df['Year'], format='%Y')\n",
"df = df.rename(columns={'Jan': 'Jan Rate', 'Feb': 'Feb Rate', 'Mar': 'Mar Rate', 'Apr': 'Apr Rate', 'May': 'May Rate', 'Jun': 'Jun Rate', 'Jul': 'Jul Rate', 'Aug': 'Aug Rate', 'Sep': 'Sep Rate', 'Oct': 'Oct Rate', 'Nov': 'Nov Rate', 'Dec': 'Dec Rate'})\n",
"df[\"Year\"] = pd.to_datetime(df[\"Year\"], format=\"%Y\")\n",
"df = df.rename(\n",
" columns={\n",
" \"Jan\": \"Jan Rate\",\n",
" \"Feb\": \"Feb Rate\",\n",
" \"Mar\": \"Mar Rate\",\n",
" \"Apr\": \"Apr Rate\",\n",
" \"May\": \"May Rate\",\n",
" \"Jun\": \"Jun Rate\",\n",
" \"Jul\": \"Jul Rate\",\n",
" \"Aug\": \"Aug Rate\",\n",
" \"Sep\": \"Sep Rate\",\n",
" \"Oct\": \"Oct Rate\",\n",
" \"Nov\": \"Nov Rate\",\n",
" \"Dec\": \"Dec Rate\",\n",
" }\n",
")\n",
"\n",
"# Calculate the average yearly inflation rate\n",
"df['Yearly Inflation'] = df[['Jan Rate', 'Feb Rate', 'Mar Rate', 'Apr Rate', 'May Rate', 'Jun Rate', 'Jul Rate', 'Aug Rate', 'Sep Rate', 'Oct Rate', 'Nov Rate', 'Dec Rate']].mean(axis=1)\n",
"df[\"Yearly Inflation\"] = df[\n",
" [\n",
" \"Jan Rate\",\n",
" \"Feb Rate\",\n",
" \"Mar Rate\",\n",
" \"Apr Rate\",\n",
" \"May Rate\",\n",
" \"Jun Rate\",\n",
" \"Jul Rate\",\n",
" \"Aug Rate\",\n",
" \"Sep Rate\",\n",
" \"Oct Rate\",\n",
" \"Nov Rate\",\n",
" \"Dec Rate\",\n",
" ]\n",
"].mean(axis=1)\n",
"\n",
"# Plot the average yearly inflation rate as a time series\n",
"plt.figure(figsize=(10, 6))\n",
"plt.plot(df['Year'], df['Yearly Inflation'], marker='o')\n",
"plt.title('Average Yearly Inflation Rate')\n",
"plt.xlabel('Year')\n",
"plt.ylabel('Inflation Rate (%)')\n",
"plt.plot(df[\"Year\"], df[\"Yearly Inflation\"], marker=\"o\")\n",
"plt.title(\"Average Yearly Inflation Rate\")\n",
"plt.xlabel(\"Year\")\n",
"plt.ylabel(\"Inflation Rate (%)\")\n",
"plt.grid(True)\n",
"plt.show()"
"plt.show()\n"
]
},
{
@ -2035,6 +2067,8 @@
"source": [
"# disable logging for clean server logs\n",
"import logging\n",
"\n",
"\n",
"def remove_root_handlers():\n",
" root_logger = logging.getLogger()\n",
" for handler in root_logger.handlers[:]:\n",
@ -2042,7 +2076,7 @@
" print(f\"Removed handler {handler.__class__.__name__} from root logger\")\n",
"\n",
"\n",
"remove_root_handlers()"
"remove_root_handlers()\n"
]
},
{
@ -2083,10 +2117,10 @@
}
],
"source": [
"from google.colab import userdata\n",
"from llama_stack_client.lib.agents.agent import Agent\n",
"from llama_stack_client.lib.agents.event_logger import EventLogger\n",
"from llama_stack_client.types.agent_create_params import AgentConfig\n",
"from google.colab import userdata\n",
"\n",
"agent_config = AgentConfig(\n",
" model=\"meta-llama/Llama-3.1-405B-Instruct\",\n",
@ -2096,7 +2130,7 @@
" {\n",
" \"type\": \"brave_search\",\n",
" \"engine\": \"tavily\",\n",
" \"api_key\": userdata.get(\"TAVILY_SEARCH_API_KEY\")\n",
" \"api_key\": userdata.get(\"TAVILY_SEARCH_API_KEY\"),\n",
" }\n",
" ]\n",
" ),\n",
@ -2125,7 +2159,7 @@
" )\n",
"\n",
" for log in EventLogger().log(response):\n",
" log.print()"
" log.print()\n"
]
},
{
@ -2265,20 +2299,21 @@
"source": [
"print(f\"Getting traces for session_id={session_id}\")\n",
"import json\n",
"\n",
"from rich.pretty import pprint\n",
"\n",
"agent_logs = []\n",
"\n",
"for span in client.telemetry.query_spans(\n",
" attribute_filters=[\n",
" {\"key\": \"session_id\", \"op\": \"eq\", \"value\": session_id},\n",
" {\"key\": \"session_id\", \"op\": \"eq\", \"value\": session_id},\n",
" ],\n",
" attributes_to_return=[\"input\", \"output\"]\n",
" ):\n",
" if span.attributes[\"output\"] != \"no shields\":\n",
" agent_logs.append(span.attributes)\n",
" attributes_to_return=[\"input\", \"output\"],\n",
"):\n",
" if span.attributes[\"output\"] != \"no shields\":\n",
" agent_logs.append(span.attributes)\n",
"\n",
"pprint(agent_logs)"
"pprint(agent_logs)\n"
]
},
{
@ -2389,23 +2424,25 @@
"eval_rows = []\n",
"\n",
"for log in agent_logs:\n",
" last_msg = log['input'][-1]\n",
" if \"\\\"role\\\":\\\"user\\\"\" in last_msg:\n",
" eval_rows.append(\n",
" {\n",
" \"input_query\": last_msg,\n",
" \"generated_answer\": log[\"output\"],\n",
" # check if generated_answer uses tools brave_search\n",
" \"expected_answer\": \"brave_search\",\n",
" },\n",
" )\n",
" last_msg = log[\"input\"][-1]\n",
" if '\"role\":\"user\"' in last_msg:\n",
" eval_rows.append(\n",
" {\n",
" \"input_query\": last_msg,\n",
" \"generated_answer\": log[\"output\"],\n",
" # check if generated_answer uses tools brave_search\n",
" \"expected_answer\": \"brave_search\",\n",
" },\n",
" )\n",
"\n",
"pprint(eval_rows)\n",
"scoring_params = {\n",
" \"basic::subset_of\": None,\n",
"}\n",
"scoring_response = client.scoring.score(input_rows=eval_rows, scoring_functions=scoring_params)\n",
"pprint(scoring_response)"
"scoring_response = client.scoring.score(\n",
" input_rows=eval_rows, scoring_functions=scoring_params\n",
")\n",
"pprint(scoring_response)\n"
]
},
{
@ -2506,7 +2543,9 @@
"EXPECTED_RESPONSE: {expected_answer}\n",
"\"\"\"\n",
"\n",
"input_query = \"What are the top 5 topics that were explained? Only list succinct bullet points.\"\n",
"input_query = (\n",
" \"What are the top 5 topics that were explained? Only list succinct bullet points.\"\n",
")\n",
"generated_answer = \"\"\"\n",
"Here are the top 5 topics that were explained in the documentation for Torchtune:\n",
"\n",
@ -2537,7 +2576,7 @@
"}\n",
"\n",
"response = client.scoring.score(input_rows=rows, scoring_functions=scoring_params)\n",
"pprint(response)"
"pprint(response)\n"
]
},
{