diff --git a/docs/getting_started.ipynb b/docs/getting_started.ipynb index fa527f1a0..921869b33 100644 --- a/docs/getting_started.ipynb +++ b/docs/getting_started.ipynb @@ -713,13 +713,15 @@ ], "source": [ "import os\n", + "\n", "from google.colab import userdata\n", "\n", - "os.environ['TOGETHER_API_KEY'] = userdata.get('TOGETHER_API_KEY')\n", + "os.environ[\"TOGETHER_API_KEY\"] = userdata.get(\"TOGETHER_API_KEY\")\n", "\n", "from llama_stack.distribution.library_client import LlamaStackAsLibraryClient\n", + "\n", "client = LlamaStackAsLibraryClient(\"together\")\n", - "_ = client.initialize()" + "_ = client.initialize()\n" ] }, { @@ -769,6 +771,7 @@ ], "source": [ "from rich.pretty import pprint\n", + "\n", "print(\"Available models:\")\n", "for m in client.models.list():\n", " print(f\"{m.identifier} (provider's alias: {m.provider_resource_id}) \")\n", @@ -777,7 +780,7 @@ "print(\"Available shields (safety models):\")\n", "for s in client.shields.list():\n", " print(s.identifier)\n", - "print(\"----\")" + "print(\"----\")\n" ] }, { @@ -822,7 +825,7 @@ "source": [ "model_id = \"meta-llama/Llama-3.1-70B-Instruct\"\n", "\n", - "model_id" + "model_id\n" ] }, { @@ -863,11 +866,11 @@ " model_id=model_id,\n", " messages=[\n", " {\"role\": \"system\", \"content\": \"You are a friendly assistant.\"},\n", - " {\"role\": \"user\", \"content\": \"Write a two-sentence poem about llama.\"}\n", + " {\"role\": \"user\", \"content\": \"Write a two-sentence poem about llama.\"},\n", " ],\n", ")\n", "\n", - "print(response.completion_message.content)" + "print(response.completion_message.content)\n" ] }, { @@ -900,12 +903,13 @@ "source": [ "from termcolor import cprint\n", "\n", + "\n", "def chat_loop():\n", " conversation_history = []\n", " while True:\n", - " user_input = input('User> ')\n", - " if user_input.lower() in ['exit', 'quit', 'bye']:\n", - " cprint('Ending conversation. Goodbye!', 'yellow')\n", + " user_input = input(\"User> \")\n", + " if user_input.lower() in [\"exit\", \"quit\", \"bye\"]:\n", + " cprint(\"Ending conversation. Goodbye!\", \"yellow\")\n", " break\n", "\n", " user_message = {\"role\": \"user\", \"content\": user_input}\n", @@ -915,14 +919,15 @@ " messages=conversation_history,\n", " model_id=model_id,\n", " )\n", - " cprint(f'> Response: {response.completion_message.content}', 'cyan')\n", + " cprint(f\"> Response: {response.completion_message.content}\", \"cyan\")\n", "\n", " assistant_message = {\n", - " \"role\": \"assistant\", # was user\n", + " \"role\": \"assistant\", # was user\n", " \"content\": response.completion_message.content,\n", " }\n", " conversation_history.append(assistant_message)\n", "\n", + "\n", "chat_loop()\n" ] }, @@ -978,21 +983,18 @@ "source": [ "from llama_stack_client.lib.inference.event_logger import EventLogger\n", "\n", - "message = {\n", - " \"role\": \"user\",\n", - " \"content\": 'Write me a sonnet about llama'\n", - "}\n", - "print(f'User> {message[\"content\"]}', 'green')\n", + "message = {\"role\": \"user\", \"content\": \"Write me a sonnet about llama\"}\n", + "print(f'User> {message[\"content\"]}', \"green\")\n", "\n", "response = client.inference.chat_completion(\n", " messages=[message],\n", " model_id=model_id,\n", - " stream=True, # <-----------\n", + " stream=True, # <-----------\n", ")\n", "\n", "# Print the tokens while they are received\n", "for log in EventLogger().log(response):\n", - " log.print()" + " log.print()\n" ] }, { @@ -1045,26 +1047,26 @@ "source": [ "from pydantic import BaseModel\n", "\n", + "\n", "class Output(BaseModel):\n", " name: str\n", " year_born: str\n", " year_retired: str\n", "\n", + "\n", "user_input = \"Michael Jordan was born in 1963. He played basketball for the Chicago Bulls. He retired in 2003. Extract this information into JSON for me. \"\n", "response = client.inference.completion(\n", " model_id=model_id,\n", " content=user_input,\n", " stream=False,\n", - " sampling_params={\n", - " \"max_tokens\": 50,\n", - " },\n", + " sampling_params={\"strategy\": {\"type\": \"greedy\"}, \"max_tokens\": 50},\n", " response_format={\n", " \"type\": \"json_schema\",\n", " \"json_schema\": Output.model_json_schema(),\n", " },\n", ")\n", "\n", - "pprint(response)" + "pprint(response)\n" ] }, { @@ -1220,7 +1222,7 @@ " shield_id=available_shields[0],\n", " params={},\n", " )\n", - " pprint(response)" + " pprint(response)\n" ] }, { @@ -1489,8 +1491,8 @@ "source": [ "from llama_stack_client.lib.agents.agent import Agent\n", "from llama_stack_client.lib.agents.event_logger import EventLogger\n", - "from llama_stack_client.types.agent_create_params import AgentConfig\n", "from llama_stack_client.types import Attachment\n", + "from llama_stack_client.types.agent_create_params import AgentConfig\n", "from termcolor import cprint\n", "\n", "urls = [\"chat.rst\", \"llama3.rst\", \"datasets.rst\", \"lora_finetune.rst\"]\n", @@ -1522,14 +1524,14 @@ " ),\n", "]\n", "for prompt, attachments in user_prompts:\n", - " cprint(f'User> {prompt}', 'green')\n", + " cprint(f\"User> {prompt}\", \"green\")\n", " response = rag_agent.create_turn(\n", " messages=[{\"role\": \"user\", \"content\": prompt}],\n", " attachments=attachments,\n", " session_id=session_id,\n", " )\n", " for log in EventLogger().log(response):\n", - " log.print()" + " log.print()\n" ] }, { @@ -1560,8 +1562,8 @@ "search_tool = {\n", " \"type\": \"brave_search\",\n", " \"engine\": \"tavily\",\n", - " \"api_key\": userdata.get(\"TAVILY_SEARCH_API_KEY\")\n", - "}" + " \"api_key\": userdata.get(\"TAVILY_SEARCH_API_KEY\"),\n", + "}\n" ] }, { @@ -1608,7 +1610,7 @@ "\n", "session_id = agent.create_session(\"test-session\")\n", "for prompt in user_prompts:\n", - " cprint(f'User> {prompt}', 'green')\n", + " cprint(f\"User> {prompt}\", \"green\")\n", " response = agent.create_turn(\n", " messages=[\n", " {\n", @@ -1758,7 +1760,7 @@ " search_tool,\n", " {\n", " \"type\": \"code_interpreter\",\n", - " }\n", + " },\n", " ],\n", " tool_choice=\"required\",\n", " input_shields=[],\n", @@ -1788,7 +1790,7 @@ "]\n", "\n", "for prompt in user_prompts:\n", - " cprint(f'User> {prompt}', 'green')\n", + " cprint(f\"User> {prompt}\", \"green\")\n", " response = codex_agent.create_turn(\n", " messages=[\n", " {\n", @@ -1841,27 +1843,57 @@ } ], "source": [ - "import pandas as pd\n", "import matplotlib.pyplot as plt\n", + "import pandas as pd\n", "\n", "# Read the CSV file\n", - "df = pd.read_csv('/tmp/tmpco0s0o4_/LOdZoVp1inflation.csv')\n", + "df = pd.read_csv(\"/tmp/tmpco0s0o4_/LOdZoVp1inflation.csv\")\n", "\n", "# Extract the year and inflation rate from the CSV file\n", - "df['Year'] = pd.to_datetime(df['Year'], format='%Y')\n", - "df = df.rename(columns={'Jan': 'Jan Rate', 'Feb': 'Feb Rate', 'Mar': 'Mar Rate', 'Apr': 'Apr Rate', 'May': 'May Rate', 'Jun': 'Jun Rate', 'Jul': 'Jul Rate', 'Aug': 'Aug Rate', 'Sep': 'Sep Rate', 'Oct': 'Oct Rate', 'Nov': 'Nov Rate', 'Dec': 'Dec Rate'})\n", + "df[\"Year\"] = pd.to_datetime(df[\"Year\"], format=\"%Y\")\n", + "df = df.rename(\n", + " columns={\n", + " \"Jan\": \"Jan Rate\",\n", + " \"Feb\": \"Feb Rate\",\n", + " \"Mar\": \"Mar Rate\",\n", + " \"Apr\": \"Apr Rate\",\n", + " \"May\": \"May Rate\",\n", + " \"Jun\": \"Jun Rate\",\n", + " \"Jul\": \"Jul Rate\",\n", + " \"Aug\": \"Aug Rate\",\n", + " \"Sep\": \"Sep Rate\",\n", + " \"Oct\": \"Oct Rate\",\n", + " \"Nov\": \"Nov Rate\",\n", + " \"Dec\": \"Dec Rate\",\n", + " }\n", + ")\n", "\n", "# Calculate the average yearly inflation rate\n", - "df['Yearly Inflation'] = df[['Jan Rate', 'Feb Rate', 'Mar Rate', 'Apr Rate', 'May Rate', 'Jun Rate', 'Jul Rate', 'Aug Rate', 'Sep Rate', 'Oct Rate', 'Nov Rate', 'Dec Rate']].mean(axis=1)\n", + "df[\"Yearly Inflation\"] = df[\n", + " [\n", + " \"Jan Rate\",\n", + " \"Feb Rate\",\n", + " \"Mar Rate\",\n", + " \"Apr Rate\",\n", + " \"May Rate\",\n", + " \"Jun Rate\",\n", + " \"Jul Rate\",\n", + " \"Aug Rate\",\n", + " \"Sep Rate\",\n", + " \"Oct Rate\",\n", + " \"Nov Rate\",\n", + " \"Dec Rate\",\n", + " ]\n", + "].mean(axis=1)\n", "\n", "# Plot the average yearly inflation rate as a time series\n", "plt.figure(figsize=(10, 6))\n", - "plt.plot(df['Year'], df['Yearly Inflation'], marker='o')\n", - "plt.title('Average Yearly Inflation Rate')\n", - "plt.xlabel('Year')\n", - "plt.ylabel('Inflation Rate (%)')\n", + "plt.plot(df[\"Year\"], df[\"Yearly Inflation\"], marker=\"o\")\n", + "plt.title(\"Average Yearly Inflation Rate\")\n", + "plt.xlabel(\"Year\")\n", + "plt.ylabel(\"Inflation Rate (%)\")\n", "plt.grid(True)\n", - "plt.show()" + "plt.show()\n" ] }, { @@ -2035,6 +2067,8 @@ "source": [ "# disable logging for clean server logs\n", "import logging\n", + "\n", + "\n", "def remove_root_handlers():\n", " root_logger = logging.getLogger()\n", " for handler in root_logger.handlers[:]:\n", @@ -2042,7 +2076,7 @@ " print(f\"Removed handler {handler.__class__.__name__} from root logger\")\n", "\n", "\n", - "remove_root_handlers()" + "remove_root_handlers()\n" ] }, { @@ -2083,10 +2117,10 @@ } ], "source": [ + "from google.colab import userdata\n", "from llama_stack_client.lib.agents.agent import Agent\n", "from llama_stack_client.lib.agents.event_logger import EventLogger\n", "from llama_stack_client.types.agent_create_params import AgentConfig\n", - "from google.colab import userdata\n", "\n", "agent_config = AgentConfig(\n", " model=\"meta-llama/Llama-3.1-405B-Instruct\",\n", @@ -2096,7 +2130,7 @@ " {\n", " \"type\": \"brave_search\",\n", " \"engine\": \"tavily\",\n", - " \"api_key\": userdata.get(\"TAVILY_SEARCH_API_KEY\")\n", + " \"api_key\": userdata.get(\"TAVILY_SEARCH_API_KEY\"),\n", " }\n", " ]\n", " ),\n", @@ -2125,7 +2159,7 @@ " )\n", "\n", " for log in EventLogger().log(response):\n", - " log.print()" + " log.print()\n" ] }, { @@ -2265,20 +2299,21 @@ "source": [ "print(f\"Getting traces for session_id={session_id}\")\n", "import json\n", + "\n", "from rich.pretty import pprint\n", "\n", "agent_logs = []\n", "\n", "for span in client.telemetry.query_spans(\n", " attribute_filters=[\n", - " {\"key\": \"session_id\", \"op\": \"eq\", \"value\": session_id},\n", + " {\"key\": \"session_id\", \"op\": \"eq\", \"value\": session_id},\n", " ],\n", - " attributes_to_return=[\"input\", \"output\"]\n", - " ):\n", - " if span.attributes[\"output\"] != \"no shields\":\n", - " agent_logs.append(span.attributes)\n", + " attributes_to_return=[\"input\", \"output\"],\n", + "):\n", + " if span.attributes[\"output\"] != \"no shields\":\n", + " agent_logs.append(span.attributes)\n", "\n", - "pprint(agent_logs)" + "pprint(agent_logs)\n" ] }, { @@ -2389,23 +2424,25 @@ "eval_rows = []\n", "\n", "for log in agent_logs:\n", - " last_msg = log['input'][-1]\n", - " if \"\\\"role\\\":\\\"user\\\"\" in last_msg:\n", - " eval_rows.append(\n", - " {\n", - " \"input_query\": last_msg,\n", - " \"generated_answer\": log[\"output\"],\n", - " # check if generated_answer uses tools brave_search\n", - " \"expected_answer\": \"brave_search\",\n", - " },\n", - " )\n", + " last_msg = log[\"input\"][-1]\n", + " if '\"role\":\"user\"' in last_msg:\n", + " eval_rows.append(\n", + " {\n", + " \"input_query\": last_msg,\n", + " \"generated_answer\": log[\"output\"],\n", + " # check if generated_answer uses tools brave_search\n", + " \"expected_answer\": \"brave_search\",\n", + " },\n", + " )\n", "\n", "pprint(eval_rows)\n", "scoring_params = {\n", " \"basic::subset_of\": None,\n", "}\n", - "scoring_response = client.scoring.score(input_rows=eval_rows, scoring_functions=scoring_params)\n", - "pprint(scoring_response)" + "scoring_response = client.scoring.score(\n", + " input_rows=eval_rows, scoring_functions=scoring_params\n", + ")\n", + "pprint(scoring_response)\n" ] }, { @@ -2506,7 +2543,9 @@ "EXPECTED_RESPONSE: {expected_answer}\n", "\"\"\"\n", "\n", - "input_query = \"What are the top 5 topics that were explained? Only list succinct bullet points.\"\n", + "input_query = (\n", + " \"What are the top 5 topics that were explained? Only list succinct bullet points.\"\n", + ")\n", "generated_answer = \"\"\"\n", "Here are the top 5 topics that were explained in the documentation for Torchtune:\n", "\n", @@ -2537,7 +2576,7 @@ "}\n", "\n", "response = client.scoring.score(input_rows=rows, scoring_functions=scoring_params)\n", - "pprint(response)" + "pprint(response)\n" ] }, { diff --git a/docs/notebooks/Llama_Stack_Benchmark_Evals.ipynb b/docs/notebooks/Llama_Stack_Benchmark_Evals.ipynb index 4810425d2..83891b7ac 100644 --- a/docs/notebooks/Llama_Stack_Benchmark_Evals.ipynb +++ b/docs/notebooks/Llama_Stack_Benchmark_Evals.ipynb @@ -618,11 +618,13 @@ ], "source": [ "import os\n", + "\n", "from google.colab import userdata\n", "\n", - "os.environ['TOGETHER_API_KEY'] = userdata.get('TOGETHER_API_KEY')\n", + "os.environ[\"TOGETHER_API_KEY\"] = userdata.get(\"TOGETHER_API_KEY\")\n", "\n", "from llama_stack.distribution.library_client import LlamaStackAsLibraryClient\n", + "\n", "client = LlamaStackAsLibraryClient(\"together\")\n", "_ = client.initialize()\n", "\n", @@ -631,7 +633,7 @@ " model_id=\"meta-llama/Llama-3.1-405B-Instruct\",\n", " provider_model_id=\"meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo\",\n", " provider_id=\"together\",\n", - ")" + ")\n" ] }, { @@ -668,7 +670,7 @@ "source": [ "name = \"llamastack/mmmu\"\n", "subset = \"Agriculture\"\n", - "split = \"dev\"" + "split = \"dev\"\n" ] }, { @@ -914,9 +916,10 @@ ], "source": [ "import datasets\n", + "\n", "ds = datasets.load_dataset(path=name, name=subset, split=split)\n", "ds = ds.select_columns([\"chat_completion_input\", \"input_query\", \"expected_answer\"])\n", - "eval_rows = ds.to_pandas().to_dict(orient=\"records\")" + "eval_rows = ds.to_pandas().to_dict(orient=\"records\")\n" ] }, { @@ -1014,8 +1017,8 @@ } ], "source": [ - "from tqdm import tqdm\n", "from rich.pretty import pprint\n", + "from tqdm import tqdm\n", "\n", "SYSTEM_PROMPT_TEMPLATE = \"\"\"\n", "You are an expert in {subject} whose job is to answer questions from the user using images.\n", @@ -1039,7 +1042,7 @@ "client.eval_tasks.register(\n", " eval_task_id=\"meta-reference::mmmu\",\n", " dataset_id=f\"mmmu-{subset}-{split}\",\n", - " scoring_functions=[\"basic::regex_parser_multiple_choice_answer\"]\n", + " scoring_functions=[\"basic::regex_parser_multiple_choice_answer\"],\n", ")\n", "\n", "response = client.eval.evaluate_rows(\n", @@ -1052,16 +1055,17 @@ " \"type\": \"model\",\n", " \"model\": \"meta-llama/Llama-3.2-90B-Vision-Instruct\",\n", " \"sampling_params\": {\n", - " \"temperature\": 0.0,\n", + " \"strategy\": {\n", + " \"type\": \"greedy\",\n", + " },\n", " \"max_tokens\": 4096,\n", - " \"top_p\": 0.9,\n", " \"repeat_penalty\": 1.0,\n", " },\n", - " \"system_message\": system_message\n", - " }\n", - " }\n", + " \"system_message\": system_message,\n", + " },\n", + " },\n", ")\n", - "pprint(response)" + "pprint(response)\n" ] }, { @@ -1098,8 +1102,8 @@ " \"input_query\": {\"type\": \"string\"},\n", " \"expected_answer\": {\"type\": \"string\"},\n", " \"chat_completion_input\": {\"type\": \"chat_completion_input\"},\n", - " }\n", - ")" + " },\n", + ")\n" ] }, { @@ -1113,7 +1117,7 @@ "eval_rows = client.datasetio.get_rows_paginated(\n", " dataset_id=simpleqa_dataset_id,\n", " rows_in_page=5,\n", - ")" + ")\n" ] }, { @@ -1209,7 +1213,7 @@ "client.eval_tasks.register(\n", " eval_task_id=\"meta-reference::simpleqa\",\n", " dataset_id=simpleqa_dataset_id,\n", - " scoring_functions=[\"llm-as-judge::405b-simpleqa\"]\n", + " scoring_functions=[\"llm-as-judge::405b-simpleqa\"],\n", ")\n", "\n", "response = client.eval.evaluate_rows(\n", @@ -1222,15 +1226,16 @@ " \"type\": \"model\",\n", " \"model\": \"meta-llama/Llama-3.2-90B-Vision-Instruct\",\n", " \"sampling_params\": {\n", - " \"temperature\": 0.0,\n", + " \"strategy\": {\n", + " \"type\": \"greedy\",\n", + " },\n", " \"max_tokens\": 4096,\n", - " \"top_p\": 0.9,\n", " \"repeat_penalty\": 1.0,\n", " },\n", - " }\n", - " }\n", + " },\n", + " },\n", ")\n", - "pprint(response)" + "pprint(response)\n" ] }, { @@ -1347,23 +1352,19 @@ "agent_config = {\n", " \"model\": \"meta-llama/Llama-3.1-405B-Instruct\",\n", " \"instructions\": \"You are a helpful assistant\",\n", - " \"sampling_params\": {\n", - " \"strategy\": \"greedy\",\n", - " \"temperature\": 0.0,\n", - " \"top_p\": 0.95,\n", - " },\n", + " \"sampling_params\": {\"strategy\": {\"type\": \"greedy\"}},\n", " \"tools\": [\n", " {\n", " \"type\": \"brave_search\",\n", " \"engine\": \"tavily\",\n", - " \"api_key\": userdata.get(\"TAVILY_SEARCH_API_KEY\")\n", + " \"api_key\": userdata.get(\"TAVILY_SEARCH_API_KEY\"),\n", " }\n", " ],\n", " \"tool_choice\": \"auto\",\n", " \"tool_prompt_format\": \"json\",\n", " \"input_shields\": [],\n", " \"output_shields\": [],\n", - " \"enable_session_persistence\": False\n", + " \"enable_session_persistence\": False,\n", "}\n", "\n", "response = client.eval.evaluate_rows(\n", @@ -1375,10 +1376,10 @@ " \"eval_candidate\": {\n", " \"type\": \"agent\",\n", " \"config\": agent_config,\n", - " }\n", - " }\n", + " },\n", + " },\n", ")\n", - "pprint(response)" + "pprint(response)\n" ] } ], diff --git a/docs/notebooks/Llama_Stack_Building_AI_Applications.ipynb b/docs/notebooks/Llama_Stack_Building_AI_Applications.ipynb index 7e6284628..472e800a6 100644 --- a/docs/notebooks/Llama_Stack_Building_AI_Applications.ipynb +++ b/docs/notebooks/Llama_Stack_Building_AI_Applications.ipynb @@ -1336,6 +1336,7 @@ ], "source": [ "from rich.pretty import pprint\n", + "\n", "print(\"Available models:\")\n", "for m in client.models.list():\n", " print(f\"{m.identifier} (provider's alias: {m.provider_resource_id}) \")\n", @@ -1344,7 +1345,7 @@ "print(\"Available shields (safety models):\")\n", "for s in client.shields.list():\n", " print(s.identifier)\n", - "print(\"----\")" + "print(\"----\")\n" ] }, { @@ -1389,7 +1390,7 @@ "source": [ "model_id = \"meta-llama/Llama-3.1-70B-Instruct\"\n", "\n", - "model_id" + "model_id\n" ] }, { @@ -1432,11 +1433,11 @@ " model_id=model_id,\n", " messages=[\n", " {\"role\": \"system\", \"content\": \"You are a friendly assistant.\"},\n", - " {\"role\": \"user\", \"content\": \"Write a two-sentence poem about llama.\"}\n", + " {\"role\": \"user\", \"content\": \"Write a two-sentence poem about llama.\"},\n", " ],\n", ")\n", "\n", - "print(response.completion_message.content)" + "print(response.completion_message.content)\n" ] }, { @@ -1489,12 +1490,13 @@ "source": [ "from termcolor import cprint\n", "\n", + "\n", "def chat_loop():\n", " conversation_history = []\n", " while True:\n", - " user_input = input('User> ')\n", - " if user_input.lower() in ['exit', 'quit', 'bye']:\n", - " cprint('Ending conversation. Goodbye!', 'yellow')\n", + " user_input = input(\"User> \")\n", + " if user_input.lower() in [\"exit\", \"quit\", \"bye\"]:\n", + " cprint(\"Ending conversation. Goodbye!\", \"yellow\")\n", " break\n", "\n", " user_message = {\"role\": \"user\", \"content\": user_input}\n", @@ -1504,15 +1506,16 @@ " messages=conversation_history,\n", " model_id=model_id,\n", " )\n", - " cprint(f'> Response: {response.completion_message.content}', 'cyan')\n", + " cprint(f\"> Response: {response.completion_message.content}\", \"cyan\")\n", "\n", " assistant_message = {\n", - " \"role\": \"assistant\", # was user\n", + " \"role\": \"assistant\", # was user\n", " \"content\": response.completion_message.content,\n", " \"stop_reason\": response.completion_message.stop_reason,\n", " }\n", " conversation_history.append(assistant_message)\n", "\n", + "\n", "chat_loop()\n" ] }, @@ -1568,21 +1571,18 @@ "source": [ "from llama_stack_client.lib.inference.event_logger import EventLogger\n", "\n", - "message = {\n", - " \"role\": \"user\",\n", - " \"content\": 'Write me a sonnet about llama'\n", - "}\n", - "print(f'User> {message[\"content\"]}', 'green')\n", + "message = {\"role\": \"user\", \"content\": \"Write me a sonnet about llama\"}\n", + "print(f'User> {message[\"content\"]}', \"green\")\n", "\n", "response = client.inference.chat_completion(\n", " messages=[message],\n", " model_id=model_id,\n", - " stream=True, # <-----------\n", + " stream=True, # <-----------\n", ")\n", "\n", "# Print the tokens while they are received\n", "for log in EventLogger().log(response):\n", - " log.print()" + " log.print()\n" ] }, { @@ -1648,17 +1648,22 @@ "source": [ "from pydantic import BaseModel\n", "\n", + "\n", "class Output(BaseModel):\n", " name: str\n", " year_born: str\n", " year_retired: str\n", "\n", + "\n", "user_input = \"Michael Jordan was born in 1963. He played basketball for the Chicago Bulls. He retired in 2003. Extract this information into JSON for me. \"\n", "response = client.inference.completion(\n", " model_id=model_id,\n", " content=user_input,\n", " stream=False,\n", " sampling_params={\n", + " \"strategy\": {\n", + " \"type\": \"greedy\",\n", + " },\n", " \"max_tokens\": 50,\n", " },\n", " response_format={\n", @@ -1667,7 +1672,7 @@ " },\n", ")\n", "\n", - "pprint(response)" + "pprint(response)\n" ] }, { @@ -1823,7 +1828,7 @@ " shield_id=available_shields[0],\n", " params={},\n", " )\n", - " pprint(response)" + " pprint(response)\n" ] }, { @@ -2025,7 +2030,7 @@ "\n", "session_id = agent.create_session(\"test-session\")\n", "for prompt in user_prompts:\n", - " cprint(f'User> {prompt}', 'green')\n", + " cprint(f\"User> {prompt}\", \"green\")\n", " response = agent.create_turn(\n", " messages=[\n", " {\n", @@ -2451,8 +2456,8 @@ } ], "source": [ - "import pandas as pd\n", "import matplotlib.pyplot as plt\n", + "import pandas as pd\n", "\n", "# Load data\n", "df = pd.read_csv(\"/tmp/tmpvzjigv7g/n2OzlTWhinflation.csv\")\n", @@ -2536,10 +2541,10 @@ } ], "source": [ + "from google.colab import userdata\n", "from llama_stack_client.lib.agents.agent import Agent\n", "from llama_stack_client.lib.agents.event_logger import EventLogger\n", "from llama_stack_client.types.agent_create_params import AgentConfig\n", - "from google.colab import userdata\n", "\n", "agent_config = AgentConfig(\n", " model=\"meta-llama/Llama-3.1-405B-Instruct-FP8\",\n", @@ -2570,7 +2575,7 @@ " )\n", "\n", " for log in EventLogger().log(response):\n", - " log.print()" + " log.print()\n" ] }, { @@ -2790,20 +2795,21 @@ "source": [ "print(f\"Getting traces for session_id={session_id}\")\n", "import json\n", + "\n", "from rich.pretty import pprint\n", "\n", "agent_logs = []\n", "\n", "for span in client.telemetry.query_spans(\n", " attribute_filters=[\n", - " {\"key\": \"session_id\", \"op\": \"eq\", \"value\": session_id},\n", + " {\"key\": \"session_id\", \"op\": \"eq\", \"value\": session_id},\n", " ],\n", - " attributes_to_return=[\"input\", \"output\"]\n", - " ):\n", - " if span.attributes[\"output\"] != \"no shields\":\n", - " agent_logs.append(span.attributes)\n", + " attributes_to_return=[\"input\", \"output\"],\n", + "):\n", + " if span.attributes[\"output\"] != \"no shields\":\n", + " agent_logs.append(span.attributes)\n", "\n", - "pprint(agent_logs)" + "pprint(agent_logs)\n" ] }, { @@ -2914,23 +2920,25 @@ "eval_rows = []\n", "\n", "for log in agent_logs:\n", - " last_msg = log['input'][-1]\n", - " if \"\\\"role\\\":\\\"user\\\"\" in last_msg:\n", - " eval_rows.append(\n", - " {\n", - " \"input_query\": last_msg,\n", - " \"generated_answer\": log[\"output\"],\n", - " # check if generated_answer uses tools brave_search\n", - " \"expected_answer\": \"brave_search\",\n", - " },\n", - " )\n", + " last_msg = log[\"input\"][-1]\n", + " if '\"role\":\"user\"' in last_msg:\n", + " eval_rows.append(\n", + " {\n", + " \"input_query\": last_msg,\n", + " \"generated_answer\": log[\"output\"],\n", + " # check if generated_answer uses tools brave_search\n", + " \"expected_answer\": \"brave_search\",\n", + " },\n", + " )\n", "\n", "pprint(eval_rows)\n", "scoring_params = {\n", " \"basic::subset_of\": None,\n", "}\n", - "scoring_response = client.scoring.score(input_rows=eval_rows, scoring_functions=scoring_params)\n", - "pprint(scoring_response)" + "scoring_response = client.scoring.score(\n", + " input_rows=eval_rows, scoring_functions=scoring_params\n", + ")\n", + "pprint(scoring_response)\n" ] }, { @@ -3031,7 +3039,9 @@ "EXPECTED_RESPONSE: {expected_answer}\n", "\"\"\"\n", "\n", - "input_query = \"What are the top 5 topics that were explained? Only list succinct bullet points.\"\n", + "input_query = (\n", + " \"What are the top 5 topics that were explained? Only list succinct bullet points.\"\n", + ")\n", "generated_answer = \"\"\"\n", "Here are the top 5 topics that were explained in the documentation for Torchtune:\n", "\n", @@ -3062,7 +3072,7 @@ "}\n", "\n", "response = client.scoring.score(input_rows=rows, scoring_functions=scoring_params)\n", - "pprint(response)" + "pprint(response)\n" ] } ], diff --git a/docs/resources/llama-stack-spec.html b/docs/resources/llama-stack-spec.html index 5ed8701a4..ad210a502 100644 --- a/docs/resources/llama-stack-spec.html +++ b/docs/resources/llama-stack-spec.html @@ -3514,6 +3514,20 @@ "tool_calls" ] }, + "GreedySamplingStrategy": { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "greedy", + "default": "greedy" + } + }, + "additionalProperties": false, + "required": [ + "type" + ] + }, "ImageContentItem": { "type": "object", "properties": { @@ -3581,20 +3595,17 @@ "type": "object", "properties": { "strategy": { - "$ref": "#/components/schemas/SamplingStrategy", - "default": "greedy" - }, - "temperature": { - "type": "number", - "default": 0.0 - }, - "top_p": { - "type": "number", - "default": 0.95 - }, - "top_k": { - "type": "integer", - "default": 0 + "oneOf": [ + { + "$ref": "#/components/schemas/GreedySamplingStrategy" + }, + { + "$ref": "#/components/schemas/TopPSamplingStrategy" + }, + { + "$ref": "#/components/schemas/TopKSamplingStrategy" + } + ] }, "max_tokens": { "type": "integer", @@ -3610,14 +3621,6 @@ "strategy" ] }, - "SamplingStrategy": { - "type": "string", - "enum": [ - "greedy", - "top_p", - "top_k" - ] - }, "StopReason": { "type": "string", "enum": [ @@ -3871,6 +3874,45 @@ "content" ] }, + "TopKSamplingStrategy": { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "top_k", + "default": "top_k" + }, + "top_k": { + "type": "integer" + } + }, + "additionalProperties": false, + "required": [ + "type", + "top_k" + ] + }, + "TopPSamplingStrategy": { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "top_p", + "default": "top_p" + }, + "temperature": { + "type": "number" + }, + "top_p": { + "type": "number", + "default": 0.95 + } + }, + "additionalProperties": false, + "required": [ + "type" + ] + }, "URL": { "type": "object", "properties": { @@ -8887,6 +8929,10 @@ "name": "GraphMemoryBankParams", "description": "" }, + { + "name": "GreedySamplingStrategy", + "description": "" + }, { "name": "HealthInfo", "description": "" @@ -9136,10 +9182,6 @@ "name": "SamplingParams", "description": "" }, - { - "name": "SamplingStrategy", - "description": "" - }, { "name": "SaveSpansToDatasetRequest", "description": "" @@ -9317,6 +9359,14 @@ { "name": "ToolRuntime" }, + { + "name": "TopKSamplingStrategy", + "description": "" + }, + { + "name": "TopPSamplingStrategy", + "description": "" + }, { "name": "Trace", "description": "" @@ -9456,6 +9506,7 @@ "GetSpanTreeRequest", "GraphMemoryBank", "GraphMemoryBankParams", + "GreedySamplingStrategy", "HealthInfo", "ImageContentItem", "InferenceStep", @@ -9513,7 +9564,6 @@ "RunShieldResponse", "SafetyViolation", "SamplingParams", - "SamplingStrategy", "SaveSpansToDatasetRequest", "ScoreBatchRequest", "ScoreBatchResponse", @@ -9553,6 +9603,8 @@ "ToolPromptFormat", "ToolResponse", "ToolResponseMessage", + "TopKSamplingStrategy", + "TopPSamplingStrategy", "Trace", "TrainingConfig", "Turn", diff --git a/docs/resources/llama-stack-spec.yaml b/docs/resources/llama-stack-spec.yaml index 2a573959f..8c885b7e5 100644 --- a/docs/resources/llama-stack-spec.yaml +++ b/docs/resources/llama-stack-spec.yaml @@ -937,6 +937,16 @@ components: required: - memory_bank_type type: object + GreedySamplingStrategy: + additionalProperties: false + properties: + type: + const: greedy + default: greedy + type: string + required: + - type + type: object HealthInfo: additionalProperties: false properties: @@ -2064,26 +2074,13 @@ components: default: 1.0 type: number strategy: - $ref: '#/components/schemas/SamplingStrategy' - default: greedy - temperature: - default: 0.0 - type: number - top_k: - default: 0 - type: integer - top_p: - default: 0.95 - type: number + oneOf: + - $ref: '#/components/schemas/GreedySamplingStrategy' + - $ref: '#/components/schemas/TopPSamplingStrategy' + - $ref: '#/components/schemas/TopKSamplingStrategy' required: - strategy type: object - SamplingStrategy: - enum: - - greedy - - top_p - - top_k - type: string SaveSpansToDatasetRequest: additionalProperties: false properties: @@ -2931,6 +2928,34 @@ components: - tool_name - content type: object + TopKSamplingStrategy: + additionalProperties: false + properties: + top_k: + type: integer + type: + const: top_k + default: top_k + type: string + required: + - type + - top_k + type: object + TopPSamplingStrategy: + additionalProperties: false + properties: + temperature: + type: number + top_p: + default: 0.95 + type: number + type: + const: top_p + default: top_p + type: string + required: + - type + type: object Trace: additionalProperties: false properties: @@ -5587,6 +5612,9 @@ tags: - description: name: GraphMemoryBankParams +- description: + name: GreedySamplingStrategy - description: name: HealthInfo - description: name: SamplingParams -- description: - name: SamplingStrategy - description: name: SaveSpansToDatasetRequest @@ -5874,6 +5899,12 @@ tags: /> name: ToolResponseMessage - name: ToolRuntime +- description: + name: TopKSamplingStrategy +- description: + name: TopPSamplingStrategy - description: name: Trace - description: @@ -5990,6 +6021,7 @@ x-tagGroups: - GetSpanTreeRequest - GraphMemoryBank - GraphMemoryBankParams + - GreedySamplingStrategy - HealthInfo - ImageContentItem - InferenceStep @@ -6047,7 +6079,6 @@ x-tagGroups: - RunShieldResponse - SafetyViolation - SamplingParams - - SamplingStrategy - SaveSpansToDatasetRequest - ScoreBatchRequest - ScoreBatchResponse @@ -6087,6 +6118,8 @@ x-tagGroups: - ToolPromptFormat - ToolResponse - ToolResponseMessage + - TopKSamplingStrategy + - TopPSamplingStrategy - Trace - TrainingConfig - Turn diff --git a/docs/source/benchmark_evaluations/index.md b/docs/source/benchmark_evaluations/index.md index 240555936..56852c89c 100644 --- a/docs/source/benchmark_evaluations/index.md +++ b/docs/source/benchmark_evaluations/index.md @@ -56,9 +56,10 @@ response = client.eval.evaluate_rows( "type": "model", "model": "meta-llama/Llama-3.2-90B-Vision-Instruct", "sampling_params": { - "temperature": 0.0, + "strategy": { + "type": "greedy", + }, "max_tokens": 4096, - "top_p": 0.9, "repeat_penalty": 1.0, }, "system_message": system_message @@ -113,9 +114,10 @@ response = client.eval.evaluate_rows( "type": "model", "model": "meta-llama/Llama-3.2-90B-Vision-Instruct", "sampling_params": { - "temperature": 0.0, + "strategy": { + "type": "greedy", + }, "max_tokens": 4096, - "top_p": 0.9, "repeat_penalty": 1.0, }, } @@ -134,9 +136,9 @@ agent_config = { "model": "meta-llama/Llama-3.1-405B-Instruct", "instructions": "You are a helpful assistant", "sampling_params": { - "strategy": "greedy", - "temperature": 0.0, - "top_p": 0.95, + "strategy": { + "type": "greedy", + }, }, "tools": [ { diff --git a/docs/source/building_applications/index.md b/docs/source/building_applications/index.md index acc19b515..61b7038cd 100644 --- a/docs/source/building_applications/index.md +++ b/docs/source/building_applications/index.md @@ -189,7 +189,11 @@ agent_config = AgentConfig( # Control the inference loop max_infer_iters=5, sampling_params={ - "temperature": 0.7, + "strategy": { + "type": "top_p", + "temperature": 0.7, + "top_p": 0.95 + }, "max_tokens": 2048 } ) diff --git a/docs/source/references/evals_reference/index.md b/docs/source/references/evals_reference/index.md index f93b56e64..c01fd69d8 100644 --- a/docs/source/references/evals_reference/index.md +++ b/docs/source/references/evals_reference/index.md @@ -92,9 +92,10 @@ response = client.eval.evaluate_rows( "type": "model", "model": "meta-llama/Llama-3.2-90B-Vision-Instruct", "sampling_params": { - "temperature": 0.0, + "strategy": { + "type": "greedy", + }, "max_tokens": 4096, - "top_p": 0.9, "repeat_penalty": 1.0, }, "system_message": system_message @@ -149,9 +150,10 @@ response = client.eval.evaluate_rows( "type": "model", "model": "meta-llama/Llama-3.2-90B-Vision-Instruct", "sampling_params": { - "temperature": 0.0, + "strategy": { + "type": "greedy", + }, "max_tokens": 4096, - "top_p": 0.9, "repeat_penalty": 1.0, }, } @@ -170,9 +172,9 @@ agent_config = { "model": "meta-llama/Llama-3.1-405B-Instruct", "instructions": "You are a helpful assistant", "sampling_params": { - "strategy": "greedy", - "temperature": 0.0, - "top_p": 0.95, + "strategy": { + "type": "greedy", + }, }, "tools": [ { @@ -318,10 +320,9 @@ The `EvalTaskConfig` are user specified config to define: "type": "model", "model": "Llama3.2-3B-Instruct", "sampling_params": { - "strategy": "greedy", - "temperature": 0, - "top_p": 0.95, - "top_k": 0, + "strategy": { + "type": "greedy", + }, "max_tokens": 0, "repetition_penalty": 1.0 } @@ -337,10 +338,9 @@ The `EvalTaskConfig` are user specified config to define: "type": "model", "model": "Llama3.1-405B-Instruct", "sampling_params": { - "strategy": "greedy", - "temperature": 0, - "top_p": 0.95, - "top_k": 0, + "strategy": { + "type": "greedy", + }, "max_tokens": 0, "repetition_penalty": 1.0 } diff --git a/docs/source/references/llama_cli_reference/index.md b/docs/source/references/llama_cli_reference/index.md index a0314644a..f7ac5fe36 100644 --- a/docs/source/references/llama_cli_reference/index.md +++ b/docs/source/references/llama_cli_reference/index.md @@ -214,7 +214,6 @@ llama model describe -m Llama3.2-3B-Instruct | | } | +-----------------------------+----------------------------------+ | Recommended sampling params | { | -| | "strategy": "top_p", | | | "temperature": 1.0, | | | "top_p": 0.9, | | | "top_k": 0 | diff --git a/docs/source/references/llama_stack_client_cli_reference.md b/docs/source/references/llama_stack_client_cli_reference.md index b35aa189d..c3abccfd9 100644 --- a/docs/source/references/llama_stack_client_cli_reference.md +++ b/docs/source/references/llama_stack_client_cli_reference.md @@ -200,10 +200,9 @@ Example eval_task_config.json: "type": "model", "model": "Llama3.1-405B-Instruct", "sampling_params": { - "strategy": "greedy", - "temperature": 0, - "top_p": 0.95, - "top_k": 0, + "strategy": { + "type": "greedy" + }, "max_tokens": 0, "repetition_penalty": 1.0 } diff --git a/docs/zero_to_hero_guide/04_Tool_Calling101.ipynb b/docs/zero_to_hero_guide/04_Tool_Calling101.ipynb index 4f0d2e887..4c278493b 100644 --- a/docs/zero_to_hero_guide/04_Tool_Calling101.ipynb +++ b/docs/zero_to_hero_guide/04_Tool_Calling101.ipynb @@ -26,27 +26,28 @@ "metadata": {}, "outputs": [], "source": [ - "import os\n", - "import requests\n", - "import json\n", "import asyncio\n", - "import nest_asyncio\n", + "import json\n", + "import os\n", "from typing import Dict, List\n", + "\n", + "import nest_asyncio\n", + "import requests\n", "from dotenv import load_dotenv\n", "from llama_stack_client import LlamaStackClient\n", - "from llama_stack_client.lib.agents.custom_tool import CustomTool\n", - "from llama_stack_client.types.shared.tool_response_message import ToolResponseMessage\n", - "from llama_stack_client.types import CompletionMessage\n", "from llama_stack_client.lib.agents.agent import Agent\n", + "from llama_stack_client.lib.agents.custom_tool import CustomTool\n", "from llama_stack_client.lib.agents.event_logger import EventLogger\n", + "from llama_stack_client.types import CompletionMessage\n", "from llama_stack_client.types.agent_create_params import AgentConfig\n", + "from llama_stack_client.types.shared.tool_response_message import ToolResponseMessage\n", "\n", "# Allow asyncio to run in Jupyter Notebook\n", "nest_asyncio.apply()\n", "\n", - "HOST='localhost'\n", - "PORT=5001\n", - "MODEL_NAME='meta-llama/Llama-3.2-3B-Instruct'" + "HOST = \"localhost\"\n", + "PORT = 5001\n", + "MODEL_NAME = \"meta-llama/Llama-3.2-3B-Instruct\"\n" ] }, { @@ -69,7 +70,7 @@ "outputs": [], "source": [ "load_dotenv()\n", - "BRAVE_SEARCH_API_KEY = os.environ['BRAVE_SEARCH_API_KEY']" + "BRAVE_SEARCH_API_KEY = os.environ[\"BRAVE_SEARCH_API_KEY\"]\n" ] }, { @@ -118,7 +119,7 @@ " cleaned = {k: v for k, v in results[idx].items() if k in selected_keys}\n", " clean_response.append(cleaned)\n", "\n", - " return {\"query\": query, \"top_k\": clean_response}" + " return {\"query\": query, \"top_k\": clean_response}\n" ] }, { @@ -157,25 +158,29 @@ " for message in messages:\n", " if isinstance(message, CompletionMessage) and message.tool_calls:\n", " for tool_call in message.tool_calls:\n", - " if 'query' in tool_call.arguments:\n", - " query = tool_call.arguments['query']\n", + " if \"query\" in tool_call.arguments:\n", + " query = tool_call.arguments[\"query\"]\n", " call_id = tool_call.call_id\n", "\n", " if query:\n", " search_result = await self.run_impl(query)\n", - " return [ToolResponseMessage(\n", - " call_id=call_id,\n", - " role=\"ipython\",\n", - " content=self._format_response_for_agent(search_result),\n", - " tool_name=\"brave_search\"\n", - " )]\n", + " return [\n", + " ToolResponseMessage(\n", + " call_id=call_id,\n", + " role=\"ipython\",\n", + " content=self._format_response_for_agent(search_result),\n", + " tool_name=\"brave_search\",\n", + " )\n", + " ]\n", "\n", - " return [ToolResponseMessage(\n", - " call_id=\"no_call_id\",\n", - " role=\"ipython\",\n", - " content=\"No query provided.\",\n", - " tool_name=\"brave_search\"\n", - " )]\n", + " return [\n", + " ToolResponseMessage(\n", + " call_id=\"no_call_id\",\n", + " role=\"ipython\",\n", + " content=\"No query provided.\",\n", + " tool_name=\"brave_search\",\n", + " )\n", + " ]\n", "\n", " def _format_response_for_agent(self, search_result):\n", " parsed_result = json.loads(search_result)\n", @@ -186,7 +191,7 @@ " f\" URL: {result.get('url', 'No URL')}\\n\"\n", " f\" Description: {result.get('description', 'No Description')}\\n\\n\"\n", " )\n", - " return formatted_result" + " return formatted_result\n" ] }, { @@ -209,7 +214,7 @@ "async def execute_search(query: str):\n", " web_search_tool = WebSearchTool(api_key=BRAVE_SEARCH_API_KEY)\n", " result = await web_search_tool.run_impl(query)\n", - " print(\"Search Results:\", result)" + " print(\"Search Results:\", result)\n" ] }, { @@ -236,7 +241,7 @@ ], "source": [ "query = \"Latest developments in quantum computing\"\n", - "asyncio.run(execute_search(query))" + "asyncio.run(execute_search(query))\n" ] }, { @@ -288,19 +293,17 @@ "\n", " # Initialize custom tool (ensure `WebSearchTool` is defined earlier in the notebook)\n", " webSearchTool = WebSearchTool(api_key=BRAVE_SEARCH_API_KEY)\n", - " \n", + "\n", " # Define the agent configuration, including the model and tool setup\n", " agent_config = AgentConfig(\n", " model=MODEL_NAME,\n", " instructions=\"\"\"You are a helpful assistant that responds to user queries with relevant information and cites sources when available.\"\"\",\n", " sampling_params={\n", - " \"strategy\": \"greedy\",\n", - " \"temperature\": 1.0,\n", - " \"top_p\": 0.9,\n", + " \"strategy\": {\n", + " \"type\": \"greedy\",\n", + " },\n", " },\n", - " tools=[\n", - " webSearchTool.get_tool_definition()\n", - " ],\n", + " tools=[webSearchTool.get_tool_definition()],\n", " tool_choice=\"auto\",\n", " tool_prompt_format=\"python_list\",\n", " input_shields=input_shields,\n", @@ -329,8 +332,9 @@ " async for log in EventLogger().log(response):\n", " log.print()\n", "\n", + "\n", "# Run the function asynchronously in a Jupyter Notebook cell\n", - "await run_main(disable_safety=True)" + "await run_main(disable_safety=True)\n" ] } ], diff --git a/docs/zero_to_hero_guide/07_Agents101.ipynb b/docs/zero_to_hero_guide/07_Agents101.ipynb index 88b73b4cd..04178f3f6 100644 --- a/docs/zero_to_hero_guide/07_Agents101.ipynb +++ b/docs/zero_to_hero_guide/07_Agents101.ipynb @@ -50,8 +50,8 @@ "outputs": [], "source": [ "HOST = \"localhost\" # Replace with your host\n", - "PORT = 5001 # Replace with your port\n", - "MODEL_NAME='meta-llama/Llama-3.2-3B-Instruct'" + "PORT = 5001 # Replace with your port\n", + "MODEL_NAME = \"meta-llama/Llama-3.2-3B-Instruct\"\n" ] }, { @@ -60,10 +60,12 @@ "metadata": {}, "outputs": [], "source": [ - "from dotenv import load_dotenv\n", "import os\n", + "\n", + "from dotenv import load_dotenv\n", + "\n", "load_dotenv()\n", - "BRAVE_SEARCH_API_KEY = os.environ['BRAVE_SEARCH_API_KEY']" + "BRAVE_SEARCH_API_KEY = os.environ[\"BRAVE_SEARCH_API_KEY\"]\n" ] }, { @@ -104,20 +106,22 @@ ], "source": [ "import os\n", + "\n", "from llama_stack_client import LlamaStackClient\n", "from llama_stack_client.lib.agents.agent import Agent\n", "from llama_stack_client.lib.agents.event_logger import EventLogger\n", "from llama_stack_client.types.agent_create_params import AgentConfig\n", "\n", + "\n", "async def agent_example():\n", " client = LlamaStackClient(base_url=f\"http://{HOST}:{PORT}\")\n", " agent_config = AgentConfig(\n", " model=MODEL_NAME,\n", " instructions=\"You are a helpful assistant! If you call builtin tools like brave search, follow the syntax brave_search.call(…)\",\n", " sampling_params={\n", - " \"strategy\": \"greedy\",\n", - " \"temperature\": 1.0,\n", - " \"top_p\": 0.9,\n", + " \"strategy\": {\n", + " \"type\": \"greedy\",\n", + " },\n", " },\n", " tools=[\n", " {\n", @@ -157,7 +161,7 @@ " log.print()\n", "\n", "\n", - "await agent_example()" + "await agent_example()\n" ] }, { diff --git a/docs/zero_to_hero_guide/README.md b/docs/zero_to_hero_guide/README.md index f96ae49ce..c4803a1d6 100644 --- a/docs/zero_to_hero_guide/README.md +++ b/docs/zero_to_hero_guide/README.md @@ -157,7 +157,15 @@ curl http://localhost:$LLAMA_STACK_PORT/alpha/inference/chat-completion {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Write me a 2-sentence poem about the moon"} ], - "sampling_params": {"temperature": 0.7, "seed": 42, "max_tokens": 512} + "sampling_params": { + "strategy": { + "type": "top_p", + "temperatrue": 0.7, + "top_p": 0.95, + }, + "seed": 42, + "max_tokens": 512 + } } EOF ``` diff --git a/docs/zero_to_hero_guide/Tool_Calling101_Using_Together's_Llama_Stack_Server.ipynb b/docs/zero_to_hero_guide/Tool_Calling101_Using_Together's_Llama_Stack_Server.ipynb index b21f3d64c..68e781018 100644 --- a/docs/zero_to_hero_guide/Tool_Calling101_Using_Together's_Llama_Stack_Server.ipynb +++ b/docs/zero_to_hero_guide/Tool_Calling101_Using_Together's_Llama_Stack_Server.ipynb @@ -83,8 +83,8 @@ }, "outputs": [], "source": [ - "LLAMA_STACK_API_TOGETHER_URL=\"https://llama-stack.together.ai\"\n", - "LLAMA31_8B_INSTRUCT = \"Llama3.1-8B-Instruct\"" + "LLAMA_STACK_API_TOGETHER_URL = \"https://llama-stack.together.ai\"\n", + "LLAMA31_8B_INSTRUCT = \"Llama3.1-8B-Instruct\"\n" ] }, { @@ -107,12 +107,13 @@ " AgentConfigToolSearchToolDefinition,\n", ")\n", "\n", + "\n", "# Helper function to create an agent with tools\n", "async def create_tool_agent(\n", " client: LlamaStackClient,\n", " tools: List[Dict],\n", " instructions: str = \"You are a helpful assistant\",\n", - " model: str = LLAMA31_8B_INSTRUCT\n", + " model: str = LLAMA31_8B_INSTRUCT,\n", ") -> Agent:\n", " \"\"\"Create an agent with specified tools.\"\"\"\n", " print(\"Using the following model: \", model)\n", @@ -120,9 +121,9 @@ " model=model,\n", " instructions=instructions,\n", " sampling_params={\n", - " \"strategy\": \"greedy\",\n", - " \"temperature\": 1.0,\n", - " \"top_p\": 0.9,\n", + " \"strategy\": {\n", + " \"type\": \"greedy\",\n", + " },\n", " },\n", " tools=tools,\n", " tool_choice=\"auto\",\n", @@ -130,7 +131,7 @@ " enable_session_persistence=True,\n", " )\n", "\n", - " return Agent(client, agent_config)" + " return Agent(client, agent_config)\n" ] }, { @@ -172,7 +173,8 @@ ], "source": [ "# comment this if you don't have a BRAVE_SEARCH_API_KEY\n", - "os.environ[\"BRAVE_SEARCH_API_KEY\"] = 'YOUR_BRAVE_SEARCH_API_KEY'\n", + "os.environ[\"BRAVE_SEARCH_API_KEY\"] = \"YOUR_BRAVE_SEARCH_API_KEY\"\n", + "\n", "\n", "async def create_search_agent(client: LlamaStackClient) -> Agent:\n", " \"\"\"Create an agent with Brave Search capability.\"\"\"\n", @@ -186,8 +188,8 @@ "\n", " return await create_tool_agent(\n", " client=client,\n", - " tools=[search_tool], # set this to [] if you don't have a BRAVE_SEARCH_API_KEY\n", - " model = LLAMA31_8B_INSTRUCT,\n", + " tools=[search_tool], # set this to [] if you don't have a BRAVE_SEARCH_API_KEY\n", + " model=LLAMA31_8B_INSTRUCT,\n", " instructions=\"\"\"\n", " You are a research assistant that can search the web.\n", " Always cite your sources with URLs when providing information.\n", @@ -198,9 +200,10 @@ "\n", " SOURCES:\n", " - [Source title](URL)\n", - " \"\"\"\n", + " \"\"\",\n", " )\n", "\n", + "\n", "# Example usage\n", "async def search_example():\n", " client = LlamaStackClient(base_url=LLAMA_STACK_API_TOGETHER_URL)\n", @@ -212,7 +215,7 @@ " # Example queries\n", " queries = [\n", " \"What are the latest developments in quantum computing?\",\n", - " #\"Who won the most recent Super Bowl?\",\n", + " # \"Who won the most recent Super Bowl?\",\n", " ]\n", "\n", " for query in queries:\n", @@ -227,8 +230,9 @@ " async for log in EventLogger().log(response):\n", " log.print()\n", "\n", + "\n", "# Run the example (in Jupyter, use asyncio.run())\n", - "await search_example()" + "await search_example()\n" ] }, { @@ -286,12 +290,16 @@ } ], "source": [ - "from typing import TypedDict, Optional, Dict, Any\n", - "from datetime import datetime\n", "import json\n", - "from llama_stack_client.types.tool_param_definition_param import ToolParamDefinitionParam\n", - "from llama_stack_client.types import CompletionMessage,ToolResponseMessage\n", + "from datetime import datetime\n", + "from typing import Any, Dict, Optional, TypedDict\n", + "\n", "from llama_stack_client.lib.agents.custom_tool import CustomTool\n", + "from llama_stack_client.types import CompletionMessage, ToolResponseMessage\n", + "from llama_stack_client.types.tool_param_definition_param import (\n", + " ToolParamDefinitionParam,\n", + ")\n", + "\n", "\n", "class WeatherTool(CustomTool):\n", " \"\"\"Example custom tool for weather information.\"\"\"\n", @@ -305,16 +313,15 @@ " def get_params_definition(self) -> Dict[str, ToolParamDefinitionParam]:\n", " return {\n", " \"location\": ToolParamDefinitionParam(\n", - " param_type=\"str\",\n", - " description=\"City or location name\",\n", - " required=True\n", + " param_type=\"str\", description=\"City or location name\", required=True\n", " ),\n", " \"date\": ToolParamDefinitionParam(\n", " param_type=\"str\",\n", " description=\"Optional date (YYYY-MM-DD)\",\n", - " required=False\n", - " )\n", + " required=False,\n", + " ),\n", " }\n", + "\n", " async def run(self, messages: List[CompletionMessage]) -> List[ToolResponseMessage]:\n", " assert len(messages) == 1, \"Expected single message\"\n", "\n", @@ -337,20 +344,14 @@ " )\n", " return [message]\n", "\n", - " async def run_impl(self, location: str, date: Optional[str] = None) -> Dict[str, Any]:\n", + " async def run_impl(\n", + " self, location: str, date: Optional[str] = None\n", + " ) -> Dict[str, Any]:\n", " \"\"\"Simulate getting weather data (replace with actual API call).\"\"\"\n", " # Mock implementation\n", " if date:\n", - " return {\n", - " \"temperature\": 90.1,\n", - " \"conditions\": \"sunny\",\n", - " \"humidity\": 40.0\n", - " }\n", - " return {\n", - " \"temperature\": 72.5,\n", - " \"conditions\": \"partly cloudy\",\n", - " \"humidity\": 65.0\n", - " }\n", + " return {\"temperature\": 90.1, \"conditions\": \"sunny\", \"humidity\": 40.0}\n", + " return {\"temperature\": 72.5, \"conditions\": \"partly cloudy\", \"humidity\": 65.0}\n", "\n", "\n", "async def create_weather_agent(client: LlamaStackClient) -> Agent:\n", @@ -358,38 +359,33 @@ "\n", " # Create the agent with the tool\n", " weather_tool = WeatherTool()\n", - " \n", + "\n", " agent_config = AgentConfig(\n", " model=LLAMA31_8B_INSTRUCT,\n", - " #model=model_name,\n", + " # model=model_name,\n", " instructions=\"\"\"\n", " You are a weather assistant that can provide weather information.\n", " Always specify the location clearly in your responses.\n", " Include both temperature and conditions in your summaries.\n", " \"\"\",\n", " sampling_params={\n", - " \"strategy\": \"greedy\",\n", - " \"temperature\": 1.0,\n", - " \"top_p\": 0.9,\n", + " \"strategy\": {\n", + " \"type\": \"greedy\",\n", + " },\n", " },\n", - " tools=[\n", - " weather_tool.get_tool_definition()\n", - " ],\n", + " tools=[weather_tool.get_tool_definition()],\n", " tool_choice=\"auto\",\n", " tool_prompt_format=\"json\",\n", " input_shields=[],\n", " output_shields=[],\n", - " enable_session_persistence=True\n", + " enable_session_persistence=True,\n", " )\n", "\n", - " agent = Agent(\n", - " client=client,\n", - " agent_config=agent_config,\n", - " custom_tools=[weather_tool]\n", - " )\n", + " agent = Agent(client=client, agent_config=agent_config, custom_tools=[weather_tool])\n", "\n", " return agent\n", "\n", + "\n", "# Example usage\n", "async def weather_example():\n", " client = LlamaStackClient(base_url=LLAMA_STACK_API_TOGETHER_URL)\n", @@ -413,12 +409,14 @@ " async for log in EventLogger().log(response):\n", " log.print()\n", "\n", + "\n", "# For Jupyter notebooks\n", "import nest_asyncio\n", + "\n", "nest_asyncio.apply()\n", "\n", "# Run the example\n", - "await weather_example()" + "await weather_example()\n" ] }, { diff --git a/llama_stack/cli/model/describe.py b/llama_stack/cli/model/describe.py index 70e72f7be..fc0190ca8 100644 --- a/llama_stack/cli/model/describe.py +++ b/llama_stack/cli/model/describe.py @@ -13,7 +13,6 @@ from termcolor import colored from llama_stack.cli.subcommand import Subcommand from llama_stack.cli.table import print_table -from llama_stack.distribution.utils.serialize import EnumEncoder class ModelDescribe(Subcommand): @@ -72,7 +71,7 @@ class ModelDescribe(Subcommand): rows.append( ( "Recommended sampling params", - json.dumps(sampling_params, cls=EnumEncoder, indent=4), + json.dumps(sampling_params, indent=4), ) ) diff --git a/llama_stack/distribution/ui/page/evaluations/native_eval.py b/llama_stack/distribution/ui/page/evaluations/native_eval.py index 2cbc8d63e..46839e2f9 100644 --- a/llama_stack/distribution/ui/page/evaluations/native_eval.py +++ b/llama_stack/distribution/ui/page/evaluations/native_eval.py @@ -58,11 +58,6 @@ def define_eval_candidate_2(): # Sampling Parameters st.markdown("##### Sampling Parameters") - strategy = st.selectbox( - "Strategy", - ["greedy", "top_p", "top_k"], - index=0, - ) temperature = st.slider( "Temperature", min_value=0.0, @@ -95,13 +90,20 @@ def define_eval_candidate_2(): help="Controls the likelihood for generating the same word or phrase multiple times in the same sentence or paragraph. 1 implies no penalty, 2 will strongly discourage model to repeat words or phrases.", ) if candidate_type == "model": + if temperature > 0.0: + strategy = { + "type": "top_p", + "temperature": temperature, + "top_p": top_p, + } + else: + strategy = {"type": "greedy"} + eval_candidate = { "type": "model", "model": selected_model, "sampling_params": { "strategy": strategy, - "temperature": temperature, - "top_p": top_p, "max_tokens": max_tokens, "repetition_penalty": repetition_penalty, }, diff --git a/llama_stack/distribution/ui/page/playground/chat.py b/llama_stack/distribution/ui/page/playground/chat.py index 0b8073756..5d91ec819 100644 --- a/llama_stack/distribution/ui/page/playground/chat.py +++ b/llama_stack/distribution/ui/page/playground/chat.py @@ -95,6 +95,15 @@ if prompt := st.chat_input("Example: What is Llama Stack?"): message_placeholder = st.empty() full_response = "" + if temperature > 0.0: + strategy = { + "type": "top_p", + "temperature": temperature, + "top_p": top_p, + } + else: + strategy = {"type": "greedy"} + response = llama_stack_api.client.inference.chat_completion( messages=[ {"role": "system", "content": system_prompt}, @@ -103,8 +112,7 @@ if prompt := st.chat_input("Example: What is Llama Stack?"): model_id=selected_model, stream=stream, sampling_params={ - "temperature": temperature, - "top_p": top_p, + "strategy": strategy, "max_tokens": max_tokens, "repetition_penalty": repetition_penalty, }, diff --git a/llama_stack/distribution/ui/page/playground/rag.py b/llama_stack/distribution/ui/page/playground/rag.py index 196c889ba..3a2ba1270 100644 --- a/llama_stack/distribution/ui/page/playground/rag.py +++ b/llama_stack/distribution/ui/page/playground/rag.py @@ -118,13 +118,20 @@ def rag_chat_page(): with st.chat_message(message["role"]): st.markdown(message["content"]) + if temperature > 0.0: + strategy = { + "type": "top_p", + "temperature": temperature, + "top_p": top_p, + } + else: + strategy = {"type": "greedy"} + agent_config = AgentConfig( model=selected_model, instructions=system_prompt, sampling_params={ - "strategy": "greedy", - "temperature": temperature, - "top_p": top_p, + "strategy": strategy, }, tools=[ { diff --git a/llama_stack/providers/inline/inference/meta_reference/generation.py b/llama_stack/providers/inline/inference/meta_reference/generation.py index 1807e4ad5..a96409cab 100644 --- a/llama_stack/providers/inline/inference/meta_reference/generation.py +++ b/llama_stack/providers/inline/inference/meta_reference/generation.py @@ -23,6 +23,11 @@ from fairscale.nn.model_parallel.initialize import ( initialize_model_parallel, model_parallel_is_initialized, ) +from llama_models.datatypes import ( + GreedySamplingStrategy, + SamplingParams, + TopPSamplingStrategy, +) from llama_models.llama3.api.args import ModelArgs from llama_models.llama3.api.chat_format import ChatFormat, LLMInput from llama_models.llama3.api.datatypes import Model @@ -363,11 +368,12 @@ class Llama: max_gen_len = self.model.params.max_seq_len - 1 model_input = self.formatter.encode_content(request.content) + temperature, top_p = _infer_sampling_params(sampling_params) yield from self.generate( model_input=model_input, max_gen_len=max_gen_len, - temperature=sampling_params.temperature, - top_p=sampling_params.top_p, + temperature=temperature, + top_p=top_p, logprobs=bool(request.logprobs), include_stop_token=True, logits_processor=get_logits_processor( @@ -390,14 +396,15 @@ class Llama: ): max_gen_len = self.model.params.max_seq_len - 1 + temperature, top_p = _infer_sampling_params(sampling_params) yield from self.generate( model_input=self.formatter.encode_dialog_prompt( request.messages, request.tool_prompt_format, ), max_gen_len=max_gen_len, - temperature=sampling_params.temperature, - top_p=sampling_params.top_p, + temperature=temperature, + top_p=top_p, logprobs=bool(request.logprobs), include_stop_token=True, logits_processor=get_logits_processor( @@ -492,3 +499,15 @@ def _build_regular_tokens_list( is_word_start_token = len(decoded_after_0) > len(decoded_regular) regular_tokens.append((token_idx, decoded_after_0, is_word_start_token)) return regular_tokens + + +def _infer_sampling_params(sampling_params: SamplingParams): + if isinstance(sampling_params.strategy, GreedySamplingStrategy): + temperature = 0.0 + top_p = 1.0 + elif isinstance(sampling_params.strategy, TopPSamplingStrategy): + temperature = sampling_params.strategy.temperature + top_p = sampling_params.strategy.top_p + else: + raise ValueError(f"Unsupported sampling strategy {sampling_params.strategy}") + return temperature, top_p diff --git a/llama_stack/providers/inline/inference/vllm/vllm.py b/llama_stack/providers/inline/inference/vllm/vllm.py index 0f1045845..49dd8316e 100644 --- a/llama_stack/providers/inline/inference/vllm/vllm.py +++ b/llama_stack/providers/inline/inference/vllm/vllm.py @@ -36,6 +36,7 @@ from llama_stack.apis.inference import ( from llama_stack.apis.models import Model from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.utils.inference.openai_compat import ( + get_sampling_options, OpenAICompatCompletionChoice, OpenAICompatCompletionResponse, process_chat_completion_response, @@ -126,21 +127,12 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): if sampling_params is None: return VLLMSamplingParams(max_tokens=self.config.max_tokens) - # TODO convert what I saw in my first test ... but surely there's more to do here - kwargs = { - "temperature": sampling_params.temperature, - "max_tokens": self.config.max_tokens, - } - if sampling_params.top_k: - kwargs["top_k"] = sampling_params.top_k - if sampling_params.top_p: - kwargs["top_p"] = sampling_params.top_p - if sampling_params.max_tokens: - kwargs["max_tokens"] = sampling_params.max_tokens - if sampling_params.repetition_penalty > 0: - kwargs["repetition_penalty"] = sampling_params.repetition_penalty + options = get_sampling_options(sampling_params) + if "repeat_penalty" in options: + options["repetition_penalty"] = options["repeat_penalty"] + del options["repeat_penalty"] - return VLLMSamplingParams(**kwargs) + return VLLMSamplingParams(**options) async def unregister_model(self, model_id: str) -> None: pass diff --git a/llama_stack/providers/remote/inference/bedrock/bedrock.py b/llama_stack/providers/remote/inference/bedrock/bedrock.py index 59f30024e..10b51e86b 100644 --- a/llama_stack/providers/remote/inference/bedrock/bedrock.py +++ b/llama_stack/providers/remote/inference/bedrock/bedrock.py @@ -34,6 +34,7 @@ from llama_stack.providers.utils.inference.model_registry import ( ModelRegistryHelper, ) from llama_stack.providers.utils.inference.openai_compat import ( + get_sampling_strategy_options, OpenAICompatCompletionChoice, OpenAICompatCompletionResponse, process_chat_completion_response, @@ -166,16 +167,13 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference): ) -> Dict: bedrock_model = request.model - inference_config = {} - param_mapping = { - "max_tokens": "max_gen_len", - "temperature": "temperature", - "top_p": "top_p", - } + sampling_params = request.sampling_params + options = get_sampling_strategy_options(sampling_params) - for k, v in param_mapping.items(): - if getattr(request.sampling_params, k): - inference_config[v] = getattr(request.sampling_params, k) + if sampling_params.max_tokens: + options["max_gen_len"] = sampling_params.max_tokens + if sampling_params.repetition_penalty > 0: + options["repetition_penalty"] = sampling_params.repetition_penalty prompt = await chat_completion_request_to_prompt( request, self.get_llama_model(request.model), self.formatter @@ -185,7 +183,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference): "body": json.dumps( { "prompt": prompt, - **inference_config, + **options, } ), } diff --git a/llama_stack/providers/remote/inference/cerebras/cerebras.py b/llama_stack/providers/remote/inference/cerebras/cerebras.py index b78471787..0b6ce142c 100644 --- a/llama_stack/providers/remote/inference/cerebras/cerebras.py +++ b/llama_stack/providers/remote/inference/cerebras/cerebras.py @@ -9,6 +9,7 @@ from typing import AsyncGenerator, List, Optional, Union from cerebras.cloud.sdk import AsyncCerebras from llama_models.datatypes import CoreModelId from llama_models.llama3.api.chat_format import ChatFormat +from llama_models.llama3.api.datatypes import TopKSamplingStrategy from llama_models.llama3.api.tokenizer import Tokenizer from llama_stack.apis.common.content_types import InterleavedContent @@ -172,7 +173,9 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference): async def _get_params( self, request: Union[ChatCompletionRequest, CompletionRequest] ) -> dict: - if request.sampling_params and request.sampling_params.top_k: + if request.sampling_params and isinstance( + request.sampling_params.strategy, TopKSamplingStrategy + ): raise ValueError("`top_k` not supported by Cerebras") prompt = "" diff --git a/llama_stack/providers/remote/inference/groq/groq_utils.py b/llama_stack/providers/remote/inference/groq/groq_utils.py index 11f684847..b614c90f4 100644 --- a/llama_stack/providers/remote/inference/groq/groq_utils.py +++ b/llama_stack/providers/remote/inference/groq/groq_utils.py @@ -48,6 +48,9 @@ from llama_stack.apis.inference import ( ToolDefinition, ToolPromptFormat, ) +from llama_stack.providers.utils.inference.openai_compat import ( + get_sampling_strategy_options, +) def convert_chat_completion_request( @@ -77,6 +80,7 @@ def convert_chat_completion_request( if request.tool_prompt_format != ToolPromptFormat.json: warnings.warn("tool_prompt_format is not used by Groq. Ignoring.") + sampling_options = get_sampling_strategy_options(request.sampling_params) return CompletionCreateParams( model=request.model, messages=[_convert_message(message) for message in request.messages], @@ -84,8 +88,8 @@ def convert_chat_completion_request( frequency_penalty=None, stream=request.stream, max_tokens=request.sampling_params.max_tokens or None, - temperature=request.sampling_params.temperature, - top_p=request.sampling_params.top_p, + temperature=sampling_options.get("temperature", 1.0), + top_p=sampling_options.get("top_p", 1.0), tools=[_convert_groq_tool_definition(tool) for tool in request.tools or []], tool_choice=request.tool_choice.value if request.tool_choice else None, ) diff --git a/llama_stack/providers/remote/inference/nvidia/openai_utils.py b/llama_stack/providers/remote/inference/nvidia/openai_utils.py index 975812844..8db7f9197 100644 --- a/llama_stack/providers/remote/inference/nvidia/openai_utils.py +++ b/llama_stack/providers/remote/inference/nvidia/openai_utils.py @@ -8,6 +8,11 @@ import json import warnings from typing import Any, AsyncGenerator, Dict, Generator, List, Optional +from llama_models.datatypes import ( + GreedySamplingStrategy, + TopKSamplingStrategy, + TopPSamplingStrategy, +) from llama_models.llama3.api.datatypes import ( BuiltinTool, StopReason, @@ -263,19 +268,20 @@ def convert_chat_completion_request( if request.sampling_params.max_tokens: payload.update(max_tokens=request.sampling_params.max_tokens) - if request.sampling_params.strategy == "top_p": + strategy = request.sampling_params.strategy + if isinstance(strategy, TopPSamplingStrategy): nvext.update(top_k=-1) - payload.update(top_p=request.sampling_params.top_p) - elif request.sampling_params.strategy == "top_k": - if ( - request.sampling_params.top_k != -1 - and request.sampling_params.top_k < 1 - ): + payload.update(top_p=strategy.top_p) + payload.update(temperature=strategy.temperature) + elif isinstance(strategy, TopKSamplingStrategy): + if strategy.top_k != -1 and strategy.top_k < 1: warnings.warn("top_k must be -1 or >= 1") - nvext.update(top_k=request.sampling_params.top_k) - elif request.sampling_params.strategy == "greedy": + nvext.update(top_k=strategy.top_k) + elif isinstance(strategy, GreedySamplingStrategy): nvext.update(top_k=-1) - payload.update(temperature=request.sampling_params.temperature) + payload.update(temperature=strategy.temperature) + else: + raise ValueError(f"Unsupported sampling strategy: {strategy}") return payload diff --git a/llama_stack/providers/tests/agents/test_agents.py b/llama_stack/providers/tests/agents/test_agents.py index 27fb90572..320096826 100644 --- a/llama_stack/providers/tests/agents/test_agents.py +++ b/llama_stack/providers/tests/agents/test_agents.py @@ -7,6 +7,7 @@ import os import pytest +from llama_models.datatypes import SamplingParams, TopPSamplingStrategy from llama_models.llama3.api.datatypes import BuiltinTool from llama_stack.apis.agents import ( @@ -22,7 +23,8 @@ from llama_stack.apis.agents import ( ToolExecutionStep, Turn, ) -from llama_stack.apis.inference import CompletionMessage, SamplingParams, UserMessage + +from llama_stack.apis.inference import CompletionMessage, UserMessage from llama_stack.apis.safety import ViolationLevel from llama_stack.providers.datatypes import Api @@ -42,7 +44,9 @@ def common_params(inference_model): model=inference_model, instructions="You are a helpful assistant.", enable_session_persistence=True, - sampling_params=SamplingParams(temperature=0.7, top_p=0.95), + sampling_params=SamplingParams( + strategy=TopPSamplingStrategy(temperature=0.7, top_p=0.95) + ), input_shields=[], output_shields=[], toolgroups=[], diff --git a/llama_stack/providers/tests/inference/groq/test_groq_utils.py b/llama_stack/providers/tests/inference/groq/test_groq_utils.py index f3f263cb1..0402a772c 100644 --- a/llama_stack/providers/tests/inference/groq/test_groq_utils.py +++ b/llama_stack/providers/tests/inference/groq/test_groq_utils.py @@ -21,6 +21,7 @@ from groq.types.chat.chat_completion_message_tool_call import ( Function, ) from groq.types.shared.function_definition import FunctionDefinition +from llama_models.datatypes import GreedySamplingStrategy, TopPSamplingStrategy from llama_models.llama3.api.datatypes import ToolParamDefinition from llama_stack.apis.inference import ( ChatCompletionRequest, @@ -152,21 +153,30 @@ class TestConvertChatCompletionRequest: assert converted["max_tokens"] == 100 - def test_includes_temperature(self): + def _dummy_chat_completion_request(self): + return ChatCompletionRequest( + model="Llama-3.2-3B", + messages=[UserMessage(content="Hello World")], + ) + + def test_includes_stratgy(self): request = self._dummy_chat_completion_request() - request.sampling_params.temperature = 0.5 + request.sampling_params.strategy = TopPSamplingStrategy( + temperature=0.5, top_p=0.95 + ) converted = convert_chat_completion_request(request) assert converted["temperature"] == 0.5 + assert converted["top_p"] == 0.95 - def test_includes_top_p(self): + def test_includes_greedy_strategy(self): request = self._dummy_chat_completion_request() - request.sampling_params.top_p = 0.95 + request.sampling_params.strategy = GreedySamplingStrategy() converted = convert_chat_completion_request(request) - assert converted["top_p"] == 0.95 + assert converted["temperature"] == 0.0 def test_includes_tool_choice(self): request = self._dummy_chat_completion_request() @@ -268,12 +278,6 @@ class TestConvertChatCompletionRequest: }, ] - def _dummy_chat_completion_request(self): - return ChatCompletionRequest( - model="Llama-3.2-3B", - messages=[UserMessage(content="Hello World")], - ) - class TestConvertNonStreamChatCompletionResponse: def test_returns_response(self): @@ -409,19 +413,19 @@ class TestConvertStreamChatCompletionResponse: iter = converted.__aiter__() chunk = await iter.__anext__() assert chunk.event.event_type == ChatCompletionResponseEventType.start - assert chunk.event.delta == "Hello " + assert chunk.event.delta.text == "Hello " chunk = await iter.__anext__() assert chunk.event.event_type == ChatCompletionResponseEventType.progress - assert chunk.event.delta == "World " + assert chunk.event.delta.text == "World " chunk = await iter.__anext__() assert chunk.event.event_type == ChatCompletionResponseEventType.progress - assert chunk.event.delta == " !" + assert chunk.event.delta.text == " !" chunk = await iter.__anext__() assert chunk.event.event_type == ChatCompletionResponseEventType.complete - assert chunk.event.delta == "" + assert chunk.event.delta.text == "" assert chunk.event.stop_reason == StopReason.end_of_turn with pytest.raises(StopAsyncIteration): diff --git a/llama_stack/providers/tests/inference/test_text_inference.py b/llama_stack/providers/tests/inference/test_text_inference.py index 932ae36e6..037e99819 100644 --- a/llama_stack/providers/tests/inference/test_text_inference.py +++ b/llama_stack/providers/tests/inference/test_text_inference.py @@ -32,6 +32,7 @@ from llama_stack.apis.inference import ( UserMessage, ) from llama_stack.apis.models import Model + from .utils import group_chunks @@ -476,7 +477,7 @@ class TestInference: last = grouped[ChatCompletionResponseEventType.progress][-1] # assert last.event.stop_reason == expected_stop_reason assert last.event.delta.parse_status == ToolCallParseStatus.succeeded - assert last.event.delta.content.type == "tool_call" + assert isinstance(last.event.delta.content, ToolCall) call = last.event.delta.content assert call.tool_name == "get_weather" diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index 4c46954cf..694212a02 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -8,7 +8,13 @@ from typing import AsyncGenerator, List, Optional from llama_models.llama3.api.chat_format import ChatFormat -from llama_models.llama3.api.datatypes import SamplingParams, StopReason +from llama_models.llama3.api.datatypes import ( + GreedySamplingStrategy, + SamplingParams, + StopReason, + TopKSamplingStrategy, + TopPSamplingStrategy, +) from pydantic import BaseModel from llama_stack.apis.common.content_types import ( @@ -49,12 +55,27 @@ class OpenAICompatCompletionResponse(BaseModel): choices: List[OpenAICompatCompletionChoice] +def get_sampling_strategy_options(params: SamplingParams) -> dict: + options = {} + if isinstance(params.strategy, GreedySamplingStrategy): + options["temperature"] = 0.0 + elif isinstance(params.strategy, TopPSamplingStrategy): + options["temperature"] = params.strategy.temperature + options["top_p"] = params.strategy.top_p + elif isinstance(params.strategy, TopKSamplingStrategy): + options["top_k"] = params.strategy.top_k + else: + raise ValueError(f"Unsupported sampling strategy: {params.strategy}") + + return options + + def get_sampling_options(params: SamplingParams) -> dict: options = {} if params: - for attr in {"temperature", "top_p", "top_k", "max_tokens"}: - if getattr(params, attr): - options[attr] = getattr(params, attr) + options.update(get_sampling_strategy_options(params)) + if params.max_tokens: + options["max_tokens"] = params.max_tokens if params.repetition_penalty is not None and params.repetition_penalty != 1.0: options["repeat_penalty"] = params.repetition_penalty diff --git a/tests/client-sdk/agents/test_agents.py b/tests/client-sdk/agents/test_agents.py index 0c16b6225..19a4064a0 100644 --- a/tests/client-sdk/agents/test_agents.py +++ b/tests/client-sdk/agents/test_agents.py @@ -97,9 +97,11 @@ def agent_config(llama_stack_client): model=model_id, instructions="You are a helpful assistant", sampling_params={ - "strategy": "greedy", - "temperature": 1.0, - "top_p": 0.9, + "strategy": { + "type": "greedy", + "temperature": 1.0, + "top_p": 0.9, + }, }, toolgroups=[], tool_choice="auto",