forked from phoenix-oss/llama-stack-mirror
Compare commits
1 commit
kvant
...
create-pul
Author | SHA1 | Date | |
---|---|---|---|
|
075c5401f5 |
351 changed files with 9284 additions and 25653 deletions
10
.github/PULL_REQUEST_TEMPLATE.md
vendored
10
.github/PULL_REQUEST_TEMPLATE.md
vendored
|
@ -1,8 +1,10 @@
|
||||||
# What does this PR do?
|
# What does this PR do?
|
||||||
<!-- Provide a short summary of what this PR does and why. Link to relevant issues if applicable. -->
|
[Provide a short summary of what this PR does and why. Link to relevant issues if applicable.]
|
||||||
|
|
||||||
<!-- If resolving an issue, uncomment and update the line below -->
|
[//]: # (If resolving an issue, uncomment and update the line below)
|
||||||
<!-- Closes #[issue-number] -->
|
[//]: # (Closes #[issue-number])
|
||||||
|
|
||||||
## Test Plan
|
## Test Plan
|
||||||
<!-- Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.* -->
|
[Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.*]
|
||||||
|
|
||||||
|
[//]: # (## Documentation)
|
||||||
|
|
22
.github/actions/setup-runner/action.yml
vendored
22
.github/actions/setup-runner/action.yml
vendored
|
@ -1,22 +0,0 @@
|
||||||
name: Setup runner
|
|
||||||
description: Prepare a runner for the tests (install uv, python, project dependencies, etc.)
|
|
||||||
runs:
|
|
||||||
using: "composite"
|
|
||||||
steps:
|
|
||||||
- name: Install uv
|
|
||||||
uses: astral-sh/setup-uv@6b9c6063abd6010835644d4c2e1bef4cf5cd0fca # v6.0.1
|
|
||||||
with:
|
|
||||||
python-version: "3.10"
|
|
||||||
activate-environment: true
|
|
||||||
version: 0.7.6
|
|
||||||
|
|
||||||
- name: Install dependencies
|
|
||||||
shell: bash
|
|
||||||
run: |
|
|
||||||
uv sync --all-groups
|
|
||||||
uv pip install ollama faiss-cpu
|
|
||||||
# always test against the latest version of the client
|
|
||||||
# TODO: this is not necessarily a good idea. we need to test against both published and latest
|
|
||||||
# to find out backwards compatibility issues.
|
|
||||||
uv pip install git+https://github.com/meta-llama/llama-stack-client-python.git@main
|
|
||||||
uv pip install -e .
|
|
1
.github/workflows/Dockerfile
vendored
1
.github/workflows/Dockerfile
vendored
|
@ -1 +0,0 @@
|
||||||
FROM localhost:5000/distribution-kvant:dev
|
|
73
.github/workflows/ci-playground.yaml
vendored
73
.github/workflows/ci-playground.yaml
vendored
|
@ -1,73 +0,0 @@
|
||||||
name: Build and Push playground container
|
|
||||||
run-name: Build and Push playground container
|
|
||||||
on:
|
|
||||||
workflow_dispatch:
|
|
||||||
#schedule:
|
|
||||||
# - cron: "0 10 * * *"
|
|
||||||
push:
|
|
||||||
branches:
|
|
||||||
- main
|
|
||||||
- kvant
|
|
||||||
tags:
|
|
||||||
- 'v*'
|
|
||||||
pull_request:
|
|
||||||
branches:
|
|
||||||
- main
|
|
||||||
- kvant
|
|
||||||
env:
|
|
||||||
IMAGE: git.kvant.cloud/${{github.repository}}-playground
|
|
||||||
jobs:
|
|
||||||
build-playground:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
steps:
|
|
||||||
- name: Checkout
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
with:
|
|
||||||
fetch-depth: 0
|
|
||||||
|
|
||||||
- name: Set current time
|
|
||||||
uses: https://github.com/gerred/actions/current-time@master
|
|
||||||
id: current_time
|
|
||||||
|
|
||||||
- name: Set up Docker Buildx
|
|
||||||
uses: docker/setup-buildx-action@v3
|
|
||||||
|
|
||||||
- name: Login to git.kvant.cloud registry
|
|
||||||
uses: docker/login-action@v3
|
|
||||||
with:
|
|
||||||
registry: git.kvant.cloud
|
|
||||||
username: ${{ vars.ORG_PACKAGE_WRITER_USERNAME }}
|
|
||||||
password: ${{ secrets.ORG_PACKAGE_WRITER_TOKEN }}
|
|
||||||
|
|
||||||
- name: Docker meta
|
|
||||||
id: meta
|
|
||||||
uses: docker/metadata-action@v5
|
|
||||||
with:
|
|
||||||
# list of Docker images to use as base name for tags
|
|
||||||
images: |
|
|
||||||
${{env.IMAGE}}
|
|
||||||
# generate Docker tags based on the following events/attributes
|
|
||||||
tags: |
|
|
||||||
type=schedule
|
|
||||||
type=ref,event=branch
|
|
||||||
type=ref,event=pr
|
|
||||||
type=ref,event=tag
|
|
||||||
type=semver,pattern={{version}}
|
|
||||||
|
|
||||||
- name: Build and push to gitea registry
|
|
||||||
uses: docker/build-push-action@v6
|
|
||||||
with:
|
|
||||||
push: ${{ github.event_name != 'pull_request' }}
|
|
||||||
tags: ${{ steps.meta.outputs.tags }}
|
|
||||||
labels: ${{ steps.meta.outputs.labels }}
|
|
||||||
context: .
|
|
||||||
file: llama_stack/distribution/ui/Containerfile
|
|
||||||
provenance: mode=max
|
|
||||||
sbom: true
|
|
||||||
build-args: |
|
|
||||||
BUILD_DATE=${{ steps.current_time.outputs.time }}
|
|
||||||
cache-from: |
|
|
||||||
type=registry,ref=${{ env.IMAGE }}:buildcache
|
|
||||||
type=registry,ref=${{ env.IMAGE }}:${{ github.ref_name }}
|
|
||||||
type=registry,ref=${{ env.IMAGE }}:main
|
|
||||||
cache-to: type=registry,ref=${{ env.IMAGE }}:buildcache,mode=max,image-manifest=true
|
|
98
.github/workflows/ci.yaml
vendored
98
.github/workflows/ci.yaml
vendored
|
@ -1,98 +0,0 @@
|
||||||
name: Build and Push container
|
|
||||||
run-name: Build and Push container
|
|
||||||
on:
|
|
||||||
workflow_dispatch:
|
|
||||||
#schedule:
|
|
||||||
# - cron: "0 10 * * *"
|
|
||||||
push:
|
|
||||||
branches:
|
|
||||||
- main
|
|
||||||
- kvant
|
|
||||||
tags:
|
|
||||||
- 'v*'
|
|
||||||
pull_request:
|
|
||||||
branches:
|
|
||||||
- main
|
|
||||||
- kvant
|
|
||||||
env:
|
|
||||||
IMAGE: git.kvant.cloud/${{github.repository}}
|
|
||||||
jobs:
|
|
||||||
build:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
services:
|
|
||||||
registry:
|
|
||||||
image: registry:2
|
|
||||||
ports:
|
|
||||||
- 5000:5000
|
|
||||||
steps:
|
|
||||||
- name: Checkout
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
with:
|
|
||||||
fetch-depth: 0
|
|
||||||
|
|
||||||
- name: Set current time
|
|
||||||
uses: https://github.com/gerred/actions/current-time@master
|
|
||||||
id: current_time
|
|
||||||
|
|
||||||
- name: Set up Docker Buildx
|
|
||||||
uses: docker/setup-buildx-action@v3
|
|
||||||
with:
|
|
||||||
driver-opts: network=host
|
|
||||||
|
|
||||||
- name: Login to git.kvant.cloud registry
|
|
||||||
uses: docker/login-action@v3
|
|
||||||
with:
|
|
||||||
registry: git.kvant.cloud
|
|
||||||
username: ${{ vars.ORG_PACKAGE_WRITER_USERNAME }}
|
|
||||||
password: ${{ secrets.ORG_PACKAGE_WRITER_TOKEN }}
|
|
||||||
|
|
||||||
- name: Docker meta
|
|
||||||
id: meta
|
|
||||||
uses: docker/metadata-action@v5
|
|
||||||
with:
|
|
||||||
# list of Docker images to use as base name for tags
|
|
||||||
images: |
|
|
||||||
${{env.IMAGE}}
|
|
||||||
# generate Docker tags based on the following events/attributes
|
|
||||||
tags: |
|
|
||||||
type=schedule
|
|
||||||
type=ref,event=branch
|
|
||||||
type=ref,event=pr
|
|
||||||
type=ref,event=tag
|
|
||||||
type=semver,pattern={{version}}
|
|
||||||
|
|
||||||
- name: Install uv
|
|
||||||
uses: https://github.com/astral-sh/setup-uv@v5
|
|
||||||
with:
|
|
||||||
# Install a specific version of uv.
|
|
||||||
version: "0.7.8"
|
|
||||||
|
|
||||||
- name: Build
|
|
||||||
env:
|
|
||||||
USE_COPY_NOT_MOUNT: true
|
|
||||||
LLAMA_STACK_DIR: .
|
|
||||||
run: |
|
|
||||||
uvx --from . llama stack build --template kvant --image-type container
|
|
||||||
|
|
||||||
# docker tag distribution-kvant:dev ${{env.IMAGE}}:kvant
|
|
||||||
# docker push ${{env.IMAGE}}:kvant
|
|
||||||
|
|
||||||
docker tag distribution-kvant:dev localhost:5000/distribution-kvant:dev
|
|
||||||
docker push localhost:5000/distribution-kvant:dev
|
|
||||||
|
|
||||||
- name: Build and push to gitea registry
|
|
||||||
uses: docker/build-push-action@v6
|
|
||||||
with:
|
|
||||||
push: ${{ github.event_name != 'pull_request' }}
|
|
||||||
tags: ${{ steps.meta.outputs.tags }}
|
|
||||||
labels: ${{ steps.meta.outputs.labels }}
|
|
||||||
context: .github/workflows
|
|
||||||
provenance: mode=max
|
|
||||||
sbom: true
|
|
||||||
build-args: |
|
|
||||||
BUILD_DATE=${{ steps.current_time.outputs.time }}
|
|
||||||
cache-from: |
|
|
||||||
type=registry,ref=${{ env.IMAGE }}:buildcache
|
|
||||||
type=registry,ref=${{ env.IMAGE }}:${{ github.ref_name }}
|
|
||||||
type=registry,ref=${{ env.IMAGE }}:main
|
|
||||||
cache-to: type=registry,ref=${{ env.IMAGE }}:buildcache,mode=max,image-manifest=true
|
|
|
@ -23,18 +23,23 @@ jobs:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
auth-provider: [oauth2_token]
|
auth-provider: [kubernetes]
|
||||||
fail-fast: false # we want to run all tests regardless of failure
|
fail-fast: false # we want to run all tests regardless of failure
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install uv
|
||||||
uses: ./.github/actions/setup-runner
|
uses: astral-sh/setup-uv@c7f87aa956e4c323abf06d5dec078e358f6b4d04 # v6.0.0
|
||||||
|
with:
|
||||||
|
python-version: "3.10"
|
||||||
|
activate-environment: true
|
||||||
|
|
||||||
- name: Build Llama Stack
|
- name: Set Up Environment and Install Dependencies
|
||||||
run: |
|
run: |
|
||||||
|
uv sync --extra dev --extra test
|
||||||
|
uv pip install -e .
|
||||||
llama stack build --template ollama --image-type venv
|
llama stack build --template ollama --image-type venv
|
||||||
|
|
||||||
- name: Install minikube
|
- name: Install minikube
|
||||||
|
@ -42,53 +47,29 @@ jobs:
|
||||||
uses: medyagh/setup-minikube@cea33675329b799adccc9526aa5daccc26cd5052 # v0.0.19
|
uses: medyagh/setup-minikube@cea33675329b799adccc9526aa5daccc26cd5052 # v0.0.19
|
||||||
|
|
||||||
- name: Start minikube
|
- name: Start minikube
|
||||||
if: ${{ matrix.auth-provider == 'oauth2_token' }}
|
if: ${{ matrix.auth-provider == 'kubernetes' }}
|
||||||
run: |
|
run: |
|
||||||
minikube start
|
minikube start
|
||||||
kubectl get pods -A
|
kubectl get pods -A
|
||||||
|
|
||||||
- name: Configure Kube Auth
|
- name: Configure Kube Auth
|
||||||
if: ${{ matrix.auth-provider == 'oauth2_token' }}
|
if: ${{ matrix.auth-provider == 'kubernetes' }}
|
||||||
run: |
|
run: |
|
||||||
kubectl create namespace llama-stack
|
kubectl create namespace llama-stack
|
||||||
kubectl create serviceaccount llama-stack-auth -n llama-stack
|
kubectl create serviceaccount llama-stack-auth -n llama-stack
|
||||||
kubectl create rolebinding llama-stack-auth-rolebinding --clusterrole=admin --serviceaccount=llama-stack:llama-stack-auth -n llama-stack
|
kubectl create rolebinding llama-stack-auth-rolebinding --clusterrole=admin --serviceaccount=llama-stack:llama-stack-auth -n llama-stack
|
||||||
kubectl create token llama-stack-auth -n llama-stack > llama-stack-auth-token
|
kubectl create token llama-stack-auth -n llama-stack > llama-stack-auth-token
|
||||||
cat <<EOF | kubectl apply -f -
|
|
||||||
apiVersion: rbac.authorization.k8s.io/v1
|
|
||||||
kind: ClusterRole
|
|
||||||
metadata:
|
|
||||||
name: allow-anonymous-openid
|
|
||||||
rules:
|
|
||||||
- nonResourceURLs: ["/openid/v1/jwks"]
|
|
||||||
verbs: ["get"]
|
|
||||||
---
|
|
||||||
apiVersion: rbac.authorization.k8s.io/v1
|
|
||||||
kind: ClusterRoleBinding
|
|
||||||
metadata:
|
|
||||||
name: allow-anonymous-openid
|
|
||||||
roleRef:
|
|
||||||
apiGroup: rbac.authorization.k8s.io
|
|
||||||
kind: ClusterRole
|
|
||||||
name: allow-anonymous-openid
|
|
||||||
subjects:
|
|
||||||
- kind: User
|
|
||||||
name: system:anonymous
|
|
||||||
apiGroup: rbac.authorization.k8s.io
|
|
||||||
EOF
|
|
||||||
|
|
||||||
- name: Set Kubernetes Config
|
- name: Set Kubernetes Config
|
||||||
if: ${{ matrix.auth-provider == 'oauth2_token' }}
|
if: ${{ matrix.auth-provider == 'kubernetes' }}
|
||||||
run: |
|
run: |
|
||||||
echo "KUBERNETES_API_SERVER_URL=$(kubectl get --raw /.well-known/openid-configuration| jq -r .jwks_uri)" >> $GITHUB_ENV
|
echo "KUBERNETES_API_SERVER_URL=$(kubectl config view --minify -o jsonpath='{.clusters[0].cluster.server}')" >> $GITHUB_ENV
|
||||||
echo "KUBERNETES_CA_CERT_PATH=$(kubectl config view --minify -o jsonpath='{.clusters[0].cluster.certificate-authority}')" >> $GITHUB_ENV
|
echo "KUBERNETES_CA_CERT_PATH=$(kubectl config view --minify -o jsonpath='{.clusters[0].cluster.certificate-authority}')" >> $GITHUB_ENV
|
||||||
echo "KUBERNETES_ISSUER=$(kubectl get --raw /.well-known/openid-configuration| jq -r .issuer)" >> $GITHUB_ENV
|
|
||||||
echo "KUBERNETES_AUDIENCE=$(kubectl create token llama-stack-auth -n llama-stack --duration=1h | cut -d. -f2 | base64 -d | jq -r '.aud[0]')" >> $GITHUB_ENV
|
|
||||||
|
|
||||||
- name: Set Kube Auth Config and run server
|
- name: Set Kube Auth Config and run server
|
||||||
env:
|
env:
|
||||||
INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct"
|
INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct"
|
||||||
if: ${{ matrix.auth-provider == 'oauth2_token' }}
|
if: ${{ matrix.auth-provider == 'kubernetes' }}
|
||||||
run: |
|
run: |
|
||||||
run_dir=$(mktemp -d)
|
run_dir=$(mktemp -d)
|
||||||
cat <<'EOF' > $run_dir/run.yaml
|
cat <<'EOF' > $run_dir/run.yaml
|
||||||
|
@ -100,10 +81,10 @@ jobs:
|
||||||
port: 8321
|
port: 8321
|
||||||
EOF
|
EOF
|
||||||
yq eval '.server.auth = {"provider_type": "${{ matrix.auth-provider }}"}' -i $run_dir/run.yaml
|
yq eval '.server.auth = {"provider_type": "${{ matrix.auth-provider }}"}' -i $run_dir/run.yaml
|
||||||
yq eval '.server.auth.config = {"tls_cafile": "${{ env.KUBERNETES_CA_CERT_PATH }}", "issuer": "${{ env.KUBERNETES_ISSUER }}", "audience": "${{ env.KUBERNETES_AUDIENCE }}"}' -i $run_dir/run.yaml
|
yq eval '.server.auth.config = {"api_server_url": "${{ env.KUBERNETES_API_SERVER_URL }}", "ca_cert_path": "${{ env.KUBERNETES_CA_CERT_PATH }}"}' -i $run_dir/run.yaml
|
||||||
yq eval '.server.auth.config.jwks = {"uri": "${{ env.KUBERNETES_API_SERVER_URL }}"}' -i $run_dir/run.yaml
|
|
||||||
cat $run_dir/run.yaml
|
cat $run_dir/run.yaml
|
||||||
|
|
||||||
|
source .venv/bin/activate
|
||||||
nohup uv run llama stack run $run_dir/run.yaml --image-type venv > server.log 2>&1 &
|
nohup uv run llama stack run $run_dir/run.yaml --image-type venv > server.log 2>&1 &
|
||||||
|
|
||||||
- name: Wait for Llama Stack server to be ready
|
- name: Wait for Llama Stack server to be ready
|
|
@ -24,7 +24,7 @@ jobs:
|
||||||
matrix:
|
matrix:
|
||||||
# Listing tests manually since some of them currently fail
|
# Listing tests manually since some of them currently fail
|
||||||
# TODO: generate matrix list from tests/integration when fixed
|
# TODO: generate matrix list from tests/integration when fixed
|
||||||
test-type: [agents, inference, datasets, inspect, scoring, post_training, providers, tool_runtime]
|
test-type: [agents, inference, datasets, inspect, scoring, post_training, providers]
|
||||||
client-type: [library, http]
|
client-type: [library, http]
|
||||||
fail-fast: false # we want to run all tests regardless of failure
|
fail-fast: false # we want to run all tests regardless of failure
|
||||||
|
|
||||||
|
@ -32,14 +32,24 @@ jobs:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install uv
|
||||||
uses: ./.github/actions/setup-runner
|
uses: astral-sh/setup-uv@c7f87aa956e4c323abf06d5dec078e358f6b4d04 # v6.0.0
|
||||||
|
with:
|
||||||
|
python-version: "3.10"
|
||||||
|
activate-environment: true
|
||||||
|
|
||||||
- name: Setup ollama
|
- name: Setup ollama
|
||||||
uses: ./.github/actions/setup-ollama
|
uses: ./.github/actions/setup-ollama
|
||||||
|
|
||||||
- name: Build Llama Stack
|
- name: Set Up Environment and Install Dependencies
|
||||||
run: |
|
run: |
|
||||||
|
uv sync --extra dev --extra test
|
||||||
|
uv pip install ollama faiss-cpu
|
||||||
|
# always test against the latest version of the client
|
||||||
|
# TODO: this is not necessarily a good idea. we need to test against both published and latest
|
||||||
|
# to find out backwards compatibility issues.
|
||||||
|
uv pip install git+https://github.com/meta-llama/llama-stack-client-python.git@main
|
||||||
|
uv pip install -e .
|
||||||
llama stack build --template ollama --image-type venv
|
llama stack build --template ollama --image-type venv
|
||||||
|
|
||||||
- name: Start Llama Stack server in background
|
- name: Start Llama Stack server in background
|
||||||
|
@ -47,7 +57,8 @@ jobs:
|
||||||
env:
|
env:
|
||||||
INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct"
|
INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct"
|
||||||
run: |
|
run: |
|
||||||
LLAMA_STACK_LOG_FILE=server.log nohup uv run llama stack run ./llama_stack/templates/ollama/run.yaml --image-type venv &
|
source .venv/bin/activate
|
||||||
|
nohup uv run llama stack run ./llama_stack/templates/ollama/run.yaml --image-type venv > server.log 2>&1 &
|
||||||
|
|
||||||
- name: Wait for Llama Stack server to be ready
|
- name: Wait for Llama Stack server to be ready
|
||||||
if: matrix.client-type == 'http'
|
if: matrix.client-type == 'http'
|
||||||
|
@ -75,12 +86,6 @@ jobs:
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
- name: Check Storage and Memory Available Before Tests
|
|
||||||
if: ${{ always() }}
|
|
||||||
run: |
|
|
||||||
free -h
|
|
||||||
df -h
|
|
||||||
|
|
||||||
- name: Run Integration Tests
|
- name: Run Integration Tests
|
||||||
env:
|
env:
|
||||||
INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct"
|
INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct"
|
||||||
|
@ -90,24 +95,17 @@ jobs:
|
||||||
else
|
else
|
||||||
stack_config="http://localhost:8321"
|
stack_config="http://localhost:8321"
|
||||||
fi
|
fi
|
||||||
uv run pytest -s -v tests/integration/${{ matrix.test-type }} --stack-config=${stack_config} \
|
uv run pytest -v tests/integration/${{ matrix.test-type }} --stack-config=${stack_config} \
|
||||||
-k "not(builtin_tool or safety_with_image or code_interpreter or test_rag)" \
|
-k "not(builtin_tool or safety_with_image or code_interpreter or test_rag)" \
|
||||||
--text-model="meta-llama/Llama-3.2-3B-Instruct" \
|
--text-model="meta-llama/Llama-3.2-3B-Instruct" \
|
||||||
--embedding-model=all-MiniLM-L6-v2
|
--embedding-model=all-MiniLM-L6-v2
|
||||||
|
|
||||||
- name: Check Storage and Memory Available After Tests
|
|
||||||
if: ${{ always() }}
|
|
||||||
run: |
|
|
||||||
free -h
|
|
||||||
df -h
|
|
||||||
|
|
||||||
- name: Write ollama logs to file
|
- name: Write ollama logs to file
|
||||||
if: ${{ always() }}
|
|
||||||
run: |
|
run: |
|
||||||
sudo journalctl -u ollama.service > ollama.log
|
sudo journalctl -u ollama.service > ollama.log
|
||||||
|
|
||||||
- name: Upload all logs to artifacts
|
- name: Upload all logs to artifacts
|
||||||
if: ${{ always() }}
|
if: always()
|
||||||
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
|
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
|
||||||
with:
|
with:
|
||||||
name: logs-${{ github.run_id }}-${{ github.run_attempt }}-${{ matrix.client-type }}-${{ matrix.test-type }}
|
name: logs-${{ github.run_id }}-${{ github.run_attempt }}-${{ matrix.client-type }}-${{ matrix.test-type }}
|
|
@ -29,7 +29,6 @@ jobs:
|
||||||
- uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1
|
- uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1
|
||||||
env:
|
env:
|
||||||
SKIP: no-commit-to-branch
|
SKIP: no-commit-to-branch
|
||||||
RUFF_OUTPUT_FORMAT: github
|
|
||||||
|
|
||||||
- name: Verify if there are any diff files after pre-commit
|
- name: Verify if there are any diff files after pre-commit
|
||||||
run: |
|
run: |
|
|
@ -50,8 +50,21 @@ jobs:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Set up Python
|
||||||
uses: ./.github/actions/setup-runner
|
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0
|
||||||
|
with:
|
||||||
|
python-version: '3.10'
|
||||||
|
|
||||||
|
- name: Install uv
|
||||||
|
uses: astral-sh/setup-uv@c7f87aa956e4c323abf06d5dec078e358f6b4d04 # v6.0.0
|
||||||
|
with:
|
||||||
|
python-version: "3.10"
|
||||||
|
|
||||||
|
- name: Install LlamaStack
|
||||||
|
run: |
|
||||||
|
uv venv
|
||||||
|
source .venv/bin/activate
|
||||||
|
uv pip install -e .
|
||||||
|
|
||||||
- name: Print build dependencies
|
- name: Print build dependencies
|
||||||
run: |
|
run: |
|
||||||
|
@ -66,6 +79,7 @@ jobs:
|
||||||
- name: Print dependencies in the image
|
- name: Print dependencies in the image
|
||||||
if: matrix.image-type == 'venv'
|
if: matrix.image-type == 'venv'
|
||||||
run: |
|
run: |
|
||||||
|
source test/bin/activate
|
||||||
uv pip list
|
uv pip list
|
||||||
|
|
||||||
build-single-provider:
|
build-single-provider:
|
||||||
|
@ -74,8 +88,21 @@ jobs:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Set up Python
|
||||||
uses: ./.github/actions/setup-runner
|
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0
|
||||||
|
with:
|
||||||
|
python-version: '3.10'
|
||||||
|
|
||||||
|
- name: Install uv
|
||||||
|
uses: astral-sh/setup-uv@c7f87aa956e4c323abf06d5dec078e358f6b4d04 # v6.0.0
|
||||||
|
with:
|
||||||
|
python-version: "3.10"
|
||||||
|
|
||||||
|
- name: Install LlamaStack
|
||||||
|
run: |
|
||||||
|
uv venv
|
||||||
|
source .venv/bin/activate
|
||||||
|
uv pip install -e .
|
||||||
|
|
||||||
- name: Build a single provider
|
- name: Build a single provider
|
||||||
run: |
|
run: |
|
||||||
|
@ -87,8 +114,21 @@ jobs:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Set up Python
|
||||||
uses: ./.github/actions/setup-runner
|
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0
|
||||||
|
with:
|
||||||
|
python-version: '3.10'
|
||||||
|
|
||||||
|
- name: Install uv
|
||||||
|
uses: astral-sh/setup-uv@c7f87aa956e4c323abf06d5dec078e358f6b4d04 # v6.0.0
|
||||||
|
with:
|
||||||
|
python-version: "3.10"
|
||||||
|
|
||||||
|
- name: Install LlamaStack
|
||||||
|
run: |
|
||||||
|
uv venv
|
||||||
|
source .venv/bin/activate
|
||||||
|
uv pip install -e .
|
||||||
|
|
||||||
- name: Build a single provider
|
- name: Build a single provider
|
||||||
run: |
|
run: |
|
||||||
|
@ -112,8 +152,21 @@ jobs:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Set up Python
|
||||||
uses: ./.github/actions/setup-runner
|
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0
|
||||||
|
with:
|
||||||
|
python-version: '3.10'
|
||||||
|
|
||||||
|
- name: Install uv
|
||||||
|
uses: astral-sh/setup-uv@0c5e2b8115b80b4c7c5ddf6ffdd634974642d182 # v5.4.1
|
||||||
|
with:
|
||||||
|
python-version: "3.10"
|
||||||
|
|
||||||
|
- name: Install LlamaStack
|
||||||
|
run: |
|
||||||
|
uv venv
|
||||||
|
source .venv/bin/activate
|
||||||
|
uv pip install -e .
|
||||||
|
|
||||||
- name: Pin template to UBI9 base
|
- name: Pin template to UBI9 base
|
||||||
run: |
|
run: |
|
|
@ -25,8 +25,15 @@ jobs:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install uv
|
||||||
uses: ./.github/actions/setup-runner
|
uses: astral-sh/setup-uv@c7f87aa956e4c323abf06d5dec078e358f6b4d04 # v6.0.0
|
||||||
|
with:
|
||||||
|
python-version: "3.10"
|
||||||
|
|
||||||
|
- name: Set Up Environment and Install Dependencies
|
||||||
|
run: |
|
||||||
|
uv sync --extra dev --extra test
|
||||||
|
uv pip install -e .
|
||||||
|
|
||||||
- name: Apply image type to config file
|
- name: Apply image type to config file
|
||||||
run: |
|
run: |
|
||||||
|
@ -52,6 +59,7 @@ jobs:
|
||||||
env:
|
env:
|
||||||
INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct"
|
INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct"
|
||||||
run: |
|
run: |
|
||||||
|
source ci-test/bin/activate
|
||||||
uv run pip list
|
uv run pip list
|
||||||
nohup uv run --active llama stack run tests/external-provider/llama-stack-provider-ollama/run.yaml --image-type ${{ matrix.image-type }} > server.log 2>&1 &
|
nohup uv run --active llama stack run tests/external-provider/llama-stack-provider-ollama/run.yaml --image-type ${{ matrix.image-type }} > server.log 2>&1 &
|
||||||
|
|
|
@ -30,11 +30,17 @@ jobs:
|
||||||
- "3.12"
|
- "3.12"
|
||||||
- "3.13"
|
- "3.13"
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Set up Python ${{ matrix.python }}
|
||||||
uses: ./.github/actions/setup-runner
|
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python }}
|
||||||
|
|
||||||
|
- uses: astral-sh/setup-uv@c7f87aa956e4c323abf06d5dec078e358f6b4d04 # v6.0.0
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python }}
|
||||||
|
enable-cache: false
|
||||||
|
|
||||||
- name: Run unit tests
|
- name: Run unit tests
|
||||||
run: |
|
run: |
|
|
@ -37,8 +37,16 @@ jobs:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Set up Python
|
||||||
uses: ./.github/actions/setup-runner
|
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0
|
||||||
|
with:
|
||||||
|
python-version: '3.11'
|
||||||
|
|
||||||
|
- name: Install the latest version of uv
|
||||||
|
uses: astral-sh/setup-uv@c7f87aa956e4c323abf06d5dec078e358f6b4d04 # v6.0.0
|
||||||
|
|
||||||
|
- name: Sync with uv
|
||||||
|
run: uv sync --extra docs
|
||||||
|
|
||||||
- name: Build HTML
|
- name: Build HTML
|
||||||
run: |
|
run: |
|
2
.gitignore
vendored
2
.gitignore
vendored
|
@ -6,7 +6,6 @@ dev_requirements.txt
|
||||||
build
|
build
|
||||||
.DS_Store
|
.DS_Store
|
||||||
llama_stack/configs/*
|
llama_stack/configs/*
|
||||||
.cursor/
|
|
||||||
xcuserdata/
|
xcuserdata/
|
||||||
*.hmap
|
*.hmap
|
||||||
.DS_Store
|
.DS_Store
|
||||||
|
@ -24,4 +23,3 @@ venv/
|
||||||
pytest-report.xml
|
pytest-report.xml
|
||||||
.coverage
|
.coverage
|
||||||
.python-version
|
.python-version
|
||||||
data
|
|
||||||
|
|
|
@ -53,7 +53,7 @@ repos:
|
||||||
- black==24.3.0
|
- black==24.3.0
|
||||||
|
|
||||||
- repo: https://github.com/astral-sh/uv-pre-commit
|
- repo: https://github.com/astral-sh/uv-pre-commit
|
||||||
rev: 0.7.8
|
rev: 0.6.3
|
||||||
hooks:
|
hooks:
|
||||||
- id: uv-lock
|
- id: uv-lock
|
||||||
- id: uv-export
|
- id: uv-export
|
||||||
|
@ -61,7 +61,6 @@ repos:
|
||||||
"--frozen",
|
"--frozen",
|
||||||
"--no-hashes",
|
"--no-hashes",
|
||||||
"--no-emit-project",
|
"--no-emit-project",
|
||||||
"--no-default-groups",
|
|
||||||
"--output-file=requirements.txt"
|
"--output-file=requirements.txt"
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -89,17 +88,20 @@ repos:
|
||||||
- id: distro-codegen
|
- id: distro-codegen
|
||||||
name: Distribution Template Codegen
|
name: Distribution Template Codegen
|
||||||
additional_dependencies:
|
additional_dependencies:
|
||||||
- uv==0.7.8
|
- uv==0.6.0
|
||||||
entry: uv run --group codegen ./scripts/distro_codegen.py
|
entry: uv run --extra codegen ./scripts/distro_codegen.py
|
||||||
language: python
|
language: python
|
||||||
pass_filenames: false
|
pass_filenames: false
|
||||||
require_serial: true
|
require_serial: true
|
||||||
files: ^llama_stack/templates/.*$|^llama_stack/providers/.*/inference/.*/models\.py$
|
files: ^llama_stack/templates/.*$|^llama_stack/providers/.*/inference/.*/models\.py$
|
||||||
|
|
||||||
|
- repo: local
|
||||||
|
hooks:
|
||||||
- id: openapi-codegen
|
- id: openapi-codegen
|
||||||
name: API Spec Codegen
|
name: API Spec Codegen
|
||||||
additional_dependencies:
|
additional_dependencies:
|
||||||
- uv==0.7.8
|
- uv==0.6.2
|
||||||
entry: sh -c 'uv run ./docs/openapi_generator/run_openapi_generator.sh > /dev/null'
|
entry: sh -c 'uv run --with ".[dev]" ./docs/openapi_generator/run_openapi_generator.sh > /dev/null'
|
||||||
language: python
|
language: python
|
||||||
pass_filenames: false
|
pass_filenames: false
|
||||||
require_serial: true
|
require_serial: true
|
||||||
|
|
|
@ -5,21 +5,28 @@
|
||||||
# Required
|
# Required
|
||||||
version: 2
|
version: 2
|
||||||
|
|
||||||
# Build documentation in the "docs/" directory with Sphinx
|
|
||||||
sphinx:
|
|
||||||
configuration: docs/source/conf.py
|
|
||||||
|
|
||||||
# Set the OS, Python version and other tools you might need
|
# Set the OS, Python version and other tools you might need
|
||||||
build:
|
build:
|
||||||
os: ubuntu-22.04
|
os: ubuntu-22.04
|
||||||
tools:
|
tools:
|
||||||
python: "3.12"
|
python: "3.12"
|
||||||
jobs:
|
# You can also specify other tool versions:
|
||||||
pre_create_environment:
|
# nodejs: "19"
|
||||||
- asdf plugin add uv
|
# rust: "1.64"
|
||||||
- asdf install uv latest
|
# golang: "1.19"
|
||||||
- asdf global uv latest
|
|
||||||
create_environment:
|
# Build documentation in the "docs/" directory with Sphinx
|
||||||
- uv venv "${READTHEDOCS_VIRTUALENV_PATH}"
|
sphinx:
|
||||||
install:
|
configuration: docs/source/conf.py
|
||||||
- UV_PROJECT_ENVIRONMENT="${READTHEDOCS_VIRTUALENV_PATH}" uv sync --frozen --group docs
|
|
||||||
|
# Optionally build your docs in additional formats such as PDF and ePub
|
||||||
|
# formats:
|
||||||
|
# - pdf
|
||||||
|
# - epub
|
||||||
|
|
||||||
|
# Optional but recommended, declare the Python requirements required
|
||||||
|
# to build your documentation
|
||||||
|
# See https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html
|
||||||
|
python:
|
||||||
|
install:
|
||||||
|
- requirements: docs/requirements.txt
|
||||||
|
|
627
CHANGELOG.md
627
CHANGELOG.md
|
@ -3,14 +3,14 @@
|
||||||
# v0.2.7
|
# v0.2.7
|
||||||
Published on: 2025-05-16T20:38:10Z
|
Published on: 2025-05-16T20:38:10Z
|
||||||
|
|
||||||
## Highlights
|
## Highlights
|
||||||
|
|
||||||
This is a small update. But a couple highlights:
|
This is a small update. But a couple highlights:
|
||||||
|
|
||||||
* feat: function tools in OpenAI Responses by @bbrowning in https://github.com/meta-llama/llama-stack/pull/2094, getting closer to ready. Streaming is the next missing piece.
|
* feat: function tools in OpenAI Responses by @bbrowning in https://github.com/meta-llama/llama-stack/pull/2094, getting closer to ready. Streaming is the next missing piece.
|
||||||
* feat: Adding support for customizing chunk context in RAG insertion and querying by @franciscojavierarceo in https://github.com/meta-llama/llama-stack/pull/2134
|
* feat: Adding support for customizing chunk context in RAG insertion and querying by @franciscojavierarceo in https://github.com/meta-llama/llama-stack/pull/2134
|
||||||
* feat: scaffolding for Llama Stack UI by @ehhuang in https://github.com/meta-llama/llama-stack/pull/2149, more to come in the coming releases.
|
* feat: scaffolding for Llama Stack UI by @ehhuang in https://github.com/meta-llama/llama-stack/pull/2149, more to come in the coming releases.
|
||||||
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
@ -31,42 +31,42 @@ Published on: 2025-05-04T20:16:49Z
|
||||||
# v0.2.4
|
# v0.2.4
|
||||||
Published on: 2025-04-29T17:26:01Z
|
Published on: 2025-04-29T17:26:01Z
|
||||||
|
|
||||||
## Highlights
|
## Highlights
|
||||||
|
|
||||||
* One-liner to install and run Llama Stack yay! by @reluctantfuturist in https://github.com/meta-llama/llama-stack/pull/1383
|
* One-liner to install and run Llama Stack yay! by @reluctantfuturist in https://github.com/meta-llama/llama-stack/pull/1383
|
||||||
* support for NVIDIA NeMo datastore by @raspawar in https://github.com/meta-llama/llama-stack/pull/1852
|
* support for NVIDIA NeMo datastore by @raspawar in https://github.com/meta-llama/llama-stack/pull/1852
|
||||||
* (yuge!) Kubernetes authentication by @leseb in https://github.com/meta-llama/llama-stack/pull/1778
|
* (yuge!) Kubernetes authentication by @leseb in https://github.com/meta-llama/llama-stack/pull/1778
|
||||||
* (yuge!) OpenAI Responses API by @bbrowning in https://github.com/meta-llama/llama-stack/pull/1989
|
* (yuge!) OpenAI Responses API by @bbrowning in https://github.com/meta-llama/llama-stack/pull/1989
|
||||||
* add api.llama provider, llama-guard-4 model by @ashwinb in https://github.com/meta-llama/llama-stack/pull/2058
|
* add api.llama provider, llama-guard-4 model by @ashwinb in https://github.com/meta-llama/llama-stack/pull/2058
|
||||||
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
# v0.2.3
|
# v0.2.3
|
||||||
Published on: 2025-04-25T22:46:21Z
|
Published on: 2025-04-25T22:46:21Z
|
||||||
|
|
||||||
## Highlights
|
## Highlights
|
||||||
|
|
||||||
* OpenAI compatible inference endpoints and client-SDK support. `client.chat.completions.create()` now works.
|
* OpenAI compatible inference endpoints and client-SDK support. `client.chat.completions.create()` now works.
|
||||||
* significant improvements and functionality added to the nVIDIA distribution
|
* significant improvements and functionality added to the nVIDIA distribution
|
||||||
* many improvements to the test verification suite.
|
* many improvements to the test verification suite.
|
||||||
* new inference providers: Ramalama, IBM WatsonX
|
* new inference providers: Ramalama, IBM WatsonX
|
||||||
* many improvements to the Playground UI
|
* many improvements to the Playground UI
|
||||||
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
# v0.2.2
|
# v0.2.2
|
||||||
Published on: 2025-04-13T01:19:49Z
|
Published on: 2025-04-13T01:19:49Z
|
||||||
|
|
||||||
## Main changes
|
## Main changes
|
||||||
|
|
||||||
- Bring Your Own Provider (@leseb) - use out-of-tree provider code to execute the distribution server
|
- Bring Your Own Provider (@leseb) - use out-of-tree provider code to execute the distribution server
|
||||||
- OpenAI compatible inference API in progress (@bbrowning)
|
- OpenAI compatible inference API in progress (@bbrowning)
|
||||||
- Provider verifications (@ehhuang)
|
- Provider verifications (@ehhuang)
|
||||||
- Many updates and fixes to playground
|
- Many updates and fixes to playground
|
||||||
- Several llama4 related fixes
|
- Several llama4 related fixes
|
||||||
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
@ -80,10 +80,10 @@ Published on: 2025-04-05T23:13:00Z
|
||||||
# v0.2.0
|
# v0.2.0
|
||||||
Published on: 2025-04-05T19:04:29Z
|
Published on: 2025-04-05T19:04:29Z
|
||||||
|
|
||||||
## Llama 4 Support
|
## Llama 4 Support
|
||||||
|
|
||||||
Checkout more at https://www.llama.com
|
Checkout more at https://www.llama.com
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
@ -91,58 +91,58 @@ Checkout more at https://www.llama.com
|
||||||
# v0.1.9
|
# v0.1.9
|
||||||
Published on: 2025-03-29T00:52:23Z
|
Published on: 2025-03-29T00:52:23Z
|
||||||
|
|
||||||
### Build and Test Agents
|
### Build and Test Agents
|
||||||
* Agents: Entire document context with attachments
|
* Agents: Entire document context with attachments
|
||||||
* RAG: Documentation with sqlite-vec faiss comparison
|
* RAG: Documentation with sqlite-vec faiss comparison
|
||||||
* Getting started: Fixes to getting started notebook.
|
* Getting started: Fixes to getting started notebook.
|
||||||
|
|
||||||
### Agent Evals and Model Customization
|
### Agent Evals and Model Customization
|
||||||
* (**New**) Post-training: Add nemo customizer
|
* (**New**) Post-training: Add nemo customizer
|
||||||
|
|
||||||
### Better Engineering
|
### Better Engineering
|
||||||
* Moved sqlite-vec to non-blocking calls
|
* Moved sqlite-vec to non-blocking calls
|
||||||
* Don't return a payload on file delete
|
* Don't return a payload on file delete
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
# v0.1.8
|
# v0.1.8
|
||||||
Published on: 2025-03-24T01:28:50Z
|
Published on: 2025-03-24T01:28:50Z
|
||||||
|
|
||||||
# v0.1.8 Release Notes
|
# v0.1.8 Release Notes
|
||||||
|
|
||||||
### Build and Test Agents
|
### Build and Test Agents
|
||||||
* Safety: Integrated NVIDIA as a safety provider.
|
* Safety: Integrated NVIDIA as a safety provider.
|
||||||
* VectorDB: Added Qdrant as an inline provider.
|
* VectorDB: Added Qdrant as an inline provider.
|
||||||
* Agents: Added support for multiple tool groups in agents.
|
* Agents: Added support for multiple tool groups in agents.
|
||||||
* Agents: Simplified imports for Agents in client package
|
* Agents: Simplified imports for Agents in client package
|
||||||
|
|
||||||
|
|
||||||
### Agent Evals and Model Customization
|
### Agent Evals and Model Customization
|
||||||
* Introduced DocVQA and IfEval benchmarks.
|
* Introduced DocVQA and IfEval benchmarks.
|
||||||
|
|
||||||
### Deploying and Monitoring Agents
|
### Deploying and Monitoring Agents
|
||||||
* Introduced a Containerfile and image workflow for the Playground.
|
* Introduced a Containerfile and image workflow for the Playground.
|
||||||
* Implemented support for Bearer (API Key) authentication.
|
* Implemented support for Bearer (API Key) authentication.
|
||||||
* Added attribute-based access control for resources.
|
* Added attribute-based access control for resources.
|
||||||
* Fixes on docker deployments: use --pull always and standardized the default port to 8321
|
* Fixes on docker deployments: use --pull always and standardized the default port to 8321
|
||||||
* Deprecated: /v1/inspect/providers use /v1/providers/ instead
|
* Deprecated: /v1/inspect/providers use /v1/providers/ instead
|
||||||
|
|
||||||
### Better Engineering
|
### Better Engineering
|
||||||
* Consolidated scripts under the ./scripts directory.
|
* Consolidated scripts under the ./scripts directory.
|
||||||
* Addressed mypy violations in various modules.
|
* Addressed mypy violations in various modules.
|
||||||
* Added Dependabot scans for Python dependencies.
|
* Added Dependabot scans for Python dependencies.
|
||||||
* Implemented a scheduled workflow to update the changelog automatically.
|
* Implemented a scheduled workflow to update the changelog automatically.
|
||||||
* Enforced concurrency to reduce CI loads.
|
* Enforced concurrency to reduce CI loads.
|
||||||
|
|
||||||
|
|
||||||
### New Contributors
|
### New Contributors
|
||||||
* @cmodi-meta made their first contribution in https://github.com/meta-llama/llama-stack/pull/1650
|
* @cmodi-meta made their first contribution in https://github.com/meta-llama/llama-stack/pull/1650
|
||||||
* @jeffmaury made their first contribution in https://github.com/meta-llama/llama-stack/pull/1671
|
* @jeffmaury made their first contribution in https://github.com/meta-llama/llama-stack/pull/1671
|
||||||
* @derekhiggins made their first contribution in https://github.com/meta-llama/llama-stack/pull/1698
|
* @derekhiggins made their first contribution in https://github.com/meta-llama/llama-stack/pull/1698
|
||||||
* @Bobbins228 made their first contribution in https://github.com/meta-llama/llama-stack/pull/1745
|
* @Bobbins228 made their first contribution in https://github.com/meta-llama/llama-stack/pull/1745
|
||||||
|
|
||||||
**Full Changelog**: https://github.com/meta-llama/llama-stack/compare/v0.1.7...v0.1.8
|
**Full Changelog**: https://github.com/meta-llama/llama-stack/compare/v0.1.7...v0.1.8
|
||||||
|
|
||||||
---
|
---
|
||||||
|
@ -150,73 +150,73 @@ Published on: 2025-03-24T01:28:50Z
|
||||||
# v0.1.7
|
# v0.1.7
|
||||||
Published on: 2025-03-14T22:30:51Z
|
Published on: 2025-03-14T22:30:51Z
|
||||||
|
|
||||||
## 0.1.7 Release Notes
|
## 0.1.7 Release Notes
|
||||||
|
|
||||||
### Build and Test Agents
|
### Build and Test Agents
|
||||||
* Inference: ImageType is now refactored to LlamaStackImageType
|
* Inference: ImageType is now refactored to LlamaStackImageType
|
||||||
* Inference: Added tests to measure TTFT
|
* Inference: Added tests to measure TTFT
|
||||||
* Inference: Bring back usage metrics
|
* Inference: Bring back usage metrics
|
||||||
* Agents: Added endpoint for get agent, list agents and list sessions
|
* Agents: Added endpoint for get agent, list agents and list sessions
|
||||||
* Agents: Automated conversion of type hints in client tool for lite llm format
|
* Agents: Automated conversion of type hints in client tool for lite llm format
|
||||||
* Agents: Deprecated ToolResponseMessage in agent.resume API
|
* Agents: Deprecated ToolResponseMessage in agent.resume API
|
||||||
* Added Provider API for listing and inspecting provider info
|
* Added Provider API for listing and inspecting provider info
|
||||||
|
|
||||||
### Agent Evals and Model Customization
|
### Agent Evals and Model Customization
|
||||||
* Eval: Added new eval benchmarks Math 500 and BFCL v3
|
* Eval: Added new eval benchmarks Math 500 and BFCL v3
|
||||||
* Deploy and Monitoring of Agents
|
* Deploy and Monitoring of Agents
|
||||||
* Telemetry: Fix tracing to work across coroutines
|
* Telemetry: Fix tracing to work across coroutines
|
||||||
|
|
||||||
### Better Engineering
|
### Better Engineering
|
||||||
* Display code coverage for unit tests
|
* Display code coverage for unit tests
|
||||||
* Updated call sites (inference, tool calls, agents) to move to async non blocking calls
|
* Updated call sites (inference, tool calls, agents) to move to async non blocking calls
|
||||||
* Unit tests also run on Python 3.11, 3.12, and 3.13
|
* Unit tests also run on Python 3.11, 3.12, and 3.13
|
||||||
* Added ollama inference to Integration tests CI
|
* Added ollama inference to Integration tests CI
|
||||||
* Improved documentation across examples, testing, CLI, updated providers table )
|
* Improved documentation across examples, testing, CLI, updated providers table )
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
# v0.1.6
|
# v0.1.6
|
||||||
Published on: 2025-03-08T04:35:08Z
|
Published on: 2025-03-08T04:35:08Z
|
||||||
|
|
||||||
## 0.1.6 Release Notes
|
## 0.1.6 Release Notes
|
||||||
|
|
||||||
### Build and Test Agents
|
### Build and Test Agents
|
||||||
* Inference: Fixed support for inline vllm provider
|
* Inference: Fixed support for inline vllm provider
|
||||||
* (**New**) Agent: Build & Monitor Agent Workflows with Llama Stack + Anthropic's Best Practice [Notebook](https://github.com/meta-llama/llama-stack/blob/main/docs/notebooks/Llama_Stack_Agent_Workflows.ipynb)
|
* (**New**) Agent: Build & Monitor Agent Workflows with Llama Stack + Anthropic's Best Practice [Notebook](https://github.com/meta-llama/llama-stack/blob/main/docs/notebooks/Llama_Stack_Agent_Workflows.ipynb)
|
||||||
* (**New**) Agent: Revamped agent [documentation](https://llama-stack.readthedocs.io/en/latest/building_applications/agent.html) with more details and examples
|
* (**New**) Agent: Revamped agent [documentation](https://llama-stack.readthedocs.io/en/latest/building_applications/agent.html) with more details and examples
|
||||||
* Agent: Unify tools and Python SDK Agents API
|
* Agent: Unify tools and Python SDK Agents API
|
||||||
* Agent: AsyncAgent Python SDK wrapper supporting async client tool calls
|
* Agent: AsyncAgent Python SDK wrapper supporting async client tool calls
|
||||||
* Agent: Support python functions without @client_tool decorator as client tools
|
* Agent: Support python functions without @client_tool decorator as client tools
|
||||||
* Agent: deprecation for allow_resume_turn flag, and remove need to specify tool_prompt_format
|
* Agent: deprecation for allow_resume_turn flag, and remove need to specify tool_prompt_format
|
||||||
* VectorIO: MilvusDB support added
|
* VectorIO: MilvusDB support added
|
||||||
|
|
||||||
### Agent Evals and Model Customization
|
### Agent Evals and Model Customization
|
||||||
* (**New**) Agent: Llama Stack RAG Lifecycle [Notebook](https://github.com/meta-llama/llama-stack/blob/main/docs/notebooks/Llama_Stack_RAG_Lifecycle.ipynb)
|
* (**New**) Agent: Llama Stack RAG Lifecycle [Notebook](https://github.com/meta-llama/llama-stack/blob/main/docs/notebooks/Llama_Stack_RAG_Lifecycle.ipynb)
|
||||||
* Eval: Documentation for eval, scoring, adding new benchmarks
|
* Eval: Documentation for eval, scoring, adding new benchmarks
|
||||||
* Eval: Distribution template to run benchmarks on llama & non-llama models
|
* Eval: Distribution template to run benchmarks on llama & non-llama models
|
||||||
* Eval: Ability to register new custom LLM-as-judge scoring functions
|
* Eval: Ability to register new custom LLM-as-judge scoring functions
|
||||||
* (**New**) Looking for contributors for open benchmarks. See [documentation](https://llama-stack.readthedocs.io/en/latest/references/evals_reference/index.html#open-benchmark-contributing-guide) for details.
|
* (**New**) Looking for contributors for open benchmarks. See [documentation](https://llama-stack.readthedocs.io/en/latest/references/evals_reference/index.html#open-benchmark-contributing-guide) for details.
|
||||||
|
|
||||||
### Deploy and Monitoring of Agents
|
### Deploy and Monitoring of Agents
|
||||||
* Better support for different log levels across all components for better monitoring
|
* Better support for different log levels across all components for better monitoring
|
||||||
|
|
||||||
### Better Engineering
|
### Better Engineering
|
||||||
* Enhance OpenAPI spec to include Error types across all APIs
|
* Enhance OpenAPI spec to include Error types across all APIs
|
||||||
* Moved all tests to /tests and created unit tests to run on each PR
|
* Moved all tests to /tests and created unit tests to run on each PR
|
||||||
* Removed all dependencies on llama-models repo
|
* Removed all dependencies on llama-models repo
|
||||||
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
# v0.1.5.1
|
# v0.1.5.1
|
||||||
Published on: 2025-02-28T22:37:44Z
|
Published on: 2025-02-28T22:37:44Z
|
||||||
|
|
||||||
## 0.1.5.1 Release Notes
|
## 0.1.5.1 Release Notes
|
||||||
* Fixes for security risk in https://github.com/meta-llama/llama-stack/pull/1327 and https://github.com/meta-llama/llama-stack/pull/1328
|
* Fixes for security risk in https://github.com/meta-llama/llama-stack/pull/1327 and https://github.com/meta-llama/llama-stack/pull/1328
|
||||||
|
|
||||||
**Full Changelog**: https://github.com/meta-llama/llama-stack/compare/v0.1.5...v0.1.5.1
|
**Full Changelog**: https://github.com/meta-llama/llama-stack/compare/v0.1.5...v0.1.5.1
|
||||||
|
|
||||||
---
|
---
|
||||||
|
@ -224,176 +224,176 @@ Published on: 2025-02-28T22:37:44Z
|
||||||
# v0.1.5
|
# v0.1.5
|
||||||
Published on: 2025-02-28T18:14:01Z
|
Published on: 2025-02-28T18:14:01Z
|
||||||
|
|
||||||
## 0.1.5 Release Notes
|
## 0.1.5 Release Notes
|
||||||
### Build Agents
|
### Build Agents
|
||||||
* Inference: Support more non-llama models (openai, anthropic, gemini)
|
* Inference: Support more non-llama models (openai, anthropic, gemini)
|
||||||
* Inference: Can use the provider's model name in addition to the HF alias
|
* Inference: Can use the provider's model name in addition to the HF alias
|
||||||
* Inference: Fixed issues with calling tools that weren't specified in the prompt
|
* Inference: Fixed issues with calling tools that weren't specified in the prompt
|
||||||
* RAG: Improved system prompt for RAG and no more need for hard-coded rag-tool calling
|
* RAG: Improved system prompt for RAG and no more need for hard-coded rag-tool calling
|
||||||
* Embeddings: Added support for Nemo retriever embedding models
|
* Embeddings: Added support for Nemo retriever embedding models
|
||||||
* Tools: Added support for MCP tools in Ollama Distribution
|
* Tools: Added support for MCP tools in Ollama Distribution
|
||||||
* Distributions: Added new Groq distribution
|
* Distributions: Added new Groq distribution
|
||||||
|
|
||||||
### Customize Models
|
### Customize Models
|
||||||
* Save post-trained checkpoint in SafeTensor format to allow Ollama inference provider to use the post-trained model
|
* Save post-trained checkpoint in SafeTensor format to allow Ollama inference provider to use the post-trained model
|
||||||
|
|
||||||
### Monitor agents
|
### Monitor agents
|
||||||
* More comprehensive logging of agent steps including client tools
|
* More comprehensive logging of agent steps including client tools
|
||||||
* Telemetry inputs/outputs are now structured and queryable
|
* Telemetry inputs/outputs are now structured and queryable
|
||||||
* Ability to retrieve agents session, turn, step by ids
|
* Ability to retrieve agents session, turn, step by ids
|
||||||
|
|
||||||
### Better Engineering
|
### Better Engineering
|
||||||
* Moved executorch Swift code out of this repo into the llama-stack-client-swift repo, similar to kotlin
|
* Moved executorch Swift code out of this repo into the llama-stack-client-swift repo, similar to kotlin
|
||||||
* Move most logging to use logger instead of prints
|
* Move most logging to use logger instead of prints
|
||||||
* Completed text /chat-completion and /completion tests
|
* Completed text /chat-completion and /completion tests
|
||||||
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
# v0.1.4
|
# v0.1.4
|
||||||
Published on: 2025-02-25T00:02:43Z
|
Published on: 2025-02-25T00:02:43Z
|
||||||
|
|
||||||
## v0.1.4 Release Notes
|
## v0.1.4 Release Notes
|
||||||
Here are the key changes coming as part of this release:
|
Here are the key changes coming as part of this release:
|
||||||
|
|
||||||
### Build and Test Agents
|
### Build and Test Agents
|
||||||
* Inference: Added support for non-llama models
|
* Inference: Added support for non-llama models
|
||||||
* Inference: Added option to list all downloaded models and remove models
|
* Inference: Added option to list all downloaded models and remove models
|
||||||
* Agent: Introduce new api agents.resume_turn to include client side tool execution in the same turn
|
* Agent: Introduce new api agents.resume_turn to include client side tool execution in the same turn
|
||||||
* Agent: AgentConfig introduces new variable “tool_config” that allows for better tool configuration and system prompt overrides
|
* Agent: AgentConfig introduces new variable “tool_config” that allows for better tool configuration and system prompt overrides
|
||||||
* Agent: Added logging for agent step start and completion times
|
* Agent: Added logging for agent step start and completion times
|
||||||
* Agent: Added support for logging for tool execution metadata
|
* Agent: Added support for logging for tool execution metadata
|
||||||
* Embedding: Updated /inference/embeddings to support asymmetric models, truncation and variable sized outputs
|
* Embedding: Updated /inference/embeddings to support asymmetric models, truncation and variable sized outputs
|
||||||
* Embedding: Updated embedding models for Ollama, Together, and Fireworks with available defaults
|
* Embedding: Updated embedding models for Ollama, Together, and Fireworks with available defaults
|
||||||
* VectorIO: Improved performance of sqlite-vec using chunked writes
|
* VectorIO: Improved performance of sqlite-vec using chunked writes
|
||||||
### Agent Evals and Model Customization
|
### Agent Evals and Model Customization
|
||||||
* Deprecated api /eval-tasks. Use /eval/benchmark instead
|
* Deprecated api /eval-tasks. Use /eval/benchmark instead
|
||||||
* Added CPU training support for TorchTune
|
* Added CPU training support for TorchTune
|
||||||
### Deploy and Monitoring of Agents
|
### Deploy and Monitoring of Agents
|
||||||
* Consistent view of client and server tool calls in telemetry
|
* Consistent view of client and server tool calls in telemetry
|
||||||
### Better Engineering
|
### Better Engineering
|
||||||
* Made tests more data-driven for consistent evaluation
|
* Made tests more data-driven for consistent evaluation
|
||||||
* Fixed documentation links and improved API reference generation
|
* Fixed documentation links and improved API reference generation
|
||||||
* Various small fixes for build scripts and system reliability
|
* Various small fixes for build scripts and system reliability
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
# v0.1.3
|
# v0.1.3
|
||||||
Published on: 2025-02-14T20:24:32Z
|
Published on: 2025-02-14T20:24:32Z
|
||||||
|
|
||||||
## v0.1.3 Release
|
## v0.1.3 Release
|
||||||
|
|
||||||
Here are some key changes that are coming as part of this release.
|
Here are some key changes that are coming as part of this release.
|
||||||
|
|
||||||
### Build and Test Agents
|
### Build and Test Agents
|
||||||
Streamlined the initial development experience
|
Streamlined the initial development experience
|
||||||
- Added support for llama stack run --image-type venv
|
- Added support for llama stack run --image-type venv
|
||||||
- Enhanced vector store options with new sqlite-vec provider and improved Qdrant integration
|
- Enhanced vector store options with new sqlite-vec provider and improved Qdrant integration
|
||||||
- vLLM improvements for tool calling and logprobs
|
- vLLM improvements for tool calling and logprobs
|
||||||
- Better handling of sporadic code_interpreter tool calls
|
- Better handling of sporadic code_interpreter tool calls
|
||||||
|
|
||||||
### Agent Evals
|
### Agent Evals
|
||||||
Better benchmarking and Agent performance assessment
|
Better benchmarking and Agent performance assessment
|
||||||
- Renamed eval API /eval-task to /benchmarks
|
- Renamed eval API /eval-task to /benchmarks
|
||||||
- Improved documentation and notebooks for RAG and evals
|
- Improved documentation and notebooks for RAG and evals
|
||||||
|
|
||||||
### Deploy and Monitoring of Agents
|
### Deploy and Monitoring of Agents
|
||||||
Improved production readiness
|
Improved production readiness
|
||||||
- Added usage metrics collection for chat completions
|
- Added usage metrics collection for chat completions
|
||||||
- CLI improvements for provider information
|
- CLI improvements for provider information
|
||||||
- Improved error handling and system reliability
|
- Improved error handling and system reliability
|
||||||
- Better model endpoint handling and accessibility
|
- Better model endpoint handling and accessibility
|
||||||
- Improved signal handling on distro server
|
- Improved signal handling on distro server
|
||||||
|
|
||||||
### Better Engineering
|
### Better Engineering
|
||||||
Infrastructure and code quality improvements
|
Infrastructure and code quality improvements
|
||||||
- Faster text-based chat completion tests
|
- Faster text-based chat completion tests
|
||||||
- Improved testing for non-streaming agent apis
|
- Improved testing for non-streaming agent apis
|
||||||
- Standardized import formatting with ruff linter
|
- Standardized import formatting with ruff linter
|
||||||
- Added conventional commits standard
|
- Added conventional commits standard
|
||||||
- Fixed documentation parsing issues
|
- Fixed documentation parsing issues
|
||||||
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
# v0.1.2
|
# v0.1.2
|
||||||
Published on: 2025-02-07T22:06:49Z
|
Published on: 2025-02-07T22:06:49Z
|
||||||
|
|
||||||
# TL;DR
|
# TL;DR
|
||||||
- Several stabilizations to development flows after the switch to `uv`
|
- Several stabilizations to development flows after the switch to `uv`
|
||||||
- Migrated CI workflows to new OSS repo - [llama-stack-ops](https://github.com/meta-llama/llama-stack-ops)
|
- Migrated CI workflows to new OSS repo - [llama-stack-ops](https://github.com/meta-llama/llama-stack-ops)
|
||||||
- Added automated rebuilds for ReadTheDocs
|
- Added automated rebuilds for ReadTheDocs
|
||||||
- Llama Stack server supports HTTPS
|
- Llama Stack server supports HTTPS
|
||||||
- Added system prompt overrides support
|
- Added system prompt overrides support
|
||||||
- Several bug fixes and improvements to documentation (check out Kubernetes deployment guide by @terrytangyuan )
|
- Several bug fixes and improvements to documentation (check out Kubernetes deployment guide by @terrytangyuan )
|
||||||
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
# v0.1.1
|
# v0.1.1
|
||||||
Published on: 2025-02-02T02:29:24Z
|
Published on: 2025-02-02T02:29:24Z
|
||||||
|
|
||||||
A bunch of small / big improvements everywhere including support for Windows, switching to `uv` and many provider improvements.
|
A bunch of small / big improvements everywhere including support for Windows, switching to `uv` and many provider improvements.
|
||||||
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
# v0.1.0
|
# v0.1.0
|
||||||
Published on: 2025-01-24T17:47:47Z
|
Published on: 2025-01-24T17:47:47Z
|
||||||
|
|
||||||
We are excited to announce a stable API release of Llama Stack, which enables developers to build RAG applications and Agents using tools and safety shields, monitor and those agents with telemetry, and evaluate the agent with scoring functions.
|
We are excited to announce a stable API release of Llama Stack, which enables developers to build RAG applications and Agents using tools and safety shields, monitor and those agents with telemetry, and evaluate the agent with scoring functions.
|
||||||
|
|
||||||
## Context
|
## Context
|
||||||
GenAI application developers need more than just an LLM - they need to integrate tools, connect with their data sources, establish guardrails, and ground the LLM responses effectively. Currently, developers must piece together various tools and APIs, complicating the development lifecycle and increasing costs. The result is that developers are spending more time on these integrations rather than focusing on the application logic itself. The bespoke coupling of components also makes it challenging to adopt state-of-the-art solutions in the rapidly evolving GenAI space. This is particularly difficult for open models like Llama, as best practices are not widely established in the open.
|
GenAI application developers need more than just an LLM - they need to integrate tools, connect with their data sources, establish guardrails, and ground the LLM responses effectively. Currently, developers must piece together various tools and APIs, complicating the development lifecycle and increasing costs. The result is that developers are spending more time on these integrations rather than focusing on the application logic itself. The bespoke coupling of components also makes it challenging to adopt state-of-the-art solutions in the rapidly evolving GenAI space. This is particularly difficult for open models like Llama, as best practices are not widely established in the open.
|
||||||
|
|
||||||
Llama Stack was created to provide developers with a comprehensive and coherent interface that simplifies AI application development and codifies best practices across the Llama ecosystem. Since our launch in September 2024, we have seen a huge uptick in interest in Llama Stack APIs by both AI developers and from partners building AI services with Llama models. Partners like Nvidia, Fireworks, and Ollama have collaborated with us to develop implementations across various APIs, including inference, memory, and safety.
|
Llama Stack was created to provide developers with a comprehensive and coherent interface that simplifies AI application development and codifies best practices across the Llama ecosystem. Since our launch in September 2024, we have seen a huge uptick in interest in Llama Stack APIs by both AI developers and from partners building AI services with Llama models. Partners like Nvidia, Fireworks, and Ollama have collaborated with us to develop implementations across various APIs, including inference, memory, and safety.
|
||||||
|
|
||||||
With Llama Stack, you can easily build a RAG agent which can also search the web, do complex math, and custom tool calling. You can use telemetry to inspect those traces, and convert telemetry into evals datasets. And with Llama Stack’s plugin architecture and prepackage distributions, you choose to run your agent anywhere - in the cloud with our partners, deploy your own environment using virtualenv, conda, or Docker, operate locally with Ollama, or even run on mobile devices with our SDKs. Llama Stack offers unprecedented flexibility while also simplifying the developer experience.
|
With Llama Stack, you can easily build a RAG agent which can also search the web, do complex math, and custom tool calling. You can use telemetry to inspect those traces, and convert telemetry into evals datasets. And with Llama Stack’s plugin architecture and prepackage distributions, you choose to run your agent anywhere - in the cloud with our partners, deploy your own environment using virtualenv, conda, or Docker, operate locally with Ollama, or even run on mobile devices with our SDKs. Llama Stack offers unprecedented flexibility while also simplifying the developer experience.
|
||||||
|
|
||||||
## Release
|
## Release
|
||||||
After iterating on the APIs for the last 3 months, today we’re launching a stable release (V1) of the Llama Stack APIs and the corresponding llama-stack server and client packages(v0.1.0). We now have automated tests for providers. These tests make sure that all provider implementations are verified. Developers can now easily and reliably select distributions or providers based on their specific requirements.
|
After iterating on the APIs for the last 3 months, today we’re launching a stable release (V1) of the Llama Stack APIs and the corresponding llama-stack server and client packages(v0.1.0). We now have automated tests for providers. These tests make sure that all provider implementations are verified. Developers can now easily and reliably select distributions or providers based on their specific requirements.
|
||||||
|
|
||||||
There are example standalone apps in llama-stack-apps.
|
There are example standalone apps in llama-stack-apps.
|
||||||
|
|
||||||
|
|
||||||
## Key Features of this release
|
## Key Features of this release
|
||||||
|
|
||||||
- **Unified API Layer**
|
- **Unified API Layer**
|
||||||
- Inference: Run LLM models
|
- Inference: Run LLM models
|
||||||
- RAG: Store and retrieve knowledge for RAG
|
- RAG: Store and retrieve knowledge for RAG
|
||||||
- Agents: Build multi-step agentic workflows
|
- Agents: Build multi-step agentic workflows
|
||||||
- Tools: Register tools that can be called by the agent
|
- Tools: Register tools that can be called by the agent
|
||||||
- Safety: Apply content filtering and safety policies
|
- Safety: Apply content filtering and safety policies
|
||||||
- Evaluation: Test model and agent quality
|
- Evaluation: Test model and agent quality
|
||||||
- Telemetry: Collect and analyze usage data and complex agentic traces
|
- Telemetry: Collect and analyze usage data and complex agentic traces
|
||||||
- Post Training ( Coming Soon ): Fine tune models for specific use cases
|
- Post Training ( Coming Soon ): Fine tune models for specific use cases
|
||||||
|
|
||||||
- **Rich Provider Ecosystem**
|
- **Rich Provider Ecosystem**
|
||||||
- Local Development: Meta's Reference, Ollama
|
- Local Development: Meta's Reference, Ollama
|
||||||
- Cloud: Fireworks, Together, Nvidia, AWS Bedrock, Groq, Cerebras
|
- Cloud: Fireworks, Together, Nvidia, AWS Bedrock, Groq, Cerebras
|
||||||
- On-premises: Nvidia NIM, vLLM, TGI, Dell-TGI
|
- On-premises: Nvidia NIM, vLLM, TGI, Dell-TGI
|
||||||
- On-device: iOS and Android support
|
- On-device: iOS and Android support
|
||||||
|
|
||||||
- **Built for Production**
|
- **Built for Production**
|
||||||
- Pre-packaged distributions for common deployment scenarios
|
- Pre-packaged distributions for common deployment scenarios
|
||||||
- Backwards compatibility across model versions
|
- Backwards compatibility across model versions
|
||||||
- Comprehensive evaluation capabilities
|
- Comprehensive evaluation capabilities
|
||||||
- Full observability and monitoring
|
- Full observability and monitoring
|
||||||
|
|
||||||
- **Multiple developer interfaces**
|
- **Multiple developer interfaces**
|
||||||
- CLI: Command line interface
|
- CLI: Command line interface
|
||||||
- Python SDK
|
- Python SDK
|
||||||
- Swift iOS SDK
|
- Swift iOS SDK
|
||||||
- Kotlin Android SDK
|
- Kotlin Android SDK
|
||||||
|
|
||||||
- **Sample llama stack applications**
|
- **Sample llama stack applications**
|
||||||
- Python
|
- Python
|
||||||
- iOS
|
- iOS
|
||||||
- Android
|
- Android
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
@ -407,8 +407,8 @@ Published on: 2025-01-22T22:24:01Z
|
||||||
# v0.0.63
|
# v0.0.63
|
||||||
Published on: 2024-12-18T07:17:43Z
|
Published on: 2024-12-18T07:17:43Z
|
||||||
|
|
||||||
A small but important bug-fix release to update the URL datatype for the client-SDKs. The issue affected multimodal agentic turns especially.
|
A small but important bug-fix release to update the URL datatype for the client-SDKs. The issue affected multimodal agentic turns especially.
|
||||||
|
|
||||||
**Full Changelog**: https://github.com/meta-llama/llama-stack/compare/v0.0.62...v0.0.63
|
**Full Changelog**: https://github.com/meta-llama/llama-stack/compare/v0.0.62...v0.0.63
|
||||||
|
|
||||||
---
|
---
|
||||||
|
@ -444,39 +444,40 @@ Published on: 2024-11-22T00:36:09Z
|
||||||
# v0.0.53
|
# v0.0.53
|
||||||
Published on: 2024-11-20T22:18:00Z
|
Published on: 2024-11-20T22:18:00Z
|
||||||
|
|
||||||
🚀 Initial Release Notes for Llama Stack!
|
🚀 Initial Release Notes for Llama Stack!
|
||||||
|
|
||||||
### Added
|
### Added
|
||||||
- Resource-oriented design for models, shields, memory banks, datasets and eval tasks
|
- Resource-oriented design for models, shields, memory banks, datasets and eval tasks
|
||||||
- Persistence for registered objects with distribution
|
- Persistence for registered objects with distribution
|
||||||
- Ability to persist memory banks created for FAISS
|
- Ability to persist memory banks created for FAISS
|
||||||
- PostgreSQL KVStore implementation
|
- PostgreSQL KVStore implementation
|
||||||
- Environment variable placeholder support in run.yaml files
|
- Environment variable placeholder support in run.yaml files
|
||||||
- Comprehensive Zero-to-Hero notebooks and quickstart guides
|
- Comprehensive Zero-to-Hero notebooks and quickstart guides
|
||||||
- Support for quantized models in Ollama
|
- Support for quantized models in Ollama
|
||||||
- Vision models support for Together, Fireworks, Meta-Reference, and Ollama, and vLLM
|
- Vision models support for Together, Fireworks, Meta-Reference, and Ollama, and vLLM
|
||||||
- Bedrock distribution with safety shields support
|
- Bedrock distribution with safety shields support
|
||||||
- Evals API with task registration and scoring functions
|
- Evals API with task registration and scoring functions
|
||||||
- MMLU and SimpleQA benchmark scoring functions
|
- MMLU and SimpleQA benchmark scoring functions
|
||||||
- Huggingface dataset provider integration for benchmarks
|
- Huggingface dataset provider integration for benchmarks
|
||||||
- Support for custom dataset registration from local paths
|
- Support for custom dataset registration from local paths
|
||||||
- Benchmark evaluation CLI tools with visualization tables
|
- Benchmark evaluation CLI tools with visualization tables
|
||||||
- RAG evaluation scoring functions and metrics
|
- RAG evaluation scoring functions and metrics
|
||||||
- Local persistence for datasets and eval tasks
|
- Local persistence for datasets and eval tasks
|
||||||
|
|
||||||
### Changed
|
### Changed
|
||||||
- Split safety into distinct providers (llama-guard, prompt-guard, code-scanner)
|
- Split safety into distinct providers (llama-guard, prompt-guard, code-scanner)
|
||||||
- Changed provider naming convention (`impls` → `inline`, `adapters` → `remote`)
|
- Changed provider naming convention (`impls` → `inline`, `adapters` → `remote`)
|
||||||
- Updated API signatures for dataset and eval task registration
|
- Updated API signatures for dataset and eval task registration
|
||||||
- Restructured folder organization for providers
|
- Restructured folder organization for providers
|
||||||
- Enhanced Docker build configuration
|
- Enhanced Docker build configuration
|
||||||
- Added version prefixing for REST API routes
|
- Added version prefixing for REST API routes
|
||||||
- Enhanced evaluation task registration workflow
|
- Enhanced evaluation task registration workflow
|
||||||
- Improved benchmark evaluation output formatting
|
- Improved benchmark evaluation output formatting
|
||||||
- Restructured evals folder organization for better modularity
|
- Restructured evals folder organization for better modularity
|
||||||
|
|
||||||
### Removed
|
### Removed
|
||||||
- `llama stack configure` command
|
- `llama stack configure` command
|
||||||
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
|
|
@ -167,11 +167,14 @@ If you have made changes to a provider's configuration in any form (introducing
|
||||||
If you are making changes to the documentation at [https://llama-stack.readthedocs.io/en/latest/](https://llama-stack.readthedocs.io/en/latest/), you can use the following command to build the documentation and preview your changes. You will need [Sphinx](https://www.sphinx-doc.org/en/master/) and the readthedocs theme.
|
If you are making changes to the documentation at [https://llama-stack.readthedocs.io/en/latest/](https://llama-stack.readthedocs.io/en/latest/), you can use the following command to build the documentation and preview your changes. You will need [Sphinx](https://www.sphinx-doc.org/en/master/) and the readthedocs theme.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
cd docs
|
||||||
|
uv sync --extra docs
|
||||||
|
|
||||||
# This rebuilds the documentation pages.
|
# This rebuilds the documentation pages.
|
||||||
uv run --group docs make -C docs/ html
|
uv run make html
|
||||||
|
|
||||||
# This will start a local server (usually at http://127.0.0.1:8000) that automatically rebuilds and refreshes when you make changes to the documentation.
|
# This will start a local server (usually at http://127.0.0.1:8000) that automatically rebuilds and refreshes when you make changes to the documentation.
|
||||||
uv run --group docs sphinx-autobuild docs/source docs/build/html --write-all
|
uv run sphinx-autobuild source build/html --write-all
|
||||||
```
|
```
|
||||||
|
|
||||||
### Update API Documentation
|
### Update API Documentation
|
||||||
|
@ -179,7 +182,7 @@ uv run --group docs sphinx-autobuild docs/source docs/build/html --write-all
|
||||||
If you modify or add new API endpoints, update the API documentation accordingly. You can do this by running the following command:
|
If you modify or add new API endpoints, update the API documentation accordingly. You can do this by running the following command:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
uv run ./docs/openapi_generator/run_openapi_generator.sh
|
uv run --with ".[dev]" ./docs/openapi_generator/run_openapi_generator.sh
|
||||||
```
|
```
|
||||||
|
|
||||||
The generated API documentation will be available in `docs/_static/`. Make sure to review the changes before committing.
|
The generated API documentation will be available in `docs/_static/`. Make sure to review the changes before committing.
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
include pyproject.toml
|
include pyproject.toml
|
||||||
|
include llama_stack/templates/dependencies.json
|
||||||
include llama_stack/models/llama/llama3/tokenizer.model
|
include llama_stack/models/llama/llama3/tokenizer.model
|
||||||
include llama_stack/models/llama/llama4/tokenizer.model
|
include llama_stack/models/llama/llama4/tokenizer.model
|
||||||
include llama_stack/distribution/*.sh
|
include llama_stack/distribution/*.sh
|
||||||
|
|
43
README.md
43
README.md
|
@ -107,29 +107,26 @@ By reducing friction and complexity, Llama Stack empowers developers to focus on
|
||||||
### API Providers
|
### API Providers
|
||||||
Here is a list of the various API providers and available distributions that can help developers get started easily with Llama Stack.
|
Here is a list of the various API providers and available distributions that can help developers get started easily with Llama Stack.
|
||||||
|
|
||||||
| **API Provider Builder** | **Environments** | **Agents** | **Inference** | **Memory** | **Safety** | **Telemetry** | **Post Training** |
|
| **API Provider Builder** | **Environments** | **Agents** | **Inference** | **Memory** | **Safety** | **Telemetry** |
|
||||||
|:------------------------:|:----------------------:|:----------:|:-------------:|:----------:|:----------:|:-------------:|:-----------------:|
|
|:------------------------:|:----------------------:|:----------:|:-------------:|:----------:|:----------:|:-------------:|
|
||||||
| Meta Reference | Single Node | ✅ | ✅ | ✅ | ✅ | ✅ | |
|
| Meta Reference | Single Node | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||||
| SambaNova | Hosted | | ✅ | | ✅ | | |
|
| SambaNova | Hosted | | ✅ | | | |
|
||||||
| Cerebras | Hosted | | ✅ | | | | |
|
| Cerebras | Hosted | | ✅ | | | |
|
||||||
| Fireworks | Hosted | ✅ | ✅ | ✅ | | | |
|
| Fireworks | Hosted | ✅ | ✅ | ✅ | | |
|
||||||
| AWS Bedrock | Hosted | | ✅ | | ✅ | | |
|
| AWS Bedrock | Hosted | | ✅ | | ✅ | |
|
||||||
| Together | Hosted | ✅ | ✅ | | ✅ | | |
|
| Together | Hosted | ✅ | ✅ | | ✅ | |
|
||||||
| Groq | Hosted | | ✅ | | | | |
|
| Groq | Hosted | | ✅ | | | |
|
||||||
| Ollama | Single Node | | ✅ | | | | |
|
| Ollama | Single Node | | ✅ | | | |
|
||||||
| TGI | Hosted and Single Node | | ✅ | | | | |
|
| TGI | Hosted and Single Node | | ✅ | | | |
|
||||||
| NVIDIA NIM | Hosted and Single Node | | ✅ | | | | |
|
| NVIDIA NIM | Hosted and Single Node | | ✅ | | | |
|
||||||
| Chroma | Single Node | | | ✅ | | | |
|
| Chroma | Single Node | | | ✅ | | |
|
||||||
| PG Vector | Single Node | | | ✅ | | | |
|
| PG Vector | Single Node | | | ✅ | | |
|
||||||
| PyTorch ExecuTorch | On-device iOS | ✅ | ✅ | | | | |
|
| PyTorch ExecuTorch | On-device iOS | ✅ | ✅ | | | |
|
||||||
| vLLM | Hosted and Single Node | | ✅ | | | | |
|
| vLLM | Hosted and Single Node | | ✅ | | | |
|
||||||
| OpenAI | Hosted | | ✅ | | | | |
|
| OpenAI | Hosted | | ✅ | | | |
|
||||||
| Anthropic | Hosted | | ✅ | | | | |
|
| Anthropic | Hosted | | ✅ | | | |
|
||||||
| Gemini | Hosted | | ✅ | | | | |
|
| Gemini | Hosted | | ✅ | | | |
|
||||||
| watsonx | Hosted | | ✅ | | | | |
|
| watsonx | Hosted | | ✅ | | | |
|
||||||
| HuggingFace | Single Node | | | | | | ✅ |
|
|
||||||
| TorchTune | Single Node | | | | | | ✅ |
|
|
||||||
| NVIDIA NEMO | Hosted | | | | | | ✅ |
|
|
||||||
|
|
||||||
|
|
||||||
### Distributions
|
### Distributions
|
||||||
|
|
1996
docs/_static/llama-stack-spec.html
vendored
1996
docs/_static/llama-stack-spec.html
vendored
File diff suppressed because it is too large
Load diff
1479
docs/_static/llama-stack-spec.yaml
vendored
1479
docs/_static/llama-stack-spec.yaml
vendored
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
|
@ -759,7 +759,7 @@ class Generator:
|
||||||
)
|
)
|
||||||
|
|
||||||
return Operation(
|
return Operation(
|
||||||
tags=[getattr(op.defining_class, "API_NAMESPACE", op.defining_class.__name__)],
|
tags=[op.defining_class.__name__],
|
||||||
summary=None,
|
summary=None,
|
||||||
# summary=doc_string.short_description,
|
# summary=doc_string.short_description,
|
||||||
description=description,
|
description=description,
|
||||||
|
@ -805,8 +805,6 @@ class Generator:
|
||||||
operation_tags: List[Tag] = []
|
operation_tags: List[Tag] = []
|
||||||
for cls in endpoint_classes:
|
for cls in endpoint_classes:
|
||||||
doc_string = parse_type(cls)
|
doc_string = parse_type(cls)
|
||||||
if hasattr(cls, "API_NAMESPACE") and cls.API_NAMESPACE != cls.__name__:
|
|
||||||
continue
|
|
||||||
operation_tags.append(
|
operation_tags.append(
|
||||||
Tag(
|
Tag(
|
||||||
name=cls.__name__,
|
name=cls.__name__,
|
||||||
|
|
|
@ -3,10 +3,10 @@
|
||||||
Here's a collection of comprehensive guides, examples, and resources for building AI applications with Llama Stack. For the complete documentation, visit our [ReadTheDocs page](https://llama-stack.readthedocs.io/en/latest/index.html).
|
Here's a collection of comprehensive guides, examples, and resources for building AI applications with Llama Stack. For the complete documentation, visit our [ReadTheDocs page](https://llama-stack.readthedocs.io/en/latest/index.html).
|
||||||
|
|
||||||
## Render locally
|
## Render locally
|
||||||
|
|
||||||
From the llama-stack root directory, run the following command to render the docs locally:
|
|
||||||
```bash
|
```bash
|
||||||
uv run --group docs sphinx-autobuild docs/source docs/build/html --write-all
|
pip install -r requirements.txt
|
||||||
|
cd docs
|
||||||
|
python -m sphinx_autobuild source _build
|
||||||
```
|
```
|
||||||
You can open up the docs in your browser at http://localhost:8000
|
You can open up the docs in your browser at http://localhost:8000
|
||||||
|
|
||||||
|
|
16
docs/requirements.txt
Normal file
16
docs/requirements.txt
Normal file
|
@ -0,0 +1,16 @@
|
||||||
|
linkify
|
||||||
|
myst-parser
|
||||||
|
-e git+https://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme
|
||||||
|
sphinx==8.1.3
|
||||||
|
sphinx-copybutton
|
||||||
|
sphinx-design
|
||||||
|
sphinx-pdj-theme
|
||||||
|
sphinx-rtd-theme>=1.0.0
|
||||||
|
sphinx-tabs
|
||||||
|
sphinx_autobuild
|
||||||
|
sphinx_rtd_dark_mode
|
||||||
|
sphinxcontrib-mermaid
|
||||||
|
sphinxcontrib-openapi
|
||||||
|
sphinxcontrib-redoc
|
||||||
|
sphinxcontrib-video
|
||||||
|
tomli
|
|
@ -57,31 +57,6 @@ chunks = [
|
||||||
]
|
]
|
||||||
client.vector_io.insert(vector_db_id=vector_db_id, chunks=chunks)
|
client.vector_io.insert(vector_db_id=vector_db_id, chunks=chunks)
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Using Precomputed Embeddings
|
|
||||||
If you decide to precompute embeddings for your documents, you can insert them directly into the vector database by
|
|
||||||
including the embedding vectors in the chunk data. This is useful if you have a separate embedding service or if you
|
|
||||||
want to customize the ingestion process.
|
|
||||||
```python
|
|
||||||
chunks_with_embeddings = [
|
|
||||||
{
|
|
||||||
"content": "First chunk of text",
|
|
||||||
"mime_type": "text/plain",
|
|
||||||
"embedding": [0.1, 0.2, 0.3, ...], # Your precomputed embedding vector
|
|
||||||
"metadata": {"document_id": "doc1", "section": "introduction"},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"content": "Second chunk of text",
|
|
||||||
"mime_type": "text/plain",
|
|
||||||
"embedding": [0.2, 0.3, 0.4, ...], # Your precomputed embedding vector
|
|
||||||
"metadata": {"document_id": "doc1", "section": "methodology"},
|
|
||||||
},
|
|
||||||
]
|
|
||||||
client.vector_io.insert(vector_db_id=vector_db_id, chunks=chunks_with_embeddings)
|
|
||||||
```
|
|
||||||
When providing precomputed embeddings, ensure the embedding dimension matches the embedding_dimension specified when
|
|
||||||
registering the vector database.
|
|
||||||
|
|
||||||
### Retrieval
|
### Retrieval
|
||||||
You can query the vector database to retrieve documents based on their embeddings.
|
You can query the vector database to retrieve documents based on their embeddings.
|
||||||
```python
|
```python
|
||||||
|
|
|
@ -22,11 +22,7 @@ from docutils import nodes
|
||||||
# Read version from pyproject.toml
|
# Read version from pyproject.toml
|
||||||
with Path(__file__).parent.parent.parent.joinpath("pyproject.toml").open("rb") as f:
|
with Path(__file__).parent.parent.parent.joinpath("pyproject.toml").open("rb") as f:
|
||||||
pypi_url = "https://pypi.org/pypi/llama-stack/json"
|
pypi_url = "https://pypi.org/pypi/llama-stack/json"
|
||||||
headers = {
|
version_tag = json.loads(requests.get(pypi_url).text)["info"]["version"]
|
||||||
'User-Agent': 'pip/23.0.1 (python 3.11)', # Mimic pip's user agent
|
|
||||||
'Accept': 'application/json'
|
|
||||||
}
|
|
||||||
version_tag = json.loads(requests.get(pypi_url, headers=headers).text)["info"]["version"]
|
|
||||||
print(f"{version_tag=}")
|
print(f"{version_tag=}")
|
||||||
|
|
||||||
# generate the full link including text and url here
|
# generate the full link including text and url here
|
||||||
|
@ -57,6 +53,14 @@ myst_enable_extensions = ["colon_fence"]
|
||||||
|
|
||||||
html_theme = "sphinx_rtd_theme"
|
html_theme = "sphinx_rtd_theme"
|
||||||
html_use_relative_paths = True
|
html_use_relative_paths = True
|
||||||
|
|
||||||
|
# html_theme = "sphinx_pdj_theme"
|
||||||
|
# html_theme_path = [sphinx_pdj_theme.get_html_theme_path()]
|
||||||
|
|
||||||
|
# html_theme = "pytorch_sphinx_theme"
|
||||||
|
# html_theme_path = [pytorch_sphinx_theme.get_html_theme_path()]
|
||||||
|
|
||||||
|
|
||||||
templates_path = ["_templates"]
|
templates_path = ["_templates"]
|
||||||
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
|
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
|
||||||
|
|
||||||
|
|
|
@ -338,48 +338,6 @@ INFO: Application startup complete.
|
||||||
INFO: Uvicorn running on http://['::', '0.0.0.0']:8321 (Press CTRL+C to quit)
|
INFO: Uvicorn running on http://['::', '0.0.0.0']:8321 (Press CTRL+C to quit)
|
||||||
INFO: 2401:db00:35c:2d2b:face:0:c9:0:54678 - "GET /models/list HTTP/1.1" 200 OK
|
INFO: 2401:db00:35c:2d2b:face:0:c9:0:54678 - "GET /models/list HTTP/1.1" 200 OK
|
||||||
```
|
```
|
||||||
### Listing Distributions
|
|
||||||
Using the list command, you can view all existing Llama Stack distributions, including stacks built from templates, from scratch, or using custom configuration files.
|
|
||||||
|
|
||||||
```
|
|
||||||
llama stack list -h
|
|
||||||
usage: llama stack list [-h]
|
|
||||||
|
|
||||||
list the build stacks
|
|
||||||
|
|
||||||
options:
|
|
||||||
-h, --help show this help message and exit
|
|
||||||
```
|
|
||||||
|
|
||||||
Example Usage
|
|
||||||
|
|
||||||
```
|
|
||||||
llama stack list
|
|
||||||
```
|
|
||||||
|
|
||||||
### Removing a Distribution
|
|
||||||
Use the remove command to delete a distribution you've previously built.
|
|
||||||
|
|
||||||
```
|
|
||||||
llama stack rm -h
|
|
||||||
usage: llama stack rm [-h] [--all] [name]
|
|
||||||
|
|
||||||
Remove the build stack
|
|
||||||
|
|
||||||
positional arguments:
|
|
||||||
name Name of the stack to delete (default: None)
|
|
||||||
|
|
||||||
options:
|
|
||||||
-h, --help show this help message and exit
|
|
||||||
--all, -a Delete all stacks (use with caution) (default: False)
|
|
||||||
```
|
|
||||||
|
|
||||||
Example
|
|
||||||
```
|
|
||||||
llama stack rm llamastack-test
|
|
||||||
```
|
|
||||||
|
|
||||||
To keep your environment organized and avoid clutter, consider using `llama stack list` to review old or unused distributions and `llama stack rm <name>` to delete them when they’re no longer needed.
|
|
||||||
|
|
||||||
### Troubleshooting
|
### Troubleshooting
|
||||||
|
|
||||||
|
|
|
@ -118,6 +118,11 @@ server:
|
||||||
port: 8321 # Port to listen on (default: 8321)
|
port: 8321 # Port to listen on (default: 8321)
|
||||||
tls_certfile: "/path/to/cert.pem" # Optional: Path to TLS certificate for HTTPS
|
tls_certfile: "/path/to/cert.pem" # Optional: Path to TLS certificate for HTTPS
|
||||||
tls_keyfile: "/path/to/key.pem" # Optional: Path to TLS key for HTTPS
|
tls_keyfile: "/path/to/key.pem" # Optional: Path to TLS key for HTTPS
|
||||||
|
auth: # Optional: Authentication configuration
|
||||||
|
provider_type: "kubernetes" # Type of auth provider
|
||||||
|
config: # Provider-specific configuration
|
||||||
|
api_server_url: "https://kubernetes.default.svc"
|
||||||
|
ca_cert_path: "/path/to/ca.crt" # Optional: Path to CA certificate
|
||||||
```
|
```
|
||||||
|
|
||||||
### Authentication Configuration
|
### Authentication Configuration
|
||||||
|
@ -130,7 +135,7 @@ Authorization: Bearer <token>
|
||||||
|
|
||||||
The server supports multiple authentication providers:
|
The server supports multiple authentication providers:
|
||||||
|
|
||||||
#### OAuth 2.0/OpenID Connect Provider with Kubernetes
|
#### Kubernetes Provider
|
||||||
|
|
||||||
The Kubernetes cluster must be configured to use a service account for authentication.
|
The Kubernetes cluster must be configured to use a service account for authentication.
|
||||||
|
|
||||||
|
@ -141,67 +146,14 @@ kubectl create rolebinding llama-stack-auth-rolebinding --clusterrole=admin --se
|
||||||
kubectl create token llama-stack-auth -n llama-stack > llama-stack-auth-token
|
kubectl create token llama-stack-auth -n llama-stack > llama-stack-auth-token
|
||||||
```
|
```
|
||||||
|
|
||||||
Make sure the `kube-apiserver` runs with `--anonymous-auth=true` to allow unauthenticated requests
|
Validates tokens against the Kubernetes API server:
|
||||||
and that the correct RoleBinding is created to allow the service account to access the necessary
|
|
||||||
resources. If that is not the case, you can create a RoleBinding for the service account to access
|
|
||||||
the necessary resources:
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
# allow-anonymous-openid.yaml
|
|
||||||
apiVersion: rbac.authorization.k8s.io/v1
|
|
||||||
kind: ClusterRole
|
|
||||||
metadata:
|
|
||||||
name: allow-anonymous-openid
|
|
||||||
rules:
|
|
||||||
- nonResourceURLs: ["/openid/v1/jwks"]
|
|
||||||
verbs: ["get"]
|
|
||||||
---
|
|
||||||
apiVersion: rbac.authorization.k8s.io/v1
|
|
||||||
kind: ClusterRoleBinding
|
|
||||||
metadata:
|
|
||||||
name: allow-anonymous-openid
|
|
||||||
roleRef:
|
|
||||||
apiGroup: rbac.authorization.k8s.io
|
|
||||||
kind: ClusterRole
|
|
||||||
name: allow-anonymous-openid
|
|
||||||
subjects:
|
|
||||||
- kind: User
|
|
||||||
name: system:anonymous
|
|
||||||
apiGroup: rbac.authorization.k8s.io
|
|
||||||
```
|
|
||||||
|
|
||||||
And then apply the configuration:
|
|
||||||
```bash
|
|
||||||
kubectl apply -f allow-anonymous-openid.yaml
|
|
||||||
```
|
|
||||||
|
|
||||||
Validates tokens against the Kubernetes API server through the OIDC provider:
|
|
||||||
```yaml
|
```yaml
|
||||||
server:
|
server:
|
||||||
auth:
|
auth:
|
||||||
provider_type: "oauth2_token"
|
provider_type: "kubernetes"
|
||||||
config:
|
config:
|
||||||
jwks:
|
api_server_url: "https://kubernetes.default.svc" # URL of the Kubernetes API server
|
||||||
uri: "https://kubernetes.default.svc"
|
ca_cert_path: "/path/to/ca.crt" # Optional: Path to CA certificate
|
||||||
key_recheck_period: 3600
|
|
||||||
tls_cafile: "/path/to/ca.crt"
|
|
||||||
issuer: "https://kubernetes.default.svc"
|
|
||||||
audience: "https://kubernetes.default.svc"
|
|
||||||
```
|
|
||||||
|
|
||||||
To find your cluster's audience, run:
|
|
||||||
```bash
|
|
||||||
kubectl create token default --duration=1h | cut -d. -f2 | base64 -d | jq .aud
|
|
||||||
```
|
|
||||||
|
|
||||||
For the issuer, you can use the OIDC provider's URL:
|
|
||||||
```bash
|
|
||||||
kubectl get --raw /.well-known/openid-configuration| jq .issuer
|
|
||||||
```
|
|
||||||
|
|
||||||
For the tls_cafile, you can use the CA certificate of the OIDC provider:
|
|
||||||
```bash
|
|
||||||
kubectl config view --minify -o jsonpath='{.clusters[0].cluster.certificate-authority}'
|
|
||||||
```
|
```
|
||||||
|
|
||||||
The provider extracts user information from the JWT token:
|
The provider extracts user information from the JWT token:
|
||||||
|
@ -256,80 +208,6 @@ And must respond with:
|
||||||
|
|
||||||
If no access attributes are returned, the token is used as a namespace.
|
If no access attributes are returned, the token is used as a namespace.
|
||||||
|
|
||||||
### Quota Configuration
|
|
||||||
|
|
||||||
The `quota` section allows you to enable server-side request throttling for both
|
|
||||||
authenticated and anonymous clients. This is useful for preventing abuse, enforcing
|
|
||||||
fairness across tenants, and controlling infrastructure costs without requiring
|
|
||||||
client-side rate limiting or external proxies.
|
|
||||||
|
|
||||||
Quotas are disabled by default. When enabled, each client is tracked using either:
|
|
||||||
|
|
||||||
* Their authenticated `client_id` (derived from the Bearer token), or
|
|
||||||
* Their IP address (fallback for anonymous requests)
|
|
||||||
|
|
||||||
Quota state is stored in a SQLite-backed key-value store, and rate limits are applied
|
|
||||||
within a configurable time window (currently only `day` is supported).
|
|
||||||
|
|
||||||
#### Example
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
server:
|
|
||||||
quota:
|
|
||||||
kvstore:
|
|
||||||
type: sqlite
|
|
||||||
db_path: ./quotas.db
|
|
||||||
anonymous_max_requests: 100
|
|
||||||
authenticated_max_requests: 1000
|
|
||||||
period: day
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Configuration Options
|
|
||||||
|
|
||||||
| Field | Description |
|
|
||||||
| ---------------------------- | -------------------------------------------------------------------------- |
|
|
||||||
| `kvstore` | Required. Backend storage config for tracking request counts. |
|
|
||||||
| `kvstore.type` | Must be `"sqlite"` for now. Other backends may be supported in the future. |
|
|
||||||
| `kvstore.db_path` | File path to the SQLite database. |
|
|
||||||
| `anonymous_max_requests` | Max requests per period for unauthenticated clients. |
|
|
||||||
| `authenticated_max_requests` | Max requests per period for authenticated clients. |
|
|
||||||
| `period` | Time window for quota enforcement. Only `"day"` is supported. |
|
|
||||||
|
|
||||||
> Note: if `authenticated_max_requests` is set but no authentication provider is
|
|
||||||
configured, the server will fall back to applying `anonymous_max_requests` to all
|
|
||||||
clients.
|
|
||||||
|
|
||||||
#### Example with Authentication Enabled
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
server:
|
|
||||||
port: 8321
|
|
||||||
auth:
|
|
||||||
provider_type: custom
|
|
||||||
config:
|
|
||||||
endpoint: https://auth.example.com/validate
|
|
||||||
quota:
|
|
||||||
kvstore:
|
|
||||||
type: sqlite
|
|
||||||
db_path: ./quotas.db
|
|
||||||
anonymous_max_requests: 100
|
|
||||||
authenticated_max_requests: 1000
|
|
||||||
period: day
|
|
||||||
```
|
|
||||||
|
|
||||||
If a client exceeds their limit, the server responds with:
|
|
||||||
|
|
||||||
```http
|
|
||||||
HTTP/1.1 429 Too Many Requests
|
|
||||||
Content-Type: application/json
|
|
||||||
|
|
||||||
{
|
|
||||||
"error": {
|
|
||||||
"message": "Quota exceeded"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
## Extending to handle Safety
|
## Extending to handle Safety
|
||||||
|
|
||||||
Configuring Safety can be a little involved so it is instructive to go through an example.
|
Configuring Safety can be a little involved so it is instructive to go through an example.
|
||||||
|
|
|
@ -172,7 +172,7 @@ spec:
|
||||||
- name: llama-stack
|
- name: llama-stack
|
||||||
image: localhost/llama-stack-run-k8s:latest
|
image: localhost/llama-stack-run-k8s:latest
|
||||||
imagePullPolicy: IfNotPresent
|
imagePullPolicy: IfNotPresent
|
||||||
command: ["python", "-m", "llama_stack.distribution.server.server", "--config", "/app/config.yaml"]
|
command: ["python", "-m", "llama_stack.distribution.server.server", "--yaml-config", "/app/config.yaml"]
|
||||||
ports:
|
ports:
|
||||||
- containerPort: 5000
|
- containerPort: 5000
|
||||||
volumeMounts:
|
volumeMounts:
|
||||||
|
|
|
@ -70,7 +70,7 @@ docker run \
|
||||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||||
-v ./run.yaml:/root/my-run.yaml \
|
-v ./run.yaml:/root/my-run.yaml \
|
||||||
llamastack/distribution-watsonx \
|
llamastack/distribution-watsonx \
|
||||||
--config /root/my-run.yaml \
|
--yaml-config /root/my-run.yaml \
|
||||||
--port $LLAMA_STACK_PORT \
|
--port $LLAMA_STACK_PORT \
|
||||||
--env WATSONX_API_KEY=$WATSONX_API_KEY \
|
--env WATSONX_API_KEY=$WATSONX_API_KEY \
|
||||||
--env WATSONX_PROJECT_ID=$WATSONX_PROJECT_ID \
|
--env WATSONX_PROJECT_ID=$WATSONX_PROJECT_ID \
|
||||||
|
|
|
@ -52,7 +52,7 @@ docker run \
|
||||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||||
-v ./run.yaml:/root/my-run.yaml \
|
-v ./run.yaml:/root/my-run.yaml \
|
||||||
llamastack/distribution-cerebras \
|
llamastack/distribution-cerebras \
|
||||||
--config /root/my-run.yaml \
|
--yaml-config /root/my-run.yaml \
|
||||||
--port $LLAMA_STACK_PORT \
|
--port $LLAMA_STACK_PORT \
|
||||||
--env CEREBRAS_API_KEY=$CEREBRAS_API_KEY
|
--env CEREBRAS_API_KEY=$CEREBRAS_API_KEY
|
||||||
```
|
```
|
||||||
|
|
|
@ -155,7 +155,7 @@ docker run \
|
||||||
-v $HOME/.llama:/root/.llama \
|
-v $HOME/.llama:/root/.llama \
|
||||||
-v ./llama_stack/templates/tgi/run-with-safety.yaml:/root/my-run.yaml \
|
-v ./llama_stack/templates/tgi/run-with-safety.yaml:/root/my-run.yaml \
|
||||||
llamastack/distribution-dell \
|
llamastack/distribution-dell \
|
||||||
--config /root/my-run.yaml \
|
--yaml-config /root/my-run.yaml \
|
||||||
--port $LLAMA_STACK_PORT \
|
--port $LLAMA_STACK_PORT \
|
||||||
--env INFERENCE_MODEL=$INFERENCE_MODEL \
|
--env INFERENCE_MODEL=$INFERENCE_MODEL \
|
||||||
--env DEH_URL=$DEH_URL \
|
--env DEH_URL=$DEH_URL \
|
||||||
|
|
|
@ -143,7 +143,7 @@ docker run \
|
||||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||||
-v ./run.yaml:/root/my-run.yaml \
|
-v ./run.yaml:/root/my-run.yaml \
|
||||||
llamastack/distribution-nvidia \
|
llamastack/distribution-nvidia \
|
||||||
--config /root/my-run.yaml \
|
--yaml-config /root/my-run.yaml \
|
||||||
--port $LLAMA_STACK_PORT \
|
--port $LLAMA_STACK_PORT \
|
||||||
--env NVIDIA_API_KEY=$NVIDIA_API_KEY
|
--env NVIDIA_API_KEY=$NVIDIA_API_KEY
|
||||||
```
|
```
|
||||||
|
|
|
@ -19,7 +19,6 @@ The `llamastack/distribution-ollama` distribution consists of the following prov
|
||||||
| datasetio | `remote::huggingface`, `inline::localfs` |
|
| datasetio | `remote::huggingface`, `inline::localfs` |
|
||||||
| eval | `inline::meta-reference` |
|
| eval | `inline::meta-reference` |
|
||||||
| inference | `remote::ollama` |
|
| inference | `remote::ollama` |
|
||||||
| post_training | `inline::huggingface` |
|
|
||||||
| safety | `inline::llama-guard` |
|
| safety | `inline::llama-guard` |
|
||||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||||
| telemetry | `inline::meta-reference` |
|
| telemetry | `inline::meta-reference` |
|
||||||
|
@ -98,7 +97,7 @@ docker run \
|
||||||
-v ~/.llama:/root/.llama \
|
-v ~/.llama:/root/.llama \
|
||||||
-v ./llama_stack/templates/ollama/run-with-safety.yaml:/root/my-run.yaml \
|
-v ./llama_stack/templates/ollama/run-with-safety.yaml:/root/my-run.yaml \
|
||||||
llamastack/distribution-ollama \
|
llamastack/distribution-ollama \
|
||||||
--config /root/my-run.yaml \
|
--yaml-config /root/my-run.yaml \
|
||||||
--port $LLAMA_STACK_PORT \
|
--port $LLAMA_STACK_PORT \
|
||||||
--env INFERENCE_MODEL=$INFERENCE_MODEL \
|
--env INFERENCE_MODEL=$INFERENCE_MODEL \
|
||||||
--env SAFETY_MODEL=$SAFETY_MODEL \
|
--env SAFETY_MODEL=$SAFETY_MODEL \
|
||||||
|
|
|
@ -233,7 +233,7 @@ docker run \
|
||||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||||
-v ./llama_stack/templates/remote-vllm/run.yaml:/root/my-run.yaml \
|
-v ./llama_stack/templates/remote-vllm/run.yaml:/root/my-run.yaml \
|
||||||
llamastack/distribution-remote-vllm \
|
llamastack/distribution-remote-vllm \
|
||||||
--config /root/my-run.yaml \
|
--yaml-config /root/my-run.yaml \
|
||||||
--port $LLAMA_STACK_PORT \
|
--port $LLAMA_STACK_PORT \
|
||||||
--env INFERENCE_MODEL=$INFERENCE_MODEL \
|
--env INFERENCE_MODEL=$INFERENCE_MODEL \
|
||||||
--env VLLM_URL=http://host.docker.internal:$INFERENCE_PORT/v1
|
--env VLLM_URL=http://host.docker.internal:$INFERENCE_PORT/v1
|
||||||
|
@ -255,7 +255,7 @@ docker run \
|
||||||
-v ~/.llama:/root/.llama \
|
-v ~/.llama:/root/.llama \
|
||||||
-v ./llama_stack/templates/remote-vllm/run-with-safety.yaml:/root/my-run.yaml \
|
-v ./llama_stack/templates/remote-vllm/run-with-safety.yaml:/root/my-run.yaml \
|
||||||
llamastack/distribution-remote-vllm \
|
llamastack/distribution-remote-vllm \
|
||||||
--config /root/my-run.yaml \
|
--yaml-config /root/my-run.yaml \
|
||||||
--port $LLAMA_STACK_PORT \
|
--port $LLAMA_STACK_PORT \
|
||||||
--env INFERENCE_MODEL=$INFERENCE_MODEL \
|
--env INFERENCE_MODEL=$INFERENCE_MODEL \
|
||||||
--env VLLM_URL=http://host.docker.internal:$INFERENCE_PORT/v1 \
|
--env VLLM_URL=http://host.docker.internal:$INFERENCE_PORT/v1 \
|
||||||
|
|
|
@ -17,7 +17,7 @@ The `llamastack/distribution-sambanova` distribution consists of the following p
|
||||||
|-----|-------------|
|
|-----|-------------|
|
||||||
| agents | `inline::meta-reference` |
|
| agents | `inline::meta-reference` |
|
||||||
| inference | `remote::sambanova`, `inline::sentence-transformers` |
|
| inference | `remote::sambanova`, `inline::sentence-transformers` |
|
||||||
| safety | `remote::sambanova` |
|
| safety | `inline::llama-guard` |
|
||||||
| telemetry | `inline::meta-reference` |
|
| telemetry | `inline::meta-reference` |
|
||||||
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::rag-runtime`, `remote::model-context-protocol`, `remote::wolfram-alpha` |
|
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::rag-runtime`, `remote::model-context-protocol`, `remote::wolfram-alpha` |
|
||||||
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
||||||
|
@ -48,44 +48,33 @@ The following models are available by default:
|
||||||
|
|
||||||
### Prerequisite: API Keys
|
### Prerequisite: API Keys
|
||||||
|
|
||||||
Make sure you have access to a SambaNova API Key. You can get one by visiting [SambaNova.ai](http://cloud.sambanova.ai?utm_source=llamastack&utm_medium=external&utm_campaign=cloud_signup).
|
Make sure you have access to a SambaNova API Key. You can get one by visiting [SambaNova.ai](https://sambanova.ai/).
|
||||||
|
|
||||||
|
|
||||||
## Running Llama Stack with SambaNova
|
## Running Llama Stack with SambaNova
|
||||||
|
|
||||||
You can do this via Conda (build code) or Docker which has a pre-built image.
|
You can do this via Conda (build code) or Docker which has a pre-built image.
|
||||||
|
|
||||||
|
|
||||||
### Via Docker
|
### Via Docker
|
||||||
|
|
||||||
|
This method allows you to get started quickly without having to build the distribution code.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
LLAMA_STACK_PORT=8321
|
LLAMA_STACK_PORT=8321
|
||||||
llama stack build --template sambanova --image-type container
|
|
||||||
docker run \
|
docker run \
|
||||||
-it \
|
-it \
|
||||||
|
--pull always \
|
||||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||||
-v ~/.llama:/root/.llama \
|
llamastack/distribution-sambanova \
|
||||||
distribution-sambanova \
|
|
||||||
--port $LLAMA_STACK_PORT \
|
--port $LLAMA_STACK_PORT \
|
||||||
--env SAMBANOVA_API_KEY=$SAMBANOVA_API_KEY
|
--env SAMBANOVA_API_KEY=$SAMBANOVA_API_KEY
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
### Via Venv
|
|
||||||
|
|
||||||
```bash
|
|
||||||
llama stack build --template sambanova --image-type venv
|
|
||||||
llama stack run --image-type venv ~/.llama/distributions/sambanova/sambanova-run.yaml \
|
|
||||||
--port $LLAMA_STACK_PORT \
|
|
||||||
--env SAMBANOVA_API_KEY=$SAMBANOVA_API_KEY
|
|
||||||
```
|
|
||||||
|
|
||||||
|
|
||||||
### Via Conda
|
### Via Conda
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
llama stack build --template sambanova --image-type conda
|
llama stack build --template sambanova --image-type conda
|
||||||
llama stack run --image-type conda ~/.llama/distributions/sambanova/sambanova-run.yaml \
|
llama stack run ./run.yaml \
|
||||||
--port $LLAMA_STACK_PORT \
|
--port $LLAMA_STACK_PORT \
|
||||||
--env SAMBANOVA_API_KEY=$SAMBANOVA_API_KEY
|
--env SAMBANOVA_API_KEY=$SAMBANOVA_API_KEY
|
||||||
```
|
```
|
||||||
|
|
|
@ -117,7 +117,7 @@ docker run \
|
||||||
-v ~/.llama:/root/.llama \
|
-v ~/.llama:/root/.llama \
|
||||||
-v ./llama_stack/templates/tgi/run-with-safety.yaml:/root/my-run.yaml \
|
-v ./llama_stack/templates/tgi/run-with-safety.yaml:/root/my-run.yaml \
|
||||||
llamastack/distribution-tgi \
|
llamastack/distribution-tgi \
|
||||||
--config /root/my-run.yaml \
|
--yaml-config /root/my-run.yaml \
|
||||||
--port $LLAMA_STACK_PORT \
|
--port $LLAMA_STACK_PORT \
|
||||||
--env INFERENCE_MODEL=$INFERENCE_MODEL \
|
--env INFERENCE_MODEL=$INFERENCE_MODEL \
|
||||||
--env TGI_URL=http://host.docker.internal:$INFERENCE_PORT \
|
--env TGI_URL=http://host.docker.internal:$INFERENCE_PORT \
|
||||||
|
|
|
@ -30,18 +30,6 @@ Runs inference with an LLM.
|
||||||
## Post Training
|
## Post Training
|
||||||
Fine-tunes a model.
|
Fine-tunes a model.
|
||||||
|
|
||||||
#### Post Training Providers
|
|
||||||
The following providers are available for Post Training:
|
|
||||||
|
|
||||||
```{toctree}
|
|
||||||
:maxdepth: 1
|
|
||||||
|
|
||||||
external
|
|
||||||
post_training/huggingface
|
|
||||||
post_training/torchtune
|
|
||||||
post_training/nvidia_nemo
|
|
||||||
```
|
|
||||||
|
|
||||||
## Safety
|
## Safety
|
||||||
Applies safety policies to the output at a Systems (not only model) level.
|
Applies safety policies to the output at a Systems (not only model) level.
|
||||||
|
|
||||||
|
|
|
@ -1,122 +0,0 @@
|
||||||
---
|
|
||||||
orphan: true
|
|
||||||
---
|
|
||||||
# HuggingFace SFTTrainer
|
|
||||||
|
|
||||||
[HuggingFace SFTTrainer](https://huggingface.co/docs/trl/en/sft_trainer) is an inline post training provider for Llama Stack. It allows you to run supervised fine tuning on a variety of models using many datasets
|
|
||||||
|
|
||||||
## Features
|
|
||||||
|
|
||||||
- Simple access through the post_training API
|
|
||||||
- Fully integrated with Llama Stack
|
|
||||||
- GPU support, CPU support, and MPS support (MacOS Metal Performance Shaders)
|
|
||||||
|
|
||||||
## Usage
|
|
||||||
|
|
||||||
To use the HF SFTTrainer in your Llama Stack project, follow these steps:
|
|
||||||
|
|
||||||
1. Configure your Llama Stack project to use this provider.
|
|
||||||
2. Kick off a SFT job using the Llama Stack post_training API.
|
|
||||||
|
|
||||||
## Setup
|
|
||||||
|
|
||||||
You can access the HuggingFace trainer via the `ollama` distribution:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
llama stack build --template ollama --image-type venv
|
|
||||||
llama stack run --image-type venv ~/.llama/distributions/ollama/ollama-run.yaml
|
|
||||||
```
|
|
||||||
|
|
||||||
## Run Training
|
|
||||||
|
|
||||||
You can access the provider and the `supervised_fine_tune` method via the post_training API:
|
|
||||||
|
|
||||||
```python
|
|
||||||
import time
|
|
||||||
import uuid
|
|
||||||
|
|
||||||
|
|
||||||
from llama_stack_client.types import (
|
|
||||||
post_training_supervised_fine_tune_params,
|
|
||||||
algorithm_config_param,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def create_http_client():
|
|
||||||
from llama_stack_client import LlamaStackClient
|
|
||||||
|
|
||||||
return LlamaStackClient(base_url="http://localhost:8321")
|
|
||||||
|
|
||||||
|
|
||||||
client = create_http_client()
|
|
||||||
|
|
||||||
# Example Dataset
|
|
||||||
client.datasets.register(
|
|
||||||
purpose="post-training/messages",
|
|
||||||
source={
|
|
||||||
"type": "uri",
|
|
||||||
"uri": "huggingface://datasets/llamastack/simpleqa?split=train",
|
|
||||||
},
|
|
||||||
dataset_id="simpleqa",
|
|
||||||
)
|
|
||||||
|
|
||||||
training_config = post_training_supervised_fine_tune_params.TrainingConfig(
|
|
||||||
data_config=post_training_supervised_fine_tune_params.TrainingConfigDataConfig(
|
|
||||||
batch_size=32,
|
|
||||||
data_format="instruct",
|
|
||||||
dataset_id="simpleqa",
|
|
||||||
shuffle=True,
|
|
||||||
),
|
|
||||||
gradient_accumulation_steps=1,
|
|
||||||
max_steps_per_epoch=0,
|
|
||||||
max_validation_steps=1,
|
|
||||||
n_epochs=4,
|
|
||||||
)
|
|
||||||
|
|
||||||
algorithm_config = algorithm_config_param.LoraFinetuningConfig( # this config is also currently mandatory but should not be
|
|
||||||
alpha=1,
|
|
||||||
apply_lora_to_mlp=True,
|
|
||||||
apply_lora_to_output=False,
|
|
||||||
lora_attn_modules=["q_proj"],
|
|
||||||
rank=1,
|
|
||||||
type="LoRA",
|
|
||||||
)
|
|
||||||
|
|
||||||
job_uuid = f"test-job{uuid.uuid4()}"
|
|
||||||
|
|
||||||
# Example Model
|
|
||||||
training_model = "ibm-granite/granite-3.3-8b-instruct"
|
|
||||||
|
|
||||||
start_time = time.time()
|
|
||||||
response = client.post_training.supervised_fine_tune(
|
|
||||||
job_uuid=job_uuid,
|
|
||||||
logger_config={},
|
|
||||||
model=training_model,
|
|
||||||
hyperparam_search_config={},
|
|
||||||
training_config=training_config,
|
|
||||||
algorithm_config=algorithm_config,
|
|
||||||
checkpoint_dir="output",
|
|
||||||
)
|
|
||||||
print("Job: ", job_uuid)
|
|
||||||
|
|
||||||
|
|
||||||
# Wait for the job to complete!
|
|
||||||
while True:
|
|
||||||
status = client.post_training.job.status(job_uuid=job_uuid)
|
|
||||||
if not status:
|
|
||||||
print("Job not found")
|
|
||||||
break
|
|
||||||
|
|
||||||
print(status)
|
|
||||||
if status.status == "completed":
|
|
||||||
break
|
|
||||||
|
|
||||||
print("Waiting for job to complete...")
|
|
||||||
time.sleep(5)
|
|
||||||
|
|
||||||
end_time = time.time()
|
|
||||||
print("Job completed in", end_time - start_time, "seconds!")
|
|
||||||
|
|
||||||
print("Artifacts:")
|
|
||||||
print(client.post_training.job.artifacts(job_uuid=job_uuid))
|
|
||||||
```
|
|
|
@ -1,163 +0,0 @@
|
||||||
---
|
|
||||||
orphan: true
|
|
||||||
---
|
|
||||||
# NVIDIA NEMO
|
|
||||||
|
|
||||||
[NVIDIA NEMO](https://developer.nvidia.com/nemo-framework) is a remote post training provider for Llama Stack. It provides enterprise-grade fine-tuning capabilities through NVIDIA's NeMo Customizer service.
|
|
||||||
|
|
||||||
## Features
|
|
||||||
|
|
||||||
- Enterprise-grade fine-tuning capabilities
|
|
||||||
- Support for LoRA and SFT fine-tuning
|
|
||||||
- Integration with NVIDIA's NeMo Customizer service
|
|
||||||
- Support for various NVIDIA-optimized models
|
|
||||||
- Efficient training with NVIDIA hardware acceleration
|
|
||||||
|
|
||||||
## Usage
|
|
||||||
|
|
||||||
To use NVIDIA NEMO in your Llama Stack project, follow these steps:
|
|
||||||
|
|
||||||
1. Configure your Llama Stack project to use this provider.
|
|
||||||
2. Set up your NVIDIA API credentials.
|
|
||||||
3. Kick off a fine-tuning job using the Llama Stack post_training API.
|
|
||||||
|
|
||||||
## Setup
|
|
||||||
|
|
||||||
You'll need to set the following environment variables:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
export NVIDIA_API_KEY="your-api-key"
|
|
||||||
export NVIDIA_DATASET_NAMESPACE="default"
|
|
||||||
export NVIDIA_CUSTOMIZER_URL="your-customizer-url"
|
|
||||||
export NVIDIA_PROJECT_ID="your-project-id"
|
|
||||||
export NVIDIA_OUTPUT_MODEL_DIR="your-output-model-dir"
|
|
||||||
```
|
|
||||||
|
|
||||||
## Run Training
|
|
||||||
|
|
||||||
You can access the provider and the `supervised_fine_tune` method via the post_training API:
|
|
||||||
|
|
||||||
```python
|
|
||||||
import time
|
|
||||||
import uuid
|
|
||||||
|
|
||||||
from llama_stack_client.types import (
|
|
||||||
post_training_supervised_fine_tune_params,
|
|
||||||
algorithm_config_param,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def create_http_client():
|
|
||||||
from llama_stack_client import LlamaStackClient
|
|
||||||
|
|
||||||
return LlamaStackClient(base_url="http://localhost:8321")
|
|
||||||
|
|
||||||
|
|
||||||
client = create_http_client()
|
|
||||||
|
|
||||||
# Example Dataset
|
|
||||||
client.datasets.register(
|
|
||||||
purpose="post-training/messages",
|
|
||||||
source={
|
|
||||||
"type": "uri",
|
|
||||||
"uri": "huggingface://datasets/llamastack/simpleqa?split=train",
|
|
||||||
},
|
|
||||||
dataset_id="simpleqa",
|
|
||||||
)
|
|
||||||
|
|
||||||
training_config = post_training_supervised_fine_tune_params.TrainingConfig(
|
|
||||||
data_config=post_training_supervised_fine_tune_params.TrainingConfigDataConfig(
|
|
||||||
batch_size=8, # Default batch size for NEMO
|
|
||||||
data_format="instruct",
|
|
||||||
dataset_id="simpleqa",
|
|
||||||
shuffle=True,
|
|
||||||
),
|
|
||||||
n_epochs=50, # Default epochs for NEMO
|
|
||||||
optimizer_config=post_training_supervised_fine_tune_params.TrainingConfigOptimizerConfig(
|
|
||||||
lr=0.0001, # Default learning rate
|
|
||||||
weight_decay=0.01, # NEMO-specific parameter
|
|
||||||
),
|
|
||||||
# NEMO-specific parameters
|
|
||||||
log_every_n_steps=None,
|
|
||||||
val_check_interval=0.25,
|
|
||||||
sequence_packing_enabled=False,
|
|
||||||
hidden_dropout=None,
|
|
||||||
attention_dropout=None,
|
|
||||||
ffn_dropout=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
algorithm_config = algorithm_config_param.LoraFinetuningConfig(
|
|
||||||
alpha=16, # Default alpha for NEMO
|
|
||||||
type="LoRA",
|
|
||||||
)
|
|
||||||
|
|
||||||
job_uuid = f"test-job{uuid.uuid4()}"
|
|
||||||
|
|
||||||
# Example Model - must be a supported NEMO model
|
|
||||||
training_model = "meta/llama-3.1-8b-instruct"
|
|
||||||
|
|
||||||
start_time = time.time()
|
|
||||||
response = client.post_training.supervised_fine_tune(
|
|
||||||
job_uuid=job_uuid,
|
|
||||||
logger_config={},
|
|
||||||
model=training_model,
|
|
||||||
hyperparam_search_config={},
|
|
||||||
training_config=training_config,
|
|
||||||
algorithm_config=algorithm_config,
|
|
||||||
checkpoint_dir="output",
|
|
||||||
)
|
|
||||||
print("Job: ", job_uuid)
|
|
||||||
|
|
||||||
# Wait for the job to complete!
|
|
||||||
while True:
|
|
||||||
status = client.post_training.job.status(job_uuid=job_uuid)
|
|
||||||
if not status:
|
|
||||||
print("Job not found")
|
|
||||||
break
|
|
||||||
|
|
||||||
print(status)
|
|
||||||
if status.status == "completed":
|
|
||||||
break
|
|
||||||
|
|
||||||
print("Waiting for job to complete...")
|
|
||||||
time.sleep(5)
|
|
||||||
|
|
||||||
end_time = time.time()
|
|
||||||
print("Job completed in", end_time - start_time, "seconds!")
|
|
||||||
|
|
||||||
print("Artifacts:")
|
|
||||||
print(client.post_training.job.artifacts(job_uuid=job_uuid))
|
|
||||||
```
|
|
||||||
|
|
||||||
## Supported Models
|
|
||||||
|
|
||||||
Currently supports the following models:
|
|
||||||
- meta/llama-3.1-8b-instruct
|
|
||||||
- meta/llama-3.2-1b-instruct
|
|
||||||
|
|
||||||
## Supported Parameters
|
|
||||||
|
|
||||||
### TrainingConfig
|
|
||||||
- n_epochs (default: 50)
|
|
||||||
- data_config
|
|
||||||
- optimizer_config
|
|
||||||
- log_every_n_steps
|
|
||||||
- val_check_interval (default: 0.25)
|
|
||||||
- sequence_packing_enabled (default: False)
|
|
||||||
- hidden_dropout (0.0-1.0)
|
|
||||||
- attention_dropout (0.0-1.0)
|
|
||||||
- ffn_dropout (0.0-1.0)
|
|
||||||
|
|
||||||
### DataConfig
|
|
||||||
- dataset_id
|
|
||||||
- batch_size (default: 8)
|
|
||||||
|
|
||||||
### OptimizerConfig
|
|
||||||
- lr (default: 0.0001)
|
|
||||||
- weight_decay (default: 0.01)
|
|
||||||
|
|
||||||
### LoRA Config
|
|
||||||
- alpha (default: 16)
|
|
||||||
- type (must be "LoRA")
|
|
||||||
|
|
||||||
Note: Some parameters from the standard Llama Stack API are not supported and will be ignored with a warning.
|
|
|
@ -1,125 +0,0 @@
|
||||||
---
|
|
||||||
orphan: true
|
|
||||||
---
|
|
||||||
# TorchTune
|
|
||||||
|
|
||||||
[TorchTune](https://github.com/pytorch/torchtune) is an inline post training provider for Llama Stack. It provides a simple and efficient way to fine-tune language models using PyTorch.
|
|
||||||
|
|
||||||
## Features
|
|
||||||
|
|
||||||
- Simple access through the post_training API
|
|
||||||
- Fully integrated with Llama Stack
|
|
||||||
- GPU support and single device capabilities.
|
|
||||||
- Support for LoRA
|
|
||||||
|
|
||||||
## Usage
|
|
||||||
|
|
||||||
To use TorchTune in your Llama Stack project, follow these steps:
|
|
||||||
|
|
||||||
1. Configure your Llama Stack project to use this provider.
|
|
||||||
2. Kick off a fine-tuning job using the Llama Stack post_training API.
|
|
||||||
|
|
||||||
## Setup
|
|
||||||
|
|
||||||
You can access the TorchTune trainer by writing your own yaml pointing to the provider:
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
post_training:
|
|
||||||
- provider_id: torchtune
|
|
||||||
provider_type: inline::torchtune
|
|
||||||
config: {}
|
|
||||||
```
|
|
||||||
|
|
||||||
you can then build and run your own stack with this provider.
|
|
||||||
|
|
||||||
## Run Training
|
|
||||||
|
|
||||||
You can access the provider and the `supervised_fine_tune` method via the post_training API:
|
|
||||||
|
|
||||||
```python
|
|
||||||
import time
|
|
||||||
import uuid
|
|
||||||
|
|
||||||
from llama_stack_client.types import (
|
|
||||||
post_training_supervised_fine_tune_params,
|
|
||||||
algorithm_config_param,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def create_http_client():
|
|
||||||
from llama_stack_client import LlamaStackClient
|
|
||||||
|
|
||||||
return LlamaStackClient(base_url="http://localhost:8321")
|
|
||||||
|
|
||||||
|
|
||||||
client = create_http_client()
|
|
||||||
|
|
||||||
# Example Dataset
|
|
||||||
client.datasets.register(
|
|
||||||
purpose="post-training/messages",
|
|
||||||
source={
|
|
||||||
"type": "uri",
|
|
||||||
"uri": "huggingface://datasets/llamastack/simpleqa?split=train",
|
|
||||||
},
|
|
||||||
dataset_id="simpleqa",
|
|
||||||
)
|
|
||||||
|
|
||||||
training_config = post_training_supervised_fine_tune_params.TrainingConfig(
|
|
||||||
data_config=post_training_supervised_fine_tune_params.TrainingConfigDataConfig(
|
|
||||||
batch_size=32,
|
|
||||||
data_format="instruct",
|
|
||||||
dataset_id="simpleqa",
|
|
||||||
shuffle=True,
|
|
||||||
),
|
|
||||||
gradient_accumulation_steps=1,
|
|
||||||
max_steps_per_epoch=0,
|
|
||||||
max_validation_steps=1,
|
|
||||||
n_epochs=4,
|
|
||||||
)
|
|
||||||
|
|
||||||
algorithm_config = algorithm_config_param.LoraFinetuningConfig(
|
|
||||||
alpha=1,
|
|
||||||
apply_lora_to_mlp=True,
|
|
||||||
apply_lora_to_output=False,
|
|
||||||
lora_attn_modules=["q_proj"],
|
|
||||||
rank=1,
|
|
||||||
type="LoRA",
|
|
||||||
)
|
|
||||||
|
|
||||||
job_uuid = f"test-job{uuid.uuid4()}"
|
|
||||||
|
|
||||||
# Example Model
|
|
||||||
training_model = "meta-llama/Llama-2-7b-hf"
|
|
||||||
|
|
||||||
start_time = time.time()
|
|
||||||
response = client.post_training.supervised_fine_tune(
|
|
||||||
job_uuid=job_uuid,
|
|
||||||
logger_config={},
|
|
||||||
model=training_model,
|
|
||||||
hyperparam_search_config={},
|
|
||||||
training_config=training_config,
|
|
||||||
algorithm_config=algorithm_config,
|
|
||||||
checkpoint_dir="output",
|
|
||||||
)
|
|
||||||
print("Job: ", job_uuid)
|
|
||||||
|
|
||||||
# Wait for the job to complete!
|
|
||||||
while True:
|
|
||||||
status = client.post_training.job.status(job_uuid=job_uuid)
|
|
||||||
if not status:
|
|
||||||
print("Job not found")
|
|
||||||
break
|
|
||||||
|
|
||||||
print(status)
|
|
||||||
if status.status == "completed":
|
|
||||||
break
|
|
||||||
|
|
||||||
print("Waiting for job to complete...")
|
|
||||||
time.sleep(5)
|
|
||||||
|
|
||||||
end_time = time.time()
|
|
||||||
print("Job completed in", end_time - start_time, "seconds!")
|
|
||||||
|
|
||||||
print("Artifacts:")
|
|
||||||
print(client.post_training.job.artifacts(job_uuid=job_uuid))
|
|
||||||
```
|
|
|
@ -66,25 +66,6 @@ To use sqlite-vec in your Llama Stack project, follow these steps:
|
||||||
2. Configure your Llama Stack project to use SQLite-Vec.
|
2. Configure your Llama Stack project to use SQLite-Vec.
|
||||||
3. Start storing and querying vectors.
|
3. Start storing and querying vectors.
|
||||||
|
|
||||||
## Supported Search Modes
|
|
||||||
|
|
||||||
The sqlite-vec provider supports both vector-based and keyword-based (full-text) search modes.
|
|
||||||
|
|
||||||
When using the RAGTool interface, you can specify the desired search behavior via the `mode` parameter in
|
|
||||||
`RAGQueryConfig`. For example:
|
|
||||||
|
|
||||||
```python
|
|
||||||
from llama_stack.apis.tool_runtime.rag import RAGQueryConfig
|
|
||||||
|
|
||||||
query_config = RAGQueryConfig(max_chunks=6, mode="vector")
|
|
||||||
|
|
||||||
results = client.tool_runtime.rag_tool.query(
|
|
||||||
vector_db_ids=[vector_db_id],
|
|
||||||
content="what is torchtune",
|
|
||||||
query_config=query_config,
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
|
|
||||||
You can install SQLite-Vec using pip:
|
You can install SQLite-Vec using pip:
|
||||||
|
|
|
@ -1,6 +0,0 @@
|
||||||
#!/usr/bin/env bash
|
|
||||||
|
|
||||||
export USE_COPY_NOT_MOUNT=true
|
|
||||||
export LLAMA_STACK_DIR=.
|
|
||||||
|
|
||||||
uvx --from . llama stack build --template kvant --image-type container --image-name kvant
|
|
|
@ -1,17 +0,0 @@
|
||||||
#!/usr/bin/env bash
|
|
||||||
|
|
||||||
export LLAMA_STACK_PORT=8321
|
|
||||||
# VLLM_API_TOKEN= env file
|
|
||||||
# KEYCLOAK_CLIENT_SECRET= env file
|
|
||||||
|
|
||||||
|
|
||||||
docker run -it \
|
|
||||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
|
||||||
-v $(pwd)/data:/root/.llama \
|
|
||||||
--mount type=bind,source="$(pwd)"/llama_stack/templates/kvant/run.yaml,target=/root/.llama/config.yaml,readonly \
|
|
||||||
--entrypoint python \
|
|
||||||
--env-file ./.env \
|
|
||||||
distribution-kvant:dev \
|
|
||||||
-m llama_stack.distribution.server.server --config /root/.llama/config.yaml \
|
|
||||||
--port $LLAMA_STACK_PORT \
|
|
||||||
|
|
|
@ -13,7 +13,7 @@ from typing import Annotated, Any, Literal, Protocol, runtime_checkable
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import URL, ContentDelta, InterleavedContent
|
from llama_stack.apis.common.content_types import URL, ContentDelta, InterleavedContent
|
||||||
from llama_stack.apis.common.responses import Order, PaginatedResponse
|
from llama_stack.apis.common.responses import PaginatedResponse
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
CompletionMessage,
|
CompletionMessage,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
|
@ -31,8 +31,6 @@ from llama_stack.apis.tools import ToolDef
|
||||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||||
|
|
||||||
from .openai_responses import (
|
from .openai_responses import (
|
||||||
ListOpenAIResponseInputItem,
|
|
||||||
ListOpenAIResponseObject,
|
|
||||||
OpenAIResponseInput,
|
OpenAIResponseInput,
|
||||||
OpenAIResponseInputTool,
|
OpenAIResponseInputTool,
|
||||||
OpenAIResponseObject,
|
OpenAIResponseObject,
|
||||||
|
@ -581,14 +579,14 @@ class Agents(Protocol):
|
||||||
#
|
#
|
||||||
# Both of these APIs are inherently stateful.
|
# Both of these APIs are inherently stateful.
|
||||||
|
|
||||||
@webmethod(route="/openai/v1/responses/{response_id}", method="GET")
|
@webmethod(route="/openai/v1/responses/{id}", method="GET")
|
||||||
async def get_openai_response(
|
async def get_openai_response(
|
||||||
self,
|
self,
|
||||||
response_id: str,
|
id: str,
|
||||||
) -> OpenAIResponseObject:
|
) -> OpenAIResponseObject:
|
||||||
"""Retrieve an OpenAI response by its ID.
|
"""Retrieve an OpenAI response by its ID.
|
||||||
|
|
||||||
:param response_id: The ID of the OpenAI response to retrieve.
|
:param id: The ID of the OpenAI response to retrieve.
|
||||||
:returns: An OpenAIResponseObject.
|
:returns: An OpenAIResponseObject.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
@ -598,7 +596,6 @@ class Agents(Protocol):
|
||||||
self,
|
self,
|
||||||
input: str | list[OpenAIResponseInput],
|
input: str | list[OpenAIResponseInput],
|
||||||
model: str,
|
model: str,
|
||||||
instructions: str | None = None,
|
|
||||||
previous_response_id: str | None = None,
|
previous_response_id: str | None = None,
|
||||||
store: bool | None = True,
|
store: bool | None = True,
|
||||||
stream: bool | None = False,
|
stream: bool | None = False,
|
||||||
|
@ -613,43 +610,3 @@ class Agents(Protocol):
|
||||||
:returns: An OpenAIResponseObject.
|
:returns: An OpenAIResponseObject.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@webmethod(route="/openai/v1/responses", method="GET")
|
|
||||||
async def list_openai_responses(
|
|
||||||
self,
|
|
||||||
after: str | None = None,
|
|
||||||
limit: int | None = 50,
|
|
||||||
model: str | None = None,
|
|
||||||
order: Order | None = Order.desc,
|
|
||||||
) -> ListOpenAIResponseObject:
|
|
||||||
"""List all OpenAI responses.
|
|
||||||
|
|
||||||
:param after: The ID of the last response to return.
|
|
||||||
:param limit: The number of responses to return.
|
|
||||||
:param model: The model to filter responses by.
|
|
||||||
:param order: The order to sort responses by when sorted by created_at ('asc' or 'desc').
|
|
||||||
:returns: A ListOpenAIResponseObject.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@webmethod(route="/openai/v1/responses/{response_id}/input_items", method="GET")
|
|
||||||
async def list_openai_response_input_items(
|
|
||||||
self,
|
|
||||||
response_id: str,
|
|
||||||
after: str | None = None,
|
|
||||||
before: str | None = None,
|
|
||||||
include: list[str] | None = None,
|
|
||||||
limit: int | None = 20,
|
|
||||||
order: Order | None = Order.desc,
|
|
||||||
) -> ListOpenAIResponseInputItem:
|
|
||||||
"""List input items for a given OpenAI response.
|
|
||||||
|
|
||||||
:param response_id: The ID of the response to retrieve input items for.
|
|
||||||
:param after: An item ID to list items after, used for pagination.
|
|
||||||
:param before: An item ID to list items before, used for pagination.
|
|
||||||
:param include: Additional fields to include in the response.
|
|
||||||
:param limit: A limit on the number of objects to be returned. Limit can range between 1 and 100, and the default is 20.
|
|
||||||
:param order: The order to return the input items in. Default is desc.
|
|
||||||
:returns: An ListOpenAIResponseInputItem.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
|
@ -10,9 +10,6 @@ from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from llama_stack.schema_utils import json_schema_type, register_schema
|
from llama_stack.schema_utils import json_schema_type, register_schema
|
||||||
|
|
||||||
# NOTE(ashwin): this file is literally a copy of the OpenAI responses API schema. We should probably
|
|
||||||
# take their YAML and generate this file automatically. Their YAML is available.
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class OpenAIResponseError(BaseModel):
|
class OpenAIResponseError(BaseModel):
|
||||||
|
@ -82,45 +79,16 @@ class OpenAIResponseOutputMessageWebSearchToolCall(BaseModel):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class OpenAIResponseOutputMessageFunctionToolCall(BaseModel):
|
class OpenAIResponseOutputMessageFunctionToolCall(BaseModel):
|
||||||
|
arguments: str
|
||||||
call_id: str
|
call_id: str
|
||||||
name: str
|
name: str
|
||||||
arguments: str
|
|
||||||
type: Literal["function_call"] = "function_call"
|
type: Literal["function_call"] = "function_call"
|
||||||
id: str | None = None
|
|
||||||
status: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class OpenAIResponseOutputMessageMCPCall(BaseModel):
|
|
||||||
id: str
|
id: str
|
||||||
type: Literal["mcp_call"] = "mcp_call"
|
status: str
|
||||||
arguments: str
|
|
||||||
name: str
|
|
||||||
server_label: str
|
|
||||||
error: str | None = None
|
|
||||||
output: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class MCPListToolsTool(BaseModel):
|
|
||||||
input_schema: dict[str, Any]
|
|
||||||
name: str
|
|
||||||
description: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class OpenAIResponseOutputMessageMCPListTools(BaseModel):
|
|
||||||
id: str
|
|
||||||
type: Literal["mcp_list_tools"] = "mcp_list_tools"
|
|
||||||
server_label: str
|
|
||||||
tools: list[MCPListToolsTool]
|
|
||||||
|
|
||||||
|
|
||||||
OpenAIResponseOutput = Annotated[
|
OpenAIResponseOutput = Annotated[
|
||||||
OpenAIResponseMessage
|
OpenAIResponseMessage | OpenAIResponseOutputMessageWebSearchToolCall | OpenAIResponseOutputMessageFunctionToolCall,
|
||||||
| OpenAIResponseOutputMessageWebSearchToolCall
|
|
||||||
| OpenAIResponseOutputMessageFunctionToolCall
|
|
||||||
| OpenAIResponseOutputMessageMCPCall
|
|
||||||
| OpenAIResponseOutputMessageMCPListTools,
|
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
register_schema(OpenAIResponseOutput, name="OpenAIResponseOutput")
|
register_schema(OpenAIResponseOutput, name="OpenAIResponseOutput")
|
||||||
|
@ -149,16 +117,6 @@ class OpenAIResponseObjectStreamResponseCreated(BaseModel):
|
||||||
type: Literal["response.created"] = "response.created"
|
type: Literal["response.created"] = "response.created"
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class OpenAIResponseObjectStreamResponseOutputTextDelta(BaseModel):
|
|
||||||
content_index: int
|
|
||||||
delta: str
|
|
||||||
item_id: str
|
|
||||||
output_index: int
|
|
||||||
sequence_number: int
|
|
||||||
type: Literal["response.output_text.delta"] = "response.output_text.delta"
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class OpenAIResponseObjectStreamResponseCompleted(BaseModel):
|
class OpenAIResponseObjectStreamResponseCompleted(BaseModel):
|
||||||
response: OpenAIResponseObject
|
response: OpenAIResponseObject
|
||||||
|
@ -166,9 +124,7 @@ class OpenAIResponseObjectStreamResponseCompleted(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
OpenAIResponseObjectStream = Annotated[
|
OpenAIResponseObjectStream = Annotated[
|
||||||
OpenAIResponseObjectStreamResponseCreated
|
OpenAIResponseObjectStreamResponseCreated | OpenAIResponseObjectStreamResponseCompleted,
|
||||||
| OpenAIResponseObjectStreamResponseOutputTextDelta
|
|
||||||
| OpenAIResponseObjectStreamResponseCompleted,
|
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
register_schema(OpenAIResponseObjectStream, name="OpenAIResponseObjectStream")
|
register_schema(OpenAIResponseObjectStream, name="OpenAIResponseObjectStream")
|
||||||
|
@ -230,50 +186,13 @@ class OpenAIResponseInputToolFileSearch(BaseModel):
|
||||||
# TODO: add filters
|
# TODO: add filters
|
||||||
|
|
||||||
|
|
||||||
class ApprovalFilter(BaseModel):
|
|
||||||
always: list[str] | None = None
|
|
||||||
never: list[str] | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class AllowedToolsFilter(BaseModel):
|
|
||||||
tool_names: list[str] | None = None
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class OpenAIResponseInputToolMCP(BaseModel):
|
|
||||||
type: Literal["mcp"] = "mcp"
|
|
||||||
server_label: str
|
|
||||||
server_url: str
|
|
||||||
headers: dict[str, Any] | None = None
|
|
||||||
|
|
||||||
require_approval: Literal["always"] | Literal["never"] | ApprovalFilter = "never"
|
|
||||||
allowed_tools: list[str] | AllowedToolsFilter | None = None
|
|
||||||
|
|
||||||
|
|
||||||
OpenAIResponseInputTool = Annotated[
|
OpenAIResponseInputTool = Annotated[
|
||||||
OpenAIResponseInputToolWebSearch
|
OpenAIResponseInputToolWebSearch | OpenAIResponseInputToolFileSearch | OpenAIResponseInputToolFunction,
|
||||||
| OpenAIResponseInputToolFileSearch
|
|
||||||
| OpenAIResponseInputToolFunction
|
|
||||||
| OpenAIResponseInputToolMCP,
|
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
register_schema(OpenAIResponseInputTool, name="OpenAIResponseInputTool")
|
register_schema(OpenAIResponseInputTool, name="OpenAIResponseInputTool")
|
||||||
|
|
||||||
|
|
||||||
class ListOpenAIResponseInputItem(BaseModel):
|
class OpenAIResponseInputItemList(BaseModel):
|
||||||
data: list[OpenAIResponseInput]
|
data: list[OpenAIResponseInput]
|
||||||
object: Literal["list"] = "list"
|
object: Literal["list"] = "list"
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class OpenAIResponseObjectWithInput(OpenAIResponseObject):
|
|
||||||
input: list[OpenAIResponseInput]
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class ListOpenAIResponseObject(BaseModel):
|
|
||||||
data: list[OpenAIResponseObjectWithInput]
|
|
||||||
has_more: bool
|
|
||||||
first_id: str
|
|
||||||
last_id: str
|
|
||||||
object: Literal["list"] = "list"
|
|
||||||
|
|
30
llama_stack/apis/common/deployment_types.py
Normal file
30
llama_stack/apis/common/deployment_types.py
Normal file
|
@ -0,0 +1,30 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from llama_stack.apis.common.content_types import URL
|
||||||
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class RestAPIMethod(Enum):
|
||||||
|
GET = "GET"
|
||||||
|
POST = "POST"
|
||||||
|
PUT = "PUT"
|
||||||
|
DELETE = "DELETE"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class RestAPIExecutionConfig(BaseModel):
|
||||||
|
url: URL
|
||||||
|
method: RestAPIMethod
|
||||||
|
params: dict[str, Any] | None = None
|
||||||
|
headers: dict[str, Any] | None = None
|
||||||
|
body: dict[str, Any] | None = None
|
|
@ -4,7 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from enum import Enum
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
@ -12,11 +11,6 @@ from pydantic import BaseModel
|
||||||
from llama_stack.schema_utils import json_schema_type
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
|
||||||
class Order(Enum):
|
|
||||||
asc = "asc"
|
|
||||||
desc = "desc"
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class PaginatedResponse(BaseModel):
|
class PaginatedResponse(BaseModel):
|
||||||
"""A generic paginated response that follows a simple format.
|
"""A generic paginated response that follows a simple format.
|
||||||
|
|
|
@ -19,7 +19,6 @@ from pydantic import BaseModel, Field, field_validator
|
||||||
from typing_extensions import TypedDict
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent, InterleavedContentItem
|
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent, InterleavedContentItem
|
||||||
from llama_stack.apis.common.responses import Order
|
|
||||||
from llama_stack.apis.models import Model
|
from llama_stack.apis.models import Model
|
||||||
from llama_stack.apis.telemetry.telemetry import MetricResponseMixin
|
from llama_stack.apis.telemetry.telemetry import MetricResponseMixin
|
||||||
from llama_stack.models.llama.datatypes import (
|
from llama_stack.models.llama.datatypes import (
|
||||||
|
@ -783,48 +782,6 @@ class OpenAICompletion(BaseModel):
|
||||||
object: Literal["text_completion"] = "text_completion"
|
object: Literal["text_completion"] = "text_completion"
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class OpenAIEmbeddingData(BaseModel):
|
|
||||||
"""A single embedding data object from an OpenAI-compatible embeddings response.
|
|
||||||
|
|
||||||
:param object: The object type, which will be "embedding"
|
|
||||||
:param embedding: The embedding vector as a list of floats (when encoding_format="float") or as a base64-encoded string (when encoding_format="base64")
|
|
||||||
:param index: The index of the embedding in the input list
|
|
||||||
"""
|
|
||||||
|
|
||||||
object: Literal["embedding"] = "embedding"
|
|
||||||
embedding: list[float] | str
|
|
||||||
index: int
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class OpenAIEmbeddingUsage(BaseModel):
|
|
||||||
"""Usage information for an OpenAI-compatible embeddings response.
|
|
||||||
|
|
||||||
:param prompt_tokens: The number of tokens in the input
|
|
||||||
:param total_tokens: The total number of tokens used
|
|
||||||
"""
|
|
||||||
|
|
||||||
prompt_tokens: int
|
|
||||||
total_tokens: int
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class OpenAIEmbeddingsResponse(BaseModel):
|
|
||||||
"""Response from an OpenAI-compatible embeddings request.
|
|
||||||
|
|
||||||
:param object: The object type, which will be "list"
|
|
||||||
:param data: List of embedding data objects
|
|
||||||
:param model: The model that was used to generate the embeddings
|
|
||||||
:param usage: Usage information
|
|
||||||
"""
|
|
||||||
|
|
||||||
object: Literal["list"] = "list"
|
|
||||||
data: list[OpenAIEmbeddingData]
|
|
||||||
model: str
|
|
||||||
usage: OpenAIEmbeddingUsage
|
|
||||||
|
|
||||||
|
|
||||||
class ModelStore(Protocol):
|
class ModelStore(Protocol):
|
||||||
async def get_model(self, identifier: str) -> Model: ...
|
async def get_model(self, identifier: str) -> Model: ...
|
||||||
|
|
||||||
|
@ -863,27 +820,15 @@ class BatchChatCompletionResponse(BaseModel):
|
||||||
batch: list[ChatCompletionResponse]
|
batch: list[ChatCompletionResponse]
|
||||||
|
|
||||||
|
|
||||||
class OpenAICompletionWithInputMessages(OpenAIChatCompletion):
|
|
||||||
input_messages: list[OpenAIMessageParam]
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class ListOpenAIChatCompletionResponse(BaseModel):
|
|
||||||
data: list[OpenAICompletionWithInputMessages]
|
|
||||||
has_more: bool
|
|
||||||
first_id: str
|
|
||||||
last_id: str
|
|
||||||
object: Literal["list"] = "list"
|
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
@trace_protocol
|
@trace_protocol
|
||||||
class InferenceProvider(Protocol):
|
class Inference(Protocol):
|
||||||
"""
|
"""Llama Stack Inference API for generating completions, chat completions, and embeddings.
|
||||||
This protocol defines the interface that should be implemented by all inference providers.
|
|
||||||
"""
|
|
||||||
|
|
||||||
API_NAMESPACE: str = "Inference"
|
This API provides the raw interface to the underlying models. Two kinds of models are supported:
|
||||||
|
- LLM models: these models generate "raw" and "chat" (conversational) completions.
|
||||||
|
- Embedding models: these models generate embeddings to be used for semantic search.
|
||||||
|
"""
|
||||||
|
|
||||||
model_store: ModelStore | None = None
|
model_store: ModelStore | None = None
|
||||||
|
|
||||||
|
@ -1117,59 +1062,3 @@ class InferenceProvider(Protocol):
|
||||||
:returns: An OpenAIChatCompletion.
|
:returns: An OpenAIChatCompletion.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@webmethod(route="/openai/v1/embeddings", method="POST")
|
|
||||||
async def openai_embeddings(
|
|
||||||
self,
|
|
||||||
model: str,
|
|
||||||
input: str | list[str],
|
|
||||||
encoding_format: str | None = "float",
|
|
||||||
dimensions: int | None = None,
|
|
||||||
user: str | None = None,
|
|
||||||
) -> OpenAIEmbeddingsResponse:
|
|
||||||
"""Generate OpenAI-compatible embeddings for the given input using the specified model.
|
|
||||||
|
|
||||||
:param model: The identifier of the model to use. The model must be an embedding model registered with Llama Stack and available via the /models endpoint.
|
|
||||||
:param input: Input text to embed, encoded as a string or array of strings. To embed multiple inputs in a single request, pass an array of strings.
|
|
||||||
:param encoding_format: (Optional) The format to return the embeddings in. Can be either "float" or "base64". Defaults to "float".
|
|
||||||
:param dimensions: (Optional) The number of dimensions the resulting output embeddings should have. Only supported in text-embedding-3 and later models.
|
|
||||||
:param user: (Optional) A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse.
|
|
||||||
:returns: An OpenAIEmbeddingsResponse containing the embeddings.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
class Inference(InferenceProvider):
|
|
||||||
"""Llama Stack Inference API for generating completions, chat completions, and embeddings.
|
|
||||||
|
|
||||||
This API provides the raw interface to the underlying models. Two kinds of models are supported:
|
|
||||||
- LLM models: these models generate "raw" and "chat" (conversational) completions.
|
|
||||||
- Embedding models: these models generate embeddings to be used for semantic search.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@webmethod(route="/openai/v1/chat/completions", method="GET")
|
|
||||||
async def list_chat_completions(
|
|
||||||
self,
|
|
||||||
after: str | None = None,
|
|
||||||
limit: int | None = 20,
|
|
||||||
model: str | None = None,
|
|
||||||
order: Order | None = Order.desc,
|
|
||||||
) -> ListOpenAIChatCompletionResponse:
|
|
||||||
"""List all chat completions.
|
|
||||||
|
|
||||||
:param after: The ID of the last chat completion to return.
|
|
||||||
:param limit: The maximum number of chat completions to return.
|
|
||||||
:param model: The model to filter by.
|
|
||||||
:param order: The order to sort the chat completions by: "asc" or "desc". Defaults to "desc".
|
|
||||||
:returns: A ListOpenAIChatCompletionResponse.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError("List chat completions is not implemented")
|
|
||||||
|
|
||||||
@webmethod(route="/openai/v1/chat/completions/{completion_id}", method="GET")
|
|
||||||
async def get_chat_completion(self, completion_id: str) -> OpenAICompletionWithInputMessages:
|
|
||||||
"""Describe a chat completion by its ID.
|
|
||||||
|
|
||||||
:param completion_id: ID of the chat completion.
|
|
||||||
:returns: A OpenAICompletionWithInputMessages.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError("Get chat completion is not implemented")
|
|
||||||
|
|
|
@ -76,7 +76,6 @@ class RAGQueryConfig(BaseModel):
|
||||||
:param chunk_template: Template for formatting each retrieved chunk in the context.
|
:param chunk_template: Template for formatting each retrieved chunk in the context.
|
||||||
Available placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk content string), {metadata} (chunk metadata dict).
|
Available placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk content string), {metadata} (chunk metadata dict).
|
||||||
Default: "Result {index}\\nContent: {chunk.content}\\nMetadata: {metadata}\\n"
|
Default: "Result {index}\\nContent: {chunk.content}\\nMetadata: {metadata}\\n"
|
||||||
:param mode: Search mode for retrieval—either "vector" or "keyword". Default "vector".
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# This config defines how a query is generated using the messages
|
# This config defines how a query is generated using the messages
|
||||||
|
@ -85,7 +84,6 @@ class RAGQueryConfig(BaseModel):
|
||||||
max_tokens_in_context: int = 4096
|
max_tokens_in_context: int = 4096
|
||||||
max_chunks: int = 5
|
max_chunks: int = 5
|
||||||
chunk_template: str = "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n"
|
chunk_template: str = "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n"
|
||||||
mode: str | None = None
|
|
||||||
|
|
||||||
@field_validator("chunk_template")
|
@field_validator("chunk_template")
|
||||||
def validate_chunk_template(cls, v: str) -> str:
|
def validate_chunk_template(cls, v: str) -> str:
|
||||||
|
|
|
@ -27,10 +27,18 @@ class ToolParameter(BaseModel):
|
||||||
default: Any | None = None
|
default: Any | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ToolHost(Enum):
|
||||||
|
distribution = "distribution"
|
||||||
|
client = "client"
|
||||||
|
model_context_protocol = "model_context_protocol"
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class Tool(Resource):
|
class Tool(Resource):
|
||||||
type: Literal[ResourceType.tool] = ResourceType.tool
|
type: Literal[ResourceType.tool] = ResourceType.tool
|
||||||
toolgroup_id: str
|
toolgroup_id: str
|
||||||
|
tool_host: ToolHost
|
||||||
description: str
|
description: str
|
||||||
parameters: list[ToolParameter]
|
parameters: list[ToolParameter]
|
||||||
metadata: dict[str, Any] | None = None
|
metadata: dict[str, Any] | None = None
|
||||||
|
@ -68,8 +76,8 @@ class ToolInvocationResult(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class ToolStore(Protocol):
|
class ToolStore(Protocol):
|
||||||
async def get_tool(self, tool_name: str) -> Tool: ...
|
def get_tool(self, tool_name: str) -> Tool: ...
|
||||||
async def get_tool_group(self, toolgroup_id: str) -> ToolGroup: ...
|
def get_tool_group(self, toolgroup_id: str) -> ToolGroup: ...
|
||||||
|
|
||||||
|
|
||||||
class ListToolGroupsResponse(BaseModel):
|
class ListToolGroupsResponse(BaseModel):
|
||||||
|
|
|
@ -19,16 +19,8 @@ from llama_stack.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
|
|
||||||
class Chunk(BaseModel):
|
class Chunk(BaseModel):
|
||||||
"""
|
|
||||||
A chunk of content that can be inserted into a vector database.
|
|
||||||
:param content: The content of the chunk, which can be interleaved text, images, or other types.
|
|
||||||
:param embedding: Optional embedding for the chunk. If not provided, it will be computed later.
|
|
||||||
:param metadata: Metadata associated with the chunk, such as document ID, source, or other relevant information.
|
|
||||||
"""
|
|
||||||
|
|
||||||
content: InterleavedContent
|
content: InterleavedContent
|
||||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||||
embedding: list[float] | None = None
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -58,10 +50,7 @@ class VectorIO(Protocol):
|
||||||
"""Insert chunks into a vector database.
|
"""Insert chunks into a vector database.
|
||||||
|
|
||||||
:param vector_db_id: The identifier of the vector database to insert the chunks into.
|
:param vector_db_id: The identifier of the vector database to insert the chunks into.
|
||||||
:param chunks: The chunks to insert. Each `Chunk` should contain content which can be interleaved text, images, or other types.
|
:param chunks: The chunks to insert.
|
||||||
`metadata`: `dict[str, Any]` and `embedding`: `List[float]` are optional.
|
|
||||||
If `metadata` is provided, you configure how Llama Stack formats the chunk during generation.
|
|
||||||
If `embedding` is not provided, it will be computed later.
|
|
||||||
:param ttl_seconds: The time to live of the chunks.
|
:param ttl_seconds: The time to live of the chunks.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
|
@ -9,7 +9,6 @@ import asyncio
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import sys
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
@ -378,15 +377,14 @@ def _meta_download(
|
||||||
downloader = ParallelDownloader(max_concurrent_downloads=max_concurrent_downloads)
|
downloader = ParallelDownloader(max_concurrent_downloads=max_concurrent_downloads)
|
||||||
asyncio.run(downloader.download_all(tasks))
|
asyncio.run(downloader.download_all(tasks))
|
||||||
|
|
||||||
cprint(f"\nSuccessfully downloaded model to {output_dir}", color="green", file=sys.stderr)
|
cprint(f"\nSuccessfully downloaded model to {output_dir}", "green")
|
||||||
cprint(
|
cprint(
|
||||||
f"\nView MD5 checksum files at: {output_dir / 'checklist.chk'}",
|
f"\nView MD5 checksum files at: {output_dir / 'checklist.chk'}",
|
||||||
file=sys.stderr,
|
"white",
|
||||||
)
|
)
|
||||||
cprint(
|
cprint(
|
||||||
f"\n[Optionally] To run MD5 checksums, use the following command: llama model verify-download --model-id {model_id}",
|
f"\n[Optionally] To run MD5 checksums, use the following command: llama model verify-download --model-id {model_id}",
|
||||||
color="yellow",
|
"yellow",
|
||||||
file=sys.stderr,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -12,7 +12,6 @@ import shutil
|
||||||
import sys
|
import sys
|
||||||
import textwrap
|
import textwrap
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from importlib.abc import Traversable
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
@ -79,7 +78,6 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
||||||
cprint(
|
cprint(
|
||||||
f"Could not find template {args.template}. Please run `llama stack build --list-templates` to check out the available templates",
|
f"Could not find template {args.template}. Please run `llama stack build --list-templates` to check out the available templates",
|
||||||
color="red",
|
color="red",
|
||||||
file=sys.stderr,
|
|
||||||
)
|
)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
build_config = available_templates[args.template]
|
build_config = available_templates[args.template]
|
||||||
|
@ -89,7 +87,6 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
||||||
cprint(
|
cprint(
|
||||||
f"Please specify a image-type ({' | '.join(e.value for e in ImageType)}) for {args.template}",
|
f"Please specify a image-type ({' | '.join(e.value for e in ImageType)}) for {args.template}",
|
||||||
color="red",
|
color="red",
|
||||||
file=sys.stderr,
|
|
||||||
)
|
)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
elif args.providers:
|
elif args.providers:
|
||||||
|
@ -99,7 +96,6 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
||||||
cprint(
|
cprint(
|
||||||
"Could not parse `--providers`. Please ensure the list is in the format api1=provider1,api2=provider2",
|
"Could not parse `--providers`. Please ensure the list is in the format api1=provider1,api2=provider2",
|
||||||
color="red",
|
color="red",
|
||||||
file=sys.stderr,
|
|
||||||
)
|
)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
api, provider = api_provider.split("=")
|
api, provider = api_provider.split("=")
|
||||||
|
@ -108,7 +104,6 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
||||||
cprint(
|
cprint(
|
||||||
f"{api} is not a valid API.",
|
f"{api} is not a valid API.",
|
||||||
color="red",
|
color="red",
|
||||||
file=sys.stderr,
|
|
||||||
)
|
)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
if provider in providers_for_api:
|
if provider in providers_for_api:
|
||||||
|
@ -117,7 +112,6 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
||||||
cprint(
|
cprint(
|
||||||
f"{provider} is not a valid provider for the {api} API.",
|
f"{provider} is not a valid provider for the {api} API.",
|
||||||
color="red",
|
color="red",
|
||||||
file=sys.stderr,
|
|
||||||
)
|
)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
distribution_spec = DistributionSpec(
|
distribution_spec = DistributionSpec(
|
||||||
|
@ -128,7 +122,6 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
||||||
cprint(
|
cprint(
|
||||||
f"Please specify a image-type (container | conda | venv) for {args.template}",
|
f"Please specify a image-type (container | conda | venv) for {args.template}",
|
||||||
color="red",
|
color="red",
|
||||||
file=sys.stderr,
|
|
||||||
)
|
)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
@ -157,14 +150,12 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
||||||
cprint(
|
cprint(
|
||||||
f"No current conda environment detected or specified, will create a new conda environment with the name `llamastack-{name}`",
|
f"No current conda environment detected or specified, will create a new conda environment with the name `llamastack-{name}`",
|
||||||
color="yellow",
|
color="yellow",
|
||||||
file=sys.stderr,
|
|
||||||
)
|
)
|
||||||
image_name = f"llamastack-{name}"
|
image_name = f"llamastack-{name}"
|
||||||
else:
|
else:
|
||||||
cprint(
|
cprint(
|
||||||
f"Using conda environment {image_name}",
|
f"Using conda environment {image_name}",
|
||||||
color="green",
|
color="green",
|
||||||
file=sys.stderr,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
image_name = f"llamastack-{name}"
|
image_name = f"llamastack-{name}"
|
||||||
|
@ -177,10 +168,9 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
||||||
""",
|
""",
|
||||||
),
|
),
|
||||||
color="green",
|
color="green",
|
||||||
file=sys.stderr,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
cprint("Tip: use <TAB> to see options for the providers.\n", color="green", file=sys.stderr)
|
print("Tip: use <TAB> to see options for the providers.\n")
|
||||||
|
|
||||||
providers = dict()
|
providers = dict()
|
||||||
for api, providers_for_api in get_provider_registry().items():
|
for api, providers_for_api in get_provider_registry().items():
|
||||||
|
@ -216,13 +206,10 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
||||||
contents = yaml.safe_load(f)
|
contents = yaml.safe_load(f)
|
||||||
contents = replace_env_vars(contents)
|
contents = replace_env_vars(contents)
|
||||||
build_config = BuildConfig(**contents)
|
build_config = BuildConfig(**contents)
|
||||||
if args.image_type:
|
|
||||||
build_config.image_type = args.image_type
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
cprint(
|
cprint(
|
||||||
f"Could not parse config file {args.config}: {e}",
|
f"Could not parse config file {args.config}: {e}",
|
||||||
color="red",
|
color="red",
|
||||||
file=sys.stderr,
|
|
||||||
)
|
)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
@ -249,27 +236,25 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
||||||
cprint(
|
cprint(
|
||||||
f"Error building stack: {exc}",
|
f"Error building stack: {exc}",
|
||||||
color="red",
|
color="red",
|
||||||
file=sys.stderr,
|
|
||||||
)
|
)
|
||||||
cprint("Stack trace:", color="red", file=sys.stderr)
|
cprint("Stack trace:", color="red")
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
if run_config is None:
|
if run_config is None:
|
||||||
cprint(
|
cprint(
|
||||||
"Run config path is empty",
|
"Run config path is empty",
|
||||||
color="red",
|
color="red",
|
||||||
file=sys.stderr,
|
|
||||||
)
|
)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
if args.run:
|
if args.run:
|
||||||
|
run_config = Path(run_config)
|
||||||
config_dict = yaml.safe_load(run_config.read_text())
|
config_dict = yaml.safe_load(run_config.read_text())
|
||||||
config = parse_and_maybe_upgrade_config(config_dict)
|
config = parse_and_maybe_upgrade_config(config_dict)
|
||||||
if config.external_providers_dir and not config.external_providers_dir.exists():
|
if not os.path.exists(str(config.external_providers_dir)):
|
||||||
config.external_providers_dir.mkdir(exist_ok=True)
|
os.makedirs(str(config.external_providers_dir), exist_ok=True)
|
||||||
run_args = formulate_run_args(args.image_type, args.image_name, config, args.template)
|
run_args = formulate_run_args(args.image_type, args.image_name, config, args.template)
|
||||||
run_args.extend([str(os.getenv("LLAMA_STACK_PORT", 8321)), "--config", run_config])
|
run_args.extend([run_config, str(os.getenv("LLAMA_STACK_PORT", 8321))])
|
||||||
run_command(run_args)
|
run_command(run_args)
|
||||||
|
|
||||||
|
|
||||||
|
@ -277,7 +262,7 @@ def _generate_run_config(
|
||||||
build_config: BuildConfig,
|
build_config: BuildConfig,
|
||||||
build_dir: Path,
|
build_dir: Path,
|
||||||
image_name: str,
|
image_name: str,
|
||||||
) -> Path:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Generate a run.yaml template file for user to edit from a build.yaml file
|
Generate a run.yaml template file for user to edit from a build.yaml file
|
||||||
"""
|
"""
|
||||||
|
@ -317,7 +302,6 @@ def _generate_run_config(
|
||||||
cprint(
|
cprint(
|
||||||
f"Failed to import provider {provider_type} for API {api} - assuming it's external, skipping",
|
f"Failed to import provider {provider_type} for API {api} - assuming it's external, skipping",
|
||||||
color="yellow",
|
color="yellow",
|
||||||
file=sys.stderr,
|
|
||||||
)
|
)
|
||||||
# Set config_type to None to avoid UnboundLocalError
|
# Set config_type to None to avoid UnboundLocalError
|
||||||
config_type = None
|
config_type = None
|
||||||
|
@ -345,7 +329,10 @@ def _generate_run_config(
|
||||||
# For non-container builds, the run.yaml is generated at the very end of the build process so it
|
# For non-container builds, the run.yaml is generated at the very end of the build process so it
|
||||||
# makes sense to display this message
|
# makes sense to display this message
|
||||||
if build_config.image_type != LlamaStackImageType.CONTAINER.value:
|
if build_config.image_type != LlamaStackImageType.CONTAINER.value:
|
||||||
cprint(f"You can now run your stack with `llama stack run {run_config_file}`", color="green", file=sys.stderr)
|
cprint(
|
||||||
|
f"You can now run your stack with `llama stack run {run_config_file}`",
|
||||||
|
color="green",
|
||||||
|
)
|
||||||
return run_config_file
|
return run_config_file
|
||||||
|
|
||||||
|
|
||||||
|
@ -354,7 +341,7 @@ def _run_stack_build_command_from_build_config(
|
||||||
image_name: str | None = None,
|
image_name: str | None = None,
|
||||||
template_name: str | None = None,
|
template_name: str | None = None,
|
||||||
config_path: str | None = None,
|
config_path: str | None = None,
|
||||||
) -> Path | Traversable:
|
) -> str:
|
||||||
image_name = image_name or build_config.image_name
|
image_name = image_name or build_config.image_name
|
||||||
if build_config.image_type == LlamaStackImageType.CONTAINER.value:
|
if build_config.image_type == LlamaStackImageType.CONTAINER.value:
|
||||||
if template_name:
|
if template_name:
|
||||||
|
@ -383,7 +370,7 @@ def _run_stack_build_command_from_build_config(
|
||||||
# Generate the run.yaml so it can be included in the container image with the proper entrypoint
|
# Generate the run.yaml so it can be included in the container image with the proper entrypoint
|
||||||
# Only do this if we're building a container image and we're not using a template
|
# Only do this if we're building a container image and we're not using a template
|
||||||
if build_config.image_type == LlamaStackImageType.CONTAINER.value and not template_name and config_path:
|
if build_config.image_type == LlamaStackImageType.CONTAINER.value and not template_name and config_path:
|
||||||
cprint("Generating run.yaml file", color="yellow", file=sys.stderr)
|
cprint("Generating run.yaml file", color="green")
|
||||||
run_config_file = _generate_run_config(build_config, build_dir, image_name)
|
run_config_file = _generate_run_config(build_config, build_dir, image_name)
|
||||||
|
|
||||||
with open(build_file_path, "w") as f:
|
with open(build_file_path, "w") as f:
|
||||||
|
@ -407,13 +394,11 @@ def _run_stack_build_command_from_build_config(
|
||||||
run_config_file = build_dir / f"{template_name}-run.yaml"
|
run_config_file = build_dir / f"{template_name}-run.yaml"
|
||||||
shutil.copy(path, run_config_file)
|
shutil.copy(path, run_config_file)
|
||||||
|
|
||||||
cprint("Build Successful!", color="green", file=sys.stderr)
|
cprint("Build Successful!", color="green")
|
||||||
cprint(f"You can find the newly-built template here: {template_path}", color="light_blue", file=sys.stderr)
|
cprint("You can find the newly-built template here: " + colored(template_path, "light_blue"))
|
||||||
cprint(
|
cprint(
|
||||||
"You can run the new Llama Stack distro via: "
|
"You can run the new Llama Stack distro via: "
|
||||||
+ colored(f"llama stack run {template_path} --image-type {build_config.image_type}", "light_blue"),
|
+ colored(f"llama stack run {template_path} --image-type {build_config.image_type}", "light_blue")
|
||||||
color="green",
|
|
||||||
file=sys.stderr,
|
|
||||||
)
|
)
|
||||||
return template_path
|
return template_path
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -49,7 +49,7 @@ class StackBuild(Subcommand):
|
||||||
type=str,
|
type=str,
|
||||||
help="Image Type to use for the build. If not specified, will use the image type from the template config.",
|
help="Image Type to use for the build. If not specified, will use the image type from the template config.",
|
||||||
choices=[e.value for e in ImageType],
|
choices=[e.value for e in ImageType],
|
||||||
default=None, # no default so we can detect if a user specified --image-type and override image_type in the config
|
default=ImageType.CONDA.value,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.parser.add_argument(
|
self.parser.add_argument(
|
||||||
|
|
|
@ -1,56 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from llama_stack.cli.subcommand import Subcommand
|
|
||||||
from llama_stack.cli.table import print_table
|
|
||||||
|
|
||||||
|
|
||||||
class StackListBuilds(Subcommand):
|
|
||||||
"""List built stacks in .llama/distributions directory"""
|
|
||||||
|
|
||||||
def __init__(self, subparsers: argparse._SubParsersAction):
|
|
||||||
super().__init__()
|
|
||||||
self.parser = subparsers.add_parser(
|
|
||||||
"list",
|
|
||||||
prog="llama stack list",
|
|
||||||
description="list the build stacks",
|
|
||||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
|
||||||
)
|
|
||||||
self._add_arguments()
|
|
||||||
self.parser.set_defaults(func=self._list_stack_command)
|
|
||||||
|
|
||||||
def _get_distribution_dirs(self) -> dict[str, Path]:
|
|
||||||
"""Return a dictionary of distribution names and their paths"""
|
|
||||||
distributions = {}
|
|
||||||
dist_dir = Path.home() / ".llama" / "distributions"
|
|
||||||
|
|
||||||
if dist_dir.exists():
|
|
||||||
for stack_dir in dist_dir.iterdir():
|
|
||||||
if stack_dir.is_dir():
|
|
||||||
distributions[stack_dir.name] = stack_dir
|
|
||||||
return distributions
|
|
||||||
|
|
||||||
def _list_stack_command(self, args: argparse.Namespace) -> None:
|
|
||||||
distributions = self._get_distribution_dirs()
|
|
||||||
|
|
||||||
if not distributions:
|
|
||||||
print("No stacks found in ~/.llama/distributions")
|
|
||||||
return
|
|
||||||
|
|
||||||
headers = ["Stack Name", "Path"]
|
|
||||||
headers.extend(["Build Config", "Run Config"])
|
|
||||||
rows = []
|
|
||||||
for name, path in distributions.items():
|
|
||||||
row = [name, str(path)]
|
|
||||||
# Check for build and run config files
|
|
||||||
build_config = "Yes" if (path / f"{name}-build.yaml").exists() else "No"
|
|
||||||
run_config = "Yes" if (path / f"{name}-run.yaml").exists() else "No"
|
|
||||||
row.extend([build_config, run_config])
|
|
||||||
rows.append(row)
|
|
||||||
print_table(rows, headers, separate_rows=True)
|
|
|
@ -1,115 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import shutil
|
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from termcolor import cprint
|
|
||||||
|
|
||||||
from llama_stack.cli.subcommand import Subcommand
|
|
||||||
from llama_stack.cli.table import print_table
|
|
||||||
|
|
||||||
|
|
||||||
class StackRemove(Subcommand):
|
|
||||||
"""Remove the build stack"""
|
|
||||||
|
|
||||||
def __init__(self, subparsers: argparse._SubParsersAction):
|
|
||||||
super().__init__()
|
|
||||||
self.parser = subparsers.add_parser(
|
|
||||||
"rm",
|
|
||||||
prog="llama stack rm",
|
|
||||||
description="Remove the build stack",
|
|
||||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
|
||||||
)
|
|
||||||
self._add_arguments()
|
|
||||||
self.parser.set_defaults(func=self._remove_stack_build_command)
|
|
||||||
|
|
||||||
def _add_arguments(self) -> None:
|
|
||||||
self.parser.add_argument(
|
|
||||||
"name",
|
|
||||||
type=str,
|
|
||||||
nargs="?",
|
|
||||||
help="Name of the stack to delete",
|
|
||||||
)
|
|
||||||
self.parser.add_argument(
|
|
||||||
"--all",
|
|
||||||
"-a",
|
|
||||||
action="store_true",
|
|
||||||
help="Delete all stacks (use with caution)",
|
|
||||||
)
|
|
||||||
|
|
||||||
def _get_distribution_dirs(self) -> dict[str, Path]:
|
|
||||||
"""Return a dictionary of distribution names and their paths"""
|
|
||||||
distributions = {}
|
|
||||||
dist_dir = Path.home() / ".llama" / "distributions"
|
|
||||||
|
|
||||||
if dist_dir.exists():
|
|
||||||
for stack_dir in dist_dir.iterdir():
|
|
||||||
if stack_dir.is_dir():
|
|
||||||
distributions[stack_dir.name] = stack_dir
|
|
||||||
return distributions
|
|
||||||
|
|
||||||
def _list_stacks(self) -> None:
|
|
||||||
"""Display available stacks in a table"""
|
|
||||||
distributions = self._get_distribution_dirs()
|
|
||||||
if not distributions:
|
|
||||||
cprint("No stacks found in ~/.llama/distributions", color="red", file=sys.stderr)
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
headers = ["Stack Name", "Path"]
|
|
||||||
rows = [[name, str(path)] for name, path in distributions.items()]
|
|
||||||
print_table(rows, headers, separate_rows=True)
|
|
||||||
|
|
||||||
def _remove_stack_build_command(self, args: argparse.Namespace) -> None:
|
|
||||||
distributions = self._get_distribution_dirs()
|
|
||||||
|
|
||||||
if args.all:
|
|
||||||
confirm = input("Are you sure you want to delete ALL stacks? [yes-i-really-want/N] ").lower()
|
|
||||||
if confirm != "yes-i-really-want":
|
|
||||||
cprint("Deletion cancelled.", color="green", file=sys.stderr)
|
|
||||||
return
|
|
||||||
|
|
||||||
for name, path in distributions.items():
|
|
||||||
try:
|
|
||||||
shutil.rmtree(path)
|
|
||||||
cprint(f"Deleted stack: {name}", color="green", file=sys.stderr)
|
|
||||||
except Exception as e:
|
|
||||||
cprint(
|
|
||||||
f"Failed to delete stack {name}: {e}",
|
|
||||||
color="red",
|
|
||||||
file=sys.stderr,
|
|
||||||
)
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
if not args.name:
|
|
||||||
self._list_stacks()
|
|
||||||
if not args.name:
|
|
||||||
return
|
|
||||||
|
|
||||||
if args.name not in distributions:
|
|
||||||
self._list_stacks()
|
|
||||||
cprint(
|
|
||||||
f"Stack not found: {args.name}",
|
|
||||||
color="red",
|
|
||||||
file=sys.stderr,
|
|
||||||
)
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
stack_path = distributions[args.name]
|
|
||||||
|
|
||||||
confirm = input(f"Are you sure you want to delete stack '{args.name}'? [y/N] ").lower()
|
|
||||||
if confirm != "y":
|
|
||||||
cprint("Deletion cancelled.", color="green", file=sys.stderr)
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
shutil.rmtree(stack_path)
|
|
||||||
cprint(f"Successfully deleted stack: {args.name}", color="green", file=sys.stderr)
|
|
||||||
except Exception as e:
|
|
||||||
cprint(f"Failed to delete stack {args.name}: {e}", color="red", file=sys.stderr)
|
|
||||||
sys.exit(1)
|
|
|
@ -6,7 +6,6 @@
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import subprocess
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from llama_stack.cli.stack.utils import ImageType
|
from llama_stack.cli.stack.utils import ImageType
|
||||||
|
@ -61,11 +60,6 @@ class StackRun(Subcommand):
|
||||||
help="Image Type used during the build. This can be either conda or container or venv.",
|
help="Image Type used during the build. This can be either conda or container or venv.",
|
||||||
choices=[e.value for e in ImageType],
|
choices=[e.value for e in ImageType],
|
||||||
)
|
)
|
||||||
self.parser.add_argument(
|
|
||||||
"--enable-ui",
|
|
||||||
action="store_true",
|
|
||||||
help="Start the UI server",
|
|
||||||
)
|
|
||||||
|
|
||||||
# If neither image type nor image name is provided, but at the same time
|
# If neither image type nor image name is provided, but at the same time
|
||||||
# the current environment has conda breadcrumbs, then assume what the user
|
# the current environment has conda breadcrumbs, then assume what the user
|
||||||
|
@ -89,8 +83,6 @@ class StackRun(Subcommand):
|
||||||
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
|
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||||
from llama_stack.distribution.utils.exec import formulate_run_args, run_command
|
from llama_stack.distribution.utils.exec import formulate_run_args, run_command
|
||||||
|
|
||||||
if args.enable_ui:
|
|
||||||
self._start_ui_development_server(args.port)
|
|
||||||
image_type, image_name = self._get_image_type_and_name(args)
|
image_type, image_name = self._get_image_type_and_name(args)
|
||||||
|
|
||||||
# Check if config is required based on image type
|
# Check if config is required based on image type
|
||||||
|
@ -178,44 +170,3 @@ class StackRun(Subcommand):
|
||||||
run_args.extend(["--env", f"{key}={value}"])
|
run_args.extend(["--env", f"{key}={value}"])
|
||||||
|
|
||||||
run_command(run_args)
|
run_command(run_args)
|
||||||
|
|
||||||
def _start_ui_development_server(self, stack_server_port: int):
|
|
||||||
logger.info("Attempting to start UI development server...")
|
|
||||||
# Check if npm is available
|
|
||||||
npm_check = subprocess.run(["npm", "--version"], capture_output=True, text=True, check=False)
|
|
||||||
if npm_check.returncode != 0:
|
|
||||||
logger.warning(
|
|
||||||
f"'npm' command not found or not executable. UI development server will not be started. Error: {npm_check.stderr}"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
ui_dir = REPO_ROOT / "llama_stack" / "ui"
|
|
||||||
logs_dir = Path("~/.llama/ui/logs").expanduser()
|
|
||||||
try:
|
|
||||||
# Create logs directory if it doesn't exist
|
|
||||||
logs_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
ui_stdout_log_path = logs_dir / "stdout.log"
|
|
||||||
ui_stderr_log_path = logs_dir / "stderr.log"
|
|
||||||
|
|
||||||
# Open log files in append mode
|
|
||||||
stdout_log_file = open(ui_stdout_log_path, "a")
|
|
||||||
stderr_log_file = open(ui_stderr_log_path, "a")
|
|
||||||
|
|
||||||
process = subprocess.Popen(
|
|
||||||
["npm", "run", "dev"],
|
|
||||||
cwd=str(ui_dir),
|
|
||||||
stdout=stdout_log_file,
|
|
||||||
stderr=stderr_log_file,
|
|
||||||
env={**os.environ, "NEXT_PUBLIC_LLAMA_STACK_BASE_URL": f"http://localhost:{stack_server_port}"},
|
|
||||||
)
|
|
||||||
logger.info(f"UI development server process started in {ui_dir} with PID {process.pid}.")
|
|
||||||
logger.info(f"Logs: stdout -> {ui_stdout_log_path}, stderr -> {ui_stderr_log_path}")
|
|
||||||
logger.info(f"UI will be available at http://localhost:{os.getenv('LLAMA_STACK_UI_PORT', 8322)}")
|
|
||||||
|
|
||||||
except FileNotFoundError:
|
|
||||||
logger.error(
|
|
||||||
"Failed to start UI development server: 'npm' command not found. Make sure npm is installed and in your PATH."
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to start UI development server in {ui_dir}: {e}")
|
|
||||||
|
|
|
@ -7,14 +7,12 @@
|
||||||
import argparse
|
import argparse
|
||||||
from importlib.metadata import version
|
from importlib.metadata import version
|
||||||
|
|
||||||
from llama_stack.cli.stack.list_stacks import StackListBuilds
|
|
||||||
from llama_stack.cli.stack.utils import print_subcommand_description
|
from llama_stack.cli.stack.utils import print_subcommand_description
|
||||||
from llama_stack.cli.subcommand import Subcommand
|
from llama_stack.cli.subcommand import Subcommand
|
||||||
|
|
||||||
from .build import StackBuild
|
from .build import StackBuild
|
||||||
from .list_apis import StackListApis
|
from .list_apis import StackListApis
|
||||||
from .list_providers import StackListProviders
|
from .list_providers import StackListProviders
|
||||||
from .remove import StackRemove
|
|
||||||
from .run import StackRun
|
from .run import StackRun
|
||||||
|
|
||||||
|
|
||||||
|
@ -43,6 +41,5 @@ class StackParser(Subcommand):
|
||||||
StackListApis.create(subparsers)
|
StackListApis.create(subparsers)
|
||||||
StackListProviders.create(subparsers)
|
StackListProviders.create(subparsers)
|
||||||
StackRun.create(subparsers)
|
StackRun.create(subparsers)
|
||||||
StackRemove.create(subparsers)
|
|
||||||
StackListBuilds.create(subparsers)
|
|
||||||
print_subcommand_description(self.parser, subparsers)
|
print_subcommand_description(self.parser, subparsers)
|
||||||
|
|
|
@ -6,7 +6,6 @@
|
||||||
|
|
||||||
import importlib.resources
|
import importlib.resources
|
||||||
import logging
|
import logging
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
@ -44,20 +43,8 @@ def get_provider_dependencies(
|
||||||
# Extract providers based on config type
|
# Extract providers based on config type
|
||||||
if isinstance(config, DistributionTemplate):
|
if isinstance(config, DistributionTemplate):
|
||||||
providers = config.providers
|
providers = config.providers
|
||||||
|
|
||||||
# TODO: This is a hack to get the dependencies for internal APIs into build
|
|
||||||
# We should have a better way to do this by formalizing the concept of "internal" APIs
|
|
||||||
# and providers, with a way to specify dependencies for them.
|
|
||||||
run_configs = config.run_configs
|
|
||||||
additional_pip_packages: list[str] = []
|
|
||||||
if run_configs:
|
|
||||||
for run_config in run_configs.values():
|
|
||||||
run_config_ = run_config.run_config(name="", providers={}, container_image=None)
|
|
||||||
if run_config_.inference_store:
|
|
||||||
additional_pip_packages.extend(run_config_.inference_store.pip_packages)
|
|
||||||
elif isinstance(config, BuildConfig):
|
elif isinstance(config, BuildConfig):
|
||||||
providers = config.distribution_spec.providers
|
providers = config.distribution_spec.providers
|
||||||
additional_pip_packages = config.additional_pip_packages
|
|
||||||
deps = []
|
deps = []
|
||||||
registry = get_provider_registry(config)
|
registry = get_provider_registry(config)
|
||||||
for api_str, provider_or_providers in providers.items():
|
for api_str, provider_or_providers in providers.items():
|
||||||
|
@ -85,9 +72,6 @@ def get_provider_dependencies(
|
||||||
else:
|
else:
|
||||||
normal_deps.append(package)
|
normal_deps.append(package)
|
||||||
|
|
||||||
if additional_pip_packages:
|
|
||||||
normal_deps.extend(additional_pip_packages)
|
|
||||||
|
|
||||||
return list(set(normal_deps)), list(set(special_deps))
|
return list(set(normal_deps)), list(set(special_deps))
|
||||||
|
|
||||||
|
|
||||||
|
@ -96,11 +80,10 @@ def print_pip_install_help(config: BuildConfig):
|
||||||
|
|
||||||
cprint(
|
cprint(
|
||||||
f"Please install needed dependencies using the following commands:\n\nuv pip install {' '.join(normal_deps)}",
|
f"Please install needed dependencies using the following commands:\n\nuv pip install {' '.join(normal_deps)}",
|
||||||
color="yellow",
|
"yellow",
|
||||||
file=sys.stderr,
|
|
||||||
)
|
)
|
||||||
for special_dep in special_deps:
|
for special_dep in special_deps:
|
||||||
cprint(f"uv pip install {special_dep}", color="yellow", file=sys.stderr)
|
cprint(f"uv pip install {special_dep}", "yellow")
|
||||||
print()
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -6,7 +6,6 @@
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
import sys
|
|
||||||
from collections.abc import AsyncIterator
|
from collections.abc import AsyncIterator
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Union, get_args, get_origin
|
from typing import Any, Union, get_args, get_origin
|
||||||
|
@ -97,13 +96,13 @@ def create_api_client_class(protocol) -> type:
|
||||||
try:
|
try:
|
||||||
data = json.loads(data)
|
data = json.loads(data)
|
||||||
if "error" in data:
|
if "error" in data:
|
||||||
cprint(data, color="red", file=sys.stderr)
|
cprint(data, "red")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
yield parse_obj_as(return_type, data)
|
yield parse_obj_as(return_type, data)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
cprint(f"Error with parsing or validation: {e}", color="red", file=sys.stderr)
|
print(f"Error with parsing or validation: {e}")
|
||||||
cprint(data, color="red", file=sys.stderr)
|
print(data)
|
||||||
|
|
||||||
def httpx_request_params(self, method_name: str, *args, **kwargs) -> dict:
|
def httpx_request_params(self, method_name: str, *args, **kwargs) -> dict:
|
||||||
webmethod, sig = self.routes[method_name]
|
webmethod, sig = self.routes[method_name]
|
||||||
|
|
|
@ -25,8 +25,7 @@ from llama_stack.apis.tools import Tool, ToolGroup, ToolGroupInput, ToolRuntime
|
||||||
from llama_stack.apis.vector_dbs import VectorDB, VectorDBInput
|
from llama_stack.apis.vector_dbs import VectorDB, VectorDBInput
|
||||||
from llama_stack.apis.vector_io import VectorIO
|
from llama_stack.apis.vector_io import VectorIO
|
||||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||||
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
from llama_stack.providers.utils.kvstore.config import KVStoreConfig
|
||||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqlStoreConfig
|
|
||||||
|
|
||||||
LLAMA_STACK_BUILD_CONFIG_VERSION = "2"
|
LLAMA_STACK_BUILD_CONFIG_VERSION = "2"
|
||||||
LLAMA_STACK_RUN_CONFIG_VERSION = "2"
|
LLAMA_STACK_RUN_CONFIG_VERSION = "2"
|
||||||
|
@ -221,38 +220,21 @@ class LoggingConfig(BaseModel):
|
||||||
class AuthProviderType(str, Enum):
|
class AuthProviderType(str, Enum):
|
||||||
"""Supported authentication provider types."""
|
"""Supported authentication provider types."""
|
||||||
|
|
||||||
OAUTH2_TOKEN = "oauth2_token"
|
KUBERNETES = "kubernetes"
|
||||||
CUSTOM = "custom"
|
CUSTOM = "custom"
|
||||||
|
|
||||||
|
|
||||||
class AuthenticationConfig(BaseModel):
|
class AuthenticationConfig(BaseModel):
|
||||||
provider_type: AuthProviderType = Field(
|
provider_type: AuthProviderType = Field(
|
||||||
...,
|
...,
|
||||||
description="Type of authentication provider",
|
description="Type of authentication provider (e.g., 'kubernetes', 'custom')",
|
||||||
)
|
)
|
||||||
config: dict[str, Any] = Field(
|
config: dict[str, str] = Field(
|
||||||
...,
|
...,
|
||||||
description="Provider-specific configuration",
|
description="Provider-specific configuration",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class AuthenticationRequiredError(Exception):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class QuotaPeriod(str, Enum):
|
|
||||||
DAY = "day"
|
|
||||||
|
|
||||||
|
|
||||||
class QuotaConfig(BaseModel):
|
|
||||||
kvstore: SqliteKVStoreConfig = Field(description="Config for KV store backend (SQLite only for now)")
|
|
||||||
anonymous_max_requests: int = Field(default=100, description="Max requests for unauthenticated clients per period")
|
|
||||||
authenticated_max_requests: int = Field(
|
|
||||||
default=1000, description="Max requests for authenticated clients per period"
|
|
||||||
)
|
|
||||||
period: QuotaPeriod = Field(default=QuotaPeriod.DAY, description="Quota period to set")
|
|
||||||
|
|
||||||
|
|
||||||
class ServerConfig(BaseModel):
|
class ServerConfig(BaseModel):
|
||||||
port: int = Field(
|
port: int = Field(
|
||||||
default=8321,
|
default=8321,
|
||||||
|
@ -280,10 +262,6 @@ class ServerConfig(BaseModel):
|
||||||
default=None,
|
default=None,
|
||||||
description="The host the server should listen on",
|
description="The host the server should listen on",
|
||||||
)
|
)
|
||||||
quota: QuotaConfig | None = Field(
|
|
||||||
default=None,
|
|
||||||
description="Per client quota request configuration",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class StackRunConfig(BaseModel):
|
class StackRunConfig(BaseModel):
|
||||||
|
@ -319,13 +297,6 @@ Configuration for the persistence store used by the distribution registry. If no
|
||||||
a default SQLite store will be used.""",
|
a default SQLite store will be used.""",
|
||||||
)
|
)
|
||||||
|
|
||||||
inference_store: SqlStoreConfig | None = Field(
|
|
||||||
default=None,
|
|
||||||
description="""
|
|
||||||
Configuration for the persistence store used by the inference API. If not specified,
|
|
||||||
a default SQLite store will be used.""",
|
|
||||||
)
|
|
||||||
|
|
||||||
# registry of "resources" in the distribution
|
# registry of "resources" in the distribution
|
||||||
models: list[ModelInput] = Field(default_factory=list)
|
models: list[ModelInput] = Field(default_factory=list)
|
||||||
shields: list[ShieldInput] = Field(default_factory=list)
|
shields: list[ShieldInput] = Field(default_factory=list)
|
||||||
|
@ -369,21 +340,8 @@ class BuildConfig(BaseModel):
|
||||||
default=None,
|
default=None,
|
||||||
description="Name of the distribution to build",
|
description="Name of the distribution to build",
|
||||||
)
|
)
|
||||||
external_providers_dir: Path | None = Field(
|
external_providers_dir: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Path to directory containing external provider implementations. The providers packages will be resolved from this directory. "
|
description="Path to directory containing external provider implementations. The providers packages will be resolved from this directory. "
|
||||||
"pip_packages MUST contain the provider package name.",
|
"pip_packages MUST contain the provider package name.",
|
||||||
)
|
)
|
||||||
additional_pip_packages: list[str] = Field(
|
|
||||||
default_factory=list,
|
|
||||||
description="Additional pip packages to install in the distribution. These packages will be installed in the distribution environment.",
|
|
||||||
)
|
|
||||||
|
|
||||||
@field_validator("external_providers_dir")
|
|
||||||
@classmethod
|
|
||||||
def validate_external_providers_dir(cls, v):
|
|
||||||
if v is None:
|
|
||||||
return None
|
|
||||||
if isinstance(v, str):
|
|
||||||
return Path(v)
|
|
||||||
return v
|
|
||||||
|
|
|
@ -16,7 +16,7 @@ from llama_stack.apis.inspect import (
|
||||||
VersionInfo,
|
VersionInfo,
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.datatypes import StackRunConfig
|
from llama_stack.distribution.datatypes import StackRunConfig
|
||||||
from llama_stack.distribution.server.routes import get_all_api_routes
|
from llama_stack.distribution.server.endpoints import get_all_api_endpoints
|
||||||
from llama_stack.providers.datatypes import HealthStatus
|
from llama_stack.providers.datatypes import HealthStatus
|
||||||
|
|
||||||
|
|
||||||
|
@ -31,7 +31,7 @@ async def get_provider_impl(config, deps):
|
||||||
|
|
||||||
|
|
||||||
class DistributionInspectImpl(Inspect):
|
class DistributionInspectImpl(Inspect):
|
||||||
def __init__(self, config: DistributionInspectConfig, deps):
|
def __init__(self, config, deps):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.deps = deps
|
self.deps = deps
|
||||||
|
|
||||||
|
@ -39,36 +39,22 @@ class DistributionInspectImpl(Inspect):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def list_routes(self) -> ListRoutesResponse:
|
async def list_routes(self) -> ListRoutesResponse:
|
||||||
run_config: StackRunConfig = self.config.run_config
|
run_config = self.config.run_config
|
||||||
|
|
||||||
ret = []
|
ret = []
|
||||||
all_endpoints = get_all_api_routes()
|
all_endpoints = get_all_api_endpoints()
|
||||||
for api, endpoints in all_endpoints.items():
|
for api, endpoints in all_endpoints.items():
|
||||||
# Always include provider and inspect APIs, filter others based on run config
|
providers = run_config.providers.get(api.value, [])
|
||||||
if api.value in ["providers", "inspect"]:
|
ret.extend(
|
||||||
ret.extend(
|
[
|
||||||
[
|
RouteInfo(
|
||||||
RouteInfo(
|
route=e.route,
|
||||||
route=e.path,
|
method=e.method,
|
||||||
method=next(iter([m for m in e.methods if m != "HEAD"])),
|
provider_types=[p.provider_type for p in providers],
|
||||||
provider_types=[], # These APIs don't have "real" providers - they're internal to the stack
|
|
||||||
)
|
|
||||||
for e in endpoints
|
|
||||||
]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
providers = run_config.providers.get(api.value, [])
|
|
||||||
if providers: # Only process if there are providers for this API
|
|
||||||
ret.extend(
|
|
||||||
[
|
|
||||||
RouteInfo(
|
|
||||||
route=e.path,
|
|
||||||
method=next(iter([m for m in e.methods if m != "HEAD"])),
|
|
||||||
provider_types=[p.provider_type for p in providers],
|
|
||||||
)
|
|
||||||
for e in endpoints
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
|
for e in endpoints
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
return ListRoutesResponse(data=ret)
|
return ListRoutesResponse(data=ret)
|
||||||
|
|
||||||
|
|
|
@ -9,7 +9,6 @@ import inspect
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import sys
|
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
@ -37,7 +36,10 @@ from llama_stack.distribution.request_headers import (
|
||||||
request_provider_data_context,
|
request_provider_data_context,
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.resolver import ProviderRegistry
|
from llama_stack.distribution.resolver import ProviderRegistry
|
||||||
from llama_stack.distribution.server.routes import find_matching_route, initialize_route_impls
|
from llama_stack.distribution.server.endpoints import (
|
||||||
|
find_matching_endpoint,
|
||||||
|
initialize_endpoint_impls,
|
||||||
|
)
|
||||||
from llama_stack.distribution.stack import (
|
from llama_stack.distribution.stack import (
|
||||||
construct_stack,
|
construct_stack,
|
||||||
get_stack_run_config_from_template,
|
get_stack_run_config_from_template,
|
||||||
|
@ -205,14 +207,13 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
|
|
||||||
async def initialize(self) -> bool:
|
async def initialize(self) -> bool:
|
||||||
try:
|
try:
|
||||||
self.route_impls = None
|
self.endpoint_impls = None
|
||||||
self.impls = await construct_stack(self.config, self.custom_provider_registry)
|
self.impls = await construct_stack(self.config, self.custom_provider_registry)
|
||||||
except ModuleNotFoundError as _e:
|
except ModuleNotFoundError as _e:
|
||||||
cprint(_e.msg, color="red", file=sys.stderr)
|
cprint(_e.msg, "red")
|
||||||
cprint(
|
cprint(
|
||||||
"Using llama-stack as a library requires installing dependencies depending on the template (providers) you choose.\n",
|
"Using llama-stack as a library requires installing dependencies depending on the template (providers) you choose.\n",
|
||||||
color="yellow",
|
"yellow",
|
||||||
file=sys.stderr,
|
|
||||||
)
|
)
|
||||||
if self.config_path_or_template_name.endswith(".yaml"):
|
if self.config_path_or_template_name.endswith(".yaml"):
|
||||||
# Convert Provider objects to their types
|
# Convert Provider objects to their types
|
||||||
|
@ -225,7 +226,6 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
distribution_spec=DistributionSpec(
|
distribution_spec=DistributionSpec(
|
||||||
providers=provider_types,
|
providers=provider_types,
|
||||||
),
|
),
|
||||||
external_providers_dir=self.config.external_providers_dir,
|
|
||||||
)
|
)
|
||||||
print_pip_install_help(build_config)
|
print_pip_install_help(build_config)
|
||||||
else:
|
else:
|
||||||
|
@ -233,13 +233,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
cprint(
|
cprint(
|
||||||
f"Please run:\n\n{prefix}llama stack build --template {self.config_path_or_template_name} --image-type venv\n\n",
|
f"Please run:\n\n{prefix}llama stack build --template {self.config_path_or_template_name} --image-type venv\n\n",
|
||||||
"yellow",
|
"yellow",
|
||||||
file=sys.stderr,
|
|
||||||
)
|
)
|
||||||
cprint(
|
|
||||||
"Please check your internet connection and try again.",
|
|
||||||
"red",
|
|
||||||
file=sys.stderr,
|
|
||||||
)
|
|
||||||
raise _e
|
raise _e
|
||||||
|
|
||||||
if Api.telemetry in self.impls:
|
if Api.telemetry in self.impls:
|
||||||
|
@ -251,7 +245,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
safe_config = redact_sensitive_fields(self.config.model_dump())
|
safe_config = redact_sensitive_fields(self.config.model_dump())
|
||||||
console.print(yaml.dump(safe_config, indent=2))
|
console.print(yaml.dump(safe_config, indent=2))
|
||||||
|
|
||||||
self.route_impls = initialize_route_impls(self.impls)
|
self.endpoint_impls = initialize_endpoint_impls(self.impls)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
async def request(
|
async def request(
|
||||||
|
@ -262,15 +256,13 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
stream=False,
|
stream=False,
|
||||||
stream_cls=None,
|
stream_cls=None,
|
||||||
):
|
):
|
||||||
if not self.route_impls:
|
if not self.endpoint_impls:
|
||||||
raise ValueError("Client not initialized")
|
raise ValueError("Client not initialized")
|
||||||
|
|
||||||
# Create headers with provider data if available
|
# Create headers with provider data if available
|
||||||
headers = options.headers or {}
|
headers = {}
|
||||||
if self.provider_data:
|
if self.provider_data:
|
||||||
keys = ["X-LlamaStack-Provider-Data", "x-llamastack-provider-data"]
|
headers["X-LlamaStack-Provider-Data"] = json.dumps(self.provider_data)
|
||||||
if all(key not in headers for key in keys):
|
|
||||||
headers["X-LlamaStack-Provider-Data"] = json.dumps(self.provider_data)
|
|
||||||
|
|
||||||
# Use context manager for provider data
|
# Use context manager for provider data
|
||||||
with request_provider_data_context(headers):
|
with request_provider_data_context(headers):
|
||||||
|
@ -293,14 +285,11 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
cast_to: Any,
|
cast_to: Any,
|
||||||
options: Any,
|
options: Any,
|
||||||
):
|
):
|
||||||
if self.route_impls is None:
|
|
||||||
raise ValueError("Client not initialized")
|
|
||||||
|
|
||||||
path = options.url
|
path = options.url
|
||||||
body = options.params or {}
|
body = options.params or {}
|
||||||
body |= options.json_data or {}
|
body |= options.json_data or {}
|
||||||
|
|
||||||
matched_func, path_params, route = find_matching_route(options.method, path, self.route_impls)
|
matched_func, path_params, route = find_matching_endpoint(options.method, path, self.endpoint_impls)
|
||||||
body |= path_params
|
body |= path_params
|
||||||
body = self._convert_body(path, options.method, body)
|
body = self._convert_body(path, options.method, body)
|
||||||
await start_trace(route, {"__location__": "library_client"})
|
await start_trace(route, {"__location__": "library_client"})
|
||||||
|
@ -342,13 +331,10 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
options: Any,
|
options: Any,
|
||||||
stream_cls: Any,
|
stream_cls: Any,
|
||||||
):
|
):
|
||||||
if self.route_impls is None:
|
|
||||||
raise ValueError("Client not initialized")
|
|
||||||
|
|
||||||
path = options.url
|
path = options.url
|
||||||
body = options.params or {}
|
body = options.params or {}
|
||||||
body |= options.json_data or {}
|
body |= options.json_data or {}
|
||||||
func, path_params, route = find_matching_route(options.method, path, self.route_impls)
|
func, path_params, route = find_matching_endpoint(options.method, path, self.endpoint_impls)
|
||||||
body |= path_params
|
body |= path_params
|
||||||
|
|
||||||
body = self._convert_body(path, options.method, body)
|
body = self._convert_body(path, options.method, body)
|
||||||
|
@ -400,10 +386,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
if not body:
|
if not body:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
if self.route_impls is None:
|
func, _, _ = find_matching_endpoint(method, path, self.endpoint_impls)
|
||||||
raise ValueError("Client not initialized")
|
|
||||||
|
|
||||||
func, _, _ = find_matching_route(method, path, self.route_impls)
|
|
||||||
sig = inspect.signature(func)
|
sig = inspect.signature(func)
|
||||||
|
|
||||||
# Strip NOT_GIVENs to use the defaults in signature
|
# Strip NOT_GIVENs to use the defaults in signature
|
||||||
|
|
|
@ -13,7 +13,7 @@ from llama_stack.apis.datasetio import DatasetIO
|
||||||
from llama_stack.apis.datasets import Datasets
|
from llama_stack.apis.datasets import Datasets
|
||||||
from llama_stack.apis.eval import Eval
|
from llama_stack.apis.eval import Eval
|
||||||
from llama_stack.apis.files import Files
|
from llama_stack.apis.files import Files
|
||||||
from llama_stack.apis.inference import Inference, InferenceProvider
|
from llama_stack.apis.inference import Inference
|
||||||
from llama_stack.apis.inspect import Inspect
|
from llama_stack.apis.inspect import Inspect
|
||||||
from llama_stack.apis.models import Models
|
from llama_stack.apis.models import Models
|
||||||
from llama_stack.apis.post_training import PostTraining
|
from llama_stack.apis.post_training import PostTraining
|
||||||
|
@ -47,7 +47,7 @@ from llama_stack.providers.datatypes import (
|
||||||
RemoteProviderSpec,
|
RemoteProviderSpec,
|
||||||
ScoringFunctionsProtocolPrivate,
|
ScoringFunctionsProtocolPrivate,
|
||||||
ShieldsProtocolPrivate,
|
ShieldsProtocolPrivate,
|
||||||
ToolGroupsProtocolPrivate,
|
ToolsProtocolPrivate,
|
||||||
VectorDBsProtocolPrivate,
|
VectorDBsProtocolPrivate,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -83,17 +83,10 @@ def api_protocol_map() -> dict[Api, Any]:
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def api_protocol_map_for_compliance_check() -> dict[Api, Any]:
|
|
||||||
return {
|
|
||||||
**api_protocol_map(),
|
|
||||||
Api.inference: InferenceProvider,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def additional_protocols_map() -> dict[Api, Any]:
|
def additional_protocols_map() -> dict[Api, Any]:
|
||||||
return {
|
return {
|
||||||
Api.inference: (ModelsProtocolPrivate, Models, Api.models),
|
Api.inference: (ModelsProtocolPrivate, Models, Api.models),
|
||||||
Api.tool_groups: (ToolGroupsProtocolPrivate, ToolGroups, Api.tool_groups),
|
Api.tool_groups: (ToolsProtocolPrivate, ToolGroups, Api.tool_groups),
|
||||||
Api.vector_io: (VectorDBsProtocolPrivate, VectorDBs, Api.vector_dbs),
|
Api.vector_io: (VectorDBsProtocolPrivate, VectorDBs, Api.vector_dbs),
|
||||||
Api.safety: (ShieldsProtocolPrivate, Shields, Api.shields),
|
Api.safety: (ShieldsProtocolPrivate, Shields, Api.shields),
|
||||||
Api.datasetio: (DatasetsProtocolPrivate, Datasets, Api.datasets),
|
Api.datasetio: (DatasetsProtocolPrivate, Datasets, Api.datasets),
|
||||||
|
@ -140,7 +133,7 @@ async def resolve_impls(
|
||||||
|
|
||||||
sorted_providers = sort_providers_by_deps(providers_with_specs, run_config)
|
sorted_providers = sort_providers_by_deps(providers_with_specs, run_config)
|
||||||
|
|
||||||
return await instantiate_providers(sorted_providers, router_apis, dist_registry, run_config)
|
return await instantiate_providers(sorted_providers, router_apis, dist_registry)
|
||||||
|
|
||||||
|
|
||||||
def specs_for_autorouted_apis(apis_to_serve: list[str] | set[str]) -> dict[str, dict[str, ProviderWithSpec]]:
|
def specs_for_autorouted_apis(apis_to_serve: list[str] | set[str]) -> dict[str, dict[str, ProviderWithSpec]]:
|
||||||
|
@ -243,10 +236,7 @@ def sort_providers_by_deps(
|
||||||
|
|
||||||
|
|
||||||
async def instantiate_providers(
|
async def instantiate_providers(
|
||||||
sorted_providers: list[tuple[str, ProviderWithSpec]],
|
sorted_providers: list[tuple[str, ProviderWithSpec]], router_apis: set[Api], dist_registry: DistributionRegistry
|
||||||
router_apis: set[Api],
|
|
||||||
dist_registry: DistributionRegistry,
|
|
||||||
run_config: StackRunConfig,
|
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""Instantiates providers asynchronously while managing dependencies."""
|
"""Instantiates providers asynchronously while managing dependencies."""
|
||||||
impls: dict[Api, Any] = {}
|
impls: dict[Api, Any] = {}
|
||||||
|
@ -261,7 +251,7 @@ async def instantiate_providers(
|
||||||
if isinstance(provider.spec, RoutingTableProviderSpec):
|
if isinstance(provider.spec, RoutingTableProviderSpec):
|
||||||
inner_impls = inner_impls_by_provider_id[f"inner-{provider.spec.router_api.value}"]
|
inner_impls = inner_impls_by_provider_id[f"inner-{provider.spec.router_api.value}"]
|
||||||
|
|
||||||
impl = await instantiate_provider(provider, deps, inner_impls, dist_registry, run_config)
|
impl = await instantiate_provider(provider, deps, inner_impls, dist_registry)
|
||||||
|
|
||||||
if api_str.startswith("inner-"):
|
if api_str.startswith("inner-"):
|
||||||
inner_impls_by_provider_id[api_str][provider.provider_id] = impl
|
inner_impls_by_provider_id[api_str][provider.provider_id] = impl
|
||||||
|
@ -311,8 +301,10 @@ async def instantiate_provider(
|
||||||
deps: dict[Api, Any],
|
deps: dict[Api, Any],
|
||||||
inner_impls: dict[str, Any],
|
inner_impls: dict[str, Any],
|
||||||
dist_registry: DistributionRegistry,
|
dist_registry: DistributionRegistry,
|
||||||
run_config: StackRunConfig,
|
|
||||||
):
|
):
|
||||||
|
protocols = api_protocol_map()
|
||||||
|
additional_protocols = additional_protocols_map()
|
||||||
|
|
||||||
provider_spec = provider.spec
|
provider_spec = provider.spec
|
||||||
if not hasattr(provider_spec, "module"):
|
if not hasattr(provider_spec, "module"):
|
||||||
raise AttributeError(f"ProviderSpec of type {type(provider_spec)} does not have a 'module' attribute")
|
raise AttributeError(f"ProviderSpec of type {type(provider_spec)} does not have a 'module' attribute")
|
||||||
|
@ -331,7 +323,7 @@ async def instantiate_provider(
|
||||||
method = "get_auto_router_impl"
|
method = "get_auto_router_impl"
|
||||||
|
|
||||||
config = None
|
config = None
|
||||||
args = [provider_spec.api, deps[provider_spec.routing_table_api], deps, run_config]
|
args = [provider_spec.api, deps[provider_spec.routing_table_api], deps]
|
||||||
elif isinstance(provider_spec, RoutingTableProviderSpec):
|
elif isinstance(provider_spec, RoutingTableProviderSpec):
|
||||||
method = "get_routing_table_impl"
|
method = "get_routing_table_impl"
|
||||||
|
|
||||||
|
@ -350,8 +342,6 @@ async def instantiate_provider(
|
||||||
impl.__provider_spec__ = provider_spec
|
impl.__provider_spec__ = provider_spec
|
||||||
impl.__provider_config__ = config
|
impl.__provider_config__ = config
|
||||||
|
|
||||||
protocols = api_protocol_map_for_compliance_check()
|
|
||||||
additional_protocols = additional_protocols_map()
|
|
||||||
# TODO: check compliance for special tool groups
|
# TODO: check compliance for special tool groups
|
||||||
# the impl should be for Api.tool_runtime, the name should be the special tool group, the protocol should be the special tool group protocol
|
# the impl should be for Api.tool_runtime, the name should be the special tool group, the protocol should be the special tool group protocol
|
||||||
check_protocol_compliance(impl, protocols[provider_spec.api])
|
check_protocol_compliance(impl, protocols[provider_spec.api])
|
||||||
|
|
|
@ -7,10 +7,18 @@
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import RoutedProtocol
|
from llama_stack.distribution.datatypes import RoutedProtocol
|
||||||
from llama_stack.distribution.stack import StackRunConfig
|
|
||||||
from llama_stack.distribution.store import DistributionRegistry
|
from llama_stack.distribution.store import DistributionRegistry
|
||||||
from llama_stack.providers.datatypes import Api, RoutingTable
|
from llama_stack.providers.datatypes import Api, RoutingTable
|
||||||
from llama_stack.providers.utils.inference.inference_store import InferenceStore
|
|
||||||
|
from .routing_tables import (
|
||||||
|
BenchmarksRoutingTable,
|
||||||
|
DatasetsRoutingTable,
|
||||||
|
ModelsRoutingTable,
|
||||||
|
ScoringFunctionsRoutingTable,
|
||||||
|
ShieldsRoutingTable,
|
||||||
|
ToolGroupsRoutingTable,
|
||||||
|
VectorDBsRoutingTable,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def get_routing_table_impl(
|
async def get_routing_table_impl(
|
||||||
|
@ -19,14 +27,6 @@ async def get_routing_table_impl(
|
||||||
_deps,
|
_deps,
|
||||||
dist_registry: DistributionRegistry,
|
dist_registry: DistributionRegistry,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
from ..routing_tables.benchmarks import BenchmarksRoutingTable
|
|
||||||
from ..routing_tables.datasets import DatasetsRoutingTable
|
|
||||||
from ..routing_tables.models import ModelsRoutingTable
|
|
||||||
from ..routing_tables.scoring_functions import ScoringFunctionsRoutingTable
|
|
||||||
from ..routing_tables.shields import ShieldsRoutingTable
|
|
||||||
from ..routing_tables.toolgroups import ToolGroupsRoutingTable
|
|
||||||
from ..routing_tables.vector_dbs import VectorDBsRoutingTable
|
|
||||||
|
|
||||||
api_to_tables = {
|
api_to_tables = {
|
||||||
"vector_dbs": VectorDBsRoutingTable,
|
"vector_dbs": VectorDBsRoutingTable,
|
||||||
"models": ModelsRoutingTable,
|
"models": ModelsRoutingTable,
|
||||||
|
@ -45,15 +45,16 @@ async def get_routing_table_impl(
|
||||||
return impl
|
return impl
|
||||||
|
|
||||||
|
|
||||||
async def get_auto_router_impl(
|
async def get_auto_router_impl(api: Api, routing_table: RoutingTable, deps: dict[str, Any]) -> Any:
|
||||||
api: Api, routing_table: RoutingTable, deps: dict[str, Any], run_config: StackRunConfig
|
from .routers import (
|
||||||
) -> Any:
|
DatasetIORouter,
|
||||||
from .datasets import DatasetIORouter
|
EvalRouter,
|
||||||
from .eval_scoring import EvalRouter, ScoringRouter
|
InferenceRouter,
|
||||||
from .inference import InferenceRouter
|
SafetyRouter,
|
||||||
from .safety import SafetyRouter
|
ScoringRouter,
|
||||||
from .tool_runtime import ToolRuntimeRouter
|
ToolRuntimeRouter,
|
||||||
from .vector_io import VectorIORouter
|
VectorIORouter,
|
||||||
|
)
|
||||||
|
|
||||||
api_to_routers = {
|
api_to_routers = {
|
||||||
"vector_io": VectorIORouter,
|
"vector_io": VectorIORouter,
|
||||||
|
@ -75,12 +76,6 @@ async def get_auto_router_impl(
|
||||||
if dep_api in deps:
|
if dep_api in deps:
|
||||||
api_to_dep_impl[dep_name] = deps[dep_api]
|
api_to_dep_impl[dep_name] = deps[dep_api]
|
||||||
|
|
||||||
# TODO: move pass configs to routers instead
|
|
||||||
if api == Api.inference and run_config.inference_store:
|
|
||||||
inference_store = InferenceStore(run_config.inference_store)
|
|
||||||
await inference_store.initialize()
|
|
||||||
api_to_dep_impl["store"] = inference_store
|
|
||||||
|
|
||||||
impl = api_to_routers[api.value](routing_table, **api_to_dep_impl)
|
impl = api_to_routers[api.value](routing_table, **api_to_dep_impl)
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
|
@ -1,71 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from llama_stack.apis.common.responses import PaginatedResponse
|
|
||||||
from llama_stack.apis.datasetio import DatasetIO
|
|
||||||
from llama_stack.apis.datasets import DatasetPurpose, DataSource
|
|
||||||
from llama_stack.log import get_logger
|
|
||||||
from llama_stack.providers.datatypes import RoutingTable
|
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
|
||||||
|
|
||||||
|
|
||||||
class DatasetIORouter(DatasetIO):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
routing_table: RoutingTable,
|
|
||||||
) -> None:
|
|
||||||
logger.debug("Initializing DatasetIORouter")
|
|
||||||
self.routing_table = routing_table
|
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
|
||||||
logger.debug("DatasetIORouter.initialize")
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
|
||||||
logger.debug("DatasetIORouter.shutdown")
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def register_dataset(
|
|
||||||
self,
|
|
||||||
purpose: DatasetPurpose,
|
|
||||||
source: DataSource,
|
|
||||||
metadata: dict[str, Any] | None = None,
|
|
||||||
dataset_id: str | None = None,
|
|
||||||
) -> None:
|
|
||||||
logger.debug(
|
|
||||||
f"DatasetIORouter.register_dataset: {purpose=} {source=} {metadata=} {dataset_id=}",
|
|
||||||
)
|
|
||||||
await self.routing_table.register_dataset(
|
|
||||||
purpose=purpose,
|
|
||||||
source=source,
|
|
||||||
metadata=metadata,
|
|
||||||
dataset_id=dataset_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def iterrows(
|
|
||||||
self,
|
|
||||||
dataset_id: str,
|
|
||||||
start_index: int | None = None,
|
|
||||||
limit: int | None = None,
|
|
||||||
) -> PaginatedResponse:
|
|
||||||
logger.debug(
|
|
||||||
f"DatasetIORouter.iterrows: {dataset_id}, {start_index=} {limit=}",
|
|
||||||
)
|
|
||||||
return await self.routing_table.get_provider_impl(dataset_id).iterrows(
|
|
||||||
dataset_id=dataset_id,
|
|
||||||
start_index=start_index,
|
|
||||||
limit=limit,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None:
|
|
||||||
logger.debug(f"DatasetIORouter.append_rows: {dataset_id}, {len(rows)} rows")
|
|
||||||
return await self.routing_table.get_provider_impl(dataset_id).append_rows(
|
|
||||||
dataset_id=dataset_id,
|
|
||||||
rows=rows,
|
|
||||||
)
|
|
|
@ -1,148 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from llama_stack.apis.eval import BenchmarkConfig, Eval, EvaluateResponse, Job
|
|
||||||
from llama_stack.apis.scoring import (
|
|
||||||
ScoreBatchResponse,
|
|
||||||
ScoreResponse,
|
|
||||||
Scoring,
|
|
||||||
ScoringFnParams,
|
|
||||||
)
|
|
||||||
from llama_stack.log import get_logger
|
|
||||||
from llama_stack.providers.datatypes import RoutingTable
|
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
|
||||||
|
|
||||||
|
|
||||||
class ScoringRouter(Scoring):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
routing_table: RoutingTable,
|
|
||||||
) -> None:
|
|
||||||
logger.debug("Initializing ScoringRouter")
|
|
||||||
self.routing_table = routing_table
|
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
|
||||||
logger.debug("ScoringRouter.initialize")
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
|
||||||
logger.debug("ScoringRouter.shutdown")
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def score_batch(
|
|
||||||
self,
|
|
||||||
dataset_id: str,
|
|
||||||
scoring_functions: dict[str, ScoringFnParams | None] = None,
|
|
||||||
save_results_dataset: bool = False,
|
|
||||||
) -> ScoreBatchResponse:
|
|
||||||
logger.debug(f"ScoringRouter.score_batch: {dataset_id}")
|
|
||||||
res = {}
|
|
||||||
for fn_identifier in scoring_functions.keys():
|
|
||||||
score_response = await self.routing_table.get_provider_impl(fn_identifier).score_batch(
|
|
||||||
dataset_id=dataset_id,
|
|
||||||
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
|
|
||||||
)
|
|
||||||
res.update(score_response.results)
|
|
||||||
|
|
||||||
if save_results_dataset:
|
|
||||||
raise NotImplementedError("Save results dataset not implemented yet")
|
|
||||||
|
|
||||||
return ScoreBatchResponse(
|
|
||||||
results=res,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def score(
|
|
||||||
self,
|
|
||||||
input_rows: list[dict[str, Any]],
|
|
||||||
scoring_functions: dict[str, ScoringFnParams | None] = None,
|
|
||||||
) -> ScoreResponse:
|
|
||||||
logger.debug(f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions")
|
|
||||||
res = {}
|
|
||||||
# look up and map each scoring function to its provider impl
|
|
||||||
for fn_identifier in scoring_functions.keys():
|
|
||||||
score_response = await self.routing_table.get_provider_impl(fn_identifier).score(
|
|
||||||
input_rows=input_rows,
|
|
||||||
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
|
|
||||||
)
|
|
||||||
res.update(score_response.results)
|
|
||||||
|
|
||||||
return ScoreResponse(results=res)
|
|
||||||
|
|
||||||
|
|
||||||
class EvalRouter(Eval):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
routing_table: RoutingTable,
|
|
||||||
) -> None:
|
|
||||||
logger.debug("Initializing EvalRouter")
|
|
||||||
self.routing_table = routing_table
|
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
|
||||||
logger.debug("EvalRouter.initialize")
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
|
||||||
logger.debug("EvalRouter.shutdown")
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def run_eval(
|
|
||||||
self,
|
|
||||||
benchmark_id: str,
|
|
||||||
benchmark_config: BenchmarkConfig,
|
|
||||||
) -> Job:
|
|
||||||
logger.debug(f"EvalRouter.run_eval: {benchmark_id}")
|
|
||||||
return await self.routing_table.get_provider_impl(benchmark_id).run_eval(
|
|
||||||
benchmark_id=benchmark_id,
|
|
||||||
benchmark_config=benchmark_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def evaluate_rows(
|
|
||||||
self,
|
|
||||||
benchmark_id: str,
|
|
||||||
input_rows: list[dict[str, Any]],
|
|
||||||
scoring_functions: list[str],
|
|
||||||
benchmark_config: BenchmarkConfig,
|
|
||||||
) -> EvaluateResponse:
|
|
||||||
logger.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows")
|
|
||||||
return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows(
|
|
||||||
benchmark_id=benchmark_id,
|
|
||||||
input_rows=input_rows,
|
|
||||||
scoring_functions=scoring_functions,
|
|
||||||
benchmark_config=benchmark_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def job_status(
|
|
||||||
self,
|
|
||||||
benchmark_id: str,
|
|
||||||
job_id: str,
|
|
||||||
) -> Job:
|
|
||||||
logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}")
|
|
||||||
return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id)
|
|
||||||
|
|
||||||
async def job_cancel(
|
|
||||||
self,
|
|
||||||
benchmark_id: str,
|
|
||||||
job_id: str,
|
|
||||||
) -> None:
|
|
||||||
logger.debug(f"EvalRouter.job_cancel: {benchmark_id}, {job_id}")
|
|
||||||
await self.routing_table.get_provider_impl(benchmark_id).job_cancel(
|
|
||||||
benchmark_id,
|
|
||||||
job_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def job_result(
|
|
||||||
self,
|
|
||||||
benchmark_id: str,
|
|
||||||
job_id: str,
|
|
||||||
) -> EvaluateResponse:
|
|
||||||
logger.debug(f"EvalRouter.job_result: {benchmark_id}, {job_id}")
|
|
||||||
return await self.routing_table.get_provider_impl(benchmark_id).job_result(
|
|
||||||
benchmark_id,
|
|
||||||
job_id,
|
|
||||||
)
|
|
|
@ -14,9 +14,14 @@ from openai.types.chat import ChatCompletionToolParam as OpenAIChatCompletionToo
|
||||||
from pydantic import Field, TypeAdapter
|
from pydantic import Field, TypeAdapter
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import (
|
from llama_stack.apis.common.content_types import (
|
||||||
|
URL,
|
||||||
InterleavedContent,
|
InterleavedContent,
|
||||||
InterleavedContentItem,
|
InterleavedContentItem,
|
||||||
)
|
)
|
||||||
|
from llama_stack.apis.common.responses import PaginatedResponse
|
||||||
|
from llama_stack.apis.datasetio import DatasetIO
|
||||||
|
from llama_stack.apis.datasets import DatasetPurpose, DataSource
|
||||||
|
from llama_stack.apis.eval import BenchmarkConfig, Eval, EvaluateResponse, Job
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
BatchChatCompletionResponse,
|
BatchChatCompletionResponse,
|
||||||
BatchCompletionResponse,
|
BatchCompletionResponse,
|
||||||
|
@ -27,11 +32,8 @@ from llama_stack.apis.inference import (
|
||||||
EmbeddingsResponse,
|
EmbeddingsResponse,
|
||||||
EmbeddingTaskType,
|
EmbeddingTaskType,
|
||||||
Inference,
|
Inference,
|
||||||
ListOpenAIChatCompletionResponse,
|
|
||||||
LogProbConfig,
|
LogProbConfig,
|
||||||
Message,
|
Message,
|
||||||
OpenAICompletionWithInputMessages,
|
|
||||||
Order,
|
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
StopReason,
|
StopReason,
|
||||||
|
@ -45,23 +47,93 @@ from llama_stack.apis.inference.inference import (
|
||||||
OpenAIChatCompletion,
|
OpenAIChatCompletion,
|
||||||
OpenAIChatCompletionChunk,
|
OpenAIChatCompletionChunk,
|
||||||
OpenAICompletion,
|
OpenAICompletion,
|
||||||
OpenAIEmbeddingsResponse,
|
|
||||||
OpenAIMessageParam,
|
OpenAIMessageParam,
|
||||||
OpenAIResponseFormatParam,
|
OpenAIResponseFormatParam,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.models import Model, ModelType
|
from llama_stack.apis.models import Model, ModelType
|
||||||
|
from llama_stack.apis.safety import RunShieldResponse, Safety
|
||||||
|
from llama_stack.apis.scoring import (
|
||||||
|
ScoreBatchResponse,
|
||||||
|
ScoreResponse,
|
||||||
|
Scoring,
|
||||||
|
ScoringFnParams,
|
||||||
|
)
|
||||||
|
from llama_stack.apis.shields import Shield
|
||||||
from llama_stack.apis.telemetry import MetricEvent, MetricInResponse, Telemetry
|
from llama_stack.apis.telemetry import MetricEvent, MetricInResponse, Telemetry
|
||||||
|
from llama_stack.apis.tools import (
|
||||||
|
ListToolDefsResponse,
|
||||||
|
RAGDocument,
|
||||||
|
RAGQueryConfig,
|
||||||
|
RAGQueryResult,
|
||||||
|
RAGToolRuntime,
|
||||||
|
ToolRuntime,
|
||||||
|
)
|
||||||
|
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.models.llama.llama3.chat_format import ChatFormat
|
from llama_stack.models.llama.llama3.chat_format import ChatFormat
|
||||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||||
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
|
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
|
||||||
from llama_stack.providers.utils.inference.inference_store import InferenceStore
|
|
||||||
from llama_stack.providers.utils.inference.stream_utils import stream_and_store_openai_completion
|
|
||||||
from llama_stack.providers.utils.telemetry.tracing import get_current_span
|
from llama_stack.providers.utils.telemetry.tracing import get_current_span
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
logger = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
|
class VectorIORouter(VectorIO):
|
||||||
|
"""Routes to an provider based on the vector db identifier"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
routing_table: RoutingTable,
|
||||||
|
) -> None:
|
||||||
|
logger.debug("Initializing VectorIORouter")
|
||||||
|
self.routing_table = routing_table
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
logger.debug("VectorIORouter.initialize")
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
logger.debug("VectorIORouter.shutdown")
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def register_vector_db(
|
||||||
|
self,
|
||||||
|
vector_db_id: str,
|
||||||
|
embedding_model: str,
|
||||||
|
embedding_dimension: int | None = 384,
|
||||||
|
provider_id: str | None = None,
|
||||||
|
provider_vector_db_id: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
logger.debug(f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}")
|
||||||
|
await self.routing_table.register_vector_db(
|
||||||
|
vector_db_id,
|
||||||
|
embedding_model,
|
||||||
|
embedding_dimension,
|
||||||
|
provider_id,
|
||||||
|
provider_vector_db_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def insert_chunks(
|
||||||
|
self,
|
||||||
|
vector_db_id: str,
|
||||||
|
chunks: list[Chunk],
|
||||||
|
ttl_seconds: int | None = None,
|
||||||
|
) -> None:
|
||||||
|
logger.debug(
|
||||||
|
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}",
|
||||||
|
)
|
||||||
|
return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(vector_db_id, chunks, ttl_seconds)
|
||||||
|
|
||||||
|
async def query_chunks(
|
||||||
|
self,
|
||||||
|
vector_db_id: str,
|
||||||
|
query: InterleavedContent,
|
||||||
|
params: dict[str, Any] | None = None,
|
||||||
|
) -> QueryChunksResponse:
|
||||||
|
logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}")
|
||||||
|
return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params)
|
||||||
|
|
||||||
|
|
||||||
class InferenceRouter(Inference):
|
class InferenceRouter(Inference):
|
||||||
"""Routes to an provider based on the model"""
|
"""Routes to an provider based on the model"""
|
||||||
|
|
||||||
|
@ -69,12 +141,10 @@ class InferenceRouter(Inference):
|
||||||
self,
|
self,
|
||||||
routing_table: RoutingTable,
|
routing_table: RoutingTable,
|
||||||
telemetry: Telemetry | None = None,
|
telemetry: Telemetry | None = None,
|
||||||
store: InferenceStore | None = None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
logger.debug("Initializing InferenceRouter")
|
logger.debug("Initializing InferenceRouter")
|
||||||
self.routing_table = routing_table
|
self.routing_table = routing_table
|
||||||
self.telemetry = telemetry
|
self.telemetry = telemetry
|
||||||
self.store = store
|
|
||||||
if self.telemetry:
|
if self.telemetry:
|
||||||
self.tokenizer = Tokenizer.get_instance()
|
self.tokenizer = Tokenizer.get_instance()
|
||||||
self.formatter = ChatFormat(self.tokenizer)
|
self.formatter = ChatFormat(self.tokenizer)
|
||||||
|
@ -537,59 +607,9 @@ class InferenceRouter(Inference):
|
||||||
|
|
||||||
provider = self.routing_table.get_provider_impl(model_obj.identifier)
|
provider = self.routing_table.get_provider_impl(model_obj.identifier)
|
||||||
if stream:
|
if stream:
|
||||||
response_stream = await provider.openai_chat_completion(**params)
|
return await provider.openai_chat_completion(**params)
|
||||||
if self.store:
|
|
||||||
return stream_and_store_openai_completion(response_stream, model, self.store, messages)
|
|
||||||
return response_stream
|
|
||||||
else:
|
else:
|
||||||
response = await self._nonstream_openai_chat_completion(provider, params)
|
return await self._nonstream_openai_chat_completion(provider, params)
|
||||||
if self.store:
|
|
||||||
await self.store.store_chat_completion(response, messages)
|
|
||||||
return response
|
|
||||||
|
|
||||||
async def openai_embeddings(
|
|
||||||
self,
|
|
||||||
model: str,
|
|
||||||
input: str | list[str],
|
|
||||||
encoding_format: str | None = "float",
|
|
||||||
dimensions: int | None = None,
|
|
||||||
user: str | None = None,
|
|
||||||
) -> OpenAIEmbeddingsResponse:
|
|
||||||
logger.debug(
|
|
||||||
f"InferenceRouter.openai_embeddings: {model=}, input_type={type(input)}, {encoding_format=}, {dimensions=}",
|
|
||||||
)
|
|
||||||
model_obj = await self.routing_table.get_model(model)
|
|
||||||
if model_obj is None:
|
|
||||||
raise ValueError(f"Model '{model}' not found")
|
|
||||||
if model_obj.model_type != ModelType.embedding:
|
|
||||||
raise ValueError(f"Model '{model}' is not an embedding model")
|
|
||||||
|
|
||||||
params = dict(
|
|
||||||
model=model_obj.identifier,
|
|
||||||
input=input,
|
|
||||||
encoding_format=encoding_format,
|
|
||||||
dimensions=dimensions,
|
|
||||||
user=user,
|
|
||||||
)
|
|
||||||
|
|
||||||
provider = self.routing_table.get_provider_impl(model_obj.identifier)
|
|
||||||
return await provider.openai_embeddings(**params)
|
|
||||||
|
|
||||||
async def list_chat_completions(
|
|
||||||
self,
|
|
||||||
after: str | None = None,
|
|
||||||
limit: int | None = 20,
|
|
||||||
model: str | None = None,
|
|
||||||
order: Order | None = Order.desc,
|
|
||||||
) -> ListOpenAIChatCompletionResponse:
|
|
||||||
if self.store:
|
|
||||||
return await self.store.list_chat_completions(after, limit, model, order)
|
|
||||||
raise NotImplementedError("List chat completions is not supported: inference store is not configured.")
|
|
||||||
|
|
||||||
async def get_chat_completion(self, completion_id: str) -> OpenAICompletionWithInputMessages:
|
|
||||||
if self.store:
|
|
||||||
return await self.store.get_chat_completion(completion_id)
|
|
||||||
raise NotImplementedError("Get chat completion is not supported: inference store is not configured.")
|
|
||||||
|
|
||||||
async def _nonstream_openai_chat_completion(self, provider: Inference, params: dict) -> OpenAIChatCompletion:
|
async def _nonstream_openai_chat_completion(self, provider: Inference, params: dict) -> OpenAIChatCompletion:
|
||||||
response = await provider.openai_chat_completion(**params)
|
response = await provider.openai_chat_completion(**params)
|
||||||
|
@ -622,3 +642,295 @@ class InferenceRouter(Inference):
|
||||||
status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}"
|
status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}"
|
||||||
)
|
)
|
||||||
return health_statuses
|
return health_statuses
|
||||||
|
|
||||||
|
|
||||||
|
class SafetyRouter(Safety):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
routing_table: RoutingTable,
|
||||||
|
) -> None:
|
||||||
|
logger.debug("Initializing SafetyRouter")
|
||||||
|
self.routing_table = routing_table
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
logger.debug("SafetyRouter.initialize")
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
logger.debug("SafetyRouter.shutdown")
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def register_shield(
|
||||||
|
self,
|
||||||
|
shield_id: str,
|
||||||
|
provider_shield_id: str | None = None,
|
||||||
|
provider_id: str | None = None,
|
||||||
|
params: dict[str, Any] | None = None,
|
||||||
|
) -> Shield:
|
||||||
|
logger.debug(f"SafetyRouter.register_shield: {shield_id}")
|
||||||
|
return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params)
|
||||||
|
|
||||||
|
async def run_shield(
|
||||||
|
self,
|
||||||
|
shield_id: str,
|
||||||
|
messages: list[Message],
|
||||||
|
params: dict[str, Any] = None,
|
||||||
|
) -> RunShieldResponse:
|
||||||
|
logger.debug(f"SafetyRouter.run_shield: {shield_id}")
|
||||||
|
return await self.routing_table.get_provider_impl(shield_id).run_shield(
|
||||||
|
shield_id=shield_id,
|
||||||
|
messages=messages,
|
||||||
|
params=params,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetIORouter(DatasetIO):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
routing_table: RoutingTable,
|
||||||
|
) -> None:
|
||||||
|
logger.debug("Initializing DatasetIORouter")
|
||||||
|
self.routing_table = routing_table
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
logger.debug("DatasetIORouter.initialize")
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
logger.debug("DatasetIORouter.shutdown")
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def register_dataset(
|
||||||
|
self,
|
||||||
|
purpose: DatasetPurpose,
|
||||||
|
source: DataSource,
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
|
dataset_id: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
logger.debug(
|
||||||
|
f"DatasetIORouter.register_dataset: {purpose=} {source=} {metadata=} {dataset_id=}",
|
||||||
|
)
|
||||||
|
await self.routing_table.register_dataset(
|
||||||
|
purpose=purpose,
|
||||||
|
source=source,
|
||||||
|
metadata=metadata,
|
||||||
|
dataset_id=dataset_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def iterrows(
|
||||||
|
self,
|
||||||
|
dataset_id: str,
|
||||||
|
start_index: int | None = None,
|
||||||
|
limit: int | None = None,
|
||||||
|
) -> PaginatedResponse:
|
||||||
|
logger.debug(
|
||||||
|
f"DatasetIORouter.iterrows: {dataset_id}, {start_index=} {limit=}",
|
||||||
|
)
|
||||||
|
return await self.routing_table.get_provider_impl(dataset_id).iterrows(
|
||||||
|
dataset_id=dataset_id,
|
||||||
|
start_index=start_index,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None:
|
||||||
|
logger.debug(f"DatasetIORouter.append_rows: {dataset_id}, {len(rows)} rows")
|
||||||
|
return await self.routing_table.get_provider_impl(dataset_id).append_rows(
|
||||||
|
dataset_id=dataset_id,
|
||||||
|
rows=rows,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ScoringRouter(Scoring):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
routing_table: RoutingTable,
|
||||||
|
) -> None:
|
||||||
|
logger.debug("Initializing ScoringRouter")
|
||||||
|
self.routing_table = routing_table
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
logger.debug("ScoringRouter.initialize")
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
logger.debug("ScoringRouter.shutdown")
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def score_batch(
|
||||||
|
self,
|
||||||
|
dataset_id: str,
|
||||||
|
scoring_functions: dict[str, ScoringFnParams | None] = None,
|
||||||
|
save_results_dataset: bool = False,
|
||||||
|
) -> ScoreBatchResponse:
|
||||||
|
logger.debug(f"ScoringRouter.score_batch: {dataset_id}")
|
||||||
|
res = {}
|
||||||
|
for fn_identifier in scoring_functions.keys():
|
||||||
|
score_response = await self.routing_table.get_provider_impl(fn_identifier).score_batch(
|
||||||
|
dataset_id=dataset_id,
|
||||||
|
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
|
||||||
|
)
|
||||||
|
res.update(score_response.results)
|
||||||
|
|
||||||
|
if save_results_dataset:
|
||||||
|
raise NotImplementedError("Save results dataset not implemented yet")
|
||||||
|
|
||||||
|
return ScoreBatchResponse(
|
||||||
|
results=res,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def score(
|
||||||
|
self,
|
||||||
|
input_rows: list[dict[str, Any]],
|
||||||
|
scoring_functions: dict[str, ScoringFnParams | None] = None,
|
||||||
|
) -> ScoreResponse:
|
||||||
|
logger.debug(f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions")
|
||||||
|
res = {}
|
||||||
|
# look up and map each scoring function to its provider impl
|
||||||
|
for fn_identifier in scoring_functions.keys():
|
||||||
|
score_response = await self.routing_table.get_provider_impl(fn_identifier).score(
|
||||||
|
input_rows=input_rows,
|
||||||
|
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
|
||||||
|
)
|
||||||
|
res.update(score_response.results)
|
||||||
|
|
||||||
|
return ScoreResponse(results=res)
|
||||||
|
|
||||||
|
|
||||||
|
class EvalRouter(Eval):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
routing_table: RoutingTable,
|
||||||
|
) -> None:
|
||||||
|
logger.debug("Initializing EvalRouter")
|
||||||
|
self.routing_table = routing_table
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
logger.debug("EvalRouter.initialize")
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
logger.debug("EvalRouter.shutdown")
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def run_eval(
|
||||||
|
self,
|
||||||
|
benchmark_id: str,
|
||||||
|
benchmark_config: BenchmarkConfig,
|
||||||
|
) -> Job:
|
||||||
|
logger.debug(f"EvalRouter.run_eval: {benchmark_id}")
|
||||||
|
return await self.routing_table.get_provider_impl(benchmark_id).run_eval(
|
||||||
|
benchmark_id=benchmark_id,
|
||||||
|
benchmark_config=benchmark_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def evaluate_rows(
|
||||||
|
self,
|
||||||
|
benchmark_id: str,
|
||||||
|
input_rows: list[dict[str, Any]],
|
||||||
|
scoring_functions: list[str],
|
||||||
|
benchmark_config: BenchmarkConfig,
|
||||||
|
) -> EvaluateResponse:
|
||||||
|
logger.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows")
|
||||||
|
return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows(
|
||||||
|
benchmark_id=benchmark_id,
|
||||||
|
input_rows=input_rows,
|
||||||
|
scoring_functions=scoring_functions,
|
||||||
|
benchmark_config=benchmark_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def job_status(
|
||||||
|
self,
|
||||||
|
benchmark_id: str,
|
||||||
|
job_id: str,
|
||||||
|
) -> Job:
|
||||||
|
logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}")
|
||||||
|
return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id)
|
||||||
|
|
||||||
|
async def job_cancel(
|
||||||
|
self,
|
||||||
|
benchmark_id: str,
|
||||||
|
job_id: str,
|
||||||
|
) -> None:
|
||||||
|
logger.debug(f"EvalRouter.job_cancel: {benchmark_id}, {job_id}")
|
||||||
|
await self.routing_table.get_provider_impl(benchmark_id).job_cancel(
|
||||||
|
benchmark_id,
|
||||||
|
job_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def job_result(
|
||||||
|
self,
|
||||||
|
benchmark_id: str,
|
||||||
|
job_id: str,
|
||||||
|
) -> EvaluateResponse:
|
||||||
|
logger.debug(f"EvalRouter.job_result: {benchmark_id}, {job_id}")
|
||||||
|
return await self.routing_table.get_provider_impl(benchmark_id).job_result(
|
||||||
|
benchmark_id,
|
||||||
|
job_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ToolRuntimeRouter(ToolRuntime):
|
||||||
|
class RagToolImpl(RAGToolRuntime):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
routing_table: RoutingTable,
|
||||||
|
) -> None:
|
||||||
|
logger.debug("Initializing ToolRuntimeRouter.RagToolImpl")
|
||||||
|
self.routing_table = routing_table
|
||||||
|
|
||||||
|
async def query(
|
||||||
|
self,
|
||||||
|
content: InterleavedContent,
|
||||||
|
vector_db_ids: list[str],
|
||||||
|
query_config: RAGQueryConfig | None = None,
|
||||||
|
) -> RAGQueryResult:
|
||||||
|
logger.debug(f"ToolRuntimeRouter.RagToolImpl.query: {vector_db_ids}")
|
||||||
|
return await self.routing_table.get_provider_impl("knowledge_search").query(
|
||||||
|
content, vector_db_ids, query_config
|
||||||
|
)
|
||||||
|
|
||||||
|
async def insert(
|
||||||
|
self,
|
||||||
|
documents: list[RAGDocument],
|
||||||
|
vector_db_id: str,
|
||||||
|
chunk_size_in_tokens: int = 512,
|
||||||
|
) -> None:
|
||||||
|
logger.debug(
|
||||||
|
f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}"
|
||||||
|
)
|
||||||
|
return await self.routing_table.get_provider_impl("insert_into_memory").insert(
|
||||||
|
documents, vector_db_id, chunk_size_in_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
routing_table: RoutingTable,
|
||||||
|
) -> None:
|
||||||
|
logger.debug("Initializing ToolRuntimeRouter")
|
||||||
|
self.routing_table = routing_table
|
||||||
|
|
||||||
|
# HACK ALERT this should be in sync with "get_all_api_endpoints()"
|
||||||
|
self.rag_tool = self.RagToolImpl(routing_table)
|
||||||
|
for method in ("query", "insert"):
|
||||||
|
setattr(self, f"rag_tool.{method}", getattr(self.rag_tool, method))
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
logger.debug("ToolRuntimeRouter.initialize")
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
logger.debug("ToolRuntimeRouter.shutdown")
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> Any:
|
||||||
|
logger.debug(f"ToolRuntimeRouter.invoke_tool: {tool_name}")
|
||||||
|
return await self.routing_table.get_provider_impl(tool_name).invoke_tool(
|
||||||
|
tool_name=tool_name,
|
||||||
|
kwargs=kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def list_runtime_tools(
|
||||||
|
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
|
||||||
|
) -> ListToolDefsResponse:
|
||||||
|
logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
|
||||||
|
return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint)
|
634
llama_stack/distribution/routers/routing_tables.py
Normal file
634
llama_stack/distribution/routers/routing_tables.py
Normal file
|
@ -0,0 +1,634 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import TypeAdapter
|
||||||
|
|
||||||
|
from llama_stack.apis.benchmarks import Benchmark, Benchmarks, ListBenchmarksResponse
|
||||||
|
from llama_stack.apis.common.content_types import URL
|
||||||
|
from llama_stack.apis.common.type_system import ParamType
|
||||||
|
from llama_stack.apis.datasets import (
|
||||||
|
Dataset,
|
||||||
|
DatasetPurpose,
|
||||||
|
Datasets,
|
||||||
|
DatasetType,
|
||||||
|
DataSource,
|
||||||
|
ListDatasetsResponse,
|
||||||
|
RowsDataSource,
|
||||||
|
URIDataSource,
|
||||||
|
)
|
||||||
|
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType, OpenAIListModelsResponse, OpenAIModel
|
||||||
|
from llama_stack.apis.resource import ResourceType
|
||||||
|
from llama_stack.apis.scoring_functions import (
|
||||||
|
ListScoringFunctionsResponse,
|
||||||
|
ScoringFn,
|
||||||
|
ScoringFnParams,
|
||||||
|
ScoringFunctions,
|
||||||
|
)
|
||||||
|
from llama_stack.apis.shields import ListShieldsResponse, Shield, Shields
|
||||||
|
from llama_stack.apis.tools import (
|
||||||
|
ListToolGroupsResponse,
|
||||||
|
ListToolsResponse,
|
||||||
|
Tool,
|
||||||
|
ToolGroup,
|
||||||
|
ToolGroups,
|
||||||
|
ToolHost,
|
||||||
|
)
|
||||||
|
from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs
|
||||||
|
from llama_stack.distribution.access_control import check_access
|
||||||
|
from llama_stack.distribution.datatypes import (
|
||||||
|
AccessAttributes,
|
||||||
|
BenchmarkWithACL,
|
||||||
|
DatasetWithACL,
|
||||||
|
ModelWithACL,
|
||||||
|
RoutableObject,
|
||||||
|
RoutableObjectWithProvider,
|
||||||
|
RoutedProtocol,
|
||||||
|
ScoringFnWithACL,
|
||||||
|
ShieldWithACL,
|
||||||
|
ToolGroupWithACL,
|
||||||
|
ToolWithACL,
|
||||||
|
VectorDBWithACL,
|
||||||
|
)
|
||||||
|
from llama_stack.distribution.request_headers import get_auth_attributes
|
||||||
|
from llama_stack.distribution.store import DistributionRegistry
|
||||||
|
from llama_stack.providers.datatypes import Api, RoutingTable
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def get_impl_api(p: Any) -> Api:
|
||||||
|
return p.__provider_spec__.api
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: this should return the registered object for all APIs
|
||||||
|
async def register_object_with_provider(obj: RoutableObject, p: Any) -> RoutableObject:
|
||||||
|
api = get_impl_api(p)
|
||||||
|
|
||||||
|
assert obj.provider_id != "remote", "Remote provider should not be registered"
|
||||||
|
|
||||||
|
if api == Api.inference:
|
||||||
|
return await p.register_model(obj)
|
||||||
|
elif api == Api.safety:
|
||||||
|
return await p.register_shield(obj)
|
||||||
|
elif api == Api.vector_io:
|
||||||
|
return await p.register_vector_db(obj)
|
||||||
|
elif api == Api.datasetio:
|
||||||
|
return await p.register_dataset(obj)
|
||||||
|
elif api == Api.scoring:
|
||||||
|
return await p.register_scoring_function(obj)
|
||||||
|
elif api == Api.eval:
|
||||||
|
return await p.register_benchmark(obj)
|
||||||
|
elif api == Api.tool_runtime:
|
||||||
|
return await p.register_tool(obj)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown API {api} for registering object with provider")
|
||||||
|
|
||||||
|
|
||||||
|
async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None:
|
||||||
|
api = get_impl_api(p)
|
||||||
|
if api == Api.vector_io:
|
||||||
|
return await p.unregister_vector_db(obj.identifier)
|
||||||
|
elif api == Api.inference:
|
||||||
|
return await p.unregister_model(obj.identifier)
|
||||||
|
elif api == Api.datasetio:
|
||||||
|
return await p.unregister_dataset(obj.identifier)
|
||||||
|
elif api == Api.tool_runtime:
|
||||||
|
return await p.unregister_tool(obj.identifier)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unregister not supported for {api}")
|
||||||
|
|
||||||
|
|
||||||
|
Registry = dict[str, list[RoutableObjectWithProvider]]
|
||||||
|
|
||||||
|
|
||||||
|
class CommonRoutingTableImpl(RoutingTable):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
impls_by_provider_id: dict[str, RoutedProtocol],
|
||||||
|
dist_registry: DistributionRegistry,
|
||||||
|
) -> None:
|
||||||
|
self.impls_by_provider_id = impls_by_provider_id
|
||||||
|
self.dist_registry = dist_registry
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
async def add_objects(objs: list[RoutableObjectWithProvider], provider_id: str, cls) -> None:
|
||||||
|
for obj in objs:
|
||||||
|
if cls is None:
|
||||||
|
obj.provider_id = provider_id
|
||||||
|
else:
|
||||||
|
# Create a copy of the model data and explicitly set provider_id
|
||||||
|
model_data = obj.model_dump()
|
||||||
|
model_data["provider_id"] = provider_id
|
||||||
|
obj = cls(**model_data)
|
||||||
|
await self.dist_registry.register(obj)
|
||||||
|
|
||||||
|
# Register all objects from providers
|
||||||
|
for pid, p in self.impls_by_provider_id.items():
|
||||||
|
api = get_impl_api(p)
|
||||||
|
if api == Api.inference:
|
||||||
|
p.model_store = self
|
||||||
|
elif api == Api.safety:
|
||||||
|
p.shield_store = self
|
||||||
|
elif api == Api.vector_io:
|
||||||
|
p.vector_db_store = self
|
||||||
|
elif api == Api.datasetio:
|
||||||
|
p.dataset_store = self
|
||||||
|
elif api == Api.scoring:
|
||||||
|
p.scoring_function_store = self
|
||||||
|
scoring_functions = await p.list_scoring_functions()
|
||||||
|
await add_objects(scoring_functions, pid, ScoringFn)
|
||||||
|
elif api == Api.eval:
|
||||||
|
p.benchmark_store = self
|
||||||
|
elif api == Api.tool_runtime:
|
||||||
|
p.tool_store = self
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
for p in self.impls_by_provider_id.values():
|
||||||
|
await p.shutdown()
|
||||||
|
|
||||||
|
def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any:
|
||||||
|
def apiname_object():
|
||||||
|
if isinstance(self, ModelsRoutingTable):
|
||||||
|
return ("Inference", "model")
|
||||||
|
elif isinstance(self, ShieldsRoutingTable):
|
||||||
|
return ("Safety", "shield")
|
||||||
|
elif isinstance(self, VectorDBsRoutingTable):
|
||||||
|
return ("VectorIO", "vector_db")
|
||||||
|
elif isinstance(self, DatasetsRoutingTable):
|
||||||
|
return ("DatasetIO", "dataset")
|
||||||
|
elif isinstance(self, ScoringFunctionsRoutingTable):
|
||||||
|
return ("Scoring", "scoring_function")
|
||||||
|
elif isinstance(self, BenchmarksRoutingTable):
|
||||||
|
return ("Eval", "benchmark")
|
||||||
|
elif isinstance(self, ToolGroupsRoutingTable):
|
||||||
|
return ("Tools", "tool")
|
||||||
|
else:
|
||||||
|
raise ValueError("Unknown routing table type")
|
||||||
|
|
||||||
|
apiname, objtype = apiname_object()
|
||||||
|
|
||||||
|
# Get objects from disk registry
|
||||||
|
obj = self.dist_registry.get_cached(objtype, routing_key)
|
||||||
|
if not obj:
|
||||||
|
provider_ids = list(self.impls_by_provider_id.keys())
|
||||||
|
if len(provider_ids) > 1:
|
||||||
|
provider_ids_str = f"any of the providers: {', '.join(provider_ids)}"
|
||||||
|
else:
|
||||||
|
provider_ids_str = f"provider: `{provider_ids[0]}`"
|
||||||
|
raise ValueError(
|
||||||
|
f"{objtype.capitalize()} `{routing_key}` not served by {provider_ids_str}. Make sure there is an {apiname} provider serving this {objtype}."
|
||||||
|
)
|
||||||
|
|
||||||
|
if not provider_id or provider_id == obj.provider_id:
|
||||||
|
return self.impls_by_provider_id[obj.provider_id]
|
||||||
|
|
||||||
|
raise ValueError(f"Provider not found for `{routing_key}`")
|
||||||
|
|
||||||
|
async def get_object_by_identifier(self, type: str, identifier: str) -> RoutableObjectWithProvider | None:
|
||||||
|
# Get from disk registry
|
||||||
|
obj = await self.dist_registry.get(type, identifier)
|
||||||
|
if not obj:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Check if user has permission to access this object
|
||||||
|
if not check_access(obj.identifier, getattr(obj, "access_attributes", None), get_auth_attributes()):
|
||||||
|
logger.debug(f"Access denied to {type} '{identifier}' based on attribute mismatch")
|
||||||
|
return None
|
||||||
|
|
||||||
|
return obj
|
||||||
|
|
||||||
|
async def unregister_object(self, obj: RoutableObjectWithProvider) -> None:
|
||||||
|
await self.dist_registry.delete(obj.type, obj.identifier)
|
||||||
|
await unregister_object_from_provider(obj, self.impls_by_provider_id[obj.provider_id])
|
||||||
|
|
||||||
|
async def register_object(self, obj: RoutableObjectWithProvider) -> RoutableObjectWithProvider:
|
||||||
|
# if provider_id is not specified, pick an arbitrary one from existing entries
|
||||||
|
if not obj.provider_id and len(self.impls_by_provider_id) > 0:
|
||||||
|
obj.provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||||
|
|
||||||
|
if obj.provider_id not in self.impls_by_provider_id:
|
||||||
|
raise ValueError(f"Provider `{obj.provider_id}` not found")
|
||||||
|
|
||||||
|
p = self.impls_by_provider_id[obj.provider_id]
|
||||||
|
|
||||||
|
# If object supports access control but no attributes set, use creator's attributes
|
||||||
|
if not obj.access_attributes:
|
||||||
|
creator_attributes = get_auth_attributes()
|
||||||
|
if creator_attributes:
|
||||||
|
obj.access_attributes = AccessAttributes(**creator_attributes)
|
||||||
|
logger.info(f"Setting access attributes for {obj.type} '{obj.identifier}' based on creator's identity")
|
||||||
|
|
||||||
|
registered_obj = await register_object_with_provider(obj, p)
|
||||||
|
# TODO: This needs to be fixed for all APIs once they return the registered object
|
||||||
|
if obj.type == ResourceType.model.value:
|
||||||
|
await self.dist_registry.register(registered_obj)
|
||||||
|
return registered_obj
|
||||||
|
|
||||||
|
else:
|
||||||
|
await self.dist_registry.register(obj)
|
||||||
|
return obj
|
||||||
|
|
||||||
|
async def get_all_with_type(self, type: str) -> list[RoutableObjectWithProvider]:
|
||||||
|
objs = await self.dist_registry.get_all()
|
||||||
|
filtered_objs = [obj for obj in objs if obj.type == type]
|
||||||
|
|
||||||
|
# Apply attribute-based access control filtering
|
||||||
|
if filtered_objs:
|
||||||
|
filtered_objs = [
|
||||||
|
obj
|
||||||
|
for obj in filtered_objs
|
||||||
|
if check_access(obj.identifier, getattr(obj, "access_attributes", None), get_auth_attributes())
|
||||||
|
]
|
||||||
|
|
||||||
|
return filtered_objs
|
||||||
|
|
||||||
|
|
||||||
|
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
|
async def list_models(self) -> ListModelsResponse:
|
||||||
|
return ListModelsResponse(data=await self.get_all_with_type("model"))
|
||||||
|
|
||||||
|
async def openai_list_models(self) -> OpenAIListModelsResponse:
|
||||||
|
models = await self.get_all_with_type("model")
|
||||||
|
openai_models = [
|
||||||
|
OpenAIModel(
|
||||||
|
id=model.identifier,
|
||||||
|
object="model",
|
||||||
|
created=int(time.time()),
|
||||||
|
owned_by="llama_stack",
|
||||||
|
)
|
||||||
|
for model in models
|
||||||
|
]
|
||||||
|
return OpenAIListModelsResponse(data=openai_models)
|
||||||
|
|
||||||
|
async def get_model(self, model_id: str) -> Model:
|
||||||
|
model = await self.get_object_by_identifier("model", model_id)
|
||||||
|
if model is None:
|
||||||
|
raise ValueError(f"Model '{model_id}' not found")
|
||||||
|
return model
|
||||||
|
|
||||||
|
async def register_model(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
provider_model_id: str | None = None,
|
||||||
|
provider_id: str | None = None,
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
|
model_type: ModelType | None = None,
|
||||||
|
) -> Model:
|
||||||
|
if provider_model_id is None:
|
||||||
|
provider_model_id = model_id
|
||||||
|
if provider_id is None:
|
||||||
|
# If provider_id not specified, use the only provider if it supports this model
|
||||||
|
if len(self.impls_by_provider_id) == 1:
|
||||||
|
provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"No provider specified and multiple providers available. Please specify a provider_id. Available providers: {self.impls_by_provider_id.keys()}"
|
||||||
|
)
|
||||||
|
if metadata is None:
|
||||||
|
metadata = {}
|
||||||
|
if model_type is None:
|
||||||
|
model_type = ModelType.llm
|
||||||
|
if "embedding_dimension" not in metadata and model_type == ModelType.embedding:
|
||||||
|
raise ValueError("Embedding model must have an embedding dimension in its metadata")
|
||||||
|
model = ModelWithACL(
|
||||||
|
identifier=model_id,
|
||||||
|
provider_resource_id=provider_model_id,
|
||||||
|
provider_id=provider_id,
|
||||||
|
metadata=metadata,
|
||||||
|
model_type=model_type,
|
||||||
|
)
|
||||||
|
registered_model = await self.register_object(model)
|
||||||
|
return registered_model
|
||||||
|
|
||||||
|
async def unregister_model(self, model_id: str) -> None:
|
||||||
|
existing_model = await self.get_model(model_id)
|
||||||
|
if existing_model is None:
|
||||||
|
raise ValueError(f"Model {model_id} not found")
|
||||||
|
await self.unregister_object(existing_model)
|
||||||
|
|
||||||
|
|
||||||
|
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
||||||
|
async def list_shields(self) -> ListShieldsResponse:
|
||||||
|
return ListShieldsResponse(data=await self.get_all_with_type(ResourceType.shield.value))
|
||||||
|
|
||||||
|
async def get_shield(self, identifier: str) -> Shield:
|
||||||
|
shield = await self.get_object_by_identifier("shield", identifier)
|
||||||
|
if shield is None:
|
||||||
|
raise ValueError(f"Shield '{identifier}' not found")
|
||||||
|
return shield
|
||||||
|
|
||||||
|
async def register_shield(
|
||||||
|
self,
|
||||||
|
shield_id: str,
|
||||||
|
provider_shield_id: str | None = None,
|
||||||
|
provider_id: str | None = None,
|
||||||
|
params: dict[str, Any] | None = None,
|
||||||
|
) -> Shield:
|
||||||
|
if provider_shield_id is None:
|
||||||
|
provider_shield_id = shield_id
|
||||||
|
if provider_id is None:
|
||||||
|
# If provider_id not specified, use the only provider if it supports this shield type
|
||||||
|
if len(self.impls_by_provider_id) == 1:
|
||||||
|
provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"No provider specified and multiple providers available. Please specify a provider_id."
|
||||||
|
)
|
||||||
|
if params is None:
|
||||||
|
params = {}
|
||||||
|
shield = ShieldWithACL(
|
||||||
|
identifier=shield_id,
|
||||||
|
provider_resource_id=provider_shield_id,
|
||||||
|
provider_id=provider_id,
|
||||||
|
params=params,
|
||||||
|
)
|
||||||
|
await self.register_object(shield)
|
||||||
|
return shield
|
||||||
|
|
||||||
|
|
||||||
|
class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
||||||
|
async def list_vector_dbs(self) -> ListVectorDBsResponse:
|
||||||
|
return ListVectorDBsResponse(data=await self.get_all_with_type("vector_db"))
|
||||||
|
|
||||||
|
async def get_vector_db(self, vector_db_id: str) -> VectorDB:
|
||||||
|
vector_db = await self.get_object_by_identifier("vector_db", vector_db_id)
|
||||||
|
if vector_db is None:
|
||||||
|
raise ValueError(f"Vector DB '{vector_db_id}' not found")
|
||||||
|
return vector_db
|
||||||
|
|
||||||
|
async def register_vector_db(
|
||||||
|
self,
|
||||||
|
vector_db_id: str,
|
||||||
|
embedding_model: str,
|
||||||
|
embedding_dimension: int | None = 384,
|
||||||
|
provider_id: str | None = None,
|
||||||
|
provider_vector_db_id: str | None = None,
|
||||||
|
) -> VectorDB:
|
||||||
|
if provider_vector_db_id is None:
|
||||||
|
provider_vector_db_id = vector_db_id
|
||||||
|
if provider_id is None:
|
||||||
|
if len(self.impls_by_provider_id) > 0:
|
||||||
|
provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||||
|
if len(self.impls_by_provider_id) > 1:
|
||||||
|
logger.warning(
|
||||||
|
f"No provider specified and multiple providers available. Arbitrarily selected the first provider {provider_id}."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("No provider available. Please configure a vector_io provider.")
|
||||||
|
model = await self.get_object_by_identifier("model", embedding_model)
|
||||||
|
if model is None:
|
||||||
|
raise ValueError(f"Model {embedding_model} not found")
|
||||||
|
if model.model_type != ModelType.embedding:
|
||||||
|
raise ValueError(f"Model {embedding_model} is not an embedding model")
|
||||||
|
if "embedding_dimension" not in model.metadata:
|
||||||
|
raise ValueError(f"Model {embedding_model} does not have an embedding dimension")
|
||||||
|
vector_db_data = {
|
||||||
|
"identifier": vector_db_id,
|
||||||
|
"type": ResourceType.vector_db.value,
|
||||||
|
"provider_id": provider_id,
|
||||||
|
"provider_resource_id": provider_vector_db_id,
|
||||||
|
"embedding_model": embedding_model,
|
||||||
|
"embedding_dimension": model.metadata["embedding_dimension"],
|
||||||
|
}
|
||||||
|
vector_db = TypeAdapter(VectorDBWithACL).validate_python(vector_db_data)
|
||||||
|
await self.register_object(vector_db)
|
||||||
|
return vector_db
|
||||||
|
|
||||||
|
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||||
|
existing_vector_db = await self.get_vector_db(vector_db_id)
|
||||||
|
if existing_vector_db is None:
|
||||||
|
raise ValueError(f"Vector DB {vector_db_id} not found")
|
||||||
|
await self.unregister_object(existing_vector_db)
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
||||||
|
async def list_datasets(self) -> ListDatasetsResponse:
|
||||||
|
return ListDatasetsResponse(data=await self.get_all_with_type(ResourceType.dataset.value))
|
||||||
|
|
||||||
|
async def get_dataset(self, dataset_id: str) -> Dataset:
|
||||||
|
dataset = await self.get_object_by_identifier("dataset", dataset_id)
|
||||||
|
if dataset is None:
|
||||||
|
raise ValueError(f"Dataset '{dataset_id}' not found")
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
async def register_dataset(
|
||||||
|
self,
|
||||||
|
purpose: DatasetPurpose,
|
||||||
|
source: DataSource,
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
|
dataset_id: str | None = None,
|
||||||
|
) -> Dataset:
|
||||||
|
if isinstance(source, dict):
|
||||||
|
if source["type"] == "uri":
|
||||||
|
source = URIDataSource.parse_obj(source)
|
||||||
|
elif source["type"] == "rows":
|
||||||
|
source = RowsDataSource.parse_obj(source)
|
||||||
|
|
||||||
|
if not dataset_id:
|
||||||
|
dataset_id = f"dataset-{str(uuid.uuid4())}"
|
||||||
|
|
||||||
|
provider_dataset_id = dataset_id
|
||||||
|
|
||||||
|
# infer provider from source
|
||||||
|
if metadata:
|
||||||
|
if metadata.get("provider_id"):
|
||||||
|
provider_id = metadata.get("provider_id") # pass through from nvidia datasetio
|
||||||
|
elif source.type == DatasetType.rows.value:
|
||||||
|
provider_id = "localfs"
|
||||||
|
elif source.type == DatasetType.uri.value:
|
||||||
|
# infer provider from uri
|
||||||
|
if source.uri.startswith("huggingface"):
|
||||||
|
provider_id = "huggingface"
|
||||||
|
else:
|
||||||
|
provider_id = "localfs"
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown data source type: {source.type}")
|
||||||
|
|
||||||
|
if metadata is None:
|
||||||
|
metadata = {}
|
||||||
|
|
||||||
|
dataset = DatasetWithACL(
|
||||||
|
identifier=dataset_id,
|
||||||
|
provider_resource_id=provider_dataset_id,
|
||||||
|
provider_id=provider_id,
|
||||||
|
purpose=purpose,
|
||||||
|
source=source,
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
await self.register_object(dataset)
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
async def unregister_dataset(self, dataset_id: str) -> None:
|
||||||
|
dataset = await self.get_dataset(dataset_id)
|
||||||
|
if dataset is None:
|
||||||
|
raise ValueError(f"Dataset {dataset_id} not found")
|
||||||
|
await self.unregister_object(dataset)
|
||||||
|
|
||||||
|
|
||||||
|
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
|
||||||
|
async def list_scoring_functions(self) -> ListScoringFunctionsResponse:
|
||||||
|
return ListScoringFunctionsResponse(data=await self.get_all_with_type(ResourceType.scoring_function.value))
|
||||||
|
|
||||||
|
async def get_scoring_function(self, scoring_fn_id: str) -> ScoringFn:
|
||||||
|
scoring_fn = await self.get_object_by_identifier("scoring_function", scoring_fn_id)
|
||||||
|
if scoring_fn is None:
|
||||||
|
raise ValueError(f"Scoring function '{scoring_fn_id}' not found")
|
||||||
|
return scoring_fn
|
||||||
|
|
||||||
|
async def register_scoring_function(
|
||||||
|
self,
|
||||||
|
scoring_fn_id: str,
|
||||||
|
description: str,
|
||||||
|
return_type: ParamType,
|
||||||
|
provider_scoring_fn_id: str | None = None,
|
||||||
|
provider_id: str | None = None,
|
||||||
|
params: ScoringFnParams | None = None,
|
||||||
|
) -> None:
|
||||||
|
if provider_scoring_fn_id is None:
|
||||||
|
provider_scoring_fn_id = scoring_fn_id
|
||||||
|
if provider_id is None:
|
||||||
|
if len(self.impls_by_provider_id) == 1:
|
||||||
|
provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"No provider specified and multiple providers available. Please specify a provider_id."
|
||||||
|
)
|
||||||
|
scoring_fn = ScoringFnWithACL(
|
||||||
|
identifier=scoring_fn_id,
|
||||||
|
description=description,
|
||||||
|
return_type=return_type,
|
||||||
|
provider_resource_id=provider_scoring_fn_id,
|
||||||
|
provider_id=provider_id,
|
||||||
|
params=params,
|
||||||
|
)
|
||||||
|
scoring_fn.provider_id = provider_id
|
||||||
|
await self.register_object(scoring_fn)
|
||||||
|
|
||||||
|
|
||||||
|
class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
|
||||||
|
async def list_benchmarks(self) -> ListBenchmarksResponse:
|
||||||
|
return ListBenchmarksResponse(data=await self.get_all_with_type("benchmark"))
|
||||||
|
|
||||||
|
async def get_benchmark(self, benchmark_id: str) -> Benchmark:
|
||||||
|
benchmark = await self.get_object_by_identifier("benchmark", benchmark_id)
|
||||||
|
if benchmark is None:
|
||||||
|
raise ValueError(f"Benchmark '{benchmark_id}' not found")
|
||||||
|
return benchmark
|
||||||
|
|
||||||
|
async def register_benchmark(
|
||||||
|
self,
|
||||||
|
benchmark_id: str,
|
||||||
|
dataset_id: str,
|
||||||
|
scoring_functions: list[str],
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
|
provider_benchmark_id: str | None = None,
|
||||||
|
provider_id: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
if metadata is None:
|
||||||
|
metadata = {}
|
||||||
|
if provider_id is None:
|
||||||
|
if len(self.impls_by_provider_id) == 1:
|
||||||
|
provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"No provider specified and multiple providers available. Please specify a provider_id."
|
||||||
|
)
|
||||||
|
if provider_benchmark_id is None:
|
||||||
|
provider_benchmark_id = benchmark_id
|
||||||
|
benchmark = BenchmarkWithACL(
|
||||||
|
identifier=benchmark_id,
|
||||||
|
dataset_id=dataset_id,
|
||||||
|
scoring_functions=scoring_functions,
|
||||||
|
metadata=metadata,
|
||||||
|
provider_id=provider_id,
|
||||||
|
provider_resource_id=provider_benchmark_id,
|
||||||
|
)
|
||||||
|
await self.register_object(benchmark)
|
||||||
|
|
||||||
|
|
||||||
|
class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
||||||
|
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse:
|
||||||
|
tools = await self.get_all_with_type("tool")
|
||||||
|
if toolgroup_id:
|
||||||
|
tools = [tool for tool in tools if tool.toolgroup_id == toolgroup_id]
|
||||||
|
return ListToolsResponse(data=tools)
|
||||||
|
|
||||||
|
async def list_tool_groups(self) -> ListToolGroupsResponse:
|
||||||
|
return ListToolGroupsResponse(data=await self.get_all_with_type("tool_group"))
|
||||||
|
|
||||||
|
async def get_tool_group(self, toolgroup_id: str) -> ToolGroup:
|
||||||
|
tool_group = await self.get_object_by_identifier("tool_group", toolgroup_id)
|
||||||
|
if tool_group is None:
|
||||||
|
raise ValueError(f"Tool group '{toolgroup_id}' not found")
|
||||||
|
return tool_group
|
||||||
|
|
||||||
|
async def get_tool(self, tool_name: str) -> Tool:
|
||||||
|
return await self.get_object_by_identifier("tool", tool_name)
|
||||||
|
|
||||||
|
async def register_tool_group(
|
||||||
|
self,
|
||||||
|
toolgroup_id: str,
|
||||||
|
provider_id: str,
|
||||||
|
mcp_endpoint: URL | None = None,
|
||||||
|
args: dict[str, Any] | None = None,
|
||||||
|
) -> None:
|
||||||
|
tools = []
|
||||||
|
tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools(toolgroup_id, mcp_endpoint)
|
||||||
|
tool_host = ToolHost.model_context_protocol if mcp_endpoint else ToolHost.distribution
|
||||||
|
|
||||||
|
for tool_def in tool_defs.data:
|
||||||
|
tools.append(
|
||||||
|
ToolWithACL(
|
||||||
|
identifier=tool_def.name,
|
||||||
|
toolgroup_id=toolgroup_id,
|
||||||
|
description=tool_def.description or "",
|
||||||
|
parameters=tool_def.parameters or [],
|
||||||
|
provider_id=provider_id,
|
||||||
|
provider_resource_id=tool_def.name,
|
||||||
|
metadata=tool_def.metadata,
|
||||||
|
tool_host=tool_host,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for tool in tools:
|
||||||
|
existing_tool = await self.get_tool(tool.identifier)
|
||||||
|
# Compare existing and new object if one exists
|
||||||
|
if existing_tool:
|
||||||
|
existing_dict = existing_tool.model_dump()
|
||||||
|
new_dict = tool.model_dump()
|
||||||
|
|
||||||
|
if existing_dict != new_dict:
|
||||||
|
raise ValueError(
|
||||||
|
f"Object {tool.identifier} already exists in registry. Please use a different identifier."
|
||||||
|
)
|
||||||
|
await self.register_object(tool)
|
||||||
|
|
||||||
|
await self.dist_registry.register(
|
||||||
|
ToolGroupWithACL(
|
||||||
|
identifier=toolgroup_id,
|
||||||
|
provider_id=provider_id,
|
||||||
|
provider_resource_id=toolgroup_id,
|
||||||
|
mcp_endpoint=mcp_endpoint,
|
||||||
|
args=args,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
|
||||||
|
tool_group = await self.get_tool_group(toolgroup_id)
|
||||||
|
if tool_group is None:
|
||||||
|
raise ValueError(f"Tool group {toolgroup_id} not found")
|
||||||
|
tools = await self.list_tools(toolgroup_id)
|
||||||
|
for tool in getattr(tools, "data", []):
|
||||||
|
await self.unregister_object(tool)
|
||||||
|
await self.unregister_object(tool_group)
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
pass
|
|
@ -1,57 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from llama_stack.apis.inference import (
|
|
||||||
Message,
|
|
||||||
)
|
|
||||||
from llama_stack.apis.safety import RunShieldResponse, Safety
|
|
||||||
from llama_stack.apis.shields import Shield
|
|
||||||
from llama_stack.log import get_logger
|
|
||||||
from llama_stack.providers.datatypes import RoutingTable
|
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
|
||||||
|
|
||||||
|
|
||||||
class SafetyRouter(Safety):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
routing_table: RoutingTable,
|
|
||||||
) -> None:
|
|
||||||
logger.debug("Initializing SafetyRouter")
|
|
||||||
self.routing_table = routing_table
|
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
|
||||||
logger.debug("SafetyRouter.initialize")
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
|
||||||
logger.debug("SafetyRouter.shutdown")
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def register_shield(
|
|
||||||
self,
|
|
||||||
shield_id: str,
|
|
||||||
provider_shield_id: str | None = None,
|
|
||||||
provider_id: str | None = None,
|
|
||||||
params: dict[str, Any] | None = None,
|
|
||||||
) -> Shield:
|
|
||||||
logger.debug(f"SafetyRouter.register_shield: {shield_id}")
|
|
||||||
return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params)
|
|
||||||
|
|
||||||
async def run_shield(
|
|
||||||
self,
|
|
||||||
shield_id: str,
|
|
||||||
messages: list[Message],
|
|
||||||
params: dict[str, Any] = None,
|
|
||||||
) -> RunShieldResponse:
|
|
||||||
logger.debug(f"SafetyRouter.run_shield: {shield_id}")
|
|
||||||
return await self.routing_table.get_provider_impl(shield_id).run_shield(
|
|
||||||
shield_id=shield_id,
|
|
||||||
messages=messages,
|
|
||||||
params=params,
|
|
||||||
)
|
|
|
@ -1,92 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import (
|
|
||||||
URL,
|
|
||||||
InterleavedContent,
|
|
||||||
)
|
|
||||||
from llama_stack.apis.tools import (
|
|
||||||
ListToolsResponse,
|
|
||||||
RAGDocument,
|
|
||||||
RAGQueryConfig,
|
|
||||||
RAGQueryResult,
|
|
||||||
RAGToolRuntime,
|
|
||||||
ToolRuntime,
|
|
||||||
)
|
|
||||||
from llama_stack.log import get_logger
|
|
||||||
|
|
||||||
from ..routing_tables.toolgroups import ToolGroupsRoutingTable
|
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
|
||||||
|
|
||||||
|
|
||||||
class ToolRuntimeRouter(ToolRuntime):
|
|
||||||
class RagToolImpl(RAGToolRuntime):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
routing_table: ToolGroupsRoutingTable,
|
|
||||||
) -> None:
|
|
||||||
logger.debug("Initializing ToolRuntimeRouter.RagToolImpl")
|
|
||||||
self.routing_table = routing_table
|
|
||||||
|
|
||||||
async def query(
|
|
||||||
self,
|
|
||||||
content: InterleavedContent,
|
|
||||||
vector_db_ids: list[str],
|
|
||||||
query_config: RAGQueryConfig | None = None,
|
|
||||||
) -> RAGQueryResult:
|
|
||||||
logger.debug(f"ToolRuntimeRouter.RagToolImpl.query: {vector_db_ids}")
|
|
||||||
return await self.routing_table.get_provider_impl("knowledge_search").query(
|
|
||||||
content, vector_db_ids, query_config
|
|
||||||
)
|
|
||||||
|
|
||||||
async def insert(
|
|
||||||
self,
|
|
||||||
documents: list[RAGDocument],
|
|
||||||
vector_db_id: str,
|
|
||||||
chunk_size_in_tokens: int = 512,
|
|
||||||
) -> None:
|
|
||||||
logger.debug(
|
|
||||||
f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}"
|
|
||||||
)
|
|
||||||
return await self.routing_table.get_provider_impl("insert_into_memory").insert(
|
|
||||||
documents, vector_db_id, chunk_size_in_tokens
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
routing_table: ToolGroupsRoutingTable,
|
|
||||||
) -> None:
|
|
||||||
logger.debug("Initializing ToolRuntimeRouter")
|
|
||||||
self.routing_table = routing_table
|
|
||||||
|
|
||||||
# HACK ALERT this should be in sync with "get_all_api_endpoints()"
|
|
||||||
self.rag_tool = self.RagToolImpl(routing_table)
|
|
||||||
for method in ("query", "insert"):
|
|
||||||
setattr(self, f"rag_tool.{method}", getattr(self.rag_tool, method))
|
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
|
||||||
logger.debug("ToolRuntimeRouter.initialize")
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
|
||||||
logger.debug("ToolRuntimeRouter.shutdown")
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> Any:
|
|
||||||
logger.debug(f"ToolRuntimeRouter.invoke_tool: {tool_name}")
|
|
||||||
return await self.routing_table.get_provider_impl(tool_name).invoke_tool(
|
|
||||||
tool_name=tool_name,
|
|
||||||
kwargs=kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def list_runtime_tools(
|
|
||||||
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
|
|
||||||
) -> ListToolsResponse:
|
|
||||||
logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
|
|
||||||
return await self.routing_table.list_tools(tool_group_id)
|
|
|
@ -1,72 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import (
|
|
||||||
InterleavedContent,
|
|
||||||
)
|
|
||||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
|
||||||
from llama_stack.log import get_logger
|
|
||||||
from llama_stack.providers.datatypes import RoutingTable
|
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
|
||||||
|
|
||||||
|
|
||||||
class VectorIORouter(VectorIO):
|
|
||||||
"""Routes to an provider based on the vector db identifier"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
routing_table: RoutingTable,
|
|
||||||
) -> None:
|
|
||||||
logger.debug("Initializing VectorIORouter")
|
|
||||||
self.routing_table = routing_table
|
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
|
||||||
logger.debug("VectorIORouter.initialize")
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
|
||||||
logger.debug("VectorIORouter.shutdown")
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def register_vector_db(
|
|
||||||
self,
|
|
||||||
vector_db_id: str,
|
|
||||||
embedding_model: str,
|
|
||||||
embedding_dimension: int | None = 384,
|
|
||||||
provider_id: str | None = None,
|
|
||||||
provider_vector_db_id: str | None = None,
|
|
||||||
) -> None:
|
|
||||||
logger.debug(f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}")
|
|
||||||
await self.routing_table.register_vector_db(
|
|
||||||
vector_db_id,
|
|
||||||
embedding_model,
|
|
||||||
embedding_dimension,
|
|
||||||
provider_id,
|
|
||||||
provider_vector_db_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def insert_chunks(
|
|
||||||
self,
|
|
||||||
vector_db_id: str,
|
|
||||||
chunks: list[Chunk],
|
|
||||||
ttl_seconds: int | None = None,
|
|
||||||
) -> None:
|
|
||||||
logger.debug(
|
|
||||||
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}",
|
|
||||||
)
|
|
||||||
return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(vector_db_id, chunks, ttl_seconds)
|
|
||||||
|
|
||||||
async def query_chunks(
|
|
||||||
self,
|
|
||||||
vector_db_id: str,
|
|
||||||
query: InterleavedContent,
|
|
||||||
params: dict[str, Any] | None = None,
|
|
||||||
) -> QueryChunksResponse:
|
|
||||||
logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}")
|
|
||||||
return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params)
|
|
|
@ -1,5 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
|
@ -1,58 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from llama_stack.apis.benchmarks import Benchmark, Benchmarks, ListBenchmarksResponse
|
|
||||||
from llama_stack.distribution.datatypes import (
|
|
||||||
BenchmarkWithACL,
|
|
||||||
)
|
|
||||||
from llama_stack.log import get_logger
|
|
||||||
|
|
||||||
from .common import CommonRoutingTableImpl
|
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
|
||||||
|
|
||||||
|
|
||||||
class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
|
|
||||||
async def list_benchmarks(self) -> ListBenchmarksResponse:
|
|
||||||
return ListBenchmarksResponse(data=await self.get_all_with_type("benchmark"))
|
|
||||||
|
|
||||||
async def get_benchmark(self, benchmark_id: str) -> Benchmark:
|
|
||||||
benchmark = await self.get_object_by_identifier("benchmark", benchmark_id)
|
|
||||||
if benchmark is None:
|
|
||||||
raise ValueError(f"Benchmark '{benchmark_id}' not found")
|
|
||||||
return benchmark
|
|
||||||
|
|
||||||
async def register_benchmark(
|
|
||||||
self,
|
|
||||||
benchmark_id: str,
|
|
||||||
dataset_id: str,
|
|
||||||
scoring_functions: list[str],
|
|
||||||
metadata: dict[str, Any] | None = None,
|
|
||||||
provider_benchmark_id: str | None = None,
|
|
||||||
provider_id: str | None = None,
|
|
||||||
) -> None:
|
|
||||||
if metadata is None:
|
|
||||||
metadata = {}
|
|
||||||
if provider_id is None:
|
|
||||||
if len(self.impls_by_provider_id) == 1:
|
|
||||||
provider_id = list(self.impls_by_provider_id.keys())[0]
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"No provider specified and multiple providers available. Please specify a provider_id."
|
|
||||||
)
|
|
||||||
if provider_benchmark_id is None:
|
|
||||||
provider_benchmark_id = benchmark_id
|
|
||||||
benchmark = BenchmarkWithACL(
|
|
||||||
identifier=benchmark_id,
|
|
||||||
dataset_id=dataset_id,
|
|
||||||
scoring_functions=scoring_functions,
|
|
||||||
metadata=metadata,
|
|
||||||
provider_id=provider_id,
|
|
||||||
provider_resource_id=provider_benchmark_id,
|
|
||||||
)
|
|
||||||
await self.register_object(benchmark)
|
|
|
@ -1,218 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from llama_stack.apis.resource import ResourceType
|
|
||||||
from llama_stack.apis.scoring_functions import ScoringFn
|
|
||||||
from llama_stack.distribution.access_control import check_access
|
|
||||||
from llama_stack.distribution.datatypes import (
|
|
||||||
AccessAttributes,
|
|
||||||
RoutableObject,
|
|
||||||
RoutableObjectWithProvider,
|
|
||||||
RoutedProtocol,
|
|
||||||
)
|
|
||||||
from llama_stack.distribution.request_headers import get_auth_attributes
|
|
||||||
from llama_stack.distribution.store import DistributionRegistry
|
|
||||||
from llama_stack.log import get_logger
|
|
||||||
from llama_stack.providers.datatypes import Api, RoutingTable
|
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
|
||||||
|
|
||||||
|
|
||||||
def get_impl_api(p: Any) -> Api:
|
|
||||||
return p.__provider_spec__.api
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: this should return the registered object for all APIs
|
|
||||||
async def register_object_with_provider(obj: RoutableObject, p: Any) -> RoutableObject:
|
|
||||||
api = get_impl_api(p)
|
|
||||||
|
|
||||||
assert obj.provider_id != "remote", "Remote provider should not be registered"
|
|
||||||
|
|
||||||
if api == Api.inference:
|
|
||||||
return await p.register_model(obj)
|
|
||||||
elif api == Api.safety:
|
|
||||||
return await p.register_shield(obj)
|
|
||||||
elif api == Api.vector_io:
|
|
||||||
return await p.register_vector_db(obj)
|
|
||||||
elif api == Api.datasetio:
|
|
||||||
return await p.register_dataset(obj)
|
|
||||||
elif api == Api.scoring:
|
|
||||||
return await p.register_scoring_function(obj)
|
|
||||||
elif api == Api.eval:
|
|
||||||
return await p.register_benchmark(obj)
|
|
||||||
elif api == Api.tool_runtime:
|
|
||||||
return await p.register_toolgroup(obj)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown API {api} for registering object with provider")
|
|
||||||
|
|
||||||
|
|
||||||
async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None:
|
|
||||||
api = get_impl_api(p)
|
|
||||||
if api == Api.vector_io:
|
|
||||||
return await p.unregister_vector_db(obj.identifier)
|
|
||||||
elif api == Api.inference:
|
|
||||||
return await p.unregister_model(obj.identifier)
|
|
||||||
elif api == Api.datasetio:
|
|
||||||
return await p.unregister_dataset(obj.identifier)
|
|
||||||
elif api == Api.tool_runtime:
|
|
||||||
return await p.unregister_toolgroup(obj.identifier)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unregister not supported for {api}")
|
|
||||||
|
|
||||||
|
|
||||||
Registry = dict[str, list[RoutableObjectWithProvider]]
|
|
||||||
|
|
||||||
|
|
||||||
class CommonRoutingTableImpl(RoutingTable):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
impls_by_provider_id: dict[str, RoutedProtocol],
|
|
||||||
dist_registry: DistributionRegistry,
|
|
||||||
) -> None:
|
|
||||||
self.impls_by_provider_id = impls_by_provider_id
|
|
||||||
self.dist_registry = dist_registry
|
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
|
||||||
async def add_objects(objs: list[RoutableObjectWithProvider], provider_id: str, cls) -> None:
|
|
||||||
for obj in objs:
|
|
||||||
if cls is None:
|
|
||||||
obj.provider_id = provider_id
|
|
||||||
else:
|
|
||||||
# Create a copy of the model data and explicitly set provider_id
|
|
||||||
model_data = obj.model_dump()
|
|
||||||
model_data["provider_id"] = provider_id
|
|
||||||
obj = cls(**model_data)
|
|
||||||
await self.dist_registry.register(obj)
|
|
||||||
|
|
||||||
# Register all objects from providers
|
|
||||||
for pid, p in self.impls_by_provider_id.items():
|
|
||||||
api = get_impl_api(p)
|
|
||||||
if api == Api.inference:
|
|
||||||
p.model_store = self
|
|
||||||
elif api == Api.safety:
|
|
||||||
p.shield_store = self
|
|
||||||
elif api == Api.vector_io:
|
|
||||||
p.vector_db_store = self
|
|
||||||
elif api == Api.datasetio:
|
|
||||||
p.dataset_store = self
|
|
||||||
elif api == Api.scoring:
|
|
||||||
p.scoring_function_store = self
|
|
||||||
scoring_functions = await p.list_scoring_functions()
|
|
||||||
await add_objects(scoring_functions, pid, ScoringFn)
|
|
||||||
elif api == Api.eval:
|
|
||||||
p.benchmark_store = self
|
|
||||||
elif api == Api.tool_runtime:
|
|
||||||
p.tool_store = self
|
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
|
||||||
for p in self.impls_by_provider_id.values():
|
|
||||||
await p.shutdown()
|
|
||||||
|
|
||||||
def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any:
|
|
||||||
from .benchmarks import BenchmarksRoutingTable
|
|
||||||
from .datasets import DatasetsRoutingTable
|
|
||||||
from .models import ModelsRoutingTable
|
|
||||||
from .scoring_functions import ScoringFunctionsRoutingTable
|
|
||||||
from .shields import ShieldsRoutingTable
|
|
||||||
from .toolgroups import ToolGroupsRoutingTable
|
|
||||||
from .vector_dbs import VectorDBsRoutingTable
|
|
||||||
|
|
||||||
def apiname_object():
|
|
||||||
if isinstance(self, ModelsRoutingTable):
|
|
||||||
return ("Inference", "model")
|
|
||||||
elif isinstance(self, ShieldsRoutingTable):
|
|
||||||
return ("Safety", "shield")
|
|
||||||
elif isinstance(self, VectorDBsRoutingTable):
|
|
||||||
return ("VectorIO", "vector_db")
|
|
||||||
elif isinstance(self, DatasetsRoutingTable):
|
|
||||||
return ("DatasetIO", "dataset")
|
|
||||||
elif isinstance(self, ScoringFunctionsRoutingTable):
|
|
||||||
return ("Scoring", "scoring_function")
|
|
||||||
elif isinstance(self, BenchmarksRoutingTable):
|
|
||||||
return ("Eval", "benchmark")
|
|
||||||
elif isinstance(self, ToolGroupsRoutingTable):
|
|
||||||
return ("ToolGroups", "tool_group")
|
|
||||||
else:
|
|
||||||
raise ValueError("Unknown routing table type")
|
|
||||||
|
|
||||||
apiname, objtype = apiname_object()
|
|
||||||
|
|
||||||
# Get objects from disk registry
|
|
||||||
obj = self.dist_registry.get_cached(objtype, routing_key)
|
|
||||||
if not obj:
|
|
||||||
provider_ids = list(self.impls_by_provider_id.keys())
|
|
||||||
if len(provider_ids) > 1:
|
|
||||||
provider_ids_str = f"any of the providers: {', '.join(provider_ids)}"
|
|
||||||
else:
|
|
||||||
provider_ids_str = f"provider: `{provider_ids[0]}`"
|
|
||||||
raise ValueError(
|
|
||||||
f"{objtype.capitalize()} `{routing_key}` not served by {provider_ids_str}. Make sure there is an {apiname} provider serving this {objtype}."
|
|
||||||
)
|
|
||||||
|
|
||||||
if not provider_id or provider_id == obj.provider_id:
|
|
||||||
return self.impls_by_provider_id[obj.provider_id]
|
|
||||||
|
|
||||||
raise ValueError(f"Provider not found for `{routing_key}`")
|
|
||||||
|
|
||||||
async def get_object_by_identifier(self, type: str, identifier: str) -> RoutableObjectWithProvider | None:
|
|
||||||
# Get from disk registry
|
|
||||||
obj = await self.dist_registry.get(type, identifier)
|
|
||||||
if not obj:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Check if user has permission to access this object
|
|
||||||
if not check_access(obj.identifier, getattr(obj, "access_attributes", None), get_auth_attributes()):
|
|
||||||
logger.debug(f"Access denied to {type} '{identifier}' based on attribute mismatch")
|
|
||||||
return None
|
|
||||||
|
|
||||||
return obj
|
|
||||||
|
|
||||||
async def unregister_object(self, obj: RoutableObjectWithProvider) -> None:
|
|
||||||
await self.dist_registry.delete(obj.type, obj.identifier)
|
|
||||||
await unregister_object_from_provider(obj, self.impls_by_provider_id[obj.provider_id])
|
|
||||||
|
|
||||||
async def register_object(self, obj: RoutableObjectWithProvider) -> RoutableObjectWithProvider:
|
|
||||||
# if provider_id is not specified, pick an arbitrary one from existing entries
|
|
||||||
if not obj.provider_id and len(self.impls_by_provider_id) > 0:
|
|
||||||
obj.provider_id = list(self.impls_by_provider_id.keys())[0]
|
|
||||||
|
|
||||||
if obj.provider_id not in self.impls_by_provider_id:
|
|
||||||
raise ValueError(f"Provider `{obj.provider_id}` not found")
|
|
||||||
|
|
||||||
p = self.impls_by_provider_id[obj.provider_id]
|
|
||||||
|
|
||||||
# If object supports access control but no attributes set, use creator's attributes
|
|
||||||
if not obj.access_attributes:
|
|
||||||
creator_attributes = get_auth_attributes()
|
|
||||||
if creator_attributes:
|
|
||||||
obj.access_attributes = AccessAttributes(**creator_attributes)
|
|
||||||
logger.info(f"Setting access attributes for {obj.type} '{obj.identifier}' based on creator's identity")
|
|
||||||
|
|
||||||
registered_obj = await register_object_with_provider(obj, p)
|
|
||||||
# TODO: This needs to be fixed for all APIs once they return the registered object
|
|
||||||
if obj.type == ResourceType.model.value:
|
|
||||||
await self.dist_registry.register(registered_obj)
|
|
||||||
return registered_obj
|
|
||||||
|
|
||||||
else:
|
|
||||||
await self.dist_registry.register(obj)
|
|
||||||
return obj
|
|
||||||
|
|
||||||
async def get_all_with_type(self, type: str) -> list[RoutableObjectWithProvider]:
|
|
||||||
objs = await self.dist_registry.get_all()
|
|
||||||
filtered_objs = [obj for obj in objs if obj.type == type]
|
|
||||||
|
|
||||||
# Apply attribute-based access control filtering
|
|
||||||
if filtered_objs:
|
|
||||||
filtered_objs = [
|
|
||||||
obj
|
|
||||||
for obj in filtered_objs
|
|
||||||
if check_access(obj.identifier, getattr(obj, "access_attributes", None), get_auth_attributes())
|
|
||||||
]
|
|
||||||
|
|
||||||
return filtered_objs
|
|
|
@ -1,93 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
import uuid
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from llama_stack.apis.datasets import (
|
|
||||||
Dataset,
|
|
||||||
DatasetPurpose,
|
|
||||||
Datasets,
|
|
||||||
DatasetType,
|
|
||||||
DataSource,
|
|
||||||
ListDatasetsResponse,
|
|
||||||
RowsDataSource,
|
|
||||||
URIDataSource,
|
|
||||||
)
|
|
||||||
from llama_stack.apis.resource import ResourceType
|
|
||||||
from llama_stack.distribution.datatypes import (
|
|
||||||
DatasetWithACL,
|
|
||||||
)
|
|
||||||
from llama_stack.log import get_logger
|
|
||||||
|
|
||||||
from .common import CommonRoutingTableImpl
|
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
|
||||||
|
|
||||||
|
|
||||||
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
|
||||||
async def list_datasets(self) -> ListDatasetsResponse:
|
|
||||||
return ListDatasetsResponse(data=await self.get_all_with_type(ResourceType.dataset.value))
|
|
||||||
|
|
||||||
async def get_dataset(self, dataset_id: str) -> Dataset:
|
|
||||||
dataset = await self.get_object_by_identifier("dataset", dataset_id)
|
|
||||||
if dataset is None:
|
|
||||||
raise ValueError(f"Dataset '{dataset_id}' not found")
|
|
||||||
return dataset
|
|
||||||
|
|
||||||
async def register_dataset(
|
|
||||||
self,
|
|
||||||
purpose: DatasetPurpose,
|
|
||||||
source: DataSource,
|
|
||||||
metadata: dict[str, Any] | None = None,
|
|
||||||
dataset_id: str | None = None,
|
|
||||||
) -> Dataset:
|
|
||||||
if isinstance(source, dict):
|
|
||||||
if source["type"] == "uri":
|
|
||||||
source = URIDataSource.parse_obj(source)
|
|
||||||
elif source["type"] == "rows":
|
|
||||||
source = RowsDataSource.parse_obj(source)
|
|
||||||
|
|
||||||
if not dataset_id:
|
|
||||||
dataset_id = f"dataset-{str(uuid.uuid4())}"
|
|
||||||
|
|
||||||
provider_dataset_id = dataset_id
|
|
||||||
|
|
||||||
# infer provider from source
|
|
||||||
if metadata:
|
|
||||||
if metadata.get("provider_id"):
|
|
||||||
provider_id = metadata.get("provider_id") # pass through from nvidia datasetio
|
|
||||||
elif source.type == DatasetType.rows.value:
|
|
||||||
provider_id = "localfs"
|
|
||||||
elif source.type == DatasetType.uri.value:
|
|
||||||
# infer provider from uri
|
|
||||||
if source.uri.startswith("huggingface"):
|
|
||||||
provider_id = "huggingface"
|
|
||||||
else:
|
|
||||||
provider_id = "localfs"
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown data source type: {source.type}")
|
|
||||||
|
|
||||||
if metadata is None:
|
|
||||||
metadata = {}
|
|
||||||
|
|
||||||
dataset = DatasetWithACL(
|
|
||||||
identifier=dataset_id,
|
|
||||||
provider_resource_id=provider_dataset_id,
|
|
||||||
provider_id=provider_id,
|
|
||||||
purpose=purpose,
|
|
||||||
source=source,
|
|
||||||
metadata=metadata,
|
|
||||||
)
|
|
||||||
|
|
||||||
await self.register_object(dataset)
|
|
||||||
return dataset
|
|
||||||
|
|
||||||
async def unregister_dataset(self, dataset_id: str) -> None:
|
|
||||||
dataset = await self.get_dataset(dataset_id)
|
|
||||||
if dataset is None:
|
|
||||||
raise ValueError(f"Dataset {dataset_id} not found")
|
|
||||||
await self.unregister_object(dataset)
|
|
|
@ -1,82 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
import time
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType, OpenAIListModelsResponse, OpenAIModel
|
|
||||||
from llama_stack.distribution.datatypes import (
|
|
||||||
ModelWithACL,
|
|
||||||
)
|
|
||||||
from llama_stack.log import get_logger
|
|
||||||
|
|
||||||
from .common import CommonRoutingTableImpl
|
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
|
||||||
|
|
||||||
|
|
||||||
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|
||||||
async def list_models(self) -> ListModelsResponse:
|
|
||||||
return ListModelsResponse(data=await self.get_all_with_type("model"))
|
|
||||||
|
|
||||||
async def openai_list_models(self) -> OpenAIListModelsResponse:
|
|
||||||
models = await self.get_all_with_type("model")
|
|
||||||
openai_models = [
|
|
||||||
OpenAIModel(
|
|
||||||
id=model.identifier,
|
|
||||||
object="model",
|
|
||||||
created=int(time.time()),
|
|
||||||
owned_by="llama_stack",
|
|
||||||
)
|
|
||||||
for model in models
|
|
||||||
]
|
|
||||||
return OpenAIListModelsResponse(data=openai_models)
|
|
||||||
|
|
||||||
async def get_model(self, model_id: str) -> Model:
|
|
||||||
model = await self.get_object_by_identifier("model", model_id)
|
|
||||||
if model is None:
|
|
||||||
raise ValueError(f"Model '{model_id}' not found")
|
|
||||||
return model
|
|
||||||
|
|
||||||
async def register_model(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
provider_model_id: str | None = None,
|
|
||||||
provider_id: str | None = None,
|
|
||||||
metadata: dict[str, Any] | None = None,
|
|
||||||
model_type: ModelType | None = None,
|
|
||||||
) -> Model:
|
|
||||||
if provider_model_id is None:
|
|
||||||
provider_model_id = model_id
|
|
||||||
if provider_id is None:
|
|
||||||
# If provider_id not specified, use the only provider if it supports this model
|
|
||||||
if len(self.impls_by_provider_id) == 1:
|
|
||||||
provider_id = list(self.impls_by_provider_id.keys())[0]
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"No provider specified and multiple providers available. Please specify a provider_id. Available providers: {self.impls_by_provider_id.keys()}"
|
|
||||||
)
|
|
||||||
if metadata is None:
|
|
||||||
metadata = {}
|
|
||||||
if model_type is None:
|
|
||||||
model_type = ModelType.llm
|
|
||||||
if "embedding_dimension" not in metadata and model_type == ModelType.embedding:
|
|
||||||
raise ValueError("Embedding model must have an embedding dimension in its metadata")
|
|
||||||
model = ModelWithACL(
|
|
||||||
identifier=model_id,
|
|
||||||
provider_resource_id=provider_model_id,
|
|
||||||
provider_id=provider_id,
|
|
||||||
metadata=metadata,
|
|
||||||
model_type=model_type,
|
|
||||||
)
|
|
||||||
registered_model = await self.register_object(model)
|
|
||||||
return registered_model
|
|
||||||
|
|
||||||
async def unregister_model(self, model_id: str) -> None:
|
|
||||||
existing_model = await self.get_model(model_id)
|
|
||||||
if existing_model is None:
|
|
||||||
raise ValueError(f"Model {model_id} not found")
|
|
||||||
await self.unregister_object(existing_model)
|
|
|
@ -1,62 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from llama_stack.apis.common.type_system import ParamType
|
|
||||||
from llama_stack.apis.resource import ResourceType
|
|
||||||
from llama_stack.apis.scoring_functions import (
|
|
||||||
ListScoringFunctionsResponse,
|
|
||||||
ScoringFn,
|
|
||||||
ScoringFnParams,
|
|
||||||
ScoringFunctions,
|
|
||||||
)
|
|
||||||
from llama_stack.distribution.datatypes import (
|
|
||||||
ScoringFnWithACL,
|
|
||||||
)
|
|
||||||
from llama_stack.log import get_logger
|
|
||||||
|
|
||||||
from .common import CommonRoutingTableImpl
|
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
|
||||||
|
|
||||||
|
|
||||||
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
|
|
||||||
async def list_scoring_functions(self) -> ListScoringFunctionsResponse:
|
|
||||||
return ListScoringFunctionsResponse(data=await self.get_all_with_type(ResourceType.scoring_function.value))
|
|
||||||
|
|
||||||
async def get_scoring_function(self, scoring_fn_id: str) -> ScoringFn:
|
|
||||||
scoring_fn = await self.get_object_by_identifier("scoring_function", scoring_fn_id)
|
|
||||||
if scoring_fn is None:
|
|
||||||
raise ValueError(f"Scoring function '{scoring_fn_id}' not found")
|
|
||||||
return scoring_fn
|
|
||||||
|
|
||||||
async def register_scoring_function(
|
|
||||||
self,
|
|
||||||
scoring_fn_id: str,
|
|
||||||
description: str,
|
|
||||||
return_type: ParamType,
|
|
||||||
provider_scoring_fn_id: str | None = None,
|
|
||||||
provider_id: str | None = None,
|
|
||||||
params: ScoringFnParams | None = None,
|
|
||||||
) -> None:
|
|
||||||
if provider_scoring_fn_id is None:
|
|
||||||
provider_scoring_fn_id = scoring_fn_id
|
|
||||||
if provider_id is None:
|
|
||||||
if len(self.impls_by_provider_id) == 1:
|
|
||||||
provider_id = list(self.impls_by_provider_id.keys())[0]
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"No provider specified and multiple providers available. Please specify a provider_id."
|
|
||||||
)
|
|
||||||
scoring_fn = ScoringFnWithACL(
|
|
||||||
identifier=scoring_fn_id,
|
|
||||||
description=description,
|
|
||||||
return_type=return_type,
|
|
||||||
provider_resource_id=provider_scoring_fn_id,
|
|
||||||
provider_id=provider_id,
|
|
||||||
params=params,
|
|
||||||
)
|
|
||||||
scoring_fn.provider_id = provider_id
|
|
||||||
await self.register_object(scoring_fn)
|
|
|
@ -1,57 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from llama_stack.apis.resource import ResourceType
|
|
||||||
from llama_stack.apis.shields import ListShieldsResponse, Shield, Shields
|
|
||||||
from llama_stack.distribution.datatypes import (
|
|
||||||
ShieldWithACL,
|
|
||||||
)
|
|
||||||
from llama_stack.log import get_logger
|
|
||||||
|
|
||||||
from .common import CommonRoutingTableImpl
|
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
|
||||||
|
|
||||||
|
|
||||||
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
|
||||||
async def list_shields(self) -> ListShieldsResponse:
|
|
||||||
return ListShieldsResponse(data=await self.get_all_with_type(ResourceType.shield.value))
|
|
||||||
|
|
||||||
async def get_shield(self, identifier: str) -> Shield:
|
|
||||||
shield = await self.get_object_by_identifier("shield", identifier)
|
|
||||||
if shield is None:
|
|
||||||
raise ValueError(f"Shield '{identifier}' not found")
|
|
||||||
return shield
|
|
||||||
|
|
||||||
async def register_shield(
|
|
||||||
self,
|
|
||||||
shield_id: str,
|
|
||||||
provider_shield_id: str | None = None,
|
|
||||||
provider_id: str | None = None,
|
|
||||||
params: dict[str, Any] | None = None,
|
|
||||||
) -> Shield:
|
|
||||||
if provider_shield_id is None:
|
|
||||||
provider_shield_id = shield_id
|
|
||||||
if provider_id is None:
|
|
||||||
# If provider_id not specified, use the only provider if it supports this shield type
|
|
||||||
if len(self.impls_by_provider_id) == 1:
|
|
||||||
provider_id = list(self.impls_by_provider_id.keys())[0]
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"No provider specified and multiple providers available. Please specify a provider_id."
|
|
||||||
)
|
|
||||||
if params is None:
|
|
||||||
params = {}
|
|
||||||
shield = ShieldWithACL(
|
|
||||||
identifier=shield_id,
|
|
||||||
provider_resource_id=provider_shield_id,
|
|
||||||
provider_id=provider_id,
|
|
||||||
params=params,
|
|
||||||
)
|
|
||||||
await self.register_object(shield)
|
|
||||||
return shield
|
|
|
@ -1,132 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import URL
|
|
||||||
from llama_stack.apis.tools import ListToolGroupsResponse, ListToolsResponse, Tool, ToolGroup, ToolGroups
|
|
||||||
from llama_stack.distribution.datatypes import ToolGroupWithACL
|
|
||||||
from llama_stack.log import get_logger
|
|
||||||
|
|
||||||
from .common import CommonRoutingTableImpl
|
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
|
||||||
|
|
||||||
|
|
||||||
def parse_toolgroup_from_toolgroup_name_pair(toolgroup_name_with_maybe_tool_name: str) -> str | None:
|
|
||||||
# handle the funny case like "builtin::rag/knowledge_search"
|
|
||||||
parts = toolgroup_name_with_maybe_tool_name.split("/")
|
|
||||||
if len(parts) == 2:
|
|
||||||
return parts[0]
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|
||||||
toolgroups_to_tools: dict[str, list[Tool]] = {}
|
|
||||||
tool_to_toolgroup: dict[str, str] = {}
|
|
||||||
|
|
||||||
# overridden
|
|
||||||
def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any:
|
|
||||||
# we don't index tools in the registry anymore, but only keep a cache of them by toolgroup_id
|
|
||||||
# TODO: we may want to invalidate the cache (for a given toolgroup_id) every once in a while?
|
|
||||||
|
|
||||||
toolgroup_id = parse_toolgroup_from_toolgroup_name_pair(routing_key)
|
|
||||||
if toolgroup_id:
|
|
||||||
routing_key = toolgroup_id
|
|
||||||
|
|
||||||
if routing_key in self.tool_to_toolgroup:
|
|
||||||
routing_key = self.tool_to_toolgroup[routing_key]
|
|
||||||
return super().get_provider_impl(routing_key, provider_id)
|
|
||||||
|
|
||||||
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse:
|
|
||||||
if toolgroup_id:
|
|
||||||
if group_id := parse_toolgroup_from_toolgroup_name_pair(toolgroup_id):
|
|
||||||
toolgroup_id = group_id
|
|
||||||
toolgroups = [await self.get_tool_group(toolgroup_id)]
|
|
||||||
else:
|
|
||||||
toolgroups = await self.get_all_with_type("tool_group")
|
|
||||||
|
|
||||||
all_tools = []
|
|
||||||
for toolgroup in toolgroups:
|
|
||||||
if toolgroup.identifier not in self.toolgroups_to_tools:
|
|
||||||
await self._index_tools(toolgroup)
|
|
||||||
all_tools.extend(self.toolgroups_to_tools[toolgroup.identifier])
|
|
||||||
|
|
||||||
return ListToolsResponse(data=all_tools)
|
|
||||||
|
|
||||||
async def _index_tools(self, toolgroup: ToolGroup):
|
|
||||||
provider_impl = super().get_provider_impl(toolgroup.identifier, toolgroup.provider_id)
|
|
||||||
tooldefs_response = await provider_impl.list_runtime_tools(toolgroup.identifier, toolgroup.mcp_endpoint)
|
|
||||||
|
|
||||||
# TODO: kill this Tool vs ToolDef distinction
|
|
||||||
tooldefs = tooldefs_response.data
|
|
||||||
tools = []
|
|
||||||
for t in tooldefs:
|
|
||||||
tools.append(
|
|
||||||
Tool(
|
|
||||||
identifier=t.name,
|
|
||||||
toolgroup_id=toolgroup.identifier,
|
|
||||||
description=t.description or "",
|
|
||||||
parameters=t.parameters or [],
|
|
||||||
metadata=t.metadata,
|
|
||||||
provider_id=toolgroup.provider_id,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.toolgroups_to_tools[toolgroup.identifier] = tools
|
|
||||||
for tool in tools:
|
|
||||||
self.tool_to_toolgroup[tool.identifier] = toolgroup.identifier
|
|
||||||
|
|
||||||
async def list_tool_groups(self) -> ListToolGroupsResponse:
|
|
||||||
return ListToolGroupsResponse(data=await self.get_all_with_type("tool_group"))
|
|
||||||
|
|
||||||
async def get_tool_group(self, toolgroup_id: str) -> ToolGroup:
|
|
||||||
tool_group = await self.get_object_by_identifier("tool_group", toolgroup_id)
|
|
||||||
if tool_group is None:
|
|
||||||
raise ValueError(f"Tool group '{toolgroup_id}' not found")
|
|
||||||
return tool_group
|
|
||||||
|
|
||||||
async def get_tool(self, tool_name: str) -> Tool:
|
|
||||||
if tool_name in self.tool_to_toolgroup:
|
|
||||||
toolgroup_id = self.tool_to_toolgroup[tool_name]
|
|
||||||
tools = self.toolgroups_to_tools[toolgroup_id]
|
|
||||||
for tool in tools:
|
|
||||||
if tool.identifier == tool_name:
|
|
||||||
return tool
|
|
||||||
raise ValueError(f"Tool '{tool_name}' not found")
|
|
||||||
|
|
||||||
async def register_tool_group(
|
|
||||||
self,
|
|
||||||
toolgroup_id: str,
|
|
||||||
provider_id: str,
|
|
||||||
mcp_endpoint: URL | None = None,
|
|
||||||
args: dict[str, Any] | None = None,
|
|
||||||
) -> None:
|
|
||||||
toolgroup = ToolGroupWithACL(
|
|
||||||
identifier=toolgroup_id,
|
|
||||||
provider_id=provider_id,
|
|
||||||
provider_resource_id=toolgroup_id,
|
|
||||||
mcp_endpoint=mcp_endpoint,
|
|
||||||
args=args,
|
|
||||||
)
|
|
||||||
await self.register_object(toolgroup)
|
|
||||||
|
|
||||||
# ideally, indexing of the tools should not be necessary because anyone using
|
|
||||||
# the tools should first list the tools and then use them. but there are assumptions
|
|
||||||
# baked in some of the code and tests right now.
|
|
||||||
if not toolgroup.mcp_endpoint:
|
|
||||||
await self._index_tools(toolgroup)
|
|
||||||
return toolgroup
|
|
||||||
|
|
||||||
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
|
|
||||||
tool_group = await self.get_tool_group(toolgroup_id)
|
|
||||||
if tool_group is None:
|
|
||||||
raise ValueError(f"Tool group {toolgroup_id} not found")
|
|
||||||
await self.unregister_object(tool_group)
|
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
|
||||||
pass
|
|
|
@ -1,74 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from pydantic import TypeAdapter
|
|
||||||
|
|
||||||
from llama_stack.apis.models import ModelType
|
|
||||||
from llama_stack.apis.resource import ResourceType
|
|
||||||
from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs
|
|
||||||
from llama_stack.distribution.datatypes import (
|
|
||||||
VectorDBWithACL,
|
|
||||||
)
|
|
||||||
from llama_stack.log import get_logger
|
|
||||||
|
|
||||||
from .common import CommonRoutingTableImpl
|
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
|
||||||
|
|
||||||
|
|
||||||
class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
|
||||||
async def list_vector_dbs(self) -> ListVectorDBsResponse:
|
|
||||||
return ListVectorDBsResponse(data=await self.get_all_with_type("vector_db"))
|
|
||||||
|
|
||||||
async def get_vector_db(self, vector_db_id: str) -> VectorDB:
|
|
||||||
vector_db = await self.get_object_by_identifier("vector_db", vector_db_id)
|
|
||||||
if vector_db is None:
|
|
||||||
raise ValueError(f"Vector DB '{vector_db_id}' not found")
|
|
||||||
return vector_db
|
|
||||||
|
|
||||||
async def register_vector_db(
|
|
||||||
self,
|
|
||||||
vector_db_id: str,
|
|
||||||
embedding_model: str,
|
|
||||||
embedding_dimension: int | None = 384,
|
|
||||||
provider_id: str | None = None,
|
|
||||||
provider_vector_db_id: str | None = None,
|
|
||||||
) -> VectorDB:
|
|
||||||
if provider_vector_db_id is None:
|
|
||||||
provider_vector_db_id = vector_db_id
|
|
||||||
if provider_id is None:
|
|
||||||
if len(self.impls_by_provider_id) > 0:
|
|
||||||
provider_id = list(self.impls_by_provider_id.keys())[0]
|
|
||||||
if len(self.impls_by_provider_id) > 1:
|
|
||||||
logger.warning(
|
|
||||||
f"No provider specified and multiple providers available. Arbitrarily selected the first provider {provider_id}."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError("No provider available. Please configure a vector_io provider.")
|
|
||||||
model = await self.get_object_by_identifier("model", embedding_model)
|
|
||||||
if model is None:
|
|
||||||
raise ValueError(f"Model {embedding_model} not found")
|
|
||||||
if model.model_type != ModelType.embedding:
|
|
||||||
raise ValueError(f"Model {embedding_model} is not an embedding model")
|
|
||||||
if "embedding_dimension" not in model.metadata:
|
|
||||||
raise ValueError(f"Model {embedding_model} does not have an embedding dimension")
|
|
||||||
vector_db_data = {
|
|
||||||
"identifier": vector_db_id,
|
|
||||||
"type": ResourceType.vector_db.value,
|
|
||||||
"provider_id": provider_id,
|
|
||||||
"provider_resource_id": provider_vector_db_id,
|
|
||||||
"embedding_model": embedding_model,
|
|
||||||
"embedding_dimension": model.metadata["embedding_dimension"],
|
|
||||||
}
|
|
||||||
vector_db = TypeAdapter(VectorDBWithACL).validate_python(vector_db_data)
|
|
||||||
await self.register_object(vector_db)
|
|
||||||
return vector_db
|
|
||||||
|
|
||||||
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
|
||||||
existing_vector_db = await self.get_vector_db(vector_db_id)
|
|
||||||
if existing_vector_db is None:
|
|
||||||
raise ValueError(f"Vector DB {vector_db_id} not found")
|
|
||||||
await self.unregister_object(existing_vector_db)
|
|
|
@ -8,8 +8,7 @@ import json
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import AuthenticationConfig
|
from llama_stack.distribution.server.auth_providers import AuthProviderConfig, create_auth_provider
|
||||||
from llama_stack.distribution.server.auth_providers import create_auth_provider
|
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="auth")
|
logger = get_logger(name=__name__, category="auth")
|
||||||
|
@ -78,7 +77,7 @@ class AuthenticationMiddleware:
|
||||||
access resources that don't have access_attributes defined.
|
access resources that don't have access_attributes defined.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, app, auth_config: AuthenticationConfig):
|
def __init__(self, app, auth_config: AuthProviderConfig):
|
||||||
self.app = app
|
self.app = app
|
||||||
self.auth_provider = create_auth_provider(auth_config)
|
self.auth_provider = create_auth_provider(auth_config)
|
||||||
|
|
||||||
|
@ -94,7 +93,7 @@ class AuthenticationMiddleware:
|
||||||
|
|
||||||
# Validate token and get access attributes
|
# Validate token and get access attributes
|
||||||
try:
|
try:
|
||||||
validation_result = await self.auth_provider.validate_token(token, scope)
|
access_attributes = await self.auth_provider.validate_token(token, scope)
|
||||||
except httpx.TimeoutException:
|
except httpx.TimeoutException:
|
||||||
logger.exception("Authentication request timed out")
|
logger.exception("Authentication request timed out")
|
||||||
return await self._send_auth_error(send, "Authentication service timeout")
|
return await self._send_auth_error(send, "Authentication service timeout")
|
||||||
|
@ -106,24 +105,17 @@ class AuthenticationMiddleware:
|
||||||
return await self._send_auth_error(send, "Authentication service error")
|
return await self._send_auth_error(send, "Authentication service error")
|
||||||
|
|
||||||
# Store attributes in request scope for access control
|
# Store attributes in request scope for access control
|
||||||
if validation_result.access_attributes:
|
if access_attributes:
|
||||||
user_attributes = validation_result.access_attributes.model_dump(exclude_none=True)
|
user_attributes = access_attributes.model_dump(exclude_none=True)
|
||||||
else:
|
else:
|
||||||
logger.warning("No access attributes, setting namespace to token by default")
|
logger.warning("No access attributes, setting namespace to token by default")
|
||||||
user_attributes = {
|
user_attributes = {
|
||||||
"roles": [token],
|
"namespaces": [token],
|
||||||
}
|
}
|
||||||
|
|
||||||
# Store the client ID in the request scope so that downstream middleware (like QuotaMiddleware)
|
|
||||||
# can identify the requester and enforce per-client rate limits.
|
|
||||||
scope["authenticated_client_id"] = token
|
|
||||||
|
|
||||||
# Store attributes in request scope
|
# Store attributes in request scope
|
||||||
scope["user_attributes"] = user_attributes
|
scope["user_attributes"] = user_attributes
|
||||||
scope["principal"] = validation_result.principal
|
logger.debug(f"Authentication successful: {len(scope['user_attributes'])} attributes")
|
||||||
logger.debug(
|
|
||||||
f"Authentication successful: {validation_result.principal} with {len(scope['user_attributes'])} attributes"
|
|
||||||
)
|
|
||||||
|
|
||||||
return await self.app(scope, receive, send)
|
return await self.app(scope, receive, send)
|
||||||
|
|
||||||
|
|
|
@ -4,29 +4,23 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import ssl
|
import json
|
||||||
import time
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from asyncio import Lock
|
from enum import Enum
|
||||||
from pathlib import Path
|
|
||||||
from urllib.parse import parse_qs
|
from urllib.parse import parse_qs
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from jose import jwt
|
from pydantic import BaseModel, Field
|
||||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
|
||||||
from typing_extensions import Self
|
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import AccessAttributes, AuthenticationConfig, AuthProviderType
|
from llama_stack.distribution.datatypes import AccessAttributes
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="auth")
|
logger = get_logger(name=__name__, category="auth")
|
||||||
|
|
||||||
|
|
||||||
class TokenValidationResult(BaseModel):
|
class AuthResponse(BaseModel):
|
||||||
principal: str | None = Field(
|
"""The format of the authentication response from the auth endpoint."""
|
||||||
default=None,
|
|
||||||
description="The principal (username or persistent identifier) of the authenticated user",
|
|
||||||
)
|
|
||||||
access_attributes: AccessAttributes | None = Field(
|
access_attributes: AccessAttributes | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="""
|
description="""
|
||||||
|
@ -49,10 +43,6 @@ class TokenValidationResult(BaseModel):
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class AuthResponse(TokenValidationResult):
|
|
||||||
"""The format of the authentication response from the auth endpoint."""
|
|
||||||
|
|
||||||
message: str | None = Field(
|
message: str | None = Field(
|
||||||
default=None, description="Optional message providing additional context about the authentication result."
|
default=None, description="Optional message providing additional context about the authentication result."
|
||||||
)
|
)
|
||||||
|
@ -74,11 +64,25 @@ class AuthRequest(BaseModel):
|
||||||
request: AuthRequestContext = Field(description="Context information about the request being authenticated")
|
request: AuthRequestContext = Field(description="Context information about the request being authenticated")
|
||||||
|
|
||||||
|
|
||||||
|
class AuthProviderType(str, Enum):
|
||||||
|
"""Supported authentication provider types."""
|
||||||
|
|
||||||
|
KUBERNETES = "kubernetes"
|
||||||
|
CUSTOM = "custom"
|
||||||
|
|
||||||
|
|
||||||
|
class AuthProviderConfig(BaseModel):
|
||||||
|
"""Base configuration for authentication providers."""
|
||||||
|
|
||||||
|
provider_type: AuthProviderType = Field(..., description="Type of authentication provider")
|
||||||
|
config: dict[str, str] = Field(..., description="Provider-specific configuration")
|
||||||
|
|
||||||
|
|
||||||
class AuthProvider(ABC):
|
class AuthProvider(ABC):
|
||||||
"""Abstract base class for authentication providers."""
|
"""Abstract base class for authentication providers."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
|
async def validate_token(self, token: str, scope: dict | None = None) -> AccessAttributes | None:
|
||||||
"""Validate a token and return access attributes."""
|
"""Validate a token and return access attributes."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -88,219 +92,88 @@ class AuthProvider(ABC):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str]) -> AccessAttributes:
|
class KubernetesAuthProvider(AuthProvider):
|
||||||
attributes = AccessAttributes()
|
"""Kubernetes authentication provider that validates tokens against the Kubernetes API server."""
|
||||||
for claim_key, attribute_key in mapping.items():
|
|
||||||
if claim_key not in claims or not hasattr(attributes, attribute_key):
|
|
||||||
continue
|
|
||||||
claim = claims[claim_key]
|
|
||||||
if isinstance(claim, list):
|
|
||||||
values = claim
|
|
||||||
else:
|
|
||||||
values = claim.split()
|
|
||||||
|
|
||||||
current = getattr(attributes, attribute_key)
|
def __init__(self, config: dict[str, str]):
|
||||||
if current:
|
self.api_server_url = config["api_server_url"]
|
||||||
current.extend(values)
|
self.ca_cert_path = config.get("ca_cert_path")
|
||||||
else:
|
self._client = None
|
||||||
setattr(attributes, attribute_key, values)
|
|
||||||
return attributes
|
|
||||||
|
|
||||||
|
async def _get_client(self):
|
||||||
|
"""Get or create a Kubernetes client."""
|
||||||
|
if self._client is None:
|
||||||
|
# kubernetes-client has not async support, see:
|
||||||
|
# https://github.com/kubernetes-client/python/issues/323
|
||||||
|
from kubernetes import client
|
||||||
|
from kubernetes.client import ApiClient
|
||||||
|
|
||||||
class OAuth2JWKSConfig(BaseModel):
|
# Configure the client
|
||||||
# The JWKS URI for collecting public keys
|
configuration = client.Configuration()
|
||||||
uri: str
|
configuration.host = self.api_server_url
|
||||||
key_recheck_period: int = Field(default=3600, description="The period to recheck the JWKS URI for key updates")
|
if self.ca_cert_path:
|
||||||
|
configuration.ssl_ca_cert = self.ca_cert_path
|
||||||
|
configuration.verify_ssl = bool(self.ca_cert_path)
|
||||||
|
|
||||||
|
# Create API client
|
||||||
|
self._client = ApiClient(configuration)
|
||||||
|
return self._client
|
||||||
|
|
||||||
class OAuth2IntrospectionConfig(BaseModel):
|
async def validate_token(self, token: str, scope: dict | None = None) -> AccessAttributes | None:
|
||||||
url: str
|
"""Validate a Kubernetes token and return access attributes."""
|
||||||
client_id: str
|
|
||||||
client_secret: str
|
|
||||||
send_secret_in_body: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
class OAuth2TokenAuthProviderConfig(BaseModel):
|
|
||||||
audience: str = "llama-stack"
|
|
||||||
verify_tls: bool = True
|
|
||||||
tls_cafile: Path | None = None
|
|
||||||
issuer: str | None = Field(default=None, description="The OIDC issuer URL.")
|
|
||||||
claims_mapping: dict[str, str] = Field(
|
|
||||||
default_factory=lambda: {
|
|
||||||
"sub": "roles",
|
|
||||||
"username": "roles",
|
|
||||||
"groups": "teams",
|
|
||||||
"team": "teams",
|
|
||||||
"project": "projects",
|
|
||||||
"tenant": "namespaces",
|
|
||||||
"namespace": "namespaces",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
jwks: OAuth2JWKSConfig | None
|
|
||||||
introspection: OAuth2IntrospectionConfig | None = None
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@field_validator("claims_mapping")
|
|
||||||
def validate_claims_mapping(cls, v):
|
|
||||||
for key, value in v.items():
|
|
||||||
if not value:
|
|
||||||
raise ValueError(f"claims_mapping value cannot be empty: {key}")
|
|
||||||
if value not in AccessAttributes.model_fields:
|
|
||||||
raise ValueError(f"claims_mapping value is not a valid attribute: {value}")
|
|
||||||
return v
|
|
||||||
|
|
||||||
@model_validator(mode="after")
|
|
||||||
def validate_mode(self) -> Self:
|
|
||||||
if not self.jwks and not self.introspection:
|
|
||||||
raise ValueError("One of jwks or introspection must be configured")
|
|
||||||
if self.jwks and self.introspection:
|
|
||||||
raise ValueError("At present only one of jwks or introspection should be configured")
|
|
||||||
return self
|
|
||||||
|
|
||||||
|
|
||||||
class OAuth2TokenAuthProvider(AuthProvider):
|
|
||||||
"""
|
|
||||||
JWT token authentication provider that validates a JWT token and extracts access attributes.
|
|
||||||
|
|
||||||
This should be the standard authentication provider for most use cases.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, config: OAuth2TokenAuthProviderConfig):
|
|
||||||
self.config = config
|
|
||||||
self._jwks_at: float = 0.0
|
|
||||||
self._jwks: dict[str, str] = {}
|
|
||||||
self._jwks_lock = Lock()
|
|
||||||
|
|
||||||
async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
|
|
||||||
if self.config.jwks:
|
|
||||||
return await self.validate_jwt_token(token, scope)
|
|
||||||
if self.config.introspection:
|
|
||||||
return await self.introspect_token(token, scope)
|
|
||||||
raise ValueError("One of jwks or introspection must be configured")
|
|
||||||
|
|
||||||
async def validate_jwt_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
|
|
||||||
"""Validate a token using the JWT token."""
|
|
||||||
await self._refresh_jwks()
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
header = jwt.get_unverified_header(token)
|
client = await self._get_client()
|
||||||
kid = header["kid"]
|
|
||||||
if kid not in self._jwks:
|
# Set the token in the client
|
||||||
raise ValueError(f"Unknown key ID: {kid}")
|
client.set_default_header("Authorization", f"Bearer {token}")
|
||||||
key_data = self._jwks[kid]
|
|
||||||
algorithm = header.get("alg", "RS256")
|
# Make a request to validate the token
|
||||||
claims = jwt.decode(
|
# We use the /api endpoint which requires authentication
|
||||||
token,
|
from kubernetes.client import CoreV1Api
|
||||||
key_data,
|
|
||||||
algorithms=[algorithm],
|
api = CoreV1Api(client)
|
||||||
audience=self.config.audience,
|
api.get_api_resources(_request_timeout=3.0) # Set timeout for this specific request
|
||||||
issuer=self.config.issuer,
|
|
||||||
|
# If we get here, the token is valid
|
||||||
|
# Extract user info from the token claims
|
||||||
|
import base64
|
||||||
|
|
||||||
|
# Decode the token (without verification since we've already validated it)
|
||||||
|
token_parts = token.split(".")
|
||||||
|
payload = json.loads(base64.b64decode(token_parts[1] + "=" * (-len(token_parts[1]) % 4)))
|
||||||
|
|
||||||
|
# Extract user information from the token
|
||||||
|
username = payload.get("sub", "")
|
||||||
|
groups = payload.get("groups", [])
|
||||||
|
|
||||||
|
return AccessAttributes(
|
||||||
|
roles=[username], # Use username as a role
|
||||||
|
teams=groups, # Use Kubernetes groups as teams
|
||||||
)
|
)
|
||||||
except Exception as exc:
|
|
||||||
raise ValueError(f"Invalid JWT token: {token}") from exc
|
|
||||||
|
|
||||||
# There are other standard claims, the most relevant of which is `scope`.
|
|
||||||
# We should incorporate these into the access attributes.
|
|
||||||
principal = claims["sub"]
|
|
||||||
access_attributes = get_attributes_from_claims(claims, self.config.claims_mapping)
|
|
||||||
return TokenValidationResult(
|
|
||||||
principal=principal,
|
|
||||||
access_attributes=access_attributes,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def introspect_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
|
|
||||||
"""Validate a token using token introspection as defined by RFC 7662."""
|
|
||||||
form = {
|
|
||||||
"token": token,
|
|
||||||
}
|
|
||||||
if self.config.introspection is None:
|
|
||||||
raise ValueError("Introspection is not configured")
|
|
||||||
|
|
||||||
if self.config.introspection.send_secret_in_body:
|
|
||||||
form["client_id"] = self.config.introspection.client_id
|
|
||||||
form["client_secret"] = self.config.introspection.client_secret
|
|
||||||
auth = None
|
|
||||||
else:
|
|
||||||
auth = (self.config.introspection.client_id, self.config.introspection.client_secret)
|
|
||||||
ssl_ctxt = None
|
|
||||||
if self.config.tls_cafile:
|
|
||||||
ssl_ctxt = ssl.create_default_context(cafile=self.config.tls_cafile.as_posix())
|
|
||||||
try:
|
|
||||||
async with httpx.AsyncClient(verify=ssl_ctxt) as client:
|
|
||||||
response = await client.post(
|
|
||||||
self.config.introspection.url,
|
|
||||||
data=form,
|
|
||||||
auth=auth,
|
|
||||||
timeout=10.0, # Add a reasonable timeout
|
|
||||||
)
|
|
||||||
if response.status_code != 200:
|
|
||||||
logger.warning(f"Token introspection failed with status code: {response.status_code}")
|
|
||||||
raise ValueError(f"Token introspection failed: {response.status_code}")
|
|
||||||
|
|
||||||
fields = response.json()
|
|
||||||
if not fields["active"]:
|
|
||||||
raise ValueError("Token not active")
|
|
||||||
principal = fields["sub"] or fields["username"]
|
|
||||||
access_attributes = get_attributes_from_claims(fields, self.config.claims_mapping)
|
|
||||||
return TokenValidationResult(
|
|
||||||
principal=principal,
|
|
||||||
access_attributes=access_attributes,
|
|
||||||
)
|
|
||||||
except httpx.TimeoutException:
|
|
||||||
logger.exception("Token introspection request timed out")
|
|
||||||
raise
|
|
||||||
except ValueError:
|
|
||||||
# Re-raise ValueError exceptions to preserve their message
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("Error during token introspection")
|
logger.exception("Failed to validate Kubernetes token")
|
||||||
raise ValueError("Token introspection error") from e
|
raise ValueError("Invalid or expired token") from e
|
||||||
|
|
||||||
async def close(self):
|
async def close(self):
|
||||||
pass
|
"""Close the HTTP client."""
|
||||||
|
if self._client:
|
||||||
async def _refresh_jwks(self) -> None:
|
self._client.close()
|
||||||
"""
|
self._client = None
|
||||||
Refresh the JWKS cache.
|
|
||||||
|
|
||||||
This is a simple cache that expires after a certain amount of time (defined by `key_recheck_period`).
|
|
||||||
If the cache is expired, we refresh the JWKS from the JWKS URI.
|
|
||||||
|
|
||||||
Notes: for Kubernetes which doesn't fully implement the OIDC protocol:
|
|
||||||
* It doesn't have user authentication flows
|
|
||||||
* It doesn't have refresh tokens
|
|
||||||
"""
|
|
||||||
async with self._jwks_lock:
|
|
||||||
if self.config.jwks is None:
|
|
||||||
raise ValueError("JWKS is not configured")
|
|
||||||
if time.time() - self._jwks_at > self.config.jwks.key_recheck_period:
|
|
||||||
verify = self.config.tls_cafile.as_posix() if self.config.tls_cafile else self.config.verify_tls
|
|
||||||
async with httpx.AsyncClient(verify=verify) as client:
|
|
||||||
res = await client.get(self.config.jwks.uri, timeout=5)
|
|
||||||
res.raise_for_status()
|
|
||||||
jwks_data = res.json()["keys"]
|
|
||||||
updated = {}
|
|
||||||
for k in jwks_data:
|
|
||||||
kid = k["kid"]
|
|
||||||
# Store the entire key object as it may be needed for different algorithms
|
|
||||||
updated[kid] = k
|
|
||||||
self._jwks = updated
|
|
||||||
self._jwks_at = time.time()
|
|
||||||
|
|
||||||
|
|
||||||
class CustomAuthProviderConfig(BaseModel):
|
|
||||||
endpoint: str
|
|
||||||
|
|
||||||
|
|
||||||
class CustomAuthProvider(AuthProvider):
|
class CustomAuthProvider(AuthProvider):
|
||||||
"""Custom authentication provider that uses an external endpoint."""
|
"""Custom authentication provider that uses an external endpoint."""
|
||||||
|
|
||||||
def __init__(self, config: CustomAuthProviderConfig):
|
def __init__(self, config: dict[str, str]):
|
||||||
self.config = config
|
self.endpoint = config["endpoint"]
|
||||||
self._client = None
|
self._client = None
|
||||||
|
|
||||||
async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
|
async def validate_token(self, token: str, scope: dict | None = None) -> AccessAttributes | None:
|
||||||
"""Validate a token using the custom authentication endpoint."""
|
"""Validate a token using the custom authentication endpoint."""
|
||||||
|
if not self.endpoint:
|
||||||
|
raise ValueError("Authentication endpoint not configured")
|
||||||
|
|
||||||
if scope is None:
|
if scope is None:
|
||||||
scope = {}
|
scope = {}
|
||||||
|
|
||||||
|
@ -329,7 +202,7 @@ class CustomAuthProvider(AuthProvider):
|
||||||
try:
|
try:
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
self.config.endpoint,
|
self.endpoint,
|
||||||
json=auth_request.model_dump(),
|
json=auth_request.model_dump(),
|
||||||
timeout=10.0, # Add a reasonable timeout
|
timeout=10.0, # Add a reasonable timeout
|
||||||
)
|
)
|
||||||
|
@ -341,7 +214,19 @@ class CustomAuthProvider(AuthProvider):
|
||||||
try:
|
try:
|
||||||
response_data = response.json()
|
response_data = response.json()
|
||||||
auth_response = AuthResponse(**response_data)
|
auth_response = AuthResponse(**response_data)
|
||||||
return auth_response
|
|
||||||
|
# Store attributes in request scope for access control
|
||||||
|
if auth_response.access_attributes:
|
||||||
|
return auth_response.access_attributes
|
||||||
|
else:
|
||||||
|
logger.warning("No access attributes, setting namespace to api_key by default")
|
||||||
|
user_attributes = {
|
||||||
|
"namespaces": [token],
|
||||||
|
}
|
||||||
|
|
||||||
|
scope["user_attributes"] = user_attributes
|
||||||
|
logger.debug(f"Authentication successful: {len(user_attributes)} attributes")
|
||||||
|
return auth_response.access_attributes
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("Error parsing authentication response")
|
logger.exception("Error parsing authentication response")
|
||||||
raise ValueError("Invalid authentication response format") from e
|
raise ValueError("Invalid authentication response format") from e
|
||||||
|
@ -363,14 +248,14 @@ class CustomAuthProvider(AuthProvider):
|
||||||
self._client = None
|
self._client = None
|
||||||
|
|
||||||
|
|
||||||
def create_auth_provider(config: AuthenticationConfig) -> AuthProvider:
|
def create_auth_provider(config: AuthProviderConfig) -> AuthProvider:
|
||||||
"""Factory function to create the appropriate auth provider."""
|
"""Factory function to create the appropriate auth provider."""
|
||||||
provider_type = config.provider_type.lower()
|
provider_type = config.provider_type.lower()
|
||||||
|
|
||||||
if provider_type == "custom":
|
if provider_type == "kubernetes":
|
||||||
return CustomAuthProvider(CustomAuthProviderConfig.model_validate(config.config))
|
return KubernetesAuthProvider(config.config)
|
||||||
elif provider_type == "oauth2_token":
|
elif provider_type == "custom":
|
||||||
return OAuth2TokenAuthProvider(OAuth2TokenAuthProviderConfig.model_validate(config.config))
|
return CustomAuthProvider(config.config)
|
||||||
else:
|
else:
|
||||||
supported_providers = ", ".join([t.value for t in AuthProviderType])
|
supported_providers = ", ".join([t.value for t in AuthProviderType])
|
||||||
raise ValueError(f"Unsupported auth provider type: {provider_type}. Supported types are: {supported_providers}")
|
raise ValueError(f"Unsupported auth provider type: {provider_type}. Supported types are: {supported_providers}")
|
||||||
|
|
|
@ -6,23 +6,20 @@
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
import re
|
import re
|
||||||
from collections.abc import Callable
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from aiohttp import hdrs
|
from pydantic import BaseModel
|
||||||
from starlette.routing import Route
|
|
||||||
|
|
||||||
from llama_stack.apis.tools import RAGToolRuntime, SpecialToolGroup
|
from llama_stack.apis.tools import RAGToolRuntime, SpecialToolGroup
|
||||||
from llama_stack.apis.version import LLAMA_STACK_API_VERSION
|
from llama_stack.apis.version import LLAMA_STACK_API_VERSION
|
||||||
from llama_stack.distribution.resolver import api_protocol_map
|
from llama_stack.distribution.resolver import api_protocol_map
|
||||||
from llama_stack.providers.datatypes import Api
|
from llama_stack.providers.datatypes import Api
|
||||||
|
|
||||||
EndpointFunc = Callable[..., Any]
|
|
||||||
PathParams = dict[str, str]
|
class ApiEndpoint(BaseModel):
|
||||||
RouteInfo = tuple[EndpointFunc, str]
|
route: str
|
||||||
PathImpl = dict[str, RouteInfo]
|
method: str
|
||||||
RouteImpls = dict[str, PathImpl]
|
name: str
|
||||||
RouteMatch = tuple[EndpointFunc, PathParams, str]
|
descriptive_name: str | None = None
|
||||||
|
|
||||||
|
|
||||||
def toolgroup_protocol_map():
|
def toolgroup_protocol_map():
|
||||||
|
@ -31,13 +28,13 @@ def toolgroup_protocol_map():
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def get_all_api_routes() -> dict[Api, list[Route]]:
|
def get_all_api_endpoints() -> dict[Api, list[ApiEndpoint]]:
|
||||||
apis = {}
|
apis = {}
|
||||||
|
|
||||||
protocols = api_protocol_map()
|
protocols = api_protocol_map()
|
||||||
toolgroup_protocols = toolgroup_protocol_map()
|
toolgroup_protocols = toolgroup_protocol_map()
|
||||||
for api, protocol in protocols.items():
|
for api, protocol in protocols.items():
|
||||||
routes = []
|
endpoints = []
|
||||||
protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction)
|
protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction)
|
||||||
|
|
||||||
# HACK ALERT
|
# HACK ALERT
|
||||||
|
@ -54,28 +51,26 @@ def get_all_api_routes() -> dict[Api, list[Route]]:
|
||||||
if not hasattr(method, "__webmethod__"):
|
if not hasattr(method, "__webmethod__"):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# The __webmethod__ attribute is dynamically added by the @webmethod decorator
|
webmethod = method.__webmethod__
|
||||||
# mypy doesn't know about this dynamic attribute, so we ignore the attr-defined error
|
route = f"/{LLAMA_STACK_API_VERSION}/{webmethod.route.lstrip('/')}"
|
||||||
webmethod = method.__webmethod__ # type: ignore[attr-defined]
|
if webmethod.method == "GET":
|
||||||
path = f"/{LLAMA_STACK_API_VERSION}/{webmethod.route.lstrip('/')}"
|
method = "get"
|
||||||
if webmethod.method == hdrs.METH_GET:
|
elif webmethod.method == "DELETE":
|
||||||
http_method = hdrs.METH_GET
|
method = "delete"
|
||||||
elif webmethod.method == hdrs.METH_DELETE:
|
|
||||||
http_method = hdrs.METH_DELETE
|
|
||||||
else:
|
else:
|
||||||
http_method = hdrs.METH_POST
|
method = "post"
|
||||||
routes.append(
|
endpoints.append(
|
||||||
Route(path=path, methods=[http_method], name=name, endpoint=None)
|
ApiEndpoint(route=route, method=method, name=name, descriptive_name=webmethod.descriptive_name)
|
||||||
) # setting endpoint to None since don't use a Router object
|
)
|
||||||
|
|
||||||
apis[api] = routes
|
apis[api] = endpoints
|
||||||
|
|
||||||
return apis
|
return apis
|
||||||
|
|
||||||
|
|
||||||
def initialize_route_impls(impls: dict[Api, Any]) -> RouteImpls:
|
def initialize_endpoint_impls(impls):
|
||||||
routes = get_all_api_routes()
|
endpoints = get_all_api_endpoints()
|
||||||
route_impls: RouteImpls = {}
|
endpoint_impls = {}
|
||||||
|
|
||||||
def _convert_path_to_regex(path: str) -> str:
|
def _convert_path_to_regex(path: str) -> str:
|
||||||
# Convert {param} to named capture groups
|
# Convert {param} to named capture groups
|
||||||
|
@ -88,34 +83,29 @@ def initialize_route_impls(impls: dict[Api, Any]) -> RouteImpls:
|
||||||
|
|
||||||
return f"^{pattern}$"
|
return f"^{pattern}$"
|
||||||
|
|
||||||
for api, api_routes in routes.items():
|
for api, api_endpoints in endpoints.items():
|
||||||
if api not in impls:
|
if api not in impls:
|
||||||
continue
|
continue
|
||||||
for route in api_routes:
|
for endpoint in api_endpoints:
|
||||||
impl = impls[api]
|
impl = impls[api]
|
||||||
func = getattr(impl, route.name)
|
func = getattr(impl, endpoint.name)
|
||||||
# Get the first (and typically only) method from the set, filtering out HEAD
|
if endpoint.method not in endpoint_impls:
|
||||||
available_methods = [m for m in route.methods if m != "HEAD"]
|
endpoint_impls[endpoint.method] = {}
|
||||||
if not available_methods:
|
endpoint_impls[endpoint.method][_convert_path_to_regex(endpoint.route)] = (
|
||||||
continue # Skip if only HEAD method is available
|
|
||||||
method = available_methods[0].lower()
|
|
||||||
if method not in route_impls:
|
|
||||||
route_impls[method] = {}
|
|
||||||
route_impls[method][_convert_path_to_regex(route.path)] = (
|
|
||||||
func,
|
func,
|
||||||
route.path,
|
endpoint.descriptive_name or endpoint.route,
|
||||||
)
|
)
|
||||||
|
|
||||||
return route_impls
|
return endpoint_impls
|
||||||
|
|
||||||
|
|
||||||
def find_matching_route(method: str, path: str, route_impls: RouteImpls) -> RouteMatch:
|
def find_matching_endpoint(method, path, endpoint_impls):
|
||||||
"""Find the matching endpoint implementation for a given method and path.
|
"""Find the matching endpoint implementation for a given method and path.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
method: HTTP method (GET, POST, etc.)
|
method: HTTP method (GET, POST, etc.)
|
||||||
path: URL path to match against
|
path: URL path to match against
|
||||||
route_impls: A dictionary of endpoint implementations
|
endpoint_impls: A dictionary of endpoint implementations
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A tuple of (endpoint_function, path_params, descriptive_name)
|
A tuple of (endpoint_function, path_params, descriptive_name)
|
||||||
|
@ -123,7 +113,7 @@ def find_matching_route(method: str, path: str, route_impls: RouteImpls) -> Rout
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If no matching endpoint is found
|
ValueError: If no matching endpoint is found
|
||||||
"""
|
"""
|
||||||
impls = route_impls.get(method.lower())
|
impls = endpoint_impls.get(method.lower())
|
||||||
if not impls:
|
if not impls:
|
||||||
raise ValueError(f"No endpoint found for {path}")
|
raise ValueError(f"No endpoint found for {path}")
|
||||||
|
|
|
@ -1,110 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
import json
|
|
||||||
import time
|
|
||||||
from datetime import datetime, timedelta, timezone
|
|
||||||
|
|
||||||
from starlette.types import ASGIApp, Receive, Scope, Send
|
|
||||||
|
|
||||||
from llama_stack.log import get_logger
|
|
||||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
|
||||||
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
|
||||||
from llama_stack.providers.utils.kvstore.kvstore import kvstore_impl
|
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="quota")
|
|
||||||
|
|
||||||
|
|
||||||
class QuotaMiddleware:
|
|
||||||
"""
|
|
||||||
ASGI middleware that enforces separate quotas for authenticated and anonymous clients
|
|
||||||
within a configurable time window.
|
|
||||||
|
|
||||||
- For authenticated requests, it reads the client ID from the
|
|
||||||
`Authorization: Bearer <client_id>` header.
|
|
||||||
- For anonymous requests, it falls back to the IP address of the client.
|
|
||||||
Requests are counted in a KV store (e.g., SQLite), and HTTP 429 is returned
|
|
||||||
once a client exceeds its quota.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
app: ASGIApp,
|
|
||||||
kv_config: KVStoreConfig,
|
|
||||||
anonymous_max_requests: int,
|
|
||||||
authenticated_max_requests: int,
|
|
||||||
window_seconds: int = 86400,
|
|
||||||
):
|
|
||||||
self.app = app
|
|
||||||
self.kv_config = kv_config
|
|
||||||
self.kv: KVStore | None = None
|
|
||||||
self.anonymous_max_requests = anonymous_max_requests
|
|
||||||
self.authenticated_max_requests = authenticated_max_requests
|
|
||||||
self.window_seconds = window_seconds
|
|
||||||
|
|
||||||
if isinstance(self.kv_config, SqliteKVStoreConfig):
|
|
||||||
logger.warning(
|
|
||||||
"QuotaMiddleware: Using SQLite backend. Expiry/TTL is not enforced; cleanup is manual. "
|
|
||||||
f"window_seconds={self.window_seconds}"
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _get_kv(self) -> KVStore:
|
|
||||||
if self.kv is None:
|
|
||||||
self.kv = await kvstore_impl(self.kv_config)
|
|
||||||
return self.kv
|
|
||||||
|
|
||||||
async def __call__(self, scope: Scope, receive: Receive, send: Send):
|
|
||||||
if scope["type"] == "http":
|
|
||||||
# pick key & limit based on auth
|
|
||||||
auth_id = scope.get("authenticated_client_id")
|
|
||||||
if auth_id:
|
|
||||||
key_id = auth_id
|
|
||||||
limit = self.authenticated_max_requests
|
|
||||||
else:
|
|
||||||
# fallback to IP
|
|
||||||
client = scope.get("client")
|
|
||||||
key_id = client[0] if client else "anonymous"
|
|
||||||
limit = self.anonymous_max_requests
|
|
||||||
|
|
||||||
current_window = int(time.time() // self.window_seconds)
|
|
||||||
key = f"quota:{key_id}:{current_window}"
|
|
||||||
|
|
||||||
try:
|
|
||||||
kv = await self._get_kv()
|
|
||||||
prev = await kv.get(key) or "0"
|
|
||||||
count = int(prev) + 1
|
|
||||||
|
|
||||||
if int(prev) == 0:
|
|
||||||
# Set with expiration datetime when it is the first request in the window.
|
|
||||||
expiration = datetime.now(timezone.utc) + timedelta(seconds=self.window_seconds)
|
|
||||||
await kv.set(key, str(count), expiration=expiration)
|
|
||||||
else:
|
|
||||||
await kv.set(key, str(count))
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Failed to access KV store for quota")
|
|
||||||
return await self._send_error(send, 500, "Quota service error")
|
|
||||||
|
|
||||||
if count > limit:
|
|
||||||
logger.warning(
|
|
||||||
"Quota exceeded for client %s: %d/%d",
|
|
||||||
key_id,
|
|
||||||
count,
|
|
||||||
limit,
|
|
||||||
)
|
|
||||||
return await self._send_error(send, 429, "Quota exceeded")
|
|
||||||
|
|
||||||
return await self.app(scope, receive, send)
|
|
||||||
|
|
||||||
async def _send_error(self, send: Send, status: int, message: str):
|
|
||||||
await send(
|
|
||||||
{
|
|
||||||
"type": "http.response.start",
|
|
||||||
"status": status,
|
|
||||||
"headers": [[b"content-type", b"application/json"]],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
body = json.dumps({"error": {"message": message}}).encode()
|
|
||||||
await send({"type": "http.response.body", "body": body})
|
|
|
@ -6,7 +6,6 @@
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import asyncio
|
import asyncio
|
||||||
import functools
|
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
@ -14,7 +13,6 @@ import ssl
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
import warnings
|
import warnings
|
||||||
from collections.abc import Callable
|
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from importlib.metadata import version as parse_version
|
from importlib.metadata import version as parse_version
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
@ -22,26 +20,23 @@ from typing import Annotated, Any
|
||||||
|
|
||||||
import rich.pretty
|
import rich.pretty
|
||||||
import yaml
|
import yaml
|
||||||
from aiohttp import hdrs
|
|
||||||
from fastapi import Body, FastAPI, HTTPException, Request
|
from fastapi import Body, FastAPI, HTTPException, Request
|
||||||
from fastapi import Path as FastapiPath
|
from fastapi import Path as FastapiPath
|
||||||
from fastapi.exceptions import RequestValidationError
|
from fastapi.exceptions import RequestValidationError
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
|
||||||
from fastapi.responses import JSONResponse, StreamingResponse
|
from fastapi.responses import JSONResponse, StreamingResponse
|
||||||
from openai import BadRequestError
|
from openai import BadRequestError
|
||||||
from pydantic import BaseModel, ValidationError
|
from pydantic import BaseModel, ValidationError
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import AuthenticationRequiredError, LoggingConfig, StackRunConfig
|
from llama_stack.distribution.datatypes import LoggingConfig, StackRunConfig
|
||||||
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
||||||
from llama_stack.distribution.request_headers import (
|
from llama_stack.distribution.request_headers import (
|
||||||
PROVIDER_DATA_VAR,
|
PROVIDER_DATA_VAR,
|
||||||
request_provider_data_context,
|
request_provider_data_context,
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.resolver import InvalidProviderError
|
from llama_stack.distribution.resolver import InvalidProviderError
|
||||||
from llama_stack.distribution.server.routes import (
|
from llama_stack.distribution.server.endpoints import (
|
||||||
find_matching_route,
|
find_matching_endpoint,
|
||||||
get_all_api_routes,
|
initialize_endpoint_impls,
|
||||||
initialize_route_impls,
|
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.stack import (
|
from llama_stack.distribution.stack import (
|
||||||
construct_stack,
|
construct_stack,
|
||||||
|
@ -64,7 +59,7 @@ from llama_stack.providers.utils.telemetry.tracing import (
|
||||||
)
|
)
|
||||||
|
|
||||||
from .auth import AuthenticationMiddleware
|
from .auth import AuthenticationMiddleware
|
||||||
from .quota import QuotaMiddleware
|
from .endpoints import get_all_api_endpoints
|
||||||
|
|
||||||
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
||||||
|
|
||||||
|
@ -125,8 +120,6 @@ def translate_exception(exc: Exception) -> HTTPException | RequestValidationErro
|
||||||
return HTTPException(status_code=504, detail=f"Operation timed out: {str(exc)}")
|
return HTTPException(status_code=504, detail=f"Operation timed out: {str(exc)}")
|
||||||
elif isinstance(exc, NotImplementedError):
|
elif isinstance(exc, NotImplementedError):
|
||||||
return HTTPException(status_code=501, detail=f"Not implemented: {str(exc)}")
|
return HTTPException(status_code=501, detail=f"Not implemented: {str(exc)}")
|
||||||
elif isinstance(exc, AuthenticationRequiredError):
|
|
||||||
return HTTPException(status_code=401, detail=f"Authentication required: {str(exc)}")
|
|
||||||
else:
|
else:
|
||||||
return HTTPException(
|
return HTTPException(
|
||||||
status_code=500,
|
status_code=500,
|
||||||
|
@ -212,9 +205,8 @@ async def log_request_pre_validation(request: Request):
|
||||||
logger.warning(f"Could not read or log request body for {request.method} {request.url.path}: {e}")
|
logger.warning(f"Could not read or log request body for {request.method} {request.url.path}: {e}")
|
||||||
|
|
||||||
|
|
||||||
def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
|
def create_dynamic_typed_route(func: Any, method: str, route: str):
|
||||||
@functools.wraps(func)
|
async def endpoint(request: Request, **kwargs):
|
||||||
async def route_handler(request: Request, **kwargs):
|
|
||||||
# Get auth attributes from the request scope
|
# Get auth attributes from the request scope
|
||||||
user_attributes = request.scope.get("user_attributes", {})
|
user_attributes = request.scope.get("user_attributes", {})
|
||||||
|
|
||||||
|
@ -254,9 +246,9 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
|
||||||
for param in new_params[1:]
|
for param in new_params[1:]
|
||||||
]
|
]
|
||||||
|
|
||||||
route_handler.__signature__ = sig.replace(parameters=new_params)
|
endpoint.__signature__ = sig.replace(parameters=new_params)
|
||||||
|
|
||||||
return route_handler
|
return endpoint
|
||||||
|
|
||||||
|
|
||||||
class TracingMiddleware:
|
class TracingMiddleware:
|
||||||
|
@ -278,28 +270,17 @@ class TracingMiddleware:
|
||||||
logger.debug(f"Bypassing custom routing for FastAPI built-in path: {path}")
|
logger.debug(f"Bypassing custom routing for FastAPI built-in path: {path}")
|
||||||
return await self.app(scope, receive, send)
|
return await self.app(scope, receive, send)
|
||||||
|
|
||||||
if not hasattr(self, "route_impls"):
|
if not hasattr(self, "endpoint_impls"):
|
||||||
self.route_impls = initialize_route_impls(self.impls)
|
self.endpoint_impls = initialize_endpoint_impls(self.impls)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
_, _, trace_path = find_matching_route(scope.get("method", hdrs.METH_GET), path, self.route_impls)
|
_, _, trace_path = find_matching_endpoint(scope.get("method", "GET"), path, self.endpoint_impls)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
# If no matching endpoint is found, pass through to FastAPI
|
# If no matching endpoint is found, pass through to FastAPI
|
||||||
logger.debug(f"No matching route found for path: {path}, falling back to FastAPI")
|
logger.debug(f"No matching endpoint found for path: {path}, falling back to FastAPI")
|
||||||
return await self.app(scope, receive, send)
|
return await self.app(scope, receive, send)
|
||||||
|
|
||||||
trace_attributes = {"__location__": "server", "raw_path": path}
|
trace_context = await start_trace(trace_path, {"__location__": "server", "raw_path": path})
|
||||||
|
|
||||||
# Extract W3C trace context headers and store as trace attributes
|
|
||||||
headers = dict(scope.get("headers", []))
|
|
||||||
traceparent = headers.get(b"traceparent", b"").decode()
|
|
||||||
if traceparent:
|
|
||||||
trace_attributes["traceparent"] = traceparent
|
|
||||||
tracestate = headers.get(b"tracestate", b"").decode()
|
|
||||||
if tracestate:
|
|
||||||
trace_attributes["tracestate"] = tracestate
|
|
||||||
|
|
||||||
trace_context = await start_trace(trace_path, trace_attributes)
|
|
||||||
|
|
||||||
async def send_with_trace_id(message):
|
async def send_with_trace_id(message):
|
||||||
if message["type"] == "http.response.start":
|
if message["type"] == "http.response.start":
|
||||||
|
@ -389,6 +370,14 @@ def main(args: argparse.Namespace | None = None):
|
||||||
if args is None:
|
if args is None:
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Check for deprecated argument usage
|
||||||
|
if "--yaml-config" in sys.argv:
|
||||||
|
warnings.warn(
|
||||||
|
"The '--yaml-config' argument is deprecated and will be removed in a future version. Use '--config' instead.",
|
||||||
|
DeprecationWarning,
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
|
||||||
log_line = ""
|
log_line = ""
|
||||||
if args.config:
|
if args.config:
|
||||||
# if the user provided a config file, use it, even if template was specified
|
# if the user provided a config file, use it, even if template was specified
|
||||||
|
@ -402,7 +391,7 @@ def main(args: argparse.Namespace | None = None):
|
||||||
raise ValueError(f"Template {args.template} does not exist")
|
raise ValueError(f"Template {args.template} does not exist")
|
||||||
log_line = f"Using template {args.template} config file: {config_file}"
|
log_line = f"Using template {args.template} config file: {config_file}"
|
||||||
else:
|
else:
|
||||||
raise ValueError("Either --config or --template must be provided")
|
raise ValueError("Either --yaml-config or --template must be provided")
|
||||||
|
|
||||||
logger_config = None
|
logger_config = None
|
||||||
with open(config_file) as fp:
|
with open(config_file) as fp:
|
||||||
|
@ -442,46 +431,6 @@ def main(args: argparse.Namespace | None = None):
|
||||||
if config.server.auth:
|
if config.server.auth:
|
||||||
logger.info(f"Enabling authentication with provider: {config.server.auth.provider_type.value}")
|
logger.info(f"Enabling authentication with provider: {config.server.auth.provider_type.value}")
|
||||||
app.add_middleware(AuthenticationMiddleware, auth_config=config.server.auth)
|
app.add_middleware(AuthenticationMiddleware, auth_config=config.server.auth)
|
||||||
else:
|
|
||||||
if config.server.quota:
|
|
||||||
quota = config.server.quota
|
|
||||||
logger.warning(
|
|
||||||
"Configured authenticated_max_requests (%d) but no auth is enabled; "
|
|
||||||
"falling back to anonymous_max_requests (%d) for all the requests",
|
|
||||||
quota.authenticated_max_requests,
|
|
||||||
quota.anonymous_max_requests,
|
|
||||||
)
|
|
||||||
|
|
||||||
if config.server.quota:
|
|
||||||
logger.info("Enabling quota middleware for authenticated and anonymous clients")
|
|
||||||
|
|
||||||
quota = config.server.quota
|
|
||||||
anonymous_max_requests = quota.anonymous_max_requests
|
|
||||||
# if auth is disabled, use the anonymous max requests
|
|
||||||
authenticated_max_requests = quota.authenticated_max_requests if config.server.auth else anonymous_max_requests
|
|
||||||
|
|
||||||
kv_config = quota.kvstore
|
|
||||||
window_map = {"day": 86400}
|
|
||||||
window_seconds = window_map[quota.period.value]
|
|
||||||
|
|
||||||
app.add_middleware(
|
|
||||||
QuotaMiddleware,
|
|
||||||
kv_config=kv_config,
|
|
||||||
anonymous_max_requests=anonymous_max_requests,
|
|
||||||
authenticated_max_requests=authenticated_max_requests,
|
|
||||||
window_seconds=window_seconds,
|
|
||||||
)
|
|
||||||
|
|
||||||
# --- CORS middleware for local development ---
|
|
||||||
# TODO: move to reverse proxy
|
|
||||||
ui_port = os.environ.get("LLAMA_STACK_UI_PORT", 8322)
|
|
||||||
app.add_middleware(
|
|
||||||
CORSMiddleware,
|
|
||||||
allow_origins=[f"http://localhost:{ui_port}"],
|
|
||||||
allow_credentials=True,
|
|
||||||
allow_methods=["*"],
|
|
||||||
allow_headers=["*"],
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
impls = asyncio.run(construct_stack(config))
|
impls = asyncio.run(construct_stack(config))
|
||||||
|
@ -494,7 +443,7 @@ def main(args: argparse.Namespace | None = None):
|
||||||
else:
|
else:
|
||||||
setup_logger(TelemetryAdapter(TelemetryConfig(), {}))
|
setup_logger(TelemetryAdapter(TelemetryConfig(), {}))
|
||||||
|
|
||||||
all_routes = get_all_api_routes()
|
all_endpoints = get_all_api_endpoints()
|
||||||
|
|
||||||
if config.apis:
|
if config.apis:
|
||||||
apis_to_serve = set(config.apis)
|
apis_to_serve = set(config.apis)
|
||||||
|
@ -512,29 +461,24 @@ def main(args: argparse.Namespace | None = None):
|
||||||
for api_str in apis_to_serve:
|
for api_str in apis_to_serve:
|
||||||
api = Api(api_str)
|
api = Api(api_str)
|
||||||
|
|
||||||
routes = all_routes[api]
|
endpoints = all_endpoints[api]
|
||||||
impl = impls[api]
|
impl = impls[api]
|
||||||
|
|
||||||
for route in routes:
|
for endpoint in endpoints:
|
||||||
if not hasattr(impl, route.name):
|
if not hasattr(impl, endpoint.name):
|
||||||
# ideally this should be a typing violation already
|
# ideally this should be a typing violation already
|
||||||
raise ValueError(f"Could not find method {route.name} on {impl}!")
|
raise ValueError(f"Could not find method {endpoint.name} on {impl}!!")
|
||||||
|
|
||||||
impl_method = getattr(impl, route.name)
|
impl_method = getattr(impl, endpoint.name)
|
||||||
# Filter out HEAD method since it's automatically handled by FastAPI for GET routes
|
logger.debug(f"{endpoint.method.upper()} {endpoint.route}")
|
||||||
available_methods = [m for m in route.methods if m != "HEAD"]
|
|
||||||
if not available_methods:
|
|
||||||
raise ValueError(f"No methods found for {route.name} on {impl}")
|
|
||||||
method = available_methods[0]
|
|
||||||
logger.debug(f"{method} {route.path}")
|
|
||||||
|
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.filterwarnings("ignore", category=UserWarning, module="pydantic._internal._fields")
|
warnings.filterwarnings("ignore", category=UserWarning, module="pydantic._internal._fields")
|
||||||
getattr(app, method.lower())(route.path, response_model=None)(
|
getattr(app, endpoint.method)(endpoint.route, response_model=None)(
|
||||||
create_dynamic_typed_route(
|
create_dynamic_typed_route(
|
||||||
impl_method,
|
impl_method,
|
||||||
method.lower(),
|
endpoint.method,
|
||||||
route.path,
|
endpoint.route,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -54,7 +54,7 @@ other_args=""
|
||||||
# Process remaining arguments
|
# Process remaining arguments
|
||||||
while [[ $# -gt 0 ]]; do
|
while [[ $# -gt 0 ]]; do
|
||||||
case "$1" in
|
case "$1" in
|
||||||
--config)
|
--config|--yaml-config)
|
||||||
if [[ -n "$2" ]]; then
|
if [[ -n "$2" ]]; then
|
||||||
yaml_config="$2"
|
yaml_config="$2"
|
||||||
shift 2
|
shift 2
|
||||||
|
@ -121,7 +121,7 @@ if [[ "$env_type" == "venv" || "$env_type" == "conda" ]]; then
|
||||||
set -x
|
set -x
|
||||||
|
|
||||||
if [ -n "$yaml_config" ]; then
|
if [ -n "$yaml_config" ]; then
|
||||||
yaml_config_arg="--config $yaml_config"
|
yaml_config_arg="--yaml-config $yaml_config"
|
||||||
else
|
else
|
||||||
yaml_config_arg=""
|
yaml_config_arg=""
|
||||||
fi
|
fi
|
||||||
|
@ -181,9 +181,9 @@ elif [[ "$env_type" == "container" ]]; then
|
||||||
|
|
||||||
# Add yaml config if provided, otherwise use default
|
# Add yaml config if provided, otherwise use default
|
||||||
if [ -n "$yaml_config" ]; then
|
if [ -n "$yaml_config" ]; then
|
||||||
cmd="$cmd -v $yaml_config:/app/run.yaml --config /app/run.yaml"
|
cmd="$cmd -v $yaml_config:/app/run.yaml --yaml-config /app/run.yaml"
|
||||||
else
|
else
|
||||||
cmd="$cmd --config /app/run.yaml"
|
cmd="$cmd --yaml-config /app/run.yaml"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Add any other args
|
# Add any other args
|
||||||
|
|
|
@ -36,7 +36,7 @@ class DistributionRegistry(Protocol):
|
||||||
|
|
||||||
|
|
||||||
REGISTER_PREFIX = "distributions:registry"
|
REGISTER_PREFIX = "distributions:registry"
|
||||||
KEY_VERSION = "v9"
|
KEY_VERSION = "v8"
|
||||||
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"
|
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -5,8 +5,7 @@ FROM python:3.12-slim
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
COPY . /app/
|
COPY . /app/
|
||||||
RUN /usr/local/bin/python -m pip install --upgrade pip && \
|
RUN /usr/local/bin/python -m pip install --upgrade pip && \
|
||||||
/usr/local/bin/pip3 install -r requirements.txt && \
|
/usr/local/bin/pip3 install -r requirements.txt
|
||||||
/usr/local/bin/pip3 install -r llama_stack/distribution/ui/requirements.txt
|
|
||||||
EXPOSE 8501
|
EXPOSE 8501
|
||||||
|
|
||||||
ENTRYPOINT ["streamlit", "run", "llama_stack/distribution/ui/app.py", "--server.port=8501", "--server.address=0.0.0.0"]
|
ENTRYPOINT ["streamlit", "run", "app.py", "--server.port=8501", "--server.address=0.0.0.0"]
|
||||||
|
|
|
@ -48,6 +48,3 @@ uv run --with ".[ui]" streamlit run llama_stack/distribution/ui/app.py
|
||||||
| TOGETHER_API_KEY | API key for Together provider | (empty string) |
|
| TOGETHER_API_KEY | API key for Together provider | (empty string) |
|
||||||
| SAMBANOVA_API_KEY | API key for SambaNova provider | (empty string) |
|
| SAMBANOVA_API_KEY | API key for SambaNova provider | (empty string) |
|
||||||
| OPENAI_API_KEY | API key for OpenAI provider | (empty string) |
|
| OPENAI_API_KEY | API key for OpenAI provider | (empty string) |
|
||||||
| KEYCLOAK_URL | URL for keycloak authentication | (empty string) |
|
|
||||||
| KEYCLOAK_REALM | Keycloak realm | default |
|
|
||||||
| KEYCLOAK_CLIENT_ID | Client ID for keycloak auth | (empty string) |
|
|
|
@ -50,42 +50,6 @@ def main():
|
||||||
)
|
)
|
||||||
pg.run()
|
pg.run()
|
||||||
|
|
||||||
def main2():
|
|
||||||
from dataclasses import asdict
|
|
||||||
st.subheader(f"Welcome {keycloak.user_info['preferred_username']}!")
|
|
||||||
st.write(f"Here is your user information:")
|
|
||||||
st.write(asdict(keycloak))
|
|
||||||
|
|
||||||
def get_access_token() -> str|None:
|
|
||||||
return st.session_state.get('access_token')
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
from streamlit_keycloak import login
|
|
||||||
import os
|
|
||||||
|
|
||||||
keycloak_url = os.environ.get("KEYCLOAK_URL")
|
|
||||||
keycloak_realm = os.environ.get("KEYCLOAK_REALM", "default")
|
|
||||||
keycloak_client_id = os.environ.get("KEYCLOAK_CLIENT_ID")
|
|
||||||
|
|
||||||
if keycloak_url and keycloak_client_id:
|
|
||||||
keycloak = login(
|
|
||||||
url=keycloak_url,
|
|
||||||
realm=keycloak_realm,
|
|
||||||
client_id=keycloak_client_id,
|
|
||||||
custom_labels={
|
|
||||||
"labelButton": "Sign in to kvant",
|
|
||||||
"labelLogin": "Please sign in to your kvant account.",
|
|
||||||
"errorNoPopup": "Unable to open the authentication popup. Allow popups and refresh the page to proceed.",
|
|
||||||
"errorPopupClosed": "Authentication popup was closed manually.",
|
|
||||||
"errorFatal": "Unable to connect to Keycloak using the current configuration."
|
|
||||||
},
|
|
||||||
auto_refresh=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
if keycloak.authenticated:
|
|
||||||
st.session_state['access_token'] = keycloak.access_token
|
|
||||||
main()
|
|
||||||
# TBD - add other authentications
|
|
||||||
else:
|
|
||||||
main()
|
|
||||||
|
|
|
@ -7,13 +7,11 @@
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from llama_stack_client import LlamaStackClient
|
from llama_stack_client import LlamaStackClient
|
||||||
from llama_stack.distribution.ui.app import get_access_token
|
|
||||||
|
|
||||||
|
|
||||||
class LlamaStackApi:
|
class LlamaStackApi:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.client = LlamaStackClient(
|
self.client = LlamaStackClient(
|
||||||
api_key=get_access_token(),
|
|
||||||
base_url=os.environ.get("LLAMA_STACK_ENDPOINT", "http://localhost:8321"),
|
base_url=os.environ.get("LLAMA_STACK_ENDPOINT", "http://localhost:8321"),
|
||||||
provider_data={
|
provider_data={
|
||||||
"fireworks_api_key": os.environ.get("FIREWORKS_API_KEY", ""),
|
"fireworks_api_key": os.environ.get("FIREWORKS_API_KEY", ""),
|
||||||
|
@ -30,3 +28,5 @@ class LlamaStackApi:
|
||||||
scoring_params = {fn_id: None for fn_id in scoring_function_ids}
|
scoring_params = {fn_id: None for fn_id in scoring_function_ids}
|
||||||
return self.client.scoring.score(input_rows=[row], scoring_functions=scoring_params)
|
return self.client.scoring.score(input_rows=[row], scoring_functions=scoring_params)
|
||||||
|
|
||||||
|
|
||||||
|
llama_stack_api = LlamaStackApi()
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue