This commit is contained in:
Xi Yan 2025-03-23 15:48:14 -07:00
commit a54d757ade
197 changed files with 9392 additions and 3089 deletions

2
.github/TRIAGERS.md vendored Normal file
View file

@ -0,0 +1,2 @@
# This file documents Triage members in the Llama Stack community
@franciscojavierarceo @leseb

View file

@ -14,6 +14,10 @@ on:
- 'requirements.txt' - 'requirements.txt'
- '.github/workflows/integration-tests.yml' # This workflow - '.github/workflows/integration-tests.yml' # This workflow
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
jobs: jobs:
test-matrix: test-matrix:
runs-on: ubuntu-latest runs-on: ubuntu-latest
@ -52,6 +56,7 @@ jobs:
# always test against the latest version of the client # always test against the latest version of the client
uv pip install git+https://github.com/meta-llama/llama-stack-client-python.git@main uv pip install git+https://github.com/meta-llama/llama-stack-client-python.git@main
uv pip install -e . uv pip install -e .
llama stack build --template ollama --image-type venv
- name: Wait for Ollama to start - name: Wait for Ollama to start
run: | run: |

View file

@ -5,6 +5,10 @@ on:
push: push:
branches: [main] branches: [main]
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
jobs: jobs:
pre-commit: pre-commit:
runs-on: ubuntu-latest runs-on: ubuntu-latest

View file

@ -18,6 +18,10 @@ on:
- 'llama_stack/distribution/*.sh' - 'llama_stack/distribution/*.sh'
- '.github/workflows/providers-build.yml' - '.github/workflows/providers-build.yml'
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
jobs: jobs:
generate-matrix: generate-matrix:
runs-on: ubuntu-latest runs-on: ubuntu-latest

View file

@ -8,6 +8,10 @@ on:
- reopened - reopened
- synchronize - synchronize
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
permissions: permissions:
contents: read contents: read

View file

@ -15,6 +15,10 @@ on:
- '.github/workflows/unit-tests.yml' # This workflow - '.github/workflows/unit-tests.yml' # This workflow
workflow_dispatch: workflow_dispatch:
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
jobs: jobs:
unit-tests: unit-tests:
runs-on: ubuntu-latest runs-on: ubuntu-latest

View file

@ -22,6 +22,10 @@ on:
- 'pyproject.toml' - 'pyproject.toml'
- '.github/workflows/update-readthedocs.yml' - '.github/workflows/update-readthedocs.yml'
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
jobs: jobs:
update-readthedocs: update-readthedocs:
runs-on: ubuntu-latest runs-on: ubuntu-latest

View file

@ -89,10 +89,11 @@ repos:
name: API Spec Codegen name: API Spec Codegen
additional_dependencies: additional_dependencies:
- uv==0.6.2 - uv==0.6.2
entry: sh -c 'uv run --with ".[dev]" ./docs/openapi_generator/run_openapi_generator.sh > /dev/null 2>&1' entry: sh -c 'uv run --with ".[dev]" ./docs/openapi_generator/run_openapi_generator.sh > /dev/null'
language: python language: python
pass_filenames: false pass_filenames: false
require_serial: true require_serial: true
files: ^llama_stack/apis/|^docs/openapi_generator/
ci: ci:
autofix_commit_msg: 🎨 [pre-commit.ci] Auto format from pre-commit.com hooks autofix_commit_msg: 🎨 [pre-commit.ci] Auto format from pre-commit.com hooks

View file

@ -135,9 +135,11 @@ uv sync
## Coding Style ## Coding Style
* Comments should provide meaningful insights into the code. Avoid filler comments that simply describe the next step, as they create unnecessary clutter, same goes for docstrings.
* Prefer comments to clarify surprising behavior and/or relationships between parts of the code rather than explain what the next line of code does.
* Catching exceptions, prefer using a specific exception type rather than a broad catch-all like `Exception`.
* Error messages should be prefixed with "Failed to ..."
* 4 spaces for indentation rather than tabs * 4 spaces for indentation rather than tabs
* 80 character line length
* ...
## Common Tasks ## Common Tasks
@ -166,7 +168,7 @@ If you have made changes to a provider's configuration in any form (introducing
If you are making changes to the documentation at [https://llama-stack.readthedocs.io/en/latest/](https://llama-stack.readthedocs.io/en/latest/), you can use the following command to build the documentation and preview your changes. You will need [Sphinx](https://www.sphinx-doc.org/en/master/) and the readthedocs theme. If you are making changes to the documentation at [https://llama-stack.readthedocs.io/en/latest/](https://llama-stack.readthedocs.io/en/latest/), you can use the following command to build the documentation and preview your changes. You will need [Sphinx](https://www.sphinx-doc.org/en/master/) and the readthedocs theme.
```bash ```bash
cd llama-stack/docs cd docs
uv sync --extra docs uv sync --extra docs
# This rebuilds the documentation pages. # This rebuilds the documentation pages.

View file

@ -6,10 +6,12 @@
"chardet", "chardet",
"chromadb-client", "chromadb-client",
"datasets", "datasets",
"emoji",
"faiss-cpu", "faiss-cpu",
"fastapi", "fastapi",
"fire", "fire",
"httpx", "httpx",
"langdetect",
"matplotlib", "matplotlib",
"mcp", "mcp",
"nltk", "nltk",
@ -21,6 +23,7 @@
"psycopg2-binary", "psycopg2-binary",
"pymongo", "pymongo",
"pypdf", "pypdf",
"pythainlp",
"redis", "redis",
"requests", "requests",
"scikit-learn", "scikit-learn",
@ -37,10 +40,12 @@
"chardet", "chardet",
"chromadb-client", "chromadb-client",
"datasets", "datasets",
"emoji",
"faiss-cpu", "faiss-cpu",
"fastapi", "fastapi",
"fire", "fire",
"httpx", "httpx",
"langdetect",
"matplotlib", "matplotlib",
"nltk", "nltk",
"numpy", "numpy",
@ -51,6 +56,7 @@
"psycopg2-binary", "psycopg2-binary",
"pymongo", "pymongo",
"pypdf", "pypdf",
"pythainlp",
"redis", "redis",
"requests", "requests",
"scikit-learn", "scikit-learn",
@ -68,10 +74,12 @@
"chardet", "chardet",
"chromadb-client", "chromadb-client",
"datasets", "datasets",
"emoji",
"fastapi", "fastapi",
"fire", "fire",
"fireworks-ai", "fireworks-ai",
"httpx", "httpx",
"langdetect",
"matplotlib", "matplotlib",
"mcp", "mcp",
"nltk", "nltk",
@ -83,6 +91,7 @@
"psycopg2-binary", "psycopg2-binary",
"pymongo", "pymongo",
"pypdf", "pypdf",
"pythainlp",
"redis", "redis",
"requests", "requests",
"scikit-learn", "scikit-learn",
@ -102,11 +111,13 @@
"chardet", "chardet",
"chromadb-client", "chromadb-client",
"datasets", "datasets",
"emoji",
"faiss-cpu", "faiss-cpu",
"fastapi", "fastapi",
"fire", "fire",
"httpx", "httpx",
"huggingface_hub", "huggingface_hub",
"langdetect",
"matplotlib", "matplotlib",
"nltk", "nltk",
"numpy", "numpy",
@ -117,6 +128,7 @@
"psycopg2-binary", "psycopg2-binary",
"pymongo", "pymongo",
"pypdf", "pypdf",
"pythainlp",
"redis", "redis",
"requests", "requests",
"scikit-learn", "scikit-learn",
@ -134,10 +146,12 @@
"chardet", "chardet",
"chromadb-client", "chromadb-client",
"datasets", "datasets",
"emoji",
"fastapi", "fastapi",
"fire", "fire",
"fireworks-ai", "fireworks-ai",
"httpx", "httpx",
"langdetect",
"litellm", "litellm",
"matplotlib", "matplotlib",
"mcp", "mcp",
@ -150,6 +164,7 @@
"psycopg2-binary", "psycopg2-binary",
"pymongo", "pymongo",
"pypdf", "pypdf",
"pythainlp",
"redis", "redis",
"requests", "requests",
"scikit-learn", "scikit-learn",
@ -168,11 +183,13 @@
"chardet", "chardet",
"chromadb-client", "chromadb-client",
"datasets", "datasets",
"emoji",
"faiss-cpu", "faiss-cpu",
"fastapi", "fastapi",
"fire", "fire",
"fireworks-ai", "fireworks-ai",
"httpx", "httpx",
"langdetect",
"matplotlib", "matplotlib",
"mcp", "mcp",
"nltk", "nltk",
@ -184,6 +201,7 @@
"psycopg2-binary", "psycopg2-binary",
"pymongo", "pymongo",
"pypdf", "pypdf",
"pythainlp",
"redis", "redis",
"requests", "requests",
"scikit-learn", "scikit-learn",
@ -200,10 +218,12 @@
"blobfile", "blobfile",
"chardet", "chardet",
"datasets", "datasets",
"emoji",
"faiss-cpu", "faiss-cpu",
"fastapi", "fastapi",
"fire", "fire",
"httpx", "httpx",
"langdetect",
"litellm", "litellm",
"matplotlib", "matplotlib",
"nltk", "nltk",
@ -215,6 +235,7 @@
"psycopg2-binary", "psycopg2-binary",
"pymongo", "pymongo",
"pypdf", "pypdf",
"pythainlp",
"redis", "redis",
"requests", "requests",
"scikit-learn", "scikit-learn",
@ -231,11 +252,13 @@
"chardet", "chardet",
"chromadb-client", "chromadb-client",
"datasets", "datasets",
"emoji",
"faiss-cpu", "faiss-cpu",
"fastapi", "fastapi",
"fire", "fire",
"httpx", "httpx",
"huggingface_hub", "huggingface_hub",
"langdetect",
"matplotlib", "matplotlib",
"mcp", "mcp",
"nltk", "nltk",
@ -247,6 +270,7 @@
"psycopg2-binary", "psycopg2-binary",
"pymongo", "pymongo",
"pypdf", "pypdf",
"pythainlp",
"redis", "redis",
"requests", "requests",
"scikit-learn", "scikit-learn",
@ -263,11 +287,13 @@
"chardet", "chardet",
"chromadb-client", "chromadb-client",
"datasets", "datasets",
"emoji",
"faiss-cpu", "faiss-cpu",
"fastapi", "fastapi",
"fire", "fire",
"httpx", "httpx",
"huggingface_hub", "huggingface_hub",
"langdetect",
"matplotlib", "matplotlib",
"mcp", "mcp",
"nltk", "nltk",
@ -279,6 +305,7 @@
"psycopg2-binary", "psycopg2-binary",
"pymongo", "pymongo",
"pypdf", "pypdf",
"pythainlp",
"redis", "redis",
"requests", "requests",
"scikit-learn", "scikit-learn",
@ -297,11 +324,13 @@
"chardet", "chardet",
"chromadb-client", "chromadb-client",
"datasets", "datasets",
"emoji",
"fairscale", "fairscale",
"faiss-cpu", "faiss-cpu",
"fastapi", "fastapi",
"fire", "fire",
"httpx", "httpx",
"langdetect",
"lm-format-enforcer", "lm-format-enforcer",
"matplotlib", "matplotlib",
"mcp", "mcp",
@ -314,6 +343,7 @@
"psycopg2-binary", "psycopg2-binary",
"pymongo", "pymongo",
"pypdf", "pypdf",
"pythainlp",
"redis", "redis",
"requests", "requests",
"scikit-learn", "scikit-learn",
@ -334,12 +364,14 @@
"chardet", "chardet",
"chromadb-client", "chromadb-client",
"datasets", "datasets",
"emoji",
"fairscale", "fairscale",
"faiss-cpu", "faiss-cpu",
"fastapi", "fastapi",
"fbgemm-gpu", "fbgemm-gpu",
"fire", "fire",
"httpx", "httpx",
"langdetect",
"lm-format-enforcer", "lm-format-enforcer",
"matplotlib", "matplotlib",
"mcp", "mcp",
@ -352,6 +384,7 @@
"psycopg2-binary", "psycopg2-binary",
"pymongo", "pymongo",
"pypdf", "pypdf",
"pythainlp",
"redis", "redis",
"requests", "requests",
"scikit-learn", "scikit-learn",
@ -370,10 +403,12 @@
"aiosqlite", "aiosqlite",
"blobfile", "blobfile",
"chardet", "chardet",
"emoji",
"faiss-cpu", "faiss-cpu",
"fastapi", "fastapi",
"fire", "fire",
"httpx", "httpx",
"langdetect",
"matplotlib", "matplotlib",
"nltk", "nltk",
"numpy", "numpy",
@ -385,6 +420,7 @@
"psycopg2-binary", "psycopg2-binary",
"pymongo", "pymongo",
"pypdf", "pypdf",
"pythainlp",
"redis", "redis",
"requests", "requests",
"scikit-learn", "scikit-learn",
@ -401,10 +437,12 @@
"chardet", "chardet",
"chromadb-client", "chromadb-client",
"datasets", "datasets",
"emoji",
"faiss-cpu", "faiss-cpu",
"fastapi", "fastapi",
"fire", "fire",
"httpx", "httpx",
"langdetect",
"matplotlib", "matplotlib",
"mcp", "mcp",
"nltk", "nltk",
@ -417,6 +455,7 @@
"psycopg2-binary", "psycopg2-binary",
"pymongo", "pymongo",
"pypdf", "pypdf",
"pythainlp",
"redis", "redis",
"requests", "requests",
"scikit-learn", "scikit-learn",
@ -432,9 +471,11 @@
"chardet", "chardet",
"chromadb-client", "chromadb-client",
"datasets", "datasets",
"emoji",
"fastapi", "fastapi",
"fire", "fire",
"httpx", "httpx",
"langdetect",
"litellm", "litellm",
"matplotlib", "matplotlib",
"mcp", "mcp",
@ -447,6 +488,7 @@
"psycopg2-binary", "psycopg2-binary",
"pymongo", "pymongo",
"pypdf", "pypdf",
"pythainlp",
"redis", "redis",
"requests", "requests",
"scikit-learn", "scikit-learn",
@ -464,10 +506,12 @@
"chardet", "chardet",
"chromadb-client", "chromadb-client",
"datasets", "datasets",
"emoji",
"faiss-cpu", "faiss-cpu",
"fastapi", "fastapi",
"fire", "fire",
"httpx", "httpx",
"langdetect",
"matplotlib", "matplotlib",
"mcp", "mcp",
"nltk", "nltk",
@ -479,6 +523,7 @@
"psycopg2-binary", "psycopg2-binary",
"pymongo", "pymongo",
"pypdf", "pypdf",
"pythainlp",
"redis", "redis",
"requests", "requests",
"scikit-learn", "scikit-learn",
@ -496,10 +541,12 @@
"chardet", "chardet",
"chromadb-client", "chromadb-client",
"datasets", "datasets",
"emoji",
"faiss-cpu", "faiss-cpu",
"fastapi", "fastapi",
"fire", "fire",
"httpx", "httpx",
"langdetect",
"matplotlib", "matplotlib",
"mcp", "mcp",
"nltk", "nltk",
@ -512,6 +559,7 @@
"psycopg2-binary", "psycopg2-binary",
"pymongo", "pymongo",
"pypdf", "pypdf",
"pythainlp",
"redis", "redis",
"requests", "requests",
"scikit-learn", "scikit-learn",
@ -559,11 +607,13 @@
"chardet", "chardet",
"chromadb-client", "chromadb-client",
"datasets", "datasets",
"emoji",
"faiss-cpu", "faiss-cpu",
"fastapi", "fastapi",
"fire", "fire",
"httpx", "httpx",
"huggingface_hub", "huggingface_hub",
"langdetect",
"matplotlib", "matplotlib",
"mcp", "mcp",
"nltk", "nltk",
@ -575,6 +625,7 @@
"psycopg2-binary", "psycopg2-binary",
"pymongo", "pymongo",
"pypdf", "pypdf",
"pythainlp",
"redis", "redis",
"requests", "requests",
"scikit-learn", "scikit-learn",
@ -592,10 +643,12 @@
"chardet", "chardet",
"chromadb-client", "chromadb-client",
"datasets", "datasets",
"emoji",
"faiss-cpu", "faiss-cpu",
"fastapi", "fastapi",
"fire", "fire",
"httpx", "httpx",
"langdetect",
"matplotlib", "matplotlib",
"mcp", "mcp",
"nltk", "nltk",
@ -607,6 +660,7 @@
"psycopg2-binary", "psycopg2-binary",
"pymongo", "pymongo",
"pypdf", "pypdf",
"pythainlp",
"redis", "redis",
"requests", "requests",
"scikit-learn", "scikit-learn",
@ -625,10 +679,12 @@
"chardet", "chardet",
"chromadb-client", "chromadb-client",
"datasets", "datasets",
"emoji",
"faiss-cpu", "faiss-cpu",
"fastapi", "fastapi",
"fire", "fire",
"httpx", "httpx",
"langdetect",
"matplotlib", "matplotlib",
"mcp", "mcp",
"nltk", "nltk",
@ -640,6 +696,7 @@
"psycopg2-binary", "psycopg2-binary",
"pymongo", "pymongo",
"pypdf", "pypdf",
"pythainlp",
"redis", "redis",
"requests", "requests",
"scikit-learn", "scikit-learn",

View file

@ -51,14 +51,14 @@ services:
- ~/local/llama-stack/:/app/llama-stack-source - ~/local/llama-stack/:/app/llama-stack-source
- ./run${SAFETY_MODEL:+-with-safety}.yaml:/root/my-run.yaml - ./run${SAFETY_MODEL:+-with-safety}.yaml:/root/my-run.yaml
ports: ports:
- "${LLAMA_STACK_PORT:-5001}:${LLAMA_STACK_PORT:-5001}" - "${LLAMA_STACK_PORT:-8321}:${LLAMA_STACK_PORT:-8321}"
environment: environment:
- INFERENCE_MODEL=${INFERENCE_MODEL} - INFERENCE_MODEL=${INFERENCE_MODEL}
- SAFETY_MODEL=${SAFETY_MODEL:-} - SAFETY_MODEL=${SAFETY_MODEL:-}
- OLLAMA_URL=http://ollama:11434 - OLLAMA_URL=http://ollama:11434
entrypoint: > entrypoint: >
python -m llama_stack.distribution.server.server /root/my-run.yaml \ python -m llama_stack.distribution.server.server /root/my-run.yaml \
--port ${LLAMA_STACK_PORT:-5001} --port ${LLAMA_STACK_PORT:-8321}
deploy: deploy:
restart_policy: restart_policy:
condition: on-failure condition: on-failure

Binary file not shown.

View file

@ -84,9 +84,9 @@ services:
- SQLITE_STORE_DIR=${SQLITE_STORE_DIR:-$HOME/.llama/distributions/remote-vllm} - SQLITE_STORE_DIR=${SQLITE_STORE_DIR:-$HOME/.llama/distributions/remote-vllm}
- SAFETY_MODEL=${SAFETY_MODEL:-meta-llama/Llama-Guard-3-1B} - SAFETY_MODEL=${SAFETY_MODEL:-meta-llama/Llama-Guard-3-1B}
ports: ports:
- "${LLAMA_STACK_PORT:-5001}:${LLAMA_STACK_PORT:-5001}" - "${LLAMA_STACK_PORT:-8321}:${LLAMA_STACK_PORT:-8321}"
# Hack: wait for vLLM server to start before starting docker # Hack: wait for vLLM server to start before starting docker
entrypoint: bash -c "sleep 60; python -m llama_stack.distribution.server.server --yaml_config /root/llamastack-run-remote-vllm.yaml --port 5001" entrypoint: bash -c "sleep 60; python -m llama_stack.distribution.server.server --yaml_config /root/llamastack-run-remote-vllm.yaml --port 8321"
deploy: deploy:
restart_policy: restart_policy:
condition: on-failure condition: on-failure

View file

@ -83,7 +83,7 @@ services:
- ~/.llama:/root/.llama - ~/.llama:/root/.llama
- ./run${TGI_SAFETY_MODEL:+-with-safety}.yaml:/root/my-run.yaml - ./run${TGI_SAFETY_MODEL:+-with-safety}.yaml:/root/my-run.yaml
ports: ports:
- "${LLAMA_STACK_PORT:-5001}:${LLAMA_STACK_PORT:-5001}" - "${LLAMA_STACK_PORT:-8321}:${LLAMA_STACK_PORT:-8321}"
# Hack: wait for TGI server to start before starting docker # Hack: wait for TGI server to start before starting docker
entrypoint: bash -c "sleep 60; python -m llama_stack.distribution.server.server --yaml_config /root/my-run.yaml" entrypoint: bash -c "sleep 60; python -m llama_stack.distribution.server.server --yaml_config /root/my-run.yaml"
restart_policy: restart_policy:

View file

@ -2285,7 +2285,7 @@
"content": { "content": {
"application/json": { "application/json": {
"schema": { "schema": {
"$ref": "#/components/schemas/ListAgentSessionsResponse" "$ref": "#/components/schemas/Job"
} }
} }
} }
@ -2719,7 +2719,7 @@
} }
}, },
"tags": [ "tags": [
"Inspect" "Providers"
], ],
"description": "", "description": "",
"parameters": [] "parameters": []
@ -4108,6 +4108,11 @@
] ]
}, },
"arguments": { "arguments": {
"oneOf": [
{
"type": "string"
},
{
"type": "object", "type": "object",
"additionalProperties": { "additionalProperties": {
"oneOf": [ "oneOf": [
@ -4173,6 +4178,11 @@
] ]
} }
} }
]
},
"arguments_json": {
"type": "string"
}
}, },
"additionalProperties": false, "additionalProperties": false,
"required": [ "required": [
@ -6182,6 +6192,382 @@
"title": "EmbeddingsResponse", "title": "EmbeddingsResponse",
"description": "Response containing generated embeddings." "description": "Response containing generated embeddings."
}, },
"AgentCandidate": {
"type": "object",
"properties": {
"type": {
"type": "string",
"const": "agent",
"default": "agent"
},
"config": {
"$ref": "#/components/schemas/AgentConfig",
"description": "The configuration for the agent candidate."
}
},
"additionalProperties": false,
"required": [
"type",
"config"
],
"title": "AgentCandidate",
"description": "An agent candidate for evaluation."
},
"AggregationFunctionType": {
"type": "string",
"enum": [
"average",
"weighted_average",
"median",
"categorical_count",
"accuracy"
],
"title": "AggregationFunctionType"
},
"BasicScoringFnParams": {
"type": "object",
"properties": {
"type": {
"type": "string",
"const": "basic",
"default": "basic"
},
"aggregation_functions": {
"type": "array",
"items": {
"$ref": "#/components/schemas/AggregationFunctionType"
}
}
},
"additionalProperties": false,
"required": [
"type"
],
"title": "BasicScoringFnParams"
},
"BenchmarkConfig": {
"type": "object",
"properties": {
"eval_candidate": {
"$ref": "#/components/schemas/EvalCandidate",
"description": "The candidate to evaluate."
},
"scoring_params": {
"type": "object",
"additionalProperties": {
"$ref": "#/components/schemas/ScoringFnParams"
},
"description": "Map between scoring function id and parameters for each scoring function you want to run"
},
"num_examples": {
"type": "integer",
"description": "(Optional) The number of examples to evaluate. If not provided, all examples in the dataset will be evaluated"
}
},
"additionalProperties": false,
"required": [
"eval_candidate",
"scoring_params"
],
"title": "BenchmarkConfig",
"description": "A benchmark configuration for evaluation."
},
"EvalCandidate": {
"oneOf": [
{
"$ref": "#/components/schemas/ModelCandidate"
},
{
"$ref": "#/components/schemas/AgentCandidate"
}
],
"discriminator": {
"propertyName": "type",
"mapping": {
"model": "#/components/schemas/ModelCandidate",
"agent": "#/components/schemas/AgentCandidate"
}
}
},
"LLMAsJudgeScoringFnParams": {
"type": "object",
"properties": {
"type": {
"type": "string",
"const": "llm_as_judge",
"default": "llm_as_judge"
},
"judge_model": {
"type": "string"
},
"prompt_template": {
"type": "string"
},
"judge_score_regexes": {
"type": "array",
"items": {
"type": "string"
}
},
"aggregation_functions": {
"type": "array",
"items": {
"$ref": "#/components/schemas/AggregationFunctionType"
}
}
},
"additionalProperties": false,
"required": [
"type",
"judge_model"
],
"title": "LLMAsJudgeScoringFnParams"
},
"ModelCandidate": {
"type": "object",
"properties": {
"type": {
"type": "string",
"const": "model",
"default": "model"
},
"model": {
"type": "string",
"description": "The model ID to evaluate."
},
"sampling_params": {
"$ref": "#/components/schemas/SamplingParams",
"description": "The sampling parameters for the model."
},
"system_message": {
"$ref": "#/components/schemas/SystemMessage",
"description": "(Optional) The system message providing instructions or context to the model."
}
},
"additionalProperties": false,
"required": [
"type",
"model",
"sampling_params"
],
"title": "ModelCandidate",
"description": "A model candidate for evaluation."
},
"RegexParserScoringFnParams": {
"type": "object",
"properties": {
"type": {
"type": "string",
"const": "regex_parser",
"default": "regex_parser"
},
"parsing_regexes": {
"type": "array",
"items": {
"type": "string"
}
},
"aggregation_functions": {
"type": "array",
"items": {
"$ref": "#/components/schemas/AggregationFunctionType"
}
}
},
"additionalProperties": false,
"required": [
"type"
],
"title": "RegexParserScoringFnParams"
},
"ScoringFnParams": {
"oneOf": [
{
"$ref": "#/components/schemas/LLMAsJudgeScoringFnParams"
},
{
"$ref": "#/components/schemas/RegexParserScoringFnParams"
},
{
"$ref": "#/components/schemas/BasicScoringFnParams"
}
],
"discriminator": {
"propertyName": "type",
"mapping": {
"llm_as_judge": "#/components/schemas/LLMAsJudgeScoringFnParams",
"regex_parser": "#/components/schemas/RegexParserScoringFnParams",
"basic": "#/components/schemas/BasicScoringFnParams"
}
}
},
"EvaluateRowsRequest": {
"type": "object",
"properties": {
"input_rows": {
"type": "array",
"items": {
"type": "object",
"additionalProperties": {
"oneOf": [
{
"type": "null"
},
{
"type": "boolean"
},
{
"type": "number"
},
{
"type": "string"
},
{
"type": "array"
},
{
"type": "object"
}
]
}
},
"description": "The rows to evaluate."
},
"scoring_functions": {
"type": "array",
"items": {
"type": "string"
},
"description": "The scoring functions to use for the evaluation."
},
"benchmark_config": {
"$ref": "#/components/schemas/BenchmarkConfig",
"description": "The configuration for the benchmark."
}
},
"additionalProperties": false,
"required": [
"input_rows",
"scoring_functions",
"benchmark_config"
],
"title": "EvaluateRowsRequest"
},
"EvaluateResponse": {
"type": "object",
"properties": {
"generations": {
"type": "array",
"items": {
"type": "object",
"additionalProperties": {
"oneOf": [
{
"type": "null"
},
{
"type": "boolean"
},
{
"type": "number"
},
{
"type": "string"
},
{
"type": "array"
},
{
"type": "object"
}
]
}
},
"description": "The generations from the evaluation."
},
"scores": {
"type": "object",
"additionalProperties": {
"$ref": "#/components/schemas/ScoringResult"
},
"description": "The scores from the evaluation."
}
},
"additionalProperties": false,
"required": [
"generations",
"scores"
],
"title": "EvaluateResponse",
"description": "The response from an evaluation."
},
"ScoringResult": {
"type": "object",
"properties": {
"score_rows": {
"type": "array",
"items": {
"type": "object",
"additionalProperties": {
"oneOf": [
{
"type": "null"
},
{
"type": "boolean"
},
{
"type": "number"
},
{
"type": "string"
},
{
"type": "array"
},
{
"type": "object"
}
]
}
},
"description": "The scoring result for each row. Each row is a map of column name to value."
},
"aggregated_results": {
"type": "object",
"additionalProperties": {
"oneOf": [
{
"type": "null"
},
{
"type": "boolean"
},
{
"type": "number"
},
{
"type": "string"
},
{
"type": "array"
},
{
"type": "object"
}
]
},
"description": "Map of metric name to aggregated value"
}
},
"additionalProperties": false,
"required": [
"score_rows",
"aggregated_results"
],
"title": "ScoringResult",
"description": "A scoring result for a single row."
},
"Agent": { "Agent": {
"type": "object", "type": "object",
"properties": { "properties": {
@ -7319,8 +7705,7 @@
"completed", "completed",
"in_progress", "in_progress",
"failed", "failed",
"scheduled", "scheduled"
"cancelled"
], ],
"title": "JobStatus" "title": "JobStatus"
}, },
@ -7698,7 +8083,8 @@
"type": "object", "type": "object",
"properties": { "properties": {
"document_id": { "document_id": {
"type": "string" "type": "string",
"description": "The unique identifier for the document."
}, },
"content": { "content": {
"oneOf": [ "oneOf": [
@ -7717,10 +8103,12 @@
{ {
"$ref": "#/components/schemas/URL" "$ref": "#/components/schemas/URL"
} }
] ],
"description": "The content of the document."
}, },
"mime_type": { "mime_type": {
"type": "string" "type": "string",
"description": "The MIME type of the document."
}, },
"metadata": { "metadata": {
"type": "object", "type": "object",
@ -7745,7 +8133,8 @@
"type": "object" "type": "object"
} }
] ]
} },
"description": "Additional metadata for the document."
} }
}, },
"additionalProperties": false, "additionalProperties": false,
@ -7754,7 +8143,8 @@
"content", "content",
"metadata" "metadata"
], ],
"title": "RAGDocument" "title": "RAGDocument",
"description": "A document to be used for document ingestion in the RAG Tool."
}, },
"InsertRequest": { "InsertRequest": {
"type": "object", "type": "object",
@ -7964,9 +8354,6 @@
} }
}, },
"additionalProperties": false, "additionalProperties": false,
"required": [
"content"
],
"title": "ToolInvocationResult" "title": "ToolInvocationResult"
}, },
"IterrowsResponse": { "IterrowsResponse": {
@ -8013,6 +8400,30 @@
"title": "IterrowsResponse", "title": "IterrowsResponse",
"description": "A paginated list of rows from a dataset." "description": "A paginated list of rows from a dataset."
}, },
"Job": {
"type": "object",
"properties": {
"job_id": {
"type": "string"
},
"status": {
"type": "string",
"enum": [
"completed",
"in_progress",
"failed",
"scheduled"
],
"title": "JobStatus"
}
},
"additionalProperties": false,
"required": [
"job_id",
"status"
],
"title": "Job"
},
"ListAgentSessionsResponse": { "ListAgentSessionsResponse": {
"type": "object", "type": "object",
"properties": { "properties": {
@ -9596,21 +10007,16 @@
"RunRequest": { "RunRequest": {
"type": "object", "type": "object",
"properties": { "properties": {
"task": { "benchmark_config": {
"$ref": "#/components/schemas/EvaluationTask", "$ref": "#/components/schemas/BenchmarkConfig",
"description": "The task to evaluate. To specify a task, one of the following must be provided: - `benchmark_id`: Run evaluation task against a benchmark_id - `dataset_id` and `grader_ids`: Run evaluation task against a dataset_id and a list of grader_ids - `data_source` and `grader_ids`: Run evaluation task against a data source (e.g. rows, uri, etc.) and a list of grader_ids" "description": "The configuration for the benchmark."
},
"candidate": {
"$ref": "#/components/schemas/EvaluationCandidate",
"description": "The candidate to evaluate."
} }
}, },
"additionalProperties": false, "additionalProperties": false,
"required": [ "required": [
"task", "benchmark_config"
"candidate"
], ],
"title": "RunRequest" "title": "RunEvalRequest"
}, },
"RunShieldRequest": { "RunShieldRequest": {
"type": "object", "type": "object",
@ -9717,6 +10123,145 @@
], ],
"title": "SaveSpansToDatasetRequest" "title": "SaveSpansToDatasetRequest"
}, },
"ScoreRequest": {
"type": "object",
"properties": {
"input_rows": {
"type": "array",
"items": {
"type": "object",
"additionalProperties": {
"oneOf": [
{
"type": "null"
},
{
"type": "boolean"
},
{
"type": "number"
},
{
"type": "string"
},
{
"type": "array"
},
{
"type": "object"
}
]
}
},
"description": "The rows to score."
},
"scoring_functions": {
"type": "object",
"additionalProperties": {
"oneOf": [
{
"$ref": "#/components/schemas/ScoringFnParams"
},
{
"type": "null"
}
]
},
"description": "The scoring functions to use for the scoring."
}
},
"additionalProperties": false,
"required": [
"input_rows",
"scoring_functions"
],
"title": "ScoreRequest"
},
"ScoreResponse": {
"type": "object",
"properties": {
"results": {
"type": "object",
"additionalProperties": {
"$ref": "#/components/schemas/ScoringResult"
},
"description": "A map of scoring function name to ScoringResult."
}
},
"additionalProperties": false,
"required": [
"results"
],
"title": "ScoreResponse",
"description": "The response from scoring."
},
"ScoreBatchRequest": {
"type": "object",
"properties": {
"dataset_id": {
"type": "string"
},
"scoring_functions": {
"type": "object",
"additionalProperties": {
"oneOf": [
{
"$ref": "#/components/schemas/ScoringFnParams"
},
{
"type": "null"
}
]
}
},
"save_results_dataset": {
"type": "boolean"
}
},
"additionalProperties": false,
"required": [
"dataset_id",
"scoring_functions",
"save_results_dataset"
],
"title": "ScoreBatchRequest"
},
"ScoreBatchResponse": {
"type": "object",
"properties": {
"dataset_id": {
"type": "string"
},
"results": {
"type": "object",
"additionalProperties": {
"$ref": "#/components/schemas/ScoringResult"
}
}
},
"additionalProperties": false,
"required": [
"results"
],
"title": "ScoreBatchResponse"
},
"AlgorithmConfig": {
"oneOf": [
{
"$ref": "#/components/schemas/LoraFinetuningConfig"
},
{
"$ref": "#/components/schemas/QATFinetuningConfig"
}
],
"discriminator": {
"propertyName": "type",
"mapping": {
"LoRA": "#/components/schemas/LoraFinetuningConfig",
"QAT": "#/components/schemas/QATFinetuningConfig"
}
}
},
"LoraFinetuningConfig": { "LoraFinetuningConfig": {
"type": "object", "type": "object",
"properties": { "properties": {
@ -9852,14 +10397,7 @@
"type": "string" "type": "string"
}, },
"algorithm_config": { "algorithm_config": {
"oneOf": [ "$ref": "#/components/schemas/AlgorithmConfig"
{
"$ref": "#/components/schemas/LoraFinetuningConfig"
},
{
"$ref": "#/components/schemas/QATFinetuningConfig"
}
]
} }
}, },
"additionalProperties": false, "additionalProperties": false,

View file

@ -1562,6 +1562,109 @@ paths:
required: false required: false
schema: schema:
type: integer type: integer
/v1/eval/benchmarks/{benchmark_id}/jobs/{job_id}:
get:
responses:
'200':
description: The status of the evaluationjob.
content:
application/json:
schema:
$ref: '#/components/schemas/Job'
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- Eval
description: Get the status of a job.
parameters:
- name: benchmark_id
in: path
description: >-
The ID of the benchmark to run the evaluation on.
required: true
schema:
type: string
- name: job_id
in: path
description: The ID of the job to get the status of.
required: true
schema:
type: string
delete:
responses:
'200':
description: OK
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- Eval
description: Cancel a job.
parameters:
- name: benchmark_id
in: path
description: >-
The ID of the benchmark to run the evaluation on.
required: true
schema:
type: string
- name: job_id
in: path
description: The ID of the job to cancel.
required: true
schema:
type: string
/v1/eval/benchmarks/{benchmark_id}/jobs/{job_id}/result:
get:
responses:
'200':
description: The result of the job.
content:
application/json:
schema:
$ref: '#/components/schemas/EvaluateResponse'
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- Eval
description: Get the result of a job.
parameters:
- name: benchmark_id
in: path
description: >-
The ID of the benchmark to run the evaluation on.
required: true
schema:
type: string
- name: job_id
in: path
description: The ID of the job to get the result of.
required: true
schema:
type: string
/v1/agents/{agent_id}/sessions: /v1/agents/{agent_id}/sessions:
get: get:
responses: responses:
@ -1820,7 +1923,7 @@ paths:
default: default:
$ref: '#/components/responses/DefaultError' $ref: '#/components/responses/DefaultError'
tags: tags:
- Models - Providers
description: '' description: ''
parameters: [] parameters: []
post: post:
@ -2841,7 +2944,9 @@ components:
title: BuiltinTool title: BuiltinTool
- type: string - type: string
arguments: arguments:
type: object oneOf:
- type: string
- type: object
additionalProperties: additionalProperties:
oneOf: oneOf:
- type: string - type: string
@ -2865,6 +2970,8 @@ components:
- type: number - type: number
- type: boolean - type: boolean
- type: 'null' - type: 'null'
arguments_json:
type: string
additionalProperties: false additionalProperties: false
required: required:
- call_id - call_id
@ -4341,6 +4448,252 @@ components:
title: EmbeddingsResponse title: EmbeddingsResponse
description: >- description: >-
Response containing generated embeddings. Response containing generated embeddings.
AgentCandidate:
type: object
properties:
type:
type: string
const: agent
default: agent
config:
$ref: '#/components/schemas/AgentConfig'
description: >-
The configuration for the agent candidate.
additionalProperties: false
required:
- type
- config
title: AgentCandidate
description: An agent candidate for evaluation.
AggregationFunctionType:
type: string
enum:
- average
- weighted_average
- median
- categorical_count
- accuracy
title: AggregationFunctionType
BasicScoringFnParams:
type: object
properties:
type:
type: string
const: basic
default: basic
aggregation_functions:
type: array
items:
$ref: '#/components/schemas/AggregationFunctionType'
additionalProperties: false
required:
- type
title: BasicScoringFnParams
BenchmarkConfig:
type: object
properties:
eval_candidate:
$ref: '#/components/schemas/EvalCandidate'
description: The candidate to evaluate.
scoring_params:
type: object
additionalProperties:
$ref: '#/components/schemas/ScoringFnParams'
description: >-
Map between scoring function id and parameters for each scoring function
you want to run
num_examples:
type: integer
description: >-
(Optional) The number of examples to evaluate. If not provided, all examples
in the dataset will be evaluated
additionalProperties: false
required:
- eval_candidate
- scoring_params
title: BenchmarkConfig
description: >-
A benchmark configuration for evaluation.
EvalCandidate:
oneOf:
- $ref: '#/components/schemas/ModelCandidate'
- $ref: '#/components/schemas/AgentCandidate'
discriminator:
propertyName: type
mapping:
model: '#/components/schemas/ModelCandidate'
agent: '#/components/schemas/AgentCandidate'
LLMAsJudgeScoringFnParams:
type: object
properties:
type:
type: string
const: llm_as_judge
default: llm_as_judge
judge_model:
type: string
prompt_template:
type: string
judge_score_regexes:
type: array
items:
type: string
aggregation_functions:
type: array
items:
$ref: '#/components/schemas/AggregationFunctionType'
additionalProperties: false
required:
- type
- judge_model
title: LLMAsJudgeScoringFnParams
ModelCandidate:
type: object
properties:
type:
type: string
const: model
default: model
model:
type: string
description: The model ID to evaluate.
sampling_params:
$ref: '#/components/schemas/SamplingParams'
description: The sampling parameters for the model.
system_message:
$ref: '#/components/schemas/SystemMessage'
description: >-
(Optional) The system message providing instructions or context to the
model.
additionalProperties: false
required:
- type
- model
- sampling_params
title: ModelCandidate
description: A model candidate for evaluation.
RegexParserScoringFnParams:
type: object
properties:
type:
type: string
const: regex_parser
default: regex_parser
parsing_regexes:
type: array
items:
type: string
aggregation_functions:
type: array
items:
$ref: '#/components/schemas/AggregationFunctionType'
additionalProperties: false
required:
- type
title: RegexParserScoringFnParams
ScoringFnParams:
oneOf:
- $ref: '#/components/schemas/LLMAsJudgeScoringFnParams'
- $ref: '#/components/schemas/RegexParserScoringFnParams'
- $ref: '#/components/schemas/BasicScoringFnParams'
discriminator:
propertyName: type
mapping:
llm_as_judge: '#/components/schemas/LLMAsJudgeScoringFnParams'
regex_parser: '#/components/schemas/RegexParserScoringFnParams'
basic: '#/components/schemas/BasicScoringFnParams'
EvaluateRowsRequest:
type: object
properties:
input_rows:
type: array
items:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
description: The rows to evaluate.
scoring_functions:
type: array
items:
type: string
description: >-
The scoring functions to use for the evaluation.
benchmark_config:
$ref: '#/components/schemas/BenchmarkConfig'
description: The configuration for the benchmark.
additionalProperties: false
required:
- input_rows
- scoring_functions
- benchmark_config
title: EvaluateRowsRequest
EvaluateResponse:
type: object
properties:
generations:
type: array
items:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
description: The generations from the evaluation.
scores:
type: object
additionalProperties:
$ref: '#/components/schemas/ScoringResult'
description: The scores from the evaluation.
additionalProperties: false
required:
- generations
- scores
title: EvaluateResponse
description: The response from an evaluation.
ScoringResult:
type: object
properties:
score_rows:
type: array
items:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
description: >-
The scoring result for each row. Each row is a map of column name to value.
aggregated_results:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
description: Map of metric name to aggregated value
additionalProperties: false
required:
- score_rows
- aggregated_results
title: ScoringResult
description: A scoring result for a single row.
Agent: Agent:
type: object type: object
properties: properties:
@ -5098,7 +5451,6 @@ components:
- in_progress - in_progress
- failed - failed
- scheduled - scheduled
- cancelled
title: JobStatus title: JobStatus
scheduled_at: scheduled_at:
type: string type: string
@ -5373,6 +5725,7 @@ components:
properties: properties:
document_id: document_id:
type: string type: string
description: The unique identifier for the document.
content: content:
oneOf: oneOf:
- type: string - type: string
@ -5381,8 +5734,10 @@ components:
items: items:
$ref: '#/components/schemas/InterleavedContentItem' $ref: '#/components/schemas/InterleavedContentItem'
- $ref: '#/components/schemas/URL' - $ref: '#/components/schemas/URL'
description: The content of the document.
mime_type: mime_type:
type: string type: string
description: The MIME type of the document.
metadata: metadata:
type: object type: object
additionalProperties: additionalProperties:
@ -5393,12 +5748,15 @@ components:
- type: string - type: string
- type: array - type: array
- type: object - type: object
description: Additional metadata for the document.
additionalProperties: false additionalProperties: false
required: required:
- document_id - document_id
- content - content
- metadata - metadata
title: RAGDocument title: RAGDocument
description: >-
A document to be used for document ingestion in the RAG Tool.
InsertRequest: InsertRequest:
type: object type: object
properties: properties:
@ -5516,8 +5874,6 @@ components:
- type: array - type: array
- type: object - type: object
additionalProperties: false additionalProperties: false
required:
- content
title: ToolInvocationResult title: ToolInvocationResult
IterrowsResponse: IterrowsResponse:
type: object type: object
@ -5545,6 +5901,24 @@ components:
- data - data
title: IterrowsResponse title: IterrowsResponse
description: A paginated list of rows from a dataset. description: A paginated list of rows from a dataset.
Job:
type: object
properties:
job_id:
type: string
status:
type: string
enum:
- completed
- in_progress
- failed
- scheduled
title: JobStatus
additionalProperties: false
required:
- job_id
- status
title: Job
ListAgentSessionsResponse: ListAgentSessionsResponse:
type: object type: object
properties: properties:
@ -6610,9 +6984,8 @@ components:
description: The candidate to evaluate. description: The candidate to evaluate.
additionalProperties: false additionalProperties: false
required: required:
- task - benchmark_config
- candidate title: RunEvalRequest
title: RunRequest
RunShieldRequest: RunShieldRequest:
type: object type: object
properties: properties:
@ -6685,6 +7058,90 @@ components:
- attributes_to_save - attributes_to_save
- dataset_id - dataset_id
title: SaveSpansToDatasetRequest title: SaveSpansToDatasetRequest
ScoreRequest:
type: object
properties:
input_rows:
type: array
items:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
description: The rows to score.
scoring_functions:
type: object
additionalProperties:
oneOf:
- $ref: '#/components/schemas/ScoringFnParams'
- type: 'null'
description: >-
The scoring functions to use for the scoring.
additionalProperties: false
required:
- input_rows
- scoring_functions
title: ScoreRequest
ScoreResponse:
type: object
properties:
results:
type: object
additionalProperties:
$ref: '#/components/schemas/ScoringResult'
description: >-
A map of scoring function name to ScoringResult.
additionalProperties: false
required:
- results
title: ScoreResponse
description: The response from scoring.
ScoreBatchRequest:
type: object
properties:
dataset_id:
type: string
scoring_functions:
type: object
additionalProperties:
oneOf:
- $ref: '#/components/schemas/ScoringFnParams'
- type: 'null'
save_results_dataset:
type: boolean
additionalProperties: false
required:
- dataset_id
- scoring_functions
- save_results_dataset
title: ScoreBatchRequest
ScoreBatchResponse:
type: object
properties:
dataset_id:
type: string
results:
type: object
additionalProperties:
$ref: '#/components/schemas/ScoringResult'
additionalProperties: false
required:
- results
title: ScoreBatchResponse
AlgorithmConfig:
oneOf:
- $ref: '#/components/schemas/LoraFinetuningConfig'
- $ref: '#/components/schemas/QATFinetuningConfig'
discriminator:
propertyName: type
mapping:
LoRA: '#/components/schemas/LoraFinetuningConfig'
QAT: '#/components/schemas/QATFinetuningConfig'
LoraFinetuningConfig: LoraFinetuningConfig:
type: object type: object
properties: properties:
@ -6768,9 +7225,7 @@ components:
checkpoint_dir: checkpoint_dir:
type: string type: string
algorithm_config: algorithm_config:
oneOf: $ref: '#/components/schemas/AlgorithmConfig'
- $ref: '#/components/schemas/LoraFinetuningConfig'
- $ref: '#/components/schemas/QATFinetuningConfig'
additionalProperties: false additionalProperties: false
required: required:
- job_uuid - job_uuid

View file

@ -4,6 +4,21 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import os
import time
def pytest_collection_modifyitems(items): def pytest_collection_modifyitems(items):
for item in items: for item in items:
item.name = item.name.replace(' ', '_') item.name = item.name.replace(' ', '_')
def pytest_runtest_teardown(item):
interval_seconds = os.getenv("LLAMA_STACK_TEST_INTERVAL_SECONDS")
if interval_seconds:
time.sleep(float(interval_seconds))
def pytest_configure(config):
config.option.tbstyle = "short"
config.option.disable_warnings = True

View file

@ -123,6 +123,8 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"# NBVAL_SKIP\n", "# NBVAL_SKIP\n",
"!pip uninstall pandas numpy -y\n",
"!pip install pandas numpy\n",
"# This will build all the dependencies you will need\n", "# This will build all the dependencies you will need\n",
"!UV_SYSTEM_PYTHON=1 llama stack build --template together --image-type venv" "!UV_SYSTEM_PYTHON=1 llama stack build --template together --image-type venv"
] ]
@ -1203,7 +1205,7 @@
} }
], ],
"source": [ "source": [
"from llama_stack_client.lib.inference.event_logger import EventLogger\n", "from llama_stack_client import InferenceEventLogger\n",
"\n", "\n",
"message = {\"role\": \"user\", \"content\": \"Write me a sonnet about llama\"}\n", "message = {\"role\": \"user\", \"content\": \"Write me a sonnet about llama\"}\n",
"print(f'User> {message[\"content\"]}', \"green\")\n", "print(f'User> {message[\"content\"]}', \"green\")\n",
@ -1215,7 +1217,7 @@
")\n", ")\n",
"\n", "\n",
"# Print the tokens while they are received\n", "# Print the tokens while they are received\n",
"for log in EventLogger().log(response):\n", "for log in InferenceEventLogger().log(response):\n",
" log.print()\n" " log.print()\n"
] ]
}, },
@ -1632,8 +1634,7 @@
} }
], ],
"source": [ "source": [
"from llama_stack_client.lib.agents.agent import Agent\n", "from llama_stack_client import Agent, AgentEventLogger\n",
"from llama_stack_client.lib.agents.event_logger import EventLogger\n",
"from termcolor import cprint\n", "from termcolor import cprint\n",
"\n", "\n",
"agent = Agent(\n", "agent = Agent(\n",
@ -1659,7 +1660,7 @@
" ],\n", " ],\n",
" session_id=session_id,\n", " session_id=session_id,\n",
" )\n", " )\n",
" for log in EventLogger().log(response):\n", " for log in AgentEventLogger().log(response):\n",
" log.print()\n" " log.print()\n"
] ]
}, },
@ -1808,14 +1809,12 @@
], ],
"source": [ "source": [
"import uuid\n", "import uuid\n",
"from llama_stack_client.lib.agents.agent import Agent\n", "from llama_stack_client import Agent, AgentEventLogger, RAGDocument\n",
"from llama_stack_client.lib.agents.event_logger import EventLogger\n",
"from termcolor import cprint\n", "from termcolor import cprint\n",
"from llama_stack_client.types import Document\n",
"\n", "\n",
"urls = [\"chat.rst\", \"llama3.rst\", \"memory_optimizations.rst\", \"lora_finetune.rst\"]\n", "urls = [\"chat.rst\", \"llama3.rst\", \"memory_optimizations.rst\", \"lora_finetune.rst\"]\n",
"documents = [\n", "documents = [\n",
" Document(\n", " RAGDocument(\n",
" document_id=f\"num-{i}\",\n", " document_id=f\"num-{i}\",\n",
" content=f\"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}\",\n", " content=f\"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}\",\n",
" mime_type=\"text/plain\",\n", " mime_type=\"text/plain\",\n",
@ -1858,7 +1857,7 @@
" messages=[{\"role\": \"user\", \"content\": prompt}],\n", " messages=[{\"role\": \"user\", \"content\": prompt}],\n",
" session_id=session_id,\n", " session_id=session_id,\n",
" )\n", " )\n",
" for log in EventLogger().log(response):\n", " for log in AgentEventLogger().log(response):\n",
" log.print()" " log.print()"
] ]
}, },
@ -1969,7 +1968,7 @@
} }
], ],
"source": [ "source": [
"from llama_stack_client.types.agents.turn_create_params import Document\n", "from llama_stack_client import Document\n",
"\n", "\n",
"codex_agent = Agent(\n", "codex_agent = Agent(\n",
" client, \n", " client, \n",
@ -2013,7 +2012,7 @@
" # for chunk in response:\n", " # for chunk in response:\n",
" # print(chunk)\n", " # print(chunk)\n",
"\n", "\n",
" for log in EventLogger().log(response):\n", " for log in AgentEventLogger().log(response):\n",
" log.print()\n" " log.print()\n"
] ]
}, },
@ -2891,8 +2890,7 @@
], ],
"source": [ "source": [
"# NBVAL_SKIP\n", "# NBVAL_SKIP\n",
"from llama_stack_client.lib.agents.agent import Agent\n", "from llama_stack_client import Agent, AgentEventLogger\n",
"from llama_stack_client.lib.agents.event_logger import EventLogger\n",
"from termcolor import cprint\n", "from termcolor import cprint\n",
"\n", "\n",
"agent = Agent(\n", "agent = Agent(\n",
@ -2918,7 +2916,7 @@
" ],\n", " ],\n",
" session_id=session_id,\n", " session_id=session_id,\n",
" )\n", " )\n",
" for log in EventLogger().log(response):\n", " for log in AgentEventLogger().log(response):\n",
" log.print()\n" " log.print()\n"
] ]
}, },
@ -2993,8 +2991,7 @@
} }
], ],
"source": [ "source": [
"from llama_stack_client.lib.agents.agent import Agent\n", "from llama_stack_client import Agent, AgentEventLogger\n",
"from llama_stack_client.lib.agents.event_logger import EventLogger\n",
"\n", "\n",
"agent = Agent(\n", "agent = Agent(\n",
" client, \n", " client, \n",
@ -3021,7 +3018,7 @@
" session_id=session_id,\n", " session_id=session_id,\n",
" )\n", " )\n",
"\n", "\n",
" for log in EventLogger().log(response):\n", " for log in AgentEventLogger().log(response):\n",
" log.print()\n" " log.print()\n"
] ]
}, },
@ -4355,7 +4352,7 @@
" session_id=session_id,\n", " session_id=session_id,\n",
")\n", ")\n",
"\n", "\n",
"for log in EventLogger().log(response):\n", "for log in AgentEventLogger().log(response):\n",
" log.print()\n", " log.print()\n",
" " " "
] ]

View file

@ -47,9 +47,8 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"from llama_stack_client import LlamaStackClient\n", "from llama_stack_client import LlamaStackClient, Agent\n",
"from llama_stack.distribution.library_client import LlamaStackAsLibraryClient\n", "from llama_stack.distribution.library_client import LlamaStackAsLibraryClient\n",
"from llama_stack_client.lib.agents.agent import Agent\n",
"from rich.pretty import pprint\n", "from rich.pretty import pprint\n",
"import json\n", "import json\n",
"import uuid\n", "import uuid\n",

View file

@ -34,10 +34,8 @@
} }
], ],
"source": [ "source": [
"from llama_stack_client import LlamaStackClient\n", "from llama_stack_client import LlamaStackClient, Agent\n",
"from llama_stack.distribution.library_client import LlamaStackAsLibraryClient\n", "from llama_stack.distribution.library_client import LlamaStackAsLibraryClient\n",
"from llama_stack_client.types.agent_create_params import AgentConfig\n",
"from llama_stack_client.lib.agents.agent import Agent\n",
"from rich.pretty import pprint\n", "from rich.pretty import pprint\n",
"import json\n", "import json\n",
"import uuid\n", "import uuid\n",

View file

@ -14,7 +14,7 @@ Agents are configured using the `AgentConfig` class, which includes:
- **Safety Shields**: Guardrails to ensure responsible AI behavior - **Safety Shields**: Guardrails to ensure responsible AI behavior
```python ```python
from llama_stack_client.lib.agents.agent import Agent from llama_stack_client import Agent
# Create the agent # Create the agent
@ -44,14 +44,14 @@ Each interaction with an agent is called a "turn" and consists of:
- **Output Message**: The agent's response - **Output Message**: The agent's response
```python ```python
from llama_stack_client.lib.agents.event_logger import EventLogger from llama_stack_client import AgentEventLogger
# Create a turn with streaming response # Create a turn with streaming response
turn_response = agent.create_turn( turn_response = agent.create_turn(
session_id=session_id, session_id=session_id,
messages=[{"role": "user", "content": "Tell me about Llama models"}], messages=[{"role": "user", "content": "Tell me about Llama models"}],
) )
for log in EventLogger().log(turn_response): for log in AgentEventLogger().log(turn_response):
log.print() log.print()
``` ```
### Non-Streaming ### Non-Streaming

View file

@ -67,9 +67,7 @@ sequenceDiagram
Each step in this process can be monitored and controlled through configurations. Here's an example that demonstrates monitoring the agent's execution: Each step in this process can be monitored and controlled through configurations. Here's an example that demonstrates monitoring the agent's execution:
```python ```python
from llama_stack_client import LlamaStackClient from llama_stack_client import LlamaStackClient, Agent, AgentEventLogger
from llama_stack_client.lib.agents.agent import Agent
from llama_stack_client.lib.agents.event_logger import EventLogger
from rich.pretty import pprint from rich.pretty import pprint
# Replace host and port # Replace host and port
@ -113,7 +111,7 @@ response = agent.create_turn(
) )
# Monitor each step of execution # Monitor each step of execution
for log in EventLogger().log(response): for log in AgentEventLogger().log(response):
log.print() log.print()
# Using non-streaming API, the response contains input, steps, and output. # Using non-streaming API, the response contains input, steps, and output.

View file

@ -23,9 +23,7 @@ In this example, we will show you how to:
##### Building a Search Agent ##### Building a Search Agent
```python ```python
from llama_stack_client import LlamaStackClient from llama_stack_client import LlamaStackClient, Agent, AgentEventLogger
from llama_stack_client.lib.agents.agent import Agent
from llama_stack_client.lib.agents.event_logger import EventLogger
client = LlamaStackClient(base_url=f"http://{HOST}:{PORT}") client = LlamaStackClient(base_url=f"http://{HOST}:{PORT}")
@ -54,7 +52,7 @@ for prompt in user_prompts:
session_id=session_id, session_id=session_id,
) )
for log in EventLogger().log(response): for log in AgentEventLogger().log(response):
log.print() log.print()
``` ```

View file

@ -55,11 +55,11 @@ chunks_response = client.vector_io.query(
A better way to ingest documents is to use the RAG Tool. This tool allows you to ingest documents from URLs, files, etc. and automatically chunks them into smaller pieces. A better way to ingest documents is to use the RAG Tool. This tool allows you to ingest documents from URLs, files, etc. and automatically chunks them into smaller pieces.
```python ```python
from llama_stack_client.types import Document from llama_stack_client import RAGDocument
urls = ["memory_optimizations.rst", "chat.rst", "llama3.rst"] urls = ["memory_optimizations.rst", "chat.rst", "llama3.rst"]
documents = [ documents = [
Document( RAGDocument(
document_id=f"num-{i}", document_id=f"num-{i}",
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}", content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
mime_type="text/plain", mime_type="text/plain",
@ -86,7 +86,7 @@ results = client.tool_runtime.rag_tool.query(
One of the most powerful patterns is combining agents with RAG capabilities. Here's a complete example: One of the most powerful patterns is combining agents with RAG capabilities. Here's a complete example:
```python ```python
from llama_stack_client.lib.agents.agent import Agent from llama_stack_client import Agent
# Create agent with memory # Create agent with memory
agent = Agent( agent = Agent(
@ -140,9 +140,9 @@ response = agent.create_turn(
You can print the response with below. You can print the response with below.
```python ```python
from llama_stack_client.lib.agents.event_logger import EventLogger from llama_stack_client import AgentEventLogger
for log in EventLogger().log(response): for log in AgentEventLogger().log(response):
log.print() log.print()
``` ```

View file

@ -57,7 +57,7 @@ The `otel` sink works with any service compatible with the OpenTelemetry collect
Start a Jaeger instance with the OTLP HTTP endpoint at 4318 and the Jaeger UI at 16686 using the following command: Start a Jaeger instance with the OTLP HTTP endpoint at 4318 and the Jaeger UI at 16686 using the following command:
```bash ```bash
$ docker run --rm --name jaeger \ $ docker run --pull always --rm --name jaeger \
-p 16686:16686 -p 4318:4318 \ -p 16686:16686 -p 4318:4318 \
jaegertracing/jaeger:2.1.0 jaegertracing/jaeger:2.1.0
``` ```

View file

@ -110,10 +110,18 @@ MCP tools are special tools that can interact with llama stack over model contex
Refer to [https://github.com/modelcontextprotocol/servers](https://github.com/modelcontextprotocol/servers) for available MCP servers. Refer to [https://github.com/modelcontextprotocol/servers](https://github.com/modelcontextprotocol/servers) for available MCP servers.
```shell
# start your MCP server
mkdir /tmp/content
touch /tmp/content/foo
touch /tmp/content/bar
npx -y supergateway --port 8000 --stdio 'npx -y @modelcontextprotocol/server-filesystem /tmp/content'
```
Then register the MCP server as a tool group,
```python ```python
# Register MCP tools
client.toolgroups.register( client.toolgroups.register(
toolgroup_id="builtin::filesystem", toolgroup_id="mcp::filesystem",
provider_id="model-context-protocol", provider_id="model-context-protocol",
mcp_endpoint=URL(uri="http://localhost:8000/sse"), mcp_endpoint=URL(uri="http://localhost:8000/sse"),
) )
@ -181,7 +189,7 @@ group_tools = client.tools.list_tools(toolgroup_id="search_tools")
## Simple Example: Using an Agent with the Code-Interpreter Tool ## Simple Example: Using an Agent with the Code-Interpreter Tool
```python ```python
from llama_stack_client.lib.agents.agent import Agent from llama_stack_client import Agent
# Instantiate the AI agent with the given configuration # Instantiate the AI agent with the given configuration
agent = Agent( agent = Agent(

View file

@ -55,7 +55,7 @@ llama stack run llama_stack/templates/open-benchmark/run.yaml
There are 3 necessary inputs to run a benchmark eval There are 3 necessary inputs to run a benchmark eval
- `list of benchmark_ids`: The list of benchmark ids to run evaluation on - `list of benchmark_ids`: The list of benchmark ids to run evaluation on
- `model-id`: The model id to evaluate on - `model-id`: The model id to evaluate on
- `utput_dir`: Path to store the evaluate results - `output_dir`: Path to store the evaluate results
``` ```
llama-stack-client eval run-benchmark <benchmark_id_1> <benchmark_id_2> ... \ llama-stack-client eval run-benchmark <benchmark_id_1> <benchmark_id_2> ... \
--model_id <model id to evaluate on> \ --model_id <model id to evaluate on> \
@ -69,7 +69,7 @@ llama-stack-client eval run-benchmark help
to see the description of all the flags that eval run-benchmark has to see the description of all the flags that eval run-benchmark has
In the output log, you can find the file path that has your evaluation results. Open that file and you can see you aggrgate In the output log, you can find the file path that has your evaluation results. Open that file and you can see you aggregate
evaluation results over there. evaluation results over there.

View file

@ -56,9 +56,10 @@ You can do this via Conda (build code) or Docker which has a pre-built image.
This method allows you to get started quickly without having to build the distribution code. This method allows you to get started quickly without having to build the distribution code.
```bash ```bash
LLAMA_STACK_PORT=5001 LLAMA_STACK_PORT=8321
docker run \ docker run \
-it \ -it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ./run.yaml:/root/my-run.yaml \ -v ./run.yaml:/root/my-run.yaml \
llamastack/distribution-nvidia \ llamastack/distribution-nvidia \
@ -72,7 +73,7 @@ docker run \
```bash ```bash
llama stack build --template nvidia --image-type conda llama stack build --template nvidia --image-type conda
llama stack run ./run.yaml \ llama stack run ./run.yaml \
--port 5001 \ --port 8321 \
--env NVIDIA_API_KEY=$NVIDIA_API_KEY --env NVIDIA_API_KEY=$NVIDIA_API_KEY
--env INFERENCE_MODEL=$INFERENCE_MODEL --env INFERENCE_MODEL=$INFERENCE_MODEL
``` ```

View file

@ -26,7 +26,7 @@ The `llamastack/distribution-bedrock` distribution consists of the following pro
The following environment variables can be configured: The following environment variables can be configured:
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `5001`) - `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `8321`)
### Models ### Models
@ -51,9 +51,10 @@ You can do this via Conda (build code) or Docker which has a pre-built image.
This method allows you to get started quickly without having to build the distribution code. This method allows you to get started quickly without having to build the distribution code.
```bash ```bash
LLAMA_STACK_PORT=5001 LLAMA_STACK_PORT=8321
docker run \ docker run \
-it \ -it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
llamastack/distribution-bedrock \ llamastack/distribution-bedrock \
--port $LLAMA_STACK_PORT \ --port $LLAMA_STACK_PORT \

View file

@ -18,7 +18,7 @@ The `llamastack/distribution-cerebras` distribution consists of the following pr
The following environment variables can be configured: The following environment variables can be configured:
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `5001`) - `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `8321`)
- `CEREBRAS_API_KEY`: Cerebras API Key (default: ``) - `CEREBRAS_API_KEY`: Cerebras API Key (default: ``)
### Models ### Models
@ -43,9 +43,10 @@ You can do this via Conda (build code) or Docker which has a pre-built image.
This method allows you to get started quickly without having to build the distribution code. This method allows you to get started quickly without having to build the distribution code.
```bash ```bash
LLAMA_STACK_PORT=5001 LLAMA_STACK_PORT=8321
docker run \ docker run \
-it \ -it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ./run.yaml:/root/my-run.yaml \ -v ./run.yaml:/root/my-run.yaml \
llamastack/distribution-cerebras \ llamastack/distribution-cerebras \
@ -59,6 +60,6 @@ docker run \
```bash ```bash
llama stack build --template cerebras --image-type conda llama stack build --template cerebras --image-type conda
llama stack run ./run.yaml \ llama stack run ./run.yaml \
--port 5001 \ --port 8321 \
--env CEREBRAS_API_KEY=$CEREBRAS_API_KEY --env CEREBRAS_API_KEY=$CEREBRAS_API_KEY
``` ```

View file

@ -53,7 +53,7 @@ docker compose down
#### Start Dell-TGI server locally #### Start Dell-TGI server locally
``` ```
docker run -it --shm-size 1g -p 80:80 --gpus 4 \ docker run -it --pull always --shm-size 1g -p 80:80 --gpus 4 \
-e NUM_SHARD=4 -e NUM_SHARD=4
-e MAX_BATCH_PREFILL_TOKENS=32768 \ -e MAX_BATCH_PREFILL_TOKENS=32768 \
-e MAX_INPUT_TOKENS=8000 \ -e MAX_INPUT_TOKENS=8000 \
@ -65,7 +65,7 @@ registry.dell.huggingface.co/enterprise-dell-inference-meta-llama-meta-llama-3.1
#### Start Llama Stack server pointing to TGI server #### Start Llama Stack server pointing to TGI server
``` ```
docker run --network host -it -p 8321:8321 -v ./run.yaml:/root/my-run.yaml --gpus=all llamastack/distribution-tgi --yaml_config /root/my-run.yaml docker run --pull always --network host -it -p 8321:8321 -v ./run.yaml:/root/my-run.yaml --gpus=all llamastack/distribution-tgi --yaml_config /root/my-run.yaml
``` ```
Make sure in you `run.yaml` file, you inference provider is pointing to the correct TGI server endpoint. E.g. Make sure in you `run.yaml` file, you inference provider is pointing to the correct TGI server endpoint. E.g.

View file

@ -55,6 +55,7 @@ export CUDA_VISIBLE_DEVICES=0
export LLAMA_STACK_PORT=8321 export LLAMA_STACK_PORT=8321
docker run --rm -it \ docker run --rm -it \
--pull always \
--network host \ --network host \
-v $HOME/.cache/huggingface:/data \ -v $HOME/.cache/huggingface:/data \
-e HF_TOKEN=$HF_TOKEN \ -e HF_TOKEN=$HF_TOKEN \
@ -78,6 +79,7 @@ export SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
export CUDA_VISIBLE_DEVICES=1 export CUDA_VISIBLE_DEVICES=1
docker run --rm -it \ docker run --rm -it \
--pull always \
--network host \ --network host \
-v $HOME/.cache/huggingface:/data \ -v $HOME/.cache/huggingface:/data \
-e HF_TOKEN=$HF_TOKEN \ -e HF_TOKEN=$HF_TOKEN \
@ -120,6 +122,7 @@ This method allows you to get started quickly without having to build the distri
```bash ```bash
docker run -it \ docker run -it \
--pull always \
--network host \ --network host \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v $HOME/.llama:/root/.llama \ -v $HOME/.llama:/root/.llama \
@ -147,6 +150,7 @@ export SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
docker run \ docker run \
-it \ -it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v $HOME/.llama:/root/.llama \ -v $HOME/.llama:/root/.llama \
-v ./llama_stack/templates/tgi/run-with-safety.yaml:/root/my-run.yaml \ -v ./llama_stack/templates/tgi/run-with-safety.yaml:/root/my-run.yaml \

View file

@ -28,7 +28,7 @@ The `llamastack/distribution-fireworks` distribution consists of the following p
The following environment variables can be configured: The following environment variables can be configured:
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `5001`) - `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `8321`)
- `FIREWORKS_API_KEY`: Fireworks.AI API Key (default: ``) - `FIREWORKS_API_KEY`: Fireworks.AI API Key (default: ``)
### Models ### Models
@ -61,9 +61,10 @@ You can do this via Conda (build code) or Docker which has a pre-built image.
This method allows you to get started quickly without having to build the distribution code. This method allows you to get started quickly without having to build the distribution code.
```bash ```bash
LLAMA_STACK_PORT=5001 LLAMA_STACK_PORT=8321
docker run \ docker run \
-it \ -it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
llamastack/distribution-fireworks \ llamastack/distribution-fireworks \
--port $LLAMA_STACK_PORT \ --port $LLAMA_STACK_PORT \

View file

@ -28,7 +28,7 @@ The `llamastack/distribution-groq` distribution consists of the following provid
The following environment variables can be configured: The following environment variables can be configured:
- `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `5001`) - `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `8321`)
- `GROQ_API_KEY`: Groq API Key (default: ``) - `GROQ_API_KEY`: Groq API Key (default: ``)
### Models ### Models
@ -56,9 +56,10 @@ You can do this via Conda (build code) or Docker which has a pre-built image.
This method allows you to get started quickly without having to build the distribution code. This method allows you to get started quickly without having to build the distribution code.
```bash ```bash
LLAMA_STACK_PORT=5001 LLAMA_STACK_PORT=8321
docker run \ docker run \
-it \ -it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
llamastack/distribution-groq \ llamastack/distribution-groq \
--port $LLAMA_STACK_PORT \ --port $LLAMA_STACK_PORT \

View file

@ -30,7 +30,7 @@ Note that you need access to nvidia GPUs to run this distribution. This distribu
The following environment variables can be configured: The following environment variables can be configured:
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `5001`) - `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `8321`)
- `INFERENCE_MODEL`: Inference model loaded into the Meta Reference server (default: `meta-llama/Llama-3.2-3B-Instruct`) - `INFERENCE_MODEL`: Inference model loaded into the Meta Reference server (default: `meta-llama/Llama-3.2-3B-Instruct`)
- `INFERENCE_CHECKPOINT_DIR`: Directory containing the Meta Reference model checkpoint (default: `null`) - `INFERENCE_CHECKPOINT_DIR`: Directory containing the Meta Reference model checkpoint (default: `null`)
- `SAFETY_MODEL`: Name of the safety (Llama-Guard) model to use (default: `meta-llama/Llama-Guard-3-1B`) - `SAFETY_MODEL`: Name of the safety (Llama-Guard) model to use (default: `meta-llama/Llama-Guard-3-1B`)
@ -75,9 +75,10 @@ You can do this via Conda (build code) or Docker which has a pre-built image.
This method allows you to get started quickly without having to build the distribution code. This method allows you to get started quickly without having to build the distribution code.
```bash ```bash
LLAMA_STACK_PORT=5001 LLAMA_STACK_PORT=8321
docker run \ docker run \
-it \ -it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ~/.llama:/root/.llama \ -v ~/.llama:/root/.llama \
llamastack/distribution-meta-reference-gpu \ llamastack/distribution-meta-reference-gpu \
@ -90,6 +91,7 @@ If you are using Llama Stack Safety / Shield APIs, use:
```bash ```bash
docker run \ docker run \
-it \ -it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ~/.llama:/root/.llama \ -v ~/.llama:/root/.llama \
llamastack/distribution-meta-reference-gpu \ llamastack/distribution-meta-reference-gpu \
@ -105,7 +107,7 @@ Make sure you have done `uv pip install llama-stack` and have the Llama Stack CL
```bash ```bash
llama stack build --template meta-reference-gpu --image-type conda llama stack build --template meta-reference-gpu --image-type conda
llama stack run distributions/meta-reference-gpu/run.yaml \ llama stack run distributions/meta-reference-gpu/run.yaml \
--port 5001 \ --port 8321 \
--env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct --env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
``` ```
@ -113,7 +115,7 @@ If you are using Llama Stack Safety / Shield APIs, use:
```bash ```bash
llama stack run distributions/meta-reference-gpu/run-with-safety.yaml \ llama stack run distributions/meta-reference-gpu/run-with-safety.yaml \
--port 5001 \ --port 8321 \
--env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \ --env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \
--env SAFETY_MODEL=meta-llama/Llama-Guard-3-1B --env SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
``` ```

View file

@ -32,7 +32,7 @@ Note that you need access to nvidia GPUs to run this distribution. This distribu
The following environment variables can be configured: The following environment variables can be configured:
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `5001`) - `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `8321`)
- `INFERENCE_MODEL`: Inference model loaded into the Meta Reference server (default: `meta-llama/Llama-3.2-3B-Instruct`) - `INFERENCE_MODEL`: Inference model loaded into the Meta Reference server (default: `meta-llama/Llama-3.2-3B-Instruct`)
- `INFERENCE_CHECKPOINT_DIR`: Directory containing the Meta Reference model checkpoint (default: `null`) - `INFERENCE_CHECKPOINT_DIR`: Directory containing the Meta Reference model checkpoint (default: `null`)
@ -75,9 +75,10 @@ You can do this via Conda (build code) or Docker which has a pre-built image.
This method allows you to get started quickly without having to build the distribution code. This method allows you to get started quickly without having to build the distribution code.
```bash ```bash
LLAMA_STACK_PORT=5001 LLAMA_STACK_PORT=8321
docker run \ docker run \
-it \ -it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ~/.llama:/root/.llama \ -v ~/.llama:/root/.llama \
llamastack/distribution-meta-reference-quantized-gpu \ llamastack/distribution-meta-reference-quantized-gpu \
@ -90,6 +91,7 @@ If you are using Llama Stack Safety / Shield APIs, use:
```bash ```bash
docker run \ docker run \
-it \ -it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ~/.llama:/root/.llama \ -v ~/.llama:/root/.llama \
llamastack/distribution-meta-reference-quantized-gpu \ llamastack/distribution-meta-reference-quantized-gpu \

View file

@ -15,7 +15,7 @@ The `llamastack/distribution-nvidia` distribution consists of the following prov
The following environment variables can be configured: The following environment variables can be configured:
- `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `5001`) - `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `8321`)
- `NVIDIA_API_KEY`: NVIDIA API Key (default: ``) - `NVIDIA_API_KEY`: NVIDIA API Key (default: ``)
### Models ### Models
@ -39,9 +39,10 @@ You can do this via Conda (build code) or Docker which has a pre-built image.
This method allows you to get started quickly without having to build the distribution code. This method allows you to get started quickly without having to build the distribution code.
```bash ```bash
LLAMA_STACK_PORT=5001 LLAMA_STACK_PORT=8321
docker run \ docker run \
-it \ -it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ./run.yaml:/root/my-run.yaml \ -v ./run.yaml:/root/my-run.yaml \
llamastack/distribution-nvidia \ llamastack/distribution-nvidia \
@ -55,6 +56,6 @@ docker run \
```bash ```bash
llama stack build --template nvidia --image-type conda llama stack build --template nvidia --image-type conda
llama stack run ./run.yaml \ llama stack run ./run.yaml \
--port 5001 \ --port 8321 \
--env NVIDIA_API_KEY=$NVIDIA_API_KEY --env NVIDIA_API_KEY=$NVIDIA_API_KEY
``` ```

View file

@ -30,7 +30,7 @@ You should use this distribution if you have a regular desktop machine without v
The following environment variables can be configured: The following environment variables can be configured:
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `5001`) - `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `8321`)
- `OLLAMA_URL`: URL of the Ollama server (default: `http://127.0.0.1:11434`) - `OLLAMA_URL`: URL of the Ollama server (default: `http://127.0.0.1:11434`)
- `INFERENCE_MODEL`: Inference model loaded into the Ollama server (default: `meta-llama/Llama-3.2-3B-Instruct`) - `INFERENCE_MODEL`: Inference model loaded into the Ollama server (default: `meta-llama/Llama-3.2-3B-Instruct`)
- `SAFETY_MODEL`: Safety model loaded into the Ollama server (default: `meta-llama/Llama-Guard-3-1B`) - `SAFETY_MODEL`: Safety model loaded into the Ollama server (default: `meta-llama/Llama-Guard-3-1B`)
@ -69,9 +69,10 @@ Now you are ready to run Llama Stack with Ollama as the inference provider. You
This method allows you to get started quickly without having to build the distribution code. This method allows you to get started quickly without having to build the distribution code.
```bash ```bash
export LLAMA_STACK_PORT=5001 export LLAMA_STACK_PORT=8321
docker run \ docker run \
-it \ -it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ~/.llama:/root/.llama \ -v ~/.llama:/root/.llama \
llamastack/distribution-ollama \ llamastack/distribution-ollama \
@ -89,6 +90,7 @@ cd /path/to/llama-stack
docker run \ docker run \
-it \ -it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ~/.llama:/root/.llama \ -v ~/.llama:/root/.llama \
-v ./llama_stack/templates/ollama/run-with-safety.yaml:/root/my-run.yaml \ -v ./llama_stack/templates/ollama/run-with-safety.yaml:/root/my-run.yaml \
@ -105,7 +107,7 @@ docker run \
Make sure you have done `uv pip install llama-stack` and have the Llama Stack CLI available. Make sure you have done `uv pip install llama-stack` and have the Llama Stack CLI available.
```bash ```bash
export LLAMA_STACK_PORT=5001 export LLAMA_STACK_PORT=8321
llama stack build --template ollama --image-type conda llama stack build --template ollama --image-type conda
llama stack run ./run.yaml \ llama stack run ./run.yaml \

View file

@ -28,7 +28,7 @@ The `llamastack/distribution-passthrough` distribution consists of the following
The following environment variables can be configured: The following environment variables can be configured:
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `5001`) - `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `8321`)
- `PASSTHROUGH_API_KEY`: Passthrough API Key (default: ``) - `PASSTHROUGH_API_KEY`: Passthrough API Key (default: ``)
- `PASSTHROUGH_URL`: Passthrough URL (default: ``) - `PASSTHROUGH_URL`: Passthrough URL (default: ``)

View file

@ -29,7 +29,7 @@ You can use this distribution if you have GPUs and want to run an independent vL
The following environment variables can be configured: The following environment variables can be configured:
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `5001`) - `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `8321`)
- `INFERENCE_MODEL`: Inference model loaded into the vLLM server (default: `meta-llama/Llama-3.2-3B-Instruct`) - `INFERENCE_MODEL`: Inference model loaded into the vLLM server (default: `meta-llama/Llama-3.2-3B-Instruct`)
- `VLLM_URL`: URL of the vLLM server with the main inference model (default: `http://host.docker.internal:5100/v1`) - `VLLM_URL`: URL of the vLLM server with the main inference model (default: `http://host.docker.internal:5100/v1`)
- `MAX_TOKENS`: Maximum number of tokens for generation (default: `4096`) - `MAX_TOKENS`: Maximum number of tokens for generation (default: `4096`)
@ -47,6 +47,7 @@ export INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
export CUDA_VISIBLE_DEVICES=0 export CUDA_VISIBLE_DEVICES=0
docker run \ docker run \
--pull always \
--runtime nvidia \ --runtime nvidia \
--gpus $CUDA_VISIBLE_DEVICES \ --gpus $CUDA_VISIBLE_DEVICES \
-v ~/.cache/huggingface:/root/.cache/huggingface \ -v ~/.cache/huggingface:/root/.cache/huggingface \
@ -59,6 +60,8 @@ docker run \
--port $INFERENCE_PORT --port $INFERENCE_PORT
``` ```
Note that you'll also need to set `--enable-auto-tool-choice` and `--tool-call-parser` to [enable tool calling in vLLM](https://docs.vllm.ai/en/latest/features/tool_calling.html).
If you are using Llama Stack Safety / Shield APIs, then you will need to also run another instance of a vLLM with a corresponding safety model like `meta-llama/Llama-Guard-3-1B` using a script like: If you are using Llama Stack Safety / Shield APIs, then you will need to also run another instance of a vLLM with a corresponding safety model like `meta-llama/Llama-Guard-3-1B` using a script like:
```bash ```bash
@ -67,6 +70,7 @@ export SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
export CUDA_VISIBLE_DEVICES=1 export CUDA_VISIBLE_DEVICES=1
docker run \ docker run \
--pull always \
--runtime nvidia \ --runtime nvidia \
--gpus $CUDA_VISIBLE_DEVICES \ --gpus $CUDA_VISIBLE_DEVICES \
-v ~/.cache/huggingface:/root/.cache/huggingface \ -v ~/.cache/huggingface:/root/.cache/huggingface \
@ -90,10 +94,11 @@ This method allows you to get started quickly without having to build the distri
```bash ```bash
export INFERENCE_PORT=8000 export INFERENCE_PORT=8000
export INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct export INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
export LLAMA_STACK_PORT=5001 export LLAMA_STACK_PORT=8321
docker run \ docker run \
-it \ -it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ./run.yaml:/root/my-run.yaml \ -v ./run.yaml:/root/my-run.yaml \
llamastack/distribution-remote-vllm \ llamastack/distribution-remote-vllm \
@ -115,6 +120,7 @@ cd /path/to/llama-stack
docker run \ docker run \
-it \ -it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ~/.llama:/root/.llama \ -v ~/.llama:/root/.llama \
-v ./llama_stack/templates/remote-vllm/run-with-safety.yaml:/root/my-run.yaml \ -v ./llama_stack/templates/remote-vllm/run-with-safety.yaml:/root/my-run.yaml \
@ -135,7 +141,7 @@ Make sure you have done `uv pip install llama-stack` and have the Llama Stack CL
```bash ```bash
export INFERENCE_PORT=8000 export INFERENCE_PORT=8000
export INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct export INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
export LLAMA_STACK_PORT=5001 export LLAMA_STACK_PORT=8321
cd distributions/remote-vllm cd distributions/remote-vllm
llama stack build --template remote-vllm --image-type conda llama stack build --template remote-vllm --image-type conda

View file

@ -27,7 +27,7 @@ The `llamastack/distribution-sambanova` distribution consists of the following p
The following environment variables can be configured: The following environment variables can be configured:
- `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `5001`) - `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `8321`)
- `SAMBANOVA_API_KEY`: SambaNova.AI API Key (default: ``) - `SAMBANOVA_API_KEY`: SambaNova.AI API Key (default: ``)
### Models ### Models
@ -59,9 +59,10 @@ You can do this via Conda (build code) or Docker which has a pre-built image.
This method allows you to get started quickly without having to build the distribution code. This method allows you to get started quickly without having to build the distribution code.
```bash ```bash
LLAMA_STACK_PORT=5001 LLAMA_STACK_PORT=8321
docker run \ docker run \
-it \ -it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
llamastack/distribution-sambanova \ llamastack/distribution-sambanova \
--port $LLAMA_STACK_PORT \ --port $LLAMA_STACK_PORT \

View file

@ -31,7 +31,7 @@ You can use this distribution if you have GPUs and want to run an independent TG
The following environment variables can be configured: The following environment variables can be configured:
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `5001`) - `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `8321`)
- `INFERENCE_MODEL`: Inference model loaded into the TGI server (default: `meta-llama/Llama-3.2-3B-Instruct`) - `INFERENCE_MODEL`: Inference model loaded into the TGI server (default: `meta-llama/Llama-3.2-3B-Instruct`)
- `TGI_URL`: URL of the TGI server with the main inference model (default: `http://127.0.0.1:8080/v1`) - `TGI_URL`: URL of the TGI server with the main inference model (default: `http://127.0.0.1:8080/v1`)
- `TGI_SAFETY_URL`: URL of the TGI server with the safety model (default: `http://127.0.0.1:8081/v1`) - `TGI_SAFETY_URL`: URL of the TGI server with the safety model (default: `http://127.0.0.1:8081/v1`)
@ -48,6 +48,7 @@ export INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
export CUDA_VISIBLE_DEVICES=0 export CUDA_VISIBLE_DEVICES=0
docker run --rm -it \ docker run --rm -it \
--pull always \
-v $HOME/.cache/huggingface:/data \ -v $HOME/.cache/huggingface:/data \
-p $INFERENCE_PORT:$INFERENCE_PORT \ -p $INFERENCE_PORT:$INFERENCE_PORT \
--gpus $CUDA_VISIBLE_DEVICES \ --gpus $CUDA_VISIBLE_DEVICES \
@ -68,6 +69,7 @@ export SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
export CUDA_VISIBLE_DEVICES=1 export CUDA_VISIBLE_DEVICES=1
docker run --rm -it \ docker run --rm -it \
--pull always \
-v $HOME/.cache/huggingface:/data \ -v $HOME/.cache/huggingface:/data \
-p $SAFETY_PORT:$SAFETY_PORT \ -p $SAFETY_PORT:$SAFETY_PORT \
--gpus $CUDA_VISIBLE_DEVICES \ --gpus $CUDA_VISIBLE_DEVICES \
@ -88,9 +90,10 @@ Now you are ready to run Llama Stack with TGI as the inference provider. You can
This method allows you to get started quickly without having to build the distribution code. This method allows you to get started quickly without having to build the distribution code.
```bash ```bash
LLAMA_STACK_PORT=5001 LLAMA_STACK_PORT=8321
docker run \ docker run \
-it \ -it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
llamastack/distribution-tgi \ llamastack/distribution-tgi \
--port $LLAMA_STACK_PORT \ --port $LLAMA_STACK_PORT \
@ -107,6 +110,7 @@ cd /path/to/llama-stack
docker run \ docker run \
-it \ -it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ~/.llama:/root/.llama \ -v ~/.llama:/root/.llama \
-v ./llama_stack/templates/tgi/run-with-safety.yaml:/root/my-run.yaml \ -v ./llama_stack/templates/tgi/run-with-safety.yaml:/root/my-run.yaml \

View file

@ -28,7 +28,7 @@ The `llamastack/distribution-together` distribution consists of the following pr
The following environment variables can be configured: The following environment variables can be configured:
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `5001`) - `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `8321`)
- `TOGETHER_API_KEY`: Together.AI API Key (default: ``) - `TOGETHER_API_KEY`: Together.AI API Key (default: ``)
### Models ### Models
@ -62,9 +62,10 @@ You can do this via Conda (build code) or Docker which has a pre-built image.
This method allows you to get started quickly without having to build the distribution code. This method allows you to get started quickly without having to build the distribution code.
```bash ```bash
LLAMA_STACK_PORT=5001 LLAMA_STACK_PORT=8321
docker run \ docker run \
-it \ -it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
llamastack/distribution-together \ llamastack/distribution-together \
--port $LLAMA_STACK_PORT \ --port $LLAMA_STACK_PORT \

View file

@ -54,6 +54,7 @@ mkdir -p ~/.llama
Then you can start the server using the container tool of your choice. For example, if you are running Docker you can use the following command: Then you can start the server using the container tool of your choice. For example, if you are running Docker you can use the following command:
```bash ```bash
docker run -it \ docker run -it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ~/.llama:/root/.llama \ -v ~/.llama:/root/.llama \
llamastack/distribution-ollama \ llamastack/distribution-ollama \
@ -74,6 +75,7 @@ Docker containers run in their own isolated network namespaces on Linux. To allo
Linux users having issues running the above command should instead try the following: Linux users having issues running the above command should instead try the following:
```bash ```bash
docker run -it \ docker run -it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ~/.llama:/root/.llama \ -v ~/.llama:/root/.llama \
--network=host \ --network=host \
@ -197,9 +199,7 @@ import os
import uuid import uuid
from termcolor import cprint from termcolor import cprint
from llama_stack_client.lib.agents.agent import Agent from llama_stack_client import Agent, AgentEventLogger, RAGDocument
from llama_stack_client.lib.agents.event_logger import EventLogger
from llama_stack_client.types import Document
def create_http_client(): def create_http_client():
@ -225,7 +225,7 @@ client = (
# Documents to be used for RAG # Documents to be used for RAG
urls = ["chat.rst", "llama3.rst", "memory_optimizations.rst", "lora_finetune.rst"] urls = ["chat.rst", "llama3.rst", "memory_optimizations.rst", "lora_finetune.rst"]
documents = [ documents = [
Document( RAGDocument(
document_id=f"num-{i}", document_id=f"num-{i}",
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}", content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
mime_type="text/plain", mime_type="text/plain",
@ -284,7 +284,7 @@ for prompt in user_prompts:
messages=[{"role": "user", "content": prompt}], messages=[{"role": "user", "content": prompt}],
session_id=session_id, session_id=session_id,
) )
for log in EventLogger().log(response): for log in AgentEventLogger().log(response):
log.print() log.print()
``` ```

View file

@ -15,8 +15,6 @@ Llama Stack defines and standardizes the core building blocks needed to bring ge
- **Multiple developer interfaces** like CLI and SDKs for Python, Node, iOS, and Android - **Multiple developer interfaces** like CLI and SDKs for Python, Node, iOS, and Android
- **Standalone applications** as examples for how to build production-grade AI applications with Llama Stack - **Standalone applications** as examples for how to build production-grade AI applications with Llama Stack
We focus on making it easy to build production applications with the Llama model family - from the latest Llama 3.3 to specialized models like Llama Guard for safety.
```{image} ../_static/llama-stack.png ```{image} ../_static/llama-stack.png
:alt: Llama Stack :alt: Llama Stack
:width: 400px :width: 400px

View file

@ -48,7 +48,7 @@ Llama Stack addresses these challenges through a service-oriented, API-first app
**Robust Ecosystem** **Robust Ecosystem**
- Llama Stack is already integrated with distribution partners (cloud providers, hardware vendors, and AI-focused companies). - Llama Stack is already integrated with distribution partners (cloud providers, hardware vendors, and AI-focused companies).
- Ecosystem offers tailored infrastructure, software, and services for deploying Llama models. - Ecosystem offers tailored infrastructure, software, and services for deploying a variety of models.
### Our Philosophy ### Our Philosophy
@ -57,7 +57,6 @@ Llama Stack addresses these challenges through a service-oriented, API-first app
- **Composability**: Every component is independent but works together seamlessly - **Composability**: Every component is independent but works together seamlessly
- **Production Ready**: Built for real-world applications, not just demos - **Production Ready**: Built for real-world applications, not just demos
- **Turnkey Solutions**: Easy to deploy built in solutions for popular deployment scenarios - **Turnkey Solutions**: Easy to deploy built in solutions for popular deployment scenarios
- **Llama First**: Explicit focus on Meta's Llama models and partnering ecosystem
With Llama Stack, you can focus on building your application while we handle the infrastructure complexity, essential capabilities, and provider integrations. With Llama Stack, you can focus on building your application while we handle the infrastructure complexity, essential capabilities, and provider integrations.

View file

@ -118,6 +118,7 @@ Playground can also be started in a docker image:
export LLAMA_STACK_URL=http://localhost:11434 export LLAMA_STACK_URL=http://localhost:11434
docker run \ docker run \
--pull always \
-p 8501:8501 \ -p 8501:8501 \
-e LLAMA_STACK_ENDPOINT=$LLAMA_STACK_URL \ -e LLAMA_STACK_ENDPOINT=$LLAMA_STACK_URL \
quay.io/jland/llama-stack-playground quay.io/jland/llama-stack-playground

View file

@ -48,7 +48,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"HOST = \"localhost\" # Replace with your host\n", "HOST = \"localhost\" # Replace with your host\n",
"PORT = 5001 # Replace with your port\n", "PORT = 8321 # Replace with your port\n",
"MODEL_NAME='meta-llama/Llama-3.2-3B-Instruct'" "MODEL_NAME='meta-llama/Llama-3.2-3B-Instruct'"
] ]
}, },
@ -369,6 +369,9 @@
} }
], ],
"metadata": { "metadata": {
"fileHeader": "",
"fileUid": "7da25939-a2a3-463c-958e-9cdfd710d158",
"isAdHoc": false,
"kernelspec": { "kernelspec": {
"display_name": "Python 3 (ipykernel)", "display_name": "Python 3 (ipykernel)",
"language": "python", "language": "python",
@ -386,7 +389,5 @@
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.10.15" "version": "3.10.15"
} }
}, }
"nbformat": 4,
"nbformat_minor": 5
} }

View file

@ -43,7 +43,7 @@
"source": [ "source": [
"#### 2. Set Up Local and Cloud Clients\n", "#### 2. Set Up Local and Cloud Clients\n",
"\n", "\n",
"Initialize both clients, specifying the `base_url` for each instance. In this case, we have the local distribution running on `http://localhost:8321` and the cloud distribution running on `http://localhost:5001`.\n" "Initialize both clients, specifying the `base_url` for each instance. In this case, we have the local distribution running on `http://localhost:8321` and the cloud distribution running on `http://localhost:8322`.\n"
] ]
}, },
{ {
@ -236,6 +236,9 @@
} }
], ],
"metadata": { "metadata": {
"fileHeader": "",
"fileUid": "e11939ac-dfbc-4a1c-83be-e494c7f803b8",
"isAdHoc": false,
"kernelspec": { "kernelspec": {
"display_name": "Python 3 (ipykernel)", "display_name": "Python 3 (ipykernel)",
"language": "python", "language": "python",
@ -253,7 +256,5 @@
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.10.15" "version": "3.10.15"
} }
}, }
"nbformat": 4,
"nbformat_minor": 5
} }

View file

@ -47,7 +47,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"HOST = \"localhost\" # Replace with your host\n", "HOST = \"localhost\" # Replace with your host\n",
"PORT = 5001 # Replace with your port\n", "PORT = 8321 # Replace with your port\n",
"MODEL_NAME='meta-llama/Llama-3.2-3B-Instruct'" "MODEL_NAME='meta-llama/Llama-3.2-3B-Instruct'"
] ]
}, },
@ -281,6 +281,9 @@
} }
], ],
"metadata": { "metadata": {
"fileHeader": "",
"fileUid": "b1b93b6e-22a2-4c24-8cb0-161fdafff29a",
"isAdHoc": false,
"kernelspec": { "kernelspec": {
"display_name": "base", "display_name": "base",
"language": "python", "language": "python",
@ -298,7 +301,5 @@
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.12.2" "version": "3.12.2"
} }
}, }
"nbformat": 4,
"nbformat_minor": 5
} }

View file

@ -45,7 +45,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"HOST = \"localhost\" # Replace with your host\n", "HOST = \"localhost\" # Replace with your host\n",
"CLOUD_PORT = 5001 # Replace with your cloud distro port\n", "CLOUD_PORT = 8321 # Replace with your cloud distro port\n",
"MODEL_NAME='Llama3.2-11B-Vision-Instruct'" "MODEL_NAME='Llama3.2-11B-Vision-Instruct'"
] ]
}, },
@ -180,6 +180,9 @@
} }
], ],
"metadata": { "metadata": {
"fileHeader": "",
"fileUid": "37bbbfda-8e42-446c-89c7-59dd49e2d339",
"isAdHoc": false,
"kernelspec": { "kernelspec": {
"display_name": "base", "display_name": "base",
"language": "python", "language": "python",
@ -197,7 +200,5 @@
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.12.2" "version": "3.12.2"
} }
}, }
"nbformat": 4,
"nbformat_minor": 5
} }

View file

@ -46,7 +46,7 @@
"nest_asyncio.apply()\n", "nest_asyncio.apply()\n",
"\n", "\n",
"HOST = \"localhost\"\n", "HOST = \"localhost\"\n",
"PORT = 5001\n", "PORT = 8321\n",
"MODEL_NAME = \"meta-llama/Llama-3.2-3B-Instruct\"\n" "MODEL_NAME = \"meta-llama/Llama-3.2-3B-Instruct\"\n"
] ]
}, },
@ -296,7 +296,7 @@
"\n", "\n",
" # Create an agent instance with the client and configuration\n", " # Create an agent instance with the client and configuration\n",
" agent = Agent(\n", " agent = Agent(\n",
" client, \n", " client,\n",
" model=MODEL_NAME,\n", " model=MODEL_NAME,\n",
" instructions=\"\"\"You are a helpful assistant that responds to user queries with relevant information and cites sources when available.\"\"\",\n", " instructions=\"\"\"You are a helpful assistant that responds to user queries with relevant information and cites sources when available.\"\"\",\n",
" sampling_params={\n", " sampling_params={\n",
@ -335,6 +335,9 @@
} }
], ],
"metadata": { "metadata": {
"fileHeader": "",
"fileUid": "f0abbf6d-ed52-40ad-afb4-f5ec99130249",
"isAdHoc": false,
"kernelspec": { "kernelspec": {
"display_name": "Python 3 (ipykernel)", "display_name": "Python 3 (ipykernel)",
"language": "python", "language": "python",
@ -352,7 +355,5 @@
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.10.15" "version": "3.10.15"
} }
}, }
"nbformat": 4,
"nbformat_minor": 5
} }

View file

@ -45,7 +45,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"HOST = \"localhost\" # Replace with your host\n", "HOST = \"localhost\" # Replace with your host\n",
"PORT = 5001 # Replace with your port\n", "PORT = 8321 # Replace with your port\n",
"MODEL_NAME='meta-llama/Llama-3.2-3B-Instruct'\n", "MODEL_NAME='meta-llama/Llama-3.2-3B-Instruct'\n",
"MEMORY_BANK_ID=\"tutorial_bank\"" "MEMORY_BANK_ID=\"tutorial_bank\""
] ]
@ -378,6 +378,9 @@
} }
], ],
"metadata": { "metadata": {
"fileHeader": "",
"fileUid": "73bc3357-0e5e-42ff-95b1-40b916d24c4f",
"isAdHoc": false,
"kernelspec": { "kernelspec": {
"display_name": "Python 3 (ipykernel)", "display_name": "Python 3 (ipykernel)",
"language": "python", "language": "python",
@ -395,7 +398,5 @@
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.10.15" "version": "3.10.15"
} }
}, }
"nbformat": 4,
"nbformat_minor": 4
} }

View file

@ -49,7 +49,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"HOST = \"localhost\" # Replace with your host\n", "HOST = \"localhost\" # Replace with your host\n",
"PORT = 5001 # Replace with your port\n", "PORT = 8321 # Replace with your port\n",
"SHEILD_NAME=\"meta-llama/Llama-Guard-3-1B\"" "SHEILD_NAME=\"meta-llama/Llama-Guard-3-1B\""
] ]
}, },
@ -112,6 +112,9 @@
} }
], ],
"metadata": { "metadata": {
"fileHeader": "",
"fileUid": "9afaddb7-c2fb-4309-8fa0-761697de53f0",
"isAdHoc": false,
"kernelspec": { "kernelspec": {
"display_name": "Python 3 (ipykernel)", "display_name": "Python 3 (ipykernel)",
"language": "python", "language": "python",
@ -129,7 +132,5 @@
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.11.10" "version": "3.11.10"
} }
}, }
"nbformat": 4,
"nbformat_minor": 4
} }

View file

@ -50,7 +50,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"HOST = \"localhost\" # Replace with your host\n", "HOST = \"localhost\" # Replace with your host\n",
"PORT = 5001 # Replace with your port\n", "PORT = 8321 # Replace with your port\n",
"MODEL_NAME = \"meta-llama/Llama-3.2-3B-Instruct\"\n" "MODEL_NAME = \"meta-llama/Llama-3.2-3B-Instruct\"\n"
] ]
}, },
@ -115,7 +115,7 @@
"async def agent_example():\n", "async def agent_example():\n",
" client = LlamaStackClient(base_url=f\"http://{HOST}:{PORT}\")\n", " client = LlamaStackClient(base_url=f\"http://{HOST}:{PORT}\")\n",
" agent = Agent(\n", " agent = Agent(\n",
" client, \n", " client,\n",
" model=MODEL_NAME,\n", " model=MODEL_NAME,\n",
" instructions=\"You are a helpful assistant! If you call builtin tools like brave search, follow the syntax brave_search.call(…)\",\n", " instructions=\"You are a helpful assistant! If you call builtin tools like brave search, follow the syntax brave_search.call(…)\",\n",
" sampling_params={\n", " sampling_params={\n",
@ -168,6 +168,9 @@
} }
], ],
"metadata": { "metadata": {
"fileHeader": "",
"fileUid": "8de24775-c4a0-49c7-904e-608264f69292",
"isAdHoc": false,
"kernelspec": { "kernelspec": {
"display_name": "Python 3 (ipykernel)", "display_name": "Python 3 (ipykernel)",
"language": "python", "language": "python",
@ -185,7 +188,5 @@
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.10.15" "version": "3.10.15"
} }
}, }
"nbformat": 4,
"nbformat_minor": 4
} }

View file

@ -96,7 +96,7 @@ If you're looking for more specific topics, we have a [Zero to Hero Guide](#next
3. **Set the ENV variables by exporting them to the terminal**: 3. **Set the ENV variables by exporting them to the terminal**:
```bash ```bash
export OLLAMA_URL="http://localhost:11434" export OLLAMA_URL="http://localhost:11434"
export LLAMA_STACK_PORT=5001 export LLAMA_STACK_PORT=8321
export INFERENCE_MODEL="meta-llama/Llama-3.2-3B-Instruct" export INFERENCE_MODEL="meta-llama/Llama-3.2-3B-Instruct"
export SAFETY_MODEL="meta-llama/Llama-Guard-3-1B" export SAFETY_MODEL="meta-llama/Llama-Guard-3-1B"
``` ```
@ -112,7 +112,7 @@ If you're looking for more specific topics, we have a [Zero to Hero Guide](#next
``` ```
Note: Every time you run a new model with `ollama run`, you will need to restart the llama stack. Otherwise it won't see the new model. Note: Every time you run a new model with `ollama run`, you will need to restart the llama stack. Otherwise it won't see the new model.
The server will start and listen on `http://localhost:5001`. The server will start and listen on `http://localhost:8321`.
--- ---
## Test with `llama-stack-client` CLI ## Test with `llama-stack-client` CLI
@ -120,11 +120,11 @@ After setting up the server, open a new terminal window and configure the llama-
1. Configure the CLI to point to the llama-stack server. 1. Configure the CLI to point to the llama-stack server.
```bash ```bash
llama-stack-client configure --endpoint http://localhost:5001 llama-stack-client configure --endpoint http://localhost:8321
``` ```
**Expected Output:** **Expected Output:**
```bash ```bash
Done! You can now use the Llama Stack Client CLI with endpoint http://localhost:5001 Done! You can now use the Llama Stack Client CLI with endpoint http://localhost:8321
``` ```
2. Test the CLI by running inference: 2. Test the CLI by running inference:
```bash ```bash
@ -218,7 +218,7 @@ if INFERENCE_MODEL is None:
raise ValueError("The environment variable 'INFERENCE_MODEL' is not set.") raise ValueError("The environment variable 'INFERENCE_MODEL' is not set.")
# Initialize the clien # Initialize the clien
client = LlamaStackClient(base_url="http://localhost:5001") client = LlamaStackClient(base_url="http://localhost:8321")
# Create a chat completion reques # Create a chat completion reques
response = client.inference.chat_completion( response = client.inference.chat_completion(

View file

@ -36,7 +36,6 @@ from llama_stack.apis.inference import (
) )
from llama_stack.apis.safety import SafetyViolation from llama_stack.apis.safety import SafetyViolation
from llama_stack.apis.tools import ToolDef from llama_stack.apis.tools import ToolDef
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
@ -189,13 +188,11 @@ class AgentToolGroupWithArgs(BaseModel):
args: Dict[str, Any] args: Dict[str, Any]
AgentToolGroup = register_schema( AgentToolGroup = Union[
Union[
str, str,
AgentToolGroupWithArgs, AgentToolGroupWithArgs,
], ]
name="AgentTool", register_schema(AgentToolGroup, name="AgentTool")
)
class AgentConfigCommon(BaseModel): class AgentConfigCommon(BaseModel):
@ -312,8 +309,7 @@ class AgentTurnResponseTurnAwaitingInputPayload(BaseModel):
turn: Turn turn: Turn
AgentTurnResponseEventPayload = register_schema( AgentTurnResponseEventPayload = Annotated[
Annotated[
Union[ Union[
AgentTurnResponseStepStartPayload, AgentTurnResponseStepStartPayload,
AgentTurnResponseStepProgressPayload, AgentTurnResponseStepProgressPayload,
@ -323,9 +319,8 @@ AgentTurnResponseEventPayload = register_schema(
AgentTurnResponseTurnAwaitingInputPayload, AgentTurnResponseTurnAwaitingInputPayload,
], ],
Field(discriminator="event_type"), Field(discriminator="event_type"),
], ]
name="AgentTurnResponseEventPayload", register_schema(AgentTurnResponseEventPayload, name="AgentTurnResponseEventPayload")
)
@json_schema_type @json_schema_type
@ -387,7 +382,6 @@ class AgentStepResponse(BaseModel):
@runtime_checkable @runtime_checkable
@trace_protocol
class Agents(Protocol): class Agents(Protocol):
"""Agents API for creating and interacting with agentic systems. """Agents API for creating and interacting with agentic systems.
@ -399,7 +393,7 @@ class Agents(Protocol):
- Agents can also use Memory to retrieve information from knowledge bases. See the RAG Tool and Vector IO APIs for more details. - Agents can also use Memory to retrieve information from knowledge bases. See the RAG Tool and Vector IO APIs for more details.
""" """
@webmethod(route="/agents", method="POST") @webmethod(route="/agents", method="POST", descriptive_name="create_agent")
async def create_agent( async def create_agent(
self, self,
agent_config: AgentConfig, agent_config: AgentConfig,
@ -411,7 +405,9 @@ class Agents(Protocol):
""" """
... ...
@webmethod(route="/agents/{agent_id}/session/{session_id}/turn", method="POST") @webmethod(
route="/agents/{agent_id}/session/{session_id}/turn", method="POST", descriptive_name="create_agent_turn"
)
async def create_agent_turn( async def create_agent_turn(
self, self,
agent_id: str, agent_id: str,
@ -443,6 +439,7 @@ class Agents(Protocol):
@webmethod( @webmethod(
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume", route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume",
method="POST", method="POST",
descriptive_name="resume_agent_turn",
) )
async def resume_agent_turn( async def resume_agent_turn(
self, self,
@ -505,7 +502,7 @@ class Agents(Protocol):
""" """
... ...
@webmethod(route="/agents/{agent_id}/session", method="POST") @webmethod(route="/agents/{agent_id}/session", method="POST", descriptive_name="create_agent_session")
async def create_agent_session( async def create_agent_session(
self, self,
agent_id: str, agent_id: str,

View file

@ -63,19 +63,15 @@ class TextContentItem(BaseModel):
# other modalities can be added here # other modalities can be added here
InterleavedContentItem = register_schema( InterleavedContentItem = Annotated[
Annotated[
Union[ImageContentItem, TextContentItem], Union[ImageContentItem, TextContentItem],
Field(discriminator="type"), Field(discriminator="type"),
], ]
name="InterleavedContentItem", register_schema(InterleavedContentItem, name="InterleavedContentItem")
)
# accept a single "str" as a special case since it is common # accept a single "str" as a special case since it is common
InterleavedContent = register_schema( InterleavedContent = Union[str, InterleavedContentItem, List[InterleavedContentItem]]
Union[str, InterleavedContentItem, List[InterleavedContentItem]], register_schema(InterleavedContent, name="InterleavedContent")
name="InterleavedContent",
)
@json_schema_type @json_schema_type
@ -109,10 +105,8 @@ class ToolCallDelta(BaseModel):
# streaming completions send a stream of ContentDeltas # streaming completions send a stream of ContentDeltas
ContentDelta = register_schema( ContentDelta = Annotated[
Annotated[
Union[TextDelta, ImageDelta, ToolCallDelta], Union[TextDelta, ImageDelta, ToolCallDelta],
Field(discriminator="type"), Field(discriminator="type"),
], ]
name="ContentDelta", register_schema(ContentDelta, name="ContentDelta")
)

View file

@ -72,8 +72,7 @@ class DialogType(BaseModel):
type: Literal["dialog"] = "dialog" type: Literal["dialog"] = "dialog"
ParamType = register_schema( ParamType = Annotated[
Annotated[
Union[ Union[
StringType, StringType,
NumberType, NumberType,
@ -87,9 +86,8 @@ ParamType = register_schema(
AgentTurnInputType, AgentTurnInputType,
], ],
Field(discriminator="type"), Field(discriminator="type"),
], ]
name="ParamType", register_schema(ParamType, name="ParamType")
)
""" """
# TODO: recursive definition of ParamType in these containers # TODO: recursive definition of ParamType in these containers

View file

@ -84,13 +84,11 @@ class RowsDataSource(BaseModel):
rows: List[Dict[str, Any]] rows: List[Dict[str, Any]]
DataSource = register_schema( DataSource = Annotated[
Annotated[
Union[URIDataSource, RowsDataSource], Union[URIDataSource, RowsDataSource],
Field(discriminator="type"), Field(discriminator="type"),
], ]
name="DataSource", register_schema(DataSource, name="DataSource")
)
class CommonDatasetFields(BaseModel): class CommonDatasetFields(BaseModel):
@ -121,8 +119,6 @@ class Dataset(CommonDatasetFields, Resource):
class DatasetInput(CommonDatasetFields, BaseModel): class DatasetInput(CommonDatasetFields, BaseModel):
dataset_id: str dataset_id: str
provider_id: Optional[str] = None
provider_dataset_id: Optional[str] = None
class ListDatasetsResponse(BaseModel): class ListDatasetsResponse(BaseModel):

View file

@ -0,0 +1,144 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List, Literal, Optional, Protocol, Union
from pydantic import BaseModel, Field
from typing_extensions import Annotated
from llama_stack.apis.agents import AgentConfig
from llama_stack.apis.common.job_types import Job
from llama_stack.apis.inference import SamplingParams, SystemMessage
from llama_stack.apis.scoring import ScoringResult
from llama_stack.apis.scoring_functions import ScoringFnParams
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
@json_schema_type
class ModelCandidate(BaseModel):
"""A model candidate for evaluation.
:param model: The model ID to evaluate.
:param sampling_params: The sampling parameters for the model.
:param system_message: (Optional) The system message providing instructions or context to the model.
"""
type: Literal["model"] = "model"
model: str
sampling_params: SamplingParams
system_message: Optional[SystemMessage] = None
@json_schema_type
class AgentCandidate(BaseModel):
"""An agent candidate for evaluation.
:param config: The configuration for the agent candidate.
"""
type: Literal["agent"] = "agent"
config: AgentConfig
EvalCandidate = Annotated[Union[ModelCandidate, AgentCandidate], Field(discriminator="type")]
register_schema(EvalCandidate, name="EvalCandidate")
@json_schema_type
class BenchmarkConfig(BaseModel):
"""A benchmark configuration for evaluation.
:param eval_candidate: The candidate to evaluate.
:param scoring_params: Map between scoring function id and parameters for each scoring function you want to run
:param num_examples: (Optional) The number of examples to evaluate. If not provided, all examples in the dataset will be evaluated
"""
eval_candidate: EvalCandidate
scoring_params: Dict[str, ScoringFnParams] = Field(
description="Map between scoring function id and parameters for each scoring function you want to run",
default_factory=dict,
)
num_examples: Optional[int] = Field(
description="Number of examples to evaluate (useful for testing), if not provided, all examples in the dataset will be evaluated",
default=None,
)
# we could optinally add any specific dataset config here
@json_schema_type
class EvaluateResponse(BaseModel):
"""The response from an evaluation.
:param generations: The generations from the evaluation.
:param scores: The scores from the evaluation.
"""
generations: List[Dict[str, Any]]
# each key in the dict is a scoring function name
scores: Dict[str, ScoringResult]
class Eval(Protocol):
"""Llama Stack Evaluation API for running evaluations on model and agent candidates."""
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs", method="POST")
async def run_eval(
self,
benchmark_id: str,
benchmark_config: BenchmarkConfig,
) -> Job:
"""Run an evaluation on a benchmark.
:param benchmark_id: The ID of the benchmark to run the evaluation on.
:param benchmark_config: The configuration for the benchmark.
:return: The job that was created to run the evaluation.
"""
@webmethod(route="/eval/benchmarks/{benchmark_id}/evaluations", method="POST")
async def evaluate_rows(
self,
benchmark_id: str,
input_rows: List[Dict[str, Any]],
scoring_functions: List[str],
benchmark_config: BenchmarkConfig,
) -> EvaluateResponse:
"""Evaluate a list of rows on a benchmark.
:param benchmark_id: The ID of the benchmark to run the evaluation on.
:param input_rows: The rows to evaluate.
:param scoring_functions: The scoring functions to use for the evaluation.
:param benchmark_config: The configuration for the benchmark.
:return: EvaluateResponse object containing generations and scores
"""
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="GET")
async def job_status(self, benchmark_id: str, job_id: str) -> Job:
"""Get the status of a job.
:param benchmark_id: The ID of the benchmark to run the evaluation on.
:param job_id: The ID of the job to get the status of.
:return: The status of the evaluationjob.
"""
...
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="DELETE")
async def job_cancel(self, benchmark_id: str, job_id: str) -> None:
"""Cancel a job.
:param benchmark_id: The ID of the benchmark to run the evaluation on.
:param job_id: The ID of the job to cancel.
"""
...
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}/result", method="GET")
async def job_result(self, benchmark_id: str, job_id: str) -> EvaluateResponse:
"""Get the result of a job.
:param benchmark_id: The ID of the benchmark to run the evaluation on.
:param job_id: The ID of the job to get the result of.
:return: The result of the job.
"""

View file

@ -144,8 +144,7 @@ class CompletionMessage(BaseModel):
tool_calls: Optional[List[ToolCall]] = Field(default_factory=list) tool_calls: Optional[List[ToolCall]] = Field(default_factory=list)
Message = register_schema( Message = Annotated[
Annotated[
Union[ Union[
UserMessage, UserMessage,
SystemMessage, SystemMessage,
@ -153,9 +152,8 @@ Message = register_schema(
CompletionMessage, CompletionMessage,
], ],
Field(discriminator="role"), Field(discriminator="role"),
], ]
name="Message", register_schema(Message, name="Message")
)
@json_schema_type @json_schema_type
@ -263,13 +261,11 @@ class GrammarResponseFormat(BaseModel):
bnf: Dict[str, Any] bnf: Dict[str, Any]
ResponseFormat = register_schema( ResponseFormat = Annotated[
Annotated[
Union[JsonSchemaResponseFormat, GrammarResponseFormat], Union[JsonSchemaResponseFormat, GrammarResponseFormat],
Field(discriminator="type"), Field(discriminator="type"),
], ]
name="ResponseFormat", register_schema(ResponseFormat, name="ResponseFormat")
)
# This is an internally used class # This is an internally used class

View file

@ -24,17 +24,6 @@ class HealthInfo(BaseModel):
# TODO: add a provider level status # TODO: add a provider level status
@json_schema_type
class ProviderInfo(BaseModel):
api: str
provider_id: str
provider_type: str
class ListProvidersResponse(BaseModel):
data: List[ProviderInfo]
@json_schema_type @json_schema_type
class VersionInfo(BaseModel): class VersionInfo(BaseModel):
version: str version: str
@ -46,9 +35,6 @@ class ListRoutesResponse(BaseModel):
@runtime_checkable @runtime_checkable
class Inspect(Protocol): class Inspect(Protocol):
@webmethod(route="/inspect/providers", method="GET")
async def list_providers(self) -> ListProvidersResponse: ...
@webmethod(route="/inspect/routes", method="GET") @webmethod(route="/inspect/routes", method="GET")
async def list_routes(self) -> ListRoutesResponse: ... async def list_routes(self) -> ListRoutesResponse: ...

View file

@ -6,7 +6,7 @@
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Protocol from typing import Any, Dict, List, Literal, Optional, Protocol, Union
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing_extensions import Annotated from typing_extensions import Annotated
@ -88,10 +88,8 @@ class QATFinetuningConfig(BaseModel):
group_size: int group_size: int
AlgorithmConfig = register_schema( AlgorithmConfig = Annotated[Union[LoraFinetuningConfig, QATFinetuningConfig], Field(discriminator="type")]
Annotated[LoraFinetuningConfig | QATFinetuningConfig, Field(discriminator="type")], register_schema(AlgorithmConfig, name="AlgorithmConfig")
name="AlgorithmConfig",
)
@json_schema_type @json_schema_type
@ -184,7 +182,7 @@ class PostTraining(Protocol):
description="Model descriptor from `llama model list`", description="Model descriptor from `llama model list`",
), ),
checkpoint_dir: Optional[str] = None, checkpoint_dir: Optional[str] = None,
algorithm_config: Optional[LoraFinetuningConfig | QATFinetuningConfig] = None, algorithm_config: Optional[AlgorithmConfig] = None,
) -> PostTrainingJob: ... ) -> PostTrainingJob: ...
@webmethod(route="/post-training/preference-optimize", method="POST") @webmethod(route="/post-training/preference-optimize", method="POST")

View file

@ -0,0 +1,149 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import (
Any,
Dict,
List,
Literal,
Optional,
Protocol,
Union,
runtime_checkable,
)
from pydantic import BaseModel, Field
from typing_extensions import Annotated
from llama_stack.apis.common.type_system import ParamType
from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
# Perhaps more structure can be imposed on these functions. Maybe they could be associated
# with standard metrics so they can be rolled up?
@json_schema_type
class ScoringFnParamsType(Enum):
llm_as_judge = "llm_as_judge"
regex_parser = "regex_parser"
basic = "basic"
@json_schema_type
class AggregationFunctionType(Enum):
average = "average"
weighted_average = "weighted_average"
median = "median"
categorical_count = "categorical_count"
accuracy = "accuracy"
@json_schema_type
class LLMAsJudgeScoringFnParams(BaseModel):
type: Literal[ScoringFnParamsType.llm_as_judge.value] = ScoringFnParamsType.llm_as_judge.value
judge_model: str
prompt_template: Optional[str] = None
judge_score_regexes: Optional[List[str]] = Field(
description="Regexes to extract the answer from generated response",
default_factory=list,
)
aggregation_functions: Optional[List[AggregationFunctionType]] = Field(
description="Aggregation functions to apply to the scores of each row",
default_factory=list,
)
@json_schema_type
class RegexParserScoringFnParams(BaseModel):
type: Literal[ScoringFnParamsType.regex_parser.value] = ScoringFnParamsType.regex_parser.value
parsing_regexes: Optional[List[str]] = Field(
description="Regex to extract the answer from generated response",
default_factory=list,
)
aggregation_functions: Optional[List[AggregationFunctionType]] = Field(
description="Aggregation functions to apply to the scores of each row",
default_factory=list,
)
@json_schema_type
class BasicScoringFnParams(BaseModel):
type: Literal[ScoringFnParamsType.basic.value] = ScoringFnParamsType.basic.value
aggregation_functions: Optional[List[AggregationFunctionType]] = Field(
description="Aggregation functions to apply to the scores of each row",
default_factory=list,
)
ScoringFnParams = Annotated[
Union[
LLMAsJudgeScoringFnParams,
RegexParserScoringFnParams,
BasicScoringFnParams,
],
Field(discriminator="type"),
]
register_schema(ScoringFnParams, name="ScoringFnParams")
class CommonScoringFnFields(BaseModel):
description: Optional[str] = None
metadata: Dict[str, Any] = Field(
default_factory=dict,
description="Any additional metadata for this definition",
)
return_type: ParamType = Field(
description="The return type of the deterministic function",
)
params: Optional[ScoringFnParams] = Field(
description="The parameters for the scoring function for benchmark eval, these can be overridden for app eval",
default=None,
)
@json_schema_type
class ScoringFn(CommonScoringFnFields, Resource):
type: Literal[ResourceType.scoring_function.value] = ResourceType.scoring_function.value
@property
def scoring_fn_id(self) -> str:
return self.identifier
@property
def provider_scoring_fn_id(self) -> str:
return self.provider_resource_id
class ScoringFnInput(CommonScoringFnFields, BaseModel):
scoring_fn_id: str
provider_id: Optional[str] = None
provider_scoring_fn_id: Optional[str] = None
class ListScoringFunctionsResponse(BaseModel):
data: List[ScoringFn]
@runtime_checkable
class ScoringFunctions(Protocol):
@webmethod(route="/scoring-functions", method="GET")
async def list_scoring_functions(self) -> ListScoringFunctionsResponse: ...
@webmethod(route="/scoring-functions/{scoring_fn_id:path}", method="GET")
async def get_scoring_function(self, scoring_fn_id: str, /) -> ScoringFn: ...
@webmethod(route="/scoring-functions", method="POST")
async def register_scoring_function(
self,
scoring_fn_id: str,
description: str,
return_type: ParamType,
provider_scoring_fn_id: Optional[str] = None,
provider_id: Optional[str] = None,
params: Optional[ScoringFnParams] = None,
) -> None: ...

View file

@ -146,16 +146,14 @@ class SpanEndPayload(BaseModel):
status: SpanStatus status: SpanStatus
StructuredLogPayload = register_schema( StructuredLogPayload = Annotated[
Annotated[
Union[ Union[
SpanStartPayload, SpanStartPayload,
SpanEndPayload, SpanEndPayload,
], ],
Field(discriminator="type"), Field(discriminator="type"),
], ]
name="StructuredLogPayload", register_schema(StructuredLogPayload, name="StructuredLogPayload")
)
@json_schema_type @json_schema_type
@ -164,17 +162,15 @@ class StructuredLogEvent(EventCommon):
payload: StructuredLogPayload payload: StructuredLogPayload
Event = register_schema( Event = Annotated[
Annotated[
Union[ Union[
UnstructuredLogEvent, UnstructuredLogEvent,
MetricEvent, MetricEvent,
StructuredLogEvent, StructuredLogEvent,
], ],
Field(discriminator="type"), Field(discriminator="type"),
], ]
name="Event", register_schema(Event, name="Event")
)
@json_schema_type @json_schema_type

View file

@ -17,6 +17,15 @@ from llama_stack.schema_utils import json_schema_type, register_schema, webmetho
@json_schema_type @json_schema_type
class RAGDocument(BaseModel): class RAGDocument(BaseModel):
"""
A document to be used for document ingestion in the RAG Tool.
:param document_id: The unique identifier for the document.
:param content: The content of the document.
:param mime_type: The MIME type of the document.
:param metadata: Additional metadata for the document.
"""
document_id: str document_id: str
content: InterleavedContent | URL content: InterleavedContent | URL
mime_type: str | None = None mime_type: str | None = None
@ -49,16 +58,14 @@ class LLMRAGQueryGeneratorConfig(BaseModel):
template: str template: str
RAGQueryGeneratorConfig = register_schema( RAGQueryGeneratorConfig = Annotated[
Annotated[
Union[ Union[
DefaultRAGQueryGeneratorConfig, DefaultRAGQueryGeneratorConfig,
LLMRAGQueryGeneratorConfig, LLMRAGQueryGeneratorConfig,
], ],
Field(discriminator="type"), Field(discriminator="type"),
], ]
name="RAGQueryGeneratorConfig", register_schema(RAGQueryGeneratorConfig, name="RAGQueryGeneratorConfig")
)
@json_schema_type @json_schema_type

View file

@ -69,7 +69,7 @@ class ToolGroup(Resource):
@json_schema_type @json_schema_type
class ToolInvocationResult(BaseModel): class ToolInvocationResult(BaseModel):
content: InterleavedContent content: Optional[InterleavedContent] = None
error_message: Optional[str] = None error_message: Optional[str] = None
error_code: Optional[int] = None error_code: Optional[int] = None
metadata: Optional[Dict[str, Any]] = None metadata: Optional[Dict[str, Any]] = None
@ -140,9 +140,9 @@ class SpecialToolGroup(Enum):
@runtime_checkable @runtime_checkable
@trace_protocol @trace_protocol
class ToolRuntime(Protocol): class ToolRuntime(Protocol):
tool_store: ToolStore tool_store: ToolStore | None = None
rag_tool: RAGToolRuntime rag_tool: RAGToolRuntime | None = None
# TODO: This needs to be renamed once OPEN API generator name conflict issue is fixed. # TODO: This needs to be renamed once OPEN API generator name conflict issue is fixed.
@webmethod(route="/tool-runtime/list-tools", method="GET") @webmethod(route="/tool-runtime/list-tools", method="GET")

View file

@ -36,7 +36,7 @@ class VectorDBStore(Protocol):
@runtime_checkable @runtime_checkable
@trace_protocol @trace_protocol
class VectorIO(Protocol): class VectorIO(Protocol):
vector_db_store: VectorDBStore vector_db_store: VectorDBStore | None = None
# this will just block now until chunks are inserted, but it should # this will just block now until chunks are inserted, but it should
# probably return a Job instance which can be polled for completion # probably return a Job instance which can be polled for completion

View file

@ -404,7 +404,7 @@ def _download_from_manifest(manifest_file: str, max_concurrent_downloads: int):
d = json.load(f) d = json.load(f)
manifest = Manifest(**d) manifest = Manifest(**d)
if datetime.now(timezone.utc) > manifest.expires_on: if datetime.now(timezone.utc) > manifest.expires_on.astimezone(timezone.utc):
raise ValueError(f"Manifest URLs have expired on {manifest.expires_on}") raise ValueError(f"Manifest URLs have expired on {manifest.expires_on}")
console = Console() console = Console()

View file

@ -0,0 +1,86 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, Optional
from llama_stack.distribution.datatypes import AccessAttributes
from llama_stack.log import get_logger
logger = get_logger(__name__, category="core")
def check_access(
obj_identifier: str,
obj_attributes: Optional[AccessAttributes],
user_attributes: Optional[Dict[str, Any]] = None,
) -> bool:
"""Check if the current user has access to the given object, based on access attributes.
Access control algorithm:
1. If the resource has no access_attributes, access is GRANTED to all authenticated users
2. If the user has no attributes, access is DENIED to any object with access_attributes defined
3. For each attribute category in the resource's access_attributes:
a. If the user lacks that category, access is DENIED
b. If the user has the category but none of the required values, access is DENIED
c. If the user has at least one matching value in each required category, access is GRANTED
Example:
# Resource requires:
access_attributes = AccessAttributes(
roles=["admin", "data-scientist"],
teams=["ml-team"]
)
# User has:
user_attributes = {
"roles": ["data-scientist", "engineer"],
"teams": ["ml-team", "infra-team"],
"projects": ["llama-3"]
}
# Result: Access GRANTED
# - User has the "data-scientist" role (matches one of the required roles)
# - AND user is part of the "ml-team" (matches the required team)
# - The extra "projects" attribute is ignored
Args:
obj_identifier: The identifier of the resource object to check access for
obj_attributes: The access attributes of the resource object
user_attributes: The attributes of the current user
Returns:
bool: True if access is granted, False if denied
"""
# If object has no access attributes, allow access by default
if not obj_attributes:
return True
# If no user attributes, deny access to objects with access control
if not user_attributes:
return False
dict_attribs = obj_attributes.model_dump(exclude_none=True)
if not dict_attribs:
return True
# Check each attribute category (requires ALL categories to match)
# TODO: formalize this into a proper ABAC policy
for attr_key, required_values in dict_attribs.items():
user_values = user_attributes.get(attr_key, [])
if not user_values:
logger.debug(f"Access denied to {obj_identifier}: missing required attribute category '{attr_key}'")
return False
if not any(val in user_values for val in required_values):
logger.debug(
f"Access denied to {obj_identifier}: "
f"no match for attribute '{attr_key}', required one of {required_values}"
)
return False
logger.debug(f"Access granted to {obj_identifier}")
return True

View file

@ -90,6 +90,7 @@ RUN apt-get update && apt-get install -y \
procps psmisc lsof \ procps psmisc lsof \
traceroute \ traceroute \
bubblewrap \ bubblewrap \
gcc \
&& rm -rf /var/lib/apt/lists/* && rm -rf /var/lib/apt/lists/*
ENV UV_SYSTEM_PYTHON=1 ENV UV_SYSTEM_PYTHON=1
@ -235,7 +236,7 @@ image_tag="$image_name:$version_tag"
# Detect platform architecture # Detect platform architecture
ARCH=$(uname -m) ARCH=$(uname -m)
if [ -n "$BUILD_PLATFORM" ]; then if [ -n "$BUILD_PLATFORM" ]; then
CLI_ARGS+=("--platform $BUILD_PLATFORM") CLI_ARGS+=("--platform" "$BUILD_PLATFORM")
elif [ "$ARCH" = "arm64" ] || [ "$ARCH" = "aarch64" ]; then elif [ "$ARCH" = "arm64" ] || [ "$ARCH" = "aarch64" ]; then
CLI_ARGS+=("--platform" "linux/arm64") CLI_ARGS+=("--platform" "linux/arm64")
elif [ "$ARCH" = "x86_64" ]; then elif [ "$ARCH" = "x86_64" ]; then

View file

@ -13,6 +13,7 @@ from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Dataset, DatasetInput from llama_stack.apis.datasets import Dataset, DatasetInput
from llama_stack.apis.inference import Inference from llama_stack.apis.inference import Inference
from llama_stack.apis.models import Model, ModelInput from llama_stack.apis.models import Model, ModelInput
from llama_stack.apis.resource import Resource
from llama_stack.apis.safety import Safety from llama_stack.apis.safety import Safety
from llama_stack.apis.shields import Shield, ShieldInput from llama_stack.apis.shields import Shield, ShieldInput
from llama_stack.apis.tools import Tool, ToolGroup, ToolGroupInput, ToolRuntime from llama_stack.apis.tools import Tool, ToolGroup, ToolGroupInput, ToolRuntime
@ -28,6 +29,115 @@ LLAMA_STACK_RUN_CONFIG_VERSION = "2"
RoutingKey = Union[str, List[str]] RoutingKey = Union[str, List[str]]
class AccessAttributes(BaseModel):
"""Structured representation of user attributes for access control.
This model defines a structured approach to representing user attributes
with common standard categories for access control.
Standard attribute categories include:
- roles: Role-based attributes (e.g., admin, data-scientist)
- teams: Team-based attributes (e.g., ml-team, infra-team)
- projects: Project access attributes (e.g., llama-3, customer-insights)
- namespaces: Namespace-based access control for resource isolation
"""
# Standard attribute categories - the minimal set we need now
roles: Optional[List[str]] = Field(
default=None, description="Role-based attributes (e.g., 'admin', 'data-scientist', 'user')"
)
teams: Optional[List[str]] = Field(default=None, description="Team-based attributes (e.g., 'ml-team', 'nlp-team')")
projects: Optional[List[str]] = Field(
default=None, description="Project-based access attributes (e.g., 'llama-3', 'customer-insights')"
)
namespaces: Optional[List[str]] = Field(
default=None, description="Namespace-based access control for resource isolation"
)
class ResourceWithACL(Resource):
"""Extension of Resource that adds attribute-based access control capabilities.
This class adds an optional access_attributes field that allows fine-grained control
over which users can access each resource. When attributes are defined, a user must have
matching attributes to access the resource.
Attribute Matching Algorithm:
1. If a resource has no access_attributes (None or empty dict), it's visible to all authenticated users
2. Each key in access_attributes represents an attribute category (e.g., "roles", "teams", "projects")
3. The matching algorithm requires ALL categories to match (AND relationship between categories)
4. Within each category, ANY value match is sufficient (OR relationship within a category)
Examples:
# Resource visible to everyone (no access control)
model = Model(identifier="llama-2", ...)
# Resource visible only to admins
model = Model(
identifier="gpt-4",
access_attributes=AccessAttributes(roles=["admin"])
)
# Resource visible to data scientists on the ML team
model = Model(
identifier="private-model",
access_attributes=AccessAttributes(
roles=["data-scientist", "researcher"],
teams=["ml-team"]
)
)
# ^ User must have at least one of the roles AND be on the ml-team
# Resource visible to users with specific project access
vector_db = VectorDB(
identifier="customer-embeddings",
access_attributes=AccessAttributes(
projects=["customer-insights"],
namespaces=["confidential"]
)
)
# ^ User must have access to the customer-insights project AND have confidential namespace
"""
access_attributes: Optional[AccessAttributes] = None
# Use the extended Resource for all routable objects
class ModelWithACL(Model, ResourceWithACL):
pass
class ShieldWithACL(Shield, ResourceWithACL):
pass
class VectorDBWithACL(VectorDB, ResourceWithACL):
pass
class DatasetWithACL(Dataset, ResourceWithACL):
pass
class ScoringFnWithACL(ScoringFn, ResourceWithACL):
pass
class BenchmarkWithACL(Benchmark, ResourceWithACL):
pass
class ToolWithACL(Tool, ResourceWithACL):
pass
class ToolGroupWithACL(ToolGroup, ResourceWithACL):
pass
RoutableObject = Union[ RoutableObject = Union[
Model, Model,
Shield, Shield,
@ -41,13 +151,14 @@ RoutableObject = Union[
RoutableObjectWithProvider = Annotated[ RoutableObjectWithProvider = Annotated[
Union[ Union[
Model, ModelWithACL,
Shield, ShieldWithACL,
VectorDB, VectorDBWithACL,
Dataset, DatasetWithACL,
Benchmark, ScoringFnWithACL,
Tool, BenchmarkWithACL,
ToolGroup, ToolWithACL,
ToolGroupWithACL,
], ],
Field(discriminator="type"), Field(discriminator="type"),
] ]

View file

@ -11,9 +11,7 @@ from pydantic import BaseModel
from llama_stack.apis.inspect import ( from llama_stack.apis.inspect import (
HealthInfo, HealthInfo,
Inspect, Inspect,
ListProvidersResponse,
ListRoutesResponse, ListRoutesResponse,
ProviderInfo,
RouteInfo, RouteInfo,
VersionInfo, VersionInfo,
) )
@ -39,24 +37,6 @@ class DistributionInspectImpl(Inspect):
async def initialize(self) -> None: async def initialize(self) -> None:
pass pass
async def list_providers(self) -> ListProvidersResponse:
run_config = self.config.run_config
ret = []
for api, providers in run_config.providers.items():
ret.extend(
[
ProviderInfo(
api=api,
provider_id=p.provider_id,
provider_type=p.provider_type,
)
for p in providers
]
)
return ListProvidersResponse(data=ret)
async def list_routes(self) -> ListRoutesResponse: async def list_routes(self) -> ListRoutesResponse:
run_config = self.config.run_config run_config = self.config.run_config

View file

@ -9,7 +9,6 @@ import inspect
import json import json
import logging import logging
import os import os
import re
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
@ -37,7 +36,10 @@ from llama_stack.distribution.request_headers import (
request_provider_data_context, request_provider_data_context,
) )
from llama_stack.distribution.resolver import ProviderRegistry from llama_stack.distribution.resolver import ProviderRegistry
from llama_stack.distribution.server.endpoints import get_all_api_endpoints from llama_stack.distribution.server.endpoints import (
find_matching_endpoint,
initialize_endpoint_impls,
)
from llama_stack.distribution.stack import ( from llama_stack.distribution.stack import (
construct_stack, construct_stack,
get_stack_run_config_from_template, get_stack_run_config_from_template,
@ -232,31 +234,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
safe_config = redact_sensitive_fields(self.config.model_dump()) safe_config = redact_sensitive_fields(self.config.model_dump())
console.print(yaml.dump(safe_config, indent=2)) console.print(yaml.dump(safe_config, indent=2))
endpoints = get_all_api_endpoints() self.endpoint_impls = initialize_endpoint_impls(self.impls)
endpoint_impls = {}
def _convert_path_to_regex(path: str) -> str:
# Convert {param} to named capture groups
# handle {param:path} as well which allows for forward slashes in the param value
pattern = re.sub(
r"{(\w+)(?::path)?}",
lambda m: f"(?P<{m.group(1)}>{'[^/]+' if not m.group(0).endswith(':path') else '.+'})",
path,
)
return f"^{pattern}$"
for api, api_endpoints in endpoints.items():
if api not in self.impls:
continue
for endpoint in api_endpoints:
impl = self.impls[api]
func = getattr(impl, endpoint.name)
if endpoint.method not in endpoint_impls:
endpoint_impls[endpoint.method] = {}
endpoint_impls[endpoint.method][_convert_path_to_regex(endpoint.route)] = func
self.endpoint_impls = endpoint_impls
return True return True
async def request( async def request(
@ -290,32 +268,6 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
) )
return response return response
def _find_matching_endpoint(self, method: str, path: str) -> tuple[Any, dict]:
"""Find the matching endpoint implementation for a given method and path.
Args:
method: HTTP method (GET, POST, etc.)
path: URL path to match against
Returns:
A tuple of (endpoint_function, path_params)
Raises:
ValueError: If no matching endpoint is found
"""
impls = self.endpoint_impls.get(method)
if not impls:
raise ValueError(f"No endpoint found for {path}")
for regex, func in impls.items():
match = re.match(regex, path)
if match:
# Extract named groups from the regex match
path_params = match.groupdict()
return func, path_params
raise ValueError(f"No endpoint found for {path}")
async def _call_non_streaming( async def _call_non_streaming(
self, self,
*, *,
@ -326,10 +278,10 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
body = options.params or {} body = options.params or {}
body |= options.json_data or {} body |= options.json_data or {}
matched_func, path_params = self._find_matching_endpoint(options.method, path) matched_func, path_params, route = find_matching_endpoint(options.method, path, self.endpoint_impls)
body |= path_params body |= path_params
body = self._convert_body(path, options.method, body) body = self._convert_body(path, options.method, body)
await start_trace(options.url, {"__location__": "library_client"}) await start_trace(route, {"__location__": "library_client"})
try: try:
result = await matched_func(**body) result = await matched_func(**body)
finally: finally:
@ -371,13 +323,13 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
path = options.url path = options.url
body = options.params or {} body = options.params or {}
body |= options.json_data or {} body |= options.json_data or {}
func, path_params = self._find_matching_endpoint(options.method, path) func, path_params, route = find_matching_endpoint(options.method, path, self.endpoint_impls)
body |= path_params body |= path_params
body = self._convert_body(path, options.method, body) body = self._convert_body(path, options.method, body)
async def gen(): async def gen():
await start_trace(options.url, {"__location__": "library_client"}) await start_trace(route, {"__location__": "library_client"})
try: try:
async for chunk in await func(**body): async for chunk in await func(**body):
data = json.dumps(convert_pydantic_to_json_value(chunk)) data = json.dumps(convert_pydantic_to_json_value(chunk))
@ -422,7 +374,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
if not body: if not body:
return {} return {}
func, _ = self._find_matching_endpoint(method, path) func, _, _ = find_matching_endpoint(method, path, self.endpoint_impls)
sig = inspect.signature(func) sig = inspect.signature(func)
# Strip NOT_GIVENs to use the defaults in signature # Strip NOT_GIVENs to use the defaults in signature

View file

@ -7,21 +7,26 @@
import contextvars import contextvars
import json import json
import logging import logging
from typing import Any, ContextManager, Dict, Optional from typing import Any, ContextManager, Dict, List, Optional
from .utils.dynamic import instantiate_class_type from .utils.dynamic import instantiate_class_type
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
# Context variable for request provider data # Context variable for request provider data and auth attributes
PROVIDER_DATA_VAR = contextvars.ContextVar("provider_data", default=None) PROVIDER_DATA_VAR = contextvars.ContextVar("provider_data", default=None)
class RequestProviderDataContext(ContextManager): class RequestProviderDataContext(ContextManager):
"""Context manager for request provider data""" """Context manager for request provider data"""
def __init__(self, provider_data: Optional[Dict[str, Any]] = None): def __init__(
self.provider_data = provider_data self, provider_data: Optional[Dict[str, Any]] = None, auth_attributes: Optional[Dict[str, List[str]]] = None
):
self.provider_data = provider_data or {}
if auth_attributes:
self.provider_data["__auth_attributes"] = auth_attributes
self.token = None self.token = None
def __enter__(self): def __enter__(self):
@ -80,7 +85,17 @@ def parse_request_provider_data(headers: Dict[str, str]) -> Optional[Dict[str, A
return None return None
def request_provider_data_context(headers: Dict[str, str]) -> ContextManager: def request_provider_data_context(
"""Context manager that sets request provider data from headers for the duration of the context""" headers: Dict[str, str], auth_attributes: Optional[Dict[str, List[str]]] = None
) -> ContextManager:
"""Context manager that sets request provider data from headers and auth attributes for the duration of the context"""
provider_data = parse_request_provider_data(headers) provider_data = parse_request_provider_data(headers)
return RequestProviderDataContext(provider_data) return RequestProviderDataContext(provider_data, auth_attributes)
def get_auth_attributes() -> Optional[Dict[str, List[str]]]:
"""Helper to retrieve auth attributes from the provider data context"""
provider_data = PROVIDER_DATA_VAR.get()
if not provider_data:
return None
return provider_data.get("__auth_attributes")

View file

@ -19,6 +19,8 @@ from llama_stack.apis.datasets import (
DatasetType, DatasetType,
DataSource, DataSource,
ListDatasetsResponse, ListDatasetsResponse,
RowsDataSource,
URIDataSource,
) )
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType
from llama_stack.apis.resource import ResourceType from llama_stack.apis.resource import ResourceType
@ -32,11 +34,22 @@ from llama_stack.apis.tools import (
ToolHost, ToolHost,
) )
from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs
from llama_stack.distribution.access_control import check_access
from llama_stack.distribution.datatypes import ( from llama_stack.distribution.datatypes import (
AccessAttributes,
BenchmarkWithACL,
DatasetWithACL,
ModelWithACL,
RoutableObject, RoutableObject,
RoutableObjectWithProvider, RoutableObjectWithProvider,
RoutedProtocol, RoutedProtocol,
ScoringFnWithACL,
ShieldWithACL,
ToolGroupWithACL,
ToolWithACL,
VectorDBWithACL,
) )
from llama_stack.distribution.request_headers import get_auth_attributes
from llama_stack.distribution.store import DistributionRegistry from llama_stack.distribution.store import DistributionRegistry
from llama_stack.providers.datatypes import Api, RoutingTable from llama_stack.providers.datatypes import Api, RoutingTable
@ -165,6 +178,11 @@ class CommonRoutingTableImpl(RoutingTable):
if not obj: if not obj:
return None return None
# Check if user has permission to access this object
if not check_access(obj.identifier, getattr(obj, "access_attributes", None), get_auth_attributes()):
logger.debug(f"Access denied to {type} '{identifier}' based on attribute mismatch")
return None
return obj return obj
async def unregister_object(self, obj: RoutableObjectWithProvider) -> None: async def unregister_object(self, obj: RoutableObjectWithProvider) -> None:
@ -181,6 +199,13 @@ class CommonRoutingTableImpl(RoutingTable):
p = self.impls_by_provider_id[obj.provider_id] p = self.impls_by_provider_id[obj.provider_id]
# If object supports access control but no attributes set, use creator's attributes
if not obj.access_attributes:
creator_attributes = get_auth_attributes()
if creator_attributes:
obj.access_attributes = AccessAttributes(**creator_attributes)
logger.info(f"Setting access attributes for {obj.type} '{obj.identifier}' based on creator's identity")
registered_obj = await register_object_with_provider(obj, p) registered_obj = await register_object_with_provider(obj, p)
# TODO: This needs to be fixed for all APIs once they return the registered object # TODO: This needs to be fixed for all APIs once they return the registered object
if obj.type == ResourceType.model.value: if obj.type == ResourceType.model.value:
@ -193,7 +218,17 @@ class CommonRoutingTableImpl(RoutingTable):
async def get_all_with_type(self, type: str) -> List[RoutableObjectWithProvider]: async def get_all_with_type(self, type: str) -> List[RoutableObjectWithProvider]:
objs = await self.dist_registry.get_all() objs = await self.dist_registry.get_all()
return [obj for obj in objs if obj.type == type] filtered_objs = [obj for obj in objs if obj.type == type]
# Apply attribute-based access control filtering
if filtered_objs:
filtered_objs = [
obj
for obj in filtered_objs
if check_access(obj.identifier, getattr(obj, "access_attributes", None), get_auth_attributes())
]
return filtered_objs
class ModelsRoutingTable(CommonRoutingTableImpl, Models): class ModelsRoutingTable(CommonRoutingTableImpl, Models):
@ -230,7 +265,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
model_type = ModelType.llm model_type = ModelType.llm
if "embedding_dimension" not in metadata and model_type == ModelType.embedding: if "embedding_dimension" not in metadata and model_type == ModelType.embedding:
raise ValueError("Embedding model must have an embedding dimension in its metadata") raise ValueError("Embedding model must have an embedding dimension in its metadata")
model = Model( model = ModelWithACL(
identifier=model_id, identifier=model_id,
provider_resource_id=provider_model_id, provider_resource_id=provider_model_id,
provider_id=provider_id, provider_id=provider_id,
@ -276,7 +311,7 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
) )
if params is None: if params is None:
params = {} params = {}
shield = Shield( shield = ShieldWithACL(
identifier=shield_id, identifier=shield_id,
provider_resource_id=provider_shield_id, provider_resource_id=provider_shield_id,
provider_id=provider_id, provider_id=provider_id,
@ -330,7 +365,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
"embedding_model": embedding_model, "embedding_model": embedding_model,
"embedding_dimension": model.metadata["embedding_dimension"], "embedding_dimension": model.metadata["embedding_dimension"],
} }
vector_db = TypeAdapter(VectorDB).validate_python(vector_db_data) vector_db = TypeAdapter(VectorDBWithACL).validate_python(vector_db_data)
await self.register_object(vector_db) await self.register_object(vector_db)
return vector_db return vector_db
@ -358,6 +393,12 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[Dict[str, Any]] = None,
dataset_id: Optional[str] = None, dataset_id: Optional[str] = None,
) -> Dataset: ) -> Dataset:
if isinstance(source, dict):
if source["type"] == "uri":
source = URIDataSource.parse_obj(source)
elif source["type"] == "rows":
source = RowsDataSource.parse_obj(source)
if not dataset_id: if not dataset_id:
dataset_id = f"dataset-{str(uuid.uuid4())}" dataset_id = f"dataset-{str(uuid.uuid4())}"
@ -378,7 +419,7 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
if metadata is None: if metadata is None:
metadata = {} metadata = {}
dataset = Dataset( dataset = DatasetWithACL(
identifier=dataset_id, identifier=dataset_id,
provider_resource_id=provider_dataset_id, provider_resource_id=provider_dataset_id,
provider_id=provider_id, provider_id=provider_id,
@ -429,7 +470,7 @@ class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
raise ValueError("No evaluation providers available. Please configure an evaluation provider.") raise ValueError("No evaluation providers available. Please configure an evaluation provider.")
provider_id = list(self.impls_by_provider_id.keys())[0] provider_id = list(self.impls_by_provider_id.keys())[0]
benchmark = Benchmark( benchmark = BenchmarkWithACL(
identifier=benchmark_id, identifier=benchmark_id,
dataset_id=dataset_id, dataset_id=dataset_id,
grader_ids=grader_ids, grader_ids=grader_ids,
@ -473,7 +514,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
for tool_def in tool_defs: for tool_def in tool_defs:
tools.append( tools.append(
Tool( ToolWithACL(
identifier=tool_def.name, identifier=tool_def.name,
toolgroup_id=toolgroup_id, toolgroup_id=toolgroup_id,
description=tool_def.description or "", description=tool_def.description or "",
@ -498,7 +539,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
await self.register_object(tool) await self.register_object(tool)
await self.dist_registry.register( await self.dist_registry.register(
ToolGroup( ToolGroupWithACL(
identifier=toolgroup_id, identifier=toolgroup_id,
provider_id=provider_id, provider_id=provider_id,
provider_resource_id=toolgroup_id, provider_resource_id=toolgroup_id,
@ -511,7 +552,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
tool_group = await self.get_tool_group(toolgroup_id) tool_group = await self.get_tool_group(toolgroup_id)
if tool_group is None: if tool_group is None:
raise ValueError(f"Tool group {toolgroup_id} not found") raise ValueError(f"Tool group {toolgroup_id} not found")
tools = await self.list_tools(toolgroup_id).data tools = (await self.list_tools(toolgroup_id)).data
for tool in tools: for tool in tools:
await self.unregister_object(tool) await self.unregister_object(tool)
await self.unregister_object(tool_group) await self.unregister_object(tool_group)

View file

@ -5,16 +5,118 @@
# the root directory of this source tree. # the root directory of this source tree.
import json import json
from typing import Dict, List, Optional
from urllib.parse import parse_qs from urllib.parse import parse_qs
import httpx import httpx
from pydantic import BaseModel, Field
from llama_stack.distribution.datatypes import AccessAttributes
from llama_stack.log import get_logger from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="auth") logger = get_logger(name=__name__, category="auth")
class AuthRequestContext(BaseModel):
path: str = Field(description="The path of the request being authenticated")
headers: Dict[str, str] = Field(description="HTTP headers from the original request (excluding Authorization)")
params: Dict[str, List[str]] = Field(
description="Query parameters from the original request, parsed as dictionary of lists"
)
class AuthRequest(BaseModel):
api_key: str = Field(description="The API key extracted from the Authorization header")
request: AuthRequestContext = Field(description="Context information about the request being authenticated")
class AuthResponse(BaseModel):
"""The format of the authentication response from the auth endpoint."""
access_attributes: Optional[AccessAttributes] = Field(
default=None,
description="""
Structured user attributes for attribute-based access control.
These attributes determine which resources the user can access.
The model provides standard categories like "roles", "teams", "projects", and "namespaces".
Each attribute category contains a list of values that the user has for that category.
During access control checks, these values are compared against resource requirements.
Example with standard categories:
```json
{
"roles": ["admin", "data-scientist"],
"teams": ["ml-team"],
"projects": ["llama-3"],
"namespaces": ["research"]
}
```
""",
)
message: Optional[str] = Field(
default=None, description="Optional message providing additional context about the authentication result."
)
class AuthenticationMiddleware: class AuthenticationMiddleware:
"""Middleware that authenticates requests using an external auth endpoint.
This middleware:
1. Extracts the Bearer token from the Authorization header
2. Sends it to the configured auth endpoint along with request details
3. Validates the response and extracts user attributes
4. Makes these attributes available to the route handlers for access control
Authentication Request Format:
```json
{
"api_key": "the-api-key-extracted-from-auth-header",
"request": {
"path": "/models/list",
"headers": {
"content-type": "application/json",
"user-agent": "..."
// All headers except Authorization
},
"params": {
"limit": ["100"],
"offset": ["0"]
// Query parameters as key -> list of values
}
}
}
```
Expected Auth Endpoint Response Format:
```json
{
"access_attributes": { // Structured attribute format
"roles": ["admin", "user"],
"teams": ["ml-team", "nlp-team"],
"projects": ["llama-3", "project-x"],
"namespaces": ["research"]
},
"message": "Optional message about auth result"
}
```
Attribute-Based Access Control:
The attributes returned by the auth endpoint are used to determine which
resources the user can access. Resources can specify required attributes
using the access_attributes field. For a user to access a resource:
1. All attribute categories specified in the resource must be present in the user's attributes
2. For each category, the user must have at least one matching value
If the auth endpoint doesn't return any attributes, the user will only be able to
access resources that don't have access_attributes defined.
"""
def __init__(self, app, auth_endpoint): def __init__(self, app, auth_endpoint):
self.app = app self.app = app
self.auth_endpoint = auth_endpoint self.auth_endpoint = auth_endpoint
@ -32,25 +134,57 @@ class AuthenticationMiddleware:
path = scope.get("path", "") path = scope.get("path", "")
request_headers = {k.decode(): v.decode() for k, v in headers.items()} request_headers = {k.decode(): v.decode() for k, v in headers.items()}
# Remove sensitive headers
if "authorization" in request_headers:
del request_headers["authorization"]
query_string = scope.get("query_string", b"").decode() query_string = scope.get("query_string", b"").decode()
params = parse_qs(query_string) params = parse_qs(query_string)
auth_data = { # Build the auth request model
"api_key": api_key, auth_request = AuthRequest(
"request": { api_key=api_key,
"path": path, request=AuthRequestContext(
"headers": request_headers, path=path,
"params": params, headers=request_headers,
}, params=params,
} ),
)
# Validate with authentication endpoint # Validate with authentication endpoint
try: try:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.post(self.auth_endpoint, json=auth_data) response = await client.post(
self.auth_endpoint,
json=auth_request.model_dump(),
timeout=10.0, # Add a reasonable timeout
)
if response.status_code != 200: if response.status_code != 200:
logger.warning(f"Authentication failed: {response.status_code}") logger.warning(f"Authentication failed: {response.status_code}")
return await self._send_auth_error(send, "Authentication failed") return await self._send_auth_error(send, "Authentication failed")
# Parse and validate the auth response
try:
response_data = response.json()
auth_response = AuthResponse(**response_data)
# Store attributes in request scope for access control
if auth_response.access_attributes:
user_attributes = auth_response.access_attributes.model_dump(exclude_none=True)
else:
logger.warning("No access attributes, setting namespace to api_key by default")
user_attributes = {
"namespaces": [api_key],
}
scope["user_attributes"] = user_attributes
logger.debug(f"Authentication successful: {len(user_attributes)} attributes")
except Exception:
logger.exception("Error parsing authentication response")
return await self._send_auth_error(send, "Invalid authentication response format")
except httpx.TimeoutException:
logger.exception("Authentication request timed out")
return await self._send_auth_error(send, "Authentication service timeout")
except Exception: except Exception:
logger.exception("Error during authentication") logger.exception("Error during authentication")
return await self._send_auth_error(send, "Authentication service error") return await self._send_auth_error(send, "Authentication service error")

View file

@ -5,6 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
import inspect import inspect
import re
from typing import Dict, List from typing import Dict, List
from pydantic import BaseModel from pydantic import BaseModel
@ -19,6 +20,7 @@ class ApiEndpoint(BaseModel):
route: str route: str
method: str method: str
name: str name: str
descriptive_name: str | None = None
def toolgroup_protocol_map(): def toolgroup_protocol_map():
@ -58,8 +60,69 @@ def get_all_api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
method = "delete" method = "delete"
else: else:
method = "post" method = "post"
endpoints.append(ApiEndpoint(route=route, method=method, name=name)) endpoints.append(
ApiEndpoint(route=route, method=method, name=name, descriptive_name=webmethod.descriptive_name)
)
apis[api] = endpoints apis[api] = endpoints
return apis return apis
def initialize_endpoint_impls(impls):
endpoints = get_all_api_endpoints()
endpoint_impls = {}
def _convert_path_to_regex(path: str) -> str:
# Convert {param} to named capture groups
# handle {param:path} as well which allows for forward slashes in the param value
pattern = re.sub(
r"{(\w+)(?::path)?}",
lambda m: f"(?P<{m.group(1)}>{'[^/]+' if not m.group(0).endswith(':path') else '.+'})",
path,
)
return f"^{pattern}$"
for api, api_endpoints in endpoints.items():
if api not in impls:
continue
for endpoint in api_endpoints:
impl = impls[api]
func = getattr(impl, endpoint.name)
if endpoint.method not in endpoint_impls:
endpoint_impls[endpoint.method] = {}
endpoint_impls[endpoint.method][_convert_path_to_regex(endpoint.route)] = (
func,
endpoint.descriptive_name or endpoint.route,
)
return endpoint_impls
def find_matching_endpoint(method, path, endpoint_impls):
"""Find the matching endpoint implementation for a given method and path.
Args:
method: HTTP method (GET, POST, etc.)
path: URL path to match against
endpoint_impls: A dictionary of endpoint implementations
Returns:
A tuple of (endpoint_function, path_params, descriptive_name)
Raises:
ValueError: If no matching endpoint is found
"""
impls = endpoint_impls.get(method.lower())
if not impls:
raise ValueError(f"No endpoint found for {path}")
for regex, (func, descriptive_name) in impls.items():
match = re.match(regex, path)
if match:
# Extract named groups from the regex match
path_params = match.groupdict()
return func, path_params, descriptive_name
raise ValueError(f"No endpoint found for {path}")

View file

@ -32,6 +32,10 @@ from llama_stack.distribution.request_headers import (
request_provider_data_context, request_provider_data_context,
) )
from llama_stack.distribution.resolver import InvalidProviderError from llama_stack.distribution.resolver import InvalidProviderError
from llama_stack.distribution.server.endpoints import (
find_matching_endpoint,
initialize_endpoint_impls,
)
from llama_stack.distribution.stack import ( from llama_stack.distribution.stack import (
construct_stack, construct_stack,
redact_sensitive_fields, redact_sensitive_fields,
@ -179,8 +183,11 @@ async def sse_generator(event_gen):
def create_dynamic_typed_route(func: Any, method: str, route: str): def create_dynamic_typed_route(func: Any, method: str, route: str):
async def endpoint(request: Request, **kwargs): async def endpoint(request: Request, **kwargs):
# Use context manager for request provider data # Get auth attributes from the request scope
with request_provider_data_context(request.headers): user_attributes = request.scope.get("user_attributes", {})
# Use context manager with both provider data and auth attributes
with request_provider_data_context(request.headers, user_attributes):
is_streaming = is_streaming_request(func.__name__, request, **kwargs) is_streaming = is_streaming_request(func.__name__, request, **kwargs)
try: try:
@ -219,14 +226,30 @@ def create_dynamic_typed_route(func: Any, method: str, route: str):
class TracingMiddleware: class TracingMiddleware:
def __init__(self, app): def __init__(self, app, impls):
self.app = app self.app = app
self.impls = impls
async def __call__(self, scope, receive, send): async def __call__(self, scope, receive, send):
path = scope.get("path", "") if scope.get("type") == "lifespan":
await start_trace(path, {"__location__": "server"})
try:
return await self.app(scope, receive, send) return await self.app(scope, receive, send)
path = scope.get("path", "")
if not hasattr(self, "endpoint_impls"):
self.endpoint_impls = initialize_endpoint_impls(self.impls)
_, _, trace_path = find_matching_endpoint(scope.get("method", "GET"), path, self.endpoint_impls)
trace_context = await start_trace(trace_path, {"__location__": "server", "raw_path": path})
async def send_with_trace_id(message):
if message["type"] == "http.response.start":
headers = message.get("headers", [])
headers.append([b"x-trace-id", str(trace_context.trace_id).encode()])
message["headers"] = headers
await send(message)
try:
return await self.app(scope, receive, send_with_trace_id)
finally: finally:
await end_trace() await end_trace()
@ -348,7 +371,6 @@ def main():
logger.info(yaml.dump(safe_config, indent=2)) logger.info(yaml.dump(safe_config, indent=2))
app = FastAPI(lifespan=lifespan) app = FastAPI(lifespan=lifespan)
app.add_middleware(TracingMiddleware)
if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"): if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"):
app.add_middleware(ClientVersionMiddleware) app.add_middleware(ClientVersionMiddleware)
@ -366,7 +388,7 @@ def main():
if Api.telemetry in impls: if Api.telemetry in impls:
setup_logger(impls[Api.telemetry]) setup_logger(impls[Api.telemetry])
else: else:
setup_logger(TelemetryAdapter(TelemetryConfig())) setup_logger(TelemetryAdapter(TelemetryConfig(), {}))
all_endpoints = get_all_api_endpoints() all_endpoints = get_all_api_endpoints()
@ -412,6 +434,7 @@ def main():
app.exception_handler(Exception)(global_exception_handler) app.exception_handler(Exception)(global_exception_handler)
app.__llama_stack_impls__ = impls app.__llama_stack_impls__ = impls
app.add_middleware(TracingMiddleware, impls=impls)
import uvicorn import uvicorn

View file

@ -12,9 +12,12 @@ import pydantic
from llama_stack.distribution.datatypes import KVStoreConfig, RoutableObjectWithProvider from llama_stack.distribution.datatypes import KVStoreConfig, RoutableObjectWithProvider
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.log import get_logger
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
logger = get_logger(__name__, category="core")
class DistributionRegistry(Protocol): class DistributionRegistry(Protocol):
async def get_all(self) -> List[RoutableObjectWithProvider]: ... async def get_all(self) -> List[RoutableObjectWithProvider]: ...
@ -47,8 +50,13 @@ def _parse_registry_values(values: List[str]) -> List[RoutableObjectWithProvider
"""Utility function to parse registry values into RoutableObjectWithProvider objects.""" """Utility function to parse registry values into RoutableObjectWithProvider objects."""
all_objects = [] all_objects = []
for value in values: for value in values:
try:
obj = pydantic.TypeAdapter(RoutableObjectWithProvider).validate_json(value) obj = pydantic.TypeAdapter(RoutableObjectWithProvider).validate_json(value)
all_objects.append(obj) all_objects.append(obj)
except pydantic.ValidationError as e:
logger.error(f"Error parsing registry value, raw value: {value}. Error: {e}")
continue
return all_objects return all_objects
@ -73,7 +81,11 @@ class DiskDistributionRegistry(DistributionRegistry):
if not json_str: if not json_str:
return None return None
try:
return pydantic.TypeAdapter(RoutableObjectWithProvider).validate_json(json_str) return pydantic.TypeAdapter(RoutableObjectWithProvider).validate_json(json_str)
except pydantic.ValidationError as e:
logger.error(f"Error parsing registry value for {type}:{identifier}, raw value: {json_str}. Error: {e}")
return None
async def update(self, obj: RoutableObjectWithProvider) -> None: async def update(self, obj: RoutableObjectWithProvider) -> None:
await self.kvstore.set( await self.kvstore.set(

View file

@ -5,9 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
import streamlit as st import streamlit as st
from llama_stack_client.lib.agents.agent import Agent from llama_stack_client import Agent, AgentEventLogger, RAGDocument
from llama_stack_client.lib.agents.event_logger import EventLogger
from llama_stack_client.types.shared.document import Document
from llama_stack.distribution.ui.modules.api import llama_stack_api from llama_stack.distribution.ui.modules.api import llama_stack_api
from llama_stack.distribution.ui.modules.utils import data_url_from_file from llama_stack.distribution.ui.modules.utils import data_url_from_file
@ -35,7 +33,7 @@ def rag_chat_page():
) )
if st.button("Create Vector Database"): if st.button("Create Vector Database"):
documents = [ documents = [
Document( RAGDocument(
document_id=uploaded_file.name, document_id=uploaded_file.name,
content=data_url_from_file(uploaded_file), content=data_url_from_file(uploaded_file),
) )
@ -167,7 +165,7 @@ def rag_chat_page():
message_placeholder = st.empty() message_placeholder = st.empty()
full_response = "" full_response = ""
retrieval_response = "" retrieval_response = ""
for log in EventLogger().log(response): for log in AgentEventLogger().log(response):
log.print() log.print()
if log.role == "tool_execution": if log.role == "tool_execution":
retrieval_response += log.content.replace("====", "").strip() retrieval_response += log.content.replace("====", "").strip()

View file

@ -47,7 +47,14 @@ RecursiveType = Union[Primitive, List[Primitive], Dict[str, Primitive]]
class ToolCall(BaseModel): class ToolCall(BaseModel):
call_id: str call_id: str
tool_name: Union[BuiltinTool, str] tool_name: Union[BuiltinTool, str]
arguments: Dict[str, RecursiveType] # Plan is to deprecate the Dict in favor of a JSON string
# that is parsed on the client side instead of trying to manage
# the recursive type here.
# Making this a union so that client side can start prepping for this change.
# Eventually, we will remove both the Dict and arguments_json field,
# and arguments will just be a str
arguments: Union[str, Dict[str, RecursiveType]]
arguments_json: Optional[str] = None
@field_validator("tool_name", mode="before") @field_validator("tool_name", mode="before")
@classmethod @classmethod
@ -179,13 +186,11 @@ class TopKSamplingStrategy(BaseModel):
top_k: int = Field(..., ge=1) top_k: int = Field(..., ge=1)
SamplingStrategy = register_schema( SamplingStrategy = Annotated[
Annotated[
Union[GreedySamplingStrategy, TopPSamplingStrategy, TopKSamplingStrategy], Union[GreedySamplingStrategy, TopPSamplingStrategy, TopKSamplingStrategy],
Field(discriminator="type"), Field(discriminator="type"),
], ]
name="SamplingStrategy", register_schema(SamplingStrategy, name="SamplingStrategy")
)
@json_schema_type @json_schema_type

View file

@ -12,6 +12,7 @@
# the top-level of this source tree. # the top-level of this source tree.
import io import io
import json
import uuid import uuid
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
@ -203,6 +204,7 @@ class ChatFormat:
# This code tries to handle that case # This code tries to handle that case
if tool_name in BuiltinTool.__members__: if tool_name in BuiltinTool.__members__:
tool_name = BuiltinTool[tool_name] tool_name = BuiltinTool[tool_name]
if isinstance(tool_arguments, dict):
tool_arguments = { tool_arguments = {
"query": list(tool_arguments.values())[0], "query": list(tool_arguments.values())[0],
} }
@ -229,6 +231,7 @@ class ChatFormat:
call_id=call_id, call_id=call_id,
tool_name=tool_name, tool_name=tool_name,
arguments=tool_arguments, arguments=tool_arguments,
arguments_json=json.dumps(tool_arguments),
) )
) )
content = "" content = ""

View file

@ -244,6 +244,7 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
template_str = textwrap.dedent( template_str = textwrap.dedent(
""" """
If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)] If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
For a boolean parameter, be sure to use `True` or `False` (capitalized) for the value.
You SHOULD NOT include any other text in the response. You SHOULD NOT include any other text in the response.
Here is a list of functions in JSON format that you can invoke. Here is a list of functions in JSON format that you can invoke.

View file

@ -11,11 +11,8 @@
# top-level folder for each specific model found within the models/ directory at # top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree. # the top-level of this source tree.
from llama_stack.models.llama.datatypes import (
BuiltinTool, from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall
StopReason,
ToolCall,
)
from .prompt_templates import ( from .prompt_templates import (
BuiltinToolGenerator, BuiltinToolGenerator,

View file

@ -15,8 +15,11 @@ import json
import re import re
from typing import Optional, Tuple from typing import Optional, Tuple
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import BuiltinTool, RecursiveType, ToolCall, ToolPromptFormat from llama_stack.models.llama.datatypes import BuiltinTool, RecursiveType, ToolCall, ToolPromptFormat
logger = get_logger(name=__name__, category="inference")
BUILTIN_TOOL_PATTERN = r'\b(?P<tool_name>\w+)\.call\(query="(?P<query>[^"]*)"\)' BUILTIN_TOOL_PATTERN = r'\b(?P<tool_name>\w+)\.call\(query="(?P<query>[^"]*)"\)'
CUSTOM_TOOL_CALL_PATTERN = re.compile(r"<function=(?P<function_name>[^}]+)>(?P<args>{.*?})") CUSTOM_TOOL_CALL_PATTERN = re.compile(r"<function=(?P<function_name>[^}]+)>(?P<args>{.*?})")
@ -92,7 +95,15 @@ def parse_python_list_for_function_calls(input_string):
# Extract keyword arguments # Extract keyword arguments
for keyword in node.keywords: for keyword in node.keywords:
try:
function_args[keyword.arg] = ast.literal_eval(keyword.value) function_args[keyword.arg] = ast.literal_eval(keyword.value)
except ValueError as e:
logger.error(
f"Error parsing tool call argument '{keyword.arg}': {e}, full input string: '{input_string}'"
)
raise ValueError(
f"Error parsing tool call argument '{keyword.arg}', full input string: '{input_string}'"
) from e
result.append((function_name, function_args)) result.append((function_name, function_args))

View file

@ -180,23 +180,27 @@ class ChatAgent(ShieldRunnerMixin):
return messages return messages
async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator: async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator:
await self._initialize_tools(request.toolgroups) span = tracing.get_current_span()
async with tracing.span("create_and_execute_turn") as span: if span:
span.set_attribute("session_id", request.session_id) span.set_attribute("session_id", request.session_id)
span.set_attribute("agent_id", self.agent_id) span.set_attribute("agent_id", self.agent_id)
span.set_attribute("request", request.model_dump_json()) span.set_attribute("request", request.model_dump_json())
turn_id = str(uuid.uuid4()) turn_id = str(uuid.uuid4())
span.set_attribute("turn_id", turn_id) span.set_attribute("turn_id", turn_id)
await self._initialize_tools(request.toolgroups)
async for chunk in self._run_turn(request, turn_id): async for chunk in self._run_turn(request, turn_id):
yield chunk yield chunk
async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator: async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator:
await self._initialize_tools() span = tracing.get_current_span()
async with tracing.span("resume_turn") as span: if span:
span.set_attribute("agent_id", self.agent_id) span.set_attribute("agent_id", self.agent_id)
span.set_attribute("session_id", request.session_id) span.set_attribute("session_id", request.session_id)
span.set_attribute("turn_id", request.turn_id)
span.set_attribute("request", request.model_dump_json()) span.set_attribute("request", request.model_dump_json())
span.set_attribute("turn_id", request.turn_id)
await self._initialize_tools()
async for chunk in self._run_turn(request): async for chunk in self._run_turn(request):
yield chunk yield chunk

View file

@ -13,6 +13,9 @@ from typing import List, Optional
from pydantic import BaseModel from pydantic import BaseModel
from llama_stack.apis.agents import ToolExecutionStep, Turn from llama_stack.apis.agents import ToolExecutionStep, Turn
from llama_stack.distribution.access_control import check_access
from llama_stack.distribution.datatypes import AccessAttributes
from llama_stack.distribution.request_headers import get_auth_attributes
from llama_stack.providers.utils.kvstore import KVStore from llama_stack.providers.utils.kvstore import KVStore
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -24,6 +27,7 @@ class AgentSessionInfo(BaseModel):
# TODO: is this used anywhere? # TODO: is this used anywhere?
vector_db_id: Optional[str] = None vector_db_id: Optional[str] = None
started_at: datetime started_at: datetime
access_attributes: Optional[AccessAttributes] = None
class AgentPersistence: class AgentPersistence:
@ -33,11 +37,18 @@ class AgentPersistence:
async def create_session(self, name: str) -> str: async def create_session(self, name: str) -> str:
session_id = str(uuid.uuid4()) session_id = str(uuid.uuid4())
# Get current user's auth attributes for new sessions
auth_attributes = get_auth_attributes()
access_attributes = AccessAttributes(**auth_attributes) if auth_attributes else None
session_info = AgentSessionInfo( session_info = AgentSessionInfo(
session_id=session_id, session_id=session_id,
session_name=name, session_name=name,
started_at=datetime.now(timezone.utc), started_at=datetime.now(timezone.utc),
access_attributes=access_attributes,
) )
await self.kvstore.set( await self.kvstore.set(
key=f"session:{self.agent_id}:{session_id}", key=f"session:{self.agent_id}:{session_id}",
value=session_info.model_dump_json(), value=session_info.model_dump_json(),
@ -51,12 +62,34 @@ class AgentPersistence:
if not value: if not value:
return None return None
return AgentSessionInfo(**json.loads(value)) session_info = AgentSessionInfo(**json.loads(value))
# Check access to session
if not self._check_session_access(session_info):
return None
return session_info
def _check_session_access(self, session_info: AgentSessionInfo) -> bool:
"""Check if current user has access to the session."""
# Handle backward compatibility for old sessions without access control
if not hasattr(session_info, "access_attributes"):
return True
return check_access(session_info.session_id, session_info.access_attributes, get_auth_attributes())
async def get_session_if_accessible(self, session_id: str) -> Optional[AgentSessionInfo]:
"""Get session info if the user has access to it. For internal use by sub-session methods."""
session_info = await self.get_session_info(session_id)
if not session_info:
return None
return session_info
async def add_vector_db_to_session(self, session_id: str, vector_db_id: str): async def add_vector_db_to_session(self, session_id: str, vector_db_id: str):
session_info = await self.get_session_info(session_id) session_info = await self.get_session_if_accessible(session_id)
if session_info is None: if session_info is None:
raise ValueError(f"Session {session_id} not found") raise ValueError(f"Session {session_id} not found or access denied")
session_info.vector_db_id = vector_db_id session_info.vector_db_id = vector_db_id
await self.kvstore.set( await self.kvstore.set(
@ -65,12 +98,18 @@ class AgentPersistence:
) )
async def add_turn_to_session(self, session_id: str, turn: Turn): async def add_turn_to_session(self, session_id: str, turn: Turn):
if not await self.get_session_if_accessible(session_id):
raise ValueError(f"Session {session_id} not found or access denied")
await self.kvstore.set( await self.kvstore.set(
key=f"session:{self.agent_id}:{session_id}:{turn.turn_id}", key=f"session:{self.agent_id}:{session_id}:{turn.turn_id}",
value=turn.model_dump_json(), value=turn.model_dump_json(),
) )
async def get_session_turns(self, session_id: str) -> List[Turn]: async def get_session_turns(self, session_id: str) -> List[Turn]:
if not await self.get_session_if_accessible(session_id):
raise ValueError(f"Session {session_id} not found or access denied")
values = await self.kvstore.range( values = await self.kvstore.range(
start_key=f"session:{self.agent_id}:{session_id}:", start_key=f"session:{self.agent_id}:{session_id}:",
end_key=f"session:{self.agent_id}:{session_id}:\xff\xff\xff\xff", end_key=f"session:{self.agent_id}:{session_id}:\xff\xff\xff\xff",
@ -87,6 +126,9 @@ class AgentPersistence:
return turns return turns
async def get_session_turn(self, session_id: str, turn_id: str) -> Optional[Turn]: async def get_session_turn(self, session_id: str, turn_id: str) -> Optional[Turn]:
if not await self.get_session_if_accessible(session_id):
raise ValueError(f"Session {session_id} not found or access denied")
value = await self.kvstore.get( value = await self.kvstore.get(
key=f"session:{self.agent_id}:{session_id}:{turn_id}", key=f"session:{self.agent_id}:{session_id}:{turn_id}",
) )
@ -95,24 +137,36 @@ class AgentPersistence:
return Turn(**json.loads(value)) return Turn(**json.loads(value))
async def set_in_progress_tool_call_step(self, session_id: str, turn_id: str, step: ToolExecutionStep): async def set_in_progress_tool_call_step(self, session_id: str, turn_id: str, step: ToolExecutionStep):
if not await self.get_session_if_accessible(session_id):
raise ValueError(f"Session {session_id} not found or access denied")
await self.kvstore.set( await self.kvstore.set(
key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}", key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}",
value=step.model_dump_json(), value=step.model_dump_json(),
) )
async def get_in_progress_tool_call_step(self, session_id: str, turn_id: str) -> Optional[ToolExecutionStep]: async def get_in_progress_tool_call_step(self, session_id: str, turn_id: str) -> Optional[ToolExecutionStep]:
if not await self.get_session_if_accessible(session_id):
return None
value = await self.kvstore.get( value = await self.kvstore.get(
key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}", key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}",
) )
return ToolExecutionStep(**json.loads(value)) if value else None return ToolExecutionStep(**json.loads(value)) if value else None
async def set_num_infer_iters_in_turn(self, session_id: str, turn_id: str, num_infer_iters: int): async def set_num_infer_iters_in_turn(self, session_id: str, turn_id: str, num_infer_iters: int):
if not await self.get_session_if_accessible(session_id):
raise ValueError(f"Session {session_id} not found or access denied")
await self.kvstore.set( await self.kvstore.set(
key=f"num_infer_iters_in_turn:{self.agent_id}:{session_id}:{turn_id}", key=f"num_infer_iters_in_turn:{self.agent_id}:{session_id}:{turn_id}",
value=str(num_infer_iters), value=str(num_infer_iters),
) )
async def get_num_infer_iters_in_turn(self, session_id: str, turn_id: str) -> Optional[int]: async def get_num_infer_iters_in_turn(self, session_id: str, turn_id: str) -> Optional[int]:
if not await self.get_session_if_accessible(session_id):
return None
value = await self.kvstore.get( value = await self.kvstore.get(
key=f"num_infer_iters_in_turn:{self.agent_id}:{session_id}:{turn_id}", key=f"num_infer_iters_in_turn:{self.agent_id}:{session_id}:{turn_id}",
) )

View file

@ -35,12 +35,12 @@ class PandasDataframeDataset:
else: else:
return self.df.iloc[idx].to_dict() return self.df.iloc[idx].to_dict()
def load(self) -> None: async def load(self) -> None:
if self.df is not None: if self.df is not None:
return return
if self.dataset_def.source.type == "uri": if self.dataset_def.source.type == "uri":
self.df = get_dataframe_from_uri(self.dataset_def.source.uri) self.df = await get_dataframe_from_uri(self.dataset_def.source.uri)
elif self.dataset_def.source.type == "rows": elif self.dataset_def.source.type == "rows":
self.df = pandas.DataFrame(self.dataset_def.source.rows) self.df = pandas.DataFrame(self.dataset_def.source.rows)
else: else:
@ -95,7 +95,7 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
) -> IterrowsResponse: ) -> IterrowsResponse:
dataset_def = self.dataset_infos[dataset_id] dataset_def = self.dataset_infos[dataset_id]
dataset_impl = PandasDataframeDataset(dataset_def) dataset_impl = PandasDataframeDataset(dataset_def)
dataset_impl.load() await dataset_impl.load()
start_index = start_index or 0 start_index = start_index or 0
@ -114,7 +114,7 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
dataset_def = self.dataset_infos[dataset_id] dataset_def = self.dataset_infos[dataset_id]
dataset_impl = PandasDataframeDataset(dataset_def) dataset_impl = PandasDataframeDataset(dataset_def)
dataset_impl.load() await dataset_impl.load()
new_rows_df = pandas.DataFrame(rows) new_rows_df = pandas.DataFrame(rows)
dataset_impl.df = pandas.concat([dataset_impl.df, new_rows_df], ignore_index=True) dataset_impl.df = pandas.concat([dataset_impl.df, new_rows_df], ignore_index=True)

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import json import json
from typing import Any, Dict, List, Optional from typing import Any, Dict, List
from tqdm import tqdm from tqdm import tqdm
@ -20,8 +20,8 @@ from llama_stack.providers.inline.agents.meta_reference.agent_instance import (
from llama_stack.providers.utils.common.data_schema_validator import ColumnName from llama_stack.providers.utils.common.data_schema_validator import ColumnName
from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore import kvstore_impl
from .....apis.common.job_types import Job from .....apis.common.job_types import Job, JobStatus
from .....apis.eval.eval import BenchmarkConfig, Eval, EvaluateResponse, JobStatus from .....apis.eval.eval import BenchmarkConfig, Eval, EvaluateResponse
from .config import MetaReferenceEvalConfig from .config import MetaReferenceEvalConfig
EVAL_TASKS_PREFIX = "benchmarks:" EVAL_TASKS_PREFIX = "benchmarks:"
@ -101,7 +101,7 @@ class MetaReferenceEvalImpl(
# need job scheduler queue (ray/celery) w/ jobs api # need job scheduler queue (ray/celery) w/ jobs api
job_id = str(len(self.jobs)) job_id = str(len(self.jobs))
self.jobs[job_id] = res self.jobs[job_id] = res
return Job(job_id=job_id) return Job(job_id=job_id, status=JobStatus.completed)
async def _run_agent_generation( async def _run_agent_generation(
self, input_rows: List[Dict[str, Any]], benchmark_config: BenchmarkConfig self, input_rows: List[Dict[str, Any]], benchmark_config: BenchmarkConfig
@ -215,17 +215,18 @@ class MetaReferenceEvalImpl(
return EvaluateResponse(generations=generations, scores=score_response.results) return EvaluateResponse(generations=generations, scores=score_response.results)
async def job_status(self, benchmark_id: str, job_id: str) -> Optional[JobStatus]: async def job_status(self, benchmark_id: str, job_id: str) -> Job:
if job_id in self.jobs: if job_id in self.jobs:
return JobStatus.completed return Job(job_id=job_id, status=JobStatus.completed)
return None raise ValueError(f"Job {job_id} not found")
async def job_cancel(self, benchmark_id: str, job_id: str) -> None: async def job_cancel(self, benchmark_id: str, job_id: str) -> None:
raise NotImplementedError("Job cancel is not implemented yet") raise NotImplementedError("Job cancel is not implemented yet")
async def job_result(self, benchmark_id: str, job_id: str) -> EvaluateResponse: async def job_result(self, benchmark_id: str, job_id: str) -> EvaluateResponse:
status = await self.job_status(benchmark_id, job_id) job = await self.job_status(benchmark_id, job_id)
status = job.status
if not status or status != JobStatus.completed: if not status or status != JobStatus.completed:
raise ValueError(f"Job is not completed, Status: {status.value}") raise ValueError(f"Job is not completed, Status: {status.value}")

View file

@ -582,6 +582,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
tool_name=t.function.name, tool_name=t.function.name,
# vLLM function args come back as a string. Llama Stack expects JSON. # vLLM function args come back as a string. Llama Stack expects JSON.
arguments=json.loads(t.function.arguments), arguments=json.loads(t.function.arguments),
arguments_json=t.function.arguments,
) )
for t in vllm_message.tool_calls for t in vllm_message.tool_calls
], ],

View file

@ -23,7 +23,9 @@ from llama_stack.providers.utils.common.data_schema_validator import (
from .config import BasicScoringConfig from .config import BasicScoringConfig
from .scoring_fn.bfcl_scoring_fn import BFCLScoringFn from .scoring_fn.bfcl_scoring_fn import BFCLScoringFn
from .scoring_fn.docvqa_scoring_fn import DocVQAScoringFn
from .scoring_fn.equality_scoring_fn import EqualityScoringFn from .scoring_fn.equality_scoring_fn import EqualityScoringFn
from .scoring_fn.ifeval_scoring_fn import IfEvalScoringFn
from .scoring_fn.regex_parser_math_response_scoring_fn import ( from .scoring_fn.regex_parser_math_response_scoring_fn import (
RegexParserMathResponseScoringFn, RegexParserMathResponseScoringFn,
) )
@ -36,6 +38,8 @@ FIXED_FNS = [
RegexParserScoringFn, RegexParserScoringFn,
RegexParserMathResponseScoringFn, RegexParserMathResponseScoringFn,
BFCLScoringFn, BFCLScoringFn,
IfEvalScoringFn,
DocVQAScoringFn,
] ]

View file

@ -0,0 +1,240 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import re
from typing import Any, Dict, Optional
from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
from .fn_defs.docvqa import docvqa
CONTRACTIONS = {
"aint": "ain't",
"arent": "aren't",
"cant": "can't",
"couldve": "could've",
"couldnt": "couldn't",
"couldn'tve": "couldn't've",
"couldnt've": "couldn't've",
"didnt": "didn't",
"doesnt": "doesn't",
"dont": "don't",
"hadnt": "hadn't",
"hadnt've": "hadn't've",
"hadn'tve": "hadn't've",
"hasnt": "hasn't",
"havent": "haven't",
"hed": "he'd",
"hed've": "he'd've",
"he'dve": "he'd've",
"hes": "he's",
"howd": "how'd",
"howll": "how'll",
"hows": "how's",
"Id've": "I'd've",
"I'dve": "I'd've",
"Im": "I'm",
"Ive": "I've",
"isnt": "isn't",
"itd": "it'd",
"itd've": "it'd've",
"it'dve": "it'd've",
"itll": "it'll",
"let's": "let's",
"maam": "ma'am",
"mightnt": "mightn't",
"mightnt've": "mightn't've",
"mightn'tve": "mightn't've",
"mightve": "might've",
"mustnt": "mustn't",
"mustve": "must've",
"neednt": "needn't",
"notve": "not've",
"oclock": "o'clock",
"oughtnt": "oughtn't",
"ow's'at": "'ow's'at",
"'ows'at": "'ow's'at",
"'ow'sat": "'ow's'at",
"shant": "shan't",
"shed've": "she'd've",
"she'dve": "she'd've",
"she's": "she's",
"shouldve": "should've",
"shouldnt": "shouldn't",
"shouldnt've": "shouldn't've",
"shouldn'tve": "shouldn't've",
"somebody'd": "somebodyd",
"somebodyd've": "somebody'd've",
"somebody'dve": "somebody'd've",
"somebodyll": "somebody'll",
"somebodys": "somebody's",
"someoned": "someone'd",
"someoned've": "someone'd've",
"someone'dve": "someone'd've",
"someonell": "someone'll",
"someones": "someone's",
"somethingd": "something'd",
"somethingd've": "something'd've",
"something'dve": "something'd've",
"somethingll": "something'll",
"thats": "that's",
"thered": "there'd",
"thered've": "there'd've",
"there'dve": "there'd've",
"therere": "there're",
"theres": "there's",
"theyd": "they'd",
"theyd've": "they'd've",
"they'dve": "they'd've",
"theyll": "they'll",
"theyre": "they're",
"theyve": "they've",
"twas": "'twas",
"wasnt": "wasn't",
"wed've": "we'd've",
"we'dve": "we'd've",
"weve": "we've",
"werent": "weren't",
"whatll": "what'll",
"whatre": "what're",
"whats": "what's",
"whatve": "what've",
"whens": "when's",
"whered": "where'd",
"wheres": "where's",
"whereve": "where've",
"whod": "who'd",
"whod've": "who'd've",
"who'dve": "who'd've",
"wholl": "who'll",
"whos": "who's",
"whove": "who've",
"whyll": "why'll",
"whyre": "why're",
"whys": "why's",
"wont": "won't",
"wouldve": "would've",
"wouldnt": "wouldn't",
"wouldnt've": "wouldn't've",
"wouldn'tve": "wouldn't've",
"yall": "y'all",
"yall'll": "y'all'll",
"y'allll": "y'all'll",
"yall'd've": "y'all'd've",
"y'alld've": "y'all'd've",
"y'all'dve": "y'all'd've",
"youd": "you'd",
"youd've": "you'd've",
"you'dve": "you'd've",
"youll": "you'll",
"youre": "you're",
"youve": "you've",
"1st": "first",
"2nd": "second",
"3rd": "third",
}
NUMBERS = {
"none": "0",
"zero": "0",
"one": "1",
"two": "2",
"three": "3",
"four": "4",
"five": "5",
"six": "6",
"seven": "7",
"eight": "8",
"nine": "9",
"ten": "10",
}
ARTICLES = [
"a",
"an",
"the",
"to",
"in",
"from",
"by",
] # Contains a bit more than just articles, but we want to get rid of these elements influencing the accuracy
PERIOD_STRIP = re.compile(r"(?!<=\d)(\.)(?!\d)")
COMMA_STRIP = re.compile(r"(\d)(\,)(\d)")
PUNCTUATION = [
";",
r"/",
"[",
"]",
'"',
"{",
"}",
"(",
")",
"=",
"+",
"\\",
"_",
"-",
">",
"<",
"@",
"`",
",",
"?",
"!",
]
def normalize_answer(s: str) -> str:
# process punctuation
for p in PUNCTUATION:
if (p + " " in s or " " + p in s) or (re.search(COMMA_STRIP, s) is not None):
s = s.replace(p, "")
else:
s = s.replace(p, " ")
s = PERIOD_STRIP.sub("", s, re.UNICODE)
# process digits and articles
temp_text = s.lower().split()
out_text = []
for word in temp_text:
word = NUMBERS.setdefault(word, word)
if word not in ARTICLES:
out_text.append(word)
# standardize contractions
for word_id, word in enumerate(out_text):
if word in CONTRACTIONS:
out_text[word_id] = CONTRACTIONS[word]
return " ".join(out_text)
class DocVQAScoringFn(RegisteredBaseScoringFn):
"""
docvqa basically matches the generated answer against several allowed
choices, but we need to normalize the answer to avoid penalizing
trivial differences
"""
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.supported_fn_defs_registry = {
docvqa.identifier: docvqa,
}
async def score_row(
self,
input_row: Dict[str, Any],
scoring_fn_identifier: Optional[str] = "docvqa",
scoring_params: Optional[ScoringFnParams] = None,
) -> ScoringResultRow:
expected_answers = json.loads(input_row["expected_answer"])
generated_answer = input_row["generated_answer"]
score = 1.0 if normalize_answer(generated_answer) in [normalize_answer(s) for s in expected_answers] else 0.0
return {
"score": score,
}

View file

@ -0,0 +1,21 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import (
AggregationFunctionType,
BasicScoringFnParams,
ScoringFn,
)
docvqa = ScoringFn(
identifier="basic::docvqa",
description="DocVQA Visual Question & Answer scoring function",
return_type=NumberType(),
provider_id="basic",
provider_resource_id="docvqa",
params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.accuracy]),
)

View file

@ -0,0 +1,23 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import (
AggregationFunctionType,
BasicScoringFnParams,
ScoringFn,
)
ifeval = ScoringFn(
identifier="basic::ifeval",
description="Eval intruction follow capacity by checkping how many instructions can be followed in each example",
return_type=NumberType(),
provider_id="basic",
provider_resource_id="ifeval",
params=BasicScoringFnParams(
aggregation_functions=[AggregationFunctionType.weighted_average],
),
)

View file

@ -0,0 +1,80 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, Optional
from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
from .fn_defs.ifeval import (
ifeval,
)
class IfEvalScoringFn(RegisteredBaseScoringFn):
"""
A scoring_fn Instruction-Following Eval (IFEval) benchmark
"""
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.supported_fn_defs_registry = {
ifeval.identifier: ifeval,
}
async def score_row(
self,
input_row: Dict[str, Any],
scoring_fn_identifier: Optional[str] = None,
scoring_params: Optional[ScoringFnParams] = None,
) -> ScoringResultRow:
from ..utils.ifeval_utils import INSTRUCTION_DICT, INSTRUCTION_LIST
assert scoring_fn_identifier is not None, "Scoring function identifier not found."
fn_def = self.supported_fn_defs_registry[scoring_fn_identifier]
if scoring_params is not None:
fn_def.params = scoring_params
instruction_list = input_row["instruction_id_list"]
generated_answer = input_row["generated_answer"].strip()
is_following_list = []
results = dict(
{k + "_correct": 0.0 for k in INSTRUCTION_LIST},
**{k + "_total": 0.0 for k in INSTRUCTION_LIST},
)
for index, instruction_id in enumerate(instruction_list):
instruction_cls = INSTRUCTION_DICT[instruction_id]
instruction = instruction_cls(instruction_id)
results[instruction_id + "_total"] += 1.0
results[instruction_id.split(":")[0] + "_total"] += 1.0
clean_input_row = {k: v for k, v in input_row["kwargs"][index].items() if v is not None}
print(clean_input_row)
instruction.build_description(**clean_input_row)
args = instruction.get_instruction_args()
if args and "prompt" in args:
instruction.build_description(prompt=input_row["prompt"])
if generated_answer and instruction.check_following(generated_answer):
is_following_list.append(True)
results[instruction_id + "_correct"] += 1.0
results[instruction_id.split(":")[0] + "_correct"] += 1.0
else:
is_following_list.append(False)
if len(is_following_list) == 0:
return {
"score": 0.0,
"weight": 0.0,
}
return {
"score": float(sum(is_following_list)) / float(len(is_following_list)),
"weight": float(len(is_following_list)),
}

File diff suppressed because it is too large Load diff

View file

@ -6,12 +6,14 @@
from typing import Any, Dict from typing import Any, Dict
from llama_stack.distribution.datatypes import Api
from .config import TelemetryConfig, TelemetrySink from .config import TelemetryConfig, TelemetrySink
__all__ = ["TelemetryConfig", "TelemetrySink"] __all__ = ["TelemetryConfig", "TelemetrySink"]
async def get_provider_impl(config: TelemetryConfig, deps: Dict[str, Any]): async def get_provider_impl(config: TelemetryConfig, deps: Dict[Api, Any]):
from .telemetry import TelemetryAdapter from .telemetry import TelemetryAdapter
impl = TelemetryAdapter(config, deps) impl = TelemetryAdapter(config, deps)

View file

@ -13,19 +13,20 @@ from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
class TelemetrySink(str, Enum): class TelemetrySink(str, Enum):
OTEL = "otel" OTEL_TRACE = "otel_trace"
OTEL_METRIC = "otel_metric"
SQLITE = "sqlite" SQLITE = "sqlite"
CONSOLE = "console" CONSOLE = "console"
class TelemetryConfig(BaseModel): class TelemetryConfig(BaseModel):
otel_endpoint: str = Field( otel_trace_endpoint: str = Field(
default="http://localhost:4318/v1/traces", default="http://localhost:4318/v1/traces",
description="The OpenTelemetry collector endpoint URL", description="The OpenTelemetry collector endpoint URL for traces",
) )
service_name: str = Field( otel_metric_endpoint: str = Field(
default="llama-stack", default="http://localhost:4318/v1/metrics",
description="The service name to use for telemetry", description="The OpenTelemetry collector endpoint URL for metrics",
) )
sinks: List[TelemetrySink] = Field( sinks: List[TelemetrySink] = Field(
default=[TelemetrySink.CONSOLE, TelemetrySink.SQLITE], default=[TelemetrySink.CONSOLE, TelemetrySink.SQLITE],
@ -46,7 +47,6 @@ class TelemetryConfig(BaseModel):
@classmethod @classmethod
def sample_run_config(cls, __distro_dir__: str, db_name: str = "trace_store.db") -> Dict[str, Any]: def sample_run_config(cls, __distro_dir__: str, db_name: str = "trace_store.db") -> Dict[str, Any]:
return { return {
"service_name": "${env.OTEL_SERVICE_NAME:llama-stack}",
"sinks": "${env.TELEMETRY_SINKS:console,sqlite}", "sinks": "${env.TELEMETRY_SINKS:console,sqlite}",
"sqlite_db_path": "${env.SQLITE_DB_PATH:" + __distro_dir__ + "/" + db_name + "}", "sqlite_db_path": "${env.SQLITE_DB_PATH:" + __distro_dir__ + "/" + db_name + "}",
} }

Some files were not shown because too many files have changed in this diff Show more