forked from phoenix-oss/llama-stack-mirror
merge
This commit is contained in:
commit
a54d757ade
197 changed files with 9392 additions and 3089 deletions
2
.github/TRIAGERS.md
vendored
Normal file
2
.github/TRIAGERS.md
vendored
Normal file
|
@ -0,0 +1,2 @@
|
||||||
|
# This file documents Triage members in the Llama Stack community
|
||||||
|
@franciscojavierarceo @leseb
|
5
.github/workflows/integration-tests.yml
vendored
5
.github/workflows/integration-tests.yml
vendored
|
@ -14,6 +14,10 @@ on:
|
||||||
- 'requirements.txt'
|
- 'requirements.txt'
|
||||||
- '.github/workflows/integration-tests.yml' # This workflow
|
- '.github/workflows/integration-tests.yml' # This workflow
|
||||||
|
|
||||||
|
concurrency:
|
||||||
|
group: ${{ github.workflow }}-${{ github.ref }}
|
||||||
|
cancel-in-progress: true
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
test-matrix:
|
test-matrix:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
@ -52,6 +56,7 @@ jobs:
|
||||||
# always test against the latest version of the client
|
# always test against the latest version of the client
|
||||||
uv pip install git+https://github.com/meta-llama/llama-stack-client-python.git@main
|
uv pip install git+https://github.com/meta-llama/llama-stack-client-python.git@main
|
||||||
uv pip install -e .
|
uv pip install -e .
|
||||||
|
llama stack build --template ollama --image-type venv
|
||||||
|
|
||||||
- name: Wait for Ollama to start
|
- name: Wait for Ollama to start
|
||||||
run: |
|
run: |
|
||||||
|
|
4
.github/workflows/pre-commit.yml
vendored
4
.github/workflows/pre-commit.yml
vendored
|
@ -5,6 +5,10 @@ on:
|
||||||
push:
|
push:
|
||||||
branches: [main]
|
branches: [main]
|
||||||
|
|
||||||
|
concurrency:
|
||||||
|
group: ${{ github.workflow }}-${{ github.ref }}
|
||||||
|
cancel-in-progress: true
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
pre-commit:
|
pre-commit:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
|
4
.github/workflows/providers-build.yml
vendored
4
.github/workflows/providers-build.yml
vendored
|
@ -18,6 +18,10 @@ on:
|
||||||
- 'llama_stack/distribution/*.sh'
|
- 'llama_stack/distribution/*.sh'
|
||||||
- '.github/workflows/providers-build.yml'
|
- '.github/workflows/providers-build.yml'
|
||||||
|
|
||||||
|
concurrency:
|
||||||
|
group: ${{ github.workflow }}-${{ github.ref }}
|
||||||
|
cancel-in-progress: true
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
generate-matrix:
|
generate-matrix:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
|
4
.github/workflows/semantic-pr.yml
vendored
4
.github/workflows/semantic-pr.yml
vendored
|
@ -8,6 +8,10 @@ on:
|
||||||
- reopened
|
- reopened
|
||||||
- synchronize
|
- synchronize
|
||||||
|
|
||||||
|
concurrency:
|
||||||
|
group: ${{ github.workflow }}-${{ github.ref }}
|
||||||
|
cancel-in-progress: true
|
||||||
|
|
||||||
permissions:
|
permissions:
|
||||||
contents: read
|
contents: read
|
||||||
|
|
||||||
|
|
4
.github/workflows/unit-tests.yml
vendored
4
.github/workflows/unit-tests.yml
vendored
|
@ -15,6 +15,10 @@ on:
|
||||||
- '.github/workflows/unit-tests.yml' # This workflow
|
- '.github/workflows/unit-tests.yml' # This workflow
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
|
|
||||||
|
concurrency:
|
||||||
|
group: ${{ github.workflow }}-${{ github.ref }}
|
||||||
|
cancel-in-progress: true
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
unit-tests:
|
unit-tests:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
|
4
.github/workflows/update-readthedocs.yml
vendored
4
.github/workflows/update-readthedocs.yml
vendored
|
@ -22,6 +22,10 @@ on:
|
||||||
- 'pyproject.toml'
|
- 'pyproject.toml'
|
||||||
- '.github/workflows/update-readthedocs.yml'
|
- '.github/workflows/update-readthedocs.yml'
|
||||||
|
|
||||||
|
concurrency:
|
||||||
|
group: ${{ github.workflow }}-${{ github.ref }}
|
||||||
|
cancel-in-progress: true
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
update-readthedocs:
|
update-readthedocs:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
|
|
@ -89,10 +89,11 @@ repos:
|
||||||
name: API Spec Codegen
|
name: API Spec Codegen
|
||||||
additional_dependencies:
|
additional_dependencies:
|
||||||
- uv==0.6.2
|
- uv==0.6.2
|
||||||
entry: sh -c 'uv run --with ".[dev]" ./docs/openapi_generator/run_openapi_generator.sh > /dev/null 2>&1'
|
entry: sh -c 'uv run --with ".[dev]" ./docs/openapi_generator/run_openapi_generator.sh > /dev/null'
|
||||||
language: python
|
language: python
|
||||||
pass_filenames: false
|
pass_filenames: false
|
||||||
require_serial: true
|
require_serial: true
|
||||||
|
files: ^llama_stack/apis/|^docs/openapi_generator/
|
||||||
|
|
||||||
ci:
|
ci:
|
||||||
autofix_commit_msg: 🎨 [pre-commit.ci] Auto format from pre-commit.com hooks
|
autofix_commit_msg: 🎨 [pre-commit.ci] Auto format from pre-commit.com hooks
|
||||||
|
|
|
@ -135,9 +135,11 @@ uv sync
|
||||||
|
|
||||||
## Coding Style
|
## Coding Style
|
||||||
|
|
||||||
|
* Comments should provide meaningful insights into the code. Avoid filler comments that simply describe the next step, as they create unnecessary clutter, same goes for docstrings.
|
||||||
|
* Prefer comments to clarify surprising behavior and/or relationships between parts of the code rather than explain what the next line of code does.
|
||||||
|
* Catching exceptions, prefer using a specific exception type rather than a broad catch-all like `Exception`.
|
||||||
|
* Error messages should be prefixed with "Failed to ..."
|
||||||
* 4 spaces for indentation rather than tabs
|
* 4 spaces for indentation rather than tabs
|
||||||
* 80 character line length
|
|
||||||
* ...
|
|
||||||
|
|
||||||
## Common Tasks
|
## Common Tasks
|
||||||
|
|
||||||
|
@ -166,7 +168,7 @@ If you have made changes to a provider's configuration in any form (introducing
|
||||||
If you are making changes to the documentation at [https://llama-stack.readthedocs.io/en/latest/](https://llama-stack.readthedocs.io/en/latest/), you can use the following command to build the documentation and preview your changes. You will need [Sphinx](https://www.sphinx-doc.org/en/master/) and the readthedocs theme.
|
If you are making changes to the documentation at [https://llama-stack.readthedocs.io/en/latest/](https://llama-stack.readthedocs.io/en/latest/), you can use the following command to build the documentation and preview your changes. You will need [Sphinx](https://www.sphinx-doc.org/en/master/) and the readthedocs theme.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
cd llama-stack/docs
|
cd docs
|
||||||
uv sync --extra docs
|
uv sync --extra docs
|
||||||
|
|
||||||
# This rebuilds the documentation pages.
|
# This rebuilds the documentation pages.
|
||||||
|
|
|
@ -6,10 +6,12 @@
|
||||||
"chardet",
|
"chardet",
|
||||||
"chromadb-client",
|
"chromadb-client",
|
||||||
"datasets",
|
"datasets",
|
||||||
|
"emoji",
|
||||||
"faiss-cpu",
|
"faiss-cpu",
|
||||||
"fastapi",
|
"fastapi",
|
||||||
"fire",
|
"fire",
|
||||||
"httpx",
|
"httpx",
|
||||||
|
"langdetect",
|
||||||
"matplotlib",
|
"matplotlib",
|
||||||
"mcp",
|
"mcp",
|
||||||
"nltk",
|
"nltk",
|
||||||
|
@ -21,6 +23,7 @@
|
||||||
"psycopg2-binary",
|
"psycopg2-binary",
|
||||||
"pymongo",
|
"pymongo",
|
||||||
"pypdf",
|
"pypdf",
|
||||||
|
"pythainlp",
|
||||||
"redis",
|
"redis",
|
||||||
"requests",
|
"requests",
|
||||||
"scikit-learn",
|
"scikit-learn",
|
||||||
|
@ -37,10 +40,12 @@
|
||||||
"chardet",
|
"chardet",
|
||||||
"chromadb-client",
|
"chromadb-client",
|
||||||
"datasets",
|
"datasets",
|
||||||
|
"emoji",
|
||||||
"faiss-cpu",
|
"faiss-cpu",
|
||||||
"fastapi",
|
"fastapi",
|
||||||
"fire",
|
"fire",
|
||||||
"httpx",
|
"httpx",
|
||||||
|
"langdetect",
|
||||||
"matplotlib",
|
"matplotlib",
|
||||||
"nltk",
|
"nltk",
|
||||||
"numpy",
|
"numpy",
|
||||||
|
@ -51,6 +56,7 @@
|
||||||
"psycopg2-binary",
|
"psycopg2-binary",
|
||||||
"pymongo",
|
"pymongo",
|
||||||
"pypdf",
|
"pypdf",
|
||||||
|
"pythainlp",
|
||||||
"redis",
|
"redis",
|
||||||
"requests",
|
"requests",
|
||||||
"scikit-learn",
|
"scikit-learn",
|
||||||
|
@ -68,10 +74,12 @@
|
||||||
"chardet",
|
"chardet",
|
||||||
"chromadb-client",
|
"chromadb-client",
|
||||||
"datasets",
|
"datasets",
|
||||||
|
"emoji",
|
||||||
"fastapi",
|
"fastapi",
|
||||||
"fire",
|
"fire",
|
||||||
"fireworks-ai",
|
"fireworks-ai",
|
||||||
"httpx",
|
"httpx",
|
||||||
|
"langdetect",
|
||||||
"matplotlib",
|
"matplotlib",
|
||||||
"mcp",
|
"mcp",
|
||||||
"nltk",
|
"nltk",
|
||||||
|
@ -83,6 +91,7 @@
|
||||||
"psycopg2-binary",
|
"psycopg2-binary",
|
||||||
"pymongo",
|
"pymongo",
|
||||||
"pypdf",
|
"pypdf",
|
||||||
|
"pythainlp",
|
||||||
"redis",
|
"redis",
|
||||||
"requests",
|
"requests",
|
||||||
"scikit-learn",
|
"scikit-learn",
|
||||||
|
@ -102,11 +111,13 @@
|
||||||
"chardet",
|
"chardet",
|
||||||
"chromadb-client",
|
"chromadb-client",
|
||||||
"datasets",
|
"datasets",
|
||||||
|
"emoji",
|
||||||
"faiss-cpu",
|
"faiss-cpu",
|
||||||
"fastapi",
|
"fastapi",
|
||||||
"fire",
|
"fire",
|
||||||
"httpx",
|
"httpx",
|
||||||
"huggingface_hub",
|
"huggingface_hub",
|
||||||
|
"langdetect",
|
||||||
"matplotlib",
|
"matplotlib",
|
||||||
"nltk",
|
"nltk",
|
||||||
"numpy",
|
"numpy",
|
||||||
|
@ -117,6 +128,7 @@
|
||||||
"psycopg2-binary",
|
"psycopg2-binary",
|
||||||
"pymongo",
|
"pymongo",
|
||||||
"pypdf",
|
"pypdf",
|
||||||
|
"pythainlp",
|
||||||
"redis",
|
"redis",
|
||||||
"requests",
|
"requests",
|
||||||
"scikit-learn",
|
"scikit-learn",
|
||||||
|
@ -134,10 +146,12 @@
|
||||||
"chardet",
|
"chardet",
|
||||||
"chromadb-client",
|
"chromadb-client",
|
||||||
"datasets",
|
"datasets",
|
||||||
|
"emoji",
|
||||||
"fastapi",
|
"fastapi",
|
||||||
"fire",
|
"fire",
|
||||||
"fireworks-ai",
|
"fireworks-ai",
|
||||||
"httpx",
|
"httpx",
|
||||||
|
"langdetect",
|
||||||
"litellm",
|
"litellm",
|
||||||
"matplotlib",
|
"matplotlib",
|
||||||
"mcp",
|
"mcp",
|
||||||
|
@ -150,6 +164,7 @@
|
||||||
"psycopg2-binary",
|
"psycopg2-binary",
|
||||||
"pymongo",
|
"pymongo",
|
||||||
"pypdf",
|
"pypdf",
|
||||||
|
"pythainlp",
|
||||||
"redis",
|
"redis",
|
||||||
"requests",
|
"requests",
|
||||||
"scikit-learn",
|
"scikit-learn",
|
||||||
|
@ -168,11 +183,13 @@
|
||||||
"chardet",
|
"chardet",
|
||||||
"chromadb-client",
|
"chromadb-client",
|
||||||
"datasets",
|
"datasets",
|
||||||
|
"emoji",
|
||||||
"faiss-cpu",
|
"faiss-cpu",
|
||||||
"fastapi",
|
"fastapi",
|
||||||
"fire",
|
"fire",
|
||||||
"fireworks-ai",
|
"fireworks-ai",
|
||||||
"httpx",
|
"httpx",
|
||||||
|
"langdetect",
|
||||||
"matplotlib",
|
"matplotlib",
|
||||||
"mcp",
|
"mcp",
|
||||||
"nltk",
|
"nltk",
|
||||||
|
@ -184,6 +201,7 @@
|
||||||
"psycopg2-binary",
|
"psycopg2-binary",
|
||||||
"pymongo",
|
"pymongo",
|
||||||
"pypdf",
|
"pypdf",
|
||||||
|
"pythainlp",
|
||||||
"redis",
|
"redis",
|
||||||
"requests",
|
"requests",
|
||||||
"scikit-learn",
|
"scikit-learn",
|
||||||
|
@ -200,10 +218,12 @@
|
||||||
"blobfile",
|
"blobfile",
|
||||||
"chardet",
|
"chardet",
|
||||||
"datasets",
|
"datasets",
|
||||||
|
"emoji",
|
||||||
"faiss-cpu",
|
"faiss-cpu",
|
||||||
"fastapi",
|
"fastapi",
|
||||||
"fire",
|
"fire",
|
||||||
"httpx",
|
"httpx",
|
||||||
|
"langdetect",
|
||||||
"litellm",
|
"litellm",
|
||||||
"matplotlib",
|
"matplotlib",
|
||||||
"nltk",
|
"nltk",
|
||||||
|
@ -215,6 +235,7 @@
|
||||||
"psycopg2-binary",
|
"psycopg2-binary",
|
||||||
"pymongo",
|
"pymongo",
|
||||||
"pypdf",
|
"pypdf",
|
||||||
|
"pythainlp",
|
||||||
"redis",
|
"redis",
|
||||||
"requests",
|
"requests",
|
||||||
"scikit-learn",
|
"scikit-learn",
|
||||||
|
@ -231,11 +252,13 @@
|
||||||
"chardet",
|
"chardet",
|
||||||
"chromadb-client",
|
"chromadb-client",
|
||||||
"datasets",
|
"datasets",
|
||||||
|
"emoji",
|
||||||
"faiss-cpu",
|
"faiss-cpu",
|
||||||
"fastapi",
|
"fastapi",
|
||||||
"fire",
|
"fire",
|
||||||
"httpx",
|
"httpx",
|
||||||
"huggingface_hub",
|
"huggingface_hub",
|
||||||
|
"langdetect",
|
||||||
"matplotlib",
|
"matplotlib",
|
||||||
"mcp",
|
"mcp",
|
||||||
"nltk",
|
"nltk",
|
||||||
|
@ -247,6 +270,7 @@
|
||||||
"psycopg2-binary",
|
"psycopg2-binary",
|
||||||
"pymongo",
|
"pymongo",
|
||||||
"pypdf",
|
"pypdf",
|
||||||
|
"pythainlp",
|
||||||
"redis",
|
"redis",
|
||||||
"requests",
|
"requests",
|
||||||
"scikit-learn",
|
"scikit-learn",
|
||||||
|
@ -263,11 +287,13 @@
|
||||||
"chardet",
|
"chardet",
|
||||||
"chromadb-client",
|
"chromadb-client",
|
||||||
"datasets",
|
"datasets",
|
||||||
|
"emoji",
|
||||||
"faiss-cpu",
|
"faiss-cpu",
|
||||||
"fastapi",
|
"fastapi",
|
||||||
"fire",
|
"fire",
|
||||||
"httpx",
|
"httpx",
|
||||||
"huggingface_hub",
|
"huggingface_hub",
|
||||||
|
"langdetect",
|
||||||
"matplotlib",
|
"matplotlib",
|
||||||
"mcp",
|
"mcp",
|
||||||
"nltk",
|
"nltk",
|
||||||
|
@ -279,6 +305,7 @@
|
||||||
"psycopg2-binary",
|
"psycopg2-binary",
|
||||||
"pymongo",
|
"pymongo",
|
||||||
"pypdf",
|
"pypdf",
|
||||||
|
"pythainlp",
|
||||||
"redis",
|
"redis",
|
||||||
"requests",
|
"requests",
|
||||||
"scikit-learn",
|
"scikit-learn",
|
||||||
|
@ -297,11 +324,13 @@
|
||||||
"chardet",
|
"chardet",
|
||||||
"chromadb-client",
|
"chromadb-client",
|
||||||
"datasets",
|
"datasets",
|
||||||
|
"emoji",
|
||||||
"fairscale",
|
"fairscale",
|
||||||
"faiss-cpu",
|
"faiss-cpu",
|
||||||
"fastapi",
|
"fastapi",
|
||||||
"fire",
|
"fire",
|
||||||
"httpx",
|
"httpx",
|
||||||
|
"langdetect",
|
||||||
"lm-format-enforcer",
|
"lm-format-enforcer",
|
||||||
"matplotlib",
|
"matplotlib",
|
||||||
"mcp",
|
"mcp",
|
||||||
|
@ -314,6 +343,7 @@
|
||||||
"psycopg2-binary",
|
"psycopg2-binary",
|
||||||
"pymongo",
|
"pymongo",
|
||||||
"pypdf",
|
"pypdf",
|
||||||
|
"pythainlp",
|
||||||
"redis",
|
"redis",
|
||||||
"requests",
|
"requests",
|
||||||
"scikit-learn",
|
"scikit-learn",
|
||||||
|
@ -334,12 +364,14 @@
|
||||||
"chardet",
|
"chardet",
|
||||||
"chromadb-client",
|
"chromadb-client",
|
||||||
"datasets",
|
"datasets",
|
||||||
|
"emoji",
|
||||||
"fairscale",
|
"fairscale",
|
||||||
"faiss-cpu",
|
"faiss-cpu",
|
||||||
"fastapi",
|
"fastapi",
|
||||||
"fbgemm-gpu",
|
"fbgemm-gpu",
|
||||||
"fire",
|
"fire",
|
||||||
"httpx",
|
"httpx",
|
||||||
|
"langdetect",
|
||||||
"lm-format-enforcer",
|
"lm-format-enforcer",
|
||||||
"matplotlib",
|
"matplotlib",
|
||||||
"mcp",
|
"mcp",
|
||||||
|
@ -352,6 +384,7 @@
|
||||||
"psycopg2-binary",
|
"psycopg2-binary",
|
||||||
"pymongo",
|
"pymongo",
|
||||||
"pypdf",
|
"pypdf",
|
||||||
|
"pythainlp",
|
||||||
"redis",
|
"redis",
|
||||||
"requests",
|
"requests",
|
||||||
"scikit-learn",
|
"scikit-learn",
|
||||||
|
@ -370,10 +403,12 @@
|
||||||
"aiosqlite",
|
"aiosqlite",
|
||||||
"blobfile",
|
"blobfile",
|
||||||
"chardet",
|
"chardet",
|
||||||
|
"emoji",
|
||||||
"faiss-cpu",
|
"faiss-cpu",
|
||||||
"fastapi",
|
"fastapi",
|
||||||
"fire",
|
"fire",
|
||||||
"httpx",
|
"httpx",
|
||||||
|
"langdetect",
|
||||||
"matplotlib",
|
"matplotlib",
|
||||||
"nltk",
|
"nltk",
|
||||||
"numpy",
|
"numpy",
|
||||||
|
@ -385,6 +420,7 @@
|
||||||
"psycopg2-binary",
|
"psycopg2-binary",
|
||||||
"pymongo",
|
"pymongo",
|
||||||
"pypdf",
|
"pypdf",
|
||||||
|
"pythainlp",
|
||||||
"redis",
|
"redis",
|
||||||
"requests",
|
"requests",
|
||||||
"scikit-learn",
|
"scikit-learn",
|
||||||
|
@ -401,10 +437,12 @@
|
||||||
"chardet",
|
"chardet",
|
||||||
"chromadb-client",
|
"chromadb-client",
|
||||||
"datasets",
|
"datasets",
|
||||||
|
"emoji",
|
||||||
"faiss-cpu",
|
"faiss-cpu",
|
||||||
"fastapi",
|
"fastapi",
|
||||||
"fire",
|
"fire",
|
||||||
"httpx",
|
"httpx",
|
||||||
|
"langdetect",
|
||||||
"matplotlib",
|
"matplotlib",
|
||||||
"mcp",
|
"mcp",
|
||||||
"nltk",
|
"nltk",
|
||||||
|
@ -417,6 +455,7 @@
|
||||||
"psycopg2-binary",
|
"psycopg2-binary",
|
||||||
"pymongo",
|
"pymongo",
|
||||||
"pypdf",
|
"pypdf",
|
||||||
|
"pythainlp",
|
||||||
"redis",
|
"redis",
|
||||||
"requests",
|
"requests",
|
||||||
"scikit-learn",
|
"scikit-learn",
|
||||||
|
@ -432,9 +471,11 @@
|
||||||
"chardet",
|
"chardet",
|
||||||
"chromadb-client",
|
"chromadb-client",
|
||||||
"datasets",
|
"datasets",
|
||||||
|
"emoji",
|
||||||
"fastapi",
|
"fastapi",
|
||||||
"fire",
|
"fire",
|
||||||
"httpx",
|
"httpx",
|
||||||
|
"langdetect",
|
||||||
"litellm",
|
"litellm",
|
||||||
"matplotlib",
|
"matplotlib",
|
||||||
"mcp",
|
"mcp",
|
||||||
|
@ -447,6 +488,7 @@
|
||||||
"psycopg2-binary",
|
"psycopg2-binary",
|
||||||
"pymongo",
|
"pymongo",
|
||||||
"pypdf",
|
"pypdf",
|
||||||
|
"pythainlp",
|
||||||
"redis",
|
"redis",
|
||||||
"requests",
|
"requests",
|
||||||
"scikit-learn",
|
"scikit-learn",
|
||||||
|
@ -464,10 +506,12 @@
|
||||||
"chardet",
|
"chardet",
|
||||||
"chromadb-client",
|
"chromadb-client",
|
||||||
"datasets",
|
"datasets",
|
||||||
|
"emoji",
|
||||||
"faiss-cpu",
|
"faiss-cpu",
|
||||||
"fastapi",
|
"fastapi",
|
||||||
"fire",
|
"fire",
|
||||||
"httpx",
|
"httpx",
|
||||||
|
"langdetect",
|
||||||
"matplotlib",
|
"matplotlib",
|
||||||
"mcp",
|
"mcp",
|
||||||
"nltk",
|
"nltk",
|
||||||
|
@ -479,6 +523,7 @@
|
||||||
"psycopg2-binary",
|
"psycopg2-binary",
|
||||||
"pymongo",
|
"pymongo",
|
||||||
"pypdf",
|
"pypdf",
|
||||||
|
"pythainlp",
|
||||||
"redis",
|
"redis",
|
||||||
"requests",
|
"requests",
|
||||||
"scikit-learn",
|
"scikit-learn",
|
||||||
|
@ -496,10 +541,12 @@
|
||||||
"chardet",
|
"chardet",
|
||||||
"chromadb-client",
|
"chromadb-client",
|
||||||
"datasets",
|
"datasets",
|
||||||
|
"emoji",
|
||||||
"faiss-cpu",
|
"faiss-cpu",
|
||||||
"fastapi",
|
"fastapi",
|
||||||
"fire",
|
"fire",
|
||||||
"httpx",
|
"httpx",
|
||||||
|
"langdetect",
|
||||||
"matplotlib",
|
"matplotlib",
|
||||||
"mcp",
|
"mcp",
|
||||||
"nltk",
|
"nltk",
|
||||||
|
@ -512,6 +559,7 @@
|
||||||
"psycopg2-binary",
|
"psycopg2-binary",
|
||||||
"pymongo",
|
"pymongo",
|
||||||
"pypdf",
|
"pypdf",
|
||||||
|
"pythainlp",
|
||||||
"redis",
|
"redis",
|
||||||
"requests",
|
"requests",
|
||||||
"scikit-learn",
|
"scikit-learn",
|
||||||
|
@ -559,11 +607,13 @@
|
||||||
"chardet",
|
"chardet",
|
||||||
"chromadb-client",
|
"chromadb-client",
|
||||||
"datasets",
|
"datasets",
|
||||||
|
"emoji",
|
||||||
"faiss-cpu",
|
"faiss-cpu",
|
||||||
"fastapi",
|
"fastapi",
|
||||||
"fire",
|
"fire",
|
||||||
"httpx",
|
"httpx",
|
||||||
"huggingface_hub",
|
"huggingface_hub",
|
||||||
|
"langdetect",
|
||||||
"matplotlib",
|
"matplotlib",
|
||||||
"mcp",
|
"mcp",
|
||||||
"nltk",
|
"nltk",
|
||||||
|
@ -575,6 +625,7 @@
|
||||||
"psycopg2-binary",
|
"psycopg2-binary",
|
||||||
"pymongo",
|
"pymongo",
|
||||||
"pypdf",
|
"pypdf",
|
||||||
|
"pythainlp",
|
||||||
"redis",
|
"redis",
|
||||||
"requests",
|
"requests",
|
||||||
"scikit-learn",
|
"scikit-learn",
|
||||||
|
@ -592,10 +643,12 @@
|
||||||
"chardet",
|
"chardet",
|
||||||
"chromadb-client",
|
"chromadb-client",
|
||||||
"datasets",
|
"datasets",
|
||||||
|
"emoji",
|
||||||
"faiss-cpu",
|
"faiss-cpu",
|
||||||
"fastapi",
|
"fastapi",
|
||||||
"fire",
|
"fire",
|
||||||
"httpx",
|
"httpx",
|
||||||
|
"langdetect",
|
||||||
"matplotlib",
|
"matplotlib",
|
||||||
"mcp",
|
"mcp",
|
||||||
"nltk",
|
"nltk",
|
||||||
|
@ -607,6 +660,7 @@
|
||||||
"psycopg2-binary",
|
"psycopg2-binary",
|
||||||
"pymongo",
|
"pymongo",
|
||||||
"pypdf",
|
"pypdf",
|
||||||
|
"pythainlp",
|
||||||
"redis",
|
"redis",
|
||||||
"requests",
|
"requests",
|
||||||
"scikit-learn",
|
"scikit-learn",
|
||||||
|
@ -625,10 +679,12 @@
|
||||||
"chardet",
|
"chardet",
|
||||||
"chromadb-client",
|
"chromadb-client",
|
||||||
"datasets",
|
"datasets",
|
||||||
|
"emoji",
|
||||||
"faiss-cpu",
|
"faiss-cpu",
|
||||||
"fastapi",
|
"fastapi",
|
||||||
"fire",
|
"fire",
|
||||||
"httpx",
|
"httpx",
|
||||||
|
"langdetect",
|
||||||
"matplotlib",
|
"matplotlib",
|
||||||
"mcp",
|
"mcp",
|
||||||
"nltk",
|
"nltk",
|
||||||
|
@ -640,6 +696,7 @@
|
||||||
"psycopg2-binary",
|
"psycopg2-binary",
|
||||||
"pymongo",
|
"pymongo",
|
||||||
"pypdf",
|
"pypdf",
|
||||||
|
"pythainlp",
|
||||||
"redis",
|
"redis",
|
||||||
"requests",
|
"requests",
|
||||||
"scikit-learn",
|
"scikit-learn",
|
||||||
|
|
|
@ -51,14 +51,14 @@ services:
|
||||||
- ~/local/llama-stack/:/app/llama-stack-source
|
- ~/local/llama-stack/:/app/llama-stack-source
|
||||||
- ./run${SAFETY_MODEL:+-with-safety}.yaml:/root/my-run.yaml
|
- ./run${SAFETY_MODEL:+-with-safety}.yaml:/root/my-run.yaml
|
||||||
ports:
|
ports:
|
||||||
- "${LLAMA_STACK_PORT:-5001}:${LLAMA_STACK_PORT:-5001}"
|
- "${LLAMA_STACK_PORT:-8321}:${LLAMA_STACK_PORT:-8321}"
|
||||||
environment:
|
environment:
|
||||||
- INFERENCE_MODEL=${INFERENCE_MODEL}
|
- INFERENCE_MODEL=${INFERENCE_MODEL}
|
||||||
- SAFETY_MODEL=${SAFETY_MODEL:-}
|
- SAFETY_MODEL=${SAFETY_MODEL:-}
|
||||||
- OLLAMA_URL=http://ollama:11434
|
- OLLAMA_URL=http://ollama:11434
|
||||||
entrypoint: >
|
entrypoint: >
|
||||||
python -m llama_stack.distribution.server.server /root/my-run.yaml \
|
python -m llama_stack.distribution.server.server /root/my-run.yaml \
|
||||||
--port ${LLAMA_STACK_PORT:-5001}
|
--port ${LLAMA_STACK_PORT:-8321}
|
||||||
deploy:
|
deploy:
|
||||||
restart_policy:
|
restart_policy:
|
||||||
condition: on-failure
|
condition: on-failure
|
||||||
|
|
BIN
distributions/ramalama/faiss_store.db
Normal file
BIN
distributions/ramalama/faiss_store.db
Normal file
Binary file not shown.
|
@ -84,9 +84,9 @@ services:
|
||||||
- SQLITE_STORE_DIR=${SQLITE_STORE_DIR:-$HOME/.llama/distributions/remote-vllm}
|
- SQLITE_STORE_DIR=${SQLITE_STORE_DIR:-$HOME/.llama/distributions/remote-vllm}
|
||||||
- SAFETY_MODEL=${SAFETY_MODEL:-meta-llama/Llama-Guard-3-1B}
|
- SAFETY_MODEL=${SAFETY_MODEL:-meta-llama/Llama-Guard-3-1B}
|
||||||
ports:
|
ports:
|
||||||
- "${LLAMA_STACK_PORT:-5001}:${LLAMA_STACK_PORT:-5001}"
|
- "${LLAMA_STACK_PORT:-8321}:${LLAMA_STACK_PORT:-8321}"
|
||||||
# Hack: wait for vLLM server to start before starting docker
|
# Hack: wait for vLLM server to start before starting docker
|
||||||
entrypoint: bash -c "sleep 60; python -m llama_stack.distribution.server.server --yaml_config /root/llamastack-run-remote-vllm.yaml --port 5001"
|
entrypoint: bash -c "sleep 60; python -m llama_stack.distribution.server.server --yaml_config /root/llamastack-run-remote-vllm.yaml --port 8321"
|
||||||
deploy:
|
deploy:
|
||||||
restart_policy:
|
restart_policy:
|
||||||
condition: on-failure
|
condition: on-failure
|
||||||
|
|
|
@ -83,7 +83,7 @@ services:
|
||||||
- ~/.llama:/root/.llama
|
- ~/.llama:/root/.llama
|
||||||
- ./run${TGI_SAFETY_MODEL:+-with-safety}.yaml:/root/my-run.yaml
|
- ./run${TGI_SAFETY_MODEL:+-with-safety}.yaml:/root/my-run.yaml
|
||||||
ports:
|
ports:
|
||||||
- "${LLAMA_STACK_PORT:-5001}:${LLAMA_STACK_PORT:-5001}"
|
- "${LLAMA_STACK_PORT:-8321}:${LLAMA_STACK_PORT:-8321}"
|
||||||
# Hack: wait for TGI server to start before starting docker
|
# Hack: wait for TGI server to start before starting docker
|
||||||
entrypoint: bash -c "sleep 60; python -m llama_stack.distribution.server.server --yaml_config /root/my-run.yaml"
|
entrypoint: bash -c "sleep 60; python -m llama_stack.distribution.server.server --yaml_config /root/my-run.yaml"
|
||||||
restart_policy:
|
restart_policy:
|
||||||
|
|
598
docs/_static/llama-stack-spec.html
vendored
598
docs/_static/llama-stack-spec.html
vendored
|
@ -2285,7 +2285,7 @@
|
||||||
"content": {
|
"content": {
|
||||||
"application/json": {
|
"application/json": {
|
||||||
"schema": {
|
"schema": {
|
||||||
"$ref": "#/components/schemas/ListAgentSessionsResponse"
|
"$ref": "#/components/schemas/Job"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -2719,7 +2719,7 @@
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"tags": [
|
"tags": [
|
||||||
"Inspect"
|
"Providers"
|
||||||
],
|
],
|
||||||
"description": "",
|
"description": "",
|
||||||
"parameters": []
|
"parameters": []
|
||||||
|
@ -4108,6 +4108,11 @@
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"arguments": {
|
"arguments": {
|
||||||
|
"oneOf": [
|
||||||
|
{
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
{
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"additionalProperties": {
|
"additionalProperties": {
|
||||||
"oneOf": [
|
"oneOf": [
|
||||||
|
@ -4173,6 +4178,11 @@
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"arguments_json": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
"required": [
|
"required": [
|
||||||
|
@ -6182,6 +6192,382 @@
|
||||||
"title": "EmbeddingsResponse",
|
"title": "EmbeddingsResponse",
|
||||||
"description": "Response containing generated embeddings."
|
"description": "Response containing generated embeddings."
|
||||||
},
|
},
|
||||||
|
"AgentCandidate": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"type": {
|
||||||
|
"type": "string",
|
||||||
|
"const": "agent",
|
||||||
|
"default": "agent"
|
||||||
|
},
|
||||||
|
"config": {
|
||||||
|
"$ref": "#/components/schemas/AgentConfig",
|
||||||
|
"description": "The configuration for the agent candidate."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"type",
|
||||||
|
"config"
|
||||||
|
],
|
||||||
|
"title": "AgentCandidate",
|
||||||
|
"description": "An agent candidate for evaluation."
|
||||||
|
},
|
||||||
|
"AggregationFunctionType": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": [
|
||||||
|
"average",
|
||||||
|
"weighted_average",
|
||||||
|
"median",
|
||||||
|
"categorical_count",
|
||||||
|
"accuracy"
|
||||||
|
],
|
||||||
|
"title": "AggregationFunctionType"
|
||||||
|
},
|
||||||
|
"BasicScoringFnParams": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"type": {
|
||||||
|
"type": "string",
|
||||||
|
"const": "basic",
|
||||||
|
"default": "basic"
|
||||||
|
},
|
||||||
|
"aggregation_functions": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"$ref": "#/components/schemas/AggregationFunctionType"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"type"
|
||||||
|
],
|
||||||
|
"title": "BasicScoringFnParams"
|
||||||
|
},
|
||||||
|
"BenchmarkConfig": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"eval_candidate": {
|
||||||
|
"$ref": "#/components/schemas/EvalCandidate",
|
||||||
|
"description": "The candidate to evaluate."
|
||||||
|
},
|
||||||
|
"scoring_params": {
|
||||||
|
"type": "object",
|
||||||
|
"additionalProperties": {
|
||||||
|
"$ref": "#/components/schemas/ScoringFnParams"
|
||||||
|
},
|
||||||
|
"description": "Map between scoring function id and parameters for each scoring function you want to run"
|
||||||
|
},
|
||||||
|
"num_examples": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "(Optional) The number of examples to evaluate. If not provided, all examples in the dataset will be evaluated"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"eval_candidate",
|
||||||
|
"scoring_params"
|
||||||
|
],
|
||||||
|
"title": "BenchmarkConfig",
|
||||||
|
"description": "A benchmark configuration for evaluation."
|
||||||
|
},
|
||||||
|
"EvalCandidate": {
|
||||||
|
"oneOf": [
|
||||||
|
{
|
||||||
|
"$ref": "#/components/schemas/ModelCandidate"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"$ref": "#/components/schemas/AgentCandidate"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"discriminator": {
|
||||||
|
"propertyName": "type",
|
||||||
|
"mapping": {
|
||||||
|
"model": "#/components/schemas/ModelCandidate",
|
||||||
|
"agent": "#/components/schemas/AgentCandidate"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"LLMAsJudgeScoringFnParams": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"type": {
|
||||||
|
"type": "string",
|
||||||
|
"const": "llm_as_judge",
|
||||||
|
"default": "llm_as_judge"
|
||||||
|
},
|
||||||
|
"judge_model": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"prompt_template": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"judge_score_regexes": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"aggregation_functions": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"$ref": "#/components/schemas/AggregationFunctionType"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"type",
|
||||||
|
"judge_model"
|
||||||
|
],
|
||||||
|
"title": "LLMAsJudgeScoringFnParams"
|
||||||
|
},
|
||||||
|
"ModelCandidate": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"type": {
|
||||||
|
"type": "string",
|
||||||
|
"const": "model",
|
||||||
|
"default": "model"
|
||||||
|
},
|
||||||
|
"model": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The model ID to evaluate."
|
||||||
|
},
|
||||||
|
"sampling_params": {
|
||||||
|
"$ref": "#/components/schemas/SamplingParams",
|
||||||
|
"description": "The sampling parameters for the model."
|
||||||
|
},
|
||||||
|
"system_message": {
|
||||||
|
"$ref": "#/components/schemas/SystemMessage",
|
||||||
|
"description": "(Optional) The system message providing instructions or context to the model."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"type",
|
||||||
|
"model",
|
||||||
|
"sampling_params"
|
||||||
|
],
|
||||||
|
"title": "ModelCandidate",
|
||||||
|
"description": "A model candidate for evaluation."
|
||||||
|
},
|
||||||
|
"RegexParserScoringFnParams": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"type": {
|
||||||
|
"type": "string",
|
||||||
|
"const": "regex_parser",
|
||||||
|
"default": "regex_parser"
|
||||||
|
},
|
||||||
|
"parsing_regexes": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"aggregation_functions": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"$ref": "#/components/schemas/AggregationFunctionType"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"type"
|
||||||
|
],
|
||||||
|
"title": "RegexParserScoringFnParams"
|
||||||
|
},
|
||||||
|
"ScoringFnParams": {
|
||||||
|
"oneOf": [
|
||||||
|
{
|
||||||
|
"$ref": "#/components/schemas/LLMAsJudgeScoringFnParams"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"$ref": "#/components/schemas/RegexParserScoringFnParams"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"$ref": "#/components/schemas/BasicScoringFnParams"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"discriminator": {
|
||||||
|
"propertyName": "type",
|
||||||
|
"mapping": {
|
||||||
|
"llm_as_judge": "#/components/schemas/LLMAsJudgeScoringFnParams",
|
||||||
|
"regex_parser": "#/components/schemas/RegexParserScoringFnParams",
|
||||||
|
"basic": "#/components/schemas/BasicScoringFnParams"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"EvaluateRowsRequest": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"input_rows": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "object",
|
||||||
|
"additionalProperties": {
|
||||||
|
"oneOf": [
|
||||||
|
{
|
||||||
|
"type": "null"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "boolean"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "number"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "array"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "object"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"description": "The rows to evaluate."
|
||||||
|
},
|
||||||
|
"scoring_functions": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"description": "The scoring functions to use for the evaluation."
|
||||||
|
},
|
||||||
|
"benchmark_config": {
|
||||||
|
"$ref": "#/components/schemas/BenchmarkConfig",
|
||||||
|
"description": "The configuration for the benchmark."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"input_rows",
|
||||||
|
"scoring_functions",
|
||||||
|
"benchmark_config"
|
||||||
|
],
|
||||||
|
"title": "EvaluateRowsRequest"
|
||||||
|
},
|
||||||
|
"EvaluateResponse": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"generations": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "object",
|
||||||
|
"additionalProperties": {
|
||||||
|
"oneOf": [
|
||||||
|
{
|
||||||
|
"type": "null"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "boolean"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "number"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "array"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "object"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"description": "The generations from the evaluation."
|
||||||
|
},
|
||||||
|
"scores": {
|
||||||
|
"type": "object",
|
||||||
|
"additionalProperties": {
|
||||||
|
"$ref": "#/components/schemas/ScoringResult"
|
||||||
|
},
|
||||||
|
"description": "The scores from the evaluation."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"generations",
|
||||||
|
"scores"
|
||||||
|
],
|
||||||
|
"title": "EvaluateResponse",
|
||||||
|
"description": "The response from an evaluation."
|
||||||
|
},
|
||||||
|
"ScoringResult": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"score_rows": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "object",
|
||||||
|
"additionalProperties": {
|
||||||
|
"oneOf": [
|
||||||
|
{
|
||||||
|
"type": "null"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "boolean"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "number"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "array"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "object"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"description": "The scoring result for each row. Each row is a map of column name to value."
|
||||||
|
},
|
||||||
|
"aggregated_results": {
|
||||||
|
"type": "object",
|
||||||
|
"additionalProperties": {
|
||||||
|
"oneOf": [
|
||||||
|
{
|
||||||
|
"type": "null"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "boolean"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "number"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "array"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "object"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"description": "Map of metric name to aggregated value"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"score_rows",
|
||||||
|
"aggregated_results"
|
||||||
|
],
|
||||||
|
"title": "ScoringResult",
|
||||||
|
"description": "A scoring result for a single row."
|
||||||
|
},
|
||||||
"Agent": {
|
"Agent": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
|
@ -7319,8 +7705,7 @@
|
||||||
"completed",
|
"completed",
|
||||||
"in_progress",
|
"in_progress",
|
||||||
"failed",
|
"failed",
|
||||||
"scheduled",
|
"scheduled"
|
||||||
"cancelled"
|
|
||||||
],
|
],
|
||||||
"title": "JobStatus"
|
"title": "JobStatus"
|
||||||
},
|
},
|
||||||
|
@ -7698,7 +8083,8 @@
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"document_id": {
|
"document_id": {
|
||||||
"type": "string"
|
"type": "string",
|
||||||
|
"description": "The unique identifier for the document."
|
||||||
},
|
},
|
||||||
"content": {
|
"content": {
|
||||||
"oneOf": [
|
"oneOf": [
|
||||||
|
@ -7717,10 +8103,12 @@
|
||||||
{
|
{
|
||||||
"$ref": "#/components/schemas/URL"
|
"$ref": "#/components/schemas/URL"
|
||||||
}
|
}
|
||||||
]
|
],
|
||||||
|
"description": "The content of the document."
|
||||||
},
|
},
|
||||||
"mime_type": {
|
"mime_type": {
|
||||||
"type": "string"
|
"type": "string",
|
||||||
|
"description": "The MIME type of the document."
|
||||||
},
|
},
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
|
@ -7745,7 +8133,8 @@
|
||||||
"type": "object"
|
"type": "object"
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
},
|
||||||
|
"description": "Additional metadata for the document."
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
|
@ -7754,7 +8143,8 @@
|
||||||
"content",
|
"content",
|
||||||
"metadata"
|
"metadata"
|
||||||
],
|
],
|
||||||
"title": "RAGDocument"
|
"title": "RAGDocument",
|
||||||
|
"description": "A document to be used for document ingestion in the RAG Tool."
|
||||||
},
|
},
|
||||||
"InsertRequest": {
|
"InsertRequest": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
|
@ -7964,9 +8354,6 @@
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
"required": [
|
|
||||||
"content"
|
|
||||||
],
|
|
||||||
"title": "ToolInvocationResult"
|
"title": "ToolInvocationResult"
|
||||||
},
|
},
|
||||||
"IterrowsResponse": {
|
"IterrowsResponse": {
|
||||||
|
@ -8013,6 +8400,30 @@
|
||||||
"title": "IterrowsResponse",
|
"title": "IterrowsResponse",
|
||||||
"description": "A paginated list of rows from a dataset."
|
"description": "A paginated list of rows from a dataset."
|
||||||
},
|
},
|
||||||
|
"Job": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"job_id": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"status": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": [
|
||||||
|
"completed",
|
||||||
|
"in_progress",
|
||||||
|
"failed",
|
||||||
|
"scheduled"
|
||||||
|
],
|
||||||
|
"title": "JobStatus"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"job_id",
|
||||||
|
"status"
|
||||||
|
],
|
||||||
|
"title": "Job"
|
||||||
|
},
|
||||||
"ListAgentSessionsResponse": {
|
"ListAgentSessionsResponse": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
|
@ -9596,21 +10007,16 @@
|
||||||
"RunRequest": {
|
"RunRequest": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"task": {
|
"benchmark_config": {
|
||||||
"$ref": "#/components/schemas/EvaluationTask",
|
"$ref": "#/components/schemas/BenchmarkConfig",
|
||||||
"description": "The task to evaluate. To specify a task, one of the following must be provided: - `benchmark_id`: Run evaluation task against a benchmark_id - `dataset_id` and `grader_ids`: Run evaluation task against a dataset_id and a list of grader_ids - `data_source` and `grader_ids`: Run evaluation task against a data source (e.g. rows, uri, etc.) and a list of grader_ids"
|
"description": "The configuration for the benchmark."
|
||||||
},
|
|
||||||
"candidate": {
|
|
||||||
"$ref": "#/components/schemas/EvaluationCandidate",
|
|
||||||
"description": "The candidate to evaluate."
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
"required": [
|
"required": [
|
||||||
"task",
|
"benchmark_config"
|
||||||
"candidate"
|
|
||||||
],
|
],
|
||||||
"title": "RunRequest"
|
"title": "RunEvalRequest"
|
||||||
},
|
},
|
||||||
"RunShieldRequest": {
|
"RunShieldRequest": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
|
@ -9717,6 +10123,145 @@
|
||||||
],
|
],
|
||||||
"title": "SaveSpansToDatasetRequest"
|
"title": "SaveSpansToDatasetRequest"
|
||||||
},
|
},
|
||||||
|
"ScoreRequest": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"input_rows": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "object",
|
||||||
|
"additionalProperties": {
|
||||||
|
"oneOf": [
|
||||||
|
{
|
||||||
|
"type": "null"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "boolean"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "number"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "array"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "object"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"description": "The rows to score."
|
||||||
|
},
|
||||||
|
"scoring_functions": {
|
||||||
|
"type": "object",
|
||||||
|
"additionalProperties": {
|
||||||
|
"oneOf": [
|
||||||
|
{
|
||||||
|
"$ref": "#/components/schemas/ScoringFnParams"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "null"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"description": "The scoring functions to use for the scoring."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"input_rows",
|
||||||
|
"scoring_functions"
|
||||||
|
],
|
||||||
|
"title": "ScoreRequest"
|
||||||
|
},
|
||||||
|
"ScoreResponse": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"results": {
|
||||||
|
"type": "object",
|
||||||
|
"additionalProperties": {
|
||||||
|
"$ref": "#/components/schemas/ScoringResult"
|
||||||
|
},
|
||||||
|
"description": "A map of scoring function name to ScoringResult."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"results"
|
||||||
|
],
|
||||||
|
"title": "ScoreResponse",
|
||||||
|
"description": "The response from scoring."
|
||||||
|
},
|
||||||
|
"ScoreBatchRequest": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"dataset_id": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"scoring_functions": {
|
||||||
|
"type": "object",
|
||||||
|
"additionalProperties": {
|
||||||
|
"oneOf": [
|
||||||
|
{
|
||||||
|
"$ref": "#/components/schemas/ScoringFnParams"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "null"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"save_results_dataset": {
|
||||||
|
"type": "boolean"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"dataset_id",
|
||||||
|
"scoring_functions",
|
||||||
|
"save_results_dataset"
|
||||||
|
],
|
||||||
|
"title": "ScoreBatchRequest"
|
||||||
|
},
|
||||||
|
"ScoreBatchResponse": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"dataset_id": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"results": {
|
||||||
|
"type": "object",
|
||||||
|
"additionalProperties": {
|
||||||
|
"$ref": "#/components/schemas/ScoringResult"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"results"
|
||||||
|
],
|
||||||
|
"title": "ScoreBatchResponse"
|
||||||
|
},
|
||||||
|
"AlgorithmConfig": {
|
||||||
|
"oneOf": [
|
||||||
|
{
|
||||||
|
"$ref": "#/components/schemas/LoraFinetuningConfig"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"$ref": "#/components/schemas/QATFinetuningConfig"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"discriminator": {
|
||||||
|
"propertyName": "type",
|
||||||
|
"mapping": {
|
||||||
|
"LoRA": "#/components/schemas/LoraFinetuningConfig",
|
||||||
|
"QAT": "#/components/schemas/QATFinetuningConfig"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
"LoraFinetuningConfig": {
|
"LoraFinetuningConfig": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
|
@ -9852,14 +10397,7 @@
|
||||||
"type": "string"
|
"type": "string"
|
||||||
},
|
},
|
||||||
"algorithm_config": {
|
"algorithm_config": {
|
||||||
"oneOf": [
|
"$ref": "#/components/schemas/AlgorithmConfig"
|
||||||
{
|
|
||||||
"$ref": "#/components/schemas/LoraFinetuningConfig"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"$ref": "#/components/schemas/QATFinetuningConfig"
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
|
|
477
docs/_static/llama-stack-spec.yaml
vendored
477
docs/_static/llama-stack-spec.yaml
vendored
|
@ -1562,6 +1562,109 @@ paths:
|
||||||
required: false
|
required: false
|
||||||
schema:
|
schema:
|
||||||
type: integer
|
type: integer
|
||||||
|
/v1/eval/benchmarks/{benchmark_id}/jobs/{job_id}:
|
||||||
|
get:
|
||||||
|
responses:
|
||||||
|
'200':
|
||||||
|
description: The status of the evaluationjob.
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: '#/components/schemas/Job'
|
||||||
|
'400':
|
||||||
|
$ref: '#/components/responses/BadRequest400'
|
||||||
|
'429':
|
||||||
|
$ref: >-
|
||||||
|
#/components/responses/TooManyRequests429
|
||||||
|
'500':
|
||||||
|
$ref: >-
|
||||||
|
#/components/responses/InternalServerError500
|
||||||
|
default:
|
||||||
|
$ref: '#/components/responses/DefaultError'
|
||||||
|
tags:
|
||||||
|
- Eval
|
||||||
|
description: Get the status of a job.
|
||||||
|
parameters:
|
||||||
|
- name: benchmark_id
|
||||||
|
in: path
|
||||||
|
description: >-
|
||||||
|
The ID of the benchmark to run the evaluation on.
|
||||||
|
required: true
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
- name: job_id
|
||||||
|
in: path
|
||||||
|
description: The ID of the job to get the status of.
|
||||||
|
required: true
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
delete:
|
||||||
|
responses:
|
||||||
|
'200':
|
||||||
|
description: OK
|
||||||
|
'400':
|
||||||
|
$ref: '#/components/responses/BadRequest400'
|
||||||
|
'429':
|
||||||
|
$ref: >-
|
||||||
|
#/components/responses/TooManyRequests429
|
||||||
|
'500':
|
||||||
|
$ref: >-
|
||||||
|
#/components/responses/InternalServerError500
|
||||||
|
default:
|
||||||
|
$ref: '#/components/responses/DefaultError'
|
||||||
|
tags:
|
||||||
|
- Eval
|
||||||
|
description: Cancel a job.
|
||||||
|
parameters:
|
||||||
|
- name: benchmark_id
|
||||||
|
in: path
|
||||||
|
description: >-
|
||||||
|
The ID of the benchmark to run the evaluation on.
|
||||||
|
required: true
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
- name: job_id
|
||||||
|
in: path
|
||||||
|
description: The ID of the job to cancel.
|
||||||
|
required: true
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
/v1/eval/benchmarks/{benchmark_id}/jobs/{job_id}/result:
|
||||||
|
get:
|
||||||
|
responses:
|
||||||
|
'200':
|
||||||
|
description: The result of the job.
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: '#/components/schemas/EvaluateResponse'
|
||||||
|
'400':
|
||||||
|
$ref: '#/components/responses/BadRequest400'
|
||||||
|
'429':
|
||||||
|
$ref: >-
|
||||||
|
#/components/responses/TooManyRequests429
|
||||||
|
'500':
|
||||||
|
$ref: >-
|
||||||
|
#/components/responses/InternalServerError500
|
||||||
|
default:
|
||||||
|
$ref: '#/components/responses/DefaultError'
|
||||||
|
tags:
|
||||||
|
- Eval
|
||||||
|
description: Get the result of a job.
|
||||||
|
parameters:
|
||||||
|
- name: benchmark_id
|
||||||
|
in: path
|
||||||
|
description: >-
|
||||||
|
The ID of the benchmark to run the evaluation on.
|
||||||
|
required: true
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
- name: job_id
|
||||||
|
in: path
|
||||||
|
description: The ID of the job to get the result of.
|
||||||
|
required: true
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
/v1/agents/{agent_id}/sessions:
|
/v1/agents/{agent_id}/sessions:
|
||||||
get:
|
get:
|
||||||
responses:
|
responses:
|
||||||
|
@ -1820,7 +1923,7 @@ paths:
|
||||||
default:
|
default:
|
||||||
$ref: '#/components/responses/DefaultError'
|
$ref: '#/components/responses/DefaultError'
|
||||||
tags:
|
tags:
|
||||||
- Models
|
- Providers
|
||||||
description: ''
|
description: ''
|
||||||
parameters: []
|
parameters: []
|
||||||
post:
|
post:
|
||||||
|
@ -2841,7 +2944,9 @@ components:
|
||||||
title: BuiltinTool
|
title: BuiltinTool
|
||||||
- type: string
|
- type: string
|
||||||
arguments:
|
arguments:
|
||||||
type: object
|
oneOf:
|
||||||
|
- type: string
|
||||||
|
- type: object
|
||||||
additionalProperties:
|
additionalProperties:
|
||||||
oneOf:
|
oneOf:
|
||||||
- type: string
|
- type: string
|
||||||
|
@ -2865,6 +2970,8 @@ components:
|
||||||
- type: number
|
- type: number
|
||||||
- type: boolean
|
- type: boolean
|
||||||
- type: 'null'
|
- type: 'null'
|
||||||
|
arguments_json:
|
||||||
|
type: string
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
required:
|
required:
|
||||||
- call_id
|
- call_id
|
||||||
|
@ -4341,6 +4448,252 @@ components:
|
||||||
title: EmbeddingsResponse
|
title: EmbeddingsResponse
|
||||||
description: >-
|
description: >-
|
||||||
Response containing generated embeddings.
|
Response containing generated embeddings.
|
||||||
|
AgentCandidate:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
type:
|
||||||
|
type: string
|
||||||
|
const: agent
|
||||||
|
default: agent
|
||||||
|
config:
|
||||||
|
$ref: '#/components/schemas/AgentConfig'
|
||||||
|
description: >-
|
||||||
|
The configuration for the agent candidate.
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- type
|
||||||
|
- config
|
||||||
|
title: AgentCandidate
|
||||||
|
description: An agent candidate for evaluation.
|
||||||
|
AggregationFunctionType:
|
||||||
|
type: string
|
||||||
|
enum:
|
||||||
|
- average
|
||||||
|
- weighted_average
|
||||||
|
- median
|
||||||
|
- categorical_count
|
||||||
|
- accuracy
|
||||||
|
title: AggregationFunctionType
|
||||||
|
BasicScoringFnParams:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
type:
|
||||||
|
type: string
|
||||||
|
const: basic
|
||||||
|
default: basic
|
||||||
|
aggregation_functions:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
$ref: '#/components/schemas/AggregationFunctionType'
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- type
|
||||||
|
title: BasicScoringFnParams
|
||||||
|
BenchmarkConfig:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
eval_candidate:
|
||||||
|
$ref: '#/components/schemas/EvalCandidate'
|
||||||
|
description: The candidate to evaluate.
|
||||||
|
scoring_params:
|
||||||
|
type: object
|
||||||
|
additionalProperties:
|
||||||
|
$ref: '#/components/schemas/ScoringFnParams'
|
||||||
|
description: >-
|
||||||
|
Map between scoring function id and parameters for each scoring function
|
||||||
|
you want to run
|
||||||
|
num_examples:
|
||||||
|
type: integer
|
||||||
|
description: >-
|
||||||
|
(Optional) The number of examples to evaluate. If not provided, all examples
|
||||||
|
in the dataset will be evaluated
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- eval_candidate
|
||||||
|
- scoring_params
|
||||||
|
title: BenchmarkConfig
|
||||||
|
description: >-
|
||||||
|
A benchmark configuration for evaluation.
|
||||||
|
EvalCandidate:
|
||||||
|
oneOf:
|
||||||
|
- $ref: '#/components/schemas/ModelCandidate'
|
||||||
|
- $ref: '#/components/schemas/AgentCandidate'
|
||||||
|
discriminator:
|
||||||
|
propertyName: type
|
||||||
|
mapping:
|
||||||
|
model: '#/components/schemas/ModelCandidate'
|
||||||
|
agent: '#/components/schemas/AgentCandidate'
|
||||||
|
LLMAsJudgeScoringFnParams:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
type:
|
||||||
|
type: string
|
||||||
|
const: llm_as_judge
|
||||||
|
default: llm_as_judge
|
||||||
|
judge_model:
|
||||||
|
type: string
|
||||||
|
prompt_template:
|
||||||
|
type: string
|
||||||
|
judge_score_regexes:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
type: string
|
||||||
|
aggregation_functions:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
$ref: '#/components/schemas/AggregationFunctionType'
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- type
|
||||||
|
- judge_model
|
||||||
|
title: LLMAsJudgeScoringFnParams
|
||||||
|
ModelCandidate:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
type:
|
||||||
|
type: string
|
||||||
|
const: model
|
||||||
|
default: model
|
||||||
|
model:
|
||||||
|
type: string
|
||||||
|
description: The model ID to evaluate.
|
||||||
|
sampling_params:
|
||||||
|
$ref: '#/components/schemas/SamplingParams'
|
||||||
|
description: The sampling parameters for the model.
|
||||||
|
system_message:
|
||||||
|
$ref: '#/components/schemas/SystemMessage'
|
||||||
|
description: >-
|
||||||
|
(Optional) The system message providing instructions or context to the
|
||||||
|
model.
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- type
|
||||||
|
- model
|
||||||
|
- sampling_params
|
||||||
|
title: ModelCandidate
|
||||||
|
description: A model candidate for evaluation.
|
||||||
|
RegexParserScoringFnParams:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
type:
|
||||||
|
type: string
|
||||||
|
const: regex_parser
|
||||||
|
default: regex_parser
|
||||||
|
parsing_regexes:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
type: string
|
||||||
|
aggregation_functions:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
$ref: '#/components/schemas/AggregationFunctionType'
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- type
|
||||||
|
title: RegexParserScoringFnParams
|
||||||
|
ScoringFnParams:
|
||||||
|
oneOf:
|
||||||
|
- $ref: '#/components/schemas/LLMAsJudgeScoringFnParams'
|
||||||
|
- $ref: '#/components/schemas/RegexParserScoringFnParams'
|
||||||
|
- $ref: '#/components/schemas/BasicScoringFnParams'
|
||||||
|
discriminator:
|
||||||
|
propertyName: type
|
||||||
|
mapping:
|
||||||
|
llm_as_judge: '#/components/schemas/LLMAsJudgeScoringFnParams'
|
||||||
|
regex_parser: '#/components/schemas/RegexParserScoringFnParams'
|
||||||
|
basic: '#/components/schemas/BasicScoringFnParams'
|
||||||
|
EvaluateRowsRequest:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
input_rows:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
type: object
|
||||||
|
additionalProperties:
|
||||||
|
oneOf:
|
||||||
|
- type: 'null'
|
||||||
|
- type: boolean
|
||||||
|
- type: number
|
||||||
|
- type: string
|
||||||
|
- type: array
|
||||||
|
- type: object
|
||||||
|
description: The rows to evaluate.
|
||||||
|
scoring_functions:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
type: string
|
||||||
|
description: >-
|
||||||
|
The scoring functions to use for the evaluation.
|
||||||
|
benchmark_config:
|
||||||
|
$ref: '#/components/schemas/BenchmarkConfig'
|
||||||
|
description: The configuration for the benchmark.
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- input_rows
|
||||||
|
- scoring_functions
|
||||||
|
- benchmark_config
|
||||||
|
title: EvaluateRowsRequest
|
||||||
|
EvaluateResponse:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
generations:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
type: object
|
||||||
|
additionalProperties:
|
||||||
|
oneOf:
|
||||||
|
- type: 'null'
|
||||||
|
- type: boolean
|
||||||
|
- type: number
|
||||||
|
- type: string
|
||||||
|
- type: array
|
||||||
|
- type: object
|
||||||
|
description: The generations from the evaluation.
|
||||||
|
scores:
|
||||||
|
type: object
|
||||||
|
additionalProperties:
|
||||||
|
$ref: '#/components/schemas/ScoringResult'
|
||||||
|
description: The scores from the evaluation.
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- generations
|
||||||
|
- scores
|
||||||
|
title: EvaluateResponse
|
||||||
|
description: The response from an evaluation.
|
||||||
|
ScoringResult:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
score_rows:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
type: object
|
||||||
|
additionalProperties:
|
||||||
|
oneOf:
|
||||||
|
- type: 'null'
|
||||||
|
- type: boolean
|
||||||
|
- type: number
|
||||||
|
- type: string
|
||||||
|
- type: array
|
||||||
|
- type: object
|
||||||
|
description: >-
|
||||||
|
The scoring result for each row. Each row is a map of column name to value.
|
||||||
|
aggregated_results:
|
||||||
|
type: object
|
||||||
|
additionalProperties:
|
||||||
|
oneOf:
|
||||||
|
- type: 'null'
|
||||||
|
- type: boolean
|
||||||
|
- type: number
|
||||||
|
- type: string
|
||||||
|
- type: array
|
||||||
|
- type: object
|
||||||
|
description: Map of metric name to aggregated value
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- score_rows
|
||||||
|
- aggregated_results
|
||||||
|
title: ScoringResult
|
||||||
|
description: A scoring result for a single row.
|
||||||
Agent:
|
Agent:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
|
@ -5098,7 +5451,6 @@ components:
|
||||||
- in_progress
|
- in_progress
|
||||||
- failed
|
- failed
|
||||||
- scheduled
|
- scheduled
|
||||||
- cancelled
|
|
||||||
title: JobStatus
|
title: JobStatus
|
||||||
scheduled_at:
|
scheduled_at:
|
||||||
type: string
|
type: string
|
||||||
|
@ -5373,6 +5725,7 @@ components:
|
||||||
properties:
|
properties:
|
||||||
document_id:
|
document_id:
|
||||||
type: string
|
type: string
|
||||||
|
description: The unique identifier for the document.
|
||||||
content:
|
content:
|
||||||
oneOf:
|
oneOf:
|
||||||
- type: string
|
- type: string
|
||||||
|
@ -5381,8 +5734,10 @@ components:
|
||||||
items:
|
items:
|
||||||
$ref: '#/components/schemas/InterleavedContentItem'
|
$ref: '#/components/schemas/InterleavedContentItem'
|
||||||
- $ref: '#/components/schemas/URL'
|
- $ref: '#/components/schemas/URL'
|
||||||
|
description: The content of the document.
|
||||||
mime_type:
|
mime_type:
|
||||||
type: string
|
type: string
|
||||||
|
description: The MIME type of the document.
|
||||||
metadata:
|
metadata:
|
||||||
type: object
|
type: object
|
||||||
additionalProperties:
|
additionalProperties:
|
||||||
|
@ -5393,12 +5748,15 @@ components:
|
||||||
- type: string
|
- type: string
|
||||||
- type: array
|
- type: array
|
||||||
- type: object
|
- type: object
|
||||||
|
description: Additional metadata for the document.
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
required:
|
required:
|
||||||
- document_id
|
- document_id
|
||||||
- content
|
- content
|
||||||
- metadata
|
- metadata
|
||||||
title: RAGDocument
|
title: RAGDocument
|
||||||
|
description: >-
|
||||||
|
A document to be used for document ingestion in the RAG Tool.
|
||||||
InsertRequest:
|
InsertRequest:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
|
@ -5516,8 +5874,6 @@ components:
|
||||||
- type: array
|
- type: array
|
||||||
- type: object
|
- type: object
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
required:
|
|
||||||
- content
|
|
||||||
title: ToolInvocationResult
|
title: ToolInvocationResult
|
||||||
IterrowsResponse:
|
IterrowsResponse:
|
||||||
type: object
|
type: object
|
||||||
|
@ -5545,6 +5901,24 @@ components:
|
||||||
- data
|
- data
|
||||||
title: IterrowsResponse
|
title: IterrowsResponse
|
||||||
description: A paginated list of rows from a dataset.
|
description: A paginated list of rows from a dataset.
|
||||||
|
Job:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
job_id:
|
||||||
|
type: string
|
||||||
|
status:
|
||||||
|
type: string
|
||||||
|
enum:
|
||||||
|
- completed
|
||||||
|
- in_progress
|
||||||
|
- failed
|
||||||
|
- scheduled
|
||||||
|
title: JobStatus
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- job_id
|
||||||
|
- status
|
||||||
|
title: Job
|
||||||
ListAgentSessionsResponse:
|
ListAgentSessionsResponse:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
|
@ -6610,9 +6984,8 @@ components:
|
||||||
description: The candidate to evaluate.
|
description: The candidate to evaluate.
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
required:
|
required:
|
||||||
- task
|
- benchmark_config
|
||||||
- candidate
|
title: RunEvalRequest
|
||||||
title: RunRequest
|
|
||||||
RunShieldRequest:
|
RunShieldRequest:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
|
@ -6685,6 +7058,90 @@ components:
|
||||||
- attributes_to_save
|
- attributes_to_save
|
||||||
- dataset_id
|
- dataset_id
|
||||||
title: SaveSpansToDatasetRequest
|
title: SaveSpansToDatasetRequest
|
||||||
|
ScoreRequest:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
input_rows:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
type: object
|
||||||
|
additionalProperties:
|
||||||
|
oneOf:
|
||||||
|
- type: 'null'
|
||||||
|
- type: boolean
|
||||||
|
- type: number
|
||||||
|
- type: string
|
||||||
|
- type: array
|
||||||
|
- type: object
|
||||||
|
description: The rows to score.
|
||||||
|
scoring_functions:
|
||||||
|
type: object
|
||||||
|
additionalProperties:
|
||||||
|
oneOf:
|
||||||
|
- $ref: '#/components/schemas/ScoringFnParams'
|
||||||
|
- type: 'null'
|
||||||
|
description: >-
|
||||||
|
The scoring functions to use for the scoring.
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- input_rows
|
||||||
|
- scoring_functions
|
||||||
|
title: ScoreRequest
|
||||||
|
ScoreResponse:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
results:
|
||||||
|
type: object
|
||||||
|
additionalProperties:
|
||||||
|
$ref: '#/components/schemas/ScoringResult'
|
||||||
|
description: >-
|
||||||
|
A map of scoring function name to ScoringResult.
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- results
|
||||||
|
title: ScoreResponse
|
||||||
|
description: The response from scoring.
|
||||||
|
ScoreBatchRequest:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
dataset_id:
|
||||||
|
type: string
|
||||||
|
scoring_functions:
|
||||||
|
type: object
|
||||||
|
additionalProperties:
|
||||||
|
oneOf:
|
||||||
|
- $ref: '#/components/schemas/ScoringFnParams'
|
||||||
|
- type: 'null'
|
||||||
|
save_results_dataset:
|
||||||
|
type: boolean
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- dataset_id
|
||||||
|
- scoring_functions
|
||||||
|
- save_results_dataset
|
||||||
|
title: ScoreBatchRequest
|
||||||
|
ScoreBatchResponse:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
dataset_id:
|
||||||
|
type: string
|
||||||
|
results:
|
||||||
|
type: object
|
||||||
|
additionalProperties:
|
||||||
|
$ref: '#/components/schemas/ScoringResult'
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- results
|
||||||
|
title: ScoreBatchResponse
|
||||||
|
AlgorithmConfig:
|
||||||
|
oneOf:
|
||||||
|
- $ref: '#/components/schemas/LoraFinetuningConfig'
|
||||||
|
- $ref: '#/components/schemas/QATFinetuningConfig'
|
||||||
|
discriminator:
|
||||||
|
propertyName: type
|
||||||
|
mapping:
|
||||||
|
LoRA: '#/components/schemas/LoraFinetuningConfig'
|
||||||
|
QAT: '#/components/schemas/QATFinetuningConfig'
|
||||||
LoraFinetuningConfig:
|
LoraFinetuningConfig:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
|
@ -6768,9 +7225,7 @@ components:
|
||||||
checkpoint_dir:
|
checkpoint_dir:
|
||||||
type: string
|
type: string
|
||||||
algorithm_config:
|
algorithm_config:
|
||||||
oneOf:
|
$ref: '#/components/schemas/AlgorithmConfig'
|
||||||
- $ref: '#/components/schemas/LoraFinetuningConfig'
|
|
||||||
- $ref: '#/components/schemas/QATFinetuningConfig'
|
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
required:
|
required:
|
||||||
- job_uuid
|
- job_uuid
|
||||||
|
|
|
@ -4,6 +4,21 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
|
||||||
|
|
||||||
def pytest_collection_modifyitems(items):
|
def pytest_collection_modifyitems(items):
|
||||||
for item in items:
|
for item in items:
|
||||||
item.name = item.name.replace(' ', '_')
|
item.name = item.name.replace(' ', '_')
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_runtest_teardown(item):
|
||||||
|
interval_seconds = os.getenv("LLAMA_STACK_TEST_INTERVAL_SECONDS")
|
||||||
|
if interval_seconds:
|
||||||
|
time.sleep(float(interval_seconds))
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_configure(config):
|
||||||
|
config.option.tbstyle = "short"
|
||||||
|
config.option.disable_warnings = True
|
||||||
|
|
|
@ -123,6 +123,8 @@
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# NBVAL_SKIP\n",
|
"# NBVAL_SKIP\n",
|
||||||
|
"!pip uninstall pandas numpy -y\n",
|
||||||
|
"!pip install pandas numpy\n",
|
||||||
"# This will build all the dependencies you will need\n",
|
"# This will build all the dependencies you will need\n",
|
||||||
"!UV_SYSTEM_PYTHON=1 llama stack build --template together --image-type venv"
|
"!UV_SYSTEM_PYTHON=1 llama stack build --template together --image-type venv"
|
||||||
]
|
]
|
||||||
|
@ -1203,7 +1205,7 @@
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"from llama_stack_client.lib.inference.event_logger import EventLogger\n",
|
"from llama_stack_client import InferenceEventLogger\n",
|
||||||
"\n",
|
"\n",
|
||||||
"message = {\"role\": \"user\", \"content\": \"Write me a sonnet about llama\"}\n",
|
"message = {\"role\": \"user\", \"content\": \"Write me a sonnet about llama\"}\n",
|
||||||
"print(f'User> {message[\"content\"]}', \"green\")\n",
|
"print(f'User> {message[\"content\"]}', \"green\")\n",
|
||||||
|
@ -1215,7 +1217,7 @@
|
||||||
")\n",
|
")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Print the tokens while they are received\n",
|
"# Print the tokens while they are received\n",
|
||||||
"for log in EventLogger().log(response):\n",
|
"for log in InferenceEventLogger().log(response):\n",
|
||||||
" log.print()\n"
|
" log.print()\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
@ -1632,8 +1634,7 @@
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"from llama_stack_client.lib.agents.agent import Agent\n",
|
"from llama_stack_client import Agent, AgentEventLogger\n",
|
||||||
"from llama_stack_client.lib.agents.event_logger import EventLogger\n",
|
|
||||||
"from termcolor import cprint\n",
|
"from termcolor import cprint\n",
|
||||||
"\n",
|
"\n",
|
||||||
"agent = Agent(\n",
|
"agent = Agent(\n",
|
||||||
|
@ -1659,7 +1660,7 @@
|
||||||
" ],\n",
|
" ],\n",
|
||||||
" session_id=session_id,\n",
|
" session_id=session_id,\n",
|
||||||
" )\n",
|
" )\n",
|
||||||
" for log in EventLogger().log(response):\n",
|
" for log in AgentEventLogger().log(response):\n",
|
||||||
" log.print()\n"
|
" log.print()\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
@ -1808,14 +1809,12 @@
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"import uuid\n",
|
"import uuid\n",
|
||||||
"from llama_stack_client.lib.agents.agent import Agent\n",
|
"from llama_stack_client import Agent, AgentEventLogger, RAGDocument\n",
|
||||||
"from llama_stack_client.lib.agents.event_logger import EventLogger\n",
|
|
||||||
"from termcolor import cprint\n",
|
"from termcolor import cprint\n",
|
||||||
"from llama_stack_client.types import Document\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
"urls = [\"chat.rst\", \"llama3.rst\", \"memory_optimizations.rst\", \"lora_finetune.rst\"]\n",
|
"urls = [\"chat.rst\", \"llama3.rst\", \"memory_optimizations.rst\", \"lora_finetune.rst\"]\n",
|
||||||
"documents = [\n",
|
"documents = [\n",
|
||||||
" Document(\n",
|
" RAGDocument(\n",
|
||||||
" document_id=f\"num-{i}\",\n",
|
" document_id=f\"num-{i}\",\n",
|
||||||
" content=f\"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}\",\n",
|
" content=f\"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}\",\n",
|
||||||
" mime_type=\"text/plain\",\n",
|
" mime_type=\"text/plain\",\n",
|
||||||
|
@ -1858,7 +1857,7 @@
|
||||||
" messages=[{\"role\": \"user\", \"content\": prompt}],\n",
|
" messages=[{\"role\": \"user\", \"content\": prompt}],\n",
|
||||||
" session_id=session_id,\n",
|
" session_id=session_id,\n",
|
||||||
" )\n",
|
" )\n",
|
||||||
" for log in EventLogger().log(response):\n",
|
" for log in AgentEventLogger().log(response):\n",
|
||||||
" log.print()"
|
" log.print()"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
@ -1969,7 +1968,7 @@
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"from llama_stack_client.types.agents.turn_create_params import Document\n",
|
"from llama_stack_client import Document\n",
|
||||||
"\n",
|
"\n",
|
||||||
"codex_agent = Agent(\n",
|
"codex_agent = Agent(\n",
|
||||||
" client, \n",
|
" client, \n",
|
||||||
|
@ -2013,7 +2012,7 @@
|
||||||
" # for chunk in response:\n",
|
" # for chunk in response:\n",
|
||||||
" # print(chunk)\n",
|
" # print(chunk)\n",
|
||||||
"\n",
|
"\n",
|
||||||
" for log in EventLogger().log(response):\n",
|
" for log in AgentEventLogger().log(response):\n",
|
||||||
" log.print()\n"
|
" log.print()\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
@ -2891,8 +2890,7 @@
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"# NBVAL_SKIP\n",
|
"# NBVAL_SKIP\n",
|
||||||
"from llama_stack_client.lib.agents.agent import Agent\n",
|
"from llama_stack_client import Agent, AgentEventLogger\n",
|
||||||
"from llama_stack_client.lib.agents.event_logger import EventLogger\n",
|
|
||||||
"from termcolor import cprint\n",
|
"from termcolor import cprint\n",
|
||||||
"\n",
|
"\n",
|
||||||
"agent = Agent(\n",
|
"agent = Agent(\n",
|
||||||
|
@ -2918,7 +2916,7 @@
|
||||||
" ],\n",
|
" ],\n",
|
||||||
" session_id=session_id,\n",
|
" session_id=session_id,\n",
|
||||||
" )\n",
|
" )\n",
|
||||||
" for log in EventLogger().log(response):\n",
|
" for log in AgentEventLogger().log(response):\n",
|
||||||
" log.print()\n"
|
" log.print()\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
@ -2993,8 +2991,7 @@
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"from llama_stack_client.lib.agents.agent import Agent\n",
|
"from llama_stack_client import Agent, AgentEventLogger\n",
|
||||||
"from llama_stack_client.lib.agents.event_logger import EventLogger\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
"agent = Agent(\n",
|
"agent = Agent(\n",
|
||||||
" client, \n",
|
" client, \n",
|
||||||
|
@ -3021,7 +3018,7 @@
|
||||||
" session_id=session_id,\n",
|
" session_id=session_id,\n",
|
||||||
" )\n",
|
" )\n",
|
||||||
"\n",
|
"\n",
|
||||||
" for log in EventLogger().log(response):\n",
|
" for log in AgentEventLogger().log(response):\n",
|
||||||
" log.print()\n"
|
" log.print()\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
@ -4355,7 +4352,7 @@
|
||||||
" session_id=session_id,\n",
|
" session_id=session_id,\n",
|
||||||
")\n",
|
")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"for log in EventLogger().log(response):\n",
|
"for log in AgentEventLogger().log(response):\n",
|
||||||
" log.print()\n",
|
" log.print()\n",
|
||||||
" "
|
" "
|
||||||
]
|
]
|
||||||
|
|
|
@ -47,9 +47,8 @@
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"from llama_stack_client import LlamaStackClient\n",
|
"from llama_stack_client import LlamaStackClient, Agent\n",
|
||||||
"from llama_stack.distribution.library_client import LlamaStackAsLibraryClient\n",
|
"from llama_stack.distribution.library_client import LlamaStackAsLibraryClient\n",
|
||||||
"from llama_stack_client.lib.agents.agent import Agent\n",
|
|
||||||
"from rich.pretty import pprint\n",
|
"from rich.pretty import pprint\n",
|
||||||
"import json\n",
|
"import json\n",
|
||||||
"import uuid\n",
|
"import uuid\n",
|
||||||
|
|
|
@ -34,10 +34,8 @@
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"from llama_stack_client import LlamaStackClient\n",
|
"from llama_stack_client import LlamaStackClient, Agent\n",
|
||||||
"from llama_stack.distribution.library_client import LlamaStackAsLibraryClient\n",
|
"from llama_stack.distribution.library_client import LlamaStackAsLibraryClient\n",
|
||||||
"from llama_stack_client.types.agent_create_params import AgentConfig\n",
|
|
||||||
"from llama_stack_client.lib.agents.agent import Agent\n",
|
|
||||||
"from rich.pretty import pprint\n",
|
"from rich.pretty import pprint\n",
|
||||||
"import json\n",
|
"import json\n",
|
||||||
"import uuid\n",
|
"import uuid\n",
|
||||||
|
|
|
@ -14,7 +14,7 @@ Agents are configured using the `AgentConfig` class, which includes:
|
||||||
- **Safety Shields**: Guardrails to ensure responsible AI behavior
|
- **Safety Shields**: Guardrails to ensure responsible AI behavior
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from llama_stack_client.lib.agents.agent import Agent
|
from llama_stack_client import Agent
|
||||||
|
|
||||||
|
|
||||||
# Create the agent
|
# Create the agent
|
||||||
|
@ -44,14 +44,14 @@ Each interaction with an agent is called a "turn" and consists of:
|
||||||
- **Output Message**: The agent's response
|
- **Output Message**: The agent's response
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from llama_stack_client.lib.agents.event_logger import EventLogger
|
from llama_stack_client import AgentEventLogger
|
||||||
|
|
||||||
# Create a turn with streaming response
|
# Create a turn with streaming response
|
||||||
turn_response = agent.create_turn(
|
turn_response = agent.create_turn(
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
messages=[{"role": "user", "content": "Tell me about Llama models"}],
|
messages=[{"role": "user", "content": "Tell me about Llama models"}],
|
||||||
)
|
)
|
||||||
for log in EventLogger().log(turn_response):
|
for log in AgentEventLogger().log(turn_response):
|
||||||
log.print()
|
log.print()
|
||||||
```
|
```
|
||||||
### Non-Streaming
|
### Non-Streaming
|
||||||
|
|
|
@ -67,9 +67,7 @@ sequenceDiagram
|
||||||
Each step in this process can be monitored and controlled through configurations. Here's an example that demonstrates monitoring the agent's execution:
|
Each step in this process can be monitored and controlled through configurations. Here's an example that demonstrates monitoring the agent's execution:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from llama_stack_client import LlamaStackClient
|
from llama_stack_client import LlamaStackClient, Agent, AgentEventLogger
|
||||||
from llama_stack_client.lib.agents.agent import Agent
|
|
||||||
from llama_stack_client.lib.agents.event_logger import EventLogger
|
|
||||||
from rich.pretty import pprint
|
from rich.pretty import pprint
|
||||||
|
|
||||||
# Replace host and port
|
# Replace host and port
|
||||||
|
@ -113,7 +111,7 @@ response = agent.create_turn(
|
||||||
)
|
)
|
||||||
|
|
||||||
# Monitor each step of execution
|
# Monitor each step of execution
|
||||||
for log in EventLogger().log(response):
|
for log in AgentEventLogger().log(response):
|
||||||
log.print()
|
log.print()
|
||||||
|
|
||||||
# Using non-streaming API, the response contains input, steps, and output.
|
# Using non-streaming API, the response contains input, steps, and output.
|
||||||
|
|
|
@ -23,9 +23,7 @@ In this example, we will show you how to:
|
||||||
|
|
||||||
##### Building a Search Agent
|
##### Building a Search Agent
|
||||||
```python
|
```python
|
||||||
from llama_stack_client import LlamaStackClient
|
from llama_stack_client import LlamaStackClient, Agent, AgentEventLogger
|
||||||
from llama_stack_client.lib.agents.agent import Agent
|
|
||||||
from llama_stack_client.lib.agents.event_logger import EventLogger
|
|
||||||
|
|
||||||
client = LlamaStackClient(base_url=f"http://{HOST}:{PORT}")
|
client = LlamaStackClient(base_url=f"http://{HOST}:{PORT}")
|
||||||
|
|
||||||
|
@ -54,7 +52,7 @@ for prompt in user_prompts:
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
for log in EventLogger().log(response):
|
for log in AgentEventLogger().log(response):
|
||||||
log.print()
|
log.print()
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
|
@ -55,11 +55,11 @@ chunks_response = client.vector_io.query(
|
||||||
A better way to ingest documents is to use the RAG Tool. This tool allows you to ingest documents from URLs, files, etc. and automatically chunks them into smaller pieces.
|
A better way to ingest documents is to use the RAG Tool. This tool allows you to ingest documents from URLs, files, etc. and automatically chunks them into smaller pieces.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from llama_stack_client.types import Document
|
from llama_stack_client import RAGDocument
|
||||||
|
|
||||||
urls = ["memory_optimizations.rst", "chat.rst", "llama3.rst"]
|
urls = ["memory_optimizations.rst", "chat.rst", "llama3.rst"]
|
||||||
documents = [
|
documents = [
|
||||||
Document(
|
RAGDocument(
|
||||||
document_id=f"num-{i}",
|
document_id=f"num-{i}",
|
||||||
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
|
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
|
||||||
mime_type="text/plain",
|
mime_type="text/plain",
|
||||||
|
@ -86,7 +86,7 @@ results = client.tool_runtime.rag_tool.query(
|
||||||
One of the most powerful patterns is combining agents with RAG capabilities. Here's a complete example:
|
One of the most powerful patterns is combining agents with RAG capabilities. Here's a complete example:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from llama_stack_client.lib.agents.agent import Agent
|
from llama_stack_client import Agent
|
||||||
|
|
||||||
# Create agent with memory
|
# Create agent with memory
|
||||||
agent = Agent(
|
agent = Agent(
|
||||||
|
@ -140,9 +140,9 @@ response = agent.create_turn(
|
||||||
|
|
||||||
You can print the response with below.
|
You can print the response with below.
|
||||||
```python
|
```python
|
||||||
from llama_stack_client.lib.agents.event_logger import EventLogger
|
from llama_stack_client import AgentEventLogger
|
||||||
|
|
||||||
for log in EventLogger().log(response):
|
for log in AgentEventLogger().log(response):
|
||||||
log.print()
|
log.print()
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
|
@ -57,7 +57,7 @@ The `otel` sink works with any service compatible with the OpenTelemetry collect
|
||||||
Start a Jaeger instance with the OTLP HTTP endpoint at 4318 and the Jaeger UI at 16686 using the following command:
|
Start a Jaeger instance with the OTLP HTTP endpoint at 4318 and the Jaeger UI at 16686 using the following command:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
$ docker run --rm --name jaeger \
|
$ docker run --pull always --rm --name jaeger \
|
||||||
-p 16686:16686 -p 4318:4318 \
|
-p 16686:16686 -p 4318:4318 \
|
||||||
jaegertracing/jaeger:2.1.0
|
jaegertracing/jaeger:2.1.0
|
||||||
```
|
```
|
||||||
|
|
|
@ -110,10 +110,18 @@ MCP tools are special tools that can interact with llama stack over model contex
|
||||||
|
|
||||||
Refer to [https://github.com/modelcontextprotocol/servers](https://github.com/modelcontextprotocol/servers) for available MCP servers.
|
Refer to [https://github.com/modelcontextprotocol/servers](https://github.com/modelcontextprotocol/servers) for available MCP servers.
|
||||||
|
|
||||||
|
```shell
|
||||||
|
# start your MCP server
|
||||||
|
mkdir /tmp/content
|
||||||
|
touch /tmp/content/foo
|
||||||
|
touch /tmp/content/bar
|
||||||
|
npx -y supergateway --port 8000 --stdio 'npx -y @modelcontextprotocol/server-filesystem /tmp/content'
|
||||||
|
```
|
||||||
|
|
||||||
|
Then register the MCP server as a tool group,
|
||||||
```python
|
```python
|
||||||
# Register MCP tools
|
|
||||||
client.toolgroups.register(
|
client.toolgroups.register(
|
||||||
toolgroup_id="builtin::filesystem",
|
toolgroup_id="mcp::filesystem",
|
||||||
provider_id="model-context-protocol",
|
provider_id="model-context-protocol",
|
||||||
mcp_endpoint=URL(uri="http://localhost:8000/sse"),
|
mcp_endpoint=URL(uri="http://localhost:8000/sse"),
|
||||||
)
|
)
|
||||||
|
@ -181,7 +189,7 @@ group_tools = client.tools.list_tools(toolgroup_id="search_tools")
|
||||||
## Simple Example: Using an Agent with the Code-Interpreter Tool
|
## Simple Example: Using an Agent with the Code-Interpreter Tool
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from llama_stack_client.lib.agents.agent import Agent
|
from llama_stack_client import Agent
|
||||||
|
|
||||||
# Instantiate the AI agent with the given configuration
|
# Instantiate the AI agent with the given configuration
|
||||||
agent = Agent(
|
agent = Agent(
|
||||||
|
|
|
@ -55,7 +55,7 @@ llama stack run llama_stack/templates/open-benchmark/run.yaml
|
||||||
There are 3 necessary inputs to run a benchmark eval
|
There are 3 necessary inputs to run a benchmark eval
|
||||||
- `list of benchmark_ids`: The list of benchmark ids to run evaluation on
|
- `list of benchmark_ids`: The list of benchmark ids to run evaluation on
|
||||||
- `model-id`: The model id to evaluate on
|
- `model-id`: The model id to evaluate on
|
||||||
- `utput_dir`: Path to store the evaluate results
|
- `output_dir`: Path to store the evaluate results
|
||||||
```
|
```
|
||||||
llama-stack-client eval run-benchmark <benchmark_id_1> <benchmark_id_2> ... \
|
llama-stack-client eval run-benchmark <benchmark_id_1> <benchmark_id_2> ... \
|
||||||
--model_id <model id to evaluate on> \
|
--model_id <model id to evaluate on> \
|
||||||
|
@ -69,7 +69,7 @@ llama-stack-client eval run-benchmark help
|
||||||
to see the description of all the flags that eval run-benchmark has
|
to see the description of all the flags that eval run-benchmark has
|
||||||
|
|
||||||
|
|
||||||
In the output log, you can find the file path that has your evaluation results. Open that file and you can see you aggrgate
|
In the output log, you can find the file path that has your evaluation results. Open that file and you can see you aggregate
|
||||||
evaluation results over there.
|
evaluation results over there.
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -56,9 +56,10 @@ You can do this via Conda (build code) or Docker which has a pre-built image.
|
||||||
This method allows you to get started quickly without having to build the distribution code.
|
This method allows you to get started quickly without having to build the distribution code.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
LLAMA_STACK_PORT=5001
|
LLAMA_STACK_PORT=8321
|
||||||
docker run \
|
docker run \
|
||||||
-it \
|
-it \
|
||||||
|
--pull always \
|
||||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||||
-v ./run.yaml:/root/my-run.yaml \
|
-v ./run.yaml:/root/my-run.yaml \
|
||||||
llamastack/distribution-nvidia \
|
llamastack/distribution-nvidia \
|
||||||
|
@ -72,7 +73,7 @@ docker run \
|
||||||
```bash
|
```bash
|
||||||
llama stack build --template nvidia --image-type conda
|
llama stack build --template nvidia --image-type conda
|
||||||
llama stack run ./run.yaml \
|
llama stack run ./run.yaml \
|
||||||
--port 5001 \
|
--port 8321 \
|
||||||
--env NVIDIA_API_KEY=$NVIDIA_API_KEY
|
--env NVIDIA_API_KEY=$NVIDIA_API_KEY
|
||||||
--env INFERENCE_MODEL=$INFERENCE_MODEL
|
--env INFERENCE_MODEL=$INFERENCE_MODEL
|
||||||
```
|
```
|
||||||
|
|
|
@ -26,7 +26,7 @@ The `llamastack/distribution-bedrock` distribution consists of the following pro
|
||||||
|
|
||||||
The following environment variables can be configured:
|
The following environment variables can be configured:
|
||||||
|
|
||||||
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
|
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `8321`)
|
||||||
|
|
||||||
### Models
|
### Models
|
||||||
|
|
||||||
|
@ -51,9 +51,10 @@ You can do this via Conda (build code) or Docker which has a pre-built image.
|
||||||
This method allows you to get started quickly without having to build the distribution code.
|
This method allows you to get started quickly without having to build the distribution code.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
LLAMA_STACK_PORT=5001
|
LLAMA_STACK_PORT=8321
|
||||||
docker run \
|
docker run \
|
||||||
-it \
|
-it \
|
||||||
|
--pull always \
|
||||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||||
llamastack/distribution-bedrock \
|
llamastack/distribution-bedrock \
|
||||||
--port $LLAMA_STACK_PORT \
|
--port $LLAMA_STACK_PORT \
|
||||||
|
|
|
@ -18,7 +18,7 @@ The `llamastack/distribution-cerebras` distribution consists of the following pr
|
||||||
|
|
||||||
The following environment variables can be configured:
|
The following environment variables can be configured:
|
||||||
|
|
||||||
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
|
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `8321`)
|
||||||
- `CEREBRAS_API_KEY`: Cerebras API Key (default: ``)
|
- `CEREBRAS_API_KEY`: Cerebras API Key (default: ``)
|
||||||
|
|
||||||
### Models
|
### Models
|
||||||
|
@ -43,9 +43,10 @@ You can do this via Conda (build code) or Docker which has a pre-built image.
|
||||||
This method allows you to get started quickly without having to build the distribution code.
|
This method allows you to get started quickly without having to build the distribution code.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
LLAMA_STACK_PORT=5001
|
LLAMA_STACK_PORT=8321
|
||||||
docker run \
|
docker run \
|
||||||
-it \
|
-it \
|
||||||
|
--pull always \
|
||||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||||
-v ./run.yaml:/root/my-run.yaml \
|
-v ./run.yaml:/root/my-run.yaml \
|
||||||
llamastack/distribution-cerebras \
|
llamastack/distribution-cerebras \
|
||||||
|
@ -59,6 +60,6 @@ docker run \
|
||||||
```bash
|
```bash
|
||||||
llama stack build --template cerebras --image-type conda
|
llama stack build --template cerebras --image-type conda
|
||||||
llama stack run ./run.yaml \
|
llama stack run ./run.yaml \
|
||||||
--port 5001 \
|
--port 8321 \
|
||||||
--env CEREBRAS_API_KEY=$CEREBRAS_API_KEY
|
--env CEREBRAS_API_KEY=$CEREBRAS_API_KEY
|
||||||
```
|
```
|
||||||
|
|
|
@ -53,7 +53,7 @@ docker compose down
|
||||||
|
|
||||||
#### Start Dell-TGI server locally
|
#### Start Dell-TGI server locally
|
||||||
```
|
```
|
||||||
docker run -it --shm-size 1g -p 80:80 --gpus 4 \
|
docker run -it --pull always --shm-size 1g -p 80:80 --gpus 4 \
|
||||||
-e NUM_SHARD=4
|
-e NUM_SHARD=4
|
||||||
-e MAX_BATCH_PREFILL_TOKENS=32768 \
|
-e MAX_BATCH_PREFILL_TOKENS=32768 \
|
||||||
-e MAX_INPUT_TOKENS=8000 \
|
-e MAX_INPUT_TOKENS=8000 \
|
||||||
|
@ -65,7 +65,7 @@ registry.dell.huggingface.co/enterprise-dell-inference-meta-llama-meta-llama-3.1
|
||||||
#### Start Llama Stack server pointing to TGI server
|
#### Start Llama Stack server pointing to TGI server
|
||||||
|
|
||||||
```
|
```
|
||||||
docker run --network host -it -p 8321:8321 -v ./run.yaml:/root/my-run.yaml --gpus=all llamastack/distribution-tgi --yaml_config /root/my-run.yaml
|
docker run --pull always --network host -it -p 8321:8321 -v ./run.yaml:/root/my-run.yaml --gpus=all llamastack/distribution-tgi --yaml_config /root/my-run.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
Make sure in you `run.yaml` file, you inference provider is pointing to the correct TGI server endpoint. E.g.
|
Make sure in you `run.yaml` file, you inference provider is pointing to the correct TGI server endpoint. E.g.
|
||||||
|
|
|
@ -55,6 +55,7 @@ export CUDA_VISIBLE_DEVICES=0
|
||||||
export LLAMA_STACK_PORT=8321
|
export LLAMA_STACK_PORT=8321
|
||||||
|
|
||||||
docker run --rm -it \
|
docker run --rm -it \
|
||||||
|
--pull always \
|
||||||
--network host \
|
--network host \
|
||||||
-v $HOME/.cache/huggingface:/data \
|
-v $HOME/.cache/huggingface:/data \
|
||||||
-e HF_TOKEN=$HF_TOKEN \
|
-e HF_TOKEN=$HF_TOKEN \
|
||||||
|
@ -78,6 +79,7 @@ export SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
|
||||||
export CUDA_VISIBLE_DEVICES=1
|
export CUDA_VISIBLE_DEVICES=1
|
||||||
|
|
||||||
docker run --rm -it \
|
docker run --rm -it \
|
||||||
|
--pull always \
|
||||||
--network host \
|
--network host \
|
||||||
-v $HOME/.cache/huggingface:/data \
|
-v $HOME/.cache/huggingface:/data \
|
||||||
-e HF_TOKEN=$HF_TOKEN \
|
-e HF_TOKEN=$HF_TOKEN \
|
||||||
|
@ -120,6 +122,7 @@ This method allows you to get started quickly without having to build the distri
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
docker run -it \
|
docker run -it \
|
||||||
|
--pull always \
|
||||||
--network host \
|
--network host \
|
||||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||||
-v $HOME/.llama:/root/.llama \
|
-v $HOME/.llama:/root/.llama \
|
||||||
|
@ -147,6 +150,7 @@ export SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
|
||||||
|
|
||||||
docker run \
|
docker run \
|
||||||
-it \
|
-it \
|
||||||
|
--pull always \
|
||||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||||
-v $HOME/.llama:/root/.llama \
|
-v $HOME/.llama:/root/.llama \
|
||||||
-v ./llama_stack/templates/tgi/run-with-safety.yaml:/root/my-run.yaml \
|
-v ./llama_stack/templates/tgi/run-with-safety.yaml:/root/my-run.yaml \
|
||||||
|
|
|
@ -28,7 +28,7 @@ The `llamastack/distribution-fireworks` distribution consists of the following p
|
||||||
|
|
||||||
The following environment variables can be configured:
|
The following environment variables can be configured:
|
||||||
|
|
||||||
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
|
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `8321`)
|
||||||
- `FIREWORKS_API_KEY`: Fireworks.AI API Key (default: ``)
|
- `FIREWORKS_API_KEY`: Fireworks.AI API Key (default: ``)
|
||||||
|
|
||||||
### Models
|
### Models
|
||||||
|
@ -61,9 +61,10 @@ You can do this via Conda (build code) or Docker which has a pre-built image.
|
||||||
This method allows you to get started quickly without having to build the distribution code.
|
This method allows you to get started quickly without having to build the distribution code.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
LLAMA_STACK_PORT=5001
|
LLAMA_STACK_PORT=8321
|
||||||
docker run \
|
docker run \
|
||||||
-it \
|
-it \
|
||||||
|
--pull always \
|
||||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||||
llamastack/distribution-fireworks \
|
llamastack/distribution-fireworks \
|
||||||
--port $LLAMA_STACK_PORT \
|
--port $LLAMA_STACK_PORT \
|
||||||
|
|
|
@ -28,7 +28,7 @@ The `llamastack/distribution-groq` distribution consists of the following provid
|
||||||
|
|
||||||
The following environment variables can be configured:
|
The following environment variables can be configured:
|
||||||
|
|
||||||
- `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
|
- `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `8321`)
|
||||||
- `GROQ_API_KEY`: Groq API Key (default: ``)
|
- `GROQ_API_KEY`: Groq API Key (default: ``)
|
||||||
|
|
||||||
### Models
|
### Models
|
||||||
|
@ -56,9 +56,10 @@ You can do this via Conda (build code) or Docker which has a pre-built image.
|
||||||
This method allows you to get started quickly without having to build the distribution code.
|
This method allows you to get started quickly without having to build the distribution code.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
LLAMA_STACK_PORT=5001
|
LLAMA_STACK_PORT=8321
|
||||||
docker run \
|
docker run \
|
||||||
-it \
|
-it \
|
||||||
|
--pull always \
|
||||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||||
llamastack/distribution-groq \
|
llamastack/distribution-groq \
|
||||||
--port $LLAMA_STACK_PORT \
|
--port $LLAMA_STACK_PORT \
|
||||||
|
|
|
@ -30,7 +30,7 @@ Note that you need access to nvidia GPUs to run this distribution. This distribu
|
||||||
|
|
||||||
The following environment variables can be configured:
|
The following environment variables can be configured:
|
||||||
|
|
||||||
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
|
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `8321`)
|
||||||
- `INFERENCE_MODEL`: Inference model loaded into the Meta Reference server (default: `meta-llama/Llama-3.2-3B-Instruct`)
|
- `INFERENCE_MODEL`: Inference model loaded into the Meta Reference server (default: `meta-llama/Llama-3.2-3B-Instruct`)
|
||||||
- `INFERENCE_CHECKPOINT_DIR`: Directory containing the Meta Reference model checkpoint (default: `null`)
|
- `INFERENCE_CHECKPOINT_DIR`: Directory containing the Meta Reference model checkpoint (default: `null`)
|
||||||
- `SAFETY_MODEL`: Name of the safety (Llama-Guard) model to use (default: `meta-llama/Llama-Guard-3-1B`)
|
- `SAFETY_MODEL`: Name of the safety (Llama-Guard) model to use (default: `meta-llama/Llama-Guard-3-1B`)
|
||||||
|
@ -75,9 +75,10 @@ You can do this via Conda (build code) or Docker which has a pre-built image.
|
||||||
This method allows you to get started quickly without having to build the distribution code.
|
This method allows you to get started quickly without having to build the distribution code.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
LLAMA_STACK_PORT=5001
|
LLAMA_STACK_PORT=8321
|
||||||
docker run \
|
docker run \
|
||||||
-it \
|
-it \
|
||||||
|
--pull always \
|
||||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||||
-v ~/.llama:/root/.llama \
|
-v ~/.llama:/root/.llama \
|
||||||
llamastack/distribution-meta-reference-gpu \
|
llamastack/distribution-meta-reference-gpu \
|
||||||
|
@ -90,6 +91,7 @@ If you are using Llama Stack Safety / Shield APIs, use:
|
||||||
```bash
|
```bash
|
||||||
docker run \
|
docker run \
|
||||||
-it \
|
-it \
|
||||||
|
--pull always \
|
||||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||||
-v ~/.llama:/root/.llama \
|
-v ~/.llama:/root/.llama \
|
||||||
llamastack/distribution-meta-reference-gpu \
|
llamastack/distribution-meta-reference-gpu \
|
||||||
|
@ -105,7 +107,7 @@ Make sure you have done `uv pip install llama-stack` and have the Llama Stack CL
|
||||||
```bash
|
```bash
|
||||||
llama stack build --template meta-reference-gpu --image-type conda
|
llama stack build --template meta-reference-gpu --image-type conda
|
||||||
llama stack run distributions/meta-reference-gpu/run.yaml \
|
llama stack run distributions/meta-reference-gpu/run.yaml \
|
||||||
--port 5001 \
|
--port 8321 \
|
||||||
--env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
|
--env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -113,7 +115,7 @@ If you are using Llama Stack Safety / Shield APIs, use:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
llama stack run distributions/meta-reference-gpu/run-with-safety.yaml \
|
llama stack run distributions/meta-reference-gpu/run-with-safety.yaml \
|
||||||
--port 5001 \
|
--port 8321 \
|
||||||
--env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \
|
--env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \
|
||||||
--env SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
|
--env SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
|
||||||
```
|
```
|
||||||
|
|
|
@ -32,7 +32,7 @@ Note that you need access to nvidia GPUs to run this distribution. This distribu
|
||||||
|
|
||||||
The following environment variables can be configured:
|
The following environment variables can be configured:
|
||||||
|
|
||||||
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
|
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `8321`)
|
||||||
- `INFERENCE_MODEL`: Inference model loaded into the Meta Reference server (default: `meta-llama/Llama-3.2-3B-Instruct`)
|
- `INFERENCE_MODEL`: Inference model loaded into the Meta Reference server (default: `meta-llama/Llama-3.2-3B-Instruct`)
|
||||||
- `INFERENCE_CHECKPOINT_DIR`: Directory containing the Meta Reference model checkpoint (default: `null`)
|
- `INFERENCE_CHECKPOINT_DIR`: Directory containing the Meta Reference model checkpoint (default: `null`)
|
||||||
|
|
||||||
|
@ -75,9 +75,10 @@ You can do this via Conda (build code) or Docker which has a pre-built image.
|
||||||
This method allows you to get started quickly without having to build the distribution code.
|
This method allows you to get started quickly without having to build the distribution code.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
LLAMA_STACK_PORT=5001
|
LLAMA_STACK_PORT=8321
|
||||||
docker run \
|
docker run \
|
||||||
-it \
|
-it \
|
||||||
|
--pull always \
|
||||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||||
-v ~/.llama:/root/.llama \
|
-v ~/.llama:/root/.llama \
|
||||||
llamastack/distribution-meta-reference-quantized-gpu \
|
llamastack/distribution-meta-reference-quantized-gpu \
|
||||||
|
@ -90,6 +91,7 @@ If you are using Llama Stack Safety / Shield APIs, use:
|
||||||
```bash
|
```bash
|
||||||
docker run \
|
docker run \
|
||||||
-it \
|
-it \
|
||||||
|
--pull always \
|
||||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||||
-v ~/.llama:/root/.llama \
|
-v ~/.llama:/root/.llama \
|
||||||
llamastack/distribution-meta-reference-quantized-gpu \
|
llamastack/distribution-meta-reference-quantized-gpu \
|
||||||
|
|
|
@ -15,7 +15,7 @@ The `llamastack/distribution-nvidia` distribution consists of the following prov
|
||||||
|
|
||||||
The following environment variables can be configured:
|
The following environment variables can be configured:
|
||||||
|
|
||||||
- `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
|
- `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `8321`)
|
||||||
- `NVIDIA_API_KEY`: NVIDIA API Key (default: ``)
|
- `NVIDIA_API_KEY`: NVIDIA API Key (default: ``)
|
||||||
|
|
||||||
### Models
|
### Models
|
||||||
|
@ -39,9 +39,10 @@ You can do this via Conda (build code) or Docker which has a pre-built image.
|
||||||
This method allows you to get started quickly without having to build the distribution code.
|
This method allows you to get started quickly without having to build the distribution code.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
LLAMA_STACK_PORT=5001
|
LLAMA_STACK_PORT=8321
|
||||||
docker run \
|
docker run \
|
||||||
-it \
|
-it \
|
||||||
|
--pull always \
|
||||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||||
-v ./run.yaml:/root/my-run.yaml \
|
-v ./run.yaml:/root/my-run.yaml \
|
||||||
llamastack/distribution-nvidia \
|
llamastack/distribution-nvidia \
|
||||||
|
@ -55,6 +56,6 @@ docker run \
|
||||||
```bash
|
```bash
|
||||||
llama stack build --template nvidia --image-type conda
|
llama stack build --template nvidia --image-type conda
|
||||||
llama stack run ./run.yaml \
|
llama stack run ./run.yaml \
|
||||||
--port 5001 \
|
--port 8321 \
|
||||||
--env NVIDIA_API_KEY=$NVIDIA_API_KEY
|
--env NVIDIA_API_KEY=$NVIDIA_API_KEY
|
||||||
```
|
```
|
||||||
|
|
|
@ -30,7 +30,7 @@ You should use this distribution if you have a regular desktop machine without v
|
||||||
|
|
||||||
The following environment variables can be configured:
|
The following environment variables can be configured:
|
||||||
|
|
||||||
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
|
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `8321`)
|
||||||
- `OLLAMA_URL`: URL of the Ollama server (default: `http://127.0.0.1:11434`)
|
- `OLLAMA_URL`: URL of the Ollama server (default: `http://127.0.0.1:11434`)
|
||||||
- `INFERENCE_MODEL`: Inference model loaded into the Ollama server (default: `meta-llama/Llama-3.2-3B-Instruct`)
|
- `INFERENCE_MODEL`: Inference model loaded into the Ollama server (default: `meta-llama/Llama-3.2-3B-Instruct`)
|
||||||
- `SAFETY_MODEL`: Safety model loaded into the Ollama server (default: `meta-llama/Llama-Guard-3-1B`)
|
- `SAFETY_MODEL`: Safety model loaded into the Ollama server (default: `meta-llama/Llama-Guard-3-1B`)
|
||||||
|
@ -69,9 +69,10 @@ Now you are ready to run Llama Stack with Ollama as the inference provider. You
|
||||||
This method allows you to get started quickly without having to build the distribution code.
|
This method allows you to get started quickly without having to build the distribution code.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
export LLAMA_STACK_PORT=5001
|
export LLAMA_STACK_PORT=8321
|
||||||
docker run \
|
docker run \
|
||||||
-it \
|
-it \
|
||||||
|
--pull always \
|
||||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||||
-v ~/.llama:/root/.llama \
|
-v ~/.llama:/root/.llama \
|
||||||
llamastack/distribution-ollama \
|
llamastack/distribution-ollama \
|
||||||
|
@ -89,6 +90,7 @@ cd /path/to/llama-stack
|
||||||
|
|
||||||
docker run \
|
docker run \
|
||||||
-it \
|
-it \
|
||||||
|
--pull always \
|
||||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||||
-v ~/.llama:/root/.llama \
|
-v ~/.llama:/root/.llama \
|
||||||
-v ./llama_stack/templates/ollama/run-with-safety.yaml:/root/my-run.yaml \
|
-v ./llama_stack/templates/ollama/run-with-safety.yaml:/root/my-run.yaml \
|
||||||
|
@ -105,7 +107,7 @@ docker run \
|
||||||
Make sure you have done `uv pip install llama-stack` and have the Llama Stack CLI available.
|
Make sure you have done `uv pip install llama-stack` and have the Llama Stack CLI available.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
export LLAMA_STACK_PORT=5001
|
export LLAMA_STACK_PORT=8321
|
||||||
|
|
||||||
llama stack build --template ollama --image-type conda
|
llama stack build --template ollama --image-type conda
|
||||||
llama stack run ./run.yaml \
|
llama stack run ./run.yaml \
|
||||||
|
|
|
@ -28,7 +28,7 @@ The `llamastack/distribution-passthrough` distribution consists of the following
|
||||||
|
|
||||||
The following environment variables can be configured:
|
The following environment variables can be configured:
|
||||||
|
|
||||||
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
|
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `8321`)
|
||||||
- `PASSTHROUGH_API_KEY`: Passthrough API Key (default: ``)
|
- `PASSTHROUGH_API_KEY`: Passthrough API Key (default: ``)
|
||||||
- `PASSTHROUGH_URL`: Passthrough URL (default: ``)
|
- `PASSTHROUGH_URL`: Passthrough URL (default: ``)
|
||||||
|
|
||||||
|
|
|
@ -29,7 +29,7 @@ You can use this distribution if you have GPUs and want to run an independent vL
|
||||||
|
|
||||||
The following environment variables can be configured:
|
The following environment variables can be configured:
|
||||||
|
|
||||||
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
|
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `8321`)
|
||||||
- `INFERENCE_MODEL`: Inference model loaded into the vLLM server (default: `meta-llama/Llama-3.2-3B-Instruct`)
|
- `INFERENCE_MODEL`: Inference model loaded into the vLLM server (default: `meta-llama/Llama-3.2-3B-Instruct`)
|
||||||
- `VLLM_URL`: URL of the vLLM server with the main inference model (default: `http://host.docker.internal:5100/v1`)
|
- `VLLM_URL`: URL of the vLLM server with the main inference model (default: `http://host.docker.internal:5100/v1`)
|
||||||
- `MAX_TOKENS`: Maximum number of tokens for generation (default: `4096`)
|
- `MAX_TOKENS`: Maximum number of tokens for generation (default: `4096`)
|
||||||
|
@ -47,6 +47,7 @@ export INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
|
||||||
export CUDA_VISIBLE_DEVICES=0
|
export CUDA_VISIBLE_DEVICES=0
|
||||||
|
|
||||||
docker run \
|
docker run \
|
||||||
|
--pull always \
|
||||||
--runtime nvidia \
|
--runtime nvidia \
|
||||||
--gpus $CUDA_VISIBLE_DEVICES \
|
--gpus $CUDA_VISIBLE_DEVICES \
|
||||||
-v ~/.cache/huggingface:/root/.cache/huggingface \
|
-v ~/.cache/huggingface:/root/.cache/huggingface \
|
||||||
|
@ -59,6 +60,8 @@ docker run \
|
||||||
--port $INFERENCE_PORT
|
--port $INFERENCE_PORT
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Note that you'll also need to set `--enable-auto-tool-choice` and `--tool-call-parser` to [enable tool calling in vLLM](https://docs.vllm.ai/en/latest/features/tool_calling.html).
|
||||||
|
|
||||||
If you are using Llama Stack Safety / Shield APIs, then you will need to also run another instance of a vLLM with a corresponding safety model like `meta-llama/Llama-Guard-3-1B` using a script like:
|
If you are using Llama Stack Safety / Shield APIs, then you will need to also run another instance of a vLLM with a corresponding safety model like `meta-llama/Llama-Guard-3-1B` using a script like:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
@ -67,6 +70,7 @@ export SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
|
||||||
export CUDA_VISIBLE_DEVICES=1
|
export CUDA_VISIBLE_DEVICES=1
|
||||||
|
|
||||||
docker run \
|
docker run \
|
||||||
|
--pull always \
|
||||||
--runtime nvidia \
|
--runtime nvidia \
|
||||||
--gpus $CUDA_VISIBLE_DEVICES \
|
--gpus $CUDA_VISIBLE_DEVICES \
|
||||||
-v ~/.cache/huggingface:/root/.cache/huggingface \
|
-v ~/.cache/huggingface:/root/.cache/huggingface \
|
||||||
|
@ -90,10 +94,11 @@ This method allows you to get started quickly without having to build the distri
|
||||||
```bash
|
```bash
|
||||||
export INFERENCE_PORT=8000
|
export INFERENCE_PORT=8000
|
||||||
export INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
|
export INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
|
||||||
export LLAMA_STACK_PORT=5001
|
export LLAMA_STACK_PORT=8321
|
||||||
|
|
||||||
docker run \
|
docker run \
|
||||||
-it \
|
-it \
|
||||||
|
--pull always \
|
||||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||||
-v ./run.yaml:/root/my-run.yaml \
|
-v ./run.yaml:/root/my-run.yaml \
|
||||||
llamastack/distribution-remote-vllm \
|
llamastack/distribution-remote-vllm \
|
||||||
|
@ -115,6 +120,7 @@ cd /path/to/llama-stack
|
||||||
|
|
||||||
docker run \
|
docker run \
|
||||||
-it \
|
-it \
|
||||||
|
--pull always \
|
||||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||||
-v ~/.llama:/root/.llama \
|
-v ~/.llama:/root/.llama \
|
||||||
-v ./llama_stack/templates/remote-vllm/run-with-safety.yaml:/root/my-run.yaml \
|
-v ./llama_stack/templates/remote-vllm/run-with-safety.yaml:/root/my-run.yaml \
|
||||||
|
@ -135,7 +141,7 @@ Make sure you have done `uv pip install llama-stack` and have the Llama Stack CL
|
||||||
```bash
|
```bash
|
||||||
export INFERENCE_PORT=8000
|
export INFERENCE_PORT=8000
|
||||||
export INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
|
export INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
|
||||||
export LLAMA_STACK_PORT=5001
|
export LLAMA_STACK_PORT=8321
|
||||||
|
|
||||||
cd distributions/remote-vllm
|
cd distributions/remote-vllm
|
||||||
llama stack build --template remote-vllm --image-type conda
|
llama stack build --template remote-vllm --image-type conda
|
||||||
|
|
|
@ -27,7 +27,7 @@ The `llamastack/distribution-sambanova` distribution consists of the following p
|
||||||
|
|
||||||
The following environment variables can be configured:
|
The following environment variables can be configured:
|
||||||
|
|
||||||
- `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
|
- `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `8321`)
|
||||||
- `SAMBANOVA_API_KEY`: SambaNova.AI API Key (default: ``)
|
- `SAMBANOVA_API_KEY`: SambaNova.AI API Key (default: ``)
|
||||||
|
|
||||||
### Models
|
### Models
|
||||||
|
@ -59,9 +59,10 @@ You can do this via Conda (build code) or Docker which has a pre-built image.
|
||||||
This method allows you to get started quickly without having to build the distribution code.
|
This method allows you to get started quickly without having to build the distribution code.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
LLAMA_STACK_PORT=5001
|
LLAMA_STACK_PORT=8321
|
||||||
docker run \
|
docker run \
|
||||||
-it \
|
-it \
|
||||||
|
--pull always \
|
||||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||||
llamastack/distribution-sambanova \
|
llamastack/distribution-sambanova \
|
||||||
--port $LLAMA_STACK_PORT \
|
--port $LLAMA_STACK_PORT \
|
||||||
|
|
|
@ -31,7 +31,7 @@ You can use this distribution if you have GPUs and want to run an independent TG
|
||||||
|
|
||||||
The following environment variables can be configured:
|
The following environment variables can be configured:
|
||||||
|
|
||||||
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
|
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `8321`)
|
||||||
- `INFERENCE_MODEL`: Inference model loaded into the TGI server (default: `meta-llama/Llama-3.2-3B-Instruct`)
|
- `INFERENCE_MODEL`: Inference model loaded into the TGI server (default: `meta-llama/Llama-3.2-3B-Instruct`)
|
||||||
- `TGI_URL`: URL of the TGI server with the main inference model (default: `http://127.0.0.1:8080/v1`)
|
- `TGI_URL`: URL of the TGI server with the main inference model (default: `http://127.0.0.1:8080/v1`)
|
||||||
- `TGI_SAFETY_URL`: URL of the TGI server with the safety model (default: `http://127.0.0.1:8081/v1`)
|
- `TGI_SAFETY_URL`: URL of the TGI server with the safety model (default: `http://127.0.0.1:8081/v1`)
|
||||||
|
@ -48,6 +48,7 @@ export INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
|
||||||
export CUDA_VISIBLE_DEVICES=0
|
export CUDA_VISIBLE_DEVICES=0
|
||||||
|
|
||||||
docker run --rm -it \
|
docker run --rm -it \
|
||||||
|
--pull always \
|
||||||
-v $HOME/.cache/huggingface:/data \
|
-v $HOME/.cache/huggingface:/data \
|
||||||
-p $INFERENCE_PORT:$INFERENCE_PORT \
|
-p $INFERENCE_PORT:$INFERENCE_PORT \
|
||||||
--gpus $CUDA_VISIBLE_DEVICES \
|
--gpus $CUDA_VISIBLE_DEVICES \
|
||||||
|
@ -68,6 +69,7 @@ export SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
|
||||||
export CUDA_VISIBLE_DEVICES=1
|
export CUDA_VISIBLE_DEVICES=1
|
||||||
|
|
||||||
docker run --rm -it \
|
docker run --rm -it \
|
||||||
|
--pull always \
|
||||||
-v $HOME/.cache/huggingface:/data \
|
-v $HOME/.cache/huggingface:/data \
|
||||||
-p $SAFETY_PORT:$SAFETY_PORT \
|
-p $SAFETY_PORT:$SAFETY_PORT \
|
||||||
--gpus $CUDA_VISIBLE_DEVICES \
|
--gpus $CUDA_VISIBLE_DEVICES \
|
||||||
|
@ -88,9 +90,10 @@ Now you are ready to run Llama Stack with TGI as the inference provider. You can
|
||||||
This method allows you to get started quickly without having to build the distribution code.
|
This method allows you to get started quickly without having to build the distribution code.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
LLAMA_STACK_PORT=5001
|
LLAMA_STACK_PORT=8321
|
||||||
docker run \
|
docker run \
|
||||||
-it \
|
-it \
|
||||||
|
--pull always \
|
||||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||||
llamastack/distribution-tgi \
|
llamastack/distribution-tgi \
|
||||||
--port $LLAMA_STACK_PORT \
|
--port $LLAMA_STACK_PORT \
|
||||||
|
@ -107,6 +110,7 @@ cd /path/to/llama-stack
|
||||||
|
|
||||||
docker run \
|
docker run \
|
||||||
-it \
|
-it \
|
||||||
|
--pull always \
|
||||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||||
-v ~/.llama:/root/.llama \
|
-v ~/.llama:/root/.llama \
|
||||||
-v ./llama_stack/templates/tgi/run-with-safety.yaml:/root/my-run.yaml \
|
-v ./llama_stack/templates/tgi/run-with-safety.yaml:/root/my-run.yaml \
|
||||||
|
|
|
@ -28,7 +28,7 @@ The `llamastack/distribution-together` distribution consists of the following pr
|
||||||
|
|
||||||
The following environment variables can be configured:
|
The following environment variables can be configured:
|
||||||
|
|
||||||
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
|
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `8321`)
|
||||||
- `TOGETHER_API_KEY`: Together.AI API Key (default: ``)
|
- `TOGETHER_API_KEY`: Together.AI API Key (default: ``)
|
||||||
|
|
||||||
### Models
|
### Models
|
||||||
|
@ -62,9 +62,10 @@ You can do this via Conda (build code) or Docker which has a pre-built image.
|
||||||
This method allows you to get started quickly without having to build the distribution code.
|
This method allows you to get started quickly without having to build the distribution code.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
LLAMA_STACK_PORT=5001
|
LLAMA_STACK_PORT=8321
|
||||||
docker run \
|
docker run \
|
||||||
-it \
|
-it \
|
||||||
|
--pull always \
|
||||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||||
llamastack/distribution-together \
|
llamastack/distribution-together \
|
||||||
--port $LLAMA_STACK_PORT \
|
--port $LLAMA_STACK_PORT \
|
||||||
|
|
|
@ -54,6 +54,7 @@ mkdir -p ~/.llama
|
||||||
Then you can start the server using the container tool of your choice. For example, if you are running Docker you can use the following command:
|
Then you can start the server using the container tool of your choice. For example, if you are running Docker you can use the following command:
|
||||||
```bash
|
```bash
|
||||||
docker run -it \
|
docker run -it \
|
||||||
|
--pull always \
|
||||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||||
-v ~/.llama:/root/.llama \
|
-v ~/.llama:/root/.llama \
|
||||||
llamastack/distribution-ollama \
|
llamastack/distribution-ollama \
|
||||||
|
@ -74,6 +75,7 @@ Docker containers run in their own isolated network namespaces on Linux. To allo
|
||||||
Linux users having issues running the above command should instead try the following:
|
Linux users having issues running the above command should instead try the following:
|
||||||
```bash
|
```bash
|
||||||
docker run -it \
|
docker run -it \
|
||||||
|
--pull always \
|
||||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||||
-v ~/.llama:/root/.llama \
|
-v ~/.llama:/root/.llama \
|
||||||
--network=host \
|
--network=host \
|
||||||
|
@ -197,9 +199,7 @@ import os
|
||||||
import uuid
|
import uuid
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
from llama_stack_client.lib.agents.agent import Agent
|
from llama_stack_client import Agent, AgentEventLogger, RAGDocument
|
||||||
from llama_stack_client.lib.agents.event_logger import EventLogger
|
|
||||||
from llama_stack_client.types import Document
|
|
||||||
|
|
||||||
|
|
||||||
def create_http_client():
|
def create_http_client():
|
||||||
|
@ -225,7 +225,7 @@ client = (
|
||||||
# Documents to be used for RAG
|
# Documents to be used for RAG
|
||||||
urls = ["chat.rst", "llama3.rst", "memory_optimizations.rst", "lora_finetune.rst"]
|
urls = ["chat.rst", "llama3.rst", "memory_optimizations.rst", "lora_finetune.rst"]
|
||||||
documents = [
|
documents = [
|
||||||
Document(
|
RAGDocument(
|
||||||
document_id=f"num-{i}",
|
document_id=f"num-{i}",
|
||||||
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
|
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
|
||||||
mime_type="text/plain",
|
mime_type="text/plain",
|
||||||
|
@ -284,7 +284,7 @@ for prompt in user_prompts:
|
||||||
messages=[{"role": "user", "content": prompt}],
|
messages=[{"role": "user", "content": prompt}],
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
for log in EventLogger().log(response):
|
for log in AgentEventLogger().log(response):
|
||||||
log.print()
|
log.print()
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
|
@ -15,8 +15,6 @@ Llama Stack defines and standardizes the core building blocks needed to bring ge
|
||||||
- **Multiple developer interfaces** like CLI and SDKs for Python, Node, iOS, and Android
|
- **Multiple developer interfaces** like CLI and SDKs for Python, Node, iOS, and Android
|
||||||
- **Standalone applications** as examples for how to build production-grade AI applications with Llama Stack
|
- **Standalone applications** as examples for how to build production-grade AI applications with Llama Stack
|
||||||
|
|
||||||
We focus on making it easy to build production applications with the Llama model family - from the latest Llama 3.3 to specialized models like Llama Guard for safety.
|
|
||||||
|
|
||||||
```{image} ../_static/llama-stack.png
|
```{image} ../_static/llama-stack.png
|
||||||
:alt: Llama Stack
|
:alt: Llama Stack
|
||||||
:width: 400px
|
:width: 400px
|
||||||
|
|
|
@ -48,7 +48,7 @@ Llama Stack addresses these challenges through a service-oriented, API-first app
|
||||||
|
|
||||||
**Robust Ecosystem**
|
**Robust Ecosystem**
|
||||||
- Llama Stack is already integrated with distribution partners (cloud providers, hardware vendors, and AI-focused companies).
|
- Llama Stack is already integrated with distribution partners (cloud providers, hardware vendors, and AI-focused companies).
|
||||||
- Ecosystem offers tailored infrastructure, software, and services for deploying Llama models.
|
- Ecosystem offers tailored infrastructure, software, and services for deploying a variety of models.
|
||||||
|
|
||||||
|
|
||||||
### Our Philosophy
|
### Our Philosophy
|
||||||
|
@ -57,7 +57,6 @@ Llama Stack addresses these challenges through a service-oriented, API-first app
|
||||||
- **Composability**: Every component is independent but works together seamlessly
|
- **Composability**: Every component is independent but works together seamlessly
|
||||||
- **Production Ready**: Built for real-world applications, not just demos
|
- **Production Ready**: Built for real-world applications, not just demos
|
||||||
- **Turnkey Solutions**: Easy to deploy built in solutions for popular deployment scenarios
|
- **Turnkey Solutions**: Easy to deploy built in solutions for popular deployment scenarios
|
||||||
- **Llama First**: Explicit focus on Meta's Llama models and partnering ecosystem
|
|
||||||
|
|
||||||
|
|
||||||
With Llama Stack, you can focus on building your application while we handle the infrastructure complexity, essential capabilities, and provider integrations.
|
With Llama Stack, you can focus on building your application while we handle the infrastructure complexity, essential capabilities, and provider integrations.
|
||||||
|
|
|
@ -118,6 +118,7 @@ Playground can also be started in a docker image:
|
||||||
export LLAMA_STACK_URL=http://localhost:11434
|
export LLAMA_STACK_URL=http://localhost:11434
|
||||||
|
|
||||||
docker run \
|
docker run \
|
||||||
|
--pull always \
|
||||||
-p 8501:8501 \
|
-p 8501:8501 \
|
||||||
-e LLAMA_STACK_ENDPOINT=$LLAMA_STACK_URL \
|
-e LLAMA_STACK_ENDPOINT=$LLAMA_STACK_URL \
|
||||||
quay.io/jland/llama-stack-playground
|
quay.io/jland/llama-stack-playground
|
||||||
|
|
|
@ -48,7 +48,7 @@
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"HOST = \"localhost\" # Replace with your host\n",
|
"HOST = \"localhost\" # Replace with your host\n",
|
||||||
"PORT = 5001 # Replace with your port\n",
|
"PORT = 8321 # Replace with your port\n",
|
||||||
"MODEL_NAME='meta-llama/Llama-3.2-3B-Instruct'"
|
"MODEL_NAME='meta-llama/Llama-3.2-3B-Instruct'"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
@ -369,6 +369,9 @@
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
|
"fileHeader": "",
|
||||||
|
"fileUid": "7da25939-a2a3-463c-958e-9cdfd710d158",
|
||||||
|
"isAdHoc": false,
|
||||||
"kernelspec": {
|
"kernelspec": {
|
||||||
"display_name": "Python 3 (ipykernel)",
|
"display_name": "Python 3 (ipykernel)",
|
||||||
"language": "python",
|
"language": "python",
|
||||||
|
@ -386,7 +389,5 @@
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.10.15"
|
"version": "3.10.15"
|
||||||
}
|
}
|
||||||
},
|
}
|
||||||
"nbformat": 4,
|
|
||||||
"nbformat_minor": 5
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -43,7 +43,7 @@
|
||||||
"source": [
|
"source": [
|
||||||
"#### 2. Set Up Local and Cloud Clients\n",
|
"#### 2. Set Up Local and Cloud Clients\n",
|
||||||
"\n",
|
"\n",
|
||||||
"Initialize both clients, specifying the `base_url` for each instance. In this case, we have the local distribution running on `http://localhost:8321` and the cloud distribution running on `http://localhost:5001`.\n"
|
"Initialize both clients, specifying the `base_url` for each instance. In this case, we have the local distribution running on `http://localhost:8321` and the cloud distribution running on `http://localhost:8322`.\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -236,6 +236,9 @@
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
|
"fileHeader": "",
|
||||||
|
"fileUid": "e11939ac-dfbc-4a1c-83be-e494c7f803b8",
|
||||||
|
"isAdHoc": false,
|
||||||
"kernelspec": {
|
"kernelspec": {
|
||||||
"display_name": "Python 3 (ipykernel)",
|
"display_name": "Python 3 (ipykernel)",
|
||||||
"language": "python",
|
"language": "python",
|
||||||
|
@ -253,7 +256,5 @@
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.10.15"
|
"version": "3.10.15"
|
||||||
}
|
}
|
||||||
},
|
}
|
||||||
"nbformat": 4,
|
|
||||||
"nbformat_minor": 5
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -47,7 +47,7 @@
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"HOST = \"localhost\" # Replace with your host\n",
|
"HOST = \"localhost\" # Replace with your host\n",
|
||||||
"PORT = 5001 # Replace with your port\n",
|
"PORT = 8321 # Replace with your port\n",
|
||||||
"MODEL_NAME='meta-llama/Llama-3.2-3B-Instruct'"
|
"MODEL_NAME='meta-llama/Llama-3.2-3B-Instruct'"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
@ -281,6 +281,9 @@
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
|
"fileHeader": "",
|
||||||
|
"fileUid": "b1b93b6e-22a2-4c24-8cb0-161fdafff29a",
|
||||||
|
"isAdHoc": false,
|
||||||
"kernelspec": {
|
"kernelspec": {
|
||||||
"display_name": "base",
|
"display_name": "base",
|
||||||
"language": "python",
|
"language": "python",
|
||||||
|
@ -298,7 +301,5 @@
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.12.2"
|
"version": "3.12.2"
|
||||||
}
|
}
|
||||||
},
|
}
|
||||||
"nbformat": 4,
|
|
||||||
"nbformat_minor": 5
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -45,7 +45,7 @@
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"HOST = \"localhost\" # Replace with your host\n",
|
"HOST = \"localhost\" # Replace with your host\n",
|
||||||
"CLOUD_PORT = 5001 # Replace with your cloud distro port\n",
|
"CLOUD_PORT = 8321 # Replace with your cloud distro port\n",
|
||||||
"MODEL_NAME='Llama3.2-11B-Vision-Instruct'"
|
"MODEL_NAME='Llama3.2-11B-Vision-Instruct'"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
@ -180,6 +180,9 @@
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
|
"fileHeader": "",
|
||||||
|
"fileUid": "37bbbfda-8e42-446c-89c7-59dd49e2d339",
|
||||||
|
"isAdHoc": false,
|
||||||
"kernelspec": {
|
"kernelspec": {
|
||||||
"display_name": "base",
|
"display_name": "base",
|
||||||
"language": "python",
|
"language": "python",
|
||||||
|
@ -197,7 +200,5 @@
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.12.2"
|
"version": "3.12.2"
|
||||||
}
|
}
|
||||||
},
|
}
|
||||||
"nbformat": 4,
|
|
||||||
"nbformat_minor": 5
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -46,7 +46,7 @@
|
||||||
"nest_asyncio.apply()\n",
|
"nest_asyncio.apply()\n",
|
||||||
"\n",
|
"\n",
|
||||||
"HOST = \"localhost\"\n",
|
"HOST = \"localhost\"\n",
|
||||||
"PORT = 5001\n",
|
"PORT = 8321\n",
|
||||||
"MODEL_NAME = \"meta-llama/Llama-3.2-3B-Instruct\"\n"
|
"MODEL_NAME = \"meta-llama/Llama-3.2-3B-Instruct\"\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
@ -296,7 +296,7 @@
|
||||||
"\n",
|
"\n",
|
||||||
" # Create an agent instance with the client and configuration\n",
|
" # Create an agent instance with the client and configuration\n",
|
||||||
" agent = Agent(\n",
|
" agent = Agent(\n",
|
||||||
" client, \n",
|
" client,\n",
|
||||||
" model=MODEL_NAME,\n",
|
" model=MODEL_NAME,\n",
|
||||||
" instructions=\"\"\"You are a helpful assistant that responds to user queries with relevant information and cites sources when available.\"\"\",\n",
|
" instructions=\"\"\"You are a helpful assistant that responds to user queries with relevant information and cites sources when available.\"\"\",\n",
|
||||||
" sampling_params={\n",
|
" sampling_params={\n",
|
||||||
|
@ -335,6 +335,9 @@
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
|
"fileHeader": "",
|
||||||
|
"fileUid": "f0abbf6d-ed52-40ad-afb4-f5ec99130249",
|
||||||
|
"isAdHoc": false,
|
||||||
"kernelspec": {
|
"kernelspec": {
|
||||||
"display_name": "Python 3 (ipykernel)",
|
"display_name": "Python 3 (ipykernel)",
|
||||||
"language": "python",
|
"language": "python",
|
||||||
|
@ -352,7 +355,5 @@
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.10.15"
|
"version": "3.10.15"
|
||||||
}
|
}
|
||||||
},
|
}
|
||||||
"nbformat": 4,
|
|
||||||
"nbformat_minor": 5
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -45,7 +45,7 @@
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"HOST = \"localhost\" # Replace with your host\n",
|
"HOST = \"localhost\" # Replace with your host\n",
|
||||||
"PORT = 5001 # Replace with your port\n",
|
"PORT = 8321 # Replace with your port\n",
|
||||||
"MODEL_NAME='meta-llama/Llama-3.2-3B-Instruct'\n",
|
"MODEL_NAME='meta-llama/Llama-3.2-3B-Instruct'\n",
|
||||||
"MEMORY_BANK_ID=\"tutorial_bank\""
|
"MEMORY_BANK_ID=\"tutorial_bank\""
|
||||||
]
|
]
|
||||||
|
@ -378,6 +378,9 @@
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
|
"fileHeader": "",
|
||||||
|
"fileUid": "73bc3357-0e5e-42ff-95b1-40b916d24c4f",
|
||||||
|
"isAdHoc": false,
|
||||||
"kernelspec": {
|
"kernelspec": {
|
||||||
"display_name": "Python 3 (ipykernel)",
|
"display_name": "Python 3 (ipykernel)",
|
||||||
"language": "python",
|
"language": "python",
|
||||||
|
@ -395,7 +398,5 @@
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.10.15"
|
"version": "3.10.15"
|
||||||
}
|
}
|
||||||
},
|
}
|
||||||
"nbformat": 4,
|
|
||||||
"nbformat_minor": 4
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -49,7 +49,7 @@
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"HOST = \"localhost\" # Replace with your host\n",
|
"HOST = \"localhost\" # Replace with your host\n",
|
||||||
"PORT = 5001 # Replace with your port\n",
|
"PORT = 8321 # Replace with your port\n",
|
||||||
"SHEILD_NAME=\"meta-llama/Llama-Guard-3-1B\""
|
"SHEILD_NAME=\"meta-llama/Llama-Guard-3-1B\""
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
@ -112,6 +112,9 @@
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
|
"fileHeader": "",
|
||||||
|
"fileUid": "9afaddb7-c2fb-4309-8fa0-761697de53f0",
|
||||||
|
"isAdHoc": false,
|
||||||
"kernelspec": {
|
"kernelspec": {
|
||||||
"display_name": "Python 3 (ipykernel)",
|
"display_name": "Python 3 (ipykernel)",
|
||||||
"language": "python",
|
"language": "python",
|
||||||
|
@ -129,7 +132,5 @@
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.11.10"
|
"version": "3.11.10"
|
||||||
}
|
}
|
||||||
},
|
}
|
||||||
"nbformat": 4,
|
|
||||||
"nbformat_minor": 4
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -50,7 +50,7 @@
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"HOST = \"localhost\" # Replace with your host\n",
|
"HOST = \"localhost\" # Replace with your host\n",
|
||||||
"PORT = 5001 # Replace with your port\n",
|
"PORT = 8321 # Replace with your port\n",
|
||||||
"MODEL_NAME = \"meta-llama/Llama-3.2-3B-Instruct\"\n"
|
"MODEL_NAME = \"meta-llama/Llama-3.2-3B-Instruct\"\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
@ -115,7 +115,7 @@
|
||||||
"async def agent_example():\n",
|
"async def agent_example():\n",
|
||||||
" client = LlamaStackClient(base_url=f\"http://{HOST}:{PORT}\")\n",
|
" client = LlamaStackClient(base_url=f\"http://{HOST}:{PORT}\")\n",
|
||||||
" agent = Agent(\n",
|
" agent = Agent(\n",
|
||||||
" client, \n",
|
" client,\n",
|
||||||
" model=MODEL_NAME,\n",
|
" model=MODEL_NAME,\n",
|
||||||
" instructions=\"You are a helpful assistant! If you call builtin tools like brave search, follow the syntax brave_search.call(…)\",\n",
|
" instructions=\"You are a helpful assistant! If you call builtin tools like brave search, follow the syntax brave_search.call(…)\",\n",
|
||||||
" sampling_params={\n",
|
" sampling_params={\n",
|
||||||
|
@ -168,6 +168,9 @@
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
|
"fileHeader": "",
|
||||||
|
"fileUid": "8de24775-c4a0-49c7-904e-608264f69292",
|
||||||
|
"isAdHoc": false,
|
||||||
"kernelspec": {
|
"kernelspec": {
|
||||||
"display_name": "Python 3 (ipykernel)",
|
"display_name": "Python 3 (ipykernel)",
|
||||||
"language": "python",
|
"language": "python",
|
||||||
|
@ -185,7 +188,5 @@
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.10.15"
|
"version": "3.10.15"
|
||||||
}
|
}
|
||||||
},
|
}
|
||||||
"nbformat": 4,
|
|
||||||
"nbformat_minor": 4
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -96,7 +96,7 @@ If you're looking for more specific topics, we have a [Zero to Hero Guide](#next
|
||||||
3. **Set the ENV variables by exporting them to the terminal**:
|
3. **Set the ENV variables by exporting them to the terminal**:
|
||||||
```bash
|
```bash
|
||||||
export OLLAMA_URL="http://localhost:11434"
|
export OLLAMA_URL="http://localhost:11434"
|
||||||
export LLAMA_STACK_PORT=5001
|
export LLAMA_STACK_PORT=8321
|
||||||
export INFERENCE_MODEL="meta-llama/Llama-3.2-3B-Instruct"
|
export INFERENCE_MODEL="meta-llama/Llama-3.2-3B-Instruct"
|
||||||
export SAFETY_MODEL="meta-llama/Llama-Guard-3-1B"
|
export SAFETY_MODEL="meta-llama/Llama-Guard-3-1B"
|
||||||
```
|
```
|
||||||
|
@ -112,7 +112,7 @@ If you're looking for more specific topics, we have a [Zero to Hero Guide](#next
|
||||||
```
|
```
|
||||||
Note: Every time you run a new model with `ollama run`, you will need to restart the llama stack. Otherwise it won't see the new model.
|
Note: Every time you run a new model with `ollama run`, you will need to restart the llama stack. Otherwise it won't see the new model.
|
||||||
|
|
||||||
The server will start and listen on `http://localhost:5001`.
|
The server will start and listen on `http://localhost:8321`.
|
||||||
|
|
||||||
---
|
---
|
||||||
## Test with `llama-stack-client` CLI
|
## Test with `llama-stack-client` CLI
|
||||||
|
@ -120,11 +120,11 @@ After setting up the server, open a new terminal window and configure the llama-
|
||||||
|
|
||||||
1. Configure the CLI to point to the llama-stack server.
|
1. Configure the CLI to point to the llama-stack server.
|
||||||
```bash
|
```bash
|
||||||
llama-stack-client configure --endpoint http://localhost:5001
|
llama-stack-client configure --endpoint http://localhost:8321
|
||||||
```
|
```
|
||||||
**Expected Output:**
|
**Expected Output:**
|
||||||
```bash
|
```bash
|
||||||
Done! You can now use the Llama Stack Client CLI with endpoint http://localhost:5001
|
Done! You can now use the Llama Stack Client CLI with endpoint http://localhost:8321
|
||||||
```
|
```
|
||||||
2. Test the CLI by running inference:
|
2. Test the CLI by running inference:
|
||||||
```bash
|
```bash
|
||||||
|
@ -218,7 +218,7 @@ if INFERENCE_MODEL is None:
|
||||||
raise ValueError("The environment variable 'INFERENCE_MODEL' is not set.")
|
raise ValueError("The environment variable 'INFERENCE_MODEL' is not set.")
|
||||||
|
|
||||||
# Initialize the clien
|
# Initialize the clien
|
||||||
client = LlamaStackClient(base_url="http://localhost:5001")
|
client = LlamaStackClient(base_url="http://localhost:8321")
|
||||||
|
|
||||||
# Create a chat completion reques
|
# Create a chat completion reques
|
||||||
response = client.inference.chat_completion(
|
response = client.inference.chat_completion(
|
||||||
|
|
|
@ -36,7 +36,6 @@ from llama_stack.apis.inference import (
|
||||||
)
|
)
|
||||||
from llama_stack.apis.safety import SafetyViolation
|
from llama_stack.apis.safety import SafetyViolation
|
||||||
from llama_stack.apis.tools import ToolDef
|
from llama_stack.apis.tools import ToolDef
|
||||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
|
||||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||||
|
|
||||||
|
|
||||||
|
@ -189,13 +188,11 @@ class AgentToolGroupWithArgs(BaseModel):
|
||||||
args: Dict[str, Any]
|
args: Dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
AgentToolGroup = register_schema(
|
AgentToolGroup = Union[
|
||||||
Union[
|
|
||||||
str,
|
str,
|
||||||
AgentToolGroupWithArgs,
|
AgentToolGroupWithArgs,
|
||||||
],
|
]
|
||||||
name="AgentTool",
|
register_schema(AgentToolGroup, name="AgentTool")
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class AgentConfigCommon(BaseModel):
|
class AgentConfigCommon(BaseModel):
|
||||||
|
@ -312,8 +309,7 @@ class AgentTurnResponseTurnAwaitingInputPayload(BaseModel):
|
||||||
turn: Turn
|
turn: Turn
|
||||||
|
|
||||||
|
|
||||||
AgentTurnResponseEventPayload = register_schema(
|
AgentTurnResponseEventPayload = Annotated[
|
||||||
Annotated[
|
|
||||||
Union[
|
Union[
|
||||||
AgentTurnResponseStepStartPayload,
|
AgentTurnResponseStepStartPayload,
|
||||||
AgentTurnResponseStepProgressPayload,
|
AgentTurnResponseStepProgressPayload,
|
||||||
|
@ -323,9 +319,8 @@ AgentTurnResponseEventPayload = register_schema(
|
||||||
AgentTurnResponseTurnAwaitingInputPayload,
|
AgentTurnResponseTurnAwaitingInputPayload,
|
||||||
],
|
],
|
||||||
Field(discriminator="event_type"),
|
Field(discriminator="event_type"),
|
||||||
],
|
]
|
||||||
name="AgentTurnResponseEventPayload",
|
register_schema(AgentTurnResponseEventPayload, name="AgentTurnResponseEventPayload")
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -387,7 +382,6 @@ class AgentStepResponse(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
@trace_protocol
|
|
||||||
class Agents(Protocol):
|
class Agents(Protocol):
|
||||||
"""Agents API for creating and interacting with agentic systems.
|
"""Agents API for creating and interacting with agentic systems.
|
||||||
|
|
||||||
|
@ -399,7 +393,7 @@ class Agents(Protocol):
|
||||||
- Agents can also use Memory to retrieve information from knowledge bases. See the RAG Tool and Vector IO APIs for more details.
|
- Agents can also use Memory to retrieve information from knowledge bases. See the RAG Tool and Vector IO APIs for more details.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@webmethod(route="/agents", method="POST")
|
@webmethod(route="/agents", method="POST", descriptive_name="create_agent")
|
||||||
async def create_agent(
|
async def create_agent(
|
||||||
self,
|
self,
|
||||||
agent_config: AgentConfig,
|
agent_config: AgentConfig,
|
||||||
|
@ -411,7 +405,9 @@ class Agents(Protocol):
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@webmethod(route="/agents/{agent_id}/session/{session_id}/turn", method="POST")
|
@webmethod(
|
||||||
|
route="/agents/{agent_id}/session/{session_id}/turn", method="POST", descriptive_name="create_agent_turn"
|
||||||
|
)
|
||||||
async def create_agent_turn(
|
async def create_agent_turn(
|
||||||
self,
|
self,
|
||||||
agent_id: str,
|
agent_id: str,
|
||||||
|
@ -443,6 +439,7 @@ class Agents(Protocol):
|
||||||
@webmethod(
|
@webmethod(
|
||||||
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume",
|
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume",
|
||||||
method="POST",
|
method="POST",
|
||||||
|
descriptive_name="resume_agent_turn",
|
||||||
)
|
)
|
||||||
async def resume_agent_turn(
|
async def resume_agent_turn(
|
||||||
self,
|
self,
|
||||||
|
@ -505,7 +502,7 @@ class Agents(Protocol):
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@webmethod(route="/agents/{agent_id}/session", method="POST")
|
@webmethod(route="/agents/{agent_id}/session", method="POST", descriptive_name="create_agent_session")
|
||||||
async def create_agent_session(
|
async def create_agent_session(
|
||||||
self,
|
self,
|
||||||
agent_id: str,
|
agent_id: str,
|
||||||
|
|
|
@ -63,19 +63,15 @@ class TextContentItem(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
# other modalities can be added here
|
# other modalities can be added here
|
||||||
InterleavedContentItem = register_schema(
|
InterleavedContentItem = Annotated[
|
||||||
Annotated[
|
|
||||||
Union[ImageContentItem, TextContentItem],
|
Union[ImageContentItem, TextContentItem],
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
],
|
]
|
||||||
name="InterleavedContentItem",
|
register_schema(InterleavedContentItem, name="InterleavedContentItem")
|
||||||
)
|
|
||||||
|
|
||||||
# accept a single "str" as a special case since it is common
|
# accept a single "str" as a special case since it is common
|
||||||
InterleavedContent = register_schema(
|
InterleavedContent = Union[str, InterleavedContentItem, List[InterleavedContentItem]]
|
||||||
Union[str, InterleavedContentItem, List[InterleavedContentItem]],
|
register_schema(InterleavedContent, name="InterleavedContent")
|
||||||
name="InterleavedContent",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -109,10 +105,8 @@ class ToolCallDelta(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
# streaming completions send a stream of ContentDeltas
|
# streaming completions send a stream of ContentDeltas
|
||||||
ContentDelta = register_schema(
|
ContentDelta = Annotated[
|
||||||
Annotated[
|
|
||||||
Union[TextDelta, ImageDelta, ToolCallDelta],
|
Union[TextDelta, ImageDelta, ToolCallDelta],
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
],
|
]
|
||||||
name="ContentDelta",
|
register_schema(ContentDelta, name="ContentDelta")
|
||||||
)
|
|
||||||
|
|
|
@ -72,8 +72,7 @@ class DialogType(BaseModel):
|
||||||
type: Literal["dialog"] = "dialog"
|
type: Literal["dialog"] = "dialog"
|
||||||
|
|
||||||
|
|
||||||
ParamType = register_schema(
|
ParamType = Annotated[
|
||||||
Annotated[
|
|
||||||
Union[
|
Union[
|
||||||
StringType,
|
StringType,
|
||||||
NumberType,
|
NumberType,
|
||||||
|
@ -87,9 +86,8 @@ ParamType = register_schema(
|
||||||
AgentTurnInputType,
|
AgentTurnInputType,
|
||||||
],
|
],
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
],
|
]
|
||||||
name="ParamType",
|
register_schema(ParamType, name="ParamType")
|
||||||
)
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
# TODO: recursive definition of ParamType in these containers
|
# TODO: recursive definition of ParamType in these containers
|
||||||
|
|
|
@ -84,13 +84,11 @@ class RowsDataSource(BaseModel):
|
||||||
rows: List[Dict[str, Any]]
|
rows: List[Dict[str, Any]]
|
||||||
|
|
||||||
|
|
||||||
DataSource = register_schema(
|
DataSource = Annotated[
|
||||||
Annotated[
|
|
||||||
Union[URIDataSource, RowsDataSource],
|
Union[URIDataSource, RowsDataSource],
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
],
|
]
|
||||||
name="DataSource",
|
register_schema(DataSource, name="DataSource")
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class CommonDatasetFields(BaseModel):
|
class CommonDatasetFields(BaseModel):
|
||||||
|
@ -121,8 +119,6 @@ class Dataset(CommonDatasetFields, Resource):
|
||||||
|
|
||||||
class DatasetInput(CommonDatasetFields, BaseModel):
|
class DatasetInput(CommonDatasetFields, BaseModel):
|
||||||
dataset_id: str
|
dataset_id: str
|
||||||
provider_id: Optional[str] = None
|
|
||||||
provider_dataset_id: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
class ListDatasetsResponse(BaseModel):
|
class ListDatasetsResponse(BaseModel):
|
||||||
|
|
144
llama_stack/apis/eval/eval.py
Normal file
144
llama_stack/apis/eval/eval.py
Normal file
|
@ -0,0 +1,144 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any, Dict, List, Literal, Optional, Protocol, Union
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
|
from llama_stack.apis.agents import AgentConfig
|
||||||
|
from llama_stack.apis.common.job_types import Job
|
||||||
|
from llama_stack.apis.inference import SamplingParams, SystemMessage
|
||||||
|
from llama_stack.apis.scoring import ScoringResult
|
||||||
|
from llama_stack.apis.scoring_functions import ScoringFnParams
|
||||||
|
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ModelCandidate(BaseModel):
|
||||||
|
"""A model candidate for evaluation.
|
||||||
|
|
||||||
|
:param model: The model ID to evaluate.
|
||||||
|
:param sampling_params: The sampling parameters for the model.
|
||||||
|
:param system_message: (Optional) The system message providing instructions or context to the model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
type: Literal["model"] = "model"
|
||||||
|
model: str
|
||||||
|
sampling_params: SamplingParams
|
||||||
|
system_message: Optional[SystemMessage] = None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class AgentCandidate(BaseModel):
|
||||||
|
"""An agent candidate for evaluation.
|
||||||
|
|
||||||
|
:param config: The configuration for the agent candidate.
|
||||||
|
"""
|
||||||
|
|
||||||
|
type: Literal["agent"] = "agent"
|
||||||
|
config: AgentConfig
|
||||||
|
|
||||||
|
|
||||||
|
EvalCandidate = Annotated[Union[ModelCandidate, AgentCandidate], Field(discriminator="type")]
|
||||||
|
register_schema(EvalCandidate, name="EvalCandidate")
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class BenchmarkConfig(BaseModel):
|
||||||
|
"""A benchmark configuration for evaluation.
|
||||||
|
|
||||||
|
:param eval_candidate: The candidate to evaluate.
|
||||||
|
:param scoring_params: Map between scoring function id and parameters for each scoring function you want to run
|
||||||
|
:param num_examples: (Optional) The number of examples to evaluate. If not provided, all examples in the dataset will be evaluated
|
||||||
|
"""
|
||||||
|
|
||||||
|
eval_candidate: EvalCandidate
|
||||||
|
scoring_params: Dict[str, ScoringFnParams] = Field(
|
||||||
|
description="Map between scoring function id and parameters for each scoring function you want to run",
|
||||||
|
default_factory=dict,
|
||||||
|
)
|
||||||
|
num_examples: Optional[int] = Field(
|
||||||
|
description="Number of examples to evaluate (useful for testing), if not provided, all examples in the dataset will be evaluated",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
# we could optinally add any specific dataset config here
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class EvaluateResponse(BaseModel):
|
||||||
|
"""The response from an evaluation.
|
||||||
|
|
||||||
|
:param generations: The generations from the evaluation.
|
||||||
|
:param scores: The scores from the evaluation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
generations: List[Dict[str, Any]]
|
||||||
|
# each key in the dict is a scoring function name
|
||||||
|
scores: Dict[str, ScoringResult]
|
||||||
|
|
||||||
|
|
||||||
|
class Eval(Protocol):
|
||||||
|
"""Llama Stack Evaluation API for running evaluations on model and agent candidates."""
|
||||||
|
|
||||||
|
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs", method="POST")
|
||||||
|
async def run_eval(
|
||||||
|
self,
|
||||||
|
benchmark_id: str,
|
||||||
|
benchmark_config: BenchmarkConfig,
|
||||||
|
) -> Job:
|
||||||
|
"""Run an evaluation on a benchmark.
|
||||||
|
|
||||||
|
:param benchmark_id: The ID of the benchmark to run the evaluation on.
|
||||||
|
:param benchmark_config: The configuration for the benchmark.
|
||||||
|
:return: The job that was created to run the evaluation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@webmethod(route="/eval/benchmarks/{benchmark_id}/evaluations", method="POST")
|
||||||
|
async def evaluate_rows(
|
||||||
|
self,
|
||||||
|
benchmark_id: str,
|
||||||
|
input_rows: List[Dict[str, Any]],
|
||||||
|
scoring_functions: List[str],
|
||||||
|
benchmark_config: BenchmarkConfig,
|
||||||
|
) -> EvaluateResponse:
|
||||||
|
"""Evaluate a list of rows on a benchmark.
|
||||||
|
|
||||||
|
:param benchmark_id: The ID of the benchmark to run the evaluation on.
|
||||||
|
:param input_rows: The rows to evaluate.
|
||||||
|
:param scoring_functions: The scoring functions to use for the evaluation.
|
||||||
|
:param benchmark_config: The configuration for the benchmark.
|
||||||
|
:return: EvaluateResponse object containing generations and scores
|
||||||
|
"""
|
||||||
|
|
||||||
|
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="GET")
|
||||||
|
async def job_status(self, benchmark_id: str, job_id: str) -> Job:
|
||||||
|
"""Get the status of a job.
|
||||||
|
|
||||||
|
:param benchmark_id: The ID of the benchmark to run the evaluation on.
|
||||||
|
:param job_id: The ID of the job to get the status of.
|
||||||
|
:return: The status of the evaluationjob.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="DELETE")
|
||||||
|
async def job_cancel(self, benchmark_id: str, job_id: str) -> None:
|
||||||
|
"""Cancel a job.
|
||||||
|
|
||||||
|
:param benchmark_id: The ID of the benchmark to run the evaluation on.
|
||||||
|
:param job_id: The ID of the job to cancel.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}/result", method="GET")
|
||||||
|
async def job_result(self, benchmark_id: str, job_id: str) -> EvaluateResponse:
|
||||||
|
"""Get the result of a job.
|
||||||
|
|
||||||
|
:param benchmark_id: The ID of the benchmark to run the evaluation on.
|
||||||
|
:param job_id: The ID of the job to get the result of.
|
||||||
|
:return: The result of the job.
|
||||||
|
"""
|
||||||
|
|
|
@ -144,8 +144,7 @@ class CompletionMessage(BaseModel):
|
||||||
tool_calls: Optional[List[ToolCall]] = Field(default_factory=list)
|
tool_calls: Optional[List[ToolCall]] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
Message = register_schema(
|
Message = Annotated[
|
||||||
Annotated[
|
|
||||||
Union[
|
Union[
|
||||||
UserMessage,
|
UserMessage,
|
||||||
SystemMessage,
|
SystemMessage,
|
||||||
|
@ -153,9 +152,8 @@ Message = register_schema(
|
||||||
CompletionMessage,
|
CompletionMessage,
|
||||||
],
|
],
|
||||||
Field(discriminator="role"),
|
Field(discriminator="role"),
|
||||||
],
|
]
|
||||||
name="Message",
|
register_schema(Message, name="Message")
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -263,13 +261,11 @@ class GrammarResponseFormat(BaseModel):
|
||||||
bnf: Dict[str, Any]
|
bnf: Dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
ResponseFormat = register_schema(
|
ResponseFormat = Annotated[
|
||||||
Annotated[
|
|
||||||
Union[JsonSchemaResponseFormat, GrammarResponseFormat],
|
Union[JsonSchemaResponseFormat, GrammarResponseFormat],
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
],
|
]
|
||||||
name="ResponseFormat",
|
register_schema(ResponseFormat, name="ResponseFormat")
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# This is an internally used class
|
# This is an internally used class
|
||||||
|
|
|
@ -24,17 +24,6 @@ class HealthInfo(BaseModel):
|
||||||
# TODO: add a provider level status
|
# TODO: add a provider level status
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class ProviderInfo(BaseModel):
|
|
||||||
api: str
|
|
||||||
provider_id: str
|
|
||||||
provider_type: str
|
|
||||||
|
|
||||||
|
|
||||||
class ListProvidersResponse(BaseModel):
|
|
||||||
data: List[ProviderInfo]
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class VersionInfo(BaseModel):
|
class VersionInfo(BaseModel):
|
||||||
version: str
|
version: str
|
||||||
|
@ -46,9 +35,6 @@ class ListRoutesResponse(BaseModel):
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
class Inspect(Protocol):
|
class Inspect(Protocol):
|
||||||
@webmethod(route="/inspect/providers", method="GET")
|
|
||||||
async def list_providers(self) -> ListProvidersResponse: ...
|
|
||||||
|
|
||||||
@webmethod(route="/inspect/routes", method="GET")
|
@webmethod(route="/inspect/routes", method="GET")
|
||||||
async def list_routes(self) -> ListRoutesResponse: ...
|
async def list_routes(self) -> ListRoutesResponse: ...
|
||||||
|
|
||||||
|
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, Literal, Optional, Protocol
|
from typing import Any, Dict, List, Literal, Optional, Protocol, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
@ -88,10 +88,8 @@ class QATFinetuningConfig(BaseModel):
|
||||||
group_size: int
|
group_size: int
|
||||||
|
|
||||||
|
|
||||||
AlgorithmConfig = register_schema(
|
AlgorithmConfig = Annotated[Union[LoraFinetuningConfig, QATFinetuningConfig], Field(discriminator="type")]
|
||||||
Annotated[LoraFinetuningConfig | QATFinetuningConfig, Field(discriminator="type")],
|
register_schema(AlgorithmConfig, name="AlgorithmConfig")
|
||||||
name="AlgorithmConfig",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -184,7 +182,7 @@ class PostTraining(Protocol):
|
||||||
description="Model descriptor from `llama model list`",
|
description="Model descriptor from `llama model list`",
|
||||||
),
|
),
|
||||||
checkpoint_dir: Optional[str] = None,
|
checkpoint_dir: Optional[str] = None,
|
||||||
algorithm_config: Optional[LoraFinetuningConfig | QATFinetuningConfig] = None,
|
algorithm_config: Optional[AlgorithmConfig] = None,
|
||||||
) -> PostTrainingJob: ...
|
) -> PostTrainingJob: ...
|
||||||
|
|
||||||
@webmethod(route="/post-training/preference-optimize", method="POST")
|
@webmethod(route="/post-training/preference-optimize", method="POST")
|
||||||
|
|
149
llama_stack/apis/scoring_functions/scoring_functions.py
Normal file
149
llama_stack/apis/scoring_functions/scoring_functions.py
Normal file
|
@ -0,0 +1,149 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Dict,
|
||||||
|
List,
|
||||||
|
Literal,
|
||||||
|
Optional,
|
||||||
|
Protocol,
|
||||||
|
Union,
|
||||||
|
runtime_checkable,
|
||||||
|
)
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
|
from llama_stack.apis.common.type_system import ParamType
|
||||||
|
from llama_stack.apis.resource import Resource, ResourceType
|
||||||
|
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||||
|
|
||||||
|
|
||||||
|
# Perhaps more structure can be imposed on these functions. Maybe they could be associated
|
||||||
|
# with standard metrics so they can be rolled up?
|
||||||
|
@json_schema_type
|
||||||
|
class ScoringFnParamsType(Enum):
|
||||||
|
llm_as_judge = "llm_as_judge"
|
||||||
|
regex_parser = "regex_parser"
|
||||||
|
basic = "basic"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class AggregationFunctionType(Enum):
|
||||||
|
average = "average"
|
||||||
|
weighted_average = "weighted_average"
|
||||||
|
median = "median"
|
||||||
|
categorical_count = "categorical_count"
|
||||||
|
accuracy = "accuracy"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class LLMAsJudgeScoringFnParams(BaseModel):
|
||||||
|
type: Literal[ScoringFnParamsType.llm_as_judge.value] = ScoringFnParamsType.llm_as_judge.value
|
||||||
|
judge_model: str
|
||||||
|
prompt_template: Optional[str] = None
|
||||||
|
judge_score_regexes: Optional[List[str]] = Field(
|
||||||
|
description="Regexes to extract the answer from generated response",
|
||||||
|
default_factory=list,
|
||||||
|
)
|
||||||
|
aggregation_functions: Optional[List[AggregationFunctionType]] = Field(
|
||||||
|
description="Aggregation functions to apply to the scores of each row",
|
||||||
|
default_factory=list,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class RegexParserScoringFnParams(BaseModel):
|
||||||
|
type: Literal[ScoringFnParamsType.regex_parser.value] = ScoringFnParamsType.regex_parser.value
|
||||||
|
parsing_regexes: Optional[List[str]] = Field(
|
||||||
|
description="Regex to extract the answer from generated response",
|
||||||
|
default_factory=list,
|
||||||
|
)
|
||||||
|
aggregation_functions: Optional[List[AggregationFunctionType]] = Field(
|
||||||
|
description="Aggregation functions to apply to the scores of each row",
|
||||||
|
default_factory=list,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class BasicScoringFnParams(BaseModel):
|
||||||
|
type: Literal[ScoringFnParamsType.basic.value] = ScoringFnParamsType.basic.value
|
||||||
|
aggregation_functions: Optional[List[AggregationFunctionType]] = Field(
|
||||||
|
description="Aggregation functions to apply to the scores of each row",
|
||||||
|
default_factory=list,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
ScoringFnParams = Annotated[
|
||||||
|
Union[
|
||||||
|
LLMAsJudgeScoringFnParams,
|
||||||
|
RegexParserScoringFnParams,
|
||||||
|
BasicScoringFnParams,
|
||||||
|
],
|
||||||
|
Field(discriminator="type"),
|
||||||
|
]
|
||||||
|
register_schema(ScoringFnParams, name="ScoringFnParams")
|
||||||
|
|
||||||
|
|
||||||
|
class CommonScoringFnFields(BaseModel):
|
||||||
|
description: Optional[str] = None
|
||||||
|
metadata: Dict[str, Any] = Field(
|
||||||
|
default_factory=dict,
|
||||||
|
description="Any additional metadata for this definition",
|
||||||
|
)
|
||||||
|
return_type: ParamType = Field(
|
||||||
|
description="The return type of the deterministic function",
|
||||||
|
)
|
||||||
|
params: Optional[ScoringFnParams] = Field(
|
||||||
|
description="The parameters for the scoring function for benchmark eval, these can be overridden for app eval",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ScoringFn(CommonScoringFnFields, Resource):
|
||||||
|
type: Literal[ResourceType.scoring_function.value] = ResourceType.scoring_function.value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def scoring_fn_id(self) -> str:
|
||||||
|
return self.identifier
|
||||||
|
|
||||||
|
@property
|
||||||
|
def provider_scoring_fn_id(self) -> str:
|
||||||
|
return self.provider_resource_id
|
||||||
|
|
||||||
|
|
||||||
|
class ScoringFnInput(CommonScoringFnFields, BaseModel):
|
||||||
|
scoring_fn_id: str
|
||||||
|
provider_id: Optional[str] = None
|
||||||
|
provider_scoring_fn_id: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ListScoringFunctionsResponse(BaseModel):
|
||||||
|
data: List[ScoringFn]
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class ScoringFunctions(Protocol):
|
||||||
|
@webmethod(route="/scoring-functions", method="GET")
|
||||||
|
async def list_scoring_functions(self) -> ListScoringFunctionsResponse: ...
|
||||||
|
|
||||||
|
@webmethod(route="/scoring-functions/{scoring_fn_id:path}", method="GET")
|
||||||
|
async def get_scoring_function(self, scoring_fn_id: str, /) -> ScoringFn: ...
|
||||||
|
|
||||||
|
@webmethod(route="/scoring-functions", method="POST")
|
||||||
|
async def register_scoring_function(
|
||||||
|
self,
|
||||||
|
scoring_fn_id: str,
|
||||||
|
description: str,
|
||||||
|
return_type: ParamType,
|
||||||
|
provider_scoring_fn_id: Optional[str] = None,
|
||||||
|
provider_id: Optional[str] = None,
|
||||||
|
params: Optional[ScoringFnParams] = None,
|
||||||
|
) -> None: ...
|
||||||
|
|
|
@ -146,16 +146,14 @@ class SpanEndPayload(BaseModel):
|
||||||
status: SpanStatus
|
status: SpanStatus
|
||||||
|
|
||||||
|
|
||||||
StructuredLogPayload = register_schema(
|
StructuredLogPayload = Annotated[
|
||||||
Annotated[
|
|
||||||
Union[
|
Union[
|
||||||
SpanStartPayload,
|
SpanStartPayload,
|
||||||
SpanEndPayload,
|
SpanEndPayload,
|
||||||
],
|
],
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
],
|
]
|
||||||
name="StructuredLogPayload",
|
register_schema(StructuredLogPayload, name="StructuredLogPayload")
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -164,17 +162,15 @@ class StructuredLogEvent(EventCommon):
|
||||||
payload: StructuredLogPayload
|
payload: StructuredLogPayload
|
||||||
|
|
||||||
|
|
||||||
Event = register_schema(
|
Event = Annotated[
|
||||||
Annotated[
|
|
||||||
Union[
|
Union[
|
||||||
UnstructuredLogEvent,
|
UnstructuredLogEvent,
|
||||||
MetricEvent,
|
MetricEvent,
|
||||||
StructuredLogEvent,
|
StructuredLogEvent,
|
||||||
],
|
],
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
],
|
]
|
||||||
name="Event",
|
register_schema(Event, name="Event")
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
|
|
@ -17,6 +17,15 @@ from llama_stack.schema_utils import json_schema_type, register_schema, webmetho
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class RAGDocument(BaseModel):
|
class RAGDocument(BaseModel):
|
||||||
|
"""
|
||||||
|
A document to be used for document ingestion in the RAG Tool.
|
||||||
|
|
||||||
|
:param document_id: The unique identifier for the document.
|
||||||
|
:param content: The content of the document.
|
||||||
|
:param mime_type: The MIME type of the document.
|
||||||
|
:param metadata: Additional metadata for the document.
|
||||||
|
"""
|
||||||
|
|
||||||
document_id: str
|
document_id: str
|
||||||
content: InterleavedContent | URL
|
content: InterleavedContent | URL
|
||||||
mime_type: str | None = None
|
mime_type: str | None = None
|
||||||
|
@ -49,16 +58,14 @@ class LLMRAGQueryGeneratorConfig(BaseModel):
|
||||||
template: str
|
template: str
|
||||||
|
|
||||||
|
|
||||||
RAGQueryGeneratorConfig = register_schema(
|
RAGQueryGeneratorConfig = Annotated[
|
||||||
Annotated[
|
|
||||||
Union[
|
Union[
|
||||||
DefaultRAGQueryGeneratorConfig,
|
DefaultRAGQueryGeneratorConfig,
|
||||||
LLMRAGQueryGeneratorConfig,
|
LLMRAGQueryGeneratorConfig,
|
||||||
],
|
],
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
],
|
]
|
||||||
name="RAGQueryGeneratorConfig",
|
register_schema(RAGQueryGeneratorConfig, name="RAGQueryGeneratorConfig")
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
|
|
@ -69,7 +69,7 @@ class ToolGroup(Resource):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ToolInvocationResult(BaseModel):
|
class ToolInvocationResult(BaseModel):
|
||||||
content: InterleavedContent
|
content: Optional[InterleavedContent] = None
|
||||||
error_message: Optional[str] = None
|
error_message: Optional[str] = None
|
||||||
error_code: Optional[int] = None
|
error_code: Optional[int] = None
|
||||||
metadata: Optional[Dict[str, Any]] = None
|
metadata: Optional[Dict[str, Any]] = None
|
||||||
|
@ -140,9 +140,9 @@ class SpecialToolGroup(Enum):
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
@trace_protocol
|
@trace_protocol
|
||||||
class ToolRuntime(Protocol):
|
class ToolRuntime(Protocol):
|
||||||
tool_store: ToolStore
|
tool_store: ToolStore | None = None
|
||||||
|
|
||||||
rag_tool: RAGToolRuntime
|
rag_tool: RAGToolRuntime | None = None
|
||||||
|
|
||||||
# TODO: This needs to be renamed once OPEN API generator name conflict issue is fixed.
|
# TODO: This needs to be renamed once OPEN API generator name conflict issue is fixed.
|
||||||
@webmethod(route="/tool-runtime/list-tools", method="GET")
|
@webmethod(route="/tool-runtime/list-tools", method="GET")
|
||||||
|
|
|
@ -36,7 +36,7 @@ class VectorDBStore(Protocol):
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
@trace_protocol
|
@trace_protocol
|
||||||
class VectorIO(Protocol):
|
class VectorIO(Protocol):
|
||||||
vector_db_store: VectorDBStore
|
vector_db_store: VectorDBStore | None = None
|
||||||
|
|
||||||
# this will just block now until chunks are inserted, but it should
|
# this will just block now until chunks are inserted, but it should
|
||||||
# probably return a Job instance which can be polled for completion
|
# probably return a Job instance which can be polled for completion
|
||||||
|
|
|
@ -404,7 +404,7 @@ def _download_from_manifest(manifest_file: str, max_concurrent_downloads: int):
|
||||||
d = json.load(f)
|
d = json.load(f)
|
||||||
manifest = Manifest(**d)
|
manifest = Manifest(**d)
|
||||||
|
|
||||||
if datetime.now(timezone.utc) > manifest.expires_on:
|
if datetime.now(timezone.utc) > manifest.expires_on.astimezone(timezone.utc):
|
||||||
raise ValueError(f"Manifest URLs have expired on {manifest.expires_on}")
|
raise ValueError(f"Manifest URLs have expired on {manifest.expires_on}")
|
||||||
|
|
||||||
console = Console()
|
console = Console()
|
||||||
|
|
86
llama_stack/distribution/access_control.py
Normal file
86
llama_stack/distribution/access_control.py
Normal file
|
@ -0,0 +1,86 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
from llama_stack.distribution.datatypes import AccessAttributes
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
|
def check_access(
|
||||||
|
obj_identifier: str,
|
||||||
|
obj_attributes: Optional[AccessAttributes],
|
||||||
|
user_attributes: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> bool:
|
||||||
|
"""Check if the current user has access to the given object, based on access attributes.
|
||||||
|
|
||||||
|
Access control algorithm:
|
||||||
|
1. If the resource has no access_attributes, access is GRANTED to all authenticated users
|
||||||
|
2. If the user has no attributes, access is DENIED to any object with access_attributes defined
|
||||||
|
3. For each attribute category in the resource's access_attributes:
|
||||||
|
a. If the user lacks that category, access is DENIED
|
||||||
|
b. If the user has the category but none of the required values, access is DENIED
|
||||||
|
c. If the user has at least one matching value in each required category, access is GRANTED
|
||||||
|
|
||||||
|
Example:
|
||||||
|
# Resource requires:
|
||||||
|
access_attributes = AccessAttributes(
|
||||||
|
roles=["admin", "data-scientist"],
|
||||||
|
teams=["ml-team"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# User has:
|
||||||
|
user_attributes = {
|
||||||
|
"roles": ["data-scientist", "engineer"],
|
||||||
|
"teams": ["ml-team", "infra-team"],
|
||||||
|
"projects": ["llama-3"]
|
||||||
|
}
|
||||||
|
|
||||||
|
# Result: Access GRANTED
|
||||||
|
# - User has the "data-scientist" role (matches one of the required roles)
|
||||||
|
# - AND user is part of the "ml-team" (matches the required team)
|
||||||
|
# - The extra "projects" attribute is ignored
|
||||||
|
|
||||||
|
Args:
|
||||||
|
obj_identifier: The identifier of the resource object to check access for
|
||||||
|
obj_attributes: The access attributes of the resource object
|
||||||
|
user_attributes: The attributes of the current user
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if access is granted, False if denied
|
||||||
|
"""
|
||||||
|
# If object has no access attributes, allow access by default
|
||||||
|
if not obj_attributes:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# If no user attributes, deny access to objects with access control
|
||||||
|
if not user_attributes:
|
||||||
|
return False
|
||||||
|
|
||||||
|
dict_attribs = obj_attributes.model_dump(exclude_none=True)
|
||||||
|
if not dict_attribs:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Check each attribute category (requires ALL categories to match)
|
||||||
|
# TODO: formalize this into a proper ABAC policy
|
||||||
|
for attr_key, required_values in dict_attribs.items():
|
||||||
|
user_values = user_attributes.get(attr_key, [])
|
||||||
|
|
||||||
|
if not user_values:
|
||||||
|
logger.debug(f"Access denied to {obj_identifier}: missing required attribute category '{attr_key}'")
|
||||||
|
return False
|
||||||
|
|
||||||
|
if not any(val in user_values for val in required_values):
|
||||||
|
logger.debug(
|
||||||
|
f"Access denied to {obj_identifier}: "
|
||||||
|
f"no match for attribute '{attr_key}', required one of {required_values}"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
logger.debug(f"Access granted to {obj_identifier}")
|
||||||
|
return True
|
|
@ -90,6 +90,7 @@ RUN apt-get update && apt-get install -y \
|
||||||
procps psmisc lsof \
|
procps psmisc lsof \
|
||||||
traceroute \
|
traceroute \
|
||||||
bubblewrap \
|
bubblewrap \
|
||||||
|
gcc \
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
ENV UV_SYSTEM_PYTHON=1
|
ENV UV_SYSTEM_PYTHON=1
|
||||||
|
@ -235,7 +236,7 @@ image_tag="$image_name:$version_tag"
|
||||||
# Detect platform architecture
|
# Detect platform architecture
|
||||||
ARCH=$(uname -m)
|
ARCH=$(uname -m)
|
||||||
if [ -n "$BUILD_PLATFORM" ]; then
|
if [ -n "$BUILD_PLATFORM" ]; then
|
||||||
CLI_ARGS+=("--platform $BUILD_PLATFORM")
|
CLI_ARGS+=("--platform" "$BUILD_PLATFORM")
|
||||||
elif [ "$ARCH" = "arm64" ] || [ "$ARCH" = "aarch64" ]; then
|
elif [ "$ARCH" = "arm64" ] || [ "$ARCH" = "aarch64" ]; then
|
||||||
CLI_ARGS+=("--platform" "linux/arm64")
|
CLI_ARGS+=("--platform" "linux/arm64")
|
||||||
elif [ "$ARCH" = "x86_64" ]; then
|
elif [ "$ARCH" = "x86_64" ]; then
|
||||||
|
|
|
@ -13,6 +13,7 @@ from llama_stack.apis.datasetio import DatasetIO
|
||||||
from llama_stack.apis.datasets import Dataset, DatasetInput
|
from llama_stack.apis.datasets import Dataset, DatasetInput
|
||||||
from llama_stack.apis.inference import Inference
|
from llama_stack.apis.inference import Inference
|
||||||
from llama_stack.apis.models import Model, ModelInput
|
from llama_stack.apis.models import Model, ModelInput
|
||||||
|
from llama_stack.apis.resource import Resource
|
||||||
from llama_stack.apis.safety import Safety
|
from llama_stack.apis.safety import Safety
|
||||||
from llama_stack.apis.shields import Shield, ShieldInput
|
from llama_stack.apis.shields import Shield, ShieldInput
|
||||||
from llama_stack.apis.tools import Tool, ToolGroup, ToolGroupInput, ToolRuntime
|
from llama_stack.apis.tools import Tool, ToolGroup, ToolGroupInput, ToolRuntime
|
||||||
|
@ -28,6 +29,115 @@ LLAMA_STACK_RUN_CONFIG_VERSION = "2"
|
||||||
RoutingKey = Union[str, List[str]]
|
RoutingKey = Union[str, List[str]]
|
||||||
|
|
||||||
|
|
||||||
|
class AccessAttributes(BaseModel):
|
||||||
|
"""Structured representation of user attributes for access control.
|
||||||
|
|
||||||
|
This model defines a structured approach to representing user attributes
|
||||||
|
with common standard categories for access control.
|
||||||
|
|
||||||
|
Standard attribute categories include:
|
||||||
|
- roles: Role-based attributes (e.g., admin, data-scientist)
|
||||||
|
- teams: Team-based attributes (e.g., ml-team, infra-team)
|
||||||
|
- projects: Project access attributes (e.g., llama-3, customer-insights)
|
||||||
|
- namespaces: Namespace-based access control for resource isolation
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Standard attribute categories - the minimal set we need now
|
||||||
|
roles: Optional[List[str]] = Field(
|
||||||
|
default=None, description="Role-based attributes (e.g., 'admin', 'data-scientist', 'user')"
|
||||||
|
)
|
||||||
|
|
||||||
|
teams: Optional[List[str]] = Field(default=None, description="Team-based attributes (e.g., 'ml-team', 'nlp-team')")
|
||||||
|
|
||||||
|
projects: Optional[List[str]] = Field(
|
||||||
|
default=None, description="Project-based access attributes (e.g., 'llama-3', 'customer-insights')"
|
||||||
|
)
|
||||||
|
|
||||||
|
namespaces: Optional[List[str]] = Field(
|
||||||
|
default=None, description="Namespace-based access control for resource isolation"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ResourceWithACL(Resource):
|
||||||
|
"""Extension of Resource that adds attribute-based access control capabilities.
|
||||||
|
|
||||||
|
This class adds an optional access_attributes field that allows fine-grained control
|
||||||
|
over which users can access each resource. When attributes are defined, a user must have
|
||||||
|
matching attributes to access the resource.
|
||||||
|
|
||||||
|
Attribute Matching Algorithm:
|
||||||
|
1. If a resource has no access_attributes (None or empty dict), it's visible to all authenticated users
|
||||||
|
2. Each key in access_attributes represents an attribute category (e.g., "roles", "teams", "projects")
|
||||||
|
3. The matching algorithm requires ALL categories to match (AND relationship between categories)
|
||||||
|
4. Within each category, ANY value match is sufficient (OR relationship within a category)
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
# Resource visible to everyone (no access control)
|
||||||
|
model = Model(identifier="llama-2", ...)
|
||||||
|
|
||||||
|
# Resource visible only to admins
|
||||||
|
model = Model(
|
||||||
|
identifier="gpt-4",
|
||||||
|
access_attributes=AccessAttributes(roles=["admin"])
|
||||||
|
)
|
||||||
|
|
||||||
|
# Resource visible to data scientists on the ML team
|
||||||
|
model = Model(
|
||||||
|
identifier="private-model",
|
||||||
|
access_attributes=AccessAttributes(
|
||||||
|
roles=["data-scientist", "researcher"],
|
||||||
|
teams=["ml-team"]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# ^ User must have at least one of the roles AND be on the ml-team
|
||||||
|
|
||||||
|
# Resource visible to users with specific project access
|
||||||
|
vector_db = VectorDB(
|
||||||
|
identifier="customer-embeddings",
|
||||||
|
access_attributes=AccessAttributes(
|
||||||
|
projects=["customer-insights"],
|
||||||
|
namespaces=["confidential"]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# ^ User must have access to the customer-insights project AND have confidential namespace
|
||||||
|
"""
|
||||||
|
|
||||||
|
access_attributes: Optional[AccessAttributes] = None
|
||||||
|
|
||||||
|
|
||||||
|
# Use the extended Resource for all routable objects
|
||||||
|
class ModelWithACL(Model, ResourceWithACL):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ShieldWithACL(Shield, ResourceWithACL):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class VectorDBWithACL(VectorDB, ResourceWithACL):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetWithACL(Dataset, ResourceWithACL):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ScoringFnWithACL(ScoringFn, ResourceWithACL):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class BenchmarkWithACL(Benchmark, ResourceWithACL):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ToolWithACL(Tool, ResourceWithACL):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ToolGroupWithACL(ToolGroup, ResourceWithACL):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
RoutableObject = Union[
|
RoutableObject = Union[
|
||||||
Model,
|
Model,
|
||||||
Shield,
|
Shield,
|
||||||
|
@ -41,13 +151,14 @@ RoutableObject = Union[
|
||||||
|
|
||||||
RoutableObjectWithProvider = Annotated[
|
RoutableObjectWithProvider = Annotated[
|
||||||
Union[
|
Union[
|
||||||
Model,
|
ModelWithACL,
|
||||||
Shield,
|
ShieldWithACL,
|
||||||
VectorDB,
|
VectorDBWithACL,
|
||||||
Dataset,
|
DatasetWithACL,
|
||||||
Benchmark,
|
ScoringFnWithACL,
|
||||||
Tool,
|
BenchmarkWithACL,
|
||||||
ToolGroup,
|
ToolWithACL,
|
||||||
|
ToolGroupWithACL,
|
||||||
],
|
],
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
|
|
|
@ -11,9 +11,7 @@ from pydantic import BaseModel
|
||||||
from llama_stack.apis.inspect import (
|
from llama_stack.apis.inspect import (
|
||||||
HealthInfo,
|
HealthInfo,
|
||||||
Inspect,
|
Inspect,
|
||||||
ListProvidersResponse,
|
|
||||||
ListRoutesResponse,
|
ListRoutesResponse,
|
||||||
ProviderInfo,
|
|
||||||
RouteInfo,
|
RouteInfo,
|
||||||
VersionInfo,
|
VersionInfo,
|
||||||
)
|
)
|
||||||
|
@ -39,24 +37,6 @@ class DistributionInspectImpl(Inspect):
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def list_providers(self) -> ListProvidersResponse:
|
|
||||||
run_config = self.config.run_config
|
|
||||||
|
|
||||||
ret = []
|
|
||||||
for api, providers in run_config.providers.items():
|
|
||||||
ret.extend(
|
|
||||||
[
|
|
||||||
ProviderInfo(
|
|
||||||
api=api,
|
|
||||||
provider_id=p.provider_id,
|
|
||||||
provider_type=p.provider_type,
|
|
||||||
)
|
|
||||||
for p in providers
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
return ListProvidersResponse(data=ret)
|
|
||||||
|
|
||||||
async def list_routes(self) -> ListRoutesResponse:
|
async def list_routes(self) -> ListRoutesResponse:
|
||||||
run_config = self.config.run_config
|
run_config = self.config.run_config
|
||||||
|
|
||||||
|
|
|
@ -9,7 +9,6 @@ import inspect
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import re
|
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
@ -37,7 +36,10 @@ from llama_stack.distribution.request_headers import (
|
||||||
request_provider_data_context,
|
request_provider_data_context,
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.resolver import ProviderRegistry
|
from llama_stack.distribution.resolver import ProviderRegistry
|
||||||
from llama_stack.distribution.server.endpoints import get_all_api_endpoints
|
from llama_stack.distribution.server.endpoints import (
|
||||||
|
find_matching_endpoint,
|
||||||
|
initialize_endpoint_impls,
|
||||||
|
)
|
||||||
from llama_stack.distribution.stack import (
|
from llama_stack.distribution.stack import (
|
||||||
construct_stack,
|
construct_stack,
|
||||||
get_stack_run_config_from_template,
|
get_stack_run_config_from_template,
|
||||||
|
@ -232,31 +234,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
safe_config = redact_sensitive_fields(self.config.model_dump())
|
safe_config = redact_sensitive_fields(self.config.model_dump())
|
||||||
console.print(yaml.dump(safe_config, indent=2))
|
console.print(yaml.dump(safe_config, indent=2))
|
||||||
|
|
||||||
endpoints = get_all_api_endpoints()
|
self.endpoint_impls = initialize_endpoint_impls(self.impls)
|
||||||
endpoint_impls = {}
|
|
||||||
|
|
||||||
def _convert_path_to_regex(path: str) -> str:
|
|
||||||
# Convert {param} to named capture groups
|
|
||||||
# handle {param:path} as well which allows for forward slashes in the param value
|
|
||||||
pattern = re.sub(
|
|
||||||
r"{(\w+)(?::path)?}",
|
|
||||||
lambda m: f"(?P<{m.group(1)}>{'[^/]+' if not m.group(0).endswith(':path') else '.+'})",
|
|
||||||
path,
|
|
||||||
)
|
|
||||||
|
|
||||||
return f"^{pattern}$"
|
|
||||||
|
|
||||||
for api, api_endpoints in endpoints.items():
|
|
||||||
if api not in self.impls:
|
|
||||||
continue
|
|
||||||
for endpoint in api_endpoints:
|
|
||||||
impl = self.impls[api]
|
|
||||||
func = getattr(impl, endpoint.name)
|
|
||||||
if endpoint.method not in endpoint_impls:
|
|
||||||
endpoint_impls[endpoint.method] = {}
|
|
||||||
endpoint_impls[endpoint.method][_convert_path_to_regex(endpoint.route)] = func
|
|
||||||
|
|
||||||
self.endpoint_impls = endpoint_impls
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
async def request(
|
async def request(
|
||||||
|
@ -290,32 +268,6 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
def _find_matching_endpoint(self, method: str, path: str) -> tuple[Any, dict]:
|
|
||||||
"""Find the matching endpoint implementation for a given method and path.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
method: HTTP method (GET, POST, etc.)
|
|
||||||
path: URL path to match against
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A tuple of (endpoint_function, path_params)
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If no matching endpoint is found
|
|
||||||
"""
|
|
||||||
impls = self.endpoint_impls.get(method)
|
|
||||||
if not impls:
|
|
||||||
raise ValueError(f"No endpoint found for {path}")
|
|
||||||
|
|
||||||
for regex, func in impls.items():
|
|
||||||
match = re.match(regex, path)
|
|
||||||
if match:
|
|
||||||
# Extract named groups from the regex match
|
|
||||||
path_params = match.groupdict()
|
|
||||||
return func, path_params
|
|
||||||
|
|
||||||
raise ValueError(f"No endpoint found for {path}")
|
|
||||||
|
|
||||||
async def _call_non_streaming(
|
async def _call_non_streaming(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
|
@ -326,10 +278,10 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
body = options.params or {}
|
body = options.params or {}
|
||||||
body |= options.json_data or {}
|
body |= options.json_data or {}
|
||||||
|
|
||||||
matched_func, path_params = self._find_matching_endpoint(options.method, path)
|
matched_func, path_params, route = find_matching_endpoint(options.method, path, self.endpoint_impls)
|
||||||
body |= path_params
|
body |= path_params
|
||||||
body = self._convert_body(path, options.method, body)
|
body = self._convert_body(path, options.method, body)
|
||||||
await start_trace(options.url, {"__location__": "library_client"})
|
await start_trace(route, {"__location__": "library_client"})
|
||||||
try:
|
try:
|
||||||
result = await matched_func(**body)
|
result = await matched_func(**body)
|
||||||
finally:
|
finally:
|
||||||
|
@ -371,13 +323,13 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
path = options.url
|
path = options.url
|
||||||
body = options.params or {}
|
body = options.params or {}
|
||||||
body |= options.json_data or {}
|
body |= options.json_data or {}
|
||||||
func, path_params = self._find_matching_endpoint(options.method, path)
|
func, path_params, route = find_matching_endpoint(options.method, path, self.endpoint_impls)
|
||||||
body |= path_params
|
body |= path_params
|
||||||
|
|
||||||
body = self._convert_body(path, options.method, body)
|
body = self._convert_body(path, options.method, body)
|
||||||
|
|
||||||
async def gen():
|
async def gen():
|
||||||
await start_trace(options.url, {"__location__": "library_client"})
|
await start_trace(route, {"__location__": "library_client"})
|
||||||
try:
|
try:
|
||||||
async for chunk in await func(**body):
|
async for chunk in await func(**body):
|
||||||
data = json.dumps(convert_pydantic_to_json_value(chunk))
|
data = json.dumps(convert_pydantic_to_json_value(chunk))
|
||||||
|
@ -422,7 +374,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
if not body:
|
if not body:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
func, _ = self._find_matching_endpoint(method, path)
|
func, _, _ = find_matching_endpoint(method, path, self.endpoint_impls)
|
||||||
sig = inspect.signature(func)
|
sig = inspect.signature(func)
|
||||||
|
|
||||||
# Strip NOT_GIVENs to use the defaults in signature
|
# Strip NOT_GIVENs to use the defaults in signature
|
||||||
|
|
|
@ -7,21 +7,26 @@
|
||||||
import contextvars
|
import contextvars
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, ContextManager, Dict, Optional
|
from typing import Any, ContextManager, Dict, List, Optional
|
||||||
|
|
||||||
from .utils.dynamic import instantiate_class_type
|
from .utils.dynamic import instantiate_class_type
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Context variable for request provider data
|
# Context variable for request provider data and auth attributes
|
||||||
PROVIDER_DATA_VAR = contextvars.ContextVar("provider_data", default=None)
|
PROVIDER_DATA_VAR = contextvars.ContextVar("provider_data", default=None)
|
||||||
|
|
||||||
|
|
||||||
class RequestProviderDataContext(ContextManager):
|
class RequestProviderDataContext(ContextManager):
|
||||||
"""Context manager for request provider data"""
|
"""Context manager for request provider data"""
|
||||||
|
|
||||||
def __init__(self, provider_data: Optional[Dict[str, Any]] = None):
|
def __init__(
|
||||||
self.provider_data = provider_data
|
self, provider_data: Optional[Dict[str, Any]] = None, auth_attributes: Optional[Dict[str, List[str]]] = None
|
||||||
|
):
|
||||||
|
self.provider_data = provider_data or {}
|
||||||
|
if auth_attributes:
|
||||||
|
self.provider_data["__auth_attributes"] = auth_attributes
|
||||||
|
|
||||||
self.token = None
|
self.token = None
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
|
@ -80,7 +85,17 @@ def parse_request_provider_data(headers: Dict[str, str]) -> Optional[Dict[str, A
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def request_provider_data_context(headers: Dict[str, str]) -> ContextManager:
|
def request_provider_data_context(
|
||||||
"""Context manager that sets request provider data from headers for the duration of the context"""
|
headers: Dict[str, str], auth_attributes: Optional[Dict[str, List[str]]] = None
|
||||||
|
) -> ContextManager:
|
||||||
|
"""Context manager that sets request provider data from headers and auth attributes for the duration of the context"""
|
||||||
provider_data = parse_request_provider_data(headers)
|
provider_data = parse_request_provider_data(headers)
|
||||||
return RequestProviderDataContext(provider_data)
|
return RequestProviderDataContext(provider_data, auth_attributes)
|
||||||
|
|
||||||
|
|
||||||
|
def get_auth_attributes() -> Optional[Dict[str, List[str]]]:
|
||||||
|
"""Helper to retrieve auth attributes from the provider data context"""
|
||||||
|
provider_data = PROVIDER_DATA_VAR.get()
|
||||||
|
if not provider_data:
|
||||||
|
return None
|
||||||
|
return provider_data.get("__auth_attributes")
|
||||||
|
|
|
@ -19,6 +19,8 @@ from llama_stack.apis.datasets import (
|
||||||
DatasetType,
|
DatasetType,
|
||||||
DataSource,
|
DataSource,
|
||||||
ListDatasetsResponse,
|
ListDatasetsResponse,
|
||||||
|
RowsDataSource,
|
||||||
|
URIDataSource,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType
|
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType
|
||||||
from llama_stack.apis.resource import ResourceType
|
from llama_stack.apis.resource import ResourceType
|
||||||
|
@ -32,11 +34,22 @@ from llama_stack.apis.tools import (
|
||||||
ToolHost,
|
ToolHost,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs
|
from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs
|
||||||
|
from llama_stack.distribution.access_control import check_access
|
||||||
from llama_stack.distribution.datatypes import (
|
from llama_stack.distribution.datatypes import (
|
||||||
|
AccessAttributes,
|
||||||
|
BenchmarkWithACL,
|
||||||
|
DatasetWithACL,
|
||||||
|
ModelWithACL,
|
||||||
RoutableObject,
|
RoutableObject,
|
||||||
RoutableObjectWithProvider,
|
RoutableObjectWithProvider,
|
||||||
RoutedProtocol,
|
RoutedProtocol,
|
||||||
|
ScoringFnWithACL,
|
||||||
|
ShieldWithACL,
|
||||||
|
ToolGroupWithACL,
|
||||||
|
ToolWithACL,
|
||||||
|
VectorDBWithACL,
|
||||||
)
|
)
|
||||||
|
from llama_stack.distribution.request_headers import get_auth_attributes
|
||||||
from llama_stack.distribution.store import DistributionRegistry
|
from llama_stack.distribution.store import DistributionRegistry
|
||||||
from llama_stack.providers.datatypes import Api, RoutingTable
|
from llama_stack.providers.datatypes import Api, RoutingTable
|
||||||
|
|
||||||
|
@ -165,6 +178,11 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
if not obj:
|
if not obj:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
# Check if user has permission to access this object
|
||||||
|
if not check_access(obj.identifier, getattr(obj, "access_attributes", None), get_auth_attributes()):
|
||||||
|
logger.debug(f"Access denied to {type} '{identifier}' based on attribute mismatch")
|
||||||
|
return None
|
||||||
|
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
async def unregister_object(self, obj: RoutableObjectWithProvider) -> None:
|
async def unregister_object(self, obj: RoutableObjectWithProvider) -> None:
|
||||||
|
@ -181,6 +199,13 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
|
|
||||||
p = self.impls_by_provider_id[obj.provider_id]
|
p = self.impls_by_provider_id[obj.provider_id]
|
||||||
|
|
||||||
|
# If object supports access control but no attributes set, use creator's attributes
|
||||||
|
if not obj.access_attributes:
|
||||||
|
creator_attributes = get_auth_attributes()
|
||||||
|
if creator_attributes:
|
||||||
|
obj.access_attributes = AccessAttributes(**creator_attributes)
|
||||||
|
logger.info(f"Setting access attributes for {obj.type} '{obj.identifier}' based on creator's identity")
|
||||||
|
|
||||||
registered_obj = await register_object_with_provider(obj, p)
|
registered_obj = await register_object_with_provider(obj, p)
|
||||||
# TODO: This needs to be fixed for all APIs once they return the registered object
|
# TODO: This needs to be fixed for all APIs once they return the registered object
|
||||||
if obj.type == ResourceType.model.value:
|
if obj.type == ResourceType.model.value:
|
||||||
|
@ -193,7 +218,17 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
|
|
||||||
async def get_all_with_type(self, type: str) -> List[RoutableObjectWithProvider]:
|
async def get_all_with_type(self, type: str) -> List[RoutableObjectWithProvider]:
|
||||||
objs = await self.dist_registry.get_all()
|
objs = await self.dist_registry.get_all()
|
||||||
return [obj for obj in objs if obj.type == type]
|
filtered_objs = [obj for obj in objs if obj.type == type]
|
||||||
|
|
||||||
|
# Apply attribute-based access control filtering
|
||||||
|
if filtered_objs:
|
||||||
|
filtered_objs = [
|
||||||
|
obj
|
||||||
|
for obj in filtered_objs
|
||||||
|
if check_access(obj.identifier, getattr(obj, "access_attributes", None), get_auth_attributes())
|
||||||
|
]
|
||||||
|
|
||||||
|
return filtered_objs
|
||||||
|
|
||||||
|
|
||||||
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
|
@ -230,7 +265,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
model_type = ModelType.llm
|
model_type = ModelType.llm
|
||||||
if "embedding_dimension" not in metadata and model_type == ModelType.embedding:
|
if "embedding_dimension" not in metadata and model_type == ModelType.embedding:
|
||||||
raise ValueError("Embedding model must have an embedding dimension in its metadata")
|
raise ValueError("Embedding model must have an embedding dimension in its metadata")
|
||||||
model = Model(
|
model = ModelWithACL(
|
||||||
identifier=model_id,
|
identifier=model_id,
|
||||||
provider_resource_id=provider_model_id,
|
provider_resource_id=provider_model_id,
|
||||||
provider_id=provider_id,
|
provider_id=provider_id,
|
||||||
|
@ -276,7 +311,7 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
||||||
)
|
)
|
||||||
if params is None:
|
if params is None:
|
||||||
params = {}
|
params = {}
|
||||||
shield = Shield(
|
shield = ShieldWithACL(
|
||||||
identifier=shield_id,
|
identifier=shield_id,
|
||||||
provider_resource_id=provider_shield_id,
|
provider_resource_id=provider_shield_id,
|
||||||
provider_id=provider_id,
|
provider_id=provider_id,
|
||||||
|
@ -330,7 +365,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
||||||
"embedding_model": embedding_model,
|
"embedding_model": embedding_model,
|
||||||
"embedding_dimension": model.metadata["embedding_dimension"],
|
"embedding_dimension": model.metadata["embedding_dimension"],
|
||||||
}
|
}
|
||||||
vector_db = TypeAdapter(VectorDB).validate_python(vector_db_data)
|
vector_db = TypeAdapter(VectorDBWithACL).validate_python(vector_db_data)
|
||||||
await self.register_object(vector_db)
|
await self.register_object(vector_db)
|
||||||
return vector_db
|
return vector_db
|
||||||
|
|
||||||
|
@ -358,6 +393,12 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
dataset_id: Optional[str] = None,
|
dataset_id: Optional[str] = None,
|
||||||
) -> Dataset:
|
) -> Dataset:
|
||||||
|
if isinstance(source, dict):
|
||||||
|
if source["type"] == "uri":
|
||||||
|
source = URIDataSource.parse_obj(source)
|
||||||
|
elif source["type"] == "rows":
|
||||||
|
source = RowsDataSource.parse_obj(source)
|
||||||
|
|
||||||
if not dataset_id:
|
if not dataset_id:
|
||||||
dataset_id = f"dataset-{str(uuid.uuid4())}"
|
dataset_id = f"dataset-{str(uuid.uuid4())}"
|
||||||
|
|
||||||
|
@ -378,7 +419,7 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
||||||
if metadata is None:
|
if metadata is None:
|
||||||
metadata = {}
|
metadata = {}
|
||||||
|
|
||||||
dataset = Dataset(
|
dataset = DatasetWithACL(
|
||||||
identifier=dataset_id,
|
identifier=dataset_id,
|
||||||
provider_resource_id=provider_dataset_id,
|
provider_resource_id=provider_dataset_id,
|
||||||
provider_id=provider_id,
|
provider_id=provider_id,
|
||||||
|
@ -429,7 +470,7 @@ class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
|
||||||
raise ValueError("No evaluation providers available. Please configure an evaluation provider.")
|
raise ValueError("No evaluation providers available. Please configure an evaluation provider.")
|
||||||
provider_id = list(self.impls_by_provider_id.keys())[0]
|
provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||||
|
|
||||||
benchmark = Benchmark(
|
benchmark = BenchmarkWithACL(
|
||||||
identifier=benchmark_id,
|
identifier=benchmark_id,
|
||||||
dataset_id=dataset_id,
|
dataset_id=dataset_id,
|
||||||
grader_ids=grader_ids,
|
grader_ids=grader_ids,
|
||||||
|
@ -473,7 +514,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
||||||
|
|
||||||
for tool_def in tool_defs:
|
for tool_def in tool_defs:
|
||||||
tools.append(
|
tools.append(
|
||||||
Tool(
|
ToolWithACL(
|
||||||
identifier=tool_def.name,
|
identifier=tool_def.name,
|
||||||
toolgroup_id=toolgroup_id,
|
toolgroup_id=toolgroup_id,
|
||||||
description=tool_def.description or "",
|
description=tool_def.description or "",
|
||||||
|
@ -498,7 +539,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
||||||
await self.register_object(tool)
|
await self.register_object(tool)
|
||||||
|
|
||||||
await self.dist_registry.register(
|
await self.dist_registry.register(
|
||||||
ToolGroup(
|
ToolGroupWithACL(
|
||||||
identifier=toolgroup_id,
|
identifier=toolgroup_id,
|
||||||
provider_id=provider_id,
|
provider_id=provider_id,
|
||||||
provider_resource_id=toolgroup_id,
|
provider_resource_id=toolgroup_id,
|
||||||
|
@ -511,7 +552,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
||||||
tool_group = await self.get_tool_group(toolgroup_id)
|
tool_group = await self.get_tool_group(toolgroup_id)
|
||||||
if tool_group is None:
|
if tool_group is None:
|
||||||
raise ValueError(f"Tool group {toolgroup_id} not found")
|
raise ValueError(f"Tool group {toolgroup_id} not found")
|
||||||
tools = await self.list_tools(toolgroup_id).data
|
tools = (await self.list_tools(toolgroup_id)).data
|
||||||
for tool in tools:
|
for tool in tools:
|
||||||
await self.unregister_object(tool)
|
await self.unregister_object(tool)
|
||||||
await self.unregister_object(tool_group)
|
await self.unregister_object(tool_group)
|
||||||
|
|
|
@ -5,16 +5,118 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
from typing import Dict, List, Optional
|
||||||
from urllib.parse import parse_qs
|
from urllib.parse import parse_qs
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from llama_stack.distribution.datatypes import AccessAttributes
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="auth")
|
logger = get_logger(name=__name__, category="auth")
|
||||||
|
|
||||||
|
|
||||||
|
class AuthRequestContext(BaseModel):
|
||||||
|
path: str = Field(description="The path of the request being authenticated")
|
||||||
|
|
||||||
|
headers: Dict[str, str] = Field(description="HTTP headers from the original request (excluding Authorization)")
|
||||||
|
|
||||||
|
params: Dict[str, List[str]] = Field(
|
||||||
|
description="Query parameters from the original request, parsed as dictionary of lists"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AuthRequest(BaseModel):
|
||||||
|
api_key: str = Field(description="The API key extracted from the Authorization header")
|
||||||
|
|
||||||
|
request: AuthRequestContext = Field(description="Context information about the request being authenticated")
|
||||||
|
|
||||||
|
|
||||||
|
class AuthResponse(BaseModel):
|
||||||
|
"""The format of the authentication response from the auth endpoint."""
|
||||||
|
|
||||||
|
access_attributes: Optional[AccessAttributes] = Field(
|
||||||
|
default=None,
|
||||||
|
description="""
|
||||||
|
Structured user attributes for attribute-based access control.
|
||||||
|
|
||||||
|
These attributes determine which resources the user can access.
|
||||||
|
The model provides standard categories like "roles", "teams", "projects", and "namespaces".
|
||||||
|
Each attribute category contains a list of values that the user has for that category.
|
||||||
|
During access control checks, these values are compared against resource requirements.
|
||||||
|
|
||||||
|
Example with standard categories:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"roles": ["admin", "data-scientist"],
|
||||||
|
"teams": ["ml-team"],
|
||||||
|
"projects": ["llama-3"],
|
||||||
|
"namespaces": ["research"]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
message: Optional[str] = Field(
|
||||||
|
default=None, description="Optional message providing additional context about the authentication result."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AuthenticationMiddleware:
|
class AuthenticationMiddleware:
|
||||||
|
"""Middleware that authenticates requests using an external auth endpoint.
|
||||||
|
|
||||||
|
This middleware:
|
||||||
|
1. Extracts the Bearer token from the Authorization header
|
||||||
|
2. Sends it to the configured auth endpoint along with request details
|
||||||
|
3. Validates the response and extracts user attributes
|
||||||
|
4. Makes these attributes available to the route handlers for access control
|
||||||
|
|
||||||
|
Authentication Request Format:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"api_key": "the-api-key-extracted-from-auth-header",
|
||||||
|
"request": {
|
||||||
|
"path": "/models/list",
|
||||||
|
"headers": {
|
||||||
|
"content-type": "application/json",
|
||||||
|
"user-agent": "..."
|
||||||
|
// All headers except Authorization
|
||||||
|
},
|
||||||
|
"params": {
|
||||||
|
"limit": ["100"],
|
||||||
|
"offset": ["0"]
|
||||||
|
// Query parameters as key -> list of values
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Expected Auth Endpoint Response Format:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"access_attributes": { // Structured attribute format
|
||||||
|
"roles": ["admin", "user"],
|
||||||
|
"teams": ["ml-team", "nlp-team"],
|
||||||
|
"projects": ["llama-3", "project-x"],
|
||||||
|
"namespaces": ["research"]
|
||||||
|
},
|
||||||
|
"message": "Optional message about auth result"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Attribute-Based Access Control:
|
||||||
|
The attributes returned by the auth endpoint are used to determine which
|
||||||
|
resources the user can access. Resources can specify required attributes
|
||||||
|
using the access_attributes field. For a user to access a resource:
|
||||||
|
|
||||||
|
1. All attribute categories specified in the resource must be present in the user's attributes
|
||||||
|
2. For each category, the user must have at least one matching value
|
||||||
|
|
||||||
|
If the auth endpoint doesn't return any attributes, the user will only be able to
|
||||||
|
access resources that don't have access_attributes defined.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, app, auth_endpoint):
|
def __init__(self, app, auth_endpoint):
|
||||||
self.app = app
|
self.app = app
|
||||||
self.auth_endpoint = auth_endpoint
|
self.auth_endpoint = auth_endpoint
|
||||||
|
@ -32,25 +134,57 @@ class AuthenticationMiddleware:
|
||||||
path = scope.get("path", "")
|
path = scope.get("path", "")
|
||||||
request_headers = {k.decode(): v.decode() for k, v in headers.items()}
|
request_headers = {k.decode(): v.decode() for k, v in headers.items()}
|
||||||
|
|
||||||
|
# Remove sensitive headers
|
||||||
|
if "authorization" in request_headers:
|
||||||
|
del request_headers["authorization"]
|
||||||
|
|
||||||
query_string = scope.get("query_string", b"").decode()
|
query_string = scope.get("query_string", b"").decode()
|
||||||
params = parse_qs(query_string)
|
params = parse_qs(query_string)
|
||||||
|
|
||||||
auth_data = {
|
# Build the auth request model
|
||||||
"api_key": api_key,
|
auth_request = AuthRequest(
|
||||||
"request": {
|
api_key=api_key,
|
||||||
"path": path,
|
request=AuthRequestContext(
|
||||||
"headers": request_headers,
|
path=path,
|
||||||
"params": params,
|
headers=request_headers,
|
||||||
},
|
params=params,
|
||||||
}
|
),
|
||||||
|
)
|
||||||
|
|
||||||
# Validate with authentication endpoint
|
# Validate with authentication endpoint
|
||||||
try:
|
try:
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
response = await client.post(self.auth_endpoint, json=auth_data)
|
response = await client.post(
|
||||||
|
self.auth_endpoint,
|
||||||
|
json=auth_request.model_dump(),
|
||||||
|
timeout=10.0, # Add a reasonable timeout
|
||||||
|
)
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
logger.warning(f"Authentication failed: {response.status_code}")
|
logger.warning(f"Authentication failed: {response.status_code}")
|
||||||
return await self._send_auth_error(send, "Authentication failed")
|
return await self._send_auth_error(send, "Authentication failed")
|
||||||
|
|
||||||
|
# Parse and validate the auth response
|
||||||
|
try:
|
||||||
|
response_data = response.json()
|
||||||
|
auth_response = AuthResponse(**response_data)
|
||||||
|
|
||||||
|
# Store attributes in request scope for access control
|
||||||
|
if auth_response.access_attributes:
|
||||||
|
user_attributes = auth_response.access_attributes.model_dump(exclude_none=True)
|
||||||
|
else:
|
||||||
|
logger.warning("No access attributes, setting namespace to api_key by default")
|
||||||
|
user_attributes = {
|
||||||
|
"namespaces": [api_key],
|
||||||
|
}
|
||||||
|
|
||||||
|
scope["user_attributes"] = user_attributes
|
||||||
|
logger.debug(f"Authentication successful: {len(user_attributes)} attributes")
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Error parsing authentication response")
|
||||||
|
return await self._send_auth_error(send, "Invalid authentication response format")
|
||||||
|
except httpx.TimeoutException:
|
||||||
|
logger.exception("Authentication request timed out")
|
||||||
|
return await self._send_auth_error(send, "Authentication service timeout")
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Error during authentication")
|
logger.exception("Error during authentication")
|
||||||
return await self._send_auth_error(send, "Authentication service error")
|
return await self._send_auth_error(send, "Authentication service error")
|
||||||
|
|
|
@ -5,6 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
|
import re
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
@ -19,6 +20,7 @@ class ApiEndpoint(BaseModel):
|
||||||
route: str
|
route: str
|
||||||
method: str
|
method: str
|
||||||
name: str
|
name: str
|
||||||
|
descriptive_name: str | None = None
|
||||||
|
|
||||||
|
|
||||||
def toolgroup_protocol_map():
|
def toolgroup_protocol_map():
|
||||||
|
@ -58,8 +60,69 @@ def get_all_api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
|
||||||
method = "delete"
|
method = "delete"
|
||||||
else:
|
else:
|
||||||
method = "post"
|
method = "post"
|
||||||
endpoints.append(ApiEndpoint(route=route, method=method, name=name))
|
endpoints.append(
|
||||||
|
ApiEndpoint(route=route, method=method, name=name, descriptive_name=webmethod.descriptive_name)
|
||||||
|
)
|
||||||
|
|
||||||
apis[api] = endpoints
|
apis[api] = endpoints
|
||||||
|
|
||||||
return apis
|
return apis
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_endpoint_impls(impls):
|
||||||
|
endpoints = get_all_api_endpoints()
|
||||||
|
endpoint_impls = {}
|
||||||
|
|
||||||
|
def _convert_path_to_regex(path: str) -> str:
|
||||||
|
# Convert {param} to named capture groups
|
||||||
|
# handle {param:path} as well which allows for forward slashes in the param value
|
||||||
|
pattern = re.sub(
|
||||||
|
r"{(\w+)(?::path)?}",
|
||||||
|
lambda m: f"(?P<{m.group(1)}>{'[^/]+' if not m.group(0).endswith(':path') else '.+'})",
|
||||||
|
path,
|
||||||
|
)
|
||||||
|
|
||||||
|
return f"^{pattern}$"
|
||||||
|
|
||||||
|
for api, api_endpoints in endpoints.items():
|
||||||
|
if api not in impls:
|
||||||
|
continue
|
||||||
|
for endpoint in api_endpoints:
|
||||||
|
impl = impls[api]
|
||||||
|
func = getattr(impl, endpoint.name)
|
||||||
|
if endpoint.method not in endpoint_impls:
|
||||||
|
endpoint_impls[endpoint.method] = {}
|
||||||
|
endpoint_impls[endpoint.method][_convert_path_to_regex(endpoint.route)] = (
|
||||||
|
func,
|
||||||
|
endpoint.descriptive_name or endpoint.route,
|
||||||
|
)
|
||||||
|
|
||||||
|
return endpoint_impls
|
||||||
|
|
||||||
|
|
||||||
|
def find_matching_endpoint(method, path, endpoint_impls):
|
||||||
|
"""Find the matching endpoint implementation for a given method and path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
method: HTTP method (GET, POST, etc.)
|
||||||
|
path: URL path to match against
|
||||||
|
endpoint_impls: A dictionary of endpoint implementations
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple of (endpoint_function, path_params, descriptive_name)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If no matching endpoint is found
|
||||||
|
"""
|
||||||
|
impls = endpoint_impls.get(method.lower())
|
||||||
|
if not impls:
|
||||||
|
raise ValueError(f"No endpoint found for {path}")
|
||||||
|
|
||||||
|
for regex, (func, descriptive_name) in impls.items():
|
||||||
|
match = re.match(regex, path)
|
||||||
|
if match:
|
||||||
|
# Extract named groups from the regex match
|
||||||
|
path_params = match.groupdict()
|
||||||
|
return func, path_params, descriptive_name
|
||||||
|
|
||||||
|
raise ValueError(f"No endpoint found for {path}")
|
||||||
|
|
|
@ -32,6 +32,10 @@ from llama_stack.distribution.request_headers import (
|
||||||
request_provider_data_context,
|
request_provider_data_context,
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.resolver import InvalidProviderError
|
from llama_stack.distribution.resolver import InvalidProviderError
|
||||||
|
from llama_stack.distribution.server.endpoints import (
|
||||||
|
find_matching_endpoint,
|
||||||
|
initialize_endpoint_impls,
|
||||||
|
)
|
||||||
from llama_stack.distribution.stack import (
|
from llama_stack.distribution.stack import (
|
||||||
construct_stack,
|
construct_stack,
|
||||||
redact_sensitive_fields,
|
redact_sensitive_fields,
|
||||||
|
@ -179,8 +183,11 @@ async def sse_generator(event_gen):
|
||||||
|
|
||||||
def create_dynamic_typed_route(func: Any, method: str, route: str):
|
def create_dynamic_typed_route(func: Any, method: str, route: str):
|
||||||
async def endpoint(request: Request, **kwargs):
|
async def endpoint(request: Request, **kwargs):
|
||||||
# Use context manager for request provider data
|
# Get auth attributes from the request scope
|
||||||
with request_provider_data_context(request.headers):
|
user_attributes = request.scope.get("user_attributes", {})
|
||||||
|
|
||||||
|
# Use context manager with both provider data and auth attributes
|
||||||
|
with request_provider_data_context(request.headers, user_attributes):
|
||||||
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
|
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -219,14 +226,30 @@ def create_dynamic_typed_route(func: Any, method: str, route: str):
|
||||||
|
|
||||||
|
|
||||||
class TracingMiddleware:
|
class TracingMiddleware:
|
||||||
def __init__(self, app):
|
def __init__(self, app, impls):
|
||||||
self.app = app
|
self.app = app
|
||||||
|
self.impls = impls
|
||||||
|
|
||||||
async def __call__(self, scope, receive, send):
|
async def __call__(self, scope, receive, send):
|
||||||
path = scope.get("path", "")
|
if scope.get("type") == "lifespan":
|
||||||
await start_trace(path, {"__location__": "server"})
|
|
||||||
try:
|
|
||||||
return await self.app(scope, receive, send)
|
return await self.app(scope, receive, send)
|
||||||
|
|
||||||
|
path = scope.get("path", "")
|
||||||
|
if not hasattr(self, "endpoint_impls"):
|
||||||
|
self.endpoint_impls = initialize_endpoint_impls(self.impls)
|
||||||
|
_, _, trace_path = find_matching_endpoint(scope.get("method", "GET"), path, self.endpoint_impls)
|
||||||
|
|
||||||
|
trace_context = await start_trace(trace_path, {"__location__": "server", "raw_path": path})
|
||||||
|
|
||||||
|
async def send_with_trace_id(message):
|
||||||
|
if message["type"] == "http.response.start":
|
||||||
|
headers = message.get("headers", [])
|
||||||
|
headers.append([b"x-trace-id", str(trace_context.trace_id).encode()])
|
||||||
|
message["headers"] = headers
|
||||||
|
await send(message)
|
||||||
|
|
||||||
|
try:
|
||||||
|
return await self.app(scope, receive, send_with_trace_id)
|
||||||
finally:
|
finally:
|
||||||
await end_trace()
|
await end_trace()
|
||||||
|
|
||||||
|
@ -348,7 +371,6 @@ def main():
|
||||||
logger.info(yaml.dump(safe_config, indent=2))
|
logger.info(yaml.dump(safe_config, indent=2))
|
||||||
|
|
||||||
app = FastAPI(lifespan=lifespan)
|
app = FastAPI(lifespan=lifespan)
|
||||||
app.add_middleware(TracingMiddleware)
|
|
||||||
if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"):
|
if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"):
|
||||||
app.add_middleware(ClientVersionMiddleware)
|
app.add_middleware(ClientVersionMiddleware)
|
||||||
|
|
||||||
|
@ -366,7 +388,7 @@ def main():
|
||||||
if Api.telemetry in impls:
|
if Api.telemetry in impls:
|
||||||
setup_logger(impls[Api.telemetry])
|
setup_logger(impls[Api.telemetry])
|
||||||
else:
|
else:
|
||||||
setup_logger(TelemetryAdapter(TelemetryConfig()))
|
setup_logger(TelemetryAdapter(TelemetryConfig(), {}))
|
||||||
|
|
||||||
all_endpoints = get_all_api_endpoints()
|
all_endpoints = get_all_api_endpoints()
|
||||||
|
|
||||||
|
@ -412,6 +434,7 @@ def main():
|
||||||
app.exception_handler(Exception)(global_exception_handler)
|
app.exception_handler(Exception)(global_exception_handler)
|
||||||
|
|
||||||
app.__llama_stack_impls__ = impls
|
app.__llama_stack_impls__ = impls
|
||||||
|
app.add_middleware(TracingMiddleware, impls=impls)
|
||||||
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|
||||||
|
|
|
@ -12,9 +12,12 @@ import pydantic
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import KVStoreConfig, RoutableObjectWithProvider
|
from llama_stack.distribution.datatypes import KVStoreConfig, RoutableObjectWithProvider
|
||||||
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
|
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||||
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
|
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
|
||||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||||
|
|
||||||
|
logger = get_logger(__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
class DistributionRegistry(Protocol):
|
class DistributionRegistry(Protocol):
|
||||||
async def get_all(self) -> List[RoutableObjectWithProvider]: ...
|
async def get_all(self) -> List[RoutableObjectWithProvider]: ...
|
||||||
|
@ -47,8 +50,13 @@ def _parse_registry_values(values: List[str]) -> List[RoutableObjectWithProvider
|
||||||
"""Utility function to parse registry values into RoutableObjectWithProvider objects."""
|
"""Utility function to parse registry values into RoutableObjectWithProvider objects."""
|
||||||
all_objects = []
|
all_objects = []
|
||||||
for value in values:
|
for value in values:
|
||||||
|
try:
|
||||||
obj = pydantic.TypeAdapter(RoutableObjectWithProvider).validate_json(value)
|
obj = pydantic.TypeAdapter(RoutableObjectWithProvider).validate_json(value)
|
||||||
all_objects.append(obj)
|
all_objects.append(obj)
|
||||||
|
except pydantic.ValidationError as e:
|
||||||
|
logger.error(f"Error parsing registry value, raw value: {value}. Error: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
return all_objects
|
return all_objects
|
||||||
|
|
||||||
|
|
||||||
|
@ -73,7 +81,11 @@ class DiskDistributionRegistry(DistributionRegistry):
|
||||||
if not json_str:
|
if not json_str:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
return pydantic.TypeAdapter(RoutableObjectWithProvider).validate_json(json_str)
|
return pydantic.TypeAdapter(RoutableObjectWithProvider).validate_json(json_str)
|
||||||
|
except pydantic.ValidationError as e:
|
||||||
|
logger.error(f"Error parsing registry value for {type}:{identifier}, raw value: {json_str}. Error: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
async def update(self, obj: RoutableObjectWithProvider) -> None:
|
async def update(self, obj: RoutableObjectWithProvider) -> None:
|
||||||
await self.kvstore.set(
|
await self.kvstore.set(
|
||||||
|
|
|
@ -5,9 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
from llama_stack_client.lib.agents.agent import Agent
|
from llama_stack_client import Agent, AgentEventLogger, RAGDocument
|
||||||
from llama_stack_client.lib.agents.event_logger import EventLogger
|
|
||||||
from llama_stack_client.types.shared.document import Document
|
|
||||||
|
|
||||||
from llama_stack.distribution.ui.modules.api import llama_stack_api
|
from llama_stack.distribution.ui.modules.api import llama_stack_api
|
||||||
from llama_stack.distribution.ui.modules.utils import data_url_from_file
|
from llama_stack.distribution.ui.modules.utils import data_url_from_file
|
||||||
|
@ -35,7 +33,7 @@ def rag_chat_page():
|
||||||
)
|
)
|
||||||
if st.button("Create Vector Database"):
|
if st.button("Create Vector Database"):
|
||||||
documents = [
|
documents = [
|
||||||
Document(
|
RAGDocument(
|
||||||
document_id=uploaded_file.name,
|
document_id=uploaded_file.name,
|
||||||
content=data_url_from_file(uploaded_file),
|
content=data_url_from_file(uploaded_file),
|
||||||
)
|
)
|
||||||
|
@ -167,7 +165,7 @@ def rag_chat_page():
|
||||||
message_placeholder = st.empty()
|
message_placeholder = st.empty()
|
||||||
full_response = ""
|
full_response = ""
|
||||||
retrieval_response = ""
|
retrieval_response = ""
|
||||||
for log in EventLogger().log(response):
|
for log in AgentEventLogger().log(response):
|
||||||
log.print()
|
log.print()
|
||||||
if log.role == "tool_execution":
|
if log.role == "tool_execution":
|
||||||
retrieval_response += log.content.replace("====", "").strip()
|
retrieval_response += log.content.replace("====", "").strip()
|
||||||
|
|
|
@ -47,7 +47,14 @@ RecursiveType = Union[Primitive, List[Primitive], Dict[str, Primitive]]
|
||||||
class ToolCall(BaseModel):
|
class ToolCall(BaseModel):
|
||||||
call_id: str
|
call_id: str
|
||||||
tool_name: Union[BuiltinTool, str]
|
tool_name: Union[BuiltinTool, str]
|
||||||
arguments: Dict[str, RecursiveType]
|
# Plan is to deprecate the Dict in favor of a JSON string
|
||||||
|
# that is parsed on the client side instead of trying to manage
|
||||||
|
# the recursive type here.
|
||||||
|
# Making this a union so that client side can start prepping for this change.
|
||||||
|
# Eventually, we will remove both the Dict and arguments_json field,
|
||||||
|
# and arguments will just be a str
|
||||||
|
arguments: Union[str, Dict[str, RecursiveType]]
|
||||||
|
arguments_json: Optional[str] = None
|
||||||
|
|
||||||
@field_validator("tool_name", mode="before")
|
@field_validator("tool_name", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -179,13 +186,11 @@ class TopKSamplingStrategy(BaseModel):
|
||||||
top_k: int = Field(..., ge=1)
|
top_k: int = Field(..., ge=1)
|
||||||
|
|
||||||
|
|
||||||
SamplingStrategy = register_schema(
|
SamplingStrategy = Annotated[
|
||||||
Annotated[
|
|
||||||
Union[GreedySamplingStrategy, TopPSamplingStrategy, TopKSamplingStrategy],
|
Union[GreedySamplingStrategy, TopPSamplingStrategy, TopKSamplingStrategy],
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
],
|
]
|
||||||
name="SamplingStrategy",
|
register_schema(SamplingStrategy, name="SamplingStrategy")
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
|
|
@ -12,6 +12,7 @@
|
||||||
# the top-level of this source tree.
|
# the top-level of this source tree.
|
||||||
|
|
||||||
import io
|
import io
|
||||||
|
import json
|
||||||
import uuid
|
import uuid
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
@ -203,6 +204,7 @@ class ChatFormat:
|
||||||
# This code tries to handle that case
|
# This code tries to handle that case
|
||||||
if tool_name in BuiltinTool.__members__:
|
if tool_name in BuiltinTool.__members__:
|
||||||
tool_name = BuiltinTool[tool_name]
|
tool_name = BuiltinTool[tool_name]
|
||||||
|
if isinstance(tool_arguments, dict):
|
||||||
tool_arguments = {
|
tool_arguments = {
|
||||||
"query": list(tool_arguments.values())[0],
|
"query": list(tool_arguments.values())[0],
|
||||||
}
|
}
|
||||||
|
@ -229,6 +231,7 @@ class ChatFormat:
|
||||||
call_id=call_id,
|
call_id=call_id,
|
||||||
tool_name=tool_name,
|
tool_name=tool_name,
|
||||||
arguments=tool_arguments,
|
arguments=tool_arguments,
|
||||||
|
arguments_json=json.dumps(tool_arguments),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
content = ""
|
content = ""
|
||||||
|
|
|
@ -244,6 +244,7 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
|
||||||
template_str = textwrap.dedent(
|
template_str = textwrap.dedent(
|
||||||
"""
|
"""
|
||||||
If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
|
If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
|
||||||
|
For a boolean parameter, be sure to use `True` or `False` (capitalized) for the value.
|
||||||
You SHOULD NOT include any other text in the response.
|
You SHOULD NOT include any other text in the response.
|
||||||
|
|
||||||
Here is a list of functions in JSON format that you can invoke.
|
Here is a list of functions in JSON format that you can invoke.
|
||||||
|
|
|
@ -11,11 +11,8 @@
|
||||||
# top-level folder for each specific model found within the models/ directory at
|
# top-level folder for each specific model found within the models/ directory at
|
||||||
# the top-level of this source tree.
|
# the top-level of this source tree.
|
||||||
|
|
||||||
from llama_stack.models.llama.datatypes import (
|
|
||||||
BuiltinTool,
|
from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall
|
||||||
StopReason,
|
|
||||||
ToolCall,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .prompt_templates import (
|
from .prompt_templates import (
|
||||||
BuiltinToolGenerator,
|
BuiltinToolGenerator,
|
||||||
|
|
|
@ -15,8 +15,11 @@ import json
|
||||||
import re
|
import re
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.models.llama.datatypes import BuiltinTool, RecursiveType, ToolCall, ToolPromptFormat
|
from llama_stack.models.llama.datatypes import BuiltinTool, RecursiveType, ToolCall, ToolPromptFormat
|
||||||
|
|
||||||
|
logger = get_logger(name=__name__, category="inference")
|
||||||
|
|
||||||
BUILTIN_TOOL_PATTERN = r'\b(?P<tool_name>\w+)\.call\(query="(?P<query>[^"]*)"\)'
|
BUILTIN_TOOL_PATTERN = r'\b(?P<tool_name>\w+)\.call\(query="(?P<query>[^"]*)"\)'
|
||||||
CUSTOM_TOOL_CALL_PATTERN = re.compile(r"<function=(?P<function_name>[^}]+)>(?P<args>{.*?})")
|
CUSTOM_TOOL_CALL_PATTERN = re.compile(r"<function=(?P<function_name>[^}]+)>(?P<args>{.*?})")
|
||||||
|
|
||||||
|
@ -92,7 +95,15 @@ def parse_python_list_for_function_calls(input_string):
|
||||||
|
|
||||||
# Extract keyword arguments
|
# Extract keyword arguments
|
||||||
for keyword in node.keywords:
|
for keyword in node.keywords:
|
||||||
|
try:
|
||||||
function_args[keyword.arg] = ast.literal_eval(keyword.value)
|
function_args[keyword.arg] = ast.literal_eval(keyword.value)
|
||||||
|
except ValueError as e:
|
||||||
|
logger.error(
|
||||||
|
f"Error parsing tool call argument '{keyword.arg}': {e}, full input string: '{input_string}'"
|
||||||
|
)
|
||||||
|
raise ValueError(
|
||||||
|
f"Error parsing tool call argument '{keyword.arg}', full input string: '{input_string}'"
|
||||||
|
) from e
|
||||||
|
|
||||||
result.append((function_name, function_args))
|
result.append((function_name, function_args))
|
||||||
|
|
||||||
|
|
|
@ -180,23 +180,27 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator:
|
async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator:
|
||||||
await self._initialize_tools(request.toolgroups)
|
span = tracing.get_current_span()
|
||||||
async with tracing.span("create_and_execute_turn") as span:
|
if span:
|
||||||
span.set_attribute("session_id", request.session_id)
|
span.set_attribute("session_id", request.session_id)
|
||||||
span.set_attribute("agent_id", self.agent_id)
|
span.set_attribute("agent_id", self.agent_id)
|
||||||
span.set_attribute("request", request.model_dump_json())
|
span.set_attribute("request", request.model_dump_json())
|
||||||
turn_id = str(uuid.uuid4())
|
turn_id = str(uuid.uuid4())
|
||||||
span.set_attribute("turn_id", turn_id)
|
span.set_attribute("turn_id", turn_id)
|
||||||
|
|
||||||
|
await self._initialize_tools(request.toolgroups)
|
||||||
async for chunk in self._run_turn(request, turn_id):
|
async for chunk in self._run_turn(request, turn_id):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator:
|
async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator:
|
||||||
await self._initialize_tools()
|
span = tracing.get_current_span()
|
||||||
async with tracing.span("resume_turn") as span:
|
if span:
|
||||||
span.set_attribute("agent_id", self.agent_id)
|
span.set_attribute("agent_id", self.agent_id)
|
||||||
span.set_attribute("session_id", request.session_id)
|
span.set_attribute("session_id", request.session_id)
|
||||||
span.set_attribute("turn_id", request.turn_id)
|
|
||||||
span.set_attribute("request", request.model_dump_json())
|
span.set_attribute("request", request.model_dump_json())
|
||||||
|
span.set_attribute("turn_id", request.turn_id)
|
||||||
|
|
||||||
|
await self._initialize_tools()
|
||||||
async for chunk in self._run_turn(request):
|
async for chunk in self._run_turn(request):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
|
|
|
@ -13,6 +13,9 @@ from typing import List, Optional
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_stack.apis.agents import ToolExecutionStep, Turn
|
from llama_stack.apis.agents import ToolExecutionStep, Turn
|
||||||
|
from llama_stack.distribution.access_control import check_access
|
||||||
|
from llama_stack.distribution.datatypes import AccessAttributes
|
||||||
|
from llama_stack.distribution.request_headers import get_auth_attributes
|
||||||
from llama_stack.providers.utils.kvstore import KVStore
|
from llama_stack.providers.utils.kvstore import KVStore
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
@ -24,6 +27,7 @@ class AgentSessionInfo(BaseModel):
|
||||||
# TODO: is this used anywhere?
|
# TODO: is this used anywhere?
|
||||||
vector_db_id: Optional[str] = None
|
vector_db_id: Optional[str] = None
|
||||||
started_at: datetime
|
started_at: datetime
|
||||||
|
access_attributes: Optional[AccessAttributes] = None
|
||||||
|
|
||||||
|
|
||||||
class AgentPersistence:
|
class AgentPersistence:
|
||||||
|
@ -33,11 +37,18 @@ class AgentPersistence:
|
||||||
|
|
||||||
async def create_session(self, name: str) -> str:
|
async def create_session(self, name: str) -> str:
|
||||||
session_id = str(uuid.uuid4())
|
session_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
# Get current user's auth attributes for new sessions
|
||||||
|
auth_attributes = get_auth_attributes()
|
||||||
|
access_attributes = AccessAttributes(**auth_attributes) if auth_attributes else None
|
||||||
|
|
||||||
session_info = AgentSessionInfo(
|
session_info = AgentSessionInfo(
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
session_name=name,
|
session_name=name,
|
||||||
started_at=datetime.now(timezone.utc),
|
started_at=datetime.now(timezone.utc),
|
||||||
|
access_attributes=access_attributes,
|
||||||
)
|
)
|
||||||
|
|
||||||
await self.kvstore.set(
|
await self.kvstore.set(
|
||||||
key=f"session:{self.agent_id}:{session_id}",
|
key=f"session:{self.agent_id}:{session_id}",
|
||||||
value=session_info.model_dump_json(),
|
value=session_info.model_dump_json(),
|
||||||
|
@ -51,12 +62,34 @@ class AgentPersistence:
|
||||||
if not value:
|
if not value:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return AgentSessionInfo(**json.loads(value))
|
session_info = AgentSessionInfo(**json.loads(value))
|
||||||
|
|
||||||
|
# Check access to session
|
||||||
|
if not self._check_session_access(session_info):
|
||||||
|
return None
|
||||||
|
|
||||||
|
return session_info
|
||||||
|
|
||||||
|
def _check_session_access(self, session_info: AgentSessionInfo) -> bool:
|
||||||
|
"""Check if current user has access to the session."""
|
||||||
|
# Handle backward compatibility for old sessions without access control
|
||||||
|
if not hasattr(session_info, "access_attributes"):
|
||||||
|
return True
|
||||||
|
|
||||||
|
return check_access(session_info.session_id, session_info.access_attributes, get_auth_attributes())
|
||||||
|
|
||||||
|
async def get_session_if_accessible(self, session_id: str) -> Optional[AgentSessionInfo]:
|
||||||
|
"""Get session info if the user has access to it. For internal use by sub-session methods."""
|
||||||
|
session_info = await self.get_session_info(session_id)
|
||||||
|
if not session_info:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return session_info
|
||||||
|
|
||||||
async def add_vector_db_to_session(self, session_id: str, vector_db_id: str):
|
async def add_vector_db_to_session(self, session_id: str, vector_db_id: str):
|
||||||
session_info = await self.get_session_info(session_id)
|
session_info = await self.get_session_if_accessible(session_id)
|
||||||
if session_info is None:
|
if session_info is None:
|
||||||
raise ValueError(f"Session {session_id} not found")
|
raise ValueError(f"Session {session_id} not found or access denied")
|
||||||
|
|
||||||
session_info.vector_db_id = vector_db_id
|
session_info.vector_db_id = vector_db_id
|
||||||
await self.kvstore.set(
|
await self.kvstore.set(
|
||||||
|
@ -65,12 +98,18 @@ class AgentPersistence:
|
||||||
)
|
)
|
||||||
|
|
||||||
async def add_turn_to_session(self, session_id: str, turn: Turn):
|
async def add_turn_to_session(self, session_id: str, turn: Turn):
|
||||||
|
if not await self.get_session_if_accessible(session_id):
|
||||||
|
raise ValueError(f"Session {session_id} not found or access denied")
|
||||||
|
|
||||||
await self.kvstore.set(
|
await self.kvstore.set(
|
||||||
key=f"session:{self.agent_id}:{session_id}:{turn.turn_id}",
|
key=f"session:{self.agent_id}:{session_id}:{turn.turn_id}",
|
||||||
value=turn.model_dump_json(),
|
value=turn.model_dump_json(),
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_session_turns(self, session_id: str) -> List[Turn]:
|
async def get_session_turns(self, session_id: str) -> List[Turn]:
|
||||||
|
if not await self.get_session_if_accessible(session_id):
|
||||||
|
raise ValueError(f"Session {session_id} not found or access denied")
|
||||||
|
|
||||||
values = await self.kvstore.range(
|
values = await self.kvstore.range(
|
||||||
start_key=f"session:{self.agent_id}:{session_id}:",
|
start_key=f"session:{self.agent_id}:{session_id}:",
|
||||||
end_key=f"session:{self.agent_id}:{session_id}:\xff\xff\xff\xff",
|
end_key=f"session:{self.agent_id}:{session_id}:\xff\xff\xff\xff",
|
||||||
|
@ -87,6 +126,9 @@ class AgentPersistence:
|
||||||
return turns
|
return turns
|
||||||
|
|
||||||
async def get_session_turn(self, session_id: str, turn_id: str) -> Optional[Turn]:
|
async def get_session_turn(self, session_id: str, turn_id: str) -> Optional[Turn]:
|
||||||
|
if not await self.get_session_if_accessible(session_id):
|
||||||
|
raise ValueError(f"Session {session_id} not found or access denied")
|
||||||
|
|
||||||
value = await self.kvstore.get(
|
value = await self.kvstore.get(
|
||||||
key=f"session:{self.agent_id}:{session_id}:{turn_id}",
|
key=f"session:{self.agent_id}:{session_id}:{turn_id}",
|
||||||
)
|
)
|
||||||
|
@ -95,24 +137,36 @@ class AgentPersistence:
|
||||||
return Turn(**json.loads(value))
|
return Turn(**json.loads(value))
|
||||||
|
|
||||||
async def set_in_progress_tool_call_step(self, session_id: str, turn_id: str, step: ToolExecutionStep):
|
async def set_in_progress_tool_call_step(self, session_id: str, turn_id: str, step: ToolExecutionStep):
|
||||||
|
if not await self.get_session_if_accessible(session_id):
|
||||||
|
raise ValueError(f"Session {session_id} not found or access denied")
|
||||||
|
|
||||||
await self.kvstore.set(
|
await self.kvstore.set(
|
||||||
key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}",
|
key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}",
|
||||||
value=step.model_dump_json(),
|
value=step.model_dump_json(),
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_in_progress_tool_call_step(self, session_id: str, turn_id: str) -> Optional[ToolExecutionStep]:
|
async def get_in_progress_tool_call_step(self, session_id: str, turn_id: str) -> Optional[ToolExecutionStep]:
|
||||||
|
if not await self.get_session_if_accessible(session_id):
|
||||||
|
return None
|
||||||
|
|
||||||
value = await self.kvstore.get(
|
value = await self.kvstore.get(
|
||||||
key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}",
|
key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}",
|
||||||
)
|
)
|
||||||
return ToolExecutionStep(**json.loads(value)) if value else None
|
return ToolExecutionStep(**json.loads(value)) if value else None
|
||||||
|
|
||||||
async def set_num_infer_iters_in_turn(self, session_id: str, turn_id: str, num_infer_iters: int):
|
async def set_num_infer_iters_in_turn(self, session_id: str, turn_id: str, num_infer_iters: int):
|
||||||
|
if not await self.get_session_if_accessible(session_id):
|
||||||
|
raise ValueError(f"Session {session_id} not found or access denied")
|
||||||
|
|
||||||
await self.kvstore.set(
|
await self.kvstore.set(
|
||||||
key=f"num_infer_iters_in_turn:{self.agent_id}:{session_id}:{turn_id}",
|
key=f"num_infer_iters_in_turn:{self.agent_id}:{session_id}:{turn_id}",
|
||||||
value=str(num_infer_iters),
|
value=str(num_infer_iters),
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_num_infer_iters_in_turn(self, session_id: str, turn_id: str) -> Optional[int]:
|
async def get_num_infer_iters_in_turn(self, session_id: str, turn_id: str) -> Optional[int]:
|
||||||
|
if not await self.get_session_if_accessible(session_id):
|
||||||
|
return None
|
||||||
|
|
||||||
value = await self.kvstore.get(
|
value = await self.kvstore.get(
|
||||||
key=f"num_infer_iters_in_turn:{self.agent_id}:{session_id}:{turn_id}",
|
key=f"num_infer_iters_in_turn:{self.agent_id}:{session_id}:{turn_id}",
|
||||||
)
|
)
|
||||||
|
|
|
@ -35,12 +35,12 @@ class PandasDataframeDataset:
|
||||||
else:
|
else:
|
||||||
return self.df.iloc[idx].to_dict()
|
return self.df.iloc[idx].to_dict()
|
||||||
|
|
||||||
def load(self) -> None:
|
async def load(self) -> None:
|
||||||
if self.df is not None:
|
if self.df is not None:
|
||||||
return
|
return
|
||||||
|
|
||||||
if self.dataset_def.source.type == "uri":
|
if self.dataset_def.source.type == "uri":
|
||||||
self.df = get_dataframe_from_uri(self.dataset_def.source.uri)
|
self.df = await get_dataframe_from_uri(self.dataset_def.source.uri)
|
||||||
elif self.dataset_def.source.type == "rows":
|
elif self.dataset_def.source.type == "rows":
|
||||||
self.df = pandas.DataFrame(self.dataset_def.source.rows)
|
self.df = pandas.DataFrame(self.dataset_def.source.rows)
|
||||||
else:
|
else:
|
||||||
|
@ -95,7 +95,7 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
||||||
) -> IterrowsResponse:
|
) -> IterrowsResponse:
|
||||||
dataset_def = self.dataset_infos[dataset_id]
|
dataset_def = self.dataset_infos[dataset_id]
|
||||||
dataset_impl = PandasDataframeDataset(dataset_def)
|
dataset_impl = PandasDataframeDataset(dataset_def)
|
||||||
dataset_impl.load()
|
await dataset_impl.load()
|
||||||
|
|
||||||
start_index = start_index or 0
|
start_index = start_index or 0
|
||||||
|
|
||||||
|
@ -114,7 +114,7 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
||||||
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
|
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
|
||||||
dataset_def = self.dataset_infos[dataset_id]
|
dataset_def = self.dataset_infos[dataset_id]
|
||||||
dataset_impl = PandasDataframeDataset(dataset_def)
|
dataset_impl = PandasDataframeDataset(dataset_def)
|
||||||
dataset_impl.load()
|
await dataset_impl.load()
|
||||||
|
|
||||||
new_rows_df = pandas.DataFrame(rows)
|
new_rows_df = pandas.DataFrame(rows)
|
||||||
dataset_impl.df = pandas.concat([dataset_impl.df, new_rows_df], ignore_index=True)
|
dataset_impl.df = pandas.concat([dataset_impl.df, new_rows_df], ignore_index=True)
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
import json
|
import json
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
@ -20,8 +20,8 @@ from llama_stack.providers.inline.agents.meta_reference.agent_instance import (
|
||||||
from llama_stack.providers.utils.common.data_schema_validator import ColumnName
|
from llama_stack.providers.utils.common.data_schema_validator import ColumnName
|
||||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||||
|
|
||||||
from .....apis.common.job_types import Job
|
from .....apis.common.job_types import Job, JobStatus
|
||||||
from .....apis.eval.eval import BenchmarkConfig, Eval, EvaluateResponse, JobStatus
|
from .....apis.eval.eval import BenchmarkConfig, Eval, EvaluateResponse
|
||||||
from .config import MetaReferenceEvalConfig
|
from .config import MetaReferenceEvalConfig
|
||||||
|
|
||||||
EVAL_TASKS_PREFIX = "benchmarks:"
|
EVAL_TASKS_PREFIX = "benchmarks:"
|
||||||
|
@ -101,7 +101,7 @@ class MetaReferenceEvalImpl(
|
||||||
# need job scheduler queue (ray/celery) w/ jobs api
|
# need job scheduler queue (ray/celery) w/ jobs api
|
||||||
job_id = str(len(self.jobs))
|
job_id = str(len(self.jobs))
|
||||||
self.jobs[job_id] = res
|
self.jobs[job_id] = res
|
||||||
return Job(job_id=job_id)
|
return Job(job_id=job_id, status=JobStatus.completed)
|
||||||
|
|
||||||
async def _run_agent_generation(
|
async def _run_agent_generation(
|
||||||
self, input_rows: List[Dict[str, Any]], benchmark_config: BenchmarkConfig
|
self, input_rows: List[Dict[str, Any]], benchmark_config: BenchmarkConfig
|
||||||
|
@ -215,17 +215,18 @@ class MetaReferenceEvalImpl(
|
||||||
|
|
||||||
return EvaluateResponse(generations=generations, scores=score_response.results)
|
return EvaluateResponse(generations=generations, scores=score_response.results)
|
||||||
|
|
||||||
async def job_status(self, benchmark_id: str, job_id: str) -> Optional[JobStatus]:
|
async def job_status(self, benchmark_id: str, job_id: str) -> Job:
|
||||||
if job_id in self.jobs:
|
if job_id in self.jobs:
|
||||||
return JobStatus.completed
|
return Job(job_id=job_id, status=JobStatus.completed)
|
||||||
|
|
||||||
return None
|
raise ValueError(f"Job {job_id} not found")
|
||||||
|
|
||||||
async def job_cancel(self, benchmark_id: str, job_id: str) -> None:
|
async def job_cancel(self, benchmark_id: str, job_id: str) -> None:
|
||||||
raise NotImplementedError("Job cancel is not implemented yet")
|
raise NotImplementedError("Job cancel is not implemented yet")
|
||||||
|
|
||||||
async def job_result(self, benchmark_id: str, job_id: str) -> EvaluateResponse:
|
async def job_result(self, benchmark_id: str, job_id: str) -> EvaluateResponse:
|
||||||
status = await self.job_status(benchmark_id, job_id)
|
job = await self.job_status(benchmark_id, job_id)
|
||||||
|
status = job.status
|
||||||
if not status or status != JobStatus.completed:
|
if not status or status != JobStatus.completed:
|
||||||
raise ValueError(f"Job is not completed, Status: {status.value}")
|
raise ValueError(f"Job is not completed, Status: {status.value}")
|
||||||
|
|
||||||
|
|
|
@ -582,6 +582,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
tool_name=t.function.name,
|
tool_name=t.function.name,
|
||||||
# vLLM function args come back as a string. Llama Stack expects JSON.
|
# vLLM function args come back as a string. Llama Stack expects JSON.
|
||||||
arguments=json.loads(t.function.arguments),
|
arguments=json.loads(t.function.arguments),
|
||||||
|
arguments_json=t.function.arguments,
|
||||||
)
|
)
|
||||||
for t in vllm_message.tool_calls
|
for t in vllm_message.tool_calls
|
||||||
],
|
],
|
||||||
|
|
|
@ -23,7 +23,9 @@ from llama_stack.providers.utils.common.data_schema_validator import (
|
||||||
|
|
||||||
from .config import BasicScoringConfig
|
from .config import BasicScoringConfig
|
||||||
from .scoring_fn.bfcl_scoring_fn import BFCLScoringFn
|
from .scoring_fn.bfcl_scoring_fn import BFCLScoringFn
|
||||||
|
from .scoring_fn.docvqa_scoring_fn import DocVQAScoringFn
|
||||||
from .scoring_fn.equality_scoring_fn import EqualityScoringFn
|
from .scoring_fn.equality_scoring_fn import EqualityScoringFn
|
||||||
|
from .scoring_fn.ifeval_scoring_fn import IfEvalScoringFn
|
||||||
from .scoring_fn.regex_parser_math_response_scoring_fn import (
|
from .scoring_fn.regex_parser_math_response_scoring_fn import (
|
||||||
RegexParserMathResponseScoringFn,
|
RegexParserMathResponseScoringFn,
|
||||||
)
|
)
|
||||||
|
@ -36,6 +38,8 @@ FIXED_FNS = [
|
||||||
RegexParserScoringFn,
|
RegexParserScoringFn,
|
||||||
RegexParserMathResponseScoringFn,
|
RegexParserMathResponseScoringFn,
|
||||||
BFCLScoringFn,
|
BFCLScoringFn,
|
||||||
|
IfEvalScoringFn,
|
||||||
|
DocVQAScoringFn,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,240 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
from llama_stack.apis.scoring import ScoringResultRow
|
||||||
|
from llama_stack.apis.scoring_functions import ScoringFnParams
|
||||||
|
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
|
||||||
|
|
||||||
|
from .fn_defs.docvqa import docvqa
|
||||||
|
|
||||||
|
CONTRACTIONS = {
|
||||||
|
"aint": "ain't",
|
||||||
|
"arent": "aren't",
|
||||||
|
"cant": "can't",
|
||||||
|
"couldve": "could've",
|
||||||
|
"couldnt": "couldn't",
|
||||||
|
"couldn'tve": "couldn't've",
|
||||||
|
"couldnt've": "couldn't've",
|
||||||
|
"didnt": "didn't",
|
||||||
|
"doesnt": "doesn't",
|
||||||
|
"dont": "don't",
|
||||||
|
"hadnt": "hadn't",
|
||||||
|
"hadnt've": "hadn't've",
|
||||||
|
"hadn'tve": "hadn't've",
|
||||||
|
"hasnt": "hasn't",
|
||||||
|
"havent": "haven't",
|
||||||
|
"hed": "he'd",
|
||||||
|
"hed've": "he'd've",
|
||||||
|
"he'dve": "he'd've",
|
||||||
|
"hes": "he's",
|
||||||
|
"howd": "how'd",
|
||||||
|
"howll": "how'll",
|
||||||
|
"hows": "how's",
|
||||||
|
"Id've": "I'd've",
|
||||||
|
"I'dve": "I'd've",
|
||||||
|
"Im": "I'm",
|
||||||
|
"Ive": "I've",
|
||||||
|
"isnt": "isn't",
|
||||||
|
"itd": "it'd",
|
||||||
|
"itd've": "it'd've",
|
||||||
|
"it'dve": "it'd've",
|
||||||
|
"itll": "it'll",
|
||||||
|
"let's": "let's",
|
||||||
|
"maam": "ma'am",
|
||||||
|
"mightnt": "mightn't",
|
||||||
|
"mightnt've": "mightn't've",
|
||||||
|
"mightn'tve": "mightn't've",
|
||||||
|
"mightve": "might've",
|
||||||
|
"mustnt": "mustn't",
|
||||||
|
"mustve": "must've",
|
||||||
|
"neednt": "needn't",
|
||||||
|
"notve": "not've",
|
||||||
|
"oclock": "o'clock",
|
||||||
|
"oughtnt": "oughtn't",
|
||||||
|
"ow's'at": "'ow's'at",
|
||||||
|
"'ows'at": "'ow's'at",
|
||||||
|
"'ow'sat": "'ow's'at",
|
||||||
|
"shant": "shan't",
|
||||||
|
"shed've": "she'd've",
|
||||||
|
"she'dve": "she'd've",
|
||||||
|
"she's": "she's",
|
||||||
|
"shouldve": "should've",
|
||||||
|
"shouldnt": "shouldn't",
|
||||||
|
"shouldnt've": "shouldn't've",
|
||||||
|
"shouldn'tve": "shouldn't've",
|
||||||
|
"somebody'd": "somebodyd",
|
||||||
|
"somebodyd've": "somebody'd've",
|
||||||
|
"somebody'dve": "somebody'd've",
|
||||||
|
"somebodyll": "somebody'll",
|
||||||
|
"somebodys": "somebody's",
|
||||||
|
"someoned": "someone'd",
|
||||||
|
"someoned've": "someone'd've",
|
||||||
|
"someone'dve": "someone'd've",
|
||||||
|
"someonell": "someone'll",
|
||||||
|
"someones": "someone's",
|
||||||
|
"somethingd": "something'd",
|
||||||
|
"somethingd've": "something'd've",
|
||||||
|
"something'dve": "something'd've",
|
||||||
|
"somethingll": "something'll",
|
||||||
|
"thats": "that's",
|
||||||
|
"thered": "there'd",
|
||||||
|
"thered've": "there'd've",
|
||||||
|
"there'dve": "there'd've",
|
||||||
|
"therere": "there're",
|
||||||
|
"theres": "there's",
|
||||||
|
"theyd": "they'd",
|
||||||
|
"theyd've": "they'd've",
|
||||||
|
"they'dve": "they'd've",
|
||||||
|
"theyll": "they'll",
|
||||||
|
"theyre": "they're",
|
||||||
|
"theyve": "they've",
|
||||||
|
"twas": "'twas",
|
||||||
|
"wasnt": "wasn't",
|
||||||
|
"wed've": "we'd've",
|
||||||
|
"we'dve": "we'd've",
|
||||||
|
"weve": "we've",
|
||||||
|
"werent": "weren't",
|
||||||
|
"whatll": "what'll",
|
||||||
|
"whatre": "what're",
|
||||||
|
"whats": "what's",
|
||||||
|
"whatve": "what've",
|
||||||
|
"whens": "when's",
|
||||||
|
"whered": "where'd",
|
||||||
|
"wheres": "where's",
|
||||||
|
"whereve": "where've",
|
||||||
|
"whod": "who'd",
|
||||||
|
"whod've": "who'd've",
|
||||||
|
"who'dve": "who'd've",
|
||||||
|
"wholl": "who'll",
|
||||||
|
"whos": "who's",
|
||||||
|
"whove": "who've",
|
||||||
|
"whyll": "why'll",
|
||||||
|
"whyre": "why're",
|
||||||
|
"whys": "why's",
|
||||||
|
"wont": "won't",
|
||||||
|
"wouldve": "would've",
|
||||||
|
"wouldnt": "wouldn't",
|
||||||
|
"wouldnt've": "wouldn't've",
|
||||||
|
"wouldn'tve": "wouldn't've",
|
||||||
|
"yall": "y'all",
|
||||||
|
"yall'll": "y'all'll",
|
||||||
|
"y'allll": "y'all'll",
|
||||||
|
"yall'd've": "y'all'd've",
|
||||||
|
"y'alld've": "y'all'd've",
|
||||||
|
"y'all'dve": "y'all'd've",
|
||||||
|
"youd": "you'd",
|
||||||
|
"youd've": "you'd've",
|
||||||
|
"you'dve": "you'd've",
|
||||||
|
"youll": "you'll",
|
||||||
|
"youre": "you're",
|
||||||
|
"youve": "you've",
|
||||||
|
"1st": "first",
|
||||||
|
"2nd": "second",
|
||||||
|
"3rd": "third",
|
||||||
|
}
|
||||||
|
NUMBERS = {
|
||||||
|
"none": "0",
|
||||||
|
"zero": "0",
|
||||||
|
"one": "1",
|
||||||
|
"two": "2",
|
||||||
|
"three": "3",
|
||||||
|
"four": "4",
|
||||||
|
"five": "5",
|
||||||
|
"six": "6",
|
||||||
|
"seven": "7",
|
||||||
|
"eight": "8",
|
||||||
|
"nine": "9",
|
||||||
|
"ten": "10",
|
||||||
|
}
|
||||||
|
ARTICLES = [
|
||||||
|
"a",
|
||||||
|
"an",
|
||||||
|
"the",
|
||||||
|
"to",
|
||||||
|
"in",
|
||||||
|
"from",
|
||||||
|
"by",
|
||||||
|
] # Contains a bit more than just articles, but we want to get rid of these elements influencing the accuracy
|
||||||
|
PERIOD_STRIP = re.compile(r"(?!<=\d)(\.)(?!\d)")
|
||||||
|
COMMA_STRIP = re.compile(r"(\d)(\,)(\d)")
|
||||||
|
PUNCTUATION = [
|
||||||
|
";",
|
||||||
|
r"/",
|
||||||
|
"[",
|
||||||
|
"]",
|
||||||
|
'"',
|
||||||
|
"{",
|
||||||
|
"}",
|
||||||
|
"(",
|
||||||
|
")",
|
||||||
|
"=",
|
||||||
|
"+",
|
||||||
|
"\\",
|
||||||
|
"_",
|
||||||
|
"-",
|
||||||
|
">",
|
||||||
|
"<",
|
||||||
|
"@",
|
||||||
|
"`",
|
||||||
|
",",
|
||||||
|
"?",
|
||||||
|
"!",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_answer(s: str) -> str:
|
||||||
|
# process punctuation
|
||||||
|
for p in PUNCTUATION:
|
||||||
|
if (p + " " in s or " " + p in s) or (re.search(COMMA_STRIP, s) is not None):
|
||||||
|
s = s.replace(p, "")
|
||||||
|
else:
|
||||||
|
s = s.replace(p, " ")
|
||||||
|
s = PERIOD_STRIP.sub("", s, re.UNICODE)
|
||||||
|
|
||||||
|
# process digits and articles
|
||||||
|
temp_text = s.lower().split()
|
||||||
|
out_text = []
|
||||||
|
for word in temp_text:
|
||||||
|
word = NUMBERS.setdefault(word, word)
|
||||||
|
if word not in ARTICLES:
|
||||||
|
out_text.append(word)
|
||||||
|
|
||||||
|
# standardize contractions
|
||||||
|
for word_id, word in enumerate(out_text):
|
||||||
|
if word in CONTRACTIONS:
|
||||||
|
out_text[word_id] = CONTRACTIONS[word]
|
||||||
|
return " ".join(out_text)
|
||||||
|
|
||||||
|
|
||||||
|
class DocVQAScoringFn(RegisteredBaseScoringFn):
|
||||||
|
"""
|
||||||
|
docvqa basically matches the generated answer against several allowed
|
||||||
|
choices, but we need to normalize the answer to avoid penalizing
|
||||||
|
trivial differences
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs) -> None:
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.supported_fn_defs_registry = {
|
||||||
|
docvqa.identifier: docvqa,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def score_row(
|
||||||
|
self,
|
||||||
|
input_row: Dict[str, Any],
|
||||||
|
scoring_fn_identifier: Optional[str] = "docvqa",
|
||||||
|
scoring_params: Optional[ScoringFnParams] = None,
|
||||||
|
) -> ScoringResultRow:
|
||||||
|
expected_answers = json.loads(input_row["expected_answer"])
|
||||||
|
generated_answer = input_row["generated_answer"]
|
||||||
|
score = 1.0 if normalize_answer(generated_answer) in [normalize_answer(s) for s in expected_answers] else 0.0
|
||||||
|
return {
|
||||||
|
"score": score,
|
||||||
|
}
|
|
@ -0,0 +1,21 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from llama_stack.apis.common.type_system import NumberType
|
||||||
|
from llama_stack.apis.scoring_functions import (
|
||||||
|
AggregationFunctionType,
|
||||||
|
BasicScoringFnParams,
|
||||||
|
ScoringFn,
|
||||||
|
)
|
||||||
|
|
||||||
|
docvqa = ScoringFn(
|
||||||
|
identifier="basic::docvqa",
|
||||||
|
description="DocVQA Visual Question & Answer scoring function",
|
||||||
|
return_type=NumberType(),
|
||||||
|
provider_id="basic",
|
||||||
|
provider_resource_id="docvqa",
|
||||||
|
params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.accuracy]),
|
||||||
|
)
|
|
@ -0,0 +1,23 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from llama_stack.apis.common.type_system import NumberType
|
||||||
|
from llama_stack.apis.scoring_functions import (
|
||||||
|
AggregationFunctionType,
|
||||||
|
BasicScoringFnParams,
|
||||||
|
ScoringFn,
|
||||||
|
)
|
||||||
|
|
||||||
|
ifeval = ScoringFn(
|
||||||
|
identifier="basic::ifeval",
|
||||||
|
description="Eval intruction follow capacity by checkping how many instructions can be followed in each example",
|
||||||
|
return_type=NumberType(),
|
||||||
|
provider_id="basic",
|
||||||
|
provider_resource_id="ifeval",
|
||||||
|
params=BasicScoringFnParams(
|
||||||
|
aggregation_functions=[AggregationFunctionType.weighted_average],
|
||||||
|
),
|
||||||
|
)
|
|
@ -0,0 +1,80 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
from llama_stack.apis.scoring import ScoringResultRow
|
||||||
|
from llama_stack.apis.scoring_functions import ScoringFnParams
|
||||||
|
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
|
||||||
|
|
||||||
|
from .fn_defs.ifeval import (
|
||||||
|
ifeval,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class IfEvalScoringFn(RegisteredBaseScoringFn):
|
||||||
|
"""
|
||||||
|
A scoring_fn Instruction-Following Eval (IFEval) benchmark
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs) -> None:
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.supported_fn_defs_registry = {
|
||||||
|
ifeval.identifier: ifeval,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def score_row(
|
||||||
|
self,
|
||||||
|
input_row: Dict[str, Any],
|
||||||
|
scoring_fn_identifier: Optional[str] = None,
|
||||||
|
scoring_params: Optional[ScoringFnParams] = None,
|
||||||
|
) -> ScoringResultRow:
|
||||||
|
from ..utils.ifeval_utils import INSTRUCTION_DICT, INSTRUCTION_LIST
|
||||||
|
|
||||||
|
assert scoring_fn_identifier is not None, "Scoring function identifier not found."
|
||||||
|
fn_def = self.supported_fn_defs_registry[scoring_fn_identifier]
|
||||||
|
if scoring_params is not None:
|
||||||
|
fn_def.params = scoring_params
|
||||||
|
|
||||||
|
instruction_list = input_row["instruction_id_list"]
|
||||||
|
generated_answer = input_row["generated_answer"].strip()
|
||||||
|
|
||||||
|
is_following_list = []
|
||||||
|
results = dict(
|
||||||
|
{k + "_correct": 0.0 for k in INSTRUCTION_LIST},
|
||||||
|
**{k + "_total": 0.0 for k in INSTRUCTION_LIST},
|
||||||
|
)
|
||||||
|
|
||||||
|
for index, instruction_id in enumerate(instruction_list):
|
||||||
|
instruction_cls = INSTRUCTION_DICT[instruction_id]
|
||||||
|
instruction = instruction_cls(instruction_id)
|
||||||
|
results[instruction_id + "_total"] += 1.0
|
||||||
|
results[instruction_id.split(":")[0] + "_total"] += 1.0
|
||||||
|
|
||||||
|
clean_input_row = {k: v for k, v in input_row["kwargs"][index].items() if v is not None}
|
||||||
|
print(clean_input_row)
|
||||||
|
instruction.build_description(**clean_input_row)
|
||||||
|
args = instruction.get_instruction_args()
|
||||||
|
if args and "prompt" in args:
|
||||||
|
instruction.build_description(prompt=input_row["prompt"])
|
||||||
|
|
||||||
|
if generated_answer and instruction.check_following(generated_answer):
|
||||||
|
is_following_list.append(True)
|
||||||
|
results[instruction_id + "_correct"] += 1.0
|
||||||
|
results[instruction_id.split(":")[0] + "_correct"] += 1.0
|
||||||
|
else:
|
||||||
|
is_following_list.append(False)
|
||||||
|
|
||||||
|
if len(is_following_list) == 0:
|
||||||
|
return {
|
||||||
|
"score": 0.0,
|
||||||
|
"weight": 0.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
"score": float(sum(is_following_list)) / float(len(is_following_list)),
|
||||||
|
"weight": float(len(is_following_list)),
|
||||||
|
}
|
3319
llama_stack/providers/inline/scoring/basic/utils/ifeval_utils.py
Normal file
3319
llama_stack/providers/inline/scoring/basic/utils/ifeval_utils.py
Normal file
File diff suppressed because it is too large
Load diff
|
@ -6,12 +6,14 @@
|
||||||
|
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
from llama_stack.distribution.datatypes import Api
|
||||||
|
|
||||||
from .config import TelemetryConfig, TelemetrySink
|
from .config import TelemetryConfig, TelemetrySink
|
||||||
|
|
||||||
__all__ = ["TelemetryConfig", "TelemetrySink"]
|
__all__ = ["TelemetryConfig", "TelemetrySink"]
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(config: TelemetryConfig, deps: Dict[str, Any]):
|
async def get_provider_impl(config: TelemetryConfig, deps: Dict[Api, Any]):
|
||||||
from .telemetry import TelemetryAdapter
|
from .telemetry import TelemetryAdapter
|
||||||
|
|
||||||
impl = TelemetryAdapter(config, deps)
|
impl = TelemetryAdapter(config, deps)
|
||||||
|
|
|
@ -13,19 +13,20 @@ from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
|
||||||
|
|
||||||
|
|
||||||
class TelemetrySink(str, Enum):
|
class TelemetrySink(str, Enum):
|
||||||
OTEL = "otel"
|
OTEL_TRACE = "otel_trace"
|
||||||
|
OTEL_METRIC = "otel_metric"
|
||||||
SQLITE = "sqlite"
|
SQLITE = "sqlite"
|
||||||
CONSOLE = "console"
|
CONSOLE = "console"
|
||||||
|
|
||||||
|
|
||||||
class TelemetryConfig(BaseModel):
|
class TelemetryConfig(BaseModel):
|
||||||
otel_endpoint: str = Field(
|
otel_trace_endpoint: str = Field(
|
||||||
default="http://localhost:4318/v1/traces",
|
default="http://localhost:4318/v1/traces",
|
||||||
description="The OpenTelemetry collector endpoint URL",
|
description="The OpenTelemetry collector endpoint URL for traces",
|
||||||
)
|
)
|
||||||
service_name: str = Field(
|
otel_metric_endpoint: str = Field(
|
||||||
default="llama-stack",
|
default="http://localhost:4318/v1/metrics",
|
||||||
description="The service name to use for telemetry",
|
description="The OpenTelemetry collector endpoint URL for metrics",
|
||||||
)
|
)
|
||||||
sinks: List[TelemetrySink] = Field(
|
sinks: List[TelemetrySink] = Field(
|
||||||
default=[TelemetrySink.CONSOLE, TelemetrySink.SQLITE],
|
default=[TelemetrySink.CONSOLE, TelemetrySink.SQLITE],
|
||||||
|
@ -46,7 +47,6 @@ class TelemetryConfig(BaseModel):
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(cls, __distro_dir__: str, db_name: str = "trace_store.db") -> Dict[str, Any]:
|
def sample_run_config(cls, __distro_dir__: str, db_name: str = "trace_store.db") -> Dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"service_name": "${env.OTEL_SERVICE_NAME:llama-stack}",
|
|
||||||
"sinks": "${env.TELEMETRY_SINKS:console,sqlite}",
|
"sinks": "${env.TELEMETRY_SINKS:console,sqlite}",
|
||||||
"sqlite_db_path": "${env.SQLITE_DB_PATH:" + __distro_dir__ + "/" + db_name + "}",
|
"sqlite_db_path": "${env.SQLITE_DB_PATH:" + __distro_dir__ + "/" + db_name + "}",
|
||||||
}
|
}
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue