mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 17:29:01 +00:00
Merge branch 'main' into feat/litellm_sambanova_usage
This commit is contained in:
commit
8783dd8162
190 changed files with 8649 additions and 3304 deletions
5
.github/workflows/integration-tests.yml
vendored
5
.github/workflows/integration-tests.yml
vendored
|
@ -14,6 +14,10 @@ on:
|
|||
- 'requirements.txt'
|
||||
- '.github/workflows/integration-tests.yml' # This workflow
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
test-matrix:
|
||||
runs-on: ubuntu-latest
|
||||
|
@ -52,6 +56,7 @@ jobs:
|
|||
# always test against the latest version of the client
|
||||
uv pip install git+https://github.com/meta-llama/llama-stack-client-python.git@main
|
||||
uv pip install -e .
|
||||
llama stack build --template ollama --image-type venv
|
||||
|
||||
- name: Wait for Ollama to start
|
||||
run: |
|
||||
|
|
4
.github/workflows/pre-commit.yml
vendored
4
.github/workflows/pre-commit.yml
vendored
|
@ -5,6 +5,10 @@ on:
|
|||
push:
|
||||
branches: [main]
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
pre-commit:
|
||||
runs-on: ubuntu-latest
|
||||
|
|
4
.github/workflows/providers-build.yml
vendored
4
.github/workflows/providers-build.yml
vendored
|
@ -18,6 +18,10 @@ on:
|
|||
- 'llama_stack/distribution/*.sh'
|
||||
- '.github/workflows/providers-build.yml'
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
generate-matrix:
|
||||
runs-on: ubuntu-latest
|
||||
|
|
4
.github/workflows/semantic-pr.yml
vendored
4
.github/workflows/semantic-pr.yml
vendored
|
@ -8,6 +8,10 @@ on:
|
|||
- reopened
|
||||
- synchronize
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
|
|
4
.github/workflows/unit-tests.yml
vendored
4
.github/workflows/unit-tests.yml
vendored
|
@ -15,6 +15,10 @@ on:
|
|||
- '.github/workflows/unit-tests.yml' # This workflow
|
||||
workflow_dispatch:
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
unit-tests:
|
||||
runs-on: ubuntu-latest
|
||||
|
|
4
.github/workflows/update-readthedocs.yml
vendored
4
.github/workflows/update-readthedocs.yml
vendored
|
@ -22,6 +22,10 @@ on:
|
|||
- 'pyproject.toml'
|
||||
- '.github/workflows/update-readthedocs.yml'
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
update-readthedocs:
|
||||
runs-on: ubuntu-latest
|
||||
|
|
|
@ -89,10 +89,11 @@ repos:
|
|||
name: API Spec Codegen
|
||||
additional_dependencies:
|
||||
- uv==0.6.2
|
||||
entry: sh -c 'uv run --with ".[dev]" ./docs/openapi_generator/run_openapi_generator.sh > /dev/null 2>&1'
|
||||
entry: sh -c 'uv run --with ".[dev]" ./docs/openapi_generator/run_openapi_generator.sh > /dev/null'
|
||||
language: python
|
||||
pass_filenames: false
|
||||
require_serial: true
|
||||
files: ^llama_stack/apis/|^docs/openapi_generator/
|
||||
|
||||
ci:
|
||||
autofix_commit_msg: 🎨 [pre-commit.ci] Auto format from pre-commit.com hooks
|
||||
|
|
|
@ -135,9 +135,11 @@ uv sync
|
|||
|
||||
## Coding Style
|
||||
|
||||
* Comments should provide meaningful insights into the code. Avoid filler comments that simply describe the next step, as they create unnecessary clutter, same goes for docstrings.
|
||||
* Prefer comments to clarify surprising behavior and/or relationships between parts of the code rather than explain what the next line of code does.
|
||||
* Catching exceptions, prefer using a specific exception type rather than a broad catch-all like `Exception`.
|
||||
* Error messages should be prefixed with "Failed to ..."
|
||||
* 4 spaces for indentation rather than tabs
|
||||
* 80 character line length
|
||||
* ...
|
||||
|
||||
## Common Tasks
|
||||
|
||||
|
@ -166,7 +168,7 @@ If you have made changes to a provider's configuration in any form (introducing
|
|||
If you are making changes to the documentation at [https://llama-stack.readthedocs.io/en/latest/](https://llama-stack.readthedocs.io/en/latest/), you can use the following command to build the documentation and preview your changes. You will need [Sphinx](https://www.sphinx-doc.org/en/master/) and the readthedocs theme.
|
||||
|
||||
```bash
|
||||
cd llama-stack/docs
|
||||
cd docs
|
||||
uv sync --extra docs
|
||||
|
||||
# This rebuilds the documentation pages.
|
||||
|
|
|
@ -7,10 +7,12 @@
|
|||
"chardet",
|
||||
"chromadb-client",
|
||||
"datasets",
|
||||
"emoji",
|
||||
"faiss-cpu",
|
||||
"fastapi",
|
||||
"fire",
|
||||
"httpx",
|
||||
"langdetect",
|
||||
"matplotlib",
|
||||
"mcp",
|
||||
"nltk",
|
||||
|
@ -23,6 +25,7 @@
|
|||
"psycopg2-binary",
|
||||
"pymongo",
|
||||
"pypdf",
|
||||
"pythainlp",
|
||||
"redis",
|
||||
"requests",
|
||||
"scikit-learn",
|
||||
|
@ -41,10 +44,12 @@
|
|||
"chardet",
|
||||
"chromadb-client",
|
||||
"datasets",
|
||||
"emoji",
|
||||
"faiss-cpu",
|
||||
"fastapi",
|
||||
"fire",
|
||||
"httpx",
|
||||
"langdetect",
|
||||
"matplotlib",
|
||||
"nltk",
|
||||
"numpy",
|
||||
|
@ -56,6 +61,7 @@
|
|||
"psycopg2-binary",
|
||||
"pymongo",
|
||||
"pypdf",
|
||||
"pythainlp",
|
||||
"redis",
|
||||
"requests",
|
||||
"scikit-learn",
|
||||
|
@ -75,10 +81,12 @@
|
|||
"chardet",
|
||||
"chromadb-client",
|
||||
"datasets",
|
||||
"emoji",
|
||||
"fastapi",
|
||||
"fire",
|
||||
"fireworks-ai",
|
||||
"httpx",
|
||||
"langdetect",
|
||||
"matplotlib",
|
||||
"mcp",
|
||||
"nltk",
|
||||
|
@ -91,6 +99,7 @@
|
|||
"psycopg2-binary",
|
||||
"pymongo",
|
||||
"pypdf",
|
||||
"pythainlp",
|
||||
"redis",
|
||||
"requests",
|
||||
"scikit-learn",
|
||||
|
@ -112,11 +121,13 @@
|
|||
"chardet",
|
||||
"chromadb-client",
|
||||
"datasets",
|
||||
"emoji",
|
||||
"faiss-cpu",
|
||||
"fastapi",
|
||||
"fire",
|
||||
"httpx",
|
||||
"huggingface_hub",
|
||||
"langdetect",
|
||||
"matplotlib",
|
||||
"nltk",
|
||||
"numpy",
|
||||
|
@ -128,6 +139,7 @@
|
|||
"psycopg2-binary",
|
||||
"pymongo",
|
||||
"pypdf",
|
||||
"pythainlp",
|
||||
"redis",
|
||||
"requests",
|
||||
"scikit-learn",
|
||||
|
@ -147,10 +159,12 @@
|
|||
"chardet",
|
||||
"chromadb-client",
|
||||
"datasets",
|
||||
"emoji",
|
||||
"fastapi",
|
||||
"fire",
|
||||
"fireworks-ai",
|
||||
"httpx",
|
||||
"langdetect",
|
||||
"litellm",
|
||||
"matplotlib",
|
||||
"mcp",
|
||||
|
@ -164,6 +178,7 @@
|
|||
"psycopg2-binary",
|
||||
"pymongo",
|
||||
"pypdf",
|
||||
"pythainlp",
|
||||
"redis",
|
||||
"requests",
|
||||
"scikit-learn",
|
||||
|
@ -184,11 +199,13 @@
|
|||
"chardet",
|
||||
"chromadb-client",
|
||||
"datasets",
|
||||
"emoji",
|
||||
"faiss-cpu",
|
||||
"fastapi",
|
||||
"fire",
|
||||
"fireworks-ai",
|
||||
"httpx",
|
||||
"langdetect",
|
||||
"matplotlib",
|
||||
"mcp",
|
||||
"nltk",
|
||||
|
@ -201,6 +218,7 @@
|
|||
"psycopg2-binary",
|
||||
"pymongo",
|
||||
"pypdf",
|
||||
"pythainlp",
|
||||
"redis",
|
||||
"requests",
|
||||
"scikit-learn",
|
||||
|
@ -219,10 +237,12 @@
|
|||
"blobfile",
|
||||
"chardet",
|
||||
"datasets",
|
||||
"emoji",
|
||||
"faiss-cpu",
|
||||
"fastapi",
|
||||
"fire",
|
||||
"httpx",
|
||||
"langdetect",
|
||||
"litellm",
|
||||
"matplotlib",
|
||||
"nltk",
|
||||
|
@ -235,6 +255,7 @@
|
|||
"psycopg2-binary",
|
||||
"pymongo",
|
||||
"pypdf",
|
||||
"pythainlp",
|
||||
"redis",
|
||||
"requests",
|
||||
"scikit-learn",
|
||||
|
@ -253,11 +274,13 @@
|
|||
"chardet",
|
||||
"chromadb-client",
|
||||
"datasets",
|
||||
"emoji",
|
||||
"faiss-cpu",
|
||||
"fastapi",
|
||||
"fire",
|
||||
"httpx",
|
||||
"huggingface_hub",
|
||||
"langdetect",
|
||||
"matplotlib",
|
||||
"mcp",
|
||||
"nltk",
|
||||
|
@ -270,6 +293,7 @@
|
|||
"psycopg2-binary",
|
||||
"pymongo",
|
||||
"pypdf",
|
||||
"pythainlp",
|
||||
"redis",
|
||||
"requests",
|
||||
"scikit-learn",
|
||||
|
@ -288,11 +312,13 @@
|
|||
"chardet",
|
||||
"chromadb-client",
|
||||
"datasets",
|
||||
"emoji",
|
||||
"faiss-cpu",
|
||||
"fastapi",
|
||||
"fire",
|
||||
"httpx",
|
||||
"huggingface_hub",
|
||||
"langdetect",
|
||||
"matplotlib",
|
||||
"mcp",
|
||||
"nltk",
|
||||
|
@ -305,6 +331,7 @@
|
|||
"psycopg2-binary",
|
||||
"pymongo",
|
||||
"pypdf",
|
||||
"pythainlp",
|
||||
"redis",
|
||||
"requests",
|
||||
"scikit-learn",
|
||||
|
@ -325,11 +352,13 @@
|
|||
"chardet",
|
||||
"chromadb-client",
|
||||
"datasets",
|
||||
"emoji",
|
||||
"fairscale",
|
||||
"faiss-cpu",
|
||||
"fastapi",
|
||||
"fire",
|
||||
"httpx",
|
||||
"langdetect",
|
||||
"lm-format-enforcer",
|
||||
"matplotlib",
|
||||
"mcp",
|
||||
|
@ -343,6 +372,7 @@
|
|||
"psycopg2-binary",
|
||||
"pymongo",
|
||||
"pypdf",
|
||||
"pythainlp",
|
||||
"redis",
|
||||
"requests",
|
||||
"scikit-learn",
|
||||
|
@ -365,12 +395,14 @@
|
|||
"chardet",
|
||||
"chromadb-client",
|
||||
"datasets",
|
||||
"emoji",
|
||||
"fairscale",
|
||||
"faiss-cpu",
|
||||
"fastapi",
|
||||
"fbgemm-gpu",
|
||||
"fire",
|
||||
"httpx",
|
||||
"langdetect",
|
||||
"lm-format-enforcer",
|
||||
"matplotlib",
|
||||
"mcp",
|
||||
|
@ -384,6 +416,7 @@
|
|||
"psycopg2-binary",
|
||||
"pymongo",
|
||||
"pypdf",
|
||||
"pythainlp",
|
||||
"redis",
|
||||
"requests",
|
||||
"scikit-learn",
|
||||
|
@ -403,10 +436,12 @@
|
|||
"aiosqlite",
|
||||
"blobfile",
|
||||
"chardet",
|
||||
"emoji",
|
||||
"faiss-cpu",
|
||||
"fastapi",
|
||||
"fire",
|
||||
"httpx",
|
||||
"langdetect",
|
||||
"matplotlib",
|
||||
"nltk",
|
||||
"numpy",
|
||||
|
@ -418,6 +453,7 @@
|
|||
"psycopg2-binary",
|
||||
"pymongo",
|
||||
"pypdf",
|
||||
"pythainlp",
|
||||
"redis",
|
||||
"requests",
|
||||
"scikit-learn",
|
||||
|
@ -436,10 +472,12 @@
|
|||
"chardet",
|
||||
"chromadb-client",
|
||||
"datasets",
|
||||
"emoji",
|
||||
"faiss-cpu",
|
||||
"fastapi",
|
||||
"fire",
|
||||
"httpx",
|
||||
"langdetect",
|
||||
"matplotlib",
|
||||
"mcp",
|
||||
"nltk",
|
||||
|
@ -453,6 +491,7 @@
|
|||
"psycopg2-binary",
|
||||
"pymongo",
|
||||
"pypdf",
|
||||
"pythainlp",
|
||||
"redis",
|
||||
"requests",
|
||||
"scikit-learn",
|
||||
|
@ -470,9 +509,11 @@
|
|||
"chardet",
|
||||
"chromadb-client",
|
||||
"datasets",
|
||||
"emoji",
|
||||
"fastapi",
|
||||
"fire",
|
||||
"httpx",
|
||||
"langdetect",
|
||||
"litellm",
|
||||
"matplotlib",
|
||||
"mcp",
|
||||
|
@ -486,6 +527,7 @@
|
|||
"psycopg2-binary",
|
||||
"pymongo",
|
||||
"pypdf",
|
||||
"pythainlp",
|
||||
"redis",
|
||||
"requests",
|
||||
"scikit-learn",
|
||||
|
@ -505,10 +547,12 @@
|
|||
"chardet",
|
||||
"chromadb-client",
|
||||
"datasets",
|
||||
"emoji",
|
||||
"faiss-cpu",
|
||||
"fastapi",
|
||||
"fire",
|
||||
"httpx",
|
||||
"langdetect",
|
||||
"matplotlib",
|
||||
"mcp",
|
||||
"nltk",
|
||||
|
@ -521,6 +565,7 @@
|
|||
"psycopg2-binary",
|
||||
"pymongo",
|
||||
"pypdf",
|
||||
"pythainlp",
|
||||
"redis",
|
||||
"requests",
|
||||
"scikit-learn",
|
||||
|
@ -540,10 +585,12 @@
|
|||
"chardet",
|
||||
"chromadb-client",
|
||||
"datasets",
|
||||
"emoji",
|
||||
"faiss-cpu",
|
||||
"fastapi",
|
||||
"fire",
|
||||
"httpx",
|
||||
"langdetect",
|
||||
"matplotlib",
|
||||
"mcp",
|
||||
"nltk",
|
||||
|
@ -556,6 +603,7 @@
|
|||
"psycopg2-binary",
|
||||
"pymongo",
|
||||
"pypdf",
|
||||
"pythainlp",
|
||||
"redis",
|
||||
"requests",
|
||||
"scikit-learn",
|
||||
|
@ -608,11 +656,13 @@
|
|||
"chardet",
|
||||
"chromadb-client",
|
||||
"datasets",
|
||||
"emoji",
|
||||
"faiss-cpu",
|
||||
"fastapi",
|
||||
"fire",
|
||||
"httpx",
|
||||
"huggingface_hub",
|
||||
"langdetect",
|
||||
"matplotlib",
|
||||
"mcp",
|
||||
"nltk",
|
||||
|
@ -625,6 +675,7 @@
|
|||
"psycopg2-binary",
|
||||
"pymongo",
|
||||
"pypdf",
|
||||
"pythainlp",
|
||||
"redis",
|
||||
"requests",
|
||||
"scikit-learn",
|
||||
|
@ -644,10 +695,12 @@
|
|||
"chardet",
|
||||
"chromadb-client",
|
||||
"datasets",
|
||||
"emoji",
|
||||
"faiss-cpu",
|
||||
"fastapi",
|
||||
"fire",
|
||||
"httpx",
|
||||
"langdetect",
|
||||
"matplotlib",
|
||||
"mcp",
|
||||
"nltk",
|
||||
|
@ -660,6 +713,7 @@
|
|||
"psycopg2-binary",
|
||||
"pymongo",
|
||||
"pypdf",
|
||||
"pythainlp",
|
||||
"redis",
|
||||
"requests",
|
||||
"scikit-learn",
|
||||
|
@ -680,10 +734,12 @@
|
|||
"chardet",
|
||||
"chromadb-client",
|
||||
"datasets",
|
||||
"emoji",
|
||||
"faiss-cpu",
|
||||
"fastapi",
|
||||
"fire",
|
||||
"httpx",
|
||||
"langdetect",
|
||||
"matplotlib",
|
||||
"mcp",
|
||||
"nltk",
|
||||
|
@ -696,6 +752,7 @@
|
|||
"psycopg2-binary",
|
||||
"pymongo",
|
||||
"pypdf",
|
||||
"pythainlp",
|
||||
"redis",
|
||||
"requests",
|
||||
"scikit-learn",
|
||||
|
|
|
@ -51,14 +51,14 @@ services:
|
|||
- ~/local/llama-stack/:/app/llama-stack-source
|
||||
- ./run${SAFETY_MODEL:+-with-safety}.yaml:/root/my-run.yaml
|
||||
ports:
|
||||
- "${LLAMA_STACK_PORT:-5001}:${LLAMA_STACK_PORT:-5001}"
|
||||
- "${LLAMA_STACK_PORT:-8321}:${LLAMA_STACK_PORT:-8321}"
|
||||
environment:
|
||||
- INFERENCE_MODEL=${INFERENCE_MODEL}
|
||||
- SAFETY_MODEL=${SAFETY_MODEL:-}
|
||||
- OLLAMA_URL=http://ollama:11434
|
||||
entrypoint: >
|
||||
python -m llama_stack.distribution.server.server /root/my-run.yaml \
|
||||
--port ${LLAMA_STACK_PORT:-5001}
|
||||
--port ${LLAMA_STACK_PORT:-8321}
|
||||
deploy:
|
||||
restart_policy:
|
||||
condition: on-failure
|
||||
|
|
BIN
distributions/ramalama/faiss_store.db
Normal file
BIN
distributions/ramalama/faiss_store.db
Normal file
Binary file not shown.
|
@ -84,9 +84,9 @@ services:
|
|||
- SQLITE_STORE_DIR=${SQLITE_STORE_DIR:-$HOME/.llama/distributions/remote-vllm}
|
||||
- SAFETY_MODEL=${SAFETY_MODEL:-meta-llama/Llama-Guard-3-1B}
|
||||
ports:
|
||||
- "${LLAMA_STACK_PORT:-5001}:${LLAMA_STACK_PORT:-5001}"
|
||||
- "${LLAMA_STACK_PORT:-8321}:${LLAMA_STACK_PORT:-8321}"
|
||||
# Hack: wait for vLLM server to start before starting docker
|
||||
entrypoint: bash -c "sleep 60; python -m llama_stack.distribution.server.server --yaml_config /root/llamastack-run-remote-vllm.yaml --port 5001"
|
||||
entrypoint: bash -c "sleep 60; python -m llama_stack.distribution.server.server --yaml_config /root/llamastack-run-remote-vllm.yaml --port 8321"
|
||||
deploy:
|
||||
restart_policy:
|
||||
condition: on-failure
|
||||
|
|
|
@ -83,7 +83,7 @@ services:
|
|||
- ~/.llama:/root/.llama
|
||||
- ./run${TGI_SAFETY_MODEL:+-with-safety}.yaml:/root/my-run.yaml
|
||||
ports:
|
||||
- "${LLAMA_STACK_PORT:-5001}:${LLAMA_STACK_PORT:-5001}"
|
||||
- "${LLAMA_STACK_PORT:-8321}:${LLAMA_STACK_PORT:-8321}"
|
||||
# Hack: wait for TGI server to start before starting docker
|
||||
entrypoint: bash -c "sleep 60; python -m llama_stack.distribution.server.server --yaml_config /root/my-run.yaml"
|
||||
restart_policy:
|
||||
|
|
88
docs/_static/llama-stack-spec.html
vendored
88
docs/_static/llama-stack-spec.html
vendored
|
@ -2183,7 +2183,7 @@
|
|||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/JobStatus"
|
||||
"$ref": "#/components/schemas/Job"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -2650,7 +2650,7 @@
|
|||
}
|
||||
},
|
||||
"tags": [
|
||||
"Inspect"
|
||||
"Providers"
|
||||
],
|
||||
"description": "",
|
||||
"parameters": []
|
||||
|
@ -6268,6 +6268,7 @@
|
|||
"type": "string",
|
||||
"enum": [
|
||||
"average",
|
||||
"weighted_average",
|
||||
"median",
|
||||
"categorical_count",
|
||||
"accuracy"
|
||||
|
@ -7647,7 +7648,13 @@
|
|||
"title": "PostTrainingJobArtifactsResponse",
|
||||
"description": "Artifacts of a finetuning job."
|
||||
},
|
||||
"JobStatus": {
|
||||
"PostTrainingJobStatusResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"job_uuid": {
|
||||
"type": "string"
|
||||
},
|
||||
"status": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"completed",
|
||||
|
@ -7657,15 +7664,6 @@
|
|||
],
|
||||
"title": "JobStatus"
|
||||
},
|
||||
"PostTrainingJobStatusResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"job_uuid": {
|
||||
"type": "string"
|
||||
},
|
||||
"status": {
|
||||
"$ref": "#/components/schemas/JobStatus"
|
||||
},
|
||||
"scheduled_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
|
@ -8068,9 +8066,6 @@
|
|||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"content"
|
||||
],
|
||||
"title": "ToolInvocationResult"
|
||||
},
|
||||
"IterrowsResponse": {
|
||||
|
@ -8117,6 +8112,30 @@
|
|||
"title": "IterrowsResponse",
|
||||
"description": "A paginated list of rows from a dataset."
|
||||
},
|
||||
"Job": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"job_id": {
|
||||
"type": "string"
|
||||
},
|
||||
"status": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"completed",
|
||||
"in_progress",
|
||||
"failed",
|
||||
"scheduled"
|
||||
],
|
||||
"title": "JobStatus"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"job_id",
|
||||
"status"
|
||||
],
|
||||
"title": "Job"
|
||||
},
|
||||
"ListAgentSessionsResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
@ -9641,19 +9660,6 @@
|
|||
],
|
||||
"title": "RunEvalRequest"
|
||||
},
|
||||
"Job": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"job_id": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"job_id"
|
||||
],
|
||||
"title": "Job"
|
||||
},
|
||||
"RunShieldRequest": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
@ -9862,6 +9868,23 @@
|
|||
],
|
||||
"title": "ScoreBatchResponse"
|
||||
},
|
||||
"AlgorithmConfig": {
|
||||
"oneOf": [
|
||||
{
|
||||
"$ref": "#/components/schemas/LoraFinetuningConfig"
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/QATFinetuningConfig"
|
||||
}
|
||||
],
|
||||
"discriminator": {
|
||||
"propertyName": "type",
|
||||
"mapping": {
|
||||
"LoRA": "#/components/schemas/LoraFinetuningConfig",
|
||||
"QAT": "#/components/schemas/QATFinetuningConfig"
|
||||
}
|
||||
}
|
||||
},
|
||||
"LoraFinetuningConfig": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
@ -9997,14 +10020,7 @@
|
|||
"type": "string"
|
||||
},
|
||||
"algorithm_config": {
|
||||
"oneOf": [
|
||||
{
|
||||
"$ref": "#/components/schemas/LoraFinetuningConfig"
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/QATFinetuningConfig"
|
||||
}
|
||||
]
|
||||
"$ref": "#/components/schemas/AlgorithmConfig"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
|
|
61
docs/_static/llama-stack-spec.yaml
vendored
61
docs/_static/llama-stack-spec.yaml
vendored
|
@ -1491,7 +1491,7 @@ paths:
|
|||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/JobStatus'
|
||||
$ref: '#/components/schemas/Job'
|
||||
'400':
|
||||
$ref: '#/components/responses/BadRequest400'
|
||||
'429':
|
||||
|
@ -1814,7 +1814,7 @@ paths:
|
|||
default:
|
||||
$ref: '#/components/responses/DefaultError'
|
||||
tags:
|
||||
- Inspect
|
||||
- Providers
|
||||
description: ''
|
||||
parameters: []
|
||||
/v1/inspect/routes:
|
||||
|
@ -4389,6 +4389,7 @@ components:
|
|||
type: string
|
||||
enum:
|
||||
- average
|
||||
- weighted_average
|
||||
- median
|
||||
- categorical_count
|
||||
- accuracy
|
||||
|
@ -5276,7 +5277,12 @@ components:
|
|||
- checkpoints
|
||||
title: PostTrainingJobArtifactsResponse
|
||||
description: Artifacts of a finetuning job.
|
||||
JobStatus:
|
||||
PostTrainingJobStatusResponse:
|
||||
type: object
|
||||
properties:
|
||||
job_uuid:
|
||||
type: string
|
||||
status:
|
||||
type: string
|
||||
enum:
|
||||
- completed
|
||||
|
@ -5284,13 +5290,6 @@ components:
|
|||
- failed
|
||||
- scheduled
|
||||
title: JobStatus
|
||||
PostTrainingJobStatusResponse:
|
||||
type: object
|
||||
properties:
|
||||
job_uuid:
|
||||
type: string
|
||||
status:
|
||||
$ref: '#/components/schemas/JobStatus'
|
||||
scheduled_at:
|
||||
type: string
|
||||
format: date-time
|
||||
|
@ -5528,8 +5527,6 @@ components:
|
|||
- type: array
|
||||
- type: object
|
||||
additionalProperties: false
|
||||
required:
|
||||
- content
|
||||
title: ToolInvocationResult
|
||||
IterrowsResponse:
|
||||
type: object
|
||||
|
@ -5557,6 +5554,24 @@ components:
|
|||
- data
|
||||
title: IterrowsResponse
|
||||
description: A paginated list of rows from a dataset.
|
||||
Job:
|
||||
type: object
|
||||
properties:
|
||||
job_id:
|
||||
type: string
|
||||
status:
|
||||
type: string
|
||||
enum:
|
||||
- completed
|
||||
- in_progress
|
||||
- failed
|
||||
- scheduled
|
||||
title: JobStatus
|
||||
additionalProperties: false
|
||||
required:
|
||||
- job_id
|
||||
- status
|
||||
title: Job
|
||||
ListAgentSessionsResponse:
|
||||
type: object
|
||||
properties:
|
||||
|
@ -6551,15 +6566,6 @@ components:
|
|||
required:
|
||||
- benchmark_config
|
||||
title: RunEvalRequest
|
||||
Job:
|
||||
type: object
|
||||
properties:
|
||||
job_id:
|
||||
type: string
|
||||
additionalProperties: false
|
||||
required:
|
||||
- job_id
|
||||
title: Job
|
||||
RunShieldRequest:
|
||||
type: object
|
||||
properties:
|
||||
|
@ -6688,6 +6694,15 @@ components:
|
|||
required:
|
||||
- results
|
||||
title: ScoreBatchResponse
|
||||
AlgorithmConfig:
|
||||
oneOf:
|
||||
- $ref: '#/components/schemas/LoraFinetuningConfig'
|
||||
- $ref: '#/components/schemas/QATFinetuningConfig'
|
||||
discriminator:
|
||||
propertyName: type
|
||||
mapping:
|
||||
LoRA: '#/components/schemas/LoraFinetuningConfig'
|
||||
QAT: '#/components/schemas/QATFinetuningConfig'
|
||||
LoraFinetuningConfig:
|
||||
type: object
|
||||
properties:
|
||||
|
@ -6771,9 +6786,7 @@ components:
|
|||
checkpoint_dir:
|
||||
type: string
|
||||
algorithm_config:
|
||||
oneOf:
|
||||
- $ref: '#/components/schemas/LoraFinetuningConfig'
|
||||
- $ref: '#/components/schemas/QATFinetuningConfig'
|
||||
$ref: '#/components/schemas/AlgorithmConfig'
|
||||
additionalProperties: false
|
||||
required:
|
||||
- job_uuid
|
||||
|
|
|
@ -4,6 +4,21 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import time
|
||||
|
||||
|
||||
def pytest_collection_modifyitems(items):
|
||||
for item in items:
|
||||
item.name = item.name.replace(' ', '_')
|
||||
|
||||
|
||||
def pytest_runtest_teardown(item):
|
||||
interval_seconds = os.getenv("LLAMA_STACK_TEST_INTERVAL_SECONDS")
|
||||
if interval_seconds:
|
||||
time.sleep(float(interval_seconds))
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
config.option.tbstyle = "short"
|
||||
config.option.disable_warnings = True
|
||||
|
|
|
@ -123,6 +123,8 @@
|
|||
"outputs": [],
|
||||
"source": [
|
||||
"# NBVAL_SKIP\n",
|
||||
"!pip uninstall pandas numpy -y\n",
|
||||
"!pip install pandas numpy\n",
|
||||
"# This will build all the dependencies you will need\n",
|
||||
"!UV_SYSTEM_PYTHON=1 llama stack build --template together --image-type venv"
|
||||
]
|
||||
|
@ -1203,7 +1205,7 @@
|
|||
}
|
||||
],
|
||||
"source": [
|
||||
"from llama_stack_client.lib.inference.event_logger import EventLogger\n",
|
||||
"from llama_stack_client import InferenceEventLogger\n",
|
||||
"\n",
|
||||
"message = {\"role\": \"user\", \"content\": \"Write me a sonnet about llama\"}\n",
|
||||
"print(f'User> {message[\"content\"]}', \"green\")\n",
|
||||
|
@ -1215,7 +1217,7 @@
|
|||
")\n",
|
||||
"\n",
|
||||
"# Print the tokens while they are received\n",
|
||||
"for log in EventLogger().log(response):\n",
|
||||
"for log in InferenceEventLogger().log(response):\n",
|
||||
" log.print()\n"
|
||||
]
|
||||
},
|
||||
|
@ -1632,8 +1634,7 @@
|
|||
}
|
||||
],
|
||||
"source": [
|
||||
"from llama_stack_client.lib.agents.agent import Agent\n",
|
||||
"from llama_stack_client.lib.agents.event_logger import EventLogger\n",
|
||||
"from llama_stack_client import Agent, AgentEventLogger\n",
|
||||
"from termcolor import cprint\n",
|
||||
"\n",
|
||||
"agent = Agent(\n",
|
||||
|
@ -1659,7 +1660,7 @@
|
|||
" ],\n",
|
||||
" session_id=session_id,\n",
|
||||
" )\n",
|
||||
" for log in EventLogger().log(response):\n",
|
||||
" for log in AgentEventLogger().log(response):\n",
|
||||
" log.print()\n"
|
||||
]
|
||||
},
|
||||
|
@ -1808,14 +1809,12 @@
|
|||
],
|
||||
"source": [
|
||||
"import uuid\n",
|
||||
"from llama_stack_client.lib.agents.agent import Agent\n",
|
||||
"from llama_stack_client.lib.agents.event_logger import EventLogger\n",
|
||||
"from llama_stack_client import Agent, AgentEventLogger, RAGDocument\n",
|
||||
"from termcolor import cprint\n",
|
||||
"from llama_stack_client.types import Document\n",
|
||||
"\n",
|
||||
"urls = [\"chat.rst\", \"llama3.rst\", \"memory_optimizations.rst\", \"lora_finetune.rst\"]\n",
|
||||
"documents = [\n",
|
||||
" Document(\n",
|
||||
" RAGDocument(\n",
|
||||
" document_id=f\"num-{i}\",\n",
|
||||
" content=f\"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}\",\n",
|
||||
" mime_type=\"text/plain\",\n",
|
||||
|
@ -1858,7 +1857,7 @@
|
|||
" messages=[{\"role\": \"user\", \"content\": prompt}],\n",
|
||||
" session_id=session_id,\n",
|
||||
" )\n",
|
||||
" for log in EventLogger().log(response):\n",
|
||||
" for log in AgentEventLogger().log(response):\n",
|
||||
" log.print()"
|
||||
]
|
||||
},
|
||||
|
@ -1969,7 +1968,7 @@
|
|||
}
|
||||
],
|
||||
"source": [
|
||||
"from llama_stack_client.types.agents.turn_create_params import Document\n",
|
||||
"from llama_stack_client import Document\n",
|
||||
"\n",
|
||||
"codex_agent = Agent(\n",
|
||||
" client, \n",
|
||||
|
@ -2013,7 +2012,7 @@
|
|||
" # for chunk in response:\n",
|
||||
" # print(chunk)\n",
|
||||
"\n",
|
||||
" for log in EventLogger().log(response):\n",
|
||||
" for log in AgentEventLogger().log(response):\n",
|
||||
" log.print()\n"
|
||||
]
|
||||
},
|
||||
|
@ -2891,8 +2890,7 @@
|
|||
],
|
||||
"source": [
|
||||
"# NBVAL_SKIP\n",
|
||||
"from llama_stack_client.lib.agents.agent import Agent\n",
|
||||
"from llama_stack_client.lib.agents.event_logger import EventLogger\n",
|
||||
"from llama_stack_client import Agent, AgentEventLogger\n",
|
||||
"from termcolor import cprint\n",
|
||||
"\n",
|
||||
"agent = Agent(\n",
|
||||
|
@ -2918,7 +2916,7 @@
|
|||
" ],\n",
|
||||
" session_id=session_id,\n",
|
||||
" )\n",
|
||||
" for log in EventLogger().log(response):\n",
|
||||
" for log in AgentEventLogger().log(response):\n",
|
||||
" log.print()\n"
|
||||
]
|
||||
},
|
||||
|
@ -2993,8 +2991,7 @@
|
|||
}
|
||||
],
|
||||
"source": [
|
||||
"from llama_stack_client.lib.agents.agent import Agent\n",
|
||||
"from llama_stack_client.lib.agents.event_logger import EventLogger\n",
|
||||
"from llama_stack_client import Agent, AgentEventLogger\n",
|
||||
"\n",
|
||||
"agent = Agent(\n",
|
||||
" client, \n",
|
||||
|
@ -3021,7 +3018,7 @@
|
|||
" session_id=session_id,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" for log in EventLogger().log(response):\n",
|
||||
" for log in AgentEventLogger().log(response):\n",
|
||||
" log.print()\n"
|
||||
]
|
||||
},
|
||||
|
@ -4355,7 +4352,7 @@
|
|||
" session_id=session_id,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"for log in EventLogger().log(response):\n",
|
||||
"for log in AgentEventLogger().log(response):\n",
|
||||
" log.print()\n",
|
||||
" "
|
||||
]
|
||||
|
|
|
@ -47,9 +47,8 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from llama_stack_client import LlamaStackClient\n",
|
||||
"from llama_stack_client import LlamaStackClient, Agent\n",
|
||||
"from llama_stack.distribution.library_client import LlamaStackAsLibraryClient\n",
|
||||
"from llama_stack_client.lib.agents.agent import Agent\n",
|
||||
"from rich.pretty import pprint\n",
|
||||
"import json\n",
|
||||
"import uuid\n",
|
||||
|
|
|
@ -22,7 +22,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
|
@ -34,10 +34,8 @@
|
|||
}
|
||||
],
|
||||
"source": [
|
||||
"from llama_stack_client import LlamaStackClient\n",
|
||||
"from llama_stack_client import LlamaStackClient, Agent\n",
|
||||
"from llama_stack.distribution.library_client import LlamaStackAsLibraryClient\n",
|
||||
"from llama_stack_client.types.agent_create_params import AgentConfig\n",
|
||||
"from llama_stack_client.lib.agents.agent import Agent\n",
|
||||
"from rich.pretty import pprint\n",
|
||||
"import json\n",
|
||||
"import uuid\n",
|
||||
|
@ -70,7 +68,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"execution_count": 14,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -1397,6 +1395,349 @@
|
|||
"pprint(session_response.turns)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### 3.1 Improved RAG with Long Context\n",
|
||||
"\n",
|
||||
"- Instead of performing reteival tool, we send documents as attachments to the agent and let it use the entire document context. \n",
|
||||
"- Note how that the model is able to understand the entire context from documentation and answers the question with better factuality with improved retrieval. "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 19,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">Question:</span> What precision formats does torchtune support?\n",
|
||||
"</pre>\n"
|
||||
],
|
||||
"text/plain": [
|
||||
"\u001b[1;36mQuestion:\u001b[0m What precision formats does torchtune support?\n"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #808000; text-decoration-color: #808000; font-weight: bold\">Agent Answer:</span> Torchtune supports two precision formats: `fp32` <span style=\"font-weight: bold\">(</span>full-precision<span style=\"font-weight: bold\">)</span> and `bfloat16` <span style=\"font-weight: bold\">(</span>half-precision<span style=\"font-weight: bold\">)</span>. \n",
|
||||
"The `bfloat16` format uses <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2</span> bytes per model parameter, which is half the memory of `fp32`, and also improves \n",
|
||||
"training speed.\n",
|
||||
"</pre>\n"
|
||||
],
|
||||
"text/plain": [
|
||||
"\u001b[1;33mAgent Answer:\u001b[0m Torchtune supports two precision formats: `fp32` \u001b[1m(\u001b[0mfull-precision\u001b[1m)\u001b[0m and `bfloat16` \u001b[1m(\u001b[0mhalf-precision\u001b[1m)\u001b[0m. \n",
|
||||
"The `bfloat16` format uses \u001b[1;36m2\u001b[0m bytes per model parameter, which is half the memory of `fp32`, and also improves \n",
|
||||
"training speed.\n"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">Question:</span> What does DoRA stand for in torchtune?\n",
|
||||
"</pre>\n"
|
||||
],
|
||||
"text/plain": [
|
||||
"\u001b[1;36mQuestion:\u001b[0m What does DoRA stand for in torchtune?\n"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #808000; text-decoration-color: #808000; font-weight: bold\">Agent Answer:</span> DoRA stands for Weight-Decomposed Low-Rank Adaptation. It is a variant of LoRA <span style=\"font-weight: bold\">(</span>Low-Rank Adaptation<span style=\"font-weight: bold\">)</span> \n",
|
||||
"that further decomposes the pre-trained weights into two components: magnitude and direction. The magnitude \n",
|
||||
"component is a scalar vector that adjusts the scale, while the direction component corresponds to the original LoRA\n",
|
||||
"decomposition and updates the orientation of weights. DoRA adds a small overhead to LoRA training due to the \n",
|
||||
"addition of the magnitude parameter, but it has been shown to improve the performance of LoRA, particularly at low \n",
|
||||
"ranks.\n",
|
||||
"</pre>\n"
|
||||
],
|
||||
"text/plain": [
|
||||
"\u001b[1;33mAgent Answer:\u001b[0m DoRA stands for Weight-Decomposed Low-Rank Adaptation. It is a variant of LoRA \u001b[1m(\u001b[0mLow-Rank Adaptation\u001b[1m)\u001b[0m \n",
|
||||
"that further decomposes the pre-trained weights into two components: magnitude and direction. The magnitude \n",
|
||||
"component is a scalar vector that adjusts the scale, while the direction component corresponds to the original LoRA\n",
|
||||
"decomposition and updates the orientation of weights. DoRA adds a small overhead to LoRA training due to the \n",
|
||||
"addition of the magnitude parameter, but it has been shown to improve the performance of LoRA, particularly at low \n",
|
||||
"ranks.\n"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">Question:</span> How does the CPUOffloadOptimizer reduce GPU memory usage?\n",
|
||||
"</pre>\n"
|
||||
],
|
||||
"text/plain": [
|
||||
"\u001b[1;36mQuestion:\u001b[0m How does the CPUOffloadOptimizer reduce GPU memory usage?\n"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #808000; text-decoration-color: #808000; font-weight: bold\">Agent Answer:</span> The CPUOffloadOptimizer reduces GPU memory usage by offloading optimizer states and gradients to the \n",
|
||||
"CPU, and performing optimizer steps on the CPU. This can significantly reduce GPU memory usage at the cost of CPU \n",
|
||||
"RAM and training speed.\n",
|
||||
"</pre>\n"
|
||||
],
|
||||
"text/plain": [
|
||||
"\u001b[1;33mAgent Answer:\u001b[0m The CPUOffloadOptimizer reduces GPU memory usage by offloading optimizer states and gradients to the \n",
|
||||
"CPU, and performing optimizer steps on the CPU. This can significantly reduce GPU memory usage at the cost of CPU \n",
|
||||
"RAM and training speed.\n"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">Question:</span> How do I ensure only LoRA parameters are trainable when fine-tuning?\n",
|
||||
"</pre>\n"
|
||||
],
|
||||
"text/plain": [
|
||||
"\u001b[1;36mQuestion:\u001b[0m How do I ensure only LoRA parameters are trainable when fine-tuning?\n"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #808000; text-decoration-color: #808000; font-weight: bold\">Agent Answer:</span> To ensure only LoRA parameters are trainable when fine-tuning, you can use the `set_trainable_params`\n",
|
||||
"function from `torchtune.modules.peft.peft_utils` to set the `requires_grad` attribute of the LoRA parameters to \n",
|
||||
"`<span style=\"color: #00ff00; text-decoration-color: #00ff00; font-style: italic\">True</span>` and the `requires_grad` attribute of the other parameters to `<span style=\"color: #ff0000; text-decoration-color: #ff0000; font-style: italic\">False</span>`.\n",
|
||||
"\n",
|
||||
"Here is an example:\n",
|
||||
"```python\n",
|
||||
"from torchtune.modules.peft.peft_utils import get_adapter_params, set_trainable_params\n",
|
||||
"\n",
|
||||
"# Get the LoRA parameters\n",
|
||||
"lora_params = <span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">get_adapter_params</span><span style=\"font-weight: bold\">(</span>model<span style=\"font-weight: bold\">)</span>\n",
|
||||
"\n",
|
||||
"# Set the LoRA parameters to trainable and the other parameters to non-trainable\n",
|
||||
"<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">set_trainable_params</span><span style=\"font-weight: bold\">(</span>model, lora_params<span style=\"font-weight: bold\">)</span>\n",
|
||||
"```\n",
|
||||
"This will ensure that only the LoRA parameters are updated during fine-tuning, while the other parameters remain \n",
|
||||
"frozen.\n",
|
||||
"\n",
|
||||
"Alternatively, you can also use the `lora_finetune` recipe in torchtune, which automatically sets the LoRA \n",
|
||||
"parameters to trainable and the other parameters to non-trainable. You can run the recipe using the following \n",
|
||||
"command:\n",
|
||||
"```bash\n",
|
||||
"tune run lora_finetune --config llama2/7B_lora\n",
|
||||
"```\n",
|
||||
"This will fine-tune the LoRA parameters of the Llama2 model using the default settings. You can modify the config \n",
|
||||
"file to change the hyperparameters or the model architecture.\n",
|
||||
"</pre>\n"
|
||||
],
|
||||
"text/plain": [
|
||||
"\u001b[1;33mAgent Answer:\u001b[0m To ensure only LoRA parameters are trainable when fine-tuning, you can use the `set_trainable_params`\n",
|
||||
"function from `torchtune.modules.peft.peft_utils` to set the `requires_grad` attribute of the LoRA parameters to \n",
|
||||
"`\u001b[3;92mTrue\u001b[0m` and the `requires_grad` attribute of the other parameters to `\u001b[3;91mFalse\u001b[0m`.\n",
|
||||
"\n",
|
||||
"Here is an example:\n",
|
||||
"```python\n",
|
||||
"from torchtune.modules.peft.peft_utils import get_adapter_params, set_trainable_params\n",
|
||||
"\n",
|
||||
"# Get the LoRA parameters\n",
|
||||
"lora_params = \u001b[1;35mget_adapter_params\u001b[0m\u001b[1m(\u001b[0mmodel\u001b[1m)\u001b[0m\n",
|
||||
"\n",
|
||||
"# Set the LoRA parameters to trainable and the other parameters to non-trainable\n",
|
||||
"\u001b[1;35mset_trainable_params\u001b[0m\u001b[1m(\u001b[0mmodel, lora_params\u001b[1m)\u001b[0m\n",
|
||||
"```\n",
|
||||
"This will ensure that only the LoRA parameters are updated during fine-tuning, while the other parameters remain \n",
|
||||
"frozen.\n",
|
||||
"\n",
|
||||
"Alternatively, you can also use the `lora_finetune` recipe in torchtune, which automatically sets the LoRA \n",
|
||||
"parameters to trainable and the other parameters to non-trainable. You can run the recipe using the following \n",
|
||||
"command:\n",
|
||||
"```bash\n",
|
||||
"tune run lora_finetune --config llama2/7B_lora\n",
|
||||
"```\n",
|
||||
"This will fine-tune the LoRA parameters of the Llama2 model using the default settings. You can modify the config \n",
|
||||
"file to change the hyperparameters or the model architecture.\n"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"urls = [\n",
|
||||
" \"memory_optimizations.rst\",\n",
|
||||
" \"chat.rst\",\n",
|
||||
" \"llama3.rst\",\n",
|
||||
" \"datasets.rst\",\n",
|
||||
" \"qat_finetune.rst\",\n",
|
||||
" \"lora_finetune.rst\",\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"attachments = [\n",
|
||||
" {\n",
|
||||
" \"content\": {\n",
|
||||
" \"uri\": f\"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}\",\n",
|
||||
" },\n",
|
||||
" \"mime_type\": \"text/plain\",\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" for i, url in enumerate(urls)\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"rag_attachment_agent = Agent(\n",
|
||||
" client,\n",
|
||||
" model=MODEL_ID,\n",
|
||||
" instructions=\"You are a helpful assistant that can answer questions about the Torchtune project. Use context from attached documentation for Torchtune to answer questions.\",\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"for example in examples:\n",
|
||||
" session_id = rag_attachment_agent.create_session(session_name=f\"rag_attachment_session_{uuid.uuid4()}\")\n",
|
||||
" response = rag_attachment_agent.create_turn(\n",
|
||||
" messages=[\n",
|
||||
" {\n",
|
||||
" \"role\": \"user\",\n",
|
||||
" \"content\": example[\"input_query\"]\n",
|
||||
" }\n",
|
||||
" ],\n",
|
||||
" session_id=session_id,\n",
|
||||
" documents=attachments,\n",
|
||||
" stream=False\n",
|
||||
" )\n",
|
||||
" rich.print(f\"[bold cyan]Question:[/bold cyan] {example['input_query']}\")\n",
|
||||
" rich.print(f\"[bold yellow]Agent Answer:[/bold yellow] {response.output_message.content}\")\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">ScoringScoreResponse</span><span style=\"font-weight: bold\">(</span>\n",
|
||||
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ </span><span style=\"color: #808000; text-decoration-color: #808000\">results</span>=<span style=\"font-weight: bold\">{</span>\n",
|
||||
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ </span><span style=\"color: #008000; text-decoration-color: #008000\">'braintrust::factuality'</span>: <span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">ScoringResult</span><span style=\"font-weight: bold\">(</span>\n",
|
||||
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ </span><span style=\"color: #808000; text-decoration-color: #808000\">aggregated_results</span>=<span style=\"font-weight: bold\">{</span><span style=\"color: #008000; text-decoration-color: #008000\">'average'</span>: <span style=\"font-weight: bold\">{</span><span style=\"color: #008000; text-decoration-color: #008000\">'average'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.6</span><span style=\"font-weight: bold\">}}</span>,\n",
|
||||
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ </span><span style=\"color: #808000; text-decoration-color: #808000\">score_rows</span>=<span style=\"font-weight: bold\">[</span>\n",
|
||||
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ </span><span style=\"font-weight: bold\">{</span>\n",
|
||||
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ │ </span><span style=\"color: #008000; text-decoration-color: #008000\">'score'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.6</span>,\n",
|
||||
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ │ </span><span style=\"color: #008000; text-decoration-color: #008000\">'metadata'</span>: <span style=\"font-weight: bold\">{</span>\n",
|
||||
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ │ │ </span><span style=\"color: #008000; text-decoration-color: #008000\">'choice'</span>: <span style=\"color: #008000; text-decoration-color: #008000\">'B'</span>,\n",
|
||||
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ │ │ </span><span style=\"color: #008000; text-decoration-color: #008000\">'rationale'</span>: <span style=\"color: #008000; text-decoration-color: #008000\">'1. Both the expert and the submitted answers mention that Torchtune supports two precision formats: `fp32` (full-precision) and `bfloat16` (half-precision).\\n2. The expert answer specifies that `fp32` uses 4 bytes per model and optimizer parameter, while `bfloat16` uses 2 bytes per model and optimizer parameter.\\n3. The submitted answer also mentions that `bfloat16` uses 2 bytes per model parameter, which is consistent with the expert answer.\\n4. The submitted answer adds that `bfloat16` improves training speed, which is additional information not present in the expert answer.\\n5. There is no conflict between the submitted answer and the expert answer; the submitted answer simply provides more information.\\n\\nBased on this analysis, the submitted answer is a superset of the expert answer and is fully consistent with it.'</span>\n",
|
||||
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ │ </span><span style=\"font-weight: bold\">}</span>\n",
|
||||
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ </span><span style=\"font-weight: bold\">}</span>,\n",
|
||||
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ </span><span style=\"font-weight: bold\">{</span>\n",
|
||||
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ │ </span><span style=\"color: #008000; text-decoration-color: #008000\">'score'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.6</span>,\n",
|
||||
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ │ </span><span style=\"color: #008000; text-decoration-color: #008000\">'metadata'</span>: <span style=\"font-weight: bold\">{</span>\n",
|
||||
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ │ │ </span><span style=\"color: #008000; text-decoration-color: #008000\">'choice'</span>: <span style=\"color: #008000; text-decoration-color: #008000\">'B'</span>,\n",
|
||||
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ │ │ </span><span style=\"color: #008000; text-decoration-color: #008000\">'rationale'</span>: <span style=\"color: #008000; text-decoration-color: #008000\">'1. The expert answer provides the definition of DoRA as \"Weight-Decomposed Low-Rank Adaptation.\"\\n2. The submitted answer also states that DoRA stands for \"Weight-Decomposed Low-Rank Adaptation,\" which matches the expert answer.\\n3. The submitted answer includes additional information about DoRA, explaining that it is a variant of LoRA and describing how it decomposes pre-trained weights into magnitude and direction components.\\n4. The submitted answer further explains the role of the magnitude component and the direction component, and mentions the performance improvement and overhead associated with DoRA.\\n5. The additional details in the submitted answer do not contradict the expert answer; instead, they expand upon it.\\n6. Therefore, the submitted answer is a superset of the expert answer and is fully consistent with it.'</span>\n",
|
||||
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ │ </span><span style=\"font-weight: bold\">}</span>\n",
|
||||
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ </span><span style=\"font-weight: bold\">}</span>,\n",
|
||||
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ </span><span style=\"font-weight: bold\">{</span>\n",
|
||||
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ │ </span><span style=\"color: #008000; text-decoration-color: #008000\">'score'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.6</span>,\n",
|
||||
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ │ </span><span style=\"color: #008000; text-decoration-color: #008000\">'metadata'</span>: <span style=\"font-weight: bold\">{</span>\n",
|
||||
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ │ │ </span><span style=\"color: #008000; text-decoration-color: #008000\">'choice'</span>: <span style=\"color: #008000; text-decoration-color: #008000\">'B'</span>,\n",
|
||||
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ │ │ </span><span style=\"color: #008000; text-decoration-color: #008000\">'rationale'</span>: <span style=\"color: #008000; text-decoration-color: #008000\">'1. The expert answer states that the CPUOffloadOptimizer reduces GPU memory usage by keeping optimizer states on CPU and performing optimizer steps on CPU. It also mentions the optional offloading of gradients to CPU with the parameter offload_gradients=True.\\n\\n2. The submitted answer states that the CPUOffloadOptimizer reduces GPU memory usage by offloading optimizer states and gradients to the CPU, and performing optimizer steps on the CPU. It adds that this can significantly reduce GPU memory usage at the cost of CPU RAM and training speed.\\n\\n3. Comparing both answers:\\n - Both answers agree on offloading optimizer states to the CPU and performing optimizer steps on the CPU.\\n - Both mention the offloading of gradients to the CPU, but the expert answer specifies it as optional with a parameter, while the submission does not specify this detail.\\n - The submission adds additional information about the trade-off involving CPU RAM and training speed, which is not mentioned in the expert answer.\\n\\n4. The submitted answer includes all the details from the expert answer and adds more information about the trade-offs, making it a superset of the expert answer.\\n\\nTherefore, the correct choice is (B) The submitted answer is a superset of the expert answer and is fully consistent with it.'</span>\n",
|
||||
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ │ </span><span style=\"font-weight: bold\">}</span>\n",
|
||||
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ </span><span style=\"font-weight: bold\">}</span>,\n",
|
||||
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ </span><span style=\"font-weight: bold\">{</span>\n",
|
||||
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ │ </span><span style=\"color: #008000; text-decoration-color: #008000\">'score'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.6</span>,\n",
|
||||
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ │ </span><span style=\"color: #008000; text-decoration-color: #008000\">'metadata'</span>: <span style=\"font-weight: bold\">{</span>\n",
|
||||
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ │ │ </span><span style=\"color: #008000; text-decoration-color: #008000\">'choice'</span>: <span style=\"color: #008000; text-decoration-color: #008000\">'B'</span>,\n",
|
||||
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ │ │ </span><span style=\"color: #008000; text-decoration-color: #008000\">'rationale'</span>: <span style=\"color: #008000; text-decoration-color: #008000\">\"1. **Expert Answer Analysis**: The expert answer provides a method to ensure only LoRA parameters are trainable by using torchtune's utility functions. It mentions fetching LoRA parameters with `get_adapter_params(lora_model)` and setting them as trainable with `set_trainable_params(lora_model, lora_params)`. It also notes that the LoRA recipe handles this automatically.\\n\\n2. **Submitted Answer Analysis**: The submitted answer provides a similar method using `set_trainable_params` to set the `requires_grad` attribute of LoRA parameters to `True` and other parameters to `False`. It includes a code example demonstrating this process. Additionally, it mentions using the `lora_finetune` recipe in torchtune, which automatically sets the LoRA parameters to trainable.\\n\\n3. **Comparison**: The submitted answer includes all the details from the expert answer regarding the use of `get_adapter_params` and `set_trainable_params`. It also provides additional information about setting the `requires_grad` attribute and using the `lora_finetune` recipe, which is not mentioned in the expert answer.\\n\\n4. **Conclusion**: The submitted answer is a superset of the expert answer as it contains all the information from the expert answer and additional details. There is no conflict between the two answers, and the additional information in the submission is consistent with the expert's explanation.\\n\\nTherefore, the correct choice is (B) The submitted answer is a superset of the expert answer and is fully consistent with it.\"</span>\n",
|
||||
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ │ </span><span style=\"font-weight: bold\">}</span>\n",
|
||||
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ │ </span><span style=\"font-weight: bold\">}</span>\n",
|
||||
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ │ </span><span style=\"font-weight: bold\">]</span>\n",
|
||||
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ │ </span><span style=\"font-weight: bold\">)</span>\n",
|
||||
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ </span><span style=\"font-weight: bold\">}</span>\n",
|
||||
"<span style=\"font-weight: bold\">)</span>\n",
|
||||
"</pre>\n"
|
||||
],
|
||||
"text/plain": [
|
||||
"\u001b[1;35mScoringScoreResponse\u001b[0m\u001b[1m(\u001b[0m\n",
|
||||
"\u001b[2;32m│ \u001b[0m\u001b[33mresults\u001b[0m=\u001b[1m{\u001b[0m\n",
|
||||
"\u001b[2;32m│ │ \u001b[0m\u001b[32m'braintrust::factuality'\u001b[0m: \u001b[1;35mScoringResult\u001b[0m\u001b[1m(\u001b[0m\n",
|
||||
"\u001b[2;32m│ │ │ \u001b[0m\u001b[33maggregated_results\u001b[0m=\u001b[1m{\u001b[0m\u001b[32m'average'\u001b[0m: \u001b[1m{\u001b[0m\u001b[32m'average'\u001b[0m: \u001b[1;36m0.6\u001b[0m\u001b[1m}\u001b[0m\u001b[1m}\u001b[0m,\n",
|
||||
"\u001b[2;32m│ │ │ \u001b[0m\u001b[33mscore_rows\u001b[0m=\u001b[1m[\u001b[0m\n",
|
||||
"\u001b[2;32m│ │ │ │ \u001b[0m\u001b[1m{\u001b[0m\n",
|
||||
"\u001b[2;32m│ │ │ │ │ \u001b[0m\u001b[32m'score'\u001b[0m: \u001b[1;36m0.6\u001b[0m,\n",
|
||||
"\u001b[2;32m│ │ │ │ │ \u001b[0m\u001b[32m'metadata'\u001b[0m: \u001b[1m{\u001b[0m\n",
|
||||
"\u001b[2;32m│ │ │ │ │ │ \u001b[0m\u001b[32m'choice'\u001b[0m: \u001b[32m'B'\u001b[0m,\n",
|
||||
"\u001b[2;32m│ │ │ │ │ │ \u001b[0m\u001b[32m'rationale'\u001b[0m: \u001b[32m'1. Both the expert and the submitted answers mention that Torchtune supports two precision formats: `fp32` \u001b[0m\u001b[32m(\u001b[0m\u001b[32mfull-precision\u001b[0m\u001b[32m)\u001b[0m\u001b[32m and `bfloat16` \u001b[0m\u001b[32m(\u001b[0m\u001b[32mhalf-precision\u001b[0m\u001b[32m)\u001b[0m\u001b[32m.\\n2. The expert answer specifies that `fp32` uses 4 bytes per model and optimizer parameter, while `bfloat16` uses 2 bytes per model and optimizer parameter.\\n3. The submitted answer also mentions that `bfloat16` uses 2 bytes per model parameter, which is consistent with the expert answer.\\n4. The submitted answer adds that `bfloat16` improves training speed, which is additional information not present in the expert answer.\\n5. There is no conflict between the submitted answer and the expert answer; the submitted answer simply provides more information.\\n\\nBased on this analysis, the submitted answer is a superset of the expert answer and is fully consistent with it.'\u001b[0m\n",
|
||||
"\u001b[2;32m│ │ │ │ │ \u001b[0m\u001b[1m}\u001b[0m\n",
|
||||
"\u001b[2;32m│ │ │ │ \u001b[0m\u001b[1m}\u001b[0m,\n",
|
||||
"\u001b[2;32m│ │ │ │ \u001b[0m\u001b[1m{\u001b[0m\n",
|
||||
"\u001b[2;32m│ │ │ │ │ \u001b[0m\u001b[32m'score'\u001b[0m: \u001b[1;36m0.6\u001b[0m,\n",
|
||||
"\u001b[2;32m│ │ │ │ │ \u001b[0m\u001b[32m'metadata'\u001b[0m: \u001b[1m{\u001b[0m\n",
|
||||
"\u001b[2;32m│ │ │ │ │ │ \u001b[0m\u001b[32m'choice'\u001b[0m: \u001b[32m'B'\u001b[0m,\n",
|
||||
"\u001b[2;32m│ │ │ │ │ │ \u001b[0m\u001b[32m'rationale'\u001b[0m: \u001b[32m'1. The expert answer provides the definition of DoRA as \"Weight-Decomposed Low-Rank Adaptation.\"\\n2. The submitted answer also states that DoRA stands for \"Weight-Decomposed Low-Rank Adaptation,\" which matches the expert answer.\\n3. The submitted answer includes additional information about DoRA, explaining that it is a variant of LoRA and describing how it decomposes pre-trained weights into magnitude and direction components.\\n4. The submitted answer further explains the role of the magnitude component and the direction component, and mentions the performance improvement and overhead associated with DoRA.\\n5. The additional details in the submitted answer do not contradict the expert answer; instead, they expand upon it.\\n6. Therefore, the submitted answer is a superset of the expert answer and is fully consistent with it.'\u001b[0m\n",
|
||||
"\u001b[2;32m│ │ │ │ │ \u001b[0m\u001b[1m}\u001b[0m\n",
|
||||
"\u001b[2;32m│ │ │ │ \u001b[0m\u001b[1m}\u001b[0m,\n",
|
||||
"\u001b[2;32m│ │ │ │ \u001b[0m\u001b[1m{\u001b[0m\n",
|
||||
"\u001b[2;32m│ │ │ │ │ \u001b[0m\u001b[32m'score'\u001b[0m: \u001b[1;36m0.6\u001b[0m,\n",
|
||||
"\u001b[2;32m│ │ │ │ │ \u001b[0m\u001b[32m'metadata'\u001b[0m: \u001b[1m{\u001b[0m\n",
|
||||
"\u001b[2;32m│ │ │ │ │ │ \u001b[0m\u001b[32m'choice'\u001b[0m: \u001b[32m'B'\u001b[0m,\n",
|
||||
"\u001b[2;32m│ │ │ │ │ │ \u001b[0m\u001b[32m'rationale'\u001b[0m: \u001b[32m'1. The expert answer states that the CPUOffloadOptimizer reduces GPU memory usage by keeping optimizer states on CPU and performing optimizer steps on CPU. It also mentions the optional offloading of gradients to CPU with the parameter \u001b[0m\u001b[32moffload_gradients\u001b[0m\u001b[32m=\u001b[0m\u001b[32mTrue\u001b[0m\u001b[32m.\\n\\n2. The submitted answer states that the CPUOffloadOptimizer reduces GPU memory usage by offloading optimizer states and gradients to the CPU, and performing optimizer steps on the CPU. It adds that this can significantly reduce GPU memory usage at the cost of CPU RAM and training speed.\\n\\n3. Comparing both answers:\\n - Both answers agree on offloading optimizer states to the CPU and performing optimizer steps on the CPU.\\n - Both mention the offloading of gradients to the CPU, but the expert answer specifies it as optional with a parameter, while the submission does not specify this detail.\\n - The submission adds additional information about the trade-off involving CPU RAM and training speed, which is not mentioned in the expert answer.\\n\\n4. The submitted answer includes all the details from the expert answer and adds more information about the trade-offs, making it a superset of the expert answer.\\n\\nTherefore, the correct choice is \u001b[0m\u001b[32m(\u001b[0m\u001b[32mB\u001b[0m\u001b[32m)\u001b[0m\u001b[32m The submitted answer is a superset of the expert answer and is fully consistent with it.'\u001b[0m\n",
|
||||
"\u001b[2;32m│ │ │ │ │ \u001b[0m\u001b[1m}\u001b[0m\n",
|
||||
"\u001b[2;32m│ │ │ │ \u001b[0m\u001b[1m}\u001b[0m,\n",
|
||||
"\u001b[2;32m│ │ │ │ \u001b[0m\u001b[1m{\u001b[0m\n",
|
||||
"\u001b[2;32m│ │ │ │ │ \u001b[0m\u001b[32m'score'\u001b[0m: \u001b[1;36m0.6\u001b[0m,\n",
|
||||
"\u001b[2;32m│ │ │ │ │ \u001b[0m\u001b[32m'metadata'\u001b[0m: \u001b[1m{\u001b[0m\n",
|
||||
"\u001b[2;32m│ │ │ │ │ │ \u001b[0m\u001b[32m'choice'\u001b[0m: \u001b[32m'B'\u001b[0m,\n",
|
||||
"\u001b[2;32m│ │ │ │ │ │ \u001b[0m\u001b[32m'rationale'\u001b[0m: \u001b[32m\"1. **Expert Answer Analysis**: The expert answer provides a method to ensure only LoRA parameters are trainable by using torchtune's utility functions. It mentions fetching LoRA parameters with `get_adapter_params\u001b[0m\u001b[32m(\u001b[0m\u001b[32mlora_model\u001b[0m\u001b[32m)\u001b[0m\u001b[32m` and setting them as trainable with `set_trainable_params\u001b[0m\u001b[32m(\u001b[0m\u001b[32mlora_model, lora_params\u001b[0m\u001b[32m)\u001b[0m\u001b[32m`. It also notes that the LoRA recipe handles this automatically.\\n\\n2. **Submitted Answer Analysis**: The submitted answer provides a similar method using `set_trainable_params` to set the `requires_grad` attribute of LoRA parameters to `True` and other parameters to `False`. It includes a code example demonstrating this process. Additionally, it mentions using the `lora_finetune` recipe in torchtune, which automatically sets the LoRA parameters to trainable.\\n\\n3. **Comparison**: The submitted answer includes all the details from the expert answer regarding the use of `get_adapter_params` and `set_trainable_params`. It also provides additional information about setting the `requires_grad` attribute and using the `lora_finetune` recipe, which is not mentioned in the expert answer.\\n\\n4. **Conclusion**: The submitted answer is a superset of the expert answer as it contains all the information from the expert answer and additional details. There is no conflict between the two answers, and the additional information in the submission is consistent with the expert's explanation.\\n\\nTherefore, the correct choice is \u001b[0m\u001b[32m(\u001b[0m\u001b[32mB\u001b[0m\u001b[32m)\u001b[0m\u001b[32m The submitted answer is a superset of the expert answer and is fully consistent with it.\"\u001b[0m\n",
|
||||
"\u001b[2;32m│ │ │ │ │ \u001b[0m\u001b[1m}\u001b[0m\n",
|
||||
"\u001b[2;32m│ │ │ │ \u001b[0m\u001b[1m}\u001b[0m\n",
|
||||
"\u001b[2;32m│ │ │ \u001b[0m\u001b[1m]\u001b[0m\n",
|
||||
"\u001b[2;32m│ │ \u001b[0m\u001b[1m)\u001b[0m\n",
|
||||
"\u001b[2;32m│ \u001b[0m\u001b[1m}\u001b[0m\n",
|
||||
"\u001b[1m)\u001b[0m\n"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"eval_rows = []\n",
|
||||
"for i, session_id in enumerate(rag_attachment_agent.sessions):\n",
|
||||
" session_response = client.agents.session.retrieve(agent_id=rag_attachment_agent.agent_id, session_id=session_id)\n",
|
||||
" for turn in session_response.turns:\n",
|
||||
" eval_rows.append({\n",
|
||||
" \"input_query\": examples[i][\"input_query\"],\n",
|
||||
" \"expected_answer\": examples[i][\"expected_answer\"],\n",
|
||||
" \"generated_answer\": turn.output_message.content,\n",
|
||||
" })\n",
|
||||
"\n",
|
||||
"scoring_params = {\n",
|
||||
" \"braintrust::factuality\": None,\n",
|
||||
"}\n",
|
||||
"scoring_response = client.scoring.score(\n",
|
||||
" input_rows=eval_rows,\n",
|
||||
" scoring_functions=scoring_params,\n",
|
||||
")\n",
|
||||
"pprint(scoring_response)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
|
|
|
@ -14,7 +14,7 @@ Agents are configured using the `AgentConfig` class, which includes:
|
|||
- **Safety Shields**: Guardrails to ensure responsible AI behavior
|
||||
|
||||
```python
|
||||
from llama_stack_client.lib.agents.agent import Agent
|
||||
from llama_stack_client import Agent
|
||||
|
||||
|
||||
# Create the agent
|
||||
|
@ -44,14 +44,14 @@ Each interaction with an agent is called a "turn" and consists of:
|
|||
- **Output Message**: The agent's response
|
||||
|
||||
```python
|
||||
from llama_stack_client.lib.agents.event_logger import EventLogger
|
||||
from llama_stack_client import AgentEventLogger
|
||||
|
||||
# Create a turn with streaming response
|
||||
turn_response = agent.create_turn(
|
||||
session_id=session_id,
|
||||
messages=[{"role": "user", "content": "Tell me about Llama models"}],
|
||||
)
|
||||
for log in EventLogger().log(turn_response):
|
||||
for log in AgentEventLogger().log(turn_response):
|
||||
log.print()
|
||||
```
|
||||
### Non-Streaming
|
||||
|
|
|
@ -67,9 +67,7 @@ sequenceDiagram
|
|||
Each step in this process can be monitored and controlled through configurations. Here's an example that demonstrates monitoring the agent's execution:
|
||||
|
||||
```python
|
||||
from llama_stack_client import LlamaStackClient
|
||||
from llama_stack_client.lib.agents.agent import Agent
|
||||
from llama_stack_client.lib.agents.event_logger import EventLogger
|
||||
from llama_stack_client import LlamaStackClient, Agent, AgentEventLogger
|
||||
from rich.pretty import pprint
|
||||
|
||||
# Replace host and port
|
||||
|
@ -113,7 +111,7 @@ response = agent.create_turn(
|
|||
)
|
||||
|
||||
# Monitor each step of execution
|
||||
for log in EventLogger().log(response):
|
||||
for log in AgentEventLogger().log(response):
|
||||
log.print()
|
||||
|
||||
# Using non-streaming API, the response contains input, steps, and output.
|
||||
|
|
|
@ -23,9 +23,7 @@ In this example, we will show you how to:
|
|||
|
||||
##### Building a Search Agent
|
||||
```python
|
||||
from llama_stack_client import LlamaStackClient
|
||||
from llama_stack_client.lib.agents.agent import Agent
|
||||
from llama_stack_client.lib.agents.event_logger import EventLogger
|
||||
from llama_stack_client import LlamaStackClient, Agent, AgentEventLogger
|
||||
|
||||
client = LlamaStackClient(base_url=f"http://{HOST}:{PORT}")
|
||||
|
||||
|
@ -54,7 +52,7 @@ for prompt in user_prompts:
|
|||
session_id=session_id,
|
||||
)
|
||||
|
||||
for log in EventLogger().log(response):
|
||||
for log in AgentEventLogger().log(response):
|
||||
log.print()
|
||||
```
|
||||
|
||||
|
|
|
@ -55,11 +55,11 @@ chunks_response = client.vector_io.query(
|
|||
A better way to ingest documents is to use the RAG Tool. This tool allows you to ingest documents from URLs, files, etc. and automatically chunks them into smaller pieces.
|
||||
|
||||
```python
|
||||
from llama_stack_client.types import Document
|
||||
from llama_stack_client import RAGDocument
|
||||
|
||||
urls = ["memory_optimizations.rst", "chat.rst", "llama3.rst"]
|
||||
documents = [
|
||||
Document(
|
||||
RAGDocument(
|
||||
document_id=f"num-{i}",
|
||||
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
|
||||
mime_type="text/plain",
|
||||
|
@ -86,7 +86,7 @@ results = client.tool_runtime.rag_tool.query(
|
|||
One of the most powerful patterns is combining agents with RAG capabilities. Here's a complete example:
|
||||
|
||||
```python
|
||||
from llama_stack_client.lib.agents.agent import Agent
|
||||
from llama_stack_client import Agent
|
||||
|
||||
# Create agent with memory
|
||||
agent = Agent(
|
||||
|
@ -140,9 +140,9 @@ response = agent.create_turn(
|
|||
|
||||
You can print the response with below.
|
||||
```python
|
||||
from llama_stack_client.lib.agents.event_logger import EventLogger
|
||||
from llama_stack_client import AgentEventLogger
|
||||
|
||||
for log in EventLogger().log(response):
|
||||
for log in AgentEventLogger().log(response):
|
||||
log.print()
|
||||
```
|
||||
|
||||
|
|
|
@ -57,7 +57,7 @@ The `otel` sink works with any service compatible with the OpenTelemetry collect
|
|||
Start a Jaeger instance with the OTLP HTTP endpoint at 4318 and the Jaeger UI at 16686 using the following command:
|
||||
|
||||
```bash
|
||||
$ docker run --rm --name jaeger \
|
||||
$ docker run --pull always --rm --name jaeger \
|
||||
-p 16686:16686 -p 4318:4318 \
|
||||
jaegertracing/jaeger:2.1.0
|
||||
```
|
||||
|
|
|
@ -110,10 +110,18 @@ MCP tools are special tools that can interact with llama stack over model contex
|
|||
|
||||
Refer to [https://github.com/modelcontextprotocol/servers](https://github.com/modelcontextprotocol/servers) for available MCP servers.
|
||||
|
||||
```shell
|
||||
# start your MCP server
|
||||
mkdir /tmp/content
|
||||
touch /tmp/content/foo
|
||||
touch /tmp/content/bar
|
||||
npx -y supergateway --port 8000 --stdio 'npx -y @modelcontextprotocol/server-filesystem /tmp/content'
|
||||
```
|
||||
|
||||
Then register the MCP server as a tool group,
|
||||
```python
|
||||
# Register MCP tools
|
||||
client.toolgroups.register(
|
||||
toolgroup_id="builtin::filesystem",
|
||||
toolgroup_id="mcp::filesystem",
|
||||
provider_id="model-context-protocol",
|
||||
mcp_endpoint=URL(uri="http://localhost:8000/sse"),
|
||||
)
|
||||
|
@ -181,7 +189,7 @@ group_tools = client.tools.list_tools(toolgroup_id="search_tools")
|
|||
## Simple Example: Using an Agent with the Code-Interpreter Tool
|
||||
|
||||
```python
|
||||
from llama_stack_client.lib.agents.agent import Agent
|
||||
from llama_stack_client import Agent
|
||||
|
||||
# Instantiate the AI agent with the given configuration
|
||||
agent = Agent(
|
||||
|
|
|
@ -55,7 +55,7 @@ llama stack run llama_stack/templates/open-benchmark/run.yaml
|
|||
There are 3 necessary inputs to run a benchmark eval
|
||||
- `list of benchmark_ids`: The list of benchmark ids to run evaluation on
|
||||
- `model-id`: The model id to evaluate on
|
||||
- `utput_dir`: Path to store the evaluate results
|
||||
- `output_dir`: Path to store the evaluate results
|
||||
```
|
||||
llama-stack-client eval run-benchmark <benchmark_id_1> <benchmark_id_2> ... \
|
||||
--model_id <model id to evaluate on> \
|
||||
|
@ -69,7 +69,7 @@ llama-stack-client eval run-benchmark help
|
|||
to see the description of all the flags that eval run-benchmark has
|
||||
|
||||
|
||||
In the output log, you can find the file path that has your evaluation results. Open that file and you can see you aggrgate
|
||||
In the output log, you can find the file path that has your evaluation results. Open that file and you can see you aggregate
|
||||
evaluation results over there.
|
||||
|
||||
|
||||
|
|
|
@ -58,9 +58,10 @@ You can do this via Conda (build code) or Docker which has a pre-built image.
|
|||
This method allows you to get started quickly without having to build the distribution code.
|
||||
|
||||
```bash
|
||||
LLAMA_STACK_PORT=5001
|
||||
LLAMA_STACK_PORT=8321
|
||||
docker run \
|
||||
-it \
|
||||
--pull always \
|
||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||
-v ./run.yaml:/root/my-run.yaml \
|
||||
llamastack/distribution-nvidia \
|
||||
|
@ -74,7 +75,7 @@ docker run \
|
|||
```bash
|
||||
llama stack build --template nvidia --image-type conda
|
||||
llama stack run ./run.yaml \
|
||||
--port 5001 \
|
||||
--port 8321 \
|
||||
--env NVIDIA_API_KEY=$NVIDIA_API_KEY
|
||||
--env INFERENCE_MODEL=$INFERENCE_MODEL
|
||||
```
|
||||
|
|
|
@ -28,7 +28,7 @@ The `llamastack/distribution-bedrock` distribution consists of the following pro
|
|||
|
||||
The following environment variables can be configured:
|
||||
|
||||
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
|
||||
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `8321`)
|
||||
|
||||
### Models
|
||||
|
||||
|
@ -53,9 +53,10 @@ You can do this via Conda (build code) or Docker which has a pre-built image.
|
|||
This method allows you to get started quickly without having to build the distribution code.
|
||||
|
||||
```bash
|
||||
LLAMA_STACK_PORT=5001
|
||||
LLAMA_STACK_PORT=8321
|
||||
docker run \
|
||||
-it \
|
||||
--pull always \
|
||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||
llamastack/distribution-bedrock \
|
||||
--port $LLAMA_STACK_PORT \
|
||||
|
|
|
@ -20,7 +20,7 @@ The `llamastack/distribution-cerebras` distribution consists of the following pr
|
|||
|
||||
The following environment variables can be configured:
|
||||
|
||||
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
|
||||
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `8321`)
|
||||
- `CEREBRAS_API_KEY`: Cerebras API Key (default: ``)
|
||||
|
||||
### Models
|
||||
|
@ -45,9 +45,10 @@ You can do this via Conda (build code) or Docker which has a pre-built image.
|
|||
This method allows you to get started quickly without having to build the distribution code.
|
||||
|
||||
```bash
|
||||
LLAMA_STACK_PORT=5001
|
||||
LLAMA_STACK_PORT=8321
|
||||
docker run \
|
||||
-it \
|
||||
--pull always \
|
||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||
-v ./run.yaml:/root/my-run.yaml \
|
||||
llamastack/distribution-cerebras \
|
||||
|
@ -61,6 +62,6 @@ docker run \
|
|||
```bash
|
||||
llama stack build --template cerebras --image-type conda
|
||||
llama stack run ./run.yaml \
|
||||
--port 5001 \
|
||||
--port 8321 \
|
||||
--env CEREBRAS_API_KEY=$CEREBRAS_API_KEY
|
||||
```
|
||||
|
|
|
@ -53,7 +53,7 @@ docker compose down
|
|||
|
||||
#### Start Dell-TGI server locally
|
||||
```
|
||||
docker run -it --shm-size 1g -p 80:80 --gpus 4 \
|
||||
docker run -it --pull always --shm-size 1g -p 80:80 --gpus 4 \
|
||||
-e NUM_SHARD=4
|
||||
-e MAX_BATCH_PREFILL_TOKENS=32768 \
|
||||
-e MAX_INPUT_TOKENS=8000 \
|
||||
|
@ -65,7 +65,7 @@ registry.dell.huggingface.co/enterprise-dell-inference-meta-llama-meta-llama-3.1
|
|||
#### Start Llama Stack server pointing to TGI server
|
||||
|
||||
```
|
||||
docker run --network host -it -p 8321:8321 -v ./run.yaml:/root/my-run.yaml --gpus=all llamastack/distribution-tgi --yaml_config /root/my-run.yaml
|
||||
docker run --pull always --network host -it -p 8321:8321 -v ./run.yaml:/root/my-run.yaml --gpus=all llamastack/distribution-tgi --yaml_config /root/my-run.yaml
|
||||
```
|
||||
|
||||
Make sure in you `run.yaml` file, you inference provider is pointing to the correct TGI server endpoint. E.g.
|
||||
|
|
|
@ -55,6 +55,7 @@ export CUDA_VISIBLE_DEVICES=0
|
|||
export LLAMA_STACK_PORT=8321
|
||||
|
||||
docker run --rm -it \
|
||||
--pull always \
|
||||
--network host \
|
||||
-v $HOME/.cache/huggingface:/data \
|
||||
-e HF_TOKEN=$HF_TOKEN \
|
||||
|
@ -78,6 +79,7 @@ export SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
|
|||
export CUDA_VISIBLE_DEVICES=1
|
||||
|
||||
docker run --rm -it \
|
||||
--pull always \
|
||||
--network host \
|
||||
-v $HOME/.cache/huggingface:/data \
|
||||
-e HF_TOKEN=$HF_TOKEN \
|
||||
|
@ -120,6 +122,7 @@ This method allows you to get started quickly without having to build the distri
|
|||
|
||||
```bash
|
||||
docker run -it \
|
||||
--pull always \
|
||||
--network host \
|
||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||
-v $HOME/.llama:/root/.llama \
|
||||
|
@ -147,6 +150,7 @@ export SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
|
|||
|
||||
docker run \
|
||||
-it \
|
||||
--pull always \
|
||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||
-v $HOME/.llama:/root/.llama \
|
||||
-v ./llama_stack/templates/tgi/run-with-safety.yaml:/root/my-run.yaml \
|
||||
|
|
|
@ -30,7 +30,7 @@ The `llamastack/distribution-fireworks` distribution consists of the following p
|
|||
|
||||
The following environment variables can be configured:
|
||||
|
||||
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
|
||||
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `8321`)
|
||||
- `FIREWORKS_API_KEY`: Fireworks.AI API Key (default: ``)
|
||||
|
||||
### Models
|
||||
|
@ -63,9 +63,10 @@ You can do this via Conda (build code) or Docker which has a pre-built image.
|
|||
This method allows you to get started quickly without having to build the distribution code.
|
||||
|
||||
```bash
|
||||
LLAMA_STACK_PORT=5001
|
||||
LLAMA_STACK_PORT=8321
|
||||
docker run \
|
||||
-it \
|
||||
--pull always \
|
||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||
llamastack/distribution-fireworks \
|
||||
--port $LLAMA_STACK_PORT \
|
||||
|
|
|
@ -30,7 +30,7 @@ The `llamastack/distribution-groq` distribution consists of the following provid
|
|||
|
||||
The following environment variables can be configured:
|
||||
|
||||
- `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
|
||||
- `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `8321`)
|
||||
- `GROQ_API_KEY`: Groq API Key (default: ``)
|
||||
|
||||
### Models
|
||||
|
@ -58,9 +58,10 @@ You can do this via Conda (build code) or Docker which has a pre-built image.
|
|||
This method allows you to get started quickly without having to build the distribution code.
|
||||
|
||||
```bash
|
||||
LLAMA_STACK_PORT=5001
|
||||
LLAMA_STACK_PORT=8321
|
||||
docker run \
|
||||
-it \
|
||||
--pull always \
|
||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||
llamastack/distribution-groq \
|
||||
--port $LLAMA_STACK_PORT \
|
||||
|
|
|
@ -32,7 +32,7 @@ Note that you need access to nvidia GPUs to run this distribution. This distribu
|
|||
|
||||
The following environment variables can be configured:
|
||||
|
||||
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
|
||||
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `8321`)
|
||||
- `INFERENCE_MODEL`: Inference model loaded into the Meta Reference server (default: `meta-llama/Llama-3.2-3B-Instruct`)
|
||||
- `INFERENCE_CHECKPOINT_DIR`: Directory containing the Meta Reference model checkpoint (default: `null`)
|
||||
- `SAFETY_MODEL`: Name of the safety (Llama-Guard) model to use (default: `meta-llama/Llama-Guard-3-1B`)
|
||||
|
@ -77,9 +77,10 @@ You can do this via Conda (build code) or Docker which has a pre-built image.
|
|||
This method allows you to get started quickly without having to build the distribution code.
|
||||
|
||||
```bash
|
||||
LLAMA_STACK_PORT=5001
|
||||
LLAMA_STACK_PORT=8321
|
||||
docker run \
|
||||
-it \
|
||||
--pull always \
|
||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||
-v ~/.llama:/root/.llama \
|
||||
llamastack/distribution-meta-reference-gpu \
|
||||
|
@ -92,6 +93,7 @@ If you are using Llama Stack Safety / Shield APIs, use:
|
|||
```bash
|
||||
docker run \
|
||||
-it \
|
||||
--pull always \
|
||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||
-v ~/.llama:/root/.llama \
|
||||
llamastack/distribution-meta-reference-gpu \
|
||||
|
@ -107,7 +109,7 @@ Make sure you have done `uv pip install llama-stack` and have the Llama Stack CL
|
|||
```bash
|
||||
llama stack build --template meta-reference-gpu --image-type conda
|
||||
llama stack run distributions/meta-reference-gpu/run.yaml \
|
||||
--port 5001 \
|
||||
--port 8321 \
|
||||
--env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
|
||||
```
|
||||
|
||||
|
@ -115,7 +117,7 @@ If you are using Llama Stack Safety / Shield APIs, use:
|
|||
|
||||
```bash
|
||||
llama stack run distributions/meta-reference-gpu/run-with-safety.yaml \
|
||||
--port 5001 \
|
||||
--port 8321 \
|
||||
--env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \
|
||||
--env SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
|
||||
```
|
||||
|
|
|
@ -34,7 +34,7 @@ Note that you need access to nvidia GPUs to run this distribution. This distribu
|
|||
|
||||
The following environment variables can be configured:
|
||||
|
||||
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
|
||||
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `8321`)
|
||||
- `INFERENCE_MODEL`: Inference model loaded into the Meta Reference server (default: `meta-llama/Llama-3.2-3B-Instruct`)
|
||||
- `INFERENCE_CHECKPOINT_DIR`: Directory containing the Meta Reference model checkpoint (default: `null`)
|
||||
|
||||
|
@ -77,9 +77,10 @@ You can do this via Conda (build code) or Docker which has a pre-built image.
|
|||
This method allows you to get started quickly without having to build the distribution code.
|
||||
|
||||
```bash
|
||||
LLAMA_STACK_PORT=5001
|
||||
LLAMA_STACK_PORT=8321
|
||||
docker run \
|
||||
-it \
|
||||
--pull always \
|
||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||
-v ~/.llama:/root/.llama \
|
||||
llamastack/distribution-meta-reference-quantized-gpu \
|
||||
|
@ -92,6 +93,7 @@ If you are using Llama Stack Safety / Shield APIs, use:
|
|||
```bash
|
||||
docker run \
|
||||
-it \
|
||||
--pull always \
|
||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||
-v ~/.llama:/root/.llama \
|
||||
llamastack/distribution-meta-reference-quantized-gpu \
|
||||
|
|
|
@ -15,7 +15,7 @@ The `llamastack/distribution-nvidia` distribution consists of the following prov
|
|||
|
||||
The following environment variables can be configured:
|
||||
|
||||
- `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
|
||||
- `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `8321`)
|
||||
- `NVIDIA_API_KEY`: NVIDIA API Key (default: ``)
|
||||
|
||||
### Models
|
||||
|
@ -39,9 +39,10 @@ You can do this via Conda (build code) or Docker which has a pre-built image.
|
|||
This method allows you to get started quickly without having to build the distribution code.
|
||||
|
||||
```bash
|
||||
LLAMA_STACK_PORT=5001
|
||||
LLAMA_STACK_PORT=8321
|
||||
docker run \
|
||||
-it \
|
||||
--pull always \
|
||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||
-v ./run.yaml:/root/my-run.yaml \
|
||||
llamastack/distribution-nvidia \
|
||||
|
@ -55,6 +56,6 @@ docker run \
|
|||
```bash
|
||||
llama stack build --template nvidia --image-type conda
|
||||
llama stack run ./run.yaml \
|
||||
--port 5001 \
|
||||
--port 8321 \
|
||||
--env NVIDIA_API_KEY=$NVIDIA_API_KEY
|
||||
```
|
||||
|
|
|
@ -32,7 +32,7 @@ You should use this distribution if you have a regular desktop machine without v
|
|||
|
||||
The following environment variables can be configured:
|
||||
|
||||
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
|
||||
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `8321`)
|
||||
- `OLLAMA_URL`: URL of the Ollama server (default: `http://127.0.0.1:11434`)
|
||||
- `INFERENCE_MODEL`: Inference model loaded into the Ollama server (default: `meta-llama/Llama-3.2-3B-Instruct`)
|
||||
- `SAFETY_MODEL`: Safety model loaded into the Ollama server (default: `meta-llama/Llama-Guard-3-1B`)
|
||||
|
@ -71,9 +71,10 @@ Now you are ready to run Llama Stack with Ollama as the inference provider. You
|
|||
This method allows you to get started quickly without having to build the distribution code.
|
||||
|
||||
```bash
|
||||
export LLAMA_STACK_PORT=5001
|
||||
export LLAMA_STACK_PORT=8321
|
||||
docker run \
|
||||
-it \
|
||||
--pull always \
|
||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||
-v ~/.llama:/root/.llama \
|
||||
llamastack/distribution-ollama \
|
||||
|
@ -91,6 +92,7 @@ cd /path/to/llama-stack
|
|||
|
||||
docker run \
|
||||
-it \
|
||||
--pull always \
|
||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||
-v ~/.llama:/root/.llama \
|
||||
-v ./llama_stack/templates/ollama/run-with-safety.yaml:/root/my-run.yaml \
|
||||
|
@ -107,7 +109,7 @@ docker run \
|
|||
Make sure you have done `uv pip install llama-stack` and have the Llama Stack CLI available.
|
||||
|
||||
```bash
|
||||
export LLAMA_STACK_PORT=5001
|
||||
export LLAMA_STACK_PORT=8321
|
||||
|
||||
llama stack build --template ollama --image-type conda
|
||||
llama stack run ./run.yaml \
|
||||
|
|
|
@ -30,7 +30,7 @@ The `llamastack/distribution-passthrough` distribution consists of the following
|
|||
|
||||
The following environment variables can be configured:
|
||||
|
||||
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
|
||||
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `8321`)
|
||||
- `PASSTHROUGH_API_KEY`: Passthrough API Key (default: ``)
|
||||
- `PASSTHROUGH_URL`: Passthrough URL (default: ``)
|
||||
|
||||
|
|
|
@ -31,7 +31,7 @@ You can use this distribution if you have GPUs and want to run an independent vL
|
|||
|
||||
The following environment variables can be configured:
|
||||
|
||||
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
|
||||
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `8321`)
|
||||
- `INFERENCE_MODEL`: Inference model loaded into the vLLM server (default: `meta-llama/Llama-3.2-3B-Instruct`)
|
||||
- `VLLM_URL`: URL of the vLLM server with the main inference model (default: `http://host.docker.internal:5100/v1`)
|
||||
- `MAX_TOKENS`: Maximum number of tokens for generation (default: `4096`)
|
||||
|
@ -49,6 +49,7 @@ export INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
|
|||
export CUDA_VISIBLE_DEVICES=0
|
||||
|
||||
docker run \
|
||||
--pull always \
|
||||
--runtime nvidia \
|
||||
--gpus $CUDA_VISIBLE_DEVICES \
|
||||
-v ~/.cache/huggingface:/root/.cache/huggingface \
|
||||
|
@ -61,6 +62,8 @@ docker run \
|
|||
--port $INFERENCE_PORT
|
||||
```
|
||||
|
||||
Note that you'll also need to set `--enable-auto-tool-choice` and `--tool-call-parser` to [enable tool calling in vLLM](https://docs.vllm.ai/en/latest/features/tool_calling.html).
|
||||
|
||||
If you are using Llama Stack Safety / Shield APIs, then you will need to also run another instance of a vLLM with a corresponding safety model like `meta-llama/Llama-Guard-3-1B` using a script like:
|
||||
|
||||
```bash
|
||||
|
@ -69,6 +72,7 @@ export SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
|
|||
export CUDA_VISIBLE_DEVICES=1
|
||||
|
||||
docker run \
|
||||
--pull always \
|
||||
--runtime nvidia \
|
||||
--gpus $CUDA_VISIBLE_DEVICES \
|
||||
-v ~/.cache/huggingface:/root/.cache/huggingface \
|
||||
|
@ -92,10 +96,11 @@ This method allows you to get started quickly without having to build the distri
|
|||
```bash
|
||||
export INFERENCE_PORT=8000
|
||||
export INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
|
||||
export LLAMA_STACK_PORT=5001
|
||||
export LLAMA_STACK_PORT=8321
|
||||
|
||||
docker run \
|
||||
-it \
|
||||
--pull always \
|
||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||
-v ./run.yaml:/root/my-run.yaml \
|
||||
llamastack/distribution-remote-vllm \
|
||||
|
@ -117,6 +122,7 @@ cd /path/to/llama-stack
|
|||
|
||||
docker run \
|
||||
-it \
|
||||
--pull always \
|
||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||
-v ~/.llama:/root/.llama \
|
||||
-v ./llama_stack/templates/remote-vllm/run-with-safety.yaml:/root/my-run.yaml \
|
||||
|
@ -137,7 +143,7 @@ Make sure you have done `uv pip install llama-stack` and have the Llama Stack CL
|
|||
```bash
|
||||
export INFERENCE_PORT=8000
|
||||
export INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
|
||||
export LLAMA_STACK_PORT=5001
|
||||
export LLAMA_STACK_PORT=8321
|
||||
|
||||
cd distributions/remote-vllm
|
||||
llama stack build --template remote-vllm --image-type conda
|
||||
|
|
|
@ -27,7 +27,7 @@ The `llamastack/distribution-sambanova` distribution consists of the following p
|
|||
|
||||
The following environment variables can be configured:
|
||||
|
||||
- `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
|
||||
- `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `8321`)
|
||||
- `SAMBANOVA_API_KEY`: SambaNova API Key (default: ``)
|
||||
|
||||
### Models
|
||||
|
@ -59,9 +59,10 @@ You can do this via Conda (build code) or Docker which has a pre-built image.
|
|||
This method allows you to get started quickly without having to build the distribution code.
|
||||
|
||||
```bash
|
||||
LLAMA_STACK_PORT=5001
|
||||
LLAMA_STACK_PORT=8321
|
||||
docker run \
|
||||
-it \
|
||||
--pull always \
|
||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||
llamastack/distribution-sambanova \
|
||||
--port $LLAMA_STACK_PORT \
|
||||
|
|
|
@ -33,7 +33,7 @@ You can use this distribution if you have GPUs and want to run an independent TG
|
|||
|
||||
The following environment variables can be configured:
|
||||
|
||||
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
|
||||
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `8321`)
|
||||
- `INFERENCE_MODEL`: Inference model loaded into the TGI server (default: `meta-llama/Llama-3.2-3B-Instruct`)
|
||||
- `TGI_URL`: URL of the TGI server with the main inference model (default: `http://127.0.0.1:8080/v1`)
|
||||
- `TGI_SAFETY_URL`: URL of the TGI server with the safety model (default: `http://127.0.0.1:8081/v1`)
|
||||
|
@ -50,6 +50,7 @@ export INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
|
|||
export CUDA_VISIBLE_DEVICES=0
|
||||
|
||||
docker run --rm -it \
|
||||
--pull always \
|
||||
-v $HOME/.cache/huggingface:/data \
|
||||
-p $INFERENCE_PORT:$INFERENCE_PORT \
|
||||
--gpus $CUDA_VISIBLE_DEVICES \
|
||||
|
@ -70,6 +71,7 @@ export SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
|
|||
export CUDA_VISIBLE_DEVICES=1
|
||||
|
||||
docker run --rm -it \
|
||||
--pull always \
|
||||
-v $HOME/.cache/huggingface:/data \
|
||||
-p $SAFETY_PORT:$SAFETY_PORT \
|
||||
--gpus $CUDA_VISIBLE_DEVICES \
|
||||
|
@ -90,9 +92,10 @@ Now you are ready to run Llama Stack with TGI as the inference provider. You can
|
|||
This method allows you to get started quickly without having to build the distribution code.
|
||||
|
||||
```bash
|
||||
LLAMA_STACK_PORT=5001
|
||||
LLAMA_STACK_PORT=8321
|
||||
docker run \
|
||||
-it \
|
||||
--pull always \
|
||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||
llamastack/distribution-tgi \
|
||||
--port $LLAMA_STACK_PORT \
|
||||
|
@ -109,6 +112,7 @@ cd /path/to/llama-stack
|
|||
|
||||
docker run \
|
||||
-it \
|
||||
--pull always \
|
||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||
-v ~/.llama:/root/.llama \
|
||||
-v ./llama_stack/templates/tgi/run-with-safety.yaml:/root/my-run.yaml \
|
||||
|
|
|
@ -30,7 +30,7 @@ The `llamastack/distribution-together` distribution consists of the following pr
|
|||
|
||||
The following environment variables can be configured:
|
||||
|
||||
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
|
||||
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `8321`)
|
||||
- `TOGETHER_API_KEY`: Together.AI API Key (default: ``)
|
||||
|
||||
### Models
|
||||
|
@ -64,9 +64,10 @@ You can do this via Conda (build code) or Docker which has a pre-built image.
|
|||
This method allows you to get started quickly without having to build the distribution code.
|
||||
|
||||
```bash
|
||||
LLAMA_STACK_PORT=5001
|
||||
LLAMA_STACK_PORT=8321
|
||||
docker run \
|
||||
-it \
|
||||
--pull always \
|
||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||
llamastack/distribution-together \
|
||||
--port $LLAMA_STACK_PORT \
|
||||
|
|
|
@ -54,6 +54,7 @@ mkdir -p ~/.llama
|
|||
Then you can start the server using the container tool of your choice. For example, if you are running Docker you can use the following command:
|
||||
```bash
|
||||
docker run -it \
|
||||
--pull always \
|
||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||
-v ~/.llama:/root/.llama \
|
||||
llamastack/distribution-ollama \
|
||||
|
@ -74,6 +75,7 @@ Docker containers run in their own isolated network namespaces on Linux. To allo
|
|||
Linux users having issues running the above command should instead try the following:
|
||||
```bash
|
||||
docker run -it \
|
||||
--pull always \
|
||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||
-v ~/.llama:/root/.llama \
|
||||
--network=host \
|
||||
|
@ -197,9 +199,7 @@ import os
|
|||
import uuid
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack_client.lib.agents.agent import Agent
|
||||
from llama_stack_client.lib.agents.event_logger import EventLogger
|
||||
from llama_stack_client.types import Document
|
||||
from llama_stack_client import Agent, AgentEventLogger, RAGDocument
|
||||
|
||||
|
||||
def create_http_client():
|
||||
|
@ -225,7 +225,7 @@ client = (
|
|||
# Documents to be used for RAG
|
||||
urls = ["chat.rst", "llama3.rst", "memory_optimizations.rst", "lora_finetune.rst"]
|
||||
documents = [
|
||||
Document(
|
||||
RAGDocument(
|
||||
document_id=f"num-{i}",
|
||||
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
|
||||
mime_type="text/plain",
|
||||
|
@ -284,7 +284,7 @@ for prompt in user_prompts:
|
|||
messages=[{"role": "user", "content": prompt}],
|
||||
session_id=session_id,
|
||||
)
|
||||
for log in EventLogger().log(response):
|
||||
for log in AgentEventLogger().log(response):
|
||||
log.print()
|
||||
```
|
||||
|
||||
|
|
|
@ -118,6 +118,7 @@ Playground can also be started in a docker image:
|
|||
export LLAMA_STACK_URL=http://localhost:11434
|
||||
|
||||
docker run \
|
||||
--pull always \
|
||||
-p 8501:8501 \
|
||||
-e LLAMA_STACK_ENDPOINT=$LLAMA_STACK_URL \
|
||||
quay.io/jland/llama-stack-playground
|
||||
|
|
|
@ -48,7 +48,7 @@
|
|||
"outputs": [],
|
||||
"source": [
|
||||
"HOST = \"localhost\" # Replace with your host\n",
|
||||
"PORT = 5001 # Replace with your port\n",
|
||||
"PORT = 8321 # Replace with your port\n",
|
||||
"MODEL_NAME='meta-llama/Llama-3.2-3B-Instruct'"
|
||||
]
|
||||
},
|
||||
|
@ -369,6 +369,9 @@
|
|||
}
|
||||
],
|
||||
"metadata": {
|
||||
"fileHeader": "",
|
||||
"fileUid": "7da25939-a2a3-463c-958e-9cdfd710d158",
|
||||
"isAdHoc": false,
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
|
@ -386,7 +389,5 @@
|
|||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.15"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
}
|
||||
|
|
|
@ -43,7 +43,7 @@
|
|||
"source": [
|
||||
"#### 2. Set Up Local and Cloud Clients\n",
|
||||
"\n",
|
||||
"Initialize both clients, specifying the `base_url` for each instance. In this case, we have the local distribution running on `http://localhost:8321` and the cloud distribution running on `http://localhost:5001`.\n"
|
||||
"Initialize both clients, specifying the `base_url` for each instance. In this case, we have the local distribution running on `http://localhost:8321` and the cloud distribution running on `http://localhost:8322`.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -236,6 +236,9 @@
|
|||
}
|
||||
],
|
||||
"metadata": {
|
||||
"fileHeader": "",
|
||||
"fileUid": "e11939ac-dfbc-4a1c-83be-e494c7f803b8",
|
||||
"isAdHoc": false,
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
|
@ -253,7 +256,5 @@
|
|||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.15"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
}
|
||||
|
|
|
@ -47,7 +47,7 @@
|
|||
"outputs": [],
|
||||
"source": [
|
||||
"HOST = \"localhost\" # Replace with your host\n",
|
||||
"PORT = 5001 # Replace with your port\n",
|
||||
"PORT = 8321 # Replace with your port\n",
|
||||
"MODEL_NAME='meta-llama/Llama-3.2-3B-Instruct'"
|
||||
]
|
||||
},
|
||||
|
@ -281,6 +281,9 @@
|
|||
}
|
||||
],
|
||||
"metadata": {
|
||||
"fileHeader": "",
|
||||
"fileUid": "b1b93b6e-22a2-4c24-8cb0-161fdafff29a",
|
||||
"isAdHoc": false,
|
||||
"kernelspec": {
|
||||
"display_name": "base",
|
||||
"language": "python",
|
||||
|
@ -298,7 +301,5 @@
|
|||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.2"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
}
|
||||
|
|
|
@ -45,7 +45,7 @@
|
|||
"outputs": [],
|
||||
"source": [
|
||||
"HOST = \"localhost\" # Replace with your host\n",
|
||||
"CLOUD_PORT = 5001 # Replace with your cloud distro port\n",
|
||||
"CLOUD_PORT = 8321 # Replace with your cloud distro port\n",
|
||||
"MODEL_NAME='Llama3.2-11B-Vision-Instruct'"
|
||||
]
|
||||
},
|
||||
|
@ -180,6 +180,9 @@
|
|||
}
|
||||
],
|
||||
"metadata": {
|
||||
"fileHeader": "",
|
||||
"fileUid": "37bbbfda-8e42-446c-89c7-59dd49e2d339",
|
||||
"isAdHoc": false,
|
||||
"kernelspec": {
|
||||
"display_name": "base",
|
||||
"language": "python",
|
||||
|
@ -197,7 +200,5 @@
|
|||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.2"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
}
|
||||
|
|
|
@ -46,7 +46,7 @@
|
|||
"nest_asyncio.apply()\n",
|
||||
"\n",
|
||||
"HOST = \"localhost\"\n",
|
||||
"PORT = 5001\n",
|
||||
"PORT = 8321\n",
|
||||
"MODEL_NAME = \"meta-llama/Llama-3.2-3B-Instruct\"\n"
|
||||
]
|
||||
},
|
||||
|
@ -335,6 +335,9 @@
|
|||
}
|
||||
],
|
||||
"metadata": {
|
||||
"fileHeader": "",
|
||||
"fileUid": "f0abbf6d-ed52-40ad-afb4-f5ec99130249",
|
||||
"isAdHoc": false,
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
|
@ -352,7 +355,5 @@
|
|||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.15"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
}
|
||||
|
|
|
@ -45,7 +45,7 @@
|
|||
"outputs": [],
|
||||
"source": [
|
||||
"HOST = \"localhost\" # Replace with your host\n",
|
||||
"PORT = 5001 # Replace with your port\n",
|
||||
"PORT = 8321 # Replace with your port\n",
|
||||
"MODEL_NAME='meta-llama/Llama-3.2-3B-Instruct'\n",
|
||||
"MEMORY_BANK_ID=\"tutorial_bank\""
|
||||
]
|
||||
|
@ -378,6 +378,9 @@
|
|||
}
|
||||
],
|
||||
"metadata": {
|
||||
"fileHeader": "",
|
||||
"fileUid": "73bc3357-0e5e-42ff-95b1-40b916d24c4f",
|
||||
"isAdHoc": false,
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
|
@ -395,7 +398,5 @@
|
|||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.15"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
}
|
||||
|
|
|
@ -49,7 +49,7 @@
|
|||
"outputs": [],
|
||||
"source": [
|
||||
"HOST = \"localhost\" # Replace with your host\n",
|
||||
"PORT = 5001 # Replace with your port\n",
|
||||
"PORT = 8321 # Replace with your port\n",
|
||||
"SHEILD_NAME=\"meta-llama/Llama-Guard-3-1B\""
|
||||
]
|
||||
},
|
||||
|
@ -112,6 +112,9 @@
|
|||
}
|
||||
],
|
||||
"metadata": {
|
||||
"fileHeader": "",
|
||||
"fileUid": "9afaddb7-c2fb-4309-8fa0-761697de53f0",
|
||||
"isAdHoc": false,
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
|
@ -129,7 +132,5 @@
|
|||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.10"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
}
|
||||
|
|
|
@ -50,7 +50,7 @@
|
|||
"outputs": [],
|
||||
"source": [
|
||||
"HOST = \"localhost\" # Replace with your host\n",
|
||||
"PORT = 5001 # Replace with your port\n",
|
||||
"PORT = 8321 # Replace with your port\n",
|
||||
"MODEL_NAME = \"meta-llama/Llama-3.2-3B-Instruct\"\n"
|
||||
]
|
||||
},
|
||||
|
@ -168,6 +168,9 @@
|
|||
}
|
||||
],
|
||||
"metadata": {
|
||||
"fileHeader": "",
|
||||
"fileUid": "8de24775-c4a0-49c7-904e-608264f69292",
|
||||
"isAdHoc": false,
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
|
@ -185,7 +188,5 @@
|
|||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.15"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
}
|
||||
|
|
|
@ -96,7 +96,7 @@ If you're looking for more specific topics, we have a [Zero to Hero Guide](#next
|
|||
3. **Set the ENV variables by exporting them to the terminal**:
|
||||
```bash
|
||||
export OLLAMA_URL="http://localhost:11434"
|
||||
export LLAMA_STACK_PORT=5001
|
||||
export LLAMA_STACK_PORT=8321
|
||||
export INFERENCE_MODEL="meta-llama/Llama-3.2-3B-Instruct"
|
||||
export SAFETY_MODEL="meta-llama/Llama-Guard-3-1B"
|
||||
```
|
||||
|
@ -112,7 +112,7 @@ If you're looking for more specific topics, we have a [Zero to Hero Guide](#next
|
|||
```
|
||||
Note: Every time you run a new model with `ollama run`, you will need to restart the llama stack. Otherwise it won't see the new model.
|
||||
|
||||
The server will start and listen on `http://localhost:5001`.
|
||||
The server will start and listen on `http://localhost:8321`.
|
||||
|
||||
---
|
||||
## Test with `llama-stack-client` CLI
|
||||
|
@ -120,11 +120,11 @@ After setting up the server, open a new terminal window and configure the llama-
|
|||
|
||||
1. Configure the CLI to point to the llama-stack server.
|
||||
```bash
|
||||
llama-stack-client configure --endpoint http://localhost:5001
|
||||
llama-stack-client configure --endpoint http://localhost:8321
|
||||
```
|
||||
**Expected Output:**
|
||||
```bash
|
||||
Done! You can now use the Llama Stack Client CLI with endpoint http://localhost:5001
|
||||
Done! You can now use the Llama Stack Client CLI with endpoint http://localhost:8321
|
||||
```
|
||||
2. Test the CLI by running inference:
|
||||
```bash
|
||||
|
@ -218,7 +218,7 @@ if INFERENCE_MODEL is None:
|
|||
raise ValueError("The environment variable 'INFERENCE_MODEL' is not set.")
|
||||
|
||||
# Initialize the clien
|
||||
client = LlamaStackClient(base_url="http://localhost:5001")
|
||||
client = LlamaStackClient(base_url="http://localhost:8321")
|
||||
|
||||
# Create a chat completion reques
|
||||
response = client.inference.chat_completion(
|
||||
|
|
|
@ -36,7 +36,6 @@ from llama_stack.apis.inference import (
|
|||
)
|
||||
from llama_stack.apis.safety import SafetyViolation
|
||||
from llama_stack.apis.tools import ToolDef
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||
|
||||
|
||||
|
@ -189,13 +188,11 @@ class AgentToolGroupWithArgs(BaseModel):
|
|||
args: Dict[str, Any]
|
||||
|
||||
|
||||
AgentToolGroup = register_schema(
|
||||
Union[
|
||||
AgentToolGroup = Union[
|
||||
str,
|
||||
AgentToolGroupWithArgs,
|
||||
],
|
||||
name="AgentTool",
|
||||
)
|
||||
]
|
||||
register_schema(AgentToolGroup, name="AgentTool")
|
||||
|
||||
|
||||
class AgentConfigCommon(BaseModel):
|
||||
|
@ -312,8 +309,7 @@ class AgentTurnResponseTurnAwaitingInputPayload(BaseModel):
|
|||
turn: Turn
|
||||
|
||||
|
||||
AgentTurnResponseEventPayload = register_schema(
|
||||
Annotated[
|
||||
AgentTurnResponseEventPayload = Annotated[
|
||||
Union[
|
||||
AgentTurnResponseStepStartPayload,
|
||||
AgentTurnResponseStepProgressPayload,
|
||||
|
@ -323,9 +319,8 @@ AgentTurnResponseEventPayload = register_schema(
|
|||
AgentTurnResponseTurnAwaitingInputPayload,
|
||||
],
|
||||
Field(discriminator="event_type"),
|
||||
],
|
||||
name="AgentTurnResponseEventPayload",
|
||||
)
|
||||
]
|
||||
register_schema(AgentTurnResponseEventPayload, name="AgentTurnResponseEventPayload")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -387,7 +382,6 @@ class AgentStepResponse(BaseModel):
|
|||
|
||||
|
||||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class Agents(Protocol):
|
||||
"""Agents API for creating and interacting with agentic systems.
|
||||
|
||||
|
@ -399,7 +393,7 @@ class Agents(Protocol):
|
|||
- Agents can also use Memory to retrieve information from knowledge bases. See the RAG Tool and Vector IO APIs for more details.
|
||||
"""
|
||||
|
||||
@webmethod(route="/agents", method="POST")
|
||||
@webmethod(route="/agents", method="POST", descriptive_name="create_agent")
|
||||
async def create_agent(
|
||||
self,
|
||||
agent_config: AgentConfig,
|
||||
|
@ -411,7 +405,9 @@ class Agents(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/agents/{agent_id}/session/{session_id}/turn", method="POST")
|
||||
@webmethod(
|
||||
route="/agents/{agent_id}/session/{session_id}/turn", method="POST", descriptive_name="create_agent_turn"
|
||||
)
|
||||
async def create_agent_turn(
|
||||
self,
|
||||
agent_id: str,
|
||||
|
@ -443,6 +439,7 @@ class Agents(Protocol):
|
|||
@webmethod(
|
||||
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume",
|
||||
method="POST",
|
||||
descriptive_name="resume_agent_turn",
|
||||
)
|
||||
async def resume_agent_turn(
|
||||
self,
|
||||
|
@ -505,7 +502,7 @@ class Agents(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/agents/{agent_id}/session", method="POST")
|
||||
@webmethod(route="/agents/{agent_id}/session", method="POST", descriptive_name="create_agent_session")
|
||||
async def create_agent_session(
|
||||
self,
|
||||
agent_id: str,
|
||||
|
|
|
@ -63,19 +63,15 @@ class TextContentItem(BaseModel):
|
|||
|
||||
|
||||
# other modalities can be added here
|
||||
InterleavedContentItem = register_schema(
|
||||
Annotated[
|
||||
InterleavedContentItem = Annotated[
|
||||
Union[ImageContentItem, TextContentItem],
|
||||
Field(discriminator="type"),
|
||||
],
|
||||
name="InterleavedContentItem",
|
||||
)
|
||||
]
|
||||
register_schema(InterleavedContentItem, name="InterleavedContentItem")
|
||||
|
||||
# accept a single "str" as a special case since it is common
|
||||
InterleavedContent = register_schema(
|
||||
Union[str, InterleavedContentItem, List[InterleavedContentItem]],
|
||||
name="InterleavedContent",
|
||||
)
|
||||
InterleavedContent = Union[str, InterleavedContentItem, List[InterleavedContentItem]]
|
||||
register_schema(InterleavedContent, name="InterleavedContent")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -109,10 +105,8 @@ class ToolCallDelta(BaseModel):
|
|||
|
||||
|
||||
# streaming completions send a stream of ContentDeltas
|
||||
ContentDelta = register_schema(
|
||||
Annotated[
|
||||
ContentDelta = Annotated[
|
||||
Union[TextDelta, ImageDelta, ToolCallDelta],
|
||||
Field(discriminator="type"),
|
||||
],
|
||||
name="ContentDelta",
|
||||
)
|
||||
]
|
||||
register_schema(ContentDelta, name="ContentDelta")
|
||||
|
|
|
@ -10,14 +10,14 @@ from pydantic import BaseModel
|
|||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Job(BaseModel):
|
||||
job_id: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class JobStatus(Enum):
|
||||
completed = "completed"
|
||||
in_progress = "in_progress"
|
||||
failed = "failed"
|
||||
scheduled = "scheduled"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Job(BaseModel):
|
||||
job_id: str
|
||||
status: JobStatus
|
||||
|
|
|
@ -72,8 +72,7 @@ class DialogType(BaseModel):
|
|||
type: Literal["dialog"] = "dialog"
|
||||
|
||||
|
||||
ParamType = register_schema(
|
||||
Annotated[
|
||||
ParamType = Annotated[
|
||||
Union[
|
||||
StringType,
|
||||
NumberType,
|
||||
|
@ -87,9 +86,8 @@ ParamType = register_schema(
|
|||
AgentTurnInputType,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
],
|
||||
name="ParamType",
|
||||
)
|
||||
]
|
||||
register_schema(ParamType, name="ParamType")
|
||||
|
||||
"""
|
||||
# TODO: recursive definition of ParamType in these containers
|
||||
|
|
|
@ -84,13 +84,11 @@ class RowsDataSource(BaseModel):
|
|||
rows: List[Dict[str, Any]]
|
||||
|
||||
|
||||
DataSource = register_schema(
|
||||
Annotated[
|
||||
DataSource = Annotated[
|
||||
Union[URIDataSource, RowsDataSource],
|
||||
Field(discriminator="type"),
|
||||
],
|
||||
name="DataSource",
|
||||
)
|
||||
]
|
||||
register_schema(DataSource, name="DataSource")
|
||||
|
||||
|
||||
class CommonDatasetFields(BaseModel):
|
||||
|
|
|
@ -10,7 +10,7 @@ from pydantic import BaseModel, Field
|
|||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.apis.agents import AgentConfig
|
||||
from llama_stack.apis.common.job_types import Job, JobStatus
|
||||
from llama_stack.apis.common.job_types import Job
|
||||
from llama_stack.apis.inference import SamplingParams, SystemMessage
|
||||
from llama_stack.apis.scoring import ScoringResult
|
||||
from llama_stack.apis.scoring_functions import ScoringFnParams
|
||||
|
@ -43,10 +43,8 @@ class AgentCandidate(BaseModel):
|
|||
config: AgentConfig
|
||||
|
||||
|
||||
EvalCandidate = register_schema(
|
||||
Annotated[Union[ModelCandidate, AgentCandidate], Field(discriminator="type")],
|
||||
name="EvalCandidate",
|
||||
)
|
||||
EvalCandidate = Annotated[Union[ModelCandidate, AgentCandidate], Field(discriminator="type")]
|
||||
register_schema(EvalCandidate, name="EvalCandidate")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -117,7 +115,7 @@ class Eval(Protocol):
|
|||
"""
|
||||
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="GET")
|
||||
async def job_status(self, benchmark_id: str, job_id: str) -> JobStatus:
|
||||
async def job_status(self, benchmark_id: str, job_id: str) -> Job:
|
||||
"""Get the status of a job.
|
||||
|
||||
:param benchmark_id: The ID of the benchmark to run the evaluation on.
|
||||
|
|
|
@ -144,8 +144,7 @@ class CompletionMessage(BaseModel):
|
|||
tool_calls: Optional[List[ToolCall]] = Field(default_factory=list)
|
||||
|
||||
|
||||
Message = register_schema(
|
||||
Annotated[
|
||||
Message = Annotated[
|
||||
Union[
|
||||
UserMessage,
|
||||
SystemMessage,
|
||||
|
@ -153,9 +152,8 @@ Message = register_schema(
|
|||
CompletionMessage,
|
||||
],
|
||||
Field(discriminator="role"),
|
||||
],
|
||||
name="Message",
|
||||
)
|
||||
]
|
||||
register_schema(Message, name="Message")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -263,13 +261,11 @@ class GrammarResponseFormat(BaseModel):
|
|||
bnf: Dict[str, Any]
|
||||
|
||||
|
||||
ResponseFormat = register_schema(
|
||||
Annotated[
|
||||
ResponseFormat = Annotated[
|
||||
Union[JsonSchemaResponseFormat, GrammarResponseFormat],
|
||||
Field(discriminator="type"),
|
||||
],
|
||||
name="ResponseFormat",
|
||||
)
|
||||
]
|
||||
register_schema(ResponseFormat, name="ResponseFormat")
|
||||
|
||||
|
||||
# This is an internally used class
|
||||
|
|
|
@ -24,17 +24,6 @@ class HealthInfo(BaseModel):
|
|||
# TODO: add a provider level status
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ProviderInfo(BaseModel):
|
||||
api: str
|
||||
provider_id: str
|
||||
provider_type: str
|
||||
|
||||
|
||||
class ListProvidersResponse(BaseModel):
|
||||
data: List[ProviderInfo]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class VersionInfo(BaseModel):
|
||||
version: str
|
||||
|
@ -46,9 +35,6 @@ class ListRoutesResponse(BaseModel):
|
|||
|
||||
@runtime_checkable
|
||||
class Inspect(Protocol):
|
||||
@webmethod(route="/inspect/providers", method="GET")
|
||||
async def list_providers(self) -> ListProvidersResponse: ...
|
||||
|
||||
@webmethod(route="/inspect/routes", method="GET")
|
||||
async def list_routes(self) -> ListRoutesResponse: ...
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Literal, Optional, Protocol
|
||||
from typing import Any, Dict, List, Literal, Optional, Protocol, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Annotated
|
||||
|
@ -88,10 +88,8 @@ class QATFinetuningConfig(BaseModel):
|
|||
group_size: int
|
||||
|
||||
|
||||
AlgorithmConfig = register_schema(
|
||||
Annotated[LoraFinetuningConfig | QATFinetuningConfig, Field(discriminator="type")],
|
||||
name="AlgorithmConfig",
|
||||
)
|
||||
AlgorithmConfig = Annotated[Union[LoraFinetuningConfig, QATFinetuningConfig], Field(discriminator="type")]
|
||||
register_schema(AlgorithmConfig, name="AlgorithmConfig")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -184,7 +182,7 @@ class PostTraining(Protocol):
|
|||
description="Model descriptor from `llama model list`",
|
||||
),
|
||||
checkpoint_dir: Optional[str] = None,
|
||||
algorithm_config: Optional[LoraFinetuningConfig | QATFinetuningConfig] = None,
|
||||
algorithm_config: Optional[AlgorithmConfig] = None,
|
||||
) -> PostTrainingJob: ...
|
||||
|
||||
@webmethod(route="/post-training/preference-optimize", method="POST")
|
||||
|
|
|
@ -36,6 +36,7 @@ class ScoringFnParamsType(Enum):
|
|||
@json_schema_type
|
||||
class AggregationFunctionType(Enum):
|
||||
average = "average"
|
||||
weighted_average = "weighted_average"
|
||||
median = "median"
|
||||
categorical_count = "categorical_count"
|
||||
accuracy = "accuracy"
|
||||
|
@ -78,17 +79,15 @@ class BasicScoringFnParams(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
ScoringFnParams = register_schema(
|
||||
Annotated[
|
||||
ScoringFnParams = Annotated[
|
||||
Union[
|
||||
LLMAsJudgeScoringFnParams,
|
||||
RegexParserScoringFnParams,
|
||||
BasicScoringFnParams,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
],
|
||||
name="ScoringFnParams",
|
||||
)
|
||||
]
|
||||
register_schema(ScoringFnParams, name="ScoringFnParams")
|
||||
|
||||
|
||||
class CommonScoringFnFields(BaseModel):
|
||||
|
|
|
@ -146,16 +146,14 @@ class SpanEndPayload(BaseModel):
|
|||
status: SpanStatus
|
||||
|
||||
|
||||
StructuredLogPayload = register_schema(
|
||||
Annotated[
|
||||
StructuredLogPayload = Annotated[
|
||||
Union[
|
||||
SpanStartPayload,
|
||||
SpanEndPayload,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
],
|
||||
name="StructuredLogPayload",
|
||||
)
|
||||
]
|
||||
register_schema(StructuredLogPayload, name="StructuredLogPayload")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -164,17 +162,15 @@ class StructuredLogEvent(EventCommon):
|
|||
payload: StructuredLogPayload
|
||||
|
||||
|
||||
Event = register_schema(
|
||||
Annotated[
|
||||
Event = Annotated[
|
||||
Union[
|
||||
UnstructuredLogEvent,
|
||||
MetricEvent,
|
||||
StructuredLogEvent,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
],
|
||||
name="Event",
|
||||
)
|
||||
]
|
||||
register_schema(Event, name="Event")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
@ -58,16 +58,14 @@ class LLMRAGQueryGeneratorConfig(BaseModel):
|
|||
template: str
|
||||
|
||||
|
||||
RAGQueryGeneratorConfig = register_schema(
|
||||
Annotated[
|
||||
RAGQueryGeneratorConfig = Annotated[
|
||||
Union[
|
||||
DefaultRAGQueryGeneratorConfig,
|
||||
LLMRAGQueryGeneratorConfig,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
],
|
||||
name="RAGQueryGeneratorConfig",
|
||||
)
|
||||
]
|
||||
register_schema(RAGQueryGeneratorConfig, name="RAGQueryGeneratorConfig")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
@ -69,7 +69,7 @@ class ToolGroup(Resource):
|
|||
|
||||
@json_schema_type
|
||||
class ToolInvocationResult(BaseModel):
|
||||
content: InterleavedContent
|
||||
content: Optional[InterleavedContent] = None
|
||||
error_message: Optional[str] = None
|
||||
error_code: Optional[int] = None
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
|
@ -140,9 +140,9 @@ class SpecialToolGroup(Enum):
|
|||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class ToolRuntime(Protocol):
|
||||
tool_store: ToolStore
|
||||
tool_store: ToolStore | None = None
|
||||
|
||||
rag_tool: RAGToolRuntime
|
||||
rag_tool: RAGToolRuntime | None = None
|
||||
|
||||
# TODO: This needs to be renamed once OPEN API generator name conflict issue is fixed.
|
||||
@webmethod(route="/tool-runtime/list-tools", method="GET")
|
||||
|
|
|
@ -36,7 +36,7 @@ class VectorDBStore(Protocol):
|
|||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class VectorIO(Protocol):
|
||||
vector_db_store: VectorDBStore
|
||||
vector_db_store: VectorDBStore | None = None
|
||||
|
||||
# this will just block now until chunks are inserted, but it should
|
||||
# probably return a Job instance which can be polled for completion
|
||||
|
|
|
@ -404,7 +404,7 @@ def _download_from_manifest(manifest_file: str, max_concurrent_downloads: int):
|
|||
d = json.load(f)
|
||||
manifest = Manifest(**d)
|
||||
|
||||
if datetime.now(timezone.utc) > manifest.expires_on:
|
||||
if datetime.now(timezone.utc) > manifest.expires_on.astimezone(timezone.utc):
|
||||
raise ValueError(f"Manifest URLs have expired on {manifest.expires_on}")
|
||||
|
||||
console = Console()
|
||||
|
|
86
llama_stack/distribution/access_control.py
Normal file
86
llama_stack/distribution/access_control.py
Normal file
|
@ -0,0 +1,86 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from llama_stack.distribution.datatypes import AccessAttributes
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = get_logger(__name__, category="core")
|
||||
|
||||
|
||||
def check_access(
|
||||
obj_identifier: str,
|
||||
obj_attributes: Optional[AccessAttributes],
|
||||
user_attributes: Optional[Dict[str, Any]] = None,
|
||||
) -> bool:
|
||||
"""Check if the current user has access to the given object, based on access attributes.
|
||||
|
||||
Access control algorithm:
|
||||
1. If the resource has no access_attributes, access is GRANTED to all authenticated users
|
||||
2. If the user has no attributes, access is DENIED to any object with access_attributes defined
|
||||
3. For each attribute category in the resource's access_attributes:
|
||||
a. If the user lacks that category, access is DENIED
|
||||
b. If the user has the category but none of the required values, access is DENIED
|
||||
c. If the user has at least one matching value in each required category, access is GRANTED
|
||||
|
||||
Example:
|
||||
# Resource requires:
|
||||
access_attributes = AccessAttributes(
|
||||
roles=["admin", "data-scientist"],
|
||||
teams=["ml-team"]
|
||||
)
|
||||
|
||||
# User has:
|
||||
user_attributes = {
|
||||
"roles": ["data-scientist", "engineer"],
|
||||
"teams": ["ml-team", "infra-team"],
|
||||
"projects": ["llama-3"]
|
||||
}
|
||||
|
||||
# Result: Access GRANTED
|
||||
# - User has the "data-scientist" role (matches one of the required roles)
|
||||
# - AND user is part of the "ml-team" (matches the required team)
|
||||
# - The extra "projects" attribute is ignored
|
||||
|
||||
Args:
|
||||
obj_identifier: The identifier of the resource object to check access for
|
||||
obj_attributes: The access attributes of the resource object
|
||||
user_attributes: The attributes of the current user
|
||||
|
||||
Returns:
|
||||
bool: True if access is granted, False if denied
|
||||
"""
|
||||
# If object has no access attributes, allow access by default
|
||||
if not obj_attributes:
|
||||
return True
|
||||
|
||||
# If no user attributes, deny access to objects with access control
|
||||
if not user_attributes:
|
||||
return False
|
||||
|
||||
dict_attribs = obj_attributes.model_dump(exclude_none=True)
|
||||
if not dict_attribs:
|
||||
return True
|
||||
|
||||
# Check each attribute category (requires ALL categories to match)
|
||||
# TODO: formalize this into a proper ABAC policy
|
||||
for attr_key, required_values in dict_attribs.items():
|
||||
user_values = user_attributes.get(attr_key, [])
|
||||
|
||||
if not user_values:
|
||||
logger.debug(f"Access denied to {obj_identifier}: missing required attribute category '{attr_key}'")
|
||||
return False
|
||||
|
||||
if not any(val in user_values for val in required_values):
|
||||
logger.debug(
|
||||
f"Access denied to {obj_identifier}: "
|
||||
f"no match for attribute '{attr_key}', required one of {required_values}"
|
||||
)
|
||||
return False
|
||||
|
||||
logger.debug(f"Access granted to {obj_identifier}")
|
||||
return True
|
|
@ -90,6 +90,7 @@ RUN apt-get update && apt-get install -y \
|
|||
procps psmisc lsof \
|
||||
traceroute \
|
||||
bubblewrap \
|
||||
gcc \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
ENV UV_SYSTEM_PYTHON=1
|
||||
|
@ -235,7 +236,7 @@ image_tag="$image_name:$version_tag"
|
|||
# Detect platform architecture
|
||||
ARCH=$(uname -m)
|
||||
if [ -n "$BUILD_PLATFORM" ]; then
|
||||
CLI_ARGS+=("--platform $BUILD_PLATFORM")
|
||||
CLI_ARGS+=("--platform" "$BUILD_PLATFORM")
|
||||
elif [ "$ARCH" = "arm64" ] || [ "$ARCH" = "aarch64" ]; then
|
||||
CLI_ARGS+=("--platform" "linux/arm64")
|
||||
elif [ "$ARCH" = "x86_64" ]; then
|
||||
|
|
|
@ -14,6 +14,7 @@ from llama_stack.apis.datasets import Dataset, DatasetInput
|
|||
from llama_stack.apis.eval import Eval
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.models import Model, ModelInput
|
||||
from llama_stack.apis.resource import Resource
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.scoring import Scoring
|
||||
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnInput
|
||||
|
@ -31,6 +32,115 @@ LLAMA_STACK_RUN_CONFIG_VERSION = "2"
|
|||
RoutingKey = Union[str, List[str]]
|
||||
|
||||
|
||||
class AccessAttributes(BaseModel):
|
||||
"""Structured representation of user attributes for access control.
|
||||
|
||||
This model defines a structured approach to representing user attributes
|
||||
with common standard categories for access control.
|
||||
|
||||
Standard attribute categories include:
|
||||
- roles: Role-based attributes (e.g., admin, data-scientist)
|
||||
- teams: Team-based attributes (e.g., ml-team, infra-team)
|
||||
- projects: Project access attributes (e.g., llama-3, customer-insights)
|
||||
- namespaces: Namespace-based access control for resource isolation
|
||||
"""
|
||||
|
||||
# Standard attribute categories - the minimal set we need now
|
||||
roles: Optional[List[str]] = Field(
|
||||
default=None, description="Role-based attributes (e.g., 'admin', 'data-scientist', 'user')"
|
||||
)
|
||||
|
||||
teams: Optional[List[str]] = Field(default=None, description="Team-based attributes (e.g., 'ml-team', 'nlp-team')")
|
||||
|
||||
projects: Optional[List[str]] = Field(
|
||||
default=None, description="Project-based access attributes (e.g., 'llama-3', 'customer-insights')"
|
||||
)
|
||||
|
||||
namespaces: Optional[List[str]] = Field(
|
||||
default=None, description="Namespace-based access control for resource isolation"
|
||||
)
|
||||
|
||||
|
||||
class ResourceWithACL(Resource):
|
||||
"""Extension of Resource that adds attribute-based access control capabilities.
|
||||
|
||||
This class adds an optional access_attributes field that allows fine-grained control
|
||||
over which users can access each resource. When attributes are defined, a user must have
|
||||
matching attributes to access the resource.
|
||||
|
||||
Attribute Matching Algorithm:
|
||||
1. If a resource has no access_attributes (None or empty dict), it's visible to all authenticated users
|
||||
2. Each key in access_attributes represents an attribute category (e.g., "roles", "teams", "projects")
|
||||
3. The matching algorithm requires ALL categories to match (AND relationship between categories)
|
||||
4. Within each category, ANY value match is sufficient (OR relationship within a category)
|
||||
|
||||
Examples:
|
||||
# Resource visible to everyone (no access control)
|
||||
model = Model(identifier="llama-2", ...)
|
||||
|
||||
# Resource visible only to admins
|
||||
model = Model(
|
||||
identifier="gpt-4",
|
||||
access_attributes=AccessAttributes(roles=["admin"])
|
||||
)
|
||||
|
||||
# Resource visible to data scientists on the ML team
|
||||
model = Model(
|
||||
identifier="private-model",
|
||||
access_attributes=AccessAttributes(
|
||||
roles=["data-scientist", "researcher"],
|
||||
teams=["ml-team"]
|
||||
)
|
||||
)
|
||||
# ^ User must have at least one of the roles AND be on the ml-team
|
||||
|
||||
# Resource visible to users with specific project access
|
||||
vector_db = VectorDB(
|
||||
identifier="customer-embeddings",
|
||||
access_attributes=AccessAttributes(
|
||||
projects=["customer-insights"],
|
||||
namespaces=["confidential"]
|
||||
)
|
||||
)
|
||||
# ^ User must have access to the customer-insights project AND have confidential namespace
|
||||
"""
|
||||
|
||||
access_attributes: Optional[AccessAttributes] = None
|
||||
|
||||
|
||||
# Use the extended Resource for all routable objects
|
||||
class ModelWithACL(Model, ResourceWithACL):
|
||||
pass
|
||||
|
||||
|
||||
class ShieldWithACL(Shield, ResourceWithACL):
|
||||
pass
|
||||
|
||||
|
||||
class VectorDBWithACL(VectorDB, ResourceWithACL):
|
||||
pass
|
||||
|
||||
|
||||
class DatasetWithACL(Dataset, ResourceWithACL):
|
||||
pass
|
||||
|
||||
|
||||
class ScoringFnWithACL(ScoringFn, ResourceWithACL):
|
||||
pass
|
||||
|
||||
|
||||
class BenchmarkWithACL(Benchmark, ResourceWithACL):
|
||||
pass
|
||||
|
||||
|
||||
class ToolWithACL(Tool, ResourceWithACL):
|
||||
pass
|
||||
|
||||
|
||||
class ToolGroupWithACL(ToolGroup, ResourceWithACL):
|
||||
pass
|
||||
|
||||
|
||||
RoutableObject = Union[
|
||||
Model,
|
||||
Shield,
|
||||
|
@ -45,14 +155,14 @@ RoutableObject = Union[
|
|||
|
||||
RoutableObjectWithProvider = Annotated[
|
||||
Union[
|
||||
Model,
|
||||
Shield,
|
||||
VectorDB,
|
||||
Dataset,
|
||||
ScoringFn,
|
||||
Benchmark,
|
||||
Tool,
|
||||
ToolGroup,
|
||||
ModelWithACL,
|
||||
ShieldWithACL,
|
||||
VectorDBWithACL,
|
||||
DatasetWithACL,
|
||||
ScoringFnWithACL,
|
||||
BenchmarkWithACL,
|
||||
ToolWithACL,
|
||||
ToolGroupWithACL,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
|
|
@ -11,9 +11,7 @@ from pydantic import BaseModel
|
|||
from llama_stack.apis.inspect import (
|
||||
HealthInfo,
|
||||
Inspect,
|
||||
ListProvidersResponse,
|
||||
ListRoutesResponse,
|
||||
ProviderInfo,
|
||||
RouteInfo,
|
||||
VersionInfo,
|
||||
)
|
||||
|
@ -39,24 +37,6 @@ class DistributionInspectImpl(Inspect):
|
|||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def list_providers(self) -> ListProvidersResponse:
|
||||
run_config = self.config.run_config
|
||||
|
||||
ret = []
|
||||
for api, providers in run_config.providers.items():
|
||||
ret.extend(
|
||||
[
|
||||
ProviderInfo(
|
||||
api=api,
|
||||
provider_id=p.provider_id,
|
||||
provider_type=p.provider_type,
|
||||
)
|
||||
for p in providers
|
||||
]
|
||||
)
|
||||
|
||||
return ListProvidersResponse(data=ret)
|
||||
|
||||
async def list_routes(self) -> ListRoutesResponse:
|
||||
run_config = self.config.run_config
|
||||
|
||||
|
|
|
@ -9,7 +9,6 @@ import inspect
|
|||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
|
@ -37,7 +36,10 @@ from llama_stack.distribution.request_headers import (
|
|||
request_provider_data_context,
|
||||
)
|
||||
from llama_stack.distribution.resolver import ProviderRegistry
|
||||
from llama_stack.distribution.server.endpoints import get_all_api_endpoints
|
||||
from llama_stack.distribution.server.endpoints import (
|
||||
find_matching_endpoint,
|
||||
initialize_endpoint_impls,
|
||||
)
|
||||
from llama_stack.distribution.stack import (
|
||||
construct_stack,
|
||||
get_stack_run_config_from_template,
|
||||
|
@ -232,31 +234,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
safe_config = redact_sensitive_fields(self.config.model_dump())
|
||||
console.print(yaml.dump(safe_config, indent=2))
|
||||
|
||||
endpoints = get_all_api_endpoints()
|
||||
endpoint_impls = {}
|
||||
|
||||
def _convert_path_to_regex(path: str) -> str:
|
||||
# Convert {param} to named capture groups
|
||||
# handle {param:path} as well which allows for forward slashes in the param value
|
||||
pattern = re.sub(
|
||||
r"{(\w+)(?::path)?}",
|
||||
lambda m: f"(?P<{m.group(1)}>{'[^/]+' if not m.group(0).endswith(':path') else '.+'})",
|
||||
path,
|
||||
)
|
||||
|
||||
return f"^{pattern}$"
|
||||
|
||||
for api, api_endpoints in endpoints.items():
|
||||
if api not in self.impls:
|
||||
continue
|
||||
for endpoint in api_endpoints:
|
||||
impl = self.impls[api]
|
||||
func = getattr(impl, endpoint.name)
|
||||
if endpoint.method not in endpoint_impls:
|
||||
endpoint_impls[endpoint.method] = {}
|
||||
endpoint_impls[endpoint.method][_convert_path_to_regex(endpoint.route)] = func
|
||||
|
||||
self.endpoint_impls = endpoint_impls
|
||||
self.endpoint_impls = initialize_endpoint_impls(self.impls)
|
||||
return True
|
||||
|
||||
async def request(
|
||||
|
@ -290,32 +268,6 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
)
|
||||
return response
|
||||
|
||||
def _find_matching_endpoint(self, method: str, path: str) -> tuple[Any, dict]:
|
||||
"""Find the matching endpoint implementation for a given method and path.
|
||||
|
||||
Args:
|
||||
method: HTTP method (GET, POST, etc.)
|
||||
path: URL path to match against
|
||||
|
||||
Returns:
|
||||
A tuple of (endpoint_function, path_params)
|
||||
|
||||
Raises:
|
||||
ValueError: If no matching endpoint is found
|
||||
"""
|
||||
impls = self.endpoint_impls.get(method)
|
||||
if not impls:
|
||||
raise ValueError(f"No endpoint found for {path}")
|
||||
|
||||
for regex, func in impls.items():
|
||||
match = re.match(regex, path)
|
||||
if match:
|
||||
# Extract named groups from the regex match
|
||||
path_params = match.groupdict()
|
||||
return func, path_params
|
||||
|
||||
raise ValueError(f"No endpoint found for {path}")
|
||||
|
||||
async def _call_non_streaming(
|
||||
self,
|
||||
*,
|
||||
|
@ -326,10 +278,10 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
body = options.params or {}
|
||||
body |= options.json_data or {}
|
||||
|
||||
matched_func, path_params = self._find_matching_endpoint(options.method, path)
|
||||
matched_func, path_params, route = find_matching_endpoint(options.method, path, self.endpoint_impls)
|
||||
body |= path_params
|
||||
body = self._convert_body(path, options.method, body)
|
||||
await start_trace(options.url, {"__location__": "library_client"})
|
||||
await start_trace(route, {"__location__": "library_client"})
|
||||
try:
|
||||
result = await matched_func(**body)
|
||||
finally:
|
||||
|
@ -371,13 +323,13 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
path = options.url
|
||||
body = options.params or {}
|
||||
body |= options.json_data or {}
|
||||
func, path_params = self._find_matching_endpoint(options.method, path)
|
||||
func, path_params, route = find_matching_endpoint(options.method, path, self.endpoint_impls)
|
||||
body |= path_params
|
||||
|
||||
body = self._convert_body(path, options.method, body)
|
||||
|
||||
async def gen():
|
||||
await start_trace(options.url, {"__location__": "library_client"})
|
||||
await start_trace(route, {"__location__": "library_client"})
|
||||
try:
|
||||
async for chunk in await func(**body):
|
||||
data = json.dumps(convert_pydantic_to_json_value(chunk))
|
||||
|
@ -422,7 +374,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
if not body:
|
||||
return {}
|
||||
|
||||
func, _ = self._find_matching_endpoint(method, path)
|
||||
func, _, _ = find_matching_endpoint(method, path, self.endpoint_impls)
|
||||
sig = inspect.signature(func)
|
||||
|
||||
# Strip NOT_GIVENs to use the defaults in signature
|
||||
|
|
|
@ -7,21 +7,26 @@
|
|||
import contextvars
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, ContextManager, Dict, Optional
|
||||
from typing import Any, ContextManager, Dict, List, Optional
|
||||
|
||||
from .utils.dynamic import instantiate_class_type
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
# Context variable for request provider data
|
||||
# Context variable for request provider data and auth attributes
|
||||
PROVIDER_DATA_VAR = contextvars.ContextVar("provider_data", default=None)
|
||||
|
||||
|
||||
class RequestProviderDataContext(ContextManager):
|
||||
"""Context manager for request provider data"""
|
||||
|
||||
def __init__(self, provider_data: Optional[Dict[str, Any]] = None):
|
||||
self.provider_data = provider_data
|
||||
def __init__(
|
||||
self, provider_data: Optional[Dict[str, Any]] = None, auth_attributes: Optional[Dict[str, List[str]]] = None
|
||||
):
|
||||
self.provider_data = provider_data or {}
|
||||
if auth_attributes:
|
||||
self.provider_data["__auth_attributes"] = auth_attributes
|
||||
|
||||
self.token = None
|
||||
|
||||
def __enter__(self):
|
||||
|
@ -80,7 +85,17 @@ def parse_request_provider_data(headers: Dict[str, str]) -> Optional[Dict[str, A
|
|||
return None
|
||||
|
||||
|
||||
def request_provider_data_context(headers: Dict[str, str]) -> ContextManager:
|
||||
"""Context manager that sets request provider data from headers for the duration of the context"""
|
||||
def request_provider_data_context(
|
||||
headers: Dict[str, str], auth_attributes: Optional[Dict[str, List[str]]] = None
|
||||
) -> ContextManager:
|
||||
"""Context manager that sets request provider data from headers and auth attributes for the duration of the context"""
|
||||
provider_data = parse_request_provider_data(headers)
|
||||
return RequestProviderDataContext(provider_data)
|
||||
return RequestProviderDataContext(provider_data, auth_attributes)
|
||||
|
||||
|
||||
def get_auth_attributes() -> Optional[Dict[str, List[str]]]:
|
||||
"""Helper to retrieve auth attributes from the provider data context"""
|
||||
provider_data = PROVIDER_DATA_VAR.get()
|
||||
if not provider_data:
|
||||
return None
|
||||
return provider_data.get("__auth_attributes")
|
||||
|
|
|
@ -14,13 +14,7 @@ from llama_stack.apis.common.content_types import (
|
|||
)
|
||||
from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse
|
||||
from llama_stack.apis.datasets import DatasetPurpose, DataSource
|
||||
from llama_stack.apis.eval import (
|
||||
BenchmarkConfig,
|
||||
Eval,
|
||||
EvaluateResponse,
|
||||
Job,
|
||||
JobStatus,
|
||||
)
|
||||
from llama_stack.apis.eval import BenchmarkConfig, Eval, EvaluateResponse, Job
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseEventType,
|
||||
|
@ -623,7 +617,7 @@ class EvalRouter(Eval):
|
|||
self,
|
||||
benchmark_id: str,
|
||||
job_id: str,
|
||||
) -> Optional[JobStatus]:
|
||||
) -> Job:
|
||||
logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}")
|
||||
return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id)
|
||||
|
||||
|
|
|
@ -41,11 +41,22 @@ from llama_stack.apis.tools import (
|
|||
ToolHost,
|
||||
)
|
||||
from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs
|
||||
from llama_stack.distribution.access_control import check_access
|
||||
from llama_stack.distribution.datatypes import (
|
||||
AccessAttributes,
|
||||
BenchmarkWithACL,
|
||||
DatasetWithACL,
|
||||
ModelWithACL,
|
||||
RoutableObject,
|
||||
RoutableObjectWithProvider,
|
||||
RoutedProtocol,
|
||||
ScoringFnWithACL,
|
||||
ShieldWithACL,
|
||||
ToolGroupWithACL,
|
||||
ToolWithACL,
|
||||
VectorDBWithACL,
|
||||
)
|
||||
from llama_stack.distribution.request_headers import get_auth_attributes
|
||||
from llama_stack.distribution.store import DistributionRegistry
|
||||
from llama_stack.providers.datatypes import Api, RoutingTable
|
||||
|
||||
|
@ -186,6 +197,11 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
if not obj:
|
||||
return None
|
||||
|
||||
# Check if user has permission to access this object
|
||||
if not check_access(obj.identifier, getattr(obj, "access_attributes", None), get_auth_attributes()):
|
||||
logger.debug(f"Access denied to {type} '{identifier}' based on attribute mismatch")
|
||||
return None
|
||||
|
||||
return obj
|
||||
|
||||
async def unregister_object(self, obj: RoutableObjectWithProvider) -> None:
|
||||
|
@ -202,6 +218,13 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
|
||||
p = self.impls_by_provider_id[obj.provider_id]
|
||||
|
||||
# If object supports access control but no attributes set, use creator's attributes
|
||||
if not obj.access_attributes:
|
||||
creator_attributes = get_auth_attributes()
|
||||
if creator_attributes:
|
||||
obj.access_attributes = AccessAttributes(**creator_attributes)
|
||||
logger.info(f"Setting access attributes for {obj.type} '{obj.identifier}' based on creator's identity")
|
||||
|
||||
registered_obj = await register_object_with_provider(obj, p)
|
||||
# TODO: This needs to be fixed for all APIs once they return the registered object
|
||||
if obj.type == ResourceType.model.value:
|
||||
|
@ -214,7 +237,17 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
|
||||
async def get_all_with_type(self, type: str) -> List[RoutableObjectWithProvider]:
|
||||
objs = await self.dist_registry.get_all()
|
||||
return [obj for obj in objs if obj.type == type]
|
||||
filtered_objs = [obj for obj in objs if obj.type == type]
|
||||
|
||||
# Apply attribute-based access control filtering
|
||||
if filtered_objs:
|
||||
filtered_objs = [
|
||||
obj
|
||||
for obj in filtered_objs
|
||||
if check_access(obj.identifier, getattr(obj, "access_attributes", None), get_auth_attributes())
|
||||
]
|
||||
|
||||
return filtered_objs
|
||||
|
||||
|
||||
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||
|
@ -251,7 +284,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
model_type = ModelType.llm
|
||||
if "embedding_dimension" not in metadata and model_type == ModelType.embedding:
|
||||
raise ValueError("Embedding model must have an embedding dimension in its metadata")
|
||||
model = Model(
|
||||
model = ModelWithACL(
|
||||
identifier=model_id,
|
||||
provider_resource_id=provider_model_id,
|
||||
provider_id=provider_id,
|
||||
|
@ -297,7 +330,7 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
|||
)
|
||||
if params is None:
|
||||
params = {}
|
||||
shield = Shield(
|
||||
shield = ShieldWithACL(
|
||||
identifier=shield_id,
|
||||
provider_resource_id=provider_shield_id,
|
||||
provider_id=provider_id,
|
||||
|
@ -351,7 +384,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
|||
"embedding_model": embedding_model,
|
||||
"embedding_dimension": model.metadata["embedding_dimension"],
|
||||
}
|
||||
vector_db = TypeAdapter(VectorDB).validate_python(vector_db_data)
|
||||
vector_db = TypeAdapter(VectorDBWithACL).validate_python(vector_db_data)
|
||||
await self.register_object(vector_db)
|
||||
return vector_db
|
||||
|
||||
|
@ -405,7 +438,7 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
|||
if metadata is None:
|
||||
metadata = {}
|
||||
|
||||
dataset = Dataset(
|
||||
dataset = DatasetWithACL(
|
||||
identifier=dataset_id,
|
||||
provider_resource_id=provider_dataset_id,
|
||||
provider_id=provider_id,
|
||||
|
@ -452,7 +485,7 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
|
|||
raise ValueError(
|
||||
"No provider specified and multiple providers available. Please specify a provider_id."
|
||||
)
|
||||
scoring_fn = ScoringFn(
|
||||
scoring_fn = ScoringFnWithACL(
|
||||
identifier=scoring_fn_id,
|
||||
description=description,
|
||||
return_type=return_type,
|
||||
|
@ -494,7 +527,7 @@ class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
|
|||
)
|
||||
if provider_benchmark_id is None:
|
||||
provider_benchmark_id = benchmark_id
|
||||
benchmark = Benchmark(
|
||||
benchmark = BenchmarkWithACL(
|
||||
identifier=benchmark_id,
|
||||
dataset_id=dataset_id,
|
||||
scoring_functions=scoring_functions,
|
||||
|
@ -537,7 +570,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
|
||||
for tool_def in tool_defs:
|
||||
tools.append(
|
||||
Tool(
|
||||
ToolWithACL(
|
||||
identifier=tool_def.name,
|
||||
toolgroup_id=toolgroup_id,
|
||||
description=tool_def.description or "",
|
||||
|
@ -562,7 +595,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
await self.register_object(tool)
|
||||
|
||||
await self.dist_registry.register(
|
||||
ToolGroup(
|
||||
ToolGroupWithACL(
|
||||
identifier=toolgroup_id,
|
||||
provider_id=provider_id,
|
||||
provider_resource_id=toolgroup_id,
|
||||
|
@ -575,7 +608,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
tool_group = await self.get_tool_group(toolgroup_id)
|
||||
if tool_group is None:
|
||||
raise ValueError(f"Tool group {toolgroup_id} not found")
|
||||
tools = await self.list_tools(toolgroup_id).data
|
||||
tools = (await self.list_tools(toolgroup_id)).data
|
||||
for tool in tools:
|
||||
await self.unregister_object(tool)
|
||||
await self.unregister_object(tool_group)
|
||||
|
|
|
@ -5,16 +5,118 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
from typing import Dict, List, Optional
|
||||
from urllib.parse import parse_qs
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.distribution.datatypes import AccessAttributes
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = get_logger(name=__name__, category="auth")
|
||||
|
||||
|
||||
class AuthRequestContext(BaseModel):
|
||||
path: str = Field(description="The path of the request being authenticated")
|
||||
|
||||
headers: Dict[str, str] = Field(description="HTTP headers from the original request (excluding Authorization)")
|
||||
|
||||
params: Dict[str, List[str]] = Field(
|
||||
description="Query parameters from the original request, parsed as dictionary of lists"
|
||||
)
|
||||
|
||||
|
||||
class AuthRequest(BaseModel):
|
||||
api_key: str = Field(description="The API key extracted from the Authorization header")
|
||||
|
||||
request: AuthRequestContext = Field(description="Context information about the request being authenticated")
|
||||
|
||||
|
||||
class AuthResponse(BaseModel):
|
||||
"""The format of the authentication response from the auth endpoint."""
|
||||
|
||||
access_attributes: Optional[AccessAttributes] = Field(
|
||||
default=None,
|
||||
description="""
|
||||
Structured user attributes for attribute-based access control.
|
||||
|
||||
These attributes determine which resources the user can access.
|
||||
The model provides standard categories like "roles", "teams", "projects", and "namespaces".
|
||||
Each attribute category contains a list of values that the user has for that category.
|
||||
During access control checks, these values are compared against resource requirements.
|
||||
|
||||
Example with standard categories:
|
||||
```json
|
||||
{
|
||||
"roles": ["admin", "data-scientist"],
|
||||
"teams": ["ml-team"],
|
||||
"projects": ["llama-3"],
|
||||
"namespaces": ["research"]
|
||||
}
|
||||
```
|
||||
""",
|
||||
)
|
||||
|
||||
message: Optional[str] = Field(
|
||||
default=None, description="Optional message providing additional context about the authentication result."
|
||||
)
|
||||
|
||||
|
||||
class AuthenticationMiddleware:
|
||||
"""Middleware that authenticates requests using an external auth endpoint.
|
||||
|
||||
This middleware:
|
||||
1. Extracts the Bearer token from the Authorization header
|
||||
2. Sends it to the configured auth endpoint along with request details
|
||||
3. Validates the response and extracts user attributes
|
||||
4. Makes these attributes available to the route handlers for access control
|
||||
|
||||
Authentication Request Format:
|
||||
```json
|
||||
{
|
||||
"api_key": "the-api-key-extracted-from-auth-header",
|
||||
"request": {
|
||||
"path": "/models/list",
|
||||
"headers": {
|
||||
"content-type": "application/json",
|
||||
"user-agent": "..."
|
||||
// All headers except Authorization
|
||||
},
|
||||
"params": {
|
||||
"limit": ["100"],
|
||||
"offset": ["0"]
|
||||
// Query parameters as key -> list of values
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Expected Auth Endpoint Response Format:
|
||||
```json
|
||||
{
|
||||
"access_attributes": { // Structured attribute format
|
||||
"roles": ["admin", "user"],
|
||||
"teams": ["ml-team", "nlp-team"],
|
||||
"projects": ["llama-3", "project-x"],
|
||||
"namespaces": ["research"]
|
||||
},
|
||||
"message": "Optional message about auth result"
|
||||
}
|
||||
```
|
||||
|
||||
Attribute-Based Access Control:
|
||||
The attributes returned by the auth endpoint are used to determine which
|
||||
resources the user can access. Resources can specify required attributes
|
||||
using the access_attributes field. For a user to access a resource:
|
||||
|
||||
1. All attribute categories specified in the resource must be present in the user's attributes
|
||||
2. For each category, the user must have at least one matching value
|
||||
|
||||
If the auth endpoint doesn't return any attributes, the user will only be able to
|
||||
access resources that don't have access_attributes defined.
|
||||
"""
|
||||
|
||||
def __init__(self, app, auth_endpoint):
|
||||
self.app = app
|
||||
self.auth_endpoint = auth_endpoint
|
||||
|
@ -32,25 +134,57 @@ class AuthenticationMiddleware:
|
|||
path = scope.get("path", "")
|
||||
request_headers = {k.decode(): v.decode() for k, v in headers.items()}
|
||||
|
||||
# Remove sensitive headers
|
||||
if "authorization" in request_headers:
|
||||
del request_headers["authorization"]
|
||||
|
||||
query_string = scope.get("query_string", b"").decode()
|
||||
params = parse_qs(query_string)
|
||||
|
||||
auth_data = {
|
||||
"api_key": api_key,
|
||||
"request": {
|
||||
"path": path,
|
||||
"headers": request_headers,
|
||||
"params": params,
|
||||
},
|
||||
}
|
||||
# Build the auth request model
|
||||
auth_request = AuthRequest(
|
||||
api_key=api_key,
|
||||
request=AuthRequestContext(
|
||||
path=path,
|
||||
headers=request_headers,
|
||||
params=params,
|
||||
),
|
||||
)
|
||||
|
||||
# Validate with authentication endpoint
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(self.auth_endpoint, json=auth_data)
|
||||
response = await client.post(
|
||||
self.auth_endpoint,
|
||||
json=auth_request.model_dump(),
|
||||
timeout=10.0, # Add a reasonable timeout
|
||||
)
|
||||
if response.status_code != 200:
|
||||
logger.warning(f"Authentication failed: {response.status_code}")
|
||||
return await self._send_auth_error(send, "Authentication failed")
|
||||
|
||||
# Parse and validate the auth response
|
||||
try:
|
||||
response_data = response.json()
|
||||
auth_response = AuthResponse(**response_data)
|
||||
|
||||
# Store attributes in request scope for access control
|
||||
if auth_response.access_attributes:
|
||||
user_attributes = auth_response.access_attributes.model_dump(exclude_none=True)
|
||||
else:
|
||||
logger.warning("No access attributes, setting namespace to api_key by default")
|
||||
user_attributes = {
|
||||
"namespaces": [api_key],
|
||||
}
|
||||
|
||||
scope["user_attributes"] = user_attributes
|
||||
logger.debug(f"Authentication successful: {len(user_attributes)} attributes")
|
||||
except Exception:
|
||||
logger.exception("Error parsing authentication response")
|
||||
return await self._send_auth_error(send, "Invalid authentication response format")
|
||||
except httpx.TimeoutException:
|
||||
logger.exception("Authentication request timed out")
|
||||
return await self._send_auth_error(send, "Authentication service timeout")
|
||||
except Exception:
|
||||
logger.exception("Error during authentication")
|
||||
return await self._send_auth_error(send, "Authentication service error")
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import inspect
|
||||
import re
|
||||
from typing import Dict, List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
@ -19,6 +20,7 @@ class ApiEndpoint(BaseModel):
|
|||
route: str
|
||||
method: str
|
||||
name: str
|
||||
descriptive_name: str | None = None
|
||||
|
||||
|
||||
def toolgroup_protocol_map():
|
||||
|
@ -58,8 +60,69 @@ def get_all_api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
|
|||
method = "delete"
|
||||
else:
|
||||
method = "post"
|
||||
endpoints.append(ApiEndpoint(route=route, method=method, name=name))
|
||||
endpoints.append(
|
||||
ApiEndpoint(route=route, method=method, name=name, descriptive_name=webmethod.descriptive_name)
|
||||
)
|
||||
|
||||
apis[api] = endpoints
|
||||
|
||||
return apis
|
||||
|
||||
|
||||
def initialize_endpoint_impls(impls):
|
||||
endpoints = get_all_api_endpoints()
|
||||
endpoint_impls = {}
|
||||
|
||||
def _convert_path_to_regex(path: str) -> str:
|
||||
# Convert {param} to named capture groups
|
||||
# handle {param:path} as well which allows for forward slashes in the param value
|
||||
pattern = re.sub(
|
||||
r"{(\w+)(?::path)?}",
|
||||
lambda m: f"(?P<{m.group(1)}>{'[^/]+' if not m.group(0).endswith(':path') else '.+'})",
|
||||
path,
|
||||
)
|
||||
|
||||
return f"^{pattern}$"
|
||||
|
||||
for api, api_endpoints in endpoints.items():
|
||||
if api not in impls:
|
||||
continue
|
||||
for endpoint in api_endpoints:
|
||||
impl = impls[api]
|
||||
func = getattr(impl, endpoint.name)
|
||||
if endpoint.method not in endpoint_impls:
|
||||
endpoint_impls[endpoint.method] = {}
|
||||
endpoint_impls[endpoint.method][_convert_path_to_regex(endpoint.route)] = (
|
||||
func,
|
||||
endpoint.descriptive_name or endpoint.route,
|
||||
)
|
||||
|
||||
return endpoint_impls
|
||||
|
||||
|
||||
def find_matching_endpoint(method, path, endpoint_impls):
|
||||
"""Find the matching endpoint implementation for a given method and path.
|
||||
|
||||
Args:
|
||||
method: HTTP method (GET, POST, etc.)
|
||||
path: URL path to match against
|
||||
endpoint_impls: A dictionary of endpoint implementations
|
||||
|
||||
Returns:
|
||||
A tuple of (endpoint_function, path_params, descriptive_name)
|
||||
|
||||
Raises:
|
||||
ValueError: If no matching endpoint is found
|
||||
"""
|
||||
impls = endpoint_impls.get(method.lower())
|
||||
if not impls:
|
||||
raise ValueError(f"No endpoint found for {path}")
|
||||
|
||||
for regex, (func, descriptive_name) in impls.items():
|
||||
match = re.match(regex, path)
|
||||
if match:
|
||||
# Extract named groups from the regex match
|
||||
path_params = match.groupdict()
|
||||
return func, path_params, descriptive_name
|
||||
|
||||
raise ValueError(f"No endpoint found for {path}")
|
||||
|
|
|
@ -32,6 +32,10 @@ from llama_stack.distribution.request_headers import (
|
|||
request_provider_data_context,
|
||||
)
|
||||
from llama_stack.distribution.resolver import InvalidProviderError
|
||||
from llama_stack.distribution.server.endpoints import (
|
||||
find_matching_endpoint,
|
||||
initialize_endpoint_impls,
|
||||
)
|
||||
from llama_stack.distribution.stack import (
|
||||
construct_stack,
|
||||
redact_sensitive_fields,
|
||||
|
@ -179,8 +183,11 @@ async def sse_generator(event_gen):
|
|||
|
||||
def create_dynamic_typed_route(func: Any, method: str, route: str):
|
||||
async def endpoint(request: Request, **kwargs):
|
||||
# Use context manager for request provider data
|
||||
with request_provider_data_context(request.headers):
|
||||
# Get auth attributes from the request scope
|
||||
user_attributes = request.scope.get("user_attributes", {})
|
||||
|
||||
# Use context manager with both provider data and auth attributes
|
||||
with request_provider_data_context(request.headers, user_attributes):
|
||||
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
|
||||
|
||||
try:
|
||||
|
@ -219,14 +226,30 @@ def create_dynamic_typed_route(func: Any, method: str, route: str):
|
|||
|
||||
|
||||
class TracingMiddleware:
|
||||
def __init__(self, app):
|
||||
def __init__(self, app, impls):
|
||||
self.app = app
|
||||
self.impls = impls
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
path = scope.get("path", "")
|
||||
await start_trace(path, {"__location__": "server"})
|
||||
try:
|
||||
if scope.get("type") == "lifespan":
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
path = scope.get("path", "")
|
||||
if not hasattr(self, "endpoint_impls"):
|
||||
self.endpoint_impls = initialize_endpoint_impls(self.impls)
|
||||
_, _, trace_path = find_matching_endpoint(scope.get("method", "GET"), path, self.endpoint_impls)
|
||||
|
||||
trace_context = await start_trace(trace_path, {"__location__": "server", "raw_path": path})
|
||||
|
||||
async def send_with_trace_id(message):
|
||||
if message["type"] == "http.response.start":
|
||||
headers = message.get("headers", [])
|
||||
headers.append([b"x-trace-id", str(trace_context.trace_id).encode()])
|
||||
message["headers"] = headers
|
||||
await send(message)
|
||||
|
||||
try:
|
||||
return await self.app(scope, receive, send_with_trace_id)
|
||||
finally:
|
||||
await end_trace()
|
||||
|
||||
|
@ -348,7 +371,6 @@ def main():
|
|||
logger.info(yaml.dump(safe_config, indent=2))
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
app.add_middleware(TracingMiddleware)
|
||||
if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"):
|
||||
app.add_middleware(ClientVersionMiddleware)
|
||||
|
||||
|
@ -366,7 +388,7 @@ def main():
|
|||
if Api.telemetry in impls:
|
||||
setup_logger(impls[Api.telemetry])
|
||||
else:
|
||||
setup_logger(TelemetryAdapter(TelemetryConfig()))
|
||||
setup_logger(TelemetryAdapter(TelemetryConfig(), {}))
|
||||
|
||||
all_endpoints = get_all_api_endpoints()
|
||||
|
||||
|
@ -412,6 +434,7 @@ def main():
|
|||
app.exception_handler(Exception)(global_exception_handler)
|
||||
|
||||
app.__llama_stack_impls__ = impls
|
||||
app.add_middleware(TracingMiddleware, impls=impls)
|
||||
|
||||
import uvicorn
|
||||
|
||||
|
|
|
@ -12,9 +12,12 @@ import pydantic
|
|||
|
||||
from llama_stack.distribution.datatypes import KVStoreConfig, RoutableObjectWithProvider
|
||||
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
|
||||
logger = get_logger(__name__, category="core")
|
||||
|
||||
|
||||
class DistributionRegistry(Protocol):
|
||||
async def get_all(self) -> List[RoutableObjectWithProvider]: ...
|
||||
|
@ -47,8 +50,13 @@ def _parse_registry_values(values: List[str]) -> List[RoutableObjectWithProvider
|
|||
"""Utility function to parse registry values into RoutableObjectWithProvider objects."""
|
||||
all_objects = []
|
||||
for value in values:
|
||||
try:
|
||||
obj = pydantic.TypeAdapter(RoutableObjectWithProvider).validate_json(value)
|
||||
all_objects.append(obj)
|
||||
except pydantic.ValidationError as e:
|
||||
logger.error(f"Error parsing registry value, raw value: {value}. Error: {e}")
|
||||
continue
|
||||
|
||||
return all_objects
|
||||
|
||||
|
||||
|
@ -73,7 +81,11 @@ class DiskDistributionRegistry(DistributionRegistry):
|
|||
if not json_str:
|
||||
return None
|
||||
|
||||
try:
|
||||
return pydantic.TypeAdapter(RoutableObjectWithProvider).validate_json(json_str)
|
||||
except pydantic.ValidationError as e:
|
||||
logger.error(f"Error parsing registry value for {type}:{identifier}, raw value: {json_str}. Error: {e}")
|
||||
return None
|
||||
|
||||
async def update(self, obj: RoutableObjectWithProvider) -> None:
|
||||
await self.kvstore.set(
|
||||
|
|
|
@ -5,9 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import streamlit as st
|
||||
from llama_stack_client.lib.agents.agent import Agent
|
||||
from llama_stack_client.lib.agents.event_logger import EventLogger
|
||||
from llama_stack_client.types.shared.document import Document
|
||||
from llama_stack_client import Agent, AgentEventLogger, RAGDocument
|
||||
|
||||
from llama_stack.distribution.ui.modules.api import llama_stack_api
|
||||
from llama_stack.distribution.ui.modules.utils import data_url_from_file
|
||||
|
@ -35,7 +33,7 @@ def rag_chat_page():
|
|||
)
|
||||
if st.button("Create Vector Database"):
|
||||
documents = [
|
||||
Document(
|
||||
RAGDocument(
|
||||
document_id=uploaded_file.name,
|
||||
content=data_url_from_file(uploaded_file),
|
||||
)
|
||||
|
@ -167,7 +165,7 @@ def rag_chat_page():
|
|||
message_placeholder = st.empty()
|
||||
full_response = ""
|
||||
retrieval_response = ""
|
||||
for log in EventLogger().log(response):
|
||||
for log in AgentEventLogger().log(response):
|
||||
log.print()
|
||||
if log.role == "tool_execution":
|
||||
retrieval_response += log.content.replace("====", "").strip()
|
||||
|
|
|
@ -186,13 +186,11 @@ class TopKSamplingStrategy(BaseModel):
|
|||
top_k: int = Field(..., ge=1)
|
||||
|
||||
|
||||
SamplingStrategy = register_schema(
|
||||
Annotated[
|
||||
SamplingStrategy = Annotated[
|
||||
Union[GreedySamplingStrategy, TopPSamplingStrategy, TopKSamplingStrategy],
|
||||
Field(discriminator="type"),
|
||||
],
|
||||
name="SamplingStrategy",
|
||||
)
|
||||
]
|
||||
register_schema(SamplingStrategy, name="SamplingStrategy")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
@ -244,6 +244,7 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
|
|||
template_str = textwrap.dedent(
|
||||
"""
|
||||
If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
|
||||
For a boolean parameter, be sure to use `True` or `False` (capitalized) for the value.
|
||||
You SHOULD NOT include any other text in the response.
|
||||
|
||||
Here is a list of functions in JSON format that you can invoke.
|
||||
|
|
|
@ -15,8 +15,11 @@ import json
|
|||
import re
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.datatypes import BuiltinTool, RecursiveType, ToolCall, ToolPromptFormat
|
||||
|
||||
logger = get_logger(name=__name__, category="inference")
|
||||
|
||||
BUILTIN_TOOL_PATTERN = r'\b(?P<tool_name>\w+)\.call\(query="(?P<query>[^"]*)"\)'
|
||||
CUSTOM_TOOL_CALL_PATTERN = re.compile(r"<function=(?P<function_name>[^}]+)>(?P<args>{.*?})")
|
||||
|
||||
|
@ -92,7 +95,15 @@ def parse_python_list_for_function_calls(input_string):
|
|||
|
||||
# Extract keyword arguments
|
||||
for keyword in node.keywords:
|
||||
try:
|
||||
function_args[keyword.arg] = ast.literal_eval(keyword.value)
|
||||
except ValueError as e:
|
||||
logger.error(
|
||||
f"Error parsing tool call argument '{keyword.arg}': {e}, full input string: '{input_string}'"
|
||||
)
|
||||
raise ValueError(
|
||||
f"Error parsing tool call argument '{keyword.arg}', full input string: '{input_string}'"
|
||||
) from e
|
||||
|
||||
result.append((function_name, function_args))
|
||||
|
||||
|
|
|
@ -6,14 +6,12 @@
|
|||
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import secrets
|
||||
import string
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import AsyncGenerator, List, Optional, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
|
||||
|
@ -60,7 +58,6 @@ from llama_stack.apis.inference import (
|
|||
)
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.tools import (
|
||||
RAGDocument,
|
||||
ToolGroups,
|
||||
ToolInvocationResult,
|
||||
ToolRuntime,
|
||||
|
@ -180,23 +177,27 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
return messages
|
||||
|
||||
async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator:
|
||||
await self._initialize_tools(request.toolgroups)
|
||||
async with tracing.span("create_and_execute_turn") as span:
|
||||
span = tracing.get_current_span()
|
||||
if span:
|
||||
span.set_attribute("session_id", request.session_id)
|
||||
span.set_attribute("agent_id", self.agent_id)
|
||||
span.set_attribute("request", request.model_dump_json())
|
||||
turn_id = str(uuid.uuid4())
|
||||
span.set_attribute("turn_id", turn_id)
|
||||
|
||||
await self._initialize_tools(request.toolgroups)
|
||||
async for chunk in self._run_turn(request, turn_id):
|
||||
yield chunk
|
||||
|
||||
async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator:
|
||||
await self._initialize_tools()
|
||||
async with tracing.span("resume_turn") as span:
|
||||
span = tracing.get_current_span()
|
||||
if span:
|
||||
span.set_attribute("agent_id", self.agent_id)
|
||||
span.set_attribute("session_id", request.session_id)
|
||||
span.set_attribute("turn_id", request.turn_id)
|
||||
span.set_attribute("request", request.model_dump_json())
|
||||
span.set_attribute("turn_id", request.turn_id)
|
||||
|
||||
await self._initialize_tools()
|
||||
async for chunk in self._run_turn(request):
|
||||
yield chunk
|
||||
|
||||
|
@ -449,8 +450,16 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
stream: bool = False,
|
||||
documents: Optional[List[Document]] = None,
|
||||
) -> AsyncGenerator:
|
||||
# if document is passed in a turn, we parse the raw text of the document
|
||||
# and sent it as a user message
|
||||
if documents:
|
||||
await self.handle_documents(session_id, documents, input_messages)
|
||||
contexts = []
|
||||
for document in documents:
|
||||
raw_document_text = await get_raw_document_text(document)
|
||||
contexts.append(raw_document_text)
|
||||
|
||||
attached_context = "\n".join(contexts)
|
||||
input_messages[-1].context = attached_context
|
||||
|
||||
session_info = await self.storage.get_session_info(session_id)
|
||||
# if the session has a memory bank id, let the memory tool use it
|
||||
|
@ -825,7 +834,10 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
tool_name_to_args[tool_def.identifier] = toolgroup_to_args.get(toolgroup_name, {})
|
||||
|
||||
self.tool_defs, self.tool_name_to_args = list(tool_name_to_def.values()), tool_name_to_args
|
||||
self.tool_defs, self.tool_name_to_args = (
|
||||
list(tool_name_to_def.values()),
|
||||
tool_name_to_args,
|
||||
)
|
||||
|
||||
def _parse_toolgroup_name(self, toolgroup_name_with_maybe_tool_name: str) -> tuple[str, Optional[str]]:
|
||||
"""Parse a toolgroup name into its components.
|
||||
|
@ -876,144 +888,27 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
logger.debug(f"tool call {tool_name_str} completed with result: {result}")
|
||||
return result
|
||||
|
||||
async def handle_documents(
|
||||
self,
|
||||
session_id: str,
|
||||
documents: List[Document],
|
||||
input_messages: List[Message],
|
||||
) -> None:
|
||||
memory_tool = any(tool_def.tool_name == MEMORY_QUERY_TOOL for tool_def in self.tool_defs)
|
||||
code_interpreter_tool = any(tool_def.tool_name == BuiltinTool.code_interpreter for tool_def in self.tool_defs)
|
||||
content_items = []
|
||||
url_items = []
|
||||
pattern = re.compile("^(https?://|file://|data:)")
|
||||
for d in documents:
|
||||
if isinstance(d.content, URL):
|
||||
url_items.append(d.content)
|
||||
elif pattern.match(d.content):
|
||||
url_items.append(URL(uri=d.content))
|
||||
else:
|
||||
content_items.append(d)
|
||||
|
||||
# Save the contents to a tempdir and use its path as a URL if code interpreter is present
|
||||
if code_interpreter_tool:
|
||||
for c in content_items:
|
||||
temp_file_path = os.path.join(self.tempdir, f"{make_random_string()}.txt")
|
||||
with open(temp_file_path, "w") as temp_file:
|
||||
temp_file.write(c.content)
|
||||
url_items.append(URL(uri=f"file://{temp_file_path}"))
|
||||
|
||||
if memory_tool and code_interpreter_tool:
|
||||
# if both memory and code_interpreter are available, we download the URLs
|
||||
# and attach the data to the last message.
|
||||
await attachment_message(self.tempdir, url_items, input_messages[-1])
|
||||
# Since memory is present, add all the data to the memory bank
|
||||
await self.add_to_session_vector_db(session_id, documents)
|
||||
elif code_interpreter_tool:
|
||||
# if only code_interpreter is available, we download the URLs to a tempdir
|
||||
# and attach the path to them as a message to inference with the
|
||||
# assumption that the model invokes the code_interpreter tool with the path
|
||||
await attachment_message(self.tempdir, url_items, input_messages[-1])
|
||||
elif memory_tool:
|
||||
# if only memory is available, we load the data from the URLs and content items to the memory bank
|
||||
await self.add_to_session_vector_db(session_id, documents)
|
||||
else:
|
||||
# if no memory or code_interpreter tool is available,
|
||||
# we try to load the data from the URLs and content items as a message to inference
|
||||
# and add it to the last message's context
|
||||
input_messages[-1].context = "\n".join(
|
||||
[doc.content for doc in content_items] + await load_data_from_urls(url_items)
|
||||
)
|
||||
|
||||
async def _ensure_vector_db(self, session_id: str) -> str:
|
||||
session_info = await self.storage.get_session_info(session_id)
|
||||
if session_info is None:
|
||||
raise ValueError(f"Session {session_id} not found")
|
||||
|
||||
if session_info.vector_db_id is None:
|
||||
vector_db_id = f"vector_db_{session_id}"
|
||||
|
||||
# TODO: the semantic for registration is definitely not "creation"
|
||||
# so we need to fix it if we expect the agent to create a new vector db
|
||||
# for each session
|
||||
await self.vector_io_api.register_vector_db(
|
||||
vector_db_id=vector_db_id,
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
)
|
||||
await self.storage.add_vector_db_to_session(session_id, vector_db_id)
|
||||
else:
|
||||
vector_db_id = session_info.vector_db_id
|
||||
|
||||
return vector_db_id
|
||||
|
||||
async def add_to_session_vector_db(self, session_id: str, data: List[Document]) -> None:
|
||||
vector_db_id = await self._ensure_vector_db(session_id)
|
||||
documents = [
|
||||
RAGDocument(
|
||||
document_id=str(uuid.uuid4()),
|
||||
content=a.content,
|
||||
mime_type=a.mime_type,
|
||||
metadata={},
|
||||
)
|
||||
for a in data
|
||||
]
|
||||
await self.tool_runtime_api.rag_tool.insert(
|
||||
documents=documents,
|
||||
vector_db_id=vector_db_id,
|
||||
chunk_size_in_tokens=512,
|
||||
)
|
||||
|
||||
|
||||
async def load_data_from_urls(urls: List[URL]) -> List[str]:
|
||||
data = []
|
||||
for url in urls:
|
||||
uri = url.uri
|
||||
if uri.startswith("file://"):
|
||||
filepath = uri[len("file://") :]
|
||||
with open(filepath, "r") as f:
|
||||
data.append(f.read())
|
||||
elif uri.startswith("http"):
|
||||
async def load_data_from_url(url: str) -> str:
|
||||
if url.startswith("http"):
|
||||
async with httpx.AsyncClient() as client:
|
||||
r = await client.get(uri)
|
||||
r = await client.get(url)
|
||||
resp = r.text
|
||||
data.append(resp)
|
||||
return data
|
||||
return resp
|
||||
raise ValueError(f"Unexpected URL: {type(url)}")
|
||||
|
||||
|
||||
async def attachment_message(tempdir: str, urls: List[URL], message: UserMessage) -> None:
|
||||
contents = []
|
||||
|
||||
for url in urls:
|
||||
uri = url.uri
|
||||
if uri.startswith("file://"):
|
||||
filepath = uri[len("file://") :]
|
||||
elif uri.startswith("http"):
|
||||
path = urlparse(uri).path
|
||||
basename = os.path.basename(path)
|
||||
filepath = f"{tempdir}/{make_random_string() + basename}"
|
||||
logger.info(f"Downloading {url} -> {filepath}")
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
r = await client.get(uri)
|
||||
resp = r.text
|
||||
with open(filepath, "w") as fp:
|
||||
fp.write(resp)
|
||||
async def get_raw_document_text(document: Document) -> str:
|
||||
if not document.mime_type.startswith("text/"):
|
||||
raise ValueError(f"Unexpected document mime type: {document.mime_type}")
|
||||
if isinstance(document.content, URL):
|
||||
return await load_data_from_url(document.content.uri)
|
||||
elif isinstance(document.content, str):
|
||||
return document.content
|
||||
elif isinstance(document.content, TextContentItem):
|
||||
return document.content.text
|
||||
else:
|
||||
raise ValueError(f"Unsupported URL {url}")
|
||||
|
||||
contents.append(
|
||||
TextContentItem(
|
||||
text=f'# User provided a file accessible to you at "{filepath}"\nYou can use code_interpreter to load and inspect it.'
|
||||
)
|
||||
)
|
||||
|
||||
if isinstance(message.content, list):
|
||||
message.content.extend(contents)
|
||||
else:
|
||||
if isinstance(message.content, str):
|
||||
message.content = [TextContentItem(text=message.content)] + contents
|
||||
else:
|
||||
message.content = [message.content] + contents
|
||||
raise ValueError(f"Unexpected document content type: {type(document.content)}")
|
||||
|
||||
|
||||
def _interpret_content_as_attachment(
|
||||
|
|
|
@ -13,6 +13,9 @@ from typing import List, Optional
|
|||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.agents import ToolExecutionStep, Turn
|
||||
from llama_stack.distribution.access_control import check_access
|
||||
from llama_stack.distribution.datatypes import AccessAttributes
|
||||
from llama_stack.distribution.request_headers import get_auth_attributes
|
||||
from llama_stack.providers.utils.kvstore import KVStore
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
@ -24,6 +27,7 @@ class AgentSessionInfo(BaseModel):
|
|||
# TODO: is this used anywhere?
|
||||
vector_db_id: Optional[str] = None
|
||||
started_at: datetime
|
||||
access_attributes: Optional[AccessAttributes] = None
|
||||
|
||||
|
||||
class AgentPersistence:
|
||||
|
@ -33,11 +37,18 @@ class AgentPersistence:
|
|||
|
||||
async def create_session(self, name: str) -> str:
|
||||
session_id = str(uuid.uuid4())
|
||||
|
||||
# Get current user's auth attributes for new sessions
|
||||
auth_attributes = get_auth_attributes()
|
||||
access_attributes = AccessAttributes(**auth_attributes) if auth_attributes else None
|
||||
|
||||
session_info = AgentSessionInfo(
|
||||
session_id=session_id,
|
||||
session_name=name,
|
||||
started_at=datetime.now(timezone.utc),
|
||||
access_attributes=access_attributes,
|
||||
)
|
||||
|
||||
await self.kvstore.set(
|
||||
key=f"session:{self.agent_id}:{session_id}",
|
||||
value=session_info.model_dump_json(),
|
||||
|
@ -51,12 +62,34 @@ class AgentPersistence:
|
|||
if not value:
|
||||
return None
|
||||
|
||||
return AgentSessionInfo(**json.loads(value))
|
||||
session_info = AgentSessionInfo(**json.loads(value))
|
||||
|
||||
# Check access to session
|
||||
if not self._check_session_access(session_info):
|
||||
return None
|
||||
|
||||
return session_info
|
||||
|
||||
def _check_session_access(self, session_info: AgentSessionInfo) -> bool:
|
||||
"""Check if current user has access to the session."""
|
||||
# Handle backward compatibility for old sessions without access control
|
||||
if not hasattr(session_info, "access_attributes"):
|
||||
return True
|
||||
|
||||
return check_access(session_info.session_id, session_info.access_attributes, get_auth_attributes())
|
||||
|
||||
async def get_session_if_accessible(self, session_id: str) -> Optional[AgentSessionInfo]:
|
||||
"""Get session info if the user has access to it. For internal use by sub-session methods."""
|
||||
session_info = await self.get_session_info(session_id)
|
||||
if not session_info:
|
||||
return None
|
||||
|
||||
return session_info
|
||||
|
||||
async def add_vector_db_to_session(self, session_id: str, vector_db_id: str):
|
||||
session_info = await self.get_session_info(session_id)
|
||||
session_info = await self.get_session_if_accessible(session_id)
|
||||
if session_info is None:
|
||||
raise ValueError(f"Session {session_id} not found")
|
||||
raise ValueError(f"Session {session_id} not found or access denied")
|
||||
|
||||
session_info.vector_db_id = vector_db_id
|
||||
await self.kvstore.set(
|
||||
|
@ -65,12 +98,18 @@ class AgentPersistence:
|
|||
)
|
||||
|
||||
async def add_turn_to_session(self, session_id: str, turn: Turn):
|
||||
if not await self.get_session_if_accessible(session_id):
|
||||
raise ValueError(f"Session {session_id} not found or access denied")
|
||||
|
||||
await self.kvstore.set(
|
||||
key=f"session:{self.agent_id}:{session_id}:{turn.turn_id}",
|
||||
value=turn.model_dump_json(),
|
||||
)
|
||||
|
||||
async def get_session_turns(self, session_id: str) -> List[Turn]:
|
||||
if not await self.get_session_if_accessible(session_id):
|
||||
raise ValueError(f"Session {session_id} not found or access denied")
|
||||
|
||||
values = await self.kvstore.range(
|
||||
start_key=f"session:{self.agent_id}:{session_id}:",
|
||||
end_key=f"session:{self.agent_id}:{session_id}:\xff\xff\xff\xff",
|
||||
|
@ -87,6 +126,9 @@ class AgentPersistence:
|
|||
return turns
|
||||
|
||||
async def get_session_turn(self, session_id: str, turn_id: str) -> Optional[Turn]:
|
||||
if not await self.get_session_if_accessible(session_id):
|
||||
raise ValueError(f"Session {session_id} not found or access denied")
|
||||
|
||||
value = await self.kvstore.get(
|
||||
key=f"session:{self.agent_id}:{session_id}:{turn_id}",
|
||||
)
|
||||
|
@ -95,24 +137,36 @@ class AgentPersistence:
|
|||
return Turn(**json.loads(value))
|
||||
|
||||
async def set_in_progress_tool_call_step(self, session_id: str, turn_id: str, step: ToolExecutionStep):
|
||||
if not await self.get_session_if_accessible(session_id):
|
||||
raise ValueError(f"Session {session_id} not found or access denied")
|
||||
|
||||
await self.kvstore.set(
|
||||
key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}",
|
||||
value=step.model_dump_json(),
|
||||
)
|
||||
|
||||
async def get_in_progress_tool_call_step(self, session_id: str, turn_id: str) -> Optional[ToolExecutionStep]:
|
||||
if not await self.get_session_if_accessible(session_id):
|
||||
return None
|
||||
|
||||
value = await self.kvstore.get(
|
||||
key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}",
|
||||
)
|
||||
return ToolExecutionStep(**json.loads(value)) if value else None
|
||||
|
||||
async def set_num_infer_iters_in_turn(self, session_id: str, turn_id: str, num_infer_iters: int):
|
||||
if not await self.get_session_if_accessible(session_id):
|
||||
raise ValueError(f"Session {session_id} not found or access denied")
|
||||
|
||||
await self.kvstore.set(
|
||||
key=f"num_infer_iters_in_turn:{self.agent_id}:{session_id}:{turn_id}",
|
||||
value=str(num_infer_iters),
|
||||
)
|
||||
|
||||
async def get_num_infer_iters_in_turn(self, session_id: str, turn_id: str) -> Optional[int]:
|
||||
if not await self.get_session_if_accessible(session_id):
|
||||
return None
|
||||
|
||||
value = await self.kvstore.get(
|
||||
key=f"num_infer_iters_in_turn:{self.agent_id}:{session_id}:{turn_id}",
|
||||
)
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
|
@ -21,8 +21,8 @@ from llama_stack.providers.inline.agents.meta_reference.agent_instance import (
|
|||
from llama_stack.providers.utils.common.data_schema_validator import ColumnName
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
|
||||
from .....apis.common.job_types import Job
|
||||
from .....apis.eval.eval import BenchmarkConfig, Eval, EvaluateResponse, JobStatus
|
||||
from .....apis.common.job_types import Job, JobStatus
|
||||
from .....apis.eval.eval import BenchmarkConfig, Eval, EvaluateResponse
|
||||
from .config import MetaReferenceEvalConfig
|
||||
|
||||
EVAL_TASKS_PREFIX = "benchmarks:"
|
||||
|
@ -102,7 +102,7 @@ class MetaReferenceEvalImpl(
|
|||
# need job scheduler queue (ray/celery) w/ jobs api
|
||||
job_id = str(len(self.jobs))
|
||||
self.jobs[job_id] = res
|
||||
return Job(job_id=job_id)
|
||||
return Job(job_id=job_id, status=JobStatus.completed)
|
||||
|
||||
async def _run_agent_generation(
|
||||
self, input_rows: List[Dict[str, Any]], benchmark_config: BenchmarkConfig
|
||||
|
@ -216,17 +216,18 @@ class MetaReferenceEvalImpl(
|
|||
|
||||
return EvaluateResponse(generations=generations, scores=score_response.results)
|
||||
|
||||
async def job_status(self, benchmark_id: str, job_id: str) -> Optional[JobStatus]:
|
||||
async def job_status(self, benchmark_id: str, job_id: str) -> Job:
|
||||
if job_id in self.jobs:
|
||||
return JobStatus.completed
|
||||
return Job(job_id=job_id, status=JobStatus.completed)
|
||||
|
||||
return None
|
||||
raise ValueError(f"Job {job_id} not found")
|
||||
|
||||
async def job_cancel(self, benchmark_id: str, job_id: str) -> None:
|
||||
raise NotImplementedError("Job cancel is not implemented yet")
|
||||
|
||||
async def job_result(self, benchmark_id: str, job_id: str) -> EvaluateResponse:
|
||||
status = await self.job_status(benchmark_id, job_id)
|
||||
job = await self.job_status(benchmark_id, job_id)
|
||||
status = job.status
|
||||
if not status or status != JobStatus.completed:
|
||||
raise ValueError(f"Job is not completed, Status: {status.value}")
|
||||
|
||||
|
|
|
@ -23,7 +23,9 @@ from llama_stack.providers.utils.common.data_schema_validator import (
|
|||
|
||||
from .config import BasicScoringConfig
|
||||
from .scoring_fn.bfcl_scoring_fn import BFCLScoringFn
|
||||
from .scoring_fn.docvqa_scoring_fn import DocVQAScoringFn
|
||||
from .scoring_fn.equality_scoring_fn import EqualityScoringFn
|
||||
from .scoring_fn.ifeval_scoring_fn import IfEvalScoringFn
|
||||
from .scoring_fn.regex_parser_math_response_scoring_fn import (
|
||||
RegexParserMathResponseScoringFn,
|
||||
)
|
||||
|
@ -36,6 +38,8 @@ FIXED_FNS = [
|
|||
RegexParserScoringFn,
|
||||
RegexParserMathResponseScoringFn,
|
||||
BFCLScoringFn,
|
||||
IfEvalScoringFn,
|
||||
DocVQAScoringFn,
|
||||
]
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,240 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from llama_stack.apis.scoring import ScoringResultRow
|
||||
from llama_stack.apis.scoring_functions import ScoringFnParams
|
||||
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
|
||||
|
||||
from .fn_defs.docvqa import docvqa
|
||||
|
||||
CONTRACTIONS = {
|
||||
"aint": "ain't",
|
||||
"arent": "aren't",
|
||||
"cant": "can't",
|
||||
"couldve": "could've",
|
||||
"couldnt": "couldn't",
|
||||
"couldn'tve": "couldn't've",
|
||||
"couldnt've": "couldn't've",
|
||||
"didnt": "didn't",
|
||||
"doesnt": "doesn't",
|
||||
"dont": "don't",
|
||||
"hadnt": "hadn't",
|
||||
"hadnt've": "hadn't've",
|
||||
"hadn'tve": "hadn't've",
|
||||
"hasnt": "hasn't",
|
||||
"havent": "haven't",
|
||||
"hed": "he'd",
|
||||
"hed've": "he'd've",
|
||||
"he'dve": "he'd've",
|
||||
"hes": "he's",
|
||||
"howd": "how'd",
|
||||
"howll": "how'll",
|
||||
"hows": "how's",
|
||||
"Id've": "I'd've",
|
||||
"I'dve": "I'd've",
|
||||
"Im": "I'm",
|
||||
"Ive": "I've",
|
||||
"isnt": "isn't",
|
||||
"itd": "it'd",
|
||||
"itd've": "it'd've",
|
||||
"it'dve": "it'd've",
|
||||
"itll": "it'll",
|
||||
"let's": "let's",
|
||||
"maam": "ma'am",
|
||||
"mightnt": "mightn't",
|
||||
"mightnt've": "mightn't've",
|
||||
"mightn'tve": "mightn't've",
|
||||
"mightve": "might've",
|
||||
"mustnt": "mustn't",
|
||||
"mustve": "must've",
|
||||
"neednt": "needn't",
|
||||
"notve": "not've",
|
||||
"oclock": "o'clock",
|
||||
"oughtnt": "oughtn't",
|
||||
"ow's'at": "'ow's'at",
|
||||
"'ows'at": "'ow's'at",
|
||||
"'ow'sat": "'ow's'at",
|
||||
"shant": "shan't",
|
||||
"shed've": "she'd've",
|
||||
"she'dve": "she'd've",
|
||||
"she's": "she's",
|
||||
"shouldve": "should've",
|
||||
"shouldnt": "shouldn't",
|
||||
"shouldnt've": "shouldn't've",
|
||||
"shouldn'tve": "shouldn't've",
|
||||
"somebody'd": "somebodyd",
|
||||
"somebodyd've": "somebody'd've",
|
||||
"somebody'dve": "somebody'd've",
|
||||
"somebodyll": "somebody'll",
|
||||
"somebodys": "somebody's",
|
||||
"someoned": "someone'd",
|
||||
"someoned've": "someone'd've",
|
||||
"someone'dve": "someone'd've",
|
||||
"someonell": "someone'll",
|
||||
"someones": "someone's",
|
||||
"somethingd": "something'd",
|
||||
"somethingd've": "something'd've",
|
||||
"something'dve": "something'd've",
|
||||
"somethingll": "something'll",
|
||||
"thats": "that's",
|
||||
"thered": "there'd",
|
||||
"thered've": "there'd've",
|
||||
"there'dve": "there'd've",
|
||||
"therere": "there're",
|
||||
"theres": "there's",
|
||||
"theyd": "they'd",
|
||||
"theyd've": "they'd've",
|
||||
"they'dve": "they'd've",
|
||||
"theyll": "they'll",
|
||||
"theyre": "they're",
|
||||
"theyve": "they've",
|
||||
"twas": "'twas",
|
||||
"wasnt": "wasn't",
|
||||
"wed've": "we'd've",
|
||||
"we'dve": "we'd've",
|
||||
"weve": "we've",
|
||||
"werent": "weren't",
|
||||
"whatll": "what'll",
|
||||
"whatre": "what're",
|
||||
"whats": "what's",
|
||||
"whatve": "what've",
|
||||
"whens": "when's",
|
||||
"whered": "where'd",
|
||||
"wheres": "where's",
|
||||
"whereve": "where've",
|
||||
"whod": "who'd",
|
||||
"whod've": "who'd've",
|
||||
"who'dve": "who'd've",
|
||||
"wholl": "who'll",
|
||||
"whos": "who's",
|
||||
"whove": "who've",
|
||||
"whyll": "why'll",
|
||||
"whyre": "why're",
|
||||
"whys": "why's",
|
||||
"wont": "won't",
|
||||
"wouldve": "would've",
|
||||
"wouldnt": "wouldn't",
|
||||
"wouldnt've": "wouldn't've",
|
||||
"wouldn'tve": "wouldn't've",
|
||||
"yall": "y'all",
|
||||
"yall'll": "y'all'll",
|
||||
"y'allll": "y'all'll",
|
||||
"yall'd've": "y'all'd've",
|
||||
"y'alld've": "y'all'd've",
|
||||
"y'all'dve": "y'all'd've",
|
||||
"youd": "you'd",
|
||||
"youd've": "you'd've",
|
||||
"you'dve": "you'd've",
|
||||
"youll": "you'll",
|
||||
"youre": "you're",
|
||||
"youve": "you've",
|
||||
"1st": "first",
|
||||
"2nd": "second",
|
||||
"3rd": "third",
|
||||
}
|
||||
NUMBERS = {
|
||||
"none": "0",
|
||||
"zero": "0",
|
||||
"one": "1",
|
||||
"two": "2",
|
||||
"three": "3",
|
||||
"four": "4",
|
||||
"five": "5",
|
||||
"six": "6",
|
||||
"seven": "7",
|
||||
"eight": "8",
|
||||
"nine": "9",
|
||||
"ten": "10",
|
||||
}
|
||||
ARTICLES = [
|
||||
"a",
|
||||
"an",
|
||||
"the",
|
||||
"to",
|
||||
"in",
|
||||
"from",
|
||||
"by",
|
||||
] # Contains a bit more than just articles, but we want to get rid of these elements influencing the accuracy
|
||||
PERIOD_STRIP = re.compile(r"(?!<=\d)(\.)(?!\d)")
|
||||
COMMA_STRIP = re.compile(r"(\d)(\,)(\d)")
|
||||
PUNCTUATION = [
|
||||
";",
|
||||
r"/",
|
||||
"[",
|
||||
"]",
|
||||
'"',
|
||||
"{",
|
||||
"}",
|
||||
"(",
|
||||
")",
|
||||
"=",
|
||||
"+",
|
||||
"\\",
|
||||
"_",
|
||||
"-",
|
||||
">",
|
||||
"<",
|
||||
"@",
|
||||
"`",
|
||||
",",
|
||||
"?",
|
||||
"!",
|
||||
]
|
||||
|
||||
|
||||
def normalize_answer(s: str) -> str:
|
||||
# process punctuation
|
||||
for p in PUNCTUATION:
|
||||
if (p + " " in s or " " + p in s) or (re.search(COMMA_STRIP, s) is not None):
|
||||
s = s.replace(p, "")
|
||||
else:
|
||||
s = s.replace(p, " ")
|
||||
s = PERIOD_STRIP.sub("", s, re.UNICODE)
|
||||
|
||||
# process digits and articles
|
||||
temp_text = s.lower().split()
|
||||
out_text = []
|
||||
for word in temp_text:
|
||||
word = NUMBERS.setdefault(word, word)
|
||||
if word not in ARTICLES:
|
||||
out_text.append(word)
|
||||
|
||||
# standardize contractions
|
||||
for word_id, word in enumerate(out_text):
|
||||
if word in CONTRACTIONS:
|
||||
out_text[word_id] = CONTRACTIONS[word]
|
||||
return " ".join(out_text)
|
||||
|
||||
|
||||
class DocVQAScoringFn(RegisteredBaseScoringFn):
|
||||
"""
|
||||
docvqa basically matches the generated answer against several allowed
|
||||
choices, but we need to normalize the answer to avoid penalizing
|
||||
trivial differences
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.supported_fn_defs_registry = {
|
||||
docvqa.identifier: docvqa,
|
||||
}
|
||||
|
||||
async def score_row(
|
||||
self,
|
||||
input_row: Dict[str, Any],
|
||||
scoring_fn_identifier: Optional[str] = "docvqa",
|
||||
scoring_params: Optional[ScoringFnParams] = None,
|
||||
) -> ScoringResultRow:
|
||||
expected_answers = json.loads(input_row["expected_answer"])
|
||||
generated_answer = input_row["generated_answer"]
|
||||
score = 1.0 if normalize_answer(generated_answer) in [normalize_answer(s) for s in expected_answers] else 0.0
|
||||
return {
|
||||
"score": score,
|
||||
}
|
|
@ -0,0 +1,21 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.common.type_system import NumberType
|
||||
from llama_stack.apis.scoring_functions import (
|
||||
AggregationFunctionType,
|
||||
BasicScoringFnParams,
|
||||
ScoringFn,
|
||||
)
|
||||
|
||||
docvqa = ScoringFn(
|
||||
identifier="basic::docvqa",
|
||||
description="DocVQA Visual Question & Answer scoring function",
|
||||
return_type=NumberType(),
|
||||
provider_id="basic",
|
||||
provider_resource_id="docvqa",
|
||||
params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.accuracy]),
|
||||
)
|
|
@ -0,0 +1,23 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.common.type_system import NumberType
|
||||
from llama_stack.apis.scoring_functions import (
|
||||
AggregationFunctionType,
|
||||
BasicScoringFnParams,
|
||||
ScoringFn,
|
||||
)
|
||||
|
||||
ifeval = ScoringFn(
|
||||
identifier="basic::ifeval",
|
||||
description="Eval intruction follow capacity by checkping how many instructions can be followed in each example",
|
||||
return_type=NumberType(),
|
||||
provider_id="basic",
|
||||
provider_resource_id="ifeval",
|
||||
params=BasicScoringFnParams(
|
||||
aggregation_functions=[AggregationFunctionType.weighted_average],
|
||||
),
|
||||
)
|
|
@ -0,0 +1,80 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from llama_stack.apis.scoring import ScoringResultRow
|
||||
from llama_stack.apis.scoring_functions import ScoringFnParams
|
||||
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
|
||||
|
||||
from .fn_defs.ifeval import (
|
||||
ifeval,
|
||||
)
|
||||
|
||||
|
||||
class IfEvalScoringFn(RegisteredBaseScoringFn):
|
||||
"""
|
||||
A scoring_fn Instruction-Following Eval (IFEval) benchmark
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.supported_fn_defs_registry = {
|
||||
ifeval.identifier: ifeval,
|
||||
}
|
||||
|
||||
async def score_row(
|
||||
self,
|
||||
input_row: Dict[str, Any],
|
||||
scoring_fn_identifier: Optional[str] = None,
|
||||
scoring_params: Optional[ScoringFnParams] = None,
|
||||
) -> ScoringResultRow:
|
||||
from ..utils.ifeval_utils import INSTRUCTION_DICT, INSTRUCTION_LIST
|
||||
|
||||
assert scoring_fn_identifier is not None, "Scoring function identifier not found."
|
||||
fn_def = self.supported_fn_defs_registry[scoring_fn_identifier]
|
||||
if scoring_params is not None:
|
||||
fn_def.params = scoring_params
|
||||
|
||||
instruction_list = input_row["instruction_id_list"]
|
||||
generated_answer = input_row["generated_answer"].strip()
|
||||
|
||||
is_following_list = []
|
||||
results = dict(
|
||||
{k + "_correct": 0.0 for k in INSTRUCTION_LIST},
|
||||
**{k + "_total": 0.0 for k in INSTRUCTION_LIST},
|
||||
)
|
||||
|
||||
for index, instruction_id in enumerate(instruction_list):
|
||||
instruction_cls = INSTRUCTION_DICT[instruction_id]
|
||||
instruction = instruction_cls(instruction_id)
|
||||
results[instruction_id + "_total"] += 1.0
|
||||
results[instruction_id.split(":")[0] + "_total"] += 1.0
|
||||
|
||||
clean_input_row = {k: v for k, v in input_row["kwargs"][index].items() if v is not None}
|
||||
print(clean_input_row)
|
||||
instruction.build_description(**clean_input_row)
|
||||
args = instruction.get_instruction_args()
|
||||
if args and "prompt" in args:
|
||||
instruction.build_description(prompt=input_row["prompt"])
|
||||
|
||||
if generated_answer and instruction.check_following(generated_answer):
|
||||
is_following_list.append(True)
|
||||
results[instruction_id + "_correct"] += 1.0
|
||||
results[instruction_id.split(":")[0] + "_correct"] += 1.0
|
||||
else:
|
||||
is_following_list.append(False)
|
||||
|
||||
if len(is_following_list) == 0:
|
||||
return {
|
||||
"score": 0.0,
|
||||
"weight": 0.0,
|
||||
}
|
||||
|
||||
return {
|
||||
"score": float(sum(is_following_list)) / float(len(is_following_list)),
|
||||
"weight": float(len(is_following_list)),
|
||||
}
|
3319
llama_stack/providers/inline/scoring/basic/utils/ifeval_utils.py
Normal file
3319
llama_stack/providers/inline/scoring/basic/utils/ifeval_utils.py
Normal file
File diff suppressed because it is too large
Load diff
|
@ -6,12 +6,14 @@
|
|||
|
||||
from typing import Any, Dict
|
||||
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
|
||||
from .config import TelemetryConfig, TelemetrySink
|
||||
|
||||
__all__ = ["TelemetryConfig", "TelemetrySink"]
|
||||
|
||||
|
||||
async def get_provider_impl(config: TelemetryConfig, deps: Dict[str, Any]):
|
||||
async def get_provider_impl(config: TelemetryConfig, deps: Dict[Api, Any]):
|
||||
from .telemetry import TelemetryAdapter
|
||||
|
||||
impl = TelemetryAdapter(config, deps)
|
||||
|
|
|
@ -13,19 +13,20 @@ from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
|
|||
|
||||
|
||||
class TelemetrySink(str, Enum):
|
||||
OTEL = "otel"
|
||||
OTEL_TRACE = "otel_trace"
|
||||
OTEL_METRIC = "otel_metric"
|
||||
SQLITE = "sqlite"
|
||||
CONSOLE = "console"
|
||||
|
||||
|
||||
class TelemetryConfig(BaseModel):
|
||||
otel_endpoint: str = Field(
|
||||
otel_trace_endpoint: str = Field(
|
||||
default="http://localhost:4318/v1/traces",
|
||||
description="The OpenTelemetry collector endpoint URL",
|
||||
description="The OpenTelemetry collector endpoint URL for traces",
|
||||
)
|
||||
service_name: str = Field(
|
||||
default="llama-stack",
|
||||
description="The service name to use for telemetry",
|
||||
otel_metric_endpoint: str = Field(
|
||||
default="http://localhost:4318/v1/metrics",
|
||||
description="The OpenTelemetry collector endpoint URL for metrics",
|
||||
)
|
||||
sinks: List[TelemetrySink] = Field(
|
||||
default=[TelemetrySink.CONSOLE, TelemetrySink.SQLITE],
|
||||
|
@ -46,7 +47,6 @@ class TelemetryConfig(BaseModel):
|
|||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, db_name: str = "trace_store.db") -> Dict[str, Any]:
|
||||
return {
|
||||
"service_name": "${env.OTEL_SERVICE_NAME:llama-stack}",
|
||||
"sinks": "${env.TELEMETRY_SINKS:console,sqlite}",
|
||||
"sqlite_db_path": "${env.SQLITE_DB_PATH:" + __distro_dir__ + "/" + db_name + "}",
|
||||
}
|
||||
|
|
|
@ -101,6 +101,6 @@ class ConsoleSpanProcessor(SpanProcessor):
|
|||
"""Shutdown the processor."""
|
||||
pass
|
||||
|
||||
def force_flush(self, timeout_millis: float = None) -> bool:
|
||||
def force_flush(self, timeout_millis: float | None = None) -> bool:
|
||||
"""Force flush any pending spans."""
|
||||
return True
|
||||
|
|
|
@ -12,6 +12,7 @@ from datetime import datetime, timezone
|
|||
|
||||
from opentelemetry.sdk.trace import SpanProcessor
|
||||
from opentelemetry.trace import Span
|
||||
from opentelemetry.trace.span import format_span_id, format_trace_id
|
||||
|
||||
|
||||
class SQLiteSpanProcessor(SpanProcessor):
|
||||
|
@ -100,14 +101,14 @@ class SQLiteSpanProcessor(SpanProcessor):
|
|||
conn = self._get_connection()
|
||||
cursor = conn.cursor()
|
||||
|
||||
trace_id = format(span.get_span_context().trace_id, "032x")
|
||||
span_id = format(span.get_span_context().span_id, "016x")
|
||||
trace_id = format_trace_id(span.get_span_context().trace_id)
|
||||
span_id = format_span_id(span.get_span_context().span_id)
|
||||
service_name = span.resource.attributes.get("service.name", "unknown")
|
||||
|
||||
parent_span_id = None
|
||||
parent_context = span.parent
|
||||
if parent_context:
|
||||
parent_span_id = format(parent_context.span_id, "016x")
|
||||
parent_span_id = format_span_id(parent_context.span_id)
|
||||
|
||||
# Insert into traces
|
||||
cursor.execute(
|
||||
|
@ -123,7 +124,7 @@ class SQLiteSpanProcessor(SpanProcessor):
|
|||
(
|
||||
trace_id,
|
||||
service_name,
|
||||
(span_id if not parent_span_id else None),
|
||||
(span_id if span.attributes.get("__root_span__") == "true" else None),
|
||||
datetime.fromtimestamp(span.start_time / 1e9, timezone.utc).isoformat(),
|
||||
datetime.fromtimestamp(span.end_time / 1e9, timezone.utc).isoformat(),
|
||||
),
|
||||
|
|
|
@ -44,7 +44,7 @@ from llama_stack.providers.utils.telemetry.sqlite_trace_store import SQLiteTrace
|
|||
|
||||
from .config import TelemetryConfig, TelemetrySink
|
||||
|
||||
_GLOBAL_STORAGE = {
|
||||
_GLOBAL_STORAGE: dict[str, dict[str | int, Any]] = {
|
||||
"active_spans": {},
|
||||
"counters": {},
|
||||
"gauges": {},
|
||||
|
@ -54,30 +54,21 @@ _global_lock = threading.Lock()
|
|||
_TRACER_PROVIDER = None
|
||||
|
||||
|
||||
def string_to_trace_id(s: str) -> int:
|
||||
# Convert the string to bytes and then to an integer
|
||||
return int.from_bytes(s.encode(), byteorder="big", signed=False)
|
||||
|
||||
|
||||
def string_to_span_id(s: str) -> int:
|
||||
# Use only the first 8 bytes (64 bits) for span ID
|
||||
return int.from_bytes(s.encode()[:8], byteorder="big", signed=False)
|
||||
|
||||
|
||||
def is_tracing_enabled(tracer):
|
||||
with tracer.start_as_current_span("check_tracing") as span:
|
||||
return span.is_recording()
|
||||
|
||||
|
||||
class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
||||
def __init__(self, config: TelemetryConfig, deps: Dict[str, Any]) -> None:
|
||||
def __init__(self, config: TelemetryConfig, deps: Dict[Api, Any]) -> None:
|
||||
self.config = config
|
||||
self.datasetio_api = deps.get(Api.datasetio)
|
||||
self.meter = None
|
||||
|
||||
resource = Resource.create(
|
||||
{
|
||||
ResourceAttributes.SERVICE_NAME: self.config.service_name,
|
||||
# service name is always the same, use zero-width space to avoid clutter
|
||||
ResourceAttributes.SERVICE_NAME: "",
|
||||
}
|
||||
)
|
||||
|
||||
|
@ -91,15 +82,16 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
|||
provider = TracerProvider(resource=resource)
|
||||
trace.set_tracer_provider(provider)
|
||||
_TRACER_PROVIDER = provider
|
||||
if TelemetrySink.OTEL in self.config.sinks:
|
||||
otlp_exporter = OTLPSpanExporter(
|
||||
endpoint=self.config.otel_endpoint,
|
||||
if TelemetrySink.OTEL_TRACE in self.config.sinks:
|
||||
span_exporter = OTLPSpanExporter(
|
||||
endpoint=self.config.otel_trace_endpoint,
|
||||
)
|
||||
span_processor = BatchSpanProcessor(otlp_exporter)
|
||||
span_processor = BatchSpanProcessor(span_exporter)
|
||||
trace.get_tracer_provider().add_span_processor(span_processor)
|
||||
if TelemetrySink.OTEL_METRIC in self.config.sinks:
|
||||
metric_reader = PeriodicExportingMetricReader(
|
||||
OTLPMetricExporter(
|
||||
endpoint=self.config.otel_endpoint,
|
||||
endpoint=self.config.otel_metric_endpoint,
|
||||
)
|
||||
)
|
||||
metric_provider = MeterProvider(resource=resource, metric_readers=[metric_reader])
|
||||
|
@ -109,7 +101,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
|||
if TelemetrySink.CONSOLE in self.config.sinks:
|
||||
trace.get_tracer_provider().add_span_processor(ConsoleSpanProcessor())
|
||||
|
||||
if TelemetrySink.OTEL in self.config.sinks:
|
||||
if TelemetrySink.OTEL_METRIC in self.config.sinks:
|
||||
self.meter = metrics.get_meter(__name__)
|
||||
if TelemetrySink.SQLITE in self.config.sinks:
|
||||
self.trace_store = SQLiteTraceStore(self.config.sqlite_db_path)
|
||||
|
@ -135,7 +127,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
|||
def _log_unstructured(self, event: UnstructuredLogEvent, ttl_seconds: int) -> None:
|
||||
with self._lock:
|
||||
# Use global storage instead of instance storage
|
||||
span_id = string_to_span_id(event.span_id)
|
||||
span_id = event.span_id
|
||||
span = _GLOBAL_STORAGE["active_spans"].get(span_id)
|
||||
|
||||
if span:
|
||||
|
@ -146,7 +138,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
|||
"message": event.message,
|
||||
"severity": event.severity.value,
|
||||
"__ttl__": ttl_seconds,
|
||||
**event.attributes,
|
||||
**(event.attributes or {}),
|
||||
},
|
||||
timestamp=timestamp_ns,
|
||||
)
|
||||
|
@ -154,6 +146,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
|||
print(f"Warning: No active span found for span_id {span_id}. Dropping event: {event}")
|
||||
|
||||
def _get_or_create_counter(self, name: str, unit: str) -> metrics.Counter:
|
||||
assert self.meter is not None
|
||||
if name not in _GLOBAL_STORAGE["counters"]:
|
||||
_GLOBAL_STORAGE["counters"][name] = self.meter.create_counter(
|
||||
name=name,
|
||||
|
@ -163,6 +156,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
|||
return _GLOBAL_STORAGE["counters"][name]
|
||||
|
||||
def _get_or_create_gauge(self, name: str, unit: str) -> metrics.ObservableGauge:
|
||||
assert self.meter is not None
|
||||
if name not in _GLOBAL_STORAGE["gauges"]:
|
||||
_GLOBAL_STORAGE["gauges"][name] = self.meter.create_gauge(
|
||||
name=name,
|
||||
|
@ -182,6 +176,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
|||
up_down_counter.add(event.value, attributes=event.attributes)
|
||||
|
||||
def _get_or_create_up_down_counter(self, name: str, unit: str) -> metrics.UpDownCounter:
|
||||
assert self.meter is not None
|
||||
if name not in _GLOBAL_STORAGE["up_down_counters"]:
|
||||
_GLOBAL_STORAGE["up_down_counters"][name] = self.meter.create_up_down_counter(
|
||||
name=name,
|
||||
|
@ -192,8 +187,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
|||
|
||||
def _log_structured(self, event: StructuredLogEvent, ttl_seconds: int) -> None:
|
||||
with self._lock:
|
||||
span_id = string_to_span_id(event.span_id)
|
||||
trace_id = string_to_trace_id(event.trace_id)
|
||||
span_id = int(event.span_id, 16)
|
||||
tracer = trace.get_tracer(__name__)
|
||||
if event.attributes is None:
|
||||
event.attributes = {}
|
||||
|
@ -204,14 +198,23 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
|||
if span_id in _GLOBAL_STORAGE["active_spans"]:
|
||||
return
|
||||
|
||||
parent_span = None
|
||||
context = None
|
||||
if event.payload.parent_span_id:
|
||||
parent_span_id = string_to_span_id(event.payload.parent_span_id)
|
||||
parent_span_id = int(event.payload.parent_span_id, 16)
|
||||
parent_span = _GLOBAL_STORAGE["active_spans"].get(parent_span_id)
|
||||
|
||||
context = trace.Context(trace_id=trace_id)
|
||||
if parent_span:
|
||||
context = trace.set_span_in_context(parent_span, context)
|
||||
context = trace.set_span_in_context(parent_span)
|
||||
else:
|
||||
context = trace.set_span_in_context(
|
||||
trace.NonRecordingSpan(
|
||||
trace.SpanContext(
|
||||
trace_id=int(event.trace_id, 16),
|
||||
span_id=span_id,
|
||||
is_remote=False,
|
||||
trace_flags=trace.TraceFlags(trace.TraceFlags.SAMPLED),
|
||||
)
|
||||
)
|
||||
)
|
||||
event.attributes["__root_span__"] = "true"
|
||||
|
||||
span = tracer.start_span(
|
||||
name=event.payload.name,
|
||||
|
|
|
@ -69,7 +69,7 @@ def popen_not_allowed(*args, **kwargs):
|
|||
)
|
||||
|
||||
|
||||
_subprocess.Popen = popen_not_allowed
|
||||
_subprocess.Popen = popen_not_allowed # type: ignore
|
||||
|
||||
|
||||
import atexit as _atexit
|
||||
|
@ -104,7 +104,7 @@ def _open_connections():
|
|||
return _NETWORK_CONNECTIONS
|
||||
|
||||
|
||||
_builtins._open_connections = _open_connections
|
||||
_builtins._open_connections = _open_connections # type: ignore
|
||||
|
||||
|
||||
@_atexit.register
|
||||
|
|
|
@ -161,9 +161,9 @@ _set_seeds()\
|
|||
def process_matplotlib_response(response, matplotlib_dump_dir: str):
|
||||
image_data = response["image_data"]
|
||||
# Convert the base64 string to a bytes object
|
||||
images = [base64.b64decode(d["image_base64"]) for d in image_data]
|
||||
images_raw = [base64.b64decode(d["image_base64"]) for d in image_data]
|
||||
# Create a list of PIL images from the bytes objects
|
||||
images = [Image.open(BytesIO(img)) for img in images]
|
||||
images = [Image.open(BytesIO(img)) for img in images_raw]
|
||||
# Create a list of image paths
|
||||
image_paths = []
|
||||
for i, img in enumerate(images):
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue