Merge branch 'main' into aidand-391-guided-decoding-vllm_3

This commit is contained in:
Aidan Do 2024-12-08 17:36:57 +11:00 committed by GitHub
commit 44bb23ebc8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
177 changed files with 5725 additions and 594 deletions

View file

@ -0,0 +1,355 @@
name: "Run Llama-stack Tests"
on:
#### Temporarily disable PR runs until tests run as intended within mainline.
#TODO Add this back.
#pull_request_target:
# types: ["opened"]
# branches:
# - 'main'
# paths:
# - 'llama_stack/**/*.py'
# - 'tests/**/*.py'
workflow_dispatch:
inputs:
runner:
description: 'GHA Runner Scale Set label to run workflow on.'
required: true
default: "llama-stack-gha-runner-gpu"
checkout_reference:
description: "The branch, tag, or SHA to checkout"
required: true
default: "main"
debug:
description: 'Run debugging steps?'
required: false
default: "true"
sleep_time:
description: '[DEBUG] sleep time for debugging'
required: true
default: "0"
provider_id:
description: 'ID of your provider'
required: true
default: "meta_reference"
model_id:
description: 'Shorthand name for target model ID (llama_3b or llama_8b)'
required: true
default: "llama_3b"
model_override_3b:
description: 'Specify shorthand model for <llama_3b> '
required: false
default: "Llama3.2-3B-Instruct"
model_override_8b:
description: 'Specify shorthand model for <llama_8b> '
required: false
default: "Llama3.1-8B-Instruct"
env:
# ID used for each test's provider config
PROVIDER_ID: "${{ inputs.provider_id || 'meta_reference' }}"
# Path to model checkpoints within EFS volume
MODEL_CHECKPOINT_DIR: "/data/llama"
# Path to directory to run tests from
TESTS_PATH: "${{ github.workspace }}/llama_stack/providers/tests"
# Keep track of a list of model IDs that are valid to use within pytest fixture marks
AVAILABLE_MODEL_IDs: "llama_3b llama_8b"
# Shorthand name for model ID, used in pytest fixture marks
MODEL_ID: "${{ inputs.model_id || 'llama_3b' }}"
# Override the `llama_3b` / `llama_8b' models, else use the default.
LLAMA_3B_OVERRIDE: "${{ inputs.model_override_3b || 'Llama3.2-3B-Instruct' }}"
LLAMA_8B_OVERRIDE: "${{ inputs.model_override_8b || 'Llama3.1-8B-Instruct' }}"
# Defines which directories in TESTS_PATH to exclude from the test loop
EXCLUDED_DIRS: "__pycache__"
# Defines the output xml reports generated after a test is run
REPORTS_GEN: ""
jobs:
execute_workflow:
name: Execute workload on Self-Hosted GPU k8s runner
permissions:
pull-requests: write
defaults:
run:
shell: bash
runs-on: ${{ inputs.runner != '' && inputs.runner || 'llama-stack-gha-runner-gpu' }}
if: always()
steps:
##############################
#### INITIAL DEBUG CHECKS ####
##############################
- name: "[DEBUG] Check content of the EFS mount"
id: debug_efs_volume
continue-on-error: true
if: inputs.debug == 'true'
run: |
echo "========= Content of the EFS mount ============="
ls -la ${{ env.MODEL_CHECKPOINT_DIR }}
- name: "[DEBUG] Get runner container OS information"
id: debug_os_info
if: ${{ inputs.debug == 'true' }}
run: |
cat /etc/os-release
- name: "[DEBUG] Print environment variables"
id: debug_env_vars
if: ${{ inputs.debug == 'true' }}
run: |
echo "PROVIDER_ID = ${PROVIDER_ID}"
echo "MODEL_CHECKPOINT_DIR = ${MODEL_CHECKPOINT_DIR}"
echo "AVAILABLE_MODEL_IDs = ${AVAILABLE_MODEL_IDs}"
echo "MODEL_ID = ${MODEL_ID}"
echo "LLAMA_3B_OVERRIDE = ${LLAMA_3B_OVERRIDE}"
echo "LLAMA_8B_OVERRIDE = ${LLAMA_8B_OVERRIDE}"
echo "EXCLUDED_DIRS = ${EXCLUDED_DIRS}"
echo "REPORTS_GEN = ${REPORTS_GEN}"
############################
#### MODEL INPUT CHECKS ####
############################
- name: "Check if env.model_id is valid"
id: check_model_id
run: |
if [[ " ${AVAILABLE_MODEL_IDs[@]} " =~ " ${MODEL_ID} " ]]; then
echo "Model ID '${MODEL_ID}' is valid."
else
echo "Model ID '${MODEL_ID}' is invalid. Terminating workflow."
exit 1
fi
#######################
#### CODE CHECKOUT ####
#######################
- name: "Checkout 'meta-llama/llama-stack' repository"
id: checkout_repo
uses: actions/checkout@v4
with:
ref: ${{ inputs.branch }}
- name: "[DEBUG] Content of the repository after checkout"
id: debug_content_after_checkout
if: ${{ inputs.debug == 'true' }}
run: |
ls -la ${GITHUB_WORKSPACE}
##########################################################
#### OPTIONAL SLEEP DEBUG ####
# #
# Use to "exec" into the test k8s POD and run tests #
# manually to identify what dependencies are being used. #
# #
##########################################################
- name: "[DEBUG] sleep"
id: debug_sleep
if: ${{ inputs.debug == 'true' && inputs.sleep_time != '' }}
run: |
sleep ${{ inputs.sleep_time }}
############################
#### UPDATE SYSTEM PATH ####
############################
- name: "Update path: execute"
id: path_update_exec
run: |
# .local/bin is needed for certain libraries installed below to be recognized
# when calling their executable to install sub-dependencies
mkdir -p ${HOME}/.local/bin
echo "${HOME}/.local/bin" >> "$GITHUB_PATH"
#####################################
#### UPDATE CHECKPOINT DIRECTORY ####
#####################################
- name: "Update checkpoint directory"
id: checkpoint_update
run: |
echo "Checkpoint directory: ${MODEL_CHECKPOINT_DIR}/$LLAMA_3B_OVERRIDE"
if [ "${MODEL_ID}" = "llama_3b" ] && [ -d "${MODEL_CHECKPOINT_DIR}/${LLAMA_3B_OVERRIDE}" ]; then
echo "MODEL_CHECKPOINT_DIR=${MODEL_CHECKPOINT_DIR}/${LLAMA_3B_OVERRIDE}" >> "$GITHUB_ENV"
elif [ "${MODEL_ID}" = "llama_8b" ] && [ -d "${MODEL_CHECKPOINT_DIR}/${LLAMA_8B_OVERRIDE}" ]; then
echo "MODEL_CHECKPOINT_DIR=${MODEL_CHECKPOINT_DIR}/${LLAMA_8B_OVERRIDE}" >> "$GITHUB_ENV"
else
echo "MODEL_ID & LLAMA_*B_OVERRIDE are not a valid pairing. Terminating workflow."
exit 1
fi
- name: "[DEBUG] Checkpoint update check"
id: debug_checkpoint_update
if: ${{ inputs.debug == 'true' }}
run: |
echo "MODEL_CHECKPOINT_DIR (after update) = ${MODEL_CHECKPOINT_DIR}"
##################################
#### DEPENDENCY INSTALLATIONS ####
##################################
- name: "Installing 'apt' required packages"
id: install_apt
run: |
echo "[STEP] Installing 'apt' required packages"
sudo apt update -y
sudo apt install -y python3 python3-pip npm wget
- name: "Installing packages with 'curl'"
id: install_curl
run: |
curl -fsSL https://ollama.com/install.sh | sh
- name: "Installing packages with 'wget'"
id: install_wget
run: |
wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh
chmod +x Miniconda3-latest-Linux-x86_64.sh
./Miniconda3-latest-Linux-x86_64.sh -b install -c pytorch -c nvidia faiss-gpu=1.9.0
# Add miniconda3 bin to system path
echo "${HOME}/miniconda3/bin" >> "$GITHUB_PATH"
- name: "Installing packages with 'npm'"
id: install_npm_generic
run: |
sudo npm install -g junit-merge
- name: "Installing pip dependencies"
id: install_pip_generic
run: |
echo "[STEP] Installing 'llama-stack' models"
pip install -U pip setuptools
pip install -r requirements.txt
pip install -e .
pip install -U \
torch torchvision \
pytest pytest_asyncio \
fairscale lm-format-enforcer \
zmq chardet pypdf \
pandas sentence_transformers together \
aiosqlite
- name: "Installing packages with conda"
id: install_conda_generic
run: |
conda install -q -c pytorch -c nvidia faiss-gpu=1.9.0
#############################################################
#### TESTING TO BE DONE FOR BOTH PRS AND MANUAL DISPATCH ####
#############################################################
- name: "Run Tests: Loop"
id: run_tests_loop
working-directory: "${{ github.workspace }}"
run: |
pattern=""
for dir in llama_stack/providers/tests/*; do
if [ -d "$dir" ]; then
dir_name=$(basename "$dir")
if [[ ! " $EXCLUDED_DIRS " =~ " $dir_name " ]]; then
for file in "$dir"/test_*.py; do
test_name=$(basename "$file")
new_file="result-${dir_name}-${test_name}.xml"
if torchrun $(which pytest) -s -v ${TESTS_PATH}/${dir_name}/${test_name} -m "${PROVIDER_ID} and ${MODEL_ID}" \
--junitxml="${{ github.workspace }}/${new_file}"; then
echo "Ran test: ${test_name}"
else
echo "Did NOT run test: ${test_name}"
fi
pattern+="${new_file} "
done
fi
fi
done
echo "REPORTS_GEN=$pattern" >> "$GITHUB_ENV"
- name: "Test Summary: Merge"
id: test_summary_merge
working-directory: "${{ github.workspace }}"
run: |
echo "Merging the following test result files: ${REPORTS_GEN}"
# Defaults to merging them into 'merged-test-results.xml'
junit-merge ${{ env.REPORTS_GEN }}
############################################
#### AUTOMATIC TESTING ON PULL REQUESTS ####
############################################
#### Run tests ####
- name: "PR - Run Tests"
id: pr_run_tests
working-directory: "${{ github.workspace }}"
if: github.event_name == 'pull_request_target'
run: |
echo "[STEP] Running PyTest tests at 'GITHUB_WORKSPACE' path: ${GITHUB_WORKSPACE} | path: ${{ github.workspace }}"
# (Optional) Add more tests here.
# Merge test results with 'merged-test-results.xml' from above.
# junit-merge <new-test-results> merged-test-results.xml
#### Create test summary ####
- name: "PR - Test Summary"
id: pr_test_summary_create
if: github.event_name == 'pull_request_target'
uses: test-summary/action@v2
with:
paths: "${{ github.workspace }}/merged-test-results.xml"
output: test-summary.md
- name: "PR - Upload Test Summary"
id: pr_test_summary_upload
if: github.event_name == 'pull_request_target'
uses: actions/upload-artifact@v3
with:
name: test-summary
path: test-summary.md
#### Update PR request ####
- name: "PR - Update comment"
id: pr_update_comment
if: github.event_name == 'pull_request_target'
uses: thollander/actions-comment-pull-request@v2
with:
filePath: test-summary.md
########################
#### MANUAL TESTING ####
########################
#### Run tests ####
- name: "Manual - Run Tests: Prep"
id: manual_run_tests
working-directory: "${{ github.workspace }}"
if: github.event_name == 'workflow_dispatch'
run: |
echo "[STEP] Running PyTest tests at 'GITHUB_WORKSPACE' path: ${{ github.workspace }}"
#TODO Use this when collection errors are resolved
# pytest -s -v -m "${PROVIDER_ID} and ${MODEL_ID}" --junitxml="${{ github.workspace }}/merged-test-results.xml"
# (Optional) Add more tests here.
# Merge test results with 'merged-test-results.xml' from above.
# junit-merge <new-test-results> merged-test-results.xml
#### Create test summary ####
- name: "Manual - Test Summary"
id: manual_test_summary
if: always() && github.event_name == 'workflow_dispatch'
uses: test-summary/action@v2
with:
paths: "${{ github.workspace }}/merged-test-results.xml"

View file

@ -80,6 +80,7 @@ Additionally, we have designed every element of the Stack such that APIs as well
| **API Provider Builder** | **Environments** | **Agents** | **Inference** | **Memory** | **Safety** | **Telemetry** |
| :----: | :----: | :----: | :----: | :----: | :----: | :----: |
| Meta Reference | Single Node | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| Cerebras | Single Node | | :heavy_check_mark: | | | |
| Fireworks | Hosted | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | | |
| AWS Bedrock | Hosted | | :heavy_check_mark: | | :heavy_check_mark: | |
| Together | Hosted | :heavy_check_mark: | :heavy_check_mark: | | :heavy_check_mark: | |
@ -93,12 +94,13 @@ Additionally, we have designed every element of the Stack such that APIs as well
| **Distribution** | **Llama Stack Docker** | Start This Distribution |
|:----------------: |:------------------------------------------: |:-----------------------: |
| Meta Reference | [llamastack/distribution-meta-reference-gpu](https://hub.docker.com/repository/docker/llamastack/distribution-meta-reference-gpu/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/getting_started/distributions/self_hosted_distro/meta-reference-gpu.html) |
| Meta Reference Quantized | [llamastack/distribution-meta-reference-quantized-gpu](https://hub.docker.com/repository/docker/llamastack/distribution-meta-reference-quantized-gpu/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/getting_started/distributions/self_hosted_distro/meta-reference-quantized-gpu.html) |
| Ollama | [llamastack/distribution-ollama](https://hub.docker.com/repository/docker/llamastack/distribution-ollama/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/getting_started/distributions/self_hosted_distro/ollama.html) |
| TGI | [llamastack/distribution-tgi](https://hub.docker.com/repository/docker/llamastack/distribution-tgi/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/getting_started/distributions/self_hosted_distro/tgi.html) |
| Together | [llamastack/distribution-together](https://hub.docker.com/repository/docker/llamastack/distribution-together/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/getting_started/distributions/remote_hosted_distro/together.html) |
| Fireworks | [llamastack/distribution-fireworks](https://hub.docker.com/repository/docker/llamastack/distribution-fireworks/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/getting_started/distributions/remote_hosted_distro/fireworks.html) |
| Meta Reference | [llamastack/distribution-meta-reference-gpu](https://hub.docker.com/repository/docker/llamastack/distribution-meta-reference-gpu/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/meta-reference-gpu.html) |
| Meta Reference Quantized | [llamastack/distribution-meta-reference-quantized-gpu](https://hub.docker.com/repository/docker/llamastack/distribution-meta-reference-quantized-gpu/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/meta-reference-quantized-gpu.html) |
| Cerebras | [llamastack/distribution-cerebras](https://hub.docker.com/repository/docker/llamastack/distribution-cerebras/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/getting_started/distributions/self_hosted_distro/cerebras.html) |
| Ollama | [llamastack/distribution-ollama](https://hub.docker.com/repository/docker/llamastack/distribution-ollama/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/ollama.html) |
| TGI | [llamastack/distribution-tgi](https://hub.docker.com/repository/docker/llamastack/distribution-tgi/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/tgi.html) |
| Together | [llamastack/distribution-together](https://hub.docker.com/repository/docker/llamastack/distribution-together/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/together.html) |
| Fireworks | [llamastack/distribution-fireworks](https://hub.docker.com/repository/docker/llamastack/distribution-fireworks/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/fireworks.html) |
## Installation
@ -111,7 +113,8 @@ You have two ways to install this repository:
```
2. **Install from source**:
If you prefer to install from the source code, follow these steps:
If you prefer to install from the source code, make sure you have [conda installed](https://docs.conda.io/projects/conda/en/stable).
Then, follow these steps:
```bash
mkdir -p ~/local
cd ~/local
@ -128,7 +131,7 @@ You have two ways to install this repository:
Please checkout our [Documentation](https://llama-stack.readthedocs.io/en/latest/index.html) page for more details.
* [CLI reference](https://llama-stack.readthedocs.io/en/latest/cli_reference/index.html)
* [CLI reference](https://llama-stack.readthedocs.io/en/latest/references/llama_cli_reference/index.html)
* Guide using `llama` CLI to work with Llama models (download, study prompts), and building/starting a Llama Stack distribution.
* [Getting Started](https://llama-stack.readthedocs.io/en/latest/getting_started/index.html)
* Quick guide to start a Llama Stack server.
@ -136,7 +139,7 @@ Please checkout our [Documentation](https://llama-stack.readthedocs.io/en/latest
* The complete Llama Stack lesson [Colab notebook](https://colab.research.google.com/drive/1dtVmxotBsI4cGZQNsJRYPrLiDeT0Wnwt) of the new [Llama 3.2 course on Deeplearning.ai](https://learn.deeplearning.ai/courses/introducing-multimodal-llama-3-2/lesson/8/llama-stack).
* A [Zero-to-Hero Guide](https://github.com/meta-llama/llama-stack/tree/main/docs/zero_to_hero_guide) that guide you through all the key components of llama stack with code samples.
* [Contributing](CONTRIBUTING.md)
* [Adding a new API Provider](https://llama-stack.readthedocs.io/en/latest/api_providers/new_api_provider.html) to walk-through how to add a new API provider.
* [Adding a new API Provider](https://llama-stack.readthedocs.io/en/latest/contributing/new_api_provider.html) to walk-through how to add a new API provider.
## Llama Stack Client SDKs

View file

@ -0,0 +1 @@
../../llama_stack/templates/cerebras/build.yaml

View file

@ -0,0 +1,16 @@
services:
llamastack:
image: llamastack/distribution-cerebras
network_mode: "host"
volumes:
- ~/.llama:/root/.llama
- ./run.yaml:/root/llamastack-run-cerebras.yaml
ports:
- "5000:5000"
entrypoint: bash -c "python -m llama_stack.distribution.server.server --yaml_config /root/llamastack-run-cerebras.yaml"
deploy:
restart_policy:
condition: on-failure
delay: 3s
max_attempts: 5
window: 60s

View file

@ -0,0 +1 @@
../../llama_stack/templates/cerebras/run.yaml

View file

@ -2,9 +2,11 @@
"hf-serverless": [
"aiohttp",
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"chromadb-client",
"datasets",
"faiss-cpu",
"fastapi",
"fire",
@ -13,6 +15,7 @@
"matplotlib",
"nltk",
"numpy",
"openai",
"pandas",
"pillow",
"psycopg2-binary",
@ -29,9 +32,11 @@
],
"together": [
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"chromadb-client",
"datasets",
"faiss-cpu",
"fastapi",
"fire",
@ -39,6 +44,7 @@
"matplotlib",
"nltk",
"numpy",
"openai",
"pandas",
"pillow",
"psycopg2-binary",
@ -56,9 +62,11 @@
],
"vllm-gpu": [
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"chromadb-client",
"datasets",
"faiss-cpu",
"fastapi",
"fire",
@ -66,6 +74,7 @@
"matplotlib",
"nltk",
"numpy",
"openai",
"pandas",
"pillow",
"psycopg2-binary",
@ -110,9 +119,11 @@
],
"fireworks": [
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"chromadb-client",
"datasets",
"faiss-cpu",
"fastapi",
"fire",
@ -121,6 +132,7 @@
"matplotlib",
"nltk",
"numpy",
"openai",
"pandas",
"pillow",
"psycopg2-binary",
@ -138,9 +150,11 @@
"tgi": [
"aiohttp",
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"chromadb-client",
"datasets",
"faiss-cpu",
"fastapi",
"fire",
@ -149,6 +163,7 @@
"matplotlib",
"nltk",
"numpy",
"openai",
"pandas",
"pillow",
"psycopg2-binary",
@ -165,10 +180,12 @@
],
"bedrock": [
"aiosqlite",
"autoevals",
"blobfile",
"boto3",
"chardet",
"chromadb-client",
"datasets",
"faiss-cpu",
"fastapi",
"fire",
@ -176,6 +193,7 @@
"matplotlib",
"nltk",
"numpy",
"openai",
"pandas",
"pillow",
"psycopg2-binary",
@ -193,9 +211,11 @@
"meta-reference-gpu": [
"accelerate",
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"chromadb-client",
"datasets",
"fairscale",
"faiss-cpu",
"fastapi",
@ -205,6 +225,7 @@
"matplotlib",
"nltk",
"numpy",
"openai",
"pandas",
"pillow",
"psycopg2-binary",
@ -225,9 +246,11 @@
"meta-reference-quantized-gpu": [
"accelerate",
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"chromadb-client",
"datasets",
"fairscale",
"faiss-cpu",
"fastapi",
@ -238,6 +261,7 @@
"matplotlib",
"nltk",
"numpy",
"openai",
"pandas",
"pillow",
"psycopg2-binary",
@ -256,12 +280,40 @@
"sentence-transformers --no-deps",
"torch --index-url https://download.pytorch.org/whl/cpu"
],
"cerebras": [
"aiosqlite",
"blobfile",
"cerebras_cloud_sdk",
"chardet",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"matplotlib",
"nltk",
"numpy",
"pandas",
"pillow",
"psycopg2-binary",
"pypdf",
"redis",
"scikit-learn",
"scipy",
"sentencepiece",
"tqdm",
"transformers",
"uvicorn",
"sentence-transformers --no-deps",
"torch --index-url https://download.pytorch.org/whl/cpu"
],
"ollama": [
"aiohttp",
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"chromadb-client",
"datasets",
"faiss-cpu",
"fastapi",
"fire",
@ -270,6 +322,7 @@
"nltk",
"numpy",
"ollama",
"openai",
"pandas",
"pillow",
"psycopg2-binary",
@ -287,9 +340,11 @@
"hf-endpoint": [
"aiohttp",
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"chromadb-client",
"datasets",
"faiss-cpu",
"fastapi",
"fire",
@ -298,6 +353,7 @@
"matplotlib",
"nltk",
"numpy",
"openai",
"pandas",
"pillow",
"psycopg2-binary",

View file

@ -2291,6 +2291,39 @@
"required": true
}
}
},
"/alpha/datasets/unregister": {
"post": {
"responses": {
"200": {
"description": "OK"
}
},
"tags": [
"Datasets"
],
"parameters": [
{
"name": "X-LlamaStack-ProviderData",
"in": "header",
"description": "JSON-encoded provider data which will be made available to the adapter servicing the API",
"required": false,
"schema": {
"type": "string"
}
}
],
"requestBody": {
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/UnregisterDatasetRequest"
}
}
},
"required": true
}
}
}
},
"jsonSchemaDialect": "https://json-schema.org/draft/2020-12/schema",
@ -7917,6 +7950,18 @@
"required": [
"model_id"
]
},
"UnregisterDatasetRequest": {
"type": "object",
"properties": {
"dataset_id": {
"type": "string"
}
},
"additionalProperties": false,
"required": [
"dataset_id"
]
}
},
"responses": {}
@ -8529,6 +8574,10 @@
"name": "UnregisterModelRequest",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/UnregisterModelRequest\" />"
},
{
"name": "UnregisterDatasetRequest",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/UnregisterDatasetRequest\" />"
},
{
"name": "UnstructuredLogEvent",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/UnstructuredLogEvent\" />"
@ -8718,6 +8767,7 @@
"URL",
"UnregisterMemoryBankRequest",
"UnregisterModelRequest",
"UnregisterDatasetRequest",
"UnstructuredLogEvent",
"UserMessage",
"VectorMemoryBank",

View file

@ -3253,6 +3253,14 @@ components:
required:
- model_id
type: object
UnregisterDatasetRequest:
additionalProperties: false
properties:
dataset_id:
type: string
required:
- dataset_id
type: object
UnstructuredLogEvent:
additionalProperties: false
properties:
@ -3789,6 +3797,27 @@ paths:
description: OK
tags:
- Datasets
/alpha/datasets/unregister:
post:
parameters:
- description: JSON-encoded provider data which will be made available to the
adapter servicing the API
in: header
name: X-LlamaStack-ProviderData
required: false
schema:
type: string
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/UnregisterDatasetRequest'
required: true
responses:
'200':
description: OK
tags:
- Datasets
/alpha/eval-tasks/get:
get:
parameters:
@ -5242,6 +5271,9 @@ tags:
- description: <SchemaDefinition schemaRef="#/components/schemas/UnregisterModelRequest"
/>
name: UnregisterModelRequest
- description: <SchemaDefinition schemaRef="#/components/schemas/UnregisterDatasetRequest"
/>
name: UnregisterDatasetRequest
- description: <SchemaDefinition schemaRef="#/components/schemas/UnstructuredLogEvent"
/>
name: UnstructuredLogEvent
@ -5418,6 +5450,7 @@ x-tagGroups:
- URL
- UnregisterMemoryBankRequest
- UnregisterModelRequest
- UnregisterDatasetRequest
- UnstructuredLogEvent
- UserMessage
- VectorMemoryBank

View file

@ -11,5 +11,12 @@
- memory / RAG; pre-ingesting content or attaching content in a turn
- how does tool calling work
- can you do evaluation?
```
For details on how to use the telemetry system to debug your applications, export traces to a dataset, and run evaluations, see the [Telemetry](telemetry) section.
```{toctree}
:hidden:
:maxdepth: 3
telemetry
```

View file

@ -0,0 +1,243 @@
# Telemetry
```{note}
The telemetry system is currently experimental and subject to change. We welcome feedback and contributions to help improve it.
```
The Llama Stack telemetry system provides comprehensive tracing, metrics, and logging capabilities. It supports multiple sink types including OpenTelemetry, SQLite, and Console output.
## Key Concepts
### Events
The telemetry system supports three main types of events:
- **Unstructured Log Events**: Free-form log messages with severity levels
```python
unstructured_log_event = UnstructuredLogEvent(
message="This is a log message",
severity=LogSeverity.INFO
)
```
- **Metric Events**: Numerical measurements with units
```python
metric_event = MetricEvent(
metric="my_metric",
value=10,
unit="count"
)
```
- **Structured Log Events**: System events like span start/end. Extensible to add more structured log types.
```python
structured_log_event = SpanStartPayload(
name="my_span",
parent_span_id="parent_span_id"
)
```
### Spans and Traces
- **Spans**: Represent operations with timing and hierarchical relationships
- **Traces**: Collection of related spans forming a complete request flow
### Sinks
- **OpenTelemetry**: Send events to an OpenTelemetry Collector. This is useful for visualizing traces in a service like Jaeger.
- **SQLite**: Store events in a local SQLite database. This is needed if you want to query the events later through the Llama Stack API.
- **Console**: Print events to the console.
## APIs
The telemetry API is designed to be flexible for different user flows like debugging/visualization in UI, monitoring, and saving traces to datasets.
The telemetry system exposes the following HTTP endpoints:
### Log Event
```http
POST /telemetry/log-event
```
Logs a telemetry event (unstructured log, metric, or structured log) with optional TTL.
### Query Traces
```http
POST /telemetry/query-traces
```
Retrieves traces based on filters with pagination support. Parameters:
- `attribute_filters`: List of conditions to filter traces
- `limit`: Maximum number of traces to return (default: 100)
- `offset`: Number of traces to skip (default: 0)
- `order_by`: List of fields to sort by
### Get Span Tree
```http
POST /telemetry/get-span-tree
```
Retrieves a hierarchical view of spans starting from a specific span. Parameters:
- `span_id`: ID of the root span to retrieve
- `attributes_to_return`: Optional list of specific attributes to include
- `max_depth`: Optional maximum depth of the span tree to return
### Query Spans
```http
POST /telemetry/query-spans
```
Retrieves spans matching specified filters and returns selected attributes. Parameters:
- `attribute_filters`: List of conditions to filter traces
- `attributes_to_return`: List of specific attributes to include in results
- `max_depth`: Optional maximum depth of spans to traverse (default: no limit)
Returns a flattened list of spans with requested attributes.
### Save Spans to Dataset
This is useful for saving traces to a dataset for running evaluations. For example, you can save the input/output of each span that is part of an agent session/turn to a dataset and then run an eval task on it. See example in [Example: Save Spans to Dataset](#example-save-spans-to-dataset).
```http
POST /telemetry/save-spans-to-dataset
```
Queries spans and saves their attributes to a dataset. Parameters:
- `attribute_filters`: List of conditions to filter traces
- `attributes_to_save`: List of span attributes to save to the dataset
- `dataset_id`: ID of the dataset to save to
- `max_depth`: Optional maximum depth of spans to traverse (default: no limit)
## Providers
### Meta-Reference Provider
Currently, only the meta-reference provider is implemented. It can be configured to send events to three sink types:
1) OpenTelemetry Collector
2) SQLite
3) Console
## Configuration
Here's an example that sends telemetry signals to all three sink types. Your configuration might use only one.
```yaml
telemetry:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
sinks: ['console', 'sqlite', 'otel']
otel_endpoint: "http://localhost:4318/v1/traces"
sqlite_db_path: "/path/to/telemetry.db"
```
## Jaeger to visualize traces
The `otel` sink works with any service compatible with the OpenTelemetry collector. Let's use Jaeger to visualize this data.
Start a Jaeger instance with the OTLP HTTP endpoint at 4318 and the Jaeger UI at 16686 using the following command:
```bash
$ docker run --rm \
--name jaeger jaegertracing/jaeger:2.0.0 \
-p 16686:16686 -p 4318:4318 \
--set receivers.otlp.protocols.http.endpoint=0.0.0.0:4318
```
Once the Jaeger instance is running, you can visualize traces by navigating to http://localhost:16686.
## Querying Traces Stored in SQLIte
The `sqlite` sink allows you to query traces without an external system. Here are some example queries:
Querying Traces for a agent session
The client SDK is not updated to support the new telemetry API. It will be updated soon. You can manually query traces using the following curl command:
``` bash
curl -X POST 'http://localhost:5000/alpha/telemetry/query-traces' \
-H 'Content-Type: application/json' \
-d '{
"attribute_filters": [
{
"key": "session_id",
"op": "eq",
"value": "dd667b87-ca4b-4d30-9265-5a0de318fc65" }],
"limit": 100,
"offset": 0,
"order_by": ["start_time"]
[
{
"trace_id": "6902f54b83b4b48be18a6f422b13e16f",
"root_span_id": "5f37b85543afc15a",
"start_time": "2024-12-04T08:08:30.501587",
"end_time": "2024-12-04T08:08:36.026463"
},
........
]
}'
```
Querying spans for a specifc root span id
``` bash
curl -X POST 'http://localhost:5000/alpha/telemetry/get-span-tree' \
-H 'Content-Type: application/json' \
-d '{ "span_id" : "6cceb4b48a156913", "max_depth": 2 }'
{
"span_id": "6cceb4b48a156913",
"trace_id": "dafa796f6aaf925f511c04cd7c67fdda",
"parent_span_id": "892a66d726c7f990",
"name": "retrieve_rag_context",
"start_time": "2024-12-04T09:28:21.781995",
"end_time": "2024-12-04T09:28:21.913352",
"attributes": {
"input": [
"{\"role\":\"system\",\"content\":\"You are a helpful assistant\"}",
"{\"role\":\"user\",\"content\":\"What are the top 5 topics that were explained in the documentation? Only list succinct bullet points.\",\"context\":null}"
]
},
"children": [
{
"span_id": "1a2df181854064a8",
"trace_id": "dafa796f6aaf925f511c04cd7c67fdda",
"parent_span_id": "6cceb4b48a156913",
"name": "MemoryRouter.query_documents",
"start_time": "2024-12-04T09:28:21.787620",
"end_time": "2024-12-04T09:28:21.906512",
"attributes": {
"input": null
},
"children": [],
"status": "ok"
}
],
"status": "ok"
}
```
## Example: Save Spans to Dataset
Save all spans for a specific agent session to a dataset.
``` bash
curl -X POST 'http://localhost:5000/alpha/telemetry/save-spans-to-dataset' \
-H 'Content-Type: application/json' \
-d '{
"attribute_filters": [
{
"key": "session_id",
"op": "eq",
"value": "dd667b87-ca4b-4d30-9265-5a0de318fc65"
}
],
"attributes_to_save": ["input", "output"],
"dataset_id": "my_dataset",
"max_depth": 10
}'
```
Save all spans for a specific agent turn to a dataset.
```bash
curl -X POST 'http://localhost:5000/alpha/telemetry/save-spans-to-dataset' \
-H 'Content-Type: application/json' \
-d '{
"attribute_filters": [
{
"key": "turn_id",
"op": "eq",
"value": "123e4567-e89b-12d3-a456-426614174000"
}
],
"attributes_to_save": ["input", "output"],
"dataset_id": "my_dataset",
"max_depth": 10
}'
```

View file

@ -8,7 +8,7 @@ This guide contains references to walk you through adding a new API provider.
- {repopath}`Remote Providers::llama_stack/providers/remote`
- {repopath}`Inline Providers::llama_stack/providers/inline`
3. [Build a Llama Stack distribution](https://llama-stack.readthedocs.io/en/latest/distribution_dev/building_distro.html) with your API provider.
3. [Build a Llama Stack distribution](https://llama-stack.readthedocs.io/en/latest/distributions/building_distro.html) with your API provider.
4. Test your code!
## Testing your newly added API providers

View file

@ -66,121 +66,247 @@ llama stack build --list-templates
```
```
+------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+
| Template Name | Providers | Description |
+------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+
| hf-serverless | { | Like local, but use Hugging Face Inference API (serverless) for running LLM |
| | "inference": "remote::hf::serverless", | inference. |
| | "memory": "meta-reference", | See https://hf.co/docs/api-inference. |
| | "safety": "meta-reference", | |
| | "agents": "meta-reference", | |
| | "telemetry": "meta-reference" | |
| | } | |
+------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+
| together | { | Use Together.ai for running LLM inference |
| | "inference": "remote::together", | |
| | "memory": [ | |
| | "meta-reference", | |
| | "remote::weaviate" | |
| | ], | |
| | "safety": "meta-reference", | |
| | "agents": "meta-reference", | |
| | "telemetry": "meta-reference" | |
| | } | |
+------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+
| fireworks | { | Use Fireworks.ai for running LLM inference |
| | "inference": "remote::fireworks", | |
| | "memory": [ | |
| | "meta-reference", | |
| | "remote::weaviate", | |
| | "remote::chromadb", | |
| | "remote::pgvector" | |
| | ], | |
| | "safety": "meta-reference", | |
| | "agents": "meta-reference", | |
| | "telemetry": "meta-reference" | |
| | } | |
+------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+
| databricks | { | Use Databricks for running LLM inference |
| | "inference": "remote::databricks", | |
| | "memory": "meta-reference", | |
| | "safety": "meta-reference", | |
| | "agents": "meta-reference", | |
| | "telemetry": "meta-reference" | |
| | } | |
+------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+
| vllm | { | Like local, but use vLLM for running LLM inference |
| | "inference": "vllm", | |
| | "memory": "meta-reference", | |
| | "safety": "meta-reference", | |
| | "agents": "meta-reference", | |
| | "telemetry": "meta-reference" | |
| | } | |
+------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+
| tgi | { | Use TGI for running LLM inference |
| | "inference": "remote::tgi", | |
| | "memory": [ | |
| | "meta-reference", | |
| | "remote::chromadb", | |
| | "remote::pgvector" | |
| | ], | |
| | "safety": "meta-reference", | |
| | "agents": "meta-reference", | |
| | "telemetry": "meta-reference" | |
| | } | |
+------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+
| bedrock | { | Use Amazon Bedrock APIs. |
| | "inference": "remote::bedrock", | |
| | "memory": "meta-reference", | |
| | "safety": "meta-reference", | |
| | "agents": "meta-reference", | |
| | "telemetry": "meta-reference" | |
| | } | |
+------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+
| meta-reference-gpu | { | Use code from `llama_stack` itself to serve all llama stack APIs |
| | "inference": "meta-reference", | |
| | "memory": [ | |
| | "meta-reference", | |
| | "remote::chromadb", | |
| | "remote::pgvector" | |
| | ], | |
| | "safety": "meta-reference", | |
| | "agents": "meta-reference", | |
| | "telemetry": "meta-reference" | |
| | } | |
+------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+
| meta-reference-quantized-gpu | { | Use code from `llama_stack` itself to serve all llama stack APIs |
| | "inference": "meta-reference-quantized", | |
| | "memory": [ | |
| | "meta-reference", | |
| | "remote::chromadb", | |
| | "remote::pgvector" | |
| | ], | |
| | "safety": "meta-reference", | |
| | "agents": "meta-reference", | |
| | "telemetry": "meta-reference" | |
| | } | |
+------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+
| ollama | { | Use ollama for running LLM inference |
| | "inference": "remote::ollama", | |
| | "memory": [ | |
| | "meta-reference", | |
| | "remote::chromadb", | |
| | "remote::pgvector" | |
| | ], | |
| | "safety": "meta-reference", | |
| | "agents": "meta-reference", | |
| | "telemetry": "meta-reference" | |
| | } | |
+------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+
| hf-endpoint | { | Like local, but use Hugging Face Inference Endpoints for running LLM inference. |
| | "inference": "remote::hf::endpoint", | See https://hf.co/docs/api-endpoints. |
| | "memory": "meta-reference", | |
| | "safety": "meta-reference", | |
| | "agents": "meta-reference", | |
| | "telemetry": "meta-reference" | |
| | } | |
+------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+
+------------------------------+----------------------------------------+-----------------------------------------------------------------------------+
| Template Name | Providers | Description |
+------------------------------+----------------------------------------+-----------------------------------------------------------------------------+
| tgi | { | Use (an external) TGI server for running LLM inference |
| | "inference": [ | |
| | "remote::tgi" | |
| | ], | |
| | "memory": [ | |
| | "inline::faiss", | |
| | "remote::chromadb", | |
| | "remote::pgvector" | |
| | ], | |
| | "safety": [ | |
| | "inline::llama-guard" | |
| | ], | |
| | "agents": [ | |
| | "inline::meta-reference" | |
| | ], | |
| | "telemetry": [ | |
| | "inline::meta-reference" | |
| | ] | |
| | } | |
+------------------------------+----------------------------------------+-----------------------------------------------------------------------------+
| remote-vllm | { | Use (an external) vLLM server for running LLM inference |
| | "inference": [ | |
| | "remote::vllm" | |
| | ], | |
| | "memory": [ | |
| | "inline::faiss", | |
| | "remote::chromadb", | |
| | "remote::pgvector" | |
| | ], | |
| | "safety": [ | |
| | "inline::llama-guard" | |
| | ], | |
| | "agents": [ | |
| | "inline::meta-reference" | |
| | ], | |
| | "telemetry": [ | |
| | "inline::meta-reference" | |
| | ] | |
| | } | |
+------------------------------+----------------------------------------+-----------------------------------------------------------------------------+
| vllm-gpu | { | Use a built-in vLLM engine for running LLM inference |
| | "inference": [ | |
| | "inline::vllm" | |
| | ], | |
| | "memory": [ | |
| | "inline::faiss", | |
| | "remote::chromadb", | |
| | "remote::pgvector" | |
| | ], | |
| | "safety": [ | |
| | "inline::llama-guard" | |
| | ], | |
| | "agents": [ | |
| | "inline::meta-reference" | |
| | ], | |
| | "telemetry": [ | |
| | "inline::meta-reference" | |
| | ] | |
| | } | |
+------------------------------+----------------------------------------+-----------------------------------------------------------------------------+
| meta-reference-quantized-gpu | { | Use Meta Reference with fp8, int4 quantization for running LLM inference |
| | "inference": [ | |
| | "inline::meta-reference-quantized" | |
| | ], | |
| | "memory": [ | |
| | "inline::faiss", | |
| | "remote::chromadb", | |
| | "remote::pgvector" | |
| | ], | |
| | "safety": [ | |
| | "inline::llama-guard" | |
| | ], | |
| | "agents": [ | |
| | "inline::meta-reference" | |
| | ], | |
| | "telemetry": [ | |
| | "inline::meta-reference" | |
| | ] | |
| | } | |
+------------------------------+----------------------------------------+-----------------------------------------------------------------------------+
| meta-reference-gpu | { | Use Meta Reference for running LLM inference |
| | "inference": [ | |
| | "inline::meta-reference" | |
| | ], | |
| | "memory": [ | |
| | "inline::faiss", | |
| | "remote::chromadb", | |
| | "remote::pgvector" | |
| | ], | |
| | "safety": [ | |
| | "inline::llama-guard" | |
| | ], | |
| | "agents": [ | |
| | "inline::meta-reference" | |
| | ], | |
| | "telemetry": [ | |
| | "inline::meta-reference" | |
| | ] | |
| | } | |
+------------------------------+----------------------------------------+-----------------------------------------------------------------------------+
| hf-serverless | { | Use (an external) Hugging Face Inference Endpoint for running LLM inference |
| | "inference": [ | |
| | "remote::hf::serverless" | |
| | ], | |
| | "memory": [ | |
| | "inline::faiss", | |
| | "remote::chromadb", | |
| | "remote::pgvector" | |
| | ], | |
| | "safety": [ | |
| | "inline::llama-guard" | |
| | ], | |
| | "agents": [ | |
| | "inline::meta-reference" | |
| | ], | |
| | "telemetry": [ | |
| | "inline::meta-reference" | |
| | ] | |
| | } | |
+------------------------------+----------------------------------------+-----------------------------------------------------------------------------+
| together | { | Use Together.AI for running LLM inference |
| | "inference": [ | |
| | "remote::together" | |
| | ], | |
| | "memory": [ | |
| | "inline::faiss", | |
| | "remote::chromadb", | |
| | "remote::pgvector" | |
| | ], | |
| | "safety": [ | |
| | "inline::llama-guard" | |
| | ], | |
| | "agents": [ | |
| | "inline::meta-reference" | |
| | ], | |
| | "telemetry": [ | |
| | "inline::meta-reference" | |
| | ] | |
| | } | |
+------------------------------+----------------------------------------+-----------------------------------------------------------------------------+
| ollama | { | Use (an external) Ollama server for running LLM inference |
| | "inference": [ | |
| | "remote::ollama" | |
| | ], | |
| | "memory": [ | |
| | "inline::faiss", | |
| | "remote::chromadb", | |
| | "remote::pgvector" | |
| | ], | |
| | "safety": [ | |
| | "inline::llama-guard" | |
| | ], | |
| | "agents": [ | |
| | "inline::meta-reference" | |
| | ], | |
| | "telemetry": [ | |
| | "inline::meta-reference" | |
| | ] | |
| | } | |
+------------------------------+----------------------------------------+-----------------------------------------------------------------------------+
| bedrock | { | Use AWS Bedrock for running LLM inference and safety |
| | "inference": [ | |
| | "remote::bedrock" | |
| | ], | |
| | "memory": [ | |
| | "inline::faiss", | |
| | "remote::chromadb", | |
| | "remote::pgvector" | |
| | ], | |
| | "safety": [ | |
| | "remote::bedrock" | |
| | ], | |
| | "agents": [ | |
| | "inline::meta-reference" | |
| | ], | |
| | "telemetry": [ | |
| | "inline::meta-reference" | |
| | ] | |
| | } | |
+------------------------------+----------------------------------------+-----------------------------------------------------------------------------+
| hf-endpoint | { | Use (an external) Hugging Face Inference Endpoint for running LLM inference |
| | "inference": [ | |
| | "remote::hf::endpoint" | |
| | ], | |
| | "memory": [ | |
| | "inline::faiss", | |
| | "remote::chromadb", | |
| | "remote::pgvector" | |
| | ], | |
| | "safety": [ | |
| | "inline::llama-guard" | |
| | ], | |
| | "agents": [ | |
| | "inline::meta-reference" | |
| | ], | |
| | "telemetry": [ | |
| | "inline::meta-reference" | |
| | ] | |
| | } | |
+------------------------------+----------------------------------------+-----------------------------------------------------------------------------+
| fireworks | { | Use Fireworks.AI for running LLM inference |
| | "inference": [ | |
| | "remote::fireworks" | |
| | ], | |
| | "memory": [ | |
| | "inline::faiss", | |
| | "remote::chromadb", | |
| | "remote::pgvector" | |
| | ], | |
| | "safety": [ | |
| | "inline::llama-guard" | |
| | ], | |
| | "agents": [ | |
| | "inline::meta-reference" | |
| | ], | |
| | "telemetry": [ | |
| | "inline::meta-reference" | |
| | ] | |
| | } | |
+------------------------------+----------------------------------------+-----------------------------------------------------------------------------+
| cerebras | { | Use Cerebras for running LLM inference |
| | "inference": [ | |
| | "remote::cerebras" | |
| | ], | |
| | "safety": [ | |
| | "inline::llama-guard" | |
| | ], | |
| | "memory": [ | |
| | "inline::meta-reference" | |
| | ], | |
| | "agents": [ | |
| | "inline::meta-reference" | |
| | ], | |
| | "telemetry": [ | |
| | "inline::meta-reference" | |
| | ] | |
| | } | |
+------------------------------+----------------------------------------+-----------------------------------------------------------------------------+
```
You may then pick a template to build your distribution with providers fitted to your liking.

View file

@ -55,7 +55,7 @@ models:
shields: []
```
Let's break this down into the different sections. It starts by specifying the set of APIs that the stack server will serve:
Let's break this down into the different sections. The first section specifies the set of APIs that the stack server will serve:
```yaml
apis:
- agents
@ -65,7 +65,8 @@ apis:
- telemetry
```
Next up is the most critical section -- the set of providers that the stack will use to serve the above APIs. Let's take the `inference` API as an example:
## Providers
Next up is the most critical part: the set of providers that the stack will use to serve the above APIs. Consider the `inference` API:
```yaml
providers:
inference:
@ -74,8 +75,12 @@ providers:
config:
url: ${env.OLLAMA_URL:http://localhost:11434}
```
A _provider instance_ is identified with an (identifier, type, configuration) tuple. The identifier is a string you can choose freely. You may instantiate any number of provider instances of the same type. The configuration dictionary is provider-specific. Notice that configuration can reference environment variables (with default values), which are expanded at runtime. When you run a stack server (via docker or via `llama stack run`), you can specify `--env OLLAMA_URL=http://my-server:11434` to override the default value.
A few things to note:
- A _provider instance_ is identified with an (identifier, type, configuration) tuple. The identifier is a string you can choose freely.
- You can instantiate any number of provider instances of the same type.
- The configuration dictionary is provider-specific. Notice that configuration can reference environment variables (with default values), which are expanded at runtime. When you run a stack server (via docker or via `llama stack run`), you can specify `--env OLLAMA_URL=http://my-server:11434` to override the default value.
## Resources
Finally, let's look at the `models` section:
```yaml
models:
@ -87,3 +92,73 @@ models:
A Model is an instance of a "Resource" (see [Concepts](../concepts/index)) and is associated with a specific inference provider (in this case, the provider with identifier `ollama`). This is an instance of a "pre-registered" model. While we always encourage the clients to always register models before using them, some Stack servers may come up a list of "already known and available" models.
What's with the `provider_model_id` field? This is an identifier for the model inside the provider's model catalog. Contrast it with `model_id` which is the identifier for the same model for Llama Stack's purposes. For example, you may want to name "llama3.2:vision-11b" as "image_captioning_model" when you use it in your Stack interactions. When omitted, the server will set `provider_model_id` to be the same as `model_id`.
## Extending to handle Safety
Configuring Safety can be a little involved so it is instructive to go through an example.
The Safety API works with the associated Resource called a `Shield`. Providers can support various kinds of Shields. Good examples include the [Llama Guard](https://ai.meta.com/research/publications/llama-guard-llm-based-input-output-safeguard-for-human-ai-conversations/) system-safety models, or [Bedrock Guardrails](https://aws.amazon.com/bedrock/guardrails/).
To configure a Bedrock Shield, you would need to add:
- A Safety API provider instance with type `remote::bedrock`
- A Shield resource served by this provider.
```yaml
...
providers:
safety:
- provider_id: bedrock
provider_type: remote::bedrock
config:
aws_access_key_id: ${env.AWS_ACCESS_KEY_ID}
aws_secret_access_key: ${env.AWS_SECRET_ACCESS_KEY}
...
shields:
- provider_id: bedrock
params:
guardrailVersion: ${env.GUARDRAIL_VERSION}
provider_shield_id: ${env.GUARDRAIL_ID}
...
```
The situation is more involved if the Shield needs _Inference_ of an associated model. This is the case with Llama Guard. In that case, you would need to add:
- A Safety API provider instance with type `inline::llama-guard`
- An Inference API provider instance for serving the model.
- A Model resource associated with this provider.
- A Shield resource served by the Safety provider.
The yaml configuration for this setup, assuming you were using vLLM as your inference server, would look like:
```yaml
...
providers:
safety:
- provider_id: llama-guard
provider_type: inline::llama-guard
config: {}
inference:
# this vLLM server serves the "normal" inference model (e.g., llama3.2:3b)
- provider_id: vllm-0
provider_type: remote::vllm
config:
url: ${env.VLLM_URL:http://localhost:8000}
# this vLLM server serves the llama-guard model (e.g., llama-guard:3b)
- provider_id: vllm-1
provider_type: remote::vllm
config:
url: ${env.SAFETY_VLLM_URL:http://localhost:8001}
...
models:
- metadata: {}
model_id: ${env.INFERENCE_MODEL}
provider_id: vllm-0
provider_model_id: null
- metadata: {}
model_id: ${env.SAFETY_MODEL}
provider_id: vllm-1
provider_model_id: null
shields:
- provider_id: llama-guard
shield_id: ${env.SAFETY_MODEL} # Llama Guard shields are identified by the corresponding LlamaGuard model
provider_shield_id: null
...
```

View file

@ -21,7 +21,7 @@ print(response)
```python
response = await client.inference.chat_completion(
messages=[UserMessage(content="What is the capital of France?", role="user")],
model="Llama3.1-8B-Instruct",
model_id="Llama3.1-8B-Instruct",
stream=False,
)
print("\nChat completion response:")

View file

@ -35,6 +35,6 @@ If so, we suggest:
- **Do you want to run Llama Stack inference on your iOS / Android device** If so, we suggest:
- [iOS SDK](ondevice_distro/ios_sdk)
- Android (coming soon)
- [Android](ondevice_distro/android_sdk)
You can also build your own [custom distribution](building_distro).

View file

@ -0,0 +1,247 @@
# Llama Stack Client Kotlin API Library
We are excited to share a guide for a Kotlin Library that brings front the benefits of Llama Stack to your Android device. This library is a set of SDKs that provide a simple and effective way to integrate AI capabilities into your Android app whether it is local (on-device) or remote inference.
Features:
- Local Inferencing: Run Llama models purely on-device with real-time processing. We currently utilize ExecuTorch as the local inference distributor and may support others in the future.
- [ExecuTorch](https://github.com/pytorch/executorch/tree/main) is a complete end-to-end solution within the PyTorch framework for inferencing capabilities on-device with high portability and seamless performance.
- Remote Inferencing: Perform inferencing tasks remotely with Llama models hosted on a remote connection (or serverless localhost).
- Simple Integration: With easy-to-use APIs, a developer can quickly integrate Llama Stack in their Android app. The difference with local vs remote inferencing is also minimal.
Latest Release Notes: [v0.0.54.1](https://github.com/meta-llama/llama-stack-client-kotlin/releases/tag/v0.0.54.1)
## Android Demo App
Check out our demo app to see how to integrate Llama Stack into your Android app: [Android Demo App](https://github.com/meta-llama/llama-stack-apps/tree/main/examples/android_app)
The key files in the app are `LlamaStackLocalInference.kt`, `LlamaStackRemoteInference.kts`, and `MainActivity.java`. With encompassed business logic, the app shows how to use Llama Stack for both the environments.
## Quick Start
### Add Dependencies
#### Kotlin Library
Add the following dependency in your `build.gradle.kts` file:
```
dependencies {
implementation("com.llama.llamastack:llama-stack-client-kotlin:0.0.54.1")
}
```
This will download jar files in your gradle cache in a directory like `~/.gradle/caches/modules-2/files-2.1/com.llama.llamastack/`
If you plan on doing remote inferencing this is sufficient to get started.
#### Dependency for Local
For local inferencing, it is required to include the ExecuTorch library into your app.
Include the ExecuTorch library by:
1. Download the `download-prebuilt-et-lib.sh` script file from the [llama-stack-client-kotlin-client-local](https://github.com/meta-llama/llama-stack-client-kotlin/blob/release/0.0.54.1/llama-stack-client-kotlin-client-local/download-prebuilt-et-lib.sh) directory to your local machine.
2. Move the script to the top level of your Android app where the app directory resides:
<p align="center">
<img src="https://raw.githubusercontent.com/meta-llama/llama-stack-client-kotlin/refs/heads/release/0.0.54.1/doc/img/example_android_app_directory.png" style="width:300px">
</p>
3. Run `sh download-prebuilt-et-lib.sh` to create an `app/libs` directory and download the `executorch.aar` in that path. This generates an ExecuTorch library for the XNNPACK delegate with commit: [0a12e33](https://github.com/pytorch/executorch/commit/0a12e33d22a3d44d1aa2af5f0d0673d45b962553).
4. Add the `executorch.aar` dependency in your `build.gradle.kts` file:
```
dependencies {
...
implementation(files("libs/executorch.aar"))
...
}
```
## Llama Stack APIs in Your Android App
Breaking down the demo app, this section will show the core pieces that are used to initialize and run inference with Llama Stack using the Kotlin library.
### Setup Remote Inferencing
Start a Llama Stack server on localhost. Here is an example of how you can do this using the firework.ai distribution:
```
conda create -n stack-fireworks python=3.10
conda activate stack-fireworks
pip install llama-stack=0.0.54
llama stack build --template fireworks --image-type conda
export FIREWORKS_API_KEY=<SOME_KEY>
llama stack run /Users/<your_username>/.llama/distributions/llamastack-fireworks/fireworks-run.yaml --port=5050
```
Other inference providers: [Table](https://llama-stack.readthedocs.io/en/latest/index.html#supported-llama-stack-implementations)
How to set remote localhost in Demo App: [Settings](https://github.com/meta-llama/llama-stack-apps/tree/main/examples/android_app#settings)
### Initialize the Client
A client serves as the primary interface for interacting with a specific inference type and its associated parameters. Only after client is initialized then you can configure and start inferences.
<table>
<tr>
<th>Local Inference</th>
<th>Remote Inference</th>
</tr>
<tr>
<td>
```
client = LlamaStackClientLocalClient
.builder()
.modelPath(modelPath)
.tokenizerPath(tokenizerPath)
.temperature(temperature)
.build()
```
</td>
<td>
```
// remoteURL is a string like "http://localhost:5050"
client = LlamaStackClientOkHttpClient
.builder()
.baseUrl(remoteURL)
.build()
```
</td>
</tr>
</table>
### Run Inference
With the Kotlin Library managing all the major operational logic, there are minimal to no changes when running simple chat inference for local or remote:
```
val result = client!!.inference().chatCompletion(
InferenceChatCompletionParams.builder()
.modelId(modelName)
.putAdditionalQueryParam("seq_len", sequenceLength.toString())
.messages(listOfMessages)
.build()
)
// response contains string with response from model
var response = result.asChatCompletionResponse().completionMessage().content().string();
```
### Setup Tool Calling
Android demo app for more details: [Tool Calling](https://github.com/meta-llama/llama-stack-apps/tree/main/examples/android_app#tool-calling)
## Advanced Users
The purpose of this section is to share more details with users that would like to dive deeper into the Llama Stack Kotlin Library. Whether youre interested in contributing to the open source library, debugging or just want to learn more, this section is for you!
### Prerequisite
You must complete the following steps:
1. Clone the repo (`git clone https://github.com/meta-llama/llama-stack-client-kotlin.git -b release/0.0.54.1`)
2. Port the appropriate ExecuTorch libraries over into your Llama Stack Kotlin library environment.
```
cd llama-stack-client-kotlin-client-local
sh download-prebuilt-et-lib.sh --unzip
```
Now you will notice that the `jni/` , `libs/`, and `AndroidManifest.xml` files from the `executorch.aar` file are present in the local module. This way the local client module will be able to realize the ExecuTorch SDK.
### Building for Development/Debugging
If youd like to contribute to the Kotlin library via development, debug, or add play around with the library with various print statements, run the following command in your terminal under the llama-stack-client-kotlin directory.
```
sh build-libs.sh
```
Output: .jar files located in the build-jars directory
Copy the .jar files over to the lib directory in your Android app. At the same time make sure to remove the llama-stack-client-kotlin dependency within your build.gradle.kts file in your app (or if you are using the demo app) to avoid having multiple llama stack client dependencies.
### Additional Options for Local Inferencing
Currently we provide additional properties support with local inferencing. In order to get the tokens/sec metric for each inference call, add the following code in your Android app after you run your chatCompletion inference function. The Reference app has this implementation as well:
```
var tps = (result.asChatCompletionResponse()._additionalProperties()["tps"] as JsonNumber).value as Float
```
We will be adding more properties in the future.
### Additional Options for Remote Inferencing
#### Network options
##### Retries
Requests that experience certain errors are automatically retried 2 times by default, with a short exponential backoff. Connection errors (for example, due to a network connectivity problem), 408 Request Timeout, 409 Conflict, 429 Rate Limit, and >=500 Internal errors will all be retried by default.
You can provide a `maxRetries` on the client builder to configure this:
```kotlin
val client = LlamaStackClientOkHttpClient.builder()
.fromEnv()
.maxRetries(4)
.build()
```
##### Timeouts
Requests time out after 1 minute by default. You can configure this on the client builder:
```kotlin
val client = LlamaStackClientOkHttpClient.builder()
.fromEnv()
.timeout(Duration.ofSeconds(30))
.build()
```
##### Proxies
Requests can be routed through a proxy. You can configure this on the client builder:
```kotlin
val client = LlamaStackClientOkHttpClient.builder()
.fromEnv()
.proxy(new Proxy(
Type.HTTP,
new InetSocketAddress("proxy.com", 8080)
))
.build()
```
##### Environments
Requests are made to the production environment by default. You can connect to other environments, like `sandbox`, via the client builder:
```kotlin
val client = LlamaStackClientOkHttpClient.builder()
.fromEnv()
.sandbox()
.build()
```
### Error Handling
This library throws exceptions in a single hierarchy for easy handling:
- **`LlamaStackClientException`** - Base exception for all exceptions
- **`LlamaStackClientServiceException`** - HTTP errors with a well-formed response body we were able to parse. The exception message and the `.debuggingRequestId()` will be set by the server.
| 400 | BadRequestException |
| ------ | ----------------------------- |
| 401 | AuthenticationException |
| 403 | PermissionDeniedException |
| 404 | NotFoundException |
| 422 | UnprocessableEntityException |
| 429 | RateLimitException |
| 5xx | InternalServerException |
| others | UnexpectedStatusCodeException |
- **`LlamaStackClientIoException`** - I/O networking errors
- **`LlamaStackClientInvalidDataException`** - any other exceptions on the client side, e.g.:
- We failed to serialize the request body
- We failed to parse the response body (has access to response code and body)
## Reporting Issues
If you encountered any bugs or issues following this guide please file a bug/issue on our [Github issue tracker](https://github.com/meta-llama/llama-stack-client-kotlin/issues).
## Known Issues
We're aware of the following issues and are working to resolve them:
1. Streaming response is a work-in-progress for local and remote inference
2. Due to #1, agents are not supported at the time. LS agents only work in streaming mode
3. Changing to another model is a work in progress for local and remote platforms
## Thanks
We'd like to extend our thanks to the ExecuTorch team for providing their support as we integrated ExecuTorch as one of the local inference distributors for Llama Stack. Checkout [ExecuTorch Github repo](https://github.com/pytorch/executorch/tree/main) for more information.
---
The API interface is generated using the OpenAPI standard with [Stainless](https://www.stainlessapi.com/).

View file

@ -1,6 +1,3 @@
---
orphan: true
---
# Bedrock Distribution
```{toctree}
@ -15,9 +12,12 @@ The `llamastack/distribution-bedrock` distribution consists of the following pro
| API | Provider(s) |
|-----|-------------|
| agents | `inline::meta-reference` |
| datasetio | `remote::huggingface`, `inline::localfs` |
| eval | `inline::meta-reference` |
| inference | `remote::bedrock` |
| memory | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
| safety | `remote::bedrock` |
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
| telemetry | `inline::meta-reference` |

View file

@ -0,0 +1,61 @@
# Cerebras Distribution
The `llamastack/distribution-cerebras` distribution consists of the following provider configurations.
| API | Provider(s) |
|-----|-------------|
| agents | `inline::meta-reference` |
| inference | `remote::cerebras` |
| memory | `inline::meta-reference` |
| safety | `inline::llama-guard` |
| telemetry | `inline::meta-reference` |
### Environment Variables
The following environment variables can be configured:
- `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
- `CEREBRAS_API_KEY`: Cerebras API Key (default: ``)
### Models
The following models are available by default:
- `meta-llama/Llama-3.1-8B-Instruct (llama3.1-8b)`
- `meta-llama/Llama-3.1-70B-Instruct (llama3.1-70b)`
### Prerequisite: API Keys
Make sure you have access to a Cerebras API Key. You can get one by visiting [cloud.cerebras.ai](https://cloud.cerebras.ai/).
## Running Llama Stack with Cerebras
You can do this via Conda (build code) or Docker which has a pre-built image.
### Via Docker
This method allows you to get started quickly without having to build the distribution code.
```bash
LLAMA_STACK_PORT=5001
docker run \
-it \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ./run.yaml:/root/my-run.yaml \
llamastack/distribution-cerebras \
--yaml-config /root/my-run.yaml \
--port $LLAMA_STACK_PORT \
--env CEREBRAS_API_KEY=$CEREBRAS_API_KEY
```
### Via Conda
```bash
llama stack build --template cerebras --image-type conda
llama stack run ./run.yaml \
--port 5001 \
--env CEREBRAS_API_KEY=$CEREBRAS_API_KEY
```

View file

@ -15,9 +15,12 @@ The `llamastack/distribution-fireworks` distribution consists of the following p
| API | Provider(s) |
|-----|-------------|
| agents | `inline::meta-reference` |
| datasetio | `remote::huggingface`, `inline::localfs` |
| eval | `inline::meta-reference` |
| inference | `remote::fireworks` |
| memory | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
| safety | `inline::llama-guard` |
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
| telemetry | `inline::meta-reference` |

View file

@ -15,9 +15,12 @@ The `llamastack/distribution-meta-reference-gpu` distribution consists of the fo
| API | Provider(s) |
|-----|-------------|
| agents | `inline::meta-reference` |
| datasetio | `remote::huggingface`, `inline::localfs` |
| eval | `inline::meta-reference` |
| inference | `inline::meta-reference` |
| memory | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
| safety | `inline::llama-guard` |
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
| telemetry | `inline::meta-reference` |
@ -36,7 +39,7 @@ The following environment variables can be configured:
## Prerequisite: Downloading Models
Please make sure you have llama model checkpoints downloaded in `~/.llama` before proceeding. See [installation guide](https://llama-stack.readthedocs.io/en/latest/cli_reference/download_models.html) here to download the models. Run `llama model list` to see the available models to download, and `llama model download` to download the checkpoints.
Please make sure you have llama model checkpoints downloaded in `~/.llama` before proceeding. See [installation guide](https://llama-stack.readthedocs.io/en/latest/references/llama_cli_reference/download_models.html) here to download the models. Run `llama model list` to see the available models to download, and `llama model download` to download the checkpoints.
```
$ ls ~/.llama/checkpoints

View file

@ -15,9 +15,12 @@ The `llamastack/distribution-meta-reference-quantized-gpu` distribution consists
| API | Provider(s) |
|-----|-------------|
| agents | `inline::meta-reference` |
| datasetio | `remote::huggingface`, `inline::localfs` |
| eval | `inline::meta-reference` |
| inference | `inline::meta-reference-quantized` |
| memory | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
| safety | `inline::llama-guard` |
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
| telemetry | `inline::meta-reference` |
@ -36,7 +39,7 @@ The following environment variables can be configured:
## Prerequisite: Downloading Models
Please make sure you have llama model checkpoints downloaded in `~/.llama` before proceeding. See [installation guide](https://llama-stack.readthedocs.io/en/latest/cli_reference/download_models.html) here to download the models. Run `llama model list` to see the available models to download, and `llama model download` to download the checkpoints.
Please make sure you have llama model checkpoints downloaded in `~/.llama` before proceeding. See [installation guide](https://llama-stack.readthedocs.io/en/latest/references/llama_cli_reference/download_models.html) here to download the models. Run `llama model list` to see the available models to download, and `llama model download` to download the checkpoints.
```
$ ls ~/.llama/checkpoints

View file

@ -15,9 +15,12 @@ The `llamastack/distribution-ollama` distribution consists of the following prov
| API | Provider(s) |
|-----|-------------|
| agents | `inline::meta-reference` |
| datasetio | `remote::huggingface`, `inline::localfs` |
| eval | `inline::meta-reference` |
| inference | `remote::ollama` |
| memory | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
| safety | `inline::llama-guard` |
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
| telemetry | `inline::meta-reference` |
@ -118,9 +121,9 @@ llama stack run ./run-with-safety.yaml \
### (Optional) Update Model Serving Configuration
> [!NOTE]
> Please check the [OLLAMA_SUPPORTED_MODELS](https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers.remote/inference/ollama/ollama.py) for the supported Ollama models.
```{note}
Please check the [model_aliases](https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/inference/ollama/ollama.py#L45) for the supported Ollama models.
```
To serve a new model with `ollama`
```bash

View file

@ -16,9 +16,12 @@ The `llamastack/distribution-tgi` distribution consists of the following provide
| API | Provider(s) |
|-----|-------------|
| agents | `inline::meta-reference` |
| datasetio | `remote::huggingface`, `inline::localfs` |
| eval | `inline::meta-reference` |
| inference | `remote::tgi` |
| memory | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
| safety | `inline::llama-guard` |
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
| telemetry | `inline::meta-reference` |

View file

@ -15,9 +15,12 @@ The `llamastack/distribution-together` distribution consists of the following pr
| API | Provider(s) |
|-----|-------------|
| agents | `inline::meta-reference` |
| datasetio | `remote::huggingface`, `inline::localfs` |
| eval | `inline::meta-reference` |
| inference | `remote::together` |
| memory | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
| safety | `inline::llama-guard` |
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
| telemetry | `inline::meta-reference` |

View file

@ -62,7 +62,7 @@ llama-stack-client --endpoint http://localhost:$LLAMA_STACK_PORT models list
You can test basic Llama inference completion using the CLI too.
```bash
llama-stack-client --endpoint http://localhost:$LLAMA_STACK_PORT \
inference chat_completion \
inference chat-completion \
--message "hello, what model are you?"
```
@ -118,6 +118,7 @@ async def run_main():
model=os.environ["INFERENCE_MODEL"],
instructions="You are a helpful assistant",
tools=[{"type": "memory"}], # enable Memory aka RAG
enable_session_persistence=True,
)
agent = Agent(client, agent_config)
@ -139,7 +140,7 @@ async def run_main():
attachments=attachments,
session_id=session_id,
)
async for log in EventLogger().log(response):
for log in EventLogger().log(response):
log.print()

View file

@ -45,6 +45,7 @@ Llama Stack already has a number of "adapters" available for some popular Infere
| **API Provider** | **Environments** | **Agents** | **Inference** | **Memory** | **Safety** | **Telemetry** |
| :----: | :----: | :----: | :----: | :----: | :----: | :----: |
| Meta Reference | Single Node | Y | Y | Y | Y | Y |
| Cerebras | Single Node | | Y | | | |
| Fireworks | Hosted | Y | Y | Y | | |
| AWS Bedrock | Hosted | | Y | | Y | |
| Together | Hosted | Y | Y | | Y | |
@ -53,6 +54,7 @@ Llama Stack already has a number of "adapters" available for some popular Infere
| Chroma | Single Node | | | Y | | |
| Postgres | Single Node | | | Y | | |
| PyTorch ExecuTorch | On-device iOS | Y | Y | | |
| PyTorch ExecuTorch | On-device Android | | Y | | |
## Dive In

View file

@ -27,8 +27,6 @@ $ llama-stack-client configure
Done! You can now use the Llama Stack Client CLI with endpoint http://localhost:5000
```
## Provider Commands
### `llama-stack-client providers list`
```bash
$ llama-stack-client providers list
@ -119,8 +117,25 @@ $ llama-stack-client memory_banks list
+--------------+----------------+--------+-------------------+------------------------+--------------------------+
```
## Shield Management
### `llama-stack-client memory_banks register`
```bash
$ llama-stack-client memory_banks register <memory-bank-id> --type <type> [--provider-id <provider-id>] [--provider-memory-bank-id <provider-memory-bank-id>] [--chunk-size <chunk-size>] [--embedding-model <embedding-model>] [--overlap-size <overlap-size>]
```
Options:
- `--type`: Required. Type of memory bank. Choices: "vector", "keyvalue", "keyword", "graph"
- `--provider-id`: Optional. Provider ID for the memory bank
- `--provider-memory-bank-id`: Optional. Provider's memory bank ID
- `--chunk-size`: Optional. Chunk size in tokens (for vector type). Default: 512
- `--embedding-model`: Optional. Embedding model (for vector type). Default: "all-MiniLM-L6-v2"
- `--overlap-size`: Optional. Overlap size in tokens (for vector type). Default: 64
### `llama-stack-client memory_banks unregister`
```bash
$ llama-stack-client memory_banks unregister <memory-bank-id>
```
## Shield Management
### `llama-stack-client shields list`
```bash
$ llama-stack-client shields list
@ -134,16 +149,51 @@ $ llama-stack-client shields list
+--------------+----------+----------------+-------------+
```
## Evaluation Tasks
### `llama-stack-client shields register`
```bash
$ llama-stack-client shields register --shield-id <shield-id> [--provider-id <provider-id>] [--provider-shield-id <provider-shield-id>] [--params <params>]
```
Options:
- `--shield-id`: Required. ID of the shield
- `--provider-id`: Optional. Provider ID for the shield
- `--provider-shield-id`: Optional. Provider's shield ID
- `--params`: Optional. JSON configuration parameters for the shield
## Eval Task Management
### `llama-stack-client eval_tasks list`
```bash
$ llama-stack-client eval run_benchmark <task_id1> <task_id2> --num-examples 10 --output-dir ./ --eval-task-config ~/eval_task_config.json
$ llama-stack-client eval_tasks list
```
where `eval_task_config.json` is the path to the eval task config file in JSON format. An example eval_task_config
### `llama-stack-client eval_tasks register`
```bash
$ llama-stack-client eval_tasks register --eval-task-id <eval-task-id> --dataset-id <dataset-id> --scoring-functions <function1> [<function2> ...] [--provider-id <provider-id>] [--provider-eval-task-id <provider-eval-task-id>] [--metadata <metadata>]
```
$ cat ~/eval_task_config.json
Options:
- `--eval-task-id`: Required. ID of the eval task
- `--dataset-id`: Required. ID of the dataset to evaluate
- `--scoring-functions`: Required. One or more scoring functions to use for evaluation
- `--provider-id`: Optional. Provider ID for the eval task
- `--provider-eval-task-id`: Optional. Provider's eval task ID
- `--metadata`: Optional. Metadata for the eval task in JSON format
## Eval execution
### `llama-stack-client eval run-benchmark`
```bash
$ llama-stack-client eval run-benchmark <eval-task-id1> [<eval-task-id2> ...] --eval-task-config <config-file> --output-dir <output-dir> [--num-examples <num>] [--visualize]
```
Options:
- `--eval-task-config`: Required. Path to the eval task config file in JSON format
- `--output-dir`: Required. Path to the directory where evaluation results will be saved
- `--num-examples`: Optional. Number of examples to evaluate (useful for debugging)
- `--visualize`: Optional flag. If set, visualizes evaluation results after completion
Example eval_task_config.json:
```json
{
"type": "benchmark",
"eval_candidate": {
@ -160,3 +210,14 @@ $ cat ~/eval_task_config.json
}
}
```
### `llama-stack-client eval run-scoring`
```bash
$ llama-stack-client eval run-scoring <eval-task-id> --eval-task-config <config-file> --output-dir <output-dir> [--num-examples <num>] [--visualize]
```
Options:
- `--eval-task-config`: Required. Path to the eval task config file in JSON format
- `--output-dir`: Required. Path to the directory where scoring results will be saved
- `--num-examples`: Optional. Number of examples to evaluate (useful for debugging)
- `--visualize`: Optional flag. If set, visualizes scoring results after completion

View file

@ -13,13 +13,13 @@ Based on your developer needs, below are references to guides to help you get st
* Developer Need: I want to start a local Llama Stack server with my GPU using meta-reference implementations.
* Effort: 5min
* Guide:
- Please see our [meta-reference-gpu](https://llama-stack.readthedocs.io/en/latest/getting_started/distributions/meta-reference-gpu.html) on starting up a meta-reference Llama Stack server.
- Please see our [meta-reference-gpu](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/meta-reference-gpu.html) on starting up a meta-reference Llama Stack server.
### Llama Stack Server with Remote Providers
* Developer need: I want a Llama Stack distribution with a remote provider.
* Effort: 10min
* Guide
- Please see our [Distributions Guide](https://llama-stack.readthedocs.io/en/latest/getting_started/distributions/index.html) on starting up distributions with remote providers.
- Please see our [Distributions Guide](https://llama-stack.readthedocs.io/en/latest/concepts/index.html#distributions) on starting up distributions with remote providers.
### On-Device (iOS) Llama Stack
@ -38,4 +38,4 @@ Based on your developer needs, below are references to guides to help you get st
* Developer Need: I want to add a new API provider to Llama Stack.
* Effort: 3hr
* Guide
- Please see our [Adding a New API Provider](https://llama-stack.readthedocs.io/en/latest/api_providers/new_api_provider.html) guide for adding a new API provider.
- Please see our [Adding a New API Provider](https://llama-stack.readthedocs.io/en/latest/contributing/new_api_provider.html) guide for adding a new API provider.

View file

@ -231,7 +231,7 @@
"source": [
"Thanks for checking out this notebook! \n",
"\n",
"The next one will be a guide on [Prompt Engineering](./01_Prompt_Engineering101.ipynb), please continue learning!"
"The next one will be a guide on [Prompt Engineering](./02_Prompt_Engineering101.ipynb), please continue learning!"
]
}
],

View file

@ -276,7 +276,7 @@
"source": [
"Thanks for checking out this notebook! \n",
"\n",
"The next one will be a guide on how to chat with images, continue to the notebook [here](./02_Image_Chat101.ipynb). Happy learning!"
"The next one will be a guide on how to chat with images, continue to the notebook [here](./03_Image_Chat101.ipynb). Happy learning!"
]
}
],

View file

@ -175,7 +175,7 @@
"source": [
"Thanks for checking out this notebook! \n",
"\n",
"The next one in the series will teach you one of the favorite applications of Large Language Models: [Tool Calling](./03_Tool_Calling101.ipynb). Enjoy!"
"The next one in the series will teach you one of the favorite applications of Large Language Models: [Tool Calling](./04_Tool_Calling101.ipynb). Enjoy!"
]
}
],

View file

@ -373,7 +373,7 @@
"source": [
"Awesome, now we can embed all our notes with Llama-stack and ask it about the meaning of life :)\n",
"\n",
"Next up, we will learn about the safety features and how to use them: [notebook link](./05_Safety101.ipynb)"
"Next up, we will learn about the safety features and how to use them: [notebook link](./06_Safety101.ipynb)."
]
}
],

View file

@ -107,7 +107,7 @@
"source": [
"Thanks for leaning about the Safety API of Llama-Stack. \n",
"\n",
"Finally, we learn about the Agents API, [here](./06_Agents101.ipynb)"
"Finally, we learn about the Agents API, [here](./07_Agents101.ipynb)."
]
}
],

View file

@ -1,37 +1,21 @@
# Llama Stack: from Zero to Hero
Llama-Stack allows you to configure your distribution from various providers, allowing you to focus on going from zero to production super fast.
Llama Stack defines and standardizes the set of core building blocks needed to bring generative AI applications to market. These building blocks are presented in the form of interoperable APIs with a broad set of Providers providing their implementations. These building blocks are assembled into Distributions which are easy for developers to get from zero to production.
This guide will walk you through how to build a local distribution, using Ollama as an inference provider.
This guide will walk you through an end-to-end workflow with Llama Stack with Ollama as the inference provider and ChromaDB as the memory provider. Please note the steps for configuring your provider and distribution will vary a little depending on the services you use. However, the user experience will remain universal - this is the power of Llama-Stack.
We also have a set of notebooks walking you through how to use Llama-Stack APIs:
If you're looking for more specific topics, we have a [Zero to Hero Guide](#next-steps) that covers everything from Tool Calling to Agents in detail. Feel free to skip to the end to explore the advanced topics you're interested in.
- Inference
- Prompt Engineering
- Chatting with Images
- Tool Calling
- Memory API for RAG
- Safety API
- Agentic API
Below, we will learn how to get started with Ollama as an inference provider, please note the steps for configuring your provider will vary a little depending on the service. However, the user experience will remain universal-this is the power of Llama-Stack.
Prototype locally using Ollama, deploy to the cloud with your favorite provider or own deployment. Use any API from any provider while focussing on development.
# Ollama Quickstart Guide
This guide will walk you through setting up an end-to-end workflow with Llama Stack with ollama, enabling you to perform text generation using the `Llama3.2-3B-Instruct` model. Follow these steps to get started quickly.
If you're looking for more specific topics like tool calling or agent setup, we have a [Zero to Hero Guide](#next-steps) that covers everything from Tool Calling to Agents in detail. Feel free to skip to the end to explore the advanced topics you're interested in.
> If you'd prefer not to set up a local server, explore our notebook on [tool calling with the Together API](Tool_Calling101_Using_Together's_Llama_Stack_Server.ipynb). This guide will show you how to leverage Together.ai's Llama Stack Server API, allowing you to get started with Llama Stack without the need for a locally built and running server.
> If you'd prefer not to set up a local server, explore our notebook on [tool calling with the Together API](Tool_Calling101_Using_Together's_Llama_Stack_Server.ipynb). This notebook will show you how to leverage together.ai's Llama Stack Server API, allowing you to get started with Llama Stack without the need for a locally built and running server.
## Table of Contents
1. [Setup ollama](#setup-ollama)
1. [Setup and run ollama](#setup-ollama)
2. [Install Dependencies and Set Up Environment](#install-dependencies-and-set-up-environment)
3. [Build, Configure, and Run Llama Stack](#build-configure-and-run-llama-stack)
4. [Run Ollama Model](#run-ollama-model)
5. [Next Steps](#next-steps)
4. [Test with llama-stack-client CLI](#test-with-llama-stack-client-cli)
5. [Test with curl](#test-with-curl)
6. [Test with Python](#test-with-python)
7. [Next Steps](#next-steps)
---
@ -39,107 +23,137 @@ If you're looking for more specific topics like tool calling or agent setup, we
1. **Download Ollama App**:
- Go to [https://ollama.com/download](https://ollama.com/download).
- Download and unzip `Ollama-darwin.zip`.
- Follow instructions based on the OS you are on. For example, if you are on a Mac, download and unzip `Ollama-darwin.zip`.
- Run the `Ollama` application.
1. **Download the Ollama CLI**:
- Ensure you have the `ollama` command line tool by downloading and installing it from the same website.
Ensure you have the `ollama` command line tool by downloading and installing it from the same website.
1. **Start ollama server**:
- Open the terminal and run:
```
ollama serve
```
Open the terminal and run:
```
ollama serve
```
1. **Run the model**:
- Open the terminal and run:
```bash
ollama run llama3.2:3b-instruct-fp16
```
**Note**: The supported models for llama stack for now is listed in [here](https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/inference/ollama/ollama.py#L43)
Open the terminal and run:
```bash
ollama run llama3.2:3b-instruct-fp16 --keepalive -1m
```
**Note**:
- The supported models for llama stack for now is listed in [here](https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/inference/ollama/ollama.py#L43)
- `keepalive -1m` is used so that ollama continues to keep the model in memory indefinitely. Otherwise, ollama frees up memory and you would have to run `ollama run` again.
---
## Install Dependencies and Set Up Environment
1. **Create a Conda Environment**:
- Create a new Conda environment with Python 3.10:
```bash
conda create -n ollama python=3.10
```
- Activate the environment:
```bash
conda activate ollama
```
Create a new Conda environment with Python 3.10:
```bash
conda create -n ollama python=3.10
```
Activate the environment:
```bash
conda activate ollama
```
2. **Install ChromaDB**:
- Install `chromadb` using `pip`:
```bash
pip install chromadb
```
Install `chromadb` using `pip`:
```bash
pip install chromadb
```
3. **Run ChromaDB**:
- Start the ChromaDB server:
```bash
chroma run --host localhost --port 8000 --path ./my_chroma_data
```
Start the ChromaDB server:
```bash
chroma run --host localhost --port 8000 --path ./my_chroma_data
```
4. **Install Llama Stack**:
- Open a new terminal and install `llama-stack`:
```bash
conda activate hack
pip install llama-stack==0.0.53
```
Open a new terminal and install `llama-stack`:
```bash
conda activate ollama
pip install llama-stack==0.0.55
```
---
## Build, Configure, and Run Llama Stack
1. **Build the Llama Stack**:
- Build the Llama Stack using the `ollama` template:
```bash
llama stack build --template ollama --image-type conda
```
After this step, you will see the console output:
```
Build Successful! Next steps:
Build the Llama Stack using the `ollama` template:
```bash
llama stack build --template ollama --image-type conda
```
**Expected Output:**
```
...
Build Successful! Next steps:
1. Set the environment variables: LLAMASTACK_PORT, OLLAMA_URL, INFERENCE_MODEL, SAFETY_MODEL
2. `llama stack run /Users/username/.llama/distributions/llamastack-ollama/ollama-run.yaml`
```
2. `llama stack run /Users/<username>/.llama/distributions/llamastack-ollama/ollama-run.yaml
```
2. **Set the ENV variables by exporting them to the terminal**:
```bash
export OLLAMA_URL="http://localhost:11434"
export LLAMA_STACK_PORT=5001
export INFERENCE_MODEL="meta-llama/Llama-3.2-3B-Instruct"
export SAFETY_MODEL="meta-llama/Llama-Guard-3-1B"
```
3. **Set the ENV variables by exporting them to the terminal**:
```bash
export OLLAMA_URL="http://localhost:11434"
export LLAMA_STACK_PORT=5051
export INFERENCE_MODEL="meta-llama/Llama-3.2-3B-Instruct"
export SAFETY_MODEL="meta-llama/Llama-Guard-3-1B"
```
3. **Run the Llama Stack**:
- Run the stack with command shared by the API from earlier:
```bash
llama stack run ollama \
--port $LLAMA_STACK_PORT \
--env INFERENCE_MODEL=$INFERENCE_MODEL \
--env SAFETY_MODEL=$SAFETY_MODEL \
--env OLLAMA_URL=http://localhost:11434
```
Note: Everytime you run a new model with `ollama run`, you will need to restart the llama stack. Otherwise it won't see the new model
Run the stack with command shared by the API from earlier:
```bash
llama stack run ollama \
--port $LLAMA_STACK_PORT \
--env INFERENCE_MODEL=$INFERENCE_MODEL \
--env SAFETY_MODEL=$SAFETY_MODEL \
--env OLLAMA_URL=$OLLAMA_URL
```
Note: Everytime you run a new model with `ollama run`, you will need to restart the llama stack. Otherwise it won't see the new model.
The server will start and listen on `http://localhost:5051`.
---
## Test with `llama-stack-client` CLI
After setting up the server, open a new terminal window and install the llama-stack-client package.
## Testing with `curl`
1. Install the llama-stack-client package
```bash
conda activate ollama
pip install llama-stack-client
```
2. Configure the CLI to point to the llama-stack server.
```bash
llama-stack-client configure --endpoint http://localhost:5051
```
**Expected Output:**
```bash
Done! You can now use the Llama Stack Client CLI with endpoint http://localhost:5051
```
3. Test the CLI by running inference:
```bash
llama-stack-client inference chat-completion --message "Write me a 2-sentence poem about the moon"
```
**Expected Output:**
```bash
ChatCompletionResponse(
completion_message=CompletionMessage(
content='Here is a 2-sentence poem about the moon:\n\nSilver crescent shining bright in the night,\nA beacon of wonder, full of gentle light.',
role='assistant',
stop_reason='end_of_turn',
tool_calls=[]
),
logprobs=None
)
```
## Test with `curl`
After setting up the server, open a new terminal window and verify it's working by sending a `POST` request using `curl`:
```bash
curl http://localhost:5051/inference/chat_completion \
curl http://localhost:$LLAMA_STACK_PORT/inference/chat_completion \
-H "Content-Type: application/json" \
-d '{
"model": "Llama3.2-3B-Instruct",
@ -168,15 +182,16 @@ You can check the available models with the command `llama-stack-client models l
---
## Testing with Python
## Test with Python
You can also interact with the Llama Stack server using a simple Python script. Below is an example:
### 1. Active Conda Environment and Install Required Python Packages
### 1. Activate Conda Environment and Install Required Python Packages
The `llama-stack-client` library offers a robust and efficient python methods for interacting with the Llama Stack server.
```bash
conda activate your-llama-stack-conda-env
conda activate ollama
pip install llama-stack-client
```
Note, the client library gets installed by default if you install the server library
@ -188,6 +203,8 @@ touch test_llama_stack.py
### 3. Create a Chat Completion Request in Python
In `test_llama_stack.py`, write the following code:
```python
from llama_stack_client import LlamaStackClient
@ -227,15 +244,15 @@ This command initializes the model to interact with your local Llama Stack insta
## Next Steps
**Explore Other Guides**: Dive deeper into specific topics by following these guides:
- [Understanding Distribution](https://llama-stack.readthedocs.io/en/latest/getting_started/index.html#decide-your-inference-provider)
- [Understanding Distribution](https://llama-stack.readthedocs.io/en/latest/concepts/index.html#distributions)
- [Inference 101](00_Inference101.ipynb)
- [Local and Cloud Model Toggling 101](00_Local_Cloud_Inference101.ipynb)
- [Prompt Engineering](01_Prompt_Engineering101.ipynb)
- [Chat with Image - LlamaStack Vision API](02_Image_Chat101.ipynb)
- [Tool Calling: How to and Details](03_Tool_Calling101.ipynb)
- [Memory API: Show Simple In-Memory Retrieval](04_Memory101.ipynb)
- [Using Safety API in Conversation](05_Safety101.ipynb)
- [Agents API: Explain Components](06_Agents101.ipynb)
- [Local and Cloud Model Toggling 101](01_Local_Cloud_Inference101.ipynb)
- [Prompt Engineering](02_Prompt_Engineering101.ipynb)
- [Chat with Image - LlamaStack Vision API](03_Image_Chat101.ipynb)
- [Tool Calling: How to and Details](04_Tool_Calling101.ipynb)
- [Memory API: Show Simple In-Memory Retrieval](05_Memory101.ipynb)
- [Using Safety API in Conversation](06_Safety101.ipynb)
- [Agents API: Explain Components](07_Agents101.ipynb)
**Explore Client SDKs**: Utilize our client SDKs for various languages to integrate Llama Stack into your applications:
@ -244,7 +261,7 @@ This command initializes the model to interact with your local Llama Stack insta
- [Swift SDK](https://github.com/meta-llama/llama-stack-client-swift)
- [Kotlin SDK](https://github.com/meta-llama/llama-stack-client-kotlin)
**Advanced Configuration**: Learn how to customize your Llama Stack distribution by referring to the [Building a Llama Stack Distribution](https://llama-stack.readthedocs.io/en/latest/distributions/index.html#building-your-own-distribution) guide.
**Advanced Configuration**: Learn how to customize your Llama Stack distribution by referring to the [Building a Llama Stack Distribution](https://llama-stack.readthedocs.io/en/latest/distributions/building_distro.html) guide.
**Explore Example Apps**: Check out [llama-stack-apps](https://github.com/meta-llama/llama-stack-apps/tree/main/examples) for example applications built using Llama Stack.

View file

@ -71,7 +71,7 @@
}
],
"source": [
"!pip install llama-stack-client"
"!pip install llama-stack-client==0.0.50"
]
},
{

View file

@ -23,6 +23,7 @@ from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, ConfigDict, Field
from typing_extensions import Annotated
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.common.deployment_types import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
@ -418,6 +419,7 @@ class AgentStepResponse(BaseModel):
@runtime_checkable
@trace_protocol
class Agents(Protocol):
@webmethod(route="/agents/create")
async def create_agent(

View file

@ -37,3 +37,8 @@ class DatasetIO(Protocol):
page_token: Optional[str] = None,
filter_condition: Optional[str] = None,
) -> PaginatedRowsResult: ...
@webmethod(route="/datasetio/append-rows", method="POST")
async def append_rows(
self, dataset_id: str, rows: List[Dict[str, Any]]
) -> None: ...

View file

@ -78,6 +78,21 @@ class DatasetsClient(Datasets):
return [DatasetDefWithProvider(**x) for x in response.json()]
async def unregister_dataset(
self,
dataset_id: str,
) -> None:
async with httpx.AsyncClient() as client:
response = await client.delete(
f"{self.base_url}/datasets/unregister",
params={
"dataset_id": dataset_id,
},
headers={"Content-Type": "application/json"},
timeout=60,
)
response.raise_for_status()
async def run_main(host: str, port: int):
client = DatasetsClient(f"http://{host}:{port}")

View file

@ -64,3 +64,9 @@ class Datasets(Protocol):
@webmethod(route="/datasets/list", method="GET")
async def list_datasets(self) -> List[Dataset]: ...
@webmethod(route="/datasets/unregister", method="POST")
async def unregister_dataset(
self,
dataset_id: str,
) -> None: ...

View file

@ -21,6 +21,8 @@ from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from typing_extensions import Annotated
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.models import * # noqa: F403
@ -220,6 +222,7 @@ class ModelStore(Protocol):
@runtime_checkable
@trace_protocol
class Inference(Protocol):
model_store: ModelStore

View file

@ -16,6 +16,7 @@ from pydantic import BaseModel, Field
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.memory_banks import * # noqa: F403
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
@json_schema_type
@ -43,6 +44,7 @@ class MemoryBankStore(Protocol):
@runtime_checkable
@trace_protocol
class Memory(Protocol):
memory_bank_store: MemoryBankStore

View file

@ -20,6 +20,7 @@ from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
@json_schema_type
@ -129,6 +130,7 @@ class MemoryBankInput(BaseModel):
@runtime_checkable
@trace_protocol
class MemoryBanks(Protocol):
@webmethod(route="/memory-banks/list", method="GET")
async def list_memory_banks(self) -> List[MemoryBank]: ...

View file

@ -10,6 +10,7 @@ from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, ConfigDict, Field
from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
class CommonModelFields(BaseModel):
@ -43,6 +44,7 @@ class ModelInput(CommonModelFields):
@runtime_checkable
@trace_protocol
class Models(Protocol):
@webmethod(route="/models/list", method="GET")
async def list_models(self) -> List[Model]: ...

View file

@ -17,6 +17,8 @@ from llama_models.llama3.api.datatypes import * # noqa: F403
from pydantic import BaseModel
from termcolor import cprint
from llama_stack.apis.version import LLAMA_STACK_API_VERSION
from llama_stack.distribution.datatypes import RemoteProviderConfig
from llama_stack.apis.safety import * # noqa: F403
@ -45,7 +47,7 @@ class SafetyClient(Safety):
) -> RunShieldResponse:
async with httpx.AsyncClient() as client:
response = await client.post(
f"{self.base_url}/safety/run_shield",
f"{self.base_url}/{LLAMA_STACK_API_VERSION}/safety/run-shield",
json=dict(
shield_id=shield_id,
messages=[encodable_dict(m) for m in messages],
@ -91,7 +93,7 @@ async def run_main(host: str, port: int, image_path: str = None):
]:
cprint(f"User>{message.content}", "green")
response = await client.run_shield(
shield_id="llama_guard",
shield_id="meta-llama/Llama-Guard-3-1B",
messages=[message],
)
print(response)

View file

@ -10,6 +10,8 @@ from typing import Any, Dict, List, Protocol, runtime_checkable
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.shields import * # noqa: F403
@ -43,6 +45,7 @@ class ShieldStore(Protocol):
@runtime_checkable
@trace_protocol
class Safety(Protocol):
shield_store: ShieldStore

View file

@ -10,6 +10,7 @@ from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel
from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
class CommonShieldFields(BaseModel):
@ -38,6 +39,7 @@ class ShieldInput(CommonShieldFields):
@runtime_checkable
@trace_protocol
class Shields(Protocol):
@webmethod(route="/shields/list", method="GET")
async def list_shields(self) -> List[Shield]: ...

View file

@ -6,12 +6,24 @@
from datetime import datetime
from enum import Enum
from typing import Any, Dict, Literal, Optional, Protocol, runtime_checkable, Union
from typing import (
Any,
Dict,
List,
Literal,
Optional,
Protocol,
runtime_checkable,
Union,
)
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from typing_extensions import Annotated
# Add this constant near the top of the file, after the imports
DEFAULT_TTL_DAYS = 7
@json_schema_type
class SpanStatus(Enum):
@ -29,6 +41,11 @@ class Span(BaseModel):
end_time: Optional[datetime] = None
attributes: Optional[Dict[str, Any]] = Field(default_factory=dict)
def set_attribute(self, key: str, value: Any):
if self.attributes is None:
self.attributes = {}
self.attributes[key] = value
@json_schema_type
class Trace(BaseModel):
@ -123,10 +140,66 @@ Event = Annotated[
]
@json_schema_type
class EvalTrace(BaseModel):
session_id: str
step: str
input: str
output: str
expected_output: str
@json_schema_type
class SpanWithChildren(Span):
children: List["SpanWithChildren"] = Field(default_factory=list)
status: Optional[SpanStatus] = None
@json_schema_type
class QueryCondition(BaseModel):
key: str
op: Literal["eq", "ne", "gt", "lt"]
value: Any
@runtime_checkable
class Telemetry(Protocol):
@webmethod(route="/telemetry/log-event")
async def log_event(self, event: Event) -> None: ...
@webmethod(route="/telemetry/get-trace", method="GET")
async def get_trace(self, trace_id: str) -> Trace: ...
@webmethod(route="/telemetry/log-event")
async def log_event(
self, event: Event, ttl_seconds: int = DEFAULT_TTL_DAYS * 86400
) -> None: ...
@webmethod(route="/telemetry/query-traces", method="POST")
async def query_traces(
self,
attribute_filters: Optional[List[QueryCondition]] = None,
limit: Optional[int] = 100,
offset: Optional[int] = 0,
order_by: Optional[List[str]] = None,
) -> List[Trace]: ...
@webmethod(route="/telemetry/get-span-tree", method="POST")
async def get_span_tree(
self,
span_id: str,
attributes_to_return: Optional[List[str]] = None,
max_depth: Optional[int] = None,
) -> SpanWithChildren: ...
@webmethod(route="/telemetry/query-spans", method="POST")
async def query_spans(
self,
attribute_filters: List[QueryCondition],
attributes_to_return: List[str],
max_depth: Optional[int] = None,
) -> List[Span]: ...
@webmethod(route="/telemetry/save-spans-to-dataset", method="POST")
async def save_spans_to_dataset(
self,
attribute_filters: List[QueryCondition],
attributes_to_save: List[str],
dataset_id: str,
max_depth: Optional[int] = None,
) -> None: ...

View file

@ -10,6 +10,7 @@ from typing import List
import pkg_resources
from pydantic import BaseModel
from termcolor import cprint
from llama_stack.distribution.utils.exec import run_with_pty
@ -45,7 +46,7 @@ class ApiInput(BaseModel):
def get_provider_dependencies(
config_providers: Dict[str, List[Provider]]
config_providers: Dict[str, List[Provider]],
) -> tuple[list[str], list[str]]:
"""Get normal and special dependencies from provider configuration."""
all_providers = get_provider_registry()
@ -90,11 +91,12 @@ def get_provider_dependencies(
def print_pip_install_help(providers: Dict[str, List[Provider]]):
normal_deps, special_deps = get_provider_dependencies(providers)
print(
f"Please install needed dependencies using the following commands:\n\n\tpip install {' '.join(normal_deps)}"
cprint(
f"Please install needed dependencies using the following commands:\n\npip install {' '.join(normal_deps)}",
"yellow",
)
for special_dep in special_deps:
log.info(f"\tpip install {special_dep}")
cprint(f"pip install {special_dep}", "yellow")
print()

View file

@ -0,0 +1,272 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import inspect
import queue
import threading
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import Any, Generator, get_args, get_origin, Optional, TypeVar
import yaml
from llama_stack_client import AsyncLlamaStackClient, LlamaStackClient, NOT_GIVEN
from pydantic import TypeAdapter
from rich.console import Console
from termcolor import cprint
from llama_stack.distribution.build import print_pip_install_help
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
from llama_stack.distribution.resolver import ProviderRegistry
from llama_stack.distribution.server.endpoints import get_all_api_endpoints
from llama_stack.distribution.stack import (
construct_stack,
get_stack_run_config_from_template,
replace_env_vars,
)
T = TypeVar("T")
def stream_across_asyncio_run_boundary(
async_gen_maker,
pool_executor: ThreadPoolExecutor,
) -> Generator[T, None, None]:
result_queue = queue.Queue()
stop_event = threading.Event()
async def consumer():
# make sure we make the generator in the event loop context
gen = await async_gen_maker()
try:
async for item in gen:
result_queue.put(item)
except Exception as e:
print(f"Error in generator {e}")
result_queue.put(e)
except asyncio.CancelledError:
return
finally:
result_queue.put(StopIteration)
stop_event.set()
def run_async():
# Run our own loop to avoid double async generator cleanup which is done
# by asyncio.run()
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
task = loop.create_task(consumer())
loop.run_until_complete(task)
finally:
# Handle pending tasks like a generator's athrow()
pending = asyncio.all_tasks(loop)
if pending:
loop.run_until_complete(
asyncio.gather(*pending, return_exceptions=True)
)
loop.close()
future = pool_executor.submit(run_async)
try:
# yield results as they come in
while not stop_event.is_set() or not result_queue.empty():
try:
item = result_queue.get(timeout=0.1)
if item is StopIteration:
break
if isinstance(item, Exception):
raise item
yield item
except queue.Empty:
continue
finally:
future.result()
class LlamaStackAsLibraryClient(LlamaStackClient):
def __init__(
self,
config_path_or_template_name: str,
custom_provider_registry: Optional[ProviderRegistry] = None,
):
super().__init__()
self.async_client = AsyncLlamaStackAsLibraryClient(
config_path_or_template_name, custom_provider_registry
)
self.pool_executor = ThreadPoolExecutor(max_workers=4)
def initialize(self):
asyncio.run(self.async_client.initialize())
def get(self, *args, **kwargs):
if kwargs.get("stream"):
return stream_across_asyncio_run_boundary(
lambda: self.async_client.get(*args, **kwargs),
self.pool_executor,
)
else:
return asyncio.run(self.async_client.get(*args, **kwargs))
def post(self, *args, **kwargs):
if kwargs.get("stream"):
return stream_across_asyncio_run_boundary(
lambda: self.async_client.post(*args, **kwargs),
self.pool_executor,
)
else:
return asyncio.run(self.async_client.post(*args, **kwargs))
class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
def __init__(
self,
config_path_or_template_name: str,
custom_provider_registry: Optional[ProviderRegistry] = None,
):
super().__init__()
if config_path_or_template_name.endswith(".yaml"):
config_path = Path(config_path_or_template_name)
if not config_path.exists():
raise ValueError(f"Config file {config_path} does not exist")
config_dict = replace_env_vars(yaml.safe_load(config_path.read_text()))
config = parse_and_maybe_upgrade_config(config_dict)
else:
# template
config = get_stack_run_config_from_template(config_path_or_template_name)
self.config_path_or_template_name = config_path_or_template_name
self.config = config
self.custom_provider_registry = custom_provider_registry
async def initialize(self):
try:
self.impls = await construct_stack(
self.config, self.custom_provider_registry
)
except ModuleNotFoundError as e:
cprint(
"Using llama-stack as a library requires installing dependencies depending on the template (providers) you choose.\n",
"yellow",
)
print_pip_install_help(self.config.providers)
raise e
console = Console()
console.print(f"Using config [blue]{self.config_path_or_template_name}[/blue]:")
console.print(yaml.dump(self.config.model_dump(), indent=2))
endpoints = get_all_api_endpoints()
endpoint_impls = {}
for api, api_endpoints in endpoints.items():
for endpoint in api_endpoints:
impl = self.impls[api]
func = getattr(impl, endpoint.name)
endpoint_impls[endpoint.route] = func
self.endpoint_impls = endpoint_impls
async def get(
self,
path: str,
*,
stream=False,
**kwargs,
):
if not self.endpoint_impls:
raise ValueError("Client not initialized")
if stream:
return self._call_streaming(path, "GET")
else:
return await self._call_non_streaming(path, "GET")
async def post(
self,
path: str,
*,
body: dict = None,
stream=False,
**kwargs,
):
if not self.endpoint_impls:
raise ValueError("Client not initialized")
if stream:
return self._call_streaming(path, "POST", body)
else:
return await self._call_non_streaming(path, "POST", body)
async def _call_non_streaming(self, path: str, method: str, body: dict = None):
func = self.endpoint_impls.get(path)
if not func:
raise ValueError(f"No endpoint found for {path}")
body = self._convert_body(path, body)
return await func(**body)
async def _call_streaming(self, path: str, method: str, body: dict = None):
func = self.endpoint_impls.get(path)
if not func:
raise ValueError(f"No endpoint found for {path}")
body = self._convert_body(path, body)
async for chunk in await func(**body):
yield chunk
def _convert_body(self, path: str, body: Optional[dict] = None) -> dict:
if not body:
return {}
func = self.endpoint_impls[path]
sig = inspect.signature(func)
# Strip NOT_GIVENs to use the defaults in signature
body = {k: v for k, v in body.items() if v is not NOT_GIVEN}
# Convert parameters to Pydantic models where needed
converted_body = {}
for param_name, param in sig.parameters.items():
if param_name in body:
value = body.get(param_name)
converted_body[param_name] = self._convert_param(
param.annotation, value
)
return converted_body
def _convert_param(self, annotation: Any, value: Any) -> Any:
if isinstance(annotation, type) and annotation in {str, int, float, bool}:
return value
origin = get_origin(annotation)
if origin is list:
item_type = get_args(annotation)[0]
try:
return [self._convert_param(item_type, item) for item in value]
except Exception:
print(f"Error converting list {value}")
return value
elif origin is dict:
key_type, val_type = get_args(annotation)
try:
return {k: self._convert_param(val_type, v) for k, v in value.items()}
except Exception:
print(f"Error converting dict {value}")
return value
try:
# Handle Pydantic models and discriminated unions
return TypeAdapter(annotation).validate_python(value)
except Exception as e:
cprint(
f"Warning: direct client failed to convert parameter {value} into {annotation}: {e}",
"yellow",
)
return value

View file

@ -35,7 +35,7 @@ class NeedsRequestProviderData:
provider_data = validator(**val)
return provider_data
except Exception as e:
log.error("Error parsing provider data", e)
log.error(f"Error parsing provider data: {e}")
def set_request_provider_data(headers: Dict[str, str]):

View file

@ -222,6 +222,12 @@ class DatasetIORouter(DatasetIO):
filter_condition=filter_condition,
)
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
return await self.routing_table.get_provider_impl(dataset_id).append_rows(
dataset_id=dataset_id,
rows=rows,
)
class ScoringRouter(Scoring):
def __init__(

View file

@ -57,6 +57,8 @@ async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None:
return await p.unregister_memory_bank(obj.identifier)
elif api == Api.inference:
return await p.unregister_model(obj.identifier)
elif api == Api.datasetio:
return await p.unregister_dataset(obj.identifier)
else:
raise ValueError(f"Unregister not supported for {api}")
@ -354,6 +356,12 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
)
await self.register_object(dataset)
async def unregister_dataset(self, dataset_id: str) -> None:
dataset = await self.get_dataset(dataset_id)
if dataset is None:
raise ValueError(f"Dataset {dataset_id} not found")
await self.unregister_object(dataset)
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
async def list_scoring_functions(self) -> List[ScoringFn]:

View file

@ -43,9 +43,9 @@ from llama_stack.distribution.stack import (
replace_env_vars,
validate_env_pair,
)
from llama_stack.providers.inline.meta_reference.telemetry.console import (
ConsoleConfig,
ConsoleTelemetryImpl,
from llama_stack.providers.inline.telemetry.meta_reference import (
TelemetryAdapter,
TelemetryConfig,
)
from .endpoints import get_all_api_endpoints
@ -217,7 +217,7 @@ class TracingMiddleware:
async def __call__(self, scope, receive, send):
path = scope["path"]
await start_trace(path, {"location": "server"})
await start_trace(path, {"__location__": "server"})
try:
return await self.app(scope, receive, send)
finally:
@ -290,7 +290,7 @@ def main():
if Api.telemetry in impls:
setup_logger(impls[Api.telemetry])
else:
setup_logger(ConsoleTelemetryImpl(ConsoleConfig()))
setup_logger(TelemetryAdapter(TelemetryConfig()))
all_endpoints = get_all_api_endpoints()

View file

@ -0,0 +1,103 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
import os
from llama_stack.distribution.library_client import LlamaStackAsLibraryClient
from llama_stack_client.lib.agents.agent import Agent
from llama_stack_client.lib.agents.event_logger import EventLogger as AgentEventLogger
from llama_stack_client.lib.inference.event_logger import EventLogger
from llama_stack_client.types import UserMessage
from llama_stack_client.types.agent_create_params import AgentConfig
def main(config_path: str):
client = LlamaStackAsLibraryClient(config_path)
client.initialize()
models = client.models.list()
print("\nModels:")
for model in models:
print(model)
if not models:
print("No models found, skipping chat completion test")
return
model_id = models[0].identifier
response = client.inference.chat_completion(
messages=[UserMessage(content="What is the capital of France?", role="user")],
model_id=model_id,
stream=False,
)
print("\nChat completion response (non-stream):")
print(response)
response = client.inference.chat_completion(
messages=[UserMessage(content="What is the capital of France?", role="user")],
model_id=model_id,
stream=True,
)
print("\nChat completion response (stream):")
for log in EventLogger().log(response):
log.print()
print("\nAgent test:")
agent_config = AgentConfig(
model=model_id,
instructions="You are a helpful assistant",
sampling_params={
"strategy": "greedy",
"temperature": 1.0,
"top_p": 0.9,
},
tools=(
[
{
"type": "brave_search",
"engine": "brave",
"api_key": os.getenv("BRAVE_SEARCH_API_KEY"),
}
]
if os.getenv("BRAVE_SEARCH_API_KEY")
else []
),
tool_choice="auto",
tool_prompt_format="json",
input_shields=[],
output_shields=[],
enable_session_persistence=False,
)
agent = Agent(client, agent_config)
user_prompts = [
"Hello",
"Which players played in the winning team of the NBA western conference semifinals of 2024, please use tools",
]
session_id = agent.create_session("test-session")
for prompt in user_prompts:
response = agent.create_turn(
messages=[
{
"role": "user",
"content": prompt,
}
],
session_id=session_id,
)
for log in AgentEventLogger().log(response):
log.print()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("config_path", help="Path to the config YAML file")
args = parser.parse_args()
main(args.config_path)

View file

@ -0,0 +1,42 @@
# (Experimental) LLama Stack UI
## Docker Setup
:warning: This is a work in progress.
## Developer Setup
1. Start up Llama Stack API server. More details [here](https://llama-stack.readthedocs.io/en/latest/getting_started/index.html).
```
llama stack build --template together --image-type conda
llama stack run together
```
2. (Optional) Register datasets and eval tasks as resources. If you want to run pre-configured evaluation flows (e.g. Evaluations (Generation + Scoring) Page).
```bash
$ llama-stack-client datasets register \
--dataset-id "mmlu" \
--provider-id "huggingface" \
--url "https://huggingface.co/datasets/llamastack/evals" \
--metadata '{"path": "llamastack/evals", "name": "evals__mmlu__details", "split": "train"}' \
--schema '{"input_query": {"type": "string"}, "expected_answer": {"type": "string", "chat_completion_input": {"type": "string"}}}'
```
```bash
$ llama-stack-client eval_tasks register \
--eval-task-id meta-reference-mmlu \
--provider-id meta-reference \
--dataset-id mmlu \
--scoring-functions basic::regex_parser_multiple_choice_answer
```
3. Start Streamlit UI
```bash
cd llama_stack/distribution/ui
pip install -r requirements.txt
streamlit run app.py
```

View file

@ -0,0 +1,57 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import streamlit as st
def main():
# Evaluation pages
application_evaluation_page = st.Page(
"page/evaluations/app_eval.py",
title="Evaluations (Scoring)",
icon="📊",
default=False,
)
native_evaluation_page = st.Page(
"page/evaluations/native_eval.py",
title="Evaluations (Generation + Scoring)",
icon="📊",
default=False,
)
# Playground pages
chat_page = st.Page(
"page/playground/chat.py", title="Chat", icon="💬", default=True
)
rag_page = st.Page("page/playground/rag.py", title="RAG", icon="💬", default=False)
# Distribution pages
resources_page = st.Page(
"page/distribution/resources.py", title="Resources", icon="🔍", default=False
)
provider_page = st.Page(
"page/distribution/providers.py",
title="API Providers",
icon="🔍",
default=False,
)
pg = st.navigation(
{
"Playground": [
chat_page,
rag_page,
application_evaluation_page,
native_evaluation_page,
],
"Inspect": [provider_page, resources_page],
},
expanded=False,
)
pg.run()
if __name__ == "__main__":
main()

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,36 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
from typing import Optional
from llama_stack_client import LlamaStackClient
class LlamaStackApi:
def __init__(self):
self.client = LlamaStackClient(
base_url=os.environ.get("LLAMA_STACK_ENDPOINT", "http://localhost:5000"),
provider_data={
"fireworks_api_key": os.environ.get("FIREWORKS_API_KEY", ""),
"together_api_key": os.environ.get("TOGETHER_API_KEY", ""),
"openai_api_key": os.environ.get("OPENAI_API_KEY", ""),
},
)
def run_scoring(
self, row, scoring_function_ids: list[str], scoring_params: Optional[dict]
):
"""Run scoring on a single row"""
if not scoring_params:
scoring_params = {fn_id: None for fn_id in scoring_function_ids}
return self.client.scoring.score(
input_rows=[row], scoring_functions=scoring_params
)
llama_stack_api = LlamaStackApi()

View file

@ -0,0 +1,42 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import base64
import os
import pandas as pd
import streamlit as st
def process_dataset(file):
if file is None:
return "No file uploaded", None
try:
# Determine file type and read accordingly
file_ext = os.path.splitext(file.name)[1].lower()
if file_ext == ".csv":
df = pd.read_csv(file)
elif file_ext in [".xlsx", ".xls"]:
df = pd.read_excel(file)
else:
return "Unsupported file format. Please upload a CSV or Excel file.", None
return df
except Exception as e:
st.error(f"Error processing file: {str(e)}")
return None
def data_url_from_file(file) -> str:
file_content = file.getvalue()
base64_content = base64.b64encode(file_content).decode("utf-8")
mime_type = file.type
data_url = f"data:{mime_type};base64,{base64_content}"
return data_url

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,19 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import streamlit as st
from modules.api import llama_stack_api
def datasets():
st.header("Datasets")
datasets_info = {
d.identifier: d.to_dict() for d in llama_stack_api.client.datasets.list()
}
selected_dataset = st.selectbox("Select a dataset", list(datasets_info.keys()))
st.json(datasets_info[selected_dataset], expanded=True)

View file

@ -0,0 +1,22 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import streamlit as st
from modules.api import llama_stack_api
def eval_tasks():
# Eval Tasks Section
st.header("Eval Tasks")
eval_tasks_info = {
d.identifier: d.to_dict() for d in llama_stack_api.client.eval_tasks.list()
}
selected_eval_task = st.selectbox(
"Select an eval task", list(eval_tasks_info.keys()), key="eval_task_inspect"
)
st.json(eval_tasks_info[selected_eval_task], expanded=True)

View file

@ -0,0 +1,23 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import streamlit as st
from modules.api import llama_stack_api
def memory_banks():
st.header("Memory Banks")
memory_banks_info = {
m.identifier: m.to_dict() for m in llama_stack_api.client.memory_banks.list()
}
if len(memory_banks_info) > 0:
selected_memory_bank = st.selectbox(
"Select a memory bank", list(memory_banks_info.keys())
)
st.json(memory_banks_info[selected_memory_bank])
else:
st.info("No memory banks found")

View file

@ -0,0 +1,19 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import streamlit as st
from modules.api import llama_stack_api
def models():
# Models Section
st.header("Models")
models_info = {
m.identifier: m.to_dict() for m in llama_stack_api.client.models.list()
}
selected_model = st.selectbox("Select a model", list(models_info.keys()))
st.json(models_info[selected_model])

View file

@ -0,0 +1,20 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import streamlit as st
from modules.api import llama_stack_api
def providers():
st.header("🔍 API Providers")
apis_providers_info = llama_stack_api.client.providers.list()
# selected_api = st.selectbox("Select an API", list(apis_providers_info.keys()))
for api in apis_providers_info.keys():
st.markdown(f"###### {api}")
st.dataframe([p.to_dict() for p in apis_providers_info[api]], width=500)
providers()

View file

@ -0,0 +1,52 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from page.distribution.datasets import datasets
from page.distribution.eval_tasks import eval_tasks
from page.distribution.memory_banks import memory_banks
from page.distribution.models import models
from page.distribution.scoring_functions import scoring_functions
from page.distribution.shields import shields
from streamlit_option_menu import option_menu
def resources_page():
options = [
"Models",
"Memory Banks",
"Shields",
"Scoring Functions",
"Datasets",
"Eval Tasks",
]
icons = ["magic", "memory", "shield", "file-bar-graph", "database", "list-task"]
selected_resource = option_menu(
None,
options,
icons=icons,
orientation="horizontal",
styles={
"nav-link": {
"font-size": "12px",
},
},
)
if selected_resource == "Eval Tasks":
eval_tasks()
elif selected_resource == "Memory Banks":
memory_banks()
elif selected_resource == "Datasets":
datasets()
elif selected_resource == "Models":
models()
elif selected_resource == "Scoring Functions":
scoring_functions()
elif selected_resource == "Shields":
shields()
resources_page()

View file

@ -0,0 +1,22 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import streamlit as st
from modules.api import llama_stack_api
def scoring_functions():
st.header("Scoring Functions")
scoring_functions_info = {
s.identifier: s.to_dict()
for s in llama_stack_api.client.scoring_functions.list()
}
selected_scoring_function = st.selectbox(
"Select a scoring function", list(scoring_functions_info.keys())
)
st.json(scoring_functions_info[selected_scoring_function], expanded=True)

View file

@ -0,0 +1,20 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import streamlit as st
from modules.api import llama_stack_api
def shields():
# Shields Section
st.header("Shields")
shields_info = {
s.identifier: s.to_dict() for s in llama_stack_api.client.shields.list()
}
selected_shield = st.selectbox("Select a shield", list(shields_info.keys()))
st.json(shields_info[selected_shield])

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,148 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import pandas as pd
import streamlit as st
from modules.api import llama_stack_api
from modules.utils import process_dataset
def application_evaluation_page():
st.set_page_config(page_title="Evaluations (Scoring)", page_icon="🦙")
st.title("📊 Evaluations (Scoring)")
# File uploader
uploaded_file = st.file_uploader("Upload Dataset", type=["csv", "xlsx", "xls"])
if uploaded_file is None:
st.error("No file uploaded")
return
# Process uploaded file
df = process_dataset(uploaded_file)
if df is None:
st.error("Error processing file")
return
# Display dataset information
st.success("Dataset loaded successfully!")
# Display dataframe preview
st.subheader("Dataset Preview")
st.dataframe(df)
# Select Scoring Functions to Run Evaluation On
st.subheader("Select Scoring Functions")
scoring_functions = llama_stack_api.client.scoring_functions.list()
scoring_functions = {sf.identifier: sf for sf in scoring_functions}
scoring_functions_names = list(scoring_functions.keys())
selected_scoring_functions = st.multiselect(
"Choose one or more scoring functions",
options=scoring_functions_names,
help="Choose one or more scoring functions.",
)
available_models = llama_stack_api.client.models.list()
available_models = [m.identifier for m in available_models]
scoring_params = {}
if selected_scoring_functions:
st.write("Selected:")
for scoring_fn_id in selected_scoring_functions:
scoring_fn = scoring_functions[scoring_fn_id]
st.write(f"- **{scoring_fn_id}**: {scoring_fn.description}")
new_params = None
if scoring_fn.params:
new_params = {}
for param_name, param_value in scoring_fn.params.to_dict().items():
if param_name == "type":
new_params[param_name] = param_value
continue
if param_name == "judge_model":
value = st.selectbox(
f"Select **{param_name}** for {scoring_fn_id}",
options=available_models,
index=0,
key=f"{scoring_fn_id}_{param_name}",
)
new_params[param_name] = value
else:
value = st.text_area(
f"Enter value for **{param_name}** in {scoring_fn_id} in valid JSON format",
value=json.dumps(param_value, indent=2),
height=80,
)
try:
new_params[param_name] = json.loads(value)
except json.JSONDecodeError:
st.error(
f"Invalid JSON for **{param_name}** in {scoring_fn_id}"
)
st.json(new_params)
scoring_params[scoring_fn_id] = new_params
# Add run evaluation button & slider
total_rows = len(df)
num_rows = st.slider("Number of rows to evaluate", 1, total_rows, total_rows)
if st.button("Run Evaluation"):
progress_text = "Running evaluation..."
progress_bar = st.progress(0, text=progress_text)
rows = df.to_dict(orient="records")
if num_rows < total_rows:
rows = rows[:num_rows]
# Create separate containers for progress text and results
progress_text_container = st.empty()
results_container = st.empty()
output_res = {}
for i, r in enumerate(rows):
# Update progress
progress = i / len(rows)
progress_bar.progress(progress, text=progress_text)
# Run evaluation for current row
score_res = llama_stack_api.run_scoring(
r,
scoring_function_ids=selected_scoring_functions,
scoring_params=scoring_params,
)
for k in r.keys():
if k not in output_res:
output_res[k] = []
output_res[k].append(r[k])
for fn_id in selected_scoring_functions:
if fn_id not in output_res:
output_res[fn_id] = []
output_res[fn_id].append(score_res.results[fn_id].score_rows[0])
# Display current row results using separate containers
progress_text_container.write(
f"Expand to see current processed result ({i+1}/{len(rows)})"
)
results_container.json(
score_res.to_json(),
expanded=2,
)
progress_bar.progress(1.0, text="Evaluation complete!")
# Display results in dataframe
if output_res:
output_df = pd.DataFrame(output_res)
st.subheader("Evaluation Results")
st.dataframe(output_df)
application_evaluation_page()

View file

@ -0,0 +1,257 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import pandas as pd
import streamlit as st
from modules.api import llama_stack_api
def select_eval_task_1():
# Select Eval Tasks
st.subheader("1. Choose An Eval Task")
eval_tasks = llama_stack_api.client.eval_tasks.list()
eval_tasks = {et.identifier: et for et in eval_tasks}
eval_tasks_names = list(eval_tasks.keys())
selected_eval_task = st.selectbox(
"Choose an eval task.",
options=eval_tasks_names,
help="Choose an eval task. Each eval task is parameterized by a dataset, and list of scoring functions.",
)
with st.expander("View Eval Task"):
st.json(eval_tasks[selected_eval_task], expanded=True)
st.session_state["selected_eval_task"] = selected_eval_task
st.session_state["eval_tasks"] = eval_tasks
if st.button("Confirm", key="confirm_1"):
st.session_state["selected_eval_task_1_next"] = True
def define_eval_candidate_2():
if not st.session_state.get("selected_eval_task_1_next", None):
return
st.subheader("2. Define Eval Candidate")
st.info(
"""
Define the configurations for the evaluation candidate model or agent used for generation.
Select "model" if you want to run generation with inference API, or "agent" if you want to run generation with agent API through specifying AgentConfig.
"""
)
with st.expander("Define Eval Candidate", expanded=True):
# Define Eval Candidate
candidate_type = st.radio("Candidate Type", ["model", "agent"])
available_models = llama_stack_api.client.models.list()
available_models = [model.identifier for model in available_models]
selected_model = st.selectbox(
"Choose a model",
available_models,
index=0,
)
# Sampling Parameters
st.markdown("##### Sampling Parameters")
strategy = st.selectbox(
"Strategy",
["greedy", "top_p", "top_k"],
index=0,
)
temperature = st.slider(
"Temperature",
min_value=0.0,
max_value=1.0,
value=0.0,
step=0.1,
help="Controls the randomness of the response. Higher values make the output more creative and unexpected, lower values make it more conservative and predictable",
)
top_p = st.slider(
"Top P",
min_value=0.0,
max_value=1.0,
value=0.95,
step=0.1,
)
max_tokens = st.slider(
"Max Tokens",
min_value=0,
max_value=4096,
value=512,
step=1,
help="The maximum number of tokens to generate",
)
repetition_penalty = st.slider(
"Repetition Penalty",
min_value=1.0,
max_value=2.0,
value=1.0,
step=0.1,
help="Controls the likelihood for generating the same word or phrase multiple times in the same sentence or paragraph. 1 implies no penalty, 2 will strongly discourage model to repeat words or phrases.",
)
if candidate_type == "model":
eval_candidate = {
"type": "model",
"model": selected_model,
"sampling_params": {
"strategy": strategy,
"temperature": temperature,
"top_p": top_p,
"max_tokens": max_tokens,
"repetition_penalty": repetition_penalty,
},
}
elif candidate_type == "agent":
system_prompt = st.text_area(
"System Prompt",
value="You are a helpful AI assistant.",
help="Initial instructions given to the AI to set its behavior and context",
)
tools_json = st.text_area(
"Tools Configuration (JSON)",
value=json.dumps(
[
{
"type": "brave_search",
"engine": "brave",
"api_key": "ENTER_BRAVE_API_KEY_HERE",
}
]
),
help="Enter tool configurations in JSON format. Each tool should have a name, description, and parameters.",
height=200,
)
try:
tools = json.loads(tools_json)
except json.JSONDecodeError:
st.error("Invalid JSON format for tools configuration")
tools = []
eval_candidate = {
"type": "agent",
"config": {
"model": selected_model,
"instructions": system_prompt,
"tools": tools,
"tool_choice": "auto",
"tool_prompt_format": "json",
"input_shields": [],
"output_shields": [],
"enable_session_persistence": False,
},
}
st.session_state["eval_candidate"] = eval_candidate
if st.button("Confirm", key="confirm_2"):
st.session_state["selected_eval_candidate_2_next"] = True
def run_evaluation_3():
if not st.session_state.get("selected_eval_candidate_2_next", None):
return
st.subheader("3. Run Evaluation")
# Add info box to explain configurations being used
st.info(
"""
Review the configurations that will be used for this evaluation run, make any necessary changes, and then click the "Run Evaluation" button.
"""
)
selected_eval_task = st.session_state["selected_eval_task"]
eval_tasks = st.session_state["eval_tasks"]
eval_candidate = st.session_state["eval_candidate"]
dataset_id = eval_tasks[selected_eval_task].dataset_id
rows = llama_stack_api.client.datasetio.get_rows_paginated(
dataset_id=dataset_id,
rows_in_page=-1,
)
total_rows = len(rows.rows)
# Add number of examples control
num_rows = st.number_input(
"Number of Examples to Evaluate",
min_value=1,
max_value=total_rows,
value=5,
help="Number of examples from the dataset to evaluate. ",
)
eval_task_config = {
"type": "benchmark",
"eval_candidate": eval_candidate,
"scoring_params": {},
}
with st.expander("View Evaluation Task", expanded=True):
st.json(eval_tasks[selected_eval_task], expanded=True)
with st.expander("View Evaluation Task Configuration", expanded=True):
st.json(eval_task_config, expanded=True)
# Add run button and handle evaluation
if st.button("Run Evaluation"):
progress_text = "Running evaluation..."
progress_bar = st.progress(0, text=progress_text)
rows = rows.rows
if num_rows < total_rows:
rows = rows[:num_rows]
# Create separate containers for progress text and results
progress_text_container = st.empty()
results_container = st.empty()
output_res = {}
for i, r in enumerate(rows):
# Update progress
progress = i / len(rows)
progress_bar.progress(progress, text=progress_text)
# Run evaluation for current row
eval_res = llama_stack_api.client.eval.evaluate_rows(
task_id=selected_eval_task,
input_rows=[r],
scoring_functions=eval_tasks[selected_eval_task].scoring_functions,
task_config=eval_task_config,
)
for k in r.keys():
if k not in output_res:
output_res[k] = []
output_res[k].append(r[k])
for k in eval_res.generations[0].keys():
if k not in output_res:
output_res[k] = []
output_res[k].append(eval_res.generations[0][k])
for scoring_fn in eval_tasks[selected_eval_task].scoring_functions:
if scoring_fn not in output_res:
output_res[scoring_fn] = []
output_res[scoring_fn].append(eval_res.scores[scoring_fn].score_rows[0])
progress_text_container.write(
f"Expand to see current processed result ({i+1}/{len(rows)})"
)
results_container.json(eval_res, expanded=2)
progress_bar.progress(1.0, text="Evaluation complete!")
# Display results in dataframe
if output_res:
output_df = pd.DataFrame(output_res)
st.subheader("Evaluation Results")
st.dataframe(output_df)
def native_evaluation_page():
st.set_page_config(page_title="Evaluations (Generation + Scoring)", page_icon="🦙")
st.title("📊 Evaluations (Generation + Scoring)")
select_eval_task_1()
define_eval_candidate_2()
run_evaluation_3()
native_evaluation_page()

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,123 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import streamlit as st
from modules.api import llama_stack_api
# Sidebar configurations
with st.sidebar:
st.header("Configuration")
available_models = llama_stack_api.client.models.list()
available_models = [model.identifier for model in available_models]
selected_model = st.selectbox(
"Choose a model",
available_models,
index=0,
)
temperature = st.slider(
"Temperature",
min_value=0.0,
max_value=1.0,
value=0.0,
step=0.1,
help="Controls the randomness of the response. Higher values make the output more creative and unexpected, lower values make it more conservative and predictable",
)
top_p = st.slider(
"Top P",
min_value=0.0,
max_value=1.0,
value=0.95,
step=0.1,
)
max_tokens = st.slider(
"Max Tokens",
min_value=0,
max_value=4096,
value=512,
step=1,
help="The maximum number of tokens to generate",
)
repetition_penalty = st.slider(
"Repetition Penalty",
min_value=1.0,
max_value=2.0,
value=1.0,
step=0.1,
help="Controls the likelihood for generating the same word or phrase multiple times in the same sentence or paragraph. 1 implies no penalty, 2 will strongly discourage model to repeat words or phrases.",
)
stream = st.checkbox("Stream", value=True)
system_prompt = st.text_area(
"System Prompt",
value="You are a helpful AI assistant.",
help="Initial instructions given to the AI to set its behavior and context",
)
# Add clear chat button to sidebar
if st.button("Clear Chat", use_container_width=True):
st.session_state.messages = []
st.rerun()
# Main chat interface
st.title("🦙 Chat")
# Initialize chat history
if "messages" not in st.session_state:
st.session_state.messages = []
# Display chat messages
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# Chat input
if prompt := st.chat_input("Example: What is Llama Stack?"):
# Add user message to chat history
st.session_state.messages.append({"role": "user", "content": prompt})
# Display user message
with st.chat_message("user"):
st.markdown(prompt)
# Display assistant response
with st.chat_message("assistant"):
message_placeholder = st.empty()
full_response = ""
response = llama_stack_api.client.inference.chat_completion(
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": prompt},
],
model_id=selected_model,
stream=stream,
sampling_params={
"temperature": temperature,
"top_p": top_p,
"max_tokens": max_tokens,
"repetition_penalty": repetition_penalty,
},
)
if stream:
for chunk in response:
if chunk.event.event_type == "progress":
full_response += chunk.event.delta
message_placeholder.markdown(full_response + "")
message_placeholder.markdown(full_response)
else:
full_response = response
message_placeholder.markdown(full_response.completion_message.content)
st.session_state.messages.append(
{"role": "assistant", "content": full_response}
)

View file

@ -0,0 +1,188 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import streamlit as st
from llama_stack_client.lib.agents.agent import Agent
from llama_stack_client.lib.agents.event_logger import EventLogger
from llama_stack_client.types.agent_create_params import AgentConfig
from llama_stack_client.types.memory_insert_params import Document
from modules.api import llama_stack_api
from modules.utils import data_url_from_file
def rag_chat_page():
st.title("🦙 RAG")
with st.sidebar:
# File/Directory Upload Section
st.subheader("Upload Documents")
uploaded_files = st.file_uploader(
"Upload file(s) or directory",
accept_multiple_files=True,
type=["txt", "pdf", "doc", "docx"], # Add more file types as needed
)
# Process uploaded files
if uploaded_files:
st.success(f"Successfully uploaded {len(uploaded_files)} files")
# Add memory bank name input field
memory_bank_name = st.text_input(
"Memory Bank Name",
value="rag_bank",
help="Enter a unique identifier for this memory bank",
)
if st.button("Create Memory Bank"):
documents = [
Document(
document_id=uploaded_file.name,
content=data_url_from_file(uploaded_file),
)
for i, uploaded_file in enumerate(uploaded_files)
]
providers = llama_stack_api.client.providers.list()
llama_stack_api.client.memory_banks.register(
memory_bank_id=memory_bank_name, # Use the user-provided name
params={
"embedding_model": "all-MiniLM-L6-v2",
"chunk_size_in_tokens": 512,
"overlap_size_in_tokens": 64,
},
provider_id=providers["memory"][0].provider_id,
)
# insert documents using the custom bank name
llama_stack_api.client.memory.insert(
bank_id=memory_bank_name, # Use the user-provided name
documents=documents,
)
st.success("Memory bank created successfully!")
st.subheader("Configure Agent")
# select memory banks
memory_banks = llama_stack_api.client.memory_banks.list()
memory_banks = [bank.identifier for bank in memory_banks]
selected_memory_banks = st.multiselect(
"Select Memory Banks",
memory_banks,
)
memory_bank_configs = [
{"bank_id": bank_id, "type": "vector"} for bank_id in selected_memory_banks
]
available_models = llama_stack_api.client.models.list()
available_models = [model.identifier for model in available_models]
selected_model = st.selectbox(
"Choose a model",
available_models,
index=0,
)
system_prompt = st.text_area(
"System Prompt",
value="You are a helpful assistant. ",
help="Initial instructions given to the AI to set its behavior and context",
)
temperature = st.slider(
"Temperature",
min_value=0.0,
max_value=1.0,
value=0.0,
step=0.1,
help="Controls the randomness of the response. Higher values make the output more creative and unexpected, lower values make it more conservative and predictable",
)
top_p = st.slider(
"Top P",
min_value=0.0,
max_value=1.0,
value=0.95,
step=0.1,
)
# Add clear chat button to sidebar
if st.button("Clear Chat", use_container_width=True):
st.session_state.messages = []
st.rerun()
# Chat Interface
if "messages" not in st.session_state:
st.session_state.messages = []
# Display chat history
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
selected_model = llama_stack_api.client.models.list()[0].identifier
agent_config = AgentConfig(
model=selected_model,
instructions=system_prompt,
sampling_params={
"strategy": "greedy",
"temperature": temperature,
"top_p": top_p,
},
tools=[
{
"type": "memory",
"memory_bank_configs": memory_bank_configs,
"query_generator_config": {"type": "default", "sep": " "},
"max_tokens_in_context": 4096,
"max_chunks": 10,
}
],
tool_choice="auto",
tool_prompt_format="json",
input_shields=[],
output_shields=[],
enable_session_persistence=False,
)
agent = Agent(llama_stack_api.client, agent_config)
session_id = agent.create_session("rag-session")
# Chat input
if prompt := st.chat_input("Ask a question about your documents"):
# Add user message to chat history
st.session_state.messages.append({"role": "user", "content": prompt})
# Display user message
with st.chat_message("user"):
st.markdown(prompt)
response = agent.create_turn(
messages=[
{
"role": "user",
"content": prompt,
}
],
session_id=session_id,
)
# Display assistant response
with st.chat_message("assistant"):
retrieval_message_placeholder = st.empty()
message_placeholder = st.empty()
full_response = ""
retrieval_response = ""
for log in EventLogger().log(response):
log.print()
if log.role == "memory_retrieval":
retrieval_response += log.content.replace("====", "").strip()
retrieval_message_placeholder.info(retrieval_response)
else:
full_response += log.content
message_placeholder.markdown(full_response + "")
message_placeholder.markdown(full_response)
st.session_state.messages.append(
{"role": "assistant", "content": full_response}
)
rag_chat_page()

View file

@ -0,0 +1,4 @@
streamlit
pandas
llama-stack-client>=0.0.55
streamlit-option-menu

View file

@ -4,11 +4,10 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
from pathlib import Path
from .config_dirs import DEFAULT_CHECKPOINT_DIR
def model_local_dir(descriptor: str) -> str:
path = os.path.join(DEFAULT_CHECKPOINT_DIR, descriptor)
return path.replace(":", "-")
return str(Path(DEFAULT_CHECKPOINT_DIR) / (descriptor.replace(":", "-")))

View file

@ -63,6 +63,8 @@ class MemoryBanksProtocolPrivate(Protocol):
class DatasetsProtocolPrivate(Protocol):
async def register_dataset(self, dataset: Dataset) -> None: ...
async def unregister_dataset(self, dataset_id: str) -> None: ...
class ScoringFunctionsProtocolPrivate(Protocol):
async def list_scoring_functions(self) -> List[ScoringFn]: ...

View file

@ -144,87 +144,91 @@ class ChatAgent(ShieldRunnerMixin):
async def create_session(self, name: str) -> str:
return await self.storage.create_session(name)
@tracing.span("create_and_execute_turn")
async def create_and_execute_turn(
self, request: AgentTurnCreateRequest
) -> AsyncGenerator:
assert request.stream is True, "Non-streaming not supported"
with tracing.span("create_and_execute_turn") as span:
span.set_attribute("session_id", request.session_id)
span.set_attribute("agent_id", self.agent_id)
span.set_attribute("request", request.model_dump_json())
assert request.stream is True, "Non-streaming not supported"
session_info = await self.storage.get_session_info(request.session_id)
if session_info is None:
raise ValueError(f"Session {request.session_id} not found")
session_info = await self.storage.get_session_info(request.session_id)
if session_info is None:
raise ValueError(f"Session {request.session_id} not found")
turns = await self.storage.get_session_turns(request.session_id)
turns = await self.storage.get_session_turns(request.session_id)
messages = []
if self.agent_config.instructions != "":
messages.append(SystemMessage(content=self.agent_config.instructions))
messages = []
if self.agent_config.instructions != "":
messages.append(SystemMessage(content=self.agent_config.instructions))
for i, turn in enumerate(turns):
messages.extend(self.turn_to_messages(turn))
for i, turn in enumerate(turns):
messages.extend(self.turn_to_messages(turn))
messages.extend(request.messages)
messages.extend(request.messages)
turn_id = str(uuid.uuid4())
start_time = datetime.now()
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseTurnStartPayload(
turn_id=turn_id,
turn_id = str(uuid.uuid4())
span.set_attribute("turn_id", turn_id)
start_time = datetime.now()
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseTurnStartPayload(
turn_id=turn_id,
)
)
)
)
steps = []
output_message = None
async for chunk in self.run(
session_id=request.session_id,
turn_id=turn_id,
input_messages=messages,
attachments=request.attachments or [],
sampling_params=self.agent_config.sampling_params,
stream=request.stream,
):
if isinstance(chunk, CompletionMessage):
log.info(
f"{chunk.role.capitalize()}: {chunk.content}",
)
output_message = chunk
continue
assert isinstance(
chunk, AgentTurnResponseStreamChunk
), f"Unexpected type {type(chunk)}"
event = chunk.event
if (
event.payload.event_type
== AgentTurnResponseEventType.step_complete.value
steps = []
output_message = None
async for chunk in self.run(
session_id=request.session_id,
turn_id=turn_id,
input_messages=messages,
attachments=request.attachments or [],
sampling_params=self.agent_config.sampling_params,
stream=request.stream,
):
steps.append(event.payload.step_details)
if isinstance(chunk, CompletionMessage):
log.info(
f"{chunk.role.capitalize()}: {chunk.content}",
)
output_message = chunk
continue
yield chunk
assert isinstance(
chunk, AgentTurnResponseStreamChunk
), f"Unexpected type {type(chunk)}"
event = chunk.event
if (
event.payload.event_type
== AgentTurnResponseEventType.step_complete.value
):
steps.append(event.payload.step_details)
assert output_message is not None
yield chunk
turn = Turn(
turn_id=turn_id,
session_id=request.session_id,
input_messages=request.messages,
output_message=output_message,
started_at=start_time,
completed_at=datetime.now(),
steps=steps,
)
await self.storage.add_turn_to_session(request.session_id, turn)
assert output_message is not None
chunk = AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseTurnCompletePayload(
turn=turn,
turn = Turn(
turn_id=turn_id,
session_id=request.session_id,
input_messages=request.messages,
output_message=output_message,
started_at=start_time,
completed_at=datetime.now(),
steps=steps,
)
await self.storage.add_turn_to_session(request.session_id, turn)
chunk = AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseTurnCompletePayload(
turn=turn,
)
)
)
)
yield chunk
yield chunk
async def run(
self,
@ -273,7 +277,6 @@ class ChatAgent(ShieldRunnerMixin):
yield final_response
@tracing.span("run_shields")
async def run_multiple_shields_wrapper(
self,
turn_id: str,
@ -281,23 +284,47 @@ class ChatAgent(ShieldRunnerMixin):
shields: List[str],
touchpoint: str,
) -> AsyncGenerator:
if len(shields) == 0:
return
with tracing.span("run_shields") as span:
span.set_attribute("turn_id", turn_id)
span.set_attribute("input", [m.model_dump_json() for m in messages])
if len(shields) == 0:
span.set_attribute("output", "no shields")
return
step_id = str(uuid.uuid4())
try:
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepStartPayload(
step_type=StepType.shield_call.value,
step_id=step_id,
metadata=dict(touchpoint=touchpoint),
step_id = str(uuid.uuid4())
try:
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepStartPayload(
step_type=StepType.shield_call.value,
step_id=step_id,
metadata=dict(touchpoint=touchpoint),
)
)
)
)
await self.run_multiple_shields(messages, shields)
await self.run_multiple_shields(messages, shields)
except SafetyException as e:
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.shield_call.value,
step_details=ShieldCallStep(
step_id=step_id,
turn_id=turn_id,
violation=e.violation,
),
)
)
)
span.set_attribute("output", e.violation.model_dump_json())
yield CompletionMessage(
content=str(e),
stop_reason=StopReason.end_of_turn,
)
yield False
except SafetyException as e:
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload(
@ -305,30 +332,12 @@ class ChatAgent(ShieldRunnerMixin):
step_details=ShieldCallStep(
step_id=step_id,
turn_id=turn_id,
violation=e.violation,
violation=None,
),
)
)
)
yield CompletionMessage(
content=str(e),
stop_reason=StopReason.end_of_turn,
)
yield False
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.shield_call.value,
step_details=ShieldCallStep(
step_id=step_id,
turn_id=turn_id,
violation=None,
),
)
)
)
span.set_attribute("output", "no violations")
async def _run(
self,
@ -356,10 +365,15 @@ class ChatAgent(ShieldRunnerMixin):
# TODO: find older context from the session and either replace it
# or append with a sliding window. this is really a very simplistic implementation
with tracing.span("retrieve_rag_context"):
with tracing.span("retrieve_rag_context") as span:
rag_context, bank_ids = await self._retrieve_context(
session_id, input_messages, attachments
)
span.set_attribute(
"input", [m.model_dump_json() for m in input_messages]
)
span.set_attribute("output", rag_context)
span.set_attribute("bank_ids", bank_ids)
step_id = str(uuid.uuid4())
yield AgentTurnResponseStreamChunk(
@ -416,7 +430,7 @@ class ChatAgent(ShieldRunnerMixin):
content = ""
stop_reason = None
with tracing.span("inference"):
with tracing.span("inference") as span:
async for chunk in await self.inference_api.chat_completion(
self.agent_config.model,
input_messages,
@ -436,7 +450,6 @@ class ChatAgent(ShieldRunnerMixin):
if isinstance(delta, ToolCallDelta):
if delta.parse_status == ToolCallParseStatus.success:
tool_calls.append(delta.content)
if stream:
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
@ -466,6 +479,13 @@ class ChatAgent(ShieldRunnerMixin):
if event.stop_reason is not None:
stop_reason = event.stop_reason
span.set_attribute("stop_reason", stop_reason)
span.set_attribute(
"input", [m.model_dump_json() for m in input_messages]
)
span.set_attribute(
"output", f"content: {content} tool_calls: {tool_calls}"
)
stop_reason = stop_reason or StopReason.out_of_tokens
@ -549,7 +569,13 @@ class ChatAgent(ShieldRunnerMixin):
)
)
with tracing.span("tool_execution"):
with tracing.span(
"tool_execution",
{
"tool_name": tool_call.tool_name,
"input": message.model_dump_json(),
},
) as span:
result_messages = await execute_tool_call_maybe(
self.tools_dict,
[message],
@ -558,6 +584,7 @@ class ChatAgent(ShieldRunnerMixin):
len(result_messages) == 1
), "Currently not supporting multiple messages"
result_message = result_messages[0]
span.set_attribute("output", result_message.model_dump_json())
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(

View file

@ -6,9 +6,12 @@
import json
import logging
import shutil
import uuid
from typing import AsyncGenerator
from termcolor import colored
from llama_stack.apis.inference import Inference
from llama_stack.apis.memory import Memory
from llama_stack.apis.memory_banks import MemoryBanks
@ -44,6 +47,15 @@ class MetaReferenceAgentsImpl(Agents):
async def initialize(self) -> None:
self.persistence_store = await kvstore_impl(self.config.persistence_store)
# check if "bwrap" is available
if not shutil.which("bwrap"):
print(
colored(
"Warning: `bwrap` is not available. Code interpreter tool will not work correctly.",
"yellow",
)
)
async def create_agent(
self,
agent_config: AgentConfig,

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -3,14 +3,17 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Optional
from typing import Any, Dict, List, Optional
import pandas
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.datasetio import * # noqa: F403
import base64
import os
from abc import ABC, abstractmethod
from dataclasses import dataclass
from urllib.parse import urlparse
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url
@ -97,6 +100,9 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
dataset_impl=dataset_impl,
)
async def unregister_dataset(self, dataset_id: str) -> None:
del self.dataset_infos[dataset_id]
async def get_rows_paginated(
self,
dataset_id: str,
@ -128,3 +134,41 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
total_count=len(rows),
next_page_token=str(end),
)
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
dataset_info = self.dataset_infos.get(dataset_id)
if dataset_info is None:
raise ValueError(f"Dataset with id {dataset_id} not found")
dataset_impl = dataset_info.dataset_impl
dataset_impl.load()
new_rows_df = pandas.DataFrame(rows)
new_rows_df = dataset_impl._validate_dataset_schema(new_rows_df)
dataset_impl.df = pandas.concat(
[dataset_impl.df, new_rows_df], ignore_index=True
)
url = str(dataset_info.dataset_def.url)
parsed_url = urlparse(url)
if parsed_url.scheme == "file" or not parsed_url.scheme:
file_path = parsed_url.path
os.makedirs(os.path.dirname(file_path), exist_ok=True)
dataset_impl.df.to_csv(file_path, index=False)
elif parsed_url.scheme == "data":
# For data URLs, we need to update the base64-encoded content
if not parsed_url.path.startswith("text/csv;base64,"):
raise ValueError("Data URL must be a base64-encoded CSV")
csv_buffer = dataset_impl.df.to_csv(index=False)
base64_content = base64.b64encode(csv_buffer.encode("utf-8")).decode(
"utf-8"
)
dataset_info.dataset_def.url = URL(
uri=f"data:text/csv;base64,{base64_content}"
)
else:
raise ValueError(
f"Unsupported URL scheme: {parsed_url.scheme}. Only file:// and data: URLs are supported for writing."
)

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -3,12 +3,13 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
from llama_stack.providers.utils.kvstore.config import (
KVStoreConfig,
SqliteKVStoreConfig,
)
from pydantic import BaseModel
class MetaReferenceEvalConfig(BaseModel):

View file

@ -4,7 +4,9 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import Any, Dict, List, Optional
from llama_models.llama3.api.datatypes import * # noqa: F403
from tqdm import tqdm
from .....apis.common.job_types import Job
from .....apis.eval.eval import Eval, EvalTaskConfig, EvaluateResponse, JobStatus
@ -17,7 +19,6 @@ from llama_stack.apis.inference import Inference
from llama_stack.apis.scoring import Scoring
from llama_stack.providers.datatypes import EvalTasksProtocolPrivate
from llama_stack.providers.utils.kvstore import kvstore_impl
from tqdm import tqdm
from .config import MetaReferenceEvalConfig

View file

@ -27,7 +27,6 @@ from llama_stack.providers.utils.memory.vector_store import (
BankWithIndex,
EmbeddingIndex,
)
from llama_stack.providers.utils.telemetry import tracing
from .config import FaissImplConfig
@ -95,7 +94,6 @@ class FaissIndex(EmbeddingIndex):
await self.kvstore.delete(f"faiss_index:v1::{self.bank_id}")
@tracing.span(name="add_chunks")
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
indexlen = len(self.id_by_index)
for i, chunk in enumerate(chunks):

View file

@ -1,15 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import ConsoleConfig
async def get_provider_impl(config: ConsoleConfig, _deps):
from .console import ConsoleTelemetryImpl
impl = ConsoleTelemetryImpl(config)
await impl.initialize()
return impl

View file

@ -1,21 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel
class LogFormat(Enum):
TEXT = "text"
JSON = "json"
@json_schema_type
class ConsoleConfig(BaseModel):
log_format: LogFormat = LogFormat.TEXT

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
import json
from typing import Optional
from typing import List, Optional
from .config import LogFormat
@ -49,8 +49,27 @@ class ConsoleTelemetryImpl(Telemetry):
if formatted:
print(formatted)
async def get_trace(self, trace_id: str) -> Trace:
raise NotImplementedError()
async def query_traces(
self,
attribute_conditions: Optional[List[QueryCondition]] = None,
attribute_keys_to_return: Optional[List[str]] = None,
limit: Optional[int] = 100,
offset: Optional[int] = 0,
order_by: Optional[List[str]] = None,
) -> List[Trace]:
raise NotImplementedError("Console telemetry does not support trace querying")
async def get_spans(
self,
span_id: str,
attribute_conditions: Optional[List[QueryCondition]] = None,
attribute_keys_to_return: Optional[List[str]] = None,
max_depth: Optional[int] = None,
limit: Optional[int] = 100,
offset: Optional[int] = 0,
order_by: Optional[List[str]] = None,
) -> SpanWithChildren:
raise NotImplementedError("Console telemetry does not support span querying")
COLORS = {

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -5,11 +5,17 @@
# the root directory of this source tree.
from typing import Dict
from pydantic import BaseModel
from llama_stack.distribution.datatypes import Api, ProviderSpec
from .config import BraintrustScoringConfig
class BraintrustProviderDataValidator(BaseModel):
openai_api_key: str
async def get_provider_impl(
config: BraintrustScoringConfig,
deps: Dict[Api, ProviderSpec],

View file

@ -12,9 +12,12 @@ from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.apis.datasetio import * # noqa: F403
from llama_stack.apis.datasets import * # noqa: F403
# from .scoring_fn.braintrust_scoring_fn import BraintrustScoringFn
import os
from autoevals.llm import Factuality
from autoevals.ragas import AnswerCorrectness
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_average
@ -24,7 +27,9 @@ from .scoring_fn.fn_defs.answer_correctness import answer_correctness_fn_def
from .scoring_fn.fn_defs.factuality import factuality_fn_def
class BraintrustScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
class BraintrustScoringImpl(
Scoring, ScoringFunctionsProtocolPrivate, NeedsRequestProviderData
):
def __init__(
self,
config: BraintrustScoringConfig,
@ -79,12 +84,25 @@ class BraintrustScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
f"Dataset {dataset_id} does not have a '{required_column}' column of type 'string'."
)
async def set_api_key(self) -> None:
# api key is in the request headers
if self.config.openai_api_key is None:
provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.openai_api_key:
raise ValueError(
'Pass OpenAI API Key in the header X-LlamaStack-ProviderData as { "openai_api_key": <your api key>}'
)
self.config.openai_api_key = provider_data.openai_api_key
os.environ["OPENAI_API_KEY"] = self.config.openai_api_key
async def score_batch(
self,
dataset_id: str,
scoring_functions: List[str],
save_results_dataset: bool = False,
) -> ScoreBatchResponse:
await self.set_api_key()
await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id)
all_rows = await self.datasetio_api.get_rows_paginated(
dataset_id=dataset_id,
@ -105,6 +123,7 @@ class BraintrustScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
async def score_row(
self, input_row: Dict[str, Any], scoring_fn_identifier: Optional[str] = None
) -> ScoringResultRow:
await self.set_api_key()
assert scoring_fn_identifier is not None, "scoring_fn_identifier cannot be None"
expected_answer = input_row["expected_answer"]
generated_answer = input_row["generated_answer"]
@ -118,6 +137,7 @@ class BraintrustScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
async def score(
self, input_rows: List[Dict[str, Any]], scoring_functions: List[str]
) -> ScoreResponse:
await self.set_api_key()
res = {}
for scoring_fn_id in scoring_functions:
if scoring_fn_id not in self.supported_fn_defs_registry:

View file

@ -6,4 +6,8 @@
from llama_stack.apis.scoring import * # noqa: F401, F403
class BraintrustScoringConfig(BaseModel): ...
class BraintrustScoringConfig(BaseModel):
openai_api_key: Optional[str] = Field(
default=None,
description="The OpenAI API Key",
)

View file

@ -10,7 +10,7 @@ from llama_stack.apis.scoring_functions import ScoringFn
answer_correctness_fn_def = ScoringFn(
identifier="braintrust::answer-correctness",
description="Test whether an output is factual, compared to an original (`expected`) value. One of Braintrust LLM basd scorer https://github.com/braintrustdata/autoevals/blob/main/py/autoevals/llm.py",
description="Scores the correctness of the answer based on the ground truth.. One of Braintrust LLM basd scorer https://github.com/braintrustdata/autoevals/blob/main/py/autoevals/llm.py",
params=None,
provider_id="braintrust",
provider_resource_id="answer-correctness",

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import ScoringFn
from llama_stack.apis.scoring_functions import LLMAsJudgeScoringFnParams, ScoringFn
llm_as_judge_base = ScoringFn(
@ -14,4 +14,8 @@ llm_as_judge_base = ScoringFn(
return_type=NumberType(),
provider_id="llm-as-judge",
provider_resource_id="llm-as-judge-base",
params=LLMAsJudgeScoringFnParams(
judge_model="meta-llama/Llama-3.1-405B-Instruct",
prompt_template="Enter custom LLM as Judge Prompt Template",
),
)

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,18 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from .config import TelemetryConfig, TelemetrySink
from .telemetry import TelemetryAdapter
__all__ = ["TelemetryConfig", "TelemetryAdapter", "TelemetrySink"]
async def get_provider_impl(config: TelemetryConfig, deps: Dict[str, Any]):
impl = TelemetryAdapter(config, deps)
await impl.initialize()
return impl

View file

@ -0,0 +1,45 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import Any, Dict, List
from pydantic import BaseModel, Field
from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
class TelemetrySink(str, Enum):
OTEL = "otel"
SQLITE = "sqlite"
CONSOLE = "console"
class TelemetryConfig(BaseModel):
otel_endpoint: str = Field(
default="http://localhost:4318/v1/traces",
description="The OpenTelemetry collector endpoint URL",
)
service_name: str = Field(
default="llama-stack",
description="The service name to use for telemetry",
)
sinks: List[TelemetrySink] = Field(
default=[TelemetrySink.CONSOLE, TelemetrySink.SQLITE],
description="List of telemetry sinks to enable (possible values: otel, sqlite, console)",
)
sqlite_db_path: str = Field(
default=(RUNTIME_BASE_DIR / "trace_store.db").as_posix(),
description="The path to the SQLite database to use for storing traces",
)
@classmethod
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
return {
"service_name": "${env.OTEL_SERVICE_NAME:llama-stack}",
"sinks": "${env.TELEMETRY_SINKS:['console', 'sqlite']}",
"sqlite_db_path": "${env.SQLITE_DB_PATH:${runtime.base_dir}/trace_store.db}",
}

View file

@ -0,0 +1,117 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from datetime import datetime
from opentelemetry.sdk.trace import ReadableSpan
from opentelemetry.sdk.trace.export import SpanProcessor
from opentelemetry.trace.status import StatusCode
# Colors for console output
COLORS = {
"reset": "\033[0m",
"bold": "\033[1m",
"dim": "\033[2m",
"red": "\033[31m",
"green": "\033[32m",
"yellow": "\033[33m",
"blue": "\033[34m",
"magenta": "\033[35m",
"cyan": "\033[36m",
"white": "\033[37m",
}
class ConsoleSpanProcessor(SpanProcessor):
def __init__(self, print_attributes: bool = False):
self.print_attributes = print_attributes
def on_start(self, span: ReadableSpan, parent_context=None) -> None:
if span.attributes and span.attributes.get("__autotraced__"):
return
timestamp = datetime.utcfromtimestamp(span.start_time / 1e9).strftime(
"%H:%M:%S.%f"
)[:-3]
print(
f"{COLORS['dim']}{timestamp}{COLORS['reset']} "
f"{COLORS['magenta']}[START]{COLORS['reset']} "
f"{COLORS['dim']}{span.name}{COLORS['reset']}"
)
def on_end(self, span: ReadableSpan) -> None:
if span.attributes and span.attributes.get("__autotraced__"):
return
timestamp = datetime.utcfromtimestamp(span.end_time / 1e9).strftime(
"%H:%M:%S.%f"
)[:-3]
span_context = (
f"{COLORS['dim']}{timestamp}{COLORS['reset']} "
f"{COLORS['magenta']}[END]{COLORS['reset']} "
f"{COLORS['dim']}{span.name}{COLORS['reset']}"
)
if span.status.status_code == StatusCode.ERROR:
span_context += f"{COLORS['reset']} {COLORS['red']}[ERROR]{COLORS['reset']}"
elif span.status.status_code != StatusCode.UNSET:
span_context += f"{COLORS['reset']} [{span.status.status_code}]"
duration_ms = (span.end_time - span.start_time) / 1e6
span_context += f"{COLORS['reset']} ({duration_ms:.2f}ms)"
print(span_context)
if self.print_attributes and span.attributes:
for key, value in span.attributes.items():
if key.startswith("__"):
continue
str_value = str(value)
if len(str_value) > 1000:
str_value = str_value[:997] + "..."
print(f" {COLORS['dim']}{key}: {str_value}{COLORS['reset']}")
for event in span.events:
event_time = datetime.utcfromtimestamp(event.timestamp / 1e9).strftime(
"%H:%M:%S.%f"
)[:-3]
severity = event.attributes.get("severity", "info")
message = event.attributes.get("message", event.name)
if isinstance(message, (dict, list)):
message = json.dumps(message, indent=2)
severity_colors = {
"error": f"{COLORS['bold']}{COLORS['red']}",
"warn": f"{COLORS['bold']}{COLORS['yellow']}",
"info": COLORS["white"],
"debug": COLORS["dim"],
}
msg_color = severity_colors.get(severity, COLORS["white"])
print(
f" {event_time} "
f"{msg_color}[{severity.upper()}] "
f"{message}{COLORS['reset']}"
)
if event.attributes:
for key, value in event.attributes.items():
if key.startswith("__") or key in ["message", "severity"]:
continue
print(f" {COLORS['dim']}{key}: {value}{COLORS['reset']}")
def shutdown(self) -> None:
"""Shutdown the processor."""
pass
def force_flush(self, timeout_millis: float = None) -> bool:
"""Force flush any pending spans."""
return True

View file

@ -0,0 +1,242 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import os
import sqlite3
import threading
from datetime import datetime, timedelta
from typing import Dict
from opentelemetry.sdk.trace import SpanProcessor
from opentelemetry.trace import Span
class SQLiteSpanProcessor(SpanProcessor):
def __init__(self, conn_string, ttl_days=30):
"""Initialize the SQLite span processor with a connection string."""
self.conn_string = conn_string
self.ttl_days = ttl_days
self.cleanup_task = None
self._thread_local = threading.local()
self._connections: Dict[int, sqlite3.Connection] = {}
self._lock = threading.Lock()
self.setup_database()
def _get_connection(self) -> sqlite3.Connection:
"""Get a thread-specific database connection."""
thread_id = threading.get_ident()
with self._lock:
if thread_id not in self._connections:
conn = sqlite3.connect(self.conn_string)
self._connections[thread_id] = conn
return self._connections[thread_id]
def setup_database(self):
"""Create the necessary tables if they don't exist."""
# Create directory if it doesn't exist
os.makedirs(os.path.dirname(self.conn_string), exist_ok=True)
conn = self._get_connection()
cursor = conn.cursor()
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS traces (
trace_id TEXT PRIMARY KEY,
service_name TEXT,
root_span_id TEXT,
start_time TIMESTAMP,
end_time TIMESTAMP,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
"""
)
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS spans (
span_id TEXT PRIMARY KEY,
trace_id TEXT REFERENCES traces(trace_id),
parent_span_id TEXT,
name TEXT,
start_time TIMESTAMP,
end_time TIMESTAMP,
attributes TEXT,
status TEXT,
kind TEXT
)
"""
)
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS span_events (
id INTEGER PRIMARY KEY AUTOINCREMENT,
span_id TEXT REFERENCES spans(span_id),
name TEXT,
timestamp TIMESTAMP,
attributes TEXT
)
"""
)
cursor.execute(
"""
CREATE INDEX IF NOT EXISTS idx_traces_created_at
ON traces(created_at)
"""
)
conn.commit()
cursor.close()
# Start periodic cleanup in a separate thread
self.cleanup_task = threading.Thread(target=self._periodic_cleanup, daemon=True)
self.cleanup_task.start()
def _cleanup_old_data(self):
"""Delete records older than TTL."""
try:
conn = self._get_connection()
cutoff_date = (datetime.now() - timedelta(days=self.ttl_days)).isoformat()
cursor = conn.cursor()
# Delete old span events
cursor.execute(
"""
DELETE FROM span_events
WHERE span_id IN (
SELECT span_id FROM spans
WHERE trace_id IN (
SELECT trace_id FROM traces
WHERE created_at < ?
)
)
""",
(cutoff_date,),
)
# Delete old spans
cursor.execute(
"""
DELETE FROM spans
WHERE trace_id IN (
SELECT trace_id FROM traces
WHERE created_at < ?
)
""",
(cutoff_date,),
)
# Delete old traces
cursor.execute("DELETE FROM traces WHERE created_at < ?", (cutoff_date,))
conn.commit()
cursor.close()
except Exception as e:
print(f"Error during cleanup: {e}")
def _periodic_cleanup(self):
"""Run cleanup periodically."""
import time
while True:
time.sleep(3600) # Sleep for 1 hour
self._cleanup_old_data()
def on_start(self, span: Span, parent_context=None):
"""Called when a span starts."""
pass
def on_end(self, span: Span):
"""Called when a span ends. Export the span data to SQLite."""
try:
conn = self._get_connection()
cursor = conn.cursor()
trace_id = format(span.get_span_context().trace_id, "032x")
span_id = format(span.get_span_context().span_id, "016x")
service_name = span.resource.attributes.get("service.name", "unknown")
parent_span_id = None
parent_context = span.parent
if parent_context:
parent_span_id = format(parent_context.span_id, "016x")
# Insert into traces
cursor.execute(
"""
INSERT INTO traces (
trace_id, service_name, root_span_id, start_time, end_time
) VALUES (?, ?, ?, ?, ?)
ON CONFLICT(trace_id) DO UPDATE SET
root_span_id = COALESCE(root_span_id, excluded.root_span_id),
start_time = MIN(excluded.start_time, start_time),
end_time = MAX(excluded.end_time, end_time)
""",
(
trace_id,
service_name,
(span_id if not parent_span_id else None),
datetime.fromtimestamp(span.start_time / 1e9).isoformat(),
datetime.fromtimestamp(span.end_time / 1e9).isoformat(),
),
)
# Insert into spans
cursor.execute(
"""
INSERT INTO spans (
span_id, trace_id, parent_span_id, name,
start_time, end_time, attributes, status,
kind
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
span_id,
trace_id,
parent_span_id,
span.name,
datetime.fromtimestamp(span.start_time / 1e9).isoformat(),
datetime.fromtimestamp(span.end_time / 1e9).isoformat(),
json.dumps(dict(span.attributes)),
span.status.status_code.name,
span.kind.name,
),
)
for event in span.events:
cursor.execute(
"""
INSERT INTO span_events (
span_id, name, timestamp, attributes
) VALUES (?, ?, ?, ?)
""",
(
span_id,
event.name,
datetime.fromtimestamp(event.timestamp / 1e9).isoformat(),
json.dumps(dict(event.attributes)),
),
)
conn.commit()
cursor.close()
except Exception as e:
print(f"Error exporting span to SQLite: {e}")
def shutdown(self):
"""Cleanup any resources."""
with self._lock:
for conn in self._connections.values():
if conn:
conn.close()
self._connections.clear()
def force_flush(self, timeout_millis=30000):
"""Force export of spans."""
pass

Some files were not shown because too many files have changed in this diff Show more