Merge branch 'main' into feat/litellm_sambanova_usage

This commit is contained in:
Jorge Piedrahita Ortiz 2025-03-24 08:02:40 -05:00 committed by GitHub
commit 8783dd8162
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
190 changed files with 8649 additions and 3304 deletions

View file

@ -14,6 +14,10 @@ on:
- 'requirements.txt' - 'requirements.txt'
- '.github/workflows/integration-tests.yml' # This workflow - '.github/workflows/integration-tests.yml' # This workflow
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
jobs: jobs:
test-matrix: test-matrix:
runs-on: ubuntu-latest runs-on: ubuntu-latest
@ -52,6 +56,7 @@ jobs:
# always test against the latest version of the client # always test against the latest version of the client
uv pip install git+https://github.com/meta-llama/llama-stack-client-python.git@main uv pip install git+https://github.com/meta-llama/llama-stack-client-python.git@main
uv pip install -e . uv pip install -e .
llama stack build --template ollama --image-type venv
- name: Wait for Ollama to start - name: Wait for Ollama to start
run: | run: |

View file

@ -5,6 +5,10 @@ on:
push: push:
branches: [main] branches: [main]
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
jobs: jobs:
pre-commit: pre-commit:
runs-on: ubuntu-latest runs-on: ubuntu-latest

View file

@ -18,6 +18,10 @@ on:
- 'llama_stack/distribution/*.sh' - 'llama_stack/distribution/*.sh'
- '.github/workflows/providers-build.yml' - '.github/workflows/providers-build.yml'
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
jobs: jobs:
generate-matrix: generate-matrix:
runs-on: ubuntu-latest runs-on: ubuntu-latest

View file

@ -8,6 +8,10 @@ on:
- reopened - reopened
- synchronize - synchronize
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
permissions: permissions:
contents: read contents: read

View file

@ -15,6 +15,10 @@ on:
- '.github/workflows/unit-tests.yml' # This workflow - '.github/workflows/unit-tests.yml' # This workflow
workflow_dispatch: workflow_dispatch:
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
jobs: jobs:
unit-tests: unit-tests:
runs-on: ubuntu-latest runs-on: ubuntu-latest

View file

@ -22,6 +22,10 @@ on:
- 'pyproject.toml' - 'pyproject.toml'
- '.github/workflows/update-readthedocs.yml' - '.github/workflows/update-readthedocs.yml'
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
jobs: jobs:
update-readthedocs: update-readthedocs:
runs-on: ubuntu-latest runs-on: ubuntu-latest

View file

@ -89,10 +89,11 @@ repos:
name: API Spec Codegen name: API Spec Codegen
additional_dependencies: additional_dependencies:
- uv==0.6.2 - uv==0.6.2
entry: sh -c 'uv run --with ".[dev]" ./docs/openapi_generator/run_openapi_generator.sh > /dev/null 2>&1' entry: sh -c 'uv run --with ".[dev]" ./docs/openapi_generator/run_openapi_generator.sh > /dev/null'
language: python language: python
pass_filenames: false pass_filenames: false
require_serial: true require_serial: true
files: ^llama_stack/apis/|^docs/openapi_generator/
ci: ci:
autofix_commit_msg: 🎨 [pre-commit.ci] Auto format from pre-commit.com hooks autofix_commit_msg: 🎨 [pre-commit.ci] Auto format from pre-commit.com hooks

View file

@ -135,9 +135,11 @@ uv sync
## Coding Style ## Coding Style
* Comments should provide meaningful insights into the code. Avoid filler comments that simply describe the next step, as they create unnecessary clutter, same goes for docstrings.
* Prefer comments to clarify surprising behavior and/or relationships between parts of the code rather than explain what the next line of code does.
* Catching exceptions, prefer using a specific exception type rather than a broad catch-all like `Exception`.
* Error messages should be prefixed with "Failed to ..."
* 4 spaces for indentation rather than tabs * 4 spaces for indentation rather than tabs
* 80 character line length
* ...
## Common Tasks ## Common Tasks
@ -166,7 +168,7 @@ If you have made changes to a provider's configuration in any form (introducing
If you are making changes to the documentation at [https://llama-stack.readthedocs.io/en/latest/](https://llama-stack.readthedocs.io/en/latest/), you can use the following command to build the documentation and preview your changes. You will need [Sphinx](https://www.sphinx-doc.org/en/master/) and the readthedocs theme. If you are making changes to the documentation at [https://llama-stack.readthedocs.io/en/latest/](https://llama-stack.readthedocs.io/en/latest/), you can use the following command to build the documentation and preview your changes. You will need [Sphinx](https://www.sphinx-doc.org/en/master/) and the readthedocs theme.
```bash ```bash
cd llama-stack/docs cd docs
uv sync --extra docs uv sync --extra docs
# This rebuilds the documentation pages. # This rebuilds the documentation pages.

View file

@ -7,10 +7,12 @@
"chardet", "chardet",
"chromadb-client", "chromadb-client",
"datasets", "datasets",
"emoji",
"faiss-cpu", "faiss-cpu",
"fastapi", "fastapi",
"fire", "fire",
"httpx", "httpx",
"langdetect",
"matplotlib", "matplotlib",
"mcp", "mcp",
"nltk", "nltk",
@ -23,6 +25,7 @@
"psycopg2-binary", "psycopg2-binary",
"pymongo", "pymongo",
"pypdf", "pypdf",
"pythainlp",
"redis", "redis",
"requests", "requests",
"scikit-learn", "scikit-learn",
@ -41,10 +44,12 @@
"chardet", "chardet",
"chromadb-client", "chromadb-client",
"datasets", "datasets",
"emoji",
"faiss-cpu", "faiss-cpu",
"fastapi", "fastapi",
"fire", "fire",
"httpx", "httpx",
"langdetect",
"matplotlib", "matplotlib",
"nltk", "nltk",
"numpy", "numpy",
@ -56,6 +61,7 @@
"psycopg2-binary", "psycopg2-binary",
"pymongo", "pymongo",
"pypdf", "pypdf",
"pythainlp",
"redis", "redis",
"requests", "requests",
"scikit-learn", "scikit-learn",
@ -75,10 +81,12 @@
"chardet", "chardet",
"chromadb-client", "chromadb-client",
"datasets", "datasets",
"emoji",
"fastapi", "fastapi",
"fire", "fire",
"fireworks-ai", "fireworks-ai",
"httpx", "httpx",
"langdetect",
"matplotlib", "matplotlib",
"mcp", "mcp",
"nltk", "nltk",
@ -91,6 +99,7 @@
"psycopg2-binary", "psycopg2-binary",
"pymongo", "pymongo",
"pypdf", "pypdf",
"pythainlp",
"redis", "redis",
"requests", "requests",
"scikit-learn", "scikit-learn",
@ -112,11 +121,13 @@
"chardet", "chardet",
"chromadb-client", "chromadb-client",
"datasets", "datasets",
"emoji",
"faiss-cpu", "faiss-cpu",
"fastapi", "fastapi",
"fire", "fire",
"httpx", "httpx",
"huggingface_hub", "huggingface_hub",
"langdetect",
"matplotlib", "matplotlib",
"nltk", "nltk",
"numpy", "numpy",
@ -128,6 +139,7 @@
"psycopg2-binary", "psycopg2-binary",
"pymongo", "pymongo",
"pypdf", "pypdf",
"pythainlp",
"redis", "redis",
"requests", "requests",
"scikit-learn", "scikit-learn",
@ -147,10 +159,12 @@
"chardet", "chardet",
"chromadb-client", "chromadb-client",
"datasets", "datasets",
"emoji",
"fastapi", "fastapi",
"fire", "fire",
"fireworks-ai", "fireworks-ai",
"httpx", "httpx",
"langdetect",
"litellm", "litellm",
"matplotlib", "matplotlib",
"mcp", "mcp",
@ -164,6 +178,7 @@
"psycopg2-binary", "psycopg2-binary",
"pymongo", "pymongo",
"pypdf", "pypdf",
"pythainlp",
"redis", "redis",
"requests", "requests",
"scikit-learn", "scikit-learn",
@ -184,11 +199,13 @@
"chardet", "chardet",
"chromadb-client", "chromadb-client",
"datasets", "datasets",
"emoji",
"faiss-cpu", "faiss-cpu",
"fastapi", "fastapi",
"fire", "fire",
"fireworks-ai", "fireworks-ai",
"httpx", "httpx",
"langdetect",
"matplotlib", "matplotlib",
"mcp", "mcp",
"nltk", "nltk",
@ -201,6 +218,7 @@
"psycopg2-binary", "psycopg2-binary",
"pymongo", "pymongo",
"pypdf", "pypdf",
"pythainlp",
"redis", "redis",
"requests", "requests",
"scikit-learn", "scikit-learn",
@ -219,10 +237,12 @@
"blobfile", "blobfile",
"chardet", "chardet",
"datasets", "datasets",
"emoji",
"faiss-cpu", "faiss-cpu",
"fastapi", "fastapi",
"fire", "fire",
"httpx", "httpx",
"langdetect",
"litellm", "litellm",
"matplotlib", "matplotlib",
"nltk", "nltk",
@ -235,6 +255,7 @@
"psycopg2-binary", "psycopg2-binary",
"pymongo", "pymongo",
"pypdf", "pypdf",
"pythainlp",
"redis", "redis",
"requests", "requests",
"scikit-learn", "scikit-learn",
@ -253,11 +274,13 @@
"chardet", "chardet",
"chromadb-client", "chromadb-client",
"datasets", "datasets",
"emoji",
"faiss-cpu", "faiss-cpu",
"fastapi", "fastapi",
"fire", "fire",
"httpx", "httpx",
"huggingface_hub", "huggingface_hub",
"langdetect",
"matplotlib", "matplotlib",
"mcp", "mcp",
"nltk", "nltk",
@ -270,6 +293,7 @@
"psycopg2-binary", "psycopg2-binary",
"pymongo", "pymongo",
"pypdf", "pypdf",
"pythainlp",
"redis", "redis",
"requests", "requests",
"scikit-learn", "scikit-learn",
@ -288,11 +312,13 @@
"chardet", "chardet",
"chromadb-client", "chromadb-client",
"datasets", "datasets",
"emoji",
"faiss-cpu", "faiss-cpu",
"fastapi", "fastapi",
"fire", "fire",
"httpx", "httpx",
"huggingface_hub", "huggingface_hub",
"langdetect",
"matplotlib", "matplotlib",
"mcp", "mcp",
"nltk", "nltk",
@ -305,6 +331,7 @@
"psycopg2-binary", "psycopg2-binary",
"pymongo", "pymongo",
"pypdf", "pypdf",
"pythainlp",
"redis", "redis",
"requests", "requests",
"scikit-learn", "scikit-learn",
@ -325,11 +352,13 @@
"chardet", "chardet",
"chromadb-client", "chromadb-client",
"datasets", "datasets",
"emoji",
"fairscale", "fairscale",
"faiss-cpu", "faiss-cpu",
"fastapi", "fastapi",
"fire", "fire",
"httpx", "httpx",
"langdetect",
"lm-format-enforcer", "lm-format-enforcer",
"matplotlib", "matplotlib",
"mcp", "mcp",
@ -343,6 +372,7 @@
"psycopg2-binary", "psycopg2-binary",
"pymongo", "pymongo",
"pypdf", "pypdf",
"pythainlp",
"redis", "redis",
"requests", "requests",
"scikit-learn", "scikit-learn",
@ -365,12 +395,14 @@
"chardet", "chardet",
"chromadb-client", "chromadb-client",
"datasets", "datasets",
"emoji",
"fairscale", "fairscale",
"faiss-cpu", "faiss-cpu",
"fastapi", "fastapi",
"fbgemm-gpu", "fbgemm-gpu",
"fire", "fire",
"httpx", "httpx",
"langdetect",
"lm-format-enforcer", "lm-format-enforcer",
"matplotlib", "matplotlib",
"mcp", "mcp",
@ -384,6 +416,7 @@
"psycopg2-binary", "psycopg2-binary",
"pymongo", "pymongo",
"pypdf", "pypdf",
"pythainlp",
"redis", "redis",
"requests", "requests",
"scikit-learn", "scikit-learn",
@ -403,10 +436,12 @@
"aiosqlite", "aiosqlite",
"blobfile", "blobfile",
"chardet", "chardet",
"emoji",
"faiss-cpu", "faiss-cpu",
"fastapi", "fastapi",
"fire", "fire",
"httpx", "httpx",
"langdetect",
"matplotlib", "matplotlib",
"nltk", "nltk",
"numpy", "numpy",
@ -418,6 +453,7 @@
"psycopg2-binary", "psycopg2-binary",
"pymongo", "pymongo",
"pypdf", "pypdf",
"pythainlp",
"redis", "redis",
"requests", "requests",
"scikit-learn", "scikit-learn",
@ -436,10 +472,12 @@
"chardet", "chardet",
"chromadb-client", "chromadb-client",
"datasets", "datasets",
"emoji",
"faiss-cpu", "faiss-cpu",
"fastapi", "fastapi",
"fire", "fire",
"httpx", "httpx",
"langdetect",
"matplotlib", "matplotlib",
"mcp", "mcp",
"nltk", "nltk",
@ -453,6 +491,7 @@
"psycopg2-binary", "psycopg2-binary",
"pymongo", "pymongo",
"pypdf", "pypdf",
"pythainlp",
"redis", "redis",
"requests", "requests",
"scikit-learn", "scikit-learn",
@ -470,9 +509,11 @@
"chardet", "chardet",
"chromadb-client", "chromadb-client",
"datasets", "datasets",
"emoji",
"fastapi", "fastapi",
"fire", "fire",
"httpx", "httpx",
"langdetect",
"litellm", "litellm",
"matplotlib", "matplotlib",
"mcp", "mcp",
@ -486,6 +527,7 @@
"psycopg2-binary", "psycopg2-binary",
"pymongo", "pymongo",
"pypdf", "pypdf",
"pythainlp",
"redis", "redis",
"requests", "requests",
"scikit-learn", "scikit-learn",
@ -505,10 +547,12 @@
"chardet", "chardet",
"chromadb-client", "chromadb-client",
"datasets", "datasets",
"emoji",
"faiss-cpu", "faiss-cpu",
"fastapi", "fastapi",
"fire", "fire",
"httpx", "httpx",
"langdetect",
"matplotlib", "matplotlib",
"mcp", "mcp",
"nltk", "nltk",
@ -521,6 +565,7 @@
"psycopg2-binary", "psycopg2-binary",
"pymongo", "pymongo",
"pypdf", "pypdf",
"pythainlp",
"redis", "redis",
"requests", "requests",
"scikit-learn", "scikit-learn",
@ -540,10 +585,12 @@
"chardet", "chardet",
"chromadb-client", "chromadb-client",
"datasets", "datasets",
"emoji",
"faiss-cpu", "faiss-cpu",
"fastapi", "fastapi",
"fire", "fire",
"httpx", "httpx",
"langdetect",
"matplotlib", "matplotlib",
"mcp", "mcp",
"nltk", "nltk",
@ -556,6 +603,7 @@
"psycopg2-binary", "psycopg2-binary",
"pymongo", "pymongo",
"pypdf", "pypdf",
"pythainlp",
"redis", "redis",
"requests", "requests",
"scikit-learn", "scikit-learn",
@ -608,11 +656,13 @@
"chardet", "chardet",
"chromadb-client", "chromadb-client",
"datasets", "datasets",
"emoji",
"faiss-cpu", "faiss-cpu",
"fastapi", "fastapi",
"fire", "fire",
"httpx", "httpx",
"huggingface_hub", "huggingface_hub",
"langdetect",
"matplotlib", "matplotlib",
"mcp", "mcp",
"nltk", "nltk",
@ -625,6 +675,7 @@
"psycopg2-binary", "psycopg2-binary",
"pymongo", "pymongo",
"pypdf", "pypdf",
"pythainlp",
"redis", "redis",
"requests", "requests",
"scikit-learn", "scikit-learn",
@ -644,10 +695,12 @@
"chardet", "chardet",
"chromadb-client", "chromadb-client",
"datasets", "datasets",
"emoji",
"faiss-cpu", "faiss-cpu",
"fastapi", "fastapi",
"fire", "fire",
"httpx", "httpx",
"langdetect",
"matplotlib", "matplotlib",
"mcp", "mcp",
"nltk", "nltk",
@ -660,6 +713,7 @@
"psycopg2-binary", "psycopg2-binary",
"pymongo", "pymongo",
"pypdf", "pypdf",
"pythainlp",
"redis", "redis",
"requests", "requests",
"scikit-learn", "scikit-learn",
@ -680,10 +734,12 @@
"chardet", "chardet",
"chromadb-client", "chromadb-client",
"datasets", "datasets",
"emoji",
"faiss-cpu", "faiss-cpu",
"fastapi", "fastapi",
"fire", "fire",
"httpx", "httpx",
"langdetect",
"matplotlib", "matplotlib",
"mcp", "mcp",
"nltk", "nltk",
@ -696,6 +752,7 @@
"psycopg2-binary", "psycopg2-binary",
"pymongo", "pymongo",
"pypdf", "pypdf",
"pythainlp",
"redis", "redis",
"requests", "requests",
"scikit-learn", "scikit-learn",

View file

@ -51,14 +51,14 @@ services:
- ~/local/llama-stack/:/app/llama-stack-source - ~/local/llama-stack/:/app/llama-stack-source
- ./run${SAFETY_MODEL:+-with-safety}.yaml:/root/my-run.yaml - ./run${SAFETY_MODEL:+-with-safety}.yaml:/root/my-run.yaml
ports: ports:
- "${LLAMA_STACK_PORT:-5001}:${LLAMA_STACK_PORT:-5001}" - "${LLAMA_STACK_PORT:-8321}:${LLAMA_STACK_PORT:-8321}"
environment: environment:
- INFERENCE_MODEL=${INFERENCE_MODEL} - INFERENCE_MODEL=${INFERENCE_MODEL}
- SAFETY_MODEL=${SAFETY_MODEL:-} - SAFETY_MODEL=${SAFETY_MODEL:-}
- OLLAMA_URL=http://ollama:11434 - OLLAMA_URL=http://ollama:11434
entrypoint: > entrypoint: >
python -m llama_stack.distribution.server.server /root/my-run.yaml \ python -m llama_stack.distribution.server.server /root/my-run.yaml \
--port ${LLAMA_STACK_PORT:-5001} --port ${LLAMA_STACK_PORT:-8321}
deploy: deploy:
restart_policy: restart_policy:
condition: on-failure condition: on-failure

Binary file not shown.

View file

@ -84,9 +84,9 @@ services:
- SQLITE_STORE_DIR=${SQLITE_STORE_DIR:-$HOME/.llama/distributions/remote-vllm} - SQLITE_STORE_DIR=${SQLITE_STORE_DIR:-$HOME/.llama/distributions/remote-vllm}
- SAFETY_MODEL=${SAFETY_MODEL:-meta-llama/Llama-Guard-3-1B} - SAFETY_MODEL=${SAFETY_MODEL:-meta-llama/Llama-Guard-3-1B}
ports: ports:
- "${LLAMA_STACK_PORT:-5001}:${LLAMA_STACK_PORT:-5001}" - "${LLAMA_STACK_PORT:-8321}:${LLAMA_STACK_PORT:-8321}"
# Hack: wait for vLLM server to start before starting docker # Hack: wait for vLLM server to start before starting docker
entrypoint: bash -c "sleep 60; python -m llama_stack.distribution.server.server --yaml_config /root/llamastack-run-remote-vllm.yaml --port 5001" entrypoint: bash -c "sleep 60; python -m llama_stack.distribution.server.server --yaml_config /root/llamastack-run-remote-vllm.yaml --port 8321"
deploy: deploy:
restart_policy: restart_policy:
condition: on-failure condition: on-failure

View file

@ -83,7 +83,7 @@ services:
- ~/.llama:/root/.llama - ~/.llama:/root/.llama
- ./run${TGI_SAFETY_MODEL:+-with-safety}.yaml:/root/my-run.yaml - ./run${TGI_SAFETY_MODEL:+-with-safety}.yaml:/root/my-run.yaml
ports: ports:
- "${LLAMA_STACK_PORT:-5001}:${LLAMA_STACK_PORT:-5001}" - "${LLAMA_STACK_PORT:-8321}:${LLAMA_STACK_PORT:-8321}"
# Hack: wait for TGI server to start before starting docker # Hack: wait for TGI server to start before starting docker
entrypoint: bash -c "sleep 60; python -m llama_stack.distribution.server.server --yaml_config /root/my-run.yaml" entrypoint: bash -c "sleep 60; python -m llama_stack.distribution.server.server --yaml_config /root/my-run.yaml"
restart_policy: restart_policy:

View file

@ -2183,7 +2183,7 @@
"content": { "content": {
"application/json": { "application/json": {
"schema": { "schema": {
"$ref": "#/components/schemas/JobStatus" "$ref": "#/components/schemas/Job"
} }
} }
} }
@ -2650,7 +2650,7 @@
} }
}, },
"tags": [ "tags": [
"Inspect" "Providers"
], ],
"description": "", "description": "",
"parameters": [] "parameters": []
@ -6268,6 +6268,7 @@
"type": "string", "type": "string",
"enum": [ "enum": [
"average", "average",
"weighted_average",
"median", "median",
"categorical_count", "categorical_count",
"accuracy" "accuracy"
@ -7647,7 +7648,13 @@
"title": "PostTrainingJobArtifactsResponse", "title": "PostTrainingJobArtifactsResponse",
"description": "Artifacts of a finetuning job." "description": "Artifacts of a finetuning job."
}, },
"JobStatus": { "PostTrainingJobStatusResponse": {
"type": "object",
"properties": {
"job_uuid": {
"type": "string"
},
"status": {
"type": "string", "type": "string",
"enum": [ "enum": [
"completed", "completed",
@ -7657,15 +7664,6 @@
], ],
"title": "JobStatus" "title": "JobStatus"
}, },
"PostTrainingJobStatusResponse": {
"type": "object",
"properties": {
"job_uuid": {
"type": "string"
},
"status": {
"$ref": "#/components/schemas/JobStatus"
},
"scheduled_at": { "scheduled_at": {
"type": "string", "type": "string",
"format": "date-time" "format": "date-time"
@ -8068,9 +8066,6 @@
} }
}, },
"additionalProperties": false, "additionalProperties": false,
"required": [
"content"
],
"title": "ToolInvocationResult" "title": "ToolInvocationResult"
}, },
"IterrowsResponse": { "IterrowsResponse": {
@ -8117,6 +8112,30 @@
"title": "IterrowsResponse", "title": "IterrowsResponse",
"description": "A paginated list of rows from a dataset." "description": "A paginated list of rows from a dataset."
}, },
"Job": {
"type": "object",
"properties": {
"job_id": {
"type": "string"
},
"status": {
"type": "string",
"enum": [
"completed",
"in_progress",
"failed",
"scheduled"
],
"title": "JobStatus"
}
},
"additionalProperties": false,
"required": [
"job_id",
"status"
],
"title": "Job"
},
"ListAgentSessionsResponse": { "ListAgentSessionsResponse": {
"type": "object", "type": "object",
"properties": { "properties": {
@ -9641,19 +9660,6 @@
], ],
"title": "RunEvalRequest" "title": "RunEvalRequest"
}, },
"Job": {
"type": "object",
"properties": {
"job_id": {
"type": "string"
}
},
"additionalProperties": false,
"required": [
"job_id"
],
"title": "Job"
},
"RunShieldRequest": { "RunShieldRequest": {
"type": "object", "type": "object",
"properties": { "properties": {
@ -9862,6 +9868,23 @@
], ],
"title": "ScoreBatchResponse" "title": "ScoreBatchResponse"
}, },
"AlgorithmConfig": {
"oneOf": [
{
"$ref": "#/components/schemas/LoraFinetuningConfig"
},
{
"$ref": "#/components/schemas/QATFinetuningConfig"
}
],
"discriminator": {
"propertyName": "type",
"mapping": {
"LoRA": "#/components/schemas/LoraFinetuningConfig",
"QAT": "#/components/schemas/QATFinetuningConfig"
}
}
},
"LoraFinetuningConfig": { "LoraFinetuningConfig": {
"type": "object", "type": "object",
"properties": { "properties": {
@ -9997,14 +10020,7 @@
"type": "string" "type": "string"
}, },
"algorithm_config": { "algorithm_config": {
"oneOf": [ "$ref": "#/components/schemas/AlgorithmConfig"
{
"$ref": "#/components/schemas/LoraFinetuningConfig"
},
{
"$ref": "#/components/schemas/QATFinetuningConfig"
}
]
} }
}, },
"additionalProperties": false, "additionalProperties": false,

View file

@ -1491,7 +1491,7 @@ paths:
content: content:
application/json: application/json:
schema: schema:
$ref: '#/components/schemas/JobStatus' $ref: '#/components/schemas/Job'
'400': '400':
$ref: '#/components/responses/BadRequest400' $ref: '#/components/responses/BadRequest400'
'429': '429':
@ -1814,7 +1814,7 @@ paths:
default: default:
$ref: '#/components/responses/DefaultError' $ref: '#/components/responses/DefaultError'
tags: tags:
- Inspect - Providers
description: '' description: ''
parameters: [] parameters: []
/v1/inspect/routes: /v1/inspect/routes:
@ -4389,6 +4389,7 @@ components:
type: string type: string
enum: enum:
- average - average
- weighted_average
- median - median
- categorical_count - categorical_count
- accuracy - accuracy
@ -5276,7 +5277,12 @@ components:
- checkpoints - checkpoints
title: PostTrainingJobArtifactsResponse title: PostTrainingJobArtifactsResponse
description: Artifacts of a finetuning job. description: Artifacts of a finetuning job.
JobStatus: PostTrainingJobStatusResponse:
type: object
properties:
job_uuid:
type: string
status:
type: string type: string
enum: enum:
- completed - completed
@ -5284,13 +5290,6 @@ components:
- failed - failed
- scheduled - scheduled
title: JobStatus title: JobStatus
PostTrainingJobStatusResponse:
type: object
properties:
job_uuid:
type: string
status:
$ref: '#/components/schemas/JobStatus'
scheduled_at: scheduled_at:
type: string type: string
format: date-time format: date-time
@ -5528,8 +5527,6 @@ components:
- type: array - type: array
- type: object - type: object
additionalProperties: false additionalProperties: false
required:
- content
title: ToolInvocationResult title: ToolInvocationResult
IterrowsResponse: IterrowsResponse:
type: object type: object
@ -5557,6 +5554,24 @@ components:
- data - data
title: IterrowsResponse title: IterrowsResponse
description: A paginated list of rows from a dataset. description: A paginated list of rows from a dataset.
Job:
type: object
properties:
job_id:
type: string
status:
type: string
enum:
- completed
- in_progress
- failed
- scheduled
title: JobStatus
additionalProperties: false
required:
- job_id
- status
title: Job
ListAgentSessionsResponse: ListAgentSessionsResponse:
type: object type: object
properties: properties:
@ -6551,15 +6566,6 @@ components:
required: required:
- benchmark_config - benchmark_config
title: RunEvalRequest title: RunEvalRequest
Job:
type: object
properties:
job_id:
type: string
additionalProperties: false
required:
- job_id
title: Job
RunShieldRequest: RunShieldRequest:
type: object type: object
properties: properties:
@ -6688,6 +6694,15 @@ components:
required: required:
- results - results
title: ScoreBatchResponse title: ScoreBatchResponse
AlgorithmConfig:
oneOf:
- $ref: '#/components/schemas/LoraFinetuningConfig'
- $ref: '#/components/schemas/QATFinetuningConfig'
discriminator:
propertyName: type
mapping:
LoRA: '#/components/schemas/LoraFinetuningConfig'
QAT: '#/components/schemas/QATFinetuningConfig'
LoraFinetuningConfig: LoraFinetuningConfig:
type: object type: object
properties: properties:
@ -6771,9 +6786,7 @@ components:
checkpoint_dir: checkpoint_dir:
type: string type: string
algorithm_config: algorithm_config:
oneOf: $ref: '#/components/schemas/AlgorithmConfig'
- $ref: '#/components/schemas/LoraFinetuningConfig'
- $ref: '#/components/schemas/QATFinetuningConfig'
additionalProperties: false additionalProperties: false
required: required:
- job_uuid - job_uuid

View file

@ -4,6 +4,21 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import os
import time
def pytest_collection_modifyitems(items): def pytest_collection_modifyitems(items):
for item in items: for item in items:
item.name = item.name.replace(' ', '_') item.name = item.name.replace(' ', '_')
def pytest_runtest_teardown(item):
interval_seconds = os.getenv("LLAMA_STACK_TEST_INTERVAL_SECONDS")
if interval_seconds:
time.sleep(float(interval_seconds))
def pytest_configure(config):
config.option.tbstyle = "short"
config.option.disable_warnings = True

View file

@ -123,6 +123,8 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"# NBVAL_SKIP\n", "# NBVAL_SKIP\n",
"!pip uninstall pandas numpy -y\n",
"!pip install pandas numpy\n",
"# This will build all the dependencies you will need\n", "# This will build all the dependencies you will need\n",
"!UV_SYSTEM_PYTHON=1 llama stack build --template together --image-type venv" "!UV_SYSTEM_PYTHON=1 llama stack build --template together --image-type venv"
] ]
@ -1203,7 +1205,7 @@
} }
], ],
"source": [ "source": [
"from llama_stack_client.lib.inference.event_logger import EventLogger\n", "from llama_stack_client import InferenceEventLogger\n",
"\n", "\n",
"message = {\"role\": \"user\", \"content\": \"Write me a sonnet about llama\"}\n", "message = {\"role\": \"user\", \"content\": \"Write me a sonnet about llama\"}\n",
"print(f'User> {message[\"content\"]}', \"green\")\n", "print(f'User> {message[\"content\"]}', \"green\")\n",
@ -1215,7 +1217,7 @@
")\n", ")\n",
"\n", "\n",
"# Print the tokens while they are received\n", "# Print the tokens while they are received\n",
"for log in EventLogger().log(response):\n", "for log in InferenceEventLogger().log(response):\n",
" log.print()\n" " log.print()\n"
] ]
}, },
@ -1632,8 +1634,7 @@
} }
], ],
"source": [ "source": [
"from llama_stack_client.lib.agents.agent import Agent\n", "from llama_stack_client import Agent, AgentEventLogger\n",
"from llama_stack_client.lib.agents.event_logger import EventLogger\n",
"from termcolor import cprint\n", "from termcolor import cprint\n",
"\n", "\n",
"agent = Agent(\n", "agent = Agent(\n",
@ -1659,7 +1660,7 @@
" ],\n", " ],\n",
" session_id=session_id,\n", " session_id=session_id,\n",
" )\n", " )\n",
" for log in EventLogger().log(response):\n", " for log in AgentEventLogger().log(response):\n",
" log.print()\n" " log.print()\n"
] ]
}, },
@ -1808,14 +1809,12 @@
], ],
"source": [ "source": [
"import uuid\n", "import uuid\n",
"from llama_stack_client.lib.agents.agent import Agent\n", "from llama_stack_client import Agent, AgentEventLogger, RAGDocument\n",
"from llama_stack_client.lib.agents.event_logger import EventLogger\n",
"from termcolor import cprint\n", "from termcolor import cprint\n",
"from llama_stack_client.types import Document\n",
"\n", "\n",
"urls = [\"chat.rst\", \"llama3.rst\", \"memory_optimizations.rst\", \"lora_finetune.rst\"]\n", "urls = [\"chat.rst\", \"llama3.rst\", \"memory_optimizations.rst\", \"lora_finetune.rst\"]\n",
"documents = [\n", "documents = [\n",
" Document(\n", " RAGDocument(\n",
" document_id=f\"num-{i}\",\n", " document_id=f\"num-{i}\",\n",
" content=f\"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}\",\n", " content=f\"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}\",\n",
" mime_type=\"text/plain\",\n", " mime_type=\"text/plain\",\n",
@ -1858,7 +1857,7 @@
" messages=[{\"role\": \"user\", \"content\": prompt}],\n", " messages=[{\"role\": \"user\", \"content\": prompt}],\n",
" session_id=session_id,\n", " session_id=session_id,\n",
" )\n", " )\n",
" for log in EventLogger().log(response):\n", " for log in AgentEventLogger().log(response):\n",
" log.print()" " log.print()"
] ]
}, },
@ -1969,7 +1968,7 @@
} }
], ],
"source": [ "source": [
"from llama_stack_client.types.agents.turn_create_params import Document\n", "from llama_stack_client import Document\n",
"\n", "\n",
"codex_agent = Agent(\n", "codex_agent = Agent(\n",
" client, \n", " client, \n",
@ -2013,7 +2012,7 @@
" # for chunk in response:\n", " # for chunk in response:\n",
" # print(chunk)\n", " # print(chunk)\n",
"\n", "\n",
" for log in EventLogger().log(response):\n", " for log in AgentEventLogger().log(response):\n",
" log.print()\n" " log.print()\n"
] ]
}, },
@ -2891,8 +2890,7 @@
], ],
"source": [ "source": [
"# NBVAL_SKIP\n", "# NBVAL_SKIP\n",
"from llama_stack_client.lib.agents.agent import Agent\n", "from llama_stack_client import Agent, AgentEventLogger\n",
"from llama_stack_client.lib.agents.event_logger import EventLogger\n",
"from termcolor import cprint\n", "from termcolor import cprint\n",
"\n", "\n",
"agent = Agent(\n", "agent = Agent(\n",
@ -2918,7 +2916,7 @@
" ],\n", " ],\n",
" session_id=session_id,\n", " session_id=session_id,\n",
" )\n", " )\n",
" for log in EventLogger().log(response):\n", " for log in AgentEventLogger().log(response):\n",
" log.print()\n" " log.print()\n"
] ]
}, },
@ -2993,8 +2991,7 @@
} }
], ],
"source": [ "source": [
"from llama_stack_client.lib.agents.agent import Agent\n", "from llama_stack_client import Agent, AgentEventLogger\n",
"from llama_stack_client.lib.agents.event_logger import EventLogger\n",
"\n", "\n",
"agent = Agent(\n", "agent = Agent(\n",
" client, \n", " client, \n",
@ -3021,7 +3018,7 @@
" session_id=session_id,\n", " session_id=session_id,\n",
" )\n", " )\n",
"\n", "\n",
" for log in EventLogger().log(response):\n", " for log in AgentEventLogger().log(response):\n",
" log.print()\n" " log.print()\n"
] ]
}, },
@ -4355,7 +4352,7 @@
" session_id=session_id,\n", " session_id=session_id,\n",
")\n", ")\n",
"\n", "\n",
"for log in EventLogger().log(response):\n", "for log in AgentEventLogger().log(response):\n",
" log.print()\n", " log.print()\n",
" " " "
] ]

View file

@ -47,9 +47,8 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"from llama_stack_client import LlamaStackClient\n", "from llama_stack_client import LlamaStackClient, Agent\n",
"from llama_stack.distribution.library_client import LlamaStackAsLibraryClient\n", "from llama_stack.distribution.library_client import LlamaStackAsLibraryClient\n",
"from llama_stack_client.lib.agents.agent import Agent\n",
"from rich.pretty import pprint\n", "from rich.pretty import pprint\n",
"import json\n", "import json\n",
"import uuid\n", "import uuid\n",

View file

@ -22,7 +22,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 2, "execution_count": 13,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
@ -34,10 +34,8 @@
} }
], ],
"source": [ "source": [
"from llama_stack_client import LlamaStackClient\n", "from llama_stack_client import LlamaStackClient, Agent\n",
"from llama_stack.distribution.library_client import LlamaStackAsLibraryClient\n", "from llama_stack.distribution.library_client import LlamaStackAsLibraryClient\n",
"from llama_stack_client.types.agent_create_params import AgentConfig\n",
"from llama_stack_client.lib.agents.agent import Agent\n",
"from rich.pretty import pprint\n", "from rich.pretty import pprint\n",
"import json\n", "import json\n",
"import uuid\n", "import uuid\n",
@ -70,7 +68,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 12, "execution_count": 14,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1397,6 +1395,349 @@
"pprint(session_response.turns)" "pprint(session_response.turns)"
] ]
}, },
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 3.1 Improved RAG with Long Context\n",
"\n",
"- Instead of performing reteival tool, we send documents as attachments to the agent and let it use the entire document context. \n",
"- Note how that the model is able to understand the entire context from documentation and answers the question with better factuality with improved retrieval. "
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">Question:</span> What precision formats does torchtune support?\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1;36mQuestion:\u001b[0m What precision formats does torchtune support?\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #808000; text-decoration-color: #808000; font-weight: bold\">Agent Answer:</span> Torchtune supports two precision formats: `fp32` <span style=\"font-weight: bold\">(</span>full-precision<span style=\"font-weight: bold\">)</span> and `bfloat16` <span style=\"font-weight: bold\">(</span>half-precision<span style=\"font-weight: bold\">)</span>. \n",
"The `bfloat16` format uses <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2</span> bytes per model parameter, which is half the memory of `fp32`, and also improves \n",
"training speed.\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1;33mAgent Answer:\u001b[0m Torchtune supports two precision formats: `fp32` \u001b[1m(\u001b[0mfull-precision\u001b[1m)\u001b[0m and `bfloat16` \u001b[1m(\u001b[0mhalf-precision\u001b[1m)\u001b[0m. \n",
"The `bfloat16` format uses \u001b[1;36m2\u001b[0m bytes per model parameter, which is half the memory of `fp32`, and also improves \n",
"training speed.\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">Question:</span> What does DoRA stand for in torchtune?\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1;36mQuestion:\u001b[0m What does DoRA stand for in torchtune?\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #808000; text-decoration-color: #808000; font-weight: bold\">Agent Answer:</span> DoRA stands for Weight-Decomposed Low-Rank Adaptation. It is a variant of LoRA <span style=\"font-weight: bold\">(</span>Low-Rank Adaptation<span style=\"font-weight: bold\">)</span> \n",
"that further decomposes the pre-trained weights into two components: magnitude and direction. The magnitude \n",
"component is a scalar vector that adjusts the scale, while the direction component corresponds to the original LoRA\n",
"decomposition and updates the orientation of weights. DoRA adds a small overhead to LoRA training due to the \n",
"addition of the magnitude parameter, but it has been shown to improve the performance of LoRA, particularly at low \n",
"ranks.\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1;33mAgent Answer:\u001b[0m DoRA stands for Weight-Decomposed Low-Rank Adaptation. It is a variant of LoRA \u001b[1m(\u001b[0mLow-Rank Adaptation\u001b[1m)\u001b[0m \n",
"that further decomposes the pre-trained weights into two components: magnitude and direction. The magnitude \n",
"component is a scalar vector that adjusts the scale, while the direction component corresponds to the original LoRA\n",
"decomposition and updates the orientation of weights. DoRA adds a small overhead to LoRA training due to the \n",
"addition of the magnitude parameter, but it has been shown to improve the performance of LoRA, particularly at low \n",
"ranks.\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">Question:</span> How does the CPUOffloadOptimizer reduce GPU memory usage?\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1;36mQuestion:\u001b[0m How does the CPUOffloadOptimizer reduce GPU memory usage?\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #808000; text-decoration-color: #808000; font-weight: bold\">Agent Answer:</span> The CPUOffloadOptimizer reduces GPU memory usage by offloading optimizer states and gradients to the \n",
"CPU, and performing optimizer steps on the CPU. This can significantly reduce GPU memory usage at the cost of CPU \n",
"RAM and training speed.\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1;33mAgent Answer:\u001b[0m The CPUOffloadOptimizer reduces GPU memory usage by offloading optimizer states and gradients to the \n",
"CPU, and performing optimizer steps on the CPU. This can significantly reduce GPU memory usage at the cost of CPU \n",
"RAM and training speed.\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">Question:</span> How do I ensure only LoRA parameters are trainable when fine-tuning?\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1;36mQuestion:\u001b[0m How do I ensure only LoRA parameters are trainable when fine-tuning?\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #808000; text-decoration-color: #808000; font-weight: bold\">Agent Answer:</span> To ensure only LoRA parameters are trainable when fine-tuning, you can use the `set_trainable_params`\n",
"function from `torchtune.modules.peft.peft_utils` to set the `requires_grad` attribute of the LoRA parameters to \n",
"`<span style=\"color: #00ff00; text-decoration-color: #00ff00; font-style: italic\">True</span>` and the `requires_grad` attribute of the other parameters to `<span style=\"color: #ff0000; text-decoration-color: #ff0000; font-style: italic\">False</span>`.\n",
"\n",
"Here is an example:\n",
"```python\n",
"from torchtune.modules.peft.peft_utils import get_adapter_params, set_trainable_params\n",
"\n",
"# Get the LoRA parameters\n",
"lora_params = <span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">get_adapter_params</span><span style=\"font-weight: bold\">(</span>model<span style=\"font-weight: bold\">)</span>\n",
"\n",
"# Set the LoRA parameters to trainable and the other parameters to non-trainable\n",
"<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">set_trainable_params</span><span style=\"font-weight: bold\">(</span>model, lora_params<span style=\"font-weight: bold\">)</span>\n",
"```\n",
"This will ensure that only the LoRA parameters are updated during fine-tuning, while the other parameters remain \n",
"frozen.\n",
"\n",
"Alternatively, you can also use the `lora_finetune` recipe in torchtune, which automatically sets the LoRA \n",
"parameters to trainable and the other parameters to non-trainable. You can run the recipe using the following \n",
"command:\n",
"```bash\n",
"tune run lora_finetune --config llama2/7B_lora\n",
"```\n",
"This will fine-tune the LoRA parameters of the Llama2 model using the default settings. You can modify the config \n",
"file to change the hyperparameters or the model architecture.\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1;33mAgent Answer:\u001b[0m To ensure only LoRA parameters are trainable when fine-tuning, you can use the `set_trainable_params`\n",
"function from `torchtune.modules.peft.peft_utils` to set the `requires_grad` attribute of the LoRA parameters to \n",
"`\u001b[3;92mTrue\u001b[0m` and the `requires_grad` attribute of the other parameters to `\u001b[3;91mFalse\u001b[0m`.\n",
"\n",
"Here is an example:\n",
"```python\n",
"from torchtune.modules.peft.peft_utils import get_adapter_params, set_trainable_params\n",
"\n",
"# Get the LoRA parameters\n",
"lora_params = \u001b[1;35mget_adapter_params\u001b[0m\u001b[1m(\u001b[0mmodel\u001b[1m)\u001b[0m\n",
"\n",
"# Set the LoRA parameters to trainable and the other parameters to non-trainable\n",
"\u001b[1;35mset_trainable_params\u001b[0m\u001b[1m(\u001b[0mmodel, lora_params\u001b[1m)\u001b[0m\n",
"```\n",
"This will ensure that only the LoRA parameters are updated during fine-tuning, while the other parameters remain \n",
"frozen.\n",
"\n",
"Alternatively, you can also use the `lora_finetune` recipe in torchtune, which automatically sets the LoRA \n",
"parameters to trainable and the other parameters to non-trainable. You can run the recipe using the following \n",
"command:\n",
"```bash\n",
"tune run lora_finetune --config llama2/7B_lora\n",
"```\n",
"This will fine-tune the LoRA parameters of the Llama2 model using the default settings. You can modify the config \n",
"file to change the hyperparameters or the model architecture.\n"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"urls = [\n",
" \"memory_optimizations.rst\",\n",
" \"chat.rst\",\n",
" \"llama3.rst\",\n",
" \"datasets.rst\",\n",
" \"qat_finetune.rst\",\n",
" \"lora_finetune.rst\",\n",
"]\n",
"\n",
"attachments = [\n",
" {\n",
" \"content\": {\n",
" \"uri\": f\"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}\",\n",
" },\n",
" \"mime_type\": \"text/plain\",\n",
" }\n",
"\n",
" for i, url in enumerate(urls)\n",
"]\n",
"\n",
"rag_attachment_agent = Agent(\n",
" client,\n",
" model=MODEL_ID,\n",
" instructions=\"You are a helpful assistant that can answer questions about the Torchtune project. Use context from attached documentation for Torchtune to answer questions.\",\n",
")\n",
"\n",
"for example in examples:\n",
" session_id = rag_attachment_agent.create_session(session_name=f\"rag_attachment_session_{uuid.uuid4()}\")\n",
" response = rag_attachment_agent.create_turn(\n",
" messages=[\n",
" {\n",
" \"role\": \"user\",\n",
" \"content\": example[\"input_query\"]\n",
" }\n",
" ],\n",
" session_id=session_id,\n",
" documents=attachments,\n",
" stream=False\n",
" )\n",
" rich.print(f\"[bold cyan]Question:[/bold cyan] {example['input_query']}\")\n",
" rich.print(f\"[bold yellow]Agent Answer:[/bold yellow] {response.output_message.content}\")\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">ScoringScoreResponse</span><span style=\"font-weight: bold\">(</span>\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ </span><span style=\"color: #808000; text-decoration-color: #808000\">results</span>=<span style=\"font-weight: bold\">{</span>\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ </span><span style=\"color: #008000; text-decoration-color: #008000\">'braintrust::factuality'</span>: <span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">ScoringResult</span><span style=\"font-weight: bold\">(</span>\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ </span><span style=\"color: #808000; text-decoration-color: #808000\">aggregated_results</span>=<span style=\"font-weight: bold\">{</span><span style=\"color: #008000; text-decoration-color: #008000\">'average'</span>: <span style=\"font-weight: bold\">{</span><span style=\"color: #008000; text-decoration-color: #008000\">'average'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.6</span><span style=\"font-weight: bold\">}}</span>,\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ </span><span style=\"color: #808000; text-decoration-color: #808000\">score_rows</span>=<span style=\"font-weight: bold\">[</span>\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ </span><span style=\"font-weight: bold\">{</span>\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ │ </span><span style=\"color: #008000; text-decoration-color: #008000\">'score'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.6</span>,\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ │ </span><span style=\"color: #008000; text-decoration-color: #008000\">'metadata'</span>: <span style=\"font-weight: bold\">{</span>\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ │ │ </span><span style=\"color: #008000; text-decoration-color: #008000\">'choice'</span>: <span style=\"color: #008000; text-decoration-color: #008000\">'B'</span>,\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ │ │ </span><span style=\"color: #008000; text-decoration-color: #008000\">'rationale'</span>: <span style=\"color: #008000; text-decoration-color: #008000\">'1. Both the expert and the submitted answers mention that Torchtune supports two precision formats: `fp32` (full-precision) and `bfloat16` (half-precision).\\n2. The expert answer specifies that `fp32` uses 4 bytes per model and optimizer parameter, while `bfloat16` uses 2 bytes per model and optimizer parameter.\\n3. The submitted answer also mentions that `bfloat16` uses 2 bytes per model parameter, which is consistent with the expert answer.\\n4. The submitted answer adds that `bfloat16` improves training speed, which is additional information not present in the expert answer.\\n5. There is no conflict between the submitted answer and the expert answer; the submitted answer simply provides more information.\\n\\nBased on this analysis, the submitted answer is a superset of the expert answer and is fully consistent with it.'</span>\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ │ </span><span style=\"font-weight: bold\">}</span>\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ </span><span style=\"font-weight: bold\">}</span>,\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ </span><span style=\"font-weight: bold\">{</span>\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ │ </span><span style=\"color: #008000; text-decoration-color: #008000\">'score'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.6</span>,\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ │ </span><span style=\"color: #008000; text-decoration-color: #008000\">'metadata'</span>: <span style=\"font-weight: bold\">{</span>\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ │ │ </span><span style=\"color: #008000; text-decoration-color: #008000\">'choice'</span>: <span style=\"color: #008000; text-decoration-color: #008000\">'B'</span>,\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ │ │ </span><span style=\"color: #008000; text-decoration-color: #008000\">'rationale'</span>: <span style=\"color: #008000; text-decoration-color: #008000\">'1. The expert answer provides the definition of DoRA as \"Weight-Decomposed Low-Rank Adaptation.\"\\n2. The submitted answer also states that DoRA stands for \"Weight-Decomposed Low-Rank Adaptation,\" which matches the expert answer.\\n3. The submitted answer includes additional information about DoRA, explaining that it is a variant of LoRA and describing how it decomposes pre-trained weights into magnitude and direction components.\\n4. The submitted answer further explains the role of the magnitude component and the direction component, and mentions the performance improvement and overhead associated with DoRA.\\n5. The additional details in the submitted answer do not contradict the expert answer; instead, they expand upon it.\\n6. Therefore, the submitted answer is a superset of the expert answer and is fully consistent with it.'</span>\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ │ </span><span style=\"font-weight: bold\">}</span>\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ </span><span style=\"font-weight: bold\">}</span>,\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ </span><span style=\"font-weight: bold\">{</span>\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ │ </span><span style=\"color: #008000; text-decoration-color: #008000\">'score'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.6</span>,\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ │ </span><span style=\"color: #008000; text-decoration-color: #008000\">'metadata'</span>: <span style=\"font-weight: bold\">{</span>\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ │ │ </span><span style=\"color: #008000; text-decoration-color: #008000\">'choice'</span>: <span style=\"color: #008000; text-decoration-color: #008000\">'B'</span>,\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ │ │ </span><span style=\"color: #008000; text-decoration-color: #008000\">'rationale'</span>: <span style=\"color: #008000; text-decoration-color: #008000\">'1. The expert answer states that the CPUOffloadOptimizer reduces GPU memory usage by keeping optimizer states on CPU and performing optimizer steps on CPU. It also mentions the optional offloading of gradients to CPU with the parameter offload_gradients=True.\\n\\n2. The submitted answer states that the CPUOffloadOptimizer reduces GPU memory usage by offloading optimizer states and gradients to the CPU, and performing optimizer steps on the CPU. It adds that this can significantly reduce GPU memory usage at the cost of CPU RAM and training speed.\\n\\n3. Comparing both answers:\\n - Both answers agree on offloading optimizer states to the CPU and performing optimizer steps on the CPU.\\n - Both mention the offloading of gradients to the CPU, but the expert answer specifies it as optional with a parameter, while the submission does not specify this detail.\\n - The submission adds additional information about the trade-off involving CPU RAM and training speed, which is not mentioned in the expert answer.\\n\\n4. The submitted answer includes all the details from the expert answer and adds more information about the trade-offs, making it a superset of the expert answer.\\n\\nTherefore, the correct choice is (B) The submitted answer is a superset of the expert answer and is fully consistent with it.'</span>\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ │ </span><span style=\"font-weight: bold\">}</span>\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ </span><span style=\"font-weight: bold\">}</span>,\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ </span><span style=\"font-weight: bold\">{</span>\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ │ </span><span style=\"color: #008000; text-decoration-color: #008000\">'score'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.6</span>,\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ │ </span><span style=\"color: #008000; text-decoration-color: #008000\">'metadata'</span>: <span style=\"font-weight: bold\">{</span>\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ │ │ </span><span style=\"color: #008000; text-decoration-color: #008000\">'choice'</span>: <span style=\"color: #008000; text-decoration-color: #008000\">'B'</span>,\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ │ │ </span><span style=\"color: #008000; text-decoration-color: #008000\">'rationale'</span>: <span style=\"color: #008000; text-decoration-color: #008000\">\"1. **Expert Answer Analysis**: The expert answer provides a method to ensure only LoRA parameters are trainable by using torchtune's utility functions. It mentions fetching LoRA parameters with `get_adapter_params(lora_model)` and setting them as trainable with `set_trainable_params(lora_model, lora_params)`. It also notes that the LoRA recipe handles this automatically.\\n\\n2. **Submitted Answer Analysis**: The submitted answer provides a similar method using `set_trainable_params` to set the `requires_grad` attribute of LoRA parameters to `True` and other parameters to `False`. It includes a code example demonstrating this process. Additionally, it mentions using the `lora_finetune` recipe in torchtune, which automatically sets the LoRA parameters to trainable.\\n\\n3. **Comparison**: The submitted answer includes all the details from the expert answer regarding the use of `get_adapter_params` and `set_trainable_params`. It also provides additional information about setting the `requires_grad` attribute and using the `lora_finetune` recipe, which is not mentioned in the expert answer.\\n\\n4. **Conclusion**: The submitted answer is a superset of the expert answer as it contains all the information from the expert answer and additional details. There is no conflict between the two answers, and the additional information in the submission is consistent with the expert's explanation.\\n\\nTherefore, the correct choice is (B) The submitted answer is a superset of the expert answer and is fully consistent with it.\"</span>\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ │ </span><span style=\"font-weight: bold\">}</span>\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ </span><span style=\"font-weight: bold\">}</span>\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ </span><span style=\"font-weight: bold\">]</span>\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ </span><span style=\"font-weight: bold\">)</span>\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ </span><span style=\"font-weight: bold\">}</span>\n",
"<span style=\"font-weight: bold\">)</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1;35mScoringScoreResponse\u001b[0m\u001b[1m(\u001b[0m\n",
"\u001b[2;32m│ \u001b[0m\u001b[33mresults\u001b[0m=\u001b[1m{\u001b[0m\n",
"\u001b[2;32m│ │ \u001b[0m\u001b[32m'braintrust::factuality'\u001b[0m: \u001b[1;35mScoringResult\u001b[0m\u001b[1m(\u001b[0m\n",
"\u001b[2;32m│ │ │ \u001b[0m\u001b[33maggregated_results\u001b[0m=\u001b[1m{\u001b[0m\u001b[32m'average'\u001b[0m: \u001b[1m{\u001b[0m\u001b[32m'average'\u001b[0m: \u001b[1;36m0.6\u001b[0m\u001b[1m}\u001b[0m\u001b[1m}\u001b[0m,\n",
"\u001b[2;32m│ │ │ \u001b[0m\u001b[33mscore_rows\u001b[0m=\u001b[1m[\u001b[0m\n",
"\u001b[2;32m│ │ │ │ \u001b[0m\u001b[1m{\u001b[0m\n",
"\u001b[2;32m│ │ │ │ │ \u001b[0m\u001b[32m'score'\u001b[0m: \u001b[1;36m0.6\u001b[0m,\n",
"\u001b[2;32m│ │ │ │ │ \u001b[0m\u001b[32m'metadata'\u001b[0m: \u001b[1m{\u001b[0m\n",
"\u001b[2;32m│ │ │ │ │ │ \u001b[0m\u001b[32m'choice'\u001b[0m: \u001b[32m'B'\u001b[0m,\n",
"\u001b[2;32m│ │ │ │ │ │ \u001b[0m\u001b[32m'rationale'\u001b[0m: \u001b[32m'1. Both the expert and the submitted answers mention that Torchtune supports two precision formats: `fp32` \u001b[0m\u001b[32m(\u001b[0m\u001b[32mfull-precision\u001b[0m\u001b[32m)\u001b[0m\u001b[32m and `bfloat16` \u001b[0m\u001b[32m(\u001b[0m\u001b[32mhalf-precision\u001b[0m\u001b[32m)\u001b[0m\u001b[32m.\\n2. The expert answer specifies that `fp32` uses 4 bytes per model and optimizer parameter, while `bfloat16` uses 2 bytes per model and optimizer parameter.\\n3. The submitted answer also mentions that `bfloat16` uses 2 bytes per model parameter, which is consistent with the expert answer.\\n4. The submitted answer adds that `bfloat16` improves training speed, which is additional information not present in the expert answer.\\n5. There is no conflict between the submitted answer and the expert answer; the submitted answer simply provides more information.\\n\\nBased on this analysis, the submitted answer is a superset of the expert answer and is fully consistent with it.'\u001b[0m\n",
"\u001b[2;32m│ │ │ │ │ \u001b[0m\u001b[1m}\u001b[0m\n",
"\u001b[2;32m│ │ │ │ \u001b[0m\u001b[1m}\u001b[0m,\n",
"\u001b[2;32m│ │ │ │ \u001b[0m\u001b[1m{\u001b[0m\n",
"\u001b[2;32m│ │ │ │ │ \u001b[0m\u001b[32m'score'\u001b[0m: \u001b[1;36m0.6\u001b[0m,\n",
"\u001b[2;32m│ │ │ │ │ \u001b[0m\u001b[32m'metadata'\u001b[0m: \u001b[1m{\u001b[0m\n",
"\u001b[2;32m│ │ │ │ │ │ \u001b[0m\u001b[32m'choice'\u001b[0m: \u001b[32m'B'\u001b[0m,\n",
"\u001b[2;32m│ │ │ │ │ │ \u001b[0m\u001b[32m'rationale'\u001b[0m: \u001b[32m'1. The expert answer provides the definition of DoRA as \"Weight-Decomposed Low-Rank Adaptation.\"\\n2. The submitted answer also states that DoRA stands for \"Weight-Decomposed Low-Rank Adaptation,\" which matches the expert answer.\\n3. The submitted answer includes additional information about DoRA, explaining that it is a variant of LoRA and describing how it decomposes pre-trained weights into magnitude and direction components.\\n4. The submitted answer further explains the role of the magnitude component and the direction component, and mentions the performance improvement and overhead associated with DoRA.\\n5. The additional details in the submitted answer do not contradict the expert answer; instead, they expand upon it.\\n6. Therefore, the submitted answer is a superset of the expert answer and is fully consistent with it.'\u001b[0m\n",
"\u001b[2;32m│ │ │ │ │ \u001b[0m\u001b[1m}\u001b[0m\n",
"\u001b[2;32m│ │ │ │ \u001b[0m\u001b[1m}\u001b[0m,\n",
"\u001b[2;32m│ │ │ │ \u001b[0m\u001b[1m{\u001b[0m\n",
"\u001b[2;32m│ │ │ │ │ \u001b[0m\u001b[32m'score'\u001b[0m: \u001b[1;36m0.6\u001b[0m,\n",
"\u001b[2;32m│ │ │ │ │ \u001b[0m\u001b[32m'metadata'\u001b[0m: \u001b[1m{\u001b[0m\n",
"\u001b[2;32m│ │ │ │ │ │ \u001b[0m\u001b[32m'choice'\u001b[0m: \u001b[32m'B'\u001b[0m,\n",
"\u001b[2;32m│ │ │ │ │ │ \u001b[0m\u001b[32m'rationale'\u001b[0m: \u001b[32m'1. The expert answer states that the CPUOffloadOptimizer reduces GPU memory usage by keeping optimizer states on CPU and performing optimizer steps on CPU. It also mentions the optional offloading of gradients to CPU with the parameter \u001b[0m\u001b[32moffload_gradients\u001b[0m\u001b[32m=\u001b[0m\u001b[32mTrue\u001b[0m\u001b[32m.\\n\\n2. The submitted answer states that the CPUOffloadOptimizer reduces GPU memory usage by offloading optimizer states and gradients to the CPU, and performing optimizer steps on the CPU. It adds that this can significantly reduce GPU memory usage at the cost of CPU RAM and training speed.\\n\\n3. Comparing both answers:\\n - Both answers agree on offloading optimizer states to the CPU and performing optimizer steps on the CPU.\\n - Both mention the offloading of gradients to the CPU, but the expert answer specifies it as optional with a parameter, while the submission does not specify this detail.\\n - The submission adds additional information about the trade-off involving CPU RAM and training speed, which is not mentioned in the expert answer.\\n\\n4. The submitted answer includes all the details from the expert answer and adds more information about the trade-offs, making it a superset of the expert answer.\\n\\nTherefore, the correct choice is \u001b[0m\u001b[32m(\u001b[0m\u001b[32mB\u001b[0m\u001b[32m)\u001b[0m\u001b[32m The submitted answer is a superset of the expert answer and is fully consistent with it.'\u001b[0m\n",
"\u001b[2;32m│ │ │ │ │ \u001b[0m\u001b[1m}\u001b[0m\n",
"\u001b[2;32m│ │ │ │ \u001b[0m\u001b[1m}\u001b[0m,\n",
"\u001b[2;32m│ │ │ │ \u001b[0m\u001b[1m{\u001b[0m\n",
"\u001b[2;32m│ │ │ │ │ \u001b[0m\u001b[32m'score'\u001b[0m: \u001b[1;36m0.6\u001b[0m,\n",
"\u001b[2;32m│ │ │ │ │ \u001b[0m\u001b[32m'metadata'\u001b[0m: \u001b[1m{\u001b[0m\n",
"\u001b[2;32m│ │ │ │ │ │ \u001b[0m\u001b[32m'choice'\u001b[0m: \u001b[32m'B'\u001b[0m,\n",
"\u001b[2;32m│ │ │ │ │ │ \u001b[0m\u001b[32m'rationale'\u001b[0m: \u001b[32m\"1. **Expert Answer Analysis**: The expert answer provides a method to ensure only LoRA parameters are trainable by using torchtune's utility functions. It mentions fetching LoRA parameters with `get_adapter_params\u001b[0m\u001b[32m(\u001b[0m\u001b[32mlora_model\u001b[0m\u001b[32m)\u001b[0m\u001b[32m` and setting them as trainable with `set_trainable_params\u001b[0m\u001b[32m(\u001b[0m\u001b[32mlora_model, lora_params\u001b[0m\u001b[32m)\u001b[0m\u001b[32m`. It also notes that the LoRA recipe handles this automatically.\\n\\n2. **Submitted Answer Analysis**: The submitted answer provides a similar method using `set_trainable_params` to set the `requires_grad` attribute of LoRA parameters to `True` and other parameters to `False`. It includes a code example demonstrating this process. Additionally, it mentions using the `lora_finetune` recipe in torchtune, which automatically sets the LoRA parameters to trainable.\\n\\n3. **Comparison**: The submitted answer includes all the details from the expert answer regarding the use of `get_adapter_params` and `set_trainable_params`. It also provides additional information about setting the `requires_grad` attribute and using the `lora_finetune` recipe, which is not mentioned in the expert answer.\\n\\n4. **Conclusion**: The submitted answer is a superset of the expert answer as it contains all the information from the expert answer and additional details. There is no conflict between the two answers, and the additional information in the submission is consistent with the expert's explanation.\\n\\nTherefore, the correct choice is \u001b[0m\u001b[32m(\u001b[0m\u001b[32mB\u001b[0m\u001b[32m)\u001b[0m\u001b[32m The submitted answer is a superset of the expert answer and is fully consistent with it.\"\u001b[0m\n",
"\u001b[2;32m│ │ │ │ │ \u001b[0m\u001b[1m}\u001b[0m\n",
"\u001b[2;32m│ │ │ │ \u001b[0m\u001b[1m}\u001b[0m\n",
"\u001b[2;32m│ │ │ \u001b[0m\u001b[1m]\u001b[0m\n",
"\u001b[2;32m│ │ \u001b[0m\u001b[1m)\u001b[0m\n",
"\u001b[2;32m│ \u001b[0m\u001b[1m}\u001b[0m\n",
"\u001b[1m)\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"eval_rows = []\n",
"for i, session_id in enumerate(rag_attachment_agent.sessions):\n",
" session_response = client.agents.session.retrieve(agent_id=rag_attachment_agent.agent_id, session_id=session_id)\n",
" for turn in session_response.turns:\n",
" eval_rows.append({\n",
" \"input_query\": examples[i][\"input_query\"],\n",
" \"expected_answer\": examples[i][\"expected_answer\"],\n",
" \"generated_answer\": turn.output_message.content,\n",
" })\n",
"\n",
"scoring_params = {\n",
" \"braintrust::factuality\": None,\n",
"}\n",
"scoring_response = client.scoring.score(\n",
" input_rows=eval_rows,\n",
" scoring_functions=scoring_params,\n",
")\n",
"pprint(scoring_response)"
]
},
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},

View file

@ -14,7 +14,7 @@ Agents are configured using the `AgentConfig` class, which includes:
- **Safety Shields**: Guardrails to ensure responsible AI behavior - **Safety Shields**: Guardrails to ensure responsible AI behavior
```python ```python
from llama_stack_client.lib.agents.agent import Agent from llama_stack_client import Agent
# Create the agent # Create the agent
@ -44,14 +44,14 @@ Each interaction with an agent is called a "turn" and consists of:
- **Output Message**: The agent's response - **Output Message**: The agent's response
```python ```python
from llama_stack_client.lib.agents.event_logger import EventLogger from llama_stack_client import AgentEventLogger
# Create a turn with streaming response # Create a turn with streaming response
turn_response = agent.create_turn( turn_response = agent.create_turn(
session_id=session_id, session_id=session_id,
messages=[{"role": "user", "content": "Tell me about Llama models"}], messages=[{"role": "user", "content": "Tell me about Llama models"}],
) )
for log in EventLogger().log(turn_response): for log in AgentEventLogger().log(turn_response):
log.print() log.print()
``` ```
### Non-Streaming ### Non-Streaming

View file

@ -67,9 +67,7 @@ sequenceDiagram
Each step in this process can be monitored and controlled through configurations. Here's an example that demonstrates monitoring the agent's execution: Each step in this process can be monitored and controlled through configurations. Here's an example that demonstrates monitoring the agent's execution:
```python ```python
from llama_stack_client import LlamaStackClient from llama_stack_client import LlamaStackClient, Agent, AgentEventLogger
from llama_stack_client.lib.agents.agent import Agent
from llama_stack_client.lib.agents.event_logger import EventLogger
from rich.pretty import pprint from rich.pretty import pprint
# Replace host and port # Replace host and port
@ -113,7 +111,7 @@ response = agent.create_turn(
) )
# Monitor each step of execution # Monitor each step of execution
for log in EventLogger().log(response): for log in AgentEventLogger().log(response):
log.print() log.print()
# Using non-streaming API, the response contains input, steps, and output. # Using non-streaming API, the response contains input, steps, and output.

View file

@ -23,9 +23,7 @@ In this example, we will show you how to:
##### Building a Search Agent ##### Building a Search Agent
```python ```python
from llama_stack_client import LlamaStackClient from llama_stack_client import LlamaStackClient, Agent, AgentEventLogger
from llama_stack_client.lib.agents.agent import Agent
from llama_stack_client.lib.agents.event_logger import EventLogger
client = LlamaStackClient(base_url=f"http://{HOST}:{PORT}") client = LlamaStackClient(base_url=f"http://{HOST}:{PORT}")
@ -54,7 +52,7 @@ for prompt in user_prompts:
session_id=session_id, session_id=session_id,
) )
for log in EventLogger().log(response): for log in AgentEventLogger().log(response):
log.print() log.print()
``` ```

View file

@ -55,11 +55,11 @@ chunks_response = client.vector_io.query(
A better way to ingest documents is to use the RAG Tool. This tool allows you to ingest documents from URLs, files, etc. and automatically chunks them into smaller pieces. A better way to ingest documents is to use the RAG Tool. This tool allows you to ingest documents from URLs, files, etc. and automatically chunks them into smaller pieces.
```python ```python
from llama_stack_client.types import Document from llama_stack_client import RAGDocument
urls = ["memory_optimizations.rst", "chat.rst", "llama3.rst"] urls = ["memory_optimizations.rst", "chat.rst", "llama3.rst"]
documents = [ documents = [
Document( RAGDocument(
document_id=f"num-{i}", document_id=f"num-{i}",
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}", content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
mime_type="text/plain", mime_type="text/plain",
@ -86,7 +86,7 @@ results = client.tool_runtime.rag_tool.query(
One of the most powerful patterns is combining agents with RAG capabilities. Here's a complete example: One of the most powerful patterns is combining agents with RAG capabilities. Here's a complete example:
```python ```python
from llama_stack_client.lib.agents.agent import Agent from llama_stack_client import Agent
# Create agent with memory # Create agent with memory
agent = Agent( agent = Agent(
@ -140,9 +140,9 @@ response = agent.create_turn(
You can print the response with below. You can print the response with below.
```python ```python
from llama_stack_client.lib.agents.event_logger import EventLogger from llama_stack_client import AgentEventLogger
for log in EventLogger().log(response): for log in AgentEventLogger().log(response):
log.print() log.print()
``` ```

View file

@ -57,7 +57,7 @@ The `otel` sink works with any service compatible with the OpenTelemetry collect
Start a Jaeger instance with the OTLP HTTP endpoint at 4318 and the Jaeger UI at 16686 using the following command: Start a Jaeger instance with the OTLP HTTP endpoint at 4318 and the Jaeger UI at 16686 using the following command:
```bash ```bash
$ docker run --rm --name jaeger \ $ docker run --pull always --rm --name jaeger \
-p 16686:16686 -p 4318:4318 \ -p 16686:16686 -p 4318:4318 \
jaegertracing/jaeger:2.1.0 jaegertracing/jaeger:2.1.0
``` ```

View file

@ -110,10 +110,18 @@ MCP tools are special tools that can interact with llama stack over model contex
Refer to [https://github.com/modelcontextprotocol/servers](https://github.com/modelcontextprotocol/servers) for available MCP servers. Refer to [https://github.com/modelcontextprotocol/servers](https://github.com/modelcontextprotocol/servers) for available MCP servers.
```shell
# start your MCP server
mkdir /tmp/content
touch /tmp/content/foo
touch /tmp/content/bar
npx -y supergateway --port 8000 --stdio 'npx -y @modelcontextprotocol/server-filesystem /tmp/content'
```
Then register the MCP server as a tool group,
```python ```python
# Register MCP tools
client.toolgroups.register( client.toolgroups.register(
toolgroup_id="builtin::filesystem", toolgroup_id="mcp::filesystem",
provider_id="model-context-protocol", provider_id="model-context-protocol",
mcp_endpoint=URL(uri="http://localhost:8000/sse"), mcp_endpoint=URL(uri="http://localhost:8000/sse"),
) )
@ -181,7 +189,7 @@ group_tools = client.tools.list_tools(toolgroup_id="search_tools")
## Simple Example: Using an Agent with the Code-Interpreter Tool ## Simple Example: Using an Agent with the Code-Interpreter Tool
```python ```python
from llama_stack_client.lib.agents.agent import Agent from llama_stack_client import Agent
# Instantiate the AI agent with the given configuration # Instantiate the AI agent with the given configuration
agent = Agent( agent = Agent(

View file

@ -55,7 +55,7 @@ llama stack run llama_stack/templates/open-benchmark/run.yaml
There are 3 necessary inputs to run a benchmark eval There are 3 necessary inputs to run a benchmark eval
- `list of benchmark_ids`: The list of benchmark ids to run evaluation on - `list of benchmark_ids`: The list of benchmark ids to run evaluation on
- `model-id`: The model id to evaluate on - `model-id`: The model id to evaluate on
- `utput_dir`: Path to store the evaluate results - `output_dir`: Path to store the evaluate results
``` ```
llama-stack-client eval run-benchmark <benchmark_id_1> <benchmark_id_2> ... \ llama-stack-client eval run-benchmark <benchmark_id_1> <benchmark_id_2> ... \
--model_id <model id to evaluate on> \ --model_id <model id to evaluate on> \
@ -69,7 +69,7 @@ llama-stack-client eval run-benchmark help
to see the description of all the flags that eval run-benchmark has to see the description of all the flags that eval run-benchmark has
In the output log, you can find the file path that has your evaluation results. Open that file and you can see you aggrgate In the output log, you can find the file path that has your evaluation results. Open that file and you can see you aggregate
evaluation results over there. evaluation results over there.

View file

@ -58,9 +58,10 @@ You can do this via Conda (build code) or Docker which has a pre-built image.
This method allows you to get started quickly without having to build the distribution code. This method allows you to get started quickly without having to build the distribution code.
```bash ```bash
LLAMA_STACK_PORT=5001 LLAMA_STACK_PORT=8321
docker run \ docker run \
-it \ -it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ./run.yaml:/root/my-run.yaml \ -v ./run.yaml:/root/my-run.yaml \
llamastack/distribution-nvidia \ llamastack/distribution-nvidia \
@ -74,7 +75,7 @@ docker run \
```bash ```bash
llama stack build --template nvidia --image-type conda llama stack build --template nvidia --image-type conda
llama stack run ./run.yaml \ llama stack run ./run.yaml \
--port 5001 \ --port 8321 \
--env NVIDIA_API_KEY=$NVIDIA_API_KEY --env NVIDIA_API_KEY=$NVIDIA_API_KEY
--env INFERENCE_MODEL=$INFERENCE_MODEL --env INFERENCE_MODEL=$INFERENCE_MODEL
``` ```

View file

@ -28,7 +28,7 @@ The `llamastack/distribution-bedrock` distribution consists of the following pro
The following environment variables can be configured: The following environment variables can be configured:
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `5001`) - `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `8321`)
### Models ### Models
@ -53,9 +53,10 @@ You can do this via Conda (build code) or Docker which has a pre-built image.
This method allows you to get started quickly without having to build the distribution code. This method allows you to get started quickly without having to build the distribution code.
```bash ```bash
LLAMA_STACK_PORT=5001 LLAMA_STACK_PORT=8321
docker run \ docker run \
-it \ -it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
llamastack/distribution-bedrock \ llamastack/distribution-bedrock \
--port $LLAMA_STACK_PORT \ --port $LLAMA_STACK_PORT \

View file

@ -20,7 +20,7 @@ The `llamastack/distribution-cerebras` distribution consists of the following pr
The following environment variables can be configured: The following environment variables can be configured:
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `5001`) - `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `8321`)
- `CEREBRAS_API_KEY`: Cerebras API Key (default: ``) - `CEREBRAS_API_KEY`: Cerebras API Key (default: ``)
### Models ### Models
@ -45,9 +45,10 @@ You can do this via Conda (build code) or Docker which has a pre-built image.
This method allows you to get started quickly without having to build the distribution code. This method allows you to get started quickly without having to build the distribution code.
```bash ```bash
LLAMA_STACK_PORT=5001 LLAMA_STACK_PORT=8321
docker run \ docker run \
-it \ -it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ./run.yaml:/root/my-run.yaml \ -v ./run.yaml:/root/my-run.yaml \
llamastack/distribution-cerebras \ llamastack/distribution-cerebras \
@ -61,6 +62,6 @@ docker run \
```bash ```bash
llama stack build --template cerebras --image-type conda llama stack build --template cerebras --image-type conda
llama stack run ./run.yaml \ llama stack run ./run.yaml \
--port 5001 \ --port 8321 \
--env CEREBRAS_API_KEY=$CEREBRAS_API_KEY --env CEREBRAS_API_KEY=$CEREBRAS_API_KEY
``` ```

View file

@ -53,7 +53,7 @@ docker compose down
#### Start Dell-TGI server locally #### Start Dell-TGI server locally
``` ```
docker run -it --shm-size 1g -p 80:80 --gpus 4 \ docker run -it --pull always --shm-size 1g -p 80:80 --gpus 4 \
-e NUM_SHARD=4 -e NUM_SHARD=4
-e MAX_BATCH_PREFILL_TOKENS=32768 \ -e MAX_BATCH_PREFILL_TOKENS=32768 \
-e MAX_INPUT_TOKENS=8000 \ -e MAX_INPUT_TOKENS=8000 \
@ -65,7 +65,7 @@ registry.dell.huggingface.co/enterprise-dell-inference-meta-llama-meta-llama-3.1
#### Start Llama Stack server pointing to TGI server #### Start Llama Stack server pointing to TGI server
``` ```
docker run --network host -it -p 8321:8321 -v ./run.yaml:/root/my-run.yaml --gpus=all llamastack/distribution-tgi --yaml_config /root/my-run.yaml docker run --pull always --network host -it -p 8321:8321 -v ./run.yaml:/root/my-run.yaml --gpus=all llamastack/distribution-tgi --yaml_config /root/my-run.yaml
``` ```
Make sure in you `run.yaml` file, you inference provider is pointing to the correct TGI server endpoint. E.g. Make sure in you `run.yaml` file, you inference provider is pointing to the correct TGI server endpoint. E.g.

View file

@ -55,6 +55,7 @@ export CUDA_VISIBLE_DEVICES=0
export LLAMA_STACK_PORT=8321 export LLAMA_STACK_PORT=8321
docker run --rm -it \ docker run --rm -it \
--pull always \
--network host \ --network host \
-v $HOME/.cache/huggingface:/data \ -v $HOME/.cache/huggingface:/data \
-e HF_TOKEN=$HF_TOKEN \ -e HF_TOKEN=$HF_TOKEN \
@ -78,6 +79,7 @@ export SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
export CUDA_VISIBLE_DEVICES=1 export CUDA_VISIBLE_DEVICES=1
docker run --rm -it \ docker run --rm -it \
--pull always \
--network host \ --network host \
-v $HOME/.cache/huggingface:/data \ -v $HOME/.cache/huggingface:/data \
-e HF_TOKEN=$HF_TOKEN \ -e HF_TOKEN=$HF_TOKEN \
@ -120,6 +122,7 @@ This method allows you to get started quickly without having to build the distri
```bash ```bash
docker run -it \ docker run -it \
--pull always \
--network host \ --network host \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v $HOME/.llama:/root/.llama \ -v $HOME/.llama:/root/.llama \
@ -147,6 +150,7 @@ export SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
docker run \ docker run \
-it \ -it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v $HOME/.llama:/root/.llama \ -v $HOME/.llama:/root/.llama \
-v ./llama_stack/templates/tgi/run-with-safety.yaml:/root/my-run.yaml \ -v ./llama_stack/templates/tgi/run-with-safety.yaml:/root/my-run.yaml \

View file

@ -30,7 +30,7 @@ The `llamastack/distribution-fireworks` distribution consists of the following p
The following environment variables can be configured: The following environment variables can be configured:
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `5001`) - `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `8321`)
- `FIREWORKS_API_KEY`: Fireworks.AI API Key (default: ``) - `FIREWORKS_API_KEY`: Fireworks.AI API Key (default: ``)
### Models ### Models
@ -63,9 +63,10 @@ You can do this via Conda (build code) or Docker which has a pre-built image.
This method allows you to get started quickly without having to build the distribution code. This method allows you to get started quickly without having to build the distribution code.
```bash ```bash
LLAMA_STACK_PORT=5001 LLAMA_STACK_PORT=8321
docker run \ docker run \
-it \ -it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
llamastack/distribution-fireworks \ llamastack/distribution-fireworks \
--port $LLAMA_STACK_PORT \ --port $LLAMA_STACK_PORT \

View file

@ -30,7 +30,7 @@ The `llamastack/distribution-groq` distribution consists of the following provid
The following environment variables can be configured: The following environment variables can be configured:
- `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `5001`) - `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `8321`)
- `GROQ_API_KEY`: Groq API Key (default: ``) - `GROQ_API_KEY`: Groq API Key (default: ``)
### Models ### Models
@ -58,9 +58,10 @@ You can do this via Conda (build code) or Docker which has a pre-built image.
This method allows you to get started quickly without having to build the distribution code. This method allows you to get started quickly without having to build the distribution code.
```bash ```bash
LLAMA_STACK_PORT=5001 LLAMA_STACK_PORT=8321
docker run \ docker run \
-it \ -it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
llamastack/distribution-groq \ llamastack/distribution-groq \
--port $LLAMA_STACK_PORT \ --port $LLAMA_STACK_PORT \

View file

@ -32,7 +32,7 @@ Note that you need access to nvidia GPUs to run this distribution. This distribu
The following environment variables can be configured: The following environment variables can be configured:
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `5001`) - `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `8321`)
- `INFERENCE_MODEL`: Inference model loaded into the Meta Reference server (default: `meta-llama/Llama-3.2-3B-Instruct`) - `INFERENCE_MODEL`: Inference model loaded into the Meta Reference server (default: `meta-llama/Llama-3.2-3B-Instruct`)
- `INFERENCE_CHECKPOINT_DIR`: Directory containing the Meta Reference model checkpoint (default: `null`) - `INFERENCE_CHECKPOINT_DIR`: Directory containing the Meta Reference model checkpoint (default: `null`)
- `SAFETY_MODEL`: Name of the safety (Llama-Guard) model to use (default: `meta-llama/Llama-Guard-3-1B`) - `SAFETY_MODEL`: Name of the safety (Llama-Guard) model to use (default: `meta-llama/Llama-Guard-3-1B`)
@ -77,9 +77,10 @@ You can do this via Conda (build code) or Docker which has a pre-built image.
This method allows you to get started quickly without having to build the distribution code. This method allows you to get started quickly without having to build the distribution code.
```bash ```bash
LLAMA_STACK_PORT=5001 LLAMA_STACK_PORT=8321
docker run \ docker run \
-it \ -it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ~/.llama:/root/.llama \ -v ~/.llama:/root/.llama \
llamastack/distribution-meta-reference-gpu \ llamastack/distribution-meta-reference-gpu \
@ -92,6 +93,7 @@ If you are using Llama Stack Safety / Shield APIs, use:
```bash ```bash
docker run \ docker run \
-it \ -it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ~/.llama:/root/.llama \ -v ~/.llama:/root/.llama \
llamastack/distribution-meta-reference-gpu \ llamastack/distribution-meta-reference-gpu \
@ -107,7 +109,7 @@ Make sure you have done `uv pip install llama-stack` and have the Llama Stack CL
```bash ```bash
llama stack build --template meta-reference-gpu --image-type conda llama stack build --template meta-reference-gpu --image-type conda
llama stack run distributions/meta-reference-gpu/run.yaml \ llama stack run distributions/meta-reference-gpu/run.yaml \
--port 5001 \ --port 8321 \
--env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct --env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
``` ```
@ -115,7 +117,7 @@ If you are using Llama Stack Safety / Shield APIs, use:
```bash ```bash
llama stack run distributions/meta-reference-gpu/run-with-safety.yaml \ llama stack run distributions/meta-reference-gpu/run-with-safety.yaml \
--port 5001 \ --port 8321 \
--env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \ --env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \
--env SAFETY_MODEL=meta-llama/Llama-Guard-3-1B --env SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
``` ```

View file

@ -34,7 +34,7 @@ Note that you need access to nvidia GPUs to run this distribution. This distribu
The following environment variables can be configured: The following environment variables can be configured:
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `5001`) - `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `8321`)
- `INFERENCE_MODEL`: Inference model loaded into the Meta Reference server (default: `meta-llama/Llama-3.2-3B-Instruct`) - `INFERENCE_MODEL`: Inference model loaded into the Meta Reference server (default: `meta-llama/Llama-3.2-3B-Instruct`)
- `INFERENCE_CHECKPOINT_DIR`: Directory containing the Meta Reference model checkpoint (default: `null`) - `INFERENCE_CHECKPOINT_DIR`: Directory containing the Meta Reference model checkpoint (default: `null`)
@ -77,9 +77,10 @@ You can do this via Conda (build code) or Docker which has a pre-built image.
This method allows you to get started quickly without having to build the distribution code. This method allows you to get started quickly without having to build the distribution code.
```bash ```bash
LLAMA_STACK_PORT=5001 LLAMA_STACK_PORT=8321
docker run \ docker run \
-it \ -it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ~/.llama:/root/.llama \ -v ~/.llama:/root/.llama \
llamastack/distribution-meta-reference-quantized-gpu \ llamastack/distribution-meta-reference-quantized-gpu \
@ -92,6 +93,7 @@ If you are using Llama Stack Safety / Shield APIs, use:
```bash ```bash
docker run \ docker run \
-it \ -it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ~/.llama:/root/.llama \ -v ~/.llama:/root/.llama \
llamastack/distribution-meta-reference-quantized-gpu \ llamastack/distribution-meta-reference-quantized-gpu \

View file

@ -15,7 +15,7 @@ The `llamastack/distribution-nvidia` distribution consists of the following prov
The following environment variables can be configured: The following environment variables can be configured:
- `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `5001`) - `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `8321`)
- `NVIDIA_API_KEY`: NVIDIA API Key (default: ``) - `NVIDIA_API_KEY`: NVIDIA API Key (default: ``)
### Models ### Models
@ -39,9 +39,10 @@ You can do this via Conda (build code) or Docker which has a pre-built image.
This method allows you to get started quickly without having to build the distribution code. This method allows you to get started quickly without having to build the distribution code.
```bash ```bash
LLAMA_STACK_PORT=5001 LLAMA_STACK_PORT=8321
docker run \ docker run \
-it \ -it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ./run.yaml:/root/my-run.yaml \ -v ./run.yaml:/root/my-run.yaml \
llamastack/distribution-nvidia \ llamastack/distribution-nvidia \
@ -55,6 +56,6 @@ docker run \
```bash ```bash
llama stack build --template nvidia --image-type conda llama stack build --template nvidia --image-type conda
llama stack run ./run.yaml \ llama stack run ./run.yaml \
--port 5001 \ --port 8321 \
--env NVIDIA_API_KEY=$NVIDIA_API_KEY --env NVIDIA_API_KEY=$NVIDIA_API_KEY
``` ```

View file

@ -32,7 +32,7 @@ You should use this distribution if you have a regular desktop machine without v
The following environment variables can be configured: The following environment variables can be configured:
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `5001`) - `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `8321`)
- `OLLAMA_URL`: URL of the Ollama server (default: `http://127.0.0.1:11434`) - `OLLAMA_URL`: URL of the Ollama server (default: `http://127.0.0.1:11434`)
- `INFERENCE_MODEL`: Inference model loaded into the Ollama server (default: `meta-llama/Llama-3.2-3B-Instruct`) - `INFERENCE_MODEL`: Inference model loaded into the Ollama server (default: `meta-llama/Llama-3.2-3B-Instruct`)
- `SAFETY_MODEL`: Safety model loaded into the Ollama server (default: `meta-llama/Llama-Guard-3-1B`) - `SAFETY_MODEL`: Safety model loaded into the Ollama server (default: `meta-llama/Llama-Guard-3-1B`)
@ -71,9 +71,10 @@ Now you are ready to run Llama Stack with Ollama as the inference provider. You
This method allows you to get started quickly without having to build the distribution code. This method allows you to get started quickly without having to build the distribution code.
```bash ```bash
export LLAMA_STACK_PORT=5001 export LLAMA_STACK_PORT=8321
docker run \ docker run \
-it \ -it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ~/.llama:/root/.llama \ -v ~/.llama:/root/.llama \
llamastack/distribution-ollama \ llamastack/distribution-ollama \
@ -91,6 +92,7 @@ cd /path/to/llama-stack
docker run \ docker run \
-it \ -it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ~/.llama:/root/.llama \ -v ~/.llama:/root/.llama \
-v ./llama_stack/templates/ollama/run-with-safety.yaml:/root/my-run.yaml \ -v ./llama_stack/templates/ollama/run-with-safety.yaml:/root/my-run.yaml \
@ -107,7 +109,7 @@ docker run \
Make sure you have done `uv pip install llama-stack` and have the Llama Stack CLI available. Make sure you have done `uv pip install llama-stack` and have the Llama Stack CLI available.
```bash ```bash
export LLAMA_STACK_PORT=5001 export LLAMA_STACK_PORT=8321
llama stack build --template ollama --image-type conda llama stack build --template ollama --image-type conda
llama stack run ./run.yaml \ llama stack run ./run.yaml \

View file

@ -30,7 +30,7 @@ The `llamastack/distribution-passthrough` distribution consists of the following
The following environment variables can be configured: The following environment variables can be configured:
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `5001`) - `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `8321`)
- `PASSTHROUGH_API_KEY`: Passthrough API Key (default: ``) - `PASSTHROUGH_API_KEY`: Passthrough API Key (default: ``)
- `PASSTHROUGH_URL`: Passthrough URL (default: ``) - `PASSTHROUGH_URL`: Passthrough URL (default: ``)

View file

@ -31,7 +31,7 @@ You can use this distribution if you have GPUs and want to run an independent vL
The following environment variables can be configured: The following environment variables can be configured:
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `5001`) - `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `8321`)
- `INFERENCE_MODEL`: Inference model loaded into the vLLM server (default: `meta-llama/Llama-3.2-3B-Instruct`) - `INFERENCE_MODEL`: Inference model loaded into the vLLM server (default: `meta-llama/Llama-3.2-3B-Instruct`)
- `VLLM_URL`: URL of the vLLM server with the main inference model (default: `http://host.docker.internal:5100/v1`) - `VLLM_URL`: URL of the vLLM server with the main inference model (default: `http://host.docker.internal:5100/v1`)
- `MAX_TOKENS`: Maximum number of tokens for generation (default: `4096`) - `MAX_TOKENS`: Maximum number of tokens for generation (default: `4096`)
@ -49,6 +49,7 @@ export INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
export CUDA_VISIBLE_DEVICES=0 export CUDA_VISIBLE_DEVICES=0
docker run \ docker run \
--pull always \
--runtime nvidia \ --runtime nvidia \
--gpus $CUDA_VISIBLE_DEVICES \ --gpus $CUDA_VISIBLE_DEVICES \
-v ~/.cache/huggingface:/root/.cache/huggingface \ -v ~/.cache/huggingface:/root/.cache/huggingface \
@ -61,6 +62,8 @@ docker run \
--port $INFERENCE_PORT --port $INFERENCE_PORT
``` ```
Note that you'll also need to set `--enable-auto-tool-choice` and `--tool-call-parser` to [enable tool calling in vLLM](https://docs.vllm.ai/en/latest/features/tool_calling.html).
If you are using Llama Stack Safety / Shield APIs, then you will need to also run another instance of a vLLM with a corresponding safety model like `meta-llama/Llama-Guard-3-1B` using a script like: If you are using Llama Stack Safety / Shield APIs, then you will need to also run another instance of a vLLM with a corresponding safety model like `meta-llama/Llama-Guard-3-1B` using a script like:
```bash ```bash
@ -69,6 +72,7 @@ export SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
export CUDA_VISIBLE_DEVICES=1 export CUDA_VISIBLE_DEVICES=1
docker run \ docker run \
--pull always \
--runtime nvidia \ --runtime nvidia \
--gpus $CUDA_VISIBLE_DEVICES \ --gpus $CUDA_VISIBLE_DEVICES \
-v ~/.cache/huggingface:/root/.cache/huggingface \ -v ~/.cache/huggingface:/root/.cache/huggingface \
@ -92,10 +96,11 @@ This method allows you to get started quickly without having to build the distri
```bash ```bash
export INFERENCE_PORT=8000 export INFERENCE_PORT=8000
export INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct export INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
export LLAMA_STACK_PORT=5001 export LLAMA_STACK_PORT=8321
docker run \ docker run \
-it \ -it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ./run.yaml:/root/my-run.yaml \ -v ./run.yaml:/root/my-run.yaml \
llamastack/distribution-remote-vllm \ llamastack/distribution-remote-vllm \
@ -117,6 +122,7 @@ cd /path/to/llama-stack
docker run \ docker run \
-it \ -it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ~/.llama:/root/.llama \ -v ~/.llama:/root/.llama \
-v ./llama_stack/templates/remote-vllm/run-with-safety.yaml:/root/my-run.yaml \ -v ./llama_stack/templates/remote-vllm/run-with-safety.yaml:/root/my-run.yaml \
@ -137,7 +143,7 @@ Make sure you have done `uv pip install llama-stack` and have the Llama Stack CL
```bash ```bash
export INFERENCE_PORT=8000 export INFERENCE_PORT=8000
export INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct export INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
export LLAMA_STACK_PORT=5001 export LLAMA_STACK_PORT=8321
cd distributions/remote-vllm cd distributions/remote-vllm
llama stack build --template remote-vllm --image-type conda llama stack build --template remote-vllm --image-type conda

View file

@ -27,7 +27,7 @@ The `llamastack/distribution-sambanova` distribution consists of the following p
The following environment variables can be configured: The following environment variables can be configured:
- `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `5001`) - `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `8321`)
- `SAMBANOVA_API_KEY`: SambaNova API Key (default: ``) - `SAMBANOVA_API_KEY`: SambaNova API Key (default: ``)
### Models ### Models
@ -59,9 +59,10 @@ You can do this via Conda (build code) or Docker which has a pre-built image.
This method allows you to get started quickly without having to build the distribution code. This method allows you to get started quickly without having to build the distribution code.
```bash ```bash
LLAMA_STACK_PORT=5001 LLAMA_STACK_PORT=8321
docker run \ docker run \
-it \ -it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
llamastack/distribution-sambanova \ llamastack/distribution-sambanova \
--port $LLAMA_STACK_PORT \ --port $LLAMA_STACK_PORT \

View file

@ -33,7 +33,7 @@ You can use this distribution if you have GPUs and want to run an independent TG
The following environment variables can be configured: The following environment variables can be configured:
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `5001`) - `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `8321`)
- `INFERENCE_MODEL`: Inference model loaded into the TGI server (default: `meta-llama/Llama-3.2-3B-Instruct`) - `INFERENCE_MODEL`: Inference model loaded into the TGI server (default: `meta-llama/Llama-3.2-3B-Instruct`)
- `TGI_URL`: URL of the TGI server with the main inference model (default: `http://127.0.0.1:8080/v1`) - `TGI_URL`: URL of the TGI server with the main inference model (default: `http://127.0.0.1:8080/v1`)
- `TGI_SAFETY_URL`: URL of the TGI server with the safety model (default: `http://127.0.0.1:8081/v1`) - `TGI_SAFETY_URL`: URL of the TGI server with the safety model (default: `http://127.0.0.1:8081/v1`)
@ -50,6 +50,7 @@ export INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
export CUDA_VISIBLE_DEVICES=0 export CUDA_VISIBLE_DEVICES=0
docker run --rm -it \ docker run --rm -it \
--pull always \
-v $HOME/.cache/huggingface:/data \ -v $HOME/.cache/huggingface:/data \
-p $INFERENCE_PORT:$INFERENCE_PORT \ -p $INFERENCE_PORT:$INFERENCE_PORT \
--gpus $CUDA_VISIBLE_DEVICES \ --gpus $CUDA_VISIBLE_DEVICES \
@ -70,6 +71,7 @@ export SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
export CUDA_VISIBLE_DEVICES=1 export CUDA_VISIBLE_DEVICES=1
docker run --rm -it \ docker run --rm -it \
--pull always \
-v $HOME/.cache/huggingface:/data \ -v $HOME/.cache/huggingface:/data \
-p $SAFETY_PORT:$SAFETY_PORT \ -p $SAFETY_PORT:$SAFETY_PORT \
--gpus $CUDA_VISIBLE_DEVICES \ --gpus $CUDA_VISIBLE_DEVICES \
@ -90,9 +92,10 @@ Now you are ready to run Llama Stack with TGI as the inference provider. You can
This method allows you to get started quickly without having to build the distribution code. This method allows you to get started quickly without having to build the distribution code.
```bash ```bash
LLAMA_STACK_PORT=5001 LLAMA_STACK_PORT=8321
docker run \ docker run \
-it \ -it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
llamastack/distribution-tgi \ llamastack/distribution-tgi \
--port $LLAMA_STACK_PORT \ --port $LLAMA_STACK_PORT \
@ -109,6 +112,7 @@ cd /path/to/llama-stack
docker run \ docker run \
-it \ -it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ~/.llama:/root/.llama \ -v ~/.llama:/root/.llama \
-v ./llama_stack/templates/tgi/run-with-safety.yaml:/root/my-run.yaml \ -v ./llama_stack/templates/tgi/run-with-safety.yaml:/root/my-run.yaml \

View file

@ -30,7 +30,7 @@ The `llamastack/distribution-together` distribution consists of the following pr
The following environment variables can be configured: The following environment variables can be configured:
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `5001`) - `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `8321`)
- `TOGETHER_API_KEY`: Together.AI API Key (default: ``) - `TOGETHER_API_KEY`: Together.AI API Key (default: ``)
### Models ### Models
@ -64,9 +64,10 @@ You can do this via Conda (build code) or Docker which has a pre-built image.
This method allows you to get started quickly without having to build the distribution code. This method allows you to get started quickly without having to build the distribution code.
```bash ```bash
LLAMA_STACK_PORT=5001 LLAMA_STACK_PORT=8321
docker run \ docker run \
-it \ -it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
llamastack/distribution-together \ llamastack/distribution-together \
--port $LLAMA_STACK_PORT \ --port $LLAMA_STACK_PORT \

View file

@ -54,6 +54,7 @@ mkdir -p ~/.llama
Then you can start the server using the container tool of your choice. For example, if you are running Docker you can use the following command: Then you can start the server using the container tool of your choice. For example, if you are running Docker you can use the following command:
```bash ```bash
docker run -it \ docker run -it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ~/.llama:/root/.llama \ -v ~/.llama:/root/.llama \
llamastack/distribution-ollama \ llamastack/distribution-ollama \
@ -74,6 +75,7 @@ Docker containers run in their own isolated network namespaces on Linux. To allo
Linux users having issues running the above command should instead try the following: Linux users having issues running the above command should instead try the following:
```bash ```bash
docker run -it \ docker run -it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ~/.llama:/root/.llama \ -v ~/.llama:/root/.llama \
--network=host \ --network=host \
@ -197,9 +199,7 @@ import os
import uuid import uuid
from termcolor import cprint from termcolor import cprint
from llama_stack_client.lib.agents.agent import Agent from llama_stack_client import Agent, AgentEventLogger, RAGDocument
from llama_stack_client.lib.agents.event_logger import EventLogger
from llama_stack_client.types import Document
def create_http_client(): def create_http_client():
@ -225,7 +225,7 @@ client = (
# Documents to be used for RAG # Documents to be used for RAG
urls = ["chat.rst", "llama3.rst", "memory_optimizations.rst", "lora_finetune.rst"] urls = ["chat.rst", "llama3.rst", "memory_optimizations.rst", "lora_finetune.rst"]
documents = [ documents = [
Document( RAGDocument(
document_id=f"num-{i}", document_id=f"num-{i}",
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}", content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
mime_type="text/plain", mime_type="text/plain",
@ -284,7 +284,7 @@ for prompt in user_prompts:
messages=[{"role": "user", "content": prompt}], messages=[{"role": "user", "content": prompt}],
session_id=session_id, session_id=session_id,
) )
for log in EventLogger().log(response): for log in AgentEventLogger().log(response):
log.print() log.print()
``` ```

View file

@ -118,6 +118,7 @@ Playground can also be started in a docker image:
export LLAMA_STACK_URL=http://localhost:11434 export LLAMA_STACK_URL=http://localhost:11434
docker run \ docker run \
--pull always \
-p 8501:8501 \ -p 8501:8501 \
-e LLAMA_STACK_ENDPOINT=$LLAMA_STACK_URL \ -e LLAMA_STACK_ENDPOINT=$LLAMA_STACK_URL \
quay.io/jland/llama-stack-playground quay.io/jland/llama-stack-playground

View file

@ -48,7 +48,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"HOST = \"localhost\" # Replace with your host\n", "HOST = \"localhost\" # Replace with your host\n",
"PORT = 5001 # Replace with your port\n", "PORT = 8321 # Replace with your port\n",
"MODEL_NAME='meta-llama/Llama-3.2-3B-Instruct'" "MODEL_NAME='meta-llama/Llama-3.2-3B-Instruct'"
] ]
}, },
@ -369,6 +369,9 @@
} }
], ],
"metadata": { "metadata": {
"fileHeader": "",
"fileUid": "7da25939-a2a3-463c-958e-9cdfd710d158",
"isAdHoc": false,
"kernelspec": { "kernelspec": {
"display_name": "Python 3 (ipykernel)", "display_name": "Python 3 (ipykernel)",
"language": "python", "language": "python",
@ -386,7 +389,5 @@
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.10.15" "version": "3.10.15"
} }
}, }
"nbformat": 4,
"nbformat_minor": 5
} }

View file

@ -43,7 +43,7 @@
"source": [ "source": [
"#### 2. Set Up Local and Cloud Clients\n", "#### 2. Set Up Local and Cloud Clients\n",
"\n", "\n",
"Initialize both clients, specifying the `base_url` for each instance. In this case, we have the local distribution running on `http://localhost:8321` and the cloud distribution running on `http://localhost:5001`.\n" "Initialize both clients, specifying the `base_url` for each instance. In this case, we have the local distribution running on `http://localhost:8321` and the cloud distribution running on `http://localhost:8322`.\n"
] ]
}, },
{ {
@ -236,6 +236,9 @@
} }
], ],
"metadata": { "metadata": {
"fileHeader": "",
"fileUid": "e11939ac-dfbc-4a1c-83be-e494c7f803b8",
"isAdHoc": false,
"kernelspec": { "kernelspec": {
"display_name": "Python 3 (ipykernel)", "display_name": "Python 3 (ipykernel)",
"language": "python", "language": "python",
@ -253,7 +256,5 @@
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.10.15" "version": "3.10.15"
} }
}, }
"nbformat": 4,
"nbformat_minor": 5
} }

View file

@ -47,7 +47,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"HOST = \"localhost\" # Replace with your host\n", "HOST = \"localhost\" # Replace with your host\n",
"PORT = 5001 # Replace with your port\n", "PORT = 8321 # Replace with your port\n",
"MODEL_NAME='meta-llama/Llama-3.2-3B-Instruct'" "MODEL_NAME='meta-llama/Llama-3.2-3B-Instruct'"
] ]
}, },
@ -281,6 +281,9 @@
} }
], ],
"metadata": { "metadata": {
"fileHeader": "",
"fileUid": "b1b93b6e-22a2-4c24-8cb0-161fdafff29a",
"isAdHoc": false,
"kernelspec": { "kernelspec": {
"display_name": "base", "display_name": "base",
"language": "python", "language": "python",
@ -298,7 +301,5 @@
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.12.2" "version": "3.12.2"
} }
}, }
"nbformat": 4,
"nbformat_minor": 5
} }

View file

@ -45,7 +45,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"HOST = \"localhost\" # Replace with your host\n", "HOST = \"localhost\" # Replace with your host\n",
"CLOUD_PORT = 5001 # Replace with your cloud distro port\n", "CLOUD_PORT = 8321 # Replace with your cloud distro port\n",
"MODEL_NAME='Llama3.2-11B-Vision-Instruct'" "MODEL_NAME='Llama3.2-11B-Vision-Instruct'"
] ]
}, },
@ -180,6 +180,9 @@
} }
], ],
"metadata": { "metadata": {
"fileHeader": "",
"fileUid": "37bbbfda-8e42-446c-89c7-59dd49e2d339",
"isAdHoc": false,
"kernelspec": { "kernelspec": {
"display_name": "base", "display_name": "base",
"language": "python", "language": "python",
@ -197,7 +200,5 @@
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.12.2" "version": "3.12.2"
} }
}, }
"nbformat": 4,
"nbformat_minor": 5
} }

View file

@ -46,7 +46,7 @@
"nest_asyncio.apply()\n", "nest_asyncio.apply()\n",
"\n", "\n",
"HOST = \"localhost\"\n", "HOST = \"localhost\"\n",
"PORT = 5001\n", "PORT = 8321\n",
"MODEL_NAME = \"meta-llama/Llama-3.2-3B-Instruct\"\n" "MODEL_NAME = \"meta-llama/Llama-3.2-3B-Instruct\"\n"
] ]
}, },
@ -296,7 +296,7 @@
"\n", "\n",
" # Create an agent instance with the client and configuration\n", " # Create an agent instance with the client and configuration\n",
" agent = Agent(\n", " agent = Agent(\n",
" client, \n", " client,\n",
" model=MODEL_NAME,\n", " model=MODEL_NAME,\n",
" instructions=\"\"\"You are a helpful assistant that responds to user queries with relevant information and cites sources when available.\"\"\",\n", " instructions=\"\"\"You are a helpful assistant that responds to user queries with relevant information and cites sources when available.\"\"\",\n",
" sampling_params={\n", " sampling_params={\n",
@ -335,6 +335,9 @@
} }
], ],
"metadata": { "metadata": {
"fileHeader": "",
"fileUid": "f0abbf6d-ed52-40ad-afb4-f5ec99130249",
"isAdHoc": false,
"kernelspec": { "kernelspec": {
"display_name": "Python 3 (ipykernel)", "display_name": "Python 3 (ipykernel)",
"language": "python", "language": "python",
@ -352,7 +355,5 @@
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.10.15" "version": "3.10.15"
} }
}, }
"nbformat": 4,
"nbformat_minor": 5
} }

View file

@ -45,7 +45,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"HOST = \"localhost\" # Replace with your host\n", "HOST = \"localhost\" # Replace with your host\n",
"PORT = 5001 # Replace with your port\n", "PORT = 8321 # Replace with your port\n",
"MODEL_NAME='meta-llama/Llama-3.2-3B-Instruct'\n", "MODEL_NAME='meta-llama/Llama-3.2-3B-Instruct'\n",
"MEMORY_BANK_ID=\"tutorial_bank\"" "MEMORY_BANK_ID=\"tutorial_bank\""
] ]
@ -378,6 +378,9 @@
} }
], ],
"metadata": { "metadata": {
"fileHeader": "",
"fileUid": "73bc3357-0e5e-42ff-95b1-40b916d24c4f",
"isAdHoc": false,
"kernelspec": { "kernelspec": {
"display_name": "Python 3 (ipykernel)", "display_name": "Python 3 (ipykernel)",
"language": "python", "language": "python",
@ -395,7 +398,5 @@
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.10.15" "version": "3.10.15"
} }
}, }
"nbformat": 4,
"nbformat_minor": 4
} }

View file

@ -49,7 +49,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"HOST = \"localhost\" # Replace with your host\n", "HOST = \"localhost\" # Replace with your host\n",
"PORT = 5001 # Replace with your port\n", "PORT = 8321 # Replace with your port\n",
"SHEILD_NAME=\"meta-llama/Llama-Guard-3-1B\"" "SHEILD_NAME=\"meta-llama/Llama-Guard-3-1B\""
] ]
}, },
@ -112,6 +112,9 @@
} }
], ],
"metadata": { "metadata": {
"fileHeader": "",
"fileUid": "9afaddb7-c2fb-4309-8fa0-761697de53f0",
"isAdHoc": false,
"kernelspec": { "kernelspec": {
"display_name": "Python 3 (ipykernel)", "display_name": "Python 3 (ipykernel)",
"language": "python", "language": "python",
@ -129,7 +132,5 @@
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.11.10" "version": "3.11.10"
} }
}, }
"nbformat": 4,
"nbformat_minor": 4
} }

View file

@ -50,7 +50,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"HOST = \"localhost\" # Replace with your host\n", "HOST = \"localhost\" # Replace with your host\n",
"PORT = 5001 # Replace with your port\n", "PORT = 8321 # Replace with your port\n",
"MODEL_NAME = \"meta-llama/Llama-3.2-3B-Instruct\"\n" "MODEL_NAME = \"meta-llama/Llama-3.2-3B-Instruct\"\n"
] ]
}, },
@ -115,7 +115,7 @@
"async def agent_example():\n", "async def agent_example():\n",
" client = LlamaStackClient(base_url=f\"http://{HOST}:{PORT}\")\n", " client = LlamaStackClient(base_url=f\"http://{HOST}:{PORT}\")\n",
" agent = Agent(\n", " agent = Agent(\n",
" client, \n", " client,\n",
" model=MODEL_NAME,\n", " model=MODEL_NAME,\n",
" instructions=\"You are a helpful assistant! If you call builtin tools like brave search, follow the syntax brave_search.call(…)\",\n", " instructions=\"You are a helpful assistant! If you call builtin tools like brave search, follow the syntax brave_search.call(…)\",\n",
" sampling_params={\n", " sampling_params={\n",
@ -168,6 +168,9 @@
} }
], ],
"metadata": { "metadata": {
"fileHeader": "",
"fileUid": "8de24775-c4a0-49c7-904e-608264f69292",
"isAdHoc": false,
"kernelspec": { "kernelspec": {
"display_name": "Python 3 (ipykernel)", "display_name": "Python 3 (ipykernel)",
"language": "python", "language": "python",
@ -185,7 +188,5 @@
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.10.15" "version": "3.10.15"
} }
}, }
"nbformat": 4,
"nbformat_minor": 4
} }

View file

@ -96,7 +96,7 @@ If you're looking for more specific topics, we have a [Zero to Hero Guide](#next
3. **Set the ENV variables by exporting them to the terminal**: 3. **Set the ENV variables by exporting them to the terminal**:
```bash ```bash
export OLLAMA_URL="http://localhost:11434" export OLLAMA_URL="http://localhost:11434"
export LLAMA_STACK_PORT=5001 export LLAMA_STACK_PORT=8321
export INFERENCE_MODEL="meta-llama/Llama-3.2-3B-Instruct" export INFERENCE_MODEL="meta-llama/Llama-3.2-3B-Instruct"
export SAFETY_MODEL="meta-llama/Llama-Guard-3-1B" export SAFETY_MODEL="meta-llama/Llama-Guard-3-1B"
``` ```
@ -112,7 +112,7 @@ If you're looking for more specific topics, we have a [Zero to Hero Guide](#next
``` ```
Note: Every time you run a new model with `ollama run`, you will need to restart the llama stack. Otherwise it won't see the new model. Note: Every time you run a new model with `ollama run`, you will need to restart the llama stack. Otherwise it won't see the new model.
The server will start and listen on `http://localhost:5001`. The server will start and listen on `http://localhost:8321`.
--- ---
## Test with `llama-stack-client` CLI ## Test with `llama-stack-client` CLI
@ -120,11 +120,11 @@ After setting up the server, open a new terminal window and configure the llama-
1. Configure the CLI to point to the llama-stack server. 1. Configure the CLI to point to the llama-stack server.
```bash ```bash
llama-stack-client configure --endpoint http://localhost:5001 llama-stack-client configure --endpoint http://localhost:8321
``` ```
**Expected Output:** **Expected Output:**
```bash ```bash
Done! You can now use the Llama Stack Client CLI with endpoint http://localhost:5001 Done! You can now use the Llama Stack Client CLI with endpoint http://localhost:8321
``` ```
2. Test the CLI by running inference: 2. Test the CLI by running inference:
```bash ```bash
@ -218,7 +218,7 @@ if INFERENCE_MODEL is None:
raise ValueError("The environment variable 'INFERENCE_MODEL' is not set.") raise ValueError("The environment variable 'INFERENCE_MODEL' is not set.")
# Initialize the clien # Initialize the clien
client = LlamaStackClient(base_url="http://localhost:5001") client = LlamaStackClient(base_url="http://localhost:8321")
# Create a chat completion reques # Create a chat completion reques
response = client.inference.chat_completion( response = client.inference.chat_completion(

View file

@ -36,7 +36,6 @@ from llama_stack.apis.inference import (
) )
from llama_stack.apis.safety import SafetyViolation from llama_stack.apis.safety import SafetyViolation
from llama_stack.apis.tools import ToolDef from llama_stack.apis.tools import ToolDef
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
@ -189,13 +188,11 @@ class AgentToolGroupWithArgs(BaseModel):
args: Dict[str, Any] args: Dict[str, Any]
AgentToolGroup = register_schema( AgentToolGroup = Union[
Union[
str, str,
AgentToolGroupWithArgs, AgentToolGroupWithArgs,
], ]
name="AgentTool", register_schema(AgentToolGroup, name="AgentTool")
)
class AgentConfigCommon(BaseModel): class AgentConfigCommon(BaseModel):
@ -312,8 +309,7 @@ class AgentTurnResponseTurnAwaitingInputPayload(BaseModel):
turn: Turn turn: Turn
AgentTurnResponseEventPayload = register_schema( AgentTurnResponseEventPayload = Annotated[
Annotated[
Union[ Union[
AgentTurnResponseStepStartPayload, AgentTurnResponseStepStartPayload,
AgentTurnResponseStepProgressPayload, AgentTurnResponseStepProgressPayload,
@ -323,9 +319,8 @@ AgentTurnResponseEventPayload = register_schema(
AgentTurnResponseTurnAwaitingInputPayload, AgentTurnResponseTurnAwaitingInputPayload,
], ],
Field(discriminator="event_type"), Field(discriminator="event_type"),
], ]
name="AgentTurnResponseEventPayload", register_schema(AgentTurnResponseEventPayload, name="AgentTurnResponseEventPayload")
)
@json_schema_type @json_schema_type
@ -387,7 +382,6 @@ class AgentStepResponse(BaseModel):
@runtime_checkable @runtime_checkable
@trace_protocol
class Agents(Protocol): class Agents(Protocol):
"""Agents API for creating and interacting with agentic systems. """Agents API for creating and interacting with agentic systems.
@ -399,7 +393,7 @@ class Agents(Protocol):
- Agents can also use Memory to retrieve information from knowledge bases. See the RAG Tool and Vector IO APIs for more details. - Agents can also use Memory to retrieve information from knowledge bases. See the RAG Tool and Vector IO APIs for more details.
""" """
@webmethod(route="/agents", method="POST") @webmethod(route="/agents", method="POST", descriptive_name="create_agent")
async def create_agent( async def create_agent(
self, self,
agent_config: AgentConfig, agent_config: AgentConfig,
@ -411,7 +405,9 @@ class Agents(Protocol):
""" """
... ...
@webmethod(route="/agents/{agent_id}/session/{session_id}/turn", method="POST") @webmethod(
route="/agents/{agent_id}/session/{session_id}/turn", method="POST", descriptive_name="create_agent_turn"
)
async def create_agent_turn( async def create_agent_turn(
self, self,
agent_id: str, agent_id: str,
@ -443,6 +439,7 @@ class Agents(Protocol):
@webmethod( @webmethod(
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume", route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume",
method="POST", method="POST",
descriptive_name="resume_agent_turn",
) )
async def resume_agent_turn( async def resume_agent_turn(
self, self,
@ -505,7 +502,7 @@ class Agents(Protocol):
""" """
... ...
@webmethod(route="/agents/{agent_id}/session", method="POST") @webmethod(route="/agents/{agent_id}/session", method="POST", descriptive_name="create_agent_session")
async def create_agent_session( async def create_agent_session(
self, self,
agent_id: str, agent_id: str,

View file

@ -63,19 +63,15 @@ class TextContentItem(BaseModel):
# other modalities can be added here # other modalities can be added here
InterleavedContentItem = register_schema( InterleavedContentItem = Annotated[
Annotated[
Union[ImageContentItem, TextContentItem], Union[ImageContentItem, TextContentItem],
Field(discriminator="type"), Field(discriminator="type"),
], ]
name="InterleavedContentItem", register_schema(InterleavedContentItem, name="InterleavedContentItem")
)
# accept a single "str" as a special case since it is common # accept a single "str" as a special case since it is common
InterleavedContent = register_schema( InterleavedContent = Union[str, InterleavedContentItem, List[InterleavedContentItem]]
Union[str, InterleavedContentItem, List[InterleavedContentItem]], register_schema(InterleavedContent, name="InterleavedContent")
name="InterleavedContent",
)
@json_schema_type @json_schema_type
@ -109,10 +105,8 @@ class ToolCallDelta(BaseModel):
# streaming completions send a stream of ContentDeltas # streaming completions send a stream of ContentDeltas
ContentDelta = register_schema( ContentDelta = Annotated[
Annotated[
Union[TextDelta, ImageDelta, ToolCallDelta], Union[TextDelta, ImageDelta, ToolCallDelta],
Field(discriminator="type"), Field(discriminator="type"),
], ]
name="ContentDelta", register_schema(ContentDelta, name="ContentDelta")
)

View file

@ -10,14 +10,14 @@ from pydantic import BaseModel
from llama_stack.schema_utils import json_schema_type from llama_stack.schema_utils import json_schema_type
@json_schema_type
class Job(BaseModel):
job_id: str
@json_schema_type
class JobStatus(Enum): class JobStatus(Enum):
completed = "completed" completed = "completed"
in_progress = "in_progress" in_progress = "in_progress"
failed = "failed" failed = "failed"
scheduled = "scheduled" scheduled = "scheduled"
@json_schema_type
class Job(BaseModel):
job_id: str
status: JobStatus

View file

@ -72,8 +72,7 @@ class DialogType(BaseModel):
type: Literal["dialog"] = "dialog" type: Literal["dialog"] = "dialog"
ParamType = register_schema( ParamType = Annotated[
Annotated[
Union[ Union[
StringType, StringType,
NumberType, NumberType,
@ -87,9 +86,8 @@ ParamType = register_schema(
AgentTurnInputType, AgentTurnInputType,
], ],
Field(discriminator="type"), Field(discriminator="type"),
], ]
name="ParamType", register_schema(ParamType, name="ParamType")
)
""" """
# TODO: recursive definition of ParamType in these containers # TODO: recursive definition of ParamType in these containers

View file

@ -84,13 +84,11 @@ class RowsDataSource(BaseModel):
rows: List[Dict[str, Any]] rows: List[Dict[str, Any]]
DataSource = register_schema( DataSource = Annotated[
Annotated[
Union[URIDataSource, RowsDataSource], Union[URIDataSource, RowsDataSource],
Field(discriminator="type"), Field(discriminator="type"),
], ]
name="DataSource", register_schema(DataSource, name="DataSource")
)
class CommonDatasetFields(BaseModel): class CommonDatasetFields(BaseModel):

View file

@ -10,7 +10,7 @@ from pydantic import BaseModel, Field
from typing_extensions import Annotated from typing_extensions import Annotated
from llama_stack.apis.agents import AgentConfig from llama_stack.apis.agents import AgentConfig
from llama_stack.apis.common.job_types import Job, JobStatus from llama_stack.apis.common.job_types import Job
from llama_stack.apis.inference import SamplingParams, SystemMessage from llama_stack.apis.inference import SamplingParams, SystemMessage
from llama_stack.apis.scoring import ScoringResult from llama_stack.apis.scoring import ScoringResult
from llama_stack.apis.scoring_functions import ScoringFnParams from llama_stack.apis.scoring_functions import ScoringFnParams
@ -43,10 +43,8 @@ class AgentCandidate(BaseModel):
config: AgentConfig config: AgentConfig
EvalCandidate = register_schema( EvalCandidate = Annotated[Union[ModelCandidate, AgentCandidate], Field(discriminator="type")]
Annotated[Union[ModelCandidate, AgentCandidate], Field(discriminator="type")], register_schema(EvalCandidate, name="EvalCandidate")
name="EvalCandidate",
)
@json_schema_type @json_schema_type
@ -117,7 +115,7 @@ class Eval(Protocol):
""" """
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="GET") @webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="GET")
async def job_status(self, benchmark_id: str, job_id: str) -> JobStatus: async def job_status(self, benchmark_id: str, job_id: str) -> Job:
"""Get the status of a job. """Get the status of a job.
:param benchmark_id: The ID of the benchmark to run the evaluation on. :param benchmark_id: The ID of the benchmark to run the evaluation on.

View file

@ -144,8 +144,7 @@ class CompletionMessage(BaseModel):
tool_calls: Optional[List[ToolCall]] = Field(default_factory=list) tool_calls: Optional[List[ToolCall]] = Field(default_factory=list)
Message = register_schema( Message = Annotated[
Annotated[
Union[ Union[
UserMessage, UserMessage,
SystemMessage, SystemMessage,
@ -153,9 +152,8 @@ Message = register_schema(
CompletionMessage, CompletionMessage,
], ],
Field(discriminator="role"), Field(discriminator="role"),
], ]
name="Message", register_schema(Message, name="Message")
)
@json_schema_type @json_schema_type
@ -263,13 +261,11 @@ class GrammarResponseFormat(BaseModel):
bnf: Dict[str, Any] bnf: Dict[str, Any]
ResponseFormat = register_schema( ResponseFormat = Annotated[
Annotated[
Union[JsonSchemaResponseFormat, GrammarResponseFormat], Union[JsonSchemaResponseFormat, GrammarResponseFormat],
Field(discriminator="type"), Field(discriminator="type"),
], ]
name="ResponseFormat", register_schema(ResponseFormat, name="ResponseFormat")
)
# This is an internally used class # This is an internally used class

View file

@ -24,17 +24,6 @@ class HealthInfo(BaseModel):
# TODO: add a provider level status # TODO: add a provider level status
@json_schema_type
class ProviderInfo(BaseModel):
api: str
provider_id: str
provider_type: str
class ListProvidersResponse(BaseModel):
data: List[ProviderInfo]
@json_schema_type @json_schema_type
class VersionInfo(BaseModel): class VersionInfo(BaseModel):
version: str version: str
@ -46,9 +35,6 @@ class ListRoutesResponse(BaseModel):
@runtime_checkable @runtime_checkable
class Inspect(Protocol): class Inspect(Protocol):
@webmethod(route="/inspect/providers", method="GET")
async def list_providers(self) -> ListProvidersResponse: ...
@webmethod(route="/inspect/routes", method="GET") @webmethod(route="/inspect/routes", method="GET")
async def list_routes(self) -> ListRoutesResponse: ... async def list_routes(self) -> ListRoutesResponse: ...

View file

@ -6,7 +6,7 @@
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Protocol from typing import Any, Dict, List, Literal, Optional, Protocol, Union
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing_extensions import Annotated from typing_extensions import Annotated
@ -88,10 +88,8 @@ class QATFinetuningConfig(BaseModel):
group_size: int group_size: int
AlgorithmConfig = register_schema( AlgorithmConfig = Annotated[Union[LoraFinetuningConfig, QATFinetuningConfig], Field(discriminator="type")]
Annotated[LoraFinetuningConfig | QATFinetuningConfig, Field(discriminator="type")], register_schema(AlgorithmConfig, name="AlgorithmConfig")
name="AlgorithmConfig",
)
@json_schema_type @json_schema_type
@ -184,7 +182,7 @@ class PostTraining(Protocol):
description="Model descriptor from `llama model list`", description="Model descriptor from `llama model list`",
), ),
checkpoint_dir: Optional[str] = None, checkpoint_dir: Optional[str] = None,
algorithm_config: Optional[LoraFinetuningConfig | QATFinetuningConfig] = None, algorithm_config: Optional[AlgorithmConfig] = None,
) -> PostTrainingJob: ... ) -> PostTrainingJob: ...
@webmethod(route="/post-training/preference-optimize", method="POST") @webmethod(route="/post-training/preference-optimize", method="POST")

View file

@ -36,6 +36,7 @@ class ScoringFnParamsType(Enum):
@json_schema_type @json_schema_type
class AggregationFunctionType(Enum): class AggregationFunctionType(Enum):
average = "average" average = "average"
weighted_average = "weighted_average"
median = "median" median = "median"
categorical_count = "categorical_count" categorical_count = "categorical_count"
accuracy = "accuracy" accuracy = "accuracy"
@ -78,17 +79,15 @@ class BasicScoringFnParams(BaseModel):
) )
ScoringFnParams = register_schema( ScoringFnParams = Annotated[
Annotated[
Union[ Union[
LLMAsJudgeScoringFnParams, LLMAsJudgeScoringFnParams,
RegexParserScoringFnParams, RegexParserScoringFnParams,
BasicScoringFnParams, BasicScoringFnParams,
], ],
Field(discriminator="type"), Field(discriminator="type"),
], ]
name="ScoringFnParams", register_schema(ScoringFnParams, name="ScoringFnParams")
)
class CommonScoringFnFields(BaseModel): class CommonScoringFnFields(BaseModel):

View file

@ -146,16 +146,14 @@ class SpanEndPayload(BaseModel):
status: SpanStatus status: SpanStatus
StructuredLogPayload = register_schema( StructuredLogPayload = Annotated[
Annotated[
Union[ Union[
SpanStartPayload, SpanStartPayload,
SpanEndPayload, SpanEndPayload,
], ],
Field(discriminator="type"), Field(discriminator="type"),
], ]
name="StructuredLogPayload", register_schema(StructuredLogPayload, name="StructuredLogPayload")
)
@json_schema_type @json_schema_type
@ -164,17 +162,15 @@ class StructuredLogEvent(EventCommon):
payload: StructuredLogPayload payload: StructuredLogPayload
Event = register_schema( Event = Annotated[
Annotated[
Union[ Union[
UnstructuredLogEvent, UnstructuredLogEvent,
MetricEvent, MetricEvent,
StructuredLogEvent, StructuredLogEvent,
], ],
Field(discriminator="type"), Field(discriminator="type"),
], ]
name="Event", register_schema(Event, name="Event")
)
@json_schema_type @json_schema_type

View file

@ -58,16 +58,14 @@ class LLMRAGQueryGeneratorConfig(BaseModel):
template: str template: str
RAGQueryGeneratorConfig = register_schema( RAGQueryGeneratorConfig = Annotated[
Annotated[
Union[ Union[
DefaultRAGQueryGeneratorConfig, DefaultRAGQueryGeneratorConfig,
LLMRAGQueryGeneratorConfig, LLMRAGQueryGeneratorConfig,
], ],
Field(discriminator="type"), Field(discriminator="type"),
], ]
name="RAGQueryGeneratorConfig", register_schema(RAGQueryGeneratorConfig, name="RAGQueryGeneratorConfig")
)
@json_schema_type @json_schema_type

View file

@ -69,7 +69,7 @@ class ToolGroup(Resource):
@json_schema_type @json_schema_type
class ToolInvocationResult(BaseModel): class ToolInvocationResult(BaseModel):
content: InterleavedContent content: Optional[InterleavedContent] = None
error_message: Optional[str] = None error_message: Optional[str] = None
error_code: Optional[int] = None error_code: Optional[int] = None
metadata: Optional[Dict[str, Any]] = None metadata: Optional[Dict[str, Any]] = None
@ -140,9 +140,9 @@ class SpecialToolGroup(Enum):
@runtime_checkable @runtime_checkable
@trace_protocol @trace_protocol
class ToolRuntime(Protocol): class ToolRuntime(Protocol):
tool_store: ToolStore tool_store: ToolStore | None = None
rag_tool: RAGToolRuntime rag_tool: RAGToolRuntime | None = None
# TODO: This needs to be renamed once OPEN API generator name conflict issue is fixed. # TODO: This needs to be renamed once OPEN API generator name conflict issue is fixed.
@webmethod(route="/tool-runtime/list-tools", method="GET") @webmethod(route="/tool-runtime/list-tools", method="GET")

View file

@ -36,7 +36,7 @@ class VectorDBStore(Protocol):
@runtime_checkable @runtime_checkable
@trace_protocol @trace_protocol
class VectorIO(Protocol): class VectorIO(Protocol):
vector_db_store: VectorDBStore vector_db_store: VectorDBStore | None = None
# this will just block now until chunks are inserted, but it should # this will just block now until chunks are inserted, but it should
# probably return a Job instance which can be polled for completion # probably return a Job instance which can be polled for completion

View file

@ -404,7 +404,7 @@ def _download_from_manifest(manifest_file: str, max_concurrent_downloads: int):
d = json.load(f) d = json.load(f)
manifest = Manifest(**d) manifest = Manifest(**d)
if datetime.now(timezone.utc) > manifest.expires_on: if datetime.now(timezone.utc) > manifest.expires_on.astimezone(timezone.utc):
raise ValueError(f"Manifest URLs have expired on {manifest.expires_on}") raise ValueError(f"Manifest URLs have expired on {manifest.expires_on}")
console = Console() console = Console()

View file

@ -0,0 +1,86 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, Optional
from llama_stack.distribution.datatypes import AccessAttributes
from llama_stack.log import get_logger
logger = get_logger(__name__, category="core")
def check_access(
obj_identifier: str,
obj_attributes: Optional[AccessAttributes],
user_attributes: Optional[Dict[str, Any]] = None,
) -> bool:
"""Check if the current user has access to the given object, based on access attributes.
Access control algorithm:
1. If the resource has no access_attributes, access is GRANTED to all authenticated users
2. If the user has no attributes, access is DENIED to any object with access_attributes defined
3. For each attribute category in the resource's access_attributes:
a. If the user lacks that category, access is DENIED
b. If the user has the category but none of the required values, access is DENIED
c. If the user has at least one matching value in each required category, access is GRANTED
Example:
# Resource requires:
access_attributes = AccessAttributes(
roles=["admin", "data-scientist"],
teams=["ml-team"]
)
# User has:
user_attributes = {
"roles": ["data-scientist", "engineer"],
"teams": ["ml-team", "infra-team"],
"projects": ["llama-3"]
}
# Result: Access GRANTED
# - User has the "data-scientist" role (matches one of the required roles)
# - AND user is part of the "ml-team" (matches the required team)
# - The extra "projects" attribute is ignored
Args:
obj_identifier: The identifier of the resource object to check access for
obj_attributes: The access attributes of the resource object
user_attributes: The attributes of the current user
Returns:
bool: True if access is granted, False if denied
"""
# If object has no access attributes, allow access by default
if not obj_attributes:
return True
# If no user attributes, deny access to objects with access control
if not user_attributes:
return False
dict_attribs = obj_attributes.model_dump(exclude_none=True)
if not dict_attribs:
return True
# Check each attribute category (requires ALL categories to match)
# TODO: formalize this into a proper ABAC policy
for attr_key, required_values in dict_attribs.items():
user_values = user_attributes.get(attr_key, [])
if not user_values:
logger.debug(f"Access denied to {obj_identifier}: missing required attribute category '{attr_key}'")
return False
if not any(val in user_values for val in required_values):
logger.debug(
f"Access denied to {obj_identifier}: "
f"no match for attribute '{attr_key}', required one of {required_values}"
)
return False
logger.debug(f"Access granted to {obj_identifier}")
return True

View file

@ -90,6 +90,7 @@ RUN apt-get update && apt-get install -y \
procps psmisc lsof \ procps psmisc lsof \
traceroute \ traceroute \
bubblewrap \ bubblewrap \
gcc \
&& rm -rf /var/lib/apt/lists/* && rm -rf /var/lib/apt/lists/*
ENV UV_SYSTEM_PYTHON=1 ENV UV_SYSTEM_PYTHON=1
@ -235,7 +236,7 @@ image_tag="$image_name:$version_tag"
# Detect platform architecture # Detect platform architecture
ARCH=$(uname -m) ARCH=$(uname -m)
if [ -n "$BUILD_PLATFORM" ]; then if [ -n "$BUILD_PLATFORM" ]; then
CLI_ARGS+=("--platform $BUILD_PLATFORM") CLI_ARGS+=("--platform" "$BUILD_PLATFORM")
elif [ "$ARCH" = "arm64" ] || [ "$ARCH" = "aarch64" ]; then elif [ "$ARCH" = "arm64" ] || [ "$ARCH" = "aarch64" ]; then
CLI_ARGS+=("--platform" "linux/arm64") CLI_ARGS+=("--platform" "linux/arm64")
elif [ "$ARCH" = "x86_64" ]; then elif [ "$ARCH" = "x86_64" ]; then

View file

@ -14,6 +14,7 @@ from llama_stack.apis.datasets import Dataset, DatasetInput
from llama_stack.apis.eval import Eval from llama_stack.apis.eval import Eval
from llama_stack.apis.inference import Inference from llama_stack.apis.inference import Inference
from llama_stack.apis.models import Model, ModelInput from llama_stack.apis.models import Model, ModelInput
from llama_stack.apis.resource import Resource
from llama_stack.apis.safety import Safety from llama_stack.apis.safety import Safety
from llama_stack.apis.scoring import Scoring from llama_stack.apis.scoring import Scoring
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnInput from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnInput
@ -31,6 +32,115 @@ LLAMA_STACK_RUN_CONFIG_VERSION = "2"
RoutingKey = Union[str, List[str]] RoutingKey = Union[str, List[str]]
class AccessAttributes(BaseModel):
"""Structured representation of user attributes for access control.
This model defines a structured approach to representing user attributes
with common standard categories for access control.
Standard attribute categories include:
- roles: Role-based attributes (e.g., admin, data-scientist)
- teams: Team-based attributes (e.g., ml-team, infra-team)
- projects: Project access attributes (e.g., llama-3, customer-insights)
- namespaces: Namespace-based access control for resource isolation
"""
# Standard attribute categories - the minimal set we need now
roles: Optional[List[str]] = Field(
default=None, description="Role-based attributes (e.g., 'admin', 'data-scientist', 'user')"
)
teams: Optional[List[str]] = Field(default=None, description="Team-based attributes (e.g., 'ml-team', 'nlp-team')")
projects: Optional[List[str]] = Field(
default=None, description="Project-based access attributes (e.g., 'llama-3', 'customer-insights')"
)
namespaces: Optional[List[str]] = Field(
default=None, description="Namespace-based access control for resource isolation"
)
class ResourceWithACL(Resource):
"""Extension of Resource that adds attribute-based access control capabilities.
This class adds an optional access_attributes field that allows fine-grained control
over which users can access each resource. When attributes are defined, a user must have
matching attributes to access the resource.
Attribute Matching Algorithm:
1. If a resource has no access_attributes (None or empty dict), it's visible to all authenticated users
2. Each key in access_attributes represents an attribute category (e.g., "roles", "teams", "projects")
3. The matching algorithm requires ALL categories to match (AND relationship between categories)
4. Within each category, ANY value match is sufficient (OR relationship within a category)
Examples:
# Resource visible to everyone (no access control)
model = Model(identifier="llama-2", ...)
# Resource visible only to admins
model = Model(
identifier="gpt-4",
access_attributes=AccessAttributes(roles=["admin"])
)
# Resource visible to data scientists on the ML team
model = Model(
identifier="private-model",
access_attributes=AccessAttributes(
roles=["data-scientist", "researcher"],
teams=["ml-team"]
)
)
# ^ User must have at least one of the roles AND be on the ml-team
# Resource visible to users with specific project access
vector_db = VectorDB(
identifier="customer-embeddings",
access_attributes=AccessAttributes(
projects=["customer-insights"],
namespaces=["confidential"]
)
)
# ^ User must have access to the customer-insights project AND have confidential namespace
"""
access_attributes: Optional[AccessAttributes] = None
# Use the extended Resource for all routable objects
class ModelWithACL(Model, ResourceWithACL):
pass
class ShieldWithACL(Shield, ResourceWithACL):
pass
class VectorDBWithACL(VectorDB, ResourceWithACL):
pass
class DatasetWithACL(Dataset, ResourceWithACL):
pass
class ScoringFnWithACL(ScoringFn, ResourceWithACL):
pass
class BenchmarkWithACL(Benchmark, ResourceWithACL):
pass
class ToolWithACL(Tool, ResourceWithACL):
pass
class ToolGroupWithACL(ToolGroup, ResourceWithACL):
pass
RoutableObject = Union[ RoutableObject = Union[
Model, Model,
Shield, Shield,
@ -45,14 +155,14 @@ RoutableObject = Union[
RoutableObjectWithProvider = Annotated[ RoutableObjectWithProvider = Annotated[
Union[ Union[
Model, ModelWithACL,
Shield, ShieldWithACL,
VectorDB, VectorDBWithACL,
Dataset, DatasetWithACL,
ScoringFn, ScoringFnWithACL,
Benchmark, BenchmarkWithACL,
Tool, ToolWithACL,
ToolGroup, ToolGroupWithACL,
], ],
Field(discriminator="type"), Field(discriminator="type"),
] ]

View file

@ -11,9 +11,7 @@ from pydantic import BaseModel
from llama_stack.apis.inspect import ( from llama_stack.apis.inspect import (
HealthInfo, HealthInfo,
Inspect, Inspect,
ListProvidersResponse,
ListRoutesResponse, ListRoutesResponse,
ProviderInfo,
RouteInfo, RouteInfo,
VersionInfo, VersionInfo,
) )
@ -39,24 +37,6 @@ class DistributionInspectImpl(Inspect):
async def initialize(self) -> None: async def initialize(self) -> None:
pass pass
async def list_providers(self) -> ListProvidersResponse:
run_config = self.config.run_config
ret = []
for api, providers in run_config.providers.items():
ret.extend(
[
ProviderInfo(
api=api,
provider_id=p.provider_id,
provider_type=p.provider_type,
)
for p in providers
]
)
return ListProvidersResponse(data=ret)
async def list_routes(self) -> ListRoutesResponse: async def list_routes(self) -> ListRoutesResponse:
run_config = self.config.run_config run_config = self.config.run_config

View file

@ -9,7 +9,6 @@ import inspect
import json import json
import logging import logging
import os import os
import re
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
@ -37,7 +36,10 @@ from llama_stack.distribution.request_headers import (
request_provider_data_context, request_provider_data_context,
) )
from llama_stack.distribution.resolver import ProviderRegistry from llama_stack.distribution.resolver import ProviderRegistry
from llama_stack.distribution.server.endpoints import get_all_api_endpoints from llama_stack.distribution.server.endpoints import (
find_matching_endpoint,
initialize_endpoint_impls,
)
from llama_stack.distribution.stack import ( from llama_stack.distribution.stack import (
construct_stack, construct_stack,
get_stack_run_config_from_template, get_stack_run_config_from_template,
@ -232,31 +234,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
safe_config = redact_sensitive_fields(self.config.model_dump()) safe_config = redact_sensitive_fields(self.config.model_dump())
console.print(yaml.dump(safe_config, indent=2)) console.print(yaml.dump(safe_config, indent=2))
endpoints = get_all_api_endpoints() self.endpoint_impls = initialize_endpoint_impls(self.impls)
endpoint_impls = {}
def _convert_path_to_regex(path: str) -> str:
# Convert {param} to named capture groups
# handle {param:path} as well which allows for forward slashes in the param value
pattern = re.sub(
r"{(\w+)(?::path)?}",
lambda m: f"(?P<{m.group(1)}>{'[^/]+' if not m.group(0).endswith(':path') else '.+'})",
path,
)
return f"^{pattern}$"
for api, api_endpoints in endpoints.items():
if api not in self.impls:
continue
for endpoint in api_endpoints:
impl = self.impls[api]
func = getattr(impl, endpoint.name)
if endpoint.method not in endpoint_impls:
endpoint_impls[endpoint.method] = {}
endpoint_impls[endpoint.method][_convert_path_to_regex(endpoint.route)] = func
self.endpoint_impls = endpoint_impls
return True return True
async def request( async def request(
@ -290,32 +268,6 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
) )
return response return response
def _find_matching_endpoint(self, method: str, path: str) -> tuple[Any, dict]:
"""Find the matching endpoint implementation for a given method and path.
Args:
method: HTTP method (GET, POST, etc.)
path: URL path to match against
Returns:
A tuple of (endpoint_function, path_params)
Raises:
ValueError: If no matching endpoint is found
"""
impls = self.endpoint_impls.get(method)
if not impls:
raise ValueError(f"No endpoint found for {path}")
for regex, func in impls.items():
match = re.match(regex, path)
if match:
# Extract named groups from the regex match
path_params = match.groupdict()
return func, path_params
raise ValueError(f"No endpoint found for {path}")
async def _call_non_streaming( async def _call_non_streaming(
self, self,
*, *,
@ -326,10 +278,10 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
body = options.params or {} body = options.params or {}
body |= options.json_data or {} body |= options.json_data or {}
matched_func, path_params = self._find_matching_endpoint(options.method, path) matched_func, path_params, route = find_matching_endpoint(options.method, path, self.endpoint_impls)
body |= path_params body |= path_params
body = self._convert_body(path, options.method, body) body = self._convert_body(path, options.method, body)
await start_trace(options.url, {"__location__": "library_client"}) await start_trace(route, {"__location__": "library_client"})
try: try:
result = await matched_func(**body) result = await matched_func(**body)
finally: finally:
@ -371,13 +323,13 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
path = options.url path = options.url
body = options.params or {} body = options.params or {}
body |= options.json_data or {} body |= options.json_data or {}
func, path_params = self._find_matching_endpoint(options.method, path) func, path_params, route = find_matching_endpoint(options.method, path, self.endpoint_impls)
body |= path_params body |= path_params
body = self._convert_body(path, options.method, body) body = self._convert_body(path, options.method, body)
async def gen(): async def gen():
await start_trace(options.url, {"__location__": "library_client"}) await start_trace(route, {"__location__": "library_client"})
try: try:
async for chunk in await func(**body): async for chunk in await func(**body):
data = json.dumps(convert_pydantic_to_json_value(chunk)) data = json.dumps(convert_pydantic_to_json_value(chunk))
@ -422,7 +374,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
if not body: if not body:
return {} return {}
func, _ = self._find_matching_endpoint(method, path) func, _, _ = find_matching_endpoint(method, path, self.endpoint_impls)
sig = inspect.signature(func) sig = inspect.signature(func)
# Strip NOT_GIVENs to use the defaults in signature # Strip NOT_GIVENs to use the defaults in signature

View file

@ -7,21 +7,26 @@
import contextvars import contextvars
import json import json
import logging import logging
from typing import Any, ContextManager, Dict, Optional from typing import Any, ContextManager, Dict, List, Optional
from .utils.dynamic import instantiate_class_type from .utils.dynamic import instantiate_class_type
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
# Context variable for request provider data # Context variable for request provider data and auth attributes
PROVIDER_DATA_VAR = contextvars.ContextVar("provider_data", default=None) PROVIDER_DATA_VAR = contextvars.ContextVar("provider_data", default=None)
class RequestProviderDataContext(ContextManager): class RequestProviderDataContext(ContextManager):
"""Context manager for request provider data""" """Context manager for request provider data"""
def __init__(self, provider_data: Optional[Dict[str, Any]] = None): def __init__(
self.provider_data = provider_data self, provider_data: Optional[Dict[str, Any]] = None, auth_attributes: Optional[Dict[str, List[str]]] = None
):
self.provider_data = provider_data or {}
if auth_attributes:
self.provider_data["__auth_attributes"] = auth_attributes
self.token = None self.token = None
def __enter__(self): def __enter__(self):
@ -80,7 +85,17 @@ def parse_request_provider_data(headers: Dict[str, str]) -> Optional[Dict[str, A
return None return None
def request_provider_data_context(headers: Dict[str, str]) -> ContextManager: def request_provider_data_context(
"""Context manager that sets request provider data from headers for the duration of the context""" headers: Dict[str, str], auth_attributes: Optional[Dict[str, List[str]]] = None
) -> ContextManager:
"""Context manager that sets request provider data from headers and auth attributes for the duration of the context"""
provider_data = parse_request_provider_data(headers) provider_data = parse_request_provider_data(headers)
return RequestProviderDataContext(provider_data) return RequestProviderDataContext(provider_data, auth_attributes)
def get_auth_attributes() -> Optional[Dict[str, List[str]]]:
"""Helper to retrieve auth attributes from the provider data context"""
provider_data = PROVIDER_DATA_VAR.get()
if not provider_data:
return None
return provider_data.get("__auth_attributes")

View file

@ -14,13 +14,7 @@ from llama_stack.apis.common.content_types import (
) )
from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse
from llama_stack.apis.datasets import DatasetPurpose, DataSource from llama_stack.apis.datasets import DatasetPurpose, DataSource
from llama_stack.apis.eval import ( from llama_stack.apis.eval import BenchmarkConfig, Eval, EvaluateResponse, Job
BenchmarkConfig,
Eval,
EvaluateResponse,
Job,
JobStatus,
)
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionResponse, ChatCompletionResponse,
ChatCompletionResponseEventType, ChatCompletionResponseEventType,
@ -623,7 +617,7 @@ class EvalRouter(Eval):
self, self,
benchmark_id: str, benchmark_id: str,
job_id: str, job_id: str,
) -> Optional[JobStatus]: ) -> Job:
logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}") logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}")
return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id) return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id)

View file

@ -41,11 +41,22 @@ from llama_stack.apis.tools import (
ToolHost, ToolHost,
) )
from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs
from llama_stack.distribution.access_control import check_access
from llama_stack.distribution.datatypes import ( from llama_stack.distribution.datatypes import (
AccessAttributes,
BenchmarkWithACL,
DatasetWithACL,
ModelWithACL,
RoutableObject, RoutableObject,
RoutableObjectWithProvider, RoutableObjectWithProvider,
RoutedProtocol, RoutedProtocol,
ScoringFnWithACL,
ShieldWithACL,
ToolGroupWithACL,
ToolWithACL,
VectorDBWithACL,
) )
from llama_stack.distribution.request_headers import get_auth_attributes
from llama_stack.distribution.store import DistributionRegistry from llama_stack.distribution.store import DistributionRegistry
from llama_stack.providers.datatypes import Api, RoutingTable from llama_stack.providers.datatypes import Api, RoutingTable
@ -186,6 +197,11 @@ class CommonRoutingTableImpl(RoutingTable):
if not obj: if not obj:
return None return None
# Check if user has permission to access this object
if not check_access(obj.identifier, getattr(obj, "access_attributes", None), get_auth_attributes()):
logger.debug(f"Access denied to {type} '{identifier}' based on attribute mismatch")
return None
return obj return obj
async def unregister_object(self, obj: RoutableObjectWithProvider) -> None: async def unregister_object(self, obj: RoutableObjectWithProvider) -> None:
@ -202,6 +218,13 @@ class CommonRoutingTableImpl(RoutingTable):
p = self.impls_by_provider_id[obj.provider_id] p = self.impls_by_provider_id[obj.provider_id]
# If object supports access control but no attributes set, use creator's attributes
if not obj.access_attributes:
creator_attributes = get_auth_attributes()
if creator_attributes:
obj.access_attributes = AccessAttributes(**creator_attributes)
logger.info(f"Setting access attributes for {obj.type} '{obj.identifier}' based on creator's identity")
registered_obj = await register_object_with_provider(obj, p) registered_obj = await register_object_with_provider(obj, p)
# TODO: This needs to be fixed for all APIs once they return the registered object # TODO: This needs to be fixed for all APIs once they return the registered object
if obj.type == ResourceType.model.value: if obj.type == ResourceType.model.value:
@ -214,7 +237,17 @@ class CommonRoutingTableImpl(RoutingTable):
async def get_all_with_type(self, type: str) -> List[RoutableObjectWithProvider]: async def get_all_with_type(self, type: str) -> List[RoutableObjectWithProvider]:
objs = await self.dist_registry.get_all() objs = await self.dist_registry.get_all()
return [obj for obj in objs if obj.type == type] filtered_objs = [obj for obj in objs if obj.type == type]
# Apply attribute-based access control filtering
if filtered_objs:
filtered_objs = [
obj
for obj in filtered_objs
if check_access(obj.identifier, getattr(obj, "access_attributes", None), get_auth_attributes())
]
return filtered_objs
class ModelsRoutingTable(CommonRoutingTableImpl, Models): class ModelsRoutingTable(CommonRoutingTableImpl, Models):
@ -251,7 +284,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
model_type = ModelType.llm model_type = ModelType.llm
if "embedding_dimension" not in metadata and model_type == ModelType.embedding: if "embedding_dimension" not in metadata and model_type == ModelType.embedding:
raise ValueError("Embedding model must have an embedding dimension in its metadata") raise ValueError("Embedding model must have an embedding dimension in its metadata")
model = Model( model = ModelWithACL(
identifier=model_id, identifier=model_id,
provider_resource_id=provider_model_id, provider_resource_id=provider_model_id,
provider_id=provider_id, provider_id=provider_id,
@ -297,7 +330,7 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
) )
if params is None: if params is None:
params = {} params = {}
shield = Shield( shield = ShieldWithACL(
identifier=shield_id, identifier=shield_id,
provider_resource_id=provider_shield_id, provider_resource_id=provider_shield_id,
provider_id=provider_id, provider_id=provider_id,
@ -351,7 +384,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
"embedding_model": embedding_model, "embedding_model": embedding_model,
"embedding_dimension": model.metadata["embedding_dimension"], "embedding_dimension": model.metadata["embedding_dimension"],
} }
vector_db = TypeAdapter(VectorDB).validate_python(vector_db_data) vector_db = TypeAdapter(VectorDBWithACL).validate_python(vector_db_data)
await self.register_object(vector_db) await self.register_object(vector_db)
return vector_db return vector_db
@ -405,7 +438,7 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
if metadata is None: if metadata is None:
metadata = {} metadata = {}
dataset = Dataset( dataset = DatasetWithACL(
identifier=dataset_id, identifier=dataset_id,
provider_resource_id=provider_dataset_id, provider_resource_id=provider_dataset_id,
provider_id=provider_id, provider_id=provider_id,
@ -452,7 +485,7 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
raise ValueError( raise ValueError(
"No provider specified and multiple providers available. Please specify a provider_id." "No provider specified and multiple providers available. Please specify a provider_id."
) )
scoring_fn = ScoringFn( scoring_fn = ScoringFnWithACL(
identifier=scoring_fn_id, identifier=scoring_fn_id,
description=description, description=description,
return_type=return_type, return_type=return_type,
@ -494,7 +527,7 @@ class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
) )
if provider_benchmark_id is None: if provider_benchmark_id is None:
provider_benchmark_id = benchmark_id provider_benchmark_id = benchmark_id
benchmark = Benchmark( benchmark = BenchmarkWithACL(
identifier=benchmark_id, identifier=benchmark_id,
dataset_id=dataset_id, dataset_id=dataset_id,
scoring_functions=scoring_functions, scoring_functions=scoring_functions,
@ -537,7 +570,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
for tool_def in tool_defs: for tool_def in tool_defs:
tools.append( tools.append(
Tool( ToolWithACL(
identifier=tool_def.name, identifier=tool_def.name,
toolgroup_id=toolgroup_id, toolgroup_id=toolgroup_id,
description=tool_def.description or "", description=tool_def.description or "",
@ -562,7 +595,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
await self.register_object(tool) await self.register_object(tool)
await self.dist_registry.register( await self.dist_registry.register(
ToolGroup( ToolGroupWithACL(
identifier=toolgroup_id, identifier=toolgroup_id,
provider_id=provider_id, provider_id=provider_id,
provider_resource_id=toolgroup_id, provider_resource_id=toolgroup_id,
@ -575,7 +608,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
tool_group = await self.get_tool_group(toolgroup_id) tool_group = await self.get_tool_group(toolgroup_id)
if tool_group is None: if tool_group is None:
raise ValueError(f"Tool group {toolgroup_id} not found") raise ValueError(f"Tool group {toolgroup_id} not found")
tools = await self.list_tools(toolgroup_id).data tools = (await self.list_tools(toolgroup_id)).data
for tool in tools: for tool in tools:
await self.unregister_object(tool) await self.unregister_object(tool)
await self.unregister_object(tool_group) await self.unregister_object(tool_group)

View file

@ -5,16 +5,118 @@
# the root directory of this source tree. # the root directory of this source tree.
import json import json
from typing import Dict, List, Optional
from urllib.parse import parse_qs from urllib.parse import parse_qs
import httpx import httpx
from pydantic import BaseModel, Field
from llama_stack.distribution.datatypes import AccessAttributes
from llama_stack.log import get_logger from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="auth") logger = get_logger(name=__name__, category="auth")
class AuthRequestContext(BaseModel):
path: str = Field(description="The path of the request being authenticated")
headers: Dict[str, str] = Field(description="HTTP headers from the original request (excluding Authorization)")
params: Dict[str, List[str]] = Field(
description="Query parameters from the original request, parsed as dictionary of lists"
)
class AuthRequest(BaseModel):
api_key: str = Field(description="The API key extracted from the Authorization header")
request: AuthRequestContext = Field(description="Context information about the request being authenticated")
class AuthResponse(BaseModel):
"""The format of the authentication response from the auth endpoint."""
access_attributes: Optional[AccessAttributes] = Field(
default=None,
description="""
Structured user attributes for attribute-based access control.
These attributes determine which resources the user can access.
The model provides standard categories like "roles", "teams", "projects", and "namespaces".
Each attribute category contains a list of values that the user has for that category.
During access control checks, these values are compared against resource requirements.
Example with standard categories:
```json
{
"roles": ["admin", "data-scientist"],
"teams": ["ml-team"],
"projects": ["llama-3"],
"namespaces": ["research"]
}
```
""",
)
message: Optional[str] = Field(
default=None, description="Optional message providing additional context about the authentication result."
)
class AuthenticationMiddleware: class AuthenticationMiddleware:
"""Middleware that authenticates requests using an external auth endpoint.
This middleware:
1. Extracts the Bearer token from the Authorization header
2. Sends it to the configured auth endpoint along with request details
3. Validates the response and extracts user attributes
4. Makes these attributes available to the route handlers for access control
Authentication Request Format:
```json
{
"api_key": "the-api-key-extracted-from-auth-header",
"request": {
"path": "/models/list",
"headers": {
"content-type": "application/json",
"user-agent": "..."
// All headers except Authorization
},
"params": {
"limit": ["100"],
"offset": ["0"]
// Query parameters as key -> list of values
}
}
}
```
Expected Auth Endpoint Response Format:
```json
{
"access_attributes": { // Structured attribute format
"roles": ["admin", "user"],
"teams": ["ml-team", "nlp-team"],
"projects": ["llama-3", "project-x"],
"namespaces": ["research"]
},
"message": "Optional message about auth result"
}
```
Attribute-Based Access Control:
The attributes returned by the auth endpoint are used to determine which
resources the user can access. Resources can specify required attributes
using the access_attributes field. For a user to access a resource:
1. All attribute categories specified in the resource must be present in the user's attributes
2. For each category, the user must have at least one matching value
If the auth endpoint doesn't return any attributes, the user will only be able to
access resources that don't have access_attributes defined.
"""
def __init__(self, app, auth_endpoint): def __init__(self, app, auth_endpoint):
self.app = app self.app = app
self.auth_endpoint = auth_endpoint self.auth_endpoint = auth_endpoint
@ -32,25 +134,57 @@ class AuthenticationMiddleware:
path = scope.get("path", "") path = scope.get("path", "")
request_headers = {k.decode(): v.decode() for k, v in headers.items()} request_headers = {k.decode(): v.decode() for k, v in headers.items()}
# Remove sensitive headers
if "authorization" in request_headers:
del request_headers["authorization"]
query_string = scope.get("query_string", b"").decode() query_string = scope.get("query_string", b"").decode()
params = parse_qs(query_string) params = parse_qs(query_string)
auth_data = { # Build the auth request model
"api_key": api_key, auth_request = AuthRequest(
"request": { api_key=api_key,
"path": path, request=AuthRequestContext(
"headers": request_headers, path=path,
"params": params, headers=request_headers,
}, params=params,
} ),
)
# Validate with authentication endpoint # Validate with authentication endpoint
try: try:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.post(self.auth_endpoint, json=auth_data) response = await client.post(
self.auth_endpoint,
json=auth_request.model_dump(),
timeout=10.0, # Add a reasonable timeout
)
if response.status_code != 200: if response.status_code != 200:
logger.warning(f"Authentication failed: {response.status_code}") logger.warning(f"Authentication failed: {response.status_code}")
return await self._send_auth_error(send, "Authentication failed") return await self._send_auth_error(send, "Authentication failed")
# Parse and validate the auth response
try:
response_data = response.json()
auth_response = AuthResponse(**response_data)
# Store attributes in request scope for access control
if auth_response.access_attributes:
user_attributes = auth_response.access_attributes.model_dump(exclude_none=True)
else:
logger.warning("No access attributes, setting namespace to api_key by default")
user_attributes = {
"namespaces": [api_key],
}
scope["user_attributes"] = user_attributes
logger.debug(f"Authentication successful: {len(user_attributes)} attributes")
except Exception:
logger.exception("Error parsing authentication response")
return await self._send_auth_error(send, "Invalid authentication response format")
except httpx.TimeoutException:
logger.exception("Authentication request timed out")
return await self._send_auth_error(send, "Authentication service timeout")
except Exception: except Exception:
logger.exception("Error during authentication") logger.exception("Error during authentication")
return await self._send_auth_error(send, "Authentication service error") return await self._send_auth_error(send, "Authentication service error")

View file

@ -5,6 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
import inspect import inspect
import re
from typing import Dict, List from typing import Dict, List
from pydantic import BaseModel from pydantic import BaseModel
@ -19,6 +20,7 @@ class ApiEndpoint(BaseModel):
route: str route: str
method: str method: str
name: str name: str
descriptive_name: str | None = None
def toolgroup_protocol_map(): def toolgroup_protocol_map():
@ -58,8 +60,69 @@ def get_all_api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
method = "delete" method = "delete"
else: else:
method = "post" method = "post"
endpoints.append(ApiEndpoint(route=route, method=method, name=name)) endpoints.append(
ApiEndpoint(route=route, method=method, name=name, descriptive_name=webmethod.descriptive_name)
)
apis[api] = endpoints apis[api] = endpoints
return apis return apis
def initialize_endpoint_impls(impls):
endpoints = get_all_api_endpoints()
endpoint_impls = {}
def _convert_path_to_regex(path: str) -> str:
# Convert {param} to named capture groups
# handle {param:path} as well which allows for forward slashes in the param value
pattern = re.sub(
r"{(\w+)(?::path)?}",
lambda m: f"(?P<{m.group(1)}>{'[^/]+' if not m.group(0).endswith(':path') else '.+'})",
path,
)
return f"^{pattern}$"
for api, api_endpoints in endpoints.items():
if api not in impls:
continue
for endpoint in api_endpoints:
impl = impls[api]
func = getattr(impl, endpoint.name)
if endpoint.method not in endpoint_impls:
endpoint_impls[endpoint.method] = {}
endpoint_impls[endpoint.method][_convert_path_to_regex(endpoint.route)] = (
func,
endpoint.descriptive_name or endpoint.route,
)
return endpoint_impls
def find_matching_endpoint(method, path, endpoint_impls):
"""Find the matching endpoint implementation for a given method and path.
Args:
method: HTTP method (GET, POST, etc.)
path: URL path to match against
endpoint_impls: A dictionary of endpoint implementations
Returns:
A tuple of (endpoint_function, path_params, descriptive_name)
Raises:
ValueError: If no matching endpoint is found
"""
impls = endpoint_impls.get(method.lower())
if not impls:
raise ValueError(f"No endpoint found for {path}")
for regex, (func, descriptive_name) in impls.items():
match = re.match(regex, path)
if match:
# Extract named groups from the regex match
path_params = match.groupdict()
return func, path_params, descriptive_name
raise ValueError(f"No endpoint found for {path}")

View file

@ -32,6 +32,10 @@ from llama_stack.distribution.request_headers import (
request_provider_data_context, request_provider_data_context,
) )
from llama_stack.distribution.resolver import InvalidProviderError from llama_stack.distribution.resolver import InvalidProviderError
from llama_stack.distribution.server.endpoints import (
find_matching_endpoint,
initialize_endpoint_impls,
)
from llama_stack.distribution.stack import ( from llama_stack.distribution.stack import (
construct_stack, construct_stack,
redact_sensitive_fields, redact_sensitive_fields,
@ -179,8 +183,11 @@ async def sse_generator(event_gen):
def create_dynamic_typed_route(func: Any, method: str, route: str): def create_dynamic_typed_route(func: Any, method: str, route: str):
async def endpoint(request: Request, **kwargs): async def endpoint(request: Request, **kwargs):
# Use context manager for request provider data # Get auth attributes from the request scope
with request_provider_data_context(request.headers): user_attributes = request.scope.get("user_attributes", {})
# Use context manager with both provider data and auth attributes
with request_provider_data_context(request.headers, user_attributes):
is_streaming = is_streaming_request(func.__name__, request, **kwargs) is_streaming = is_streaming_request(func.__name__, request, **kwargs)
try: try:
@ -219,14 +226,30 @@ def create_dynamic_typed_route(func: Any, method: str, route: str):
class TracingMiddleware: class TracingMiddleware:
def __init__(self, app): def __init__(self, app, impls):
self.app = app self.app = app
self.impls = impls
async def __call__(self, scope, receive, send): async def __call__(self, scope, receive, send):
path = scope.get("path", "") if scope.get("type") == "lifespan":
await start_trace(path, {"__location__": "server"})
try:
return await self.app(scope, receive, send) return await self.app(scope, receive, send)
path = scope.get("path", "")
if not hasattr(self, "endpoint_impls"):
self.endpoint_impls = initialize_endpoint_impls(self.impls)
_, _, trace_path = find_matching_endpoint(scope.get("method", "GET"), path, self.endpoint_impls)
trace_context = await start_trace(trace_path, {"__location__": "server", "raw_path": path})
async def send_with_trace_id(message):
if message["type"] == "http.response.start":
headers = message.get("headers", [])
headers.append([b"x-trace-id", str(trace_context.trace_id).encode()])
message["headers"] = headers
await send(message)
try:
return await self.app(scope, receive, send_with_trace_id)
finally: finally:
await end_trace() await end_trace()
@ -348,7 +371,6 @@ def main():
logger.info(yaml.dump(safe_config, indent=2)) logger.info(yaml.dump(safe_config, indent=2))
app = FastAPI(lifespan=lifespan) app = FastAPI(lifespan=lifespan)
app.add_middleware(TracingMiddleware)
if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"): if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"):
app.add_middleware(ClientVersionMiddleware) app.add_middleware(ClientVersionMiddleware)
@ -366,7 +388,7 @@ def main():
if Api.telemetry in impls: if Api.telemetry in impls:
setup_logger(impls[Api.telemetry]) setup_logger(impls[Api.telemetry])
else: else:
setup_logger(TelemetryAdapter(TelemetryConfig())) setup_logger(TelemetryAdapter(TelemetryConfig(), {}))
all_endpoints = get_all_api_endpoints() all_endpoints = get_all_api_endpoints()
@ -412,6 +434,7 @@ def main():
app.exception_handler(Exception)(global_exception_handler) app.exception_handler(Exception)(global_exception_handler)
app.__llama_stack_impls__ = impls app.__llama_stack_impls__ = impls
app.add_middleware(TracingMiddleware, impls=impls)
import uvicorn import uvicorn

View file

@ -12,9 +12,12 @@ import pydantic
from llama_stack.distribution.datatypes import KVStoreConfig, RoutableObjectWithProvider from llama_stack.distribution.datatypes import KVStoreConfig, RoutableObjectWithProvider
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.log import get_logger
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
logger = get_logger(__name__, category="core")
class DistributionRegistry(Protocol): class DistributionRegistry(Protocol):
async def get_all(self) -> List[RoutableObjectWithProvider]: ... async def get_all(self) -> List[RoutableObjectWithProvider]: ...
@ -47,8 +50,13 @@ def _parse_registry_values(values: List[str]) -> List[RoutableObjectWithProvider
"""Utility function to parse registry values into RoutableObjectWithProvider objects.""" """Utility function to parse registry values into RoutableObjectWithProvider objects."""
all_objects = [] all_objects = []
for value in values: for value in values:
try:
obj = pydantic.TypeAdapter(RoutableObjectWithProvider).validate_json(value) obj = pydantic.TypeAdapter(RoutableObjectWithProvider).validate_json(value)
all_objects.append(obj) all_objects.append(obj)
except pydantic.ValidationError as e:
logger.error(f"Error parsing registry value, raw value: {value}. Error: {e}")
continue
return all_objects return all_objects
@ -73,7 +81,11 @@ class DiskDistributionRegistry(DistributionRegistry):
if not json_str: if not json_str:
return None return None
try:
return pydantic.TypeAdapter(RoutableObjectWithProvider).validate_json(json_str) return pydantic.TypeAdapter(RoutableObjectWithProvider).validate_json(json_str)
except pydantic.ValidationError as e:
logger.error(f"Error parsing registry value for {type}:{identifier}, raw value: {json_str}. Error: {e}")
return None
async def update(self, obj: RoutableObjectWithProvider) -> None: async def update(self, obj: RoutableObjectWithProvider) -> None:
await self.kvstore.set( await self.kvstore.set(

View file

@ -5,9 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
import streamlit as st import streamlit as st
from llama_stack_client.lib.agents.agent import Agent from llama_stack_client import Agent, AgentEventLogger, RAGDocument
from llama_stack_client.lib.agents.event_logger import EventLogger
from llama_stack_client.types.shared.document import Document
from llama_stack.distribution.ui.modules.api import llama_stack_api from llama_stack.distribution.ui.modules.api import llama_stack_api
from llama_stack.distribution.ui.modules.utils import data_url_from_file from llama_stack.distribution.ui.modules.utils import data_url_from_file
@ -35,7 +33,7 @@ def rag_chat_page():
) )
if st.button("Create Vector Database"): if st.button("Create Vector Database"):
documents = [ documents = [
Document( RAGDocument(
document_id=uploaded_file.name, document_id=uploaded_file.name,
content=data_url_from_file(uploaded_file), content=data_url_from_file(uploaded_file),
) )
@ -167,7 +165,7 @@ def rag_chat_page():
message_placeholder = st.empty() message_placeholder = st.empty()
full_response = "" full_response = ""
retrieval_response = "" retrieval_response = ""
for log in EventLogger().log(response): for log in AgentEventLogger().log(response):
log.print() log.print()
if log.role == "tool_execution": if log.role == "tool_execution":
retrieval_response += log.content.replace("====", "").strip() retrieval_response += log.content.replace("====", "").strip()

View file

@ -186,13 +186,11 @@ class TopKSamplingStrategy(BaseModel):
top_k: int = Field(..., ge=1) top_k: int = Field(..., ge=1)
SamplingStrategy = register_schema( SamplingStrategy = Annotated[
Annotated[
Union[GreedySamplingStrategy, TopPSamplingStrategy, TopKSamplingStrategy], Union[GreedySamplingStrategy, TopPSamplingStrategy, TopKSamplingStrategy],
Field(discriminator="type"), Field(discriminator="type"),
], ]
name="SamplingStrategy", register_schema(SamplingStrategy, name="SamplingStrategy")
)
@json_schema_type @json_schema_type

View file

@ -244,6 +244,7 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
template_str = textwrap.dedent( template_str = textwrap.dedent(
""" """
If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)] If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
For a boolean parameter, be sure to use `True` or `False` (capitalized) for the value.
You SHOULD NOT include any other text in the response. You SHOULD NOT include any other text in the response.
Here is a list of functions in JSON format that you can invoke. Here is a list of functions in JSON format that you can invoke.

View file

@ -15,8 +15,11 @@ import json
import re import re
from typing import Optional, Tuple from typing import Optional, Tuple
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import BuiltinTool, RecursiveType, ToolCall, ToolPromptFormat from llama_stack.models.llama.datatypes import BuiltinTool, RecursiveType, ToolCall, ToolPromptFormat
logger = get_logger(name=__name__, category="inference")
BUILTIN_TOOL_PATTERN = r'\b(?P<tool_name>\w+)\.call\(query="(?P<query>[^"]*)"\)' BUILTIN_TOOL_PATTERN = r'\b(?P<tool_name>\w+)\.call\(query="(?P<query>[^"]*)"\)'
CUSTOM_TOOL_CALL_PATTERN = re.compile(r"<function=(?P<function_name>[^}]+)>(?P<args>{.*?})") CUSTOM_TOOL_CALL_PATTERN = re.compile(r"<function=(?P<function_name>[^}]+)>(?P<args>{.*?})")
@ -92,7 +95,15 @@ def parse_python_list_for_function_calls(input_string):
# Extract keyword arguments # Extract keyword arguments
for keyword in node.keywords: for keyword in node.keywords:
try:
function_args[keyword.arg] = ast.literal_eval(keyword.value) function_args[keyword.arg] = ast.literal_eval(keyword.value)
except ValueError as e:
logger.error(
f"Error parsing tool call argument '{keyword.arg}': {e}, full input string: '{input_string}'"
)
raise ValueError(
f"Error parsing tool call argument '{keyword.arg}', full input string: '{input_string}'"
) from e
result.append((function_name, function_args)) result.append((function_name, function_args))

View file

@ -6,14 +6,12 @@
import copy import copy
import json import json
import os
import re import re
import secrets import secrets
import string import string
import uuid import uuid
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import AsyncGenerator, List, Optional, Union from typing import AsyncGenerator, List, Optional, Union
from urllib.parse import urlparse
import httpx import httpx
@ -60,7 +58,6 @@ from llama_stack.apis.inference import (
) )
from llama_stack.apis.safety import Safety from llama_stack.apis.safety import Safety
from llama_stack.apis.tools import ( from llama_stack.apis.tools import (
RAGDocument,
ToolGroups, ToolGroups,
ToolInvocationResult, ToolInvocationResult,
ToolRuntime, ToolRuntime,
@ -180,23 +177,27 @@ class ChatAgent(ShieldRunnerMixin):
return messages return messages
async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator: async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator:
await self._initialize_tools(request.toolgroups) span = tracing.get_current_span()
async with tracing.span("create_and_execute_turn") as span: if span:
span.set_attribute("session_id", request.session_id) span.set_attribute("session_id", request.session_id)
span.set_attribute("agent_id", self.agent_id) span.set_attribute("agent_id", self.agent_id)
span.set_attribute("request", request.model_dump_json()) span.set_attribute("request", request.model_dump_json())
turn_id = str(uuid.uuid4()) turn_id = str(uuid.uuid4())
span.set_attribute("turn_id", turn_id) span.set_attribute("turn_id", turn_id)
await self._initialize_tools(request.toolgroups)
async for chunk in self._run_turn(request, turn_id): async for chunk in self._run_turn(request, turn_id):
yield chunk yield chunk
async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator: async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator:
await self._initialize_tools() span = tracing.get_current_span()
async with tracing.span("resume_turn") as span: if span:
span.set_attribute("agent_id", self.agent_id) span.set_attribute("agent_id", self.agent_id)
span.set_attribute("session_id", request.session_id) span.set_attribute("session_id", request.session_id)
span.set_attribute("turn_id", request.turn_id)
span.set_attribute("request", request.model_dump_json()) span.set_attribute("request", request.model_dump_json())
span.set_attribute("turn_id", request.turn_id)
await self._initialize_tools()
async for chunk in self._run_turn(request): async for chunk in self._run_turn(request):
yield chunk yield chunk
@ -449,8 +450,16 @@ class ChatAgent(ShieldRunnerMixin):
stream: bool = False, stream: bool = False,
documents: Optional[List[Document]] = None, documents: Optional[List[Document]] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
# if document is passed in a turn, we parse the raw text of the document
# and sent it as a user message
if documents: if documents:
await self.handle_documents(session_id, documents, input_messages) contexts = []
for document in documents:
raw_document_text = await get_raw_document_text(document)
contexts.append(raw_document_text)
attached_context = "\n".join(contexts)
input_messages[-1].context = attached_context
session_info = await self.storage.get_session_info(session_id) session_info = await self.storage.get_session_info(session_id)
# if the session has a memory bank id, let the memory tool use it # if the session has a memory bank id, let the memory tool use it
@ -825,7 +834,10 @@ class ChatAgent(ShieldRunnerMixin):
) )
tool_name_to_args[tool_def.identifier] = toolgroup_to_args.get(toolgroup_name, {}) tool_name_to_args[tool_def.identifier] = toolgroup_to_args.get(toolgroup_name, {})
self.tool_defs, self.tool_name_to_args = list(tool_name_to_def.values()), tool_name_to_args self.tool_defs, self.tool_name_to_args = (
list(tool_name_to_def.values()),
tool_name_to_args,
)
def _parse_toolgroup_name(self, toolgroup_name_with_maybe_tool_name: str) -> tuple[str, Optional[str]]: def _parse_toolgroup_name(self, toolgroup_name_with_maybe_tool_name: str) -> tuple[str, Optional[str]]:
"""Parse a toolgroup name into its components. """Parse a toolgroup name into its components.
@ -876,144 +888,27 @@ class ChatAgent(ShieldRunnerMixin):
logger.debug(f"tool call {tool_name_str} completed with result: {result}") logger.debug(f"tool call {tool_name_str} completed with result: {result}")
return result return result
async def handle_documents(
self,
session_id: str,
documents: List[Document],
input_messages: List[Message],
) -> None:
memory_tool = any(tool_def.tool_name == MEMORY_QUERY_TOOL for tool_def in self.tool_defs)
code_interpreter_tool = any(tool_def.tool_name == BuiltinTool.code_interpreter for tool_def in self.tool_defs)
content_items = []
url_items = []
pattern = re.compile("^(https?://|file://|data:)")
for d in documents:
if isinstance(d.content, URL):
url_items.append(d.content)
elif pattern.match(d.content):
url_items.append(URL(uri=d.content))
else:
content_items.append(d)
# Save the contents to a tempdir and use its path as a URL if code interpreter is present async def load_data_from_url(url: str) -> str:
if code_interpreter_tool: if url.startswith("http"):
for c in content_items:
temp_file_path = os.path.join(self.tempdir, f"{make_random_string()}.txt")
with open(temp_file_path, "w") as temp_file:
temp_file.write(c.content)
url_items.append(URL(uri=f"file://{temp_file_path}"))
if memory_tool and code_interpreter_tool:
# if both memory and code_interpreter are available, we download the URLs
# and attach the data to the last message.
await attachment_message(self.tempdir, url_items, input_messages[-1])
# Since memory is present, add all the data to the memory bank
await self.add_to_session_vector_db(session_id, documents)
elif code_interpreter_tool:
# if only code_interpreter is available, we download the URLs to a tempdir
# and attach the path to them as a message to inference with the
# assumption that the model invokes the code_interpreter tool with the path
await attachment_message(self.tempdir, url_items, input_messages[-1])
elif memory_tool:
# if only memory is available, we load the data from the URLs and content items to the memory bank
await self.add_to_session_vector_db(session_id, documents)
else:
# if no memory or code_interpreter tool is available,
# we try to load the data from the URLs and content items as a message to inference
# and add it to the last message's context
input_messages[-1].context = "\n".join(
[doc.content for doc in content_items] + await load_data_from_urls(url_items)
)
async def _ensure_vector_db(self, session_id: str) -> str:
session_info = await self.storage.get_session_info(session_id)
if session_info is None:
raise ValueError(f"Session {session_id} not found")
if session_info.vector_db_id is None:
vector_db_id = f"vector_db_{session_id}"
# TODO: the semantic for registration is definitely not "creation"
# so we need to fix it if we expect the agent to create a new vector db
# for each session
await self.vector_io_api.register_vector_db(
vector_db_id=vector_db_id,
embedding_model="all-MiniLM-L6-v2",
)
await self.storage.add_vector_db_to_session(session_id, vector_db_id)
else:
vector_db_id = session_info.vector_db_id
return vector_db_id
async def add_to_session_vector_db(self, session_id: str, data: List[Document]) -> None:
vector_db_id = await self._ensure_vector_db(session_id)
documents = [
RAGDocument(
document_id=str(uuid.uuid4()),
content=a.content,
mime_type=a.mime_type,
metadata={},
)
for a in data
]
await self.tool_runtime_api.rag_tool.insert(
documents=documents,
vector_db_id=vector_db_id,
chunk_size_in_tokens=512,
)
async def load_data_from_urls(urls: List[URL]) -> List[str]:
data = []
for url in urls:
uri = url.uri
if uri.startswith("file://"):
filepath = uri[len("file://") :]
with open(filepath, "r") as f:
data.append(f.read())
elif uri.startswith("http"):
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
r = await client.get(uri) r = await client.get(url)
resp = r.text resp = r.text
data.append(resp) return resp
return data raise ValueError(f"Unexpected URL: {type(url)}")
async def attachment_message(tempdir: str, urls: List[URL], message: UserMessage) -> None: async def get_raw_document_text(document: Document) -> str:
contents = [] if not document.mime_type.startswith("text/"):
raise ValueError(f"Unexpected document mime type: {document.mime_type}")
for url in urls: if isinstance(document.content, URL):
uri = url.uri return await load_data_from_url(document.content.uri)
if uri.startswith("file://"): elif isinstance(document.content, str):
filepath = uri[len("file://") :] return document.content
elif uri.startswith("http"): elif isinstance(document.content, TextContentItem):
path = urlparse(uri).path return document.content.text
basename = os.path.basename(path)
filepath = f"{tempdir}/{make_random_string() + basename}"
logger.info(f"Downloading {url} -> {filepath}")
async with httpx.AsyncClient() as client:
r = await client.get(uri)
resp = r.text
with open(filepath, "w") as fp:
fp.write(resp)
else: else:
raise ValueError(f"Unsupported URL {url}") raise ValueError(f"Unexpected document content type: {type(document.content)}")
contents.append(
TextContentItem(
text=f'# User provided a file accessible to you at "{filepath}"\nYou can use code_interpreter to load and inspect it.'
)
)
if isinstance(message.content, list):
message.content.extend(contents)
else:
if isinstance(message.content, str):
message.content = [TextContentItem(text=message.content)] + contents
else:
message.content = [message.content] + contents
def _interpret_content_as_attachment( def _interpret_content_as_attachment(

View file

@ -13,6 +13,9 @@ from typing import List, Optional
from pydantic import BaseModel from pydantic import BaseModel
from llama_stack.apis.agents import ToolExecutionStep, Turn from llama_stack.apis.agents import ToolExecutionStep, Turn
from llama_stack.distribution.access_control import check_access
from llama_stack.distribution.datatypes import AccessAttributes
from llama_stack.distribution.request_headers import get_auth_attributes
from llama_stack.providers.utils.kvstore import KVStore from llama_stack.providers.utils.kvstore import KVStore
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -24,6 +27,7 @@ class AgentSessionInfo(BaseModel):
# TODO: is this used anywhere? # TODO: is this used anywhere?
vector_db_id: Optional[str] = None vector_db_id: Optional[str] = None
started_at: datetime started_at: datetime
access_attributes: Optional[AccessAttributes] = None
class AgentPersistence: class AgentPersistence:
@ -33,11 +37,18 @@ class AgentPersistence:
async def create_session(self, name: str) -> str: async def create_session(self, name: str) -> str:
session_id = str(uuid.uuid4()) session_id = str(uuid.uuid4())
# Get current user's auth attributes for new sessions
auth_attributes = get_auth_attributes()
access_attributes = AccessAttributes(**auth_attributes) if auth_attributes else None
session_info = AgentSessionInfo( session_info = AgentSessionInfo(
session_id=session_id, session_id=session_id,
session_name=name, session_name=name,
started_at=datetime.now(timezone.utc), started_at=datetime.now(timezone.utc),
access_attributes=access_attributes,
) )
await self.kvstore.set( await self.kvstore.set(
key=f"session:{self.agent_id}:{session_id}", key=f"session:{self.agent_id}:{session_id}",
value=session_info.model_dump_json(), value=session_info.model_dump_json(),
@ -51,12 +62,34 @@ class AgentPersistence:
if not value: if not value:
return None return None
return AgentSessionInfo(**json.loads(value)) session_info = AgentSessionInfo(**json.loads(value))
# Check access to session
if not self._check_session_access(session_info):
return None
return session_info
def _check_session_access(self, session_info: AgentSessionInfo) -> bool:
"""Check if current user has access to the session."""
# Handle backward compatibility for old sessions without access control
if not hasattr(session_info, "access_attributes"):
return True
return check_access(session_info.session_id, session_info.access_attributes, get_auth_attributes())
async def get_session_if_accessible(self, session_id: str) -> Optional[AgentSessionInfo]:
"""Get session info if the user has access to it. For internal use by sub-session methods."""
session_info = await self.get_session_info(session_id)
if not session_info:
return None
return session_info
async def add_vector_db_to_session(self, session_id: str, vector_db_id: str): async def add_vector_db_to_session(self, session_id: str, vector_db_id: str):
session_info = await self.get_session_info(session_id) session_info = await self.get_session_if_accessible(session_id)
if session_info is None: if session_info is None:
raise ValueError(f"Session {session_id} not found") raise ValueError(f"Session {session_id} not found or access denied")
session_info.vector_db_id = vector_db_id session_info.vector_db_id = vector_db_id
await self.kvstore.set( await self.kvstore.set(
@ -65,12 +98,18 @@ class AgentPersistence:
) )
async def add_turn_to_session(self, session_id: str, turn: Turn): async def add_turn_to_session(self, session_id: str, turn: Turn):
if not await self.get_session_if_accessible(session_id):
raise ValueError(f"Session {session_id} not found or access denied")
await self.kvstore.set( await self.kvstore.set(
key=f"session:{self.agent_id}:{session_id}:{turn.turn_id}", key=f"session:{self.agent_id}:{session_id}:{turn.turn_id}",
value=turn.model_dump_json(), value=turn.model_dump_json(),
) )
async def get_session_turns(self, session_id: str) -> List[Turn]: async def get_session_turns(self, session_id: str) -> List[Turn]:
if not await self.get_session_if_accessible(session_id):
raise ValueError(f"Session {session_id} not found or access denied")
values = await self.kvstore.range( values = await self.kvstore.range(
start_key=f"session:{self.agent_id}:{session_id}:", start_key=f"session:{self.agent_id}:{session_id}:",
end_key=f"session:{self.agent_id}:{session_id}:\xff\xff\xff\xff", end_key=f"session:{self.agent_id}:{session_id}:\xff\xff\xff\xff",
@ -87,6 +126,9 @@ class AgentPersistence:
return turns return turns
async def get_session_turn(self, session_id: str, turn_id: str) -> Optional[Turn]: async def get_session_turn(self, session_id: str, turn_id: str) -> Optional[Turn]:
if not await self.get_session_if_accessible(session_id):
raise ValueError(f"Session {session_id} not found or access denied")
value = await self.kvstore.get( value = await self.kvstore.get(
key=f"session:{self.agent_id}:{session_id}:{turn_id}", key=f"session:{self.agent_id}:{session_id}:{turn_id}",
) )
@ -95,24 +137,36 @@ class AgentPersistence:
return Turn(**json.loads(value)) return Turn(**json.loads(value))
async def set_in_progress_tool_call_step(self, session_id: str, turn_id: str, step: ToolExecutionStep): async def set_in_progress_tool_call_step(self, session_id: str, turn_id: str, step: ToolExecutionStep):
if not await self.get_session_if_accessible(session_id):
raise ValueError(f"Session {session_id} not found or access denied")
await self.kvstore.set( await self.kvstore.set(
key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}", key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}",
value=step.model_dump_json(), value=step.model_dump_json(),
) )
async def get_in_progress_tool_call_step(self, session_id: str, turn_id: str) -> Optional[ToolExecutionStep]: async def get_in_progress_tool_call_step(self, session_id: str, turn_id: str) -> Optional[ToolExecutionStep]:
if not await self.get_session_if_accessible(session_id):
return None
value = await self.kvstore.get( value = await self.kvstore.get(
key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}", key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}",
) )
return ToolExecutionStep(**json.loads(value)) if value else None return ToolExecutionStep(**json.loads(value)) if value else None
async def set_num_infer_iters_in_turn(self, session_id: str, turn_id: str, num_infer_iters: int): async def set_num_infer_iters_in_turn(self, session_id: str, turn_id: str, num_infer_iters: int):
if not await self.get_session_if_accessible(session_id):
raise ValueError(f"Session {session_id} not found or access denied")
await self.kvstore.set( await self.kvstore.set(
key=f"num_infer_iters_in_turn:{self.agent_id}:{session_id}:{turn_id}", key=f"num_infer_iters_in_turn:{self.agent_id}:{session_id}:{turn_id}",
value=str(num_infer_iters), value=str(num_infer_iters),
) )
async def get_num_infer_iters_in_turn(self, session_id: str, turn_id: str) -> Optional[int]: async def get_num_infer_iters_in_turn(self, session_id: str, turn_id: str) -> Optional[int]:
if not await self.get_session_if_accessible(session_id):
return None
value = await self.kvstore.get( value = await self.kvstore.get(
key=f"num_infer_iters_in_turn:{self.agent_id}:{session_id}:{turn_id}", key=f"num_infer_iters_in_turn:{self.agent_id}:{session_id}:{turn_id}",
) )

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import json import json
from typing import Any, Dict, List, Optional from typing import Any, Dict, List
from tqdm import tqdm from tqdm import tqdm
@ -21,8 +21,8 @@ from llama_stack.providers.inline.agents.meta_reference.agent_instance import (
from llama_stack.providers.utils.common.data_schema_validator import ColumnName from llama_stack.providers.utils.common.data_schema_validator import ColumnName
from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore import kvstore_impl
from .....apis.common.job_types import Job from .....apis.common.job_types import Job, JobStatus
from .....apis.eval.eval import BenchmarkConfig, Eval, EvaluateResponse, JobStatus from .....apis.eval.eval import BenchmarkConfig, Eval, EvaluateResponse
from .config import MetaReferenceEvalConfig from .config import MetaReferenceEvalConfig
EVAL_TASKS_PREFIX = "benchmarks:" EVAL_TASKS_PREFIX = "benchmarks:"
@ -102,7 +102,7 @@ class MetaReferenceEvalImpl(
# need job scheduler queue (ray/celery) w/ jobs api # need job scheduler queue (ray/celery) w/ jobs api
job_id = str(len(self.jobs)) job_id = str(len(self.jobs))
self.jobs[job_id] = res self.jobs[job_id] = res
return Job(job_id=job_id) return Job(job_id=job_id, status=JobStatus.completed)
async def _run_agent_generation( async def _run_agent_generation(
self, input_rows: List[Dict[str, Any]], benchmark_config: BenchmarkConfig self, input_rows: List[Dict[str, Any]], benchmark_config: BenchmarkConfig
@ -216,17 +216,18 @@ class MetaReferenceEvalImpl(
return EvaluateResponse(generations=generations, scores=score_response.results) return EvaluateResponse(generations=generations, scores=score_response.results)
async def job_status(self, benchmark_id: str, job_id: str) -> Optional[JobStatus]: async def job_status(self, benchmark_id: str, job_id: str) -> Job:
if job_id in self.jobs: if job_id in self.jobs:
return JobStatus.completed return Job(job_id=job_id, status=JobStatus.completed)
return None raise ValueError(f"Job {job_id} not found")
async def job_cancel(self, benchmark_id: str, job_id: str) -> None: async def job_cancel(self, benchmark_id: str, job_id: str) -> None:
raise NotImplementedError("Job cancel is not implemented yet") raise NotImplementedError("Job cancel is not implemented yet")
async def job_result(self, benchmark_id: str, job_id: str) -> EvaluateResponse: async def job_result(self, benchmark_id: str, job_id: str) -> EvaluateResponse:
status = await self.job_status(benchmark_id, job_id) job = await self.job_status(benchmark_id, job_id)
status = job.status
if not status or status != JobStatus.completed: if not status or status != JobStatus.completed:
raise ValueError(f"Job is not completed, Status: {status.value}") raise ValueError(f"Job is not completed, Status: {status.value}")

View file

@ -23,7 +23,9 @@ from llama_stack.providers.utils.common.data_schema_validator import (
from .config import BasicScoringConfig from .config import BasicScoringConfig
from .scoring_fn.bfcl_scoring_fn import BFCLScoringFn from .scoring_fn.bfcl_scoring_fn import BFCLScoringFn
from .scoring_fn.docvqa_scoring_fn import DocVQAScoringFn
from .scoring_fn.equality_scoring_fn import EqualityScoringFn from .scoring_fn.equality_scoring_fn import EqualityScoringFn
from .scoring_fn.ifeval_scoring_fn import IfEvalScoringFn
from .scoring_fn.regex_parser_math_response_scoring_fn import ( from .scoring_fn.regex_parser_math_response_scoring_fn import (
RegexParserMathResponseScoringFn, RegexParserMathResponseScoringFn,
) )
@ -36,6 +38,8 @@ FIXED_FNS = [
RegexParserScoringFn, RegexParserScoringFn,
RegexParserMathResponseScoringFn, RegexParserMathResponseScoringFn,
BFCLScoringFn, BFCLScoringFn,
IfEvalScoringFn,
DocVQAScoringFn,
] ]

View file

@ -0,0 +1,240 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import re
from typing import Any, Dict, Optional
from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
from .fn_defs.docvqa import docvqa
CONTRACTIONS = {
"aint": "ain't",
"arent": "aren't",
"cant": "can't",
"couldve": "could've",
"couldnt": "couldn't",
"couldn'tve": "couldn't've",
"couldnt've": "couldn't've",
"didnt": "didn't",
"doesnt": "doesn't",
"dont": "don't",
"hadnt": "hadn't",
"hadnt've": "hadn't've",
"hadn'tve": "hadn't've",
"hasnt": "hasn't",
"havent": "haven't",
"hed": "he'd",
"hed've": "he'd've",
"he'dve": "he'd've",
"hes": "he's",
"howd": "how'd",
"howll": "how'll",
"hows": "how's",
"Id've": "I'd've",
"I'dve": "I'd've",
"Im": "I'm",
"Ive": "I've",
"isnt": "isn't",
"itd": "it'd",
"itd've": "it'd've",
"it'dve": "it'd've",
"itll": "it'll",
"let's": "let's",
"maam": "ma'am",
"mightnt": "mightn't",
"mightnt've": "mightn't've",
"mightn'tve": "mightn't've",
"mightve": "might've",
"mustnt": "mustn't",
"mustve": "must've",
"neednt": "needn't",
"notve": "not've",
"oclock": "o'clock",
"oughtnt": "oughtn't",
"ow's'at": "'ow's'at",
"'ows'at": "'ow's'at",
"'ow'sat": "'ow's'at",
"shant": "shan't",
"shed've": "she'd've",
"she'dve": "she'd've",
"she's": "she's",
"shouldve": "should've",
"shouldnt": "shouldn't",
"shouldnt've": "shouldn't've",
"shouldn'tve": "shouldn't've",
"somebody'd": "somebodyd",
"somebodyd've": "somebody'd've",
"somebody'dve": "somebody'd've",
"somebodyll": "somebody'll",
"somebodys": "somebody's",
"someoned": "someone'd",
"someoned've": "someone'd've",
"someone'dve": "someone'd've",
"someonell": "someone'll",
"someones": "someone's",
"somethingd": "something'd",
"somethingd've": "something'd've",
"something'dve": "something'd've",
"somethingll": "something'll",
"thats": "that's",
"thered": "there'd",
"thered've": "there'd've",
"there'dve": "there'd've",
"therere": "there're",
"theres": "there's",
"theyd": "they'd",
"theyd've": "they'd've",
"they'dve": "they'd've",
"theyll": "they'll",
"theyre": "they're",
"theyve": "they've",
"twas": "'twas",
"wasnt": "wasn't",
"wed've": "we'd've",
"we'dve": "we'd've",
"weve": "we've",
"werent": "weren't",
"whatll": "what'll",
"whatre": "what're",
"whats": "what's",
"whatve": "what've",
"whens": "when's",
"whered": "where'd",
"wheres": "where's",
"whereve": "where've",
"whod": "who'd",
"whod've": "who'd've",
"who'dve": "who'd've",
"wholl": "who'll",
"whos": "who's",
"whove": "who've",
"whyll": "why'll",
"whyre": "why're",
"whys": "why's",
"wont": "won't",
"wouldve": "would've",
"wouldnt": "wouldn't",
"wouldnt've": "wouldn't've",
"wouldn'tve": "wouldn't've",
"yall": "y'all",
"yall'll": "y'all'll",
"y'allll": "y'all'll",
"yall'd've": "y'all'd've",
"y'alld've": "y'all'd've",
"y'all'dve": "y'all'd've",
"youd": "you'd",
"youd've": "you'd've",
"you'dve": "you'd've",
"youll": "you'll",
"youre": "you're",
"youve": "you've",
"1st": "first",
"2nd": "second",
"3rd": "third",
}
NUMBERS = {
"none": "0",
"zero": "0",
"one": "1",
"two": "2",
"three": "3",
"four": "4",
"five": "5",
"six": "6",
"seven": "7",
"eight": "8",
"nine": "9",
"ten": "10",
}
ARTICLES = [
"a",
"an",
"the",
"to",
"in",
"from",
"by",
] # Contains a bit more than just articles, but we want to get rid of these elements influencing the accuracy
PERIOD_STRIP = re.compile(r"(?!<=\d)(\.)(?!\d)")
COMMA_STRIP = re.compile(r"(\d)(\,)(\d)")
PUNCTUATION = [
";",
r"/",
"[",
"]",
'"',
"{",
"}",
"(",
")",
"=",
"+",
"\\",
"_",
"-",
">",
"<",
"@",
"`",
",",
"?",
"!",
]
def normalize_answer(s: str) -> str:
# process punctuation
for p in PUNCTUATION:
if (p + " " in s or " " + p in s) or (re.search(COMMA_STRIP, s) is not None):
s = s.replace(p, "")
else:
s = s.replace(p, " ")
s = PERIOD_STRIP.sub("", s, re.UNICODE)
# process digits and articles
temp_text = s.lower().split()
out_text = []
for word in temp_text:
word = NUMBERS.setdefault(word, word)
if word not in ARTICLES:
out_text.append(word)
# standardize contractions
for word_id, word in enumerate(out_text):
if word in CONTRACTIONS:
out_text[word_id] = CONTRACTIONS[word]
return " ".join(out_text)
class DocVQAScoringFn(RegisteredBaseScoringFn):
"""
docvqa basically matches the generated answer against several allowed
choices, but we need to normalize the answer to avoid penalizing
trivial differences
"""
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.supported_fn_defs_registry = {
docvqa.identifier: docvqa,
}
async def score_row(
self,
input_row: Dict[str, Any],
scoring_fn_identifier: Optional[str] = "docvqa",
scoring_params: Optional[ScoringFnParams] = None,
) -> ScoringResultRow:
expected_answers = json.loads(input_row["expected_answer"])
generated_answer = input_row["generated_answer"]
score = 1.0 if normalize_answer(generated_answer) in [normalize_answer(s) for s in expected_answers] else 0.0
return {
"score": score,
}

View file

@ -0,0 +1,21 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import (
AggregationFunctionType,
BasicScoringFnParams,
ScoringFn,
)
docvqa = ScoringFn(
identifier="basic::docvqa",
description="DocVQA Visual Question & Answer scoring function",
return_type=NumberType(),
provider_id="basic",
provider_resource_id="docvqa",
params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.accuracy]),
)

View file

@ -0,0 +1,23 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import (
AggregationFunctionType,
BasicScoringFnParams,
ScoringFn,
)
ifeval = ScoringFn(
identifier="basic::ifeval",
description="Eval intruction follow capacity by checkping how many instructions can be followed in each example",
return_type=NumberType(),
provider_id="basic",
provider_resource_id="ifeval",
params=BasicScoringFnParams(
aggregation_functions=[AggregationFunctionType.weighted_average],
),
)

View file

@ -0,0 +1,80 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, Optional
from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
from .fn_defs.ifeval import (
ifeval,
)
class IfEvalScoringFn(RegisteredBaseScoringFn):
"""
A scoring_fn Instruction-Following Eval (IFEval) benchmark
"""
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.supported_fn_defs_registry = {
ifeval.identifier: ifeval,
}
async def score_row(
self,
input_row: Dict[str, Any],
scoring_fn_identifier: Optional[str] = None,
scoring_params: Optional[ScoringFnParams] = None,
) -> ScoringResultRow:
from ..utils.ifeval_utils import INSTRUCTION_DICT, INSTRUCTION_LIST
assert scoring_fn_identifier is not None, "Scoring function identifier not found."
fn_def = self.supported_fn_defs_registry[scoring_fn_identifier]
if scoring_params is not None:
fn_def.params = scoring_params
instruction_list = input_row["instruction_id_list"]
generated_answer = input_row["generated_answer"].strip()
is_following_list = []
results = dict(
{k + "_correct": 0.0 for k in INSTRUCTION_LIST},
**{k + "_total": 0.0 for k in INSTRUCTION_LIST},
)
for index, instruction_id in enumerate(instruction_list):
instruction_cls = INSTRUCTION_DICT[instruction_id]
instruction = instruction_cls(instruction_id)
results[instruction_id + "_total"] += 1.0
results[instruction_id.split(":")[0] + "_total"] += 1.0
clean_input_row = {k: v for k, v in input_row["kwargs"][index].items() if v is not None}
print(clean_input_row)
instruction.build_description(**clean_input_row)
args = instruction.get_instruction_args()
if args and "prompt" in args:
instruction.build_description(prompt=input_row["prompt"])
if generated_answer and instruction.check_following(generated_answer):
is_following_list.append(True)
results[instruction_id + "_correct"] += 1.0
results[instruction_id.split(":")[0] + "_correct"] += 1.0
else:
is_following_list.append(False)
if len(is_following_list) == 0:
return {
"score": 0.0,
"weight": 0.0,
}
return {
"score": float(sum(is_following_list)) / float(len(is_following_list)),
"weight": float(len(is_following_list)),
}

File diff suppressed because it is too large Load diff

View file

@ -6,12 +6,14 @@
from typing import Any, Dict from typing import Any, Dict
from llama_stack.distribution.datatypes import Api
from .config import TelemetryConfig, TelemetrySink from .config import TelemetryConfig, TelemetrySink
__all__ = ["TelemetryConfig", "TelemetrySink"] __all__ = ["TelemetryConfig", "TelemetrySink"]
async def get_provider_impl(config: TelemetryConfig, deps: Dict[str, Any]): async def get_provider_impl(config: TelemetryConfig, deps: Dict[Api, Any]):
from .telemetry import TelemetryAdapter from .telemetry import TelemetryAdapter
impl = TelemetryAdapter(config, deps) impl = TelemetryAdapter(config, deps)

View file

@ -13,19 +13,20 @@ from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
class TelemetrySink(str, Enum): class TelemetrySink(str, Enum):
OTEL = "otel" OTEL_TRACE = "otel_trace"
OTEL_METRIC = "otel_metric"
SQLITE = "sqlite" SQLITE = "sqlite"
CONSOLE = "console" CONSOLE = "console"
class TelemetryConfig(BaseModel): class TelemetryConfig(BaseModel):
otel_endpoint: str = Field( otel_trace_endpoint: str = Field(
default="http://localhost:4318/v1/traces", default="http://localhost:4318/v1/traces",
description="The OpenTelemetry collector endpoint URL", description="The OpenTelemetry collector endpoint URL for traces",
) )
service_name: str = Field( otel_metric_endpoint: str = Field(
default="llama-stack", default="http://localhost:4318/v1/metrics",
description="The service name to use for telemetry", description="The OpenTelemetry collector endpoint URL for metrics",
) )
sinks: List[TelemetrySink] = Field( sinks: List[TelemetrySink] = Field(
default=[TelemetrySink.CONSOLE, TelemetrySink.SQLITE], default=[TelemetrySink.CONSOLE, TelemetrySink.SQLITE],
@ -46,7 +47,6 @@ class TelemetryConfig(BaseModel):
@classmethod @classmethod
def sample_run_config(cls, __distro_dir__: str, db_name: str = "trace_store.db") -> Dict[str, Any]: def sample_run_config(cls, __distro_dir__: str, db_name: str = "trace_store.db") -> Dict[str, Any]:
return { return {
"service_name": "${env.OTEL_SERVICE_NAME:llama-stack}",
"sinks": "${env.TELEMETRY_SINKS:console,sqlite}", "sinks": "${env.TELEMETRY_SINKS:console,sqlite}",
"sqlite_db_path": "${env.SQLITE_DB_PATH:" + __distro_dir__ + "/" + db_name + "}", "sqlite_db_path": "${env.SQLITE_DB_PATH:" + __distro_dir__ + "/" + db_name + "}",
} }

View file

@ -101,6 +101,6 @@ class ConsoleSpanProcessor(SpanProcessor):
"""Shutdown the processor.""" """Shutdown the processor."""
pass pass
def force_flush(self, timeout_millis: float = None) -> bool: def force_flush(self, timeout_millis: float | None = None) -> bool:
"""Force flush any pending spans.""" """Force flush any pending spans."""
return True return True

View file

@ -12,6 +12,7 @@ from datetime import datetime, timezone
from opentelemetry.sdk.trace import SpanProcessor from opentelemetry.sdk.trace import SpanProcessor
from opentelemetry.trace import Span from opentelemetry.trace import Span
from opentelemetry.trace.span import format_span_id, format_trace_id
class SQLiteSpanProcessor(SpanProcessor): class SQLiteSpanProcessor(SpanProcessor):
@ -100,14 +101,14 @@ class SQLiteSpanProcessor(SpanProcessor):
conn = self._get_connection() conn = self._get_connection()
cursor = conn.cursor() cursor = conn.cursor()
trace_id = format(span.get_span_context().trace_id, "032x") trace_id = format_trace_id(span.get_span_context().trace_id)
span_id = format(span.get_span_context().span_id, "016x") span_id = format_span_id(span.get_span_context().span_id)
service_name = span.resource.attributes.get("service.name", "unknown") service_name = span.resource.attributes.get("service.name", "unknown")
parent_span_id = None parent_span_id = None
parent_context = span.parent parent_context = span.parent
if parent_context: if parent_context:
parent_span_id = format(parent_context.span_id, "016x") parent_span_id = format_span_id(parent_context.span_id)
# Insert into traces # Insert into traces
cursor.execute( cursor.execute(
@ -123,7 +124,7 @@ class SQLiteSpanProcessor(SpanProcessor):
( (
trace_id, trace_id,
service_name, service_name,
(span_id if not parent_span_id else None), (span_id if span.attributes.get("__root_span__") == "true" else None),
datetime.fromtimestamp(span.start_time / 1e9, timezone.utc).isoformat(), datetime.fromtimestamp(span.start_time / 1e9, timezone.utc).isoformat(),
datetime.fromtimestamp(span.end_time / 1e9, timezone.utc).isoformat(), datetime.fromtimestamp(span.end_time / 1e9, timezone.utc).isoformat(),
), ),

View file

@ -44,7 +44,7 @@ from llama_stack.providers.utils.telemetry.sqlite_trace_store import SQLiteTrace
from .config import TelemetryConfig, TelemetrySink from .config import TelemetryConfig, TelemetrySink
_GLOBAL_STORAGE = { _GLOBAL_STORAGE: dict[str, dict[str | int, Any]] = {
"active_spans": {}, "active_spans": {},
"counters": {}, "counters": {},
"gauges": {}, "gauges": {},
@ -54,30 +54,21 @@ _global_lock = threading.Lock()
_TRACER_PROVIDER = None _TRACER_PROVIDER = None
def string_to_trace_id(s: str) -> int:
# Convert the string to bytes and then to an integer
return int.from_bytes(s.encode(), byteorder="big", signed=False)
def string_to_span_id(s: str) -> int:
# Use only the first 8 bytes (64 bits) for span ID
return int.from_bytes(s.encode()[:8], byteorder="big", signed=False)
def is_tracing_enabled(tracer): def is_tracing_enabled(tracer):
with tracer.start_as_current_span("check_tracing") as span: with tracer.start_as_current_span("check_tracing") as span:
return span.is_recording() return span.is_recording()
class TelemetryAdapter(TelemetryDatasetMixin, Telemetry): class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
def __init__(self, config: TelemetryConfig, deps: Dict[str, Any]) -> None: def __init__(self, config: TelemetryConfig, deps: Dict[Api, Any]) -> None:
self.config = config self.config = config
self.datasetio_api = deps.get(Api.datasetio) self.datasetio_api = deps.get(Api.datasetio)
self.meter = None self.meter = None
resource = Resource.create( resource = Resource.create(
{ {
ResourceAttributes.SERVICE_NAME: self.config.service_name, # service name is always the same, use zero-width space to avoid clutter
ResourceAttributes.SERVICE_NAME: "",
} }
) )
@ -91,15 +82,16 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
provider = TracerProvider(resource=resource) provider = TracerProvider(resource=resource)
trace.set_tracer_provider(provider) trace.set_tracer_provider(provider)
_TRACER_PROVIDER = provider _TRACER_PROVIDER = provider
if TelemetrySink.OTEL in self.config.sinks: if TelemetrySink.OTEL_TRACE in self.config.sinks:
otlp_exporter = OTLPSpanExporter( span_exporter = OTLPSpanExporter(
endpoint=self.config.otel_endpoint, endpoint=self.config.otel_trace_endpoint,
) )
span_processor = BatchSpanProcessor(otlp_exporter) span_processor = BatchSpanProcessor(span_exporter)
trace.get_tracer_provider().add_span_processor(span_processor) trace.get_tracer_provider().add_span_processor(span_processor)
if TelemetrySink.OTEL_METRIC in self.config.sinks:
metric_reader = PeriodicExportingMetricReader( metric_reader = PeriodicExportingMetricReader(
OTLPMetricExporter( OTLPMetricExporter(
endpoint=self.config.otel_endpoint, endpoint=self.config.otel_metric_endpoint,
) )
) )
metric_provider = MeterProvider(resource=resource, metric_readers=[metric_reader]) metric_provider = MeterProvider(resource=resource, metric_readers=[metric_reader])
@ -109,7 +101,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
if TelemetrySink.CONSOLE in self.config.sinks: if TelemetrySink.CONSOLE in self.config.sinks:
trace.get_tracer_provider().add_span_processor(ConsoleSpanProcessor()) trace.get_tracer_provider().add_span_processor(ConsoleSpanProcessor())
if TelemetrySink.OTEL in self.config.sinks: if TelemetrySink.OTEL_METRIC in self.config.sinks:
self.meter = metrics.get_meter(__name__) self.meter = metrics.get_meter(__name__)
if TelemetrySink.SQLITE in self.config.sinks: if TelemetrySink.SQLITE in self.config.sinks:
self.trace_store = SQLiteTraceStore(self.config.sqlite_db_path) self.trace_store = SQLiteTraceStore(self.config.sqlite_db_path)
@ -135,7 +127,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
def _log_unstructured(self, event: UnstructuredLogEvent, ttl_seconds: int) -> None: def _log_unstructured(self, event: UnstructuredLogEvent, ttl_seconds: int) -> None:
with self._lock: with self._lock:
# Use global storage instead of instance storage # Use global storage instead of instance storage
span_id = string_to_span_id(event.span_id) span_id = event.span_id
span = _GLOBAL_STORAGE["active_spans"].get(span_id) span = _GLOBAL_STORAGE["active_spans"].get(span_id)
if span: if span:
@ -146,7 +138,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
"message": event.message, "message": event.message,
"severity": event.severity.value, "severity": event.severity.value,
"__ttl__": ttl_seconds, "__ttl__": ttl_seconds,
**event.attributes, **(event.attributes or {}),
}, },
timestamp=timestamp_ns, timestamp=timestamp_ns,
) )
@ -154,6 +146,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
print(f"Warning: No active span found for span_id {span_id}. Dropping event: {event}") print(f"Warning: No active span found for span_id {span_id}. Dropping event: {event}")
def _get_or_create_counter(self, name: str, unit: str) -> metrics.Counter: def _get_or_create_counter(self, name: str, unit: str) -> metrics.Counter:
assert self.meter is not None
if name not in _GLOBAL_STORAGE["counters"]: if name not in _GLOBAL_STORAGE["counters"]:
_GLOBAL_STORAGE["counters"][name] = self.meter.create_counter( _GLOBAL_STORAGE["counters"][name] = self.meter.create_counter(
name=name, name=name,
@ -163,6 +156,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
return _GLOBAL_STORAGE["counters"][name] return _GLOBAL_STORAGE["counters"][name]
def _get_or_create_gauge(self, name: str, unit: str) -> metrics.ObservableGauge: def _get_or_create_gauge(self, name: str, unit: str) -> metrics.ObservableGauge:
assert self.meter is not None
if name not in _GLOBAL_STORAGE["gauges"]: if name not in _GLOBAL_STORAGE["gauges"]:
_GLOBAL_STORAGE["gauges"][name] = self.meter.create_gauge( _GLOBAL_STORAGE["gauges"][name] = self.meter.create_gauge(
name=name, name=name,
@ -182,6 +176,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
up_down_counter.add(event.value, attributes=event.attributes) up_down_counter.add(event.value, attributes=event.attributes)
def _get_or_create_up_down_counter(self, name: str, unit: str) -> metrics.UpDownCounter: def _get_or_create_up_down_counter(self, name: str, unit: str) -> metrics.UpDownCounter:
assert self.meter is not None
if name not in _GLOBAL_STORAGE["up_down_counters"]: if name not in _GLOBAL_STORAGE["up_down_counters"]:
_GLOBAL_STORAGE["up_down_counters"][name] = self.meter.create_up_down_counter( _GLOBAL_STORAGE["up_down_counters"][name] = self.meter.create_up_down_counter(
name=name, name=name,
@ -192,8 +187,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
def _log_structured(self, event: StructuredLogEvent, ttl_seconds: int) -> None: def _log_structured(self, event: StructuredLogEvent, ttl_seconds: int) -> None:
with self._lock: with self._lock:
span_id = string_to_span_id(event.span_id) span_id = int(event.span_id, 16)
trace_id = string_to_trace_id(event.trace_id)
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
if event.attributes is None: if event.attributes is None:
event.attributes = {} event.attributes = {}
@ -204,14 +198,23 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
if span_id in _GLOBAL_STORAGE["active_spans"]: if span_id in _GLOBAL_STORAGE["active_spans"]:
return return
parent_span = None context = None
if event.payload.parent_span_id: if event.payload.parent_span_id:
parent_span_id = string_to_span_id(event.payload.parent_span_id) parent_span_id = int(event.payload.parent_span_id, 16)
parent_span = _GLOBAL_STORAGE["active_spans"].get(parent_span_id) parent_span = _GLOBAL_STORAGE["active_spans"].get(parent_span_id)
context = trace.set_span_in_context(parent_span)
context = trace.Context(trace_id=trace_id) else:
if parent_span: context = trace.set_span_in_context(
context = trace.set_span_in_context(parent_span, context) trace.NonRecordingSpan(
trace.SpanContext(
trace_id=int(event.trace_id, 16),
span_id=span_id,
is_remote=False,
trace_flags=trace.TraceFlags(trace.TraceFlags.SAMPLED),
)
)
)
event.attributes["__root_span__"] = "true"
span = tracer.start_span( span = tracer.start_span(
name=event.payload.name, name=event.payload.name,

View file

@ -69,7 +69,7 @@ def popen_not_allowed(*args, **kwargs):
) )
_subprocess.Popen = popen_not_allowed _subprocess.Popen = popen_not_allowed # type: ignore
import atexit as _atexit import atexit as _atexit
@ -104,7 +104,7 @@ def _open_connections():
return _NETWORK_CONNECTIONS return _NETWORK_CONNECTIONS
_builtins._open_connections = _open_connections _builtins._open_connections = _open_connections # type: ignore
@_atexit.register @_atexit.register

View file

@ -161,9 +161,9 @@ _set_seeds()\
def process_matplotlib_response(response, matplotlib_dump_dir: str): def process_matplotlib_response(response, matplotlib_dump_dir: str):
image_data = response["image_data"] image_data = response["image_data"]
# Convert the base64 string to a bytes object # Convert the base64 string to a bytes object
images = [base64.b64decode(d["image_base64"]) for d in image_data] images_raw = [base64.b64decode(d["image_base64"]) for d in image_data]
# Create a list of PIL images from the bytes objects # Create a list of PIL images from the bytes objects
images = [Image.open(BytesIO(img)) for img in images] images = [Image.open(BytesIO(img)) for img in images_raw]
# Create a list of image paths # Create a list of image paths
image_paths = [] image_paths = []
for i, img in enumerate(images): for i, img in enumerate(images):

Some files were not shown because too many files have changed in this diff Show more