Merge branch 'main' into feat/litellm_sambanova_usage

This commit is contained in:
Jorge Piedrahita Ortiz 2025-03-24 08:02:40 -05:00 committed by GitHub
commit 8783dd8162
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
190 changed files with 8649 additions and 3304 deletions

View file

@ -14,6 +14,10 @@ on:
- 'requirements.txt'
- '.github/workflows/integration-tests.yml' # This workflow
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
jobs:
test-matrix:
runs-on: ubuntu-latest
@ -52,6 +56,7 @@ jobs:
# always test against the latest version of the client
uv pip install git+https://github.com/meta-llama/llama-stack-client-python.git@main
uv pip install -e .
llama stack build --template ollama --image-type venv
- name: Wait for Ollama to start
run: |

View file

@ -5,6 +5,10 @@ on:
push:
branches: [main]
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
jobs:
pre-commit:
runs-on: ubuntu-latest

View file

@ -18,6 +18,10 @@ on:
- 'llama_stack/distribution/*.sh'
- '.github/workflows/providers-build.yml'
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
jobs:
generate-matrix:
runs-on: ubuntu-latest

View file

@ -8,6 +8,10 @@ on:
- reopened
- synchronize
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
permissions:
contents: read

View file

@ -15,6 +15,10 @@ on:
- '.github/workflows/unit-tests.yml' # This workflow
workflow_dispatch:
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
jobs:
unit-tests:
runs-on: ubuntu-latest

View file

@ -22,6 +22,10 @@ on:
- 'pyproject.toml'
- '.github/workflows/update-readthedocs.yml'
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
jobs:
update-readthedocs:
runs-on: ubuntu-latest

View file

@ -89,10 +89,11 @@ repos:
name: API Spec Codegen
additional_dependencies:
- uv==0.6.2
entry: sh -c 'uv run --with ".[dev]" ./docs/openapi_generator/run_openapi_generator.sh > /dev/null 2>&1'
entry: sh -c 'uv run --with ".[dev]" ./docs/openapi_generator/run_openapi_generator.sh > /dev/null'
language: python
pass_filenames: false
require_serial: true
files: ^llama_stack/apis/|^docs/openapi_generator/
ci:
autofix_commit_msg: 🎨 [pre-commit.ci] Auto format from pre-commit.com hooks

View file

@ -135,9 +135,11 @@ uv sync
## Coding Style
* Comments should provide meaningful insights into the code. Avoid filler comments that simply describe the next step, as they create unnecessary clutter, same goes for docstrings.
* Prefer comments to clarify surprising behavior and/or relationships between parts of the code rather than explain what the next line of code does.
* Catching exceptions, prefer using a specific exception type rather than a broad catch-all like `Exception`.
* Error messages should be prefixed with "Failed to ..."
* 4 spaces for indentation rather than tabs
* 80 character line length
* ...
## Common Tasks
@ -166,7 +168,7 @@ If you have made changes to a provider's configuration in any form (introducing
If you are making changes to the documentation at [https://llama-stack.readthedocs.io/en/latest/](https://llama-stack.readthedocs.io/en/latest/), you can use the following command to build the documentation and preview your changes. You will need [Sphinx](https://www.sphinx-doc.org/en/master/) and the readthedocs theme.
```bash
cd llama-stack/docs
cd docs
uv sync --extra docs
# This rebuilds the documentation pages.

View file

@ -7,10 +7,12 @@
"chardet",
"chromadb-client",
"datasets",
"emoji",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"langdetect",
"matplotlib",
"mcp",
"nltk",
@ -23,6 +25,7 @@
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
@ -41,10 +44,12 @@
"chardet",
"chromadb-client",
"datasets",
"emoji",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"langdetect",
"matplotlib",
"nltk",
"numpy",
@ -56,6 +61,7 @@
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
@ -75,10 +81,12 @@
"chardet",
"chromadb-client",
"datasets",
"emoji",
"fastapi",
"fire",
"fireworks-ai",
"httpx",
"langdetect",
"matplotlib",
"mcp",
"nltk",
@ -91,6 +99,7 @@
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
@ -112,11 +121,13 @@
"chardet",
"chromadb-client",
"datasets",
"emoji",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"huggingface_hub",
"langdetect",
"matplotlib",
"nltk",
"numpy",
@ -128,6 +139,7 @@
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
@ -147,10 +159,12 @@
"chardet",
"chromadb-client",
"datasets",
"emoji",
"fastapi",
"fire",
"fireworks-ai",
"httpx",
"langdetect",
"litellm",
"matplotlib",
"mcp",
@ -164,6 +178,7 @@
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
@ -184,11 +199,13 @@
"chardet",
"chromadb-client",
"datasets",
"emoji",
"faiss-cpu",
"fastapi",
"fire",
"fireworks-ai",
"httpx",
"langdetect",
"matplotlib",
"mcp",
"nltk",
@ -201,6 +218,7 @@
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
@ -219,10 +237,12 @@
"blobfile",
"chardet",
"datasets",
"emoji",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"langdetect",
"litellm",
"matplotlib",
"nltk",
@ -235,6 +255,7 @@
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
@ -253,11 +274,13 @@
"chardet",
"chromadb-client",
"datasets",
"emoji",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"huggingface_hub",
"langdetect",
"matplotlib",
"mcp",
"nltk",
@ -270,6 +293,7 @@
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
@ -288,11 +312,13 @@
"chardet",
"chromadb-client",
"datasets",
"emoji",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"huggingface_hub",
"langdetect",
"matplotlib",
"mcp",
"nltk",
@ -305,6 +331,7 @@
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
@ -325,11 +352,13 @@
"chardet",
"chromadb-client",
"datasets",
"emoji",
"fairscale",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"langdetect",
"lm-format-enforcer",
"matplotlib",
"mcp",
@ -343,6 +372,7 @@
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
@ -365,12 +395,14 @@
"chardet",
"chromadb-client",
"datasets",
"emoji",
"fairscale",
"faiss-cpu",
"fastapi",
"fbgemm-gpu",
"fire",
"httpx",
"langdetect",
"lm-format-enforcer",
"matplotlib",
"mcp",
@ -384,6 +416,7 @@
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
@ -403,10 +436,12 @@
"aiosqlite",
"blobfile",
"chardet",
"emoji",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"langdetect",
"matplotlib",
"nltk",
"numpy",
@ -418,6 +453,7 @@
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
@ -436,10 +472,12 @@
"chardet",
"chromadb-client",
"datasets",
"emoji",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"langdetect",
"matplotlib",
"mcp",
"nltk",
@ -453,6 +491,7 @@
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
@ -470,9 +509,11 @@
"chardet",
"chromadb-client",
"datasets",
"emoji",
"fastapi",
"fire",
"httpx",
"langdetect",
"litellm",
"matplotlib",
"mcp",
@ -486,6 +527,7 @@
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
@ -505,10 +547,12 @@
"chardet",
"chromadb-client",
"datasets",
"emoji",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"langdetect",
"matplotlib",
"mcp",
"nltk",
@ -521,6 +565,7 @@
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
@ -540,10 +585,12 @@
"chardet",
"chromadb-client",
"datasets",
"emoji",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"langdetect",
"matplotlib",
"mcp",
"nltk",
@ -556,6 +603,7 @@
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
@ -608,11 +656,13 @@
"chardet",
"chromadb-client",
"datasets",
"emoji",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"huggingface_hub",
"langdetect",
"matplotlib",
"mcp",
"nltk",
@ -625,6 +675,7 @@
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
@ -644,10 +695,12 @@
"chardet",
"chromadb-client",
"datasets",
"emoji",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"langdetect",
"matplotlib",
"mcp",
"nltk",
@ -660,6 +713,7 @@
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
@ -680,10 +734,12 @@
"chardet",
"chromadb-client",
"datasets",
"emoji",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"langdetect",
"matplotlib",
"mcp",
"nltk",
@ -696,6 +752,7 @@
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",

View file

@ -51,14 +51,14 @@ services:
- ~/local/llama-stack/:/app/llama-stack-source
- ./run${SAFETY_MODEL:+-with-safety}.yaml:/root/my-run.yaml
ports:
- "${LLAMA_STACK_PORT:-5001}:${LLAMA_STACK_PORT:-5001}"
- "${LLAMA_STACK_PORT:-8321}:${LLAMA_STACK_PORT:-8321}"
environment:
- INFERENCE_MODEL=${INFERENCE_MODEL}
- SAFETY_MODEL=${SAFETY_MODEL:-}
- OLLAMA_URL=http://ollama:11434
entrypoint: >
python -m llama_stack.distribution.server.server /root/my-run.yaml \
--port ${LLAMA_STACK_PORT:-5001}
--port ${LLAMA_STACK_PORT:-8321}
deploy:
restart_policy:
condition: on-failure

Binary file not shown.

View file

@ -84,9 +84,9 @@ services:
- SQLITE_STORE_DIR=${SQLITE_STORE_DIR:-$HOME/.llama/distributions/remote-vllm}
- SAFETY_MODEL=${SAFETY_MODEL:-meta-llama/Llama-Guard-3-1B}
ports:
- "${LLAMA_STACK_PORT:-5001}:${LLAMA_STACK_PORT:-5001}"
- "${LLAMA_STACK_PORT:-8321}:${LLAMA_STACK_PORT:-8321}"
# Hack: wait for vLLM server to start before starting docker
entrypoint: bash -c "sleep 60; python -m llama_stack.distribution.server.server --yaml_config /root/llamastack-run-remote-vllm.yaml --port 5001"
entrypoint: bash -c "sleep 60; python -m llama_stack.distribution.server.server --yaml_config /root/llamastack-run-remote-vllm.yaml --port 8321"
deploy:
restart_policy:
condition: on-failure

View file

@ -83,7 +83,7 @@ services:
- ~/.llama:/root/.llama
- ./run${TGI_SAFETY_MODEL:+-with-safety}.yaml:/root/my-run.yaml
ports:
- "${LLAMA_STACK_PORT:-5001}:${LLAMA_STACK_PORT:-5001}"
- "${LLAMA_STACK_PORT:-8321}:${LLAMA_STACK_PORT:-8321}"
# Hack: wait for TGI server to start before starting docker
entrypoint: bash -c "sleep 60; python -m llama_stack.distribution.server.server --yaml_config /root/my-run.yaml"
restart_policy:

View file

@ -2183,7 +2183,7 @@
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/JobStatus"
"$ref": "#/components/schemas/Job"
}
}
}
@ -2650,7 +2650,7 @@
}
},
"tags": [
"Inspect"
"Providers"
],
"description": "",
"parameters": []
@ -6268,6 +6268,7 @@
"type": "string",
"enum": [
"average",
"weighted_average",
"median",
"categorical_count",
"accuracy"
@ -7647,7 +7648,13 @@
"title": "PostTrainingJobArtifactsResponse",
"description": "Artifacts of a finetuning job."
},
"JobStatus": {
"PostTrainingJobStatusResponse": {
"type": "object",
"properties": {
"job_uuid": {
"type": "string"
},
"status": {
"type": "string",
"enum": [
"completed",
@ -7657,15 +7664,6 @@
],
"title": "JobStatus"
},
"PostTrainingJobStatusResponse": {
"type": "object",
"properties": {
"job_uuid": {
"type": "string"
},
"status": {
"$ref": "#/components/schemas/JobStatus"
},
"scheduled_at": {
"type": "string",
"format": "date-time"
@ -8068,9 +8066,6 @@
}
},
"additionalProperties": false,
"required": [
"content"
],
"title": "ToolInvocationResult"
},
"IterrowsResponse": {
@ -8117,6 +8112,30 @@
"title": "IterrowsResponse",
"description": "A paginated list of rows from a dataset."
},
"Job": {
"type": "object",
"properties": {
"job_id": {
"type": "string"
},
"status": {
"type": "string",
"enum": [
"completed",
"in_progress",
"failed",
"scheduled"
],
"title": "JobStatus"
}
},
"additionalProperties": false,
"required": [
"job_id",
"status"
],
"title": "Job"
},
"ListAgentSessionsResponse": {
"type": "object",
"properties": {
@ -9641,19 +9660,6 @@
],
"title": "RunEvalRequest"
},
"Job": {
"type": "object",
"properties": {
"job_id": {
"type": "string"
}
},
"additionalProperties": false,
"required": [
"job_id"
],
"title": "Job"
},
"RunShieldRequest": {
"type": "object",
"properties": {
@ -9862,6 +9868,23 @@
],
"title": "ScoreBatchResponse"
},
"AlgorithmConfig": {
"oneOf": [
{
"$ref": "#/components/schemas/LoraFinetuningConfig"
},
{
"$ref": "#/components/schemas/QATFinetuningConfig"
}
],
"discriminator": {
"propertyName": "type",
"mapping": {
"LoRA": "#/components/schemas/LoraFinetuningConfig",
"QAT": "#/components/schemas/QATFinetuningConfig"
}
}
},
"LoraFinetuningConfig": {
"type": "object",
"properties": {
@ -9997,14 +10020,7 @@
"type": "string"
},
"algorithm_config": {
"oneOf": [
{
"$ref": "#/components/schemas/LoraFinetuningConfig"
},
{
"$ref": "#/components/schemas/QATFinetuningConfig"
}
]
"$ref": "#/components/schemas/AlgorithmConfig"
}
},
"additionalProperties": false,

View file

@ -1491,7 +1491,7 @@ paths:
content:
application/json:
schema:
$ref: '#/components/schemas/JobStatus'
$ref: '#/components/schemas/Job'
'400':
$ref: '#/components/responses/BadRequest400'
'429':
@ -1814,7 +1814,7 @@ paths:
default:
$ref: '#/components/responses/DefaultError'
tags:
- Inspect
- Providers
description: ''
parameters: []
/v1/inspect/routes:
@ -4389,6 +4389,7 @@ components:
type: string
enum:
- average
- weighted_average
- median
- categorical_count
- accuracy
@ -5276,7 +5277,12 @@ components:
- checkpoints
title: PostTrainingJobArtifactsResponse
description: Artifacts of a finetuning job.
JobStatus:
PostTrainingJobStatusResponse:
type: object
properties:
job_uuid:
type: string
status:
type: string
enum:
- completed
@ -5284,13 +5290,6 @@ components:
- failed
- scheduled
title: JobStatus
PostTrainingJobStatusResponse:
type: object
properties:
job_uuid:
type: string
status:
$ref: '#/components/schemas/JobStatus'
scheduled_at:
type: string
format: date-time
@ -5528,8 +5527,6 @@ components:
- type: array
- type: object
additionalProperties: false
required:
- content
title: ToolInvocationResult
IterrowsResponse:
type: object
@ -5557,6 +5554,24 @@ components:
- data
title: IterrowsResponse
description: A paginated list of rows from a dataset.
Job:
type: object
properties:
job_id:
type: string
status:
type: string
enum:
- completed
- in_progress
- failed
- scheduled
title: JobStatus
additionalProperties: false
required:
- job_id
- status
title: Job
ListAgentSessionsResponse:
type: object
properties:
@ -6551,15 +6566,6 @@ components:
required:
- benchmark_config
title: RunEvalRequest
Job:
type: object
properties:
job_id:
type: string
additionalProperties: false
required:
- job_id
title: Job
RunShieldRequest:
type: object
properties:
@ -6688,6 +6694,15 @@ components:
required:
- results
title: ScoreBatchResponse
AlgorithmConfig:
oneOf:
- $ref: '#/components/schemas/LoraFinetuningConfig'
- $ref: '#/components/schemas/QATFinetuningConfig'
discriminator:
propertyName: type
mapping:
LoRA: '#/components/schemas/LoraFinetuningConfig'
QAT: '#/components/schemas/QATFinetuningConfig'
LoraFinetuningConfig:
type: object
properties:
@ -6771,9 +6786,7 @@ components:
checkpoint_dir:
type: string
algorithm_config:
oneOf:
- $ref: '#/components/schemas/LoraFinetuningConfig'
- $ref: '#/components/schemas/QATFinetuningConfig'
$ref: '#/components/schemas/AlgorithmConfig'
additionalProperties: false
required:
- job_uuid

View file

@ -4,6 +4,21 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
import time
def pytest_collection_modifyitems(items):
for item in items:
item.name = item.name.replace(' ', '_')
def pytest_runtest_teardown(item):
interval_seconds = os.getenv("LLAMA_STACK_TEST_INTERVAL_SECONDS")
if interval_seconds:
time.sleep(float(interval_seconds))
def pytest_configure(config):
config.option.tbstyle = "short"
config.option.disable_warnings = True

View file

@ -123,6 +123,8 @@
"outputs": [],
"source": [
"# NBVAL_SKIP\n",
"!pip uninstall pandas numpy -y\n",
"!pip install pandas numpy\n",
"# This will build all the dependencies you will need\n",
"!UV_SYSTEM_PYTHON=1 llama stack build --template together --image-type venv"
]
@ -1203,7 +1205,7 @@
}
],
"source": [
"from llama_stack_client.lib.inference.event_logger import EventLogger\n",
"from llama_stack_client import InferenceEventLogger\n",
"\n",
"message = {\"role\": \"user\", \"content\": \"Write me a sonnet about llama\"}\n",
"print(f'User> {message[\"content\"]}', \"green\")\n",
@ -1215,7 +1217,7 @@
")\n",
"\n",
"# Print the tokens while they are received\n",
"for log in EventLogger().log(response):\n",
"for log in InferenceEventLogger().log(response):\n",
" log.print()\n"
]
},
@ -1632,8 +1634,7 @@
}
],
"source": [
"from llama_stack_client.lib.agents.agent import Agent\n",
"from llama_stack_client.lib.agents.event_logger import EventLogger\n",
"from llama_stack_client import Agent, AgentEventLogger\n",
"from termcolor import cprint\n",
"\n",
"agent = Agent(\n",
@ -1659,7 +1660,7 @@
" ],\n",
" session_id=session_id,\n",
" )\n",
" for log in EventLogger().log(response):\n",
" for log in AgentEventLogger().log(response):\n",
" log.print()\n"
]
},
@ -1808,14 +1809,12 @@
],
"source": [
"import uuid\n",
"from llama_stack_client.lib.agents.agent import Agent\n",
"from llama_stack_client.lib.agents.event_logger import EventLogger\n",
"from llama_stack_client import Agent, AgentEventLogger, RAGDocument\n",
"from termcolor import cprint\n",
"from llama_stack_client.types import Document\n",
"\n",
"urls = [\"chat.rst\", \"llama3.rst\", \"memory_optimizations.rst\", \"lora_finetune.rst\"]\n",
"documents = [\n",
" Document(\n",
" RAGDocument(\n",
" document_id=f\"num-{i}\",\n",
" content=f\"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}\",\n",
" mime_type=\"text/plain\",\n",
@ -1858,7 +1857,7 @@
" messages=[{\"role\": \"user\", \"content\": prompt}],\n",
" session_id=session_id,\n",
" )\n",
" for log in EventLogger().log(response):\n",
" for log in AgentEventLogger().log(response):\n",
" log.print()"
]
},
@ -1969,7 +1968,7 @@
}
],
"source": [
"from llama_stack_client.types.agents.turn_create_params import Document\n",
"from llama_stack_client import Document\n",
"\n",
"codex_agent = Agent(\n",
" client, \n",
@ -2013,7 +2012,7 @@
" # for chunk in response:\n",
" # print(chunk)\n",
"\n",
" for log in EventLogger().log(response):\n",
" for log in AgentEventLogger().log(response):\n",
" log.print()\n"
]
},
@ -2891,8 +2890,7 @@
],
"source": [
"# NBVAL_SKIP\n",
"from llama_stack_client.lib.agents.agent import Agent\n",
"from llama_stack_client.lib.agents.event_logger import EventLogger\n",
"from llama_stack_client import Agent, AgentEventLogger\n",
"from termcolor import cprint\n",
"\n",
"agent = Agent(\n",
@ -2918,7 +2916,7 @@
" ],\n",
" session_id=session_id,\n",
" )\n",
" for log in EventLogger().log(response):\n",
" for log in AgentEventLogger().log(response):\n",
" log.print()\n"
]
},
@ -2993,8 +2991,7 @@
}
],
"source": [
"from llama_stack_client.lib.agents.agent import Agent\n",
"from llama_stack_client.lib.agents.event_logger import EventLogger\n",
"from llama_stack_client import Agent, AgentEventLogger\n",
"\n",
"agent = Agent(\n",
" client, \n",
@ -3021,7 +3018,7 @@
" session_id=session_id,\n",
" )\n",
"\n",
" for log in EventLogger().log(response):\n",
" for log in AgentEventLogger().log(response):\n",
" log.print()\n"
]
},
@ -4355,7 +4352,7 @@
" session_id=session_id,\n",
")\n",
"\n",
"for log in EventLogger().log(response):\n",
"for log in AgentEventLogger().log(response):\n",
" log.print()\n",
" "
]

View file

@ -47,9 +47,8 @@
"metadata": {},
"outputs": [],
"source": [
"from llama_stack_client import LlamaStackClient\n",
"from llama_stack_client import LlamaStackClient, Agent\n",
"from llama_stack.distribution.library_client import LlamaStackAsLibraryClient\n",
"from llama_stack_client.lib.agents.agent import Agent\n",
"from rich.pretty import pprint\n",
"import json\n",
"import uuid\n",

View file

@ -22,7 +22,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 13,
"metadata": {},
"outputs": [
{
@ -34,10 +34,8 @@
}
],
"source": [
"from llama_stack_client import LlamaStackClient\n",
"from llama_stack_client import LlamaStackClient, Agent\n",
"from llama_stack.distribution.library_client import LlamaStackAsLibraryClient\n",
"from llama_stack_client.types.agent_create_params import AgentConfig\n",
"from llama_stack_client.lib.agents.agent import Agent\n",
"from rich.pretty import pprint\n",
"import json\n",
"import uuid\n",
@ -70,7 +68,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
@ -1397,6 +1395,349 @@
"pprint(session_response.turns)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 3.1 Improved RAG with Long Context\n",
"\n",
"- Instead of performing reteival tool, we send documents as attachments to the agent and let it use the entire document context. \n",
"- Note how that the model is able to understand the entire context from documentation and answers the question with better factuality with improved retrieval. "
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">Question:</span> What precision formats does torchtune support?\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1;36mQuestion:\u001b[0m What precision formats does torchtune support?\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #808000; text-decoration-color: #808000; font-weight: bold\">Agent Answer:</span> Torchtune supports two precision formats: `fp32` <span style=\"font-weight: bold\">(</span>full-precision<span style=\"font-weight: bold\">)</span> and `bfloat16` <span style=\"font-weight: bold\">(</span>half-precision<span style=\"font-weight: bold\">)</span>. \n",
"The `bfloat16` format uses <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2</span> bytes per model parameter, which is half the memory of `fp32`, and also improves \n",
"training speed.\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1;33mAgent Answer:\u001b[0m Torchtune supports two precision formats: `fp32` \u001b[1m(\u001b[0mfull-precision\u001b[1m)\u001b[0m and `bfloat16` \u001b[1m(\u001b[0mhalf-precision\u001b[1m)\u001b[0m. \n",
"The `bfloat16` format uses \u001b[1;36m2\u001b[0m bytes per model parameter, which is half the memory of `fp32`, and also improves \n",
"training speed.\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">Question:</span> What does DoRA stand for in torchtune?\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1;36mQuestion:\u001b[0m What does DoRA stand for in torchtune?\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #808000; text-decoration-color: #808000; font-weight: bold\">Agent Answer:</span> DoRA stands for Weight-Decomposed Low-Rank Adaptation. It is a variant of LoRA <span style=\"font-weight: bold\">(</span>Low-Rank Adaptation<span style=\"font-weight: bold\">)</span> \n",
"that further decomposes the pre-trained weights into two components: magnitude and direction. The magnitude \n",
"component is a scalar vector that adjusts the scale, while the direction component corresponds to the original LoRA\n",
"decomposition and updates the orientation of weights. DoRA adds a small overhead to LoRA training due to the \n",
"addition of the magnitude parameter, but it has been shown to improve the performance of LoRA, particularly at low \n",
"ranks.\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1;33mAgent Answer:\u001b[0m DoRA stands for Weight-Decomposed Low-Rank Adaptation. It is a variant of LoRA \u001b[1m(\u001b[0mLow-Rank Adaptation\u001b[1m)\u001b[0m \n",
"that further decomposes the pre-trained weights into two components: magnitude and direction. The magnitude \n",
"component is a scalar vector that adjusts the scale, while the direction component corresponds to the original LoRA\n",
"decomposition and updates the orientation of weights. DoRA adds a small overhead to LoRA training due to the \n",
"addition of the magnitude parameter, but it has been shown to improve the performance of LoRA, particularly at low \n",
"ranks.\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">Question:</span> How does the CPUOffloadOptimizer reduce GPU memory usage?\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1;36mQuestion:\u001b[0m How does the CPUOffloadOptimizer reduce GPU memory usage?\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #808000; text-decoration-color: #808000; font-weight: bold\">Agent Answer:</span> The CPUOffloadOptimizer reduces GPU memory usage by offloading optimizer states and gradients to the \n",
"CPU, and performing optimizer steps on the CPU. This can significantly reduce GPU memory usage at the cost of CPU \n",
"RAM and training speed.\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1;33mAgent Answer:\u001b[0m The CPUOffloadOptimizer reduces GPU memory usage by offloading optimizer states and gradients to the \n",
"CPU, and performing optimizer steps on the CPU. This can significantly reduce GPU memory usage at the cost of CPU \n",
"RAM and training speed.\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">Question:</span> How do I ensure only LoRA parameters are trainable when fine-tuning?\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1;36mQuestion:\u001b[0m How do I ensure only LoRA parameters are trainable when fine-tuning?\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #808000; text-decoration-color: #808000; font-weight: bold\">Agent Answer:</span> To ensure only LoRA parameters are trainable when fine-tuning, you can use the `set_trainable_params`\n",
"function from `torchtune.modules.peft.peft_utils` to set the `requires_grad` attribute of the LoRA parameters to \n",
"`<span style=\"color: #00ff00; text-decoration-color: #00ff00; font-style: italic\">True</span>` and the `requires_grad` attribute of the other parameters to `<span style=\"color: #ff0000; text-decoration-color: #ff0000; font-style: italic\">False</span>`.\n",
"\n",
"Here is an example:\n",
"```python\n",
"from torchtune.modules.peft.peft_utils import get_adapter_params, set_trainable_params\n",
"\n",
"# Get the LoRA parameters\n",
"lora_params = <span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">get_adapter_params</span><span style=\"font-weight: bold\">(</span>model<span style=\"font-weight: bold\">)</span>\n",
"\n",
"# Set the LoRA parameters to trainable and the other parameters to non-trainable\n",
"<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">set_trainable_params</span><span style=\"font-weight: bold\">(</span>model, lora_params<span style=\"font-weight: bold\">)</span>\n",
"```\n",
"This will ensure that only the LoRA parameters are updated during fine-tuning, while the other parameters remain \n",
"frozen.\n",
"\n",
"Alternatively, you can also use the `lora_finetune` recipe in torchtune, which automatically sets the LoRA \n",
"parameters to trainable and the other parameters to non-trainable. You can run the recipe using the following \n",
"command:\n",
"```bash\n",
"tune run lora_finetune --config llama2/7B_lora\n",
"```\n",
"This will fine-tune the LoRA parameters of the Llama2 model using the default settings. You can modify the config \n",
"file to change the hyperparameters or the model architecture.\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1;33mAgent Answer:\u001b[0m To ensure only LoRA parameters are trainable when fine-tuning, you can use the `set_trainable_params`\n",
"function from `torchtune.modules.peft.peft_utils` to set the `requires_grad` attribute of the LoRA parameters to \n",
"`\u001b[3;92mTrue\u001b[0m` and the `requires_grad` attribute of the other parameters to `\u001b[3;91mFalse\u001b[0m`.\n",
"\n",
"Here is an example:\n",
"```python\n",
"from torchtune.modules.peft.peft_utils import get_adapter_params, set_trainable_params\n",
"\n",
"# Get the LoRA parameters\n",
"lora_params = \u001b[1;35mget_adapter_params\u001b[0m\u001b[1m(\u001b[0mmodel\u001b[1m)\u001b[0m\n",
"\n",
"# Set the LoRA parameters to trainable and the other parameters to non-trainable\n",
"\u001b[1;35mset_trainable_params\u001b[0m\u001b[1m(\u001b[0mmodel, lora_params\u001b[1m)\u001b[0m\n",
"```\n",
"This will ensure that only the LoRA parameters are updated during fine-tuning, while the other parameters remain \n",
"frozen.\n",
"\n",
"Alternatively, you can also use the `lora_finetune` recipe in torchtune, which automatically sets the LoRA \n",
"parameters to trainable and the other parameters to non-trainable. You can run the recipe using the following \n",
"command:\n",
"```bash\n",
"tune run lora_finetune --config llama2/7B_lora\n",
"```\n",
"This will fine-tune the LoRA parameters of the Llama2 model using the default settings. You can modify the config \n",
"file to change the hyperparameters or the model architecture.\n"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"urls = [\n",
" \"memory_optimizations.rst\",\n",
" \"chat.rst\",\n",
" \"llama3.rst\",\n",
" \"datasets.rst\",\n",
" \"qat_finetune.rst\",\n",
" \"lora_finetune.rst\",\n",
"]\n",
"\n",
"attachments = [\n",
" {\n",
" \"content\": {\n",
" \"uri\": f\"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}\",\n",
" },\n",
" \"mime_type\": \"text/plain\",\n",
" }\n",
"\n",
" for i, url in enumerate(urls)\n",
"]\n",
"\n",
"rag_attachment_agent = Agent(\n",
" client,\n",
" model=MODEL_ID,\n",
" instructions=\"You are a helpful assistant that can answer questions about the Torchtune project. Use context from attached documentation for Torchtune to answer questions.\",\n",
")\n",
"\n",
"for example in examples:\n",
" session_id = rag_attachment_agent.create_session(session_name=f\"rag_attachment_session_{uuid.uuid4()}\")\n",
" response = rag_attachment_agent.create_turn(\n",
" messages=[\n",
" {\n",
" \"role\": \"user\",\n",
" \"content\": example[\"input_query\"]\n",
" }\n",
" ],\n",
" session_id=session_id,\n",
" documents=attachments,\n",
" stream=False\n",
" )\n",
" rich.print(f\"[bold cyan]Question:[/bold cyan] {example['input_query']}\")\n",
" rich.print(f\"[bold yellow]Agent Answer:[/bold yellow] {response.output_message.content}\")\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">ScoringScoreResponse</span><span style=\"font-weight: bold\">(</span>\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ </span><span style=\"color: #808000; text-decoration-color: #808000\">results</span>=<span style=\"font-weight: bold\">{</span>\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ </span><span style=\"color: #008000; text-decoration-color: #008000\">'braintrust::factuality'</span>: <span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">ScoringResult</span><span style=\"font-weight: bold\">(</span>\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ </span><span style=\"color: #808000; text-decoration-color: #808000\">aggregated_results</span>=<span style=\"font-weight: bold\">{</span><span style=\"color: #008000; text-decoration-color: #008000\">'average'</span>: <span style=\"font-weight: bold\">{</span><span style=\"color: #008000; text-decoration-color: #008000\">'average'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.6</span><span style=\"font-weight: bold\">}}</span>,\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ </span><span style=\"color: #808000; text-decoration-color: #808000\">score_rows</span>=<span style=\"font-weight: bold\">[</span>\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ </span><span style=\"font-weight: bold\">{</span>\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ │ </span><span style=\"color: #008000; text-decoration-color: #008000\">'score'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.6</span>,\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ │ </span><span style=\"color: #008000; text-decoration-color: #008000\">'metadata'</span>: <span style=\"font-weight: bold\">{</span>\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ │ │ </span><span style=\"color: #008000; text-decoration-color: #008000\">'choice'</span>: <span style=\"color: #008000; text-decoration-color: #008000\">'B'</span>,\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ │ │ </span><span style=\"color: #008000; text-decoration-color: #008000\">'rationale'</span>: <span style=\"color: #008000; text-decoration-color: #008000\">'1. Both the expert and the submitted answers mention that Torchtune supports two precision formats: `fp32` (full-precision) and `bfloat16` (half-precision).\\n2. The expert answer specifies that `fp32` uses 4 bytes per model and optimizer parameter, while `bfloat16` uses 2 bytes per model and optimizer parameter.\\n3. The submitted answer also mentions that `bfloat16` uses 2 bytes per model parameter, which is consistent with the expert answer.\\n4. The submitted answer adds that `bfloat16` improves training speed, which is additional information not present in the expert answer.\\n5. There is no conflict between the submitted answer and the expert answer; the submitted answer simply provides more information.\\n\\nBased on this analysis, the submitted answer is a superset of the expert answer and is fully consistent with it.'</span>\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ │ </span><span style=\"font-weight: bold\">}</span>\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ </span><span style=\"font-weight: bold\">}</span>,\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ </span><span style=\"font-weight: bold\">{</span>\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ │ </span><span style=\"color: #008000; text-decoration-color: #008000\">'score'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.6</span>,\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ │ </span><span style=\"color: #008000; text-decoration-color: #008000\">'metadata'</span>: <span style=\"font-weight: bold\">{</span>\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ │ │ </span><span style=\"color: #008000; text-decoration-color: #008000\">'choice'</span>: <span style=\"color: #008000; text-decoration-color: #008000\">'B'</span>,\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ │ │ </span><span style=\"color: #008000; text-decoration-color: #008000\">'rationale'</span>: <span style=\"color: #008000; text-decoration-color: #008000\">'1. The expert answer provides the definition of DoRA as \"Weight-Decomposed Low-Rank Adaptation.\"\\n2. The submitted answer also states that DoRA stands for \"Weight-Decomposed Low-Rank Adaptation,\" which matches the expert answer.\\n3. The submitted answer includes additional information about DoRA, explaining that it is a variant of LoRA and describing how it decomposes pre-trained weights into magnitude and direction components.\\n4. The submitted answer further explains the role of the magnitude component and the direction component, and mentions the performance improvement and overhead associated with DoRA.\\n5. The additional details in the submitted answer do not contradict the expert answer; instead, they expand upon it.\\n6. Therefore, the submitted answer is a superset of the expert answer and is fully consistent with it.'</span>\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ │ </span><span style=\"font-weight: bold\">}</span>\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ </span><span style=\"font-weight: bold\">}</span>,\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ </span><span style=\"font-weight: bold\">{</span>\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ │ </span><span style=\"color: #008000; text-decoration-color: #008000\">'score'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.6</span>,\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ │ </span><span style=\"color: #008000; text-decoration-color: #008000\">'metadata'</span>: <span style=\"font-weight: bold\">{</span>\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ │ │ </span><span style=\"color: #008000; text-decoration-color: #008000\">'choice'</span>: <span style=\"color: #008000; text-decoration-color: #008000\">'B'</span>,\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ │ │ </span><span style=\"color: #008000; text-decoration-color: #008000\">'rationale'</span>: <span style=\"color: #008000; text-decoration-color: #008000\">'1. The expert answer states that the CPUOffloadOptimizer reduces GPU memory usage by keeping optimizer states on CPU and performing optimizer steps on CPU. It also mentions the optional offloading of gradients to CPU with the parameter offload_gradients=True.\\n\\n2. The submitted answer states that the CPUOffloadOptimizer reduces GPU memory usage by offloading optimizer states and gradients to the CPU, and performing optimizer steps on the CPU. It adds that this can significantly reduce GPU memory usage at the cost of CPU RAM and training speed.\\n\\n3. Comparing both answers:\\n - Both answers agree on offloading optimizer states to the CPU and performing optimizer steps on the CPU.\\n - Both mention the offloading of gradients to the CPU, but the expert answer specifies it as optional with a parameter, while the submission does not specify this detail.\\n - The submission adds additional information about the trade-off involving CPU RAM and training speed, which is not mentioned in the expert answer.\\n\\n4. The submitted answer includes all the details from the expert answer and adds more information about the trade-offs, making it a superset of the expert answer.\\n\\nTherefore, the correct choice is (B) The submitted answer is a superset of the expert answer and is fully consistent with it.'</span>\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ │ </span><span style=\"font-weight: bold\">}</span>\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ </span><span style=\"font-weight: bold\">}</span>,\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ </span><span style=\"font-weight: bold\">{</span>\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ │ </span><span style=\"color: #008000; text-decoration-color: #008000\">'score'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.6</span>,\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ │ </span><span style=\"color: #008000; text-decoration-color: #008000\">'metadata'</span>: <span style=\"font-weight: bold\">{</span>\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ │ │ </span><span style=\"color: #008000; text-decoration-color: #008000\">'choice'</span>: <span style=\"color: #008000; text-decoration-color: #008000\">'B'</span>,\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ │ │ </span><span style=\"color: #008000; text-decoration-color: #008000\">'rationale'</span>: <span style=\"color: #008000; text-decoration-color: #008000\">\"1. **Expert Answer Analysis**: The expert answer provides a method to ensure only LoRA parameters are trainable by using torchtune's utility functions. It mentions fetching LoRA parameters with `get_adapter_params(lora_model)` and setting them as trainable with `set_trainable_params(lora_model, lora_params)`. It also notes that the LoRA recipe handles this automatically.\\n\\n2. **Submitted Answer Analysis**: The submitted answer provides a similar method using `set_trainable_params` to set the `requires_grad` attribute of LoRA parameters to `True` and other parameters to `False`. It includes a code example demonstrating this process. Additionally, it mentions using the `lora_finetune` recipe in torchtune, which automatically sets the LoRA parameters to trainable.\\n\\n3. **Comparison**: The submitted answer includes all the details from the expert answer regarding the use of `get_adapter_params` and `set_trainable_params`. It also provides additional information about setting the `requires_grad` attribute and using the `lora_finetune` recipe, which is not mentioned in the expert answer.\\n\\n4. **Conclusion**: The submitted answer is a superset of the expert answer as it contains all the information from the expert answer and additional details. There is no conflict between the two answers, and the additional information in the submission is consistent with the expert's explanation.\\n\\nTherefore, the correct choice is (B) The submitted answer is a superset of the expert answer and is fully consistent with it.\"</span>\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ │ </span><span style=\"font-weight: bold\">}</span>\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ </span><span style=\"font-weight: bold\">}</span>\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ </span><span style=\"font-weight: bold\">]</span>\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ </span><span style=\"font-weight: bold\">)</span>\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ </span><span style=\"font-weight: bold\">}</span>\n",
"<span style=\"font-weight: bold\">)</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1;35mScoringScoreResponse\u001b[0m\u001b[1m(\u001b[0m\n",
"\u001b[2;32m│ \u001b[0m\u001b[33mresults\u001b[0m=\u001b[1m{\u001b[0m\n",
"\u001b[2;32m│ │ \u001b[0m\u001b[32m'braintrust::factuality'\u001b[0m: \u001b[1;35mScoringResult\u001b[0m\u001b[1m(\u001b[0m\n",
"\u001b[2;32m│ │ │ \u001b[0m\u001b[33maggregated_results\u001b[0m=\u001b[1m{\u001b[0m\u001b[32m'average'\u001b[0m: \u001b[1m{\u001b[0m\u001b[32m'average'\u001b[0m: \u001b[1;36m0.6\u001b[0m\u001b[1m}\u001b[0m\u001b[1m}\u001b[0m,\n",
"\u001b[2;32m│ │ │ \u001b[0m\u001b[33mscore_rows\u001b[0m=\u001b[1m[\u001b[0m\n",
"\u001b[2;32m│ │ │ │ \u001b[0m\u001b[1m{\u001b[0m\n",
"\u001b[2;32m│ │ │ │ │ \u001b[0m\u001b[32m'score'\u001b[0m: \u001b[1;36m0.6\u001b[0m,\n",
"\u001b[2;32m│ │ │ │ │ \u001b[0m\u001b[32m'metadata'\u001b[0m: \u001b[1m{\u001b[0m\n",
"\u001b[2;32m│ │ │ │ │ │ \u001b[0m\u001b[32m'choice'\u001b[0m: \u001b[32m'B'\u001b[0m,\n",
"\u001b[2;32m│ │ │ │ │ │ \u001b[0m\u001b[32m'rationale'\u001b[0m: \u001b[32m'1. Both the expert and the submitted answers mention that Torchtune supports two precision formats: `fp32` \u001b[0m\u001b[32m(\u001b[0m\u001b[32mfull-precision\u001b[0m\u001b[32m)\u001b[0m\u001b[32m and `bfloat16` \u001b[0m\u001b[32m(\u001b[0m\u001b[32mhalf-precision\u001b[0m\u001b[32m)\u001b[0m\u001b[32m.\\n2. The expert answer specifies that `fp32` uses 4 bytes per model and optimizer parameter, while `bfloat16` uses 2 bytes per model and optimizer parameter.\\n3. The submitted answer also mentions that `bfloat16` uses 2 bytes per model parameter, which is consistent with the expert answer.\\n4. The submitted answer adds that `bfloat16` improves training speed, which is additional information not present in the expert answer.\\n5. There is no conflict between the submitted answer and the expert answer; the submitted answer simply provides more information.\\n\\nBased on this analysis, the submitted answer is a superset of the expert answer and is fully consistent with it.'\u001b[0m\n",
"\u001b[2;32m│ │ │ │ │ \u001b[0m\u001b[1m}\u001b[0m\n",
"\u001b[2;32m│ │ │ │ \u001b[0m\u001b[1m}\u001b[0m,\n",
"\u001b[2;32m│ │ │ │ \u001b[0m\u001b[1m{\u001b[0m\n",
"\u001b[2;32m│ │ │ │ │ \u001b[0m\u001b[32m'score'\u001b[0m: \u001b[1;36m0.6\u001b[0m,\n",
"\u001b[2;32m│ │ │ │ │ \u001b[0m\u001b[32m'metadata'\u001b[0m: \u001b[1m{\u001b[0m\n",
"\u001b[2;32m│ │ │ │ │ │ \u001b[0m\u001b[32m'choice'\u001b[0m: \u001b[32m'B'\u001b[0m,\n",
"\u001b[2;32m│ │ │ │ │ │ \u001b[0m\u001b[32m'rationale'\u001b[0m: \u001b[32m'1. The expert answer provides the definition of DoRA as \"Weight-Decomposed Low-Rank Adaptation.\"\\n2. The submitted answer also states that DoRA stands for \"Weight-Decomposed Low-Rank Adaptation,\" which matches the expert answer.\\n3. The submitted answer includes additional information about DoRA, explaining that it is a variant of LoRA and describing how it decomposes pre-trained weights into magnitude and direction components.\\n4. The submitted answer further explains the role of the magnitude component and the direction component, and mentions the performance improvement and overhead associated with DoRA.\\n5. The additional details in the submitted answer do not contradict the expert answer; instead, they expand upon it.\\n6. Therefore, the submitted answer is a superset of the expert answer and is fully consistent with it.'\u001b[0m\n",
"\u001b[2;32m│ │ │ │ │ \u001b[0m\u001b[1m}\u001b[0m\n",
"\u001b[2;32m│ │ │ │ \u001b[0m\u001b[1m}\u001b[0m,\n",
"\u001b[2;32m│ │ │ │ \u001b[0m\u001b[1m{\u001b[0m\n",
"\u001b[2;32m│ │ │ │ │ \u001b[0m\u001b[32m'score'\u001b[0m: \u001b[1;36m0.6\u001b[0m,\n",
"\u001b[2;32m│ │ │ │ │ \u001b[0m\u001b[32m'metadata'\u001b[0m: \u001b[1m{\u001b[0m\n",
"\u001b[2;32m│ │ │ │ │ │ \u001b[0m\u001b[32m'choice'\u001b[0m: \u001b[32m'B'\u001b[0m,\n",
"\u001b[2;32m│ │ │ │ │ │ \u001b[0m\u001b[32m'rationale'\u001b[0m: \u001b[32m'1. The expert answer states that the CPUOffloadOptimizer reduces GPU memory usage by keeping optimizer states on CPU and performing optimizer steps on CPU. It also mentions the optional offloading of gradients to CPU with the parameter \u001b[0m\u001b[32moffload_gradients\u001b[0m\u001b[32m=\u001b[0m\u001b[32mTrue\u001b[0m\u001b[32m.\\n\\n2. The submitted answer states that the CPUOffloadOptimizer reduces GPU memory usage by offloading optimizer states and gradients to the CPU, and performing optimizer steps on the CPU. It adds that this can significantly reduce GPU memory usage at the cost of CPU RAM and training speed.\\n\\n3. Comparing both answers:\\n - Both answers agree on offloading optimizer states to the CPU and performing optimizer steps on the CPU.\\n - Both mention the offloading of gradients to the CPU, but the expert answer specifies it as optional with a parameter, while the submission does not specify this detail.\\n - The submission adds additional information about the trade-off involving CPU RAM and training speed, which is not mentioned in the expert answer.\\n\\n4. The submitted answer includes all the details from the expert answer and adds more information about the trade-offs, making it a superset of the expert answer.\\n\\nTherefore, the correct choice is \u001b[0m\u001b[32m(\u001b[0m\u001b[32mB\u001b[0m\u001b[32m)\u001b[0m\u001b[32m The submitted answer is a superset of the expert answer and is fully consistent with it.'\u001b[0m\n",
"\u001b[2;32m│ │ │ │ │ \u001b[0m\u001b[1m}\u001b[0m\n",
"\u001b[2;32m│ │ │ │ \u001b[0m\u001b[1m}\u001b[0m,\n",
"\u001b[2;32m│ │ │ │ \u001b[0m\u001b[1m{\u001b[0m\n",
"\u001b[2;32m│ │ │ │ │ \u001b[0m\u001b[32m'score'\u001b[0m: \u001b[1;36m0.6\u001b[0m,\n",
"\u001b[2;32m│ │ │ │ │ \u001b[0m\u001b[32m'metadata'\u001b[0m: \u001b[1m{\u001b[0m\n",
"\u001b[2;32m│ │ │ │ │ │ \u001b[0m\u001b[32m'choice'\u001b[0m: \u001b[32m'B'\u001b[0m,\n",
"\u001b[2;32m│ │ │ │ │ │ \u001b[0m\u001b[32m'rationale'\u001b[0m: \u001b[32m\"1. **Expert Answer Analysis**: The expert answer provides a method to ensure only LoRA parameters are trainable by using torchtune's utility functions. It mentions fetching LoRA parameters with `get_adapter_params\u001b[0m\u001b[32m(\u001b[0m\u001b[32mlora_model\u001b[0m\u001b[32m)\u001b[0m\u001b[32m` and setting them as trainable with `set_trainable_params\u001b[0m\u001b[32m(\u001b[0m\u001b[32mlora_model, lora_params\u001b[0m\u001b[32m)\u001b[0m\u001b[32m`. It also notes that the LoRA recipe handles this automatically.\\n\\n2. **Submitted Answer Analysis**: The submitted answer provides a similar method using `set_trainable_params` to set the `requires_grad` attribute of LoRA parameters to `True` and other parameters to `False`. It includes a code example demonstrating this process. Additionally, it mentions using the `lora_finetune` recipe in torchtune, which automatically sets the LoRA parameters to trainable.\\n\\n3. **Comparison**: The submitted answer includes all the details from the expert answer regarding the use of `get_adapter_params` and `set_trainable_params`. It also provides additional information about setting the `requires_grad` attribute and using the `lora_finetune` recipe, which is not mentioned in the expert answer.\\n\\n4. **Conclusion**: The submitted answer is a superset of the expert answer as it contains all the information from the expert answer and additional details. There is no conflict between the two answers, and the additional information in the submission is consistent with the expert's explanation.\\n\\nTherefore, the correct choice is \u001b[0m\u001b[32m(\u001b[0m\u001b[32mB\u001b[0m\u001b[32m)\u001b[0m\u001b[32m The submitted answer is a superset of the expert answer and is fully consistent with it.\"\u001b[0m\n",
"\u001b[2;32m│ │ │ │ │ \u001b[0m\u001b[1m}\u001b[0m\n",
"\u001b[2;32m│ │ │ │ \u001b[0m\u001b[1m}\u001b[0m\n",
"\u001b[2;32m│ │ │ \u001b[0m\u001b[1m]\u001b[0m\n",
"\u001b[2;32m│ │ \u001b[0m\u001b[1m)\u001b[0m\n",
"\u001b[2;32m│ \u001b[0m\u001b[1m}\u001b[0m\n",
"\u001b[1m)\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"eval_rows = []\n",
"for i, session_id in enumerate(rag_attachment_agent.sessions):\n",
" session_response = client.agents.session.retrieve(agent_id=rag_attachment_agent.agent_id, session_id=session_id)\n",
" for turn in session_response.turns:\n",
" eval_rows.append({\n",
" \"input_query\": examples[i][\"input_query\"],\n",
" \"expected_answer\": examples[i][\"expected_answer\"],\n",
" \"generated_answer\": turn.output_message.content,\n",
" })\n",
"\n",
"scoring_params = {\n",
" \"braintrust::factuality\": None,\n",
"}\n",
"scoring_response = client.scoring.score(\n",
" input_rows=eval_rows,\n",
" scoring_functions=scoring_params,\n",
")\n",
"pprint(scoring_response)"
]
},
{
"cell_type": "markdown",
"metadata": {},

View file

@ -14,7 +14,7 @@ Agents are configured using the `AgentConfig` class, which includes:
- **Safety Shields**: Guardrails to ensure responsible AI behavior
```python
from llama_stack_client.lib.agents.agent import Agent
from llama_stack_client import Agent
# Create the agent
@ -44,14 +44,14 @@ Each interaction with an agent is called a "turn" and consists of:
- **Output Message**: The agent's response
```python
from llama_stack_client.lib.agents.event_logger import EventLogger
from llama_stack_client import AgentEventLogger
# Create a turn with streaming response
turn_response = agent.create_turn(
session_id=session_id,
messages=[{"role": "user", "content": "Tell me about Llama models"}],
)
for log in EventLogger().log(turn_response):
for log in AgentEventLogger().log(turn_response):
log.print()
```
### Non-Streaming

View file

@ -67,9 +67,7 @@ sequenceDiagram
Each step in this process can be monitored and controlled through configurations. Here's an example that demonstrates monitoring the agent's execution:
```python
from llama_stack_client import LlamaStackClient
from llama_stack_client.lib.agents.agent import Agent
from llama_stack_client.lib.agents.event_logger import EventLogger
from llama_stack_client import LlamaStackClient, Agent, AgentEventLogger
from rich.pretty import pprint
# Replace host and port
@ -113,7 +111,7 @@ response = agent.create_turn(
)
# Monitor each step of execution
for log in EventLogger().log(response):
for log in AgentEventLogger().log(response):
log.print()
# Using non-streaming API, the response contains input, steps, and output.

View file

@ -23,9 +23,7 @@ In this example, we will show you how to:
##### Building a Search Agent
```python
from llama_stack_client import LlamaStackClient
from llama_stack_client.lib.agents.agent import Agent
from llama_stack_client.lib.agents.event_logger import EventLogger
from llama_stack_client import LlamaStackClient, Agent, AgentEventLogger
client = LlamaStackClient(base_url=f"http://{HOST}:{PORT}")
@ -54,7 +52,7 @@ for prompt in user_prompts:
session_id=session_id,
)
for log in EventLogger().log(response):
for log in AgentEventLogger().log(response):
log.print()
```

View file

@ -55,11 +55,11 @@ chunks_response = client.vector_io.query(
A better way to ingest documents is to use the RAG Tool. This tool allows you to ingest documents from URLs, files, etc. and automatically chunks them into smaller pieces.
```python
from llama_stack_client.types import Document
from llama_stack_client import RAGDocument
urls = ["memory_optimizations.rst", "chat.rst", "llama3.rst"]
documents = [
Document(
RAGDocument(
document_id=f"num-{i}",
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
mime_type="text/plain",
@ -86,7 +86,7 @@ results = client.tool_runtime.rag_tool.query(
One of the most powerful patterns is combining agents with RAG capabilities. Here's a complete example:
```python
from llama_stack_client.lib.agents.agent import Agent
from llama_stack_client import Agent
# Create agent with memory
agent = Agent(
@ -140,9 +140,9 @@ response = agent.create_turn(
You can print the response with below.
```python
from llama_stack_client.lib.agents.event_logger import EventLogger
from llama_stack_client import AgentEventLogger
for log in EventLogger().log(response):
for log in AgentEventLogger().log(response):
log.print()
```

View file

@ -57,7 +57,7 @@ The `otel` sink works with any service compatible with the OpenTelemetry collect
Start a Jaeger instance with the OTLP HTTP endpoint at 4318 and the Jaeger UI at 16686 using the following command:
```bash
$ docker run --rm --name jaeger \
$ docker run --pull always --rm --name jaeger \
-p 16686:16686 -p 4318:4318 \
jaegertracing/jaeger:2.1.0
```

View file

@ -110,10 +110,18 @@ MCP tools are special tools that can interact with llama stack over model contex
Refer to [https://github.com/modelcontextprotocol/servers](https://github.com/modelcontextprotocol/servers) for available MCP servers.
```shell
# start your MCP server
mkdir /tmp/content
touch /tmp/content/foo
touch /tmp/content/bar
npx -y supergateway --port 8000 --stdio 'npx -y @modelcontextprotocol/server-filesystem /tmp/content'
```
Then register the MCP server as a tool group,
```python
# Register MCP tools
client.toolgroups.register(
toolgroup_id="builtin::filesystem",
toolgroup_id="mcp::filesystem",
provider_id="model-context-protocol",
mcp_endpoint=URL(uri="http://localhost:8000/sse"),
)
@ -181,7 +189,7 @@ group_tools = client.tools.list_tools(toolgroup_id="search_tools")
## Simple Example: Using an Agent with the Code-Interpreter Tool
```python
from llama_stack_client.lib.agents.agent import Agent
from llama_stack_client import Agent
# Instantiate the AI agent with the given configuration
agent = Agent(

View file

@ -55,7 +55,7 @@ llama stack run llama_stack/templates/open-benchmark/run.yaml
There are 3 necessary inputs to run a benchmark eval
- `list of benchmark_ids`: The list of benchmark ids to run evaluation on
- `model-id`: The model id to evaluate on
- `utput_dir`: Path to store the evaluate results
- `output_dir`: Path to store the evaluate results
```
llama-stack-client eval run-benchmark <benchmark_id_1> <benchmark_id_2> ... \
--model_id <model id to evaluate on> \
@ -69,7 +69,7 @@ llama-stack-client eval run-benchmark help
to see the description of all the flags that eval run-benchmark has
In the output log, you can find the file path that has your evaluation results. Open that file and you can see you aggrgate
In the output log, you can find the file path that has your evaluation results. Open that file and you can see you aggregate
evaluation results over there.

View file

@ -58,9 +58,10 @@ You can do this via Conda (build code) or Docker which has a pre-built image.
This method allows you to get started quickly without having to build the distribution code.
```bash
LLAMA_STACK_PORT=5001
LLAMA_STACK_PORT=8321
docker run \
-it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ./run.yaml:/root/my-run.yaml \
llamastack/distribution-nvidia \
@ -74,7 +75,7 @@ docker run \
```bash
llama stack build --template nvidia --image-type conda
llama stack run ./run.yaml \
--port 5001 \
--port 8321 \
--env NVIDIA_API_KEY=$NVIDIA_API_KEY
--env INFERENCE_MODEL=$INFERENCE_MODEL
```

View file

@ -28,7 +28,7 @@ The `llamastack/distribution-bedrock` distribution consists of the following pro
The following environment variables can be configured:
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `8321`)
### Models
@ -53,9 +53,10 @@ You can do this via Conda (build code) or Docker which has a pre-built image.
This method allows you to get started quickly without having to build the distribution code.
```bash
LLAMA_STACK_PORT=5001
LLAMA_STACK_PORT=8321
docker run \
-it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
llamastack/distribution-bedrock \
--port $LLAMA_STACK_PORT \

View file

@ -20,7 +20,7 @@ The `llamastack/distribution-cerebras` distribution consists of the following pr
The following environment variables can be configured:
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `8321`)
- `CEREBRAS_API_KEY`: Cerebras API Key (default: ``)
### Models
@ -45,9 +45,10 @@ You can do this via Conda (build code) or Docker which has a pre-built image.
This method allows you to get started quickly without having to build the distribution code.
```bash
LLAMA_STACK_PORT=5001
LLAMA_STACK_PORT=8321
docker run \
-it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ./run.yaml:/root/my-run.yaml \
llamastack/distribution-cerebras \
@ -61,6 +62,6 @@ docker run \
```bash
llama stack build --template cerebras --image-type conda
llama stack run ./run.yaml \
--port 5001 \
--port 8321 \
--env CEREBRAS_API_KEY=$CEREBRAS_API_KEY
```

View file

@ -53,7 +53,7 @@ docker compose down
#### Start Dell-TGI server locally
```
docker run -it --shm-size 1g -p 80:80 --gpus 4 \
docker run -it --pull always --shm-size 1g -p 80:80 --gpus 4 \
-e NUM_SHARD=4
-e MAX_BATCH_PREFILL_TOKENS=32768 \
-e MAX_INPUT_TOKENS=8000 \
@ -65,7 +65,7 @@ registry.dell.huggingface.co/enterprise-dell-inference-meta-llama-meta-llama-3.1
#### Start Llama Stack server pointing to TGI server
```
docker run --network host -it -p 8321:8321 -v ./run.yaml:/root/my-run.yaml --gpus=all llamastack/distribution-tgi --yaml_config /root/my-run.yaml
docker run --pull always --network host -it -p 8321:8321 -v ./run.yaml:/root/my-run.yaml --gpus=all llamastack/distribution-tgi --yaml_config /root/my-run.yaml
```
Make sure in you `run.yaml` file, you inference provider is pointing to the correct TGI server endpoint. E.g.

View file

@ -55,6 +55,7 @@ export CUDA_VISIBLE_DEVICES=0
export LLAMA_STACK_PORT=8321
docker run --rm -it \
--pull always \
--network host \
-v $HOME/.cache/huggingface:/data \
-e HF_TOKEN=$HF_TOKEN \
@ -78,6 +79,7 @@ export SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
export CUDA_VISIBLE_DEVICES=1
docker run --rm -it \
--pull always \
--network host \
-v $HOME/.cache/huggingface:/data \
-e HF_TOKEN=$HF_TOKEN \
@ -120,6 +122,7 @@ This method allows you to get started quickly without having to build the distri
```bash
docker run -it \
--pull always \
--network host \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v $HOME/.llama:/root/.llama \
@ -147,6 +150,7 @@ export SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
docker run \
-it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v $HOME/.llama:/root/.llama \
-v ./llama_stack/templates/tgi/run-with-safety.yaml:/root/my-run.yaml \

View file

@ -30,7 +30,7 @@ The `llamastack/distribution-fireworks` distribution consists of the following p
The following environment variables can be configured:
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `8321`)
- `FIREWORKS_API_KEY`: Fireworks.AI API Key (default: ``)
### Models
@ -63,9 +63,10 @@ You can do this via Conda (build code) or Docker which has a pre-built image.
This method allows you to get started quickly without having to build the distribution code.
```bash
LLAMA_STACK_PORT=5001
LLAMA_STACK_PORT=8321
docker run \
-it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
llamastack/distribution-fireworks \
--port $LLAMA_STACK_PORT \

View file

@ -30,7 +30,7 @@ The `llamastack/distribution-groq` distribution consists of the following provid
The following environment variables can be configured:
- `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
- `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `8321`)
- `GROQ_API_KEY`: Groq API Key (default: ``)
### Models
@ -58,9 +58,10 @@ You can do this via Conda (build code) or Docker which has a pre-built image.
This method allows you to get started quickly without having to build the distribution code.
```bash
LLAMA_STACK_PORT=5001
LLAMA_STACK_PORT=8321
docker run \
-it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
llamastack/distribution-groq \
--port $LLAMA_STACK_PORT \

View file

@ -32,7 +32,7 @@ Note that you need access to nvidia GPUs to run this distribution. This distribu
The following environment variables can be configured:
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `8321`)
- `INFERENCE_MODEL`: Inference model loaded into the Meta Reference server (default: `meta-llama/Llama-3.2-3B-Instruct`)
- `INFERENCE_CHECKPOINT_DIR`: Directory containing the Meta Reference model checkpoint (default: `null`)
- `SAFETY_MODEL`: Name of the safety (Llama-Guard) model to use (default: `meta-llama/Llama-Guard-3-1B`)
@ -77,9 +77,10 @@ You can do this via Conda (build code) or Docker which has a pre-built image.
This method allows you to get started quickly without having to build the distribution code.
```bash
LLAMA_STACK_PORT=5001
LLAMA_STACK_PORT=8321
docker run \
-it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ~/.llama:/root/.llama \
llamastack/distribution-meta-reference-gpu \
@ -92,6 +93,7 @@ If you are using Llama Stack Safety / Shield APIs, use:
```bash
docker run \
-it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ~/.llama:/root/.llama \
llamastack/distribution-meta-reference-gpu \
@ -107,7 +109,7 @@ Make sure you have done `uv pip install llama-stack` and have the Llama Stack CL
```bash
llama stack build --template meta-reference-gpu --image-type conda
llama stack run distributions/meta-reference-gpu/run.yaml \
--port 5001 \
--port 8321 \
--env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
```
@ -115,7 +117,7 @@ If you are using Llama Stack Safety / Shield APIs, use:
```bash
llama stack run distributions/meta-reference-gpu/run-with-safety.yaml \
--port 5001 \
--port 8321 \
--env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \
--env SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
```

View file

@ -34,7 +34,7 @@ Note that you need access to nvidia GPUs to run this distribution. This distribu
The following environment variables can be configured:
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `8321`)
- `INFERENCE_MODEL`: Inference model loaded into the Meta Reference server (default: `meta-llama/Llama-3.2-3B-Instruct`)
- `INFERENCE_CHECKPOINT_DIR`: Directory containing the Meta Reference model checkpoint (default: `null`)
@ -77,9 +77,10 @@ You can do this via Conda (build code) or Docker which has a pre-built image.
This method allows you to get started quickly without having to build the distribution code.
```bash
LLAMA_STACK_PORT=5001
LLAMA_STACK_PORT=8321
docker run \
-it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ~/.llama:/root/.llama \
llamastack/distribution-meta-reference-quantized-gpu \
@ -92,6 +93,7 @@ If you are using Llama Stack Safety / Shield APIs, use:
```bash
docker run \
-it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ~/.llama:/root/.llama \
llamastack/distribution-meta-reference-quantized-gpu \

View file

@ -15,7 +15,7 @@ The `llamastack/distribution-nvidia` distribution consists of the following prov
The following environment variables can be configured:
- `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
- `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `8321`)
- `NVIDIA_API_KEY`: NVIDIA API Key (default: ``)
### Models
@ -39,9 +39,10 @@ You can do this via Conda (build code) or Docker which has a pre-built image.
This method allows you to get started quickly without having to build the distribution code.
```bash
LLAMA_STACK_PORT=5001
LLAMA_STACK_PORT=8321
docker run \
-it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ./run.yaml:/root/my-run.yaml \
llamastack/distribution-nvidia \
@ -55,6 +56,6 @@ docker run \
```bash
llama stack build --template nvidia --image-type conda
llama stack run ./run.yaml \
--port 5001 \
--port 8321 \
--env NVIDIA_API_KEY=$NVIDIA_API_KEY
```

View file

@ -32,7 +32,7 @@ You should use this distribution if you have a regular desktop machine without v
The following environment variables can be configured:
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `8321`)
- `OLLAMA_URL`: URL of the Ollama server (default: `http://127.0.0.1:11434`)
- `INFERENCE_MODEL`: Inference model loaded into the Ollama server (default: `meta-llama/Llama-3.2-3B-Instruct`)
- `SAFETY_MODEL`: Safety model loaded into the Ollama server (default: `meta-llama/Llama-Guard-3-1B`)
@ -71,9 +71,10 @@ Now you are ready to run Llama Stack with Ollama as the inference provider. You
This method allows you to get started quickly without having to build the distribution code.
```bash
export LLAMA_STACK_PORT=5001
export LLAMA_STACK_PORT=8321
docker run \
-it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ~/.llama:/root/.llama \
llamastack/distribution-ollama \
@ -91,6 +92,7 @@ cd /path/to/llama-stack
docker run \
-it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ~/.llama:/root/.llama \
-v ./llama_stack/templates/ollama/run-with-safety.yaml:/root/my-run.yaml \
@ -107,7 +109,7 @@ docker run \
Make sure you have done `uv pip install llama-stack` and have the Llama Stack CLI available.
```bash
export LLAMA_STACK_PORT=5001
export LLAMA_STACK_PORT=8321
llama stack build --template ollama --image-type conda
llama stack run ./run.yaml \

View file

@ -30,7 +30,7 @@ The `llamastack/distribution-passthrough` distribution consists of the following
The following environment variables can be configured:
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `8321`)
- `PASSTHROUGH_API_KEY`: Passthrough API Key (default: ``)
- `PASSTHROUGH_URL`: Passthrough URL (default: ``)

View file

@ -31,7 +31,7 @@ You can use this distribution if you have GPUs and want to run an independent vL
The following environment variables can be configured:
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `8321`)
- `INFERENCE_MODEL`: Inference model loaded into the vLLM server (default: `meta-llama/Llama-3.2-3B-Instruct`)
- `VLLM_URL`: URL of the vLLM server with the main inference model (default: `http://host.docker.internal:5100/v1`)
- `MAX_TOKENS`: Maximum number of tokens for generation (default: `4096`)
@ -49,6 +49,7 @@ export INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
export CUDA_VISIBLE_DEVICES=0
docker run \
--pull always \
--runtime nvidia \
--gpus $CUDA_VISIBLE_DEVICES \
-v ~/.cache/huggingface:/root/.cache/huggingface \
@ -61,6 +62,8 @@ docker run \
--port $INFERENCE_PORT
```
Note that you'll also need to set `--enable-auto-tool-choice` and `--tool-call-parser` to [enable tool calling in vLLM](https://docs.vllm.ai/en/latest/features/tool_calling.html).
If you are using Llama Stack Safety / Shield APIs, then you will need to also run another instance of a vLLM with a corresponding safety model like `meta-llama/Llama-Guard-3-1B` using a script like:
```bash
@ -69,6 +72,7 @@ export SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
export CUDA_VISIBLE_DEVICES=1
docker run \
--pull always \
--runtime nvidia \
--gpus $CUDA_VISIBLE_DEVICES \
-v ~/.cache/huggingface:/root/.cache/huggingface \
@ -92,10 +96,11 @@ This method allows you to get started quickly without having to build the distri
```bash
export INFERENCE_PORT=8000
export INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
export LLAMA_STACK_PORT=5001
export LLAMA_STACK_PORT=8321
docker run \
-it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ./run.yaml:/root/my-run.yaml \
llamastack/distribution-remote-vllm \
@ -117,6 +122,7 @@ cd /path/to/llama-stack
docker run \
-it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ~/.llama:/root/.llama \
-v ./llama_stack/templates/remote-vllm/run-with-safety.yaml:/root/my-run.yaml \
@ -137,7 +143,7 @@ Make sure you have done `uv pip install llama-stack` and have the Llama Stack CL
```bash
export INFERENCE_PORT=8000
export INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
export LLAMA_STACK_PORT=5001
export LLAMA_STACK_PORT=8321
cd distributions/remote-vllm
llama stack build --template remote-vllm --image-type conda

View file

@ -27,7 +27,7 @@ The `llamastack/distribution-sambanova` distribution consists of the following p
The following environment variables can be configured:
- `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
- `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `8321`)
- `SAMBANOVA_API_KEY`: SambaNova API Key (default: ``)
### Models
@ -59,9 +59,10 @@ You can do this via Conda (build code) or Docker which has a pre-built image.
This method allows you to get started quickly without having to build the distribution code.
```bash
LLAMA_STACK_PORT=5001
LLAMA_STACK_PORT=8321
docker run \
-it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
llamastack/distribution-sambanova \
--port $LLAMA_STACK_PORT \

View file

@ -33,7 +33,7 @@ You can use this distribution if you have GPUs and want to run an independent TG
The following environment variables can be configured:
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `8321`)
- `INFERENCE_MODEL`: Inference model loaded into the TGI server (default: `meta-llama/Llama-3.2-3B-Instruct`)
- `TGI_URL`: URL of the TGI server with the main inference model (default: `http://127.0.0.1:8080/v1`)
- `TGI_SAFETY_URL`: URL of the TGI server with the safety model (default: `http://127.0.0.1:8081/v1`)
@ -50,6 +50,7 @@ export INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
export CUDA_VISIBLE_DEVICES=0
docker run --rm -it \
--pull always \
-v $HOME/.cache/huggingface:/data \
-p $INFERENCE_PORT:$INFERENCE_PORT \
--gpus $CUDA_VISIBLE_DEVICES \
@ -70,6 +71,7 @@ export SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
export CUDA_VISIBLE_DEVICES=1
docker run --rm -it \
--pull always \
-v $HOME/.cache/huggingface:/data \
-p $SAFETY_PORT:$SAFETY_PORT \
--gpus $CUDA_VISIBLE_DEVICES \
@ -90,9 +92,10 @@ Now you are ready to run Llama Stack with TGI as the inference provider. You can
This method allows you to get started quickly without having to build the distribution code.
```bash
LLAMA_STACK_PORT=5001
LLAMA_STACK_PORT=8321
docker run \
-it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
llamastack/distribution-tgi \
--port $LLAMA_STACK_PORT \
@ -109,6 +112,7 @@ cd /path/to/llama-stack
docker run \
-it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ~/.llama:/root/.llama \
-v ./llama_stack/templates/tgi/run-with-safety.yaml:/root/my-run.yaml \

View file

@ -30,7 +30,7 @@ The `llamastack/distribution-together` distribution consists of the following pr
The following environment variables can be configured:
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `8321`)
- `TOGETHER_API_KEY`: Together.AI API Key (default: ``)
### Models
@ -64,9 +64,10 @@ You can do this via Conda (build code) or Docker which has a pre-built image.
This method allows you to get started quickly without having to build the distribution code.
```bash
LLAMA_STACK_PORT=5001
LLAMA_STACK_PORT=8321
docker run \
-it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
llamastack/distribution-together \
--port $LLAMA_STACK_PORT \

View file

@ -54,6 +54,7 @@ mkdir -p ~/.llama
Then you can start the server using the container tool of your choice. For example, if you are running Docker you can use the following command:
```bash
docker run -it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ~/.llama:/root/.llama \
llamastack/distribution-ollama \
@ -74,6 +75,7 @@ Docker containers run in their own isolated network namespaces on Linux. To allo
Linux users having issues running the above command should instead try the following:
```bash
docker run -it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ~/.llama:/root/.llama \
--network=host \
@ -197,9 +199,7 @@ import os
import uuid
from termcolor import cprint
from llama_stack_client.lib.agents.agent import Agent
from llama_stack_client.lib.agents.event_logger import EventLogger
from llama_stack_client.types import Document
from llama_stack_client import Agent, AgentEventLogger, RAGDocument
def create_http_client():
@ -225,7 +225,7 @@ client = (
# Documents to be used for RAG
urls = ["chat.rst", "llama3.rst", "memory_optimizations.rst", "lora_finetune.rst"]
documents = [
Document(
RAGDocument(
document_id=f"num-{i}",
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
mime_type="text/plain",
@ -284,7 +284,7 @@ for prompt in user_prompts:
messages=[{"role": "user", "content": prompt}],
session_id=session_id,
)
for log in EventLogger().log(response):
for log in AgentEventLogger().log(response):
log.print()
```

View file

@ -118,6 +118,7 @@ Playground can also be started in a docker image:
export LLAMA_STACK_URL=http://localhost:11434
docker run \
--pull always \
-p 8501:8501 \
-e LLAMA_STACK_ENDPOINT=$LLAMA_STACK_URL \
quay.io/jland/llama-stack-playground

View file

@ -48,7 +48,7 @@
"outputs": [],
"source": [
"HOST = \"localhost\" # Replace with your host\n",
"PORT = 5001 # Replace with your port\n",
"PORT = 8321 # Replace with your port\n",
"MODEL_NAME='meta-llama/Llama-3.2-3B-Instruct'"
]
},
@ -369,6 +369,9 @@
}
],
"metadata": {
"fileHeader": "",
"fileUid": "7da25939-a2a3-463c-958e-9cdfd710d158",
"isAdHoc": false,
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
@ -386,7 +389,5 @@
"pygments_lexer": "ipython3",
"version": "3.10.15"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
}

View file

@ -43,7 +43,7 @@
"source": [
"#### 2. Set Up Local and Cloud Clients\n",
"\n",
"Initialize both clients, specifying the `base_url` for each instance. In this case, we have the local distribution running on `http://localhost:8321` and the cloud distribution running on `http://localhost:5001`.\n"
"Initialize both clients, specifying the `base_url` for each instance. In this case, we have the local distribution running on `http://localhost:8321` and the cloud distribution running on `http://localhost:8322`.\n"
]
},
{
@ -236,6 +236,9 @@
}
],
"metadata": {
"fileHeader": "",
"fileUid": "e11939ac-dfbc-4a1c-83be-e494c7f803b8",
"isAdHoc": false,
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
@ -253,7 +256,5 @@
"pygments_lexer": "ipython3",
"version": "3.10.15"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
}

View file

@ -47,7 +47,7 @@
"outputs": [],
"source": [
"HOST = \"localhost\" # Replace with your host\n",
"PORT = 5001 # Replace with your port\n",
"PORT = 8321 # Replace with your port\n",
"MODEL_NAME='meta-llama/Llama-3.2-3B-Instruct'"
]
},
@ -281,6 +281,9 @@
}
],
"metadata": {
"fileHeader": "",
"fileUid": "b1b93b6e-22a2-4c24-8cb0-161fdafff29a",
"isAdHoc": false,
"kernelspec": {
"display_name": "base",
"language": "python",
@ -298,7 +301,5 @@
"pygments_lexer": "ipython3",
"version": "3.12.2"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
}

View file

@ -45,7 +45,7 @@
"outputs": [],
"source": [
"HOST = \"localhost\" # Replace with your host\n",
"CLOUD_PORT = 5001 # Replace with your cloud distro port\n",
"CLOUD_PORT = 8321 # Replace with your cloud distro port\n",
"MODEL_NAME='Llama3.2-11B-Vision-Instruct'"
]
},
@ -180,6 +180,9 @@
}
],
"metadata": {
"fileHeader": "",
"fileUid": "37bbbfda-8e42-446c-89c7-59dd49e2d339",
"isAdHoc": false,
"kernelspec": {
"display_name": "base",
"language": "python",
@ -197,7 +200,5 @@
"pygments_lexer": "ipython3",
"version": "3.12.2"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
}

View file

@ -46,7 +46,7 @@
"nest_asyncio.apply()\n",
"\n",
"HOST = \"localhost\"\n",
"PORT = 5001\n",
"PORT = 8321\n",
"MODEL_NAME = \"meta-llama/Llama-3.2-3B-Instruct\"\n"
]
},
@ -296,7 +296,7 @@
"\n",
" # Create an agent instance with the client and configuration\n",
" agent = Agent(\n",
" client, \n",
" client,\n",
" model=MODEL_NAME,\n",
" instructions=\"\"\"You are a helpful assistant that responds to user queries with relevant information and cites sources when available.\"\"\",\n",
" sampling_params={\n",
@ -335,6 +335,9 @@
}
],
"metadata": {
"fileHeader": "",
"fileUid": "f0abbf6d-ed52-40ad-afb4-f5ec99130249",
"isAdHoc": false,
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
@ -352,7 +355,5 @@
"pygments_lexer": "ipython3",
"version": "3.10.15"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
}

View file

@ -45,7 +45,7 @@
"outputs": [],
"source": [
"HOST = \"localhost\" # Replace with your host\n",
"PORT = 5001 # Replace with your port\n",
"PORT = 8321 # Replace with your port\n",
"MODEL_NAME='meta-llama/Llama-3.2-3B-Instruct'\n",
"MEMORY_BANK_ID=\"tutorial_bank\""
]
@ -378,6 +378,9 @@
}
],
"metadata": {
"fileHeader": "",
"fileUid": "73bc3357-0e5e-42ff-95b1-40b916d24c4f",
"isAdHoc": false,
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
@ -395,7 +398,5 @@
"pygments_lexer": "ipython3",
"version": "3.10.15"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
}

View file

@ -49,7 +49,7 @@
"outputs": [],
"source": [
"HOST = \"localhost\" # Replace with your host\n",
"PORT = 5001 # Replace with your port\n",
"PORT = 8321 # Replace with your port\n",
"SHEILD_NAME=\"meta-llama/Llama-Guard-3-1B\""
]
},
@ -112,6 +112,9 @@
}
],
"metadata": {
"fileHeader": "",
"fileUid": "9afaddb7-c2fb-4309-8fa0-761697de53f0",
"isAdHoc": false,
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
@ -129,7 +132,5 @@
"pygments_lexer": "ipython3",
"version": "3.11.10"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
}

View file

@ -50,7 +50,7 @@
"outputs": [],
"source": [
"HOST = \"localhost\" # Replace with your host\n",
"PORT = 5001 # Replace with your port\n",
"PORT = 8321 # Replace with your port\n",
"MODEL_NAME = \"meta-llama/Llama-3.2-3B-Instruct\"\n"
]
},
@ -115,7 +115,7 @@
"async def agent_example():\n",
" client = LlamaStackClient(base_url=f\"http://{HOST}:{PORT}\")\n",
" agent = Agent(\n",
" client, \n",
" client,\n",
" model=MODEL_NAME,\n",
" instructions=\"You are a helpful assistant! If you call builtin tools like brave search, follow the syntax brave_search.call(…)\",\n",
" sampling_params={\n",
@ -168,6 +168,9 @@
}
],
"metadata": {
"fileHeader": "",
"fileUid": "8de24775-c4a0-49c7-904e-608264f69292",
"isAdHoc": false,
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
@ -185,7 +188,5 @@
"pygments_lexer": "ipython3",
"version": "3.10.15"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
}

View file

@ -96,7 +96,7 @@ If you're looking for more specific topics, we have a [Zero to Hero Guide](#next
3. **Set the ENV variables by exporting them to the terminal**:
```bash
export OLLAMA_URL="http://localhost:11434"
export LLAMA_STACK_PORT=5001
export LLAMA_STACK_PORT=8321
export INFERENCE_MODEL="meta-llama/Llama-3.2-3B-Instruct"
export SAFETY_MODEL="meta-llama/Llama-Guard-3-1B"
```
@ -112,7 +112,7 @@ If you're looking for more specific topics, we have a [Zero to Hero Guide](#next
```
Note: Every time you run a new model with `ollama run`, you will need to restart the llama stack. Otherwise it won't see the new model.
The server will start and listen on `http://localhost:5001`.
The server will start and listen on `http://localhost:8321`.
---
## Test with `llama-stack-client` CLI
@ -120,11 +120,11 @@ After setting up the server, open a new terminal window and configure the llama-
1. Configure the CLI to point to the llama-stack server.
```bash
llama-stack-client configure --endpoint http://localhost:5001
llama-stack-client configure --endpoint http://localhost:8321
```
**Expected Output:**
```bash
Done! You can now use the Llama Stack Client CLI with endpoint http://localhost:5001
Done! You can now use the Llama Stack Client CLI with endpoint http://localhost:8321
```
2. Test the CLI by running inference:
```bash
@ -218,7 +218,7 @@ if INFERENCE_MODEL is None:
raise ValueError("The environment variable 'INFERENCE_MODEL' is not set.")
# Initialize the clien
client = LlamaStackClient(base_url="http://localhost:5001")
client = LlamaStackClient(base_url="http://localhost:8321")
# Create a chat completion reques
response = client.inference.chat_completion(

View file

@ -36,7 +36,6 @@ from llama_stack.apis.inference import (
)
from llama_stack.apis.safety import SafetyViolation
from llama_stack.apis.tools import ToolDef
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
@ -189,13 +188,11 @@ class AgentToolGroupWithArgs(BaseModel):
args: Dict[str, Any]
AgentToolGroup = register_schema(
Union[
AgentToolGroup = Union[
str,
AgentToolGroupWithArgs,
],
name="AgentTool",
)
]
register_schema(AgentToolGroup, name="AgentTool")
class AgentConfigCommon(BaseModel):
@ -312,8 +309,7 @@ class AgentTurnResponseTurnAwaitingInputPayload(BaseModel):
turn: Turn
AgentTurnResponseEventPayload = register_schema(
Annotated[
AgentTurnResponseEventPayload = Annotated[
Union[
AgentTurnResponseStepStartPayload,
AgentTurnResponseStepProgressPayload,
@ -323,9 +319,8 @@ AgentTurnResponseEventPayload = register_schema(
AgentTurnResponseTurnAwaitingInputPayload,
],
Field(discriminator="event_type"),
],
name="AgentTurnResponseEventPayload",
)
]
register_schema(AgentTurnResponseEventPayload, name="AgentTurnResponseEventPayload")
@json_schema_type
@ -387,7 +382,6 @@ class AgentStepResponse(BaseModel):
@runtime_checkable
@trace_protocol
class Agents(Protocol):
"""Agents API for creating and interacting with agentic systems.
@ -399,7 +393,7 @@ class Agents(Protocol):
- Agents can also use Memory to retrieve information from knowledge bases. See the RAG Tool and Vector IO APIs for more details.
"""
@webmethod(route="/agents", method="POST")
@webmethod(route="/agents", method="POST", descriptive_name="create_agent")
async def create_agent(
self,
agent_config: AgentConfig,
@ -411,7 +405,9 @@ class Agents(Protocol):
"""
...
@webmethod(route="/agents/{agent_id}/session/{session_id}/turn", method="POST")
@webmethod(
route="/agents/{agent_id}/session/{session_id}/turn", method="POST", descriptive_name="create_agent_turn"
)
async def create_agent_turn(
self,
agent_id: str,
@ -443,6 +439,7 @@ class Agents(Protocol):
@webmethod(
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume",
method="POST",
descriptive_name="resume_agent_turn",
)
async def resume_agent_turn(
self,
@ -505,7 +502,7 @@ class Agents(Protocol):
"""
...
@webmethod(route="/agents/{agent_id}/session", method="POST")
@webmethod(route="/agents/{agent_id}/session", method="POST", descriptive_name="create_agent_session")
async def create_agent_session(
self,
agent_id: str,

View file

@ -63,19 +63,15 @@ class TextContentItem(BaseModel):
# other modalities can be added here
InterleavedContentItem = register_schema(
Annotated[
InterleavedContentItem = Annotated[
Union[ImageContentItem, TextContentItem],
Field(discriminator="type"),
],
name="InterleavedContentItem",
)
]
register_schema(InterleavedContentItem, name="InterleavedContentItem")
# accept a single "str" as a special case since it is common
InterleavedContent = register_schema(
Union[str, InterleavedContentItem, List[InterleavedContentItem]],
name="InterleavedContent",
)
InterleavedContent = Union[str, InterleavedContentItem, List[InterleavedContentItem]]
register_schema(InterleavedContent, name="InterleavedContent")
@json_schema_type
@ -109,10 +105,8 @@ class ToolCallDelta(BaseModel):
# streaming completions send a stream of ContentDeltas
ContentDelta = register_schema(
Annotated[
ContentDelta = Annotated[
Union[TextDelta, ImageDelta, ToolCallDelta],
Field(discriminator="type"),
],
name="ContentDelta",
)
]
register_schema(ContentDelta, name="ContentDelta")

View file

@ -10,14 +10,14 @@ from pydantic import BaseModel
from llama_stack.schema_utils import json_schema_type
@json_schema_type
class Job(BaseModel):
job_id: str
@json_schema_type
class JobStatus(Enum):
completed = "completed"
in_progress = "in_progress"
failed = "failed"
scheduled = "scheduled"
@json_schema_type
class Job(BaseModel):
job_id: str
status: JobStatus

View file

@ -72,8 +72,7 @@ class DialogType(BaseModel):
type: Literal["dialog"] = "dialog"
ParamType = register_schema(
Annotated[
ParamType = Annotated[
Union[
StringType,
NumberType,
@ -87,9 +86,8 @@ ParamType = register_schema(
AgentTurnInputType,
],
Field(discriminator="type"),
],
name="ParamType",
)
]
register_schema(ParamType, name="ParamType")
"""
# TODO: recursive definition of ParamType in these containers

View file

@ -84,13 +84,11 @@ class RowsDataSource(BaseModel):
rows: List[Dict[str, Any]]
DataSource = register_schema(
Annotated[
DataSource = Annotated[
Union[URIDataSource, RowsDataSource],
Field(discriminator="type"),
],
name="DataSource",
)
]
register_schema(DataSource, name="DataSource")
class CommonDatasetFields(BaseModel):

View file

@ -10,7 +10,7 @@ from pydantic import BaseModel, Field
from typing_extensions import Annotated
from llama_stack.apis.agents import AgentConfig
from llama_stack.apis.common.job_types import Job, JobStatus
from llama_stack.apis.common.job_types import Job
from llama_stack.apis.inference import SamplingParams, SystemMessage
from llama_stack.apis.scoring import ScoringResult
from llama_stack.apis.scoring_functions import ScoringFnParams
@ -43,10 +43,8 @@ class AgentCandidate(BaseModel):
config: AgentConfig
EvalCandidate = register_schema(
Annotated[Union[ModelCandidate, AgentCandidate], Field(discriminator="type")],
name="EvalCandidate",
)
EvalCandidate = Annotated[Union[ModelCandidate, AgentCandidate], Field(discriminator="type")]
register_schema(EvalCandidate, name="EvalCandidate")
@json_schema_type
@ -117,7 +115,7 @@ class Eval(Protocol):
"""
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="GET")
async def job_status(self, benchmark_id: str, job_id: str) -> JobStatus:
async def job_status(self, benchmark_id: str, job_id: str) -> Job:
"""Get the status of a job.
:param benchmark_id: The ID of the benchmark to run the evaluation on.

View file

@ -144,8 +144,7 @@ class CompletionMessage(BaseModel):
tool_calls: Optional[List[ToolCall]] = Field(default_factory=list)
Message = register_schema(
Annotated[
Message = Annotated[
Union[
UserMessage,
SystemMessage,
@ -153,9 +152,8 @@ Message = register_schema(
CompletionMessage,
],
Field(discriminator="role"),
],
name="Message",
)
]
register_schema(Message, name="Message")
@json_schema_type
@ -263,13 +261,11 @@ class GrammarResponseFormat(BaseModel):
bnf: Dict[str, Any]
ResponseFormat = register_schema(
Annotated[
ResponseFormat = Annotated[
Union[JsonSchemaResponseFormat, GrammarResponseFormat],
Field(discriminator="type"),
],
name="ResponseFormat",
)
]
register_schema(ResponseFormat, name="ResponseFormat")
# This is an internally used class

View file

@ -24,17 +24,6 @@ class HealthInfo(BaseModel):
# TODO: add a provider level status
@json_schema_type
class ProviderInfo(BaseModel):
api: str
provider_id: str
provider_type: str
class ListProvidersResponse(BaseModel):
data: List[ProviderInfo]
@json_schema_type
class VersionInfo(BaseModel):
version: str
@ -46,9 +35,6 @@ class ListRoutesResponse(BaseModel):
@runtime_checkable
class Inspect(Protocol):
@webmethod(route="/inspect/providers", method="GET")
async def list_providers(self) -> ListProvidersResponse: ...
@webmethod(route="/inspect/routes", method="GET")
async def list_routes(self) -> ListRoutesResponse: ...

View file

@ -6,7 +6,7 @@
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Protocol
from typing import Any, Dict, List, Literal, Optional, Protocol, Union
from pydantic import BaseModel, Field
from typing_extensions import Annotated
@ -88,10 +88,8 @@ class QATFinetuningConfig(BaseModel):
group_size: int
AlgorithmConfig = register_schema(
Annotated[LoraFinetuningConfig | QATFinetuningConfig, Field(discriminator="type")],
name="AlgorithmConfig",
)
AlgorithmConfig = Annotated[Union[LoraFinetuningConfig, QATFinetuningConfig], Field(discriminator="type")]
register_schema(AlgorithmConfig, name="AlgorithmConfig")
@json_schema_type
@ -184,7 +182,7 @@ class PostTraining(Protocol):
description="Model descriptor from `llama model list`",
),
checkpoint_dir: Optional[str] = None,
algorithm_config: Optional[LoraFinetuningConfig | QATFinetuningConfig] = None,
algorithm_config: Optional[AlgorithmConfig] = None,
) -> PostTrainingJob: ...
@webmethod(route="/post-training/preference-optimize", method="POST")

View file

@ -36,6 +36,7 @@ class ScoringFnParamsType(Enum):
@json_schema_type
class AggregationFunctionType(Enum):
average = "average"
weighted_average = "weighted_average"
median = "median"
categorical_count = "categorical_count"
accuracy = "accuracy"
@ -78,17 +79,15 @@ class BasicScoringFnParams(BaseModel):
)
ScoringFnParams = register_schema(
Annotated[
ScoringFnParams = Annotated[
Union[
LLMAsJudgeScoringFnParams,
RegexParserScoringFnParams,
BasicScoringFnParams,
],
Field(discriminator="type"),
],
name="ScoringFnParams",
)
]
register_schema(ScoringFnParams, name="ScoringFnParams")
class CommonScoringFnFields(BaseModel):

View file

@ -146,16 +146,14 @@ class SpanEndPayload(BaseModel):
status: SpanStatus
StructuredLogPayload = register_schema(
Annotated[
StructuredLogPayload = Annotated[
Union[
SpanStartPayload,
SpanEndPayload,
],
Field(discriminator="type"),
],
name="StructuredLogPayload",
)
]
register_schema(StructuredLogPayload, name="StructuredLogPayload")
@json_schema_type
@ -164,17 +162,15 @@ class StructuredLogEvent(EventCommon):
payload: StructuredLogPayload
Event = register_schema(
Annotated[
Event = Annotated[
Union[
UnstructuredLogEvent,
MetricEvent,
StructuredLogEvent,
],
Field(discriminator="type"),
],
name="Event",
)
]
register_schema(Event, name="Event")
@json_schema_type

View file

@ -58,16 +58,14 @@ class LLMRAGQueryGeneratorConfig(BaseModel):
template: str
RAGQueryGeneratorConfig = register_schema(
Annotated[
RAGQueryGeneratorConfig = Annotated[
Union[
DefaultRAGQueryGeneratorConfig,
LLMRAGQueryGeneratorConfig,
],
Field(discriminator="type"),
],
name="RAGQueryGeneratorConfig",
)
]
register_schema(RAGQueryGeneratorConfig, name="RAGQueryGeneratorConfig")
@json_schema_type

View file

@ -69,7 +69,7 @@ class ToolGroup(Resource):
@json_schema_type
class ToolInvocationResult(BaseModel):
content: InterleavedContent
content: Optional[InterleavedContent] = None
error_message: Optional[str] = None
error_code: Optional[int] = None
metadata: Optional[Dict[str, Any]] = None
@ -140,9 +140,9 @@ class SpecialToolGroup(Enum):
@runtime_checkable
@trace_protocol
class ToolRuntime(Protocol):
tool_store: ToolStore
tool_store: ToolStore | None = None
rag_tool: RAGToolRuntime
rag_tool: RAGToolRuntime | None = None
# TODO: This needs to be renamed once OPEN API generator name conflict issue is fixed.
@webmethod(route="/tool-runtime/list-tools", method="GET")

View file

@ -36,7 +36,7 @@ class VectorDBStore(Protocol):
@runtime_checkable
@trace_protocol
class VectorIO(Protocol):
vector_db_store: VectorDBStore
vector_db_store: VectorDBStore | None = None
# this will just block now until chunks are inserted, but it should
# probably return a Job instance which can be polled for completion

View file

@ -404,7 +404,7 @@ def _download_from_manifest(manifest_file: str, max_concurrent_downloads: int):
d = json.load(f)
manifest = Manifest(**d)
if datetime.now(timezone.utc) > manifest.expires_on:
if datetime.now(timezone.utc) > manifest.expires_on.astimezone(timezone.utc):
raise ValueError(f"Manifest URLs have expired on {manifest.expires_on}")
console = Console()

View file

@ -0,0 +1,86 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, Optional
from llama_stack.distribution.datatypes import AccessAttributes
from llama_stack.log import get_logger
logger = get_logger(__name__, category="core")
def check_access(
obj_identifier: str,
obj_attributes: Optional[AccessAttributes],
user_attributes: Optional[Dict[str, Any]] = None,
) -> bool:
"""Check if the current user has access to the given object, based on access attributes.
Access control algorithm:
1. If the resource has no access_attributes, access is GRANTED to all authenticated users
2. If the user has no attributes, access is DENIED to any object with access_attributes defined
3. For each attribute category in the resource's access_attributes:
a. If the user lacks that category, access is DENIED
b. If the user has the category but none of the required values, access is DENIED
c. If the user has at least one matching value in each required category, access is GRANTED
Example:
# Resource requires:
access_attributes = AccessAttributes(
roles=["admin", "data-scientist"],
teams=["ml-team"]
)
# User has:
user_attributes = {
"roles": ["data-scientist", "engineer"],
"teams": ["ml-team", "infra-team"],
"projects": ["llama-3"]
}
# Result: Access GRANTED
# - User has the "data-scientist" role (matches one of the required roles)
# - AND user is part of the "ml-team" (matches the required team)
# - The extra "projects" attribute is ignored
Args:
obj_identifier: The identifier of the resource object to check access for
obj_attributes: The access attributes of the resource object
user_attributes: The attributes of the current user
Returns:
bool: True if access is granted, False if denied
"""
# If object has no access attributes, allow access by default
if not obj_attributes:
return True
# If no user attributes, deny access to objects with access control
if not user_attributes:
return False
dict_attribs = obj_attributes.model_dump(exclude_none=True)
if not dict_attribs:
return True
# Check each attribute category (requires ALL categories to match)
# TODO: formalize this into a proper ABAC policy
for attr_key, required_values in dict_attribs.items():
user_values = user_attributes.get(attr_key, [])
if not user_values:
logger.debug(f"Access denied to {obj_identifier}: missing required attribute category '{attr_key}'")
return False
if not any(val in user_values for val in required_values):
logger.debug(
f"Access denied to {obj_identifier}: "
f"no match for attribute '{attr_key}', required one of {required_values}"
)
return False
logger.debug(f"Access granted to {obj_identifier}")
return True

View file

@ -90,6 +90,7 @@ RUN apt-get update && apt-get install -y \
procps psmisc lsof \
traceroute \
bubblewrap \
gcc \
&& rm -rf /var/lib/apt/lists/*
ENV UV_SYSTEM_PYTHON=1
@ -235,7 +236,7 @@ image_tag="$image_name:$version_tag"
# Detect platform architecture
ARCH=$(uname -m)
if [ -n "$BUILD_PLATFORM" ]; then
CLI_ARGS+=("--platform $BUILD_PLATFORM")
CLI_ARGS+=("--platform" "$BUILD_PLATFORM")
elif [ "$ARCH" = "arm64" ] || [ "$ARCH" = "aarch64" ]; then
CLI_ARGS+=("--platform" "linux/arm64")
elif [ "$ARCH" = "x86_64" ]; then

View file

@ -14,6 +14,7 @@ from llama_stack.apis.datasets import Dataset, DatasetInput
from llama_stack.apis.eval import Eval
from llama_stack.apis.inference import Inference
from llama_stack.apis.models import Model, ModelInput
from llama_stack.apis.resource import Resource
from llama_stack.apis.safety import Safety
from llama_stack.apis.scoring import Scoring
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnInput
@ -31,6 +32,115 @@ LLAMA_STACK_RUN_CONFIG_VERSION = "2"
RoutingKey = Union[str, List[str]]
class AccessAttributes(BaseModel):
"""Structured representation of user attributes for access control.
This model defines a structured approach to representing user attributes
with common standard categories for access control.
Standard attribute categories include:
- roles: Role-based attributes (e.g., admin, data-scientist)
- teams: Team-based attributes (e.g., ml-team, infra-team)
- projects: Project access attributes (e.g., llama-3, customer-insights)
- namespaces: Namespace-based access control for resource isolation
"""
# Standard attribute categories - the minimal set we need now
roles: Optional[List[str]] = Field(
default=None, description="Role-based attributes (e.g., 'admin', 'data-scientist', 'user')"
)
teams: Optional[List[str]] = Field(default=None, description="Team-based attributes (e.g., 'ml-team', 'nlp-team')")
projects: Optional[List[str]] = Field(
default=None, description="Project-based access attributes (e.g., 'llama-3', 'customer-insights')"
)
namespaces: Optional[List[str]] = Field(
default=None, description="Namespace-based access control for resource isolation"
)
class ResourceWithACL(Resource):
"""Extension of Resource that adds attribute-based access control capabilities.
This class adds an optional access_attributes field that allows fine-grained control
over which users can access each resource. When attributes are defined, a user must have
matching attributes to access the resource.
Attribute Matching Algorithm:
1. If a resource has no access_attributes (None or empty dict), it's visible to all authenticated users
2. Each key in access_attributes represents an attribute category (e.g., "roles", "teams", "projects")
3. The matching algorithm requires ALL categories to match (AND relationship between categories)
4. Within each category, ANY value match is sufficient (OR relationship within a category)
Examples:
# Resource visible to everyone (no access control)
model = Model(identifier="llama-2", ...)
# Resource visible only to admins
model = Model(
identifier="gpt-4",
access_attributes=AccessAttributes(roles=["admin"])
)
# Resource visible to data scientists on the ML team
model = Model(
identifier="private-model",
access_attributes=AccessAttributes(
roles=["data-scientist", "researcher"],
teams=["ml-team"]
)
)
# ^ User must have at least one of the roles AND be on the ml-team
# Resource visible to users with specific project access
vector_db = VectorDB(
identifier="customer-embeddings",
access_attributes=AccessAttributes(
projects=["customer-insights"],
namespaces=["confidential"]
)
)
# ^ User must have access to the customer-insights project AND have confidential namespace
"""
access_attributes: Optional[AccessAttributes] = None
# Use the extended Resource for all routable objects
class ModelWithACL(Model, ResourceWithACL):
pass
class ShieldWithACL(Shield, ResourceWithACL):
pass
class VectorDBWithACL(VectorDB, ResourceWithACL):
pass
class DatasetWithACL(Dataset, ResourceWithACL):
pass
class ScoringFnWithACL(ScoringFn, ResourceWithACL):
pass
class BenchmarkWithACL(Benchmark, ResourceWithACL):
pass
class ToolWithACL(Tool, ResourceWithACL):
pass
class ToolGroupWithACL(ToolGroup, ResourceWithACL):
pass
RoutableObject = Union[
Model,
Shield,
@ -45,14 +155,14 @@ RoutableObject = Union[
RoutableObjectWithProvider = Annotated[
Union[
Model,
Shield,
VectorDB,
Dataset,
ScoringFn,
Benchmark,
Tool,
ToolGroup,
ModelWithACL,
ShieldWithACL,
VectorDBWithACL,
DatasetWithACL,
ScoringFnWithACL,
BenchmarkWithACL,
ToolWithACL,
ToolGroupWithACL,
],
Field(discriminator="type"),
]

View file

@ -11,9 +11,7 @@ from pydantic import BaseModel
from llama_stack.apis.inspect import (
HealthInfo,
Inspect,
ListProvidersResponse,
ListRoutesResponse,
ProviderInfo,
RouteInfo,
VersionInfo,
)
@ -39,24 +37,6 @@ class DistributionInspectImpl(Inspect):
async def initialize(self) -> None:
pass
async def list_providers(self) -> ListProvidersResponse:
run_config = self.config.run_config
ret = []
for api, providers in run_config.providers.items():
ret.extend(
[
ProviderInfo(
api=api,
provider_id=p.provider_id,
provider_type=p.provider_type,
)
for p in providers
]
)
return ListProvidersResponse(data=ret)
async def list_routes(self) -> ListRoutesResponse:
run_config = self.config.run_config

View file

@ -9,7 +9,6 @@ import inspect
import json
import logging
import os
import re
from concurrent.futures import ThreadPoolExecutor
from enum import Enum
from pathlib import Path
@ -37,7 +36,10 @@ from llama_stack.distribution.request_headers import (
request_provider_data_context,
)
from llama_stack.distribution.resolver import ProviderRegistry
from llama_stack.distribution.server.endpoints import get_all_api_endpoints
from llama_stack.distribution.server.endpoints import (
find_matching_endpoint,
initialize_endpoint_impls,
)
from llama_stack.distribution.stack import (
construct_stack,
get_stack_run_config_from_template,
@ -232,31 +234,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
safe_config = redact_sensitive_fields(self.config.model_dump())
console.print(yaml.dump(safe_config, indent=2))
endpoints = get_all_api_endpoints()
endpoint_impls = {}
def _convert_path_to_regex(path: str) -> str:
# Convert {param} to named capture groups
# handle {param:path} as well which allows for forward slashes in the param value
pattern = re.sub(
r"{(\w+)(?::path)?}",
lambda m: f"(?P<{m.group(1)}>{'[^/]+' if not m.group(0).endswith(':path') else '.+'})",
path,
)
return f"^{pattern}$"
for api, api_endpoints in endpoints.items():
if api not in self.impls:
continue
for endpoint in api_endpoints:
impl = self.impls[api]
func = getattr(impl, endpoint.name)
if endpoint.method not in endpoint_impls:
endpoint_impls[endpoint.method] = {}
endpoint_impls[endpoint.method][_convert_path_to_regex(endpoint.route)] = func
self.endpoint_impls = endpoint_impls
self.endpoint_impls = initialize_endpoint_impls(self.impls)
return True
async def request(
@ -290,32 +268,6 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
)
return response
def _find_matching_endpoint(self, method: str, path: str) -> tuple[Any, dict]:
"""Find the matching endpoint implementation for a given method and path.
Args:
method: HTTP method (GET, POST, etc.)
path: URL path to match against
Returns:
A tuple of (endpoint_function, path_params)
Raises:
ValueError: If no matching endpoint is found
"""
impls = self.endpoint_impls.get(method)
if not impls:
raise ValueError(f"No endpoint found for {path}")
for regex, func in impls.items():
match = re.match(regex, path)
if match:
# Extract named groups from the regex match
path_params = match.groupdict()
return func, path_params
raise ValueError(f"No endpoint found for {path}")
async def _call_non_streaming(
self,
*,
@ -326,10 +278,10 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
body = options.params or {}
body |= options.json_data or {}
matched_func, path_params = self._find_matching_endpoint(options.method, path)
matched_func, path_params, route = find_matching_endpoint(options.method, path, self.endpoint_impls)
body |= path_params
body = self._convert_body(path, options.method, body)
await start_trace(options.url, {"__location__": "library_client"})
await start_trace(route, {"__location__": "library_client"})
try:
result = await matched_func(**body)
finally:
@ -371,13 +323,13 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
path = options.url
body = options.params or {}
body |= options.json_data or {}
func, path_params = self._find_matching_endpoint(options.method, path)
func, path_params, route = find_matching_endpoint(options.method, path, self.endpoint_impls)
body |= path_params
body = self._convert_body(path, options.method, body)
async def gen():
await start_trace(options.url, {"__location__": "library_client"})
await start_trace(route, {"__location__": "library_client"})
try:
async for chunk in await func(**body):
data = json.dumps(convert_pydantic_to_json_value(chunk))
@ -422,7 +374,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
if not body:
return {}
func, _ = self._find_matching_endpoint(method, path)
func, _, _ = find_matching_endpoint(method, path, self.endpoint_impls)
sig = inspect.signature(func)
# Strip NOT_GIVENs to use the defaults in signature

View file

@ -7,21 +7,26 @@
import contextvars
import json
import logging
from typing import Any, ContextManager, Dict, Optional
from typing import Any, ContextManager, Dict, List, Optional
from .utils.dynamic import instantiate_class_type
log = logging.getLogger(__name__)
# Context variable for request provider data
# Context variable for request provider data and auth attributes
PROVIDER_DATA_VAR = contextvars.ContextVar("provider_data", default=None)
class RequestProviderDataContext(ContextManager):
"""Context manager for request provider data"""
def __init__(self, provider_data: Optional[Dict[str, Any]] = None):
self.provider_data = provider_data
def __init__(
self, provider_data: Optional[Dict[str, Any]] = None, auth_attributes: Optional[Dict[str, List[str]]] = None
):
self.provider_data = provider_data or {}
if auth_attributes:
self.provider_data["__auth_attributes"] = auth_attributes
self.token = None
def __enter__(self):
@ -80,7 +85,17 @@ def parse_request_provider_data(headers: Dict[str, str]) -> Optional[Dict[str, A
return None
def request_provider_data_context(headers: Dict[str, str]) -> ContextManager:
"""Context manager that sets request provider data from headers for the duration of the context"""
def request_provider_data_context(
headers: Dict[str, str], auth_attributes: Optional[Dict[str, List[str]]] = None
) -> ContextManager:
"""Context manager that sets request provider data from headers and auth attributes for the duration of the context"""
provider_data = parse_request_provider_data(headers)
return RequestProviderDataContext(provider_data)
return RequestProviderDataContext(provider_data, auth_attributes)
def get_auth_attributes() -> Optional[Dict[str, List[str]]]:
"""Helper to retrieve auth attributes from the provider data context"""
provider_data = PROVIDER_DATA_VAR.get()
if not provider_data:
return None
return provider_data.get("__auth_attributes")

View file

@ -14,13 +14,7 @@ from llama_stack.apis.common.content_types import (
)
from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse
from llama_stack.apis.datasets import DatasetPurpose, DataSource
from llama_stack.apis.eval import (
BenchmarkConfig,
Eval,
EvaluateResponse,
Job,
JobStatus,
)
from llama_stack.apis.eval import BenchmarkConfig, Eval, EvaluateResponse, Job
from llama_stack.apis.inference import (
ChatCompletionResponse,
ChatCompletionResponseEventType,
@ -623,7 +617,7 @@ class EvalRouter(Eval):
self,
benchmark_id: str,
job_id: str,
) -> Optional[JobStatus]:
) -> Job:
logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}")
return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id)

View file

@ -41,11 +41,22 @@ from llama_stack.apis.tools import (
ToolHost,
)
from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs
from llama_stack.distribution.access_control import check_access
from llama_stack.distribution.datatypes import (
AccessAttributes,
BenchmarkWithACL,
DatasetWithACL,
ModelWithACL,
RoutableObject,
RoutableObjectWithProvider,
RoutedProtocol,
ScoringFnWithACL,
ShieldWithACL,
ToolGroupWithACL,
ToolWithACL,
VectorDBWithACL,
)
from llama_stack.distribution.request_headers import get_auth_attributes
from llama_stack.distribution.store import DistributionRegistry
from llama_stack.providers.datatypes import Api, RoutingTable
@ -186,6 +197,11 @@ class CommonRoutingTableImpl(RoutingTable):
if not obj:
return None
# Check if user has permission to access this object
if not check_access(obj.identifier, getattr(obj, "access_attributes", None), get_auth_attributes()):
logger.debug(f"Access denied to {type} '{identifier}' based on attribute mismatch")
return None
return obj
async def unregister_object(self, obj: RoutableObjectWithProvider) -> None:
@ -202,6 +218,13 @@ class CommonRoutingTableImpl(RoutingTable):
p = self.impls_by_provider_id[obj.provider_id]
# If object supports access control but no attributes set, use creator's attributes
if not obj.access_attributes:
creator_attributes = get_auth_attributes()
if creator_attributes:
obj.access_attributes = AccessAttributes(**creator_attributes)
logger.info(f"Setting access attributes for {obj.type} '{obj.identifier}' based on creator's identity")
registered_obj = await register_object_with_provider(obj, p)
# TODO: This needs to be fixed for all APIs once they return the registered object
if obj.type == ResourceType.model.value:
@ -214,7 +237,17 @@ class CommonRoutingTableImpl(RoutingTable):
async def get_all_with_type(self, type: str) -> List[RoutableObjectWithProvider]:
objs = await self.dist_registry.get_all()
return [obj for obj in objs if obj.type == type]
filtered_objs = [obj for obj in objs if obj.type == type]
# Apply attribute-based access control filtering
if filtered_objs:
filtered_objs = [
obj
for obj in filtered_objs
if check_access(obj.identifier, getattr(obj, "access_attributes", None), get_auth_attributes())
]
return filtered_objs
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
@ -251,7 +284,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
model_type = ModelType.llm
if "embedding_dimension" not in metadata and model_type == ModelType.embedding:
raise ValueError("Embedding model must have an embedding dimension in its metadata")
model = Model(
model = ModelWithACL(
identifier=model_id,
provider_resource_id=provider_model_id,
provider_id=provider_id,
@ -297,7 +330,7 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
)
if params is None:
params = {}
shield = Shield(
shield = ShieldWithACL(
identifier=shield_id,
provider_resource_id=provider_shield_id,
provider_id=provider_id,
@ -351,7 +384,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
"embedding_model": embedding_model,
"embedding_dimension": model.metadata["embedding_dimension"],
}
vector_db = TypeAdapter(VectorDB).validate_python(vector_db_data)
vector_db = TypeAdapter(VectorDBWithACL).validate_python(vector_db_data)
await self.register_object(vector_db)
return vector_db
@ -405,7 +438,7 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
if metadata is None:
metadata = {}
dataset = Dataset(
dataset = DatasetWithACL(
identifier=dataset_id,
provider_resource_id=provider_dataset_id,
provider_id=provider_id,
@ -452,7 +485,7 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
raise ValueError(
"No provider specified and multiple providers available. Please specify a provider_id."
)
scoring_fn = ScoringFn(
scoring_fn = ScoringFnWithACL(
identifier=scoring_fn_id,
description=description,
return_type=return_type,
@ -494,7 +527,7 @@ class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
)
if provider_benchmark_id is None:
provider_benchmark_id = benchmark_id
benchmark = Benchmark(
benchmark = BenchmarkWithACL(
identifier=benchmark_id,
dataset_id=dataset_id,
scoring_functions=scoring_functions,
@ -537,7 +570,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
for tool_def in tool_defs:
tools.append(
Tool(
ToolWithACL(
identifier=tool_def.name,
toolgroup_id=toolgroup_id,
description=tool_def.description or "",
@ -562,7 +595,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
await self.register_object(tool)
await self.dist_registry.register(
ToolGroup(
ToolGroupWithACL(
identifier=toolgroup_id,
provider_id=provider_id,
provider_resource_id=toolgroup_id,
@ -575,7 +608,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
tool_group = await self.get_tool_group(toolgroup_id)
if tool_group is None:
raise ValueError(f"Tool group {toolgroup_id} not found")
tools = await self.list_tools(toolgroup_id).data
tools = (await self.list_tools(toolgroup_id)).data
for tool in tools:
await self.unregister_object(tool)
await self.unregister_object(tool_group)

View file

@ -5,16 +5,118 @@
# the root directory of this source tree.
import json
from typing import Dict, List, Optional
from urllib.parse import parse_qs
import httpx
from pydantic import BaseModel, Field
from llama_stack.distribution.datatypes import AccessAttributes
from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="auth")
class AuthRequestContext(BaseModel):
path: str = Field(description="The path of the request being authenticated")
headers: Dict[str, str] = Field(description="HTTP headers from the original request (excluding Authorization)")
params: Dict[str, List[str]] = Field(
description="Query parameters from the original request, parsed as dictionary of lists"
)
class AuthRequest(BaseModel):
api_key: str = Field(description="The API key extracted from the Authorization header")
request: AuthRequestContext = Field(description="Context information about the request being authenticated")
class AuthResponse(BaseModel):
"""The format of the authentication response from the auth endpoint."""
access_attributes: Optional[AccessAttributes] = Field(
default=None,
description="""
Structured user attributes for attribute-based access control.
These attributes determine which resources the user can access.
The model provides standard categories like "roles", "teams", "projects", and "namespaces".
Each attribute category contains a list of values that the user has for that category.
During access control checks, these values are compared against resource requirements.
Example with standard categories:
```json
{
"roles": ["admin", "data-scientist"],
"teams": ["ml-team"],
"projects": ["llama-3"],
"namespaces": ["research"]
}
```
""",
)
message: Optional[str] = Field(
default=None, description="Optional message providing additional context about the authentication result."
)
class AuthenticationMiddleware:
"""Middleware that authenticates requests using an external auth endpoint.
This middleware:
1. Extracts the Bearer token from the Authorization header
2. Sends it to the configured auth endpoint along with request details
3. Validates the response and extracts user attributes
4. Makes these attributes available to the route handlers for access control
Authentication Request Format:
```json
{
"api_key": "the-api-key-extracted-from-auth-header",
"request": {
"path": "/models/list",
"headers": {
"content-type": "application/json",
"user-agent": "..."
// All headers except Authorization
},
"params": {
"limit": ["100"],
"offset": ["0"]
// Query parameters as key -> list of values
}
}
}
```
Expected Auth Endpoint Response Format:
```json
{
"access_attributes": { // Structured attribute format
"roles": ["admin", "user"],
"teams": ["ml-team", "nlp-team"],
"projects": ["llama-3", "project-x"],
"namespaces": ["research"]
},
"message": "Optional message about auth result"
}
```
Attribute-Based Access Control:
The attributes returned by the auth endpoint are used to determine which
resources the user can access. Resources can specify required attributes
using the access_attributes field. For a user to access a resource:
1. All attribute categories specified in the resource must be present in the user's attributes
2. For each category, the user must have at least one matching value
If the auth endpoint doesn't return any attributes, the user will only be able to
access resources that don't have access_attributes defined.
"""
def __init__(self, app, auth_endpoint):
self.app = app
self.auth_endpoint = auth_endpoint
@ -32,25 +134,57 @@ class AuthenticationMiddleware:
path = scope.get("path", "")
request_headers = {k.decode(): v.decode() for k, v in headers.items()}
# Remove sensitive headers
if "authorization" in request_headers:
del request_headers["authorization"]
query_string = scope.get("query_string", b"").decode()
params = parse_qs(query_string)
auth_data = {
"api_key": api_key,
"request": {
"path": path,
"headers": request_headers,
"params": params,
},
}
# Build the auth request model
auth_request = AuthRequest(
api_key=api_key,
request=AuthRequestContext(
path=path,
headers=request_headers,
params=params,
),
)
# Validate with authentication endpoint
try:
async with httpx.AsyncClient() as client:
response = await client.post(self.auth_endpoint, json=auth_data)
response = await client.post(
self.auth_endpoint,
json=auth_request.model_dump(),
timeout=10.0, # Add a reasonable timeout
)
if response.status_code != 200:
logger.warning(f"Authentication failed: {response.status_code}")
return await self._send_auth_error(send, "Authentication failed")
# Parse and validate the auth response
try:
response_data = response.json()
auth_response = AuthResponse(**response_data)
# Store attributes in request scope for access control
if auth_response.access_attributes:
user_attributes = auth_response.access_attributes.model_dump(exclude_none=True)
else:
logger.warning("No access attributes, setting namespace to api_key by default")
user_attributes = {
"namespaces": [api_key],
}
scope["user_attributes"] = user_attributes
logger.debug(f"Authentication successful: {len(user_attributes)} attributes")
except Exception:
logger.exception("Error parsing authentication response")
return await self._send_auth_error(send, "Invalid authentication response format")
except httpx.TimeoutException:
logger.exception("Authentication request timed out")
return await self._send_auth_error(send, "Authentication service timeout")
except Exception:
logger.exception("Error during authentication")
return await self._send_auth_error(send, "Authentication service error")

View file

@ -5,6 +5,7 @@
# the root directory of this source tree.
import inspect
import re
from typing import Dict, List
from pydantic import BaseModel
@ -19,6 +20,7 @@ class ApiEndpoint(BaseModel):
route: str
method: str
name: str
descriptive_name: str | None = None
def toolgroup_protocol_map():
@ -58,8 +60,69 @@ def get_all_api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
method = "delete"
else:
method = "post"
endpoints.append(ApiEndpoint(route=route, method=method, name=name))
endpoints.append(
ApiEndpoint(route=route, method=method, name=name, descriptive_name=webmethod.descriptive_name)
)
apis[api] = endpoints
return apis
def initialize_endpoint_impls(impls):
endpoints = get_all_api_endpoints()
endpoint_impls = {}
def _convert_path_to_regex(path: str) -> str:
# Convert {param} to named capture groups
# handle {param:path} as well which allows for forward slashes in the param value
pattern = re.sub(
r"{(\w+)(?::path)?}",
lambda m: f"(?P<{m.group(1)}>{'[^/]+' if not m.group(0).endswith(':path') else '.+'})",
path,
)
return f"^{pattern}$"
for api, api_endpoints in endpoints.items():
if api not in impls:
continue
for endpoint in api_endpoints:
impl = impls[api]
func = getattr(impl, endpoint.name)
if endpoint.method not in endpoint_impls:
endpoint_impls[endpoint.method] = {}
endpoint_impls[endpoint.method][_convert_path_to_regex(endpoint.route)] = (
func,
endpoint.descriptive_name or endpoint.route,
)
return endpoint_impls
def find_matching_endpoint(method, path, endpoint_impls):
"""Find the matching endpoint implementation for a given method and path.
Args:
method: HTTP method (GET, POST, etc.)
path: URL path to match against
endpoint_impls: A dictionary of endpoint implementations
Returns:
A tuple of (endpoint_function, path_params, descriptive_name)
Raises:
ValueError: If no matching endpoint is found
"""
impls = endpoint_impls.get(method.lower())
if not impls:
raise ValueError(f"No endpoint found for {path}")
for regex, (func, descriptive_name) in impls.items():
match = re.match(regex, path)
if match:
# Extract named groups from the regex match
path_params = match.groupdict()
return func, path_params, descriptive_name
raise ValueError(f"No endpoint found for {path}")

View file

@ -32,6 +32,10 @@ from llama_stack.distribution.request_headers import (
request_provider_data_context,
)
from llama_stack.distribution.resolver import InvalidProviderError
from llama_stack.distribution.server.endpoints import (
find_matching_endpoint,
initialize_endpoint_impls,
)
from llama_stack.distribution.stack import (
construct_stack,
redact_sensitive_fields,
@ -179,8 +183,11 @@ async def sse_generator(event_gen):
def create_dynamic_typed_route(func: Any, method: str, route: str):
async def endpoint(request: Request, **kwargs):
# Use context manager for request provider data
with request_provider_data_context(request.headers):
# Get auth attributes from the request scope
user_attributes = request.scope.get("user_attributes", {})
# Use context manager with both provider data and auth attributes
with request_provider_data_context(request.headers, user_attributes):
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
try:
@ -219,14 +226,30 @@ def create_dynamic_typed_route(func: Any, method: str, route: str):
class TracingMiddleware:
def __init__(self, app):
def __init__(self, app, impls):
self.app = app
self.impls = impls
async def __call__(self, scope, receive, send):
path = scope.get("path", "")
await start_trace(path, {"__location__": "server"})
try:
if scope.get("type") == "lifespan":
return await self.app(scope, receive, send)
path = scope.get("path", "")
if not hasattr(self, "endpoint_impls"):
self.endpoint_impls = initialize_endpoint_impls(self.impls)
_, _, trace_path = find_matching_endpoint(scope.get("method", "GET"), path, self.endpoint_impls)
trace_context = await start_trace(trace_path, {"__location__": "server", "raw_path": path})
async def send_with_trace_id(message):
if message["type"] == "http.response.start":
headers = message.get("headers", [])
headers.append([b"x-trace-id", str(trace_context.trace_id).encode()])
message["headers"] = headers
await send(message)
try:
return await self.app(scope, receive, send_with_trace_id)
finally:
await end_trace()
@ -348,7 +371,6 @@ def main():
logger.info(yaml.dump(safe_config, indent=2))
app = FastAPI(lifespan=lifespan)
app.add_middleware(TracingMiddleware)
if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"):
app.add_middleware(ClientVersionMiddleware)
@ -366,7 +388,7 @@ def main():
if Api.telemetry in impls:
setup_logger(impls[Api.telemetry])
else:
setup_logger(TelemetryAdapter(TelemetryConfig()))
setup_logger(TelemetryAdapter(TelemetryConfig(), {}))
all_endpoints = get_all_api_endpoints()
@ -412,6 +434,7 @@ def main():
app.exception_handler(Exception)(global_exception_handler)
app.__llama_stack_impls__ = impls
app.add_middleware(TracingMiddleware, impls=impls)
import uvicorn

View file

@ -12,9 +12,12 @@ import pydantic
from llama_stack.distribution.datatypes import KVStoreConfig, RoutableObjectWithProvider
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.log import get_logger
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
logger = get_logger(__name__, category="core")
class DistributionRegistry(Protocol):
async def get_all(self) -> List[RoutableObjectWithProvider]: ...
@ -47,8 +50,13 @@ def _parse_registry_values(values: List[str]) -> List[RoutableObjectWithProvider
"""Utility function to parse registry values into RoutableObjectWithProvider objects."""
all_objects = []
for value in values:
try:
obj = pydantic.TypeAdapter(RoutableObjectWithProvider).validate_json(value)
all_objects.append(obj)
except pydantic.ValidationError as e:
logger.error(f"Error parsing registry value, raw value: {value}. Error: {e}")
continue
return all_objects
@ -73,7 +81,11 @@ class DiskDistributionRegistry(DistributionRegistry):
if not json_str:
return None
try:
return pydantic.TypeAdapter(RoutableObjectWithProvider).validate_json(json_str)
except pydantic.ValidationError as e:
logger.error(f"Error parsing registry value for {type}:{identifier}, raw value: {json_str}. Error: {e}")
return None
async def update(self, obj: RoutableObjectWithProvider) -> None:
await self.kvstore.set(

View file

@ -5,9 +5,7 @@
# the root directory of this source tree.
import streamlit as st
from llama_stack_client.lib.agents.agent import Agent
from llama_stack_client.lib.agents.event_logger import EventLogger
from llama_stack_client.types.shared.document import Document
from llama_stack_client import Agent, AgentEventLogger, RAGDocument
from llama_stack.distribution.ui.modules.api import llama_stack_api
from llama_stack.distribution.ui.modules.utils import data_url_from_file
@ -35,7 +33,7 @@ def rag_chat_page():
)
if st.button("Create Vector Database"):
documents = [
Document(
RAGDocument(
document_id=uploaded_file.name,
content=data_url_from_file(uploaded_file),
)
@ -167,7 +165,7 @@ def rag_chat_page():
message_placeholder = st.empty()
full_response = ""
retrieval_response = ""
for log in EventLogger().log(response):
for log in AgentEventLogger().log(response):
log.print()
if log.role == "tool_execution":
retrieval_response += log.content.replace("====", "").strip()

View file

@ -186,13 +186,11 @@ class TopKSamplingStrategy(BaseModel):
top_k: int = Field(..., ge=1)
SamplingStrategy = register_schema(
Annotated[
SamplingStrategy = Annotated[
Union[GreedySamplingStrategy, TopPSamplingStrategy, TopKSamplingStrategy],
Field(discriminator="type"),
],
name="SamplingStrategy",
)
]
register_schema(SamplingStrategy, name="SamplingStrategy")
@json_schema_type

View file

@ -244,6 +244,7 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
template_str = textwrap.dedent(
"""
If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
For a boolean parameter, be sure to use `True` or `False` (capitalized) for the value.
You SHOULD NOT include any other text in the response.
Here is a list of functions in JSON format that you can invoke.

View file

@ -15,8 +15,11 @@ import json
import re
from typing import Optional, Tuple
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import BuiltinTool, RecursiveType, ToolCall, ToolPromptFormat
logger = get_logger(name=__name__, category="inference")
BUILTIN_TOOL_PATTERN = r'\b(?P<tool_name>\w+)\.call\(query="(?P<query>[^"]*)"\)'
CUSTOM_TOOL_CALL_PATTERN = re.compile(r"<function=(?P<function_name>[^}]+)>(?P<args>{.*?})")
@ -92,7 +95,15 @@ def parse_python_list_for_function_calls(input_string):
# Extract keyword arguments
for keyword in node.keywords:
try:
function_args[keyword.arg] = ast.literal_eval(keyword.value)
except ValueError as e:
logger.error(
f"Error parsing tool call argument '{keyword.arg}': {e}, full input string: '{input_string}'"
)
raise ValueError(
f"Error parsing tool call argument '{keyword.arg}', full input string: '{input_string}'"
) from e
result.append((function_name, function_args))

View file

@ -6,14 +6,12 @@
import copy
import json
import os
import re
import secrets
import string
import uuid
from datetime import datetime, timezone
from typing import AsyncGenerator, List, Optional, Union
from urllib.parse import urlparse
import httpx
@ -60,7 +58,6 @@ from llama_stack.apis.inference import (
)
from llama_stack.apis.safety import Safety
from llama_stack.apis.tools import (
RAGDocument,
ToolGroups,
ToolInvocationResult,
ToolRuntime,
@ -180,23 +177,27 @@ class ChatAgent(ShieldRunnerMixin):
return messages
async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator:
await self._initialize_tools(request.toolgroups)
async with tracing.span("create_and_execute_turn") as span:
span = tracing.get_current_span()
if span:
span.set_attribute("session_id", request.session_id)
span.set_attribute("agent_id", self.agent_id)
span.set_attribute("request", request.model_dump_json())
turn_id = str(uuid.uuid4())
span.set_attribute("turn_id", turn_id)
await self._initialize_tools(request.toolgroups)
async for chunk in self._run_turn(request, turn_id):
yield chunk
async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator:
await self._initialize_tools()
async with tracing.span("resume_turn") as span:
span = tracing.get_current_span()
if span:
span.set_attribute("agent_id", self.agent_id)
span.set_attribute("session_id", request.session_id)
span.set_attribute("turn_id", request.turn_id)
span.set_attribute("request", request.model_dump_json())
span.set_attribute("turn_id", request.turn_id)
await self._initialize_tools()
async for chunk in self._run_turn(request):
yield chunk
@ -449,8 +450,16 @@ class ChatAgent(ShieldRunnerMixin):
stream: bool = False,
documents: Optional[List[Document]] = None,
) -> AsyncGenerator:
# if document is passed in a turn, we parse the raw text of the document
# and sent it as a user message
if documents:
await self.handle_documents(session_id, documents, input_messages)
contexts = []
for document in documents:
raw_document_text = await get_raw_document_text(document)
contexts.append(raw_document_text)
attached_context = "\n".join(contexts)
input_messages[-1].context = attached_context
session_info = await self.storage.get_session_info(session_id)
# if the session has a memory bank id, let the memory tool use it
@ -825,7 +834,10 @@ class ChatAgent(ShieldRunnerMixin):
)
tool_name_to_args[tool_def.identifier] = toolgroup_to_args.get(toolgroup_name, {})
self.tool_defs, self.tool_name_to_args = list(tool_name_to_def.values()), tool_name_to_args
self.tool_defs, self.tool_name_to_args = (
list(tool_name_to_def.values()),
tool_name_to_args,
)
def _parse_toolgroup_name(self, toolgroup_name_with_maybe_tool_name: str) -> tuple[str, Optional[str]]:
"""Parse a toolgroup name into its components.
@ -876,144 +888,27 @@ class ChatAgent(ShieldRunnerMixin):
logger.debug(f"tool call {tool_name_str} completed with result: {result}")
return result
async def handle_documents(
self,
session_id: str,
documents: List[Document],
input_messages: List[Message],
) -> None:
memory_tool = any(tool_def.tool_name == MEMORY_QUERY_TOOL for tool_def in self.tool_defs)
code_interpreter_tool = any(tool_def.tool_name == BuiltinTool.code_interpreter for tool_def in self.tool_defs)
content_items = []
url_items = []
pattern = re.compile("^(https?://|file://|data:)")
for d in documents:
if isinstance(d.content, URL):
url_items.append(d.content)
elif pattern.match(d.content):
url_items.append(URL(uri=d.content))
else:
content_items.append(d)
# Save the contents to a tempdir and use its path as a URL if code interpreter is present
if code_interpreter_tool:
for c in content_items:
temp_file_path = os.path.join(self.tempdir, f"{make_random_string()}.txt")
with open(temp_file_path, "w") as temp_file:
temp_file.write(c.content)
url_items.append(URL(uri=f"file://{temp_file_path}"))
if memory_tool and code_interpreter_tool:
# if both memory and code_interpreter are available, we download the URLs
# and attach the data to the last message.
await attachment_message(self.tempdir, url_items, input_messages[-1])
# Since memory is present, add all the data to the memory bank
await self.add_to_session_vector_db(session_id, documents)
elif code_interpreter_tool:
# if only code_interpreter is available, we download the URLs to a tempdir
# and attach the path to them as a message to inference with the
# assumption that the model invokes the code_interpreter tool with the path
await attachment_message(self.tempdir, url_items, input_messages[-1])
elif memory_tool:
# if only memory is available, we load the data from the URLs and content items to the memory bank
await self.add_to_session_vector_db(session_id, documents)
else:
# if no memory or code_interpreter tool is available,
# we try to load the data from the URLs and content items as a message to inference
# and add it to the last message's context
input_messages[-1].context = "\n".join(
[doc.content for doc in content_items] + await load_data_from_urls(url_items)
)
async def _ensure_vector_db(self, session_id: str) -> str:
session_info = await self.storage.get_session_info(session_id)
if session_info is None:
raise ValueError(f"Session {session_id} not found")
if session_info.vector_db_id is None:
vector_db_id = f"vector_db_{session_id}"
# TODO: the semantic for registration is definitely not "creation"
# so we need to fix it if we expect the agent to create a new vector db
# for each session
await self.vector_io_api.register_vector_db(
vector_db_id=vector_db_id,
embedding_model="all-MiniLM-L6-v2",
)
await self.storage.add_vector_db_to_session(session_id, vector_db_id)
else:
vector_db_id = session_info.vector_db_id
return vector_db_id
async def add_to_session_vector_db(self, session_id: str, data: List[Document]) -> None:
vector_db_id = await self._ensure_vector_db(session_id)
documents = [
RAGDocument(
document_id=str(uuid.uuid4()),
content=a.content,
mime_type=a.mime_type,
metadata={},
)
for a in data
]
await self.tool_runtime_api.rag_tool.insert(
documents=documents,
vector_db_id=vector_db_id,
chunk_size_in_tokens=512,
)
async def load_data_from_urls(urls: List[URL]) -> List[str]:
data = []
for url in urls:
uri = url.uri
if uri.startswith("file://"):
filepath = uri[len("file://") :]
with open(filepath, "r") as f:
data.append(f.read())
elif uri.startswith("http"):
async def load_data_from_url(url: str) -> str:
if url.startswith("http"):
async with httpx.AsyncClient() as client:
r = await client.get(uri)
r = await client.get(url)
resp = r.text
data.append(resp)
return data
return resp
raise ValueError(f"Unexpected URL: {type(url)}")
async def attachment_message(tempdir: str, urls: List[URL], message: UserMessage) -> None:
contents = []
for url in urls:
uri = url.uri
if uri.startswith("file://"):
filepath = uri[len("file://") :]
elif uri.startswith("http"):
path = urlparse(uri).path
basename = os.path.basename(path)
filepath = f"{tempdir}/{make_random_string() + basename}"
logger.info(f"Downloading {url} -> {filepath}")
async with httpx.AsyncClient() as client:
r = await client.get(uri)
resp = r.text
with open(filepath, "w") as fp:
fp.write(resp)
async def get_raw_document_text(document: Document) -> str:
if not document.mime_type.startswith("text/"):
raise ValueError(f"Unexpected document mime type: {document.mime_type}")
if isinstance(document.content, URL):
return await load_data_from_url(document.content.uri)
elif isinstance(document.content, str):
return document.content
elif isinstance(document.content, TextContentItem):
return document.content.text
else:
raise ValueError(f"Unsupported URL {url}")
contents.append(
TextContentItem(
text=f'# User provided a file accessible to you at "{filepath}"\nYou can use code_interpreter to load and inspect it.'
)
)
if isinstance(message.content, list):
message.content.extend(contents)
else:
if isinstance(message.content, str):
message.content = [TextContentItem(text=message.content)] + contents
else:
message.content = [message.content] + contents
raise ValueError(f"Unexpected document content type: {type(document.content)}")
def _interpret_content_as_attachment(

View file

@ -13,6 +13,9 @@ from typing import List, Optional
from pydantic import BaseModel
from llama_stack.apis.agents import ToolExecutionStep, Turn
from llama_stack.distribution.access_control import check_access
from llama_stack.distribution.datatypes import AccessAttributes
from llama_stack.distribution.request_headers import get_auth_attributes
from llama_stack.providers.utils.kvstore import KVStore
log = logging.getLogger(__name__)
@ -24,6 +27,7 @@ class AgentSessionInfo(BaseModel):
# TODO: is this used anywhere?
vector_db_id: Optional[str] = None
started_at: datetime
access_attributes: Optional[AccessAttributes] = None
class AgentPersistence:
@ -33,11 +37,18 @@ class AgentPersistence:
async def create_session(self, name: str) -> str:
session_id = str(uuid.uuid4())
# Get current user's auth attributes for new sessions
auth_attributes = get_auth_attributes()
access_attributes = AccessAttributes(**auth_attributes) if auth_attributes else None
session_info = AgentSessionInfo(
session_id=session_id,
session_name=name,
started_at=datetime.now(timezone.utc),
access_attributes=access_attributes,
)
await self.kvstore.set(
key=f"session:{self.agent_id}:{session_id}",
value=session_info.model_dump_json(),
@ -51,12 +62,34 @@ class AgentPersistence:
if not value:
return None
return AgentSessionInfo(**json.loads(value))
session_info = AgentSessionInfo(**json.loads(value))
# Check access to session
if not self._check_session_access(session_info):
return None
return session_info
def _check_session_access(self, session_info: AgentSessionInfo) -> bool:
"""Check if current user has access to the session."""
# Handle backward compatibility for old sessions without access control
if not hasattr(session_info, "access_attributes"):
return True
return check_access(session_info.session_id, session_info.access_attributes, get_auth_attributes())
async def get_session_if_accessible(self, session_id: str) -> Optional[AgentSessionInfo]:
"""Get session info if the user has access to it. For internal use by sub-session methods."""
session_info = await self.get_session_info(session_id)
if not session_info:
return None
return session_info
async def add_vector_db_to_session(self, session_id: str, vector_db_id: str):
session_info = await self.get_session_info(session_id)
session_info = await self.get_session_if_accessible(session_id)
if session_info is None:
raise ValueError(f"Session {session_id} not found")
raise ValueError(f"Session {session_id} not found or access denied")
session_info.vector_db_id = vector_db_id
await self.kvstore.set(
@ -65,12 +98,18 @@ class AgentPersistence:
)
async def add_turn_to_session(self, session_id: str, turn: Turn):
if not await self.get_session_if_accessible(session_id):
raise ValueError(f"Session {session_id} not found or access denied")
await self.kvstore.set(
key=f"session:{self.agent_id}:{session_id}:{turn.turn_id}",
value=turn.model_dump_json(),
)
async def get_session_turns(self, session_id: str) -> List[Turn]:
if not await self.get_session_if_accessible(session_id):
raise ValueError(f"Session {session_id} not found or access denied")
values = await self.kvstore.range(
start_key=f"session:{self.agent_id}:{session_id}:",
end_key=f"session:{self.agent_id}:{session_id}:\xff\xff\xff\xff",
@ -87,6 +126,9 @@ class AgentPersistence:
return turns
async def get_session_turn(self, session_id: str, turn_id: str) -> Optional[Turn]:
if not await self.get_session_if_accessible(session_id):
raise ValueError(f"Session {session_id} not found or access denied")
value = await self.kvstore.get(
key=f"session:{self.agent_id}:{session_id}:{turn_id}",
)
@ -95,24 +137,36 @@ class AgentPersistence:
return Turn(**json.loads(value))
async def set_in_progress_tool_call_step(self, session_id: str, turn_id: str, step: ToolExecutionStep):
if not await self.get_session_if_accessible(session_id):
raise ValueError(f"Session {session_id} not found or access denied")
await self.kvstore.set(
key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}",
value=step.model_dump_json(),
)
async def get_in_progress_tool_call_step(self, session_id: str, turn_id: str) -> Optional[ToolExecutionStep]:
if not await self.get_session_if_accessible(session_id):
return None
value = await self.kvstore.get(
key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}",
)
return ToolExecutionStep(**json.loads(value)) if value else None
async def set_num_infer_iters_in_turn(self, session_id: str, turn_id: str, num_infer_iters: int):
if not await self.get_session_if_accessible(session_id):
raise ValueError(f"Session {session_id} not found or access denied")
await self.kvstore.set(
key=f"num_infer_iters_in_turn:{self.agent_id}:{session_id}:{turn_id}",
value=str(num_infer_iters),
)
async def get_num_infer_iters_in_turn(self, session_id: str, turn_id: str) -> Optional[int]:
if not await self.get_session_if_accessible(session_id):
return None
value = await self.kvstore.get(
key=f"num_infer_iters_in_turn:{self.agent_id}:{session_id}:{turn_id}",
)

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List
from tqdm import tqdm
@ -21,8 +21,8 @@ from llama_stack.providers.inline.agents.meta_reference.agent_instance import (
from llama_stack.providers.utils.common.data_schema_validator import ColumnName
from llama_stack.providers.utils.kvstore import kvstore_impl
from .....apis.common.job_types import Job
from .....apis.eval.eval import BenchmarkConfig, Eval, EvaluateResponse, JobStatus
from .....apis.common.job_types import Job, JobStatus
from .....apis.eval.eval import BenchmarkConfig, Eval, EvaluateResponse
from .config import MetaReferenceEvalConfig
EVAL_TASKS_PREFIX = "benchmarks:"
@ -102,7 +102,7 @@ class MetaReferenceEvalImpl(
# need job scheduler queue (ray/celery) w/ jobs api
job_id = str(len(self.jobs))
self.jobs[job_id] = res
return Job(job_id=job_id)
return Job(job_id=job_id, status=JobStatus.completed)
async def _run_agent_generation(
self, input_rows: List[Dict[str, Any]], benchmark_config: BenchmarkConfig
@ -216,17 +216,18 @@ class MetaReferenceEvalImpl(
return EvaluateResponse(generations=generations, scores=score_response.results)
async def job_status(self, benchmark_id: str, job_id: str) -> Optional[JobStatus]:
async def job_status(self, benchmark_id: str, job_id: str) -> Job:
if job_id in self.jobs:
return JobStatus.completed
return Job(job_id=job_id, status=JobStatus.completed)
return None
raise ValueError(f"Job {job_id} not found")
async def job_cancel(self, benchmark_id: str, job_id: str) -> None:
raise NotImplementedError("Job cancel is not implemented yet")
async def job_result(self, benchmark_id: str, job_id: str) -> EvaluateResponse:
status = await self.job_status(benchmark_id, job_id)
job = await self.job_status(benchmark_id, job_id)
status = job.status
if not status or status != JobStatus.completed:
raise ValueError(f"Job is not completed, Status: {status.value}")

View file

@ -23,7 +23,9 @@ from llama_stack.providers.utils.common.data_schema_validator import (
from .config import BasicScoringConfig
from .scoring_fn.bfcl_scoring_fn import BFCLScoringFn
from .scoring_fn.docvqa_scoring_fn import DocVQAScoringFn
from .scoring_fn.equality_scoring_fn import EqualityScoringFn
from .scoring_fn.ifeval_scoring_fn import IfEvalScoringFn
from .scoring_fn.regex_parser_math_response_scoring_fn import (
RegexParserMathResponseScoringFn,
)
@ -36,6 +38,8 @@ FIXED_FNS = [
RegexParserScoringFn,
RegexParserMathResponseScoringFn,
BFCLScoringFn,
IfEvalScoringFn,
DocVQAScoringFn,
]

View file

@ -0,0 +1,240 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import re
from typing import Any, Dict, Optional
from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
from .fn_defs.docvqa import docvqa
CONTRACTIONS = {
"aint": "ain't",
"arent": "aren't",
"cant": "can't",
"couldve": "could've",
"couldnt": "couldn't",
"couldn'tve": "couldn't've",
"couldnt've": "couldn't've",
"didnt": "didn't",
"doesnt": "doesn't",
"dont": "don't",
"hadnt": "hadn't",
"hadnt've": "hadn't've",
"hadn'tve": "hadn't've",
"hasnt": "hasn't",
"havent": "haven't",
"hed": "he'd",
"hed've": "he'd've",
"he'dve": "he'd've",
"hes": "he's",
"howd": "how'd",
"howll": "how'll",
"hows": "how's",
"Id've": "I'd've",
"I'dve": "I'd've",
"Im": "I'm",
"Ive": "I've",
"isnt": "isn't",
"itd": "it'd",
"itd've": "it'd've",
"it'dve": "it'd've",
"itll": "it'll",
"let's": "let's",
"maam": "ma'am",
"mightnt": "mightn't",
"mightnt've": "mightn't've",
"mightn'tve": "mightn't've",
"mightve": "might've",
"mustnt": "mustn't",
"mustve": "must've",
"neednt": "needn't",
"notve": "not've",
"oclock": "o'clock",
"oughtnt": "oughtn't",
"ow's'at": "'ow's'at",
"'ows'at": "'ow's'at",
"'ow'sat": "'ow's'at",
"shant": "shan't",
"shed've": "she'd've",
"she'dve": "she'd've",
"she's": "she's",
"shouldve": "should've",
"shouldnt": "shouldn't",
"shouldnt've": "shouldn't've",
"shouldn'tve": "shouldn't've",
"somebody'd": "somebodyd",
"somebodyd've": "somebody'd've",
"somebody'dve": "somebody'd've",
"somebodyll": "somebody'll",
"somebodys": "somebody's",
"someoned": "someone'd",
"someoned've": "someone'd've",
"someone'dve": "someone'd've",
"someonell": "someone'll",
"someones": "someone's",
"somethingd": "something'd",
"somethingd've": "something'd've",
"something'dve": "something'd've",
"somethingll": "something'll",
"thats": "that's",
"thered": "there'd",
"thered've": "there'd've",
"there'dve": "there'd've",
"therere": "there're",
"theres": "there's",
"theyd": "they'd",
"theyd've": "they'd've",
"they'dve": "they'd've",
"theyll": "they'll",
"theyre": "they're",
"theyve": "they've",
"twas": "'twas",
"wasnt": "wasn't",
"wed've": "we'd've",
"we'dve": "we'd've",
"weve": "we've",
"werent": "weren't",
"whatll": "what'll",
"whatre": "what're",
"whats": "what's",
"whatve": "what've",
"whens": "when's",
"whered": "where'd",
"wheres": "where's",
"whereve": "where've",
"whod": "who'd",
"whod've": "who'd've",
"who'dve": "who'd've",
"wholl": "who'll",
"whos": "who's",
"whove": "who've",
"whyll": "why'll",
"whyre": "why're",
"whys": "why's",
"wont": "won't",
"wouldve": "would've",
"wouldnt": "wouldn't",
"wouldnt've": "wouldn't've",
"wouldn'tve": "wouldn't've",
"yall": "y'all",
"yall'll": "y'all'll",
"y'allll": "y'all'll",
"yall'd've": "y'all'd've",
"y'alld've": "y'all'd've",
"y'all'dve": "y'all'd've",
"youd": "you'd",
"youd've": "you'd've",
"you'dve": "you'd've",
"youll": "you'll",
"youre": "you're",
"youve": "you've",
"1st": "first",
"2nd": "second",
"3rd": "third",
}
NUMBERS = {
"none": "0",
"zero": "0",
"one": "1",
"two": "2",
"three": "3",
"four": "4",
"five": "5",
"six": "6",
"seven": "7",
"eight": "8",
"nine": "9",
"ten": "10",
}
ARTICLES = [
"a",
"an",
"the",
"to",
"in",
"from",
"by",
] # Contains a bit more than just articles, but we want to get rid of these elements influencing the accuracy
PERIOD_STRIP = re.compile(r"(?!<=\d)(\.)(?!\d)")
COMMA_STRIP = re.compile(r"(\d)(\,)(\d)")
PUNCTUATION = [
";",
r"/",
"[",
"]",
'"',
"{",
"}",
"(",
")",
"=",
"+",
"\\",
"_",
"-",
">",
"<",
"@",
"`",
",",
"?",
"!",
]
def normalize_answer(s: str) -> str:
# process punctuation
for p in PUNCTUATION:
if (p + " " in s or " " + p in s) or (re.search(COMMA_STRIP, s) is not None):
s = s.replace(p, "")
else:
s = s.replace(p, " ")
s = PERIOD_STRIP.sub("", s, re.UNICODE)
# process digits and articles
temp_text = s.lower().split()
out_text = []
for word in temp_text:
word = NUMBERS.setdefault(word, word)
if word not in ARTICLES:
out_text.append(word)
# standardize contractions
for word_id, word in enumerate(out_text):
if word in CONTRACTIONS:
out_text[word_id] = CONTRACTIONS[word]
return " ".join(out_text)
class DocVQAScoringFn(RegisteredBaseScoringFn):
"""
docvqa basically matches the generated answer against several allowed
choices, but we need to normalize the answer to avoid penalizing
trivial differences
"""
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.supported_fn_defs_registry = {
docvqa.identifier: docvqa,
}
async def score_row(
self,
input_row: Dict[str, Any],
scoring_fn_identifier: Optional[str] = "docvqa",
scoring_params: Optional[ScoringFnParams] = None,
) -> ScoringResultRow:
expected_answers = json.loads(input_row["expected_answer"])
generated_answer = input_row["generated_answer"]
score = 1.0 if normalize_answer(generated_answer) in [normalize_answer(s) for s in expected_answers] else 0.0
return {
"score": score,
}

View file

@ -0,0 +1,21 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import (
AggregationFunctionType,
BasicScoringFnParams,
ScoringFn,
)
docvqa = ScoringFn(
identifier="basic::docvqa",
description="DocVQA Visual Question & Answer scoring function",
return_type=NumberType(),
provider_id="basic",
provider_resource_id="docvqa",
params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.accuracy]),
)

View file

@ -0,0 +1,23 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import (
AggregationFunctionType,
BasicScoringFnParams,
ScoringFn,
)
ifeval = ScoringFn(
identifier="basic::ifeval",
description="Eval intruction follow capacity by checkping how many instructions can be followed in each example",
return_type=NumberType(),
provider_id="basic",
provider_resource_id="ifeval",
params=BasicScoringFnParams(
aggregation_functions=[AggregationFunctionType.weighted_average],
),
)

View file

@ -0,0 +1,80 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, Optional
from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
from .fn_defs.ifeval import (
ifeval,
)
class IfEvalScoringFn(RegisteredBaseScoringFn):
"""
A scoring_fn Instruction-Following Eval (IFEval) benchmark
"""
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.supported_fn_defs_registry = {
ifeval.identifier: ifeval,
}
async def score_row(
self,
input_row: Dict[str, Any],
scoring_fn_identifier: Optional[str] = None,
scoring_params: Optional[ScoringFnParams] = None,
) -> ScoringResultRow:
from ..utils.ifeval_utils import INSTRUCTION_DICT, INSTRUCTION_LIST
assert scoring_fn_identifier is not None, "Scoring function identifier not found."
fn_def = self.supported_fn_defs_registry[scoring_fn_identifier]
if scoring_params is not None:
fn_def.params = scoring_params
instruction_list = input_row["instruction_id_list"]
generated_answer = input_row["generated_answer"].strip()
is_following_list = []
results = dict(
{k + "_correct": 0.0 for k in INSTRUCTION_LIST},
**{k + "_total": 0.0 for k in INSTRUCTION_LIST},
)
for index, instruction_id in enumerate(instruction_list):
instruction_cls = INSTRUCTION_DICT[instruction_id]
instruction = instruction_cls(instruction_id)
results[instruction_id + "_total"] += 1.0
results[instruction_id.split(":")[0] + "_total"] += 1.0
clean_input_row = {k: v for k, v in input_row["kwargs"][index].items() if v is not None}
print(clean_input_row)
instruction.build_description(**clean_input_row)
args = instruction.get_instruction_args()
if args and "prompt" in args:
instruction.build_description(prompt=input_row["prompt"])
if generated_answer and instruction.check_following(generated_answer):
is_following_list.append(True)
results[instruction_id + "_correct"] += 1.0
results[instruction_id.split(":")[0] + "_correct"] += 1.0
else:
is_following_list.append(False)
if len(is_following_list) == 0:
return {
"score": 0.0,
"weight": 0.0,
}
return {
"score": float(sum(is_following_list)) / float(len(is_following_list)),
"weight": float(len(is_following_list)),
}

File diff suppressed because it is too large Load diff

View file

@ -6,12 +6,14 @@
from typing import Any, Dict
from llama_stack.distribution.datatypes import Api
from .config import TelemetryConfig, TelemetrySink
__all__ = ["TelemetryConfig", "TelemetrySink"]
async def get_provider_impl(config: TelemetryConfig, deps: Dict[str, Any]):
async def get_provider_impl(config: TelemetryConfig, deps: Dict[Api, Any]):
from .telemetry import TelemetryAdapter
impl = TelemetryAdapter(config, deps)

View file

@ -13,19 +13,20 @@ from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
class TelemetrySink(str, Enum):
OTEL = "otel"
OTEL_TRACE = "otel_trace"
OTEL_METRIC = "otel_metric"
SQLITE = "sqlite"
CONSOLE = "console"
class TelemetryConfig(BaseModel):
otel_endpoint: str = Field(
otel_trace_endpoint: str = Field(
default="http://localhost:4318/v1/traces",
description="The OpenTelemetry collector endpoint URL",
description="The OpenTelemetry collector endpoint URL for traces",
)
service_name: str = Field(
default="llama-stack",
description="The service name to use for telemetry",
otel_metric_endpoint: str = Field(
default="http://localhost:4318/v1/metrics",
description="The OpenTelemetry collector endpoint URL for metrics",
)
sinks: List[TelemetrySink] = Field(
default=[TelemetrySink.CONSOLE, TelemetrySink.SQLITE],
@ -46,7 +47,6 @@ class TelemetryConfig(BaseModel):
@classmethod
def sample_run_config(cls, __distro_dir__: str, db_name: str = "trace_store.db") -> Dict[str, Any]:
return {
"service_name": "${env.OTEL_SERVICE_NAME:llama-stack}",
"sinks": "${env.TELEMETRY_SINKS:console,sqlite}",
"sqlite_db_path": "${env.SQLITE_DB_PATH:" + __distro_dir__ + "/" + db_name + "}",
}

View file

@ -101,6 +101,6 @@ class ConsoleSpanProcessor(SpanProcessor):
"""Shutdown the processor."""
pass
def force_flush(self, timeout_millis: float = None) -> bool:
def force_flush(self, timeout_millis: float | None = None) -> bool:
"""Force flush any pending spans."""
return True

View file

@ -12,6 +12,7 @@ from datetime import datetime, timezone
from opentelemetry.sdk.trace import SpanProcessor
from opentelemetry.trace import Span
from opentelemetry.trace.span import format_span_id, format_trace_id
class SQLiteSpanProcessor(SpanProcessor):
@ -100,14 +101,14 @@ class SQLiteSpanProcessor(SpanProcessor):
conn = self._get_connection()
cursor = conn.cursor()
trace_id = format(span.get_span_context().trace_id, "032x")
span_id = format(span.get_span_context().span_id, "016x")
trace_id = format_trace_id(span.get_span_context().trace_id)
span_id = format_span_id(span.get_span_context().span_id)
service_name = span.resource.attributes.get("service.name", "unknown")
parent_span_id = None
parent_context = span.parent
if parent_context:
parent_span_id = format(parent_context.span_id, "016x")
parent_span_id = format_span_id(parent_context.span_id)
# Insert into traces
cursor.execute(
@ -123,7 +124,7 @@ class SQLiteSpanProcessor(SpanProcessor):
(
trace_id,
service_name,
(span_id if not parent_span_id else None),
(span_id if span.attributes.get("__root_span__") == "true" else None),
datetime.fromtimestamp(span.start_time / 1e9, timezone.utc).isoformat(),
datetime.fromtimestamp(span.end_time / 1e9, timezone.utc).isoformat(),
),

View file

@ -44,7 +44,7 @@ from llama_stack.providers.utils.telemetry.sqlite_trace_store import SQLiteTrace
from .config import TelemetryConfig, TelemetrySink
_GLOBAL_STORAGE = {
_GLOBAL_STORAGE: dict[str, dict[str | int, Any]] = {
"active_spans": {},
"counters": {},
"gauges": {},
@ -54,30 +54,21 @@ _global_lock = threading.Lock()
_TRACER_PROVIDER = None
def string_to_trace_id(s: str) -> int:
# Convert the string to bytes and then to an integer
return int.from_bytes(s.encode(), byteorder="big", signed=False)
def string_to_span_id(s: str) -> int:
# Use only the first 8 bytes (64 bits) for span ID
return int.from_bytes(s.encode()[:8], byteorder="big", signed=False)
def is_tracing_enabled(tracer):
with tracer.start_as_current_span("check_tracing") as span:
return span.is_recording()
class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
def __init__(self, config: TelemetryConfig, deps: Dict[str, Any]) -> None:
def __init__(self, config: TelemetryConfig, deps: Dict[Api, Any]) -> None:
self.config = config
self.datasetio_api = deps.get(Api.datasetio)
self.meter = None
resource = Resource.create(
{
ResourceAttributes.SERVICE_NAME: self.config.service_name,
# service name is always the same, use zero-width space to avoid clutter
ResourceAttributes.SERVICE_NAME: "",
}
)
@ -91,15 +82,16 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
provider = TracerProvider(resource=resource)
trace.set_tracer_provider(provider)
_TRACER_PROVIDER = provider
if TelemetrySink.OTEL in self.config.sinks:
otlp_exporter = OTLPSpanExporter(
endpoint=self.config.otel_endpoint,
if TelemetrySink.OTEL_TRACE in self.config.sinks:
span_exporter = OTLPSpanExporter(
endpoint=self.config.otel_trace_endpoint,
)
span_processor = BatchSpanProcessor(otlp_exporter)
span_processor = BatchSpanProcessor(span_exporter)
trace.get_tracer_provider().add_span_processor(span_processor)
if TelemetrySink.OTEL_METRIC in self.config.sinks:
metric_reader = PeriodicExportingMetricReader(
OTLPMetricExporter(
endpoint=self.config.otel_endpoint,
endpoint=self.config.otel_metric_endpoint,
)
)
metric_provider = MeterProvider(resource=resource, metric_readers=[metric_reader])
@ -109,7 +101,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
if TelemetrySink.CONSOLE in self.config.sinks:
trace.get_tracer_provider().add_span_processor(ConsoleSpanProcessor())
if TelemetrySink.OTEL in self.config.sinks:
if TelemetrySink.OTEL_METRIC in self.config.sinks:
self.meter = metrics.get_meter(__name__)
if TelemetrySink.SQLITE in self.config.sinks:
self.trace_store = SQLiteTraceStore(self.config.sqlite_db_path)
@ -135,7 +127,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
def _log_unstructured(self, event: UnstructuredLogEvent, ttl_seconds: int) -> None:
with self._lock:
# Use global storage instead of instance storage
span_id = string_to_span_id(event.span_id)
span_id = event.span_id
span = _GLOBAL_STORAGE["active_spans"].get(span_id)
if span:
@ -146,7 +138,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
"message": event.message,
"severity": event.severity.value,
"__ttl__": ttl_seconds,
**event.attributes,
**(event.attributes or {}),
},
timestamp=timestamp_ns,
)
@ -154,6 +146,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
print(f"Warning: No active span found for span_id {span_id}. Dropping event: {event}")
def _get_or_create_counter(self, name: str, unit: str) -> metrics.Counter:
assert self.meter is not None
if name not in _GLOBAL_STORAGE["counters"]:
_GLOBAL_STORAGE["counters"][name] = self.meter.create_counter(
name=name,
@ -163,6 +156,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
return _GLOBAL_STORAGE["counters"][name]
def _get_or_create_gauge(self, name: str, unit: str) -> metrics.ObservableGauge:
assert self.meter is not None
if name not in _GLOBAL_STORAGE["gauges"]:
_GLOBAL_STORAGE["gauges"][name] = self.meter.create_gauge(
name=name,
@ -182,6 +176,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
up_down_counter.add(event.value, attributes=event.attributes)
def _get_or_create_up_down_counter(self, name: str, unit: str) -> metrics.UpDownCounter:
assert self.meter is not None
if name not in _GLOBAL_STORAGE["up_down_counters"]:
_GLOBAL_STORAGE["up_down_counters"][name] = self.meter.create_up_down_counter(
name=name,
@ -192,8 +187,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
def _log_structured(self, event: StructuredLogEvent, ttl_seconds: int) -> None:
with self._lock:
span_id = string_to_span_id(event.span_id)
trace_id = string_to_trace_id(event.trace_id)
span_id = int(event.span_id, 16)
tracer = trace.get_tracer(__name__)
if event.attributes is None:
event.attributes = {}
@ -204,14 +198,23 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
if span_id in _GLOBAL_STORAGE["active_spans"]:
return
parent_span = None
context = None
if event.payload.parent_span_id:
parent_span_id = string_to_span_id(event.payload.parent_span_id)
parent_span_id = int(event.payload.parent_span_id, 16)
parent_span = _GLOBAL_STORAGE["active_spans"].get(parent_span_id)
context = trace.Context(trace_id=trace_id)
if parent_span:
context = trace.set_span_in_context(parent_span, context)
context = trace.set_span_in_context(parent_span)
else:
context = trace.set_span_in_context(
trace.NonRecordingSpan(
trace.SpanContext(
trace_id=int(event.trace_id, 16),
span_id=span_id,
is_remote=False,
trace_flags=trace.TraceFlags(trace.TraceFlags.SAMPLED),
)
)
)
event.attributes["__root_span__"] = "true"
span = tracer.start_span(
name=event.payload.name,

View file

@ -69,7 +69,7 @@ def popen_not_allowed(*args, **kwargs):
)
_subprocess.Popen = popen_not_allowed
_subprocess.Popen = popen_not_allowed # type: ignore
import atexit as _atexit
@ -104,7 +104,7 @@ def _open_connections():
return _NETWORK_CONNECTIONS
_builtins._open_connections = _open_connections
_builtins._open_connections = _open_connections # type: ignore
@_atexit.register

View file

@ -161,9 +161,9 @@ _set_seeds()\
def process_matplotlib_response(response, matplotlib_dump_dir: str):
image_data = response["image_data"]
# Convert the base64 string to a bytes object
images = [base64.b64decode(d["image_base64"]) for d in image_data]
images_raw = [base64.b64decode(d["image_base64"]) for d in image_data]
# Create a list of PIL images from the bytes objects
images = [Image.open(BytesIO(img)) for img in images]
images = [Image.open(BytesIO(img)) for img in images_raw]
# Create a list of image paths
image_paths = []
for i, img in enumerate(images):

Some files were not shown because too many files have changed in this diff Show more