Convert SamplingParams.strategy to a union (#767)

# What does this PR do?

Cleans up how we provide sampling params. Earlier, strategy was an enum
and all params (top_p, temperature, top_k) across all strategies were
grouped. We now have a strategy union object with each strategy (greedy,
top_p, top_k) having its corresponding params.
Earlier, 
```
class SamplingParams: 
    strategy: enum ()
    top_p, temperature, top_k and other params
```
However, the `strategy` field was not being used in any providers making
it confusing to know the exact sampling behavior purely based on the
params since you could pass temperature, top_p, top_k and how the
provider would interpret those would not be clear.

Hence we introduced -- a union where the strategy and relevant params
are all clubbed together to avoid this confusion.

Have updated all providers, tests, notebooks, readme and otehr places
where sampling params was being used to use the new format.
   

## Test Plan
`pytest llama_stack/providers/tests/inference/groq/test_groq_utils.py`
// inference on ollama, fireworks and together 
`with-proxy pytest -v -s -k "ollama"
--inference-model="meta-llama/Llama-3.1-8B-Instruct"
llama_stack/providers/tests/inference/test_text_inference.py `
// agents on fireworks 
`pytest -v -s -k 'fireworks and create_agent'
--inference-model="meta-llama/Llama-3.1-8B-Instruct"
llama_stack/providers/tests/agents/test_agents.py
--safety-shield="meta-llama/Llama-Guard-3-8B"`

## Before submitting

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [X] Ran pre-commit to handle lint / formatting issues.
- [X] Read the [contributor
guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md),
      Pull Request section?
- [X] Updated relevant documentation.
- [X] Wrote necessary unit or integration tests.

---------

Co-authored-by: Hardik Shah <hjshah@fb.com>
This commit is contained in:
Hardik Shah 2025-01-15 05:38:51 -08:00 committed by GitHub
parent 300e6e2702
commit a51c8b4efc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
29 changed files with 611 additions and 388 deletions

View file

@ -713,13 +713,15 @@
], ],
"source": [ "source": [
"import os\n", "import os\n",
"\n",
"from google.colab import userdata\n", "from google.colab import userdata\n",
"\n", "\n",
"os.environ['TOGETHER_API_KEY'] = userdata.get('TOGETHER_API_KEY')\n", "os.environ[\"TOGETHER_API_KEY\"] = userdata.get(\"TOGETHER_API_KEY\")\n",
"\n", "\n",
"from llama_stack.distribution.library_client import LlamaStackAsLibraryClient\n", "from llama_stack.distribution.library_client import LlamaStackAsLibraryClient\n",
"\n",
"client = LlamaStackAsLibraryClient(\"together\")\n", "client = LlamaStackAsLibraryClient(\"together\")\n",
"_ = client.initialize()" "_ = client.initialize()\n"
] ]
}, },
{ {
@ -769,6 +771,7 @@
], ],
"source": [ "source": [
"from rich.pretty import pprint\n", "from rich.pretty import pprint\n",
"\n",
"print(\"Available models:\")\n", "print(\"Available models:\")\n",
"for m in client.models.list():\n", "for m in client.models.list():\n",
" print(f\"{m.identifier} (provider's alias: {m.provider_resource_id}) \")\n", " print(f\"{m.identifier} (provider's alias: {m.provider_resource_id}) \")\n",
@ -777,7 +780,7 @@
"print(\"Available shields (safety models):\")\n", "print(\"Available shields (safety models):\")\n",
"for s in client.shields.list():\n", "for s in client.shields.list():\n",
" print(s.identifier)\n", " print(s.identifier)\n",
"print(\"----\")" "print(\"----\")\n"
] ]
}, },
{ {
@ -822,7 +825,7 @@
"source": [ "source": [
"model_id = \"meta-llama/Llama-3.1-70B-Instruct\"\n", "model_id = \"meta-llama/Llama-3.1-70B-Instruct\"\n",
"\n", "\n",
"model_id" "model_id\n"
] ]
}, },
{ {
@ -863,11 +866,11 @@
" model_id=model_id,\n", " model_id=model_id,\n",
" messages=[\n", " messages=[\n",
" {\"role\": \"system\", \"content\": \"You are a friendly assistant.\"},\n", " {\"role\": \"system\", \"content\": \"You are a friendly assistant.\"},\n",
" {\"role\": \"user\", \"content\": \"Write a two-sentence poem about llama.\"}\n", " {\"role\": \"user\", \"content\": \"Write a two-sentence poem about llama.\"},\n",
" ],\n", " ],\n",
")\n", ")\n",
"\n", "\n",
"print(response.completion_message.content)" "print(response.completion_message.content)\n"
] ]
}, },
{ {
@ -900,12 +903,13 @@
"source": [ "source": [
"from termcolor import cprint\n", "from termcolor import cprint\n",
"\n", "\n",
"\n",
"def chat_loop():\n", "def chat_loop():\n",
" conversation_history = []\n", " conversation_history = []\n",
" while True:\n", " while True:\n",
" user_input = input('User> ')\n", " user_input = input(\"User> \")\n",
" if user_input.lower() in ['exit', 'quit', 'bye']:\n", " if user_input.lower() in [\"exit\", \"quit\", \"bye\"]:\n",
" cprint('Ending conversation. Goodbye!', 'yellow')\n", " cprint(\"Ending conversation. Goodbye!\", \"yellow\")\n",
" break\n", " break\n",
"\n", "\n",
" user_message = {\"role\": \"user\", \"content\": user_input}\n", " user_message = {\"role\": \"user\", \"content\": user_input}\n",
@ -915,14 +919,15 @@
" messages=conversation_history,\n", " messages=conversation_history,\n",
" model_id=model_id,\n", " model_id=model_id,\n",
" )\n", " )\n",
" cprint(f'> Response: {response.completion_message.content}', 'cyan')\n", " cprint(f\"> Response: {response.completion_message.content}\", \"cyan\")\n",
"\n", "\n",
" assistant_message = {\n", " assistant_message = {\n",
" \"role\": \"assistant\", # was user\n", " \"role\": \"assistant\", # was user\n",
" \"content\": response.completion_message.content,\n", " \"content\": response.completion_message.content,\n",
" }\n", " }\n",
" conversation_history.append(assistant_message)\n", " conversation_history.append(assistant_message)\n",
"\n", "\n",
"\n",
"chat_loop()\n" "chat_loop()\n"
] ]
}, },
@ -978,21 +983,18 @@
"source": [ "source": [
"from llama_stack_client.lib.inference.event_logger import EventLogger\n", "from llama_stack_client.lib.inference.event_logger import EventLogger\n",
"\n", "\n",
"message = {\n", "message = {\"role\": \"user\", \"content\": \"Write me a sonnet about llama\"}\n",
" \"role\": \"user\",\n", "print(f'User> {message[\"content\"]}', \"green\")\n",
" \"content\": 'Write me a sonnet about llama'\n",
"}\n",
"print(f'User> {message[\"content\"]}', 'green')\n",
"\n", "\n",
"response = client.inference.chat_completion(\n", "response = client.inference.chat_completion(\n",
" messages=[message],\n", " messages=[message],\n",
" model_id=model_id,\n", " model_id=model_id,\n",
" stream=True, # <-----------\n", " stream=True, # <-----------\n",
")\n", ")\n",
"\n", "\n",
"# Print the tokens while they are received\n", "# Print the tokens while they are received\n",
"for log in EventLogger().log(response):\n", "for log in EventLogger().log(response):\n",
" log.print()" " log.print()\n"
] ]
}, },
{ {
@ -1045,26 +1047,26 @@
"source": [ "source": [
"from pydantic import BaseModel\n", "from pydantic import BaseModel\n",
"\n", "\n",
"\n",
"class Output(BaseModel):\n", "class Output(BaseModel):\n",
" name: str\n", " name: str\n",
" year_born: str\n", " year_born: str\n",
" year_retired: str\n", " year_retired: str\n",
"\n", "\n",
"\n",
"user_input = \"Michael Jordan was born in 1963. He played basketball for the Chicago Bulls. He retired in 2003. Extract this information into JSON for me. \"\n", "user_input = \"Michael Jordan was born in 1963. He played basketball for the Chicago Bulls. He retired in 2003. Extract this information into JSON for me. \"\n",
"response = client.inference.completion(\n", "response = client.inference.completion(\n",
" model_id=model_id,\n", " model_id=model_id,\n",
" content=user_input,\n", " content=user_input,\n",
" stream=False,\n", " stream=False,\n",
" sampling_params={\n", " sampling_params={\"strategy\": {\"type\": \"greedy\"}, \"max_tokens\": 50},\n",
" \"max_tokens\": 50,\n",
" },\n",
" response_format={\n", " response_format={\n",
" \"type\": \"json_schema\",\n", " \"type\": \"json_schema\",\n",
" \"json_schema\": Output.model_json_schema(),\n", " \"json_schema\": Output.model_json_schema(),\n",
" },\n", " },\n",
")\n", ")\n",
"\n", "\n",
"pprint(response)" "pprint(response)\n"
] ]
}, },
{ {
@ -1220,7 +1222,7 @@
" shield_id=available_shields[0],\n", " shield_id=available_shields[0],\n",
" params={},\n", " params={},\n",
" )\n", " )\n",
" pprint(response)" " pprint(response)\n"
] ]
}, },
{ {
@ -1489,8 +1491,8 @@
"source": [ "source": [
"from llama_stack_client.lib.agents.agent import Agent\n", "from llama_stack_client.lib.agents.agent import Agent\n",
"from llama_stack_client.lib.agents.event_logger import EventLogger\n", "from llama_stack_client.lib.agents.event_logger import EventLogger\n",
"from llama_stack_client.types.agent_create_params import AgentConfig\n",
"from llama_stack_client.types import Attachment\n", "from llama_stack_client.types import Attachment\n",
"from llama_stack_client.types.agent_create_params import AgentConfig\n",
"from termcolor import cprint\n", "from termcolor import cprint\n",
"\n", "\n",
"urls = [\"chat.rst\", \"llama3.rst\", \"datasets.rst\", \"lora_finetune.rst\"]\n", "urls = [\"chat.rst\", \"llama3.rst\", \"datasets.rst\", \"lora_finetune.rst\"]\n",
@ -1522,14 +1524,14 @@
" ),\n", " ),\n",
"]\n", "]\n",
"for prompt, attachments in user_prompts:\n", "for prompt, attachments in user_prompts:\n",
" cprint(f'User> {prompt}', 'green')\n", " cprint(f\"User> {prompt}\", \"green\")\n",
" response = rag_agent.create_turn(\n", " response = rag_agent.create_turn(\n",
" messages=[{\"role\": \"user\", \"content\": prompt}],\n", " messages=[{\"role\": \"user\", \"content\": prompt}],\n",
" attachments=attachments,\n", " attachments=attachments,\n",
" session_id=session_id,\n", " session_id=session_id,\n",
" )\n", " )\n",
" for log in EventLogger().log(response):\n", " for log in EventLogger().log(response):\n",
" log.print()" " log.print()\n"
] ]
}, },
{ {
@ -1560,8 +1562,8 @@
"search_tool = {\n", "search_tool = {\n",
" \"type\": \"brave_search\",\n", " \"type\": \"brave_search\",\n",
" \"engine\": \"tavily\",\n", " \"engine\": \"tavily\",\n",
" \"api_key\": userdata.get(\"TAVILY_SEARCH_API_KEY\")\n", " \"api_key\": userdata.get(\"TAVILY_SEARCH_API_KEY\"),\n",
"}" "}\n"
] ]
}, },
{ {
@ -1608,7 +1610,7 @@
"\n", "\n",
"session_id = agent.create_session(\"test-session\")\n", "session_id = agent.create_session(\"test-session\")\n",
"for prompt in user_prompts:\n", "for prompt in user_prompts:\n",
" cprint(f'User> {prompt}', 'green')\n", " cprint(f\"User> {prompt}\", \"green\")\n",
" response = agent.create_turn(\n", " response = agent.create_turn(\n",
" messages=[\n", " messages=[\n",
" {\n", " {\n",
@ -1758,7 +1760,7 @@
" search_tool,\n", " search_tool,\n",
" {\n", " {\n",
" \"type\": \"code_interpreter\",\n", " \"type\": \"code_interpreter\",\n",
" }\n", " },\n",
" ],\n", " ],\n",
" tool_choice=\"required\",\n", " tool_choice=\"required\",\n",
" input_shields=[],\n", " input_shields=[],\n",
@ -1788,7 +1790,7 @@
"]\n", "]\n",
"\n", "\n",
"for prompt in user_prompts:\n", "for prompt in user_prompts:\n",
" cprint(f'User> {prompt}', 'green')\n", " cprint(f\"User> {prompt}\", \"green\")\n",
" response = codex_agent.create_turn(\n", " response = codex_agent.create_turn(\n",
" messages=[\n", " messages=[\n",
" {\n", " {\n",
@ -1841,27 +1843,57 @@
} }
], ],
"source": [ "source": [
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n", "import matplotlib.pyplot as plt\n",
"import pandas as pd\n",
"\n", "\n",
"# Read the CSV file\n", "# Read the CSV file\n",
"df = pd.read_csv('/tmp/tmpco0s0o4_/LOdZoVp1inflation.csv')\n", "df = pd.read_csv(\"/tmp/tmpco0s0o4_/LOdZoVp1inflation.csv\")\n",
"\n", "\n",
"# Extract the year and inflation rate from the CSV file\n", "# Extract the year and inflation rate from the CSV file\n",
"df['Year'] = pd.to_datetime(df['Year'], format='%Y')\n", "df[\"Year\"] = pd.to_datetime(df[\"Year\"], format=\"%Y\")\n",
"df = df.rename(columns={'Jan': 'Jan Rate', 'Feb': 'Feb Rate', 'Mar': 'Mar Rate', 'Apr': 'Apr Rate', 'May': 'May Rate', 'Jun': 'Jun Rate', 'Jul': 'Jul Rate', 'Aug': 'Aug Rate', 'Sep': 'Sep Rate', 'Oct': 'Oct Rate', 'Nov': 'Nov Rate', 'Dec': 'Dec Rate'})\n", "df = df.rename(\n",
" columns={\n",
" \"Jan\": \"Jan Rate\",\n",
" \"Feb\": \"Feb Rate\",\n",
" \"Mar\": \"Mar Rate\",\n",
" \"Apr\": \"Apr Rate\",\n",
" \"May\": \"May Rate\",\n",
" \"Jun\": \"Jun Rate\",\n",
" \"Jul\": \"Jul Rate\",\n",
" \"Aug\": \"Aug Rate\",\n",
" \"Sep\": \"Sep Rate\",\n",
" \"Oct\": \"Oct Rate\",\n",
" \"Nov\": \"Nov Rate\",\n",
" \"Dec\": \"Dec Rate\",\n",
" }\n",
")\n",
"\n", "\n",
"# Calculate the average yearly inflation rate\n", "# Calculate the average yearly inflation rate\n",
"df['Yearly Inflation'] = df[['Jan Rate', 'Feb Rate', 'Mar Rate', 'Apr Rate', 'May Rate', 'Jun Rate', 'Jul Rate', 'Aug Rate', 'Sep Rate', 'Oct Rate', 'Nov Rate', 'Dec Rate']].mean(axis=1)\n", "df[\"Yearly Inflation\"] = df[\n",
" [\n",
" \"Jan Rate\",\n",
" \"Feb Rate\",\n",
" \"Mar Rate\",\n",
" \"Apr Rate\",\n",
" \"May Rate\",\n",
" \"Jun Rate\",\n",
" \"Jul Rate\",\n",
" \"Aug Rate\",\n",
" \"Sep Rate\",\n",
" \"Oct Rate\",\n",
" \"Nov Rate\",\n",
" \"Dec Rate\",\n",
" ]\n",
"].mean(axis=1)\n",
"\n", "\n",
"# Plot the average yearly inflation rate as a time series\n", "# Plot the average yearly inflation rate as a time series\n",
"plt.figure(figsize=(10, 6))\n", "plt.figure(figsize=(10, 6))\n",
"plt.plot(df['Year'], df['Yearly Inflation'], marker='o')\n", "plt.plot(df[\"Year\"], df[\"Yearly Inflation\"], marker=\"o\")\n",
"plt.title('Average Yearly Inflation Rate')\n", "plt.title(\"Average Yearly Inflation Rate\")\n",
"plt.xlabel('Year')\n", "plt.xlabel(\"Year\")\n",
"plt.ylabel('Inflation Rate (%)')\n", "plt.ylabel(\"Inflation Rate (%)\")\n",
"plt.grid(True)\n", "plt.grid(True)\n",
"plt.show()" "plt.show()\n"
] ]
}, },
{ {
@ -2035,6 +2067,8 @@
"source": [ "source": [
"# disable logging for clean server logs\n", "# disable logging for clean server logs\n",
"import logging\n", "import logging\n",
"\n",
"\n",
"def remove_root_handlers():\n", "def remove_root_handlers():\n",
" root_logger = logging.getLogger()\n", " root_logger = logging.getLogger()\n",
" for handler in root_logger.handlers[:]:\n", " for handler in root_logger.handlers[:]:\n",
@ -2042,7 +2076,7 @@
" print(f\"Removed handler {handler.__class__.__name__} from root logger\")\n", " print(f\"Removed handler {handler.__class__.__name__} from root logger\")\n",
"\n", "\n",
"\n", "\n",
"remove_root_handlers()" "remove_root_handlers()\n"
] ]
}, },
{ {
@ -2083,10 +2117,10 @@
} }
], ],
"source": [ "source": [
"from google.colab import userdata\n",
"from llama_stack_client.lib.agents.agent import Agent\n", "from llama_stack_client.lib.agents.agent import Agent\n",
"from llama_stack_client.lib.agents.event_logger import EventLogger\n", "from llama_stack_client.lib.agents.event_logger import EventLogger\n",
"from llama_stack_client.types.agent_create_params import AgentConfig\n", "from llama_stack_client.types.agent_create_params import AgentConfig\n",
"from google.colab import userdata\n",
"\n", "\n",
"agent_config = AgentConfig(\n", "agent_config = AgentConfig(\n",
" model=\"meta-llama/Llama-3.1-405B-Instruct\",\n", " model=\"meta-llama/Llama-3.1-405B-Instruct\",\n",
@ -2096,7 +2130,7 @@
" {\n", " {\n",
" \"type\": \"brave_search\",\n", " \"type\": \"brave_search\",\n",
" \"engine\": \"tavily\",\n", " \"engine\": \"tavily\",\n",
" \"api_key\": userdata.get(\"TAVILY_SEARCH_API_KEY\")\n", " \"api_key\": userdata.get(\"TAVILY_SEARCH_API_KEY\"),\n",
" }\n", " }\n",
" ]\n", " ]\n",
" ),\n", " ),\n",
@ -2125,7 +2159,7 @@
" )\n", " )\n",
"\n", "\n",
" for log in EventLogger().log(response):\n", " for log in EventLogger().log(response):\n",
" log.print()" " log.print()\n"
] ]
}, },
{ {
@ -2265,20 +2299,21 @@
"source": [ "source": [
"print(f\"Getting traces for session_id={session_id}\")\n", "print(f\"Getting traces for session_id={session_id}\")\n",
"import json\n", "import json\n",
"\n",
"from rich.pretty import pprint\n", "from rich.pretty import pprint\n",
"\n", "\n",
"agent_logs = []\n", "agent_logs = []\n",
"\n", "\n",
"for span in client.telemetry.query_spans(\n", "for span in client.telemetry.query_spans(\n",
" attribute_filters=[\n", " attribute_filters=[\n",
" {\"key\": \"session_id\", \"op\": \"eq\", \"value\": session_id},\n", " {\"key\": \"session_id\", \"op\": \"eq\", \"value\": session_id},\n",
" ],\n", " ],\n",
" attributes_to_return=[\"input\", \"output\"]\n", " attributes_to_return=[\"input\", \"output\"],\n",
" ):\n", "):\n",
" if span.attributes[\"output\"] != \"no shields\":\n", " if span.attributes[\"output\"] != \"no shields\":\n",
" agent_logs.append(span.attributes)\n", " agent_logs.append(span.attributes)\n",
"\n", "\n",
"pprint(agent_logs)" "pprint(agent_logs)\n"
] ]
}, },
{ {
@ -2389,23 +2424,25 @@
"eval_rows = []\n", "eval_rows = []\n",
"\n", "\n",
"for log in agent_logs:\n", "for log in agent_logs:\n",
" last_msg = log['input'][-1]\n", " last_msg = log[\"input\"][-1]\n",
" if \"\\\"role\\\":\\\"user\\\"\" in last_msg:\n", " if '\"role\":\"user\"' in last_msg:\n",
" eval_rows.append(\n", " eval_rows.append(\n",
" {\n", " {\n",
" \"input_query\": last_msg,\n", " \"input_query\": last_msg,\n",
" \"generated_answer\": log[\"output\"],\n", " \"generated_answer\": log[\"output\"],\n",
" # check if generated_answer uses tools brave_search\n", " # check if generated_answer uses tools brave_search\n",
" \"expected_answer\": \"brave_search\",\n", " \"expected_answer\": \"brave_search\",\n",
" },\n", " },\n",
" )\n", " )\n",
"\n", "\n",
"pprint(eval_rows)\n", "pprint(eval_rows)\n",
"scoring_params = {\n", "scoring_params = {\n",
" \"basic::subset_of\": None,\n", " \"basic::subset_of\": None,\n",
"}\n", "}\n",
"scoring_response = client.scoring.score(input_rows=eval_rows, scoring_functions=scoring_params)\n", "scoring_response = client.scoring.score(\n",
"pprint(scoring_response)" " input_rows=eval_rows, scoring_functions=scoring_params\n",
")\n",
"pprint(scoring_response)\n"
] ]
}, },
{ {
@ -2506,7 +2543,9 @@
"EXPECTED_RESPONSE: {expected_answer}\n", "EXPECTED_RESPONSE: {expected_answer}\n",
"\"\"\"\n", "\"\"\"\n",
"\n", "\n",
"input_query = \"What are the top 5 topics that were explained? Only list succinct bullet points.\"\n", "input_query = (\n",
" \"What are the top 5 topics that were explained? Only list succinct bullet points.\"\n",
")\n",
"generated_answer = \"\"\"\n", "generated_answer = \"\"\"\n",
"Here are the top 5 topics that were explained in the documentation for Torchtune:\n", "Here are the top 5 topics that were explained in the documentation for Torchtune:\n",
"\n", "\n",
@ -2537,7 +2576,7 @@
"}\n", "}\n",
"\n", "\n",
"response = client.scoring.score(input_rows=rows, scoring_functions=scoring_params)\n", "response = client.scoring.score(input_rows=rows, scoring_functions=scoring_params)\n",
"pprint(response)" "pprint(response)\n"
] ]
}, },
{ {

View file

@ -618,11 +618,13 @@
], ],
"source": [ "source": [
"import os\n", "import os\n",
"\n",
"from google.colab import userdata\n", "from google.colab import userdata\n",
"\n", "\n",
"os.environ['TOGETHER_API_KEY'] = userdata.get('TOGETHER_API_KEY')\n", "os.environ[\"TOGETHER_API_KEY\"] = userdata.get(\"TOGETHER_API_KEY\")\n",
"\n", "\n",
"from llama_stack.distribution.library_client import LlamaStackAsLibraryClient\n", "from llama_stack.distribution.library_client import LlamaStackAsLibraryClient\n",
"\n",
"client = LlamaStackAsLibraryClient(\"together\")\n", "client = LlamaStackAsLibraryClient(\"together\")\n",
"_ = client.initialize()\n", "_ = client.initialize()\n",
"\n", "\n",
@ -631,7 +633,7 @@
" model_id=\"meta-llama/Llama-3.1-405B-Instruct\",\n", " model_id=\"meta-llama/Llama-3.1-405B-Instruct\",\n",
" provider_model_id=\"meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo\",\n", " provider_model_id=\"meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo\",\n",
" provider_id=\"together\",\n", " provider_id=\"together\",\n",
")" ")\n"
] ]
}, },
{ {
@ -668,7 +670,7 @@
"source": [ "source": [
"name = \"llamastack/mmmu\"\n", "name = \"llamastack/mmmu\"\n",
"subset = \"Agriculture\"\n", "subset = \"Agriculture\"\n",
"split = \"dev\"" "split = \"dev\"\n"
] ]
}, },
{ {
@ -914,9 +916,10 @@
], ],
"source": [ "source": [
"import datasets\n", "import datasets\n",
"\n",
"ds = datasets.load_dataset(path=name, name=subset, split=split)\n", "ds = datasets.load_dataset(path=name, name=subset, split=split)\n",
"ds = ds.select_columns([\"chat_completion_input\", \"input_query\", \"expected_answer\"])\n", "ds = ds.select_columns([\"chat_completion_input\", \"input_query\", \"expected_answer\"])\n",
"eval_rows = ds.to_pandas().to_dict(orient=\"records\")" "eval_rows = ds.to_pandas().to_dict(orient=\"records\")\n"
] ]
}, },
{ {
@ -1014,8 +1017,8 @@
} }
], ],
"source": [ "source": [
"from tqdm import tqdm\n",
"from rich.pretty import pprint\n", "from rich.pretty import pprint\n",
"from tqdm import tqdm\n",
"\n", "\n",
"SYSTEM_PROMPT_TEMPLATE = \"\"\"\n", "SYSTEM_PROMPT_TEMPLATE = \"\"\"\n",
"You are an expert in {subject} whose job is to answer questions from the user using images.\n", "You are an expert in {subject} whose job is to answer questions from the user using images.\n",
@ -1039,7 +1042,7 @@
"client.eval_tasks.register(\n", "client.eval_tasks.register(\n",
" eval_task_id=\"meta-reference::mmmu\",\n", " eval_task_id=\"meta-reference::mmmu\",\n",
" dataset_id=f\"mmmu-{subset}-{split}\",\n", " dataset_id=f\"mmmu-{subset}-{split}\",\n",
" scoring_functions=[\"basic::regex_parser_multiple_choice_answer\"]\n", " scoring_functions=[\"basic::regex_parser_multiple_choice_answer\"],\n",
")\n", ")\n",
"\n", "\n",
"response = client.eval.evaluate_rows(\n", "response = client.eval.evaluate_rows(\n",
@ -1052,16 +1055,17 @@
" \"type\": \"model\",\n", " \"type\": \"model\",\n",
" \"model\": \"meta-llama/Llama-3.2-90B-Vision-Instruct\",\n", " \"model\": \"meta-llama/Llama-3.2-90B-Vision-Instruct\",\n",
" \"sampling_params\": {\n", " \"sampling_params\": {\n",
" \"temperature\": 0.0,\n", " \"strategy\": {\n",
" \"type\": \"greedy\",\n",
" },\n",
" \"max_tokens\": 4096,\n", " \"max_tokens\": 4096,\n",
" \"top_p\": 0.9,\n",
" \"repeat_penalty\": 1.0,\n", " \"repeat_penalty\": 1.0,\n",
" },\n", " },\n",
" \"system_message\": system_message\n", " \"system_message\": system_message,\n",
" }\n", " },\n",
" }\n", " },\n",
")\n", ")\n",
"pprint(response)" "pprint(response)\n"
] ]
}, },
{ {
@ -1098,8 +1102,8 @@
" \"input_query\": {\"type\": \"string\"},\n", " \"input_query\": {\"type\": \"string\"},\n",
" \"expected_answer\": {\"type\": \"string\"},\n", " \"expected_answer\": {\"type\": \"string\"},\n",
" \"chat_completion_input\": {\"type\": \"chat_completion_input\"},\n", " \"chat_completion_input\": {\"type\": \"chat_completion_input\"},\n",
" }\n", " },\n",
")" ")\n"
] ]
}, },
{ {
@ -1113,7 +1117,7 @@
"eval_rows = client.datasetio.get_rows_paginated(\n", "eval_rows = client.datasetio.get_rows_paginated(\n",
" dataset_id=simpleqa_dataset_id,\n", " dataset_id=simpleqa_dataset_id,\n",
" rows_in_page=5,\n", " rows_in_page=5,\n",
")" ")\n"
] ]
}, },
{ {
@ -1209,7 +1213,7 @@
"client.eval_tasks.register(\n", "client.eval_tasks.register(\n",
" eval_task_id=\"meta-reference::simpleqa\",\n", " eval_task_id=\"meta-reference::simpleqa\",\n",
" dataset_id=simpleqa_dataset_id,\n", " dataset_id=simpleqa_dataset_id,\n",
" scoring_functions=[\"llm-as-judge::405b-simpleqa\"]\n", " scoring_functions=[\"llm-as-judge::405b-simpleqa\"],\n",
")\n", ")\n",
"\n", "\n",
"response = client.eval.evaluate_rows(\n", "response = client.eval.evaluate_rows(\n",
@ -1222,15 +1226,16 @@
" \"type\": \"model\",\n", " \"type\": \"model\",\n",
" \"model\": \"meta-llama/Llama-3.2-90B-Vision-Instruct\",\n", " \"model\": \"meta-llama/Llama-3.2-90B-Vision-Instruct\",\n",
" \"sampling_params\": {\n", " \"sampling_params\": {\n",
" \"temperature\": 0.0,\n", " \"strategy\": {\n",
" \"type\": \"greedy\",\n",
" },\n",
" \"max_tokens\": 4096,\n", " \"max_tokens\": 4096,\n",
" \"top_p\": 0.9,\n",
" \"repeat_penalty\": 1.0,\n", " \"repeat_penalty\": 1.0,\n",
" },\n", " },\n",
" }\n", " },\n",
" }\n", " },\n",
")\n", ")\n",
"pprint(response)" "pprint(response)\n"
] ]
}, },
{ {
@ -1347,23 +1352,19 @@
"agent_config = {\n", "agent_config = {\n",
" \"model\": \"meta-llama/Llama-3.1-405B-Instruct\",\n", " \"model\": \"meta-llama/Llama-3.1-405B-Instruct\",\n",
" \"instructions\": \"You are a helpful assistant\",\n", " \"instructions\": \"You are a helpful assistant\",\n",
" \"sampling_params\": {\n", " \"sampling_params\": {\"strategy\": {\"type\": \"greedy\"}},\n",
" \"strategy\": \"greedy\",\n",
" \"temperature\": 0.0,\n",
" \"top_p\": 0.95,\n",
" },\n",
" \"tools\": [\n", " \"tools\": [\n",
" {\n", " {\n",
" \"type\": \"brave_search\",\n", " \"type\": \"brave_search\",\n",
" \"engine\": \"tavily\",\n", " \"engine\": \"tavily\",\n",
" \"api_key\": userdata.get(\"TAVILY_SEARCH_API_KEY\")\n", " \"api_key\": userdata.get(\"TAVILY_SEARCH_API_KEY\"),\n",
" }\n", " }\n",
" ],\n", " ],\n",
" \"tool_choice\": \"auto\",\n", " \"tool_choice\": \"auto\",\n",
" \"tool_prompt_format\": \"json\",\n", " \"tool_prompt_format\": \"json\",\n",
" \"input_shields\": [],\n", " \"input_shields\": [],\n",
" \"output_shields\": [],\n", " \"output_shields\": [],\n",
" \"enable_session_persistence\": False\n", " \"enable_session_persistence\": False,\n",
"}\n", "}\n",
"\n", "\n",
"response = client.eval.evaluate_rows(\n", "response = client.eval.evaluate_rows(\n",
@ -1375,10 +1376,10 @@
" \"eval_candidate\": {\n", " \"eval_candidate\": {\n",
" \"type\": \"agent\",\n", " \"type\": \"agent\",\n",
" \"config\": agent_config,\n", " \"config\": agent_config,\n",
" }\n", " },\n",
" }\n", " },\n",
")\n", ")\n",
"pprint(response)" "pprint(response)\n"
] ]
} }
], ],

View file

@ -1336,6 +1336,7 @@
], ],
"source": [ "source": [
"from rich.pretty import pprint\n", "from rich.pretty import pprint\n",
"\n",
"print(\"Available models:\")\n", "print(\"Available models:\")\n",
"for m in client.models.list():\n", "for m in client.models.list():\n",
" print(f\"{m.identifier} (provider's alias: {m.provider_resource_id}) \")\n", " print(f\"{m.identifier} (provider's alias: {m.provider_resource_id}) \")\n",
@ -1344,7 +1345,7 @@
"print(\"Available shields (safety models):\")\n", "print(\"Available shields (safety models):\")\n",
"for s in client.shields.list():\n", "for s in client.shields.list():\n",
" print(s.identifier)\n", " print(s.identifier)\n",
"print(\"----\")" "print(\"----\")\n"
] ]
}, },
{ {
@ -1389,7 +1390,7 @@
"source": [ "source": [
"model_id = \"meta-llama/Llama-3.1-70B-Instruct\"\n", "model_id = \"meta-llama/Llama-3.1-70B-Instruct\"\n",
"\n", "\n",
"model_id" "model_id\n"
] ]
}, },
{ {
@ -1432,11 +1433,11 @@
" model_id=model_id,\n", " model_id=model_id,\n",
" messages=[\n", " messages=[\n",
" {\"role\": \"system\", \"content\": \"You are a friendly assistant.\"},\n", " {\"role\": \"system\", \"content\": \"You are a friendly assistant.\"},\n",
" {\"role\": \"user\", \"content\": \"Write a two-sentence poem about llama.\"}\n", " {\"role\": \"user\", \"content\": \"Write a two-sentence poem about llama.\"},\n",
" ],\n", " ],\n",
")\n", ")\n",
"\n", "\n",
"print(response.completion_message.content)" "print(response.completion_message.content)\n"
] ]
}, },
{ {
@ -1489,12 +1490,13 @@
"source": [ "source": [
"from termcolor import cprint\n", "from termcolor import cprint\n",
"\n", "\n",
"\n",
"def chat_loop():\n", "def chat_loop():\n",
" conversation_history = []\n", " conversation_history = []\n",
" while True:\n", " while True:\n",
" user_input = input('User> ')\n", " user_input = input(\"User> \")\n",
" if user_input.lower() in ['exit', 'quit', 'bye']:\n", " if user_input.lower() in [\"exit\", \"quit\", \"bye\"]:\n",
" cprint('Ending conversation. Goodbye!', 'yellow')\n", " cprint(\"Ending conversation. Goodbye!\", \"yellow\")\n",
" break\n", " break\n",
"\n", "\n",
" user_message = {\"role\": \"user\", \"content\": user_input}\n", " user_message = {\"role\": \"user\", \"content\": user_input}\n",
@ -1504,15 +1506,16 @@
" messages=conversation_history,\n", " messages=conversation_history,\n",
" model_id=model_id,\n", " model_id=model_id,\n",
" )\n", " )\n",
" cprint(f'> Response: {response.completion_message.content}', 'cyan')\n", " cprint(f\"> Response: {response.completion_message.content}\", \"cyan\")\n",
"\n", "\n",
" assistant_message = {\n", " assistant_message = {\n",
" \"role\": \"assistant\", # was user\n", " \"role\": \"assistant\", # was user\n",
" \"content\": response.completion_message.content,\n", " \"content\": response.completion_message.content,\n",
" \"stop_reason\": response.completion_message.stop_reason,\n", " \"stop_reason\": response.completion_message.stop_reason,\n",
" }\n", " }\n",
" conversation_history.append(assistant_message)\n", " conversation_history.append(assistant_message)\n",
"\n", "\n",
"\n",
"chat_loop()\n" "chat_loop()\n"
] ]
}, },
@ -1568,21 +1571,18 @@
"source": [ "source": [
"from llama_stack_client.lib.inference.event_logger import EventLogger\n", "from llama_stack_client.lib.inference.event_logger import EventLogger\n",
"\n", "\n",
"message = {\n", "message = {\"role\": \"user\", \"content\": \"Write me a sonnet about llama\"}\n",
" \"role\": \"user\",\n", "print(f'User> {message[\"content\"]}', \"green\")\n",
" \"content\": 'Write me a sonnet about llama'\n",
"}\n",
"print(f'User> {message[\"content\"]}', 'green')\n",
"\n", "\n",
"response = client.inference.chat_completion(\n", "response = client.inference.chat_completion(\n",
" messages=[message],\n", " messages=[message],\n",
" model_id=model_id,\n", " model_id=model_id,\n",
" stream=True, # <-----------\n", " stream=True, # <-----------\n",
")\n", ")\n",
"\n", "\n",
"# Print the tokens while they are received\n", "# Print the tokens while they are received\n",
"for log in EventLogger().log(response):\n", "for log in EventLogger().log(response):\n",
" log.print()" " log.print()\n"
] ]
}, },
{ {
@ -1648,17 +1648,22 @@
"source": [ "source": [
"from pydantic import BaseModel\n", "from pydantic import BaseModel\n",
"\n", "\n",
"\n",
"class Output(BaseModel):\n", "class Output(BaseModel):\n",
" name: str\n", " name: str\n",
" year_born: str\n", " year_born: str\n",
" year_retired: str\n", " year_retired: str\n",
"\n", "\n",
"\n",
"user_input = \"Michael Jordan was born in 1963. He played basketball for the Chicago Bulls. He retired in 2003. Extract this information into JSON for me. \"\n", "user_input = \"Michael Jordan was born in 1963. He played basketball for the Chicago Bulls. He retired in 2003. Extract this information into JSON for me. \"\n",
"response = client.inference.completion(\n", "response = client.inference.completion(\n",
" model_id=model_id,\n", " model_id=model_id,\n",
" content=user_input,\n", " content=user_input,\n",
" stream=False,\n", " stream=False,\n",
" sampling_params={\n", " sampling_params={\n",
" \"strategy\": {\n",
" \"type\": \"greedy\",\n",
" },\n",
" \"max_tokens\": 50,\n", " \"max_tokens\": 50,\n",
" },\n", " },\n",
" response_format={\n", " response_format={\n",
@ -1667,7 +1672,7 @@
" },\n", " },\n",
")\n", ")\n",
"\n", "\n",
"pprint(response)" "pprint(response)\n"
] ]
}, },
{ {
@ -1823,7 +1828,7 @@
" shield_id=available_shields[0],\n", " shield_id=available_shields[0],\n",
" params={},\n", " params={},\n",
" )\n", " )\n",
" pprint(response)" " pprint(response)\n"
] ]
}, },
{ {
@ -2025,7 +2030,7 @@
"\n", "\n",
"session_id = agent.create_session(\"test-session\")\n", "session_id = agent.create_session(\"test-session\")\n",
"for prompt in user_prompts:\n", "for prompt in user_prompts:\n",
" cprint(f'User> {prompt}', 'green')\n", " cprint(f\"User> {prompt}\", \"green\")\n",
" response = agent.create_turn(\n", " response = agent.create_turn(\n",
" messages=[\n", " messages=[\n",
" {\n", " {\n",
@ -2451,8 +2456,8 @@
} }
], ],
"source": [ "source": [
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n", "import matplotlib.pyplot as plt\n",
"import pandas as pd\n",
"\n", "\n",
"# Load data\n", "# Load data\n",
"df = pd.read_csv(\"/tmp/tmpvzjigv7g/n2OzlTWhinflation.csv\")\n", "df = pd.read_csv(\"/tmp/tmpvzjigv7g/n2OzlTWhinflation.csv\")\n",
@ -2536,10 +2541,10 @@
} }
], ],
"source": [ "source": [
"from google.colab import userdata\n",
"from llama_stack_client.lib.agents.agent import Agent\n", "from llama_stack_client.lib.agents.agent import Agent\n",
"from llama_stack_client.lib.agents.event_logger import EventLogger\n", "from llama_stack_client.lib.agents.event_logger import EventLogger\n",
"from llama_stack_client.types.agent_create_params import AgentConfig\n", "from llama_stack_client.types.agent_create_params import AgentConfig\n",
"from google.colab import userdata\n",
"\n", "\n",
"agent_config = AgentConfig(\n", "agent_config = AgentConfig(\n",
" model=\"meta-llama/Llama-3.1-405B-Instruct-FP8\",\n", " model=\"meta-llama/Llama-3.1-405B-Instruct-FP8\",\n",
@ -2570,7 +2575,7 @@
" )\n", " )\n",
"\n", "\n",
" for log in EventLogger().log(response):\n", " for log in EventLogger().log(response):\n",
" log.print()" " log.print()\n"
] ]
}, },
{ {
@ -2790,20 +2795,21 @@
"source": [ "source": [
"print(f\"Getting traces for session_id={session_id}\")\n", "print(f\"Getting traces for session_id={session_id}\")\n",
"import json\n", "import json\n",
"\n",
"from rich.pretty import pprint\n", "from rich.pretty import pprint\n",
"\n", "\n",
"agent_logs = []\n", "agent_logs = []\n",
"\n", "\n",
"for span in client.telemetry.query_spans(\n", "for span in client.telemetry.query_spans(\n",
" attribute_filters=[\n", " attribute_filters=[\n",
" {\"key\": \"session_id\", \"op\": \"eq\", \"value\": session_id},\n", " {\"key\": \"session_id\", \"op\": \"eq\", \"value\": session_id},\n",
" ],\n", " ],\n",
" attributes_to_return=[\"input\", \"output\"]\n", " attributes_to_return=[\"input\", \"output\"],\n",
" ):\n", "):\n",
" if span.attributes[\"output\"] != \"no shields\":\n", " if span.attributes[\"output\"] != \"no shields\":\n",
" agent_logs.append(span.attributes)\n", " agent_logs.append(span.attributes)\n",
"\n", "\n",
"pprint(agent_logs)" "pprint(agent_logs)\n"
] ]
}, },
{ {
@ -2914,23 +2920,25 @@
"eval_rows = []\n", "eval_rows = []\n",
"\n", "\n",
"for log in agent_logs:\n", "for log in agent_logs:\n",
" last_msg = log['input'][-1]\n", " last_msg = log[\"input\"][-1]\n",
" if \"\\\"role\\\":\\\"user\\\"\" in last_msg:\n", " if '\"role\":\"user\"' in last_msg:\n",
" eval_rows.append(\n", " eval_rows.append(\n",
" {\n", " {\n",
" \"input_query\": last_msg,\n", " \"input_query\": last_msg,\n",
" \"generated_answer\": log[\"output\"],\n", " \"generated_answer\": log[\"output\"],\n",
" # check if generated_answer uses tools brave_search\n", " # check if generated_answer uses tools brave_search\n",
" \"expected_answer\": \"brave_search\",\n", " \"expected_answer\": \"brave_search\",\n",
" },\n", " },\n",
" )\n", " )\n",
"\n", "\n",
"pprint(eval_rows)\n", "pprint(eval_rows)\n",
"scoring_params = {\n", "scoring_params = {\n",
" \"basic::subset_of\": None,\n", " \"basic::subset_of\": None,\n",
"}\n", "}\n",
"scoring_response = client.scoring.score(input_rows=eval_rows, scoring_functions=scoring_params)\n", "scoring_response = client.scoring.score(\n",
"pprint(scoring_response)" " input_rows=eval_rows, scoring_functions=scoring_params\n",
")\n",
"pprint(scoring_response)\n"
] ]
}, },
{ {
@ -3031,7 +3039,9 @@
"EXPECTED_RESPONSE: {expected_answer}\n", "EXPECTED_RESPONSE: {expected_answer}\n",
"\"\"\"\n", "\"\"\"\n",
"\n", "\n",
"input_query = \"What are the top 5 topics that were explained? Only list succinct bullet points.\"\n", "input_query = (\n",
" \"What are the top 5 topics that were explained? Only list succinct bullet points.\"\n",
")\n",
"generated_answer = \"\"\"\n", "generated_answer = \"\"\"\n",
"Here are the top 5 topics that were explained in the documentation for Torchtune:\n", "Here are the top 5 topics that were explained in the documentation for Torchtune:\n",
"\n", "\n",
@ -3062,7 +3072,7 @@
"}\n", "}\n",
"\n", "\n",
"response = client.scoring.score(input_rows=rows, scoring_functions=scoring_params)\n", "response = client.scoring.score(input_rows=rows, scoring_functions=scoring_params)\n",
"pprint(response)" "pprint(response)\n"
] ]
} }
], ],

View file

@ -3514,6 +3514,20 @@
"tool_calls" "tool_calls"
] ]
}, },
"GreedySamplingStrategy": {
"type": "object",
"properties": {
"type": {
"type": "string",
"const": "greedy",
"default": "greedy"
}
},
"additionalProperties": false,
"required": [
"type"
]
},
"ImageContentItem": { "ImageContentItem": {
"type": "object", "type": "object",
"properties": { "properties": {
@ -3581,20 +3595,17 @@
"type": "object", "type": "object",
"properties": { "properties": {
"strategy": { "strategy": {
"$ref": "#/components/schemas/SamplingStrategy", "oneOf": [
"default": "greedy" {
}, "$ref": "#/components/schemas/GreedySamplingStrategy"
"temperature": { },
"type": "number", {
"default": 0.0 "$ref": "#/components/schemas/TopPSamplingStrategy"
}, },
"top_p": { {
"type": "number", "$ref": "#/components/schemas/TopKSamplingStrategy"
"default": 0.95 }
}, ]
"top_k": {
"type": "integer",
"default": 0
}, },
"max_tokens": { "max_tokens": {
"type": "integer", "type": "integer",
@ -3610,14 +3621,6 @@
"strategy" "strategy"
] ]
}, },
"SamplingStrategy": {
"type": "string",
"enum": [
"greedy",
"top_p",
"top_k"
]
},
"StopReason": { "StopReason": {
"type": "string", "type": "string",
"enum": [ "enum": [
@ -3871,6 +3874,45 @@
"content" "content"
] ]
}, },
"TopKSamplingStrategy": {
"type": "object",
"properties": {
"type": {
"type": "string",
"const": "top_k",
"default": "top_k"
},
"top_k": {
"type": "integer"
}
},
"additionalProperties": false,
"required": [
"type",
"top_k"
]
},
"TopPSamplingStrategy": {
"type": "object",
"properties": {
"type": {
"type": "string",
"const": "top_p",
"default": "top_p"
},
"temperature": {
"type": "number"
},
"top_p": {
"type": "number",
"default": 0.95
}
},
"additionalProperties": false,
"required": [
"type"
]
},
"URL": { "URL": {
"type": "object", "type": "object",
"properties": { "properties": {
@ -8887,6 +8929,10 @@
"name": "GraphMemoryBankParams", "name": "GraphMemoryBankParams",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/GraphMemoryBankParams\" />" "description": "<SchemaDefinition schemaRef=\"#/components/schemas/GraphMemoryBankParams\" />"
}, },
{
"name": "GreedySamplingStrategy",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/GreedySamplingStrategy\" />"
},
{ {
"name": "HealthInfo", "name": "HealthInfo",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/HealthInfo\" />" "description": "<SchemaDefinition schemaRef=\"#/components/schemas/HealthInfo\" />"
@ -9136,10 +9182,6 @@
"name": "SamplingParams", "name": "SamplingParams",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/SamplingParams\" />" "description": "<SchemaDefinition schemaRef=\"#/components/schemas/SamplingParams\" />"
}, },
{
"name": "SamplingStrategy",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/SamplingStrategy\" />"
},
{ {
"name": "SaveSpansToDatasetRequest", "name": "SaveSpansToDatasetRequest",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/SaveSpansToDatasetRequest\" />" "description": "<SchemaDefinition schemaRef=\"#/components/schemas/SaveSpansToDatasetRequest\" />"
@ -9317,6 +9359,14 @@
{ {
"name": "ToolRuntime" "name": "ToolRuntime"
}, },
{
"name": "TopKSamplingStrategy",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/TopKSamplingStrategy\" />"
},
{
"name": "TopPSamplingStrategy",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/TopPSamplingStrategy\" />"
},
{ {
"name": "Trace", "name": "Trace",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/Trace\" />" "description": "<SchemaDefinition schemaRef=\"#/components/schemas/Trace\" />"
@ -9456,6 +9506,7 @@
"GetSpanTreeRequest", "GetSpanTreeRequest",
"GraphMemoryBank", "GraphMemoryBank",
"GraphMemoryBankParams", "GraphMemoryBankParams",
"GreedySamplingStrategy",
"HealthInfo", "HealthInfo",
"ImageContentItem", "ImageContentItem",
"InferenceStep", "InferenceStep",
@ -9513,7 +9564,6 @@
"RunShieldResponse", "RunShieldResponse",
"SafetyViolation", "SafetyViolation",
"SamplingParams", "SamplingParams",
"SamplingStrategy",
"SaveSpansToDatasetRequest", "SaveSpansToDatasetRequest",
"ScoreBatchRequest", "ScoreBatchRequest",
"ScoreBatchResponse", "ScoreBatchResponse",
@ -9553,6 +9603,8 @@
"ToolPromptFormat", "ToolPromptFormat",
"ToolResponse", "ToolResponse",
"ToolResponseMessage", "ToolResponseMessage",
"TopKSamplingStrategy",
"TopPSamplingStrategy",
"Trace", "Trace",
"TrainingConfig", "TrainingConfig",
"Turn", "Turn",

View file

@ -937,6 +937,16 @@ components:
required: required:
- memory_bank_type - memory_bank_type
type: object type: object
GreedySamplingStrategy:
additionalProperties: false
properties:
type:
const: greedy
default: greedy
type: string
required:
- type
type: object
HealthInfo: HealthInfo:
additionalProperties: false additionalProperties: false
properties: properties:
@ -2064,26 +2074,13 @@ components:
default: 1.0 default: 1.0
type: number type: number
strategy: strategy:
$ref: '#/components/schemas/SamplingStrategy' oneOf:
default: greedy - $ref: '#/components/schemas/GreedySamplingStrategy'
temperature: - $ref: '#/components/schemas/TopPSamplingStrategy'
default: 0.0 - $ref: '#/components/schemas/TopKSamplingStrategy'
type: number
top_k:
default: 0
type: integer
top_p:
default: 0.95
type: number
required: required:
- strategy - strategy
type: object type: object
SamplingStrategy:
enum:
- greedy
- top_p
- top_k
type: string
SaveSpansToDatasetRequest: SaveSpansToDatasetRequest:
additionalProperties: false additionalProperties: false
properties: properties:
@ -2931,6 +2928,34 @@ components:
- tool_name - tool_name
- content - content
type: object type: object
TopKSamplingStrategy:
additionalProperties: false
properties:
top_k:
type: integer
type:
const: top_k
default: top_k
type: string
required:
- type
- top_k
type: object
TopPSamplingStrategy:
additionalProperties: false
properties:
temperature:
type: number
top_p:
default: 0.95
type: number
type:
const: top_p
default: top_p
type: string
required:
- type
type: object
Trace: Trace:
additionalProperties: false additionalProperties: false
properties: properties:
@ -5587,6 +5612,9 @@ tags:
- description: <SchemaDefinition schemaRef="#/components/schemas/GraphMemoryBankParams" - description: <SchemaDefinition schemaRef="#/components/schemas/GraphMemoryBankParams"
/> />
name: GraphMemoryBankParams name: GraphMemoryBankParams
- description: <SchemaDefinition schemaRef="#/components/schemas/GreedySamplingStrategy"
/>
name: GreedySamplingStrategy
- description: <SchemaDefinition schemaRef="#/components/schemas/HealthInfo" /> - description: <SchemaDefinition schemaRef="#/components/schemas/HealthInfo" />
name: HealthInfo name: HealthInfo
- description: <SchemaDefinition schemaRef="#/components/schemas/ImageContentItem" - description: <SchemaDefinition schemaRef="#/components/schemas/ImageContentItem"
@ -5753,9 +5781,6 @@ tags:
name: SafetyViolation name: SafetyViolation
- description: <SchemaDefinition schemaRef="#/components/schemas/SamplingParams" /> - description: <SchemaDefinition schemaRef="#/components/schemas/SamplingParams" />
name: SamplingParams name: SamplingParams
- description: <SchemaDefinition schemaRef="#/components/schemas/SamplingStrategy"
/>
name: SamplingStrategy
- description: <SchemaDefinition schemaRef="#/components/schemas/SaveSpansToDatasetRequest" - description: <SchemaDefinition schemaRef="#/components/schemas/SaveSpansToDatasetRequest"
/> />
name: SaveSpansToDatasetRequest name: SaveSpansToDatasetRequest
@ -5874,6 +5899,12 @@ tags:
/> />
name: ToolResponseMessage name: ToolResponseMessage
- name: ToolRuntime - name: ToolRuntime
- description: <SchemaDefinition schemaRef="#/components/schemas/TopKSamplingStrategy"
/>
name: TopKSamplingStrategy
- description: <SchemaDefinition schemaRef="#/components/schemas/TopPSamplingStrategy"
/>
name: TopPSamplingStrategy
- description: <SchemaDefinition schemaRef="#/components/schemas/Trace" /> - description: <SchemaDefinition schemaRef="#/components/schemas/Trace" />
name: Trace name: Trace
- description: <SchemaDefinition schemaRef="#/components/schemas/TrainingConfig" /> - description: <SchemaDefinition schemaRef="#/components/schemas/TrainingConfig" />
@ -5990,6 +6021,7 @@ x-tagGroups:
- GetSpanTreeRequest - GetSpanTreeRequest
- GraphMemoryBank - GraphMemoryBank
- GraphMemoryBankParams - GraphMemoryBankParams
- GreedySamplingStrategy
- HealthInfo - HealthInfo
- ImageContentItem - ImageContentItem
- InferenceStep - InferenceStep
@ -6047,7 +6079,6 @@ x-tagGroups:
- RunShieldResponse - RunShieldResponse
- SafetyViolation - SafetyViolation
- SamplingParams - SamplingParams
- SamplingStrategy
- SaveSpansToDatasetRequest - SaveSpansToDatasetRequest
- ScoreBatchRequest - ScoreBatchRequest
- ScoreBatchResponse - ScoreBatchResponse
@ -6087,6 +6118,8 @@ x-tagGroups:
- ToolPromptFormat - ToolPromptFormat
- ToolResponse - ToolResponse
- ToolResponseMessage - ToolResponseMessage
- TopKSamplingStrategy
- TopPSamplingStrategy
- Trace - Trace
- TrainingConfig - TrainingConfig
- Turn - Turn

View file

@ -56,9 +56,10 @@ response = client.eval.evaluate_rows(
"type": "model", "type": "model",
"model": "meta-llama/Llama-3.2-90B-Vision-Instruct", "model": "meta-llama/Llama-3.2-90B-Vision-Instruct",
"sampling_params": { "sampling_params": {
"temperature": 0.0, "strategy": {
"type": "greedy",
},
"max_tokens": 4096, "max_tokens": 4096,
"top_p": 0.9,
"repeat_penalty": 1.0, "repeat_penalty": 1.0,
}, },
"system_message": system_message "system_message": system_message
@ -113,9 +114,10 @@ response = client.eval.evaluate_rows(
"type": "model", "type": "model",
"model": "meta-llama/Llama-3.2-90B-Vision-Instruct", "model": "meta-llama/Llama-3.2-90B-Vision-Instruct",
"sampling_params": { "sampling_params": {
"temperature": 0.0, "strategy": {
"type": "greedy",
},
"max_tokens": 4096, "max_tokens": 4096,
"top_p": 0.9,
"repeat_penalty": 1.0, "repeat_penalty": 1.0,
}, },
} }
@ -134,9 +136,9 @@ agent_config = {
"model": "meta-llama/Llama-3.1-405B-Instruct", "model": "meta-llama/Llama-3.1-405B-Instruct",
"instructions": "You are a helpful assistant", "instructions": "You are a helpful assistant",
"sampling_params": { "sampling_params": {
"strategy": "greedy", "strategy": {
"temperature": 0.0, "type": "greedy",
"top_p": 0.95, },
}, },
"tools": [ "tools": [
{ {

View file

@ -189,7 +189,11 @@ agent_config = AgentConfig(
# Control the inference loop # Control the inference loop
max_infer_iters=5, max_infer_iters=5,
sampling_params={ sampling_params={
"temperature": 0.7, "strategy": {
"type": "top_p",
"temperature": 0.7,
"top_p": 0.95
},
"max_tokens": 2048 "max_tokens": 2048
} }
) )

View file

@ -92,9 +92,10 @@ response = client.eval.evaluate_rows(
"type": "model", "type": "model",
"model": "meta-llama/Llama-3.2-90B-Vision-Instruct", "model": "meta-llama/Llama-3.2-90B-Vision-Instruct",
"sampling_params": { "sampling_params": {
"temperature": 0.0, "strategy": {
"type": "greedy",
},
"max_tokens": 4096, "max_tokens": 4096,
"top_p": 0.9,
"repeat_penalty": 1.0, "repeat_penalty": 1.0,
}, },
"system_message": system_message "system_message": system_message
@ -149,9 +150,10 @@ response = client.eval.evaluate_rows(
"type": "model", "type": "model",
"model": "meta-llama/Llama-3.2-90B-Vision-Instruct", "model": "meta-llama/Llama-3.2-90B-Vision-Instruct",
"sampling_params": { "sampling_params": {
"temperature": 0.0, "strategy": {
"type": "greedy",
},
"max_tokens": 4096, "max_tokens": 4096,
"top_p": 0.9,
"repeat_penalty": 1.0, "repeat_penalty": 1.0,
}, },
} }
@ -170,9 +172,9 @@ agent_config = {
"model": "meta-llama/Llama-3.1-405B-Instruct", "model": "meta-llama/Llama-3.1-405B-Instruct",
"instructions": "You are a helpful assistant", "instructions": "You are a helpful assistant",
"sampling_params": { "sampling_params": {
"strategy": "greedy", "strategy": {
"temperature": 0.0, "type": "greedy",
"top_p": 0.95, },
}, },
"tools": [ "tools": [
{ {
@ -318,10 +320,9 @@ The `EvalTaskConfig` are user specified config to define:
"type": "model", "type": "model",
"model": "Llama3.2-3B-Instruct", "model": "Llama3.2-3B-Instruct",
"sampling_params": { "sampling_params": {
"strategy": "greedy", "strategy": {
"temperature": 0, "type": "greedy",
"top_p": 0.95, },
"top_k": 0,
"max_tokens": 0, "max_tokens": 0,
"repetition_penalty": 1.0 "repetition_penalty": 1.0
} }
@ -337,10 +338,9 @@ The `EvalTaskConfig` are user specified config to define:
"type": "model", "type": "model",
"model": "Llama3.1-405B-Instruct", "model": "Llama3.1-405B-Instruct",
"sampling_params": { "sampling_params": {
"strategy": "greedy", "strategy": {
"temperature": 0, "type": "greedy",
"top_p": 0.95, },
"top_k": 0,
"max_tokens": 0, "max_tokens": 0,
"repetition_penalty": 1.0 "repetition_penalty": 1.0
} }

View file

@ -214,7 +214,6 @@ llama model describe -m Llama3.2-3B-Instruct
| | } | | | } |
+-----------------------------+----------------------------------+ +-----------------------------+----------------------------------+
| Recommended sampling params | { | | Recommended sampling params | { |
| | "strategy": "top_p", |
| | "temperature": 1.0, | | | "temperature": 1.0, |
| | "top_p": 0.9, | | | "top_p": 0.9, |
| | "top_k": 0 | | | "top_k": 0 |

View file

@ -200,10 +200,9 @@ Example eval_task_config.json:
"type": "model", "type": "model",
"model": "Llama3.1-405B-Instruct", "model": "Llama3.1-405B-Instruct",
"sampling_params": { "sampling_params": {
"strategy": "greedy", "strategy": {
"temperature": 0, "type": "greedy"
"top_p": 0.95, },
"top_k": 0,
"max_tokens": 0, "max_tokens": 0,
"repetition_penalty": 1.0 "repetition_penalty": 1.0
} }

View file

@ -26,27 +26,28 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"import os\n",
"import requests\n",
"import json\n",
"import asyncio\n", "import asyncio\n",
"import nest_asyncio\n", "import json\n",
"import os\n",
"from typing import Dict, List\n", "from typing import Dict, List\n",
"\n",
"import nest_asyncio\n",
"import requests\n",
"from dotenv import load_dotenv\n", "from dotenv import load_dotenv\n",
"from llama_stack_client import LlamaStackClient\n", "from llama_stack_client import LlamaStackClient\n",
"from llama_stack_client.lib.agents.custom_tool import CustomTool\n",
"from llama_stack_client.types.shared.tool_response_message import ToolResponseMessage\n",
"from llama_stack_client.types import CompletionMessage\n",
"from llama_stack_client.lib.agents.agent import Agent\n", "from llama_stack_client.lib.agents.agent import Agent\n",
"from llama_stack_client.lib.agents.custom_tool import CustomTool\n",
"from llama_stack_client.lib.agents.event_logger import EventLogger\n", "from llama_stack_client.lib.agents.event_logger import EventLogger\n",
"from llama_stack_client.types import CompletionMessage\n",
"from llama_stack_client.types.agent_create_params import AgentConfig\n", "from llama_stack_client.types.agent_create_params import AgentConfig\n",
"from llama_stack_client.types.shared.tool_response_message import ToolResponseMessage\n",
"\n", "\n",
"# Allow asyncio to run in Jupyter Notebook\n", "# Allow asyncio to run in Jupyter Notebook\n",
"nest_asyncio.apply()\n", "nest_asyncio.apply()\n",
"\n", "\n",
"HOST='localhost'\n", "HOST = \"localhost\"\n",
"PORT=5001\n", "PORT = 5001\n",
"MODEL_NAME='meta-llama/Llama-3.2-3B-Instruct'" "MODEL_NAME = \"meta-llama/Llama-3.2-3B-Instruct\"\n"
] ]
}, },
{ {
@ -69,7 +70,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"load_dotenv()\n", "load_dotenv()\n",
"BRAVE_SEARCH_API_KEY = os.environ['BRAVE_SEARCH_API_KEY']" "BRAVE_SEARCH_API_KEY = os.environ[\"BRAVE_SEARCH_API_KEY\"]\n"
] ]
}, },
{ {
@ -118,7 +119,7 @@
" cleaned = {k: v for k, v in results[idx].items() if k in selected_keys}\n", " cleaned = {k: v for k, v in results[idx].items() if k in selected_keys}\n",
" clean_response.append(cleaned)\n", " clean_response.append(cleaned)\n",
"\n", "\n",
" return {\"query\": query, \"top_k\": clean_response}" " return {\"query\": query, \"top_k\": clean_response}\n"
] ]
}, },
{ {
@ -157,25 +158,29 @@
" for message in messages:\n", " for message in messages:\n",
" if isinstance(message, CompletionMessage) and message.tool_calls:\n", " if isinstance(message, CompletionMessage) and message.tool_calls:\n",
" for tool_call in message.tool_calls:\n", " for tool_call in message.tool_calls:\n",
" if 'query' in tool_call.arguments:\n", " if \"query\" in tool_call.arguments:\n",
" query = tool_call.arguments['query']\n", " query = tool_call.arguments[\"query\"]\n",
" call_id = tool_call.call_id\n", " call_id = tool_call.call_id\n",
"\n", "\n",
" if query:\n", " if query:\n",
" search_result = await self.run_impl(query)\n", " search_result = await self.run_impl(query)\n",
" return [ToolResponseMessage(\n", " return [\n",
" call_id=call_id,\n", " ToolResponseMessage(\n",
" role=\"ipython\",\n", " call_id=call_id,\n",
" content=self._format_response_for_agent(search_result),\n", " role=\"ipython\",\n",
" tool_name=\"brave_search\"\n", " content=self._format_response_for_agent(search_result),\n",
" )]\n", " tool_name=\"brave_search\",\n",
" )\n",
" ]\n",
"\n", "\n",
" return [ToolResponseMessage(\n", " return [\n",
" call_id=\"no_call_id\",\n", " ToolResponseMessage(\n",
" role=\"ipython\",\n", " call_id=\"no_call_id\",\n",
" content=\"No query provided.\",\n", " role=\"ipython\",\n",
" tool_name=\"brave_search\"\n", " content=\"No query provided.\",\n",
" )]\n", " tool_name=\"brave_search\",\n",
" )\n",
" ]\n",
"\n", "\n",
" def _format_response_for_agent(self, search_result):\n", " def _format_response_for_agent(self, search_result):\n",
" parsed_result = json.loads(search_result)\n", " parsed_result = json.loads(search_result)\n",
@ -186,7 +191,7 @@
" f\" URL: {result.get('url', 'No URL')}\\n\"\n", " f\" URL: {result.get('url', 'No URL')}\\n\"\n",
" f\" Description: {result.get('description', 'No Description')}\\n\\n\"\n", " f\" Description: {result.get('description', 'No Description')}\\n\\n\"\n",
" )\n", " )\n",
" return formatted_result" " return formatted_result\n"
] ]
}, },
{ {
@ -209,7 +214,7 @@
"async def execute_search(query: str):\n", "async def execute_search(query: str):\n",
" web_search_tool = WebSearchTool(api_key=BRAVE_SEARCH_API_KEY)\n", " web_search_tool = WebSearchTool(api_key=BRAVE_SEARCH_API_KEY)\n",
" result = await web_search_tool.run_impl(query)\n", " result = await web_search_tool.run_impl(query)\n",
" print(\"Search Results:\", result)" " print(\"Search Results:\", result)\n"
] ]
}, },
{ {
@ -236,7 +241,7 @@
], ],
"source": [ "source": [
"query = \"Latest developments in quantum computing\"\n", "query = \"Latest developments in quantum computing\"\n",
"asyncio.run(execute_search(query))" "asyncio.run(execute_search(query))\n"
] ]
}, },
{ {
@ -288,19 +293,17 @@
"\n", "\n",
" # Initialize custom tool (ensure `WebSearchTool` is defined earlier in the notebook)\n", " # Initialize custom tool (ensure `WebSearchTool` is defined earlier in the notebook)\n",
" webSearchTool = WebSearchTool(api_key=BRAVE_SEARCH_API_KEY)\n", " webSearchTool = WebSearchTool(api_key=BRAVE_SEARCH_API_KEY)\n",
" \n", "\n",
" # Define the agent configuration, including the model and tool setup\n", " # Define the agent configuration, including the model and tool setup\n",
" agent_config = AgentConfig(\n", " agent_config = AgentConfig(\n",
" model=MODEL_NAME,\n", " model=MODEL_NAME,\n",
" instructions=\"\"\"You are a helpful assistant that responds to user queries with relevant information and cites sources when available.\"\"\",\n", " instructions=\"\"\"You are a helpful assistant that responds to user queries with relevant information and cites sources when available.\"\"\",\n",
" sampling_params={\n", " sampling_params={\n",
" \"strategy\": \"greedy\",\n", " \"strategy\": {\n",
" \"temperature\": 1.0,\n", " \"type\": \"greedy\",\n",
" \"top_p\": 0.9,\n", " },\n",
" },\n", " },\n",
" tools=[\n", " tools=[webSearchTool.get_tool_definition()],\n",
" webSearchTool.get_tool_definition()\n",
" ],\n",
" tool_choice=\"auto\",\n", " tool_choice=\"auto\",\n",
" tool_prompt_format=\"python_list\",\n", " tool_prompt_format=\"python_list\",\n",
" input_shields=input_shields,\n", " input_shields=input_shields,\n",
@ -329,8 +332,9 @@
" async for log in EventLogger().log(response):\n", " async for log in EventLogger().log(response):\n",
" log.print()\n", " log.print()\n",
"\n", "\n",
"\n",
"# Run the function asynchronously in a Jupyter Notebook cell\n", "# Run the function asynchronously in a Jupyter Notebook cell\n",
"await run_main(disable_safety=True)" "await run_main(disable_safety=True)\n"
] ]
} }
], ],

View file

@ -50,8 +50,8 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"HOST = \"localhost\" # Replace with your host\n", "HOST = \"localhost\" # Replace with your host\n",
"PORT = 5001 # Replace with your port\n", "PORT = 5001 # Replace with your port\n",
"MODEL_NAME='meta-llama/Llama-3.2-3B-Instruct'" "MODEL_NAME = \"meta-llama/Llama-3.2-3B-Instruct\"\n"
] ]
}, },
{ {
@ -60,10 +60,12 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"from dotenv import load_dotenv\n",
"import os\n", "import os\n",
"\n",
"from dotenv import load_dotenv\n",
"\n",
"load_dotenv()\n", "load_dotenv()\n",
"BRAVE_SEARCH_API_KEY = os.environ['BRAVE_SEARCH_API_KEY']" "BRAVE_SEARCH_API_KEY = os.environ[\"BRAVE_SEARCH_API_KEY\"]\n"
] ]
}, },
{ {
@ -104,20 +106,22 @@
], ],
"source": [ "source": [
"import os\n", "import os\n",
"\n",
"from llama_stack_client import LlamaStackClient\n", "from llama_stack_client import LlamaStackClient\n",
"from llama_stack_client.lib.agents.agent import Agent\n", "from llama_stack_client.lib.agents.agent import Agent\n",
"from llama_stack_client.lib.agents.event_logger import EventLogger\n", "from llama_stack_client.lib.agents.event_logger import EventLogger\n",
"from llama_stack_client.types.agent_create_params import AgentConfig\n", "from llama_stack_client.types.agent_create_params import AgentConfig\n",
"\n", "\n",
"\n",
"async def agent_example():\n", "async def agent_example():\n",
" client = LlamaStackClient(base_url=f\"http://{HOST}:{PORT}\")\n", " client = LlamaStackClient(base_url=f\"http://{HOST}:{PORT}\")\n",
" agent_config = AgentConfig(\n", " agent_config = AgentConfig(\n",
" model=MODEL_NAME,\n", " model=MODEL_NAME,\n",
" instructions=\"You are a helpful assistant! If you call builtin tools like brave search, follow the syntax brave_search.call(…)\",\n", " instructions=\"You are a helpful assistant! If you call builtin tools like brave search, follow the syntax brave_search.call(…)\",\n",
" sampling_params={\n", " sampling_params={\n",
" \"strategy\": \"greedy\",\n", " \"strategy\": {\n",
" \"temperature\": 1.0,\n", " \"type\": \"greedy\",\n",
" \"top_p\": 0.9,\n", " },\n",
" },\n", " },\n",
" tools=[\n", " tools=[\n",
" {\n", " {\n",
@ -157,7 +161,7 @@
" log.print()\n", " log.print()\n",
"\n", "\n",
"\n", "\n",
"await agent_example()" "await agent_example()\n"
] ]
}, },
{ {

View file

@ -157,7 +157,15 @@ curl http://localhost:$LLAMA_STACK_PORT/alpha/inference/chat-completion
{"role": "system", "content": "You are a helpful assistant."}, {"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Write me a 2-sentence poem about the moon"} {"role": "user", "content": "Write me a 2-sentence poem about the moon"}
], ],
"sampling_params": {"temperature": 0.7, "seed": 42, "max_tokens": 512} "sampling_params": {
"strategy": {
"type": "top_p",
"temperatrue": 0.7,
"top_p": 0.95,
},
"seed": 42,
"max_tokens": 512
}
} }
EOF EOF
``` ```

View file

@ -83,8 +83,8 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"LLAMA_STACK_API_TOGETHER_URL=\"https://llama-stack.together.ai\"\n", "LLAMA_STACK_API_TOGETHER_URL = \"https://llama-stack.together.ai\"\n",
"LLAMA31_8B_INSTRUCT = \"Llama3.1-8B-Instruct\"" "LLAMA31_8B_INSTRUCT = \"Llama3.1-8B-Instruct\"\n"
] ]
}, },
{ {
@ -107,12 +107,13 @@
" AgentConfigToolSearchToolDefinition,\n", " AgentConfigToolSearchToolDefinition,\n",
")\n", ")\n",
"\n", "\n",
"\n",
"# Helper function to create an agent with tools\n", "# Helper function to create an agent with tools\n",
"async def create_tool_agent(\n", "async def create_tool_agent(\n",
" client: LlamaStackClient,\n", " client: LlamaStackClient,\n",
" tools: List[Dict],\n", " tools: List[Dict],\n",
" instructions: str = \"You are a helpful assistant\",\n", " instructions: str = \"You are a helpful assistant\",\n",
" model: str = LLAMA31_8B_INSTRUCT\n", " model: str = LLAMA31_8B_INSTRUCT,\n",
") -> Agent:\n", ") -> Agent:\n",
" \"\"\"Create an agent with specified tools.\"\"\"\n", " \"\"\"Create an agent with specified tools.\"\"\"\n",
" print(\"Using the following model: \", model)\n", " print(\"Using the following model: \", model)\n",
@ -120,9 +121,9 @@
" model=model,\n", " model=model,\n",
" instructions=instructions,\n", " instructions=instructions,\n",
" sampling_params={\n", " sampling_params={\n",
" \"strategy\": \"greedy\",\n", " \"strategy\": {\n",
" \"temperature\": 1.0,\n", " \"type\": \"greedy\",\n",
" \"top_p\": 0.9,\n", " },\n",
" },\n", " },\n",
" tools=tools,\n", " tools=tools,\n",
" tool_choice=\"auto\",\n", " tool_choice=\"auto\",\n",
@ -130,7 +131,7 @@
" enable_session_persistence=True,\n", " enable_session_persistence=True,\n",
" )\n", " )\n",
"\n", "\n",
" return Agent(client, agent_config)" " return Agent(client, agent_config)\n"
] ]
}, },
{ {
@ -172,7 +173,8 @@
], ],
"source": [ "source": [
"# comment this if you don't have a BRAVE_SEARCH_API_KEY\n", "# comment this if you don't have a BRAVE_SEARCH_API_KEY\n",
"os.environ[\"BRAVE_SEARCH_API_KEY\"] = 'YOUR_BRAVE_SEARCH_API_KEY'\n", "os.environ[\"BRAVE_SEARCH_API_KEY\"] = \"YOUR_BRAVE_SEARCH_API_KEY\"\n",
"\n",
"\n", "\n",
"async def create_search_agent(client: LlamaStackClient) -> Agent:\n", "async def create_search_agent(client: LlamaStackClient) -> Agent:\n",
" \"\"\"Create an agent with Brave Search capability.\"\"\"\n", " \"\"\"Create an agent with Brave Search capability.\"\"\"\n",
@ -186,8 +188,8 @@
"\n", "\n",
" return await create_tool_agent(\n", " return await create_tool_agent(\n",
" client=client,\n", " client=client,\n",
" tools=[search_tool], # set this to [] if you don't have a BRAVE_SEARCH_API_KEY\n", " tools=[search_tool], # set this to [] if you don't have a BRAVE_SEARCH_API_KEY\n",
" model = LLAMA31_8B_INSTRUCT,\n", " model=LLAMA31_8B_INSTRUCT,\n",
" instructions=\"\"\"\n", " instructions=\"\"\"\n",
" You are a research assistant that can search the web.\n", " You are a research assistant that can search the web.\n",
" Always cite your sources with URLs when providing information.\n", " Always cite your sources with URLs when providing information.\n",
@ -198,9 +200,10 @@
"\n", "\n",
" SOURCES:\n", " SOURCES:\n",
" - [Source title](URL)\n", " - [Source title](URL)\n",
" \"\"\"\n", " \"\"\",\n",
" )\n", " )\n",
"\n", "\n",
"\n",
"# Example usage\n", "# Example usage\n",
"async def search_example():\n", "async def search_example():\n",
" client = LlamaStackClient(base_url=LLAMA_STACK_API_TOGETHER_URL)\n", " client = LlamaStackClient(base_url=LLAMA_STACK_API_TOGETHER_URL)\n",
@ -212,7 +215,7 @@
" # Example queries\n", " # Example queries\n",
" queries = [\n", " queries = [\n",
" \"What are the latest developments in quantum computing?\",\n", " \"What are the latest developments in quantum computing?\",\n",
" #\"Who won the most recent Super Bowl?\",\n", " # \"Who won the most recent Super Bowl?\",\n",
" ]\n", " ]\n",
"\n", "\n",
" for query in queries:\n", " for query in queries:\n",
@ -227,8 +230,9 @@
" async for log in EventLogger().log(response):\n", " async for log in EventLogger().log(response):\n",
" log.print()\n", " log.print()\n",
"\n", "\n",
"\n",
"# Run the example (in Jupyter, use asyncio.run())\n", "# Run the example (in Jupyter, use asyncio.run())\n",
"await search_example()" "await search_example()\n"
] ]
}, },
{ {
@ -286,12 +290,16 @@
} }
], ],
"source": [ "source": [
"from typing import TypedDict, Optional, Dict, Any\n",
"from datetime import datetime\n",
"import json\n", "import json\n",
"from llama_stack_client.types.tool_param_definition_param import ToolParamDefinitionParam\n", "from datetime import datetime\n",
"from llama_stack_client.types import CompletionMessage,ToolResponseMessage\n", "from typing import Any, Dict, Optional, TypedDict\n",
"\n",
"from llama_stack_client.lib.agents.custom_tool import CustomTool\n", "from llama_stack_client.lib.agents.custom_tool import CustomTool\n",
"from llama_stack_client.types import CompletionMessage, ToolResponseMessage\n",
"from llama_stack_client.types.tool_param_definition_param import (\n",
" ToolParamDefinitionParam,\n",
")\n",
"\n",
"\n", "\n",
"class WeatherTool(CustomTool):\n", "class WeatherTool(CustomTool):\n",
" \"\"\"Example custom tool for weather information.\"\"\"\n", " \"\"\"Example custom tool for weather information.\"\"\"\n",
@ -305,16 +313,15 @@
" def get_params_definition(self) -> Dict[str, ToolParamDefinitionParam]:\n", " def get_params_definition(self) -> Dict[str, ToolParamDefinitionParam]:\n",
" return {\n", " return {\n",
" \"location\": ToolParamDefinitionParam(\n", " \"location\": ToolParamDefinitionParam(\n",
" param_type=\"str\",\n", " param_type=\"str\", description=\"City or location name\", required=True\n",
" description=\"City or location name\",\n",
" required=True\n",
" ),\n", " ),\n",
" \"date\": ToolParamDefinitionParam(\n", " \"date\": ToolParamDefinitionParam(\n",
" param_type=\"str\",\n", " param_type=\"str\",\n",
" description=\"Optional date (YYYY-MM-DD)\",\n", " description=\"Optional date (YYYY-MM-DD)\",\n",
" required=False\n", " required=False,\n",
" )\n", " ),\n",
" }\n", " }\n",
"\n",
" async def run(self, messages: List[CompletionMessage]) -> List[ToolResponseMessage]:\n", " async def run(self, messages: List[CompletionMessage]) -> List[ToolResponseMessage]:\n",
" assert len(messages) == 1, \"Expected single message\"\n", " assert len(messages) == 1, \"Expected single message\"\n",
"\n", "\n",
@ -337,20 +344,14 @@
" )\n", " )\n",
" return [message]\n", " return [message]\n",
"\n", "\n",
" async def run_impl(self, location: str, date: Optional[str] = None) -> Dict[str, Any]:\n", " async def run_impl(\n",
" self, location: str, date: Optional[str] = None\n",
" ) -> Dict[str, Any]:\n",
" \"\"\"Simulate getting weather data (replace with actual API call).\"\"\"\n", " \"\"\"Simulate getting weather data (replace with actual API call).\"\"\"\n",
" # Mock implementation\n", " # Mock implementation\n",
" if date:\n", " if date:\n",
" return {\n", " return {\"temperature\": 90.1, \"conditions\": \"sunny\", \"humidity\": 40.0}\n",
" \"temperature\": 90.1,\n", " return {\"temperature\": 72.5, \"conditions\": \"partly cloudy\", \"humidity\": 65.0}\n",
" \"conditions\": \"sunny\",\n",
" \"humidity\": 40.0\n",
" }\n",
" return {\n",
" \"temperature\": 72.5,\n",
" \"conditions\": \"partly cloudy\",\n",
" \"humidity\": 65.0\n",
" }\n",
"\n", "\n",
"\n", "\n",
"async def create_weather_agent(client: LlamaStackClient) -> Agent:\n", "async def create_weather_agent(client: LlamaStackClient) -> Agent:\n",
@ -358,38 +359,33 @@
"\n", "\n",
" # Create the agent with the tool\n", " # Create the agent with the tool\n",
" weather_tool = WeatherTool()\n", " weather_tool = WeatherTool()\n",
" \n", "\n",
" agent_config = AgentConfig(\n", " agent_config = AgentConfig(\n",
" model=LLAMA31_8B_INSTRUCT,\n", " model=LLAMA31_8B_INSTRUCT,\n",
" #model=model_name,\n", " # model=model_name,\n",
" instructions=\"\"\"\n", " instructions=\"\"\"\n",
" You are a weather assistant that can provide weather information.\n", " You are a weather assistant that can provide weather information.\n",
" Always specify the location clearly in your responses.\n", " Always specify the location clearly in your responses.\n",
" Include both temperature and conditions in your summaries.\n", " Include both temperature and conditions in your summaries.\n",
" \"\"\",\n", " \"\"\",\n",
" sampling_params={\n", " sampling_params={\n",
" \"strategy\": \"greedy\",\n", " \"strategy\": {\n",
" \"temperature\": 1.0,\n", " \"type\": \"greedy\",\n",
" \"top_p\": 0.9,\n", " },\n",
" },\n", " },\n",
" tools=[\n", " tools=[weather_tool.get_tool_definition()],\n",
" weather_tool.get_tool_definition()\n",
" ],\n",
" tool_choice=\"auto\",\n", " tool_choice=\"auto\",\n",
" tool_prompt_format=\"json\",\n", " tool_prompt_format=\"json\",\n",
" input_shields=[],\n", " input_shields=[],\n",
" output_shields=[],\n", " output_shields=[],\n",
" enable_session_persistence=True\n", " enable_session_persistence=True,\n",
" )\n", " )\n",
"\n", "\n",
" agent = Agent(\n", " agent = Agent(client=client, agent_config=agent_config, custom_tools=[weather_tool])\n",
" client=client,\n",
" agent_config=agent_config,\n",
" custom_tools=[weather_tool]\n",
" )\n",
"\n", "\n",
" return agent\n", " return agent\n",
"\n", "\n",
"\n",
"# Example usage\n", "# Example usage\n",
"async def weather_example():\n", "async def weather_example():\n",
" client = LlamaStackClient(base_url=LLAMA_STACK_API_TOGETHER_URL)\n", " client = LlamaStackClient(base_url=LLAMA_STACK_API_TOGETHER_URL)\n",
@ -413,12 +409,14 @@
" async for log in EventLogger().log(response):\n", " async for log in EventLogger().log(response):\n",
" log.print()\n", " log.print()\n",
"\n", "\n",
"\n",
"# For Jupyter notebooks\n", "# For Jupyter notebooks\n",
"import nest_asyncio\n", "import nest_asyncio\n",
"\n",
"nest_asyncio.apply()\n", "nest_asyncio.apply()\n",
"\n", "\n",
"# Run the example\n", "# Run the example\n",
"await weather_example()" "await weather_example()\n"
] ]
}, },
{ {

View file

@ -13,7 +13,6 @@ from termcolor import colored
from llama_stack.cli.subcommand import Subcommand from llama_stack.cli.subcommand import Subcommand
from llama_stack.cli.table import print_table from llama_stack.cli.table import print_table
from llama_stack.distribution.utils.serialize import EnumEncoder
class ModelDescribe(Subcommand): class ModelDescribe(Subcommand):
@ -72,7 +71,7 @@ class ModelDescribe(Subcommand):
rows.append( rows.append(
( (
"Recommended sampling params", "Recommended sampling params",
json.dumps(sampling_params, cls=EnumEncoder, indent=4), json.dumps(sampling_params, indent=4),
) )
) )

View file

@ -58,11 +58,6 @@ def define_eval_candidate_2():
# Sampling Parameters # Sampling Parameters
st.markdown("##### Sampling Parameters") st.markdown("##### Sampling Parameters")
strategy = st.selectbox(
"Strategy",
["greedy", "top_p", "top_k"],
index=0,
)
temperature = st.slider( temperature = st.slider(
"Temperature", "Temperature",
min_value=0.0, min_value=0.0,
@ -95,13 +90,20 @@ def define_eval_candidate_2():
help="Controls the likelihood for generating the same word or phrase multiple times in the same sentence or paragraph. 1 implies no penalty, 2 will strongly discourage model to repeat words or phrases.", help="Controls the likelihood for generating the same word or phrase multiple times in the same sentence or paragraph. 1 implies no penalty, 2 will strongly discourage model to repeat words or phrases.",
) )
if candidate_type == "model": if candidate_type == "model":
if temperature > 0.0:
strategy = {
"type": "top_p",
"temperature": temperature,
"top_p": top_p,
}
else:
strategy = {"type": "greedy"}
eval_candidate = { eval_candidate = {
"type": "model", "type": "model",
"model": selected_model, "model": selected_model,
"sampling_params": { "sampling_params": {
"strategy": strategy, "strategy": strategy,
"temperature": temperature,
"top_p": top_p,
"max_tokens": max_tokens, "max_tokens": max_tokens,
"repetition_penalty": repetition_penalty, "repetition_penalty": repetition_penalty,
}, },

View file

@ -95,6 +95,15 @@ if prompt := st.chat_input("Example: What is Llama Stack?"):
message_placeholder = st.empty() message_placeholder = st.empty()
full_response = "" full_response = ""
if temperature > 0.0:
strategy = {
"type": "top_p",
"temperature": temperature,
"top_p": top_p,
}
else:
strategy = {"type": "greedy"}
response = llama_stack_api.client.inference.chat_completion( response = llama_stack_api.client.inference.chat_completion(
messages=[ messages=[
{"role": "system", "content": system_prompt}, {"role": "system", "content": system_prompt},
@ -103,8 +112,7 @@ if prompt := st.chat_input("Example: What is Llama Stack?"):
model_id=selected_model, model_id=selected_model,
stream=stream, stream=stream,
sampling_params={ sampling_params={
"temperature": temperature, "strategy": strategy,
"top_p": top_p,
"max_tokens": max_tokens, "max_tokens": max_tokens,
"repetition_penalty": repetition_penalty, "repetition_penalty": repetition_penalty,
}, },

View file

@ -118,13 +118,20 @@ def rag_chat_page():
with st.chat_message(message["role"]): with st.chat_message(message["role"]):
st.markdown(message["content"]) st.markdown(message["content"])
if temperature > 0.0:
strategy = {
"type": "top_p",
"temperature": temperature,
"top_p": top_p,
}
else:
strategy = {"type": "greedy"}
agent_config = AgentConfig( agent_config = AgentConfig(
model=selected_model, model=selected_model,
instructions=system_prompt, instructions=system_prompt,
sampling_params={ sampling_params={
"strategy": "greedy", "strategy": strategy,
"temperature": temperature,
"top_p": top_p,
}, },
tools=[ tools=[
{ {

View file

@ -23,6 +23,11 @@ from fairscale.nn.model_parallel.initialize import (
initialize_model_parallel, initialize_model_parallel,
model_parallel_is_initialized, model_parallel_is_initialized,
) )
from llama_models.datatypes import (
GreedySamplingStrategy,
SamplingParams,
TopPSamplingStrategy,
)
from llama_models.llama3.api.args import ModelArgs from llama_models.llama3.api.args import ModelArgs
from llama_models.llama3.api.chat_format import ChatFormat, LLMInput from llama_models.llama3.api.chat_format import ChatFormat, LLMInput
from llama_models.llama3.api.datatypes import Model from llama_models.llama3.api.datatypes import Model
@ -363,11 +368,12 @@ class Llama:
max_gen_len = self.model.params.max_seq_len - 1 max_gen_len = self.model.params.max_seq_len - 1
model_input = self.formatter.encode_content(request.content) model_input = self.formatter.encode_content(request.content)
temperature, top_p = _infer_sampling_params(sampling_params)
yield from self.generate( yield from self.generate(
model_input=model_input, model_input=model_input,
max_gen_len=max_gen_len, max_gen_len=max_gen_len,
temperature=sampling_params.temperature, temperature=temperature,
top_p=sampling_params.top_p, top_p=top_p,
logprobs=bool(request.logprobs), logprobs=bool(request.logprobs),
include_stop_token=True, include_stop_token=True,
logits_processor=get_logits_processor( logits_processor=get_logits_processor(
@ -390,14 +396,15 @@ class Llama:
): ):
max_gen_len = self.model.params.max_seq_len - 1 max_gen_len = self.model.params.max_seq_len - 1
temperature, top_p = _infer_sampling_params(sampling_params)
yield from self.generate( yield from self.generate(
model_input=self.formatter.encode_dialog_prompt( model_input=self.formatter.encode_dialog_prompt(
request.messages, request.messages,
request.tool_prompt_format, request.tool_prompt_format,
), ),
max_gen_len=max_gen_len, max_gen_len=max_gen_len,
temperature=sampling_params.temperature, temperature=temperature,
top_p=sampling_params.top_p, top_p=top_p,
logprobs=bool(request.logprobs), logprobs=bool(request.logprobs),
include_stop_token=True, include_stop_token=True,
logits_processor=get_logits_processor( logits_processor=get_logits_processor(
@ -492,3 +499,15 @@ def _build_regular_tokens_list(
is_word_start_token = len(decoded_after_0) > len(decoded_regular) is_word_start_token = len(decoded_after_0) > len(decoded_regular)
regular_tokens.append((token_idx, decoded_after_0, is_word_start_token)) regular_tokens.append((token_idx, decoded_after_0, is_word_start_token))
return regular_tokens return regular_tokens
def _infer_sampling_params(sampling_params: SamplingParams):
if isinstance(sampling_params.strategy, GreedySamplingStrategy):
temperature = 0.0
top_p = 1.0
elif isinstance(sampling_params.strategy, TopPSamplingStrategy):
temperature = sampling_params.strategy.temperature
top_p = sampling_params.strategy.top_p
else:
raise ValueError(f"Unsupported sampling strategy {sampling_params.strategy}")
return temperature, top_p

View file

@ -36,6 +36,7 @@ from llama_stack.apis.inference import (
from llama_stack.apis.models import Model from llama_stack.apis.models import Model
from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
OpenAICompatCompletionChoice, OpenAICompatCompletionChoice,
OpenAICompatCompletionResponse, OpenAICompatCompletionResponse,
process_chat_completion_response, process_chat_completion_response,
@ -126,21 +127,12 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
if sampling_params is None: if sampling_params is None:
return VLLMSamplingParams(max_tokens=self.config.max_tokens) return VLLMSamplingParams(max_tokens=self.config.max_tokens)
# TODO convert what I saw in my first test ... but surely there's more to do here options = get_sampling_options(sampling_params)
kwargs = { if "repeat_penalty" in options:
"temperature": sampling_params.temperature, options["repetition_penalty"] = options["repeat_penalty"]
"max_tokens": self.config.max_tokens, del options["repeat_penalty"]
}
if sampling_params.top_k:
kwargs["top_k"] = sampling_params.top_k
if sampling_params.top_p:
kwargs["top_p"] = sampling_params.top_p
if sampling_params.max_tokens:
kwargs["max_tokens"] = sampling_params.max_tokens
if sampling_params.repetition_penalty > 0:
kwargs["repetition_penalty"] = sampling_params.repetition_penalty
return VLLMSamplingParams(**kwargs) return VLLMSamplingParams(**options)
async def unregister_model(self, model_id: str) -> None: async def unregister_model(self, model_id: str) -> None:
pass pass

View file

@ -34,6 +34,7 @@ from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper, ModelRegistryHelper,
) )
from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_strategy_options,
OpenAICompatCompletionChoice, OpenAICompatCompletionChoice,
OpenAICompatCompletionResponse, OpenAICompatCompletionResponse,
process_chat_completion_response, process_chat_completion_response,
@ -166,16 +167,13 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
) -> Dict: ) -> Dict:
bedrock_model = request.model bedrock_model = request.model
inference_config = {} sampling_params = request.sampling_params
param_mapping = { options = get_sampling_strategy_options(sampling_params)
"max_tokens": "max_gen_len",
"temperature": "temperature",
"top_p": "top_p",
}
for k, v in param_mapping.items(): if sampling_params.max_tokens:
if getattr(request.sampling_params, k): options["max_gen_len"] = sampling_params.max_tokens
inference_config[v] = getattr(request.sampling_params, k) if sampling_params.repetition_penalty > 0:
options["repetition_penalty"] = sampling_params.repetition_penalty
prompt = await chat_completion_request_to_prompt( prompt = await chat_completion_request_to_prompt(
request, self.get_llama_model(request.model), self.formatter request, self.get_llama_model(request.model), self.formatter
@ -185,7 +183,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
"body": json.dumps( "body": json.dumps(
{ {
"prompt": prompt, "prompt": prompt,
**inference_config, **options,
} }
), ),
} }

View file

@ -9,6 +9,7 @@ from typing import AsyncGenerator, List, Optional, Union
from cerebras.cloud.sdk import AsyncCerebras from cerebras.cloud.sdk import AsyncCerebras
from llama_models.datatypes import CoreModelId from llama_models.datatypes import CoreModelId
from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import TopKSamplingStrategy
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.apis.common.content_types import InterleavedContent from llama_stack.apis.common.content_types import InterleavedContent
@ -172,7 +173,9 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
async def _get_params( async def _get_params(
self, request: Union[ChatCompletionRequest, CompletionRequest] self, request: Union[ChatCompletionRequest, CompletionRequest]
) -> dict: ) -> dict:
if request.sampling_params and request.sampling_params.top_k: if request.sampling_params and isinstance(
request.sampling_params.strategy, TopKSamplingStrategy
):
raise ValueError("`top_k` not supported by Cerebras") raise ValueError("`top_k` not supported by Cerebras")
prompt = "" prompt = ""

View file

@ -48,6 +48,9 @@ from llama_stack.apis.inference import (
ToolDefinition, ToolDefinition,
ToolPromptFormat, ToolPromptFormat,
) )
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_strategy_options,
)
def convert_chat_completion_request( def convert_chat_completion_request(
@ -77,6 +80,7 @@ def convert_chat_completion_request(
if request.tool_prompt_format != ToolPromptFormat.json: if request.tool_prompt_format != ToolPromptFormat.json:
warnings.warn("tool_prompt_format is not used by Groq. Ignoring.") warnings.warn("tool_prompt_format is not used by Groq. Ignoring.")
sampling_options = get_sampling_strategy_options(request.sampling_params)
return CompletionCreateParams( return CompletionCreateParams(
model=request.model, model=request.model,
messages=[_convert_message(message) for message in request.messages], messages=[_convert_message(message) for message in request.messages],
@ -84,8 +88,8 @@ def convert_chat_completion_request(
frequency_penalty=None, frequency_penalty=None,
stream=request.stream, stream=request.stream,
max_tokens=request.sampling_params.max_tokens or None, max_tokens=request.sampling_params.max_tokens or None,
temperature=request.sampling_params.temperature, temperature=sampling_options.get("temperature", 1.0),
top_p=request.sampling_params.top_p, top_p=sampling_options.get("top_p", 1.0),
tools=[_convert_groq_tool_definition(tool) for tool in request.tools or []], tools=[_convert_groq_tool_definition(tool) for tool in request.tools or []],
tool_choice=request.tool_choice.value if request.tool_choice else None, tool_choice=request.tool_choice.value if request.tool_choice else None,
) )

View file

@ -8,6 +8,11 @@ import json
import warnings import warnings
from typing import Any, AsyncGenerator, Dict, Generator, List, Optional from typing import Any, AsyncGenerator, Dict, Generator, List, Optional
from llama_models.datatypes import (
GreedySamplingStrategy,
TopKSamplingStrategy,
TopPSamplingStrategy,
)
from llama_models.llama3.api.datatypes import ( from llama_models.llama3.api.datatypes import (
BuiltinTool, BuiltinTool,
StopReason, StopReason,
@ -263,19 +268,20 @@ def convert_chat_completion_request(
if request.sampling_params.max_tokens: if request.sampling_params.max_tokens:
payload.update(max_tokens=request.sampling_params.max_tokens) payload.update(max_tokens=request.sampling_params.max_tokens)
if request.sampling_params.strategy == "top_p": strategy = request.sampling_params.strategy
if isinstance(strategy, TopPSamplingStrategy):
nvext.update(top_k=-1) nvext.update(top_k=-1)
payload.update(top_p=request.sampling_params.top_p) payload.update(top_p=strategy.top_p)
elif request.sampling_params.strategy == "top_k": payload.update(temperature=strategy.temperature)
if ( elif isinstance(strategy, TopKSamplingStrategy):
request.sampling_params.top_k != -1 if strategy.top_k != -1 and strategy.top_k < 1:
and request.sampling_params.top_k < 1
):
warnings.warn("top_k must be -1 or >= 1") warnings.warn("top_k must be -1 or >= 1")
nvext.update(top_k=request.sampling_params.top_k) nvext.update(top_k=strategy.top_k)
elif request.sampling_params.strategy == "greedy": elif isinstance(strategy, GreedySamplingStrategy):
nvext.update(top_k=-1) nvext.update(top_k=-1)
payload.update(temperature=request.sampling_params.temperature) payload.update(temperature=strategy.temperature)
else:
raise ValueError(f"Unsupported sampling strategy: {strategy}")
return payload return payload

View file

@ -7,6 +7,7 @@
import os import os
import pytest import pytest
from llama_models.datatypes import SamplingParams, TopPSamplingStrategy
from llama_models.llama3.api.datatypes import BuiltinTool from llama_models.llama3.api.datatypes import BuiltinTool
from llama_stack.apis.agents import ( from llama_stack.apis.agents import (
@ -22,7 +23,8 @@ from llama_stack.apis.agents import (
ToolExecutionStep, ToolExecutionStep,
Turn, Turn,
) )
from llama_stack.apis.inference import CompletionMessage, SamplingParams, UserMessage
from llama_stack.apis.inference import CompletionMessage, UserMessage
from llama_stack.apis.safety import ViolationLevel from llama_stack.apis.safety import ViolationLevel
from llama_stack.providers.datatypes import Api from llama_stack.providers.datatypes import Api
@ -42,7 +44,9 @@ def common_params(inference_model):
model=inference_model, model=inference_model,
instructions="You are a helpful assistant.", instructions="You are a helpful assistant.",
enable_session_persistence=True, enable_session_persistence=True,
sampling_params=SamplingParams(temperature=0.7, top_p=0.95), sampling_params=SamplingParams(
strategy=TopPSamplingStrategy(temperature=0.7, top_p=0.95)
),
input_shields=[], input_shields=[],
output_shields=[], output_shields=[],
toolgroups=[], toolgroups=[],

View file

@ -21,6 +21,7 @@ from groq.types.chat.chat_completion_message_tool_call import (
Function, Function,
) )
from groq.types.shared.function_definition import FunctionDefinition from groq.types.shared.function_definition import FunctionDefinition
from llama_models.datatypes import GreedySamplingStrategy, TopPSamplingStrategy
from llama_models.llama3.api.datatypes import ToolParamDefinition from llama_models.llama3.api.datatypes import ToolParamDefinition
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionRequest, ChatCompletionRequest,
@ -152,21 +153,30 @@ class TestConvertChatCompletionRequest:
assert converted["max_tokens"] == 100 assert converted["max_tokens"] == 100
def test_includes_temperature(self): def _dummy_chat_completion_request(self):
return ChatCompletionRequest(
model="Llama-3.2-3B",
messages=[UserMessage(content="Hello World")],
)
def test_includes_stratgy(self):
request = self._dummy_chat_completion_request() request = self._dummy_chat_completion_request()
request.sampling_params.temperature = 0.5 request.sampling_params.strategy = TopPSamplingStrategy(
temperature=0.5, top_p=0.95
)
converted = convert_chat_completion_request(request) converted = convert_chat_completion_request(request)
assert converted["temperature"] == 0.5 assert converted["temperature"] == 0.5
assert converted["top_p"] == 0.95
def test_includes_top_p(self): def test_includes_greedy_strategy(self):
request = self._dummy_chat_completion_request() request = self._dummy_chat_completion_request()
request.sampling_params.top_p = 0.95 request.sampling_params.strategy = GreedySamplingStrategy()
converted = convert_chat_completion_request(request) converted = convert_chat_completion_request(request)
assert converted["top_p"] == 0.95 assert converted["temperature"] == 0.0
def test_includes_tool_choice(self): def test_includes_tool_choice(self):
request = self._dummy_chat_completion_request() request = self._dummy_chat_completion_request()
@ -268,12 +278,6 @@ class TestConvertChatCompletionRequest:
}, },
] ]
def _dummy_chat_completion_request(self):
return ChatCompletionRequest(
model="Llama-3.2-3B",
messages=[UserMessage(content="Hello World")],
)
class TestConvertNonStreamChatCompletionResponse: class TestConvertNonStreamChatCompletionResponse:
def test_returns_response(self): def test_returns_response(self):
@ -409,19 +413,19 @@ class TestConvertStreamChatCompletionResponse:
iter = converted.__aiter__() iter = converted.__aiter__()
chunk = await iter.__anext__() chunk = await iter.__anext__()
assert chunk.event.event_type == ChatCompletionResponseEventType.start assert chunk.event.event_type == ChatCompletionResponseEventType.start
assert chunk.event.delta == "Hello " assert chunk.event.delta.text == "Hello "
chunk = await iter.__anext__() chunk = await iter.__anext__()
assert chunk.event.event_type == ChatCompletionResponseEventType.progress assert chunk.event.event_type == ChatCompletionResponseEventType.progress
assert chunk.event.delta == "World " assert chunk.event.delta.text == "World "
chunk = await iter.__anext__() chunk = await iter.__anext__()
assert chunk.event.event_type == ChatCompletionResponseEventType.progress assert chunk.event.event_type == ChatCompletionResponseEventType.progress
assert chunk.event.delta == " !" assert chunk.event.delta.text == " !"
chunk = await iter.__anext__() chunk = await iter.__anext__()
assert chunk.event.event_type == ChatCompletionResponseEventType.complete assert chunk.event.event_type == ChatCompletionResponseEventType.complete
assert chunk.event.delta == "" assert chunk.event.delta.text == ""
assert chunk.event.stop_reason == StopReason.end_of_turn assert chunk.event.stop_reason == StopReason.end_of_turn
with pytest.raises(StopAsyncIteration): with pytest.raises(StopAsyncIteration):

View file

@ -32,6 +32,7 @@ from llama_stack.apis.inference import (
UserMessage, UserMessage,
) )
from llama_stack.apis.models import Model from llama_stack.apis.models import Model
from .utils import group_chunks from .utils import group_chunks
@ -476,7 +477,7 @@ class TestInference:
last = grouped[ChatCompletionResponseEventType.progress][-1] last = grouped[ChatCompletionResponseEventType.progress][-1]
# assert last.event.stop_reason == expected_stop_reason # assert last.event.stop_reason == expected_stop_reason
assert last.event.delta.parse_status == ToolCallParseStatus.succeeded assert last.event.delta.parse_status == ToolCallParseStatus.succeeded
assert last.event.delta.content.type == "tool_call" assert isinstance(last.event.delta.content, ToolCall)
call = last.event.delta.content call = last.event.delta.content
assert call.tool_name == "get_weather" assert call.tool_name == "get_weather"

View file

@ -8,7 +8,13 @@ from typing import AsyncGenerator, List, Optional
from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import SamplingParams, StopReason from llama_models.llama3.api.datatypes import (
GreedySamplingStrategy,
SamplingParams,
StopReason,
TopKSamplingStrategy,
TopPSamplingStrategy,
)
from pydantic import BaseModel from pydantic import BaseModel
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
@ -49,12 +55,27 @@ class OpenAICompatCompletionResponse(BaseModel):
choices: List[OpenAICompatCompletionChoice] choices: List[OpenAICompatCompletionChoice]
def get_sampling_strategy_options(params: SamplingParams) -> dict:
options = {}
if isinstance(params.strategy, GreedySamplingStrategy):
options["temperature"] = 0.0
elif isinstance(params.strategy, TopPSamplingStrategy):
options["temperature"] = params.strategy.temperature
options["top_p"] = params.strategy.top_p
elif isinstance(params.strategy, TopKSamplingStrategy):
options["top_k"] = params.strategy.top_k
else:
raise ValueError(f"Unsupported sampling strategy: {params.strategy}")
return options
def get_sampling_options(params: SamplingParams) -> dict: def get_sampling_options(params: SamplingParams) -> dict:
options = {} options = {}
if params: if params:
for attr in {"temperature", "top_p", "top_k", "max_tokens"}: options.update(get_sampling_strategy_options(params))
if getattr(params, attr): if params.max_tokens:
options[attr] = getattr(params, attr) options["max_tokens"] = params.max_tokens
if params.repetition_penalty is not None and params.repetition_penalty != 1.0: if params.repetition_penalty is not None and params.repetition_penalty != 1.0:
options["repeat_penalty"] = params.repetition_penalty options["repeat_penalty"] = params.repetition_penalty

View file

@ -97,9 +97,11 @@ def agent_config(llama_stack_client):
model=model_id, model=model_id,
instructions="You are a helpful assistant", instructions="You are a helpful assistant",
sampling_params={ sampling_params={
"strategy": "greedy", "strategy": {
"temperature": 1.0, "type": "greedy",
"top_p": 0.9, "temperature": 1.0,
"top_p": 0.9,
},
}, },
toolgroups=[], toolgroups=[],
tool_choice="auto", tool_choice="auto",