Convert SamplingParams.strategy to a union (#767)

# What does this PR do?

Cleans up how we provide sampling params. Earlier, strategy was an enum
and all params (top_p, temperature, top_k) across all strategies were
grouped. We now have a strategy union object with each strategy (greedy,
top_p, top_k) having its corresponding params.
Earlier, 
```
class SamplingParams: 
    strategy: enum ()
    top_p, temperature, top_k and other params
```
However, the `strategy` field was not being used in any providers making
it confusing to know the exact sampling behavior purely based on the
params since you could pass temperature, top_p, top_k and how the
provider would interpret those would not be clear.

Hence we introduced -- a union where the strategy and relevant params
are all clubbed together to avoid this confusion.

Have updated all providers, tests, notebooks, readme and otehr places
where sampling params was being used to use the new format.
   

## Test Plan
`pytest llama_stack/providers/tests/inference/groq/test_groq_utils.py`
// inference on ollama, fireworks and together 
`with-proxy pytest -v -s -k "ollama"
--inference-model="meta-llama/Llama-3.1-8B-Instruct"
llama_stack/providers/tests/inference/test_text_inference.py `
// agents on fireworks 
`pytest -v -s -k 'fireworks and create_agent'
--inference-model="meta-llama/Llama-3.1-8B-Instruct"
llama_stack/providers/tests/agents/test_agents.py
--safety-shield="meta-llama/Llama-Guard-3-8B"`

## Before submitting

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [X] Ran pre-commit to handle lint / formatting issues.
- [X] Read the [contributor
guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md),
      Pull Request section?
- [X] Updated relevant documentation.
- [X] Wrote necessary unit or integration tests.

---------

Co-authored-by: Hardik Shah <hjshah@fb.com>
This commit is contained in:
Hardik Shah 2025-01-15 05:38:51 -08:00 committed by GitHub
parent 300e6e2702
commit a51c8b4efc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
29 changed files with 611 additions and 388 deletions

View file

@ -713,13 +713,15 @@
],
"source": [
"import os\n",
"\n",
"from google.colab import userdata\n",
"\n",
"os.environ['TOGETHER_API_KEY'] = userdata.get('TOGETHER_API_KEY')\n",
"os.environ[\"TOGETHER_API_KEY\"] = userdata.get(\"TOGETHER_API_KEY\")\n",
"\n",
"from llama_stack.distribution.library_client import LlamaStackAsLibraryClient\n",
"\n",
"client = LlamaStackAsLibraryClient(\"together\")\n",
"_ = client.initialize()"
"_ = client.initialize()\n"
]
},
{
@ -769,6 +771,7 @@
],
"source": [
"from rich.pretty import pprint\n",
"\n",
"print(\"Available models:\")\n",
"for m in client.models.list():\n",
" print(f\"{m.identifier} (provider's alias: {m.provider_resource_id}) \")\n",
@ -777,7 +780,7 @@
"print(\"Available shields (safety models):\")\n",
"for s in client.shields.list():\n",
" print(s.identifier)\n",
"print(\"----\")"
"print(\"----\")\n"
]
},
{
@ -822,7 +825,7 @@
"source": [
"model_id = \"meta-llama/Llama-3.1-70B-Instruct\"\n",
"\n",
"model_id"
"model_id\n"
]
},
{
@ -863,11 +866,11 @@
" model_id=model_id,\n",
" messages=[\n",
" {\"role\": \"system\", \"content\": \"You are a friendly assistant.\"},\n",
" {\"role\": \"user\", \"content\": \"Write a two-sentence poem about llama.\"}\n",
" {\"role\": \"user\", \"content\": \"Write a two-sentence poem about llama.\"},\n",
" ],\n",
")\n",
"\n",
"print(response.completion_message.content)"
"print(response.completion_message.content)\n"
]
},
{
@ -900,12 +903,13 @@
"source": [
"from termcolor import cprint\n",
"\n",
"\n",
"def chat_loop():\n",
" conversation_history = []\n",
" while True:\n",
" user_input = input('User> ')\n",
" if user_input.lower() in ['exit', 'quit', 'bye']:\n",
" cprint('Ending conversation. Goodbye!', 'yellow')\n",
" user_input = input(\"User> \")\n",
" if user_input.lower() in [\"exit\", \"quit\", \"bye\"]:\n",
" cprint(\"Ending conversation. Goodbye!\", \"yellow\")\n",
" break\n",
"\n",
" user_message = {\"role\": \"user\", \"content\": user_input}\n",
@ -915,14 +919,15 @@
" messages=conversation_history,\n",
" model_id=model_id,\n",
" )\n",
" cprint(f'> Response: {response.completion_message.content}', 'cyan')\n",
" cprint(f\"> Response: {response.completion_message.content}\", \"cyan\")\n",
"\n",
" assistant_message = {\n",
" \"role\": \"assistant\", # was user\n",
" \"role\": \"assistant\", # was user\n",
" \"content\": response.completion_message.content,\n",
" }\n",
" conversation_history.append(assistant_message)\n",
"\n",
"\n",
"chat_loop()\n"
]
},
@ -978,21 +983,18 @@
"source": [
"from llama_stack_client.lib.inference.event_logger import EventLogger\n",
"\n",
"message = {\n",
" \"role\": \"user\",\n",
" \"content\": 'Write me a sonnet about llama'\n",
"}\n",
"print(f'User> {message[\"content\"]}', 'green')\n",
"message = {\"role\": \"user\", \"content\": \"Write me a sonnet about llama\"}\n",
"print(f'User> {message[\"content\"]}', \"green\")\n",
"\n",
"response = client.inference.chat_completion(\n",
" messages=[message],\n",
" model_id=model_id,\n",
" stream=True, # <-----------\n",
" stream=True, # <-----------\n",
")\n",
"\n",
"# Print the tokens while they are received\n",
"for log in EventLogger().log(response):\n",
" log.print()"
" log.print()\n"
]
},
{
@ -1045,26 +1047,26 @@
"source": [
"from pydantic import BaseModel\n",
"\n",
"\n",
"class Output(BaseModel):\n",
" name: str\n",
" year_born: str\n",
" year_retired: str\n",
"\n",
"\n",
"user_input = \"Michael Jordan was born in 1963. He played basketball for the Chicago Bulls. He retired in 2003. Extract this information into JSON for me. \"\n",
"response = client.inference.completion(\n",
" model_id=model_id,\n",
" content=user_input,\n",
" stream=False,\n",
" sampling_params={\n",
" \"max_tokens\": 50,\n",
" },\n",
" sampling_params={\"strategy\": {\"type\": \"greedy\"}, \"max_tokens\": 50},\n",
" response_format={\n",
" \"type\": \"json_schema\",\n",
" \"json_schema\": Output.model_json_schema(),\n",
" },\n",
")\n",
"\n",
"pprint(response)"
"pprint(response)\n"
]
},
{
@ -1220,7 +1222,7 @@
" shield_id=available_shields[0],\n",
" params={},\n",
" )\n",
" pprint(response)"
" pprint(response)\n"
]
},
{
@ -1489,8 +1491,8 @@
"source": [
"from llama_stack_client.lib.agents.agent import Agent\n",
"from llama_stack_client.lib.agents.event_logger import EventLogger\n",
"from llama_stack_client.types.agent_create_params import AgentConfig\n",
"from llama_stack_client.types import Attachment\n",
"from llama_stack_client.types.agent_create_params import AgentConfig\n",
"from termcolor import cprint\n",
"\n",
"urls = [\"chat.rst\", \"llama3.rst\", \"datasets.rst\", \"lora_finetune.rst\"]\n",
@ -1522,14 +1524,14 @@
" ),\n",
"]\n",
"for prompt, attachments in user_prompts:\n",
" cprint(f'User> {prompt}', 'green')\n",
" cprint(f\"User> {prompt}\", \"green\")\n",
" response = rag_agent.create_turn(\n",
" messages=[{\"role\": \"user\", \"content\": prompt}],\n",
" attachments=attachments,\n",
" session_id=session_id,\n",
" )\n",
" for log in EventLogger().log(response):\n",
" log.print()"
" log.print()\n"
]
},
{
@ -1560,8 +1562,8 @@
"search_tool = {\n",
" \"type\": \"brave_search\",\n",
" \"engine\": \"tavily\",\n",
" \"api_key\": userdata.get(\"TAVILY_SEARCH_API_KEY\")\n",
"}"
" \"api_key\": userdata.get(\"TAVILY_SEARCH_API_KEY\"),\n",
"}\n"
]
},
{
@ -1608,7 +1610,7 @@
"\n",
"session_id = agent.create_session(\"test-session\")\n",
"for prompt in user_prompts:\n",
" cprint(f'User> {prompt}', 'green')\n",
" cprint(f\"User> {prompt}\", \"green\")\n",
" response = agent.create_turn(\n",
" messages=[\n",
" {\n",
@ -1758,7 +1760,7 @@
" search_tool,\n",
" {\n",
" \"type\": \"code_interpreter\",\n",
" }\n",
" },\n",
" ],\n",
" tool_choice=\"required\",\n",
" input_shields=[],\n",
@ -1788,7 +1790,7 @@
"]\n",
"\n",
"for prompt in user_prompts:\n",
" cprint(f'User> {prompt}', 'green')\n",
" cprint(f\"User> {prompt}\", \"green\")\n",
" response = codex_agent.create_turn(\n",
" messages=[\n",
" {\n",
@ -1841,27 +1843,57 @@
}
],
"source": [
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"import pandas as pd\n",
"\n",
"# Read the CSV file\n",
"df = pd.read_csv('/tmp/tmpco0s0o4_/LOdZoVp1inflation.csv')\n",
"df = pd.read_csv(\"/tmp/tmpco0s0o4_/LOdZoVp1inflation.csv\")\n",
"\n",
"# Extract the year and inflation rate from the CSV file\n",
"df['Year'] = pd.to_datetime(df['Year'], format='%Y')\n",
"df = df.rename(columns={'Jan': 'Jan Rate', 'Feb': 'Feb Rate', 'Mar': 'Mar Rate', 'Apr': 'Apr Rate', 'May': 'May Rate', 'Jun': 'Jun Rate', 'Jul': 'Jul Rate', 'Aug': 'Aug Rate', 'Sep': 'Sep Rate', 'Oct': 'Oct Rate', 'Nov': 'Nov Rate', 'Dec': 'Dec Rate'})\n",
"df[\"Year\"] = pd.to_datetime(df[\"Year\"], format=\"%Y\")\n",
"df = df.rename(\n",
" columns={\n",
" \"Jan\": \"Jan Rate\",\n",
" \"Feb\": \"Feb Rate\",\n",
" \"Mar\": \"Mar Rate\",\n",
" \"Apr\": \"Apr Rate\",\n",
" \"May\": \"May Rate\",\n",
" \"Jun\": \"Jun Rate\",\n",
" \"Jul\": \"Jul Rate\",\n",
" \"Aug\": \"Aug Rate\",\n",
" \"Sep\": \"Sep Rate\",\n",
" \"Oct\": \"Oct Rate\",\n",
" \"Nov\": \"Nov Rate\",\n",
" \"Dec\": \"Dec Rate\",\n",
" }\n",
")\n",
"\n",
"# Calculate the average yearly inflation rate\n",
"df['Yearly Inflation'] = df[['Jan Rate', 'Feb Rate', 'Mar Rate', 'Apr Rate', 'May Rate', 'Jun Rate', 'Jul Rate', 'Aug Rate', 'Sep Rate', 'Oct Rate', 'Nov Rate', 'Dec Rate']].mean(axis=1)\n",
"df[\"Yearly Inflation\"] = df[\n",
" [\n",
" \"Jan Rate\",\n",
" \"Feb Rate\",\n",
" \"Mar Rate\",\n",
" \"Apr Rate\",\n",
" \"May Rate\",\n",
" \"Jun Rate\",\n",
" \"Jul Rate\",\n",
" \"Aug Rate\",\n",
" \"Sep Rate\",\n",
" \"Oct Rate\",\n",
" \"Nov Rate\",\n",
" \"Dec Rate\",\n",
" ]\n",
"].mean(axis=1)\n",
"\n",
"# Plot the average yearly inflation rate as a time series\n",
"plt.figure(figsize=(10, 6))\n",
"plt.plot(df['Year'], df['Yearly Inflation'], marker='o')\n",
"plt.title('Average Yearly Inflation Rate')\n",
"plt.xlabel('Year')\n",
"plt.ylabel('Inflation Rate (%)')\n",
"plt.plot(df[\"Year\"], df[\"Yearly Inflation\"], marker=\"o\")\n",
"plt.title(\"Average Yearly Inflation Rate\")\n",
"plt.xlabel(\"Year\")\n",
"plt.ylabel(\"Inflation Rate (%)\")\n",
"plt.grid(True)\n",
"plt.show()"
"plt.show()\n"
]
},
{
@ -2035,6 +2067,8 @@
"source": [
"# disable logging for clean server logs\n",
"import logging\n",
"\n",
"\n",
"def remove_root_handlers():\n",
" root_logger = logging.getLogger()\n",
" for handler in root_logger.handlers[:]:\n",
@ -2042,7 +2076,7 @@
" print(f\"Removed handler {handler.__class__.__name__} from root logger\")\n",
"\n",
"\n",
"remove_root_handlers()"
"remove_root_handlers()\n"
]
},
{
@ -2083,10 +2117,10 @@
}
],
"source": [
"from google.colab import userdata\n",
"from llama_stack_client.lib.agents.agent import Agent\n",
"from llama_stack_client.lib.agents.event_logger import EventLogger\n",
"from llama_stack_client.types.agent_create_params import AgentConfig\n",
"from google.colab import userdata\n",
"\n",
"agent_config = AgentConfig(\n",
" model=\"meta-llama/Llama-3.1-405B-Instruct\",\n",
@ -2096,7 +2130,7 @@
" {\n",
" \"type\": \"brave_search\",\n",
" \"engine\": \"tavily\",\n",
" \"api_key\": userdata.get(\"TAVILY_SEARCH_API_KEY\")\n",
" \"api_key\": userdata.get(\"TAVILY_SEARCH_API_KEY\"),\n",
" }\n",
" ]\n",
" ),\n",
@ -2125,7 +2159,7 @@
" )\n",
"\n",
" for log in EventLogger().log(response):\n",
" log.print()"
" log.print()\n"
]
},
{
@ -2265,20 +2299,21 @@
"source": [
"print(f\"Getting traces for session_id={session_id}\")\n",
"import json\n",
"\n",
"from rich.pretty import pprint\n",
"\n",
"agent_logs = []\n",
"\n",
"for span in client.telemetry.query_spans(\n",
" attribute_filters=[\n",
" {\"key\": \"session_id\", \"op\": \"eq\", \"value\": session_id},\n",
" {\"key\": \"session_id\", \"op\": \"eq\", \"value\": session_id},\n",
" ],\n",
" attributes_to_return=[\"input\", \"output\"]\n",
" ):\n",
" if span.attributes[\"output\"] != \"no shields\":\n",
" agent_logs.append(span.attributes)\n",
" attributes_to_return=[\"input\", \"output\"],\n",
"):\n",
" if span.attributes[\"output\"] != \"no shields\":\n",
" agent_logs.append(span.attributes)\n",
"\n",
"pprint(agent_logs)"
"pprint(agent_logs)\n"
]
},
{
@ -2389,23 +2424,25 @@
"eval_rows = []\n",
"\n",
"for log in agent_logs:\n",
" last_msg = log['input'][-1]\n",
" if \"\\\"role\\\":\\\"user\\\"\" in last_msg:\n",
" eval_rows.append(\n",
" {\n",
" \"input_query\": last_msg,\n",
" \"generated_answer\": log[\"output\"],\n",
" # check if generated_answer uses tools brave_search\n",
" \"expected_answer\": \"brave_search\",\n",
" },\n",
" )\n",
" last_msg = log[\"input\"][-1]\n",
" if '\"role\":\"user\"' in last_msg:\n",
" eval_rows.append(\n",
" {\n",
" \"input_query\": last_msg,\n",
" \"generated_answer\": log[\"output\"],\n",
" # check if generated_answer uses tools brave_search\n",
" \"expected_answer\": \"brave_search\",\n",
" },\n",
" )\n",
"\n",
"pprint(eval_rows)\n",
"scoring_params = {\n",
" \"basic::subset_of\": None,\n",
"}\n",
"scoring_response = client.scoring.score(input_rows=eval_rows, scoring_functions=scoring_params)\n",
"pprint(scoring_response)"
"scoring_response = client.scoring.score(\n",
" input_rows=eval_rows, scoring_functions=scoring_params\n",
")\n",
"pprint(scoring_response)\n"
]
},
{
@ -2506,7 +2543,9 @@
"EXPECTED_RESPONSE: {expected_answer}\n",
"\"\"\"\n",
"\n",
"input_query = \"What are the top 5 topics that were explained? Only list succinct bullet points.\"\n",
"input_query = (\n",
" \"What are the top 5 topics that were explained? Only list succinct bullet points.\"\n",
")\n",
"generated_answer = \"\"\"\n",
"Here are the top 5 topics that were explained in the documentation for Torchtune:\n",
"\n",
@ -2537,7 +2576,7 @@
"}\n",
"\n",
"response = client.scoring.score(input_rows=rows, scoring_functions=scoring_params)\n",
"pprint(response)"
"pprint(response)\n"
]
},
{