forked from phoenix-oss/llama-stack-mirror
Compare commits
1 commit
kvant
...
create-pul
Author | SHA1 | Date | |
---|---|---|---|
|
075c5401f5 |
351 changed files with 9284 additions and 25653 deletions
10
.github/PULL_REQUEST_TEMPLATE.md
vendored
10
.github/PULL_REQUEST_TEMPLATE.md
vendored
|
@ -1,8 +1,10 @@
|
||||||
# What does this PR do?
|
# What does this PR do?
|
||||||
<!-- Provide a short summary of what this PR does and why. Link to relevant issues if applicable. -->
|
[Provide a short summary of what this PR does and why. Link to relevant issues if applicable.]
|
||||||
|
|
||||||
<!-- If resolving an issue, uncomment and update the line below -->
|
[//]: # (If resolving an issue, uncomment and update the line below)
|
||||||
<!-- Closes #[issue-number] -->
|
[//]: # (Closes #[issue-number])
|
||||||
|
|
||||||
## Test Plan
|
## Test Plan
|
||||||
<!-- Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.* -->
|
[Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.*]
|
||||||
|
|
||||||
|
[//]: # (## Documentation)
|
||||||
|
|
22
.github/actions/setup-runner/action.yml
vendored
22
.github/actions/setup-runner/action.yml
vendored
|
@ -1,22 +0,0 @@
|
||||||
name: Setup runner
|
|
||||||
description: Prepare a runner for the tests (install uv, python, project dependencies, etc.)
|
|
||||||
runs:
|
|
||||||
using: "composite"
|
|
||||||
steps:
|
|
||||||
- name: Install uv
|
|
||||||
uses: astral-sh/setup-uv@6b9c6063abd6010835644d4c2e1bef4cf5cd0fca # v6.0.1
|
|
||||||
with:
|
|
||||||
python-version: "3.10"
|
|
||||||
activate-environment: true
|
|
||||||
version: 0.7.6
|
|
||||||
|
|
||||||
- name: Install dependencies
|
|
||||||
shell: bash
|
|
||||||
run: |
|
|
||||||
uv sync --all-groups
|
|
||||||
uv pip install ollama faiss-cpu
|
|
||||||
# always test against the latest version of the client
|
|
||||||
# TODO: this is not necessarily a good idea. we need to test against both published and latest
|
|
||||||
# to find out backwards compatibility issues.
|
|
||||||
uv pip install git+https://github.com/meta-llama/llama-stack-client-python.git@main
|
|
||||||
uv pip install -e .
|
|
1
.github/workflows/Dockerfile
vendored
1
.github/workflows/Dockerfile
vendored
|
@ -1 +0,0 @@
|
||||||
FROM localhost:5000/distribution-kvant:dev
|
|
73
.github/workflows/ci-playground.yaml
vendored
73
.github/workflows/ci-playground.yaml
vendored
|
@ -1,73 +0,0 @@
|
||||||
name: Build and Push playground container
|
|
||||||
run-name: Build and Push playground container
|
|
||||||
on:
|
|
||||||
workflow_dispatch:
|
|
||||||
#schedule:
|
|
||||||
# - cron: "0 10 * * *"
|
|
||||||
push:
|
|
||||||
branches:
|
|
||||||
- main
|
|
||||||
- kvant
|
|
||||||
tags:
|
|
||||||
- 'v*'
|
|
||||||
pull_request:
|
|
||||||
branches:
|
|
||||||
- main
|
|
||||||
- kvant
|
|
||||||
env:
|
|
||||||
IMAGE: git.kvant.cloud/${{github.repository}}-playground
|
|
||||||
jobs:
|
|
||||||
build-playground:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
steps:
|
|
||||||
- name: Checkout
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
with:
|
|
||||||
fetch-depth: 0
|
|
||||||
|
|
||||||
- name: Set current time
|
|
||||||
uses: https://github.com/gerred/actions/current-time@master
|
|
||||||
id: current_time
|
|
||||||
|
|
||||||
- name: Set up Docker Buildx
|
|
||||||
uses: docker/setup-buildx-action@v3
|
|
||||||
|
|
||||||
- name: Login to git.kvant.cloud registry
|
|
||||||
uses: docker/login-action@v3
|
|
||||||
with:
|
|
||||||
registry: git.kvant.cloud
|
|
||||||
username: ${{ vars.ORG_PACKAGE_WRITER_USERNAME }}
|
|
||||||
password: ${{ secrets.ORG_PACKAGE_WRITER_TOKEN }}
|
|
||||||
|
|
||||||
- name: Docker meta
|
|
||||||
id: meta
|
|
||||||
uses: docker/metadata-action@v5
|
|
||||||
with:
|
|
||||||
# list of Docker images to use as base name for tags
|
|
||||||
images: |
|
|
||||||
${{env.IMAGE}}
|
|
||||||
# generate Docker tags based on the following events/attributes
|
|
||||||
tags: |
|
|
||||||
type=schedule
|
|
||||||
type=ref,event=branch
|
|
||||||
type=ref,event=pr
|
|
||||||
type=ref,event=tag
|
|
||||||
type=semver,pattern={{version}}
|
|
||||||
|
|
||||||
- name: Build and push to gitea registry
|
|
||||||
uses: docker/build-push-action@v6
|
|
||||||
with:
|
|
||||||
push: ${{ github.event_name != 'pull_request' }}
|
|
||||||
tags: ${{ steps.meta.outputs.tags }}
|
|
||||||
labels: ${{ steps.meta.outputs.labels }}
|
|
||||||
context: .
|
|
||||||
file: llama_stack/distribution/ui/Containerfile
|
|
||||||
provenance: mode=max
|
|
||||||
sbom: true
|
|
||||||
build-args: |
|
|
||||||
BUILD_DATE=${{ steps.current_time.outputs.time }}
|
|
||||||
cache-from: |
|
|
||||||
type=registry,ref=${{ env.IMAGE }}:buildcache
|
|
||||||
type=registry,ref=${{ env.IMAGE }}:${{ github.ref_name }}
|
|
||||||
type=registry,ref=${{ env.IMAGE }}:main
|
|
||||||
cache-to: type=registry,ref=${{ env.IMAGE }}:buildcache,mode=max,image-manifest=true
|
|
98
.github/workflows/ci.yaml
vendored
98
.github/workflows/ci.yaml
vendored
|
@ -1,98 +0,0 @@
|
||||||
name: Build and Push container
|
|
||||||
run-name: Build and Push container
|
|
||||||
on:
|
|
||||||
workflow_dispatch:
|
|
||||||
#schedule:
|
|
||||||
# - cron: "0 10 * * *"
|
|
||||||
push:
|
|
||||||
branches:
|
|
||||||
- main
|
|
||||||
- kvant
|
|
||||||
tags:
|
|
||||||
- 'v*'
|
|
||||||
pull_request:
|
|
||||||
branches:
|
|
||||||
- main
|
|
||||||
- kvant
|
|
||||||
env:
|
|
||||||
IMAGE: git.kvant.cloud/${{github.repository}}
|
|
||||||
jobs:
|
|
||||||
build:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
services:
|
|
||||||
registry:
|
|
||||||
image: registry:2
|
|
||||||
ports:
|
|
||||||
- 5000:5000
|
|
||||||
steps:
|
|
||||||
- name: Checkout
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
with:
|
|
||||||
fetch-depth: 0
|
|
||||||
|
|
||||||
- name: Set current time
|
|
||||||
uses: https://github.com/gerred/actions/current-time@master
|
|
||||||
id: current_time
|
|
||||||
|
|
||||||
- name: Set up Docker Buildx
|
|
||||||
uses: docker/setup-buildx-action@v3
|
|
||||||
with:
|
|
||||||
driver-opts: network=host
|
|
||||||
|
|
||||||
- name: Login to git.kvant.cloud registry
|
|
||||||
uses: docker/login-action@v3
|
|
||||||
with:
|
|
||||||
registry: git.kvant.cloud
|
|
||||||
username: ${{ vars.ORG_PACKAGE_WRITER_USERNAME }}
|
|
||||||
password: ${{ secrets.ORG_PACKAGE_WRITER_TOKEN }}
|
|
||||||
|
|
||||||
- name: Docker meta
|
|
||||||
id: meta
|
|
||||||
uses: docker/metadata-action@v5
|
|
||||||
with:
|
|
||||||
# list of Docker images to use as base name for tags
|
|
||||||
images: |
|
|
||||||
${{env.IMAGE}}
|
|
||||||
# generate Docker tags based on the following events/attributes
|
|
||||||
tags: |
|
|
||||||
type=schedule
|
|
||||||
type=ref,event=branch
|
|
||||||
type=ref,event=pr
|
|
||||||
type=ref,event=tag
|
|
||||||
type=semver,pattern={{version}}
|
|
||||||
|
|
||||||
- name: Install uv
|
|
||||||
uses: https://github.com/astral-sh/setup-uv@v5
|
|
||||||
with:
|
|
||||||
# Install a specific version of uv.
|
|
||||||
version: "0.7.8"
|
|
||||||
|
|
||||||
- name: Build
|
|
||||||
env:
|
|
||||||
USE_COPY_NOT_MOUNT: true
|
|
||||||
LLAMA_STACK_DIR: .
|
|
||||||
run: |
|
|
||||||
uvx --from . llama stack build --template kvant --image-type container
|
|
||||||
|
|
||||||
# docker tag distribution-kvant:dev ${{env.IMAGE}}:kvant
|
|
||||||
# docker push ${{env.IMAGE}}:kvant
|
|
||||||
|
|
||||||
docker tag distribution-kvant:dev localhost:5000/distribution-kvant:dev
|
|
||||||
docker push localhost:5000/distribution-kvant:dev
|
|
||||||
|
|
||||||
- name: Build and push to gitea registry
|
|
||||||
uses: docker/build-push-action@v6
|
|
||||||
with:
|
|
||||||
push: ${{ github.event_name != 'pull_request' }}
|
|
||||||
tags: ${{ steps.meta.outputs.tags }}
|
|
||||||
labels: ${{ steps.meta.outputs.labels }}
|
|
||||||
context: .github/workflows
|
|
||||||
provenance: mode=max
|
|
||||||
sbom: true
|
|
||||||
build-args: |
|
|
||||||
BUILD_DATE=${{ steps.current_time.outputs.time }}
|
|
||||||
cache-from: |
|
|
||||||
type=registry,ref=${{ env.IMAGE }}:buildcache
|
|
||||||
type=registry,ref=${{ env.IMAGE }}:${{ github.ref_name }}
|
|
||||||
type=registry,ref=${{ env.IMAGE }}:main
|
|
||||||
cache-to: type=registry,ref=${{ env.IMAGE }}:buildcache,mode=max,image-manifest=true
|
|
|
@ -23,18 +23,23 @@ jobs:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
auth-provider: [oauth2_token]
|
auth-provider: [kubernetes]
|
||||||
fail-fast: false # we want to run all tests regardless of failure
|
fail-fast: false # we want to run all tests regardless of failure
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install uv
|
||||||
uses: ./.github/actions/setup-runner
|
uses: astral-sh/setup-uv@c7f87aa956e4c323abf06d5dec078e358f6b4d04 # v6.0.0
|
||||||
|
with:
|
||||||
|
python-version: "3.10"
|
||||||
|
activate-environment: true
|
||||||
|
|
||||||
- name: Build Llama Stack
|
- name: Set Up Environment and Install Dependencies
|
||||||
run: |
|
run: |
|
||||||
|
uv sync --extra dev --extra test
|
||||||
|
uv pip install -e .
|
||||||
llama stack build --template ollama --image-type venv
|
llama stack build --template ollama --image-type venv
|
||||||
|
|
||||||
- name: Install minikube
|
- name: Install minikube
|
||||||
|
@ -42,53 +47,29 @@ jobs:
|
||||||
uses: medyagh/setup-minikube@cea33675329b799adccc9526aa5daccc26cd5052 # v0.0.19
|
uses: medyagh/setup-minikube@cea33675329b799adccc9526aa5daccc26cd5052 # v0.0.19
|
||||||
|
|
||||||
- name: Start minikube
|
- name: Start minikube
|
||||||
if: ${{ matrix.auth-provider == 'oauth2_token' }}
|
if: ${{ matrix.auth-provider == 'kubernetes' }}
|
||||||
run: |
|
run: |
|
||||||
minikube start
|
minikube start
|
||||||
kubectl get pods -A
|
kubectl get pods -A
|
||||||
|
|
||||||
- name: Configure Kube Auth
|
- name: Configure Kube Auth
|
||||||
if: ${{ matrix.auth-provider == 'oauth2_token' }}
|
if: ${{ matrix.auth-provider == 'kubernetes' }}
|
||||||
run: |
|
run: |
|
||||||
kubectl create namespace llama-stack
|
kubectl create namespace llama-stack
|
||||||
kubectl create serviceaccount llama-stack-auth -n llama-stack
|
kubectl create serviceaccount llama-stack-auth -n llama-stack
|
||||||
kubectl create rolebinding llama-stack-auth-rolebinding --clusterrole=admin --serviceaccount=llama-stack:llama-stack-auth -n llama-stack
|
kubectl create rolebinding llama-stack-auth-rolebinding --clusterrole=admin --serviceaccount=llama-stack:llama-stack-auth -n llama-stack
|
||||||
kubectl create token llama-stack-auth -n llama-stack > llama-stack-auth-token
|
kubectl create token llama-stack-auth -n llama-stack > llama-stack-auth-token
|
||||||
cat <<EOF | kubectl apply -f -
|
|
||||||
apiVersion: rbac.authorization.k8s.io/v1
|
|
||||||
kind: ClusterRole
|
|
||||||
metadata:
|
|
||||||
name: allow-anonymous-openid
|
|
||||||
rules:
|
|
||||||
- nonResourceURLs: ["/openid/v1/jwks"]
|
|
||||||
verbs: ["get"]
|
|
||||||
---
|
|
||||||
apiVersion: rbac.authorization.k8s.io/v1
|
|
||||||
kind: ClusterRoleBinding
|
|
||||||
metadata:
|
|
||||||
name: allow-anonymous-openid
|
|
||||||
roleRef:
|
|
||||||
apiGroup: rbac.authorization.k8s.io
|
|
||||||
kind: ClusterRole
|
|
||||||
name: allow-anonymous-openid
|
|
||||||
subjects:
|
|
||||||
- kind: User
|
|
||||||
name: system:anonymous
|
|
||||||
apiGroup: rbac.authorization.k8s.io
|
|
||||||
EOF
|
|
||||||
|
|
||||||
- name: Set Kubernetes Config
|
- name: Set Kubernetes Config
|
||||||
if: ${{ matrix.auth-provider == 'oauth2_token' }}
|
if: ${{ matrix.auth-provider == 'kubernetes' }}
|
||||||
run: |
|
run: |
|
||||||
echo "KUBERNETES_API_SERVER_URL=$(kubectl get --raw /.well-known/openid-configuration| jq -r .jwks_uri)" >> $GITHUB_ENV
|
echo "KUBERNETES_API_SERVER_URL=$(kubectl config view --minify -o jsonpath='{.clusters[0].cluster.server}')" >> $GITHUB_ENV
|
||||||
echo "KUBERNETES_CA_CERT_PATH=$(kubectl config view --minify -o jsonpath='{.clusters[0].cluster.certificate-authority}')" >> $GITHUB_ENV
|
echo "KUBERNETES_CA_CERT_PATH=$(kubectl config view --minify -o jsonpath='{.clusters[0].cluster.certificate-authority}')" >> $GITHUB_ENV
|
||||||
echo "KUBERNETES_ISSUER=$(kubectl get --raw /.well-known/openid-configuration| jq -r .issuer)" >> $GITHUB_ENV
|
|
||||||
echo "KUBERNETES_AUDIENCE=$(kubectl create token llama-stack-auth -n llama-stack --duration=1h | cut -d. -f2 | base64 -d | jq -r '.aud[0]')" >> $GITHUB_ENV
|
|
||||||
|
|
||||||
- name: Set Kube Auth Config and run server
|
- name: Set Kube Auth Config and run server
|
||||||
env:
|
env:
|
||||||
INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct"
|
INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct"
|
||||||
if: ${{ matrix.auth-provider == 'oauth2_token' }}
|
if: ${{ matrix.auth-provider == 'kubernetes' }}
|
||||||
run: |
|
run: |
|
||||||
run_dir=$(mktemp -d)
|
run_dir=$(mktemp -d)
|
||||||
cat <<'EOF' > $run_dir/run.yaml
|
cat <<'EOF' > $run_dir/run.yaml
|
||||||
|
@ -100,10 +81,10 @@ jobs:
|
||||||
port: 8321
|
port: 8321
|
||||||
EOF
|
EOF
|
||||||
yq eval '.server.auth = {"provider_type": "${{ matrix.auth-provider }}"}' -i $run_dir/run.yaml
|
yq eval '.server.auth = {"provider_type": "${{ matrix.auth-provider }}"}' -i $run_dir/run.yaml
|
||||||
yq eval '.server.auth.config = {"tls_cafile": "${{ env.KUBERNETES_CA_CERT_PATH }}", "issuer": "${{ env.KUBERNETES_ISSUER }}", "audience": "${{ env.KUBERNETES_AUDIENCE }}"}' -i $run_dir/run.yaml
|
yq eval '.server.auth.config = {"api_server_url": "${{ env.KUBERNETES_API_SERVER_URL }}", "ca_cert_path": "${{ env.KUBERNETES_CA_CERT_PATH }}"}' -i $run_dir/run.yaml
|
||||||
yq eval '.server.auth.config.jwks = {"uri": "${{ env.KUBERNETES_API_SERVER_URL }}"}' -i $run_dir/run.yaml
|
|
||||||
cat $run_dir/run.yaml
|
cat $run_dir/run.yaml
|
||||||
|
|
||||||
|
source .venv/bin/activate
|
||||||
nohup uv run llama stack run $run_dir/run.yaml --image-type venv > server.log 2>&1 &
|
nohup uv run llama stack run $run_dir/run.yaml --image-type venv > server.log 2>&1 &
|
||||||
|
|
||||||
- name: Wait for Llama Stack server to be ready
|
- name: Wait for Llama Stack server to be ready
|
|
@ -24,7 +24,7 @@ jobs:
|
||||||
matrix:
|
matrix:
|
||||||
# Listing tests manually since some of them currently fail
|
# Listing tests manually since some of them currently fail
|
||||||
# TODO: generate matrix list from tests/integration when fixed
|
# TODO: generate matrix list from tests/integration when fixed
|
||||||
test-type: [agents, inference, datasets, inspect, scoring, post_training, providers, tool_runtime]
|
test-type: [agents, inference, datasets, inspect, scoring, post_training, providers]
|
||||||
client-type: [library, http]
|
client-type: [library, http]
|
||||||
fail-fast: false # we want to run all tests regardless of failure
|
fail-fast: false # we want to run all tests regardless of failure
|
||||||
|
|
||||||
|
@ -32,14 +32,24 @@ jobs:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install uv
|
||||||
uses: ./.github/actions/setup-runner
|
uses: astral-sh/setup-uv@c7f87aa956e4c323abf06d5dec078e358f6b4d04 # v6.0.0
|
||||||
|
with:
|
||||||
|
python-version: "3.10"
|
||||||
|
activate-environment: true
|
||||||
|
|
||||||
- name: Setup ollama
|
- name: Setup ollama
|
||||||
uses: ./.github/actions/setup-ollama
|
uses: ./.github/actions/setup-ollama
|
||||||
|
|
||||||
- name: Build Llama Stack
|
- name: Set Up Environment and Install Dependencies
|
||||||
run: |
|
run: |
|
||||||
|
uv sync --extra dev --extra test
|
||||||
|
uv pip install ollama faiss-cpu
|
||||||
|
# always test against the latest version of the client
|
||||||
|
# TODO: this is not necessarily a good idea. we need to test against both published and latest
|
||||||
|
# to find out backwards compatibility issues.
|
||||||
|
uv pip install git+https://github.com/meta-llama/llama-stack-client-python.git@main
|
||||||
|
uv pip install -e .
|
||||||
llama stack build --template ollama --image-type venv
|
llama stack build --template ollama --image-type venv
|
||||||
|
|
||||||
- name: Start Llama Stack server in background
|
- name: Start Llama Stack server in background
|
||||||
|
@ -47,7 +57,8 @@ jobs:
|
||||||
env:
|
env:
|
||||||
INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct"
|
INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct"
|
||||||
run: |
|
run: |
|
||||||
LLAMA_STACK_LOG_FILE=server.log nohup uv run llama stack run ./llama_stack/templates/ollama/run.yaml --image-type venv &
|
source .venv/bin/activate
|
||||||
|
nohup uv run llama stack run ./llama_stack/templates/ollama/run.yaml --image-type venv > server.log 2>&1 &
|
||||||
|
|
||||||
- name: Wait for Llama Stack server to be ready
|
- name: Wait for Llama Stack server to be ready
|
||||||
if: matrix.client-type == 'http'
|
if: matrix.client-type == 'http'
|
||||||
|
@ -75,12 +86,6 @@ jobs:
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
- name: Check Storage and Memory Available Before Tests
|
|
||||||
if: ${{ always() }}
|
|
||||||
run: |
|
|
||||||
free -h
|
|
||||||
df -h
|
|
||||||
|
|
||||||
- name: Run Integration Tests
|
- name: Run Integration Tests
|
||||||
env:
|
env:
|
||||||
INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct"
|
INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct"
|
||||||
|
@ -90,24 +95,17 @@ jobs:
|
||||||
else
|
else
|
||||||
stack_config="http://localhost:8321"
|
stack_config="http://localhost:8321"
|
||||||
fi
|
fi
|
||||||
uv run pytest -s -v tests/integration/${{ matrix.test-type }} --stack-config=${stack_config} \
|
uv run pytest -v tests/integration/${{ matrix.test-type }} --stack-config=${stack_config} \
|
||||||
-k "not(builtin_tool or safety_with_image or code_interpreter or test_rag)" \
|
-k "not(builtin_tool or safety_with_image or code_interpreter or test_rag)" \
|
||||||
--text-model="meta-llama/Llama-3.2-3B-Instruct" \
|
--text-model="meta-llama/Llama-3.2-3B-Instruct" \
|
||||||
--embedding-model=all-MiniLM-L6-v2
|
--embedding-model=all-MiniLM-L6-v2
|
||||||
|
|
||||||
- name: Check Storage and Memory Available After Tests
|
|
||||||
if: ${{ always() }}
|
|
||||||
run: |
|
|
||||||
free -h
|
|
||||||
df -h
|
|
||||||
|
|
||||||
- name: Write ollama logs to file
|
- name: Write ollama logs to file
|
||||||
if: ${{ always() }}
|
|
||||||
run: |
|
run: |
|
||||||
sudo journalctl -u ollama.service > ollama.log
|
sudo journalctl -u ollama.service > ollama.log
|
||||||
|
|
||||||
- name: Upload all logs to artifacts
|
- name: Upload all logs to artifacts
|
||||||
if: ${{ always() }}
|
if: always()
|
||||||
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
|
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
|
||||||
with:
|
with:
|
||||||
name: logs-${{ github.run_id }}-${{ github.run_attempt }}-${{ matrix.client-type }}-${{ matrix.test-type }}
|
name: logs-${{ github.run_id }}-${{ github.run_attempt }}-${{ matrix.client-type }}-${{ matrix.test-type }}
|
|
@ -29,7 +29,6 @@ jobs:
|
||||||
- uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1
|
- uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1
|
||||||
env:
|
env:
|
||||||
SKIP: no-commit-to-branch
|
SKIP: no-commit-to-branch
|
||||||
RUFF_OUTPUT_FORMAT: github
|
|
||||||
|
|
||||||
- name: Verify if there are any diff files after pre-commit
|
- name: Verify if there are any diff files after pre-commit
|
||||||
run: |
|
run: |
|
|
@ -50,8 +50,21 @@ jobs:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Set up Python
|
||||||
uses: ./.github/actions/setup-runner
|
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0
|
||||||
|
with:
|
||||||
|
python-version: '3.10'
|
||||||
|
|
||||||
|
- name: Install uv
|
||||||
|
uses: astral-sh/setup-uv@c7f87aa956e4c323abf06d5dec078e358f6b4d04 # v6.0.0
|
||||||
|
with:
|
||||||
|
python-version: "3.10"
|
||||||
|
|
||||||
|
- name: Install LlamaStack
|
||||||
|
run: |
|
||||||
|
uv venv
|
||||||
|
source .venv/bin/activate
|
||||||
|
uv pip install -e .
|
||||||
|
|
||||||
- name: Print build dependencies
|
- name: Print build dependencies
|
||||||
run: |
|
run: |
|
||||||
|
@ -66,6 +79,7 @@ jobs:
|
||||||
- name: Print dependencies in the image
|
- name: Print dependencies in the image
|
||||||
if: matrix.image-type == 'venv'
|
if: matrix.image-type == 'venv'
|
||||||
run: |
|
run: |
|
||||||
|
source test/bin/activate
|
||||||
uv pip list
|
uv pip list
|
||||||
|
|
||||||
build-single-provider:
|
build-single-provider:
|
||||||
|
@ -74,8 +88,21 @@ jobs:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Set up Python
|
||||||
uses: ./.github/actions/setup-runner
|
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0
|
||||||
|
with:
|
||||||
|
python-version: '3.10'
|
||||||
|
|
||||||
|
- name: Install uv
|
||||||
|
uses: astral-sh/setup-uv@c7f87aa956e4c323abf06d5dec078e358f6b4d04 # v6.0.0
|
||||||
|
with:
|
||||||
|
python-version: "3.10"
|
||||||
|
|
||||||
|
- name: Install LlamaStack
|
||||||
|
run: |
|
||||||
|
uv venv
|
||||||
|
source .venv/bin/activate
|
||||||
|
uv pip install -e .
|
||||||
|
|
||||||
- name: Build a single provider
|
- name: Build a single provider
|
||||||
run: |
|
run: |
|
||||||
|
@ -87,8 +114,21 @@ jobs:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Set up Python
|
||||||
uses: ./.github/actions/setup-runner
|
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0
|
||||||
|
with:
|
||||||
|
python-version: '3.10'
|
||||||
|
|
||||||
|
- name: Install uv
|
||||||
|
uses: astral-sh/setup-uv@c7f87aa956e4c323abf06d5dec078e358f6b4d04 # v6.0.0
|
||||||
|
with:
|
||||||
|
python-version: "3.10"
|
||||||
|
|
||||||
|
- name: Install LlamaStack
|
||||||
|
run: |
|
||||||
|
uv venv
|
||||||
|
source .venv/bin/activate
|
||||||
|
uv pip install -e .
|
||||||
|
|
||||||
- name: Build a single provider
|
- name: Build a single provider
|
||||||
run: |
|
run: |
|
||||||
|
@ -112,8 +152,21 @@ jobs:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Set up Python
|
||||||
uses: ./.github/actions/setup-runner
|
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0
|
||||||
|
with:
|
||||||
|
python-version: '3.10'
|
||||||
|
|
||||||
|
- name: Install uv
|
||||||
|
uses: astral-sh/setup-uv@0c5e2b8115b80b4c7c5ddf6ffdd634974642d182 # v5.4.1
|
||||||
|
with:
|
||||||
|
python-version: "3.10"
|
||||||
|
|
||||||
|
- name: Install LlamaStack
|
||||||
|
run: |
|
||||||
|
uv venv
|
||||||
|
source .venv/bin/activate
|
||||||
|
uv pip install -e .
|
||||||
|
|
||||||
- name: Pin template to UBI9 base
|
- name: Pin template to UBI9 base
|
||||||
run: |
|
run: |
|
|
@ -25,8 +25,15 @@ jobs:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install uv
|
||||||
uses: ./.github/actions/setup-runner
|
uses: astral-sh/setup-uv@c7f87aa956e4c323abf06d5dec078e358f6b4d04 # v6.0.0
|
||||||
|
with:
|
||||||
|
python-version: "3.10"
|
||||||
|
|
||||||
|
- name: Set Up Environment and Install Dependencies
|
||||||
|
run: |
|
||||||
|
uv sync --extra dev --extra test
|
||||||
|
uv pip install -e .
|
||||||
|
|
||||||
- name: Apply image type to config file
|
- name: Apply image type to config file
|
||||||
run: |
|
run: |
|
||||||
|
@ -52,6 +59,7 @@ jobs:
|
||||||
env:
|
env:
|
||||||
INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct"
|
INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct"
|
||||||
run: |
|
run: |
|
||||||
|
source ci-test/bin/activate
|
||||||
uv run pip list
|
uv run pip list
|
||||||
nohup uv run --active llama stack run tests/external-provider/llama-stack-provider-ollama/run.yaml --image-type ${{ matrix.image-type }} > server.log 2>&1 &
|
nohup uv run --active llama stack run tests/external-provider/llama-stack-provider-ollama/run.yaml --image-type ${{ matrix.image-type }} > server.log 2>&1 &
|
||||||
|
|
|
@ -30,11 +30,17 @@ jobs:
|
||||||
- "3.12"
|
- "3.12"
|
||||||
- "3.13"
|
- "3.13"
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Set up Python ${{ matrix.python }}
|
||||||
uses: ./.github/actions/setup-runner
|
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python }}
|
||||||
|
|
||||||
|
- uses: astral-sh/setup-uv@c7f87aa956e4c323abf06d5dec078e358f6b4d04 # v6.0.0
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python }}
|
||||||
|
enable-cache: false
|
||||||
|
|
||||||
- name: Run unit tests
|
- name: Run unit tests
|
||||||
run: |
|
run: |
|
|
@ -37,8 +37,16 @@ jobs:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Set up Python
|
||||||
uses: ./.github/actions/setup-runner
|
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0
|
||||||
|
with:
|
||||||
|
python-version: '3.11'
|
||||||
|
|
||||||
|
- name: Install the latest version of uv
|
||||||
|
uses: astral-sh/setup-uv@c7f87aa956e4c323abf06d5dec078e358f6b4d04 # v6.0.0
|
||||||
|
|
||||||
|
- name: Sync with uv
|
||||||
|
run: uv sync --extra docs
|
||||||
|
|
||||||
- name: Build HTML
|
- name: Build HTML
|
||||||
run: |
|
run: |
|
2
.gitignore
vendored
2
.gitignore
vendored
|
@ -6,7 +6,6 @@ dev_requirements.txt
|
||||||
build
|
build
|
||||||
.DS_Store
|
.DS_Store
|
||||||
llama_stack/configs/*
|
llama_stack/configs/*
|
||||||
.cursor/
|
|
||||||
xcuserdata/
|
xcuserdata/
|
||||||
*.hmap
|
*.hmap
|
||||||
.DS_Store
|
.DS_Store
|
||||||
|
@ -24,4 +23,3 @@ venv/
|
||||||
pytest-report.xml
|
pytest-report.xml
|
||||||
.coverage
|
.coverage
|
||||||
.python-version
|
.python-version
|
||||||
data
|
|
||||||
|
|
|
@ -53,7 +53,7 @@ repos:
|
||||||
- black==24.3.0
|
- black==24.3.0
|
||||||
|
|
||||||
- repo: https://github.com/astral-sh/uv-pre-commit
|
- repo: https://github.com/astral-sh/uv-pre-commit
|
||||||
rev: 0.7.8
|
rev: 0.6.3
|
||||||
hooks:
|
hooks:
|
||||||
- id: uv-lock
|
- id: uv-lock
|
||||||
- id: uv-export
|
- id: uv-export
|
||||||
|
@ -61,7 +61,6 @@ repos:
|
||||||
"--frozen",
|
"--frozen",
|
||||||
"--no-hashes",
|
"--no-hashes",
|
||||||
"--no-emit-project",
|
"--no-emit-project",
|
||||||
"--no-default-groups",
|
|
||||||
"--output-file=requirements.txt"
|
"--output-file=requirements.txt"
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -89,17 +88,20 @@ repos:
|
||||||
- id: distro-codegen
|
- id: distro-codegen
|
||||||
name: Distribution Template Codegen
|
name: Distribution Template Codegen
|
||||||
additional_dependencies:
|
additional_dependencies:
|
||||||
- uv==0.7.8
|
- uv==0.6.0
|
||||||
entry: uv run --group codegen ./scripts/distro_codegen.py
|
entry: uv run --extra codegen ./scripts/distro_codegen.py
|
||||||
language: python
|
language: python
|
||||||
pass_filenames: false
|
pass_filenames: false
|
||||||
require_serial: true
|
require_serial: true
|
||||||
files: ^llama_stack/templates/.*$|^llama_stack/providers/.*/inference/.*/models\.py$
|
files: ^llama_stack/templates/.*$|^llama_stack/providers/.*/inference/.*/models\.py$
|
||||||
|
|
||||||
|
- repo: local
|
||||||
|
hooks:
|
||||||
- id: openapi-codegen
|
- id: openapi-codegen
|
||||||
name: API Spec Codegen
|
name: API Spec Codegen
|
||||||
additional_dependencies:
|
additional_dependencies:
|
||||||
- uv==0.7.8
|
- uv==0.6.2
|
||||||
entry: sh -c 'uv run ./docs/openapi_generator/run_openapi_generator.sh > /dev/null'
|
entry: sh -c 'uv run --with ".[dev]" ./docs/openapi_generator/run_openapi_generator.sh > /dev/null'
|
||||||
language: python
|
language: python
|
||||||
pass_filenames: false
|
pass_filenames: false
|
||||||
require_serial: true
|
require_serial: true
|
||||||
|
|
|
@ -5,21 +5,28 @@
|
||||||
# Required
|
# Required
|
||||||
version: 2
|
version: 2
|
||||||
|
|
||||||
# Build documentation in the "docs/" directory with Sphinx
|
|
||||||
sphinx:
|
|
||||||
configuration: docs/source/conf.py
|
|
||||||
|
|
||||||
# Set the OS, Python version and other tools you might need
|
# Set the OS, Python version and other tools you might need
|
||||||
build:
|
build:
|
||||||
os: ubuntu-22.04
|
os: ubuntu-22.04
|
||||||
tools:
|
tools:
|
||||||
python: "3.12"
|
python: "3.12"
|
||||||
jobs:
|
# You can also specify other tool versions:
|
||||||
pre_create_environment:
|
# nodejs: "19"
|
||||||
- asdf plugin add uv
|
# rust: "1.64"
|
||||||
- asdf install uv latest
|
# golang: "1.19"
|
||||||
- asdf global uv latest
|
|
||||||
create_environment:
|
# Build documentation in the "docs/" directory with Sphinx
|
||||||
- uv venv "${READTHEDOCS_VIRTUALENV_PATH}"
|
sphinx:
|
||||||
|
configuration: docs/source/conf.py
|
||||||
|
|
||||||
|
# Optionally build your docs in additional formats such as PDF and ePub
|
||||||
|
# formats:
|
||||||
|
# - pdf
|
||||||
|
# - epub
|
||||||
|
|
||||||
|
# Optional but recommended, declare the Python requirements required
|
||||||
|
# to build your documentation
|
||||||
|
# See https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html
|
||||||
|
python:
|
||||||
install:
|
install:
|
||||||
- UV_PROJECT_ENVIRONMENT="${READTHEDOCS_VIRTUALENV_PATH}" uv sync --frozen --group docs
|
- requirements: docs/requirements.txt
|
||||||
|
|
|
@ -480,3 +480,4 @@ Published on: 2024-11-20T22:18:00Z
|
||||||
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
|
|
@ -167,11 +167,14 @@ If you have made changes to a provider's configuration in any form (introducing
|
||||||
If you are making changes to the documentation at [https://llama-stack.readthedocs.io/en/latest/](https://llama-stack.readthedocs.io/en/latest/), you can use the following command to build the documentation and preview your changes. You will need [Sphinx](https://www.sphinx-doc.org/en/master/) and the readthedocs theme.
|
If you are making changes to the documentation at [https://llama-stack.readthedocs.io/en/latest/](https://llama-stack.readthedocs.io/en/latest/), you can use the following command to build the documentation and preview your changes. You will need [Sphinx](https://www.sphinx-doc.org/en/master/) and the readthedocs theme.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
cd docs
|
||||||
|
uv sync --extra docs
|
||||||
|
|
||||||
# This rebuilds the documentation pages.
|
# This rebuilds the documentation pages.
|
||||||
uv run --group docs make -C docs/ html
|
uv run make html
|
||||||
|
|
||||||
# This will start a local server (usually at http://127.0.0.1:8000) that automatically rebuilds and refreshes when you make changes to the documentation.
|
# This will start a local server (usually at http://127.0.0.1:8000) that automatically rebuilds and refreshes when you make changes to the documentation.
|
||||||
uv run --group docs sphinx-autobuild docs/source docs/build/html --write-all
|
uv run sphinx-autobuild source build/html --write-all
|
||||||
```
|
```
|
||||||
|
|
||||||
### Update API Documentation
|
### Update API Documentation
|
||||||
|
@ -179,7 +182,7 @@ uv run --group docs sphinx-autobuild docs/source docs/build/html --write-all
|
||||||
If you modify or add new API endpoints, update the API documentation accordingly. You can do this by running the following command:
|
If you modify or add new API endpoints, update the API documentation accordingly. You can do this by running the following command:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
uv run ./docs/openapi_generator/run_openapi_generator.sh
|
uv run --with ".[dev]" ./docs/openapi_generator/run_openapi_generator.sh
|
||||||
```
|
```
|
||||||
|
|
||||||
The generated API documentation will be available in `docs/_static/`. Make sure to review the changes before committing.
|
The generated API documentation will be available in `docs/_static/`. Make sure to review the changes before committing.
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
include pyproject.toml
|
include pyproject.toml
|
||||||
|
include llama_stack/templates/dependencies.json
|
||||||
include llama_stack/models/llama/llama3/tokenizer.model
|
include llama_stack/models/llama/llama3/tokenizer.model
|
||||||
include llama_stack/models/llama/llama4/tokenizer.model
|
include llama_stack/models/llama/llama4/tokenizer.model
|
||||||
include llama_stack/distribution/*.sh
|
include llama_stack/distribution/*.sh
|
||||||
|
|
43
README.md
43
README.md
|
@ -107,29 +107,26 @@ By reducing friction and complexity, Llama Stack empowers developers to focus on
|
||||||
### API Providers
|
### API Providers
|
||||||
Here is a list of the various API providers and available distributions that can help developers get started easily with Llama Stack.
|
Here is a list of the various API providers and available distributions that can help developers get started easily with Llama Stack.
|
||||||
|
|
||||||
| **API Provider Builder** | **Environments** | **Agents** | **Inference** | **Memory** | **Safety** | **Telemetry** | **Post Training** |
|
| **API Provider Builder** | **Environments** | **Agents** | **Inference** | **Memory** | **Safety** | **Telemetry** |
|
||||||
|:------------------------:|:----------------------:|:----------:|:-------------:|:----------:|:----------:|:-------------:|:-----------------:|
|
|:------------------------:|:----------------------:|:----------:|:-------------:|:----------:|:----------:|:-------------:|
|
||||||
| Meta Reference | Single Node | ✅ | ✅ | ✅ | ✅ | ✅ | |
|
| Meta Reference | Single Node | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||||
| SambaNova | Hosted | | ✅ | | ✅ | | |
|
| SambaNova | Hosted | | ✅ | | | |
|
||||||
| Cerebras | Hosted | | ✅ | | | | |
|
| Cerebras | Hosted | | ✅ | | | |
|
||||||
| Fireworks | Hosted | ✅ | ✅ | ✅ | | | |
|
| Fireworks | Hosted | ✅ | ✅ | ✅ | | |
|
||||||
| AWS Bedrock | Hosted | | ✅ | | ✅ | | |
|
| AWS Bedrock | Hosted | | ✅ | | ✅ | |
|
||||||
| Together | Hosted | ✅ | ✅ | | ✅ | | |
|
| Together | Hosted | ✅ | ✅ | | ✅ | |
|
||||||
| Groq | Hosted | | ✅ | | | | |
|
| Groq | Hosted | | ✅ | | | |
|
||||||
| Ollama | Single Node | | ✅ | | | | |
|
| Ollama | Single Node | | ✅ | | | |
|
||||||
| TGI | Hosted and Single Node | | ✅ | | | | |
|
| TGI | Hosted and Single Node | | ✅ | | | |
|
||||||
| NVIDIA NIM | Hosted and Single Node | | ✅ | | | | |
|
| NVIDIA NIM | Hosted and Single Node | | ✅ | | | |
|
||||||
| Chroma | Single Node | | | ✅ | | | |
|
| Chroma | Single Node | | | ✅ | | |
|
||||||
| PG Vector | Single Node | | | ✅ | | | |
|
| PG Vector | Single Node | | | ✅ | | |
|
||||||
| PyTorch ExecuTorch | On-device iOS | ✅ | ✅ | | | | |
|
| PyTorch ExecuTorch | On-device iOS | ✅ | ✅ | | | |
|
||||||
| vLLM | Hosted and Single Node | | ✅ | | | | |
|
| vLLM | Hosted and Single Node | | ✅ | | | |
|
||||||
| OpenAI | Hosted | | ✅ | | | | |
|
| OpenAI | Hosted | | ✅ | | | |
|
||||||
| Anthropic | Hosted | | ✅ | | | | |
|
| Anthropic | Hosted | | ✅ | | | |
|
||||||
| Gemini | Hosted | | ✅ | | | | |
|
| Gemini | Hosted | | ✅ | | | |
|
||||||
| watsonx | Hosted | | ✅ | | | | |
|
| watsonx | Hosted | | ✅ | | | |
|
||||||
| HuggingFace | Single Node | | | | | | ✅ |
|
|
||||||
| TorchTune | Single Node | | | | | | ✅ |
|
|
||||||
| NVIDIA NEMO | Hosted | | | | | | ✅ |
|
|
||||||
|
|
||||||
|
|
||||||
### Distributions
|
### Distributions
|
||||||
|
|
1996
docs/_static/llama-stack-spec.html
vendored
1996
docs/_static/llama-stack-spec.html
vendored
File diff suppressed because it is too large
Load diff
1479
docs/_static/llama-stack-spec.yaml
vendored
1479
docs/_static/llama-stack-spec.yaml
vendored
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
|
@ -759,7 +759,7 @@ class Generator:
|
||||||
)
|
)
|
||||||
|
|
||||||
return Operation(
|
return Operation(
|
||||||
tags=[getattr(op.defining_class, "API_NAMESPACE", op.defining_class.__name__)],
|
tags=[op.defining_class.__name__],
|
||||||
summary=None,
|
summary=None,
|
||||||
# summary=doc_string.short_description,
|
# summary=doc_string.short_description,
|
||||||
description=description,
|
description=description,
|
||||||
|
@ -805,8 +805,6 @@ class Generator:
|
||||||
operation_tags: List[Tag] = []
|
operation_tags: List[Tag] = []
|
||||||
for cls in endpoint_classes:
|
for cls in endpoint_classes:
|
||||||
doc_string = parse_type(cls)
|
doc_string = parse_type(cls)
|
||||||
if hasattr(cls, "API_NAMESPACE") and cls.API_NAMESPACE != cls.__name__:
|
|
||||||
continue
|
|
||||||
operation_tags.append(
|
operation_tags.append(
|
||||||
Tag(
|
Tag(
|
||||||
name=cls.__name__,
|
name=cls.__name__,
|
||||||
|
|
|
@ -3,10 +3,10 @@
|
||||||
Here's a collection of comprehensive guides, examples, and resources for building AI applications with Llama Stack. For the complete documentation, visit our [ReadTheDocs page](https://llama-stack.readthedocs.io/en/latest/index.html).
|
Here's a collection of comprehensive guides, examples, and resources for building AI applications with Llama Stack. For the complete documentation, visit our [ReadTheDocs page](https://llama-stack.readthedocs.io/en/latest/index.html).
|
||||||
|
|
||||||
## Render locally
|
## Render locally
|
||||||
|
|
||||||
From the llama-stack root directory, run the following command to render the docs locally:
|
|
||||||
```bash
|
```bash
|
||||||
uv run --group docs sphinx-autobuild docs/source docs/build/html --write-all
|
pip install -r requirements.txt
|
||||||
|
cd docs
|
||||||
|
python -m sphinx_autobuild source _build
|
||||||
```
|
```
|
||||||
You can open up the docs in your browser at http://localhost:8000
|
You can open up the docs in your browser at http://localhost:8000
|
||||||
|
|
||||||
|
|
16
docs/requirements.txt
Normal file
16
docs/requirements.txt
Normal file
|
@ -0,0 +1,16 @@
|
||||||
|
linkify
|
||||||
|
myst-parser
|
||||||
|
-e git+https://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme
|
||||||
|
sphinx==8.1.3
|
||||||
|
sphinx-copybutton
|
||||||
|
sphinx-design
|
||||||
|
sphinx-pdj-theme
|
||||||
|
sphinx-rtd-theme>=1.0.0
|
||||||
|
sphinx-tabs
|
||||||
|
sphinx_autobuild
|
||||||
|
sphinx_rtd_dark_mode
|
||||||
|
sphinxcontrib-mermaid
|
||||||
|
sphinxcontrib-openapi
|
||||||
|
sphinxcontrib-redoc
|
||||||
|
sphinxcontrib-video
|
||||||
|
tomli
|
|
@ -57,31 +57,6 @@ chunks = [
|
||||||
]
|
]
|
||||||
client.vector_io.insert(vector_db_id=vector_db_id, chunks=chunks)
|
client.vector_io.insert(vector_db_id=vector_db_id, chunks=chunks)
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Using Precomputed Embeddings
|
|
||||||
If you decide to precompute embeddings for your documents, you can insert them directly into the vector database by
|
|
||||||
including the embedding vectors in the chunk data. This is useful if you have a separate embedding service or if you
|
|
||||||
want to customize the ingestion process.
|
|
||||||
```python
|
|
||||||
chunks_with_embeddings = [
|
|
||||||
{
|
|
||||||
"content": "First chunk of text",
|
|
||||||
"mime_type": "text/plain",
|
|
||||||
"embedding": [0.1, 0.2, 0.3, ...], # Your precomputed embedding vector
|
|
||||||
"metadata": {"document_id": "doc1", "section": "introduction"},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"content": "Second chunk of text",
|
|
||||||
"mime_type": "text/plain",
|
|
||||||
"embedding": [0.2, 0.3, 0.4, ...], # Your precomputed embedding vector
|
|
||||||
"metadata": {"document_id": "doc1", "section": "methodology"},
|
|
||||||
},
|
|
||||||
]
|
|
||||||
client.vector_io.insert(vector_db_id=vector_db_id, chunks=chunks_with_embeddings)
|
|
||||||
```
|
|
||||||
When providing precomputed embeddings, ensure the embedding dimension matches the embedding_dimension specified when
|
|
||||||
registering the vector database.
|
|
||||||
|
|
||||||
### Retrieval
|
### Retrieval
|
||||||
You can query the vector database to retrieve documents based on their embeddings.
|
You can query the vector database to retrieve documents based on their embeddings.
|
||||||
```python
|
```python
|
||||||
|
|
|
@ -22,11 +22,7 @@ from docutils import nodes
|
||||||
# Read version from pyproject.toml
|
# Read version from pyproject.toml
|
||||||
with Path(__file__).parent.parent.parent.joinpath("pyproject.toml").open("rb") as f:
|
with Path(__file__).parent.parent.parent.joinpath("pyproject.toml").open("rb") as f:
|
||||||
pypi_url = "https://pypi.org/pypi/llama-stack/json"
|
pypi_url = "https://pypi.org/pypi/llama-stack/json"
|
||||||
headers = {
|
version_tag = json.loads(requests.get(pypi_url).text)["info"]["version"]
|
||||||
'User-Agent': 'pip/23.0.1 (python 3.11)', # Mimic pip's user agent
|
|
||||||
'Accept': 'application/json'
|
|
||||||
}
|
|
||||||
version_tag = json.loads(requests.get(pypi_url, headers=headers).text)["info"]["version"]
|
|
||||||
print(f"{version_tag=}")
|
print(f"{version_tag=}")
|
||||||
|
|
||||||
# generate the full link including text and url here
|
# generate the full link including text and url here
|
||||||
|
@ -57,6 +53,14 @@ myst_enable_extensions = ["colon_fence"]
|
||||||
|
|
||||||
html_theme = "sphinx_rtd_theme"
|
html_theme = "sphinx_rtd_theme"
|
||||||
html_use_relative_paths = True
|
html_use_relative_paths = True
|
||||||
|
|
||||||
|
# html_theme = "sphinx_pdj_theme"
|
||||||
|
# html_theme_path = [sphinx_pdj_theme.get_html_theme_path()]
|
||||||
|
|
||||||
|
# html_theme = "pytorch_sphinx_theme"
|
||||||
|
# html_theme_path = [pytorch_sphinx_theme.get_html_theme_path()]
|
||||||
|
|
||||||
|
|
||||||
templates_path = ["_templates"]
|
templates_path = ["_templates"]
|
||||||
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
|
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
|
||||||
|
|
||||||
|
|
|
@ -338,48 +338,6 @@ INFO: Application startup complete.
|
||||||
INFO: Uvicorn running on http://['::', '0.0.0.0']:8321 (Press CTRL+C to quit)
|
INFO: Uvicorn running on http://['::', '0.0.0.0']:8321 (Press CTRL+C to quit)
|
||||||
INFO: 2401:db00:35c:2d2b:face:0:c9:0:54678 - "GET /models/list HTTP/1.1" 200 OK
|
INFO: 2401:db00:35c:2d2b:face:0:c9:0:54678 - "GET /models/list HTTP/1.1" 200 OK
|
||||||
```
|
```
|
||||||
### Listing Distributions
|
|
||||||
Using the list command, you can view all existing Llama Stack distributions, including stacks built from templates, from scratch, or using custom configuration files.
|
|
||||||
|
|
||||||
```
|
|
||||||
llama stack list -h
|
|
||||||
usage: llama stack list [-h]
|
|
||||||
|
|
||||||
list the build stacks
|
|
||||||
|
|
||||||
options:
|
|
||||||
-h, --help show this help message and exit
|
|
||||||
```
|
|
||||||
|
|
||||||
Example Usage
|
|
||||||
|
|
||||||
```
|
|
||||||
llama stack list
|
|
||||||
```
|
|
||||||
|
|
||||||
### Removing a Distribution
|
|
||||||
Use the remove command to delete a distribution you've previously built.
|
|
||||||
|
|
||||||
```
|
|
||||||
llama stack rm -h
|
|
||||||
usage: llama stack rm [-h] [--all] [name]
|
|
||||||
|
|
||||||
Remove the build stack
|
|
||||||
|
|
||||||
positional arguments:
|
|
||||||
name Name of the stack to delete (default: None)
|
|
||||||
|
|
||||||
options:
|
|
||||||
-h, --help show this help message and exit
|
|
||||||
--all, -a Delete all stacks (use with caution) (default: False)
|
|
||||||
```
|
|
||||||
|
|
||||||
Example
|
|
||||||
```
|
|
||||||
llama stack rm llamastack-test
|
|
||||||
```
|
|
||||||
|
|
||||||
To keep your environment organized and avoid clutter, consider using `llama stack list` to review old or unused distributions and `llama stack rm <name>` to delete them when they’re no longer needed.
|
|
||||||
|
|
||||||
### Troubleshooting
|
### Troubleshooting
|
||||||
|
|
||||||
|
|
|
@ -118,6 +118,11 @@ server:
|
||||||
port: 8321 # Port to listen on (default: 8321)
|
port: 8321 # Port to listen on (default: 8321)
|
||||||
tls_certfile: "/path/to/cert.pem" # Optional: Path to TLS certificate for HTTPS
|
tls_certfile: "/path/to/cert.pem" # Optional: Path to TLS certificate for HTTPS
|
||||||
tls_keyfile: "/path/to/key.pem" # Optional: Path to TLS key for HTTPS
|
tls_keyfile: "/path/to/key.pem" # Optional: Path to TLS key for HTTPS
|
||||||
|
auth: # Optional: Authentication configuration
|
||||||
|
provider_type: "kubernetes" # Type of auth provider
|
||||||
|
config: # Provider-specific configuration
|
||||||
|
api_server_url: "https://kubernetes.default.svc"
|
||||||
|
ca_cert_path: "/path/to/ca.crt" # Optional: Path to CA certificate
|
||||||
```
|
```
|
||||||
|
|
||||||
### Authentication Configuration
|
### Authentication Configuration
|
||||||
|
@ -130,7 +135,7 @@ Authorization: Bearer <token>
|
||||||
|
|
||||||
The server supports multiple authentication providers:
|
The server supports multiple authentication providers:
|
||||||
|
|
||||||
#### OAuth 2.0/OpenID Connect Provider with Kubernetes
|
#### Kubernetes Provider
|
||||||
|
|
||||||
The Kubernetes cluster must be configured to use a service account for authentication.
|
The Kubernetes cluster must be configured to use a service account for authentication.
|
||||||
|
|
||||||
|
@ -141,67 +146,14 @@ kubectl create rolebinding llama-stack-auth-rolebinding --clusterrole=admin --se
|
||||||
kubectl create token llama-stack-auth -n llama-stack > llama-stack-auth-token
|
kubectl create token llama-stack-auth -n llama-stack > llama-stack-auth-token
|
||||||
```
|
```
|
||||||
|
|
||||||
Make sure the `kube-apiserver` runs with `--anonymous-auth=true` to allow unauthenticated requests
|
Validates tokens against the Kubernetes API server:
|
||||||
and that the correct RoleBinding is created to allow the service account to access the necessary
|
|
||||||
resources. If that is not the case, you can create a RoleBinding for the service account to access
|
|
||||||
the necessary resources:
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
# allow-anonymous-openid.yaml
|
|
||||||
apiVersion: rbac.authorization.k8s.io/v1
|
|
||||||
kind: ClusterRole
|
|
||||||
metadata:
|
|
||||||
name: allow-anonymous-openid
|
|
||||||
rules:
|
|
||||||
- nonResourceURLs: ["/openid/v1/jwks"]
|
|
||||||
verbs: ["get"]
|
|
||||||
---
|
|
||||||
apiVersion: rbac.authorization.k8s.io/v1
|
|
||||||
kind: ClusterRoleBinding
|
|
||||||
metadata:
|
|
||||||
name: allow-anonymous-openid
|
|
||||||
roleRef:
|
|
||||||
apiGroup: rbac.authorization.k8s.io
|
|
||||||
kind: ClusterRole
|
|
||||||
name: allow-anonymous-openid
|
|
||||||
subjects:
|
|
||||||
- kind: User
|
|
||||||
name: system:anonymous
|
|
||||||
apiGroup: rbac.authorization.k8s.io
|
|
||||||
```
|
|
||||||
|
|
||||||
And then apply the configuration:
|
|
||||||
```bash
|
|
||||||
kubectl apply -f allow-anonymous-openid.yaml
|
|
||||||
```
|
|
||||||
|
|
||||||
Validates tokens against the Kubernetes API server through the OIDC provider:
|
|
||||||
```yaml
|
```yaml
|
||||||
server:
|
server:
|
||||||
auth:
|
auth:
|
||||||
provider_type: "oauth2_token"
|
provider_type: "kubernetes"
|
||||||
config:
|
config:
|
||||||
jwks:
|
api_server_url: "https://kubernetes.default.svc" # URL of the Kubernetes API server
|
||||||
uri: "https://kubernetes.default.svc"
|
ca_cert_path: "/path/to/ca.crt" # Optional: Path to CA certificate
|
||||||
key_recheck_period: 3600
|
|
||||||
tls_cafile: "/path/to/ca.crt"
|
|
||||||
issuer: "https://kubernetes.default.svc"
|
|
||||||
audience: "https://kubernetes.default.svc"
|
|
||||||
```
|
|
||||||
|
|
||||||
To find your cluster's audience, run:
|
|
||||||
```bash
|
|
||||||
kubectl create token default --duration=1h | cut -d. -f2 | base64 -d | jq .aud
|
|
||||||
```
|
|
||||||
|
|
||||||
For the issuer, you can use the OIDC provider's URL:
|
|
||||||
```bash
|
|
||||||
kubectl get --raw /.well-known/openid-configuration| jq .issuer
|
|
||||||
```
|
|
||||||
|
|
||||||
For the tls_cafile, you can use the CA certificate of the OIDC provider:
|
|
||||||
```bash
|
|
||||||
kubectl config view --minify -o jsonpath='{.clusters[0].cluster.certificate-authority}'
|
|
||||||
```
|
```
|
||||||
|
|
||||||
The provider extracts user information from the JWT token:
|
The provider extracts user information from the JWT token:
|
||||||
|
@ -256,80 +208,6 @@ And must respond with:
|
||||||
|
|
||||||
If no access attributes are returned, the token is used as a namespace.
|
If no access attributes are returned, the token is used as a namespace.
|
||||||
|
|
||||||
### Quota Configuration
|
|
||||||
|
|
||||||
The `quota` section allows you to enable server-side request throttling for both
|
|
||||||
authenticated and anonymous clients. This is useful for preventing abuse, enforcing
|
|
||||||
fairness across tenants, and controlling infrastructure costs without requiring
|
|
||||||
client-side rate limiting or external proxies.
|
|
||||||
|
|
||||||
Quotas are disabled by default. When enabled, each client is tracked using either:
|
|
||||||
|
|
||||||
* Their authenticated `client_id` (derived from the Bearer token), or
|
|
||||||
* Their IP address (fallback for anonymous requests)
|
|
||||||
|
|
||||||
Quota state is stored in a SQLite-backed key-value store, and rate limits are applied
|
|
||||||
within a configurable time window (currently only `day` is supported).
|
|
||||||
|
|
||||||
#### Example
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
server:
|
|
||||||
quota:
|
|
||||||
kvstore:
|
|
||||||
type: sqlite
|
|
||||||
db_path: ./quotas.db
|
|
||||||
anonymous_max_requests: 100
|
|
||||||
authenticated_max_requests: 1000
|
|
||||||
period: day
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Configuration Options
|
|
||||||
|
|
||||||
| Field | Description |
|
|
||||||
| ---------------------------- | -------------------------------------------------------------------------- |
|
|
||||||
| `kvstore` | Required. Backend storage config for tracking request counts. |
|
|
||||||
| `kvstore.type` | Must be `"sqlite"` for now. Other backends may be supported in the future. |
|
|
||||||
| `kvstore.db_path` | File path to the SQLite database. |
|
|
||||||
| `anonymous_max_requests` | Max requests per period for unauthenticated clients. |
|
|
||||||
| `authenticated_max_requests` | Max requests per period for authenticated clients. |
|
|
||||||
| `period` | Time window for quota enforcement. Only `"day"` is supported. |
|
|
||||||
|
|
||||||
> Note: if `authenticated_max_requests` is set but no authentication provider is
|
|
||||||
configured, the server will fall back to applying `anonymous_max_requests` to all
|
|
||||||
clients.
|
|
||||||
|
|
||||||
#### Example with Authentication Enabled
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
server:
|
|
||||||
port: 8321
|
|
||||||
auth:
|
|
||||||
provider_type: custom
|
|
||||||
config:
|
|
||||||
endpoint: https://auth.example.com/validate
|
|
||||||
quota:
|
|
||||||
kvstore:
|
|
||||||
type: sqlite
|
|
||||||
db_path: ./quotas.db
|
|
||||||
anonymous_max_requests: 100
|
|
||||||
authenticated_max_requests: 1000
|
|
||||||
period: day
|
|
||||||
```
|
|
||||||
|
|
||||||
If a client exceeds their limit, the server responds with:
|
|
||||||
|
|
||||||
```http
|
|
||||||
HTTP/1.1 429 Too Many Requests
|
|
||||||
Content-Type: application/json
|
|
||||||
|
|
||||||
{
|
|
||||||
"error": {
|
|
||||||
"message": "Quota exceeded"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
## Extending to handle Safety
|
## Extending to handle Safety
|
||||||
|
|
||||||
Configuring Safety can be a little involved so it is instructive to go through an example.
|
Configuring Safety can be a little involved so it is instructive to go through an example.
|
||||||
|
|
|
@ -172,7 +172,7 @@ spec:
|
||||||
- name: llama-stack
|
- name: llama-stack
|
||||||
image: localhost/llama-stack-run-k8s:latest
|
image: localhost/llama-stack-run-k8s:latest
|
||||||
imagePullPolicy: IfNotPresent
|
imagePullPolicy: IfNotPresent
|
||||||
command: ["python", "-m", "llama_stack.distribution.server.server", "--config", "/app/config.yaml"]
|
command: ["python", "-m", "llama_stack.distribution.server.server", "--yaml-config", "/app/config.yaml"]
|
||||||
ports:
|
ports:
|
||||||
- containerPort: 5000
|
- containerPort: 5000
|
||||||
volumeMounts:
|
volumeMounts:
|
||||||
|
|
|
@ -70,7 +70,7 @@ docker run \
|
||||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||||
-v ./run.yaml:/root/my-run.yaml \
|
-v ./run.yaml:/root/my-run.yaml \
|
||||||
llamastack/distribution-watsonx \
|
llamastack/distribution-watsonx \
|
||||||
--config /root/my-run.yaml \
|
--yaml-config /root/my-run.yaml \
|
||||||
--port $LLAMA_STACK_PORT \
|
--port $LLAMA_STACK_PORT \
|
||||||
--env WATSONX_API_KEY=$WATSONX_API_KEY \
|
--env WATSONX_API_KEY=$WATSONX_API_KEY \
|
||||||
--env WATSONX_PROJECT_ID=$WATSONX_PROJECT_ID \
|
--env WATSONX_PROJECT_ID=$WATSONX_PROJECT_ID \
|
||||||
|
|
|
@ -52,7 +52,7 @@ docker run \
|
||||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||||
-v ./run.yaml:/root/my-run.yaml \
|
-v ./run.yaml:/root/my-run.yaml \
|
||||||
llamastack/distribution-cerebras \
|
llamastack/distribution-cerebras \
|
||||||
--config /root/my-run.yaml \
|
--yaml-config /root/my-run.yaml \
|
||||||
--port $LLAMA_STACK_PORT \
|
--port $LLAMA_STACK_PORT \
|
||||||
--env CEREBRAS_API_KEY=$CEREBRAS_API_KEY
|
--env CEREBRAS_API_KEY=$CEREBRAS_API_KEY
|
||||||
```
|
```
|
||||||
|
|
|
@ -155,7 +155,7 @@ docker run \
|
||||||
-v $HOME/.llama:/root/.llama \
|
-v $HOME/.llama:/root/.llama \
|
||||||
-v ./llama_stack/templates/tgi/run-with-safety.yaml:/root/my-run.yaml \
|
-v ./llama_stack/templates/tgi/run-with-safety.yaml:/root/my-run.yaml \
|
||||||
llamastack/distribution-dell \
|
llamastack/distribution-dell \
|
||||||
--config /root/my-run.yaml \
|
--yaml-config /root/my-run.yaml \
|
||||||
--port $LLAMA_STACK_PORT \
|
--port $LLAMA_STACK_PORT \
|
||||||
--env INFERENCE_MODEL=$INFERENCE_MODEL \
|
--env INFERENCE_MODEL=$INFERENCE_MODEL \
|
||||||
--env DEH_URL=$DEH_URL \
|
--env DEH_URL=$DEH_URL \
|
||||||
|
|
|
@ -143,7 +143,7 @@ docker run \
|
||||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||||
-v ./run.yaml:/root/my-run.yaml \
|
-v ./run.yaml:/root/my-run.yaml \
|
||||||
llamastack/distribution-nvidia \
|
llamastack/distribution-nvidia \
|
||||||
--config /root/my-run.yaml \
|
--yaml-config /root/my-run.yaml \
|
||||||
--port $LLAMA_STACK_PORT \
|
--port $LLAMA_STACK_PORT \
|
||||||
--env NVIDIA_API_KEY=$NVIDIA_API_KEY
|
--env NVIDIA_API_KEY=$NVIDIA_API_KEY
|
||||||
```
|
```
|
||||||
|
|
|
@ -19,7 +19,6 @@ The `llamastack/distribution-ollama` distribution consists of the following prov
|
||||||
| datasetio | `remote::huggingface`, `inline::localfs` |
|
| datasetio | `remote::huggingface`, `inline::localfs` |
|
||||||
| eval | `inline::meta-reference` |
|
| eval | `inline::meta-reference` |
|
||||||
| inference | `remote::ollama` |
|
| inference | `remote::ollama` |
|
||||||
| post_training | `inline::huggingface` |
|
|
||||||
| safety | `inline::llama-guard` |
|
| safety | `inline::llama-guard` |
|
||||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||||
| telemetry | `inline::meta-reference` |
|
| telemetry | `inline::meta-reference` |
|
||||||
|
@ -98,7 +97,7 @@ docker run \
|
||||||
-v ~/.llama:/root/.llama \
|
-v ~/.llama:/root/.llama \
|
||||||
-v ./llama_stack/templates/ollama/run-with-safety.yaml:/root/my-run.yaml \
|
-v ./llama_stack/templates/ollama/run-with-safety.yaml:/root/my-run.yaml \
|
||||||
llamastack/distribution-ollama \
|
llamastack/distribution-ollama \
|
||||||
--config /root/my-run.yaml \
|
--yaml-config /root/my-run.yaml \
|
||||||
--port $LLAMA_STACK_PORT \
|
--port $LLAMA_STACK_PORT \
|
||||||
--env INFERENCE_MODEL=$INFERENCE_MODEL \
|
--env INFERENCE_MODEL=$INFERENCE_MODEL \
|
||||||
--env SAFETY_MODEL=$SAFETY_MODEL \
|
--env SAFETY_MODEL=$SAFETY_MODEL \
|
||||||
|
|
|
@ -233,7 +233,7 @@ docker run \
|
||||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||||
-v ./llama_stack/templates/remote-vllm/run.yaml:/root/my-run.yaml \
|
-v ./llama_stack/templates/remote-vllm/run.yaml:/root/my-run.yaml \
|
||||||
llamastack/distribution-remote-vllm \
|
llamastack/distribution-remote-vllm \
|
||||||
--config /root/my-run.yaml \
|
--yaml-config /root/my-run.yaml \
|
||||||
--port $LLAMA_STACK_PORT \
|
--port $LLAMA_STACK_PORT \
|
||||||
--env INFERENCE_MODEL=$INFERENCE_MODEL \
|
--env INFERENCE_MODEL=$INFERENCE_MODEL \
|
||||||
--env VLLM_URL=http://host.docker.internal:$INFERENCE_PORT/v1
|
--env VLLM_URL=http://host.docker.internal:$INFERENCE_PORT/v1
|
||||||
|
@ -255,7 +255,7 @@ docker run \
|
||||||
-v ~/.llama:/root/.llama \
|
-v ~/.llama:/root/.llama \
|
||||||
-v ./llama_stack/templates/remote-vllm/run-with-safety.yaml:/root/my-run.yaml \
|
-v ./llama_stack/templates/remote-vllm/run-with-safety.yaml:/root/my-run.yaml \
|
||||||
llamastack/distribution-remote-vllm \
|
llamastack/distribution-remote-vllm \
|
||||||
--config /root/my-run.yaml \
|
--yaml-config /root/my-run.yaml \
|
||||||
--port $LLAMA_STACK_PORT \
|
--port $LLAMA_STACK_PORT \
|
||||||
--env INFERENCE_MODEL=$INFERENCE_MODEL \
|
--env INFERENCE_MODEL=$INFERENCE_MODEL \
|
||||||
--env VLLM_URL=http://host.docker.internal:$INFERENCE_PORT/v1 \
|
--env VLLM_URL=http://host.docker.internal:$INFERENCE_PORT/v1 \
|
||||||
|
|
|
@ -17,7 +17,7 @@ The `llamastack/distribution-sambanova` distribution consists of the following p
|
||||||
|-----|-------------|
|
|-----|-------------|
|
||||||
| agents | `inline::meta-reference` |
|
| agents | `inline::meta-reference` |
|
||||||
| inference | `remote::sambanova`, `inline::sentence-transformers` |
|
| inference | `remote::sambanova`, `inline::sentence-transformers` |
|
||||||
| safety | `remote::sambanova` |
|
| safety | `inline::llama-guard` |
|
||||||
| telemetry | `inline::meta-reference` |
|
| telemetry | `inline::meta-reference` |
|
||||||
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::rag-runtime`, `remote::model-context-protocol`, `remote::wolfram-alpha` |
|
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::rag-runtime`, `remote::model-context-protocol`, `remote::wolfram-alpha` |
|
||||||
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
||||||
|
@ -48,44 +48,33 @@ The following models are available by default:
|
||||||
|
|
||||||
### Prerequisite: API Keys
|
### Prerequisite: API Keys
|
||||||
|
|
||||||
Make sure you have access to a SambaNova API Key. You can get one by visiting [SambaNova.ai](http://cloud.sambanova.ai?utm_source=llamastack&utm_medium=external&utm_campaign=cloud_signup).
|
Make sure you have access to a SambaNova API Key. You can get one by visiting [SambaNova.ai](https://sambanova.ai/).
|
||||||
|
|
||||||
|
|
||||||
## Running Llama Stack with SambaNova
|
## Running Llama Stack with SambaNova
|
||||||
|
|
||||||
You can do this via Conda (build code) or Docker which has a pre-built image.
|
You can do this via Conda (build code) or Docker which has a pre-built image.
|
||||||
|
|
||||||
|
|
||||||
### Via Docker
|
### Via Docker
|
||||||
|
|
||||||
|
This method allows you to get started quickly without having to build the distribution code.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
LLAMA_STACK_PORT=8321
|
LLAMA_STACK_PORT=8321
|
||||||
llama stack build --template sambanova --image-type container
|
|
||||||
docker run \
|
docker run \
|
||||||
-it \
|
-it \
|
||||||
|
--pull always \
|
||||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||||
-v ~/.llama:/root/.llama \
|
llamastack/distribution-sambanova \
|
||||||
distribution-sambanova \
|
|
||||||
--port $LLAMA_STACK_PORT \
|
--port $LLAMA_STACK_PORT \
|
||||||
--env SAMBANOVA_API_KEY=$SAMBANOVA_API_KEY
|
--env SAMBANOVA_API_KEY=$SAMBANOVA_API_KEY
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
### Via Venv
|
|
||||||
|
|
||||||
```bash
|
|
||||||
llama stack build --template sambanova --image-type venv
|
|
||||||
llama stack run --image-type venv ~/.llama/distributions/sambanova/sambanova-run.yaml \
|
|
||||||
--port $LLAMA_STACK_PORT \
|
|
||||||
--env SAMBANOVA_API_KEY=$SAMBANOVA_API_KEY
|
|
||||||
```
|
|
||||||
|
|
||||||
|
|
||||||
### Via Conda
|
### Via Conda
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
llama stack build --template sambanova --image-type conda
|
llama stack build --template sambanova --image-type conda
|
||||||
llama stack run --image-type conda ~/.llama/distributions/sambanova/sambanova-run.yaml \
|
llama stack run ./run.yaml \
|
||||||
--port $LLAMA_STACK_PORT \
|
--port $LLAMA_STACK_PORT \
|
||||||
--env SAMBANOVA_API_KEY=$SAMBANOVA_API_KEY
|
--env SAMBANOVA_API_KEY=$SAMBANOVA_API_KEY
|
||||||
```
|
```
|
||||||
|
|
|
@ -117,7 +117,7 @@ docker run \
|
||||||
-v ~/.llama:/root/.llama \
|
-v ~/.llama:/root/.llama \
|
||||||
-v ./llama_stack/templates/tgi/run-with-safety.yaml:/root/my-run.yaml \
|
-v ./llama_stack/templates/tgi/run-with-safety.yaml:/root/my-run.yaml \
|
||||||
llamastack/distribution-tgi \
|
llamastack/distribution-tgi \
|
||||||
--config /root/my-run.yaml \
|
--yaml-config /root/my-run.yaml \
|
||||||
--port $LLAMA_STACK_PORT \
|
--port $LLAMA_STACK_PORT \
|
||||||
--env INFERENCE_MODEL=$INFERENCE_MODEL \
|
--env INFERENCE_MODEL=$INFERENCE_MODEL \
|
||||||
--env TGI_URL=http://host.docker.internal:$INFERENCE_PORT \
|
--env TGI_URL=http://host.docker.internal:$INFERENCE_PORT \
|
||||||
|
|
|
@ -30,18 +30,6 @@ Runs inference with an LLM.
|
||||||
## Post Training
|
## Post Training
|
||||||
Fine-tunes a model.
|
Fine-tunes a model.
|
||||||
|
|
||||||
#### Post Training Providers
|
|
||||||
The following providers are available for Post Training:
|
|
||||||
|
|
||||||
```{toctree}
|
|
||||||
:maxdepth: 1
|
|
||||||
|
|
||||||
external
|
|
||||||
post_training/huggingface
|
|
||||||
post_training/torchtune
|
|
||||||
post_training/nvidia_nemo
|
|
||||||
```
|
|
||||||
|
|
||||||
## Safety
|
## Safety
|
||||||
Applies safety policies to the output at a Systems (not only model) level.
|
Applies safety policies to the output at a Systems (not only model) level.
|
||||||
|
|
||||||
|
|
|
@ -1,122 +0,0 @@
|
||||||
---
|
|
||||||
orphan: true
|
|
||||||
---
|
|
||||||
# HuggingFace SFTTrainer
|
|
||||||
|
|
||||||
[HuggingFace SFTTrainer](https://huggingface.co/docs/trl/en/sft_trainer) is an inline post training provider for Llama Stack. It allows you to run supervised fine tuning on a variety of models using many datasets
|
|
||||||
|
|
||||||
## Features
|
|
||||||
|
|
||||||
- Simple access through the post_training API
|
|
||||||
- Fully integrated with Llama Stack
|
|
||||||
- GPU support, CPU support, and MPS support (MacOS Metal Performance Shaders)
|
|
||||||
|
|
||||||
## Usage
|
|
||||||
|
|
||||||
To use the HF SFTTrainer in your Llama Stack project, follow these steps:
|
|
||||||
|
|
||||||
1. Configure your Llama Stack project to use this provider.
|
|
||||||
2. Kick off a SFT job using the Llama Stack post_training API.
|
|
||||||
|
|
||||||
## Setup
|
|
||||||
|
|
||||||
You can access the HuggingFace trainer via the `ollama` distribution:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
llama stack build --template ollama --image-type venv
|
|
||||||
llama stack run --image-type venv ~/.llama/distributions/ollama/ollama-run.yaml
|
|
||||||
```
|
|
||||||
|
|
||||||
## Run Training
|
|
||||||
|
|
||||||
You can access the provider and the `supervised_fine_tune` method via the post_training API:
|
|
||||||
|
|
||||||
```python
|
|
||||||
import time
|
|
||||||
import uuid
|
|
||||||
|
|
||||||
|
|
||||||
from llama_stack_client.types import (
|
|
||||||
post_training_supervised_fine_tune_params,
|
|
||||||
algorithm_config_param,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def create_http_client():
|
|
||||||
from llama_stack_client import LlamaStackClient
|
|
||||||
|
|
||||||
return LlamaStackClient(base_url="http://localhost:8321")
|
|
||||||
|
|
||||||
|
|
||||||
client = create_http_client()
|
|
||||||
|
|
||||||
# Example Dataset
|
|
||||||
client.datasets.register(
|
|
||||||
purpose="post-training/messages",
|
|
||||||
source={
|
|
||||||
"type": "uri",
|
|
||||||
"uri": "huggingface://datasets/llamastack/simpleqa?split=train",
|
|
||||||
},
|
|
||||||
dataset_id="simpleqa",
|
|
||||||
)
|
|
||||||
|
|
||||||
training_config = post_training_supervised_fine_tune_params.TrainingConfig(
|
|
||||||
data_config=post_training_supervised_fine_tune_params.TrainingConfigDataConfig(
|
|
||||||
batch_size=32,
|
|
||||||
data_format="instruct",
|
|
||||||
dataset_id="simpleqa",
|
|
||||||
shuffle=True,
|
|
||||||
),
|
|
||||||
gradient_accumulation_steps=1,
|
|
||||||
max_steps_per_epoch=0,
|
|
||||||
max_validation_steps=1,
|
|
||||||
n_epochs=4,
|
|
||||||
)
|
|
||||||
|
|
||||||
algorithm_config = algorithm_config_param.LoraFinetuningConfig( # this config is also currently mandatory but should not be
|
|
||||||
alpha=1,
|
|
||||||
apply_lora_to_mlp=True,
|
|
||||||
apply_lora_to_output=False,
|
|
||||||
lora_attn_modules=["q_proj"],
|
|
||||||
rank=1,
|
|
||||||
type="LoRA",
|
|
||||||
)
|
|
||||||
|
|
||||||
job_uuid = f"test-job{uuid.uuid4()}"
|
|
||||||
|
|
||||||
# Example Model
|
|
||||||
training_model = "ibm-granite/granite-3.3-8b-instruct"
|
|
||||||
|
|
||||||
start_time = time.time()
|
|
||||||
response = client.post_training.supervised_fine_tune(
|
|
||||||
job_uuid=job_uuid,
|
|
||||||
logger_config={},
|
|
||||||
model=training_model,
|
|
||||||
hyperparam_search_config={},
|
|
||||||
training_config=training_config,
|
|
||||||
algorithm_config=algorithm_config,
|
|
||||||
checkpoint_dir="output",
|
|
||||||
)
|
|
||||||
print("Job: ", job_uuid)
|
|
||||||
|
|
||||||
|
|
||||||
# Wait for the job to complete!
|
|
||||||
while True:
|
|
||||||
status = client.post_training.job.status(job_uuid=job_uuid)
|
|
||||||
if not status:
|
|
||||||
print("Job not found")
|
|
||||||
break
|
|
||||||
|
|
||||||
print(status)
|
|
||||||
if status.status == "completed":
|
|
||||||
break
|
|
||||||
|
|
||||||
print("Waiting for job to complete...")
|
|
||||||
time.sleep(5)
|
|
||||||
|
|
||||||
end_time = time.time()
|
|
||||||
print("Job completed in", end_time - start_time, "seconds!")
|
|
||||||
|
|
||||||
print("Artifacts:")
|
|
||||||
print(client.post_training.job.artifacts(job_uuid=job_uuid))
|
|
||||||
```
|
|
|
@ -1,163 +0,0 @@
|
||||||
---
|
|
||||||
orphan: true
|
|
||||||
---
|
|
||||||
# NVIDIA NEMO
|
|
||||||
|
|
||||||
[NVIDIA NEMO](https://developer.nvidia.com/nemo-framework) is a remote post training provider for Llama Stack. It provides enterprise-grade fine-tuning capabilities through NVIDIA's NeMo Customizer service.
|
|
||||||
|
|
||||||
## Features
|
|
||||||
|
|
||||||
- Enterprise-grade fine-tuning capabilities
|
|
||||||
- Support for LoRA and SFT fine-tuning
|
|
||||||
- Integration with NVIDIA's NeMo Customizer service
|
|
||||||
- Support for various NVIDIA-optimized models
|
|
||||||
- Efficient training with NVIDIA hardware acceleration
|
|
||||||
|
|
||||||
## Usage
|
|
||||||
|
|
||||||
To use NVIDIA NEMO in your Llama Stack project, follow these steps:
|
|
||||||
|
|
||||||
1. Configure your Llama Stack project to use this provider.
|
|
||||||
2. Set up your NVIDIA API credentials.
|
|
||||||
3. Kick off a fine-tuning job using the Llama Stack post_training API.
|
|
||||||
|
|
||||||
## Setup
|
|
||||||
|
|
||||||
You'll need to set the following environment variables:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
export NVIDIA_API_KEY="your-api-key"
|
|
||||||
export NVIDIA_DATASET_NAMESPACE="default"
|
|
||||||
export NVIDIA_CUSTOMIZER_URL="your-customizer-url"
|
|
||||||
export NVIDIA_PROJECT_ID="your-project-id"
|
|
||||||
export NVIDIA_OUTPUT_MODEL_DIR="your-output-model-dir"
|
|
||||||
```
|
|
||||||
|
|
||||||
## Run Training
|
|
||||||
|
|
||||||
You can access the provider and the `supervised_fine_tune` method via the post_training API:
|
|
||||||
|
|
||||||
```python
|
|
||||||
import time
|
|
||||||
import uuid
|
|
||||||
|
|
||||||
from llama_stack_client.types import (
|
|
||||||
post_training_supervised_fine_tune_params,
|
|
||||||
algorithm_config_param,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def create_http_client():
|
|
||||||
from llama_stack_client import LlamaStackClient
|
|
||||||
|
|
||||||
return LlamaStackClient(base_url="http://localhost:8321")
|
|
||||||
|
|
||||||
|
|
||||||
client = create_http_client()
|
|
||||||
|
|
||||||
# Example Dataset
|
|
||||||
client.datasets.register(
|
|
||||||
purpose="post-training/messages",
|
|
||||||
source={
|
|
||||||
"type": "uri",
|
|
||||||
"uri": "huggingface://datasets/llamastack/simpleqa?split=train",
|
|
||||||
},
|
|
||||||
dataset_id="simpleqa",
|
|
||||||
)
|
|
||||||
|
|
||||||
training_config = post_training_supervised_fine_tune_params.TrainingConfig(
|
|
||||||
data_config=post_training_supervised_fine_tune_params.TrainingConfigDataConfig(
|
|
||||||
batch_size=8, # Default batch size for NEMO
|
|
||||||
data_format="instruct",
|
|
||||||
dataset_id="simpleqa",
|
|
||||||
shuffle=True,
|
|
||||||
),
|
|
||||||
n_epochs=50, # Default epochs for NEMO
|
|
||||||
optimizer_config=post_training_supervised_fine_tune_params.TrainingConfigOptimizerConfig(
|
|
||||||
lr=0.0001, # Default learning rate
|
|
||||||
weight_decay=0.01, # NEMO-specific parameter
|
|
||||||
),
|
|
||||||
# NEMO-specific parameters
|
|
||||||
log_every_n_steps=None,
|
|
||||||
val_check_interval=0.25,
|
|
||||||
sequence_packing_enabled=False,
|
|
||||||
hidden_dropout=None,
|
|
||||||
attention_dropout=None,
|
|
||||||
ffn_dropout=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
algorithm_config = algorithm_config_param.LoraFinetuningConfig(
|
|
||||||
alpha=16, # Default alpha for NEMO
|
|
||||||
type="LoRA",
|
|
||||||
)
|
|
||||||
|
|
||||||
job_uuid = f"test-job{uuid.uuid4()}"
|
|
||||||
|
|
||||||
# Example Model - must be a supported NEMO model
|
|
||||||
training_model = "meta/llama-3.1-8b-instruct"
|
|
||||||
|
|
||||||
start_time = time.time()
|
|
||||||
response = client.post_training.supervised_fine_tune(
|
|
||||||
job_uuid=job_uuid,
|
|
||||||
logger_config={},
|
|
||||||
model=training_model,
|
|
||||||
hyperparam_search_config={},
|
|
||||||
training_config=training_config,
|
|
||||||
algorithm_config=algorithm_config,
|
|
||||||
checkpoint_dir="output",
|
|
||||||
)
|
|
||||||
print("Job: ", job_uuid)
|
|
||||||
|
|
||||||
# Wait for the job to complete!
|
|
||||||
while True:
|
|
||||||
status = client.post_training.job.status(job_uuid=job_uuid)
|
|
||||||
if not status:
|
|
||||||
print("Job not found")
|
|
||||||
break
|
|
||||||
|
|
||||||
print(status)
|
|
||||||
if status.status == "completed":
|
|
||||||
break
|
|
||||||
|
|
||||||
print("Waiting for job to complete...")
|
|
||||||
time.sleep(5)
|
|
||||||
|
|
||||||
end_time = time.time()
|
|
||||||
print("Job completed in", end_time - start_time, "seconds!")
|
|
||||||
|
|
||||||
print("Artifacts:")
|
|
||||||
print(client.post_training.job.artifacts(job_uuid=job_uuid))
|
|
||||||
```
|
|
||||||
|
|
||||||
## Supported Models
|
|
||||||
|
|
||||||
Currently supports the following models:
|
|
||||||
- meta/llama-3.1-8b-instruct
|
|
||||||
- meta/llama-3.2-1b-instruct
|
|
||||||
|
|
||||||
## Supported Parameters
|
|
||||||
|
|
||||||
### TrainingConfig
|
|
||||||
- n_epochs (default: 50)
|
|
||||||
- data_config
|
|
||||||
- optimizer_config
|
|
||||||
- log_every_n_steps
|
|
||||||
- val_check_interval (default: 0.25)
|
|
||||||
- sequence_packing_enabled (default: False)
|
|
||||||
- hidden_dropout (0.0-1.0)
|
|
||||||
- attention_dropout (0.0-1.0)
|
|
||||||
- ffn_dropout (0.0-1.0)
|
|
||||||
|
|
||||||
### DataConfig
|
|
||||||
- dataset_id
|
|
||||||
- batch_size (default: 8)
|
|
||||||
|
|
||||||
### OptimizerConfig
|
|
||||||
- lr (default: 0.0001)
|
|
||||||
- weight_decay (default: 0.01)
|
|
||||||
|
|
||||||
### LoRA Config
|
|
||||||
- alpha (default: 16)
|
|
||||||
- type (must be "LoRA")
|
|
||||||
|
|
||||||
Note: Some parameters from the standard Llama Stack API are not supported and will be ignored with a warning.
|
|
|
@ -1,125 +0,0 @@
|
||||||
---
|
|
||||||
orphan: true
|
|
||||||
---
|
|
||||||
# TorchTune
|
|
||||||
|
|
||||||
[TorchTune](https://github.com/pytorch/torchtune) is an inline post training provider for Llama Stack. It provides a simple and efficient way to fine-tune language models using PyTorch.
|
|
||||||
|
|
||||||
## Features
|
|
||||||
|
|
||||||
- Simple access through the post_training API
|
|
||||||
- Fully integrated with Llama Stack
|
|
||||||
- GPU support and single device capabilities.
|
|
||||||
- Support for LoRA
|
|
||||||
|
|
||||||
## Usage
|
|
||||||
|
|
||||||
To use TorchTune in your Llama Stack project, follow these steps:
|
|
||||||
|
|
||||||
1. Configure your Llama Stack project to use this provider.
|
|
||||||
2. Kick off a fine-tuning job using the Llama Stack post_training API.
|
|
||||||
|
|
||||||
## Setup
|
|
||||||
|
|
||||||
You can access the TorchTune trainer by writing your own yaml pointing to the provider:
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
post_training:
|
|
||||||
- provider_id: torchtune
|
|
||||||
provider_type: inline::torchtune
|
|
||||||
config: {}
|
|
||||||
```
|
|
||||||
|
|
||||||
you can then build and run your own stack with this provider.
|
|
||||||
|
|
||||||
## Run Training
|
|
||||||
|
|
||||||
You can access the provider and the `supervised_fine_tune` method via the post_training API:
|
|
||||||
|
|
||||||
```python
|
|
||||||
import time
|
|
||||||
import uuid
|
|
||||||
|
|
||||||
from llama_stack_client.types import (
|
|
||||||
post_training_supervised_fine_tune_params,
|
|
||||||
algorithm_config_param,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def create_http_client():
|
|
||||||
from llama_stack_client import LlamaStackClient
|
|
||||||
|
|
||||||
return LlamaStackClient(base_url="http://localhost:8321")
|
|
||||||
|
|
||||||
|
|
||||||
client = create_http_client()
|
|
||||||
|
|
||||||
# Example Dataset
|
|
||||||
client.datasets.register(
|
|
||||||
purpose="post-training/messages",
|
|
||||||
source={
|
|
||||||
"type": "uri",
|
|
||||||
"uri": "huggingface://datasets/llamastack/simpleqa?split=train",
|
|
||||||
},
|
|
||||||
dataset_id="simpleqa",
|
|
||||||
)
|
|
||||||
|
|
||||||
training_config = post_training_supervised_fine_tune_params.TrainingConfig(
|
|
||||||
data_config=post_training_supervised_fine_tune_params.TrainingConfigDataConfig(
|
|
||||||
batch_size=32,
|
|
||||||
data_format="instruct",
|
|
||||||
dataset_id="simpleqa",
|
|
||||||
shuffle=True,
|
|
||||||
),
|
|
||||||
gradient_accumulation_steps=1,
|
|
||||||
max_steps_per_epoch=0,
|
|
||||||
max_validation_steps=1,
|
|
||||||
n_epochs=4,
|
|
||||||
)
|
|
||||||
|
|
||||||
algorithm_config = algorithm_config_param.LoraFinetuningConfig(
|
|
||||||
alpha=1,
|
|
||||||
apply_lora_to_mlp=True,
|
|
||||||
apply_lora_to_output=False,
|
|
||||||
lora_attn_modules=["q_proj"],
|
|
||||||
rank=1,
|
|
||||||
type="LoRA",
|
|
||||||
)
|
|
||||||
|
|
||||||
job_uuid = f"test-job{uuid.uuid4()}"
|
|
||||||
|
|
||||||
# Example Model
|
|
||||||
training_model = "meta-llama/Llama-2-7b-hf"
|
|
||||||
|
|
||||||
start_time = time.time()
|
|
||||||
response = client.post_training.supervised_fine_tune(
|
|
||||||
job_uuid=job_uuid,
|
|
||||||
logger_config={},
|
|
||||||
model=training_model,
|
|
||||||
hyperparam_search_config={},
|
|
||||||
training_config=training_config,
|
|
||||||
algorithm_config=algorithm_config,
|
|
||||||
checkpoint_dir="output",
|
|
||||||
)
|
|
||||||
print("Job: ", job_uuid)
|
|
||||||
|
|
||||||
# Wait for the job to complete!
|
|
||||||
while True:
|
|
||||||
status = client.post_training.job.status(job_uuid=job_uuid)
|
|
||||||
if not status:
|
|
||||||
print("Job not found")
|
|
||||||
break
|
|
||||||
|
|
||||||
print(status)
|
|
||||||
if status.status == "completed":
|
|
||||||
break
|
|
||||||
|
|
||||||
print("Waiting for job to complete...")
|
|
||||||
time.sleep(5)
|
|
||||||
|
|
||||||
end_time = time.time()
|
|
||||||
print("Job completed in", end_time - start_time, "seconds!")
|
|
||||||
|
|
||||||
print("Artifacts:")
|
|
||||||
print(client.post_training.job.artifacts(job_uuid=job_uuid))
|
|
||||||
```
|
|
|
@ -66,25 +66,6 @@ To use sqlite-vec in your Llama Stack project, follow these steps:
|
||||||
2. Configure your Llama Stack project to use SQLite-Vec.
|
2. Configure your Llama Stack project to use SQLite-Vec.
|
||||||
3. Start storing and querying vectors.
|
3. Start storing and querying vectors.
|
||||||
|
|
||||||
## Supported Search Modes
|
|
||||||
|
|
||||||
The sqlite-vec provider supports both vector-based and keyword-based (full-text) search modes.
|
|
||||||
|
|
||||||
When using the RAGTool interface, you can specify the desired search behavior via the `mode` parameter in
|
|
||||||
`RAGQueryConfig`. For example:
|
|
||||||
|
|
||||||
```python
|
|
||||||
from llama_stack.apis.tool_runtime.rag import RAGQueryConfig
|
|
||||||
|
|
||||||
query_config = RAGQueryConfig(max_chunks=6, mode="vector")
|
|
||||||
|
|
||||||
results = client.tool_runtime.rag_tool.query(
|
|
||||||
vector_db_ids=[vector_db_id],
|
|
||||||
content="what is torchtune",
|
|
||||||
query_config=query_config,
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
|
|
||||||
You can install SQLite-Vec using pip:
|
You can install SQLite-Vec using pip:
|
||||||
|
|
|
@ -1,6 +0,0 @@
|
||||||
#!/usr/bin/env bash
|
|
||||||
|
|
||||||
export USE_COPY_NOT_MOUNT=true
|
|
||||||
export LLAMA_STACK_DIR=.
|
|
||||||
|
|
||||||
uvx --from . llama stack build --template kvant --image-type container --image-name kvant
|
|
|
@ -1,17 +0,0 @@
|
||||||
#!/usr/bin/env bash
|
|
||||||
|
|
||||||
export LLAMA_STACK_PORT=8321
|
|
||||||
# VLLM_API_TOKEN= env file
|
|
||||||
# KEYCLOAK_CLIENT_SECRET= env file
|
|
||||||
|
|
||||||
|
|
||||||
docker run -it \
|
|
||||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
|
||||||
-v $(pwd)/data:/root/.llama \
|
|
||||||
--mount type=bind,source="$(pwd)"/llama_stack/templates/kvant/run.yaml,target=/root/.llama/config.yaml,readonly \
|
|
||||||
--entrypoint python \
|
|
||||||
--env-file ./.env \
|
|
||||||
distribution-kvant:dev \
|
|
||||||
-m llama_stack.distribution.server.server --config /root/.llama/config.yaml \
|
|
||||||
--port $LLAMA_STACK_PORT \
|
|
||||||
|
|
|
@ -13,7 +13,7 @@ from typing import Annotated, Any, Literal, Protocol, runtime_checkable
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import URL, ContentDelta, InterleavedContent
|
from llama_stack.apis.common.content_types import URL, ContentDelta, InterleavedContent
|
||||||
from llama_stack.apis.common.responses import Order, PaginatedResponse
|
from llama_stack.apis.common.responses import PaginatedResponse
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
CompletionMessage,
|
CompletionMessage,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
|
@ -31,8 +31,6 @@ from llama_stack.apis.tools import ToolDef
|
||||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||||
|
|
||||||
from .openai_responses import (
|
from .openai_responses import (
|
||||||
ListOpenAIResponseInputItem,
|
|
||||||
ListOpenAIResponseObject,
|
|
||||||
OpenAIResponseInput,
|
OpenAIResponseInput,
|
||||||
OpenAIResponseInputTool,
|
OpenAIResponseInputTool,
|
||||||
OpenAIResponseObject,
|
OpenAIResponseObject,
|
||||||
|
@ -581,14 +579,14 @@ class Agents(Protocol):
|
||||||
#
|
#
|
||||||
# Both of these APIs are inherently stateful.
|
# Both of these APIs are inherently stateful.
|
||||||
|
|
||||||
@webmethod(route="/openai/v1/responses/{response_id}", method="GET")
|
@webmethod(route="/openai/v1/responses/{id}", method="GET")
|
||||||
async def get_openai_response(
|
async def get_openai_response(
|
||||||
self,
|
self,
|
||||||
response_id: str,
|
id: str,
|
||||||
) -> OpenAIResponseObject:
|
) -> OpenAIResponseObject:
|
||||||
"""Retrieve an OpenAI response by its ID.
|
"""Retrieve an OpenAI response by its ID.
|
||||||
|
|
||||||
:param response_id: The ID of the OpenAI response to retrieve.
|
:param id: The ID of the OpenAI response to retrieve.
|
||||||
:returns: An OpenAIResponseObject.
|
:returns: An OpenAIResponseObject.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
@ -598,7 +596,6 @@ class Agents(Protocol):
|
||||||
self,
|
self,
|
||||||
input: str | list[OpenAIResponseInput],
|
input: str | list[OpenAIResponseInput],
|
||||||
model: str,
|
model: str,
|
||||||
instructions: str | None = None,
|
|
||||||
previous_response_id: str | None = None,
|
previous_response_id: str | None = None,
|
||||||
store: bool | None = True,
|
store: bool | None = True,
|
||||||
stream: bool | None = False,
|
stream: bool | None = False,
|
||||||
|
@ -613,43 +610,3 @@ class Agents(Protocol):
|
||||||
:returns: An OpenAIResponseObject.
|
:returns: An OpenAIResponseObject.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@webmethod(route="/openai/v1/responses", method="GET")
|
|
||||||
async def list_openai_responses(
|
|
||||||
self,
|
|
||||||
after: str | None = None,
|
|
||||||
limit: int | None = 50,
|
|
||||||
model: str | None = None,
|
|
||||||
order: Order | None = Order.desc,
|
|
||||||
) -> ListOpenAIResponseObject:
|
|
||||||
"""List all OpenAI responses.
|
|
||||||
|
|
||||||
:param after: The ID of the last response to return.
|
|
||||||
:param limit: The number of responses to return.
|
|
||||||
:param model: The model to filter responses by.
|
|
||||||
:param order: The order to sort responses by when sorted by created_at ('asc' or 'desc').
|
|
||||||
:returns: A ListOpenAIResponseObject.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@webmethod(route="/openai/v1/responses/{response_id}/input_items", method="GET")
|
|
||||||
async def list_openai_response_input_items(
|
|
||||||
self,
|
|
||||||
response_id: str,
|
|
||||||
after: str | None = None,
|
|
||||||
before: str | None = None,
|
|
||||||
include: list[str] | None = None,
|
|
||||||
limit: int | None = 20,
|
|
||||||
order: Order | None = Order.desc,
|
|
||||||
) -> ListOpenAIResponseInputItem:
|
|
||||||
"""List input items for a given OpenAI response.
|
|
||||||
|
|
||||||
:param response_id: The ID of the response to retrieve input items for.
|
|
||||||
:param after: An item ID to list items after, used for pagination.
|
|
||||||
:param before: An item ID to list items before, used for pagination.
|
|
||||||
:param include: Additional fields to include in the response.
|
|
||||||
:param limit: A limit on the number of objects to be returned. Limit can range between 1 and 100, and the default is 20.
|
|
||||||
:param order: The order to return the input items in. Default is desc.
|
|
||||||
:returns: An ListOpenAIResponseInputItem.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
|
@ -10,9 +10,6 @@ from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from llama_stack.schema_utils import json_schema_type, register_schema
|
from llama_stack.schema_utils import json_schema_type, register_schema
|
||||||
|
|
||||||
# NOTE(ashwin): this file is literally a copy of the OpenAI responses API schema. We should probably
|
|
||||||
# take their YAML and generate this file automatically. Their YAML is available.
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class OpenAIResponseError(BaseModel):
|
class OpenAIResponseError(BaseModel):
|
||||||
|
@ -82,45 +79,16 @@ class OpenAIResponseOutputMessageWebSearchToolCall(BaseModel):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class OpenAIResponseOutputMessageFunctionToolCall(BaseModel):
|
class OpenAIResponseOutputMessageFunctionToolCall(BaseModel):
|
||||||
|
arguments: str
|
||||||
call_id: str
|
call_id: str
|
||||||
name: str
|
name: str
|
||||||
arguments: str
|
|
||||||
type: Literal["function_call"] = "function_call"
|
type: Literal["function_call"] = "function_call"
|
||||||
id: str | None = None
|
|
||||||
status: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class OpenAIResponseOutputMessageMCPCall(BaseModel):
|
|
||||||
id: str
|
id: str
|
||||||
type: Literal["mcp_call"] = "mcp_call"
|
status: str
|
||||||
arguments: str
|
|
||||||
name: str
|
|
||||||
server_label: str
|
|
||||||
error: str | None = None
|
|
||||||
output: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class MCPListToolsTool(BaseModel):
|
|
||||||
input_schema: dict[str, Any]
|
|
||||||
name: str
|
|
||||||
description: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class OpenAIResponseOutputMessageMCPListTools(BaseModel):
|
|
||||||
id: str
|
|
||||||
type: Literal["mcp_list_tools"] = "mcp_list_tools"
|
|
||||||
server_label: str
|
|
||||||
tools: list[MCPListToolsTool]
|
|
||||||
|
|
||||||
|
|
||||||
OpenAIResponseOutput = Annotated[
|
OpenAIResponseOutput = Annotated[
|
||||||
OpenAIResponseMessage
|
OpenAIResponseMessage | OpenAIResponseOutputMessageWebSearchToolCall | OpenAIResponseOutputMessageFunctionToolCall,
|
||||||
| OpenAIResponseOutputMessageWebSearchToolCall
|
|
||||||
| OpenAIResponseOutputMessageFunctionToolCall
|
|
||||||
| OpenAIResponseOutputMessageMCPCall
|
|
||||||
| OpenAIResponseOutputMessageMCPListTools,
|
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
register_schema(OpenAIResponseOutput, name="OpenAIResponseOutput")
|
register_schema(OpenAIResponseOutput, name="OpenAIResponseOutput")
|
||||||
|
@ -149,16 +117,6 @@ class OpenAIResponseObjectStreamResponseCreated(BaseModel):
|
||||||
type: Literal["response.created"] = "response.created"
|
type: Literal["response.created"] = "response.created"
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class OpenAIResponseObjectStreamResponseOutputTextDelta(BaseModel):
|
|
||||||
content_index: int
|
|
||||||
delta: str
|
|
||||||
item_id: str
|
|
||||||
output_index: int
|
|
||||||
sequence_number: int
|
|
||||||
type: Literal["response.output_text.delta"] = "response.output_text.delta"
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class OpenAIResponseObjectStreamResponseCompleted(BaseModel):
|
class OpenAIResponseObjectStreamResponseCompleted(BaseModel):
|
||||||
response: OpenAIResponseObject
|
response: OpenAIResponseObject
|
||||||
|
@ -166,9 +124,7 @@ class OpenAIResponseObjectStreamResponseCompleted(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
OpenAIResponseObjectStream = Annotated[
|
OpenAIResponseObjectStream = Annotated[
|
||||||
OpenAIResponseObjectStreamResponseCreated
|
OpenAIResponseObjectStreamResponseCreated | OpenAIResponseObjectStreamResponseCompleted,
|
||||||
| OpenAIResponseObjectStreamResponseOutputTextDelta
|
|
||||||
| OpenAIResponseObjectStreamResponseCompleted,
|
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
register_schema(OpenAIResponseObjectStream, name="OpenAIResponseObjectStream")
|
register_schema(OpenAIResponseObjectStream, name="OpenAIResponseObjectStream")
|
||||||
|
@ -230,50 +186,13 @@ class OpenAIResponseInputToolFileSearch(BaseModel):
|
||||||
# TODO: add filters
|
# TODO: add filters
|
||||||
|
|
||||||
|
|
||||||
class ApprovalFilter(BaseModel):
|
|
||||||
always: list[str] | None = None
|
|
||||||
never: list[str] | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class AllowedToolsFilter(BaseModel):
|
|
||||||
tool_names: list[str] | None = None
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class OpenAIResponseInputToolMCP(BaseModel):
|
|
||||||
type: Literal["mcp"] = "mcp"
|
|
||||||
server_label: str
|
|
||||||
server_url: str
|
|
||||||
headers: dict[str, Any] | None = None
|
|
||||||
|
|
||||||
require_approval: Literal["always"] | Literal["never"] | ApprovalFilter = "never"
|
|
||||||
allowed_tools: list[str] | AllowedToolsFilter | None = None
|
|
||||||
|
|
||||||
|
|
||||||
OpenAIResponseInputTool = Annotated[
|
OpenAIResponseInputTool = Annotated[
|
||||||
OpenAIResponseInputToolWebSearch
|
OpenAIResponseInputToolWebSearch | OpenAIResponseInputToolFileSearch | OpenAIResponseInputToolFunction,
|
||||||
| OpenAIResponseInputToolFileSearch
|
|
||||||
| OpenAIResponseInputToolFunction
|
|
||||||
| OpenAIResponseInputToolMCP,
|
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
register_schema(OpenAIResponseInputTool, name="OpenAIResponseInputTool")
|
register_schema(OpenAIResponseInputTool, name="OpenAIResponseInputTool")
|
||||||
|
|
||||||
|
|
||||||
class ListOpenAIResponseInputItem(BaseModel):
|
class OpenAIResponseInputItemList(BaseModel):
|
||||||
data: list[OpenAIResponseInput]
|
data: list[OpenAIResponseInput]
|
||||||
object: Literal["list"] = "list"
|
object: Literal["list"] = "list"
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class OpenAIResponseObjectWithInput(OpenAIResponseObject):
|
|
||||||
input: list[OpenAIResponseInput]
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class ListOpenAIResponseObject(BaseModel):
|
|
||||||
data: list[OpenAIResponseObjectWithInput]
|
|
||||||
has_more: bool
|
|
||||||
first_id: str
|
|
||||||
last_id: str
|
|
||||||
object: Literal["list"] = "list"
|
|
||||||
|
|
30
llama_stack/apis/common/deployment_types.py
Normal file
30
llama_stack/apis/common/deployment_types.py
Normal file
|
@ -0,0 +1,30 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from llama_stack.apis.common.content_types import URL
|
||||||
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class RestAPIMethod(Enum):
|
||||||
|
GET = "GET"
|
||||||
|
POST = "POST"
|
||||||
|
PUT = "PUT"
|
||||||
|
DELETE = "DELETE"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class RestAPIExecutionConfig(BaseModel):
|
||||||
|
url: URL
|
||||||
|
method: RestAPIMethod
|
||||||
|
params: dict[str, Any] | None = None
|
||||||
|
headers: dict[str, Any] | None = None
|
||||||
|
body: dict[str, Any] | None = None
|
|
@ -4,7 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from enum import Enum
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
@ -12,11 +11,6 @@ from pydantic import BaseModel
|
||||||
from llama_stack.schema_utils import json_schema_type
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
|
||||||
class Order(Enum):
|
|
||||||
asc = "asc"
|
|
||||||
desc = "desc"
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class PaginatedResponse(BaseModel):
|
class PaginatedResponse(BaseModel):
|
||||||
"""A generic paginated response that follows a simple format.
|
"""A generic paginated response that follows a simple format.
|
||||||
|
|
|
@ -19,7 +19,6 @@ from pydantic import BaseModel, Field, field_validator
|
||||||
from typing_extensions import TypedDict
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent, InterleavedContentItem
|
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent, InterleavedContentItem
|
||||||
from llama_stack.apis.common.responses import Order
|
|
||||||
from llama_stack.apis.models import Model
|
from llama_stack.apis.models import Model
|
||||||
from llama_stack.apis.telemetry.telemetry import MetricResponseMixin
|
from llama_stack.apis.telemetry.telemetry import MetricResponseMixin
|
||||||
from llama_stack.models.llama.datatypes import (
|
from llama_stack.models.llama.datatypes import (
|
||||||
|
@ -783,48 +782,6 @@ class OpenAICompletion(BaseModel):
|
||||||
object: Literal["text_completion"] = "text_completion"
|
object: Literal["text_completion"] = "text_completion"
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class OpenAIEmbeddingData(BaseModel):
|
|
||||||
"""A single embedding data object from an OpenAI-compatible embeddings response.
|
|
||||||
|
|
||||||
:param object: The object type, which will be "embedding"
|
|
||||||
:param embedding: The embedding vector as a list of floats (when encoding_format="float") or as a base64-encoded string (when encoding_format="base64")
|
|
||||||
:param index: The index of the embedding in the input list
|
|
||||||
"""
|
|
||||||
|
|
||||||
object: Literal["embedding"] = "embedding"
|
|
||||||
embedding: list[float] | str
|
|
||||||
index: int
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class OpenAIEmbeddingUsage(BaseModel):
|
|
||||||
"""Usage information for an OpenAI-compatible embeddings response.
|
|
||||||
|
|
||||||
:param prompt_tokens: The number of tokens in the input
|
|
||||||
:param total_tokens: The total number of tokens used
|
|
||||||
"""
|
|
||||||
|
|
||||||
prompt_tokens: int
|
|
||||||
total_tokens: int
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class OpenAIEmbeddingsResponse(BaseModel):
|
|
||||||
"""Response from an OpenAI-compatible embeddings request.
|
|
||||||
|
|
||||||
:param object: The object type, which will be "list"
|
|
||||||
:param data: List of embedding data objects
|
|
||||||
:param model: The model that was used to generate the embeddings
|
|
||||||
:param usage: Usage information
|
|
||||||
"""
|
|
||||||
|
|
||||||
object: Literal["list"] = "list"
|
|
||||||
data: list[OpenAIEmbeddingData]
|
|
||||||
model: str
|
|
||||||
usage: OpenAIEmbeddingUsage
|
|
||||||
|
|
||||||
|
|
||||||
class ModelStore(Protocol):
|
class ModelStore(Protocol):
|
||||||
async def get_model(self, identifier: str) -> Model: ...
|
async def get_model(self, identifier: str) -> Model: ...
|
||||||
|
|
||||||
|
@ -863,27 +820,15 @@ class BatchChatCompletionResponse(BaseModel):
|
||||||
batch: list[ChatCompletionResponse]
|
batch: list[ChatCompletionResponse]
|
||||||
|
|
||||||
|
|
||||||
class OpenAICompletionWithInputMessages(OpenAIChatCompletion):
|
|
||||||
input_messages: list[OpenAIMessageParam]
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class ListOpenAIChatCompletionResponse(BaseModel):
|
|
||||||
data: list[OpenAICompletionWithInputMessages]
|
|
||||||
has_more: bool
|
|
||||||
first_id: str
|
|
||||||
last_id: str
|
|
||||||
object: Literal["list"] = "list"
|
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
@trace_protocol
|
@trace_protocol
|
||||||
class InferenceProvider(Protocol):
|
class Inference(Protocol):
|
||||||
"""
|
"""Llama Stack Inference API for generating completions, chat completions, and embeddings.
|
||||||
This protocol defines the interface that should be implemented by all inference providers.
|
|
||||||
"""
|
|
||||||
|
|
||||||
API_NAMESPACE: str = "Inference"
|
This API provides the raw interface to the underlying models. Two kinds of models are supported:
|
||||||
|
- LLM models: these models generate "raw" and "chat" (conversational) completions.
|
||||||
|
- Embedding models: these models generate embeddings to be used for semantic search.
|
||||||
|
"""
|
||||||
|
|
||||||
model_store: ModelStore | None = None
|
model_store: ModelStore | None = None
|
||||||
|
|
||||||
|
@ -1117,59 +1062,3 @@ class InferenceProvider(Protocol):
|
||||||
:returns: An OpenAIChatCompletion.
|
:returns: An OpenAIChatCompletion.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@webmethod(route="/openai/v1/embeddings", method="POST")
|
|
||||||
async def openai_embeddings(
|
|
||||||
self,
|
|
||||||
model: str,
|
|
||||||
input: str | list[str],
|
|
||||||
encoding_format: str | None = "float",
|
|
||||||
dimensions: int | None = None,
|
|
||||||
user: str | None = None,
|
|
||||||
) -> OpenAIEmbeddingsResponse:
|
|
||||||
"""Generate OpenAI-compatible embeddings for the given input using the specified model.
|
|
||||||
|
|
||||||
:param model: The identifier of the model to use. The model must be an embedding model registered with Llama Stack and available via the /models endpoint.
|
|
||||||
:param input: Input text to embed, encoded as a string or array of strings. To embed multiple inputs in a single request, pass an array of strings.
|
|
||||||
:param encoding_format: (Optional) The format to return the embeddings in. Can be either "float" or "base64". Defaults to "float".
|
|
||||||
:param dimensions: (Optional) The number of dimensions the resulting output embeddings should have. Only supported in text-embedding-3 and later models.
|
|
||||||
:param user: (Optional) A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse.
|
|
||||||
:returns: An OpenAIEmbeddingsResponse containing the embeddings.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
class Inference(InferenceProvider):
|
|
||||||
"""Llama Stack Inference API for generating completions, chat completions, and embeddings.
|
|
||||||
|
|
||||||
This API provides the raw interface to the underlying models. Two kinds of models are supported:
|
|
||||||
- LLM models: these models generate "raw" and "chat" (conversational) completions.
|
|
||||||
- Embedding models: these models generate embeddings to be used for semantic search.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@webmethod(route="/openai/v1/chat/completions", method="GET")
|
|
||||||
async def list_chat_completions(
|
|
||||||
self,
|
|
||||||
after: str | None = None,
|
|
||||||
limit: int | None = 20,
|
|
||||||
model: str | None = None,
|
|
||||||
order: Order | None = Order.desc,
|
|
||||||
) -> ListOpenAIChatCompletionResponse:
|
|
||||||
"""List all chat completions.
|
|
||||||
|
|
||||||
:param after: The ID of the last chat completion to return.
|
|
||||||
:param limit: The maximum number of chat completions to return.
|
|
||||||
:param model: The model to filter by.
|
|
||||||
:param order: The order to sort the chat completions by: "asc" or "desc". Defaults to "desc".
|
|
||||||
:returns: A ListOpenAIChatCompletionResponse.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError("List chat completions is not implemented")
|
|
||||||
|
|
||||||
@webmethod(route="/openai/v1/chat/completions/{completion_id}", method="GET")
|
|
||||||
async def get_chat_completion(self, completion_id: str) -> OpenAICompletionWithInputMessages:
|
|
||||||
"""Describe a chat completion by its ID.
|
|
||||||
|
|
||||||
:param completion_id: ID of the chat completion.
|
|
||||||
:returns: A OpenAICompletionWithInputMessages.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError("Get chat completion is not implemented")
|
|
||||||
|
|
|
@ -76,7 +76,6 @@ class RAGQueryConfig(BaseModel):
|
||||||
:param chunk_template: Template for formatting each retrieved chunk in the context.
|
:param chunk_template: Template for formatting each retrieved chunk in the context.
|
||||||
Available placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk content string), {metadata} (chunk metadata dict).
|
Available placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk content string), {metadata} (chunk metadata dict).
|
||||||
Default: "Result {index}\\nContent: {chunk.content}\\nMetadata: {metadata}\\n"
|
Default: "Result {index}\\nContent: {chunk.content}\\nMetadata: {metadata}\\n"
|
||||||
:param mode: Search mode for retrieval—either "vector" or "keyword". Default "vector".
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# This config defines how a query is generated using the messages
|
# This config defines how a query is generated using the messages
|
||||||
|
@ -85,7 +84,6 @@ class RAGQueryConfig(BaseModel):
|
||||||
max_tokens_in_context: int = 4096
|
max_tokens_in_context: int = 4096
|
||||||
max_chunks: int = 5
|
max_chunks: int = 5
|
||||||
chunk_template: str = "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n"
|
chunk_template: str = "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n"
|
||||||
mode: str | None = None
|
|
||||||
|
|
||||||
@field_validator("chunk_template")
|
@field_validator("chunk_template")
|
||||||
def validate_chunk_template(cls, v: str) -> str:
|
def validate_chunk_template(cls, v: str) -> str:
|
||||||
|
|
|
@ -27,10 +27,18 @@ class ToolParameter(BaseModel):
|
||||||
default: Any | None = None
|
default: Any | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ToolHost(Enum):
|
||||||
|
distribution = "distribution"
|
||||||
|
client = "client"
|
||||||
|
model_context_protocol = "model_context_protocol"
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class Tool(Resource):
|
class Tool(Resource):
|
||||||
type: Literal[ResourceType.tool] = ResourceType.tool
|
type: Literal[ResourceType.tool] = ResourceType.tool
|
||||||
toolgroup_id: str
|
toolgroup_id: str
|
||||||
|
tool_host: ToolHost
|
||||||
description: str
|
description: str
|
||||||
parameters: list[ToolParameter]
|
parameters: list[ToolParameter]
|
||||||
metadata: dict[str, Any] | None = None
|
metadata: dict[str, Any] | None = None
|
||||||
|
@ -68,8 +76,8 @@ class ToolInvocationResult(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class ToolStore(Protocol):
|
class ToolStore(Protocol):
|
||||||
async def get_tool(self, tool_name: str) -> Tool: ...
|
def get_tool(self, tool_name: str) -> Tool: ...
|
||||||
async def get_tool_group(self, toolgroup_id: str) -> ToolGroup: ...
|
def get_tool_group(self, toolgroup_id: str) -> ToolGroup: ...
|
||||||
|
|
||||||
|
|
||||||
class ListToolGroupsResponse(BaseModel):
|
class ListToolGroupsResponse(BaseModel):
|
||||||
|
|
|
@ -19,16 +19,8 @@ from llama_stack.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
|
|
||||||
class Chunk(BaseModel):
|
class Chunk(BaseModel):
|
||||||
"""
|
|
||||||
A chunk of content that can be inserted into a vector database.
|
|
||||||
:param content: The content of the chunk, which can be interleaved text, images, or other types.
|
|
||||||
:param embedding: Optional embedding for the chunk. If not provided, it will be computed later.
|
|
||||||
:param metadata: Metadata associated with the chunk, such as document ID, source, or other relevant information.
|
|
||||||
"""
|
|
||||||
|
|
||||||
content: InterleavedContent
|
content: InterleavedContent
|
||||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||||
embedding: list[float] | None = None
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -58,10 +50,7 @@ class VectorIO(Protocol):
|
||||||
"""Insert chunks into a vector database.
|
"""Insert chunks into a vector database.
|
||||||
|
|
||||||
:param vector_db_id: The identifier of the vector database to insert the chunks into.
|
:param vector_db_id: The identifier of the vector database to insert the chunks into.
|
||||||
:param chunks: The chunks to insert. Each `Chunk` should contain content which can be interleaved text, images, or other types.
|
:param chunks: The chunks to insert.
|
||||||
`metadata`: `dict[str, Any]` and `embedding`: `List[float]` are optional.
|
|
||||||
If `metadata` is provided, you configure how Llama Stack formats the chunk during generation.
|
|
||||||
If `embedding` is not provided, it will be computed later.
|
|
||||||
:param ttl_seconds: The time to live of the chunks.
|
:param ttl_seconds: The time to live of the chunks.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
|
@ -9,7 +9,6 @@ import asyncio
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import sys
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
@ -378,15 +377,14 @@ def _meta_download(
|
||||||
downloader = ParallelDownloader(max_concurrent_downloads=max_concurrent_downloads)
|
downloader = ParallelDownloader(max_concurrent_downloads=max_concurrent_downloads)
|
||||||
asyncio.run(downloader.download_all(tasks))
|
asyncio.run(downloader.download_all(tasks))
|
||||||
|
|
||||||
cprint(f"\nSuccessfully downloaded model to {output_dir}", color="green", file=sys.stderr)
|
cprint(f"\nSuccessfully downloaded model to {output_dir}", "green")
|
||||||
cprint(
|
cprint(
|
||||||
f"\nView MD5 checksum files at: {output_dir / 'checklist.chk'}",
|
f"\nView MD5 checksum files at: {output_dir / 'checklist.chk'}",
|
||||||
file=sys.stderr,
|
"white",
|
||||||
)
|
)
|
||||||
cprint(
|
cprint(
|
||||||
f"\n[Optionally] To run MD5 checksums, use the following command: llama model verify-download --model-id {model_id}",
|
f"\n[Optionally] To run MD5 checksums, use the following command: llama model verify-download --model-id {model_id}",
|
||||||
color="yellow",
|
"yellow",
|
||||||
file=sys.stderr,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -12,7 +12,6 @@ import shutil
|
||||||
import sys
|
import sys
|
||||||
import textwrap
|
import textwrap
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from importlib.abc import Traversable
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
@ -79,7 +78,6 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
||||||
cprint(
|
cprint(
|
||||||
f"Could not find template {args.template}. Please run `llama stack build --list-templates` to check out the available templates",
|
f"Could not find template {args.template}. Please run `llama stack build --list-templates` to check out the available templates",
|
||||||
color="red",
|
color="red",
|
||||||
file=sys.stderr,
|
|
||||||
)
|
)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
build_config = available_templates[args.template]
|
build_config = available_templates[args.template]
|
||||||
|
@ -89,7 +87,6 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
||||||
cprint(
|
cprint(
|
||||||
f"Please specify a image-type ({' | '.join(e.value for e in ImageType)}) for {args.template}",
|
f"Please specify a image-type ({' | '.join(e.value for e in ImageType)}) for {args.template}",
|
||||||
color="red",
|
color="red",
|
||||||
file=sys.stderr,
|
|
||||||
)
|
)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
elif args.providers:
|
elif args.providers:
|
||||||
|
@ -99,7 +96,6 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
||||||
cprint(
|
cprint(
|
||||||
"Could not parse `--providers`. Please ensure the list is in the format api1=provider1,api2=provider2",
|
"Could not parse `--providers`. Please ensure the list is in the format api1=provider1,api2=provider2",
|
||||||
color="red",
|
color="red",
|
||||||
file=sys.stderr,
|
|
||||||
)
|
)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
api, provider = api_provider.split("=")
|
api, provider = api_provider.split("=")
|
||||||
|
@ -108,7 +104,6 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
||||||
cprint(
|
cprint(
|
||||||
f"{api} is not a valid API.",
|
f"{api} is not a valid API.",
|
||||||
color="red",
|
color="red",
|
||||||
file=sys.stderr,
|
|
||||||
)
|
)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
if provider in providers_for_api:
|
if provider in providers_for_api:
|
||||||
|
@ -117,7 +112,6 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
||||||
cprint(
|
cprint(
|
||||||
f"{provider} is not a valid provider for the {api} API.",
|
f"{provider} is not a valid provider for the {api} API.",
|
||||||
color="red",
|
color="red",
|
||||||
file=sys.stderr,
|
|
||||||
)
|
)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
distribution_spec = DistributionSpec(
|
distribution_spec = DistributionSpec(
|
||||||
|
@ -128,7 +122,6 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
||||||
cprint(
|
cprint(
|
||||||
f"Please specify a image-type (container | conda | venv) for {args.template}",
|
f"Please specify a image-type (container | conda | venv) for {args.template}",
|
||||||
color="red",
|
color="red",
|
||||||
file=sys.stderr,
|
|
||||||
)
|
)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
@ -157,14 +150,12 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
||||||
cprint(
|
cprint(
|
||||||
f"No current conda environment detected or specified, will create a new conda environment with the name `llamastack-{name}`",
|
f"No current conda environment detected or specified, will create a new conda environment with the name `llamastack-{name}`",
|
||||||
color="yellow",
|
color="yellow",
|
||||||
file=sys.stderr,
|
|
||||||
)
|
)
|
||||||
image_name = f"llamastack-{name}"
|
image_name = f"llamastack-{name}"
|
||||||
else:
|
else:
|
||||||
cprint(
|
cprint(
|
||||||
f"Using conda environment {image_name}",
|
f"Using conda environment {image_name}",
|
||||||
color="green",
|
color="green",
|
||||||
file=sys.stderr,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
image_name = f"llamastack-{name}"
|
image_name = f"llamastack-{name}"
|
||||||
|
@ -177,10 +168,9 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
||||||
""",
|
""",
|
||||||
),
|
),
|
||||||
color="green",
|
color="green",
|
||||||
file=sys.stderr,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
cprint("Tip: use <TAB> to see options for the providers.\n", color="green", file=sys.stderr)
|
print("Tip: use <TAB> to see options for the providers.\n")
|
||||||
|
|
||||||
providers = dict()
|
providers = dict()
|
||||||
for api, providers_for_api in get_provider_registry().items():
|
for api, providers_for_api in get_provider_registry().items():
|
||||||
|
@ -216,13 +206,10 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
||||||
contents = yaml.safe_load(f)
|
contents = yaml.safe_load(f)
|
||||||
contents = replace_env_vars(contents)
|
contents = replace_env_vars(contents)
|
||||||
build_config = BuildConfig(**contents)
|
build_config = BuildConfig(**contents)
|
||||||
if args.image_type:
|
|
||||||
build_config.image_type = args.image_type
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
cprint(
|
cprint(
|
||||||
f"Could not parse config file {args.config}: {e}",
|
f"Could not parse config file {args.config}: {e}",
|
||||||
color="red",
|
color="red",
|
||||||
file=sys.stderr,
|
|
||||||
)
|
)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
@ -249,27 +236,25 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
||||||
cprint(
|
cprint(
|
||||||
f"Error building stack: {exc}",
|
f"Error building stack: {exc}",
|
||||||
color="red",
|
color="red",
|
||||||
file=sys.stderr,
|
|
||||||
)
|
)
|
||||||
cprint("Stack trace:", color="red", file=sys.stderr)
|
cprint("Stack trace:", color="red")
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
if run_config is None:
|
if run_config is None:
|
||||||
cprint(
|
cprint(
|
||||||
"Run config path is empty",
|
"Run config path is empty",
|
||||||
color="red",
|
color="red",
|
||||||
file=sys.stderr,
|
|
||||||
)
|
)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
if args.run:
|
if args.run:
|
||||||
|
run_config = Path(run_config)
|
||||||
config_dict = yaml.safe_load(run_config.read_text())
|
config_dict = yaml.safe_load(run_config.read_text())
|
||||||
config = parse_and_maybe_upgrade_config(config_dict)
|
config = parse_and_maybe_upgrade_config(config_dict)
|
||||||
if config.external_providers_dir and not config.external_providers_dir.exists():
|
if not os.path.exists(str(config.external_providers_dir)):
|
||||||
config.external_providers_dir.mkdir(exist_ok=True)
|
os.makedirs(str(config.external_providers_dir), exist_ok=True)
|
||||||
run_args = formulate_run_args(args.image_type, args.image_name, config, args.template)
|
run_args = formulate_run_args(args.image_type, args.image_name, config, args.template)
|
||||||
run_args.extend([str(os.getenv("LLAMA_STACK_PORT", 8321)), "--config", run_config])
|
run_args.extend([run_config, str(os.getenv("LLAMA_STACK_PORT", 8321))])
|
||||||
run_command(run_args)
|
run_command(run_args)
|
||||||
|
|
||||||
|
|
||||||
|
@ -277,7 +262,7 @@ def _generate_run_config(
|
||||||
build_config: BuildConfig,
|
build_config: BuildConfig,
|
||||||
build_dir: Path,
|
build_dir: Path,
|
||||||
image_name: str,
|
image_name: str,
|
||||||
) -> Path:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Generate a run.yaml template file for user to edit from a build.yaml file
|
Generate a run.yaml template file for user to edit from a build.yaml file
|
||||||
"""
|
"""
|
||||||
|
@ -317,7 +302,6 @@ def _generate_run_config(
|
||||||
cprint(
|
cprint(
|
||||||
f"Failed to import provider {provider_type} for API {api} - assuming it's external, skipping",
|
f"Failed to import provider {provider_type} for API {api} - assuming it's external, skipping",
|
||||||
color="yellow",
|
color="yellow",
|
||||||
file=sys.stderr,
|
|
||||||
)
|
)
|
||||||
# Set config_type to None to avoid UnboundLocalError
|
# Set config_type to None to avoid UnboundLocalError
|
||||||
config_type = None
|
config_type = None
|
||||||
|
@ -345,7 +329,10 @@ def _generate_run_config(
|
||||||
# For non-container builds, the run.yaml is generated at the very end of the build process so it
|
# For non-container builds, the run.yaml is generated at the very end of the build process so it
|
||||||
# makes sense to display this message
|
# makes sense to display this message
|
||||||
if build_config.image_type != LlamaStackImageType.CONTAINER.value:
|
if build_config.image_type != LlamaStackImageType.CONTAINER.value:
|
||||||
cprint(f"You can now run your stack with `llama stack run {run_config_file}`", color="green", file=sys.stderr)
|
cprint(
|
||||||
|
f"You can now run your stack with `llama stack run {run_config_file}`",
|
||||||
|
color="green",
|
||||||
|
)
|
||||||
return run_config_file
|
return run_config_file
|
||||||
|
|
||||||
|
|
||||||
|
@ -354,7 +341,7 @@ def _run_stack_build_command_from_build_config(
|
||||||
image_name: str | None = None,
|
image_name: str | None = None,
|
||||||
template_name: str | None = None,
|
template_name: str | None = None,
|
||||||
config_path: str | None = None,
|
config_path: str | None = None,
|
||||||
) -> Path | Traversable:
|
) -> str:
|
||||||
image_name = image_name or build_config.image_name
|
image_name = image_name or build_config.image_name
|
||||||
if build_config.image_type == LlamaStackImageType.CONTAINER.value:
|
if build_config.image_type == LlamaStackImageType.CONTAINER.value:
|
||||||
if template_name:
|
if template_name:
|
||||||
|
@ -383,7 +370,7 @@ def _run_stack_build_command_from_build_config(
|
||||||
# Generate the run.yaml so it can be included in the container image with the proper entrypoint
|
# Generate the run.yaml so it can be included in the container image with the proper entrypoint
|
||||||
# Only do this if we're building a container image and we're not using a template
|
# Only do this if we're building a container image and we're not using a template
|
||||||
if build_config.image_type == LlamaStackImageType.CONTAINER.value and not template_name and config_path:
|
if build_config.image_type == LlamaStackImageType.CONTAINER.value and not template_name and config_path:
|
||||||
cprint("Generating run.yaml file", color="yellow", file=sys.stderr)
|
cprint("Generating run.yaml file", color="green")
|
||||||
run_config_file = _generate_run_config(build_config, build_dir, image_name)
|
run_config_file = _generate_run_config(build_config, build_dir, image_name)
|
||||||
|
|
||||||
with open(build_file_path, "w") as f:
|
with open(build_file_path, "w") as f:
|
||||||
|
@ -407,13 +394,11 @@ def _run_stack_build_command_from_build_config(
|
||||||
run_config_file = build_dir / f"{template_name}-run.yaml"
|
run_config_file = build_dir / f"{template_name}-run.yaml"
|
||||||
shutil.copy(path, run_config_file)
|
shutil.copy(path, run_config_file)
|
||||||
|
|
||||||
cprint("Build Successful!", color="green", file=sys.stderr)
|
cprint("Build Successful!", color="green")
|
||||||
cprint(f"You can find the newly-built template here: {template_path}", color="light_blue", file=sys.stderr)
|
cprint("You can find the newly-built template here: " + colored(template_path, "light_blue"))
|
||||||
cprint(
|
cprint(
|
||||||
"You can run the new Llama Stack distro via: "
|
"You can run the new Llama Stack distro via: "
|
||||||
+ colored(f"llama stack run {template_path} --image-type {build_config.image_type}", "light_blue"),
|
+ colored(f"llama stack run {template_path} --image-type {build_config.image_type}", "light_blue")
|
||||||
color="green",
|
|
||||||
file=sys.stderr,
|
|
||||||
)
|
)
|
||||||
return template_path
|
return template_path
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -49,7 +49,7 @@ class StackBuild(Subcommand):
|
||||||
type=str,
|
type=str,
|
||||||
help="Image Type to use for the build. If not specified, will use the image type from the template config.",
|
help="Image Type to use for the build. If not specified, will use the image type from the template config.",
|
||||||
choices=[e.value for e in ImageType],
|
choices=[e.value for e in ImageType],
|
||||||
default=None, # no default so we can detect if a user specified --image-type and override image_type in the config
|
default=ImageType.CONDA.value,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.parser.add_argument(
|
self.parser.add_argument(
|
||||||
|
|
|
@ -1,56 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from llama_stack.cli.subcommand import Subcommand
|
|
||||||
from llama_stack.cli.table import print_table
|
|
||||||
|
|
||||||
|
|
||||||
class StackListBuilds(Subcommand):
|
|
||||||
"""List built stacks in .llama/distributions directory"""
|
|
||||||
|
|
||||||
def __init__(self, subparsers: argparse._SubParsersAction):
|
|
||||||
super().__init__()
|
|
||||||
self.parser = subparsers.add_parser(
|
|
||||||
"list",
|
|
||||||
prog="llama stack list",
|
|
||||||
description="list the build stacks",
|
|
||||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
|
||||||
)
|
|
||||||
self._add_arguments()
|
|
||||||
self.parser.set_defaults(func=self._list_stack_command)
|
|
||||||
|
|
||||||
def _get_distribution_dirs(self) -> dict[str, Path]:
|
|
||||||
"""Return a dictionary of distribution names and their paths"""
|
|
||||||
distributions = {}
|
|
||||||
dist_dir = Path.home() / ".llama" / "distributions"
|
|
||||||
|
|
||||||
if dist_dir.exists():
|
|
||||||
for stack_dir in dist_dir.iterdir():
|
|
||||||
if stack_dir.is_dir():
|
|
||||||
distributions[stack_dir.name] = stack_dir
|
|
||||||
return distributions
|
|
||||||
|
|
||||||
def _list_stack_command(self, args: argparse.Namespace) -> None:
|
|
||||||
distributions = self._get_distribution_dirs()
|
|
||||||
|
|
||||||
if not distributions:
|
|
||||||
print("No stacks found in ~/.llama/distributions")
|
|
||||||
return
|
|
||||||
|
|
||||||
headers = ["Stack Name", "Path"]
|
|
||||||
headers.extend(["Build Config", "Run Config"])
|
|
||||||
rows = []
|
|
||||||
for name, path in distributions.items():
|
|
||||||
row = [name, str(path)]
|
|
||||||
# Check for build and run config files
|
|
||||||
build_config = "Yes" if (path / f"{name}-build.yaml").exists() else "No"
|
|
||||||
run_config = "Yes" if (path / f"{name}-run.yaml").exists() else "No"
|
|
||||||
row.extend([build_config, run_config])
|
|
||||||
rows.append(row)
|
|
||||||
print_table(rows, headers, separate_rows=True)
|
|
|
@ -1,115 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import shutil
|
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from termcolor import cprint
|
|
||||||
|
|
||||||
from llama_stack.cli.subcommand import Subcommand
|
|
||||||
from llama_stack.cli.table import print_table
|
|
||||||
|
|
||||||
|
|
||||||
class StackRemove(Subcommand):
|
|
||||||
"""Remove the build stack"""
|
|
||||||
|
|
||||||
def __init__(self, subparsers: argparse._SubParsersAction):
|
|
||||||
super().__init__()
|
|
||||||
self.parser = subparsers.add_parser(
|
|
||||||
"rm",
|
|
||||||
prog="llama stack rm",
|
|
||||||
description="Remove the build stack",
|
|
||||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
|
||||||
)
|
|
||||||
self._add_arguments()
|
|
||||||
self.parser.set_defaults(func=self._remove_stack_build_command)
|
|
||||||
|
|
||||||
def _add_arguments(self) -> None:
|
|
||||||
self.parser.add_argument(
|
|
||||||
"name",
|
|
||||||
type=str,
|
|
||||||
nargs="?",
|
|
||||||
help="Name of the stack to delete",
|
|
||||||
)
|
|
||||||
self.parser.add_argument(
|
|
||||||
"--all",
|
|
||||||
"-a",
|
|
||||||
action="store_true",
|
|
||||||
help="Delete all stacks (use with caution)",
|
|
||||||
)
|
|
||||||
|
|
||||||
def _get_distribution_dirs(self) -> dict[str, Path]:
|
|
||||||
"""Return a dictionary of distribution names and their paths"""
|
|
||||||
distributions = {}
|
|
||||||
dist_dir = Path.home() / ".llama" / "distributions"
|
|
||||||
|
|
||||||
if dist_dir.exists():
|
|
||||||
for stack_dir in dist_dir.iterdir():
|
|
||||||
if stack_dir.is_dir():
|
|
||||||
distributions[stack_dir.name] = stack_dir
|
|
||||||
return distributions
|
|
||||||
|
|
||||||
def _list_stacks(self) -> None:
|
|
||||||
"""Display available stacks in a table"""
|
|
||||||
distributions = self._get_distribution_dirs()
|
|
||||||
if not distributions:
|
|
||||||
cprint("No stacks found in ~/.llama/distributions", color="red", file=sys.stderr)
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
headers = ["Stack Name", "Path"]
|
|
||||||
rows = [[name, str(path)] for name, path in distributions.items()]
|
|
||||||
print_table(rows, headers, separate_rows=True)
|
|
||||||
|
|
||||||
def _remove_stack_build_command(self, args: argparse.Namespace) -> None:
|
|
||||||
distributions = self._get_distribution_dirs()
|
|
||||||
|
|
||||||
if args.all:
|
|
||||||
confirm = input("Are you sure you want to delete ALL stacks? [yes-i-really-want/N] ").lower()
|
|
||||||
if confirm != "yes-i-really-want":
|
|
||||||
cprint("Deletion cancelled.", color="green", file=sys.stderr)
|
|
||||||
return
|
|
||||||
|
|
||||||
for name, path in distributions.items():
|
|
||||||
try:
|
|
||||||
shutil.rmtree(path)
|
|
||||||
cprint(f"Deleted stack: {name}", color="green", file=sys.stderr)
|
|
||||||
except Exception as e:
|
|
||||||
cprint(
|
|
||||||
f"Failed to delete stack {name}: {e}",
|
|
||||||
color="red",
|
|
||||||
file=sys.stderr,
|
|
||||||
)
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
if not args.name:
|
|
||||||
self._list_stacks()
|
|
||||||
if not args.name:
|
|
||||||
return
|
|
||||||
|
|
||||||
if args.name not in distributions:
|
|
||||||
self._list_stacks()
|
|
||||||
cprint(
|
|
||||||
f"Stack not found: {args.name}",
|
|
||||||
color="red",
|
|
||||||
file=sys.stderr,
|
|
||||||
)
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
stack_path = distributions[args.name]
|
|
||||||
|
|
||||||
confirm = input(f"Are you sure you want to delete stack '{args.name}'? [y/N] ").lower()
|
|
||||||
if confirm != "y":
|
|
||||||
cprint("Deletion cancelled.", color="green", file=sys.stderr)
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
shutil.rmtree(stack_path)
|
|
||||||
cprint(f"Successfully deleted stack: {args.name}", color="green", file=sys.stderr)
|
|
||||||
except Exception as e:
|
|
||||||
cprint(f"Failed to delete stack {args.name}: {e}", color="red", file=sys.stderr)
|
|
||||||
sys.exit(1)
|
|
|
@ -6,7 +6,6 @@
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import subprocess
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from llama_stack.cli.stack.utils import ImageType
|
from llama_stack.cli.stack.utils import ImageType
|
||||||
|
@ -61,11 +60,6 @@ class StackRun(Subcommand):
|
||||||
help="Image Type used during the build. This can be either conda or container or venv.",
|
help="Image Type used during the build. This can be either conda or container or venv.",
|
||||||
choices=[e.value for e in ImageType],
|
choices=[e.value for e in ImageType],
|
||||||
)
|
)
|
||||||
self.parser.add_argument(
|
|
||||||
"--enable-ui",
|
|
||||||
action="store_true",
|
|
||||||
help="Start the UI server",
|
|
||||||
)
|
|
||||||
|
|
||||||
# If neither image type nor image name is provided, but at the same time
|
# If neither image type nor image name is provided, but at the same time
|
||||||
# the current environment has conda breadcrumbs, then assume what the user
|
# the current environment has conda breadcrumbs, then assume what the user
|
||||||
|
@ -89,8 +83,6 @@ class StackRun(Subcommand):
|
||||||
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
|
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||||
from llama_stack.distribution.utils.exec import formulate_run_args, run_command
|
from llama_stack.distribution.utils.exec import formulate_run_args, run_command
|
||||||
|
|
||||||
if args.enable_ui:
|
|
||||||
self._start_ui_development_server(args.port)
|
|
||||||
image_type, image_name = self._get_image_type_and_name(args)
|
image_type, image_name = self._get_image_type_and_name(args)
|
||||||
|
|
||||||
# Check if config is required based on image type
|
# Check if config is required based on image type
|
||||||
|
@ -178,44 +170,3 @@ class StackRun(Subcommand):
|
||||||
run_args.extend(["--env", f"{key}={value}"])
|
run_args.extend(["--env", f"{key}={value}"])
|
||||||
|
|
||||||
run_command(run_args)
|
run_command(run_args)
|
||||||
|
|
||||||
def _start_ui_development_server(self, stack_server_port: int):
|
|
||||||
logger.info("Attempting to start UI development server...")
|
|
||||||
# Check if npm is available
|
|
||||||
npm_check = subprocess.run(["npm", "--version"], capture_output=True, text=True, check=False)
|
|
||||||
if npm_check.returncode != 0:
|
|
||||||
logger.warning(
|
|
||||||
f"'npm' command not found or not executable. UI development server will not be started. Error: {npm_check.stderr}"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
ui_dir = REPO_ROOT / "llama_stack" / "ui"
|
|
||||||
logs_dir = Path("~/.llama/ui/logs").expanduser()
|
|
||||||
try:
|
|
||||||
# Create logs directory if it doesn't exist
|
|
||||||
logs_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
ui_stdout_log_path = logs_dir / "stdout.log"
|
|
||||||
ui_stderr_log_path = logs_dir / "stderr.log"
|
|
||||||
|
|
||||||
# Open log files in append mode
|
|
||||||
stdout_log_file = open(ui_stdout_log_path, "a")
|
|
||||||
stderr_log_file = open(ui_stderr_log_path, "a")
|
|
||||||
|
|
||||||
process = subprocess.Popen(
|
|
||||||
["npm", "run", "dev"],
|
|
||||||
cwd=str(ui_dir),
|
|
||||||
stdout=stdout_log_file,
|
|
||||||
stderr=stderr_log_file,
|
|
||||||
env={**os.environ, "NEXT_PUBLIC_LLAMA_STACK_BASE_URL": f"http://localhost:{stack_server_port}"},
|
|
||||||
)
|
|
||||||
logger.info(f"UI development server process started in {ui_dir} with PID {process.pid}.")
|
|
||||||
logger.info(f"Logs: stdout -> {ui_stdout_log_path}, stderr -> {ui_stderr_log_path}")
|
|
||||||
logger.info(f"UI will be available at http://localhost:{os.getenv('LLAMA_STACK_UI_PORT', 8322)}")
|
|
||||||
|
|
||||||
except FileNotFoundError:
|
|
||||||
logger.error(
|
|
||||||
"Failed to start UI development server: 'npm' command not found. Make sure npm is installed and in your PATH."
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to start UI development server in {ui_dir}: {e}")
|
|
||||||
|
|
|
@ -7,14 +7,12 @@
|
||||||
import argparse
|
import argparse
|
||||||
from importlib.metadata import version
|
from importlib.metadata import version
|
||||||
|
|
||||||
from llama_stack.cli.stack.list_stacks import StackListBuilds
|
|
||||||
from llama_stack.cli.stack.utils import print_subcommand_description
|
from llama_stack.cli.stack.utils import print_subcommand_description
|
||||||
from llama_stack.cli.subcommand import Subcommand
|
from llama_stack.cli.subcommand import Subcommand
|
||||||
|
|
||||||
from .build import StackBuild
|
from .build import StackBuild
|
||||||
from .list_apis import StackListApis
|
from .list_apis import StackListApis
|
||||||
from .list_providers import StackListProviders
|
from .list_providers import StackListProviders
|
||||||
from .remove import StackRemove
|
|
||||||
from .run import StackRun
|
from .run import StackRun
|
||||||
|
|
||||||
|
|
||||||
|
@ -43,6 +41,5 @@ class StackParser(Subcommand):
|
||||||
StackListApis.create(subparsers)
|
StackListApis.create(subparsers)
|
||||||
StackListProviders.create(subparsers)
|
StackListProviders.create(subparsers)
|
||||||
StackRun.create(subparsers)
|
StackRun.create(subparsers)
|
||||||
StackRemove.create(subparsers)
|
|
||||||
StackListBuilds.create(subparsers)
|
|
||||||
print_subcommand_description(self.parser, subparsers)
|
print_subcommand_description(self.parser, subparsers)
|
||||||
|
|
|
@ -6,7 +6,6 @@
|
||||||
|
|
||||||
import importlib.resources
|
import importlib.resources
|
||||||
import logging
|
import logging
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
@ -44,20 +43,8 @@ def get_provider_dependencies(
|
||||||
# Extract providers based on config type
|
# Extract providers based on config type
|
||||||
if isinstance(config, DistributionTemplate):
|
if isinstance(config, DistributionTemplate):
|
||||||
providers = config.providers
|
providers = config.providers
|
||||||
|
|
||||||
# TODO: This is a hack to get the dependencies for internal APIs into build
|
|
||||||
# We should have a better way to do this by formalizing the concept of "internal" APIs
|
|
||||||
# and providers, with a way to specify dependencies for them.
|
|
||||||
run_configs = config.run_configs
|
|
||||||
additional_pip_packages: list[str] = []
|
|
||||||
if run_configs:
|
|
||||||
for run_config in run_configs.values():
|
|
||||||
run_config_ = run_config.run_config(name="", providers={}, container_image=None)
|
|
||||||
if run_config_.inference_store:
|
|
||||||
additional_pip_packages.extend(run_config_.inference_store.pip_packages)
|
|
||||||
elif isinstance(config, BuildConfig):
|
elif isinstance(config, BuildConfig):
|
||||||
providers = config.distribution_spec.providers
|
providers = config.distribution_spec.providers
|
||||||
additional_pip_packages = config.additional_pip_packages
|
|
||||||
deps = []
|
deps = []
|
||||||
registry = get_provider_registry(config)
|
registry = get_provider_registry(config)
|
||||||
for api_str, provider_or_providers in providers.items():
|
for api_str, provider_or_providers in providers.items():
|
||||||
|
@ -85,9 +72,6 @@ def get_provider_dependencies(
|
||||||
else:
|
else:
|
||||||
normal_deps.append(package)
|
normal_deps.append(package)
|
||||||
|
|
||||||
if additional_pip_packages:
|
|
||||||
normal_deps.extend(additional_pip_packages)
|
|
||||||
|
|
||||||
return list(set(normal_deps)), list(set(special_deps))
|
return list(set(normal_deps)), list(set(special_deps))
|
||||||
|
|
||||||
|
|
||||||
|
@ -96,11 +80,10 @@ def print_pip_install_help(config: BuildConfig):
|
||||||
|
|
||||||
cprint(
|
cprint(
|
||||||
f"Please install needed dependencies using the following commands:\n\nuv pip install {' '.join(normal_deps)}",
|
f"Please install needed dependencies using the following commands:\n\nuv pip install {' '.join(normal_deps)}",
|
||||||
color="yellow",
|
"yellow",
|
||||||
file=sys.stderr,
|
|
||||||
)
|
)
|
||||||
for special_dep in special_deps:
|
for special_dep in special_deps:
|
||||||
cprint(f"uv pip install {special_dep}", color="yellow", file=sys.stderr)
|
cprint(f"uv pip install {special_dep}", "yellow")
|
||||||
print()
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -6,7 +6,6 @@
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
import sys
|
|
||||||
from collections.abc import AsyncIterator
|
from collections.abc import AsyncIterator
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Union, get_args, get_origin
|
from typing import Any, Union, get_args, get_origin
|
||||||
|
@ -97,13 +96,13 @@ def create_api_client_class(protocol) -> type:
|
||||||
try:
|
try:
|
||||||
data = json.loads(data)
|
data = json.loads(data)
|
||||||
if "error" in data:
|
if "error" in data:
|
||||||
cprint(data, color="red", file=sys.stderr)
|
cprint(data, "red")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
yield parse_obj_as(return_type, data)
|
yield parse_obj_as(return_type, data)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
cprint(f"Error with parsing or validation: {e}", color="red", file=sys.stderr)
|
print(f"Error with parsing or validation: {e}")
|
||||||
cprint(data, color="red", file=sys.stderr)
|
print(data)
|
||||||
|
|
||||||
def httpx_request_params(self, method_name: str, *args, **kwargs) -> dict:
|
def httpx_request_params(self, method_name: str, *args, **kwargs) -> dict:
|
||||||
webmethod, sig = self.routes[method_name]
|
webmethod, sig = self.routes[method_name]
|
||||||
|
|
|
@ -25,8 +25,7 @@ from llama_stack.apis.tools import Tool, ToolGroup, ToolGroupInput, ToolRuntime
|
||||||
from llama_stack.apis.vector_dbs import VectorDB, VectorDBInput
|
from llama_stack.apis.vector_dbs import VectorDB, VectorDBInput
|
||||||
from llama_stack.apis.vector_io import VectorIO
|
from llama_stack.apis.vector_io import VectorIO
|
||||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||||
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
from llama_stack.providers.utils.kvstore.config import KVStoreConfig
|
||||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqlStoreConfig
|
|
||||||
|
|
||||||
LLAMA_STACK_BUILD_CONFIG_VERSION = "2"
|
LLAMA_STACK_BUILD_CONFIG_VERSION = "2"
|
||||||
LLAMA_STACK_RUN_CONFIG_VERSION = "2"
|
LLAMA_STACK_RUN_CONFIG_VERSION = "2"
|
||||||
|
@ -221,38 +220,21 @@ class LoggingConfig(BaseModel):
|
||||||
class AuthProviderType(str, Enum):
|
class AuthProviderType(str, Enum):
|
||||||
"""Supported authentication provider types."""
|
"""Supported authentication provider types."""
|
||||||
|
|
||||||
OAUTH2_TOKEN = "oauth2_token"
|
KUBERNETES = "kubernetes"
|
||||||
CUSTOM = "custom"
|
CUSTOM = "custom"
|
||||||
|
|
||||||
|
|
||||||
class AuthenticationConfig(BaseModel):
|
class AuthenticationConfig(BaseModel):
|
||||||
provider_type: AuthProviderType = Field(
|
provider_type: AuthProviderType = Field(
|
||||||
...,
|
...,
|
||||||
description="Type of authentication provider",
|
description="Type of authentication provider (e.g., 'kubernetes', 'custom')",
|
||||||
)
|
)
|
||||||
config: dict[str, Any] = Field(
|
config: dict[str, str] = Field(
|
||||||
...,
|
...,
|
||||||
description="Provider-specific configuration",
|
description="Provider-specific configuration",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class AuthenticationRequiredError(Exception):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class QuotaPeriod(str, Enum):
|
|
||||||
DAY = "day"
|
|
||||||
|
|
||||||
|
|
||||||
class QuotaConfig(BaseModel):
|
|
||||||
kvstore: SqliteKVStoreConfig = Field(description="Config for KV store backend (SQLite only for now)")
|
|
||||||
anonymous_max_requests: int = Field(default=100, description="Max requests for unauthenticated clients per period")
|
|
||||||
authenticated_max_requests: int = Field(
|
|
||||||
default=1000, description="Max requests for authenticated clients per period"
|
|
||||||
)
|
|
||||||
period: QuotaPeriod = Field(default=QuotaPeriod.DAY, description="Quota period to set")
|
|
||||||
|
|
||||||
|
|
||||||
class ServerConfig(BaseModel):
|
class ServerConfig(BaseModel):
|
||||||
port: int = Field(
|
port: int = Field(
|
||||||
default=8321,
|
default=8321,
|
||||||
|
@ -280,10 +262,6 @@ class ServerConfig(BaseModel):
|
||||||
default=None,
|
default=None,
|
||||||
description="The host the server should listen on",
|
description="The host the server should listen on",
|
||||||
)
|
)
|
||||||
quota: QuotaConfig | None = Field(
|
|
||||||
default=None,
|
|
||||||
description="Per client quota request configuration",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class StackRunConfig(BaseModel):
|
class StackRunConfig(BaseModel):
|
||||||
|
@ -319,13 +297,6 @@ Configuration for the persistence store used by the distribution registry. If no
|
||||||
a default SQLite store will be used.""",
|
a default SQLite store will be used.""",
|
||||||
)
|
)
|
||||||
|
|
||||||
inference_store: SqlStoreConfig | None = Field(
|
|
||||||
default=None,
|
|
||||||
description="""
|
|
||||||
Configuration for the persistence store used by the inference API. If not specified,
|
|
||||||
a default SQLite store will be used.""",
|
|
||||||
)
|
|
||||||
|
|
||||||
# registry of "resources" in the distribution
|
# registry of "resources" in the distribution
|
||||||
models: list[ModelInput] = Field(default_factory=list)
|
models: list[ModelInput] = Field(default_factory=list)
|
||||||
shields: list[ShieldInput] = Field(default_factory=list)
|
shields: list[ShieldInput] = Field(default_factory=list)
|
||||||
|
@ -369,21 +340,8 @@ class BuildConfig(BaseModel):
|
||||||
default=None,
|
default=None,
|
||||||
description="Name of the distribution to build",
|
description="Name of the distribution to build",
|
||||||
)
|
)
|
||||||
external_providers_dir: Path | None = Field(
|
external_providers_dir: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Path to directory containing external provider implementations. The providers packages will be resolved from this directory. "
|
description="Path to directory containing external provider implementations. The providers packages will be resolved from this directory. "
|
||||||
"pip_packages MUST contain the provider package name.",
|
"pip_packages MUST contain the provider package name.",
|
||||||
)
|
)
|
||||||
additional_pip_packages: list[str] = Field(
|
|
||||||
default_factory=list,
|
|
||||||
description="Additional pip packages to install in the distribution. These packages will be installed in the distribution environment.",
|
|
||||||
)
|
|
||||||
|
|
||||||
@field_validator("external_providers_dir")
|
|
||||||
@classmethod
|
|
||||||
def validate_external_providers_dir(cls, v):
|
|
||||||
if v is None:
|
|
||||||
return None
|
|
||||||
if isinstance(v, str):
|
|
||||||
return Path(v)
|
|
||||||
return v
|
|
||||||
|
|
|
@ -16,7 +16,7 @@ from llama_stack.apis.inspect import (
|
||||||
VersionInfo,
|
VersionInfo,
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.datatypes import StackRunConfig
|
from llama_stack.distribution.datatypes import StackRunConfig
|
||||||
from llama_stack.distribution.server.routes import get_all_api_routes
|
from llama_stack.distribution.server.endpoints import get_all_api_endpoints
|
||||||
from llama_stack.providers.datatypes import HealthStatus
|
from llama_stack.providers.datatypes import HealthStatus
|
||||||
|
|
||||||
|
|
||||||
|
@ -31,7 +31,7 @@ async def get_provider_impl(config, deps):
|
||||||
|
|
||||||
|
|
||||||
class DistributionInspectImpl(Inspect):
|
class DistributionInspectImpl(Inspect):
|
||||||
def __init__(self, config: DistributionInspectConfig, deps):
|
def __init__(self, config, deps):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.deps = deps
|
self.deps = deps
|
||||||
|
|
||||||
|
@ -39,31 +39,17 @@ class DistributionInspectImpl(Inspect):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def list_routes(self) -> ListRoutesResponse:
|
async def list_routes(self) -> ListRoutesResponse:
|
||||||
run_config: StackRunConfig = self.config.run_config
|
run_config = self.config.run_config
|
||||||
|
|
||||||
ret = []
|
ret = []
|
||||||
all_endpoints = get_all_api_routes()
|
all_endpoints = get_all_api_endpoints()
|
||||||
for api, endpoints in all_endpoints.items():
|
for api, endpoints in all_endpoints.items():
|
||||||
# Always include provider and inspect APIs, filter others based on run config
|
|
||||||
if api.value in ["providers", "inspect"]:
|
|
||||||
ret.extend(
|
|
||||||
[
|
|
||||||
RouteInfo(
|
|
||||||
route=e.path,
|
|
||||||
method=next(iter([m for m in e.methods if m != "HEAD"])),
|
|
||||||
provider_types=[], # These APIs don't have "real" providers - they're internal to the stack
|
|
||||||
)
|
|
||||||
for e in endpoints
|
|
||||||
]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
providers = run_config.providers.get(api.value, [])
|
providers = run_config.providers.get(api.value, [])
|
||||||
if providers: # Only process if there are providers for this API
|
|
||||||
ret.extend(
|
ret.extend(
|
||||||
[
|
[
|
||||||
RouteInfo(
|
RouteInfo(
|
||||||
route=e.path,
|
route=e.route,
|
||||||
method=next(iter([m for m in e.methods if m != "HEAD"])),
|
method=e.method,
|
||||||
provider_types=[p.provider_type for p in providers],
|
provider_types=[p.provider_type for p in providers],
|
||||||
)
|
)
|
||||||
for e in endpoints
|
for e in endpoints
|
||||||
|
|
|
@ -9,7 +9,6 @@ import inspect
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import sys
|
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
@ -37,7 +36,10 @@ from llama_stack.distribution.request_headers import (
|
||||||
request_provider_data_context,
|
request_provider_data_context,
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.resolver import ProviderRegistry
|
from llama_stack.distribution.resolver import ProviderRegistry
|
||||||
from llama_stack.distribution.server.routes import find_matching_route, initialize_route_impls
|
from llama_stack.distribution.server.endpoints import (
|
||||||
|
find_matching_endpoint,
|
||||||
|
initialize_endpoint_impls,
|
||||||
|
)
|
||||||
from llama_stack.distribution.stack import (
|
from llama_stack.distribution.stack import (
|
||||||
construct_stack,
|
construct_stack,
|
||||||
get_stack_run_config_from_template,
|
get_stack_run_config_from_template,
|
||||||
|
@ -205,14 +207,13 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
|
|
||||||
async def initialize(self) -> bool:
|
async def initialize(self) -> bool:
|
||||||
try:
|
try:
|
||||||
self.route_impls = None
|
self.endpoint_impls = None
|
||||||
self.impls = await construct_stack(self.config, self.custom_provider_registry)
|
self.impls = await construct_stack(self.config, self.custom_provider_registry)
|
||||||
except ModuleNotFoundError as _e:
|
except ModuleNotFoundError as _e:
|
||||||
cprint(_e.msg, color="red", file=sys.stderr)
|
cprint(_e.msg, "red")
|
||||||
cprint(
|
cprint(
|
||||||
"Using llama-stack as a library requires installing dependencies depending on the template (providers) you choose.\n",
|
"Using llama-stack as a library requires installing dependencies depending on the template (providers) you choose.\n",
|
||||||
color="yellow",
|
"yellow",
|
||||||
file=sys.stderr,
|
|
||||||
)
|
)
|
||||||
if self.config_path_or_template_name.endswith(".yaml"):
|
if self.config_path_or_template_name.endswith(".yaml"):
|
||||||
# Convert Provider objects to their types
|
# Convert Provider objects to their types
|
||||||
|
@ -225,7 +226,6 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
distribution_spec=DistributionSpec(
|
distribution_spec=DistributionSpec(
|
||||||
providers=provider_types,
|
providers=provider_types,
|
||||||
),
|
),
|
||||||
external_providers_dir=self.config.external_providers_dir,
|
|
||||||
)
|
)
|
||||||
print_pip_install_help(build_config)
|
print_pip_install_help(build_config)
|
||||||
else:
|
else:
|
||||||
|
@ -233,12 +233,6 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
cprint(
|
cprint(
|
||||||
f"Please run:\n\n{prefix}llama stack build --template {self.config_path_or_template_name} --image-type venv\n\n",
|
f"Please run:\n\n{prefix}llama stack build --template {self.config_path_or_template_name} --image-type venv\n\n",
|
||||||
"yellow",
|
"yellow",
|
||||||
file=sys.stderr,
|
|
||||||
)
|
|
||||||
cprint(
|
|
||||||
"Please check your internet connection and try again.",
|
|
||||||
"red",
|
|
||||||
file=sys.stderr,
|
|
||||||
)
|
)
|
||||||
raise _e
|
raise _e
|
||||||
|
|
||||||
|
@ -251,7 +245,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
safe_config = redact_sensitive_fields(self.config.model_dump())
|
safe_config = redact_sensitive_fields(self.config.model_dump())
|
||||||
console.print(yaml.dump(safe_config, indent=2))
|
console.print(yaml.dump(safe_config, indent=2))
|
||||||
|
|
||||||
self.route_impls = initialize_route_impls(self.impls)
|
self.endpoint_impls = initialize_endpoint_impls(self.impls)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
async def request(
|
async def request(
|
||||||
|
@ -262,14 +256,12 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
stream=False,
|
stream=False,
|
||||||
stream_cls=None,
|
stream_cls=None,
|
||||||
):
|
):
|
||||||
if not self.route_impls:
|
if not self.endpoint_impls:
|
||||||
raise ValueError("Client not initialized")
|
raise ValueError("Client not initialized")
|
||||||
|
|
||||||
# Create headers with provider data if available
|
# Create headers with provider data if available
|
||||||
headers = options.headers or {}
|
headers = {}
|
||||||
if self.provider_data:
|
if self.provider_data:
|
||||||
keys = ["X-LlamaStack-Provider-Data", "x-llamastack-provider-data"]
|
|
||||||
if all(key not in headers for key in keys):
|
|
||||||
headers["X-LlamaStack-Provider-Data"] = json.dumps(self.provider_data)
|
headers["X-LlamaStack-Provider-Data"] = json.dumps(self.provider_data)
|
||||||
|
|
||||||
# Use context manager for provider data
|
# Use context manager for provider data
|
||||||
|
@ -293,14 +285,11 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
cast_to: Any,
|
cast_to: Any,
|
||||||
options: Any,
|
options: Any,
|
||||||
):
|
):
|
||||||
if self.route_impls is None:
|
|
||||||
raise ValueError("Client not initialized")
|
|
||||||
|
|
||||||
path = options.url
|
path = options.url
|
||||||
body = options.params or {}
|
body = options.params or {}
|
||||||
body |= options.json_data or {}
|
body |= options.json_data or {}
|
||||||
|
|
||||||
matched_func, path_params, route = find_matching_route(options.method, path, self.route_impls)
|
matched_func, path_params, route = find_matching_endpoint(options.method, path, self.endpoint_impls)
|
||||||
body |= path_params
|
body |= path_params
|
||||||
body = self._convert_body(path, options.method, body)
|
body = self._convert_body(path, options.method, body)
|
||||||
await start_trace(route, {"__location__": "library_client"})
|
await start_trace(route, {"__location__": "library_client"})
|
||||||
|
@ -342,13 +331,10 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
options: Any,
|
options: Any,
|
||||||
stream_cls: Any,
|
stream_cls: Any,
|
||||||
):
|
):
|
||||||
if self.route_impls is None:
|
|
||||||
raise ValueError("Client not initialized")
|
|
||||||
|
|
||||||
path = options.url
|
path = options.url
|
||||||
body = options.params or {}
|
body = options.params or {}
|
||||||
body |= options.json_data or {}
|
body |= options.json_data or {}
|
||||||
func, path_params, route = find_matching_route(options.method, path, self.route_impls)
|
func, path_params, route = find_matching_endpoint(options.method, path, self.endpoint_impls)
|
||||||
body |= path_params
|
body |= path_params
|
||||||
|
|
||||||
body = self._convert_body(path, options.method, body)
|
body = self._convert_body(path, options.method, body)
|
||||||
|
@ -400,10 +386,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
if not body:
|
if not body:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
if self.route_impls is None:
|
func, _, _ = find_matching_endpoint(method, path, self.endpoint_impls)
|
||||||
raise ValueError("Client not initialized")
|
|
||||||
|
|
||||||
func, _, _ = find_matching_route(method, path, self.route_impls)
|
|
||||||
sig = inspect.signature(func)
|
sig = inspect.signature(func)
|
||||||
|
|
||||||
# Strip NOT_GIVENs to use the defaults in signature
|
# Strip NOT_GIVENs to use the defaults in signature
|
||||||
|
|
|
@ -13,7 +13,7 @@ from llama_stack.apis.datasetio import DatasetIO
|
||||||
from llama_stack.apis.datasets import Datasets
|
from llama_stack.apis.datasets import Datasets
|
||||||
from llama_stack.apis.eval import Eval
|
from llama_stack.apis.eval import Eval
|
||||||
from llama_stack.apis.files import Files
|
from llama_stack.apis.files import Files
|
||||||
from llama_stack.apis.inference import Inference, InferenceProvider
|
from llama_stack.apis.inference import Inference
|
||||||
from llama_stack.apis.inspect import Inspect
|
from llama_stack.apis.inspect import Inspect
|
||||||
from llama_stack.apis.models import Models
|
from llama_stack.apis.models import Models
|
||||||
from llama_stack.apis.post_training import PostTraining
|
from llama_stack.apis.post_training import PostTraining
|
||||||
|
@ -47,7 +47,7 @@ from llama_stack.providers.datatypes import (
|
||||||
RemoteProviderSpec,
|
RemoteProviderSpec,
|
||||||
ScoringFunctionsProtocolPrivate,
|
ScoringFunctionsProtocolPrivate,
|
||||||
ShieldsProtocolPrivate,
|
ShieldsProtocolPrivate,
|
||||||
ToolGroupsProtocolPrivate,
|
ToolsProtocolPrivate,
|
||||||
VectorDBsProtocolPrivate,
|
VectorDBsProtocolPrivate,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -83,17 +83,10 @@ def api_protocol_map() -> dict[Api, Any]:
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def api_protocol_map_for_compliance_check() -> dict[Api, Any]:
|
|
||||||
return {
|
|
||||||
**api_protocol_map(),
|
|
||||||
Api.inference: InferenceProvider,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def additional_protocols_map() -> dict[Api, Any]:
|
def additional_protocols_map() -> dict[Api, Any]:
|
||||||
return {
|
return {
|
||||||
Api.inference: (ModelsProtocolPrivate, Models, Api.models),
|
Api.inference: (ModelsProtocolPrivate, Models, Api.models),
|
||||||
Api.tool_groups: (ToolGroupsProtocolPrivate, ToolGroups, Api.tool_groups),
|
Api.tool_groups: (ToolsProtocolPrivate, ToolGroups, Api.tool_groups),
|
||||||
Api.vector_io: (VectorDBsProtocolPrivate, VectorDBs, Api.vector_dbs),
|
Api.vector_io: (VectorDBsProtocolPrivate, VectorDBs, Api.vector_dbs),
|
||||||
Api.safety: (ShieldsProtocolPrivate, Shields, Api.shields),
|
Api.safety: (ShieldsProtocolPrivate, Shields, Api.shields),
|
||||||
Api.datasetio: (DatasetsProtocolPrivate, Datasets, Api.datasets),
|
Api.datasetio: (DatasetsProtocolPrivate, Datasets, Api.datasets),
|
||||||
|
@ -140,7 +133,7 @@ async def resolve_impls(
|
||||||
|
|
||||||
sorted_providers = sort_providers_by_deps(providers_with_specs, run_config)
|
sorted_providers = sort_providers_by_deps(providers_with_specs, run_config)
|
||||||
|
|
||||||
return await instantiate_providers(sorted_providers, router_apis, dist_registry, run_config)
|
return await instantiate_providers(sorted_providers, router_apis, dist_registry)
|
||||||
|
|
||||||
|
|
||||||
def specs_for_autorouted_apis(apis_to_serve: list[str] | set[str]) -> dict[str, dict[str, ProviderWithSpec]]:
|
def specs_for_autorouted_apis(apis_to_serve: list[str] | set[str]) -> dict[str, dict[str, ProviderWithSpec]]:
|
||||||
|
@ -243,10 +236,7 @@ def sort_providers_by_deps(
|
||||||
|
|
||||||
|
|
||||||
async def instantiate_providers(
|
async def instantiate_providers(
|
||||||
sorted_providers: list[tuple[str, ProviderWithSpec]],
|
sorted_providers: list[tuple[str, ProviderWithSpec]], router_apis: set[Api], dist_registry: DistributionRegistry
|
||||||
router_apis: set[Api],
|
|
||||||
dist_registry: DistributionRegistry,
|
|
||||||
run_config: StackRunConfig,
|
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""Instantiates providers asynchronously while managing dependencies."""
|
"""Instantiates providers asynchronously while managing dependencies."""
|
||||||
impls: dict[Api, Any] = {}
|
impls: dict[Api, Any] = {}
|
||||||
|
@ -261,7 +251,7 @@ async def instantiate_providers(
|
||||||
if isinstance(provider.spec, RoutingTableProviderSpec):
|
if isinstance(provider.spec, RoutingTableProviderSpec):
|
||||||
inner_impls = inner_impls_by_provider_id[f"inner-{provider.spec.router_api.value}"]
|
inner_impls = inner_impls_by_provider_id[f"inner-{provider.spec.router_api.value}"]
|
||||||
|
|
||||||
impl = await instantiate_provider(provider, deps, inner_impls, dist_registry, run_config)
|
impl = await instantiate_provider(provider, deps, inner_impls, dist_registry)
|
||||||
|
|
||||||
if api_str.startswith("inner-"):
|
if api_str.startswith("inner-"):
|
||||||
inner_impls_by_provider_id[api_str][provider.provider_id] = impl
|
inner_impls_by_provider_id[api_str][provider.provider_id] = impl
|
||||||
|
@ -311,8 +301,10 @@ async def instantiate_provider(
|
||||||
deps: dict[Api, Any],
|
deps: dict[Api, Any],
|
||||||
inner_impls: dict[str, Any],
|
inner_impls: dict[str, Any],
|
||||||
dist_registry: DistributionRegistry,
|
dist_registry: DistributionRegistry,
|
||||||
run_config: StackRunConfig,
|
|
||||||
):
|
):
|
||||||
|
protocols = api_protocol_map()
|
||||||
|
additional_protocols = additional_protocols_map()
|
||||||
|
|
||||||
provider_spec = provider.spec
|
provider_spec = provider.spec
|
||||||
if not hasattr(provider_spec, "module"):
|
if not hasattr(provider_spec, "module"):
|
||||||
raise AttributeError(f"ProviderSpec of type {type(provider_spec)} does not have a 'module' attribute")
|
raise AttributeError(f"ProviderSpec of type {type(provider_spec)} does not have a 'module' attribute")
|
||||||
|
@ -331,7 +323,7 @@ async def instantiate_provider(
|
||||||
method = "get_auto_router_impl"
|
method = "get_auto_router_impl"
|
||||||
|
|
||||||
config = None
|
config = None
|
||||||
args = [provider_spec.api, deps[provider_spec.routing_table_api], deps, run_config]
|
args = [provider_spec.api, deps[provider_spec.routing_table_api], deps]
|
||||||
elif isinstance(provider_spec, RoutingTableProviderSpec):
|
elif isinstance(provider_spec, RoutingTableProviderSpec):
|
||||||
method = "get_routing_table_impl"
|
method = "get_routing_table_impl"
|
||||||
|
|
||||||
|
@ -350,8 +342,6 @@ async def instantiate_provider(
|
||||||
impl.__provider_spec__ = provider_spec
|
impl.__provider_spec__ = provider_spec
|
||||||
impl.__provider_config__ = config
|
impl.__provider_config__ = config
|
||||||
|
|
||||||
protocols = api_protocol_map_for_compliance_check()
|
|
||||||
additional_protocols = additional_protocols_map()
|
|
||||||
# TODO: check compliance for special tool groups
|
# TODO: check compliance for special tool groups
|
||||||
# the impl should be for Api.tool_runtime, the name should be the special tool group, the protocol should be the special tool group protocol
|
# the impl should be for Api.tool_runtime, the name should be the special tool group, the protocol should be the special tool group protocol
|
||||||
check_protocol_compliance(impl, protocols[provider_spec.api])
|
check_protocol_compliance(impl, protocols[provider_spec.api])
|
||||||
|
|
|
@ -7,10 +7,18 @@
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import RoutedProtocol
|
from llama_stack.distribution.datatypes import RoutedProtocol
|
||||||
from llama_stack.distribution.stack import StackRunConfig
|
|
||||||
from llama_stack.distribution.store import DistributionRegistry
|
from llama_stack.distribution.store import DistributionRegistry
|
||||||
from llama_stack.providers.datatypes import Api, RoutingTable
|
from llama_stack.providers.datatypes import Api, RoutingTable
|
||||||
from llama_stack.providers.utils.inference.inference_store import InferenceStore
|
|
||||||
|
from .routing_tables import (
|
||||||
|
BenchmarksRoutingTable,
|
||||||
|
DatasetsRoutingTable,
|
||||||
|
ModelsRoutingTable,
|
||||||
|
ScoringFunctionsRoutingTable,
|
||||||
|
ShieldsRoutingTable,
|
||||||
|
ToolGroupsRoutingTable,
|
||||||
|
VectorDBsRoutingTable,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def get_routing_table_impl(
|
async def get_routing_table_impl(
|
||||||
|
@ -19,14 +27,6 @@ async def get_routing_table_impl(
|
||||||
_deps,
|
_deps,
|
||||||
dist_registry: DistributionRegistry,
|
dist_registry: DistributionRegistry,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
from ..routing_tables.benchmarks import BenchmarksRoutingTable
|
|
||||||
from ..routing_tables.datasets import DatasetsRoutingTable
|
|
||||||
from ..routing_tables.models import ModelsRoutingTable
|
|
||||||
from ..routing_tables.scoring_functions import ScoringFunctionsRoutingTable
|
|
||||||
from ..routing_tables.shields import ShieldsRoutingTable
|
|
||||||
from ..routing_tables.toolgroups import ToolGroupsRoutingTable
|
|
||||||
from ..routing_tables.vector_dbs import VectorDBsRoutingTable
|
|
||||||
|
|
||||||
api_to_tables = {
|
api_to_tables = {
|
||||||
"vector_dbs": VectorDBsRoutingTable,
|
"vector_dbs": VectorDBsRoutingTable,
|
||||||
"models": ModelsRoutingTable,
|
"models": ModelsRoutingTable,
|
||||||
|
@ -45,15 +45,16 @@ async def get_routing_table_impl(
|
||||||
return impl
|
return impl
|
||||||
|
|
||||||
|
|
||||||
async def get_auto_router_impl(
|
async def get_auto_router_impl(api: Api, routing_table: RoutingTable, deps: dict[str, Any]) -> Any:
|
||||||
api: Api, routing_table: RoutingTable, deps: dict[str, Any], run_config: StackRunConfig
|
from .routers import (
|
||||||
) -> Any:
|
DatasetIORouter,
|
||||||
from .datasets import DatasetIORouter
|
EvalRouter,
|
||||||
from .eval_scoring import EvalRouter, ScoringRouter
|
InferenceRouter,
|
||||||
from .inference import InferenceRouter
|
SafetyRouter,
|
||||||
from .safety import SafetyRouter
|
ScoringRouter,
|
||||||
from .tool_runtime import ToolRuntimeRouter
|
ToolRuntimeRouter,
|
||||||
from .vector_io import VectorIORouter
|
VectorIORouter,
|
||||||
|
)
|
||||||
|
|
||||||
api_to_routers = {
|
api_to_routers = {
|
||||||
"vector_io": VectorIORouter,
|
"vector_io": VectorIORouter,
|
||||||
|
@ -75,12 +76,6 @@ async def get_auto_router_impl(
|
||||||
if dep_api in deps:
|
if dep_api in deps:
|
||||||
api_to_dep_impl[dep_name] = deps[dep_api]
|
api_to_dep_impl[dep_name] = deps[dep_api]
|
||||||
|
|
||||||
# TODO: move pass configs to routers instead
|
|
||||||
if api == Api.inference and run_config.inference_store:
|
|
||||||
inference_store = InferenceStore(run_config.inference_store)
|
|
||||||
await inference_store.initialize()
|
|
||||||
api_to_dep_impl["store"] = inference_store
|
|
||||||
|
|
||||||
impl = api_to_routers[api.value](routing_table, **api_to_dep_impl)
|
impl = api_to_routers[api.value](routing_table, **api_to_dep_impl)
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
|
@ -1,71 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from llama_stack.apis.common.responses import PaginatedResponse
|
|
||||||
from llama_stack.apis.datasetio import DatasetIO
|
|
||||||
from llama_stack.apis.datasets import DatasetPurpose, DataSource
|
|
||||||
from llama_stack.log import get_logger
|
|
||||||
from llama_stack.providers.datatypes import RoutingTable
|
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
|
||||||
|
|
||||||
|
|
||||||
class DatasetIORouter(DatasetIO):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
routing_table: RoutingTable,
|
|
||||||
) -> None:
|
|
||||||
logger.debug("Initializing DatasetIORouter")
|
|
||||||
self.routing_table = routing_table
|
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
|
||||||
logger.debug("DatasetIORouter.initialize")
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
|
||||||
logger.debug("DatasetIORouter.shutdown")
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def register_dataset(
|
|
||||||
self,
|
|
||||||
purpose: DatasetPurpose,
|
|
||||||
source: DataSource,
|
|
||||||
metadata: dict[str, Any] | None = None,
|
|
||||||
dataset_id: str | None = None,
|
|
||||||
) -> None:
|
|
||||||
logger.debug(
|
|
||||||
f"DatasetIORouter.register_dataset: {purpose=} {source=} {metadata=} {dataset_id=}",
|
|
||||||
)
|
|
||||||
await self.routing_table.register_dataset(
|
|
||||||
purpose=purpose,
|
|
||||||
source=source,
|
|
||||||
metadata=metadata,
|
|
||||||
dataset_id=dataset_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def iterrows(
|
|
||||||
self,
|
|
||||||
dataset_id: str,
|
|
||||||
start_index: int | None = None,
|
|
||||||
limit: int | None = None,
|
|
||||||
) -> PaginatedResponse:
|
|
||||||
logger.debug(
|
|
||||||
f"DatasetIORouter.iterrows: {dataset_id}, {start_index=} {limit=}",
|
|
||||||
)
|
|
||||||
return await self.routing_table.get_provider_impl(dataset_id).iterrows(
|
|
||||||
dataset_id=dataset_id,
|
|
||||||
start_index=start_index,
|
|
||||||
limit=limit,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None:
|
|
||||||
logger.debug(f"DatasetIORouter.append_rows: {dataset_id}, {len(rows)} rows")
|
|
||||||
return await self.routing_table.get_provider_impl(dataset_id).append_rows(
|
|
||||||
dataset_id=dataset_id,
|
|
||||||
rows=rows,
|
|
||||||
)
|
|
|
@ -1,148 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from llama_stack.apis.eval import BenchmarkConfig, Eval, EvaluateResponse, Job
|
|
||||||
from llama_stack.apis.scoring import (
|
|
||||||
ScoreBatchResponse,
|
|
||||||
ScoreResponse,
|
|
||||||
Scoring,
|
|
||||||
ScoringFnParams,
|
|
||||||
)
|
|
||||||
from llama_stack.log import get_logger
|
|
||||||
from llama_stack.providers.datatypes import RoutingTable
|
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
|
||||||
|
|
||||||
|
|
||||||
class ScoringRouter(Scoring):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
routing_table: RoutingTable,
|
|
||||||
) -> None:
|
|
||||||
logger.debug("Initializing ScoringRouter")
|
|
||||||
self.routing_table = routing_table
|
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
|
||||||
logger.debug("ScoringRouter.initialize")
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
|
||||||
logger.debug("ScoringRouter.shutdown")
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def score_batch(
|
|
||||||
self,
|
|
||||||
dataset_id: str,
|
|
||||||
scoring_functions: dict[str, ScoringFnParams | None] = None,
|
|
||||||
save_results_dataset: bool = False,
|
|
||||||
) -> ScoreBatchResponse:
|
|
||||||
logger.debug(f"ScoringRouter.score_batch: {dataset_id}")
|
|
||||||
res = {}
|
|
||||||
for fn_identifier in scoring_functions.keys():
|
|
||||||
score_response = await self.routing_table.get_provider_impl(fn_identifier).score_batch(
|
|
||||||
dataset_id=dataset_id,
|
|
||||||
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
|
|
||||||
)
|
|
||||||
res.update(score_response.results)
|
|
||||||
|
|
||||||
if save_results_dataset:
|
|
||||||
raise NotImplementedError("Save results dataset not implemented yet")
|
|
||||||
|
|
||||||
return ScoreBatchResponse(
|
|
||||||
results=res,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def score(
|
|
||||||
self,
|
|
||||||
input_rows: list[dict[str, Any]],
|
|
||||||
scoring_functions: dict[str, ScoringFnParams | None] = None,
|
|
||||||
) -> ScoreResponse:
|
|
||||||
logger.debug(f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions")
|
|
||||||
res = {}
|
|
||||||
# look up and map each scoring function to its provider impl
|
|
||||||
for fn_identifier in scoring_functions.keys():
|
|
||||||
score_response = await self.routing_table.get_provider_impl(fn_identifier).score(
|
|
||||||
input_rows=input_rows,
|
|
||||||
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
|
|
||||||
)
|
|
||||||
res.update(score_response.results)
|
|
||||||
|
|
||||||
return ScoreResponse(results=res)
|
|
||||||
|
|
||||||
|
|
||||||
class EvalRouter(Eval):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
routing_table: RoutingTable,
|
|
||||||
) -> None:
|
|
||||||
logger.debug("Initializing EvalRouter")
|
|
||||||
self.routing_table = routing_table
|
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
|
||||||
logger.debug("EvalRouter.initialize")
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
|
||||||
logger.debug("EvalRouter.shutdown")
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def run_eval(
|
|
||||||
self,
|
|
||||||
benchmark_id: str,
|
|
||||||
benchmark_config: BenchmarkConfig,
|
|
||||||
) -> Job:
|
|
||||||
logger.debug(f"EvalRouter.run_eval: {benchmark_id}")
|
|
||||||
return await self.routing_table.get_provider_impl(benchmark_id).run_eval(
|
|
||||||
benchmark_id=benchmark_id,
|
|
||||||
benchmark_config=benchmark_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def evaluate_rows(
|
|
||||||
self,
|
|
||||||
benchmark_id: str,
|
|
||||||
input_rows: list[dict[str, Any]],
|
|
||||||
scoring_functions: list[str],
|
|
||||||
benchmark_config: BenchmarkConfig,
|
|
||||||
) -> EvaluateResponse:
|
|
||||||
logger.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows")
|
|
||||||
return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows(
|
|
||||||
benchmark_id=benchmark_id,
|
|
||||||
input_rows=input_rows,
|
|
||||||
scoring_functions=scoring_functions,
|
|
||||||
benchmark_config=benchmark_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def job_status(
|
|
||||||
self,
|
|
||||||
benchmark_id: str,
|
|
||||||
job_id: str,
|
|
||||||
) -> Job:
|
|
||||||
logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}")
|
|
||||||
return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id)
|
|
||||||
|
|
||||||
async def job_cancel(
|
|
||||||
self,
|
|
||||||
benchmark_id: str,
|
|
||||||
job_id: str,
|
|
||||||
) -> None:
|
|
||||||
logger.debug(f"EvalRouter.job_cancel: {benchmark_id}, {job_id}")
|
|
||||||
await self.routing_table.get_provider_impl(benchmark_id).job_cancel(
|
|
||||||
benchmark_id,
|
|
||||||
job_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def job_result(
|
|
||||||
self,
|
|
||||||
benchmark_id: str,
|
|
||||||
job_id: str,
|
|
||||||
) -> EvaluateResponse:
|
|
||||||
logger.debug(f"EvalRouter.job_result: {benchmark_id}, {job_id}")
|
|
||||||
return await self.routing_table.get_provider_impl(benchmark_id).job_result(
|
|
||||||
benchmark_id,
|
|
||||||
job_id,
|
|
||||||
)
|
|
|
@ -14,9 +14,14 @@ from openai.types.chat import ChatCompletionToolParam as OpenAIChatCompletionToo
|
||||||
from pydantic import Field, TypeAdapter
|
from pydantic import Field, TypeAdapter
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import (
|
from llama_stack.apis.common.content_types import (
|
||||||
|
URL,
|
||||||
InterleavedContent,
|
InterleavedContent,
|
||||||
InterleavedContentItem,
|
InterleavedContentItem,
|
||||||
)
|
)
|
||||||
|
from llama_stack.apis.common.responses import PaginatedResponse
|
||||||
|
from llama_stack.apis.datasetio import DatasetIO
|
||||||
|
from llama_stack.apis.datasets import DatasetPurpose, DataSource
|
||||||
|
from llama_stack.apis.eval import BenchmarkConfig, Eval, EvaluateResponse, Job
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
BatchChatCompletionResponse,
|
BatchChatCompletionResponse,
|
||||||
BatchCompletionResponse,
|
BatchCompletionResponse,
|
||||||
|
@ -27,11 +32,8 @@ from llama_stack.apis.inference import (
|
||||||
EmbeddingsResponse,
|
EmbeddingsResponse,
|
||||||
EmbeddingTaskType,
|
EmbeddingTaskType,
|
||||||
Inference,
|
Inference,
|
||||||
ListOpenAIChatCompletionResponse,
|
|
||||||
LogProbConfig,
|
LogProbConfig,
|
||||||
Message,
|
Message,
|
||||||
OpenAICompletionWithInputMessages,
|
|
||||||
Order,
|
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
StopReason,
|
StopReason,
|
||||||
|
@ -45,23 +47,93 @@ from llama_stack.apis.inference.inference import (
|
||||||
OpenAIChatCompletion,
|
OpenAIChatCompletion,
|
||||||
OpenAIChatCompletionChunk,
|
OpenAIChatCompletionChunk,
|
||||||
OpenAICompletion,
|
OpenAICompletion,
|
||||||
OpenAIEmbeddingsResponse,
|
|
||||||
OpenAIMessageParam,
|
OpenAIMessageParam,
|
||||||
OpenAIResponseFormatParam,
|
OpenAIResponseFormatParam,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.models import Model, ModelType
|
from llama_stack.apis.models import Model, ModelType
|
||||||
|
from llama_stack.apis.safety import RunShieldResponse, Safety
|
||||||
|
from llama_stack.apis.scoring import (
|
||||||
|
ScoreBatchResponse,
|
||||||
|
ScoreResponse,
|
||||||
|
Scoring,
|
||||||
|
ScoringFnParams,
|
||||||
|
)
|
||||||
|
from llama_stack.apis.shields import Shield
|
||||||
from llama_stack.apis.telemetry import MetricEvent, MetricInResponse, Telemetry
|
from llama_stack.apis.telemetry import MetricEvent, MetricInResponse, Telemetry
|
||||||
|
from llama_stack.apis.tools import (
|
||||||
|
ListToolDefsResponse,
|
||||||
|
RAGDocument,
|
||||||
|
RAGQueryConfig,
|
||||||
|
RAGQueryResult,
|
||||||
|
RAGToolRuntime,
|
||||||
|
ToolRuntime,
|
||||||
|
)
|
||||||
|
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.models.llama.llama3.chat_format import ChatFormat
|
from llama_stack.models.llama.llama3.chat_format import ChatFormat
|
||||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||||
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
|
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
|
||||||
from llama_stack.providers.utils.inference.inference_store import InferenceStore
|
|
||||||
from llama_stack.providers.utils.inference.stream_utils import stream_and_store_openai_completion
|
|
||||||
from llama_stack.providers.utils.telemetry.tracing import get_current_span
|
from llama_stack.providers.utils.telemetry.tracing import get_current_span
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
logger = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
|
class VectorIORouter(VectorIO):
|
||||||
|
"""Routes to an provider based on the vector db identifier"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
routing_table: RoutingTable,
|
||||||
|
) -> None:
|
||||||
|
logger.debug("Initializing VectorIORouter")
|
||||||
|
self.routing_table = routing_table
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
logger.debug("VectorIORouter.initialize")
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
logger.debug("VectorIORouter.shutdown")
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def register_vector_db(
|
||||||
|
self,
|
||||||
|
vector_db_id: str,
|
||||||
|
embedding_model: str,
|
||||||
|
embedding_dimension: int | None = 384,
|
||||||
|
provider_id: str | None = None,
|
||||||
|
provider_vector_db_id: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
logger.debug(f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}")
|
||||||
|
await self.routing_table.register_vector_db(
|
||||||
|
vector_db_id,
|
||||||
|
embedding_model,
|
||||||
|
embedding_dimension,
|
||||||
|
provider_id,
|
||||||
|
provider_vector_db_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def insert_chunks(
|
||||||
|
self,
|
||||||
|
vector_db_id: str,
|
||||||
|
chunks: list[Chunk],
|
||||||
|
ttl_seconds: int | None = None,
|
||||||
|
) -> None:
|
||||||
|
logger.debug(
|
||||||
|
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}",
|
||||||
|
)
|
||||||
|
return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(vector_db_id, chunks, ttl_seconds)
|
||||||
|
|
||||||
|
async def query_chunks(
|
||||||
|
self,
|
||||||
|
vector_db_id: str,
|
||||||
|
query: InterleavedContent,
|
||||||
|
params: dict[str, Any] | None = None,
|
||||||
|
) -> QueryChunksResponse:
|
||||||
|
logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}")
|
||||||
|
return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params)
|
||||||
|
|
||||||
|
|
||||||
class InferenceRouter(Inference):
|
class InferenceRouter(Inference):
|
||||||
"""Routes to an provider based on the model"""
|
"""Routes to an provider based on the model"""
|
||||||
|
|
||||||
|
@ -69,12 +141,10 @@ class InferenceRouter(Inference):
|
||||||
self,
|
self,
|
||||||
routing_table: RoutingTable,
|
routing_table: RoutingTable,
|
||||||
telemetry: Telemetry | None = None,
|
telemetry: Telemetry | None = None,
|
||||||
store: InferenceStore | None = None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
logger.debug("Initializing InferenceRouter")
|
logger.debug("Initializing InferenceRouter")
|
||||||
self.routing_table = routing_table
|
self.routing_table = routing_table
|
||||||
self.telemetry = telemetry
|
self.telemetry = telemetry
|
||||||
self.store = store
|
|
||||||
if self.telemetry:
|
if self.telemetry:
|
||||||
self.tokenizer = Tokenizer.get_instance()
|
self.tokenizer = Tokenizer.get_instance()
|
||||||
self.formatter = ChatFormat(self.tokenizer)
|
self.formatter = ChatFormat(self.tokenizer)
|
||||||
|
@ -537,59 +607,9 @@ class InferenceRouter(Inference):
|
||||||
|
|
||||||
provider = self.routing_table.get_provider_impl(model_obj.identifier)
|
provider = self.routing_table.get_provider_impl(model_obj.identifier)
|
||||||
if stream:
|
if stream:
|
||||||
response_stream = await provider.openai_chat_completion(**params)
|
return await provider.openai_chat_completion(**params)
|
||||||
if self.store:
|
|
||||||
return stream_and_store_openai_completion(response_stream, model, self.store, messages)
|
|
||||||
return response_stream
|
|
||||||
else:
|
else:
|
||||||
response = await self._nonstream_openai_chat_completion(provider, params)
|
return await self._nonstream_openai_chat_completion(provider, params)
|
||||||
if self.store:
|
|
||||||
await self.store.store_chat_completion(response, messages)
|
|
||||||
return response
|
|
||||||
|
|
||||||
async def openai_embeddings(
|
|
||||||
self,
|
|
||||||
model: str,
|
|
||||||
input: str | list[str],
|
|
||||||
encoding_format: str | None = "float",
|
|
||||||
dimensions: int | None = None,
|
|
||||||
user: str | None = None,
|
|
||||||
) -> OpenAIEmbeddingsResponse:
|
|
||||||
logger.debug(
|
|
||||||
f"InferenceRouter.openai_embeddings: {model=}, input_type={type(input)}, {encoding_format=}, {dimensions=}",
|
|
||||||
)
|
|
||||||
model_obj = await self.routing_table.get_model(model)
|
|
||||||
if model_obj is None:
|
|
||||||
raise ValueError(f"Model '{model}' not found")
|
|
||||||
if model_obj.model_type != ModelType.embedding:
|
|
||||||
raise ValueError(f"Model '{model}' is not an embedding model")
|
|
||||||
|
|
||||||
params = dict(
|
|
||||||
model=model_obj.identifier,
|
|
||||||
input=input,
|
|
||||||
encoding_format=encoding_format,
|
|
||||||
dimensions=dimensions,
|
|
||||||
user=user,
|
|
||||||
)
|
|
||||||
|
|
||||||
provider = self.routing_table.get_provider_impl(model_obj.identifier)
|
|
||||||
return await provider.openai_embeddings(**params)
|
|
||||||
|
|
||||||
async def list_chat_completions(
|
|
||||||
self,
|
|
||||||
after: str | None = None,
|
|
||||||
limit: int | None = 20,
|
|
||||||
model: str | None = None,
|
|
||||||
order: Order | None = Order.desc,
|
|
||||||
) -> ListOpenAIChatCompletionResponse:
|
|
||||||
if self.store:
|
|
||||||
return await self.store.list_chat_completions(after, limit, model, order)
|
|
||||||
raise NotImplementedError("List chat completions is not supported: inference store is not configured.")
|
|
||||||
|
|
||||||
async def get_chat_completion(self, completion_id: str) -> OpenAICompletionWithInputMessages:
|
|
||||||
if self.store:
|
|
||||||
return await self.store.get_chat_completion(completion_id)
|
|
||||||
raise NotImplementedError("Get chat completion is not supported: inference store is not configured.")
|
|
||||||
|
|
||||||
async def _nonstream_openai_chat_completion(self, provider: Inference, params: dict) -> OpenAIChatCompletion:
|
async def _nonstream_openai_chat_completion(self, provider: Inference, params: dict) -> OpenAIChatCompletion:
|
||||||
response = await provider.openai_chat_completion(**params)
|
response = await provider.openai_chat_completion(**params)
|
||||||
|
@ -622,3 +642,295 @@ class InferenceRouter(Inference):
|
||||||
status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}"
|
status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}"
|
||||||
)
|
)
|
||||||
return health_statuses
|
return health_statuses
|
||||||
|
|
||||||
|
|
||||||
|
class SafetyRouter(Safety):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
routing_table: RoutingTable,
|
||||||
|
) -> None:
|
||||||
|
logger.debug("Initializing SafetyRouter")
|
||||||
|
self.routing_table = routing_table
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
logger.debug("SafetyRouter.initialize")
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
logger.debug("SafetyRouter.shutdown")
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def register_shield(
|
||||||
|
self,
|
||||||
|
shield_id: str,
|
||||||
|
provider_shield_id: str | None = None,
|
||||||
|
provider_id: str | None = None,
|
||||||
|
params: dict[str, Any] | None = None,
|
||||||
|
) -> Shield:
|
||||||
|
logger.debug(f"SafetyRouter.register_shield: {shield_id}")
|
||||||
|
return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params)
|
||||||
|
|
||||||
|
async def run_shield(
|
||||||
|
self,
|
||||||
|
shield_id: str,
|
||||||
|
messages: list[Message],
|
||||||
|
params: dict[str, Any] = None,
|
||||||
|
) -> RunShieldResponse:
|
||||||
|
logger.debug(f"SafetyRouter.run_shield: {shield_id}")
|
||||||
|
return await self.routing_table.get_provider_impl(shield_id).run_shield(
|
||||||
|
shield_id=shield_id,
|
||||||
|
messages=messages,
|
||||||
|
params=params,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetIORouter(DatasetIO):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
routing_table: RoutingTable,
|
||||||
|
) -> None:
|
||||||
|
logger.debug("Initializing DatasetIORouter")
|
||||||
|
self.routing_table = routing_table
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
logger.debug("DatasetIORouter.initialize")
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
logger.debug("DatasetIORouter.shutdown")
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def register_dataset(
|
||||||
|
self,
|
||||||
|
purpose: DatasetPurpose,
|
||||||
|
source: DataSource,
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
|
dataset_id: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
logger.debug(
|
||||||
|
f"DatasetIORouter.register_dataset: {purpose=} {source=} {metadata=} {dataset_id=}",
|
||||||
|
)
|
||||||
|
await self.routing_table.register_dataset(
|
||||||
|
purpose=purpose,
|
||||||
|
source=source,
|
||||||
|
metadata=metadata,
|
||||||
|
dataset_id=dataset_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def iterrows(
|
||||||
|
self,
|
||||||
|
dataset_id: str,
|
||||||
|
start_index: int | None = None,
|
||||||
|
limit: int | None = None,
|
||||||
|
) -> PaginatedResponse:
|
||||||
|
logger.debug(
|
||||||
|
f"DatasetIORouter.iterrows: {dataset_id}, {start_index=} {limit=}",
|
||||||
|
)
|
||||||
|
return await self.routing_table.get_provider_impl(dataset_id).iterrows(
|
||||||
|
dataset_id=dataset_id,
|
||||||
|
start_index=start_index,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None:
|
||||||
|
logger.debug(f"DatasetIORouter.append_rows: {dataset_id}, {len(rows)} rows")
|
||||||
|
return await self.routing_table.get_provider_impl(dataset_id).append_rows(
|
||||||
|
dataset_id=dataset_id,
|
||||||
|
rows=rows,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ScoringRouter(Scoring):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
routing_table: RoutingTable,
|
||||||
|
) -> None:
|
||||||
|
logger.debug("Initializing ScoringRouter")
|
||||||
|
self.routing_table = routing_table
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
logger.debug("ScoringRouter.initialize")
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
logger.debug("ScoringRouter.shutdown")
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def score_batch(
|
||||||
|
self,
|
||||||
|
dataset_id: str,
|
||||||
|
scoring_functions: dict[str, ScoringFnParams | None] = None,
|
||||||
|
save_results_dataset: bool = False,
|
||||||
|
) -> ScoreBatchResponse:
|
||||||
|
logger.debug(f"ScoringRouter.score_batch: {dataset_id}")
|
||||||
|
res = {}
|
||||||
|
for fn_identifier in scoring_functions.keys():
|
||||||
|
score_response = await self.routing_table.get_provider_impl(fn_identifier).score_batch(
|
||||||
|
dataset_id=dataset_id,
|
||||||
|
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
|
||||||
|
)
|
||||||
|
res.update(score_response.results)
|
||||||
|
|
||||||
|
if save_results_dataset:
|
||||||
|
raise NotImplementedError("Save results dataset not implemented yet")
|
||||||
|
|
||||||
|
return ScoreBatchResponse(
|
||||||
|
results=res,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def score(
|
||||||
|
self,
|
||||||
|
input_rows: list[dict[str, Any]],
|
||||||
|
scoring_functions: dict[str, ScoringFnParams | None] = None,
|
||||||
|
) -> ScoreResponse:
|
||||||
|
logger.debug(f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions")
|
||||||
|
res = {}
|
||||||
|
# look up and map each scoring function to its provider impl
|
||||||
|
for fn_identifier in scoring_functions.keys():
|
||||||
|
score_response = await self.routing_table.get_provider_impl(fn_identifier).score(
|
||||||
|
input_rows=input_rows,
|
||||||
|
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
|
||||||
|
)
|
||||||
|
res.update(score_response.results)
|
||||||
|
|
||||||
|
return ScoreResponse(results=res)
|
||||||
|
|
||||||
|
|
||||||
|
class EvalRouter(Eval):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
routing_table: RoutingTable,
|
||||||
|
) -> None:
|
||||||
|
logger.debug("Initializing EvalRouter")
|
||||||
|
self.routing_table = routing_table
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
logger.debug("EvalRouter.initialize")
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
logger.debug("EvalRouter.shutdown")
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def run_eval(
|
||||||
|
self,
|
||||||
|
benchmark_id: str,
|
||||||
|
benchmark_config: BenchmarkConfig,
|
||||||
|
) -> Job:
|
||||||
|
logger.debug(f"EvalRouter.run_eval: {benchmark_id}")
|
||||||
|
return await self.routing_table.get_provider_impl(benchmark_id).run_eval(
|
||||||
|
benchmark_id=benchmark_id,
|
||||||
|
benchmark_config=benchmark_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def evaluate_rows(
|
||||||
|
self,
|
||||||
|
benchmark_id: str,
|
||||||
|
input_rows: list[dict[str, Any]],
|
||||||
|
scoring_functions: list[str],
|
||||||
|
benchmark_config: BenchmarkConfig,
|
||||||
|
) -> EvaluateResponse:
|
||||||
|
logger.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows")
|
||||||
|
return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows(
|
||||||
|
benchmark_id=benchmark_id,
|
||||||
|
input_rows=input_rows,
|
||||||
|
scoring_functions=scoring_functions,
|
||||||
|
benchmark_config=benchmark_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def job_status(
|
||||||
|
self,
|
||||||
|
benchmark_id: str,
|
||||||
|
job_id: str,
|
||||||
|
) -> Job:
|
||||||
|
logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}")
|
||||||
|
return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id)
|
||||||
|
|
||||||
|
async def job_cancel(
|
||||||
|
self,
|
||||||
|
benchmark_id: str,
|
||||||
|
job_id: str,
|
||||||
|
) -> None:
|
||||||
|
logger.debug(f"EvalRouter.job_cancel: {benchmark_id}, {job_id}")
|
||||||
|
await self.routing_table.get_provider_impl(benchmark_id).job_cancel(
|
||||||
|
benchmark_id,
|
||||||
|
job_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def job_result(
|
||||||
|
self,
|
||||||
|
benchmark_id: str,
|
||||||
|
job_id: str,
|
||||||
|
) -> EvaluateResponse:
|
||||||
|
logger.debug(f"EvalRouter.job_result: {benchmark_id}, {job_id}")
|
||||||
|
return await self.routing_table.get_provider_impl(benchmark_id).job_result(
|
||||||
|
benchmark_id,
|
||||||
|
job_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ToolRuntimeRouter(ToolRuntime):
|
||||||
|
class RagToolImpl(RAGToolRuntime):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
routing_table: RoutingTable,
|
||||||
|
) -> None:
|
||||||
|
logger.debug("Initializing ToolRuntimeRouter.RagToolImpl")
|
||||||
|
self.routing_table = routing_table
|
||||||
|
|
||||||
|
async def query(
|
||||||
|
self,
|
||||||
|
content: InterleavedContent,
|
||||||
|
vector_db_ids: list[str],
|
||||||
|
query_config: RAGQueryConfig | None = None,
|
||||||
|
) -> RAGQueryResult:
|
||||||
|
logger.debug(f"ToolRuntimeRouter.RagToolImpl.query: {vector_db_ids}")
|
||||||
|
return await self.routing_table.get_provider_impl("knowledge_search").query(
|
||||||
|
content, vector_db_ids, query_config
|
||||||
|
)
|
||||||
|
|
||||||
|
async def insert(
|
||||||
|
self,
|
||||||
|
documents: list[RAGDocument],
|
||||||
|
vector_db_id: str,
|
||||||
|
chunk_size_in_tokens: int = 512,
|
||||||
|
) -> None:
|
||||||
|
logger.debug(
|
||||||
|
f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}"
|
||||||
|
)
|
||||||
|
return await self.routing_table.get_provider_impl("insert_into_memory").insert(
|
||||||
|
documents, vector_db_id, chunk_size_in_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
routing_table: RoutingTable,
|
||||||
|
) -> None:
|
||||||
|
logger.debug("Initializing ToolRuntimeRouter")
|
||||||
|
self.routing_table = routing_table
|
||||||
|
|
||||||
|
# HACK ALERT this should be in sync with "get_all_api_endpoints()"
|
||||||
|
self.rag_tool = self.RagToolImpl(routing_table)
|
||||||
|
for method in ("query", "insert"):
|
||||||
|
setattr(self, f"rag_tool.{method}", getattr(self.rag_tool, method))
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
logger.debug("ToolRuntimeRouter.initialize")
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
logger.debug("ToolRuntimeRouter.shutdown")
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> Any:
|
||||||
|
logger.debug(f"ToolRuntimeRouter.invoke_tool: {tool_name}")
|
||||||
|
return await self.routing_table.get_provider_impl(tool_name).invoke_tool(
|
||||||
|
tool_name=tool_name,
|
||||||
|
kwargs=kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def list_runtime_tools(
|
||||||
|
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
|
||||||
|
) -> ListToolDefsResponse:
|
||||||
|
logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
|
||||||
|
return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint)
|
634
llama_stack/distribution/routers/routing_tables.py
Normal file
634
llama_stack/distribution/routers/routing_tables.py
Normal file
|
@ -0,0 +1,634 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import TypeAdapter
|
||||||
|
|
||||||
|
from llama_stack.apis.benchmarks import Benchmark, Benchmarks, ListBenchmarksResponse
|
||||||
|
from llama_stack.apis.common.content_types import URL
|
||||||
|
from llama_stack.apis.common.type_system import ParamType
|
||||||
|
from llama_stack.apis.datasets import (
|
||||||
|
Dataset,
|
||||||
|
DatasetPurpose,
|
||||||
|
Datasets,
|
||||||
|
DatasetType,
|
||||||
|
DataSource,
|
||||||
|
ListDatasetsResponse,
|
||||||
|
RowsDataSource,
|
||||||
|
URIDataSource,
|
||||||
|
)
|
||||||
|
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType, OpenAIListModelsResponse, OpenAIModel
|
||||||
|
from llama_stack.apis.resource import ResourceType
|
||||||
|
from llama_stack.apis.scoring_functions import (
|
||||||
|
ListScoringFunctionsResponse,
|
||||||
|
ScoringFn,
|
||||||
|
ScoringFnParams,
|
||||||
|
ScoringFunctions,
|
||||||
|
)
|
||||||
|
from llama_stack.apis.shields import ListShieldsResponse, Shield, Shields
|
||||||
|
from llama_stack.apis.tools import (
|
||||||
|
ListToolGroupsResponse,
|
||||||
|
ListToolsResponse,
|
||||||
|
Tool,
|
||||||
|
ToolGroup,
|
||||||
|
ToolGroups,
|
||||||
|
ToolHost,
|
||||||
|
)
|
||||||
|
from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs
|
||||||
|
from llama_stack.distribution.access_control import check_access
|
||||||
|
from llama_stack.distribution.datatypes import (
|
||||||
|
AccessAttributes,
|
||||||
|
BenchmarkWithACL,
|
||||||
|
DatasetWithACL,
|
||||||
|
ModelWithACL,
|
||||||
|
RoutableObject,
|
||||||
|
RoutableObjectWithProvider,
|
||||||
|
RoutedProtocol,
|
||||||
|
ScoringFnWithACL,
|
||||||
|
ShieldWithACL,
|
||||||
|
ToolGroupWithACL,
|
||||||
|
ToolWithACL,
|
||||||
|
VectorDBWithACL,
|
||||||
|
)
|
||||||
|
from llama_stack.distribution.request_headers import get_auth_attributes
|
||||||
|
from llama_stack.distribution.store import DistributionRegistry
|
||||||
|
from llama_stack.providers.datatypes import Api, RoutingTable
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def get_impl_api(p: Any) -> Api:
|
||||||
|
return p.__provider_spec__.api
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: this should return the registered object for all APIs
|
||||||
|
async def register_object_with_provider(obj: RoutableObject, p: Any) -> RoutableObject:
|
||||||
|
api = get_impl_api(p)
|
||||||
|
|
||||||
|
assert obj.provider_id != "remote", "Remote provider should not be registered"
|
||||||
|
|
||||||
|
if api == Api.inference:
|
||||||
|
return await p.register_model(obj)
|
||||||
|
elif api == Api.safety:
|
||||||
|
return await p.register_shield(obj)
|
||||||
|
elif api == Api.vector_io:
|
||||||
|
return await p.register_vector_db(obj)
|
||||||
|
elif api == Api.datasetio:
|
||||||
|
return await p.register_dataset(obj)
|
||||||
|
elif api == Api.scoring:
|
||||||
|
return await p.register_scoring_function(obj)
|
||||||
|
elif api == Api.eval:
|
||||||
|
return await p.register_benchmark(obj)
|
||||||
|
elif api == Api.tool_runtime:
|
||||||
|
return await p.register_tool(obj)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown API {api} for registering object with provider")
|
||||||
|
|
||||||
|
|
||||||
|
async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None:
|
||||||
|
api = get_impl_api(p)
|
||||||
|
if api == Api.vector_io:
|
||||||
|
return await p.unregister_vector_db(obj.identifier)
|
||||||
|
elif api == Api.inference:
|
||||||
|
return await p.unregister_model(obj.identifier)
|
||||||
|
elif api == Api.datasetio:
|
||||||
|
return await p.unregister_dataset(obj.identifier)
|
||||||
|
elif api == Api.tool_runtime:
|
||||||
|
return await p.unregister_tool(obj.identifier)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unregister not supported for {api}")
|
||||||
|
|
||||||
|
|
||||||
|
Registry = dict[str, list[RoutableObjectWithProvider]]
|
||||||
|
|
||||||
|
|
||||||
|
class CommonRoutingTableImpl(RoutingTable):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
impls_by_provider_id: dict[str, RoutedProtocol],
|
||||||
|
dist_registry: DistributionRegistry,
|
||||||
|
) -> None:
|
||||||
|
self.impls_by_provider_id = impls_by_provider_id
|
||||||
|
self.dist_registry = dist_registry
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
async def add_objects(objs: list[RoutableObjectWithProvider], provider_id: str, cls) -> None:
|
||||||
|
for obj in objs:
|
||||||
|
if cls is None:
|
||||||
|
obj.provider_id = provider_id
|
||||||
|
else:
|
||||||
|
# Create a copy of the model data and explicitly set provider_id
|
||||||
|
model_data = obj.model_dump()
|
||||||
|
model_data["provider_id"] = provider_id
|
||||||
|
obj = cls(**model_data)
|
||||||
|
await self.dist_registry.register(obj)
|
||||||
|
|
||||||
|
# Register all objects from providers
|
||||||
|
for pid, p in self.impls_by_provider_id.items():
|
||||||
|
api = get_impl_api(p)
|
||||||
|
if api == Api.inference:
|
||||||
|
p.model_store = self
|
||||||
|
elif api == Api.safety:
|
||||||
|
p.shield_store = self
|
||||||
|
elif api == Api.vector_io:
|
||||||
|
p.vector_db_store = self
|
||||||
|
elif api == Api.datasetio:
|
||||||
|
p.dataset_store = self
|
||||||
|
elif api == Api.scoring:
|
||||||
|
p.scoring_function_store = self
|
||||||
|
scoring_functions = await p.list_scoring_functions()
|
||||||
|
await add_objects(scoring_functions, pid, ScoringFn)
|
||||||
|
elif api == Api.eval:
|
||||||
|
p.benchmark_store = self
|
||||||
|
elif api == Api.tool_runtime:
|
||||||
|
p.tool_store = self
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
for p in self.impls_by_provider_id.values():
|
||||||
|
await p.shutdown()
|
||||||
|
|
||||||
|
def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any:
|
||||||
|
def apiname_object():
|
||||||
|
if isinstance(self, ModelsRoutingTable):
|
||||||
|
return ("Inference", "model")
|
||||||
|
elif isinstance(self, ShieldsRoutingTable):
|
||||||
|
return ("Safety", "shield")
|
||||||
|
elif isinstance(self, VectorDBsRoutingTable):
|
||||||
|
return ("VectorIO", "vector_db")
|
||||||
|
elif isinstance(self, DatasetsRoutingTable):
|
||||||
|
return ("DatasetIO", "dataset")
|
||||||
|
elif isinstance(self, ScoringFunctionsRoutingTable):
|
||||||
|
return ("Scoring", "scoring_function")
|
||||||
|
elif isinstance(self, BenchmarksRoutingTable):
|
||||||
|
return ("Eval", "benchmark")
|
||||||
|
elif isinstance(self, ToolGroupsRoutingTable):
|
||||||
|
return ("Tools", "tool")
|
||||||
|
else:
|
||||||
|
raise ValueError("Unknown routing table type")
|
||||||
|
|
||||||
|
apiname, objtype = apiname_object()
|
||||||
|
|
||||||
|
# Get objects from disk registry
|
||||||
|
obj = self.dist_registry.get_cached(objtype, routing_key)
|
||||||
|
if not obj:
|
||||||
|
provider_ids = list(self.impls_by_provider_id.keys())
|
||||||
|
if len(provider_ids) > 1:
|
||||||
|
provider_ids_str = f"any of the providers: {', '.join(provider_ids)}"
|
||||||
|
else:
|
||||||
|
provider_ids_str = f"provider: `{provider_ids[0]}`"
|
||||||
|
raise ValueError(
|
||||||
|
f"{objtype.capitalize()} `{routing_key}` not served by {provider_ids_str}. Make sure there is an {apiname} provider serving this {objtype}."
|
||||||
|
)
|
||||||
|
|
||||||
|
if not provider_id or provider_id == obj.provider_id:
|
||||||
|
return self.impls_by_provider_id[obj.provider_id]
|
||||||
|
|
||||||
|
raise ValueError(f"Provider not found for `{routing_key}`")
|
||||||
|
|
||||||
|
async def get_object_by_identifier(self, type: str, identifier: str) -> RoutableObjectWithProvider | None:
|
||||||
|
# Get from disk registry
|
||||||
|
obj = await self.dist_registry.get(type, identifier)
|
||||||
|
if not obj:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Check if user has permission to access this object
|
||||||
|
if not check_access(obj.identifier, getattr(obj, "access_attributes", None), get_auth_attributes()):
|
||||||
|
logger.debug(f"Access denied to {type} '{identifier}' based on attribute mismatch")
|
||||||
|
return None
|
||||||
|
|
||||||
|
return obj
|
||||||
|
|
||||||
|
async def unregister_object(self, obj: RoutableObjectWithProvider) -> None:
|
||||||
|
await self.dist_registry.delete(obj.type, obj.identifier)
|
||||||
|
await unregister_object_from_provider(obj, self.impls_by_provider_id[obj.provider_id])
|
||||||
|
|
||||||
|
async def register_object(self, obj: RoutableObjectWithProvider) -> RoutableObjectWithProvider:
|
||||||
|
# if provider_id is not specified, pick an arbitrary one from existing entries
|
||||||
|
if not obj.provider_id and len(self.impls_by_provider_id) > 0:
|
||||||
|
obj.provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||||
|
|
||||||
|
if obj.provider_id not in self.impls_by_provider_id:
|
||||||
|
raise ValueError(f"Provider `{obj.provider_id}` not found")
|
||||||
|
|
||||||
|
p = self.impls_by_provider_id[obj.provider_id]
|
||||||
|
|
||||||
|
# If object supports access control but no attributes set, use creator's attributes
|
||||||
|
if not obj.access_attributes:
|
||||||
|
creator_attributes = get_auth_attributes()
|
||||||
|
if creator_attributes:
|
||||||
|
obj.access_attributes = AccessAttributes(**creator_attributes)
|
||||||
|
logger.info(f"Setting access attributes for {obj.type} '{obj.identifier}' based on creator's identity")
|
||||||
|
|
||||||
|
registered_obj = await register_object_with_provider(obj, p)
|
||||||
|
# TODO: This needs to be fixed for all APIs once they return the registered object
|
||||||
|
if obj.type == ResourceType.model.value:
|
||||||
|
await self.dist_registry.register(registered_obj)
|
||||||
|
return registered_obj
|
||||||
|
|
||||||
|
else:
|
||||||
|
await self.dist_registry.register(obj)
|
||||||
|
return obj
|
||||||
|
|
||||||
|
async def get_all_with_type(self, type: str) -> list[RoutableObjectWithProvider]:
|
||||||
|
objs = await self.dist_registry.get_all()
|
||||||
|
filtered_objs = [obj for obj in objs if obj.type == type]
|
||||||
|
|
||||||
|
# Apply attribute-based access control filtering
|
||||||
|
if filtered_objs:
|
||||||
|
filtered_objs = [
|
||||||
|
obj
|
||||||
|
for obj in filtered_objs
|
||||||
|
if check_access(obj.identifier, getattr(obj, "access_attributes", None), get_auth_attributes())
|
||||||
|
]
|
||||||
|
|
||||||
|
return filtered_objs
|
||||||
|
|
||||||
|
|
||||||
|
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
|
async def list_models(self) -> ListModelsResponse:
|
||||||
|
return ListModelsResponse(data=await self.get_all_with_type("model"))
|
||||||
|
|
||||||
|
async def openai_list_models(self) -> OpenAIListModelsResponse:
|
||||||
|
models = await self.get_all_with_type("model")
|
||||||
|
openai_models = [
|
||||||
|
OpenAIModel(
|
||||||
|
id=model.identifier,
|
||||||
|
object="model",
|
||||||
|
created=int(time.time()),
|
||||||
|
owned_by="llama_stack",
|
||||||
|
)
|
||||||
|
for model in models
|
||||||
|
]
|
||||||
|
return OpenAIListModelsResponse(data=openai_models)
|
||||||
|
|
||||||
|
async def get_model(self, model_id: str) -> Model:
|
||||||
|
model = await self.get_object_by_identifier("model", model_id)
|
||||||
|
if model is None:
|
||||||
|
raise ValueError(f"Model '{model_id}' not found")
|
||||||
|
return model
|
||||||
|
|
||||||
|
async def register_model(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
provider_model_id: str | None = None,
|
||||||
|
provider_id: str | None = None,
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
|
model_type: ModelType | None = None,
|
||||||
|
) -> Model:
|
||||||
|
if provider_model_id is None:
|
||||||
|
provider_model_id = model_id
|
||||||
|
if provider_id is None:
|
||||||
|
# If provider_id not specified, use the only provider if it supports this model
|
||||||
|
if len(self.impls_by_provider_id) == 1:
|
||||||
|
provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"No provider specified and multiple providers available. Please specify a provider_id. Available providers: {self.impls_by_provider_id.keys()}"
|
||||||
|
)
|
||||||
|
if metadata is None:
|
||||||
|
metadata = {}
|
||||||
|
if model_type is None:
|
||||||
|
model_type = ModelType.llm
|
||||||
|
if "embedding_dimension" not in metadata and model_type == ModelType.embedding:
|
||||||
|
raise ValueError("Embedding model must have an embedding dimension in its metadata")
|
||||||
|
model = ModelWithACL(
|
||||||
|
identifier=model_id,
|
||||||
|
provider_resource_id=provider_model_id,
|
||||||
|
provider_id=provider_id,
|
||||||
|
metadata=metadata,
|
||||||
|
model_type=model_type,
|
||||||
|
)
|
||||||
|
registered_model = await self.register_object(model)
|
||||||
|
return registered_model
|
||||||
|
|
||||||
|
async def unregister_model(self, model_id: str) -> None:
|
||||||
|
existing_model = await self.get_model(model_id)
|
||||||
|
if existing_model is None:
|
||||||
|
raise ValueError(f"Model {model_id} not found")
|
||||||
|
await self.unregister_object(existing_model)
|
||||||
|
|
||||||
|
|
||||||
|
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
||||||
|
async def list_shields(self) -> ListShieldsResponse:
|
||||||
|
return ListShieldsResponse(data=await self.get_all_with_type(ResourceType.shield.value))
|
||||||
|
|
||||||
|
async def get_shield(self, identifier: str) -> Shield:
|
||||||
|
shield = await self.get_object_by_identifier("shield", identifier)
|
||||||
|
if shield is None:
|
||||||
|
raise ValueError(f"Shield '{identifier}' not found")
|
||||||
|
return shield
|
||||||
|
|
||||||
|
async def register_shield(
|
||||||
|
self,
|
||||||
|
shield_id: str,
|
||||||
|
provider_shield_id: str | None = None,
|
||||||
|
provider_id: str | None = None,
|
||||||
|
params: dict[str, Any] | None = None,
|
||||||
|
) -> Shield:
|
||||||
|
if provider_shield_id is None:
|
||||||
|
provider_shield_id = shield_id
|
||||||
|
if provider_id is None:
|
||||||
|
# If provider_id not specified, use the only provider if it supports this shield type
|
||||||
|
if len(self.impls_by_provider_id) == 1:
|
||||||
|
provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"No provider specified and multiple providers available. Please specify a provider_id."
|
||||||
|
)
|
||||||
|
if params is None:
|
||||||
|
params = {}
|
||||||
|
shield = ShieldWithACL(
|
||||||
|
identifier=shield_id,
|
||||||
|
provider_resource_id=provider_shield_id,
|
||||||
|
provider_id=provider_id,
|
||||||
|
params=params,
|
||||||
|
)
|
||||||
|
await self.register_object(shield)
|
||||||
|
return shield
|
||||||
|
|
||||||
|
|
||||||
|
class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
||||||
|
async def list_vector_dbs(self) -> ListVectorDBsResponse:
|
||||||
|
return ListVectorDBsResponse(data=await self.get_all_with_type("vector_db"))
|
||||||
|
|
||||||
|
async def get_vector_db(self, vector_db_id: str) -> VectorDB:
|
||||||
|
vector_db = await self.get_object_by_identifier("vector_db", vector_db_id)
|
||||||
|
if vector_db is None:
|
||||||
|
raise ValueError(f"Vector DB '{vector_db_id}' not found")
|
||||||
|
return vector_db
|
||||||
|
|
||||||
|
async def register_vector_db(
|
||||||
|
self,
|
||||||
|
vector_db_id: str,
|
||||||
|
embedding_model: str,
|
||||||
|
embedding_dimension: int | None = 384,
|
||||||
|
provider_id: str | None = None,
|
||||||
|
provider_vector_db_id: str | None = None,
|
||||||
|
) -> VectorDB:
|
||||||
|
if provider_vector_db_id is None:
|
||||||
|
provider_vector_db_id = vector_db_id
|
||||||
|
if provider_id is None:
|
||||||
|
if len(self.impls_by_provider_id) > 0:
|
||||||
|
provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||||
|
if len(self.impls_by_provider_id) > 1:
|
||||||
|
logger.warning(
|
||||||
|
f"No provider specified and multiple providers available. Arbitrarily selected the first provider {provider_id}."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("No provider available. Please configure a vector_io provider.")
|
||||||
|
model = await self.get_object_by_identifier("model", embedding_model)
|
||||||
|
if model is None:
|
||||||
|
raise ValueError(f"Model {embedding_model} not found")
|
||||||
|
if model.model_type != ModelType.embedding:
|
||||||
|
raise ValueError(f"Model {embedding_model} is not an embedding model")
|
||||||
|
if "embedding_dimension" not in model.metadata:
|
||||||
|
raise ValueError(f"Model {embedding_model} does not have an embedding dimension")
|
||||||
|
vector_db_data = {
|
||||||
|
"identifier": vector_db_id,
|
||||||
|
"type": ResourceType.vector_db.value,
|
||||||
|
"provider_id": provider_id,
|
||||||
|
"provider_resource_id": provider_vector_db_id,
|
||||||
|
"embedding_model": embedding_model,
|
||||||
|
"embedding_dimension": model.metadata["embedding_dimension"],
|
||||||
|
}
|
||||||
|
vector_db = TypeAdapter(VectorDBWithACL).validate_python(vector_db_data)
|
||||||
|
await self.register_object(vector_db)
|
||||||
|
return vector_db
|
||||||
|
|
||||||
|
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||||
|
existing_vector_db = await self.get_vector_db(vector_db_id)
|
||||||
|
if existing_vector_db is None:
|
||||||
|
raise ValueError(f"Vector DB {vector_db_id} not found")
|
||||||
|
await self.unregister_object(existing_vector_db)
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
||||||
|
async def list_datasets(self) -> ListDatasetsResponse:
|
||||||
|
return ListDatasetsResponse(data=await self.get_all_with_type(ResourceType.dataset.value))
|
||||||
|
|
||||||
|
async def get_dataset(self, dataset_id: str) -> Dataset:
|
||||||
|
dataset = await self.get_object_by_identifier("dataset", dataset_id)
|
||||||
|
if dataset is None:
|
||||||
|
raise ValueError(f"Dataset '{dataset_id}' not found")
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
async def register_dataset(
|
||||||
|
self,
|
||||||
|
purpose: DatasetPurpose,
|
||||||
|
source: DataSource,
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
|
dataset_id: str | None = None,
|
||||||
|
) -> Dataset:
|
||||||
|
if isinstance(source, dict):
|
||||||
|
if source["type"] == "uri":
|
||||||
|
source = URIDataSource.parse_obj(source)
|
||||||
|
elif source["type"] == "rows":
|
||||||
|
source = RowsDataSource.parse_obj(source)
|
||||||
|
|
||||||
|
if not dataset_id:
|
||||||
|
dataset_id = f"dataset-{str(uuid.uuid4())}"
|
||||||
|
|
||||||
|
provider_dataset_id = dataset_id
|
||||||
|
|
||||||
|
# infer provider from source
|
||||||
|
if metadata:
|
||||||
|
if metadata.get("provider_id"):
|
||||||
|
provider_id = metadata.get("provider_id") # pass through from nvidia datasetio
|
||||||
|
elif source.type == DatasetType.rows.value:
|
||||||
|
provider_id = "localfs"
|
||||||
|
elif source.type == DatasetType.uri.value:
|
||||||
|
# infer provider from uri
|
||||||
|
if source.uri.startswith("huggingface"):
|
||||||
|
provider_id = "huggingface"
|
||||||
|
else:
|
||||||
|
provider_id = "localfs"
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown data source type: {source.type}")
|
||||||
|
|
||||||
|
if metadata is None:
|
||||||
|
metadata = {}
|
||||||
|
|
||||||
|
dataset = DatasetWithACL(
|
||||||
|
identifier=dataset_id,
|
||||||
|
provider_resource_id=provider_dataset_id,
|
||||||
|
provider_id=provider_id,
|
||||||
|
purpose=purpose,
|
||||||
|
source=source,
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
await self.register_object(dataset)
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
async def unregister_dataset(self, dataset_id: str) -> None:
|
||||||
|
dataset = await self.get_dataset(dataset_id)
|
||||||
|
if dataset is None:
|
||||||
|
raise ValueError(f"Dataset {dataset_id} not found")
|
||||||
|
await self.unregister_object(dataset)
|
||||||
|
|
||||||
|
|
||||||
|
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
|
||||||
|
async def list_scoring_functions(self) -> ListScoringFunctionsResponse:
|
||||||
|
return ListScoringFunctionsResponse(data=await self.get_all_with_type(ResourceType.scoring_function.value))
|
||||||
|
|
||||||
|
async def get_scoring_function(self, scoring_fn_id: str) -> ScoringFn:
|
||||||
|
scoring_fn = await self.get_object_by_identifier("scoring_function", scoring_fn_id)
|
||||||
|
if scoring_fn is None:
|
||||||
|
raise ValueError(f"Scoring function '{scoring_fn_id}' not found")
|
||||||
|
return scoring_fn
|
||||||
|
|
||||||
|
async def register_scoring_function(
|
||||||
|
self,
|
||||||
|
scoring_fn_id: str,
|
||||||
|
description: str,
|
||||||
|
return_type: ParamType,
|
||||||
|
provider_scoring_fn_id: str | None = None,
|
||||||
|
provider_id: str | None = None,
|
||||||
|
params: ScoringFnParams | None = None,
|
||||||
|
) -> None:
|
||||||
|
if provider_scoring_fn_id is None:
|
||||||
|
provider_scoring_fn_id = scoring_fn_id
|
||||||
|
if provider_id is None:
|
||||||
|
if len(self.impls_by_provider_id) == 1:
|
||||||
|
provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"No provider specified and multiple providers available. Please specify a provider_id."
|
||||||
|
)
|
||||||
|
scoring_fn = ScoringFnWithACL(
|
||||||
|
identifier=scoring_fn_id,
|
||||||
|
description=description,
|
||||||
|
return_type=return_type,
|
||||||
|
provider_resource_id=provider_scoring_fn_id,
|
||||||
|
provider_id=provider_id,
|
||||||
|
params=params,
|
||||||
|
)
|
||||||
|
scoring_fn.provider_id = provider_id
|
||||||
|
await self.register_object(scoring_fn)
|
||||||
|
|
||||||
|
|
||||||
|
class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
|
||||||
|
async def list_benchmarks(self) -> ListBenchmarksResponse:
|
||||||
|
return ListBenchmarksResponse(data=await self.get_all_with_type("benchmark"))
|
||||||
|
|
||||||
|
async def get_benchmark(self, benchmark_id: str) -> Benchmark:
|
||||||
|
benchmark = await self.get_object_by_identifier("benchmark", benchmark_id)
|
||||||
|
if benchmark is None:
|
||||||
|
raise ValueError(f"Benchmark '{benchmark_id}' not found")
|
||||||
|
return benchmark
|
||||||
|
|
||||||
|
async def register_benchmark(
|
||||||
|
self,
|
||||||
|
benchmark_id: str,
|
||||||
|
dataset_id: str,
|
||||||
|
scoring_functions: list[str],
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
|
provider_benchmark_id: str | None = None,
|
||||||
|
provider_id: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
if metadata is None:
|
||||||
|
metadata = {}
|
||||||
|
if provider_id is None:
|
||||||
|
if len(self.impls_by_provider_id) == 1:
|
||||||
|
provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"No provider specified and multiple providers available. Please specify a provider_id."
|
||||||
|
)
|
||||||
|
if provider_benchmark_id is None:
|
||||||
|
provider_benchmark_id = benchmark_id
|
||||||
|
benchmark = BenchmarkWithACL(
|
||||||
|
identifier=benchmark_id,
|
||||||
|
dataset_id=dataset_id,
|
||||||
|
scoring_functions=scoring_functions,
|
||||||
|
metadata=metadata,
|
||||||
|
provider_id=provider_id,
|
||||||
|
provider_resource_id=provider_benchmark_id,
|
||||||
|
)
|
||||||
|
await self.register_object(benchmark)
|
||||||
|
|
||||||
|
|
||||||
|
class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
||||||
|
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse:
|
||||||
|
tools = await self.get_all_with_type("tool")
|
||||||
|
if toolgroup_id:
|
||||||
|
tools = [tool for tool in tools if tool.toolgroup_id == toolgroup_id]
|
||||||
|
return ListToolsResponse(data=tools)
|
||||||
|
|
||||||
|
async def list_tool_groups(self) -> ListToolGroupsResponse:
|
||||||
|
return ListToolGroupsResponse(data=await self.get_all_with_type("tool_group"))
|
||||||
|
|
||||||
|
async def get_tool_group(self, toolgroup_id: str) -> ToolGroup:
|
||||||
|
tool_group = await self.get_object_by_identifier("tool_group", toolgroup_id)
|
||||||
|
if tool_group is None:
|
||||||
|
raise ValueError(f"Tool group '{toolgroup_id}' not found")
|
||||||
|
return tool_group
|
||||||
|
|
||||||
|
async def get_tool(self, tool_name: str) -> Tool:
|
||||||
|
return await self.get_object_by_identifier("tool", tool_name)
|
||||||
|
|
||||||
|
async def register_tool_group(
|
||||||
|
self,
|
||||||
|
toolgroup_id: str,
|
||||||
|
provider_id: str,
|
||||||
|
mcp_endpoint: URL | None = None,
|
||||||
|
args: dict[str, Any] | None = None,
|
||||||
|
) -> None:
|
||||||
|
tools = []
|
||||||
|
tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools(toolgroup_id, mcp_endpoint)
|
||||||
|
tool_host = ToolHost.model_context_protocol if mcp_endpoint else ToolHost.distribution
|
||||||
|
|
||||||
|
for tool_def in tool_defs.data:
|
||||||
|
tools.append(
|
||||||
|
ToolWithACL(
|
||||||
|
identifier=tool_def.name,
|
||||||
|
toolgroup_id=toolgroup_id,
|
||||||
|
description=tool_def.description or "",
|
||||||
|
parameters=tool_def.parameters or [],
|
||||||
|
provider_id=provider_id,
|
||||||
|
provider_resource_id=tool_def.name,
|
||||||
|
metadata=tool_def.metadata,
|
||||||
|
tool_host=tool_host,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for tool in tools:
|
||||||
|
existing_tool = await self.get_tool(tool.identifier)
|
||||||
|
# Compare existing and new object if one exists
|
||||||
|
if existing_tool:
|
||||||
|
existing_dict = existing_tool.model_dump()
|
||||||
|
new_dict = tool.model_dump()
|
||||||
|
|
||||||
|
if existing_dict != new_dict:
|
||||||
|
raise ValueError(
|
||||||
|
f"Object {tool.identifier} already exists in registry. Please use a different identifier."
|
||||||
|
)
|
||||||
|
await self.register_object(tool)
|
||||||
|
|
||||||
|
await self.dist_registry.register(
|
||||||
|
ToolGroupWithACL(
|
||||||
|
identifier=toolgroup_id,
|
||||||
|
provider_id=provider_id,
|
||||||
|
provider_resource_id=toolgroup_id,
|
||||||
|
mcp_endpoint=mcp_endpoint,
|
||||||
|
args=args,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
|
||||||
|
tool_group = await self.get_tool_group(toolgroup_id)
|
||||||
|
if tool_group is None:
|
||||||
|
raise ValueError(f"Tool group {toolgroup_id} not found")
|
||||||
|
tools = await self.list_tools(toolgroup_id)
|
||||||
|
for tool in getattr(tools, "data", []):
|
||||||
|
await self.unregister_object(tool)
|
||||||
|
await self.unregister_object(tool_group)
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
pass
|
|
@ -1,57 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from llama_stack.apis.inference import (
|
|
||||||
Message,
|
|
||||||
)
|
|
||||||
from llama_stack.apis.safety import RunShieldResponse, Safety
|
|
||||||
from llama_stack.apis.shields import Shield
|
|
||||||
from llama_stack.log import get_logger
|
|
||||||
from llama_stack.providers.datatypes import RoutingTable
|
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
|
||||||
|
|
||||||
|
|
||||||
class SafetyRouter(Safety):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
routing_table: RoutingTable,
|
|
||||||
) -> None:
|
|
||||||
logger.debug("Initializing SafetyRouter")
|
|
||||||
self.routing_table = routing_table
|
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
|
||||||
logger.debug("SafetyRouter.initialize")
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
|
||||||
logger.debug("SafetyRouter.shutdown")
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def register_shield(
|
|
||||||
self,
|
|
||||||
shield_id: str,
|
|
||||||
provider_shield_id: str | None = None,
|
|
||||||
provider_id: str | None = None,
|
|
||||||
params: dict[str, Any] | None = None,
|
|
||||||
) -> Shield:
|
|
||||||
logger.debug(f"SafetyRouter.register_shield: {shield_id}")
|
|
||||||
return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params)
|
|
||||||
|
|
||||||
async def run_shield(
|
|
||||||
self,
|
|
||||||
shield_id: str,
|
|
||||||
messages: list[Message],
|
|
||||||
params: dict[str, Any] = None,
|
|
||||||
) -> RunShieldResponse:
|
|
||||||
logger.debug(f"SafetyRouter.run_shield: {shield_id}")
|
|
||||||
return await self.routing_table.get_provider_impl(shield_id).run_shield(
|
|
||||||
shield_id=shield_id,
|
|
||||||
messages=messages,
|
|
||||||
params=params,
|
|
||||||
)
|
|
|
@ -1,92 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import (
|
|
||||||
URL,
|
|
||||||
InterleavedContent,
|
|
||||||
)
|
|
||||||
from llama_stack.apis.tools import (
|
|
||||||
ListToolsResponse,
|
|
||||||
RAGDocument,
|
|
||||||
RAGQueryConfig,
|
|
||||||
RAGQueryResult,
|
|
||||||
RAGToolRuntime,
|
|
||||||
ToolRuntime,
|
|
||||||
)
|
|
||||||
from llama_stack.log import get_logger
|
|
||||||
|
|
||||||
from ..routing_tables.toolgroups import ToolGroupsRoutingTable
|
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
|
||||||
|
|
||||||
|
|
||||||
class ToolRuntimeRouter(ToolRuntime):
|
|
||||||
class RagToolImpl(RAGToolRuntime):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
routing_table: ToolGroupsRoutingTable,
|
|
||||||
) -> None:
|
|
||||||
logger.debug("Initializing ToolRuntimeRouter.RagToolImpl")
|
|
||||||
self.routing_table = routing_table
|
|
||||||
|
|
||||||
async def query(
|
|
||||||
self,
|
|
||||||
content: InterleavedContent,
|
|
||||||
vector_db_ids: list[str],
|
|
||||||
query_config: RAGQueryConfig | None = None,
|
|
||||||
) -> RAGQueryResult:
|
|
||||||
logger.debug(f"ToolRuntimeRouter.RagToolImpl.query: {vector_db_ids}")
|
|
||||||
return await self.routing_table.get_provider_impl("knowledge_search").query(
|
|
||||||
content, vector_db_ids, query_config
|
|
||||||
)
|
|
||||||
|
|
||||||
async def insert(
|
|
||||||
self,
|
|
||||||
documents: list[RAGDocument],
|
|
||||||
vector_db_id: str,
|
|
||||||
chunk_size_in_tokens: int = 512,
|
|
||||||
) -> None:
|
|
||||||
logger.debug(
|
|
||||||
f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}"
|
|
||||||
)
|
|
||||||
return await self.routing_table.get_provider_impl("insert_into_memory").insert(
|
|
||||||
documents, vector_db_id, chunk_size_in_tokens
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
routing_table: ToolGroupsRoutingTable,
|
|
||||||
) -> None:
|
|
||||||
logger.debug("Initializing ToolRuntimeRouter")
|
|
||||||
self.routing_table = routing_table
|
|
||||||
|
|
||||||
# HACK ALERT this should be in sync with "get_all_api_endpoints()"
|
|
||||||
self.rag_tool = self.RagToolImpl(routing_table)
|
|
||||||
for method in ("query", "insert"):
|
|
||||||
setattr(self, f"rag_tool.{method}", getattr(self.rag_tool, method))
|
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
|
||||||
logger.debug("ToolRuntimeRouter.initialize")
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
|
||||||
logger.debug("ToolRuntimeRouter.shutdown")
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> Any:
|
|
||||||
logger.debug(f"ToolRuntimeRouter.invoke_tool: {tool_name}")
|
|
||||||
return await self.routing_table.get_provider_impl(tool_name).invoke_tool(
|
|
||||||
tool_name=tool_name,
|
|
||||||
kwargs=kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def list_runtime_tools(
|
|
||||||
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
|
|
||||||
) -> ListToolsResponse:
|
|
||||||
logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
|
|
||||||
return await self.routing_table.list_tools(tool_group_id)
|
|
|
@ -1,72 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import (
|
|
||||||
InterleavedContent,
|
|
||||||
)
|
|
||||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
|
||||||
from llama_stack.log import get_logger
|
|
||||||
from llama_stack.providers.datatypes import RoutingTable
|
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
|
||||||
|
|
||||||
|
|
||||||
class VectorIORouter(VectorIO):
|
|
||||||
"""Routes to an provider based on the vector db identifier"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
routing_table: RoutingTable,
|
|
||||||
) -> None:
|
|
||||||
logger.debug("Initializing VectorIORouter")
|
|
||||||
self.routing_table = routing_table
|
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
|
||||||
logger.debug("VectorIORouter.initialize")
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
|
||||||
logger.debug("VectorIORouter.shutdown")
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def register_vector_db(
|
|
||||||
self,
|
|
||||||
vector_db_id: str,
|
|
||||||
embedding_model: str,
|
|
||||||
embedding_dimension: int | None = 384,
|
|
||||||
provider_id: str | None = None,
|
|
||||||
provider_vector_db_id: str | None = None,
|
|
||||||
) -> None:
|
|
||||||
logger.debug(f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}")
|
|
||||||
await self.routing_table.register_vector_db(
|
|
||||||
vector_db_id,
|
|
||||||
embedding_model,
|
|
||||||
embedding_dimension,
|
|
||||||
provider_id,
|
|
||||||
provider_vector_db_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def insert_chunks(
|
|
||||||
self,
|
|
||||||
vector_db_id: str,
|
|
||||||
chunks: list[Chunk],
|
|
||||||
ttl_seconds: int | None = None,
|
|
||||||
) -> None:
|
|
||||||
logger.debug(
|
|
||||||
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}",
|
|
||||||
)
|
|
||||||
return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(vector_db_id, chunks, ttl_seconds)
|
|
||||||
|
|
||||||
async def query_chunks(
|
|
||||||
self,
|
|
||||||
vector_db_id: str,
|
|
||||||
query: InterleavedContent,
|
|
||||||
params: dict[str, Any] | None = None,
|
|
||||||
) -> QueryChunksResponse:
|
|
||||||
logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}")
|
|
||||||
return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params)
|
|
|
@ -1,5 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
|
@ -1,58 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from llama_stack.apis.benchmarks import Benchmark, Benchmarks, ListBenchmarksResponse
|
|
||||||
from llama_stack.distribution.datatypes import (
|
|
||||||
BenchmarkWithACL,
|
|
||||||
)
|
|
||||||
from llama_stack.log import get_logger
|
|
||||||
|
|
||||||
from .common import CommonRoutingTableImpl
|
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
|
||||||
|
|
||||||
|
|
||||||
class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
|
|
||||||
async def list_benchmarks(self) -> ListBenchmarksResponse:
|
|
||||||
return ListBenchmarksResponse(data=await self.get_all_with_type("benchmark"))
|
|
||||||
|
|
||||||
async def get_benchmark(self, benchmark_id: str) -> Benchmark:
|
|
||||||
benchmark = await self.get_object_by_identifier("benchmark", benchmark_id)
|
|
||||||
if benchmark is None:
|
|
||||||
raise ValueError(f"Benchmark '{benchmark_id}' not found")
|
|
||||||
return benchmark
|
|
||||||
|
|
||||||
async def register_benchmark(
|
|
||||||
self,
|
|
||||||
benchmark_id: str,
|
|
||||||
dataset_id: str,
|
|
||||||
scoring_functions: list[str],
|
|
||||||
metadata: dict[str, Any] | None = None,
|
|
||||||
provider_benchmark_id: str | None = None,
|
|
||||||
provider_id: str | None = None,
|
|
||||||
) -> None:
|
|
||||||
if metadata is None:
|
|
||||||
metadata = {}
|
|
||||||
if provider_id is None:
|
|
||||||
if len(self.impls_by_provider_id) == 1:
|
|
||||||
provider_id = list(self.impls_by_provider_id.keys())[0]
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"No provider specified and multiple providers available. Please specify a provider_id."
|
|
||||||
)
|
|
||||||
if provider_benchmark_id is None:
|
|
||||||
provider_benchmark_id = benchmark_id
|
|
||||||
benchmark = BenchmarkWithACL(
|
|
||||||
identifier=benchmark_id,
|
|
||||||
dataset_id=dataset_id,
|
|
||||||
scoring_functions=scoring_functions,
|
|
||||||
metadata=metadata,
|
|
||||||
provider_id=provider_id,
|
|
||||||
provider_resource_id=provider_benchmark_id,
|
|
||||||
)
|
|
||||||
await self.register_object(benchmark)
|
|
|
@ -1,218 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from llama_stack.apis.resource import ResourceType
|
|
||||||
from llama_stack.apis.scoring_functions import ScoringFn
|
|
||||||
from llama_stack.distribution.access_control import check_access
|
|
||||||
from llama_stack.distribution.datatypes import (
|
|
||||||
AccessAttributes,
|
|
||||||
RoutableObject,
|
|
||||||
RoutableObjectWithProvider,
|
|
||||||
RoutedProtocol,
|
|
||||||
)
|
|
||||||
from llama_stack.distribution.request_headers import get_auth_attributes
|
|
||||||
from llama_stack.distribution.store import DistributionRegistry
|
|
||||||
from llama_stack.log import get_logger
|
|
||||||
from llama_stack.providers.datatypes import Api, RoutingTable
|
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
|
||||||
|
|
||||||
|
|
||||||
def get_impl_api(p: Any) -> Api:
|
|
||||||
return p.__provider_spec__.api
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: this should return the registered object for all APIs
|
|
||||||
async def register_object_with_provider(obj: RoutableObject, p: Any) -> RoutableObject:
|
|
||||||
api = get_impl_api(p)
|
|
||||||
|
|
||||||
assert obj.provider_id != "remote", "Remote provider should not be registered"
|
|
||||||
|
|
||||||
if api == Api.inference:
|
|
||||||
return await p.register_model(obj)
|
|
||||||
elif api == Api.safety:
|
|
||||||
return await p.register_shield(obj)
|
|
||||||
elif api == Api.vector_io:
|
|
||||||
return await p.register_vector_db(obj)
|
|
||||||
elif api == Api.datasetio:
|
|
||||||
return await p.register_dataset(obj)
|
|
||||||
elif api == Api.scoring:
|
|
||||||
return await p.register_scoring_function(obj)
|
|
||||||
elif api == Api.eval:
|
|
||||||
return await p.register_benchmark(obj)
|
|
||||||
elif api == Api.tool_runtime:
|
|
||||||
return await p.register_toolgroup(obj)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown API {api} for registering object with provider")
|
|
||||||
|
|
||||||
|
|
||||||
async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None:
|
|
||||||
api = get_impl_api(p)
|
|
||||||
if api == Api.vector_io:
|
|
||||||
return await p.unregister_vector_db(obj.identifier)
|
|
||||||
elif api == Api.inference:
|
|
||||||
return await p.unregister_model(obj.identifier)
|
|
||||||
elif api == Api.datasetio:
|
|
||||||
return await p.unregister_dataset(obj.identifier)
|
|
||||||
elif api == Api.tool_runtime:
|
|
||||||
return await p.unregister_toolgroup(obj.identifier)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unregister not supported for {api}")
|
|
||||||
|
|
||||||
|
|
||||||
Registry = dict[str, list[RoutableObjectWithProvider]]
|
|
||||||
|
|
||||||
|
|
||||||
class CommonRoutingTableImpl(RoutingTable):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
impls_by_provider_id: dict[str, RoutedProtocol],
|
|
||||||
dist_registry: DistributionRegistry,
|
|
||||||
) -> None:
|
|
||||||
self.impls_by_provider_id = impls_by_provider_id
|
|
||||||
self.dist_registry = dist_registry
|
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
|
||||||
async def add_objects(objs: list[RoutableObjectWithProvider], provider_id: str, cls) -> None:
|
|
||||||
for obj in objs:
|
|
||||||
if cls is None:
|
|
||||||
obj.provider_id = provider_id
|
|
||||||
else:
|
|
||||||
# Create a copy of the model data and explicitly set provider_id
|
|
||||||
model_data = obj.model_dump()
|
|
||||||
model_data["provider_id"] = provider_id
|
|
||||||
obj = cls(**model_data)
|
|
||||||
await self.dist_registry.register(obj)
|
|
||||||
|
|
||||||
# Register all objects from providers
|
|
||||||
for pid, p in self.impls_by_provider_id.items():
|
|
||||||
api = get_impl_api(p)
|
|
||||||
if api == Api.inference:
|
|
||||||
p.model_store = self
|
|
||||||
elif api == Api.safety:
|
|
||||||
p.shield_store = self
|
|
||||||
elif api == Api.vector_io:
|
|
||||||
p.vector_db_store = self
|
|
||||||
elif api == Api.datasetio:
|
|
||||||
p.dataset_store = self
|
|
||||||
elif api == Api.scoring:
|
|
||||||
p.scoring_function_store = self
|
|
||||||
scoring_functions = await p.list_scoring_functions()
|
|
||||||
await add_objects(scoring_functions, pid, ScoringFn)
|
|
||||||
elif api == Api.eval:
|
|
||||||
p.benchmark_store = self
|
|
||||||
elif api == Api.tool_runtime:
|
|
||||||
p.tool_store = self
|
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
|
||||||
for p in self.impls_by_provider_id.values():
|
|
||||||
await p.shutdown()
|
|
||||||
|
|
||||||
def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any:
|
|
||||||
from .benchmarks import BenchmarksRoutingTable
|
|
||||||
from .datasets import DatasetsRoutingTable
|
|
||||||
from .models import ModelsRoutingTable
|
|
||||||
from .scoring_functions import ScoringFunctionsRoutingTable
|
|
||||||
from .shields import ShieldsRoutingTable
|
|
||||||
from .toolgroups import ToolGroupsRoutingTable
|
|
||||||
from .vector_dbs import VectorDBsRoutingTable
|
|
||||||
|
|
||||||
def apiname_object():
|
|
||||||
if isinstance(self, ModelsRoutingTable):
|
|
||||||
return ("Inference", "model")
|
|
||||||
elif isinstance(self, ShieldsRoutingTable):
|
|
||||||
return ("Safety", "shield")
|
|
||||||
elif isinstance(self, VectorDBsRoutingTable):
|
|
||||||
return ("VectorIO", "vector_db")
|
|
||||||
elif isinstance(self, DatasetsRoutingTable):
|
|
||||||
return ("DatasetIO", "dataset")
|
|
||||||
elif isinstance(self, ScoringFunctionsRoutingTable):
|
|
||||||
return ("Scoring", "scoring_function")
|
|
||||||
elif isinstance(self, BenchmarksRoutingTable):
|
|
||||||
return ("Eval", "benchmark")
|
|
||||||
elif isinstance(self, ToolGroupsRoutingTable):
|
|
||||||
return ("ToolGroups", "tool_group")
|
|
||||||
else:
|
|
||||||
raise ValueError("Unknown routing table type")
|
|
||||||
|
|
||||||
apiname, objtype = apiname_object()
|
|
||||||
|
|
||||||
# Get objects from disk registry
|
|
||||||
obj = self.dist_registry.get_cached(objtype, routing_key)
|
|
||||||
if not obj:
|
|
||||||
provider_ids = list(self.impls_by_provider_id.keys())
|
|
||||||
if len(provider_ids) > 1:
|
|
||||||
provider_ids_str = f"any of the providers: {', '.join(provider_ids)}"
|
|
||||||
else:
|
|
||||||
provider_ids_str = f"provider: `{provider_ids[0]}`"
|
|
||||||
raise ValueError(
|
|
||||||
f"{objtype.capitalize()} `{routing_key}` not served by {provider_ids_str}. Make sure there is an {apiname} provider serving this {objtype}."
|
|
||||||
)
|
|
||||||
|
|
||||||
if not provider_id or provider_id == obj.provider_id:
|
|
||||||
return self.impls_by_provider_id[obj.provider_id]
|
|
||||||
|
|
||||||
raise ValueError(f"Provider not found for `{routing_key}`")
|
|
||||||
|
|
||||||
async def get_object_by_identifier(self, type: str, identifier: str) -> RoutableObjectWithProvider | None:
|
|
||||||
# Get from disk registry
|
|
||||||
obj = await self.dist_registry.get(type, identifier)
|
|
||||||
if not obj:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Check if user has permission to access this object
|
|
||||||
if not check_access(obj.identifier, getattr(obj, "access_attributes", None), get_auth_attributes()):
|
|
||||||
logger.debug(f"Access denied to {type} '{identifier}' based on attribute mismatch")
|
|
||||||
return None
|
|
||||||
|
|
||||||
return obj
|
|
||||||
|
|
||||||
async def unregister_object(self, obj: RoutableObjectWithProvider) -> None:
|
|
||||||
await self.dist_registry.delete(obj.type, obj.identifier)
|
|
||||||
await unregister_object_from_provider(obj, self.impls_by_provider_id[obj.provider_id])
|
|
||||||
|
|
||||||
async def register_object(self, obj: RoutableObjectWithProvider) -> RoutableObjectWithProvider:
|
|
||||||
# if provider_id is not specified, pick an arbitrary one from existing entries
|
|
||||||
if not obj.provider_id and len(self.impls_by_provider_id) > 0:
|
|
||||||
obj.provider_id = list(self.impls_by_provider_id.keys())[0]
|
|
||||||
|
|
||||||
if obj.provider_id not in self.impls_by_provider_id:
|
|
||||||
raise ValueError(f"Provider `{obj.provider_id}` not found")
|
|
||||||
|
|
||||||
p = self.impls_by_provider_id[obj.provider_id]
|
|
||||||
|
|
||||||
# If object supports access control but no attributes set, use creator's attributes
|
|
||||||
if not obj.access_attributes:
|
|
||||||
creator_attributes = get_auth_attributes()
|
|
||||||
if creator_attributes:
|
|
||||||
obj.access_attributes = AccessAttributes(**creator_attributes)
|
|
||||||
logger.info(f"Setting access attributes for {obj.type} '{obj.identifier}' based on creator's identity")
|
|
||||||
|
|
||||||
registered_obj = await register_object_with_provider(obj, p)
|
|
||||||
# TODO: This needs to be fixed for all APIs once they return the registered object
|
|
||||||
if obj.type == ResourceType.model.value:
|
|
||||||
await self.dist_registry.register(registered_obj)
|
|
||||||
return registered_obj
|
|
||||||
|
|
||||||
else:
|
|
||||||
await self.dist_registry.register(obj)
|
|
||||||
return obj
|
|
||||||
|
|
||||||
async def get_all_with_type(self, type: str) -> list[RoutableObjectWithProvider]:
|
|
||||||
objs = await self.dist_registry.get_all()
|
|
||||||
filtered_objs = [obj for obj in objs if obj.type == type]
|
|
||||||
|
|
||||||
# Apply attribute-based access control filtering
|
|
||||||
if filtered_objs:
|
|
||||||
filtered_objs = [
|
|
||||||
obj
|
|
||||||
for obj in filtered_objs
|
|
||||||
if check_access(obj.identifier, getattr(obj, "access_attributes", None), get_auth_attributes())
|
|
||||||
]
|
|
||||||
|
|
||||||
return filtered_objs
|
|
|
@ -1,93 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
import uuid
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from llama_stack.apis.datasets import (
|
|
||||||
Dataset,
|
|
||||||
DatasetPurpose,
|
|
||||||
Datasets,
|
|
||||||
DatasetType,
|
|
||||||
DataSource,
|
|
||||||
ListDatasetsResponse,
|
|
||||||
RowsDataSource,
|
|
||||||
URIDataSource,
|
|
||||||
)
|
|
||||||
from llama_stack.apis.resource import ResourceType
|
|
||||||
from llama_stack.distribution.datatypes import (
|
|
||||||
DatasetWithACL,
|
|
||||||
)
|
|
||||||
from llama_stack.log import get_logger
|
|
||||||
|
|
||||||
from .common import CommonRoutingTableImpl
|
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
|
||||||
|
|
||||||
|
|
||||||
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
|
||||||
async def list_datasets(self) -> ListDatasetsResponse:
|
|
||||||
return ListDatasetsResponse(data=await self.get_all_with_type(ResourceType.dataset.value))
|
|
||||||
|
|
||||||
async def get_dataset(self, dataset_id: str) -> Dataset:
|
|
||||||
dataset = await self.get_object_by_identifier("dataset", dataset_id)
|
|
||||||
if dataset is None:
|
|
||||||
raise ValueError(f"Dataset '{dataset_id}' not found")
|
|
||||||
return dataset
|
|
||||||
|
|
||||||
async def register_dataset(
|
|
||||||
self,
|
|
||||||
purpose: DatasetPurpose,
|
|
||||||
source: DataSource,
|
|
||||||
metadata: dict[str, Any] | None = None,
|
|
||||||
dataset_id: str | None = None,
|
|
||||||
) -> Dataset:
|
|
||||||
if isinstance(source, dict):
|
|
||||||
if source["type"] == "uri":
|
|
||||||
source = URIDataSource.parse_obj(source)
|
|
||||||
elif source["type"] == "rows":
|
|
||||||
source = RowsDataSource.parse_obj(source)
|
|
||||||
|
|
||||||
if not dataset_id:
|
|
||||||
dataset_id = f"dataset-{str(uuid.uuid4())}"
|
|
||||||
|
|
||||||
provider_dataset_id = dataset_id
|
|
||||||
|
|
||||||
# infer provider from source
|
|
||||||
if metadata:
|
|
||||||
if metadata.get("provider_id"):
|
|
||||||
provider_id = metadata.get("provider_id") # pass through from nvidia datasetio
|
|
||||||
elif source.type == DatasetType.rows.value:
|
|
||||||
provider_id = "localfs"
|
|
||||||
elif source.type == DatasetType.uri.value:
|
|
||||||
# infer provider from uri
|
|
||||||
if source.uri.startswith("huggingface"):
|
|
||||||
provider_id = "huggingface"
|
|
||||||
else:
|
|
||||||
provider_id = "localfs"
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown data source type: {source.type}")
|
|
||||||
|
|
||||||
if metadata is None:
|
|
||||||
metadata = {}
|
|
||||||
|
|
||||||
dataset = DatasetWithACL(
|
|
||||||
identifier=dataset_id,
|
|
||||||
provider_resource_id=provider_dataset_id,
|
|
||||||
provider_id=provider_id,
|
|
||||||
purpose=purpose,
|
|
||||||
source=source,
|
|
||||||
metadata=metadata,
|
|
||||||
)
|
|
||||||
|
|
||||||
await self.register_object(dataset)
|
|
||||||
return dataset
|
|
||||||
|
|
||||||
async def unregister_dataset(self, dataset_id: str) -> None:
|
|
||||||
dataset = await self.get_dataset(dataset_id)
|
|
||||||
if dataset is None:
|
|
||||||
raise ValueError(f"Dataset {dataset_id} not found")
|
|
||||||
await self.unregister_object(dataset)
|
|
|
@ -1,82 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
import time
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType, OpenAIListModelsResponse, OpenAIModel
|
|
||||||
from llama_stack.distribution.datatypes import (
|
|
||||||
ModelWithACL,
|
|
||||||
)
|
|
||||||
from llama_stack.log import get_logger
|
|
||||||
|
|
||||||
from .common import CommonRoutingTableImpl
|
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
|
||||||
|
|
||||||
|
|
||||||
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|
||||||
async def list_models(self) -> ListModelsResponse:
|
|
||||||
return ListModelsResponse(data=await self.get_all_with_type("model"))
|
|
||||||
|
|
||||||
async def openai_list_models(self) -> OpenAIListModelsResponse:
|
|
||||||
models = await self.get_all_with_type("model")
|
|
||||||
openai_models = [
|
|
||||||
OpenAIModel(
|
|
||||||
id=model.identifier,
|
|
||||||
object="model",
|
|
||||||
created=int(time.time()),
|
|
||||||
owned_by="llama_stack",
|
|
||||||
)
|
|
||||||
for model in models
|
|
||||||
]
|
|
||||||
return OpenAIListModelsResponse(data=openai_models)
|
|
||||||
|
|
||||||
async def get_model(self, model_id: str) -> Model:
|
|
||||||
model = await self.get_object_by_identifier("model", model_id)
|
|
||||||
if model is None:
|
|
||||||
raise ValueError(f"Model '{model_id}' not found")
|
|
||||||
return model
|
|
||||||
|
|
||||||
async def register_model(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
provider_model_id: str | None = None,
|
|
||||||
provider_id: str | None = None,
|
|
||||||
metadata: dict[str, Any] | None = None,
|
|
||||||
model_type: ModelType | None = None,
|
|
||||||
) -> Model:
|
|
||||||
if provider_model_id is None:
|
|
||||||
provider_model_id = model_id
|
|
||||||
if provider_id is None:
|
|
||||||
# If provider_id not specified, use the only provider if it supports this model
|
|
||||||
if len(self.impls_by_provider_id) == 1:
|
|
||||||
provider_id = list(self.impls_by_provider_id.keys())[0]
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"No provider specified and multiple providers available. Please specify a provider_id. Available providers: {self.impls_by_provider_id.keys()}"
|
|
||||||
)
|
|
||||||
if metadata is None:
|
|
||||||
metadata = {}
|
|
||||||
if model_type is None:
|
|
||||||
model_type = ModelType.llm
|
|
||||||
if "embedding_dimension" not in metadata and model_type == ModelType.embedding:
|
|
||||||
raise ValueError("Embedding model must have an embedding dimension in its metadata")
|
|
||||||
model = ModelWithACL(
|
|
||||||
identifier=model_id,
|
|
||||||
provider_resource_id=provider_model_id,
|
|
||||||
provider_id=provider_id,
|
|
||||||
metadata=metadata,
|
|
||||||
model_type=model_type,
|
|
||||||
)
|
|
||||||
registered_model = await self.register_object(model)
|
|
||||||
return registered_model
|
|
||||||
|
|
||||||
async def unregister_model(self, model_id: str) -> None:
|
|
||||||
existing_model = await self.get_model(model_id)
|
|
||||||
if existing_model is None:
|
|
||||||
raise ValueError(f"Model {model_id} not found")
|
|
||||||
await self.unregister_object(existing_model)
|
|
|
@ -1,62 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from llama_stack.apis.common.type_system import ParamType
|
|
||||||
from llama_stack.apis.resource import ResourceType
|
|
||||||
from llama_stack.apis.scoring_functions import (
|
|
||||||
ListScoringFunctionsResponse,
|
|
||||||
ScoringFn,
|
|
||||||
ScoringFnParams,
|
|
||||||
ScoringFunctions,
|
|
||||||
)
|
|
||||||
from llama_stack.distribution.datatypes import (
|
|
||||||
ScoringFnWithACL,
|
|
||||||
)
|
|
||||||
from llama_stack.log import get_logger
|
|
||||||
|
|
||||||
from .common import CommonRoutingTableImpl
|
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
|
||||||
|
|
||||||
|
|
||||||
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
|
|
||||||
async def list_scoring_functions(self) -> ListScoringFunctionsResponse:
|
|
||||||
return ListScoringFunctionsResponse(data=await self.get_all_with_type(ResourceType.scoring_function.value))
|
|
||||||
|
|
||||||
async def get_scoring_function(self, scoring_fn_id: str) -> ScoringFn:
|
|
||||||
scoring_fn = await self.get_object_by_identifier("scoring_function", scoring_fn_id)
|
|
||||||
if scoring_fn is None:
|
|
||||||
raise ValueError(f"Scoring function '{scoring_fn_id}' not found")
|
|
||||||
return scoring_fn
|
|
||||||
|
|
||||||
async def register_scoring_function(
|
|
||||||
self,
|
|
||||||
scoring_fn_id: str,
|
|
||||||
description: str,
|
|
||||||
return_type: ParamType,
|
|
||||||
provider_scoring_fn_id: str | None = None,
|
|
||||||
provider_id: str | None = None,
|
|
||||||
params: ScoringFnParams | None = None,
|
|
||||||
) -> None:
|
|
||||||
if provider_scoring_fn_id is None:
|
|
||||||
provider_scoring_fn_id = scoring_fn_id
|
|
||||||
if provider_id is None:
|
|
||||||
if len(self.impls_by_provider_id) == 1:
|
|
||||||
provider_id = list(self.impls_by_provider_id.keys())[0]
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"No provider specified and multiple providers available. Please specify a provider_id."
|
|
||||||
)
|
|
||||||
scoring_fn = ScoringFnWithACL(
|
|
||||||
identifier=scoring_fn_id,
|
|
||||||
description=description,
|
|
||||||
return_type=return_type,
|
|
||||||
provider_resource_id=provider_scoring_fn_id,
|
|
||||||
provider_id=provider_id,
|
|
||||||
params=params,
|
|
||||||
)
|
|
||||||
scoring_fn.provider_id = provider_id
|
|
||||||
await self.register_object(scoring_fn)
|
|
|
@ -1,57 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from llama_stack.apis.resource import ResourceType
|
|
||||||
from llama_stack.apis.shields import ListShieldsResponse, Shield, Shields
|
|
||||||
from llama_stack.distribution.datatypes import (
|
|
||||||
ShieldWithACL,
|
|
||||||
)
|
|
||||||
from llama_stack.log import get_logger
|
|
||||||
|
|
||||||
from .common import CommonRoutingTableImpl
|
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
|
||||||
|
|
||||||
|
|
||||||
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
|
||||||
async def list_shields(self) -> ListShieldsResponse:
|
|
||||||
return ListShieldsResponse(data=await self.get_all_with_type(ResourceType.shield.value))
|
|
||||||
|
|
||||||
async def get_shield(self, identifier: str) -> Shield:
|
|
||||||
shield = await self.get_object_by_identifier("shield", identifier)
|
|
||||||
if shield is None:
|
|
||||||
raise ValueError(f"Shield '{identifier}' not found")
|
|
||||||
return shield
|
|
||||||
|
|
||||||
async def register_shield(
|
|
||||||
self,
|
|
||||||
shield_id: str,
|
|
||||||
provider_shield_id: str | None = None,
|
|
||||||
provider_id: str | None = None,
|
|
||||||
params: dict[str, Any] | None = None,
|
|
||||||
) -> Shield:
|
|
||||||
if provider_shield_id is None:
|
|
||||||
provider_shield_id = shield_id
|
|
||||||
if provider_id is None:
|
|
||||||
# If provider_id not specified, use the only provider if it supports this shield type
|
|
||||||
if len(self.impls_by_provider_id) == 1:
|
|
||||||
provider_id = list(self.impls_by_provider_id.keys())[0]
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"No provider specified and multiple providers available. Please specify a provider_id."
|
|
||||||
)
|
|
||||||
if params is None:
|
|
||||||
params = {}
|
|
||||||
shield = ShieldWithACL(
|
|
||||||
identifier=shield_id,
|
|
||||||
provider_resource_id=provider_shield_id,
|
|
||||||
provider_id=provider_id,
|
|
||||||
params=params,
|
|
||||||
)
|
|
||||||
await self.register_object(shield)
|
|
||||||
return shield
|
|
|
@ -1,132 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import URL
|
|
||||||
from llama_stack.apis.tools import ListToolGroupsResponse, ListToolsResponse, Tool, ToolGroup, ToolGroups
|
|
||||||
from llama_stack.distribution.datatypes import ToolGroupWithACL
|
|
||||||
from llama_stack.log import get_logger
|
|
||||||
|
|
||||||
from .common import CommonRoutingTableImpl
|
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
|
||||||
|
|
||||||
|
|
||||||
def parse_toolgroup_from_toolgroup_name_pair(toolgroup_name_with_maybe_tool_name: str) -> str | None:
|
|
||||||
# handle the funny case like "builtin::rag/knowledge_search"
|
|
||||||
parts = toolgroup_name_with_maybe_tool_name.split("/")
|
|
||||||
if len(parts) == 2:
|
|
||||||
return parts[0]
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|
||||||
toolgroups_to_tools: dict[str, list[Tool]] = {}
|
|
||||||
tool_to_toolgroup: dict[str, str] = {}
|
|
||||||
|
|
||||||
# overridden
|
|
||||||
def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any:
|
|
||||||
# we don't index tools in the registry anymore, but only keep a cache of them by toolgroup_id
|
|
||||||
# TODO: we may want to invalidate the cache (for a given toolgroup_id) every once in a while?
|
|
||||||
|
|
||||||
toolgroup_id = parse_toolgroup_from_toolgroup_name_pair(routing_key)
|
|
||||||
if toolgroup_id:
|
|
||||||
routing_key = toolgroup_id
|
|
||||||
|
|
||||||
if routing_key in self.tool_to_toolgroup:
|
|
||||||
routing_key = self.tool_to_toolgroup[routing_key]
|
|
||||||
return super().get_provider_impl(routing_key, provider_id)
|
|
||||||
|
|
||||||
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse:
|
|
||||||
if toolgroup_id:
|
|
||||||
if group_id := parse_toolgroup_from_toolgroup_name_pair(toolgroup_id):
|
|
||||||
toolgroup_id = group_id
|
|
||||||
toolgroups = [await self.get_tool_group(toolgroup_id)]
|
|
||||||
else:
|
|
||||||
toolgroups = await self.get_all_with_type("tool_group")
|
|
||||||
|
|
||||||
all_tools = []
|
|
||||||
for toolgroup in toolgroups:
|
|
||||||
if toolgroup.identifier not in self.toolgroups_to_tools:
|
|
||||||
await self._index_tools(toolgroup)
|
|
||||||
all_tools.extend(self.toolgroups_to_tools[toolgroup.identifier])
|
|
||||||
|
|
||||||
return ListToolsResponse(data=all_tools)
|
|
||||||
|
|
||||||
async def _index_tools(self, toolgroup: ToolGroup):
|
|
||||||
provider_impl = super().get_provider_impl(toolgroup.identifier, toolgroup.provider_id)
|
|
||||||
tooldefs_response = await provider_impl.list_runtime_tools(toolgroup.identifier, toolgroup.mcp_endpoint)
|
|
||||||
|
|
||||||
# TODO: kill this Tool vs ToolDef distinction
|
|
||||||
tooldefs = tooldefs_response.data
|
|
||||||
tools = []
|
|
||||||
for t in tooldefs:
|
|
||||||
tools.append(
|
|
||||||
Tool(
|
|
||||||
identifier=t.name,
|
|
||||||
toolgroup_id=toolgroup.identifier,
|
|
||||||
description=t.description or "",
|
|
||||||
parameters=t.parameters or [],
|
|
||||||
metadata=t.metadata,
|
|
||||||
provider_id=toolgroup.provider_id,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.toolgroups_to_tools[toolgroup.identifier] = tools
|
|
||||||
for tool in tools:
|
|
||||||
self.tool_to_toolgroup[tool.identifier] = toolgroup.identifier
|
|
||||||
|
|
||||||
async def list_tool_groups(self) -> ListToolGroupsResponse:
|
|
||||||
return ListToolGroupsResponse(data=await self.get_all_with_type("tool_group"))
|
|
||||||
|
|
||||||
async def get_tool_group(self, toolgroup_id: str) -> ToolGroup:
|
|
||||||
tool_group = await self.get_object_by_identifier("tool_group", toolgroup_id)
|
|
||||||
if tool_group is None:
|
|
||||||
raise ValueError(f"Tool group '{toolgroup_id}' not found")
|
|
||||||
return tool_group
|
|
||||||
|
|
||||||
async def get_tool(self, tool_name: str) -> Tool:
|
|
||||||
if tool_name in self.tool_to_toolgroup:
|
|
||||||
toolgroup_id = self.tool_to_toolgroup[tool_name]
|
|
||||||
tools = self.toolgroups_to_tools[toolgroup_id]
|
|
||||||
for tool in tools:
|
|
||||||
if tool.identifier == tool_name:
|
|
||||||
return tool
|
|
||||||
raise ValueError(f"Tool '{tool_name}' not found")
|
|
||||||
|
|
||||||
async def register_tool_group(
|
|
||||||
self,
|
|
||||||
toolgroup_id: str,
|
|
||||||
provider_id: str,
|
|
||||||
mcp_endpoint: URL | None = None,
|
|
||||||
args: dict[str, Any] | None = None,
|
|
||||||
) -> None:
|
|
||||||
toolgroup = ToolGroupWithACL(
|
|
||||||
identifier=toolgroup_id,
|
|
||||||
provider_id=provider_id,
|
|
||||||
provider_resource_id=toolgroup_id,
|
|
||||||
mcp_endpoint=mcp_endpoint,
|
|
||||||
args=args,
|
|
||||||
)
|
|
||||||
await self.register_object(toolgroup)
|
|
||||||
|
|
||||||
# ideally, indexing of the tools should not be necessary because anyone using
|
|
||||||
# the tools should first list the tools and then use them. but there are assumptions
|
|
||||||
# baked in some of the code and tests right now.
|
|
||||||
if not toolgroup.mcp_endpoint:
|
|
||||||
await self._index_tools(toolgroup)
|
|
||||||
return toolgroup
|
|
||||||
|
|
||||||
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
|
|
||||||
tool_group = await self.get_tool_group(toolgroup_id)
|
|
||||||
if tool_group is None:
|
|
||||||
raise ValueError(f"Tool group {toolgroup_id} not found")
|
|
||||||
await self.unregister_object(tool_group)
|
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
|
||||||
pass
|
|
|
@ -1,74 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from pydantic import TypeAdapter
|
|
||||||
|
|
||||||
from llama_stack.apis.models import ModelType
|
|
||||||
from llama_stack.apis.resource import ResourceType
|
|
||||||
from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs
|
|
||||||
from llama_stack.distribution.datatypes import (
|
|
||||||
VectorDBWithACL,
|
|
||||||
)
|
|
||||||
from llama_stack.log import get_logger
|
|
||||||
|
|
||||||
from .common import CommonRoutingTableImpl
|
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
|
||||||
|
|
||||||
|
|
||||||
class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
|
||||||
async def list_vector_dbs(self) -> ListVectorDBsResponse:
|
|
||||||
return ListVectorDBsResponse(data=await self.get_all_with_type("vector_db"))
|
|
||||||
|
|
||||||
async def get_vector_db(self, vector_db_id: str) -> VectorDB:
|
|
||||||
vector_db = await self.get_object_by_identifier("vector_db", vector_db_id)
|
|
||||||
if vector_db is None:
|
|
||||||
raise ValueError(f"Vector DB '{vector_db_id}' not found")
|
|
||||||
return vector_db
|
|
||||||
|
|
||||||
async def register_vector_db(
|
|
||||||
self,
|
|
||||||
vector_db_id: str,
|
|
||||||
embedding_model: str,
|
|
||||||
embedding_dimension: int | None = 384,
|
|
||||||
provider_id: str | None = None,
|
|
||||||
provider_vector_db_id: str | None = None,
|
|
||||||
) -> VectorDB:
|
|
||||||
if provider_vector_db_id is None:
|
|
||||||
provider_vector_db_id = vector_db_id
|
|
||||||
if provider_id is None:
|
|
||||||
if len(self.impls_by_provider_id) > 0:
|
|
||||||
provider_id = list(self.impls_by_provider_id.keys())[0]
|
|
||||||
if len(self.impls_by_provider_id) > 1:
|
|
||||||
logger.warning(
|
|
||||||
f"No provider specified and multiple providers available. Arbitrarily selected the first provider {provider_id}."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError("No provider available. Please configure a vector_io provider.")
|
|
||||||
model = await self.get_object_by_identifier("model", embedding_model)
|
|
||||||
if model is None:
|
|
||||||
raise ValueError(f"Model {embedding_model} not found")
|
|
||||||
if model.model_type != ModelType.embedding:
|
|
||||||
raise ValueError(f"Model {embedding_model} is not an embedding model")
|
|
||||||
if "embedding_dimension" not in model.metadata:
|
|
||||||
raise ValueError(f"Model {embedding_model} does not have an embedding dimension")
|
|
||||||
vector_db_data = {
|
|
||||||
"identifier": vector_db_id,
|
|
||||||
"type": ResourceType.vector_db.value,
|
|
||||||
"provider_id": provider_id,
|
|
||||||
"provider_resource_id": provider_vector_db_id,
|
|
||||||
"embedding_model": embedding_model,
|
|
||||||
"embedding_dimension": model.metadata["embedding_dimension"],
|
|
||||||
}
|
|
||||||
vector_db = TypeAdapter(VectorDBWithACL).validate_python(vector_db_data)
|
|
||||||
await self.register_object(vector_db)
|
|
||||||
return vector_db
|
|
||||||
|
|
||||||
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
|
||||||
existing_vector_db = await self.get_vector_db(vector_db_id)
|
|
||||||
if existing_vector_db is None:
|
|
||||||
raise ValueError(f"Vector DB {vector_db_id} not found")
|
|
||||||
await self.unregister_object(existing_vector_db)
|
|
|
@ -8,8 +8,7 @@ import json
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import AuthenticationConfig
|
from llama_stack.distribution.server.auth_providers import AuthProviderConfig, create_auth_provider
|
||||||
from llama_stack.distribution.server.auth_providers import create_auth_provider
|
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="auth")
|
logger = get_logger(name=__name__, category="auth")
|
||||||
|
@ -78,7 +77,7 @@ class AuthenticationMiddleware:
|
||||||
access resources that don't have access_attributes defined.
|
access resources that don't have access_attributes defined.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, app, auth_config: AuthenticationConfig):
|
def __init__(self, app, auth_config: AuthProviderConfig):
|
||||||
self.app = app
|
self.app = app
|
||||||
self.auth_provider = create_auth_provider(auth_config)
|
self.auth_provider = create_auth_provider(auth_config)
|
||||||
|
|
||||||
|
@ -94,7 +93,7 @@ class AuthenticationMiddleware:
|
||||||
|
|
||||||
# Validate token and get access attributes
|
# Validate token and get access attributes
|
||||||
try:
|
try:
|
||||||
validation_result = await self.auth_provider.validate_token(token, scope)
|
access_attributes = await self.auth_provider.validate_token(token, scope)
|
||||||
except httpx.TimeoutException:
|
except httpx.TimeoutException:
|
||||||
logger.exception("Authentication request timed out")
|
logger.exception("Authentication request timed out")
|
||||||
return await self._send_auth_error(send, "Authentication service timeout")
|
return await self._send_auth_error(send, "Authentication service timeout")
|
||||||
|
@ -106,24 +105,17 @@ class AuthenticationMiddleware:
|
||||||
return await self._send_auth_error(send, "Authentication service error")
|
return await self._send_auth_error(send, "Authentication service error")
|
||||||
|
|
||||||
# Store attributes in request scope for access control
|
# Store attributes in request scope for access control
|
||||||
if validation_result.access_attributes:
|
if access_attributes:
|
||||||
user_attributes = validation_result.access_attributes.model_dump(exclude_none=True)
|
user_attributes = access_attributes.model_dump(exclude_none=True)
|
||||||
else:
|
else:
|
||||||
logger.warning("No access attributes, setting namespace to token by default")
|
logger.warning("No access attributes, setting namespace to token by default")
|
||||||
user_attributes = {
|
user_attributes = {
|
||||||
"roles": [token],
|
"namespaces": [token],
|
||||||
}
|
}
|
||||||
|
|
||||||
# Store the client ID in the request scope so that downstream middleware (like QuotaMiddleware)
|
|
||||||
# can identify the requester and enforce per-client rate limits.
|
|
||||||
scope["authenticated_client_id"] = token
|
|
||||||
|
|
||||||
# Store attributes in request scope
|
# Store attributes in request scope
|
||||||
scope["user_attributes"] = user_attributes
|
scope["user_attributes"] = user_attributes
|
||||||
scope["principal"] = validation_result.principal
|
logger.debug(f"Authentication successful: {len(scope['user_attributes'])} attributes")
|
||||||
logger.debug(
|
|
||||||
f"Authentication successful: {validation_result.principal} with {len(scope['user_attributes'])} attributes"
|
|
||||||
)
|
|
||||||
|
|
||||||
return await self.app(scope, receive, send)
|
return await self.app(scope, receive, send)
|
||||||
|
|
||||||
|
|
|
@ -4,29 +4,23 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import ssl
|
import json
|
||||||
import time
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from asyncio import Lock
|
from enum import Enum
|
||||||
from pathlib import Path
|
|
||||||
from urllib.parse import parse_qs
|
from urllib.parse import parse_qs
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from jose import jwt
|
from pydantic import BaseModel, Field
|
||||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
|
||||||
from typing_extensions import Self
|
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import AccessAttributes, AuthenticationConfig, AuthProviderType
|
from llama_stack.distribution.datatypes import AccessAttributes
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="auth")
|
logger = get_logger(name=__name__, category="auth")
|
||||||
|
|
||||||
|
|
||||||
class TokenValidationResult(BaseModel):
|
class AuthResponse(BaseModel):
|
||||||
principal: str | None = Field(
|
"""The format of the authentication response from the auth endpoint."""
|
||||||
default=None,
|
|
||||||
description="The principal (username or persistent identifier) of the authenticated user",
|
|
||||||
)
|
|
||||||
access_attributes: AccessAttributes | None = Field(
|
access_attributes: AccessAttributes | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="""
|
description="""
|
||||||
|
@ -49,10 +43,6 @@ class TokenValidationResult(BaseModel):
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class AuthResponse(TokenValidationResult):
|
|
||||||
"""The format of the authentication response from the auth endpoint."""
|
|
||||||
|
|
||||||
message: str | None = Field(
|
message: str | None = Field(
|
||||||
default=None, description="Optional message providing additional context about the authentication result."
|
default=None, description="Optional message providing additional context about the authentication result."
|
||||||
)
|
)
|
||||||
|
@ -74,11 +64,25 @@ class AuthRequest(BaseModel):
|
||||||
request: AuthRequestContext = Field(description="Context information about the request being authenticated")
|
request: AuthRequestContext = Field(description="Context information about the request being authenticated")
|
||||||
|
|
||||||
|
|
||||||
|
class AuthProviderType(str, Enum):
|
||||||
|
"""Supported authentication provider types."""
|
||||||
|
|
||||||
|
KUBERNETES = "kubernetes"
|
||||||
|
CUSTOM = "custom"
|
||||||
|
|
||||||
|
|
||||||
|
class AuthProviderConfig(BaseModel):
|
||||||
|
"""Base configuration for authentication providers."""
|
||||||
|
|
||||||
|
provider_type: AuthProviderType = Field(..., description="Type of authentication provider")
|
||||||
|
config: dict[str, str] = Field(..., description="Provider-specific configuration")
|
||||||
|
|
||||||
|
|
||||||
class AuthProvider(ABC):
|
class AuthProvider(ABC):
|
||||||
"""Abstract base class for authentication providers."""
|
"""Abstract base class for authentication providers."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
|
async def validate_token(self, token: str, scope: dict | None = None) -> AccessAttributes | None:
|
||||||
"""Validate a token and return access attributes."""
|
"""Validate a token and return access attributes."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -88,219 +92,88 @@ class AuthProvider(ABC):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str]) -> AccessAttributes:
|
class KubernetesAuthProvider(AuthProvider):
|
||||||
attributes = AccessAttributes()
|
"""Kubernetes authentication provider that validates tokens against the Kubernetes API server."""
|
||||||
for claim_key, attribute_key in mapping.items():
|
|
||||||
if claim_key not in claims or not hasattr(attributes, attribute_key):
|
|
||||||
continue
|
|
||||||
claim = claims[claim_key]
|
|
||||||
if isinstance(claim, list):
|
|
||||||
values = claim
|
|
||||||
else:
|
|
||||||
values = claim.split()
|
|
||||||
|
|
||||||
current = getattr(attributes, attribute_key)
|
def __init__(self, config: dict[str, str]):
|
||||||
if current:
|
self.api_server_url = config["api_server_url"]
|
||||||
current.extend(values)
|
self.ca_cert_path = config.get("ca_cert_path")
|
||||||
else:
|
self._client = None
|
||||||
setattr(attributes, attribute_key, values)
|
|
||||||
return attributes
|
|
||||||
|
|
||||||
|
async def _get_client(self):
|
||||||
|
"""Get or create a Kubernetes client."""
|
||||||
|
if self._client is None:
|
||||||
|
# kubernetes-client has not async support, see:
|
||||||
|
# https://github.com/kubernetes-client/python/issues/323
|
||||||
|
from kubernetes import client
|
||||||
|
from kubernetes.client import ApiClient
|
||||||
|
|
||||||
class OAuth2JWKSConfig(BaseModel):
|
# Configure the client
|
||||||
# The JWKS URI for collecting public keys
|
configuration = client.Configuration()
|
||||||
uri: str
|
configuration.host = self.api_server_url
|
||||||
key_recheck_period: int = Field(default=3600, description="The period to recheck the JWKS URI for key updates")
|
if self.ca_cert_path:
|
||||||
|
configuration.ssl_ca_cert = self.ca_cert_path
|
||||||
|
configuration.verify_ssl = bool(self.ca_cert_path)
|
||||||
|
|
||||||
|
# Create API client
|
||||||
|
self._client = ApiClient(configuration)
|
||||||
|
return self._client
|
||||||
|
|
||||||
class OAuth2IntrospectionConfig(BaseModel):
|
async def validate_token(self, token: str, scope: dict | None = None) -> AccessAttributes | None:
|
||||||
url: str
|
"""Validate a Kubernetes token and return access attributes."""
|
||||||
client_id: str
|
|
||||||
client_secret: str
|
|
||||||
send_secret_in_body: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
class OAuth2TokenAuthProviderConfig(BaseModel):
|
|
||||||
audience: str = "llama-stack"
|
|
||||||
verify_tls: bool = True
|
|
||||||
tls_cafile: Path | None = None
|
|
||||||
issuer: str | None = Field(default=None, description="The OIDC issuer URL.")
|
|
||||||
claims_mapping: dict[str, str] = Field(
|
|
||||||
default_factory=lambda: {
|
|
||||||
"sub": "roles",
|
|
||||||
"username": "roles",
|
|
||||||
"groups": "teams",
|
|
||||||
"team": "teams",
|
|
||||||
"project": "projects",
|
|
||||||
"tenant": "namespaces",
|
|
||||||
"namespace": "namespaces",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
jwks: OAuth2JWKSConfig | None
|
|
||||||
introspection: OAuth2IntrospectionConfig | None = None
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@field_validator("claims_mapping")
|
|
||||||
def validate_claims_mapping(cls, v):
|
|
||||||
for key, value in v.items():
|
|
||||||
if not value:
|
|
||||||
raise ValueError(f"claims_mapping value cannot be empty: {key}")
|
|
||||||
if value not in AccessAttributes.model_fields:
|
|
||||||
raise ValueError(f"claims_mapping value is not a valid attribute: {value}")
|
|
||||||
return v
|
|
||||||
|
|
||||||
@model_validator(mode="after")
|
|
||||||
def validate_mode(self) -> Self:
|
|
||||||
if not self.jwks and not self.introspection:
|
|
||||||
raise ValueError("One of jwks or introspection must be configured")
|
|
||||||
if self.jwks and self.introspection:
|
|
||||||
raise ValueError("At present only one of jwks or introspection should be configured")
|
|
||||||
return self
|
|
||||||
|
|
||||||
|
|
||||||
class OAuth2TokenAuthProvider(AuthProvider):
|
|
||||||
"""
|
|
||||||
JWT token authentication provider that validates a JWT token and extracts access attributes.
|
|
||||||
|
|
||||||
This should be the standard authentication provider for most use cases.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, config: OAuth2TokenAuthProviderConfig):
|
|
||||||
self.config = config
|
|
||||||
self._jwks_at: float = 0.0
|
|
||||||
self._jwks: dict[str, str] = {}
|
|
||||||
self._jwks_lock = Lock()
|
|
||||||
|
|
||||||
async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
|
|
||||||
if self.config.jwks:
|
|
||||||
return await self.validate_jwt_token(token, scope)
|
|
||||||
if self.config.introspection:
|
|
||||||
return await self.introspect_token(token, scope)
|
|
||||||
raise ValueError("One of jwks or introspection must be configured")
|
|
||||||
|
|
||||||
async def validate_jwt_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
|
|
||||||
"""Validate a token using the JWT token."""
|
|
||||||
await self._refresh_jwks()
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
header = jwt.get_unverified_header(token)
|
client = await self._get_client()
|
||||||
kid = header["kid"]
|
|
||||||
if kid not in self._jwks:
|
|
||||||
raise ValueError(f"Unknown key ID: {kid}")
|
|
||||||
key_data = self._jwks[kid]
|
|
||||||
algorithm = header.get("alg", "RS256")
|
|
||||||
claims = jwt.decode(
|
|
||||||
token,
|
|
||||||
key_data,
|
|
||||||
algorithms=[algorithm],
|
|
||||||
audience=self.config.audience,
|
|
||||||
issuer=self.config.issuer,
|
|
||||||
)
|
|
||||||
except Exception as exc:
|
|
||||||
raise ValueError(f"Invalid JWT token: {token}") from exc
|
|
||||||
|
|
||||||
# There are other standard claims, the most relevant of which is `scope`.
|
# Set the token in the client
|
||||||
# We should incorporate these into the access attributes.
|
client.set_default_header("Authorization", f"Bearer {token}")
|
||||||
principal = claims["sub"]
|
|
||||||
access_attributes = get_attributes_from_claims(claims, self.config.claims_mapping)
|
# Make a request to validate the token
|
||||||
return TokenValidationResult(
|
# We use the /api endpoint which requires authentication
|
||||||
principal=principal,
|
from kubernetes.client import CoreV1Api
|
||||||
access_attributes=access_attributes,
|
|
||||||
|
api = CoreV1Api(client)
|
||||||
|
api.get_api_resources(_request_timeout=3.0) # Set timeout for this specific request
|
||||||
|
|
||||||
|
# If we get here, the token is valid
|
||||||
|
# Extract user info from the token claims
|
||||||
|
import base64
|
||||||
|
|
||||||
|
# Decode the token (without verification since we've already validated it)
|
||||||
|
token_parts = token.split(".")
|
||||||
|
payload = json.loads(base64.b64decode(token_parts[1] + "=" * (-len(token_parts[1]) % 4)))
|
||||||
|
|
||||||
|
# Extract user information from the token
|
||||||
|
username = payload.get("sub", "")
|
||||||
|
groups = payload.get("groups", [])
|
||||||
|
|
||||||
|
return AccessAttributes(
|
||||||
|
roles=[username], # Use username as a role
|
||||||
|
teams=groups, # Use Kubernetes groups as teams
|
||||||
)
|
)
|
||||||
|
|
||||||
async def introspect_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
|
|
||||||
"""Validate a token using token introspection as defined by RFC 7662."""
|
|
||||||
form = {
|
|
||||||
"token": token,
|
|
||||||
}
|
|
||||||
if self.config.introspection is None:
|
|
||||||
raise ValueError("Introspection is not configured")
|
|
||||||
|
|
||||||
if self.config.introspection.send_secret_in_body:
|
|
||||||
form["client_id"] = self.config.introspection.client_id
|
|
||||||
form["client_secret"] = self.config.introspection.client_secret
|
|
||||||
auth = None
|
|
||||||
else:
|
|
||||||
auth = (self.config.introspection.client_id, self.config.introspection.client_secret)
|
|
||||||
ssl_ctxt = None
|
|
||||||
if self.config.tls_cafile:
|
|
||||||
ssl_ctxt = ssl.create_default_context(cafile=self.config.tls_cafile.as_posix())
|
|
||||||
try:
|
|
||||||
async with httpx.AsyncClient(verify=ssl_ctxt) as client:
|
|
||||||
response = await client.post(
|
|
||||||
self.config.introspection.url,
|
|
||||||
data=form,
|
|
||||||
auth=auth,
|
|
||||||
timeout=10.0, # Add a reasonable timeout
|
|
||||||
)
|
|
||||||
if response.status_code != 200:
|
|
||||||
logger.warning(f"Token introspection failed with status code: {response.status_code}")
|
|
||||||
raise ValueError(f"Token introspection failed: {response.status_code}")
|
|
||||||
|
|
||||||
fields = response.json()
|
|
||||||
if not fields["active"]:
|
|
||||||
raise ValueError("Token not active")
|
|
||||||
principal = fields["sub"] or fields["username"]
|
|
||||||
access_attributes = get_attributes_from_claims(fields, self.config.claims_mapping)
|
|
||||||
return TokenValidationResult(
|
|
||||||
principal=principal,
|
|
||||||
access_attributes=access_attributes,
|
|
||||||
)
|
|
||||||
except httpx.TimeoutException:
|
|
||||||
logger.exception("Token introspection request timed out")
|
|
||||||
raise
|
|
||||||
except ValueError:
|
|
||||||
# Re-raise ValueError exceptions to preserve their message
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("Error during token introspection")
|
logger.exception("Failed to validate Kubernetes token")
|
||||||
raise ValueError("Token introspection error") from e
|
raise ValueError("Invalid or expired token") from e
|
||||||
|
|
||||||
async def close(self):
|
async def close(self):
|
||||||
pass
|
"""Close the HTTP client."""
|
||||||
|
if self._client:
|
||||||
async def _refresh_jwks(self) -> None:
|
self._client.close()
|
||||||
"""
|
self._client = None
|
||||||
Refresh the JWKS cache.
|
|
||||||
|
|
||||||
This is a simple cache that expires after a certain amount of time (defined by `key_recheck_period`).
|
|
||||||
If the cache is expired, we refresh the JWKS from the JWKS URI.
|
|
||||||
|
|
||||||
Notes: for Kubernetes which doesn't fully implement the OIDC protocol:
|
|
||||||
* It doesn't have user authentication flows
|
|
||||||
* It doesn't have refresh tokens
|
|
||||||
"""
|
|
||||||
async with self._jwks_lock:
|
|
||||||
if self.config.jwks is None:
|
|
||||||
raise ValueError("JWKS is not configured")
|
|
||||||
if time.time() - self._jwks_at > self.config.jwks.key_recheck_period:
|
|
||||||
verify = self.config.tls_cafile.as_posix() if self.config.tls_cafile else self.config.verify_tls
|
|
||||||
async with httpx.AsyncClient(verify=verify) as client:
|
|
||||||
res = await client.get(self.config.jwks.uri, timeout=5)
|
|
||||||
res.raise_for_status()
|
|
||||||
jwks_data = res.json()["keys"]
|
|
||||||
updated = {}
|
|
||||||
for k in jwks_data:
|
|
||||||
kid = k["kid"]
|
|
||||||
# Store the entire key object as it may be needed for different algorithms
|
|
||||||
updated[kid] = k
|
|
||||||
self._jwks = updated
|
|
||||||
self._jwks_at = time.time()
|
|
||||||
|
|
||||||
|
|
||||||
class CustomAuthProviderConfig(BaseModel):
|
|
||||||
endpoint: str
|
|
||||||
|
|
||||||
|
|
||||||
class CustomAuthProvider(AuthProvider):
|
class CustomAuthProvider(AuthProvider):
|
||||||
"""Custom authentication provider that uses an external endpoint."""
|
"""Custom authentication provider that uses an external endpoint."""
|
||||||
|
|
||||||
def __init__(self, config: CustomAuthProviderConfig):
|
def __init__(self, config: dict[str, str]):
|
||||||
self.config = config
|
self.endpoint = config["endpoint"]
|
||||||
self._client = None
|
self._client = None
|
||||||
|
|
||||||
async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
|
async def validate_token(self, token: str, scope: dict | None = None) -> AccessAttributes | None:
|
||||||
"""Validate a token using the custom authentication endpoint."""
|
"""Validate a token using the custom authentication endpoint."""
|
||||||
|
if not self.endpoint:
|
||||||
|
raise ValueError("Authentication endpoint not configured")
|
||||||
|
|
||||||
if scope is None:
|
if scope is None:
|
||||||
scope = {}
|
scope = {}
|
||||||
|
|
||||||
|
@ -329,7 +202,7 @@ class CustomAuthProvider(AuthProvider):
|
||||||
try:
|
try:
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
self.config.endpoint,
|
self.endpoint,
|
||||||
json=auth_request.model_dump(),
|
json=auth_request.model_dump(),
|
||||||
timeout=10.0, # Add a reasonable timeout
|
timeout=10.0, # Add a reasonable timeout
|
||||||
)
|
)
|
||||||
|
@ -341,7 +214,19 @@ class CustomAuthProvider(AuthProvider):
|
||||||
try:
|
try:
|
||||||
response_data = response.json()
|
response_data = response.json()
|
||||||
auth_response = AuthResponse(**response_data)
|
auth_response = AuthResponse(**response_data)
|
||||||
return auth_response
|
|
||||||
|
# Store attributes in request scope for access control
|
||||||
|
if auth_response.access_attributes:
|
||||||
|
return auth_response.access_attributes
|
||||||
|
else:
|
||||||
|
logger.warning("No access attributes, setting namespace to api_key by default")
|
||||||
|
user_attributes = {
|
||||||
|
"namespaces": [token],
|
||||||
|
}
|
||||||
|
|
||||||
|
scope["user_attributes"] = user_attributes
|
||||||
|
logger.debug(f"Authentication successful: {len(user_attributes)} attributes")
|
||||||
|
return auth_response.access_attributes
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("Error parsing authentication response")
|
logger.exception("Error parsing authentication response")
|
||||||
raise ValueError("Invalid authentication response format") from e
|
raise ValueError("Invalid authentication response format") from e
|
||||||
|
@ -363,14 +248,14 @@ class CustomAuthProvider(AuthProvider):
|
||||||
self._client = None
|
self._client = None
|
||||||
|
|
||||||
|
|
||||||
def create_auth_provider(config: AuthenticationConfig) -> AuthProvider:
|
def create_auth_provider(config: AuthProviderConfig) -> AuthProvider:
|
||||||
"""Factory function to create the appropriate auth provider."""
|
"""Factory function to create the appropriate auth provider."""
|
||||||
provider_type = config.provider_type.lower()
|
provider_type = config.provider_type.lower()
|
||||||
|
|
||||||
if provider_type == "custom":
|
if provider_type == "kubernetes":
|
||||||
return CustomAuthProvider(CustomAuthProviderConfig.model_validate(config.config))
|
return KubernetesAuthProvider(config.config)
|
||||||
elif provider_type == "oauth2_token":
|
elif provider_type == "custom":
|
||||||
return OAuth2TokenAuthProvider(OAuth2TokenAuthProviderConfig.model_validate(config.config))
|
return CustomAuthProvider(config.config)
|
||||||
else:
|
else:
|
||||||
supported_providers = ", ".join([t.value for t in AuthProviderType])
|
supported_providers = ", ".join([t.value for t in AuthProviderType])
|
||||||
raise ValueError(f"Unsupported auth provider type: {provider_type}. Supported types are: {supported_providers}")
|
raise ValueError(f"Unsupported auth provider type: {provider_type}. Supported types are: {supported_providers}")
|
||||||
|
|
|
@ -6,23 +6,20 @@
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
import re
|
import re
|
||||||
from collections.abc import Callable
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from aiohttp import hdrs
|
from pydantic import BaseModel
|
||||||
from starlette.routing import Route
|
|
||||||
|
|
||||||
from llama_stack.apis.tools import RAGToolRuntime, SpecialToolGroup
|
from llama_stack.apis.tools import RAGToolRuntime, SpecialToolGroup
|
||||||
from llama_stack.apis.version import LLAMA_STACK_API_VERSION
|
from llama_stack.apis.version import LLAMA_STACK_API_VERSION
|
||||||
from llama_stack.distribution.resolver import api_protocol_map
|
from llama_stack.distribution.resolver import api_protocol_map
|
||||||
from llama_stack.providers.datatypes import Api
|
from llama_stack.providers.datatypes import Api
|
||||||
|
|
||||||
EndpointFunc = Callable[..., Any]
|
|
||||||
PathParams = dict[str, str]
|
class ApiEndpoint(BaseModel):
|
||||||
RouteInfo = tuple[EndpointFunc, str]
|
route: str
|
||||||
PathImpl = dict[str, RouteInfo]
|
method: str
|
||||||
RouteImpls = dict[str, PathImpl]
|
name: str
|
||||||
RouteMatch = tuple[EndpointFunc, PathParams, str]
|
descriptive_name: str | None = None
|
||||||
|
|
||||||
|
|
||||||
def toolgroup_protocol_map():
|
def toolgroup_protocol_map():
|
||||||
|
@ -31,13 +28,13 @@ def toolgroup_protocol_map():
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def get_all_api_routes() -> dict[Api, list[Route]]:
|
def get_all_api_endpoints() -> dict[Api, list[ApiEndpoint]]:
|
||||||
apis = {}
|
apis = {}
|
||||||
|
|
||||||
protocols = api_protocol_map()
|
protocols = api_protocol_map()
|
||||||
toolgroup_protocols = toolgroup_protocol_map()
|
toolgroup_protocols = toolgroup_protocol_map()
|
||||||
for api, protocol in protocols.items():
|
for api, protocol in protocols.items():
|
||||||
routes = []
|
endpoints = []
|
||||||
protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction)
|
protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction)
|
||||||
|
|
||||||
# HACK ALERT
|
# HACK ALERT
|
||||||
|
@ -54,28 +51,26 @@ def get_all_api_routes() -> dict[Api, list[Route]]:
|
||||||
if not hasattr(method, "__webmethod__"):
|
if not hasattr(method, "__webmethod__"):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# The __webmethod__ attribute is dynamically added by the @webmethod decorator
|
webmethod = method.__webmethod__
|
||||||
# mypy doesn't know about this dynamic attribute, so we ignore the attr-defined error
|
route = f"/{LLAMA_STACK_API_VERSION}/{webmethod.route.lstrip('/')}"
|
||||||
webmethod = method.__webmethod__ # type: ignore[attr-defined]
|
if webmethod.method == "GET":
|
||||||
path = f"/{LLAMA_STACK_API_VERSION}/{webmethod.route.lstrip('/')}"
|
method = "get"
|
||||||
if webmethod.method == hdrs.METH_GET:
|
elif webmethod.method == "DELETE":
|
||||||
http_method = hdrs.METH_GET
|
method = "delete"
|
||||||
elif webmethod.method == hdrs.METH_DELETE:
|
|
||||||
http_method = hdrs.METH_DELETE
|
|
||||||
else:
|
else:
|
||||||
http_method = hdrs.METH_POST
|
method = "post"
|
||||||
routes.append(
|
endpoints.append(
|
||||||
Route(path=path, methods=[http_method], name=name, endpoint=None)
|
ApiEndpoint(route=route, method=method, name=name, descriptive_name=webmethod.descriptive_name)
|
||||||
) # setting endpoint to None since don't use a Router object
|
)
|
||||||
|
|
||||||
apis[api] = routes
|
apis[api] = endpoints
|
||||||
|
|
||||||
return apis
|
return apis
|
||||||
|
|
||||||
|
|
||||||
def initialize_route_impls(impls: dict[Api, Any]) -> RouteImpls:
|
def initialize_endpoint_impls(impls):
|
||||||
routes = get_all_api_routes()
|
endpoints = get_all_api_endpoints()
|
||||||
route_impls: RouteImpls = {}
|
endpoint_impls = {}
|
||||||
|
|
||||||
def _convert_path_to_regex(path: str) -> str:
|
def _convert_path_to_regex(path: str) -> str:
|
||||||
# Convert {param} to named capture groups
|
# Convert {param} to named capture groups
|
||||||
|
@ -88,34 +83,29 @@ def initialize_route_impls(impls: dict[Api, Any]) -> RouteImpls:
|
||||||
|
|
||||||
return f"^{pattern}$"
|
return f"^{pattern}$"
|
||||||
|
|
||||||
for api, api_routes in routes.items():
|
for api, api_endpoints in endpoints.items():
|
||||||
if api not in impls:
|
if api not in impls:
|
||||||
continue
|
continue
|
||||||
for route in api_routes:
|
for endpoint in api_endpoints:
|
||||||
impl = impls[api]
|
impl = impls[api]
|
||||||
func = getattr(impl, route.name)
|
func = getattr(impl, endpoint.name)
|
||||||
# Get the first (and typically only) method from the set, filtering out HEAD
|
if endpoint.method not in endpoint_impls:
|
||||||
available_methods = [m for m in route.methods if m != "HEAD"]
|
endpoint_impls[endpoint.method] = {}
|
||||||
if not available_methods:
|
endpoint_impls[endpoint.method][_convert_path_to_regex(endpoint.route)] = (
|
||||||
continue # Skip if only HEAD method is available
|
|
||||||
method = available_methods[0].lower()
|
|
||||||
if method not in route_impls:
|
|
||||||
route_impls[method] = {}
|
|
||||||
route_impls[method][_convert_path_to_regex(route.path)] = (
|
|
||||||
func,
|
func,
|
||||||
route.path,
|
endpoint.descriptive_name or endpoint.route,
|
||||||
)
|
)
|
||||||
|
|
||||||
return route_impls
|
return endpoint_impls
|
||||||
|
|
||||||
|
|
||||||
def find_matching_route(method: str, path: str, route_impls: RouteImpls) -> RouteMatch:
|
def find_matching_endpoint(method, path, endpoint_impls):
|
||||||
"""Find the matching endpoint implementation for a given method and path.
|
"""Find the matching endpoint implementation for a given method and path.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
method: HTTP method (GET, POST, etc.)
|
method: HTTP method (GET, POST, etc.)
|
||||||
path: URL path to match against
|
path: URL path to match against
|
||||||
route_impls: A dictionary of endpoint implementations
|
endpoint_impls: A dictionary of endpoint implementations
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A tuple of (endpoint_function, path_params, descriptive_name)
|
A tuple of (endpoint_function, path_params, descriptive_name)
|
||||||
|
@ -123,7 +113,7 @@ def find_matching_route(method: str, path: str, route_impls: RouteImpls) -> Rout
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If no matching endpoint is found
|
ValueError: If no matching endpoint is found
|
||||||
"""
|
"""
|
||||||
impls = route_impls.get(method.lower())
|
impls = endpoint_impls.get(method.lower())
|
||||||
if not impls:
|
if not impls:
|
||||||
raise ValueError(f"No endpoint found for {path}")
|
raise ValueError(f"No endpoint found for {path}")
|
||||||
|
|
|
@ -1,110 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
import json
|
|
||||||
import time
|
|
||||||
from datetime import datetime, timedelta, timezone
|
|
||||||
|
|
||||||
from starlette.types import ASGIApp, Receive, Scope, Send
|
|
||||||
|
|
||||||
from llama_stack.log import get_logger
|
|
||||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
|
||||||
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
|
||||||
from llama_stack.providers.utils.kvstore.kvstore import kvstore_impl
|
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="quota")
|
|
||||||
|
|
||||||
|
|
||||||
class QuotaMiddleware:
|
|
||||||
"""
|
|
||||||
ASGI middleware that enforces separate quotas for authenticated and anonymous clients
|
|
||||||
within a configurable time window.
|
|
||||||
|
|
||||||
- For authenticated requests, it reads the client ID from the
|
|
||||||
`Authorization: Bearer <client_id>` header.
|
|
||||||
- For anonymous requests, it falls back to the IP address of the client.
|
|
||||||
Requests are counted in a KV store (e.g., SQLite), and HTTP 429 is returned
|
|
||||||
once a client exceeds its quota.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
app: ASGIApp,
|
|
||||||
kv_config: KVStoreConfig,
|
|
||||||
anonymous_max_requests: int,
|
|
||||||
authenticated_max_requests: int,
|
|
||||||
window_seconds: int = 86400,
|
|
||||||
):
|
|
||||||
self.app = app
|
|
||||||
self.kv_config = kv_config
|
|
||||||
self.kv: KVStore | None = None
|
|
||||||
self.anonymous_max_requests = anonymous_max_requests
|
|
||||||
self.authenticated_max_requests = authenticated_max_requests
|
|
||||||
self.window_seconds = window_seconds
|
|
||||||
|
|
||||||
if isinstance(self.kv_config, SqliteKVStoreConfig):
|
|
||||||
logger.warning(
|
|
||||||
"QuotaMiddleware: Using SQLite backend. Expiry/TTL is not enforced; cleanup is manual. "
|
|
||||||
f"window_seconds={self.window_seconds}"
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _get_kv(self) -> KVStore:
|
|
||||||
if self.kv is None:
|
|
||||||
self.kv = await kvstore_impl(self.kv_config)
|
|
||||||
return self.kv
|
|
||||||
|
|
||||||
async def __call__(self, scope: Scope, receive: Receive, send: Send):
|
|
||||||
if scope["type"] == "http":
|
|
||||||
# pick key & limit based on auth
|
|
||||||
auth_id = scope.get("authenticated_client_id")
|
|
||||||
if auth_id:
|
|
||||||
key_id = auth_id
|
|
||||||
limit = self.authenticated_max_requests
|
|
||||||
else:
|
|
||||||
# fallback to IP
|
|
||||||
client = scope.get("client")
|
|
||||||
key_id = client[0] if client else "anonymous"
|
|
||||||
limit = self.anonymous_max_requests
|
|
||||||
|
|
||||||
current_window = int(time.time() // self.window_seconds)
|
|
||||||
key = f"quota:{key_id}:{current_window}"
|
|
||||||
|
|
||||||
try:
|
|
||||||
kv = await self._get_kv()
|
|
||||||
prev = await kv.get(key) or "0"
|
|
||||||
count = int(prev) + 1
|
|
||||||
|
|
||||||
if int(prev) == 0:
|
|
||||||
# Set with expiration datetime when it is the first request in the window.
|
|
||||||
expiration = datetime.now(timezone.utc) + timedelta(seconds=self.window_seconds)
|
|
||||||
await kv.set(key, str(count), expiration=expiration)
|
|
||||||
else:
|
|
||||||
await kv.set(key, str(count))
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Failed to access KV store for quota")
|
|
||||||
return await self._send_error(send, 500, "Quota service error")
|
|
||||||
|
|
||||||
if count > limit:
|
|
||||||
logger.warning(
|
|
||||||
"Quota exceeded for client %s: %d/%d",
|
|
||||||
key_id,
|
|
||||||
count,
|
|
||||||
limit,
|
|
||||||
)
|
|
||||||
return await self._send_error(send, 429, "Quota exceeded")
|
|
||||||
|
|
||||||
return await self.app(scope, receive, send)
|
|
||||||
|
|
||||||
async def _send_error(self, send: Send, status: int, message: str):
|
|
||||||
await send(
|
|
||||||
{
|
|
||||||
"type": "http.response.start",
|
|
||||||
"status": status,
|
|
||||||
"headers": [[b"content-type", b"application/json"]],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
body = json.dumps({"error": {"message": message}}).encode()
|
|
||||||
await send({"type": "http.response.body", "body": body})
|
|
|
@ -6,7 +6,6 @@
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import asyncio
|
import asyncio
|
||||||
import functools
|
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
@ -14,7 +13,6 @@ import ssl
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
import warnings
|
import warnings
|
||||||
from collections.abc import Callable
|
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from importlib.metadata import version as parse_version
|
from importlib.metadata import version as parse_version
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
@ -22,26 +20,23 @@ from typing import Annotated, Any
|
||||||
|
|
||||||
import rich.pretty
|
import rich.pretty
|
||||||
import yaml
|
import yaml
|
||||||
from aiohttp import hdrs
|
|
||||||
from fastapi import Body, FastAPI, HTTPException, Request
|
from fastapi import Body, FastAPI, HTTPException, Request
|
||||||
from fastapi import Path as FastapiPath
|
from fastapi import Path as FastapiPath
|
||||||
from fastapi.exceptions import RequestValidationError
|
from fastapi.exceptions import RequestValidationError
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
|
||||||
from fastapi.responses import JSONResponse, StreamingResponse
|
from fastapi.responses import JSONResponse, StreamingResponse
|
||||||
from openai import BadRequestError
|
from openai import BadRequestError
|
||||||
from pydantic import BaseModel, ValidationError
|
from pydantic import BaseModel, ValidationError
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import AuthenticationRequiredError, LoggingConfig, StackRunConfig
|
from llama_stack.distribution.datatypes import LoggingConfig, StackRunConfig
|
||||||
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
||||||
from llama_stack.distribution.request_headers import (
|
from llama_stack.distribution.request_headers import (
|
||||||
PROVIDER_DATA_VAR,
|
PROVIDER_DATA_VAR,
|
||||||
request_provider_data_context,
|
request_provider_data_context,
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.resolver import InvalidProviderError
|
from llama_stack.distribution.resolver import InvalidProviderError
|
||||||
from llama_stack.distribution.server.routes import (
|
from llama_stack.distribution.server.endpoints import (
|
||||||
find_matching_route,
|
find_matching_endpoint,
|
||||||
get_all_api_routes,
|
initialize_endpoint_impls,
|
||||||
initialize_route_impls,
|
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.stack import (
|
from llama_stack.distribution.stack import (
|
||||||
construct_stack,
|
construct_stack,
|
||||||
|
@ -64,7 +59,7 @@ from llama_stack.providers.utils.telemetry.tracing import (
|
||||||
)
|
)
|
||||||
|
|
||||||
from .auth import AuthenticationMiddleware
|
from .auth import AuthenticationMiddleware
|
||||||
from .quota import QuotaMiddleware
|
from .endpoints import get_all_api_endpoints
|
||||||
|
|
||||||
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
||||||
|
|
||||||
|
@ -125,8 +120,6 @@ def translate_exception(exc: Exception) -> HTTPException | RequestValidationErro
|
||||||
return HTTPException(status_code=504, detail=f"Operation timed out: {str(exc)}")
|
return HTTPException(status_code=504, detail=f"Operation timed out: {str(exc)}")
|
||||||
elif isinstance(exc, NotImplementedError):
|
elif isinstance(exc, NotImplementedError):
|
||||||
return HTTPException(status_code=501, detail=f"Not implemented: {str(exc)}")
|
return HTTPException(status_code=501, detail=f"Not implemented: {str(exc)}")
|
||||||
elif isinstance(exc, AuthenticationRequiredError):
|
|
||||||
return HTTPException(status_code=401, detail=f"Authentication required: {str(exc)}")
|
|
||||||
else:
|
else:
|
||||||
return HTTPException(
|
return HTTPException(
|
||||||
status_code=500,
|
status_code=500,
|
||||||
|
@ -212,9 +205,8 @@ async def log_request_pre_validation(request: Request):
|
||||||
logger.warning(f"Could not read or log request body for {request.method} {request.url.path}: {e}")
|
logger.warning(f"Could not read or log request body for {request.method} {request.url.path}: {e}")
|
||||||
|
|
||||||
|
|
||||||
def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
|
def create_dynamic_typed_route(func: Any, method: str, route: str):
|
||||||
@functools.wraps(func)
|
async def endpoint(request: Request, **kwargs):
|
||||||
async def route_handler(request: Request, **kwargs):
|
|
||||||
# Get auth attributes from the request scope
|
# Get auth attributes from the request scope
|
||||||
user_attributes = request.scope.get("user_attributes", {})
|
user_attributes = request.scope.get("user_attributes", {})
|
||||||
|
|
||||||
|
@ -254,9 +246,9 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
|
||||||
for param in new_params[1:]
|
for param in new_params[1:]
|
||||||
]
|
]
|
||||||
|
|
||||||
route_handler.__signature__ = sig.replace(parameters=new_params)
|
endpoint.__signature__ = sig.replace(parameters=new_params)
|
||||||
|
|
||||||
return route_handler
|
return endpoint
|
||||||
|
|
||||||
|
|
||||||
class TracingMiddleware:
|
class TracingMiddleware:
|
||||||
|
@ -278,28 +270,17 @@ class TracingMiddleware:
|
||||||
logger.debug(f"Bypassing custom routing for FastAPI built-in path: {path}")
|
logger.debug(f"Bypassing custom routing for FastAPI built-in path: {path}")
|
||||||
return await self.app(scope, receive, send)
|
return await self.app(scope, receive, send)
|
||||||
|
|
||||||
if not hasattr(self, "route_impls"):
|
if not hasattr(self, "endpoint_impls"):
|
||||||
self.route_impls = initialize_route_impls(self.impls)
|
self.endpoint_impls = initialize_endpoint_impls(self.impls)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
_, _, trace_path = find_matching_route(scope.get("method", hdrs.METH_GET), path, self.route_impls)
|
_, _, trace_path = find_matching_endpoint(scope.get("method", "GET"), path, self.endpoint_impls)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
# If no matching endpoint is found, pass through to FastAPI
|
# If no matching endpoint is found, pass through to FastAPI
|
||||||
logger.debug(f"No matching route found for path: {path}, falling back to FastAPI")
|
logger.debug(f"No matching endpoint found for path: {path}, falling back to FastAPI")
|
||||||
return await self.app(scope, receive, send)
|
return await self.app(scope, receive, send)
|
||||||
|
|
||||||
trace_attributes = {"__location__": "server", "raw_path": path}
|
trace_context = await start_trace(trace_path, {"__location__": "server", "raw_path": path})
|
||||||
|
|
||||||
# Extract W3C trace context headers and store as trace attributes
|
|
||||||
headers = dict(scope.get("headers", []))
|
|
||||||
traceparent = headers.get(b"traceparent", b"").decode()
|
|
||||||
if traceparent:
|
|
||||||
trace_attributes["traceparent"] = traceparent
|
|
||||||
tracestate = headers.get(b"tracestate", b"").decode()
|
|
||||||
if tracestate:
|
|
||||||
trace_attributes["tracestate"] = tracestate
|
|
||||||
|
|
||||||
trace_context = await start_trace(trace_path, trace_attributes)
|
|
||||||
|
|
||||||
async def send_with_trace_id(message):
|
async def send_with_trace_id(message):
|
||||||
if message["type"] == "http.response.start":
|
if message["type"] == "http.response.start":
|
||||||
|
@ -389,6 +370,14 @@ def main(args: argparse.Namespace | None = None):
|
||||||
if args is None:
|
if args is None:
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Check for deprecated argument usage
|
||||||
|
if "--yaml-config" in sys.argv:
|
||||||
|
warnings.warn(
|
||||||
|
"The '--yaml-config' argument is deprecated and will be removed in a future version. Use '--config' instead.",
|
||||||
|
DeprecationWarning,
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
|
||||||
log_line = ""
|
log_line = ""
|
||||||
if args.config:
|
if args.config:
|
||||||
# if the user provided a config file, use it, even if template was specified
|
# if the user provided a config file, use it, even if template was specified
|
||||||
|
@ -402,7 +391,7 @@ def main(args: argparse.Namespace | None = None):
|
||||||
raise ValueError(f"Template {args.template} does not exist")
|
raise ValueError(f"Template {args.template} does not exist")
|
||||||
log_line = f"Using template {args.template} config file: {config_file}"
|
log_line = f"Using template {args.template} config file: {config_file}"
|
||||||
else:
|
else:
|
||||||
raise ValueError("Either --config or --template must be provided")
|
raise ValueError("Either --yaml-config or --template must be provided")
|
||||||
|
|
||||||
logger_config = None
|
logger_config = None
|
||||||
with open(config_file) as fp:
|
with open(config_file) as fp:
|
||||||
|
@ -442,46 +431,6 @@ def main(args: argparse.Namespace | None = None):
|
||||||
if config.server.auth:
|
if config.server.auth:
|
||||||
logger.info(f"Enabling authentication with provider: {config.server.auth.provider_type.value}")
|
logger.info(f"Enabling authentication with provider: {config.server.auth.provider_type.value}")
|
||||||
app.add_middleware(AuthenticationMiddleware, auth_config=config.server.auth)
|
app.add_middleware(AuthenticationMiddleware, auth_config=config.server.auth)
|
||||||
else:
|
|
||||||
if config.server.quota:
|
|
||||||
quota = config.server.quota
|
|
||||||
logger.warning(
|
|
||||||
"Configured authenticated_max_requests (%d) but no auth is enabled; "
|
|
||||||
"falling back to anonymous_max_requests (%d) for all the requests",
|
|
||||||
quota.authenticated_max_requests,
|
|
||||||
quota.anonymous_max_requests,
|
|
||||||
)
|
|
||||||
|
|
||||||
if config.server.quota:
|
|
||||||
logger.info("Enabling quota middleware for authenticated and anonymous clients")
|
|
||||||
|
|
||||||
quota = config.server.quota
|
|
||||||
anonymous_max_requests = quota.anonymous_max_requests
|
|
||||||
# if auth is disabled, use the anonymous max requests
|
|
||||||
authenticated_max_requests = quota.authenticated_max_requests if config.server.auth else anonymous_max_requests
|
|
||||||
|
|
||||||
kv_config = quota.kvstore
|
|
||||||
window_map = {"day": 86400}
|
|
||||||
window_seconds = window_map[quota.period.value]
|
|
||||||
|
|
||||||
app.add_middleware(
|
|
||||||
QuotaMiddleware,
|
|
||||||
kv_config=kv_config,
|
|
||||||
anonymous_max_requests=anonymous_max_requests,
|
|
||||||
authenticated_max_requests=authenticated_max_requests,
|
|
||||||
window_seconds=window_seconds,
|
|
||||||
)
|
|
||||||
|
|
||||||
# --- CORS middleware for local development ---
|
|
||||||
# TODO: move to reverse proxy
|
|
||||||
ui_port = os.environ.get("LLAMA_STACK_UI_PORT", 8322)
|
|
||||||
app.add_middleware(
|
|
||||||
CORSMiddleware,
|
|
||||||
allow_origins=[f"http://localhost:{ui_port}"],
|
|
||||||
allow_credentials=True,
|
|
||||||
allow_methods=["*"],
|
|
||||||
allow_headers=["*"],
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
impls = asyncio.run(construct_stack(config))
|
impls = asyncio.run(construct_stack(config))
|
||||||
|
@ -494,7 +443,7 @@ def main(args: argparse.Namespace | None = None):
|
||||||
else:
|
else:
|
||||||
setup_logger(TelemetryAdapter(TelemetryConfig(), {}))
|
setup_logger(TelemetryAdapter(TelemetryConfig(), {}))
|
||||||
|
|
||||||
all_routes = get_all_api_routes()
|
all_endpoints = get_all_api_endpoints()
|
||||||
|
|
||||||
if config.apis:
|
if config.apis:
|
||||||
apis_to_serve = set(config.apis)
|
apis_to_serve = set(config.apis)
|
||||||
|
@ -512,29 +461,24 @@ def main(args: argparse.Namespace | None = None):
|
||||||
for api_str in apis_to_serve:
|
for api_str in apis_to_serve:
|
||||||
api = Api(api_str)
|
api = Api(api_str)
|
||||||
|
|
||||||
routes = all_routes[api]
|
endpoints = all_endpoints[api]
|
||||||
impl = impls[api]
|
impl = impls[api]
|
||||||
|
|
||||||
for route in routes:
|
for endpoint in endpoints:
|
||||||
if not hasattr(impl, route.name):
|
if not hasattr(impl, endpoint.name):
|
||||||
# ideally this should be a typing violation already
|
# ideally this should be a typing violation already
|
||||||
raise ValueError(f"Could not find method {route.name} on {impl}!")
|
raise ValueError(f"Could not find method {endpoint.name} on {impl}!!")
|
||||||
|
|
||||||
impl_method = getattr(impl, route.name)
|
impl_method = getattr(impl, endpoint.name)
|
||||||
# Filter out HEAD method since it's automatically handled by FastAPI for GET routes
|
logger.debug(f"{endpoint.method.upper()} {endpoint.route}")
|
||||||
available_methods = [m for m in route.methods if m != "HEAD"]
|
|
||||||
if not available_methods:
|
|
||||||
raise ValueError(f"No methods found for {route.name} on {impl}")
|
|
||||||
method = available_methods[0]
|
|
||||||
logger.debug(f"{method} {route.path}")
|
|
||||||
|
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.filterwarnings("ignore", category=UserWarning, module="pydantic._internal._fields")
|
warnings.filterwarnings("ignore", category=UserWarning, module="pydantic._internal._fields")
|
||||||
getattr(app, method.lower())(route.path, response_model=None)(
|
getattr(app, endpoint.method)(endpoint.route, response_model=None)(
|
||||||
create_dynamic_typed_route(
|
create_dynamic_typed_route(
|
||||||
impl_method,
|
impl_method,
|
||||||
method.lower(),
|
endpoint.method,
|
||||||
route.path,
|
endpoint.route,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -54,7 +54,7 @@ other_args=""
|
||||||
# Process remaining arguments
|
# Process remaining arguments
|
||||||
while [[ $# -gt 0 ]]; do
|
while [[ $# -gt 0 ]]; do
|
||||||
case "$1" in
|
case "$1" in
|
||||||
--config)
|
--config|--yaml-config)
|
||||||
if [[ -n "$2" ]]; then
|
if [[ -n "$2" ]]; then
|
||||||
yaml_config="$2"
|
yaml_config="$2"
|
||||||
shift 2
|
shift 2
|
||||||
|
@ -121,7 +121,7 @@ if [[ "$env_type" == "venv" || "$env_type" == "conda" ]]; then
|
||||||
set -x
|
set -x
|
||||||
|
|
||||||
if [ -n "$yaml_config" ]; then
|
if [ -n "$yaml_config" ]; then
|
||||||
yaml_config_arg="--config $yaml_config"
|
yaml_config_arg="--yaml-config $yaml_config"
|
||||||
else
|
else
|
||||||
yaml_config_arg=""
|
yaml_config_arg=""
|
||||||
fi
|
fi
|
||||||
|
@ -181,9 +181,9 @@ elif [[ "$env_type" == "container" ]]; then
|
||||||
|
|
||||||
# Add yaml config if provided, otherwise use default
|
# Add yaml config if provided, otherwise use default
|
||||||
if [ -n "$yaml_config" ]; then
|
if [ -n "$yaml_config" ]; then
|
||||||
cmd="$cmd -v $yaml_config:/app/run.yaml --config /app/run.yaml"
|
cmd="$cmd -v $yaml_config:/app/run.yaml --yaml-config /app/run.yaml"
|
||||||
else
|
else
|
||||||
cmd="$cmd --config /app/run.yaml"
|
cmd="$cmd --yaml-config /app/run.yaml"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Add any other args
|
# Add any other args
|
||||||
|
|
|
@ -36,7 +36,7 @@ class DistributionRegistry(Protocol):
|
||||||
|
|
||||||
|
|
||||||
REGISTER_PREFIX = "distributions:registry"
|
REGISTER_PREFIX = "distributions:registry"
|
||||||
KEY_VERSION = "v9"
|
KEY_VERSION = "v8"
|
||||||
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"
|
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -5,8 +5,7 @@ FROM python:3.12-slim
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
COPY . /app/
|
COPY . /app/
|
||||||
RUN /usr/local/bin/python -m pip install --upgrade pip && \
|
RUN /usr/local/bin/python -m pip install --upgrade pip && \
|
||||||
/usr/local/bin/pip3 install -r requirements.txt && \
|
/usr/local/bin/pip3 install -r requirements.txt
|
||||||
/usr/local/bin/pip3 install -r llama_stack/distribution/ui/requirements.txt
|
|
||||||
EXPOSE 8501
|
EXPOSE 8501
|
||||||
|
|
||||||
ENTRYPOINT ["streamlit", "run", "llama_stack/distribution/ui/app.py", "--server.port=8501", "--server.address=0.0.0.0"]
|
ENTRYPOINT ["streamlit", "run", "app.py", "--server.port=8501", "--server.address=0.0.0.0"]
|
||||||
|
|
|
@ -48,6 +48,3 @@ uv run --with ".[ui]" streamlit run llama_stack/distribution/ui/app.py
|
||||||
| TOGETHER_API_KEY | API key for Together provider | (empty string) |
|
| TOGETHER_API_KEY | API key for Together provider | (empty string) |
|
||||||
| SAMBANOVA_API_KEY | API key for SambaNova provider | (empty string) |
|
| SAMBANOVA_API_KEY | API key for SambaNova provider | (empty string) |
|
||||||
| OPENAI_API_KEY | API key for OpenAI provider | (empty string) |
|
| OPENAI_API_KEY | API key for OpenAI provider | (empty string) |
|
||||||
| KEYCLOAK_URL | URL for keycloak authentication | (empty string) |
|
|
||||||
| KEYCLOAK_REALM | Keycloak realm | default |
|
|
||||||
| KEYCLOAK_CLIENT_ID | Client ID for keycloak auth | (empty string) |
|
|
|
@ -50,42 +50,6 @@ def main():
|
||||||
)
|
)
|
||||||
pg.run()
|
pg.run()
|
||||||
|
|
||||||
def main2():
|
|
||||||
from dataclasses import asdict
|
|
||||||
st.subheader(f"Welcome {keycloak.user_info['preferred_username']}!")
|
|
||||||
st.write(f"Here is your user information:")
|
|
||||||
st.write(asdict(keycloak))
|
|
||||||
|
|
||||||
def get_access_token() -> str|None:
|
|
||||||
return st.session_state.get('access_token')
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
from streamlit_keycloak import login
|
|
||||||
import os
|
|
||||||
|
|
||||||
keycloak_url = os.environ.get("KEYCLOAK_URL")
|
|
||||||
keycloak_realm = os.environ.get("KEYCLOAK_REALM", "default")
|
|
||||||
keycloak_client_id = os.environ.get("KEYCLOAK_CLIENT_ID")
|
|
||||||
|
|
||||||
if keycloak_url and keycloak_client_id:
|
|
||||||
keycloak = login(
|
|
||||||
url=keycloak_url,
|
|
||||||
realm=keycloak_realm,
|
|
||||||
client_id=keycloak_client_id,
|
|
||||||
custom_labels={
|
|
||||||
"labelButton": "Sign in to kvant",
|
|
||||||
"labelLogin": "Please sign in to your kvant account.",
|
|
||||||
"errorNoPopup": "Unable to open the authentication popup. Allow popups and refresh the page to proceed.",
|
|
||||||
"errorPopupClosed": "Authentication popup was closed manually.",
|
|
||||||
"errorFatal": "Unable to connect to Keycloak using the current configuration."
|
|
||||||
},
|
|
||||||
auto_refresh=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
if keycloak.authenticated:
|
|
||||||
st.session_state['access_token'] = keycloak.access_token
|
|
||||||
main()
|
|
||||||
# TBD - add other authentications
|
|
||||||
else:
|
|
||||||
main()
|
main()
|
||||||
|
|
|
@ -7,13 +7,11 @@
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from llama_stack_client import LlamaStackClient
|
from llama_stack_client import LlamaStackClient
|
||||||
from llama_stack.distribution.ui.app import get_access_token
|
|
||||||
|
|
||||||
|
|
||||||
class LlamaStackApi:
|
class LlamaStackApi:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.client = LlamaStackClient(
|
self.client = LlamaStackClient(
|
||||||
api_key=get_access_token(),
|
|
||||||
base_url=os.environ.get("LLAMA_STACK_ENDPOINT", "http://localhost:8321"),
|
base_url=os.environ.get("LLAMA_STACK_ENDPOINT", "http://localhost:8321"),
|
||||||
provider_data={
|
provider_data={
|
||||||
"fireworks_api_key": os.environ.get("FIREWORKS_API_KEY", ""),
|
"fireworks_api_key": os.environ.get("FIREWORKS_API_KEY", ""),
|
||||||
|
@ -30,3 +28,5 @@ class LlamaStackApi:
|
||||||
scoring_params = {fn_id: None for fn_id in scoring_function_ids}
|
scoring_params = {fn_id: None for fn_id in scoring_function_ids}
|
||||||
return self.client.scoring.score(input_rows=[row], scoring_functions=scoring_params)
|
return self.client.scoring.score(input_rows=[row], scoring_functions=scoring_params)
|
||||||
|
|
||||||
|
|
||||||
|
llama_stack_api = LlamaStackApi()
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue