mirror of
https://github.com/meta-llama/llama-stack.git
synced 2026-01-01 11:10:00 +00:00
Update Strategy in SamplingParams to be a union
This commit is contained in:
parent
300e6e2702
commit
dea575c994
28 changed files with 600 additions and 377 deletions
|
|
@ -713,13 +713,15 @@
|
|||
],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"\n",
|
||||
"from google.colab import userdata\n",
|
||||
"\n",
|
||||
"os.environ['TOGETHER_API_KEY'] = userdata.get('TOGETHER_API_KEY')\n",
|
||||
"os.environ[\"TOGETHER_API_KEY\"] = userdata.get(\"TOGETHER_API_KEY\")\n",
|
||||
"\n",
|
||||
"from llama_stack.distribution.library_client import LlamaStackAsLibraryClient\n",
|
||||
"\n",
|
||||
"client = LlamaStackAsLibraryClient(\"together\")\n",
|
||||
"_ = client.initialize()"
|
||||
"_ = client.initialize()\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
@ -769,6 +771,7 @@
|
|||
],
|
||||
"source": [
|
||||
"from rich.pretty import pprint\n",
|
||||
"\n",
|
||||
"print(\"Available models:\")\n",
|
||||
"for m in client.models.list():\n",
|
||||
" print(f\"{m.identifier} (provider's alias: {m.provider_resource_id}) \")\n",
|
||||
|
|
@ -777,7 +780,7 @@
|
|||
"print(\"Available shields (safety models):\")\n",
|
||||
"for s in client.shields.list():\n",
|
||||
" print(s.identifier)\n",
|
||||
"print(\"----\")"
|
||||
"print(\"----\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
@ -822,7 +825,7 @@
|
|||
"source": [
|
||||
"model_id = \"meta-llama/Llama-3.1-70B-Instruct\"\n",
|
||||
"\n",
|
||||
"model_id"
|
||||
"model_id\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
@ -863,11 +866,11 @@
|
|||
" model_id=model_id,\n",
|
||||
" messages=[\n",
|
||||
" {\"role\": \"system\", \"content\": \"You are a friendly assistant.\"},\n",
|
||||
" {\"role\": \"user\", \"content\": \"Write a two-sentence poem about llama.\"}\n",
|
||||
" {\"role\": \"user\", \"content\": \"Write a two-sentence poem about llama.\"},\n",
|
||||
" ],\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"print(response.completion_message.content)"
|
||||
"print(response.completion_message.content)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
@ -900,12 +903,13 @@
|
|||
"source": [
|
||||
"from termcolor import cprint\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def chat_loop():\n",
|
||||
" conversation_history = []\n",
|
||||
" while True:\n",
|
||||
" user_input = input('User> ')\n",
|
||||
" if user_input.lower() in ['exit', 'quit', 'bye']:\n",
|
||||
" cprint('Ending conversation. Goodbye!', 'yellow')\n",
|
||||
" user_input = input(\"User> \")\n",
|
||||
" if user_input.lower() in [\"exit\", \"quit\", \"bye\"]:\n",
|
||||
" cprint(\"Ending conversation. Goodbye!\", \"yellow\")\n",
|
||||
" break\n",
|
||||
"\n",
|
||||
" user_message = {\"role\": \"user\", \"content\": user_input}\n",
|
||||
|
|
@ -915,14 +919,15 @@
|
|||
" messages=conversation_history,\n",
|
||||
" model_id=model_id,\n",
|
||||
" )\n",
|
||||
" cprint(f'> Response: {response.completion_message.content}', 'cyan')\n",
|
||||
" cprint(f\"> Response: {response.completion_message.content}\", \"cyan\")\n",
|
||||
"\n",
|
||||
" assistant_message = {\n",
|
||||
" \"role\": \"assistant\", # was user\n",
|
||||
" \"role\": \"assistant\", # was user\n",
|
||||
" \"content\": response.completion_message.content,\n",
|
||||
" }\n",
|
||||
" conversation_history.append(assistant_message)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"chat_loop()\n"
|
||||
]
|
||||
},
|
||||
|
|
@ -978,21 +983,18 @@
|
|||
"source": [
|
||||
"from llama_stack_client.lib.inference.event_logger import EventLogger\n",
|
||||
"\n",
|
||||
"message = {\n",
|
||||
" \"role\": \"user\",\n",
|
||||
" \"content\": 'Write me a sonnet about llama'\n",
|
||||
"}\n",
|
||||
"print(f'User> {message[\"content\"]}', 'green')\n",
|
||||
"message = {\"role\": \"user\", \"content\": \"Write me a sonnet about llama\"}\n",
|
||||
"print(f'User> {message[\"content\"]}', \"green\")\n",
|
||||
"\n",
|
||||
"response = client.inference.chat_completion(\n",
|
||||
" messages=[message],\n",
|
||||
" model_id=model_id,\n",
|
||||
" stream=True, # <-----------\n",
|
||||
" stream=True, # <-----------\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Print the tokens while they are received\n",
|
||||
"for log in EventLogger().log(response):\n",
|
||||
" log.print()"
|
||||
" log.print()\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
@ -1045,26 +1047,26 @@
|
|||
"source": [
|
||||
"from pydantic import BaseModel\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class Output(BaseModel):\n",
|
||||
" name: str\n",
|
||||
" year_born: str\n",
|
||||
" year_retired: str\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"user_input = \"Michael Jordan was born in 1963. He played basketball for the Chicago Bulls. He retired in 2003. Extract this information into JSON for me. \"\n",
|
||||
"response = client.inference.completion(\n",
|
||||
" model_id=model_id,\n",
|
||||
" content=user_input,\n",
|
||||
" stream=False,\n",
|
||||
" sampling_params={\n",
|
||||
" \"max_tokens\": 50,\n",
|
||||
" },\n",
|
||||
" sampling_params={\"strategy\": {\"type\": \"greedy\"}, \"max_tokens\": 50},\n",
|
||||
" response_format={\n",
|
||||
" \"type\": \"json_schema\",\n",
|
||||
" \"json_schema\": Output.model_json_schema(),\n",
|
||||
" },\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"pprint(response)"
|
||||
"pprint(response)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
@ -1220,7 +1222,7 @@
|
|||
" shield_id=available_shields[0],\n",
|
||||
" params={},\n",
|
||||
" )\n",
|
||||
" pprint(response)"
|
||||
" pprint(response)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
@ -1489,8 +1491,8 @@
|
|||
"source": [
|
||||
"from llama_stack_client.lib.agents.agent import Agent\n",
|
||||
"from llama_stack_client.lib.agents.event_logger import EventLogger\n",
|
||||
"from llama_stack_client.types.agent_create_params import AgentConfig\n",
|
||||
"from llama_stack_client.types import Attachment\n",
|
||||
"from llama_stack_client.types.agent_create_params import AgentConfig\n",
|
||||
"from termcolor import cprint\n",
|
||||
"\n",
|
||||
"urls = [\"chat.rst\", \"llama3.rst\", \"datasets.rst\", \"lora_finetune.rst\"]\n",
|
||||
|
|
@ -1522,14 +1524,14 @@
|
|||
" ),\n",
|
||||
"]\n",
|
||||
"for prompt, attachments in user_prompts:\n",
|
||||
" cprint(f'User> {prompt}', 'green')\n",
|
||||
" cprint(f\"User> {prompt}\", \"green\")\n",
|
||||
" response = rag_agent.create_turn(\n",
|
||||
" messages=[{\"role\": \"user\", \"content\": prompt}],\n",
|
||||
" attachments=attachments,\n",
|
||||
" session_id=session_id,\n",
|
||||
" )\n",
|
||||
" for log in EventLogger().log(response):\n",
|
||||
" log.print()"
|
||||
" log.print()\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
@ -1560,8 +1562,8 @@
|
|||
"search_tool = {\n",
|
||||
" \"type\": \"brave_search\",\n",
|
||||
" \"engine\": \"tavily\",\n",
|
||||
" \"api_key\": userdata.get(\"TAVILY_SEARCH_API_KEY\")\n",
|
||||
"}"
|
||||
" \"api_key\": userdata.get(\"TAVILY_SEARCH_API_KEY\"),\n",
|
||||
"}\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
@ -1608,7 +1610,7 @@
|
|||
"\n",
|
||||
"session_id = agent.create_session(\"test-session\")\n",
|
||||
"for prompt in user_prompts:\n",
|
||||
" cprint(f'User> {prompt}', 'green')\n",
|
||||
" cprint(f\"User> {prompt}\", \"green\")\n",
|
||||
" response = agent.create_turn(\n",
|
||||
" messages=[\n",
|
||||
" {\n",
|
||||
|
|
@ -1758,7 +1760,7 @@
|
|||
" search_tool,\n",
|
||||
" {\n",
|
||||
" \"type\": \"code_interpreter\",\n",
|
||||
" }\n",
|
||||
" },\n",
|
||||
" ],\n",
|
||||
" tool_choice=\"required\",\n",
|
||||
" input_shields=[],\n",
|
||||
|
|
@ -1788,7 +1790,7 @@
|
|||
"]\n",
|
||||
"\n",
|
||||
"for prompt in user_prompts:\n",
|
||||
" cprint(f'User> {prompt}', 'green')\n",
|
||||
" cprint(f\"User> {prompt}\", \"green\")\n",
|
||||
" response = codex_agent.create_turn(\n",
|
||||
" messages=[\n",
|
||||
" {\n",
|
||||
|
|
@ -1841,27 +1843,57 @@
|
|||
}
|
||||
],
|
||||
"source": [
|
||||
"import pandas as pd\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"import pandas as pd\n",
|
||||
"\n",
|
||||
"# Read the CSV file\n",
|
||||
"df = pd.read_csv('/tmp/tmpco0s0o4_/LOdZoVp1inflation.csv')\n",
|
||||
"df = pd.read_csv(\"/tmp/tmpco0s0o4_/LOdZoVp1inflation.csv\")\n",
|
||||
"\n",
|
||||
"# Extract the year and inflation rate from the CSV file\n",
|
||||
"df['Year'] = pd.to_datetime(df['Year'], format='%Y')\n",
|
||||
"df = df.rename(columns={'Jan': 'Jan Rate', 'Feb': 'Feb Rate', 'Mar': 'Mar Rate', 'Apr': 'Apr Rate', 'May': 'May Rate', 'Jun': 'Jun Rate', 'Jul': 'Jul Rate', 'Aug': 'Aug Rate', 'Sep': 'Sep Rate', 'Oct': 'Oct Rate', 'Nov': 'Nov Rate', 'Dec': 'Dec Rate'})\n",
|
||||
"df[\"Year\"] = pd.to_datetime(df[\"Year\"], format=\"%Y\")\n",
|
||||
"df = df.rename(\n",
|
||||
" columns={\n",
|
||||
" \"Jan\": \"Jan Rate\",\n",
|
||||
" \"Feb\": \"Feb Rate\",\n",
|
||||
" \"Mar\": \"Mar Rate\",\n",
|
||||
" \"Apr\": \"Apr Rate\",\n",
|
||||
" \"May\": \"May Rate\",\n",
|
||||
" \"Jun\": \"Jun Rate\",\n",
|
||||
" \"Jul\": \"Jul Rate\",\n",
|
||||
" \"Aug\": \"Aug Rate\",\n",
|
||||
" \"Sep\": \"Sep Rate\",\n",
|
||||
" \"Oct\": \"Oct Rate\",\n",
|
||||
" \"Nov\": \"Nov Rate\",\n",
|
||||
" \"Dec\": \"Dec Rate\",\n",
|
||||
" }\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Calculate the average yearly inflation rate\n",
|
||||
"df['Yearly Inflation'] = df[['Jan Rate', 'Feb Rate', 'Mar Rate', 'Apr Rate', 'May Rate', 'Jun Rate', 'Jul Rate', 'Aug Rate', 'Sep Rate', 'Oct Rate', 'Nov Rate', 'Dec Rate']].mean(axis=1)\n",
|
||||
"df[\"Yearly Inflation\"] = df[\n",
|
||||
" [\n",
|
||||
" \"Jan Rate\",\n",
|
||||
" \"Feb Rate\",\n",
|
||||
" \"Mar Rate\",\n",
|
||||
" \"Apr Rate\",\n",
|
||||
" \"May Rate\",\n",
|
||||
" \"Jun Rate\",\n",
|
||||
" \"Jul Rate\",\n",
|
||||
" \"Aug Rate\",\n",
|
||||
" \"Sep Rate\",\n",
|
||||
" \"Oct Rate\",\n",
|
||||
" \"Nov Rate\",\n",
|
||||
" \"Dec Rate\",\n",
|
||||
" ]\n",
|
||||
"].mean(axis=1)\n",
|
||||
"\n",
|
||||
"# Plot the average yearly inflation rate as a time series\n",
|
||||
"plt.figure(figsize=(10, 6))\n",
|
||||
"plt.plot(df['Year'], df['Yearly Inflation'], marker='o')\n",
|
||||
"plt.title('Average Yearly Inflation Rate')\n",
|
||||
"plt.xlabel('Year')\n",
|
||||
"plt.ylabel('Inflation Rate (%)')\n",
|
||||
"plt.plot(df[\"Year\"], df[\"Yearly Inflation\"], marker=\"o\")\n",
|
||||
"plt.title(\"Average Yearly Inflation Rate\")\n",
|
||||
"plt.xlabel(\"Year\")\n",
|
||||
"plt.ylabel(\"Inflation Rate (%)\")\n",
|
||||
"plt.grid(True)\n",
|
||||
"plt.show()"
|
||||
"plt.show()\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
@ -2035,6 +2067,8 @@
|
|||
"source": [
|
||||
"# disable logging for clean server logs\n",
|
||||
"import logging\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def remove_root_handlers():\n",
|
||||
" root_logger = logging.getLogger()\n",
|
||||
" for handler in root_logger.handlers[:]:\n",
|
||||
|
|
@ -2042,7 +2076,7 @@
|
|||
" print(f\"Removed handler {handler.__class__.__name__} from root logger\")\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"remove_root_handlers()"
|
||||
"remove_root_handlers()\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
@ -2083,10 +2117,10 @@
|
|||
}
|
||||
],
|
||||
"source": [
|
||||
"from google.colab import userdata\n",
|
||||
"from llama_stack_client.lib.agents.agent import Agent\n",
|
||||
"from llama_stack_client.lib.agents.event_logger import EventLogger\n",
|
||||
"from llama_stack_client.types.agent_create_params import AgentConfig\n",
|
||||
"from google.colab import userdata\n",
|
||||
"\n",
|
||||
"agent_config = AgentConfig(\n",
|
||||
" model=\"meta-llama/Llama-3.1-405B-Instruct\",\n",
|
||||
|
|
@ -2096,7 +2130,7 @@
|
|||
" {\n",
|
||||
" \"type\": \"brave_search\",\n",
|
||||
" \"engine\": \"tavily\",\n",
|
||||
" \"api_key\": userdata.get(\"TAVILY_SEARCH_API_KEY\")\n",
|
||||
" \"api_key\": userdata.get(\"TAVILY_SEARCH_API_KEY\"),\n",
|
||||
" }\n",
|
||||
" ]\n",
|
||||
" ),\n",
|
||||
|
|
@ -2125,7 +2159,7 @@
|
|||
" )\n",
|
||||
"\n",
|
||||
" for log in EventLogger().log(response):\n",
|
||||
" log.print()"
|
||||
" log.print()\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
@ -2265,20 +2299,21 @@
|
|||
"source": [
|
||||
"print(f\"Getting traces for session_id={session_id}\")\n",
|
||||
"import json\n",
|
||||
"\n",
|
||||
"from rich.pretty import pprint\n",
|
||||
"\n",
|
||||
"agent_logs = []\n",
|
||||
"\n",
|
||||
"for span in client.telemetry.query_spans(\n",
|
||||
" attribute_filters=[\n",
|
||||
" {\"key\": \"session_id\", \"op\": \"eq\", \"value\": session_id},\n",
|
||||
" {\"key\": \"session_id\", \"op\": \"eq\", \"value\": session_id},\n",
|
||||
" ],\n",
|
||||
" attributes_to_return=[\"input\", \"output\"]\n",
|
||||
" ):\n",
|
||||
" if span.attributes[\"output\"] != \"no shields\":\n",
|
||||
" agent_logs.append(span.attributes)\n",
|
||||
" attributes_to_return=[\"input\", \"output\"],\n",
|
||||
"):\n",
|
||||
" if span.attributes[\"output\"] != \"no shields\":\n",
|
||||
" agent_logs.append(span.attributes)\n",
|
||||
"\n",
|
||||
"pprint(agent_logs)"
|
||||
"pprint(agent_logs)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
@ -2389,23 +2424,25 @@
|
|||
"eval_rows = []\n",
|
||||
"\n",
|
||||
"for log in agent_logs:\n",
|
||||
" last_msg = log['input'][-1]\n",
|
||||
" if \"\\\"role\\\":\\\"user\\\"\" in last_msg:\n",
|
||||
" eval_rows.append(\n",
|
||||
" {\n",
|
||||
" \"input_query\": last_msg,\n",
|
||||
" \"generated_answer\": log[\"output\"],\n",
|
||||
" # check if generated_answer uses tools brave_search\n",
|
||||
" \"expected_answer\": \"brave_search\",\n",
|
||||
" },\n",
|
||||
" )\n",
|
||||
" last_msg = log[\"input\"][-1]\n",
|
||||
" if '\"role\":\"user\"' in last_msg:\n",
|
||||
" eval_rows.append(\n",
|
||||
" {\n",
|
||||
" \"input_query\": last_msg,\n",
|
||||
" \"generated_answer\": log[\"output\"],\n",
|
||||
" # check if generated_answer uses tools brave_search\n",
|
||||
" \"expected_answer\": \"brave_search\",\n",
|
||||
" },\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"pprint(eval_rows)\n",
|
||||
"scoring_params = {\n",
|
||||
" \"basic::subset_of\": None,\n",
|
||||
"}\n",
|
||||
"scoring_response = client.scoring.score(input_rows=eval_rows, scoring_functions=scoring_params)\n",
|
||||
"pprint(scoring_response)"
|
||||
"scoring_response = client.scoring.score(\n",
|
||||
" input_rows=eval_rows, scoring_functions=scoring_params\n",
|
||||
")\n",
|
||||
"pprint(scoring_response)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
@ -2506,7 +2543,9 @@
|
|||
"EXPECTED_RESPONSE: {expected_answer}\n",
|
||||
"\"\"\"\n",
|
||||
"\n",
|
||||
"input_query = \"What are the top 5 topics that were explained? Only list succinct bullet points.\"\n",
|
||||
"input_query = (\n",
|
||||
" \"What are the top 5 topics that were explained? Only list succinct bullet points.\"\n",
|
||||
")\n",
|
||||
"generated_answer = \"\"\"\n",
|
||||
"Here are the top 5 topics that were explained in the documentation for Torchtune:\n",
|
||||
"\n",
|
||||
|
|
@ -2537,7 +2576,7 @@
|
|||
"}\n",
|
||||
"\n",
|
||||
"response = client.scoring.score(input_rows=rows, scoring_functions=scoring_params)\n",
|
||||
"pprint(response)"
|
||||
"pprint(response)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue