mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-06 02:32:40 +00:00
Update Strategy in SamplingParams to be a union
This commit is contained in:
parent
300e6e2702
commit
dea575c994
28 changed files with 600 additions and 377 deletions
|
@ -713,13 +713,15 @@
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"import os\n",
|
"import os\n",
|
||||||
|
"\n",
|
||||||
"from google.colab import userdata\n",
|
"from google.colab import userdata\n",
|
||||||
"\n",
|
"\n",
|
||||||
"os.environ['TOGETHER_API_KEY'] = userdata.get('TOGETHER_API_KEY')\n",
|
"os.environ[\"TOGETHER_API_KEY\"] = userdata.get(\"TOGETHER_API_KEY\")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"from llama_stack.distribution.library_client import LlamaStackAsLibraryClient\n",
|
"from llama_stack.distribution.library_client import LlamaStackAsLibraryClient\n",
|
||||||
|
"\n",
|
||||||
"client = LlamaStackAsLibraryClient(\"together\")\n",
|
"client = LlamaStackAsLibraryClient(\"together\")\n",
|
||||||
"_ = client.initialize()"
|
"_ = client.initialize()\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -769,6 +771,7 @@
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"from rich.pretty import pprint\n",
|
"from rich.pretty import pprint\n",
|
||||||
|
"\n",
|
||||||
"print(\"Available models:\")\n",
|
"print(\"Available models:\")\n",
|
||||||
"for m in client.models.list():\n",
|
"for m in client.models.list():\n",
|
||||||
" print(f\"{m.identifier} (provider's alias: {m.provider_resource_id}) \")\n",
|
" print(f\"{m.identifier} (provider's alias: {m.provider_resource_id}) \")\n",
|
||||||
|
@ -777,7 +780,7 @@
|
||||||
"print(\"Available shields (safety models):\")\n",
|
"print(\"Available shields (safety models):\")\n",
|
||||||
"for s in client.shields.list():\n",
|
"for s in client.shields.list():\n",
|
||||||
" print(s.identifier)\n",
|
" print(s.identifier)\n",
|
||||||
"print(\"----\")"
|
"print(\"----\")\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -822,7 +825,7 @@
|
||||||
"source": [
|
"source": [
|
||||||
"model_id = \"meta-llama/Llama-3.1-70B-Instruct\"\n",
|
"model_id = \"meta-llama/Llama-3.1-70B-Instruct\"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"model_id"
|
"model_id\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -863,11 +866,11 @@
|
||||||
" model_id=model_id,\n",
|
" model_id=model_id,\n",
|
||||||
" messages=[\n",
|
" messages=[\n",
|
||||||
" {\"role\": \"system\", \"content\": \"You are a friendly assistant.\"},\n",
|
" {\"role\": \"system\", \"content\": \"You are a friendly assistant.\"},\n",
|
||||||
" {\"role\": \"user\", \"content\": \"Write a two-sentence poem about llama.\"}\n",
|
" {\"role\": \"user\", \"content\": \"Write a two-sentence poem about llama.\"},\n",
|
||||||
" ],\n",
|
" ],\n",
|
||||||
")\n",
|
")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"print(response.completion_message.content)"
|
"print(response.completion_message.content)\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -900,12 +903,13 @@
|
||||||
"source": [
|
"source": [
|
||||||
"from termcolor import cprint\n",
|
"from termcolor import cprint\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
"\n",
|
||||||
"def chat_loop():\n",
|
"def chat_loop():\n",
|
||||||
" conversation_history = []\n",
|
" conversation_history = []\n",
|
||||||
" while True:\n",
|
" while True:\n",
|
||||||
" user_input = input('User> ')\n",
|
" user_input = input(\"User> \")\n",
|
||||||
" if user_input.lower() in ['exit', 'quit', 'bye']:\n",
|
" if user_input.lower() in [\"exit\", \"quit\", \"bye\"]:\n",
|
||||||
" cprint('Ending conversation. Goodbye!', 'yellow')\n",
|
" cprint(\"Ending conversation. Goodbye!\", \"yellow\")\n",
|
||||||
" break\n",
|
" break\n",
|
||||||
"\n",
|
"\n",
|
||||||
" user_message = {\"role\": \"user\", \"content\": user_input}\n",
|
" user_message = {\"role\": \"user\", \"content\": user_input}\n",
|
||||||
|
@ -915,14 +919,15 @@
|
||||||
" messages=conversation_history,\n",
|
" messages=conversation_history,\n",
|
||||||
" model_id=model_id,\n",
|
" model_id=model_id,\n",
|
||||||
" )\n",
|
" )\n",
|
||||||
" cprint(f'> Response: {response.completion_message.content}', 'cyan')\n",
|
" cprint(f\"> Response: {response.completion_message.content}\", \"cyan\")\n",
|
||||||
"\n",
|
"\n",
|
||||||
" assistant_message = {\n",
|
" assistant_message = {\n",
|
||||||
" \"role\": \"assistant\", # was user\n",
|
" \"role\": \"assistant\", # was user\n",
|
||||||
" \"content\": response.completion_message.content,\n",
|
" \"content\": response.completion_message.content,\n",
|
||||||
" }\n",
|
" }\n",
|
||||||
" conversation_history.append(assistant_message)\n",
|
" conversation_history.append(assistant_message)\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
"\n",
|
||||||
"chat_loop()\n"
|
"chat_loop()\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
@ -978,21 +983,18 @@
|
||||||
"source": [
|
"source": [
|
||||||
"from llama_stack_client.lib.inference.event_logger import EventLogger\n",
|
"from llama_stack_client.lib.inference.event_logger import EventLogger\n",
|
||||||
"\n",
|
"\n",
|
||||||
"message = {\n",
|
"message = {\"role\": \"user\", \"content\": \"Write me a sonnet about llama\"}\n",
|
||||||
" \"role\": \"user\",\n",
|
"print(f'User> {message[\"content\"]}', \"green\")\n",
|
||||||
" \"content\": 'Write me a sonnet about llama'\n",
|
|
||||||
"}\n",
|
|
||||||
"print(f'User> {message[\"content\"]}', 'green')\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
"response = client.inference.chat_completion(\n",
|
"response = client.inference.chat_completion(\n",
|
||||||
" messages=[message],\n",
|
" messages=[message],\n",
|
||||||
" model_id=model_id,\n",
|
" model_id=model_id,\n",
|
||||||
" stream=True, # <-----------\n",
|
" stream=True, # <-----------\n",
|
||||||
")\n",
|
")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Print the tokens while they are received\n",
|
"# Print the tokens while they are received\n",
|
||||||
"for log in EventLogger().log(response):\n",
|
"for log in EventLogger().log(response):\n",
|
||||||
" log.print()"
|
" log.print()\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -1045,26 +1047,26 @@
|
||||||
"source": [
|
"source": [
|
||||||
"from pydantic import BaseModel\n",
|
"from pydantic import BaseModel\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
"\n",
|
||||||
"class Output(BaseModel):\n",
|
"class Output(BaseModel):\n",
|
||||||
" name: str\n",
|
" name: str\n",
|
||||||
" year_born: str\n",
|
" year_born: str\n",
|
||||||
" year_retired: str\n",
|
" year_retired: str\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
"\n",
|
||||||
"user_input = \"Michael Jordan was born in 1963. He played basketball for the Chicago Bulls. He retired in 2003. Extract this information into JSON for me. \"\n",
|
"user_input = \"Michael Jordan was born in 1963. He played basketball for the Chicago Bulls. He retired in 2003. Extract this information into JSON for me. \"\n",
|
||||||
"response = client.inference.completion(\n",
|
"response = client.inference.completion(\n",
|
||||||
" model_id=model_id,\n",
|
" model_id=model_id,\n",
|
||||||
" content=user_input,\n",
|
" content=user_input,\n",
|
||||||
" stream=False,\n",
|
" stream=False,\n",
|
||||||
" sampling_params={\n",
|
" sampling_params={\"strategy\": {\"type\": \"greedy\"}, \"max_tokens\": 50},\n",
|
||||||
" \"max_tokens\": 50,\n",
|
|
||||||
" },\n",
|
|
||||||
" response_format={\n",
|
" response_format={\n",
|
||||||
" \"type\": \"json_schema\",\n",
|
" \"type\": \"json_schema\",\n",
|
||||||
" \"json_schema\": Output.model_json_schema(),\n",
|
" \"json_schema\": Output.model_json_schema(),\n",
|
||||||
" },\n",
|
" },\n",
|
||||||
")\n",
|
")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"pprint(response)"
|
"pprint(response)\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -1220,7 +1222,7 @@
|
||||||
" shield_id=available_shields[0],\n",
|
" shield_id=available_shields[0],\n",
|
||||||
" params={},\n",
|
" params={},\n",
|
||||||
" )\n",
|
" )\n",
|
||||||
" pprint(response)"
|
" pprint(response)\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -1489,8 +1491,8 @@
|
||||||
"source": [
|
"source": [
|
||||||
"from llama_stack_client.lib.agents.agent import Agent\n",
|
"from llama_stack_client.lib.agents.agent import Agent\n",
|
||||||
"from llama_stack_client.lib.agents.event_logger import EventLogger\n",
|
"from llama_stack_client.lib.agents.event_logger import EventLogger\n",
|
||||||
"from llama_stack_client.types.agent_create_params import AgentConfig\n",
|
|
||||||
"from llama_stack_client.types import Attachment\n",
|
"from llama_stack_client.types import Attachment\n",
|
||||||
|
"from llama_stack_client.types.agent_create_params import AgentConfig\n",
|
||||||
"from termcolor import cprint\n",
|
"from termcolor import cprint\n",
|
||||||
"\n",
|
"\n",
|
||||||
"urls = [\"chat.rst\", \"llama3.rst\", \"datasets.rst\", \"lora_finetune.rst\"]\n",
|
"urls = [\"chat.rst\", \"llama3.rst\", \"datasets.rst\", \"lora_finetune.rst\"]\n",
|
||||||
|
@ -1522,14 +1524,14 @@
|
||||||
" ),\n",
|
" ),\n",
|
||||||
"]\n",
|
"]\n",
|
||||||
"for prompt, attachments in user_prompts:\n",
|
"for prompt, attachments in user_prompts:\n",
|
||||||
" cprint(f'User> {prompt}', 'green')\n",
|
" cprint(f\"User> {prompt}\", \"green\")\n",
|
||||||
" response = rag_agent.create_turn(\n",
|
" response = rag_agent.create_turn(\n",
|
||||||
" messages=[{\"role\": \"user\", \"content\": prompt}],\n",
|
" messages=[{\"role\": \"user\", \"content\": prompt}],\n",
|
||||||
" attachments=attachments,\n",
|
" attachments=attachments,\n",
|
||||||
" session_id=session_id,\n",
|
" session_id=session_id,\n",
|
||||||
" )\n",
|
" )\n",
|
||||||
" for log in EventLogger().log(response):\n",
|
" for log in EventLogger().log(response):\n",
|
||||||
" log.print()"
|
" log.print()\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -1560,8 +1562,8 @@
|
||||||
"search_tool = {\n",
|
"search_tool = {\n",
|
||||||
" \"type\": \"brave_search\",\n",
|
" \"type\": \"brave_search\",\n",
|
||||||
" \"engine\": \"tavily\",\n",
|
" \"engine\": \"tavily\",\n",
|
||||||
" \"api_key\": userdata.get(\"TAVILY_SEARCH_API_KEY\")\n",
|
" \"api_key\": userdata.get(\"TAVILY_SEARCH_API_KEY\"),\n",
|
||||||
"}"
|
"}\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -1608,7 +1610,7 @@
|
||||||
"\n",
|
"\n",
|
||||||
"session_id = agent.create_session(\"test-session\")\n",
|
"session_id = agent.create_session(\"test-session\")\n",
|
||||||
"for prompt in user_prompts:\n",
|
"for prompt in user_prompts:\n",
|
||||||
" cprint(f'User> {prompt}', 'green')\n",
|
" cprint(f\"User> {prompt}\", \"green\")\n",
|
||||||
" response = agent.create_turn(\n",
|
" response = agent.create_turn(\n",
|
||||||
" messages=[\n",
|
" messages=[\n",
|
||||||
" {\n",
|
" {\n",
|
||||||
|
@ -1758,7 +1760,7 @@
|
||||||
" search_tool,\n",
|
" search_tool,\n",
|
||||||
" {\n",
|
" {\n",
|
||||||
" \"type\": \"code_interpreter\",\n",
|
" \"type\": \"code_interpreter\",\n",
|
||||||
" }\n",
|
" },\n",
|
||||||
" ],\n",
|
" ],\n",
|
||||||
" tool_choice=\"required\",\n",
|
" tool_choice=\"required\",\n",
|
||||||
" input_shields=[],\n",
|
" input_shields=[],\n",
|
||||||
|
@ -1788,7 +1790,7 @@
|
||||||
"]\n",
|
"]\n",
|
||||||
"\n",
|
"\n",
|
||||||
"for prompt in user_prompts:\n",
|
"for prompt in user_prompts:\n",
|
||||||
" cprint(f'User> {prompt}', 'green')\n",
|
" cprint(f\"User> {prompt}\", \"green\")\n",
|
||||||
" response = codex_agent.create_turn(\n",
|
" response = codex_agent.create_turn(\n",
|
||||||
" messages=[\n",
|
" messages=[\n",
|
||||||
" {\n",
|
" {\n",
|
||||||
|
@ -1841,27 +1843,57 @@
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"import pandas as pd\n",
|
|
||||||
"import matplotlib.pyplot as plt\n",
|
"import matplotlib.pyplot as plt\n",
|
||||||
|
"import pandas as pd\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Read the CSV file\n",
|
"# Read the CSV file\n",
|
||||||
"df = pd.read_csv('/tmp/tmpco0s0o4_/LOdZoVp1inflation.csv')\n",
|
"df = pd.read_csv(\"/tmp/tmpco0s0o4_/LOdZoVp1inflation.csv\")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Extract the year and inflation rate from the CSV file\n",
|
"# Extract the year and inflation rate from the CSV file\n",
|
||||||
"df['Year'] = pd.to_datetime(df['Year'], format='%Y')\n",
|
"df[\"Year\"] = pd.to_datetime(df[\"Year\"], format=\"%Y\")\n",
|
||||||
"df = df.rename(columns={'Jan': 'Jan Rate', 'Feb': 'Feb Rate', 'Mar': 'Mar Rate', 'Apr': 'Apr Rate', 'May': 'May Rate', 'Jun': 'Jun Rate', 'Jul': 'Jul Rate', 'Aug': 'Aug Rate', 'Sep': 'Sep Rate', 'Oct': 'Oct Rate', 'Nov': 'Nov Rate', 'Dec': 'Dec Rate'})\n",
|
"df = df.rename(\n",
|
||||||
|
" columns={\n",
|
||||||
|
" \"Jan\": \"Jan Rate\",\n",
|
||||||
|
" \"Feb\": \"Feb Rate\",\n",
|
||||||
|
" \"Mar\": \"Mar Rate\",\n",
|
||||||
|
" \"Apr\": \"Apr Rate\",\n",
|
||||||
|
" \"May\": \"May Rate\",\n",
|
||||||
|
" \"Jun\": \"Jun Rate\",\n",
|
||||||
|
" \"Jul\": \"Jul Rate\",\n",
|
||||||
|
" \"Aug\": \"Aug Rate\",\n",
|
||||||
|
" \"Sep\": \"Sep Rate\",\n",
|
||||||
|
" \"Oct\": \"Oct Rate\",\n",
|
||||||
|
" \"Nov\": \"Nov Rate\",\n",
|
||||||
|
" \"Dec\": \"Dec Rate\",\n",
|
||||||
|
" }\n",
|
||||||
|
")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Calculate the average yearly inflation rate\n",
|
"# Calculate the average yearly inflation rate\n",
|
||||||
"df['Yearly Inflation'] = df[['Jan Rate', 'Feb Rate', 'Mar Rate', 'Apr Rate', 'May Rate', 'Jun Rate', 'Jul Rate', 'Aug Rate', 'Sep Rate', 'Oct Rate', 'Nov Rate', 'Dec Rate']].mean(axis=1)\n",
|
"df[\"Yearly Inflation\"] = df[\n",
|
||||||
|
" [\n",
|
||||||
|
" \"Jan Rate\",\n",
|
||||||
|
" \"Feb Rate\",\n",
|
||||||
|
" \"Mar Rate\",\n",
|
||||||
|
" \"Apr Rate\",\n",
|
||||||
|
" \"May Rate\",\n",
|
||||||
|
" \"Jun Rate\",\n",
|
||||||
|
" \"Jul Rate\",\n",
|
||||||
|
" \"Aug Rate\",\n",
|
||||||
|
" \"Sep Rate\",\n",
|
||||||
|
" \"Oct Rate\",\n",
|
||||||
|
" \"Nov Rate\",\n",
|
||||||
|
" \"Dec Rate\",\n",
|
||||||
|
" ]\n",
|
||||||
|
"].mean(axis=1)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Plot the average yearly inflation rate as a time series\n",
|
"# Plot the average yearly inflation rate as a time series\n",
|
||||||
"plt.figure(figsize=(10, 6))\n",
|
"plt.figure(figsize=(10, 6))\n",
|
||||||
"plt.plot(df['Year'], df['Yearly Inflation'], marker='o')\n",
|
"plt.plot(df[\"Year\"], df[\"Yearly Inflation\"], marker=\"o\")\n",
|
||||||
"plt.title('Average Yearly Inflation Rate')\n",
|
"plt.title(\"Average Yearly Inflation Rate\")\n",
|
||||||
"plt.xlabel('Year')\n",
|
"plt.xlabel(\"Year\")\n",
|
||||||
"plt.ylabel('Inflation Rate (%)')\n",
|
"plt.ylabel(\"Inflation Rate (%)\")\n",
|
||||||
"plt.grid(True)\n",
|
"plt.grid(True)\n",
|
||||||
"plt.show()"
|
"plt.show()\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -2035,6 +2067,8 @@
|
||||||
"source": [
|
"source": [
|
||||||
"# disable logging for clean server logs\n",
|
"# disable logging for clean server logs\n",
|
||||||
"import logging\n",
|
"import logging\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
"def remove_root_handlers():\n",
|
"def remove_root_handlers():\n",
|
||||||
" root_logger = logging.getLogger()\n",
|
" root_logger = logging.getLogger()\n",
|
||||||
" for handler in root_logger.handlers[:]:\n",
|
" for handler in root_logger.handlers[:]:\n",
|
||||||
|
@ -2042,7 +2076,7 @@
|
||||||
" print(f\"Removed handler {handler.__class__.__name__} from root logger\")\n",
|
" print(f\"Removed handler {handler.__class__.__name__} from root logger\")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"remove_root_handlers()"
|
"remove_root_handlers()\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -2083,10 +2117,10 @@
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
|
"from google.colab import userdata\n",
|
||||||
"from llama_stack_client.lib.agents.agent import Agent\n",
|
"from llama_stack_client.lib.agents.agent import Agent\n",
|
||||||
"from llama_stack_client.lib.agents.event_logger import EventLogger\n",
|
"from llama_stack_client.lib.agents.event_logger import EventLogger\n",
|
||||||
"from llama_stack_client.types.agent_create_params import AgentConfig\n",
|
"from llama_stack_client.types.agent_create_params import AgentConfig\n",
|
||||||
"from google.colab import userdata\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
"agent_config = AgentConfig(\n",
|
"agent_config = AgentConfig(\n",
|
||||||
" model=\"meta-llama/Llama-3.1-405B-Instruct\",\n",
|
" model=\"meta-llama/Llama-3.1-405B-Instruct\",\n",
|
||||||
|
@ -2096,7 +2130,7 @@
|
||||||
" {\n",
|
" {\n",
|
||||||
" \"type\": \"brave_search\",\n",
|
" \"type\": \"brave_search\",\n",
|
||||||
" \"engine\": \"tavily\",\n",
|
" \"engine\": \"tavily\",\n",
|
||||||
" \"api_key\": userdata.get(\"TAVILY_SEARCH_API_KEY\")\n",
|
" \"api_key\": userdata.get(\"TAVILY_SEARCH_API_KEY\"),\n",
|
||||||
" }\n",
|
" }\n",
|
||||||
" ]\n",
|
" ]\n",
|
||||||
" ),\n",
|
" ),\n",
|
||||||
|
@ -2125,7 +2159,7 @@
|
||||||
" )\n",
|
" )\n",
|
||||||
"\n",
|
"\n",
|
||||||
" for log in EventLogger().log(response):\n",
|
" for log in EventLogger().log(response):\n",
|
||||||
" log.print()"
|
" log.print()\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -2265,20 +2299,21 @@
|
||||||
"source": [
|
"source": [
|
||||||
"print(f\"Getting traces for session_id={session_id}\")\n",
|
"print(f\"Getting traces for session_id={session_id}\")\n",
|
||||||
"import json\n",
|
"import json\n",
|
||||||
|
"\n",
|
||||||
"from rich.pretty import pprint\n",
|
"from rich.pretty import pprint\n",
|
||||||
"\n",
|
"\n",
|
||||||
"agent_logs = []\n",
|
"agent_logs = []\n",
|
||||||
"\n",
|
"\n",
|
||||||
"for span in client.telemetry.query_spans(\n",
|
"for span in client.telemetry.query_spans(\n",
|
||||||
" attribute_filters=[\n",
|
" attribute_filters=[\n",
|
||||||
" {\"key\": \"session_id\", \"op\": \"eq\", \"value\": session_id},\n",
|
" {\"key\": \"session_id\", \"op\": \"eq\", \"value\": session_id},\n",
|
||||||
" ],\n",
|
" ],\n",
|
||||||
" attributes_to_return=[\"input\", \"output\"]\n",
|
" attributes_to_return=[\"input\", \"output\"],\n",
|
||||||
" ):\n",
|
"):\n",
|
||||||
" if span.attributes[\"output\"] != \"no shields\":\n",
|
" if span.attributes[\"output\"] != \"no shields\":\n",
|
||||||
" agent_logs.append(span.attributes)\n",
|
" agent_logs.append(span.attributes)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"pprint(agent_logs)"
|
"pprint(agent_logs)\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -2389,23 +2424,25 @@
|
||||||
"eval_rows = []\n",
|
"eval_rows = []\n",
|
||||||
"\n",
|
"\n",
|
||||||
"for log in agent_logs:\n",
|
"for log in agent_logs:\n",
|
||||||
" last_msg = log['input'][-1]\n",
|
" last_msg = log[\"input\"][-1]\n",
|
||||||
" if \"\\\"role\\\":\\\"user\\\"\" in last_msg:\n",
|
" if '\"role\":\"user\"' in last_msg:\n",
|
||||||
" eval_rows.append(\n",
|
" eval_rows.append(\n",
|
||||||
" {\n",
|
" {\n",
|
||||||
" \"input_query\": last_msg,\n",
|
" \"input_query\": last_msg,\n",
|
||||||
" \"generated_answer\": log[\"output\"],\n",
|
" \"generated_answer\": log[\"output\"],\n",
|
||||||
" # check if generated_answer uses tools brave_search\n",
|
" # check if generated_answer uses tools brave_search\n",
|
||||||
" \"expected_answer\": \"brave_search\",\n",
|
" \"expected_answer\": \"brave_search\",\n",
|
||||||
" },\n",
|
" },\n",
|
||||||
" )\n",
|
" )\n",
|
||||||
"\n",
|
"\n",
|
||||||
"pprint(eval_rows)\n",
|
"pprint(eval_rows)\n",
|
||||||
"scoring_params = {\n",
|
"scoring_params = {\n",
|
||||||
" \"basic::subset_of\": None,\n",
|
" \"basic::subset_of\": None,\n",
|
||||||
"}\n",
|
"}\n",
|
||||||
"scoring_response = client.scoring.score(input_rows=eval_rows, scoring_functions=scoring_params)\n",
|
"scoring_response = client.scoring.score(\n",
|
||||||
"pprint(scoring_response)"
|
" input_rows=eval_rows, scoring_functions=scoring_params\n",
|
||||||
|
")\n",
|
||||||
|
"pprint(scoring_response)\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -2506,7 +2543,9 @@
|
||||||
"EXPECTED_RESPONSE: {expected_answer}\n",
|
"EXPECTED_RESPONSE: {expected_answer}\n",
|
||||||
"\"\"\"\n",
|
"\"\"\"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"input_query = \"What are the top 5 topics that were explained? Only list succinct bullet points.\"\n",
|
"input_query = (\n",
|
||||||
|
" \"What are the top 5 topics that were explained? Only list succinct bullet points.\"\n",
|
||||||
|
")\n",
|
||||||
"generated_answer = \"\"\"\n",
|
"generated_answer = \"\"\"\n",
|
||||||
"Here are the top 5 topics that were explained in the documentation for Torchtune:\n",
|
"Here are the top 5 topics that were explained in the documentation for Torchtune:\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
@ -2537,7 +2576,7 @@
|
||||||
"}\n",
|
"}\n",
|
||||||
"\n",
|
"\n",
|
||||||
"response = client.scoring.score(input_rows=rows, scoring_functions=scoring_params)\n",
|
"response = client.scoring.score(input_rows=rows, scoring_functions=scoring_params)\n",
|
||||||
"pprint(response)"
|
"pprint(response)\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|
|
@ -618,11 +618,13 @@
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"import os\n",
|
"import os\n",
|
||||||
|
"\n",
|
||||||
"from google.colab import userdata\n",
|
"from google.colab import userdata\n",
|
||||||
"\n",
|
"\n",
|
||||||
"os.environ['TOGETHER_API_KEY'] = userdata.get('TOGETHER_API_KEY')\n",
|
"os.environ[\"TOGETHER_API_KEY\"] = userdata.get(\"TOGETHER_API_KEY\")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"from llama_stack.distribution.library_client import LlamaStackAsLibraryClient\n",
|
"from llama_stack.distribution.library_client import LlamaStackAsLibraryClient\n",
|
||||||
|
"\n",
|
||||||
"client = LlamaStackAsLibraryClient(\"together\")\n",
|
"client = LlamaStackAsLibraryClient(\"together\")\n",
|
||||||
"_ = client.initialize()\n",
|
"_ = client.initialize()\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
@ -631,7 +633,7 @@
|
||||||
" model_id=\"meta-llama/Llama-3.1-405B-Instruct\",\n",
|
" model_id=\"meta-llama/Llama-3.1-405B-Instruct\",\n",
|
||||||
" provider_model_id=\"meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo\",\n",
|
" provider_model_id=\"meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo\",\n",
|
||||||
" provider_id=\"together\",\n",
|
" provider_id=\"together\",\n",
|
||||||
")"
|
")\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -668,7 +670,7 @@
|
||||||
"source": [
|
"source": [
|
||||||
"name = \"llamastack/mmmu\"\n",
|
"name = \"llamastack/mmmu\"\n",
|
||||||
"subset = \"Agriculture\"\n",
|
"subset = \"Agriculture\"\n",
|
||||||
"split = \"dev\""
|
"split = \"dev\"\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -914,9 +916,10 @@
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"import datasets\n",
|
"import datasets\n",
|
||||||
|
"\n",
|
||||||
"ds = datasets.load_dataset(path=name, name=subset, split=split)\n",
|
"ds = datasets.load_dataset(path=name, name=subset, split=split)\n",
|
||||||
"ds = ds.select_columns([\"chat_completion_input\", \"input_query\", \"expected_answer\"])\n",
|
"ds = ds.select_columns([\"chat_completion_input\", \"input_query\", \"expected_answer\"])\n",
|
||||||
"eval_rows = ds.to_pandas().to_dict(orient=\"records\")"
|
"eval_rows = ds.to_pandas().to_dict(orient=\"records\")\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -1014,8 +1017,8 @@
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"from tqdm import tqdm\n",
|
|
||||||
"from rich.pretty import pprint\n",
|
"from rich.pretty import pprint\n",
|
||||||
|
"from tqdm import tqdm\n",
|
||||||
"\n",
|
"\n",
|
||||||
"SYSTEM_PROMPT_TEMPLATE = \"\"\"\n",
|
"SYSTEM_PROMPT_TEMPLATE = \"\"\"\n",
|
||||||
"You are an expert in {subject} whose job is to answer questions from the user using images.\n",
|
"You are an expert in {subject} whose job is to answer questions from the user using images.\n",
|
||||||
|
@ -1039,7 +1042,7 @@
|
||||||
"client.eval_tasks.register(\n",
|
"client.eval_tasks.register(\n",
|
||||||
" eval_task_id=\"meta-reference::mmmu\",\n",
|
" eval_task_id=\"meta-reference::mmmu\",\n",
|
||||||
" dataset_id=f\"mmmu-{subset}-{split}\",\n",
|
" dataset_id=f\"mmmu-{subset}-{split}\",\n",
|
||||||
" scoring_functions=[\"basic::regex_parser_multiple_choice_answer\"]\n",
|
" scoring_functions=[\"basic::regex_parser_multiple_choice_answer\"],\n",
|
||||||
")\n",
|
")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"response = client.eval.evaluate_rows(\n",
|
"response = client.eval.evaluate_rows(\n",
|
||||||
|
@ -1052,16 +1055,17 @@
|
||||||
" \"type\": \"model\",\n",
|
" \"type\": \"model\",\n",
|
||||||
" \"model\": \"meta-llama/Llama-3.2-90B-Vision-Instruct\",\n",
|
" \"model\": \"meta-llama/Llama-3.2-90B-Vision-Instruct\",\n",
|
||||||
" \"sampling_params\": {\n",
|
" \"sampling_params\": {\n",
|
||||||
" \"temperature\": 0.0,\n",
|
" \"strategy\": {\n",
|
||||||
|
" \"type\": \"greedy\",\n",
|
||||||
|
" },\n",
|
||||||
" \"max_tokens\": 4096,\n",
|
" \"max_tokens\": 4096,\n",
|
||||||
" \"top_p\": 0.9,\n",
|
|
||||||
" \"repeat_penalty\": 1.0,\n",
|
" \"repeat_penalty\": 1.0,\n",
|
||||||
" },\n",
|
" },\n",
|
||||||
" \"system_message\": system_message\n",
|
" \"system_message\": system_message,\n",
|
||||||
" }\n",
|
" },\n",
|
||||||
" }\n",
|
" },\n",
|
||||||
")\n",
|
")\n",
|
||||||
"pprint(response)"
|
"pprint(response)\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -1098,8 +1102,8 @@
|
||||||
" \"input_query\": {\"type\": \"string\"},\n",
|
" \"input_query\": {\"type\": \"string\"},\n",
|
||||||
" \"expected_answer\": {\"type\": \"string\"},\n",
|
" \"expected_answer\": {\"type\": \"string\"},\n",
|
||||||
" \"chat_completion_input\": {\"type\": \"chat_completion_input\"},\n",
|
" \"chat_completion_input\": {\"type\": \"chat_completion_input\"},\n",
|
||||||
" }\n",
|
" },\n",
|
||||||
")"
|
")\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -1113,7 +1117,7 @@
|
||||||
"eval_rows = client.datasetio.get_rows_paginated(\n",
|
"eval_rows = client.datasetio.get_rows_paginated(\n",
|
||||||
" dataset_id=simpleqa_dataset_id,\n",
|
" dataset_id=simpleqa_dataset_id,\n",
|
||||||
" rows_in_page=5,\n",
|
" rows_in_page=5,\n",
|
||||||
")"
|
")\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -1209,7 +1213,7 @@
|
||||||
"client.eval_tasks.register(\n",
|
"client.eval_tasks.register(\n",
|
||||||
" eval_task_id=\"meta-reference::simpleqa\",\n",
|
" eval_task_id=\"meta-reference::simpleqa\",\n",
|
||||||
" dataset_id=simpleqa_dataset_id,\n",
|
" dataset_id=simpleqa_dataset_id,\n",
|
||||||
" scoring_functions=[\"llm-as-judge::405b-simpleqa\"]\n",
|
" scoring_functions=[\"llm-as-judge::405b-simpleqa\"],\n",
|
||||||
")\n",
|
")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"response = client.eval.evaluate_rows(\n",
|
"response = client.eval.evaluate_rows(\n",
|
||||||
|
@ -1222,15 +1226,16 @@
|
||||||
" \"type\": \"model\",\n",
|
" \"type\": \"model\",\n",
|
||||||
" \"model\": \"meta-llama/Llama-3.2-90B-Vision-Instruct\",\n",
|
" \"model\": \"meta-llama/Llama-3.2-90B-Vision-Instruct\",\n",
|
||||||
" \"sampling_params\": {\n",
|
" \"sampling_params\": {\n",
|
||||||
" \"temperature\": 0.0,\n",
|
" \"strategy\": {\n",
|
||||||
|
" \"type\": \"greedy\",\n",
|
||||||
|
" },\n",
|
||||||
" \"max_tokens\": 4096,\n",
|
" \"max_tokens\": 4096,\n",
|
||||||
" \"top_p\": 0.9,\n",
|
|
||||||
" \"repeat_penalty\": 1.0,\n",
|
" \"repeat_penalty\": 1.0,\n",
|
||||||
" },\n",
|
" },\n",
|
||||||
" }\n",
|
" },\n",
|
||||||
" }\n",
|
" },\n",
|
||||||
")\n",
|
")\n",
|
||||||
"pprint(response)"
|
"pprint(response)\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -1347,23 +1352,19 @@
|
||||||
"agent_config = {\n",
|
"agent_config = {\n",
|
||||||
" \"model\": \"meta-llama/Llama-3.1-405B-Instruct\",\n",
|
" \"model\": \"meta-llama/Llama-3.1-405B-Instruct\",\n",
|
||||||
" \"instructions\": \"You are a helpful assistant\",\n",
|
" \"instructions\": \"You are a helpful assistant\",\n",
|
||||||
" \"sampling_params\": {\n",
|
" \"sampling_params\": {\"strategy\": {\"type\": \"greedy\"}},\n",
|
||||||
" \"strategy\": \"greedy\",\n",
|
|
||||||
" \"temperature\": 0.0,\n",
|
|
||||||
" \"top_p\": 0.95,\n",
|
|
||||||
" },\n",
|
|
||||||
" \"tools\": [\n",
|
" \"tools\": [\n",
|
||||||
" {\n",
|
" {\n",
|
||||||
" \"type\": \"brave_search\",\n",
|
" \"type\": \"brave_search\",\n",
|
||||||
" \"engine\": \"tavily\",\n",
|
" \"engine\": \"tavily\",\n",
|
||||||
" \"api_key\": userdata.get(\"TAVILY_SEARCH_API_KEY\")\n",
|
" \"api_key\": userdata.get(\"TAVILY_SEARCH_API_KEY\"),\n",
|
||||||
" }\n",
|
" }\n",
|
||||||
" ],\n",
|
" ],\n",
|
||||||
" \"tool_choice\": \"auto\",\n",
|
" \"tool_choice\": \"auto\",\n",
|
||||||
" \"tool_prompt_format\": \"json\",\n",
|
" \"tool_prompt_format\": \"json\",\n",
|
||||||
" \"input_shields\": [],\n",
|
" \"input_shields\": [],\n",
|
||||||
" \"output_shields\": [],\n",
|
" \"output_shields\": [],\n",
|
||||||
" \"enable_session_persistence\": False\n",
|
" \"enable_session_persistence\": False,\n",
|
||||||
"}\n",
|
"}\n",
|
||||||
"\n",
|
"\n",
|
||||||
"response = client.eval.evaluate_rows(\n",
|
"response = client.eval.evaluate_rows(\n",
|
||||||
|
@ -1375,10 +1376,10 @@
|
||||||
" \"eval_candidate\": {\n",
|
" \"eval_candidate\": {\n",
|
||||||
" \"type\": \"agent\",\n",
|
" \"type\": \"agent\",\n",
|
||||||
" \"config\": agent_config,\n",
|
" \"config\": agent_config,\n",
|
||||||
" }\n",
|
" },\n",
|
||||||
" }\n",
|
" },\n",
|
||||||
")\n",
|
")\n",
|
||||||
"pprint(response)"
|
"pprint(response)\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
|
|
@ -1336,6 +1336,7 @@
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"from rich.pretty import pprint\n",
|
"from rich.pretty import pprint\n",
|
||||||
|
"\n",
|
||||||
"print(\"Available models:\")\n",
|
"print(\"Available models:\")\n",
|
||||||
"for m in client.models.list():\n",
|
"for m in client.models.list():\n",
|
||||||
" print(f\"{m.identifier} (provider's alias: {m.provider_resource_id}) \")\n",
|
" print(f\"{m.identifier} (provider's alias: {m.provider_resource_id}) \")\n",
|
||||||
|
@ -1344,7 +1345,7 @@
|
||||||
"print(\"Available shields (safety models):\")\n",
|
"print(\"Available shields (safety models):\")\n",
|
||||||
"for s in client.shields.list():\n",
|
"for s in client.shields.list():\n",
|
||||||
" print(s.identifier)\n",
|
" print(s.identifier)\n",
|
||||||
"print(\"----\")"
|
"print(\"----\")\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -1389,7 +1390,7 @@
|
||||||
"source": [
|
"source": [
|
||||||
"model_id = \"meta-llama/Llama-3.1-70B-Instruct\"\n",
|
"model_id = \"meta-llama/Llama-3.1-70B-Instruct\"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"model_id"
|
"model_id\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -1432,11 +1433,11 @@
|
||||||
" model_id=model_id,\n",
|
" model_id=model_id,\n",
|
||||||
" messages=[\n",
|
" messages=[\n",
|
||||||
" {\"role\": \"system\", \"content\": \"You are a friendly assistant.\"},\n",
|
" {\"role\": \"system\", \"content\": \"You are a friendly assistant.\"},\n",
|
||||||
" {\"role\": \"user\", \"content\": \"Write a two-sentence poem about llama.\"}\n",
|
" {\"role\": \"user\", \"content\": \"Write a two-sentence poem about llama.\"},\n",
|
||||||
" ],\n",
|
" ],\n",
|
||||||
")\n",
|
")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"print(response.completion_message.content)"
|
"print(response.completion_message.content)\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -1489,12 +1490,13 @@
|
||||||
"source": [
|
"source": [
|
||||||
"from termcolor import cprint\n",
|
"from termcolor import cprint\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
"\n",
|
||||||
"def chat_loop():\n",
|
"def chat_loop():\n",
|
||||||
" conversation_history = []\n",
|
" conversation_history = []\n",
|
||||||
" while True:\n",
|
" while True:\n",
|
||||||
" user_input = input('User> ')\n",
|
" user_input = input(\"User> \")\n",
|
||||||
" if user_input.lower() in ['exit', 'quit', 'bye']:\n",
|
" if user_input.lower() in [\"exit\", \"quit\", \"bye\"]:\n",
|
||||||
" cprint('Ending conversation. Goodbye!', 'yellow')\n",
|
" cprint(\"Ending conversation. Goodbye!\", \"yellow\")\n",
|
||||||
" break\n",
|
" break\n",
|
||||||
"\n",
|
"\n",
|
||||||
" user_message = {\"role\": \"user\", \"content\": user_input}\n",
|
" user_message = {\"role\": \"user\", \"content\": user_input}\n",
|
||||||
|
@ -1504,15 +1506,16 @@
|
||||||
" messages=conversation_history,\n",
|
" messages=conversation_history,\n",
|
||||||
" model_id=model_id,\n",
|
" model_id=model_id,\n",
|
||||||
" )\n",
|
" )\n",
|
||||||
" cprint(f'> Response: {response.completion_message.content}', 'cyan')\n",
|
" cprint(f\"> Response: {response.completion_message.content}\", \"cyan\")\n",
|
||||||
"\n",
|
"\n",
|
||||||
" assistant_message = {\n",
|
" assistant_message = {\n",
|
||||||
" \"role\": \"assistant\", # was user\n",
|
" \"role\": \"assistant\", # was user\n",
|
||||||
" \"content\": response.completion_message.content,\n",
|
" \"content\": response.completion_message.content,\n",
|
||||||
" \"stop_reason\": response.completion_message.stop_reason,\n",
|
" \"stop_reason\": response.completion_message.stop_reason,\n",
|
||||||
" }\n",
|
" }\n",
|
||||||
" conversation_history.append(assistant_message)\n",
|
" conversation_history.append(assistant_message)\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
"\n",
|
||||||
"chat_loop()\n"
|
"chat_loop()\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
@ -1568,21 +1571,18 @@
|
||||||
"source": [
|
"source": [
|
||||||
"from llama_stack_client.lib.inference.event_logger import EventLogger\n",
|
"from llama_stack_client.lib.inference.event_logger import EventLogger\n",
|
||||||
"\n",
|
"\n",
|
||||||
"message = {\n",
|
"message = {\"role\": \"user\", \"content\": \"Write me a sonnet about llama\"}\n",
|
||||||
" \"role\": \"user\",\n",
|
"print(f'User> {message[\"content\"]}', \"green\")\n",
|
||||||
" \"content\": 'Write me a sonnet about llama'\n",
|
|
||||||
"}\n",
|
|
||||||
"print(f'User> {message[\"content\"]}', 'green')\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
"response = client.inference.chat_completion(\n",
|
"response = client.inference.chat_completion(\n",
|
||||||
" messages=[message],\n",
|
" messages=[message],\n",
|
||||||
" model_id=model_id,\n",
|
" model_id=model_id,\n",
|
||||||
" stream=True, # <-----------\n",
|
" stream=True, # <-----------\n",
|
||||||
")\n",
|
")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Print the tokens while they are received\n",
|
"# Print the tokens while they are received\n",
|
||||||
"for log in EventLogger().log(response):\n",
|
"for log in EventLogger().log(response):\n",
|
||||||
" log.print()"
|
" log.print()\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -1648,17 +1648,22 @@
|
||||||
"source": [
|
"source": [
|
||||||
"from pydantic import BaseModel\n",
|
"from pydantic import BaseModel\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
"\n",
|
||||||
"class Output(BaseModel):\n",
|
"class Output(BaseModel):\n",
|
||||||
" name: str\n",
|
" name: str\n",
|
||||||
" year_born: str\n",
|
" year_born: str\n",
|
||||||
" year_retired: str\n",
|
" year_retired: str\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
"\n",
|
||||||
"user_input = \"Michael Jordan was born in 1963. He played basketball for the Chicago Bulls. He retired in 2003. Extract this information into JSON for me. \"\n",
|
"user_input = \"Michael Jordan was born in 1963. He played basketball for the Chicago Bulls. He retired in 2003. Extract this information into JSON for me. \"\n",
|
||||||
"response = client.inference.completion(\n",
|
"response = client.inference.completion(\n",
|
||||||
" model_id=model_id,\n",
|
" model_id=model_id,\n",
|
||||||
" content=user_input,\n",
|
" content=user_input,\n",
|
||||||
" stream=False,\n",
|
" stream=False,\n",
|
||||||
" sampling_params={\n",
|
" sampling_params={\n",
|
||||||
|
" \"strategy\": {\n",
|
||||||
|
" \"type\": \"greedy\",\n",
|
||||||
|
" },\n",
|
||||||
" \"max_tokens\": 50,\n",
|
" \"max_tokens\": 50,\n",
|
||||||
" },\n",
|
" },\n",
|
||||||
" response_format={\n",
|
" response_format={\n",
|
||||||
|
@ -1667,7 +1672,7 @@
|
||||||
" },\n",
|
" },\n",
|
||||||
")\n",
|
")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"pprint(response)"
|
"pprint(response)\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -1823,7 +1828,7 @@
|
||||||
" shield_id=available_shields[0],\n",
|
" shield_id=available_shields[0],\n",
|
||||||
" params={},\n",
|
" params={},\n",
|
||||||
" )\n",
|
" )\n",
|
||||||
" pprint(response)"
|
" pprint(response)\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -2025,7 +2030,7 @@
|
||||||
"\n",
|
"\n",
|
||||||
"session_id = agent.create_session(\"test-session\")\n",
|
"session_id = agent.create_session(\"test-session\")\n",
|
||||||
"for prompt in user_prompts:\n",
|
"for prompt in user_prompts:\n",
|
||||||
" cprint(f'User> {prompt}', 'green')\n",
|
" cprint(f\"User> {prompt}\", \"green\")\n",
|
||||||
" response = agent.create_turn(\n",
|
" response = agent.create_turn(\n",
|
||||||
" messages=[\n",
|
" messages=[\n",
|
||||||
" {\n",
|
" {\n",
|
||||||
|
@ -2451,8 +2456,8 @@
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"import pandas as pd\n",
|
|
||||||
"import matplotlib.pyplot as plt\n",
|
"import matplotlib.pyplot as plt\n",
|
||||||
|
"import pandas as pd\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Load data\n",
|
"# Load data\n",
|
||||||
"df = pd.read_csv(\"/tmp/tmpvzjigv7g/n2OzlTWhinflation.csv\")\n",
|
"df = pd.read_csv(\"/tmp/tmpvzjigv7g/n2OzlTWhinflation.csv\")\n",
|
||||||
|
@ -2536,10 +2541,10 @@
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
|
"from google.colab import userdata\n",
|
||||||
"from llama_stack_client.lib.agents.agent import Agent\n",
|
"from llama_stack_client.lib.agents.agent import Agent\n",
|
||||||
"from llama_stack_client.lib.agents.event_logger import EventLogger\n",
|
"from llama_stack_client.lib.agents.event_logger import EventLogger\n",
|
||||||
"from llama_stack_client.types.agent_create_params import AgentConfig\n",
|
"from llama_stack_client.types.agent_create_params import AgentConfig\n",
|
||||||
"from google.colab import userdata\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
"agent_config = AgentConfig(\n",
|
"agent_config = AgentConfig(\n",
|
||||||
" model=\"meta-llama/Llama-3.1-405B-Instruct-FP8\",\n",
|
" model=\"meta-llama/Llama-3.1-405B-Instruct-FP8\",\n",
|
||||||
|
@ -2570,7 +2575,7 @@
|
||||||
" )\n",
|
" )\n",
|
||||||
"\n",
|
"\n",
|
||||||
" for log in EventLogger().log(response):\n",
|
" for log in EventLogger().log(response):\n",
|
||||||
" log.print()"
|
" log.print()\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -2790,20 +2795,21 @@
|
||||||
"source": [
|
"source": [
|
||||||
"print(f\"Getting traces for session_id={session_id}\")\n",
|
"print(f\"Getting traces for session_id={session_id}\")\n",
|
||||||
"import json\n",
|
"import json\n",
|
||||||
|
"\n",
|
||||||
"from rich.pretty import pprint\n",
|
"from rich.pretty import pprint\n",
|
||||||
"\n",
|
"\n",
|
||||||
"agent_logs = []\n",
|
"agent_logs = []\n",
|
||||||
"\n",
|
"\n",
|
||||||
"for span in client.telemetry.query_spans(\n",
|
"for span in client.telemetry.query_spans(\n",
|
||||||
" attribute_filters=[\n",
|
" attribute_filters=[\n",
|
||||||
" {\"key\": \"session_id\", \"op\": \"eq\", \"value\": session_id},\n",
|
" {\"key\": \"session_id\", \"op\": \"eq\", \"value\": session_id},\n",
|
||||||
" ],\n",
|
" ],\n",
|
||||||
" attributes_to_return=[\"input\", \"output\"]\n",
|
" attributes_to_return=[\"input\", \"output\"],\n",
|
||||||
" ):\n",
|
"):\n",
|
||||||
" if span.attributes[\"output\"] != \"no shields\":\n",
|
" if span.attributes[\"output\"] != \"no shields\":\n",
|
||||||
" agent_logs.append(span.attributes)\n",
|
" agent_logs.append(span.attributes)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"pprint(agent_logs)"
|
"pprint(agent_logs)\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -2914,23 +2920,25 @@
|
||||||
"eval_rows = []\n",
|
"eval_rows = []\n",
|
||||||
"\n",
|
"\n",
|
||||||
"for log in agent_logs:\n",
|
"for log in agent_logs:\n",
|
||||||
" last_msg = log['input'][-1]\n",
|
" last_msg = log[\"input\"][-1]\n",
|
||||||
" if \"\\\"role\\\":\\\"user\\\"\" in last_msg:\n",
|
" if '\"role\":\"user\"' in last_msg:\n",
|
||||||
" eval_rows.append(\n",
|
" eval_rows.append(\n",
|
||||||
" {\n",
|
" {\n",
|
||||||
" \"input_query\": last_msg,\n",
|
" \"input_query\": last_msg,\n",
|
||||||
" \"generated_answer\": log[\"output\"],\n",
|
" \"generated_answer\": log[\"output\"],\n",
|
||||||
" # check if generated_answer uses tools brave_search\n",
|
" # check if generated_answer uses tools brave_search\n",
|
||||||
" \"expected_answer\": \"brave_search\",\n",
|
" \"expected_answer\": \"brave_search\",\n",
|
||||||
" },\n",
|
" },\n",
|
||||||
" )\n",
|
" )\n",
|
||||||
"\n",
|
"\n",
|
||||||
"pprint(eval_rows)\n",
|
"pprint(eval_rows)\n",
|
||||||
"scoring_params = {\n",
|
"scoring_params = {\n",
|
||||||
" \"basic::subset_of\": None,\n",
|
" \"basic::subset_of\": None,\n",
|
||||||
"}\n",
|
"}\n",
|
||||||
"scoring_response = client.scoring.score(input_rows=eval_rows, scoring_functions=scoring_params)\n",
|
"scoring_response = client.scoring.score(\n",
|
||||||
"pprint(scoring_response)"
|
" input_rows=eval_rows, scoring_functions=scoring_params\n",
|
||||||
|
")\n",
|
||||||
|
"pprint(scoring_response)\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -3031,7 +3039,9 @@
|
||||||
"EXPECTED_RESPONSE: {expected_answer}\n",
|
"EXPECTED_RESPONSE: {expected_answer}\n",
|
||||||
"\"\"\"\n",
|
"\"\"\"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"input_query = \"What are the top 5 topics that were explained? Only list succinct bullet points.\"\n",
|
"input_query = (\n",
|
||||||
|
" \"What are the top 5 topics that were explained? Only list succinct bullet points.\"\n",
|
||||||
|
")\n",
|
||||||
"generated_answer = \"\"\"\n",
|
"generated_answer = \"\"\"\n",
|
||||||
"Here are the top 5 topics that were explained in the documentation for Torchtune:\n",
|
"Here are the top 5 topics that were explained in the documentation for Torchtune:\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
@ -3062,7 +3072,7 @@
|
||||||
"}\n",
|
"}\n",
|
||||||
"\n",
|
"\n",
|
||||||
"response = client.scoring.score(input_rows=rows, scoring_functions=scoring_params)\n",
|
"response = client.scoring.score(input_rows=rows, scoring_functions=scoring_params)\n",
|
||||||
"pprint(response)"
|
"pprint(response)\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
|
|
@ -3514,6 +3514,20 @@
|
||||||
"tool_calls"
|
"tool_calls"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
"GreedySamplingStrategy": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"type": {
|
||||||
|
"type": "string",
|
||||||
|
"const": "greedy",
|
||||||
|
"default": "greedy"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"type"
|
||||||
|
]
|
||||||
|
},
|
||||||
"ImageContentItem": {
|
"ImageContentItem": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
|
@ -3581,20 +3595,17 @@
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"strategy": {
|
"strategy": {
|
||||||
"$ref": "#/components/schemas/SamplingStrategy",
|
"oneOf": [
|
||||||
"default": "greedy"
|
{
|
||||||
},
|
"$ref": "#/components/schemas/GreedySamplingStrategy"
|
||||||
"temperature": {
|
},
|
||||||
"type": "number",
|
{
|
||||||
"default": 0.0
|
"$ref": "#/components/schemas/TopPSamplingStrategy"
|
||||||
},
|
},
|
||||||
"top_p": {
|
{
|
||||||
"type": "number",
|
"$ref": "#/components/schemas/TopKSamplingStrategy"
|
||||||
"default": 0.95
|
}
|
||||||
},
|
]
|
||||||
"top_k": {
|
|
||||||
"type": "integer",
|
|
||||||
"default": 0
|
|
||||||
},
|
},
|
||||||
"max_tokens": {
|
"max_tokens": {
|
||||||
"type": "integer",
|
"type": "integer",
|
||||||
|
@ -3610,14 +3621,6 @@
|
||||||
"strategy"
|
"strategy"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"SamplingStrategy": {
|
|
||||||
"type": "string",
|
|
||||||
"enum": [
|
|
||||||
"greedy",
|
|
||||||
"top_p",
|
|
||||||
"top_k"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"StopReason": {
|
"StopReason": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"enum": [
|
"enum": [
|
||||||
|
@ -3871,6 +3874,45 @@
|
||||||
"content"
|
"content"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
"TopKSamplingStrategy": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"type": {
|
||||||
|
"type": "string",
|
||||||
|
"const": "top_k",
|
||||||
|
"default": "top_k"
|
||||||
|
},
|
||||||
|
"top_k": {
|
||||||
|
"type": "integer"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"type",
|
||||||
|
"top_k"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"TopPSamplingStrategy": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"type": {
|
||||||
|
"type": "string",
|
||||||
|
"const": "top_p",
|
||||||
|
"default": "top_p"
|
||||||
|
},
|
||||||
|
"temperature": {
|
||||||
|
"type": "number"
|
||||||
|
},
|
||||||
|
"top_p": {
|
||||||
|
"type": "number",
|
||||||
|
"default": 0.95
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"type"
|
||||||
|
]
|
||||||
|
},
|
||||||
"URL": {
|
"URL": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
|
@ -8887,6 +8929,10 @@
|
||||||
"name": "GraphMemoryBankParams",
|
"name": "GraphMemoryBankParams",
|
||||||
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/GraphMemoryBankParams\" />"
|
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/GraphMemoryBankParams\" />"
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"name": "GreedySamplingStrategy",
|
||||||
|
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/GreedySamplingStrategy\" />"
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"name": "HealthInfo",
|
"name": "HealthInfo",
|
||||||
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/HealthInfo\" />"
|
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/HealthInfo\" />"
|
||||||
|
@ -9136,10 +9182,6 @@
|
||||||
"name": "SamplingParams",
|
"name": "SamplingParams",
|
||||||
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/SamplingParams\" />"
|
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/SamplingParams\" />"
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"name": "SamplingStrategy",
|
|
||||||
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/SamplingStrategy\" />"
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"name": "SaveSpansToDatasetRequest",
|
"name": "SaveSpansToDatasetRequest",
|
||||||
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/SaveSpansToDatasetRequest\" />"
|
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/SaveSpansToDatasetRequest\" />"
|
||||||
|
@ -9317,6 +9359,14 @@
|
||||||
{
|
{
|
||||||
"name": "ToolRuntime"
|
"name": "ToolRuntime"
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"name": "TopKSamplingStrategy",
|
||||||
|
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/TopKSamplingStrategy\" />"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "TopPSamplingStrategy",
|
||||||
|
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/TopPSamplingStrategy\" />"
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"name": "Trace",
|
"name": "Trace",
|
||||||
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/Trace\" />"
|
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/Trace\" />"
|
||||||
|
@ -9456,6 +9506,7 @@
|
||||||
"GetSpanTreeRequest",
|
"GetSpanTreeRequest",
|
||||||
"GraphMemoryBank",
|
"GraphMemoryBank",
|
||||||
"GraphMemoryBankParams",
|
"GraphMemoryBankParams",
|
||||||
|
"GreedySamplingStrategy",
|
||||||
"HealthInfo",
|
"HealthInfo",
|
||||||
"ImageContentItem",
|
"ImageContentItem",
|
||||||
"InferenceStep",
|
"InferenceStep",
|
||||||
|
@ -9513,7 +9564,6 @@
|
||||||
"RunShieldResponse",
|
"RunShieldResponse",
|
||||||
"SafetyViolation",
|
"SafetyViolation",
|
||||||
"SamplingParams",
|
"SamplingParams",
|
||||||
"SamplingStrategy",
|
|
||||||
"SaveSpansToDatasetRequest",
|
"SaveSpansToDatasetRequest",
|
||||||
"ScoreBatchRequest",
|
"ScoreBatchRequest",
|
||||||
"ScoreBatchResponse",
|
"ScoreBatchResponse",
|
||||||
|
@ -9553,6 +9603,8 @@
|
||||||
"ToolPromptFormat",
|
"ToolPromptFormat",
|
||||||
"ToolResponse",
|
"ToolResponse",
|
||||||
"ToolResponseMessage",
|
"ToolResponseMessage",
|
||||||
|
"TopKSamplingStrategy",
|
||||||
|
"TopPSamplingStrategy",
|
||||||
"Trace",
|
"Trace",
|
||||||
"TrainingConfig",
|
"TrainingConfig",
|
||||||
"Turn",
|
"Turn",
|
||||||
|
|
|
@ -937,6 +937,16 @@ components:
|
||||||
required:
|
required:
|
||||||
- memory_bank_type
|
- memory_bank_type
|
||||||
type: object
|
type: object
|
||||||
|
GreedySamplingStrategy:
|
||||||
|
additionalProperties: false
|
||||||
|
properties:
|
||||||
|
type:
|
||||||
|
const: greedy
|
||||||
|
default: greedy
|
||||||
|
type: string
|
||||||
|
required:
|
||||||
|
- type
|
||||||
|
type: object
|
||||||
HealthInfo:
|
HealthInfo:
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
properties:
|
properties:
|
||||||
|
@ -2064,26 +2074,13 @@ components:
|
||||||
default: 1.0
|
default: 1.0
|
||||||
type: number
|
type: number
|
||||||
strategy:
|
strategy:
|
||||||
$ref: '#/components/schemas/SamplingStrategy'
|
oneOf:
|
||||||
default: greedy
|
- $ref: '#/components/schemas/GreedySamplingStrategy'
|
||||||
temperature:
|
- $ref: '#/components/schemas/TopPSamplingStrategy'
|
||||||
default: 0.0
|
- $ref: '#/components/schemas/TopKSamplingStrategy'
|
||||||
type: number
|
|
||||||
top_k:
|
|
||||||
default: 0
|
|
||||||
type: integer
|
|
||||||
top_p:
|
|
||||||
default: 0.95
|
|
||||||
type: number
|
|
||||||
required:
|
required:
|
||||||
- strategy
|
- strategy
|
||||||
type: object
|
type: object
|
||||||
SamplingStrategy:
|
|
||||||
enum:
|
|
||||||
- greedy
|
|
||||||
- top_p
|
|
||||||
- top_k
|
|
||||||
type: string
|
|
||||||
SaveSpansToDatasetRequest:
|
SaveSpansToDatasetRequest:
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
properties:
|
properties:
|
||||||
|
@ -2931,6 +2928,34 @@ components:
|
||||||
- tool_name
|
- tool_name
|
||||||
- content
|
- content
|
||||||
type: object
|
type: object
|
||||||
|
TopKSamplingStrategy:
|
||||||
|
additionalProperties: false
|
||||||
|
properties:
|
||||||
|
top_k:
|
||||||
|
type: integer
|
||||||
|
type:
|
||||||
|
const: top_k
|
||||||
|
default: top_k
|
||||||
|
type: string
|
||||||
|
required:
|
||||||
|
- type
|
||||||
|
- top_k
|
||||||
|
type: object
|
||||||
|
TopPSamplingStrategy:
|
||||||
|
additionalProperties: false
|
||||||
|
properties:
|
||||||
|
temperature:
|
||||||
|
type: number
|
||||||
|
top_p:
|
||||||
|
default: 0.95
|
||||||
|
type: number
|
||||||
|
type:
|
||||||
|
const: top_p
|
||||||
|
default: top_p
|
||||||
|
type: string
|
||||||
|
required:
|
||||||
|
- type
|
||||||
|
type: object
|
||||||
Trace:
|
Trace:
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
properties:
|
properties:
|
||||||
|
@ -5587,6 +5612,9 @@ tags:
|
||||||
- description: <SchemaDefinition schemaRef="#/components/schemas/GraphMemoryBankParams"
|
- description: <SchemaDefinition schemaRef="#/components/schemas/GraphMemoryBankParams"
|
||||||
/>
|
/>
|
||||||
name: GraphMemoryBankParams
|
name: GraphMemoryBankParams
|
||||||
|
- description: <SchemaDefinition schemaRef="#/components/schemas/GreedySamplingStrategy"
|
||||||
|
/>
|
||||||
|
name: GreedySamplingStrategy
|
||||||
- description: <SchemaDefinition schemaRef="#/components/schemas/HealthInfo" />
|
- description: <SchemaDefinition schemaRef="#/components/schemas/HealthInfo" />
|
||||||
name: HealthInfo
|
name: HealthInfo
|
||||||
- description: <SchemaDefinition schemaRef="#/components/schemas/ImageContentItem"
|
- description: <SchemaDefinition schemaRef="#/components/schemas/ImageContentItem"
|
||||||
|
@ -5753,9 +5781,6 @@ tags:
|
||||||
name: SafetyViolation
|
name: SafetyViolation
|
||||||
- description: <SchemaDefinition schemaRef="#/components/schemas/SamplingParams" />
|
- description: <SchemaDefinition schemaRef="#/components/schemas/SamplingParams" />
|
||||||
name: SamplingParams
|
name: SamplingParams
|
||||||
- description: <SchemaDefinition schemaRef="#/components/schemas/SamplingStrategy"
|
|
||||||
/>
|
|
||||||
name: SamplingStrategy
|
|
||||||
- description: <SchemaDefinition schemaRef="#/components/schemas/SaveSpansToDatasetRequest"
|
- description: <SchemaDefinition schemaRef="#/components/schemas/SaveSpansToDatasetRequest"
|
||||||
/>
|
/>
|
||||||
name: SaveSpansToDatasetRequest
|
name: SaveSpansToDatasetRequest
|
||||||
|
@ -5874,6 +5899,12 @@ tags:
|
||||||
/>
|
/>
|
||||||
name: ToolResponseMessage
|
name: ToolResponseMessage
|
||||||
- name: ToolRuntime
|
- name: ToolRuntime
|
||||||
|
- description: <SchemaDefinition schemaRef="#/components/schemas/TopKSamplingStrategy"
|
||||||
|
/>
|
||||||
|
name: TopKSamplingStrategy
|
||||||
|
- description: <SchemaDefinition schemaRef="#/components/schemas/TopPSamplingStrategy"
|
||||||
|
/>
|
||||||
|
name: TopPSamplingStrategy
|
||||||
- description: <SchemaDefinition schemaRef="#/components/schemas/Trace" />
|
- description: <SchemaDefinition schemaRef="#/components/schemas/Trace" />
|
||||||
name: Trace
|
name: Trace
|
||||||
- description: <SchemaDefinition schemaRef="#/components/schemas/TrainingConfig" />
|
- description: <SchemaDefinition schemaRef="#/components/schemas/TrainingConfig" />
|
||||||
|
@ -5990,6 +6021,7 @@ x-tagGroups:
|
||||||
- GetSpanTreeRequest
|
- GetSpanTreeRequest
|
||||||
- GraphMemoryBank
|
- GraphMemoryBank
|
||||||
- GraphMemoryBankParams
|
- GraphMemoryBankParams
|
||||||
|
- GreedySamplingStrategy
|
||||||
- HealthInfo
|
- HealthInfo
|
||||||
- ImageContentItem
|
- ImageContentItem
|
||||||
- InferenceStep
|
- InferenceStep
|
||||||
|
@ -6047,7 +6079,6 @@ x-tagGroups:
|
||||||
- RunShieldResponse
|
- RunShieldResponse
|
||||||
- SafetyViolation
|
- SafetyViolation
|
||||||
- SamplingParams
|
- SamplingParams
|
||||||
- SamplingStrategy
|
|
||||||
- SaveSpansToDatasetRequest
|
- SaveSpansToDatasetRequest
|
||||||
- ScoreBatchRequest
|
- ScoreBatchRequest
|
||||||
- ScoreBatchResponse
|
- ScoreBatchResponse
|
||||||
|
@ -6087,6 +6118,8 @@ x-tagGroups:
|
||||||
- ToolPromptFormat
|
- ToolPromptFormat
|
||||||
- ToolResponse
|
- ToolResponse
|
||||||
- ToolResponseMessage
|
- ToolResponseMessage
|
||||||
|
- TopKSamplingStrategy
|
||||||
|
- TopPSamplingStrategy
|
||||||
- Trace
|
- Trace
|
||||||
- TrainingConfig
|
- TrainingConfig
|
||||||
- Turn
|
- Turn
|
||||||
|
|
|
@ -56,9 +56,10 @@ response = client.eval.evaluate_rows(
|
||||||
"type": "model",
|
"type": "model",
|
||||||
"model": "meta-llama/Llama-3.2-90B-Vision-Instruct",
|
"model": "meta-llama/Llama-3.2-90B-Vision-Instruct",
|
||||||
"sampling_params": {
|
"sampling_params": {
|
||||||
"temperature": 0.0,
|
"strategy": {
|
||||||
|
"type": "greedy",
|
||||||
|
},
|
||||||
"max_tokens": 4096,
|
"max_tokens": 4096,
|
||||||
"top_p": 0.9,
|
|
||||||
"repeat_penalty": 1.0,
|
"repeat_penalty": 1.0,
|
||||||
},
|
},
|
||||||
"system_message": system_message
|
"system_message": system_message
|
||||||
|
@ -113,9 +114,10 @@ response = client.eval.evaluate_rows(
|
||||||
"type": "model",
|
"type": "model",
|
||||||
"model": "meta-llama/Llama-3.2-90B-Vision-Instruct",
|
"model": "meta-llama/Llama-3.2-90B-Vision-Instruct",
|
||||||
"sampling_params": {
|
"sampling_params": {
|
||||||
"temperature": 0.0,
|
"strategy": {
|
||||||
|
"type": "greedy",
|
||||||
|
},
|
||||||
"max_tokens": 4096,
|
"max_tokens": 4096,
|
||||||
"top_p": 0.9,
|
|
||||||
"repeat_penalty": 1.0,
|
"repeat_penalty": 1.0,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -134,9 +136,9 @@ agent_config = {
|
||||||
"model": "meta-llama/Llama-3.1-405B-Instruct",
|
"model": "meta-llama/Llama-3.1-405B-Instruct",
|
||||||
"instructions": "You are a helpful assistant",
|
"instructions": "You are a helpful assistant",
|
||||||
"sampling_params": {
|
"sampling_params": {
|
||||||
"strategy": "greedy",
|
"strategy": {
|
||||||
"temperature": 0.0,
|
"type": "greedy",
|
||||||
"top_p": 0.95,
|
},
|
||||||
},
|
},
|
||||||
"tools": [
|
"tools": [
|
||||||
{
|
{
|
||||||
|
|
|
@ -189,7 +189,11 @@ agent_config = AgentConfig(
|
||||||
# Control the inference loop
|
# Control the inference loop
|
||||||
max_infer_iters=5,
|
max_infer_iters=5,
|
||||||
sampling_params={
|
sampling_params={
|
||||||
"temperature": 0.7,
|
"strategy": {
|
||||||
|
"type": "top_p",
|
||||||
|
"temperature": 0.7,
|
||||||
|
"top_p": 0.95
|
||||||
|
},
|
||||||
"max_tokens": 2048
|
"max_tokens": 2048
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
|
@ -92,9 +92,10 @@ response = client.eval.evaluate_rows(
|
||||||
"type": "model",
|
"type": "model",
|
||||||
"model": "meta-llama/Llama-3.2-90B-Vision-Instruct",
|
"model": "meta-llama/Llama-3.2-90B-Vision-Instruct",
|
||||||
"sampling_params": {
|
"sampling_params": {
|
||||||
"temperature": 0.0,
|
"strategy": {
|
||||||
|
"type": "greedy",
|
||||||
|
},
|
||||||
"max_tokens": 4096,
|
"max_tokens": 4096,
|
||||||
"top_p": 0.9,
|
|
||||||
"repeat_penalty": 1.0,
|
"repeat_penalty": 1.0,
|
||||||
},
|
},
|
||||||
"system_message": system_message
|
"system_message": system_message
|
||||||
|
@ -149,9 +150,10 @@ response = client.eval.evaluate_rows(
|
||||||
"type": "model",
|
"type": "model",
|
||||||
"model": "meta-llama/Llama-3.2-90B-Vision-Instruct",
|
"model": "meta-llama/Llama-3.2-90B-Vision-Instruct",
|
||||||
"sampling_params": {
|
"sampling_params": {
|
||||||
"temperature": 0.0,
|
"strategy": {
|
||||||
|
"type": "greedy",
|
||||||
|
},
|
||||||
"max_tokens": 4096,
|
"max_tokens": 4096,
|
||||||
"top_p": 0.9,
|
|
||||||
"repeat_penalty": 1.0,
|
"repeat_penalty": 1.0,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -170,9 +172,9 @@ agent_config = {
|
||||||
"model": "meta-llama/Llama-3.1-405B-Instruct",
|
"model": "meta-llama/Llama-3.1-405B-Instruct",
|
||||||
"instructions": "You are a helpful assistant",
|
"instructions": "You are a helpful assistant",
|
||||||
"sampling_params": {
|
"sampling_params": {
|
||||||
"strategy": "greedy",
|
"strategy": {
|
||||||
"temperature": 0.0,
|
"type": "greedy",
|
||||||
"top_p": 0.95,
|
},
|
||||||
},
|
},
|
||||||
"tools": [
|
"tools": [
|
||||||
{
|
{
|
||||||
|
@ -318,10 +320,9 @@ The `EvalTaskConfig` are user specified config to define:
|
||||||
"type": "model",
|
"type": "model",
|
||||||
"model": "Llama3.2-3B-Instruct",
|
"model": "Llama3.2-3B-Instruct",
|
||||||
"sampling_params": {
|
"sampling_params": {
|
||||||
"strategy": "greedy",
|
"strategy": {
|
||||||
"temperature": 0,
|
"type": "greedy",
|
||||||
"top_p": 0.95,
|
},
|
||||||
"top_k": 0,
|
|
||||||
"max_tokens": 0,
|
"max_tokens": 0,
|
||||||
"repetition_penalty": 1.0
|
"repetition_penalty": 1.0
|
||||||
}
|
}
|
||||||
|
@ -337,10 +338,9 @@ The `EvalTaskConfig` are user specified config to define:
|
||||||
"type": "model",
|
"type": "model",
|
||||||
"model": "Llama3.1-405B-Instruct",
|
"model": "Llama3.1-405B-Instruct",
|
||||||
"sampling_params": {
|
"sampling_params": {
|
||||||
"strategy": "greedy",
|
"strategy": {
|
||||||
"temperature": 0,
|
"type": "greedy",
|
||||||
"top_p": 0.95,
|
},
|
||||||
"top_k": 0,
|
|
||||||
"max_tokens": 0,
|
"max_tokens": 0,
|
||||||
"repetition_penalty": 1.0
|
"repetition_penalty": 1.0
|
||||||
}
|
}
|
||||||
|
|
|
@ -214,7 +214,6 @@ llama model describe -m Llama3.2-3B-Instruct
|
||||||
| | } |
|
| | } |
|
||||||
+-----------------------------+----------------------------------+
|
+-----------------------------+----------------------------------+
|
||||||
| Recommended sampling params | { |
|
| Recommended sampling params | { |
|
||||||
| | "strategy": "top_p", |
|
|
||||||
| | "temperature": 1.0, |
|
| | "temperature": 1.0, |
|
||||||
| | "top_p": 0.9, |
|
| | "top_p": 0.9, |
|
||||||
| | "top_k": 0 |
|
| | "top_k": 0 |
|
||||||
|
|
|
@ -200,10 +200,9 @@ Example eval_task_config.json:
|
||||||
"type": "model",
|
"type": "model",
|
||||||
"model": "Llama3.1-405B-Instruct",
|
"model": "Llama3.1-405B-Instruct",
|
||||||
"sampling_params": {
|
"sampling_params": {
|
||||||
"strategy": "greedy",
|
"strategy": {
|
||||||
"temperature": 0,
|
"type": "greedy"
|
||||||
"top_p": 0.95,
|
},
|
||||||
"top_k": 0,
|
|
||||||
"max_tokens": 0,
|
"max_tokens": 0,
|
||||||
"repetition_penalty": 1.0
|
"repetition_penalty": 1.0
|
||||||
}
|
}
|
||||||
|
|
|
@ -26,27 +26,28 @@
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"import os\n",
|
|
||||||
"import requests\n",
|
|
||||||
"import json\n",
|
|
||||||
"import asyncio\n",
|
"import asyncio\n",
|
||||||
"import nest_asyncio\n",
|
"import json\n",
|
||||||
|
"import os\n",
|
||||||
"from typing import Dict, List\n",
|
"from typing import Dict, List\n",
|
||||||
|
"\n",
|
||||||
|
"import nest_asyncio\n",
|
||||||
|
"import requests\n",
|
||||||
"from dotenv import load_dotenv\n",
|
"from dotenv import load_dotenv\n",
|
||||||
"from llama_stack_client import LlamaStackClient\n",
|
"from llama_stack_client import LlamaStackClient\n",
|
||||||
"from llama_stack_client.lib.agents.custom_tool import CustomTool\n",
|
|
||||||
"from llama_stack_client.types.shared.tool_response_message import ToolResponseMessage\n",
|
|
||||||
"from llama_stack_client.types import CompletionMessage\n",
|
|
||||||
"from llama_stack_client.lib.agents.agent import Agent\n",
|
"from llama_stack_client.lib.agents.agent import Agent\n",
|
||||||
|
"from llama_stack_client.lib.agents.custom_tool import CustomTool\n",
|
||||||
"from llama_stack_client.lib.agents.event_logger import EventLogger\n",
|
"from llama_stack_client.lib.agents.event_logger import EventLogger\n",
|
||||||
|
"from llama_stack_client.types import CompletionMessage\n",
|
||||||
"from llama_stack_client.types.agent_create_params import AgentConfig\n",
|
"from llama_stack_client.types.agent_create_params import AgentConfig\n",
|
||||||
|
"from llama_stack_client.types.shared.tool_response_message import ToolResponseMessage\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Allow asyncio to run in Jupyter Notebook\n",
|
"# Allow asyncio to run in Jupyter Notebook\n",
|
||||||
"nest_asyncio.apply()\n",
|
"nest_asyncio.apply()\n",
|
||||||
"\n",
|
"\n",
|
||||||
"HOST='localhost'\n",
|
"HOST = \"localhost\"\n",
|
||||||
"PORT=5001\n",
|
"PORT = 5001\n",
|
||||||
"MODEL_NAME='meta-llama/Llama-3.2-3B-Instruct'"
|
"MODEL_NAME = \"meta-llama/Llama-3.2-3B-Instruct\"\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -69,7 +70,7 @@
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"load_dotenv()\n",
|
"load_dotenv()\n",
|
||||||
"BRAVE_SEARCH_API_KEY = os.environ['BRAVE_SEARCH_API_KEY']"
|
"BRAVE_SEARCH_API_KEY = os.environ[\"BRAVE_SEARCH_API_KEY\"]\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -118,7 +119,7 @@
|
||||||
" cleaned = {k: v for k, v in results[idx].items() if k in selected_keys}\n",
|
" cleaned = {k: v for k, v in results[idx].items() if k in selected_keys}\n",
|
||||||
" clean_response.append(cleaned)\n",
|
" clean_response.append(cleaned)\n",
|
||||||
"\n",
|
"\n",
|
||||||
" return {\"query\": query, \"top_k\": clean_response}"
|
" return {\"query\": query, \"top_k\": clean_response}\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -157,25 +158,29 @@
|
||||||
" for message in messages:\n",
|
" for message in messages:\n",
|
||||||
" if isinstance(message, CompletionMessage) and message.tool_calls:\n",
|
" if isinstance(message, CompletionMessage) and message.tool_calls:\n",
|
||||||
" for tool_call in message.tool_calls:\n",
|
" for tool_call in message.tool_calls:\n",
|
||||||
" if 'query' in tool_call.arguments:\n",
|
" if \"query\" in tool_call.arguments:\n",
|
||||||
" query = tool_call.arguments['query']\n",
|
" query = tool_call.arguments[\"query\"]\n",
|
||||||
" call_id = tool_call.call_id\n",
|
" call_id = tool_call.call_id\n",
|
||||||
"\n",
|
"\n",
|
||||||
" if query:\n",
|
" if query:\n",
|
||||||
" search_result = await self.run_impl(query)\n",
|
" search_result = await self.run_impl(query)\n",
|
||||||
" return [ToolResponseMessage(\n",
|
" return [\n",
|
||||||
" call_id=call_id,\n",
|
" ToolResponseMessage(\n",
|
||||||
" role=\"ipython\",\n",
|
" call_id=call_id,\n",
|
||||||
" content=self._format_response_for_agent(search_result),\n",
|
" role=\"ipython\",\n",
|
||||||
" tool_name=\"brave_search\"\n",
|
" content=self._format_response_for_agent(search_result),\n",
|
||||||
" )]\n",
|
" tool_name=\"brave_search\",\n",
|
||||||
|
" )\n",
|
||||||
|
" ]\n",
|
||||||
"\n",
|
"\n",
|
||||||
" return [ToolResponseMessage(\n",
|
" return [\n",
|
||||||
" call_id=\"no_call_id\",\n",
|
" ToolResponseMessage(\n",
|
||||||
" role=\"ipython\",\n",
|
" call_id=\"no_call_id\",\n",
|
||||||
" content=\"No query provided.\",\n",
|
" role=\"ipython\",\n",
|
||||||
" tool_name=\"brave_search\"\n",
|
" content=\"No query provided.\",\n",
|
||||||
" )]\n",
|
" tool_name=\"brave_search\",\n",
|
||||||
|
" )\n",
|
||||||
|
" ]\n",
|
||||||
"\n",
|
"\n",
|
||||||
" def _format_response_for_agent(self, search_result):\n",
|
" def _format_response_for_agent(self, search_result):\n",
|
||||||
" parsed_result = json.loads(search_result)\n",
|
" parsed_result = json.loads(search_result)\n",
|
||||||
|
@ -186,7 +191,7 @@
|
||||||
" f\" URL: {result.get('url', 'No URL')}\\n\"\n",
|
" f\" URL: {result.get('url', 'No URL')}\\n\"\n",
|
||||||
" f\" Description: {result.get('description', 'No Description')}\\n\\n\"\n",
|
" f\" Description: {result.get('description', 'No Description')}\\n\\n\"\n",
|
||||||
" )\n",
|
" )\n",
|
||||||
" return formatted_result"
|
" return formatted_result\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -209,7 +214,7 @@
|
||||||
"async def execute_search(query: str):\n",
|
"async def execute_search(query: str):\n",
|
||||||
" web_search_tool = WebSearchTool(api_key=BRAVE_SEARCH_API_KEY)\n",
|
" web_search_tool = WebSearchTool(api_key=BRAVE_SEARCH_API_KEY)\n",
|
||||||
" result = await web_search_tool.run_impl(query)\n",
|
" result = await web_search_tool.run_impl(query)\n",
|
||||||
" print(\"Search Results:\", result)"
|
" print(\"Search Results:\", result)\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -236,7 +241,7 @@
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"query = \"Latest developments in quantum computing\"\n",
|
"query = \"Latest developments in quantum computing\"\n",
|
||||||
"asyncio.run(execute_search(query))"
|
"asyncio.run(execute_search(query))\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -288,19 +293,17 @@
|
||||||
"\n",
|
"\n",
|
||||||
" # Initialize custom tool (ensure `WebSearchTool` is defined earlier in the notebook)\n",
|
" # Initialize custom tool (ensure `WebSearchTool` is defined earlier in the notebook)\n",
|
||||||
" webSearchTool = WebSearchTool(api_key=BRAVE_SEARCH_API_KEY)\n",
|
" webSearchTool = WebSearchTool(api_key=BRAVE_SEARCH_API_KEY)\n",
|
||||||
" \n",
|
"\n",
|
||||||
" # Define the agent configuration, including the model and tool setup\n",
|
" # Define the agent configuration, including the model and tool setup\n",
|
||||||
" agent_config = AgentConfig(\n",
|
" agent_config = AgentConfig(\n",
|
||||||
" model=MODEL_NAME,\n",
|
" model=MODEL_NAME,\n",
|
||||||
" instructions=\"\"\"You are a helpful assistant that responds to user queries with relevant information and cites sources when available.\"\"\",\n",
|
" instructions=\"\"\"You are a helpful assistant that responds to user queries with relevant information and cites sources when available.\"\"\",\n",
|
||||||
" sampling_params={\n",
|
" sampling_params={\n",
|
||||||
" \"strategy\": \"greedy\",\n",
|
" \"strategy\": {\n",
|
||||||
" \"temperature\": 1.0,\n",
|
" \"type\": \"greedy\",\n",
|
||||||
" \"top_p\": 0.9,\n",
|
" },\n",
|
||||||
" },\n",
|
" },\n",
|
||||||
" tools=[\n",
|
" tools=[webSearchTool.get_tool_definition()],\n",
|
||||||
" webSearchTool.get_tool_definition()\n",
|
|
||||||
" ],\n",
|
|
||||||
" tool_choice=\"auto\",\n",
|
" tool_choice=\"auto\",\n",
|
||||||
" tool_prompt_format=\"python_list\",\n",
|
" tool_prompt_format=\"python_list\",\n",
|
||||||
" input_shields=input_shields,\n",
|
" input_shields=input_shields,\n",
|
||||||
|
@ -329,8 +332,9 @@
|
||||||
" async for log in EventLogger().log(response):\n",
|
" async for log in EventLogger().log(response):\n",
|
||||||
" log.print()\n",
|
" log.print()\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
"\n",
|
||||||
"# Run the function asynchronously in a Jupyter Notebook cell\n",
|
"# Run the function asynchronously in a Jupyter Notebook cell\n",
|
||||||
"await run_main(disable_safety=True)"
|
"await run_main(disable_safety=True)\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
|
|
@ -50,8 +50,8 @@
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"HOST = \"localhost\" # Replace with your host\n",
|
"HOST = \"localhost\" # Replace with your host\n",
|
||||||
"PORT = 5001 # Replace with your port\n",
|
"PORT = 5001 # Replace with your port\n",
|
||||||
"MODEL_NAME='meta-llama/Llama-3.2-3B-Instruct'"
|
"MODEL_NAME = \"meta-llama/Llama-3.2-3B-Instruct\"\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -60,10 +60,12 @@
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"from dotenv import load_dotenv\n",
|
|
||||||
"import os\n",
|
"import os\n",
|
||||||
|
"\n",
|
||||||
|
"from dotenv import load_dotenv\n",
|
||||||
|
"\n",
|
||||||
"load_dotenv()\n",
|
"load_dotenv()\n",
|
||||||
"BRAVE_SEARCH_API_KEY = os.environ['BRAVE_SEARCH_API_KEY']"
|
"BRAVE_SEARCH_API_KEY = os.environ[\"BRAVE_SEARCH_API_KEY\"]\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -104,20 +106,22 @@
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"import os\n",
|
"import os\n",
|
||||||
|
"\n",
|
||||||
"from llama_stack_client import LlamaStackClient\n",
|
"from llama_stack_client import LlamaStackClient\n",
|
||||||
"from llama_stack_client.lib.agents.agent import Agent\n",
|
"from llama_stack_client.lib.agents.agent import Agent\n",
|
||||||
"from llama_stack_client.lib.agents.event_logger import EventLogger\n",
|
"from llama_stack_client.lib.agents.event_logger import EventLogger\n",
|
||||||
"from llama_stack_client.types.agent_create_params import AgentConfig\n",
|
"from llama_stack_client.types.agent_create_params import AgentConfig\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
"\n",
|
||||||
"async def agent_example():\n",
|
"async def agent_example():\n",
|
||||||
" client = LlamaStackClient(base_url=f\"http://{HOST}:{PORT}\")\n",
|
" client = LlamaStackClient(base_url=f\"http://{HOST}:{PORT}\")\n",
|
||||||
" agent_config = AgentConfig(\n",
|
" agent_config = AgentConfig(\n",
|
||||||
" model=MODEL_NAME,\n",
|
" model=MODEL_NAME,\n",
|
||||||
" instructions=\"You are a helpful assistant! If you call builtin tools like brave search, follow the syntax brave_search.call(…)\",\n",
|
" instructions=\"You are a helpful assistant! If you call builtin tools like brave search, follow the syntax brave_search.call(…)\",\n",
|
||||||
" sampling_params={\n",
|
" sampling_params={\n",
|
||||||
" \"strategy\": \"greedy\",\n",
|
" \"strategy\": {\n",
|
||||||
" \"temperature\": 1.0,\n",
|
" \"type\": \"greedy\",\n",
|
||||||
" \"top_p\": 0.9,\n",
|
" },\n",
|
||||||
" },\n",
|
" },\n",
|
||||||
" tools=[\n",
|
" tools=[\n",
|
||||||
" {\n",
|
" {\n",
|
||||||
|
@ -157,7 +161,7 @@
|
||||||
" log.print()\n",
|
" log.print()\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"await agent_example()"
|
"await agent_example()\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|
|
@ -157,7 +157,15 @@ curl http://localhost:$LLAMA_STACK_PORT/alpha/inference/chat-completion
|
||||||
{"role": "system", "content": "You are a helpful assistant."},
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
{"role": "user", "content": "Write me a 2-sentence poem about the moon"}
|
{"role": "user", "content": "Write me a 2-sentence poem about the moon"}
|
||||||
],
|
],
|
||||||
"sampling_params": {"temperature": 0.7, "seed": 42, "max_tokens": 512}
|
"sampling_params": {
|
||||||
|
"strategy": {
|
||||||
|
"type": "top_p",
|
||||||
|
"temperatrue": 0.7,
|
||||||
|
"top_p": 0.95,
|
||||||
|
},
|
||||||
|
"seed": 42,
|
||||||
|
"max_tokens": 512
|
||||||
|
}
|
||||||
}
|
}
|
||||||
EOF
|
EOF
|
||||||
```
|
```
|
||||||
|
|
|
@ -83,8 +83,8 @@
|
||||||
},
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"LLAMA_STACK_API_TOGETHER_URL=\"https://llama-stack.together.ai\"\n",
|
"LLAMA_STACK_API_TOGETHER_URL = \"https://llama-stack.together.ai\"\n",
|
||||||
"LLAMA31_8B_INSTRUCT = \"Llama3.1-8B-Instruct\""
|
"LLAMA31_8B_INSTRUCT = \"Llama3.1-8B-Instruct\"\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -107,12 +107,13 @@
|
||||||
" AgentConfigToolSearchToolDefinition,\n",
|
" AgentConfigToolSearchToolDefinition,\n",
|
||||||
")\n",
|
")\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
"\n",
|
||||||
"# Helper function to create an agent with tools\n",
|
"# Helper function to create an agent with tools\n",
|
||||||
"async def create_tool_agent(\n",
|
"async def create_tool_agent(\n",
|
||||||
" client: LlamaStackClient,\n",
|
" client: LlamaStackClient,\n",
|
||||||
" tools: List[Dict],\n",
|
" tools: List[Dict],\n",
|
||||||
" instructions: str = \"You are a helpful assistant\",\n",
|
" instructions: str = \"You are a helpful assistant\",\n",
|
||||||
" model: str = LLAMA31_8B_INSTRUCT\n",
|
" model: str = LLAMA31_8B_INSTRUCT,\n",
|
||||||
") -> Agent:\n",
|
") -> Agent:\n",
|
||||||
" \"\"\"Create an agent with specified tools.\"\"\"\n",
|
" \"\"\"Create an agent with specified tools.\"\"\"\n",
|
||||||
" print(\"Using the following model: \", model)\n",
|
" print(\"Using the following model: \", model)\n",
|
||||||
|
@ -120,9 +121,9 @@
|
||||||
" model=model,\n",
|
" model=model,\n",
|
||||||
" instructions=instructions,\n",
|
" instructions=instructions,\n",
|
||||||
" sampling_params={\n",
|
" sampling_params={\n",
|
||||||
" \"strategy\": \"greedy\",\n",
|
" \"strategy\": {\n",
|
||||||
" \"temperature\": 1.0,\n",
|
" \"type\": \"greedy\",\n",
|
||||||
" \"top_p\": 0.9,\n",
|
" },\n",
|
||||||
" },\n",
|
" },\n",
|
||||||
" tools=tools,\n",
|
" tools=tools,\n",
|
||||||
" tool_choice=\"auto\",\n",
|
" tool_choice=\"auto\",\n",
|
||||||
|
@ -130,7 +131,7 @@
|
||||||
" enable_session_persistence=True,\n",
|
" enable_session_persistence=True,\n",
|
||||||
" )\n",
|
" )\n",
|
||||||
"\n",
|
"\n",
|
||||||
" return Agent(client, agent_config)"
|
" return Agent(client, agent_config)\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -172,7 +173,8 @@
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"# comment this if you don't have a BRAVE_SEARCH_API_KEY\n",
|
"# comment this if you don't have a BRAVE_SEARCH_API_KEY\n",
|
||||||
"os.environ[\"BRAVE_SEARCH_API_KEY\"] = 'YOUR_BRAVE_SEARCH_API_KEY'\n",
|
"os.environ[\"BRAVE_SEARCH_API_KEY\"] = \"YOUR_BRAVE_SEARCH_API_KEY\"\n",
|
||||||
|
"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"async def create_search_agent(client: LlamaStackClient) -> Agent:\n",
|
"async def create_search_agent(client: LlamaStackClient) -> Agent:\n",
|
||||||
" \"\"\"Create an agent with Brave Search capability.\"\"\"\n",
|
" \"\"\"Create an agent with Brave Search capability.\"\"\"\n",
|
||||||
|
@ -186,8 +188,8 @@
|
||||||
"\n",
|
"\n",
|
||||||
" return await create_tool_agent(\n",
|
" return await create_tool_agent(\n",
|
||||||
" client=client,\n",
|
" client=client,\n",
|
||||||
" tools=[search_tool], # set this to [] if you don't have a BRAVE_SEARCH_API_KEY\n",
|
" tools=[search_tool], # set this to [] if you don't have a BRAVE_SEARCH_API_KEY\n",
|
||||||
" model = LLAMA31_8B_INSTRUCT,\n",
|
" model=LLAMA31_8B_INSTRUCT,\n",
|
||||||
" instructions=\"\"\"\n",
|
" instructions=\"\"\"\n",
|
||||||
" You are a research assistant that can search the web.\n",
|
" You are a research assistant that can search the web.\n",
|
||||||
" Always cite your sources with URLs when providing information.\n",
|
" Always cite your sources with URLs when providing information.\n",
|
||||||
|
@ -198,9 +200,10 @@
|
||||||
"\n",
|
"\n",
|
||||||
" SOURCES:\n",
|
" SOURCES:\n",
|
||||||
" - [Source title](URL)\n",
|
" - [Source title](URL)\n",
|
||||||
" \"\"\"\n",
|
" \"\"\",\n",
|
||||||
" )\n",
|
" )\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
"\n",
|
||||||
"# Example usage\n",
|
"# Example usage\n",
|
||||||
"async def search_example():\n",
|
"async def search_example():\n",
|
||||||
" client = LlamaStackClient(base_url=LLAMA_STACK_API_TOGETHER_URL)\n",
|
" client = LlamaStackClient(base_url=LLAMA_STACK_API_TOGETHER_URL)\n",
|
||||||
|
@ -212,7 +215,7 @@
|
||||||
" # Example queries\n",
|
" # Example queries\n",
|
||||||
" queries = [\n",
|
" queries = [\n",
|
||||||
" \"What are the latest developments in quantum computing?\",\n",
|
" \"What are the latest developments in quantum computing?\",\n",
|
||||||
" #\"Who won the most recent Super Bowl?\",\n",
|
" # \"Who won the most recent Super Bowl?\",\n",
|
||||||
" ]\n",
|
" ]\n",
|
||||||
"\n",
|
"\n",
|
||||||
" for query in queries:\n",
|
" for query in queries:\n",
|
||||||
|
@ -227,8 +230,9 @@
|
||||||
" async for log in EventLogger().log(response):\n",
|
" async for log in EventLogger().log(response):\n",
|
||||||
" log.print()\n",
|
" log.print()\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
"\n",
|
||||||
"# Run the example (in Jupyter, use asyncio.run())\n",
|
"# Run the example (in Jupyter, use asyncio.run())\n",
|
||||||
"await search_example()"
|
"await search_example()\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -286,12 +290,16 @@
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"from typing import TypedDict, Optional, Dict, Any\n",
|
|
||||||
"from datetime import datetime\n",
|
|
||||||
"import json\n",
|
"import json\n",
|
||||||
"from llama_stack_client.types.tool_param_definition_param import ToolParamDefinitionParam\n",
|
"from datetime import datetime\n",
|
||||||
"from llama_stack_client.types import CompletionMessage,ToolResponseMessage\n",
|
"from typing import Any, Dict, Optional, TypedDict\n",
|
||||||
|
"\n",
|
||||||
"from llama_stack_client.lib.agents.custom_tool import CustomTool\n",
|
"from llama_stack_client.lib.agents.custom_tool import CustomTool\n",
|
||||||
|
"from llama_stack_client.types import CompletionMessage, ToolResponseMessage\n",
|
||||||
|
"from llama_stack_client.types.tool_param_definition_param import (\n",
|
||||||
|
" ToolParamDefinitionParam,\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"class WeatherTool(CustomTool):\n",
|
"class WeatherTool(CustomTool):\n",
|
||||||
" \"\"\"Example custom tool for weather information.\"\"\"\n",
|
" \"\"\"Example custom tool for weather information.\"\"\"\n",
|
||||||
|
@ -305,16 +313,15 @@
|
||||||
" def get_params_definition(self) -> Dict[str, ToolParamDefinitionParam]:\n",
|
" def get_params_definition(self) -> Dict[str, ToolParamDefinitionParam]:\n",
|
||||||
" return {\n",
|
" return {\n",
|
||||||
" \"location\": ToolParamDefinitionParam(\n",
|
" \"location\": ToolParamDefinitionParam(\n",
|
||||||
" param_type=\"str\",\n",
|
" param_type=\"str\", description=\"City or location name\", required=True\n",
|
||||||
" description=\"City or location name\",\n",
|
|
||||||
" required=True\n",
|
|
||||||
" ),\n",
|
" ),\n",
|
||||||
" \"date\": ToolParamDefinitionParam(\n",
|
" \"date\": ToolParamDefinitionParam(\n",
|
||||||
" param_type=\"str\",\n",
|
" param_type=\"str\",\n",
|
||||||
" description=\"Optional date (YYYY-MM-DD)\",\n",
|
" description=\"Optional date (YYYY-MM-DD)\",\n",
|
||||||
" required=False\n",
|
" required=False,\n",
|
||||||
" )\n",
|
" ),\n",
|
||||||
" }\n",
|
" }\n",
|
||||||
|
"\n",
|
||||||
" async def run(self, messages: List[CompletionMessage]) -> List[ToolResponseMessage]:\n",
|
" async def run(self, messages: List[CompletionMessage]) -> List[ToolResponseMessage]:\n",
|
||||||
" assert len(messages) == 1, \"Expected single message\"\n",
|
" assert len(messages) == 1, \"Expected single message\"\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
@ -337,20 +344,14 @@
|
||||||
" )\n",
|
" )\n",
|
||||||
" return [message]\n",
|
" return [message]\n",
|
||||||
"\n",
|
"\n",
|
||||||
" async def run_impl(self, location: str, date: Optional[str] = None) -> Dict[str, Any]:\n",
|
" async def run_impl(\n",
|
||||||
|
" self, location: str, date: Optional[str] = None\n",
|
||||||
|
" ) -> Dict[str, Any]:\n",
|
||||||
" \"\"\"Simulate getting weather data (replace with actual API call).\"\"\"\n",
|
" \"\"\"Simulate getting weather data (replace with actual API call).\"\"\"\n",
|
||||||
" # Mock implementation\n",
|
" # Mock implementation\n",
|
||||||
" if date:\n",
|
" if date:\n",
|
||||||
" return {\n",
|
" return {\"temperature\": 90.1, \"conditions\": \"sunny\", \"humidity\": 40.0}\n",
|
||||||
" \"temperature\": 90.1,\n",
|
" return {\"temperature\": 72.5, \"conditions\": \"partly cloudy\", \"humidity\": 65.0}\n",
|
||||||
" \"conditions\": \"sunny\",\n",
|
|
||||||
" \"humidity\": 40.0\n",
|
|
||||||
" }\n",
|
|
||||||
" return {\n",
|
|
||||||
" \"temperature\": 72.5,\n",
|
|
||||||
" \"conditions\": \"partly cloudy\",\n",
|
|
||||||
" \"humidity\": 65.0\n",
|
|
||||||
" }\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"async def create_weather_agent(client: LlamaStackClient) -> Agent:\n",
|
"async def create_weather_agent(client: LlamaStackClient) -> Agent:\n",
|
||||||
|
@ -358,38 +359,33 @@
|
||||||
"\n",
|
"\n",
|
||||||
" # Create the agent with the tool\n",
|
" # Create the agent with the tool\n",
|
||||||
" weather_tool = WeatherTool()\n",
|
" weather_tool = WeatherTool()\n",
|
||||||
" \n",
|
"\n",
|
||||||
" agent_config = AgentConfig(\n",
|
" agent_config = AgentConfig(\n",
|
||||||
" model=LLAMA31_8B_INSTRUCT,\n",
|
" model=LLAMA31_8B_INSTRUCT,\n",
|
||||||
" #model=model_name,\n",
|
" # model=model_name,\n",
|
||||||
" instructions=\"\"\"\n",
|
" instructions=\"\"\"\n",
|
||||||
" You are a weather assistant that can provide weather information.\n",
|
" You are a weather assistant that can provide weather information.\n",
|
||||||
" Always specify the location clearly in your responses.\n",
|
" Always specify the location clearly in your responses.\n",
|
||||||
" Include both temperature and conditions in your summaries.\n",
|
" Include both temperature and conditions in your summaries.\n",
|
||||||
" \"\"\",\n",
|
" \"\"\",\n",
|
||||||
" sampling_params={\n",
|
" sampling_params={\n",
|
||||||
" \"strategy\": \"greedy\",\n",
|
" \"strategy\": {\n",
|
||||||
" \"temperature\": 1.0,\n",
|
" \"type\": \"greedy\",\n",
|
||||||
" \"top_p\": 0.9,\n",
|
" },\n",
|
||||||
" },\n",
|
" },\n",
|
||||||
" tools=[\n",
|
" tools=[weather_tool.get_tool_definition()],\n",
|
||||||
" weather_tool.get_tool_definition()\n",
|
|
||||||
" ],\n",
|
|
||||||
" tool_choice=\"auto\",\n",
|
" tool_choice=\"auto\",\n",
|
||||||
" tool_prompt_format=\"json\",\n",
|
" tool_prompt_format=\"json\",\n",
|
||||||
" input_shields=[],\n",
|
" input_shields=[],\n",
|
||||||
" output_shields=[],\n",
|
" output_shields=[],\n",
|
||||||
" enable_session_persistence=True\n",
|
" enable_session_persistence=True,\n",
|
||||||
" )\n",
|
" )\n",
|
||||||
"\n",
|
"\n",
|
||||||
" agent = Agent(\n",
|
" agent = Agent(client=client, agent_config=agent_config, custom_tools=[weather_tool])\n",
|
||||||
" client=client,\n",
|
|
||||||
" agent_config=agent_config,\n",
|
|
||||||
" custom_tools=[weather_tool]\n",
|
|
||||||
" )\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
" return agent\n",
|
" return agent\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
"\n",
|
||||||
"# Example usage\n",
|
"# Example usage\n",
|
||||||
"async def weather_example():\n",
|
"async def weather_example():\n",
|
||||||
" client = LlamaStackClient(base_url=LLAMA_STACK_API_TOGETHER_URL)\n",
|
" client = LlamaStackClient(base_url=LLAMA_STACK_API_TOGETHER_URL)\n",
|
||||||
|
@ -413,12 +409,14 @@
|
||||||
" async for log in EventLogger().log(response):\n",
|
" async for log in EventLogger().log(response):\n",
|
||||||
" log.print()\n",
|
" log.print()\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
"\n",
|
||||||
"# For Jupyter notebooks\n",
|
"# For Jupyter notebooks\n",
|
||||||
"import nest_asyncio\n",
|
"import nest_asyncio\n",
|
||||||
|
"\n",
|
||||||
"nest_asyncio.apply()\n",
|
"nest_asyncio.apply()\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Run the example\n",
|
"# Run the example\n",
|
||||||
"await weather_example()"
|
"await weather_example()\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|
|
@ -13,7 +13,6 @@ from termcolor import colored
|
||||||
|
|
||||||
from llama_stack.cli.subcommand import Subcommand
|
from llama_stack.cli.subcommand import Subcommand
|
||||||
from llama_stack.cli.table import print_table
|
from llama_stack.cli.table import print_table
|
||||||
from llama_stack.distribution.utils.serialize import EnumEncoder
|
|
||||||
|
|
||||||
|
|
||||||
class ModelDescribe(Subcommand):
|
class ModelDescribe(Subcommand):
|
||||||
|
@ -72,7 +71,7 @@ class ModelDescribe(Subcommand):
|
||||||
rows.append(
|
rows.append(
|
||||||
(
|
(
|
||||||
"Recommended sampling params",
|
"Recommended sampling params",
|
||||||
json.dumps(sampling_params, cls=EnumEncoder, indent=4),
|
json.dumps(sampling_params, indent=4),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -58,11 +58,6 @@ def define_eval_candidate_2():
|
||||||
|
|
||||||
# Sampling Parameters
|
# Sampling Parameters
|
||||||
st.markdown("##### Sampling Parameters")
|
st.markdown("##### Sampling Parameters")
|
||||||
strategy = st.selectbox(
|
|
||||||
"Strategy",
|
|
||||||
["greedy", "top_p", "top_k"],
|
|
||||||
index=0,
|
|
||||||
)
|
|
||||||
temperature = st.slider(
|
temperature = st.slider(
|
||||||
"Temperature",
|
"Temperature",
|
||||||
min_value=0.0,
|
min_value=0.0,
|
||||||
|
@ -95,13 +90,20 @@ def define_eval_candidate_2():
|
||||||
help="Controls the likelihood for generating the same word or phrase multiple times in the same sentence or paragraph. 1 implies no penalty, 2 will strongly discourage model to repeat words or phrases.",
|
help="Controls the likelihood for generating the same word or phrase multiple times in the same sentence or paragraph. 1 implies no penalty, 2 will strongly discourage model to repeat words or phrases.",
|
||||||
)
|
)
|
||||||
if candidate_type == "model":
|
if candidate_type == "model":
|
||||||
|
if temperature > 0.0:
|
||||||
|
strategy = {
|
||||||
|
"type": "top_p",
|
||||||
|
"temperature": temperature,
|
||||||
|
"top_p": top_p,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
strategy = {"type": "greedy"}
|
||||||
|
|
||||||
eval_candidate = {
|
eval_candidate = {
|
||||||
"type": "model",
|
"type": "model",
|
||||||
"model": selected_model,
|
"model": selected_model,
|
||||||
"sampling_params": {
|
"sampling_params": {
|
||||||
"strategy": strategy,
|
"strategy": strategy,
|
||||||
"temperature": temperature,
|
|
||||||
"top_p": top_p,
|
|
||||||
"max_tokens": max_tokens,
|
"max_tokens": max_tokens,
|
||||||
"repetition_penalty": repetition_penalty,
|
"repetition_penalty": repetition_penalty,
|
||||||
},
|
},
|
||||||
|
|
|
@ -95,6 +95,15 @@ if prompt := st.chat_input("Example: What is Llama Stack?"):
|
||||||
message_placeholder = st.empty()
|
message_placeholder = st.empty()
|
||||||
full_response = ""
|
full_response = ""
|
||||||
|
|
||||||
|
if temperature > 0.0:
|
||||||
|
strategy = {
|
||||||
|
"type": "top_p",
|
||||||
|
"temperature": temperature,
|
||||||
|
"top_p": top_p,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
strategy = {"type": "greedy"}
|
||||||
|
|
||||||
response = llama_stack_api.client.inference.chat_completion(
|
response = llama_stack_api.client.inference.chat_completion(
|
||||||
messages=[
|
messages=[
|
||||||
{"role": "system", "content": system_prompt},
|
{"role": "system", "content": system_prompt},
|
||||||
|
@ -103,8 +112,7 @@ if prompt := st.chat_input("Example: What is Llama Stack?"):
|
||||||
model_id=selected_model,
|
model_id=selected_model,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
sampling_params={
|
sampling_params={
|
||||||
"temperature": temperature,
|
"strategy": strategy,
|
||||||
"top_p": top_p,
|
|
||||||
"max_tokens": max_tokens,
|
"max_tokens": max_tokens,
|
||||||
"repetition_penalty": repetition_penalty,
|
"repetition_penalty": repetition_penalty,
|
||||||
},
|
},
|
||||||
|
|
|
@ -118,13 +118,20 @@ def rag_chat_page():
|
||||||
with st.chat_message(message["role"]):
|
with st.chat_message(message["role"]):
|
||||||
st.markdown(message["content"])
|
st.markdown(message["content"])
|
||||||
|
|
||||||
|
if temperature > 0.0:
|
||||||
|
strategy = {
|
||||||
|
"type": "top_p",
|
||||||
|
"temperature": temperature,
|
||||||
|
"top_p": top_p,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
strategy = {"type": "greedy"}
|
||||||
|
|
||||||
agent_config = AgentConfig(
|
agent_config = AgentConfig(
|
||||||
model=selected_model,
|
model=selected_model,
|
||||||
instructions=system_prompt,
|
instructions=system_prompt,
|
||||||
sampling_params={
|
sampling_params={
|
||||||
"strategy": "greedy",
|
"strategy": strategy,
|
||||||
"temperature": temperature,
|
|
||||||
"top_p": top_p,
|
|
||||||
},
|
},
|
||||||
tools=[
|
tools=[
|
||||||
{
|
{
|
||||||
|
|
|
@ -23,6 +23,11 @@ from fairscale.nn.model_parallel.initialize import (
|
||||||
initialize_model_parallel,
|
initialize_model_parallel,
|
||||||
model_parallel_is_initialized,
|
model_parallel_is_initialized,
|
||||||
)
|
)
|
||||||
|
from llama_models.datatypes import (
|
||||||
|
GreedySamplingStrategy,
|
||||||
|
SamplingParams,
|
||||||
|
TopPSamplingStrategy,
|
||||||
|
)
|
||||||
from llama_models.llama3.api.args import ModelArgs
|
from llama_models.llama3.api.args import ModelArgs
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat, LLMInput
|
from llama_models.llama3.api.chat_format import ChatFormat, LLMInput
|
||||||
from llama_models.llama3.api.datatypes import Model
|
from llama_models.llama3.api.datatypes import Model
|
||||||
|
@ -363,11 +368,12 @@ class Llama:
|
||||||
max_gen_len = self.model.params.max_seq_len - 1
|
max_gen_len = self.model.params.max_seq_len - 1
|
||||||
|
|
||||||
model_input = self.formatter.encode_content(request.content)
|
model_input = self.formatter.encode_content(request.content)
|
||||||
|
temperature, top_p = _infer_sampling_params(sampling_params)
|
||||||
yield from self.generate(
|
yield from self.generate(
|
||||||
model_input=model_input,
|
model_input=model_input,
|
||||||
max_gen_len=max_gen_len,
|
max_gen_len=max_gen_len,
|
||||||
temperature=sampling_params.temperature,
|
temperature=temperature,
|
||||||
top_p=sampling_params.top_p,
|
top_p=top_p,
|
||||||
logprobs=bool(request.logprobs),
|
logprobs=bool(request.logprobs),
|
||||||
include_stop_token=True,
|
include_stop_token=True,
|
||||||
logits_processor=get_logits_processor(
|
logits_processor=get_logits_processor(
|
||||||
|
@ -390,14 +396,15 @@ class Llama:
|
||||||
):
|
):
|
||||||
max_gen_len = self.model.params.max_seq_len - 1
|
max_gen_len = self.model.params.max_seq_len - 1
|
||||||
|
|
||||||
|
temperature, top_p = _infer_sampling_params(sampling_params)
|
||||||
yield from self.generate(
|
yield from self.generate(
|
||||||
model_input=self.formatter.encode_dialog_prompt(
|
model_input=self.formatter.encode_dialog_prompt(
|
||||||
request.messages,
|
request.messages,
|
||||||
request.tool_prompt_format,
|
request.tool_prompt_format,
|
||||||
),
|
),
|
||||||
max_gen_len=max_gen_len,
|
max_gen_len=max_gen_len,
|
||||||
temperature=sampling_params.temperature,
|
temperature=temperature,
|
||||||
top_p=sampling_params.top_p,
|
top_p=top_p,
|
||||||
logprobs=bool(request.logprobs),
|
logprobs=bool(request.logprobs),
|
||||||
include_stop_token=True,
|
include_stop_token=True,
|
||||||
logits_processor=get_logits_processor(
|
logits_processor=get_logits_processor(
|
||||||
|
@ -492,3 +499,15 @@ def _build_regular_tokens_list(
|
||||||
is_word_start_token = len(decoded_after_0) > len(decoded_regular)
|
is_word_start_token = len(decoded_after_0) > len(decoded_regular)
|
||||||
regular_tokens.append((token_idx, decoded_after_0, is_word_start_token))
|
regular_tokens.append((token_idx, decoded_after_0, is_word_start_token))
|
||||||
return regular_tokens
|
return regular_tokens
|
||||||
|
|
||||||
|
|
||||||
|
def _infer_sampling_params(sampling_params: SamplingParams):
|
||||||
|
if isinstance(sampling_params.strategy, GreedySamplingStrategy):
|
||||||
|
temperature = 0.0
|
||||||
|
top_p = 1.0
|
||||||
|
elif isinstance(sampling_params.strategy, TopPSamplingStrategy):
|
||||||
|
temperature = sampling_params.strategy.temperature
|
||||||
|
top_p = sampling_params.strategy.top_p
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported sampling strategy {sampling_params.strategy}")
|
||||||
|
return temperature, top_p
|
||||||
|
|
|
@ -36,6 +36,7 @@ from llama_stack.apis.inference import (
|
||||||
from llama_stack.apis.models import Model
|
from llama_stack.apis.models import Model
|
||||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||||
from llama_stack.providers.utils.inference.openai_compat import (
|
from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
|
get_sampling_options,
|
||||||
OpenAICompatCompletionChoice,
|
OpenAICompatCompletionChoice,
|
||||||
OpenAICompatCompletionResponse,
|
OpenAICompatCompletionResponse,
|
||||||
process_chat_completion_response,
|
process_chat_completion_response,
|
||||||
|
@ -126,21 +127,12 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
if sampling_params is None:
|
if sampling_params is None:
|
||||||
return VLLMSamplingParams(max_tokens=self.config.max_tokens)
|
return VLLMSamplingParams(max_tokens=self.config.max_tokens)
|
||||||
|
|
||||||
# TODO convert what I saw in my first test ... but surely there's more to do here
|
options = get_sampling_options(sampling_params)
|
||||||
kwargs = {
|
if "repeat_penalty" in options:
|
||||||
"temperature": sampling_params.temperature,
|
options["repetition_penalty"] = options["repeat_penalty"]
|
||||||
"max_tokens": self.config.max_tokens,
|
del options["repeat_penalty"]
|
||||||
}
|
|
||||||
if sampling_params.top_k:
|
|
||||||
kwargs["top_k"] = sampling_params.top_k
|
|
||||||
if sampling_params.top_p:
|
|
||||||
kwargs["top_p"] = sampling_params.top_p
|
|
||||||
if sampling_params.max_tokens:
|
|
||||||
kwargs["max_tokens"] = sampling_params.max_tokens
|
|
||||||
if sampling_params.repetition_penalty > 0:
|
|
||||||
kwargs["repetition_penalty"] = sampling_params.repetition_penalty
|
|
||||||
|
|
||||||
return VLLMSamplingParams(**kwargs)
|
return VLLMSamplingParams(**options)
|
||||||
|
|
||||||
async def unregister_model(self, model_id: str) -> None:
|
async def unregister_model(self, model_id: str) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -34,6 +34,7 @@ from llama_stack.providers.utils.inference.model_registry import (
|
||||||
ModelRegistryHelper,
|
ModelRegistryHelper,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.utils.inference.openai_compat import (
|
from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
|
get_sampling_strategy_options,
|
||||||
OpenAICompatCompletionChoice,
|
OpenAICompatCompletionChoice,
|
||||||
OpenAICompatCompletionResponse,
|
OpenAICompatCompletionResponse,
|
||||||
process_chat_completion_response,
|
process_chat_completion_response,
|
||||||
|
@ -166,16 +167,13 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
bedrock_model = request.model
|
bedrock_model = request.model
|
||||||
|
|
||||||
inference_config = {}
|
sampling_params = request.sampling_params
|
||||||
param_mapping = {
|
options = get_sampling_strategy_options(sampling_params)
|
||||||
"max_tokens": "max_gen_len",
|
|
||||||
"temperature": "temperature",
|
|
||||||
"top_p": "top_p",
|
|
||||||
}
|
|
||||||
|
|
||||||
for k, v in param_mapping.items():
|
if sampling_params.max_tokens:
|
||||||
if getattr(request.sampling_params, k):
|
options["max_gen_len"] = sampling_params.max_tokens
|
||||||
inference_config[v] = getattr(request.sampling_params, k)
|
if sampling_params.repetition_penalty > 0:
|
||||||
|
options["repetition_penalty"] = sampling_params.repetition_penalty
|
||||||
|
|
||||||
prompt = await chat_completion_request_to_prompt(
|
prompt = await chat_completion_request_to_prompt(
|
||||||
request, self.get_llama_model(request.model), self.formatter
|
request, self.get_llama_model(request.model), self.formatter
|
||||||
|
@ -185,7 +183,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
"body": json.dumps(
|
"body": json.dumps(
|
||||||
{
|
{
|
||||||
"prompt": prompt,
|
"prompt": prompt,
|
||||||
**inference_config,
|
**options,
|
||||||
}
|
}
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,6 +9,7 @@ from typing import AsyncGenerator, List, Optional, Union
|
||||||
from cerebras.cloud.sdk import AsyncCerebras
|
from cerebras.cloud.sdk import AsyncCerebras
|
||||||
from llama_models.datatypes import CoreModelId
|
from llama_models.datatypes import CoreModelId
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
|
from llama_models.llama3.api.datatypes import TopKSamplingStrategy
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import InterleavedContent
|
from llama_stack.apis.common.content_types import InterleavedContent
|
||||||
|
@ -172,7 +173,9 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
async def _get_params(
|
async def _get_params(
|
||||||
self, request: Union[ChatCompletionRequest, CompletionRequest]
|
self, request: Union[ChatCompletionRequest, CompletionRequest]
|
||||||
) -> dict:
|
) -> dict:
|
||||||
if request.sampling_params and request.sampling_params.top_k:
|
if request.sampling_params and isinstance(
|
||||||
|
request.sampling_params.strategy, TopKSamplingStrategy
|
||||||
|
):
|
||||||
raise ValueError("`top_k` not supported by Cerebras")
|
raise ValueError("`top_k` not supported by Cerebras")
|
||||||
|
|
||||||
prompt = ""
|
prompt = ""
|
||||||
|
|
|
@ -48,6 +48,9 @@ from llama_stack.apis.inference import (
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
|
from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
|
get_sampling_strategy_options,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def convert_chat_completion_request(
|
def convert_chat_completion_request(
|
||||||
|
@ -77,6 +80,7 @@ def convert_chat_completion_request(
|
||||||
if request.tool_prompt_format != ToolPromptFormat.json:
|
if request.tool_prompt_format != ToolPromptFormat.json:
|
||||||
warnings.warn("tool_prompt_format is not used by Groq. Ignoring.")
|
warnings.warn("tool_prompt_format is not used by Groq. Ignoring.")
|
||||||
|
|
||||||
|
sampling_options = get_sampling_strategy_options(request.sampling_params)
|
||||||
return CompletionCreateParams(
|
return CompletionCreateParams(
|
||||||
model=request.model,
|
model=request.model,
|
||||||
messages=[_convert_message(message) for message in request.messages],
|
messages=[_convert_message(message) for message in request.messages],
|
||||||
|
@ -84,8 +88,8 @@ def convert_chat_completion_request(
|
||||||
frequency_penalty=None,
|
frequency_penalty=None,
|
||||||
stream=request.stream,
|
stream=request.stream,
|
||||||
max_tokens=request.sampling_params.max_tokens or None,
|
max_tokens=request.sampling_params.max_tokens or None,
|
||||||
temperature=request.sampling_params.temperature,
|
temperature=sampling_options.get("temperature", 1.0),
|
||||||
top_p=request.sampling_params.top_p,
|
top_p=sampling_options.get("top_p", 1.0),
|
||||||
tools=[_convert_groq_tool_definition(tool) for tool in request.tools or []],
|
tools=[_convert_groq_tool_definition(tool) for tool in request.tools or []],
|
||||||
tool_choice=request.tool_choice.value if request.tool_choice else None,
|
tool_choice=request.tool_choice.value if request.tool_choice else None,
|
||||||
)
|
)
|
||||||
|
|
|
@ -263,19 +263,18 @@ def convert_chat_completion_request(
|
||||||
if request.sampling_params.max_tokens:
|
if request.sampling_params.max_tokens:
|
||||||
payload.update(max_tokens=request.sampling_params.max_tokens)
|
payload.update(max_tokens=request.sampling_params.max_tokens)
|
||||||
|
|
||||||
if request.sampling_params.strategy == "top_p":
|
strategy = request.sampling_params.strategy
|
||||||
|
if isinstance(strategy, TopPSamplingStrategy):
|
||||||
nvext.update(top_k=-1)
|
nvext.update(top_k=-1)
|
||||||
payload.update(top_p=request.sampling_params.top_p)
|
payload.update(top_p=strategy.top_p)
|
||||||
elif request.sampling_params.strategy == "top_k":
|
payload.update(temperature=strategy.temperature)
|
||||||
if (
|
elif isinstance(strategy, TopKSamplingStrategy):
|
||||||
request.sampling_params.top_k != -1
|
if strategy.top_k != -1 and strategy.top_k < 1:
|
||||||
and request.sampling_params.top_k < 1
|
|
||||||
):
|
|
||||||
warnings.warn("top_k must be -1 or >= 1")
|
warnings.warn("top_k must be -1 or >= 1")
|
||||||
nvext.update(top_k=request.sampling_params.top_k)
|
nvext.update(top_k=strategy.top_k)
|
||||||
elif request.sampling_params.strategy == "greedy":
|
elif strategy.strategy == "greedy":
|
||||||
nvext.update(top_k=-1)
|
nvext.update(top_k=-1)
|
||||||
payload.update(temperature=request.sampling_params.temperature)
|
payload.update(temperature=strategy.temperature)
|
||||||
|
|
||||||
return payload
|
return payload
|
||||||
|
|
||||||
|
|
|
@ -22,7 +22,12 @@ from llama_stack.apis.agents import (
|
||||||
ToolExecutionStep,
|
ToolExecutionStep,
|
||||||
Turn,
|
Turn,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.inference import CompletionMessage, SamplingParams, UserMessage
|
from llama_stack.apis.inference import (
|
||||||
|
CompletionMessage,
|
||||||
|
SamplingParams,
|
||||||
|
TopPSamplingStrategy,
|
||||||
|
UserMessage,
|
||||||
|
)
|
||||||
from llama_stack.apis.safety import ViolationLevel
|
from llama_stack.apis.safety import ViolationLevel
|
||||||
from llama_stack.providers.datatypes import Api
|
from llama_stack.providers.datatypes import Api
|
||||||
|
|
||||||
|
@ -42,7 +47,9 @@ def common_params(inference_model):
|
||||||
model=inference_model,
|
model=inference_model,
|
||||||
instructions="You are a helpful assistant.",
|
instructions="You are a helpful assistant.",
|
||||||
enable_session_persistence=True,
|
enable_session_persistence=True,
|
||||||
sampling_params=SamplingParams(temperature=0.7, top_p=0.95),
|
sampling_params=SamplingParams(
|
||||||
|
strategy=TopPSamplingStrategy(temperature=0.7, top_p=0.95)
|
||||||
|
),
|
||||||
input_shields=[],
|
input_shields=[],
|
||||||
output_shields=[],
|
output_shields=[],
|
||||||
toolgroups=[],
|
toolgroups=[],
|
||||||
|
|
|
@ -21,6 +21,7 @@ from groq.types.chat.chat_completion_message_tool_call import (
|
||||||
Function,
|
Function,
|
||||||
)
|
)
|
||||||
from groq.types.shared.function_definition import FunctionDefinition
|
from groq.types.shared.function_definition import FunctionDefinition
|
||||||
|
from llama_models.datatypes import GreedySamplingStrategy, TopPSamplingStrategy
|
||||||
from llama_models.llama3.api.datatypes import ToolParamDefinition
|
from llama_models.llama3.api.datatypes import ToolParamDefinition
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
|
@ -152,21 +153,30 @@ class TestConvertChatCompletionRequest:
|
||||||
|
|
||||||
assert converted["max_tokens"] == 100
|
assert converted["max_tokens"] == 100
|
||||||
|
|
||||||
def test_includes_temperature(self):
|
def _dummy_chat_completion_request(self):
|
||||||
|
return ChatCompletionRequest(
|
||||||
|
model="Llama-3.2-3B",
|
||||||
|
messages=[UserMessage(content="Hello World")],
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_includes_stratgy(self):
|
||||||
request = self._dummy_chat_completion_request()
|
request = self._dummy_chat_completion_request()
|
||||||
request.sampling_params.temperature = 0.5
|
request.sampling_params.strategy = TopPSamplingStrategy(
|
||||||
|
temperature=0.5, top_p=0.95
|
||||||
|
)
|
||||||
|
|
||||||
converted = convert_chat_completion_request(request)
|
converted = convert_chat_completion_request(request)
|
||||||
|
|
||||||
assert converted["temperature"] == 0.5
|
assert converted["temperature"] == 0.5
|
||||||
|
assert converted["top_p"] == 0.95
|
||||||
|
|
||||||
def test_includes_top_p(self):
|
def test_includes_greedy_strategy(self):
|
||||||
request = self._dummy_chat_completion_request()
|
request = self._dummy_chat_completion_request()
|
||||||
request.sampling_params.top_p = 0.95
|
request.sampling_params.strategy = GreedySamplingStrategy()
|
||||||
|
|
||||||
converted = convert_chat_completion_request(request)
|
converted = convert_chat_completion_request(request)
|
||||||
|
|
||||||
assert converted["top_p"] == 0.95
|
assert converted["temperature"] == 0.0
|
||||||
|
|
||||||
def test_includes_tool_choice(self):
|
def test_includes_tool_choice(self):
|
||||||
request = self._dummy_chat_completion_request()
|
request = self._dummy_chat_completion_request()
|
||||||
|
|
|
@ -8,7 +8,13 @@ from typing import AsyncGenerator, List, Optional
|
||||||
|
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import SamplingParams, StopReason
|
from llama_models.llama3.api.datatypes import (
|
||||||
|
GreedySamplingStrategy,
|
||||||
|
SamplingParams,
|
||||||
|
StopReason,
|
||||||
|
TopKSamplingStrategy,
|
||||||
|
TopPSamplingStrategy,
|
||||||
|
)
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import (
|
from llama_stack.apis.common.content_types import (
|
||||||
|
@ -49,12 +55,26 @@ class OpenAICompatCompletionResponse(BaseModel):
|
||||||
choices: List[OpenAICompatCompletionChoice]
|
choices: List[OpenAICompatCompletionChoice]
|
||||||
|
|
||||||
|
|
||||||
|
def get_sampling_strategy_options(params: SamplingParams) -> dict:
|
||||||
|
options = {}
|
||||||
|
if isinstance(params.strategy, GreedySamplingStrategy):
|
||||||
|
options["temperature"] = 0.0
|
||||||
|
elif isinstance(params.strategy, TopPSamplingStrategy):
|
||||||
|
options["temperature"] = params.strategy.temperature
|
||||||
|
options["top_p"] = params.strategy.top_p
|
||||||
|
elif isinstance(params.strategy, TopKSamplingStrategy):
|
||||||
|
options["top_k"] = params.strategy.top_k
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported sampling strategy: {params.strategy}")
|
||||||
|
|
||||||
|
return options
|
||||||
|
|
||||||
|
|
||||||
def get_sampling_options(params: SamplingParams) -> dict:
|
def get_sampling_options(params: SamplingParams) -> dict:
|
||||||
options = {}
|
options = {}
|
||||||
if params:
|
if params:
|
||||||
for attr in {"temperature", "top_p", "top_k", "max_tokens"}:
|
options.update(get_sampling_strategy_options(params))
|
||||||
if getattr(params, attr):
|
options["max_tokens"] = params.max_tokens
|
||||||
options[attr] = getattr(params, attr)
|
|
||||||
|
|
||||||
if params.repetition_penalty is not None and params.repetition_penalty != 1.0:
|
if params.repetition_penalty is not None and params.repetition_penalty != 1.0:
|
||||||
options["repeat_penalty"] = params.repetition_penalty
|
options["repeat_penalty"] = params.repetition_penalty
|
||||||
|
|
|
@ -97,9 +97,11 @@ def agent_config(llama_stack_client):
|
||||||
model=model_id,
|
model=model_id,
|
||||||
instructions="You are a helpful assistant",
|
instructions="You are a helpful assistant",
|
||||||
sampling_params={
|
sampling_params={
|
||||||
"strategy": "greedy",
|
"strategy": {
|
||||||
"temperature": 1.0,
|
"type": "greedy",
|
||||||
"top_p": 0.9,
|
"temperature": 1.0,
|
||||||
|
"top_p": 0.9,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
toolgroups=[],
|
toolgroups=[],
|
||||||
tool_choice="auto",
|
tool_choice="auto",
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue