mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-20 11:47:00 +00:00
Merge branch 'main' into nvidia-e2e-notebook
This commit is contained in:
commit
f5cb965f0f
226 changed files with 16519 additions and 8666 deletions
10
.github/PULL_REQUEST_TEMPLATE.md
vendored
10
.github/PULL_REQUEST_TEMPLATE.md
vendored
|
@ -1,10 +1,8 @@
|
||||||
# What does this PR do?
|
# What does this PR do?
|
||||||
[Provide a short summary of what this PR does and why. Link to relevant issues if applicable.]
|
<!-- Provide a short summary of what this PR does and why. Link to relevant issues if applicable. -->
|
||||||
|
|
||||||
[//]: # (If resolving an issue, uncomment and update the line below)
|
<!-- If resolving an issue, uncomment and update the line below -->
|
||||||
[//]: # (Closes #[issue-number])
|
<!-- Closes #[issue-number] -->
|
||||||
|
|
||||||
## Test Plan
|
## Test Plan
|
||||||
[Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.*]
|
<!-- Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.* -->
|
||||||
|
|
||||||
[//]: # (## Documentation)
|
|
||||||
|
|
22
.github/actions/setup-runner/action.yml
vendored
Normal file
22
.github/actions/setup-runner/action.yml
vendored
Normal file
|
@ -0,0 +1,22 @@
|
||||||
|
name: Setup runner
|
||||||
|
description: Prepare a runner for the tests (install uv, python, project dependencies, etc.)
|
||||||
|
runs:
|
||||||
|
using: "composite"
|
||||||
|
steps:
|
||||||
|
- name: Install uv
|
||||||
|
uses: astral-sh/setup-uv@6b9c6063abd6010835644d4c2e1bef4cf5cd0fca # v6.0.1
|
||||||
|
with:
|
||||||
|
python-version: "3.10"
|
||||||
|
activate-environment: true
|
||||||
|
version: 0.7.6
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
uv sync --all-groups
|
||||||
|
uv pip install ollama faiss-cpu
|
||||||
|
# always test against the latest version of the client
|
||||||
|
# TODO: this is not necessarily a good idea. we need to test against both published and latest
|
||||||
|
# to find out backwards compatibility issues.
|
||||||
|
uv pip install git+https://github.com/meta-llama/llama-stack-client-python.git@main
|
||||||
|
uv pip install -e .
|
51
.github/workflows/integration-auth-tests.yml
vendored
51
.github/workflows/integration-auth-tests.yml
vendored
|
@ -23,23 +23,18 @@ jobs:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
auth-provider: [kubernetes]
|
auth-provider: [oauth2_token]
|
||||||
fail-fast: false # we want to run all tests regardless of failure
|
fail-fast: false # we want to run all tests regardless of failure
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||||
|
|
||||||
- name: Install uv
|
- name: Install dependencies
|
||||||
uses: astral-sh/setup-uv@6b9c6063abd6010835644d4c2e1bef4cf5cd0fca # v6.0.1
|
uses: ./.github/actions/setup-runner
|
||||||
with:
|
|
||||||
python-version: "3.10"
|
|
||||||
activate-environment: true
|
|
||||||
|
|
||||||
- name: Set Up Environment and Install Dependencies
|
- name: Build Llama Stack
|
||||||
run: |
|
run: |
|
||||||
uv sync --extra dev --extra test
|
|
||||||
uv pip install -e .
|
|
||||||
llama stack build --template ollama --image-type venv
|
llama stack build --template ollama --image-type venv
|
||||||
|
|
||||||
- name: Install minikube
|
- name: Install minikube
|
||||||
|
@ -47,29 +42,53 @@ jobs:
|
||||||
uses: medyagh/setup-minikube@cea33675329b799adccc9526aa5daccc26cd5052 # v0.0.19
|
uses: medyagh/setup-minikube@cea33675329b799adccc9526aa5daccc26cd5052 # v0.0.19
|
||||||
|
|
||||||
- name: Start minikube
|
- name: Start minikube
|
||||||
if: ${{ matrix.auth-provider == 'kubernetes' }}
|
if: ${{ matrix.auth-provider == 'oauth2_token' }}
|
||||||
run: |
|
run: |
|
||||||
minikube start
|
minikube start
|
||||||
kubectl get pods -A
|
kubectl get pods -A
|
||||||
|
|
||||||
- name: Configure Kube Auth
|
- name: Configure Kube Auth
|
||||||
if: ${{ matrix.auth-provider == 'kubernetes' }}
|
if: ${{ matrix.auth-provider == 'oauth2_token' }}
|
||||||
run: |
|
run: |
|
||||||
kubectl create namespace llama-stack
|
kubectl create namespace llama-stack
|
||||||
kubectl create serviceaccount llama-stack-auth -n llama-stack
|
kubectl create serviceaccount llama-stack-auth -n llama-stack
|
||||||
kubectl create rolebinding llama-stack-auth-rolebinding --clusterrole=admin --serviceaccount=llama-stack:llama-stack-auth -n llama-stack
|
kubectl create rolebinding llama-stack-auth-rolebinding --clusterrole=admin --serviceaccount=llama-stack:llama-stack-auth -n llama-stack
|
||||||
kubectl create token llama-stack-auth -n llama-stack > llama-stack-auth-token
|
kubectl create token llama-stack-auth -n llama-stack > llama-stack-auth-token
|
||||||
|
cat <<EOF | kubectl apply -f -
|
||||||
|
apiVersion: rbac.authorization.k8s.io/v1
|
||||||
|
kind: ClusterRole
|
||||||
|
metadata:
|
||||||
|
name: allow-anonymous-openid
|
||||||
|
rules:
|
||||||
|
- nonResourceURLs: ["/openid/v1/jwks"]
|
||||||
|
verbs: ["get"]
|
||||||
|
---
|
||||||
|
apiVersion: rbac.authorization.k8s.io/v1
|
||||||
|
kind: ClusterRoleBinding
|
||||||
|
metadata:
|
||||||
|
name: allow-anonymous-openid
|
||||||
|
roleRef:
|
||||||
|
apiGroup: rbac.authorization.k8s.io
|
||||||
|
kind: ClusterRole
|
||||||
|
name: allow-anonymous-openid
|
||||||
|
subjects:
|
||||||
|
- kind: User
|
||||||
|
name: system:anonymous
|
||||||
|
apiGroup: rbac.authorization.k8s.io
|
||||||
|
EOF
|
||||||
|
|
||||||
- name: Set Kubernetes Config
|
- name: Set Kubernetes Config
|
||||||
if: ${{ matrix.auth-provider == 'kubernetes' }}
|
if: ${{ matrix.auth-provider == 'oauth2_token' }}
|
||||||
run: |
|
run: |
|
||||||
echo "KUBERNETES_API_SERVER_URL=$(kubectl config view --minify -o jsonpath='{.clusters[0].cluster.server}')" >> $GITHUB_ENV
|
echo "KUBERNETES_API_SERVER_URL=$(kubectl get --raw /.well-known/openid-configuration| jq -r .jwks_uri)" >> $GITHUB_ENV
|
||||||
echo "KUBERNETES_CA_CERT_PATH=$(kubectl config view --minify -o jsonpath='{.clusters[0].cluster.certificate-authority}')" >> $GITHUB_ENV
|
echo "KUBERNETES_CA_CERT_PATH=$(kubectl config view --minify -o jsonpath='{.clusters[0].cluster.certificate-authority}')" >> $GITHUB_ENV
|
||||||
|
echo "KUBERNETES_ISSUER=$(kubectl get --raw /.well-known/openid-configuration| jq -r .issuer)" >> $GITHUB_ENV
|
||||||
|
echo "KUBERNETES_AUDIENCE=$(kubectl create token llama-stack-auth -n llama-stack --duration=1h | cut -d. -f2 | base64 -d | jq -r '.aud[0]')" >> $GITHUB_ENV
|
||||||
|
|
||||||
- name: Set Kube Auth Config and run server
|
- name: Set Kube Auth Config and run server
|
||||||
env:
|
env:
|
||||||
INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct"
|
INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct"
|
||||||
if: ${{ matrix.auth-provider == 'kubernetes' }}
|
if: ${{ matrix.auth-provider == 'oauth2_token' }}
|
||||||
run: |
|
run: |
|
||||||
run_dir=$(mktemp -d)
|
run_dir=$(mktemp -d)
|
||||||
cat <<'EOF' > $run_dir/run.yaml
|
cat <<'EOF' > $run_dir/run.yaml
|
||||||
|
@ -81,10 +100,10 @@ jobs:
|
||||||
port: 8321
|
port: 8321
|
||||||
EOF
|
EOF
|
||||||
yq eval '.server.auth = {"provider_type": "${{ matrix.auth-provider }}"}' -i $run_dir/run.yaml
|
yq eval '.server.auth = {"provider_type": "${{ matrix.auth-provider }}"}' -i $run_dir/run.yaml
|
||||||
yq eval '.server.auth.config = {"api_server_url": "${{ env.KUBERNETES_API_SERVER_URL }}", "ca_cert_path": "${{ env.KUBERNETES_CA_CERT_PATH }}"}' -i $run_dir/run.yaml
|
yq eval '.server.auth.config = {"tls_cafile": "${{ env.KUBERNETES_CA_CERT_PATH }}", "issuer": "${{ env.KUBERNETES_ISSUER }}", "audience": "${{ env.KUBERNETES_AUDIENCE }}"}' -i $run_dir/run.yaml
|
||||||
|
yq eval '.server.auth.config.jwks = {"uri": "${{ env.KUBERNETES_API_SERVER_URL }}"}' -i $run_dir/run.yaml
|
||||||
cat $run_dir/run.yaml
|
cat $run_dir/run.yaml
|
||||||
|
|
||||||
source .venv/bin/activate
|
|
||||||
nohup uv run llama stack run $run_dir/run.yaml --image-type venv > server.log 2>&1 &
|
nohup uv run llama stack run $run_dir/run.yaml --image-type venv > server.log 2>&1 &
|
||||||
|
|
||||||
- name: Wait for Llama Stack server to be ready
|
- name: Wait for Llama Stack server to be ready
|
||||||
|
|
22
.github/workflows/integration-tests.yml
vendored
22
.github/workflows/integration-tests.yml
vendored
|
@ -24,7 +24,7 @@ jobs:
|
||||||
matrix:
|
matrix:
|
||||||
# Listing tests manually since some of them currently fail
|
# Listing tests manually since some of them currently fail
|
||||||
# TODO: generate matrix list from tests/integration when fixed
|
# TODO: generate matrix list from tests/integration when fixed
|
||||||
test-type: [agents, inference, datasets, inspect, scoring, post_training, providers]
|
test-type: [agents, inference, datasets, inspect, scoring, post_training, providers, tool_runtime]
|
||||||
client-type: [library, http]
|
client-type: [library, http]
|
||||||
fail-fast: false # we want to run all tests regardless of failure
|
fail-fast: false # we want to run all tests regardless of failure
|
||||||
|
|
||||||
|
@ -32,24 +32,14 @@ jobs:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||||
|
|
||||||
- name: Install uv
|
- name: Install dependencies
|
||||||
uses: astral-sh/setup-uv@6b9c6063abd6010835644d4c2e1bef4cf5cd0fca # v6.0.1
|
uses: ./.github/actions/setup-runner
|
||||||
with:
|
|
||||||
python-version: "3.10"
|
|
||||||
activate-environment: true
|
|
||||||
|
|
||||||
- name: Setup ollama
|
- name: Setup ollama
|
||||||
uses: ./.github/actions/setup-ollama
|
uses: ./.github/actions/setup-ollama
|
||||||
|
|
||||||
- name: Set Up Environment and Install Dependencies
|
- name: Build Llama Stack
|
||||||
run: |
|
run: |
|
||||||
uv sync --extra dev --extra test
|
|
||||||
uv pip install ollama faiss-cpu
|
|
||||||
# always test against the latest version of the client
|
|
||||||
# TODO: this is not necessarily a good idea. we need to test against both published and latest
|
|
||||||
# to find out backwards compatibility issues.
|
|
||||||
uv pip install git+https://github.com/meta-llama/llama-stack-client-python.git@main
|
|
||||||
uv pip install -e .
|
|
||||||
llama stack build --template ollama --image-type venv
|
llama stack build --template ollama --image-type venv
|
||||||
|
|
||||||
- name: Start Llama Stack server in background
|
- name: Start Llama Stack server in background
|
||||||
|
@ -57,7 +47,6 @@ jobs:
|
||||||
env:
|
env:
|
||||||
INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct"
|
INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct"
|
||||||
run: |
|
run: |
|
||||||
source .venv/bin/activate
|
|
||||||
LLAMA_STACK_LOG_FILE=server.log nohup uv run llama stack run ./llama_stack/templates/ollama/run.yaml --image-type venv &
|
LLAMA_STACK_LOG_FILE=server.log nohup uv run llama stack run ./llama_stack/templates/ollama/run.yaml --image-type venv &
|
||||||
|
|
||||||
- name: Wait for Llama Stack server to be ready
|
- name: Wait for Llama Stack server to be ready
|
||||||
|
@ -85,6 +74,7 @@ jobs:
|
||||||
echo "Ollama health check failed"
|
echo "Ollama health check failed"
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
- name: Check Storage and Memory Available Before Tests
|
- name: Check Storage and Memory Available Before Tests
|
||||||
if: ${{ always() }}
|
if: ${{ always() }}
|
||||||
run: |
|
run: |
|
||||||
|
@ -100,7 +90,7 @@ jobs:
|
||||||
else
|
else
|
||||||
stack_config="http://localhost:8321"
|
stack_config="http://localhost:8321"
|
||||||
fi
|
fi
|
||||||
uv run pytest -v tests/integration/${{ matrix.test-type }} --stack-config=${stack_config} \
|
uv run pytest -s -v tests/integration/${{ matrix.test-type }} --stack-config=${stack_config} \
|
||||||
-k "not(builtin_tool or safety_with_image or code_interpreter or test_rag)" \
|
-k "not(builtin_tool or safety_with_image or code_interpreter or test_rag)" \
|
||||||
--text-model="meta-llama/Llama-3.2-3B-Instruct" \
|
--text-model="meta-llama/Llama-3.2-3B-Instruct" \
|
||||||
--embedding-model=all-MiniLM-L6-v2
|
--embedding-model=all-MiniLM-L6-v2
|
||||||
|
|
1
.github/workflows/pre-commit.yml
vendored
1
.github/workflows/pre-commit.yml
vendored
|
@ -29,6 +29,7 @@ jobs:
|
||||||
- uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1
|
- uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1
|
||||||
env:
|
env:
|
||||||
SKIP: no-commit-to-branch
|
SKIP: no-commit-to-branch
|
||||||
|
RUFF_OUTPUT_FORMAT: github
|
||||||
|
|
||||||
- name: Verify if there are any diff files after pre-commit
|
- name: Verify if there are any diff files after pre-commit
|
||||||
run: |
|
run: |
|
||||||
|
|
69
.github/workflows/providers-build.yml
vendored
69
.github/workflows/providers-build.yml
vendored
|
@ -50,21 +50,8 @@ jobs:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||||
|
|
||||||
- name: Set up Python
|
- name: Install dependencies
|
||||||
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0
|
uses: ./.github/actions/setup-runner
|
||||||
with:
|
|
||||||
python-version: '3.10'
|
|
||||||
|
|
||||||
- name: Install uv
|
|
||||||
uses: astral-sh/setup-uv@6b9c6063abd6010835644d4c2e1bef4cf5cd0fca # v6.0.1
|
|
||||||
with:
|
|
||||||
python-version: "3.10"
|
|
||||||
|
|
||||||
- name: Install LlamaStack
|
|
||||||
run: |
|
|
||||||
uv venv
|
|
||||||
source .venv/bin/activate
|
|
||||||
uv pip install -e .
|
|
||||||
|
|
||||||
- name: Print build dependencies
|
- name: Print build dependencies
|
||||||
run: |
|
run: |
|
||||||
|
@ -79,7 +66,6 @@ jobs:
|
||||||
- name: Print dependencies in the image
|
- name: Print dependencies in the image
|
||||||
if: matrix.image-type == 'venv'
|
if: matrix.image-type == 'venv'
|
||||||
run: |
|
run: |
|
||||||
source test/bin/activate
|
|
||||||
uv pip list
|
uv pip list
|
||||||
|
|
||||||
build-single-provider:
|
build-single-provider:
|
||||||
|
@ -88,21 +74,8 @@ jobs:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||||
|
|
||||||
- name: Set up Python
|
- name: Install dependencies
|
||||||
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0
|
uses: ./.github/actions/setup-runner
|
||||||
with:
|
|
||||||
python-version: '3.10'
|
|
||||||
|
|
||||||
- name: Install uv
|
|
||||||
uses: astral-sh/setup-uv@6b9c6063abd6010835644d4c2e1bef4cf5cd0fca # v6.0.1
|
|
||||||
with:
|
|
||||||
python-version: "3.10"
|
|
||||||
|
|
||||||
- name: Install LlamaStack
|
|
||||||
run: |
|
|
||||||
uv venv
|
|
||||||
source .venv/bin/activate
|
|
||||||
uv pip install -e .
|
|
||||||
|
|
||||||
- name: Build a single provider
|
- name: Build a single provider
|
||||||
run: |
|
run: |
|
||||||
|
@ -114,21 +87,8 @@ jobs:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||||
|
|
||||||
- name: Set up Python
|
- name: Install dependencies
|
||||||
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0
|
uses: ./.github/actions/setup-runner
|
||||||
with:
|
|
||||||
python-version: '3.10'
|
|
||||||
|
|
||||||
- name: Install uv
|
|
||||||
uses: astral-sh/setup-uv@6b9c6063abd6010835644d4c2e1bef4cf5cd0fca # v6.0.1
|
|
||||||
with:
|
|
||||||
python-version: "3.10"
|
|
||||||
|
|
||||||
- name: Install LlamaStack
|
|
||||||
run: |
|
|
||||||
uv venv
|
|
||||||
source .venv/bin/activate
|
|
||||||
uv pip install -e .
|
|
||||||
|
|
||||||
- name: Build a single provider
|
- name: Build a single provider
|
||||||
run: |
|
run: |
|
||||||
|
@ -152,21 +112,8 @@ jobs:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||||
|
|
||||||
- name: Set up Python
|
- name: Install dependencies
|
||||||
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0
|
uses: ./.github/actions/setup-runner
|
||||||
with:
|
|
||||||
python-version: '3.10'
|
|
||||||
|
|
||||||
- name: Install uv
|
|
||||||
uses: astral-sh/setup-uv@6b9c6063abd6010835644d4c2e1bef4cf5cd0fca # v6.0.1
|
|
||||||
with:
|
|
||||||
python-version: "3.10"
|
|
||||||
|
|
||||||
- name: Install LlamaStack
|
|
||||||
run: |
|
|
||||||
uv venv
|
|
||||||
source .venv/bin/activate
|
|
||||||
uv pip install -e .
|
|
||||||
|
|
||||||
- name: Pin template to UBI9 base
|
- name: Pin template to UBI9 base
|
||||||
run: |
|
run: |
|
||||||
|
|
12
.github/workflows/test-external-providers.yml
vendored
12
.github/workflows/test-external-providers.yml
vendored
|
@ -25,15 +25,8 @@ jobs:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||||
|
|
||||||
- name: Install uv
|
- name: Install dependencies
|
||||||
uses: astral-sh/setup-uv@6b9c6063abd6010835644d4c2e1bef4cf5cd0fca # v6.0.1
|
uses: ./.github/actions/setup-runner
|
||||||
with:
|
|
||||||
python-version: "3.10"
|
|
||||||
|
|
||||||
- name: Set Up Environment and Install Dependencies
|
|
||||||
run: |
|
|
||||||
uv sync --extra dev --extra test
|
|
||||||
uv pip install -e .
|
|
||||||
|
|
||||||
- name: Apply image type to config file
|
- name: Apply image type to config file
|
||||||
run: |
|
run: |
|
||||||
|
@ -59,7 +52,6 @@ jobs:
|
||||||
env:
|
env:
|
||||||
INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct"
|
INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct"
|
||||||
run: |
|
run: |
|
||||||
source ci-test/bin/activate
|
|
||||||
uv run pip list
|
uv run pip list
|
||||||
nohup uv run --active llama stack run tests/external-provider/llama-stack-provider-ollama/run.yaml --image-type ${{ matrix.image-type }} > server.log 2>&1 &
|
nohup uv run --active llama stack run tests/external-provider/llama-stack-provider-ollama/run.yaml --image-type ${{ matrix.image-type }} > server.log 2>&1 &
|
||||||
|
|
||||||
|
|
14
.github/workflows/unit-tests.yml
vendored
14
.github/workflows/unit-tests.yml
vendored
|
@ -30,17 +30,11 @@ jobs:
|
||||||
- "3.12"
|
- "3.12"
|
||||||
- "3.13"
|
- "3.13"
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
- name: Checkout repository
|
||||||
|
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||||
|
|
||||||
- name: Set up Python ${{ matrix.python }}
|
- name: Install dependencies
|
||||||
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0
|
uses: ./.github/actions/setup-runner
|
||||||
with:
|
|
||||||
python-version: ${{ matrix.python }}
|
|
||||||
|
|
||||||
- uses: astral-sh/setup-uv@6b9c6063abd6010835644d4c2e1bef4cf5cd0fca # v6.0.1
|
|
||||||
with:
|
|
||||||
python-version: ${{ matrix.python }}
|
|
||||||
enable-cache: false
|
|
||||||
|
|
||||||
- name: Run unit tests
|
- name: Run unit tests
|
||||||
run: |
|
run: |
|
||||||
|
|
12
.github/workflows/update-readthedocs.yml
vendored
12
.github/workflows/update-readthedocs.yml
vendored
|
@ -37,16 +37,8 @@ jobs:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||||
|
|
||||||
- name: Set up Python
|
- name: Install dependencies
|
||||||
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0
|
uses: ./.github/actions/setup-runner
|
||||||
with:
|
|
||||||
python-version: '3.11'
|
|
||||||
|
|
||||||
- name: Install the latest version of uv
|
|
||||||
uses: astral-sh/setup-uv@6b9c6063abd6010835644d4c2e1bef4cf5cd0fca # v6.0.1
|
|
||||||
|
|
||||||
- name: Sync with uv
|
|
||||||
run: uv sync --extra docs
|
|
||||||
|
|
||||||
- name: Build HTML
|
- name: Build HTML
|
||||||
run: |
|
run: |
|
||||||
|
|
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -6,6 +6,7 @@ dev_requirements.txt
|
||||||
build
|
build
|
||||||
.DS_Store
|
.DS_Store
|
||||||
llama_stack/configs/*
|
llama_stack/configs/*
|
||||||
|
.cursor/
|
||||||
xcuserdata/
|
xcuserdata/
|
||||||
*.hmap
|
*.hmap
|
||||||
.DS_Store
|
.DS_Store
|
||||||
|
|
|
@ -53,7 +53,7 @@ repos:
|
||||||
- black==24.3.0
|
- black==24.3.0
|
||||||
|
|
||||||
- repo: https://github.com/astral-sh/uv-pre-commit
|
- repo: https://github.com/astral-sh/uv-pre-commit
|
||||||
rev: 0.6.3
|
rev: 0.7.8
|
||||||
hooks:
|
hooks:
|
||||||
- id: uv-lock
|
- id: uv-lock
|
||||||
- id: uv-export
|
- id: uv-export
|
||||||
|
@ -61,6 +61,7 @@ repos:
|
||||||
"--frozen",
|
"--frozen",
|
||||||
"--no-hashes",
|
"--no-hashes",
|
||||||
"--no-emit-project",
|
"--no-emit-project",
|
||||||
|
"--no-default-groups",
|
||||||
"--output-file=requirements.txt"
|
"--output-file=requirements.txt"
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -88,20 +89,17 @@ repos:
|
||||||
- id: distro-codegen
|
- id: distro-codegen
|
||||||
name: Distribution Template Codegen
|
name: Distribution Template Codegen
|
||||||
additional_dependencies:
|
additional_dependencies:
|
||||||
- uv==0.6.0
|
- uv==0.7.8
|
||||||
entry: uv run --extra codegen ./scripts/distro_codegen.py
|
entry: uv run --group codegen ./scripts/distro_codegen.py
|
||||||
language: python
|
language: python
|
||||||
pass_filenames: false
|
pass_filenames: false
|
||||||
require_serial: true
|
require_serial: true
|
||||||
files: ^llama_stack/templates/.*$|^llama_stack/providers/.*/inference/.*/models\.py$
|
files: ^llama_stack/templates/.*$|^llama_stack/providers/.*/inference/.*/models\.py$
|
||||||
|
|
||||||
- repo: local
|
|
||||||
hooks:
|
|
||||||
- id: openapi-codegen
|
- id: openapi-codegen
|
||||||
name: API Spec Codegen
|
name: API Spec Codegen
|
||||||
additional_dependencies:
|
additional_dependencies:
|
||||||
- uv==0.6.2
|
- uv==0.7.8
|
||||||
entry: sh -c 'uv run --with ".[dev]" ./docs/openapi_generator/run_openapi_generator.sh > /dev/null'
|
entry: sh -c 'uv run ./docs/openapi_generator/run_openapi_generator.sh > /dev/null'
|
||||||
language: python
|
language: python
|
||||||
pass_filenames: false
|
pass_filenames: false
|
||||||
require_serial: true
|
require_serial: true
|
||||||
|
|
21
CHANGELOG.md
21
CHANGELOG.md
|
@ -1,5 +1,26 @@
|
||||||
# Changelog
|
# Changelog
|
||||||
|
|
||||||
|
# v0.2.7
|
||||||
|
Published on: 2025-05-16T20:38:10Z
|
||||||
|
|
||||||
|
## Highlights
|
||||||
|
|
||||||
|
This is a small update. But a couple highlights:
|
||||||
|
|
||||||
|
* feat: function tools in OpenAI Responses by @bbrowning in https://github.com/meta-llama/llama-stack/pull/2094, getting closer to ready. Streaming is the next missing piece.
|
||||||
|
* feat: Adding support for customizing chunk context in RAG insertion and querying by @franciscojavierarceo in https://github.com/meta-llama/llama-stack/pull/2134
|
||||||
|
* feat: scaffolding for Llama Stack UI by @ehhuang in https://github.com/meta-llama/llama-stack/pull/2149, more to come in the coming releases.
|
||||||
|
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
# v0.2.6
|
||||||
|
Published on: 2025-05-12T18:06:52Z
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
# v0.2.5
|
# v0.2.5
|
||||||
Published on: 2025-05-04T20:16:49Z
|
Published on: 2025-05-04T20:16:49Z
|
||||||
|
|
||||||
|
|
|
@ -167,14 +167,11 @@ If you have made changes to a provider's configuration in any form (introducing
|
||||||
If you are making changes to the documentation at [https://llama-stack.readthedocs.io/en/latest/](https://llama-stack.readthedocs.io/en/latest/), you can use the following command to build the documentation and preview your changes. You will need [Sphinx](https://www.sphinx-doc.org/en/master/) and the readthedocs theme.
|
If you are making changes to the documentation at [https://llama-stack.readthedocs.io/en/latest/](https://llama-stack.readthedocs.io/en/latest/), you can use the following command to build the documentation and preview your changes. You will need [Sphinx](https://www.sphinx-doc.org/en/master/) and the readthedocs theme.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
cd docs
|
|
||||||
uv sync --extra docs
|
|
||||||
|
|
||||||
# This rebuilds the documentation pages.
|
# This rebuilds the documentation pages.
|
||||||
uv run make html
|
uv run --with ".[docs]" make -C docs/ html
|
||||||
|
|
||||||
# This will start a local server (usually at http://127.0.0.1:8000) that automatically rebuilds and refreshes when you make changes to the documentation.
|
# This will start a local server (usually at http://127.0.0.1:8000) that automatically rebuilds and refreshes when you make changes to the documentation.
|
||||||
uv run sphinx-autobuild source build/html --write-all
|
uv run --with ".[docs]" sphinx-autobuild docs/source docs/build/html --write-all
|
||||||
```
|
```
|
||||||
|
|
||||||
### Update API Documentation
|
### Update API Documentation
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
include pyproject.toml
|
include pyproject.toml
|
||||||
include llama_stack/templates/dependencies.json
|
|
||||||
include llama_stack/models/llama/llama3/tokenizer.model
|
include llama_stack/models/llama/llama3/tokenizer.model
|
||||||
include llama_stack/models/llama/llama4/tokenizer.model
|
include llama_stack/models/llama/llama4/tokenizer.model
|
||||||
include llama_stack/distribution/*.sh
|
include llama_stack/distribution/*.sh
|
||||||
|
|
|
@ -110,7 +110,7 @@ Here is a list of the various API providers and available distributions that can
|
||||||
| **API Provider Builder** | **Environments** | **Agents** | **Inference** | **Memory** | **Safety** | **Telemetry** |
|
| **API Provider Builder** | **Environments** | **Agents** | **Inference** | **Memory** | **Safety** | **Telemetry** |
|
||||||
|:------------------------:|:----------------------:|:----------:|:-------------:|:----------:|:----------:|:-------------:|
|
|:------------------------:|:----------------------:|:----------:|:-------------:|:----------:|:----------:|:-------------:|
|
||||||
| Meta Reference | Single Node | ✅ | ✅ | ✅ | ✅ | ✅ |
|
| Meta Reference | Single Node | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||||
| SambaNova | Hosted | | ✅ | | | |
|
| SambaNova | Hosted | | ✅ | | ✅ | |
|
||||||
| Cerebras | Hosted | | ✅ | | | |
|
| Cerebras | Hosted | | ✅ | | | |
|
||||||
| Fireworks | Hosted | ✅ | ✅ | ✅ | | |
|
| Fireworks | Hosted | ✅ | ✅ | ✅ | | |
|
||||||
| AWS Bedrock | Hosted | | ✅ | | ✅ | |
|
| AWS Bedrock | Hosted | | ✅ | | ✅ | |
|
||||||
|
|
584
docs/_static/llama-stack-spec.html
vendored
584
docs/_static/llama-stack-spec.html
vendored
|
@ -518,6 +518,74 @@
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"/v1/openai/v1/responses": {
|
"/v1/openai/v1/responses": {
|
||||||
|
"get": {
|
||||||
|
"responses": {
|
||||||
|
"200": {
|
||||||
|
"description": "A ListOpenAIResponseObject.",
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/ListOpenAIResponseObject"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"400": {
|
||||||
|
"$ref": "#/components/responses/BadRequest400"
|
||||||
|
},
|
||||||
|
"429": {
|
||||||
|
"$ref": "#/components/responses/TooManyRequests429"
|
||||||
|
},
|
||||||
|
"500": {
|
||||||
|
"$ref": "#/components/responses/InternalServerError500"
|
||||||
|
},
|
||||||
|
"default": {
|
||||||
|
"$ref": "#/components/responses/DefaultError"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"tags": [
|
||||||
|
"Agents"
|
||||||
|
],
|
||||||
|
"description": "List all OpenAI responses.",
|
||||||
|
"parameters": [
|
||||||
|
{
|
||||||
|
"name": "after",
|
||||||
|
"in": "query",
|
||||||
|
"description": "The ID of the last response to return.",
|
||||||
|
"required": false,
|
||||||
|
"schema": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "limit",
|
||||||
|
"in": "query",
|
||||||
|
"description": "The number of responses to return.",
|
||||||
|
"required": false,
|
||||||
|
"schema": {
|
||||||
|
"type": "integer"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "model",
|
||||||
|
"in": "query",
|
||||||
|
"description": "The model to filter responses by.",
|
||||||
|
"required": false,
|
||||||
|
"schema": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "order",
|
||||||
|
"in": "query",
|
||||||
|
"description": "The order to sort responses by when sorted by created_at ('asc' or 'desc').",
|
||||||
|
"required": false,
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/Order"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
"post": {
|
"post": {
|
||||||
"responses": {
|
"responses": {
|
||||||
"200": {
|
"200": {
|
||||||
|
@ -1395,7 +1463,7 @@
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"/v1/openai/v1/responses/{id}": {
|
"/v1/openai/v1/responses/{response_id}": {
|
||||||
"get": {
|
"get": {
|
||||||
"responses": {
|
"responses": {
|
||||||
"200": {
|
"200": {
|
||||||
|
@ -1427,7 +1495,7 @@
|
||||||
"description": "Retrieve an OpenAI response by its ID.",
|
"description": "Retrieve an OpenAI response by its ID.",
|
||||||
"parameters": [
|
"parameters": [
|
||||||
{
|
{
|
||||||
"name": "id",
|
"name": "response_id",
|
||||||
"in": "path",
|
"in": "path",
|
||||||
"description": "The ID of the OpenAI response to retrieve.",
|
"description": "The ID of the OpenAI response to retrieve.",
|
||||||
"required": true,
|
"required": true,
|
||||||
|
@ -2926,6 +2994,97 @@
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"/v1/openai/v1/responses/{response_id}/input_items": {
|
||||||
|
"get": {
|
||||||
|
"responses": {
|
||||||
|
"200": {
|
||||||
|
"description": "An ListOpenAIResponseInputItem.",
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/ListOpenAIResponseInputItem"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"400": {
|
||||||
|
"$ref": "#/components/responses/BadRequest400"
|
||||||
|
},
|
||||||
|
"429": {
|
||||||
|
"$ref": "#/components/responses/TooManyRequests429"
|
||||||
|
},
|
||||||
|
"500": {
|
||||||
|
"$ref": "#/components/responses/InternalServerError500"
|
||||||
|
},
|
||||||
|
"default": {
|
||||||
|
"$ref": "#/components/responses/DefaultError"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"tags": [
|
||||||
|
"Agents"
|
||||||
|
],
|
||||||
|
"description": "List input items for a given OpenAI response.",
|
||||||
|
"parameters": [
|
||||||
|
{
|
||||||
|
"name": "response_id",
|
||||||
|
"in": "path",
|
||||||
|
"description": "The ID of the response to retrieve input items for.",
|
||||||
|
"required": true,
|
||||||
|
"schema": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "after",
|
||||||
|
"in": "query",
|
||||||
|
"description": "An item ID to list items after, used for pagination.",
|
||||||
|
"required": false,
|
||||||
|
"schema": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "before",
|
||||||
|
"in": "query",
|
||||||
|
"description": "An item ID to list items before, used for pagination.",
|
||||||
|
"required": false,
|
||||||
|
"schema": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "include",
|
||||||
|
"in": "query",
|
||||||
|
"description": "Additional fields to include in the response.",
|
||||||
|
"required": false,
|
||||||
|
"schema": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "limit",
|
||||||
|
"in": "query",
|
||||||
|
"description": "A limit on the number of objects to be returned. Limit can range between 1 and 100, and the default is 20.",
|
||||||
|
"required": false,
|
||||||
|
"schema": {
|
||||||
|
"type": "integer"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "order",
|
||||||
|
"in": "query",
|
||||||
|
"description": "The order to return the input items in. Default is desc.",
|
||||||
|
"required": false,
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/Order"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
"/v1/providers": {
|
"/v1/providers": {
|
||||||
"get": {
|
"get": {
|
||||||
"responses": {
|
"responses": {
|
||||||
|
@ -6742,6 +6901,9 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"$ref": "#/components/schemas/OpenAIResponseInputToolFunction"
|
"$ref": "#/components/schemas/OpenAIResponseInputToolFunction"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"$ref": "#/components/schemas/OpenAIResponseInputToolMCP"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"discriminator": {
|
"discriminator": {
|
||||||
|
@ -6749,7 +6911,8 @@
|
||||||
"mapping": {
|
"mapping": {
|
||||||
"web_search": "#/components/schemas/OpenAIResponseInputToolWebSearch",
|
"web_search": "#/components/schemas/OpenAIResponseInputToolWebSearch",
|
||||||
"file_search": "#/components/schemas/OpenAIResponseInputToolFileSearch",
|
"file_search": "#/components/schemas/OpenAIResponseInputToolFileSearch",
|
||||||
"function": "#/components/schemas/OpenAIResponseInputToolFunction"
|
"function": "#/components/schemas/OpenAIResponseInputToolFunction",
|
||||||
|
"mcp": "#/components/schemas/OpenAIResponseInputToolMCP"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
@ -6839,6 +7002,110 @@
|
||||||
],
|
],
|
||||||
"title": "OpenAIResponseInputToolFunction"
|
"title": "OpenAIResponseInputToolFunction"
|
||||||
},
|
},
|
||||||
|
"OpenAIResponseInputToolMCP": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"type": {
|
||||||
|
"type": "string",
|
||||||
|
"const": "mcp",
|
||||||
|
"default": "mcp"
|
||||||
|
},
|
||||||
|
"server_label": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"server_url": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"headers": {
|
||||||
|
"type": "object",
|
||||||
|
"additionalProperties": {
|
||||||
|
"oneOf": [
|
||||||
|
{
|
||||||
|
"type": "null"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "boolean"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "number"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "array"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "object"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"require_approval": {
|
||||||
|
"oneOf": [
|
||||||
|
{
|
||||||
|
"type": "string",
|
||||||
|
"const": "always"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "string",
|
||||||
|
"const": "never"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"always": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"never": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"title": "ApprovalFilter"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"default": "never"
|
||||||
|
},
|
||||||
|
"allowed_tools": {
|
||||||
|
"oneOf": [
|
||||||
|
{
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"tool_names": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"title": "AllowedToolsFilter"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"type",
|
||||||
|
"server_label",
|
||||||
|
"server_url",
|
||||||
|
"require_approval"
|
||||||
|
],
|
||||||
|
"title": "OpenAIResponseInputToolMCP"
|
||||||
|
},
|
||||||
"OpenAIResponseInputToolWebSearch": {
|
"OpenAIResponseInputToolWebSearch": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
|
@ -6951,15 +7218,15 @@
|
||||||
"OpenAIResponseOutputMessageFunctionToolCall": {
|
"OpenAIResponseOutputMessageFunctionToolCall": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"arguments": {
|
|
||||||
"type": "string"
|
|
||||||
},
|
|
||||||
"call_id": {
|
"call_id": {
|
||||||
"type": "string"
|
"type": "string"
|
||||||
},
|
},
|
||||||
"name": {
|
"name": {
|
||||||
"type": "string"
|
"type": "string"
|
||||||
},
|
},
|
||||||
|
"arguments": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
"type": {
|
"type": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"const": "function_call",
|
"const": "function_call",
|
||||||
|
@ -6974,12 +7241,10 @@
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
"required": [
|
"required": [
|
||||||
"arguments",
|
|
||||||
"call_id",
|
"call_id",
|
||||||
"name",
|
"name",
|
||||||
"type",
|
"arguments",
|
||||||
"id",
|
"type"
|
||||||
"status"
|
|
||||||
],
|
],
|
||||||
"title": "OpenAIResponseOutputMessageFunctionToolCall"
|
"title": "OpenAIResponseOutputMessageFunctionToolCall"
|
||||||
},
|
},
|
||||||
|
@ -7027,6 +7292,9 @@
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "The underlying LLM used for completions."
|
"description": "The underlying LLM used for completions."
|
||||||
},
|
},
|
||||||
|
"instructions": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
"previous_response_id": {
|
"previous_response_id": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "(Optional) if specified, the new response will be a continuation of the previous response. This can be used to easily fork-off new responses from existing responses."
|
"description": "(Optional) if specified, the new response will be a continuation of the previous response. This can be used to easily fork-off new responses from existing responses."
|
||||||
|
@ -7142,6 +7410,12 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"$ref": "#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall"
|
"$ref": "#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"$ref": "#/components/schemas/OpenAIResponseOutputMessageMCPCall"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"$ref": "#/components/schemas/OpenAIResponseOutputMessageMCPListTools"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"discriminator": {
|
"discriminator": {
|
||||||
|
@ -7149,15 +7423,126 @@
|
||||||
"mapping": {
|
"mapping": {
|
||||||
"message": "#/components/schemas/OpenAIResponseMessage",
|
"message": "#/components/schemas/OpenAIResponseMessage",
|
||||||
"web_search_call": "#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall",
|
"web_search_call": "#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall",
|
||||||
"function_call": "#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall"
|
"function_call": "#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall",
|
||||||
|
"mcp_call": "#/components/schemas/OpenAIResponseOutputMessageMCPCall",
|
||||||
|
"mcp_list_tools": "#/components/schemas/OpenAIResponseOutputMessageMCPListTools"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"OpenAIResponseOutputMessageMCPCall": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"id": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"type": "string",
|
||||||
|
"const": "mcp_call",
|
||||||
|
"default": "mcp_call"
|
||||||
|
},
|
||||||
|
"arguments": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"name": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"server_label": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"error": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"output": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"id",
|
||||||
|
"type",
|
||||||
|
"arguments",
|
||||||
|
"name",
|
||||||
|
"server_label"
|
||||||
|
],
|
||||||
|
"title": "OpenAIResponseOutputMessageMCPCall"
|
||||||
|
},
|
||||||
|
"OpenAIResponseOutputMessageMCPListTools": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"id": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"type": "string",
|
||||||
|
"const": "mcp_list_tools",
|
||||||
|
"default": "mcp_list_tools"
|
||||||
|
},
|
||||||
|
"server_label": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"tools": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"input_schema": {
|
||||||
|
"type": "object",
|
||||||
|
"additionalProperties": {
|
||||||
|
"oneOf": [
|
||||||
|
{
|
||||||
|
"type": "null"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "boolean"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "number"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "array"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "object"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"name": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"description": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"input_schema",
|
||||||
|
"name"
|
||||||
|
],
|
||||||
|
"title": "MCPListToolsTool"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"id",
|
||||||
|
"type",
|
||||||
|
"server_label",
|
||||||
|
"tools"
|
||||||
|
],
|
||||||
|
"title": "OpenAIResponseOutputMessageMCPListTools"
|
||||||
|
},
|
||||||
"OpenAIResponseObjectStream": {
|
"OpenAIResponseObjectStream": {
|
||||||
"oneOf": [
|
"oneOf": [
|
||||||
{
|
{
|
||||||
"$ref": "#/components/schemas/OpenAIResponseObjectStreamResponseCreated"
|
"$ref": "#/components/schemas/OpenAIResponseObjectStreamResponseCreated"
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"$ref": "#/components/schemas/OpenAIResponseObjectStreamResponseOutputTextDelta"
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"$ref": "#/components/schemas/OpenAIResponseObjectStreamResponseCompleted"
|
"$ref": "#/components/schemas/OpenAIResponseObjectStreamResponseCompleted"
|
||||||
}
|
}
|
||||||
|
@ -7166,6 +7551,7 @@
|
||||||
"propertyName": "type",
|
"propertyName": "type",
|
||||||
"mapping": {
|
"mapping": {
|
||||||
"response.created": "#/components/schemas/OpenAIResponseObjectStreamResponseCreated",
|
"response.created": "#/components/schemas/OpenAIResponseObjectStreamResponseCreated",
|
||||||
|
"response.output_text.delta": "#/components/schemas/OpenAIResponseObjectStreamResponseOutputTextDelta",
|
||||||
"response.completed": "#/components/schemas/OpenAIResponseObjectStreamResponseCompleted"
|
"response.completed": "#/components/schemas/OpenAIResponseObjectStreamResponseCompleted"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -7208,6 +7594,41 @@
|
||||||
],
|
],
|
||||||
"title": "OpenAIResponseObjectStreamResponseCreated"
|
"title": "OpenAIResponseObjectStreamResponseCreated"
|
||||||
},
|
},
|
||||||
|
"OpenAIResponseObjectStreamResponseOutputTextDelta": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"content_index": {
|
||||||
|
"type": "integer"
|
||||||
|
},
|
||||||
|
"delta": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"item_id": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"output_index": {
|
||||||
|
"type": "integer"
|
||||||
|
},
|
||||||
|
"sequence_number": {
|
||||||
|
"type": "integer"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"type": "string",
|
||||||
|
"const": "response.output_text.delta",
|
||||||
|
"default": "response.output_text.delta"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"content_index",
|
||||||
|
"delta",
|
||||||
|
"item_id",
|
||||||
|
"output_index",
|
||||||
|
"sequence_number",
|
||||||
|
"type"
|
||||||
|
],
|
||||||
|
"title": "OpenAIResponseObjectStreamResponseOutputTextDelta"
|
||||||
|
},
|
||||||
"CreateUploadSessionRequest": {
|
"CreateUploadSessionRequest": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
|
@ -9173,9 +9594,6 @@
|
||||||
"toolgroup_id": {
|
"toolgroup_id": {
|
||||||
"type": "string"
|
"type": "string"
|
||||||
},
|
},
|
||||||
"tool_host": {
|
|
||||||
"$ref": "#/components/schemas/ToolHost"
|
|
||||||
},
|
|
||||||
"description": {
|
"description": {
|
||||||
"type": "string"
|
"type": "string"
|
||||||
},
|
},
|
||||||
|
@ -9217,21 +9635,11 @@
|
||||||
"provider_id",
|
"provider_id",
|
||||||
"type",
|
"type",
|
||||||
"toolgroup_id",
|
"toolgroup_id",
|
||||||
"tool_host",
|
|
||||||
"description",
|
"description",
|
||||||
"parameters"
|
"parameters"
|
||||||
],
|
],
|
||||||
"title": "Tool"
|
"title": "Tool"
|
||||||
},
|
},
|
||||||
"ToolHost": {
|
|
||||||
"type": "string",
|
|
||||||
"enum": [
|
|
||||||
"distribution",
|
|
||||||
"client",
|
|
||||||
"model_context_protocol"
|
|
||||||
],
|
|
||||||
"title": "ToolHost"
|
|
||||||
},
|
|
||||||
"ToolGroup": {
|
"ToolGroup": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
|
@ -10068,6 +10476,130 @@
|
||||||
],
|
],
|
||||||
"title": "ListModelsResponse"
|
"title": "ListModelsResponse"
|
||||||
},
|
},
|
||||||
|
"ListOpenAIResponseInputItem": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"data": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"$ref": "#/components/schemas/OpenAIResponseInput"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"object": {
|
||||||
|
"type": "string",
|
||||||
|
"const": "list",
|
||||||
|
"default": "list"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"data",
|
||||||
|
"object"
|
||||||
|
],
|
||||||
|
"title": "ListOpenAIResponseInputItem"
|
||||||
|
},
|
||||||
|
"ListOpenAIResponseObject": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"data": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"$ref": "#/components/schemas/OpenAIResponseObjectWithInput"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"has_more": {
|
||||||
|
"type": "boolean"
|
||||||
|
},
|
||||||
|
"first_id": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"last_id": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"object": {
|
||||||
|
"type": "string",
|
||||||
|
"const": "list",
|
||||||
|
"default": "list"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"data",
|
||||||
|
"has_more",
|
||||||
|
"first_id",
|
||||||
|
"last_id",
|
||||||
|
"object"
|
||||||
|
],
|
||||||
|
"title": "ListOpenAIResponseObject"
|
||||||
|
},
|
||||||
|
"OpenAIResponseObjectWithInput": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"created_at": {
|
||||||
|
"type": "integer"
|
||||||
|
},
|
||||||
|
"error": {
|
||||||
|
"$ref": "#/components/schemas/OpenAIResponseError"
|
||||||
|
},
|
||||||
|
"id": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"model": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"object": {
|
||||||
|
"type": "string",
|
||||||
|
"const": "response",
|
||||||
|
"default": "response"
|
||||||
|
},
|
||||||
|
"output": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"$ref": "#/components/schemas/OpenAIResponseOutput"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"parallel_tool_calls": {
|
||||||
|
"type": "boolean",
|
||||||
|
"default": false
|
||||||
|
},
|
||||||
|
"previous_response_id": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"status": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"temperature": {
|
||||||
|
"type": "number"
|
||||||
|
},
|
||||||
|
"top_p": {
|
||||||
|
"type": "number"
|
||||||
|
},
|
||||||
|
"truncation": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"user": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"input": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"$ref": "#/components/schemas/OpenAIResponseInput"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"created_at",
|
||||||
|
"id",
|
||||||
|
"model",
|
||||||
|
"object",
|
||||||
|
"output",
|
||||||
|
"parallel_tool_calls",
|
||||||
|
"status",
|
||||||
|
"input"
|
||||||
|
],
|
||||||
|
"title": "OpenAIResponseObjectWithInput"
|
||||||
|
},
|
||||||
"ListProvidersResponse": {
|
"ListProvidersResponse": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
|
@ -11605,6 +12137,10 @@
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"default": "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n",
|
"default": "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n",
|
||||||
"description": "Template for formatting each retrieved chunk in the context. Available placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk content string), {metadata} (chunk metadata dict). Default: \"Result {index}\\nContent: {chunk.content}\\nMetadata: {metadata}\\n\""
|
"description": "Template for formatting each retrieved chunk in the context. Available placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk content string), {metadata} (chunk metadata dict). Default: \"Result {index}\\nContent: {chunk.content}\\nMetadata: {metadata}\\n\""
|
||||||
|
},
|
||||||
|
"mode": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Search mode for retrieval—either \"vector\" or \"keyword\". Default \"vector\"."
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
|
|
398
docs/_static/llama-stack-spec.yaml
vendored
398
docs/_static/llama-stack-spec.yaml
vendored
|
@ -349,6 +349,53 @@ paths:
|
||||||
$ref: '#/components/schemas/CreateAgentTurnRequest'
|
$ref: '#/components/schemas/CreateAgentTurnRequest'
|
||||||
required: true
|
required: true
|
||||||
/v1/openai/v1/responses:
|
/v1/openai/v1/responses:
|
||||||
|
get:
|
||||||
|
responses:
|
||||||
|
'200':
|
||||||
|
description: A ListOpenAIResponseObject.
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: '#/components/schemas/ListOpenAIResponseObject'
|
||||||
|
'400':
|
||||||
|
$ref: '#/components/responses/BadRequest400'
|
||||||
|
'429':
|
||||||
|
$ref: >-
|
||||||
|
#/components/responses/TooManyRequests429
|
||||||
|
'500':
|
||||||
|
$ref: >-
|
||||||
|
#/components/responses/InternalServerError500
|
||||||
|
default:
|
||||||
|
$ref: '#/components/responses/DefaultError'
|
||||||
|
tags:
|
||||||
|
- Agents
|
||||||
|
description: List all OpenAI responses.
|
||||||
|
parameters:
|
||||||
|
- name: after
|
||||||
|
in: query
|
||||||
|
description: The ID of the last response to return.
|
||||||
|
required: false
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
- name: limit
|
||||||
|
in: query
|
||||||
|
description: The number of responses to return.
|
||||||
|
required: false
|
||||||
|
schema:
|
||||||
|
type: integer
|
||||||
|
- name: model
|
||||||
|
in: query
|
||||||
|
description: The model to filter responses by.
|
||||||
|
required: false
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
- name: order
|
||||||
|
in: query
|
||||||
|
description: >-
|
||||||
|
The order to sort responses by when sorted by created_at ('asc' or 'desc').
|
||||||
|
required: false
|
||||||
|
schema:
|
||||||
|
$ref: '#/components/schemas/Order'
|
||||||
post:
|
post:
|
||||||
responses:
|
responses:
|
||||||
'200':
|
'200':
|
||||||
|
@ -963,7 +1010,7 @@ paths:
|
||||||
required: true
|
required: true
|
||||||
schema:
|
schema:
|
||||||
type: string
|
type: string
|
||||||
/v1/openai/v1/responses/{id}:
|
/v1/openai/v1/responses/{response_id}:
|
||||||
get:
|
get:
|
||||||
responses:
|
responses:
|
||||||
'200':
|
'200':
|
||||||
|
@ -986,7 +1033,7 @@ paths:
|
||||||
- Agents
|
- Agents
|
||||||
description: Retrieve an OpenAI response by its ID.
|
description: Retrieve an OpenAI response by its ID.
|
||||||
parameters:
|
parameters:
|
||||||
- name: id
|
- name: response_id
|
||||||
in: path
|
in: path
|
||||||
description: >-
|
description: >-
|
||||||
The ID of the OpenAI response to retrieve.
|
The ID of the OpenAI response to retrieve.
|
||||||
|
@ -2038,6 +2085,75 @@ paths:
|
||||||
schema:
|
schema:
|
||||||
$ref: '#/components/schemas/RegisterModelRequest'
|
$ref: '#/components/schemas/RegisterModelRequest'
|
||||||
required: true
|
required: true
|
||||||
|
/v1/openai/v1/responses/{response_id}/input_items:
|
||||||
|
get:
|
||||||
|
responses:
|
||||||
|
'200':
|
||||||
|
description: An ListOpenAIResponseInputItem.
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: '#/components/schemas/ListOpenAIResponseInputItem'
|
||||||
|
'400':
|
||||||
|
$ref: '#/components/responses/BadRequest400'
|
||||||
|
'429':
|
||||||
|
$ref: >-
|
||||||
|
#/components/responses/TooManyRequests429
|
||||||
|
'500':
|
||||||
|
$ref: >-
|
||||||
|
#/components/responses/InternalServerError500
|
||||||
|
default:
|
||||||
|
$ref: '#/components/responses/DefaultError'
|
||||||
|
tags:
|
||||||
|
- Agents
|
||||||
|
description: >-
|
||||||
|
List input items for a given OpenAI response.
|
||||||
|
parameters:
|
||||||
|
- name: response_id
|
||||||
|
in: path
|
||||||
|
description: >-
|
||||||
|
The ID of the response to retrieve input items for.
|
||||||
|
required: true
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
- name: after
|
||||||
|
in: query
|
||||||
|
description: >-
|
||||||
|
An item ID to list items after, used for pagination.
|
||||||
|
required: false
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
- name: before
|
||||||
|
in: query
|
||||||
|
description: >-
|
||||||
|
An item ID to list items before, used for pagination.
|
||||||
|
required: false
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
- name: include
|
||||||
|
in: query
|
||||||
|
description: >-
|
||||||
|
Additional fields to include in the response.
|
||||||
|
required: false
|
||||||
|
schema:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
type: string
|
||||||
|
- name: limit
|
||||||
|
in: query
|
||||||
|
description: >-
|
||||||
|
A limit on the number of objects to be returned. Limit can range between
|
||||||
|
1 and 100, and the default is 20.
|
||||||
|
required: false
|
||||||
|
schema:
|
||||||
|
type: integer
|
||||||
|
- name: order
|
||||||
|
in: query
|
||||||
|
description: >-
|
||||||
|
The order to return the input items in. Default is desc.
|
||||||
|
required: false
|
||||||
|
schema:
|
||||||
|
$ref: '#/components/schemas/Order'
|
||||||
/v1/providers:
|
/v1/providers:
|
||||||
get:
|
get:
|
||||||
responses:
|
responses:
|
||||||
|
@ -4762,12 +4878,14 @@ components:
|
||||||
- $ref: '#/components/schemas/OpenAIResponseInputToolWebSearch'
|
- $ref: '#/components/schemas/OpenAIResponseInputToolWebSearch'
|
||||||
- $ref: '#/components/schemas/OpenAIResponseInputToolFileSearch'
|
- $ref: '#/components/schemas/OpenAIResponseInputToolFileSearch'
|
||||||
- $ref: '#/components/schemas/OpenAIResponseInputToolFunction'
|
- $ref: '#/components/schemas/OpenAIResponseInputToolFunction'
|
||||||
|
- $ref: '#/components/schemas/OpenAIResponseInputToolMCP'
|
||||||
discriminator:
|
discriminator:
|
||||||
propertyName: type
|
propertyName: type
|
||||||
mapping:
|
mapping:
|
||||||
web_search: '#/components/schemas/OpenAIResponseInputToolWebSearch'
|
web_search: '#/components/schemas/OpenAIResponseInputToolWebSearch'
|
||||||
file_search: '#/components/schemas/OpenAIResponseInputToolFileSearch'
|
file_search: '#/components/schemas/OpenAIResponseInputToolFileSearch'
|
||||||
function: '#/components/schemas/OpenAIResponseInputToolFunction'
|
function: '#/components/schemas/OpenAIResponseInputToolFunction'
|
||||||
|
mcp: '#/components/schemas/OpenAIResponseInputToolMCP'
|
||||||
OpenAIResponseInputToolFileSearch:
|
OpenAIResponseInputToolFileSearch:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
|
@ -4822,6 +4940,66 @@ components:
|
||||||
- type
|
- type
|
||||||
- name
|
- name
|
||||||
title: OpenAIResponseInputToolFunction
|
title: OpenAIResponseInputToolFunction
|
||||||
|
OpenAIResponseInputToolMCP:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
type:
|
||||||
|
type: string
|
||||||
|
const: mcp
|
||||||
|
default: mcp
|
||||||
|
server_label:
|
||||||
|
type: string
|
||||||
|
server_url:
|
||||||
|
type: string
|
||||||
|
headers:
|
||||||
|
type: object
|
||||||
|
additionalProperties:
|
||||||
|
oneOf:
|
||||||
|
- type: 'null'
|
||||||
|
- type: boolean
|
||||||
|
- type: number
|
||||||
|
- type: string
|
||||||
|
- type: array
|
||||||
|
- type: object
|
||||||
|
require_approval:
|
||||||
|
oneOf:
|
||||||
|
- type: string
|
||||||
|
const: always
|
||||||
|
- type: string
|
||||||
|
const: never
|
||||||
|
- type: object
|
||||||
|
properties:
|
||||||
|
always:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
type: string
|
||||||
|
never:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
type: string
|
||||||
|
additionalProperties: false
|
||||||
|
title: ApprovalFilter
|
||||||
|
default: never
|
||||||
|
allowed_tools:
|
||||||
|
oneOf:
|
||||||
|
- type: array
|
||||||
|
items:
|
||||||
|
type: string
|
||||||
|
- type: object
|
||||||
|
properties:
|
||||||
|
tool_names:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
type: string
|
||||||
|
additionalProperties: false
|
||||||
|
title: AllowedToolsFilter
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- type
|
||||||
|
- server_label
|
||||||
|
- server_url
|
||||||
|
- require_approval
|
||||||
|
title: OpenAIResponseInputToolMCP
|
||||||
OpenAIResponseInputToolWebSearch:
|
OpenAIResponseInputToolWebSearch:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
|
@ -4897,12 +5075,12 @@ components:
|
||||||
"OpenAIResponseOutputMessageFunctionToolCall":
|
"OpenAIResponseOutputMessageFunctionToolCall":
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
arguments:
|
|
||||||
type: string
|
|
||||||
call_id:
|
call_id:
|
||||||
type: string
|
type: string
|
||||||
name:
|
name:
|
||||||
type: string
|
type: string
|
||||||
|
arguments:
|
||||||
|
type: string
|
||||||
type:
|
type:
|
||||||
type: string
|
type: string
|
||||||
const: function_call
|
const: function_call
|
||||||
|
@ -4913,12 +5091,10 @@ components:
|
||||||
type: string
|
type: string
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
required:
|
required:
|
||||||
- arguments
|
|
||||||
- call_id
|
- call_id
|
||||||
- name
|
- name
|
||||||
|
- arguments
|
||||||
- type
|
- type
|
||||||
- id
|
|
||||||
- status
|
|
||||||
title: >-
|
title: >-
|
||||||
OpenAIResponseOutputMessageFunctionToolCall
|
OpenAIResponseOutputMessageFunctionToolCall
|
||||||
"OpenAIResponseOutputMessageWebSearchToolCall":
|
"OpenAIResponseOutputMessageWebSearchToolCall":
|
||||||
|
@ -4952,6 +5128,8 @@ components:
|
||||||
model:
|
model:
|
||||||
type: string
|
type: string
|
||||||
description: The underlying LLM used for completions.
|
description: The underlying LLM used for completions.
|
||||||
|
instructions:
|
||||||
|
type: string
|
||||||
previous_response_id:
|
previous_response_id:
|
||||||
type: string
|
type: string
|
||||||
description: >-
|
description: >-
|
||||||
|
@ -5034,20 +5212,95 @@ components:
|
||||||
- $ref: '#/components/schemas/OpenAIResponseMessage'
|
- $ref: '#/components/schemas/OpenAIResponseMessage'
|
||||||
- $ref: '#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall'
|
- $ref: '#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall'
|
||||||
- $ref: '#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall'
|
- $ref: '#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall'
|
||||||
|
- $ref: '#/components/schemas/OpenAIResponseOutputMessageMCPCall'
|
||||||
|
- $ref: '#/components/schemas/OpenAIResponseOutputMessageMCPListTools'
|
||||||
discriminator:
|
discriminator:
|
||||||
propertyName: type
|
propertyName: type
|
||||||
mapping:
|
mapping:
|
||||||
message: '#/components/schemas/OpenAIResponseMessage'
|
message: '#/components/schemas/OpenAIResponseMessage'
|
||||||
web_search_call: '#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall'
|
web_search_call: '#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall'
|
||||||
function_call: '#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall'
|
function_call: '#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall'
|
||||||
|
mcp_call: '#/components/schemas/OpenAIResponseOutputMessageMCPCall'
|
||||||
|
mcp_list_tools: '#/components/schemas/OpenAIResponseOutputMessageMCPListTools'
|
||||||
|
OpenAIResponseOutputMessageMCPCall:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
id:
|
||||||
|
type: string
|
||||||
|
type:
|
||||||
|
type: string
|
||||||
|
const: mcp_call
|
||||||
|
default: mcp_call
|
||||||
|
arguments:
|
||||||
|
type: string
|
||||||
|
name:
|
||||||
|
type: string
|
||||||
|
server_label:
|
||||||
|
type: string
|
||||||
|
error:
|
||||||
|
type: string
|
||||||
|
output:
|
||||||
|
type: string
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- id
|
||||||
|
- type
|
||||||
|
- arguments
|
||||||
|
- name
|
||||||
|
- server_label
|
||||||
|
title: OpenAIResponseOutputMessageMCPCall
|
||||||
|
OpenAIResponseOutputMessageMCPListTools:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
id:
|
||||||
|
type: string
|
||||||
|
type:
|
||||||
|
type: string
|
||||||
|
const: mcp_list_tools
|
||||||
|
default: mcp_list_tools
|
||||||
|
server_label:
|
||||||
|
type: string
|
||||||
|
tools:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
input_schema:
|
||||||
|
type: object
|
||||||
|
additionalProperties:
|
||||||
|
oneOf:
|
||||||
|
- type: 'null'
|
||||||
|
- type: boolean
|
||||||
|
- type: number
|
||||||
|
- type: string
|
||||||
|
- type: array
|
||||||
|
- type: object
|
||||||
|
name:
|
||||||
|
type: string
|
||||||
|
description:
|
||||||
|
type: string
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- input_schema
|
||||||
|
- name
|
||||||
|
title: MCPListToolsTool
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- id
|
||||||
|
- type
|
||||||
|
- server_label
|
||||||
|
- tools
|
||||||
|
title: OpenAIResponseOutputMessageMCPListTools
|
||||||
OpenAIResponseObjectStream:
|
OpenAIResponseObjectStream:
|
||||||
oneOf:
|
oneOf:
|
||||||
- $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseCreated'
|
- $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseCreated'
|
||||||
|
- $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseOutputTextDelta'
|
||||||
- $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseCompleted'
|
- $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseCompleted'
|
||||||
discriminator:
|
discriminator:
|
||||||
propertyName: type
|
propertyName: type
|
||||||
mapping:
|
mapping:
|
||||||
response.created: '#/components/schemas/OpenAIResponseObjectStreamResponseCreated'
|
response.created: '#/components/schemas/OpenAIResponseObjectStreamResponseCreated'
|
||||||
|
response.output_text.delta: '#/components/schemas/OpenAIResponseObjectStreamResponseOutputTextDelta'
|
||||||
response.completed: '#/components/schemas/OpenAIResponseObjectStreamResponseCompleted'
|
response.completed: '#/components/schemas/OpenAIResponseObjectStreamResponseCompleted'
|
||||||
"OpenAIResponseObjectStreamResponseCompleted":
|
"OpenAIResponseObjectStreamResponseCompleted":
|
||||||
type: object
|
type: object
|
||||||
|
@ -5079,6 +5332,33 @@ components:
|
||||||
- type
|
- type
|
||||||
title: >-
|
title: >-
|
||||||
OpenAIResponseObjectStreamResponseCreated
|
OpenAIResponseObjectStreamResponseCreated
|
||||||
|
"OpenAIResponseObjectStreamResponseOutputTextDelta":
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
content_index:
|
||||||
|
type: integer
|
||||||
|
delta:
|
||||||
|
type: string
|
||||||
|
item_id:
|
||||||
|
type: string
|
||||||
|
output_index:
|
||||||
|
type: integer
|
||||||
|
sequence_number:
|
||||||
|
type: integer
|
||||||
|
type:
|
||||||
|
type: string
|
||||||
|
const: response.output_text.delta
|
||||||
|
default: response.output_text.delta
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- content_index
|
||||||
|
- delta
|
||||||
|
- item_id
|
||||||
|
- output_index
|
||||||
|
- sequence_number
|
||||||
|
- type
|
||||||
|
title: >-
|
||||||
|
OpenAIResponseObjectStreamResponseOutputTextDelta
|
||||||
CreateUploadSessionRequest:
|
CreateUploadSessionRequest:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
|
@ -6462,8 +6742,6 @@ components:
|
||||||
default: tool
|
default: tool
|
||||||
toolgroup_id:
|
toolgroup_id:
|
||||||
type: string
|
type: string
|
||||||
tool_host:
|
|
||||||
$ref: '#/components/schemas/ToolHost'
|
|
||||||
description:
|
description:
|
||||||
type: string
|
type: string
|
||||||
parameters:
|
parameters:
|
||||||
|
@ -6486,17 +6764,9 @@ components:
|
||||||
- provider_id
|
- provider_id
|
||||||
- type
|
- type
|
||||||
- toolgroup_id
|
- toolgroup_id
|
||||||
- tool_host
|
|
||||||
- description
|
- description
|
||||||
- parameters
|
- parameters
|
||||||
title: Tool
|
title: Tool
|
||||||
ToolHost:
|
|
||||||
type: string
|
|
||||||
enum:
|
|
||||||
- distribution
|
|
||||||
- client
|
|
||||||
- model_context_protocol
|
|
||||||
title: ToolHost
|
|
||||||
ToolGroup:
|
ToolGroup:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
|
@ -7042,6 +7312,96 @@ components:
|
||||||
required:
|
required:
|
||||||
- data
|
- data
|
||||||
title: ListModelsResponse
|
title: ListModelsResponse
|
||||||
|
ListOpenAIResponseInputItem:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
data:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
$ref: '#/components/schemas/OpenAIResponseInput'
|
||||||
|
object:
|
||||||
|
type: string
|
||||||
|
const: list
|
||||||
|
default: list
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- data
|
||||||
|
- object
|
||||||
|
title: ListOpenAIResponseInputItem
|
||||||
|
ListOpenAIResponseObject:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
data:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
$ref: '#/components/schemas/OpenAIResponseObjectWithInput'
|
||||||
|
has_more:
|
||||||
|
type: boolean
|
||||||
|
first_id:
|
||||||
|
type: string
|
||||||
|
last_id:
|
||||||
|
type: string
|
||||||
|
object:
|
||||||
|
type: string
|
||||||
|
const: list
|
||||||
|
default: list
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- data
|
||||||
|
- has_more
|
||||||
|
- first_id
|
||||||
|
- last_id
|
||||||
|
- object
|
||||||
|
title: ListOpenAIResponseObject
|
||||||
|
OpenAIResponseObjectWithInput:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
created_at:
|
||||||
|
type: integer
|
||||||
|
error:
|
||||||
|
$ref: '#/components/schemas/OpenAIResponseError'
|
||||||
|
id:
|
||||||
|
type: string
|
||||||
|
model:
|
||||||
|
type: string
|
||||||
|
object:
|
||||||
|
type: string
|
||||||
|
const: response
|
||||||
|
default: response
|
||||||
|
output:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
$ref: '#/components/schemas/OpenAIResponseOutput'
|
||||||
|
parallel_tool_calls:
|
||||||
|
type: boolean
|
||||||
|
default: false
|
||||||
|
previous_response_id:
|
||||||
|
type: string
|
||||||
|
status:
|
||||||
|
type: string
|
||||||
|
temperature:
|
||||||
|
type: number
|
||||||
|
top_p:
|
||||||
|
type: number
|
||||||
|
truncation:
|
||||||
|
type: string
|
||||||
|
user:
|
||||||
|
type: string
|
||||||
|
input:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
$ref: '#/components/schemas/OpenAIResponseInput'
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- created_at
|
||||||
|
- id
|
||||||
|
- model
|
||||||
|
- object
|
||||||
|
- output
|
||||||
|
- parallel_tool_calls
|
||||||
|
- status
|
||||||
|
- input
|
||||||
|
title: OpenAIResponseObjectWithInput
|
||||||
ListProvidersResponse:
|
ListProvidersResponse:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
|
@ -8084,6 +8444,10 @@ components:
|
||||||
placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk
|
placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk
|
||||||
content string), {metadata} (chunk metadata dict). Default: "Result {index}\nContent:
|
content string), {metadata} (chunk metadata dict). Default: "Result {index}\nContent:
|
||||||
{chunk.content}\nMetadata: {metadata}\n"
|
{chunk.content}\nMetadata: {metadata}\n"
|
||||||
|
mode:
|
||||||
|
type: string
|
||||||
|
description: >-
|
||||||
|
Search mode for retrieval—either "vector" or "keyword". Default "vector".
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
required:
|
required:
|
||||||
- query_generator_config
|
- query_generator_config
|
||||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -3,10 +3,10 @@
|
||||||
Here's a collection of comprehensive guides, examples, and resources for building AI applications with Llama Stack. For the complete documentation, visit our [ReadTheDocs page](https://llama-stack.readthedocs.io/en/latest/index.html).
|
Here's a collection of comprehensive guides, examples, and resources for building AI applications with Llama Stack. For the complete documentation, visit our [ReadTheDocs page](https://llama-stack.readthedocs.io/en/latest/index.html).
|
||||||
|
|
||||||
## Render locally
|
## Render locally
|
||||||
|
|
||||||
|
From the llama-stack root directory, run the following command to render the docs locally:
|
||||||
```bash
|
```bash
|
||||||
pip install -r requirements.txt
|
uv run --with ".[docs]" sphinx-autobuild docs/source docs/build/html --write-all
|
||||||
cd docs
|
|
||||||
python -m sphinx_autobuild source _build
|
|
||||||
```
|
```
|
||||||
You can open up the docs in your browser at http://localhost:8000
|
You can open up the docs in your browser at http://localhost:8000
|
||||||
|
|
||||||
|
|
|
@ -1,16 +0,0 @@
|
||||||
linkify
|
|
||||||
myst-parser
|
|
||||||
-e git+https://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme
|
|
||||||
sphinx==8.1.3
|
|
||||||
sphinx-copybutton
|
|
||||||
sphinx-design
|
|
||||||
sphinx-pdj-theme
|
|
||||||
sphinx-rtd-theme>=1.0.0
|
|
||||||
sphinx-tabs
|
|
||||||
sphinx_autobuild
|
|
||||||
sphinx_rtd_dark_mode
|
|
||||||
sphinxcontrib-mermaid
|
|
||||||
sphinxcontrib-openapi
|
|
||||||
sphinxcontrib-redoc
|
|
||||||
sphinxcontrib-video
|
|
||||||
tomli
|
|
|
@ -22,7 +22,11 @@ from docutils import nodes
|
||||||
# Read version from pyproject.toml
|
# Read version from pyproject.toml
|
||||||
with Path(__file__).parent.parent.parent.joinpath("pyproject.toml").open("rb") as f:
|
with Path(__file__).parent.parent.parent.joinpath("pyproject.toml").open("rb") as f:
|
||||||
pypi_url = "https://pypi.org/pypi/llama-stack/json"
|
pypi_url = "https://pypi.org/pypi/llama-stack/json"
|
||||||
version_tag = json.loads(requests.get(pypi_url).text)["info"]["version"]
|
headers = {
|
||||||
|
'User-Agent': 'pip/23.0.1 (python 3.11)', # Mimic pip's user agent
|
||||||
|
'Accept': 'application/json'
|
||||||
|
}
|
||||||
|
version_tag = json.loads(requests.get(pypi_url, headers=headers).text)["info"]["version"]
|
||||||
print(f"{version_tag=}")
|
print(f"{version_tag=}")
|
||||||
|
|
||||||
# generate the full link including text and url here
|
# generate the full link including text and url here
|
||||||
|
@ -53,14 +57,6 @@ myst_enable_extensions = ["colon_fence"]
|
||||||
|
|
||||||
html_theme = "sphinx_rtd_theme"
|
html_theme = "sphinx_rtd_theme"
|
||||||
html_use_relative_paths = True
|
html_use_relative_paths = True
|
||||||
|
|
||||||
# html_theme = "sphinx_pdj_theme"
|
|
||||||
# html_theme_path = [sphinx_pdj_theme.get_html_theme_path()]
|
|
||||||
|
|
||||||
# html_theme = "pytorch_sphinx_theme"
|
|
||||||
# html_theme_path = [pytorch_sphinx_theme.get_html_theme_path()]
|
|
||||||
|
|
||||||
|
|
||||||
templates_path = ["_templates"]
|
templates_path = ["_templates"]
|
||||||
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
|
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
|
||||||
|
|
||||||
|
|
|
@ -338,6 +338,48 @@ INFO: Application startup complete.
|
||||||
INFO: Uvicorn running on http://['::', '0.0.0.0']:8321 (Press CTRL+C to quit)
|
INFO: Uvicorn running on http://['::', '0.0.0.0']:8321 (Press CTRL+C to quit)
|
||||||
INFO: 2401:db00:35c:2d2b:face:0:c9:0:54678 - "GET /models/list HTTP/1.1" 200 OK
|
INFO: 2401:db00:35c:2d2b:face:0:c9:0:54678 - "GET /models/list HTTP/1.1" 200 OK
|
||||||
```
|
```
|
||||||
|
### Listing Distributions
|
||||||
|
Using the list command, you can view all existing Llama Stack distributions, including stacks built from templates, from scratch, or using custom configuration files.
|
||||||
|
|
||||||
|
```
|
||||||
|
llama stack list -h
|
||||||
|
usage: llama stack list [-h]
|
||||||
|
|
||||||
|
list the build stacks
|
||||||
|
|
||||||
|
options:
|
||||||
|
-h, --help show this help message and exit
|
||||||
|
```
|
||||||
|
|
||||||
|
Example Usage
|
||||||
|
|
||||||
|
```
|
||||||
|
llama stack list
|
||||||
|
```
|
||||||
|
|
||||||
|
### Removing a Distribution
|
||||||
|
Use the remove command to delete a distribution you've previously built.
|
||||||
|
|
||||||
|
```
|
||||||
|
llama stack rm -h
|
||||||
|
usage: llama stack rm [-h] [--all] [name]
|
||||||
|
|
||||||
|
Remove the build stack
|
||||||
|
|
||||||
|
positional arguments:
|
||||||
|
name Name of the stack to delete (default: None)
|
||||||
|
|
||||||
|
options:
|
||||||
|
-h, --help show this help message and exit
|
||||||
|
--all, -a Delete all stacks (use with caution) (default: False)
|
||||||
|
```
|
||||||
|
|
||||||
|
Example
|
||||||
|
```
|
||||||
|
llama stack rm llamastack-test
|
||||||
|
```
|
||||||
|
|
||||||
|
To keep your environment organized and avoid clutter, consider using `llama stack list` to review old or unused distributions and `llama stack rm <name>` to delete them when they’re no longer needed.
|
||||||
|
|
||||||
### Troubleshooting
|
### Troubleshooting
|
||||||
|
|
||||||
|
|
|
@ -118,11 +118,6 @@ server:
|
||||||
port: 8321 # Port to listen on (default: 8321)
|
port: 8321 # Port to listen on (default: 8321)
|
||||||
tls_certfile: "/path/to/cert.pem" # Optional: Path to TLS certificate for HTTPS
|
tls_certfile: "/path/to/cert.pem" # Optional: Path to TLS certificate for HTTPS
|
||||||
tls_keyfile: "/path/to/key.pem" # Optional: Path to TLS key for HTTPS
|
tls_keyfile: "/path/to/key.pem" # Optional: Path to TLS key for HTTPS
|
||||||
auth: # Optional: Authentication configuration
|
|
||||||
provider_type: "kubernetes" # Type of auth provider
|
|
||||||
config: # Provider-specific configuration
|
|
||||||
api_server_url: "https://kubernetes.default.svc"
|
|
||||||
ca_cert_path: "/path/to/ca.crt" # Optional: Path to CA certificate
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### Authentication Configuration
|
### Authentication Configuration
|
||||||
|
@ -135,7 +130,7 @@ Authorization: Bearer <token>
|
||||||
|
|
||||||
The server supports multiple authentication providers:
|
The server supports multiple authentication providers:
|
||||||
|
|
||||||
#### Kubernetes Provider
|
#### OAuth 2.0/OpenID Connect Provider with Kubernetes
|
||||||
|
|
||||||
The Kubernetes cluster must be configured to use a service account for authentication.
|
The Kubernetes cluster must be configured to use a service account for authentication.
|
||||||
|
|
||||||
|
@ -146,14 +141,67 @@ kubectl create rolebinding llama-stack-auth-rolebinding --clusterrole=admin --se
|
||||||
kubectl create token llama-stack-auth -n llama-stack > llama-stack-auth-token
|
kubectl create token llama-stack-auth -n llama-stack > llama-stack-auth-token
|
||||||
```
|
```
|
||||||
|
|
||||||
Validates tokens against the Kubernetes API server:
|
Make sure the `kube-apiserver` runs with `--anonymous-auth=true` to allow unauthenticated requests
|
||||||
|
and that the correct RoleBinding is created to allow the service account to access the necessary
|
||||||
|
resources. If that is not the case, you can create a RoleBinding for the service account to access
|
||||||
|
the necessary resources:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# allow-anonymous-openid.yaml
|
||||||
|
apiVersion: rbac.authorization.k8s.io/v1
|
||||||
|
kind: ClusterRole
|
||||||
|
metadata:
|
||||||
|
name: allow-anonymous-openid
|
||||||
|
rules:
|
||||||
|
- nonResourceURLs: ["/openid/v1/jwks"]
|
||||||
|
verbs: ["get"]
|
||||||
|
---
|
||||||
|
apiVersion: rbac.authorization.k8s.io/v1
|
||||||
|
kind: ClusterRoleBinding
|
||||||
|
metadata:
|
||||||
|
name: allow-anonymous-openid
|
||||||
|
roleRef:
|
||||||
|
apiGroup: rbac.authorization.k8s.io
|
||||||
|
kind: ClusterRole
|
||||||
|
name: allow-anonymous-openid
|
||||||
|
subjects:
|
||||||
|
- kind: User
|
||||||
|
name: system:anonymous
|
||||||
|
apiGroup: rbac.authorization.k8s.io
|
||||||
|
```
|
||||||
|
|
||||||
|
And then apply the configuration:
|
||||||
|
```bash
|
||||||
|
kubectl apply -f allow-anonymous-openid.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
Validates tokens against the Kubernetes API server through the OIDC provider:
|
||||||
```yaml
|
```yaml
|
||||||
server:
|
server:
|
||||||
auth:
|
auth:
|
||||||
provider_type: "kubernetes"
|
provider_type: "oauth2_token"
|
||||||
config:
|
config:
|
||||||
api_server_url: "https://kubernetes.default.svc" # URL of the Kubernetes API server
|
jwks:
|
||||||
ca_cert_path: "/path/to/ca.crt" # Optional: Path to CA certificate
|
uri: "https://kubernetes.default.svc"
|
||||||
|
key_recheck_period: 3600
|
||||||
|
tls_cafile: "/path/to/ca.crt"
|
||||||
|
issuer: "https://kubernetes.default.svc"
|
||||||
|
audience: "https://kubernetes.default.svc"
|
||||||
|
```
|
||||||
|
|
||||||
|
To find your cluster's audience, run:
|
||||||
|
```bash
|
||||||
|
kubectl create token default --duration=1h | cut -d. -f2 | base64 -d | jq .aud
|
||||||
|
```
|
||||||
|
|
||||||
|
For the issuer, you can use the OIDC provider's URL:
|
||||||
|
```bash
|
||||||
|
kubectl get --raw /.well-known/openid-configuration| jq .issuer
|
||||||
|
```
|
||||||
|
|
||||||
|
For the tls_cafile, you can use the CA certificate of the OIDC provider:
|
||||||
|
```bash
|
||||||
|
kubectl config view --minify -o jsonpath='{.clusters[0].cluster.certificate-authority}'
|
||||||
```
|
```
|
||||||
|
|
||||||
The provider extracts user information from the JWT token:
|
The provider extracts user information from the JWT token:
|
||||||
|
@ -208,6 +256,80 @@ And must respond with:
|
||||||
|
|
||||||
If no access attributes are returned, the token is used as a namespace.
|
If no access attributes are returned, the token is used as a namespace.
|
||||||
|
|
||||||
|
### Quota Configuration
|
||||||
|
|
||||||
|
The `quota` section allows you to enable server-side request throttling for both
|
||||||
|
authenticated and anonymous clients. This is useful for preventing abuse, enforcing
|
||||||
|
fairness across tenants, and controlling infrastructure costs without requiring
|
||||||
|
client-side rate limiting or external proxies.
|
||||||
|
|
||||||
|
Quotas are disabled by default. When enabled, each client is tracked using either:
|
||||||
|
|
||||||
|
* Their authenticated `client_id` (derived from the Bearer token), or
|
||||||
|
* Their IP address (fallback for anonymous requests)
|
||||||
|
|
||||||
|
Quota state is stored in a SQLite-backed key-value store, and rate limits are applied
|
||||||
|
within a configurable time window (currently only `day` is supported).
|
||||||
|
|
||||||
|
#### Example
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
server:
|
||||||
|
quota:
|
||||||
|
kvstore:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ./quotas.db
|
||||||
|
anonymous_max_requests: 100
|
||||||
|
authenticated_max_requests: 1000
|
||||||
|
period: day
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Configuration Options
|
||||||
|
|
||||||
|
| Field | Description |
|
||||||
|
| ---------------------------- | -------------------------------------------------------------------------- |
|
||||||
|
| `kvstore` | Required. Backend storage config for tracking request counts. |
|
||||||
|
| `kvstore.type` | Must be `"sqlite"` for now. Other backends may be supported in the future. |
|
||||||
|
| `kvstore.db_path` | File path to the SQLite database. |
|
||||||
|
| `anonymous_max_requests` | Max requests per period for unauthenticated clients. |
|
||||||
|
| `authenticated_max_requests` | Max requests per period for authenticated clients. |
|
||||||
|
| `period` | Time window for quota enforcement. Only `"day"` is supported. |
|
||||||
|
|
||||||
|
> Note: if `authenticated_max_requests` is set but no authentication provider is
|
||||||
|
configured, the server will fall back to applying `anonymous_max_requests` to all
|
||||||
|
clients.
|
||||||
|
|
||||||
|
#### Example with Authentication Enabled
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
server:
|
||||||
|
port: 8321
|
||||||
|
auth:
|
||||||
|
provider_type: custom
|
||||||
|
config:
|
||||||
|
endpoint: https://auth.example.com/validate
|
||||||
|
quota:
|
||||||
|
kvstore:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ./quotas.db
|
||||||
|
anonymous_max_requests: 100
|
||||||
|
authenticated_max_requests: 1000
|
||||||
|
period: day
|
||||||
|
```
|
||||||
|
|
||||||
|
If a client exceeds their limit, the server responds with:
|
||||||
|
|
||||||
|
```http
|
||||||
|
HTTP/1.1 429 Too Many Requests
|
||||||
|
Content-Type: application/json
|
||||||
|
|
||||||
|
{
|
||||||
|
"error": {
|
||||||
|
"message": "Quota exceeded"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
## Extending to handle Safety
|
## Extending to handle Safety
|
||||||
|
|
||||||
Configuring Safety can be a little involved so it is instructive to go through an example.
|
Configuring Safety can be a little involved so it is instructive to go through an example.
|
||||||
|
|
|
@ -17,7 +17,7 @@ The `llamastack/distribution-sambanova` distribution consists of the following p
|
||||||
|-----|-------------|
|
|-----|-------------|
|
||||||
| agents | `inline::meta-reference` |
|
| agents | `inline::meta-reference` |
|
||||||
| inference | `remote::sambanova`, `inline::sentence-transformers` |
|
| inference | `remote::sambanova`, `inline::sentence-transformers` |
|
||||||
| safety | `inline::llama-guard` |
|
| safety | `remote::sambanova` |
|
||||||
| telemetry | `inline::meta-reference` |
|
| telemetry | `inline::meta-reference` |
|
||||||
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::rag-runtime`, `remote::model-context-protocol`, `remote::wolfram-alpha` |
|
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::rag-runtime`, `remote::model-context-protocol`, `remote::wolfram-alpha` |
|
||||||
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
||||||
|
@ -48,33 +48,44 @@ The following models are available by default:
|
||||||
|
|
||||||
### Prerequisite: API Keys
|
### Prerequisite: API Keys
|
||||||
|
|
||||||
Make sure you have access to a SambaNova API Key. You can get one by visiting [SambaNova.ai](https://sambanova.ai/).
|
Make sure you have access to a SambaNova API Key. You can get one by visiting [SambaNova.ai](http://cloud.sambanova.ai?utm_source=llamastack&utm_medium=external&utm_campaign=cloud_signup).
|
||||||
|
|
||||||
|
|
||||||
## Running Llama Stack with SambaNova
|
## Running Llama Stack with SambaNova
|
||||||
|
|
||||||
You can do this via Conda (build code) or Docker which has a pre-built image.
|
You can do this via Conda (build code) or Docker which has a pre-built image.
|
||||||
|
|
||||||
### Via Docker
|
|
||||||
|
|
||||||
This method allows you to get started quickly without having to build the distribution code.
|
### Via Docker
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
LLAMA_STACK_PORT=8321
|
LLAMA_STACK_PORT=8321
|
||||||
|
llama stack build --template sambanova --image-type container
|
||||||
docker run \
|
docker run \
|
||||||
-it \
|
-it \
|
||||||
--pull always \
|
|
||||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||||
llamastack/distribution-sambanova \
|
-v ~/.llama:/root/.llama \
|
||||||
|
distribution-sambanova \
|
||||||
--port $LLAMA_STACK_PORT \
|
--port $LLAMA_STACK_PORT \
|
||||||
--env SAMBANOVA_API_KEY=$SAMBANOVA_API_KEY
|
--env SAMBANOVA_API_KEY=$SAMBANOVA_API_KEY
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
|
### Via Venv
|
||||||
|
|
||||||
|
```bash
|
||||||
|
llama stack build --template sambanova --image-type venv
|
||||||
|
llama stack run --image-type venv ~/.llama/distributions/sambanova/sambanova-run.yaml \
|
||||||
|
--port $LLAMA_STACK_PORT \
|
||||||
|
--env SAMBANOVA_API_KEY=$SAMBANOVA_API_KEY
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
### Via Conda
|
### Via Conda
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
llama stack build --template sambanova --image-type conda
|
llama stack build --template sambanova --image-type conda
|
||||||
llama stack run ./run.yaml \
|
llama stack run --image-type conda ~/.llama/distributions/sambanova/sambanova-run.yaml \
|
||||||
--port $LLAMA_STACK_PORT \
|
--port $LLAMA_STACK_PORT \
|
||||||
--env SAMBANOVA_API_KEY=$SAMBANOVA_API_KEY
|
--env SAMBANOVA_API_KEY=$SAMBANOVA_API_KEY
|
||||||
```
|
```
|
||||||
|
|
|
@ -66,6 +66,25 @@ To use sqlite-vec in your Llama Stack project, follow these steps:
|
||||||
2. Configure your Llama Stack project to use SQLite-Vec.
|
2. Configure your Llama Stack project to use SQLite-Vec.
|
||||||
3. Start storing and querying vectors.
|
3. Start storing and querying vectors.
|
||||||
|
|
||||||
|
## Supported Search Modes
|
||||||
|
|
||||||
|
The sqlite-vec provider supports both vector-based and keyword-based (full-text) search modes.
|
||||||
|
|
||||||
|
When using the RAGTool interface, you can specify the desired search behavior via the `mode` parameter in
|
||||||
|
`RAGQueryConfig`. For example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from llama_stack.apis.tool_runtime.rag import RAGQueryConfig
|
||||||
|
|
||||||
|
query_config = RAGQueryConfig(max_chunks=6, mode="vector")
|
||||||
|
|
||||||
|
results = client.tool_runtime.rag_tool.query(
|
||||||
|
vector_db_ids=[vector_db_id],
|
||||||
|
content="what is torchtune",
|
||||||
|
query_config=query_config,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
|
|
||||||
You can install SQLite-Vec using pip:
|
You can install SQLite-Vec using pip:
|
||||||
|
|
|
@ -13,7 +13,7 @@ from typing import Annotated, Any, Literal, Protocol, runtime_checkable
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import URL, ContentDelta, InterleavedContent
|
from llama_stack.apis.common.content_types import URL, ContentDelta, InterleavedContent
|
||||||
from llama_stack.apis.common.responses import PaginatedResponse
|
from llama_stack.apis.common.responses import Order, PaginatedResponse
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
CompletionMessage,
|
CompletionMessage,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
|
@ -31,6 +31,8 @@ from llama_stack.apis.tools import ToolDef
|
||||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||||
|
|
||||||
from .openai_responses import (
|
from .openai_responses import (
|
||||||
|
ListOpenAIResponseInputItem,
|
||||||
|
ListOpenAIResponseObject,
|
||||||
OpenAIResponseInput,
|
OpenAIResponseInput,
|
||||||
OpenAIResponseInputTool,
|
OpenAIResponseInputTool,
|
||||||
OpenAIResponseObject,
|
OpenAIResponseObject,
|
||||||
|
@ -579,14 +581,14 @@ class Agents(Protocol):
|
||||||
#
|
#
|
||||||
# Both of these APIs are inherently stateful.
|
# Both of these APIs are inherently stateful.
|
||||||
|
|
||||||
@webmethod(route="/openai/v1/responses/{id}", method="GET")
|
@webmethod(route="/openai/v1/responses/{response_id}", method="GET")
|
||||||
async def get_openai_response(
|
async def get_openai_response(
|
||||||
self,
|
self,
|
||||||
id: str,
|
response_id: str,
|
||||||
) -> OpenAIResponseObject:
|
) -> OpenAIResponseObject:
|
||||||
"""Retrieve an OpenAI response by its ID.
|
"""Retrieve an OpenAI response by its ID.
|
||||||
|
|
||||||
:param id: The ID of the OpenAI response to retrieve.
|
:param response_id: The ID of the OpenAI response to retrieve.
|
||||||
:returns: An OpenAIResponseObject.
|
:returns: An OpenAIResponseObject.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
@ -596,6 +598,7 @@ class Agents(Protocol):
|
||||||
self,
|
self,
|
||||||
input: str | list[OpenAIResponseInput],
|
input: str | list[OpenAIResponseInput],
|
||||||
model: str,
|
model: str,
|
||||||
|
instructions: str | None = None,
|
||||||
previous_response_id: str | None = None,
|
previous_response_id: str | None = None,
|
||||||
store: bool | None = True,
|
store: bool | None = True,
|
||||||
stream: bool | None = False,
|
stream: bool | None = False,
|
||||||
|
@ -610,3 +613,43 @@ class Agents(Protocol):
|
||||||
:returns: An OpenAIResponseObject.
|
:returns: An OpenAIResponseObject.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@webmethod(route="/openai/v1/responses", method="GET")
|
||||||
|
async def list_openai_responses(
|
||||||
|
self,
|
||||||
|
after: str | None = None,
|
||||||
|
limit: int | None = 50,
|
||||||
|
model: str | None = None,
|
||||||
|
order: Order | None = Order.desc,
|
||||||
|
) -> ListOpenAIResponseObject:
|
||||||
|
"""List all OpenAI responses.
|
||||||
|
|
||||||
|
:param after: The ID of the last response to return.
|
||||||
|
:param limit: The number of responses to return.
|
||||||
|
:param model: The model to filter responses by.
|
||||||
|
:param order: The order to sort responses by when sorted by created_at ('asc' or 'desc').
|
||||||
|
:returns: A ListOpenAIResponseObject.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@webmethod(route="/openai/v1/responses/{response_id}/input_items", method="GET")
|
||||||
|
async def list_openai_response_input_items(
|
||||||
|
self,
|
||||||
|
response_id: str,
|
||||||
|
after: str | None = None,
|
||||||
|
before: str | None = None,
|
||||||
|
include: list[str] | None = None,
|
||||||
|
limit: int | None = 20,
|
||||||
|
order: Order | None = Order.desc,
|
||||||
|
) -> ListOpenAIResponseInputItem:
|
||||||
|
"""List input items for a given OpenAI response.
|
||||||
|
|
||||||
|
:param response_id: The ID of the response to retrieve input items for.
|
||||||
|
:param after: An item ID to list items after, used for pagination.
|
||||||
|
:param before: An item ID to list items before, used for pagination.
|
||||||
|
:param include: Additional fields to include in the response.
|
||||||
|
:param limit: A limit on the number of objects to be returned. Limit can range between 1 and 100, and the default is 20.
|
||||||
|
:param order: The order to return the input items in. Default is desc.
|
||||||
|
:returns: An ListOpenAIResponseInputItem.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
|
@ -10,6 +10,9 @@ from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from llama_stack.schema_utils import json_schema_type, register_schema
|
from llama_stack.schema_utils import json_schema_type, register_schema
|
||||||
|
|
||||||
|
# NOTE(ashwin): this file is literally a copy of the OpenAI responses API schema. We should probably
|
||||||
|
# take their YAML and generate this file automatically. Their YAML is available.
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class OpenAIResponseError(BaseModel):
|
class OpenAIResponseError(BaseModel):
|
||||||
|
@ -79,16 +82,45 @@ class OpenAIResponseOutputMessageWebSearchToolCall(BaseModel):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class OpenAIResponseOutputMessageFunctionToolCall(BaseModel):
|
class OpenAIResponseOutputMessageFunctionToolCall(BaseModel):
|
||||||
arguments: str
|
|
||||||
call_id: str
|
call_id: str
|
||||||
name: str
|
name: str
|
||||||
|
arguments: str
|
||||||
type: Literal["function_call"] = "function_call"
|
type: Literal["function_call"] = "function_call"
|
||||||
|
id: str | None = None
|
||||||
|
status: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIResponseOutputMessageMCPCall(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
status: str
|
type: Literal["mcp_call"] = "mcp_call"
|
||||||
|
arguments: str
|
||||||
|
name: str
|
||||||
|
server_label: str
|
||||||
|
error: str | None = None
|
||||||
|
output: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class MCPListToolsTool(BaseModel):
|
||||||
|
input_schema: dict[str, Any]
|
||||||
|
name: str
|
||||||
|
description: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIResponseOutputMessageMCPListTools(BaseModel):
|
||||||
|
id: str
|
||||||
|
type: Literal["mcp_list_tools"] = "mcp_list_tools"
|
||||||
|
server_label: str
|
||||||
|
tools: list[MCPListToolsTool]
|
||||||
|
|
||||||
|
|
||||||
OpenAIResponseOutput = Annotated[
|
OpenAIResponseOutput = Annotated[
|
||||||
OpenAIResponseMessage | OpenAIResponseOutputMessageWebSearchToolCall | OpenAIResponseOutputMessageFunctionToolCall,
|
OpenAIResponseMessage
|
||||||
|
| OpenAIResponseOutputMessageWebSearchToolCall
|
||||||
|
| OpenAIResponseOutputMessageFunctionToolCall
|
||||||
|
| OpenAIResponseOutputMessageMCPCall
|
||||||
|
| OpenAIResponseOutputMessageMCPListTools,
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
register_schema(OpenAIResponseOutput, name="OpenAIResponseOutput")
|
register_schema(OpenAIResponseOutput, name="OpenAIResponseOutput")
|
||||||
|
@ -117,6 +149,16 @@ class OpenAIResponseObjectStreamResponseCreated(BaseModel):
|
||||||
type: Literal["response.created"] = "response.created"
|
type: Literal["response.created"] = "response.created"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIResponseObjectStreamResponseOutputTextDelta(BaseModel):
|
||||||
|
content_index: int
|
||||||
|
delta: str
|
||||||
|
item_id: str
|
||||||
|
output_index: int
|
||||||
|
sequence_number: int
|
||||||
|
type: Literal["response.output_text.delta"] = "response.output_text.delta"
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class OpenAIResponseObjectStreamResponseCompleted(BaseModel):
|
class OpenAIResponseObjectStreamResponseCompleted(BaseModel):
|
||||||
response: OpenAIResponseObject
|
response: OpenAIResponseObject
|
||||||
|
@ -124,7 +166,9 @@ class OpenAIResponseObjectStreamResponseCompleted(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
OpenAIResponseObjectStream = Annotated[
|
OpenAIResponseObjectStream = Annotated[
|
||||||
OpenAIResponseObjectStreamResponseCreated | OpenAIResponseObjectStreamResponseCompleted,
|
OpenAIResponseObjectStreamResponseCreated
|
||||||
|
| OpenAIResponseObjectStreamResponseOutputTextDelta
|
||||||
|
| OpenAIResponseObjectStreamResponseCompleted,
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
register_schema(OpenAIResponseObjectStream, name="OpenAIResponseObjectStream")
|
register_schema(OpenAIResponseObjectStream, name="OpenAIResponseObjectStream")
|
||||||
|
@ -186,13 +230,50 @@ class OpenAIResponseInputToolFileSearch(BaseModel):
|
||||||
# TODO: add filters
|
# TODO: add filters
|
||||||
|
|
||||||
|
|
||||||
|
class ApprovalFilter(BaseModel):
|
||||||
|
always: list[str] | None = None
|
||||||
|
never: list[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class AllowedToolsFilter(BaseModel):
|
||||||
|
tool_names: list[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIResponseInputToolMCP(BaseModel):
|
||||||
|
type: Literal["mcp"] = "mcp"
|
||||||
|
server_label: str
|
||||||
|
server_url: str
|
||||||
|
headers: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
require_approval: Literal["always"] | Literal["never"] | ApprovalFilter = "never"
|
||||||
|
allowed_tools: list[str] | AllowedToolsFilter | None = None
|
||||||
|
|
||||||
|
|
||||||
OpenAIResponseInputTool = Annotated[
|
OpenAIResponseInputTool = Annotated[
|
||||||
OpenAIResponseInputToolWebSearch | OpenAIResponseInputToolFileSearch | OpenAIResponseInputToolFunction,
|
OpenAIResponseInputToolWebSearch
|
||||||
|
| OpenAIResponseInputToolFileSearch
|
||||||
|
| OpenAIResponseInputToolFunction
|
||||||
|
| OpenAIResponseInputToolMCP,
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
register_schema(OpenAIResponseInputTool, name="OpenAIResponseInputTool")
|
register_schema(OpenAIResponseInputTool, name="OpenAIResponseInputTool")
|
||||||
|
|
||||||
|
|
||||||
class OpenAIResponseInputItemList(BaseModel):
|
class ListOpenAIResponseInputItem(BaseModel):
|
||||||
data: list[OpenAIResponseInput]
|
data: list[OpenAIResponseInput]
|
||||||
object: Literal["list"] = "list"
|
object: Literal["list"] = "list"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIResponseObjectWithInput(OpenAIResponseObject):
|
||||||
|
input: list[OpenAIResponseInput]
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ListOpenAIResponseObject(BaseModel):
|
||||||
|
data: list[OpenAIResponseObjectWithInput]
|
||||||
|
has_more: bool
|
||||||
|
first_id: str
|
||||||
|
last_id: str
|
||||||
|
object: Literal["list"] = "list"
|
||||||
|
|
|
@ -1,30 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from enum import Enum
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import URL
|
|
||||||
from llama_stack.schema_utils import json_schema_type
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class RestAPIMethod(Enum):
|
|
||||||
GET = "GET"
|
|
||||||
POST = "POST"
|
|
||||||
PUT = "PUT"
|
|
||||||
DELETE = "DELETE"
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class RestAPIExecutionConfig(BaseModel):
|
|
||||||
url: URL
|
|
||||||
method: RestAPIMethod
|
|
||||||
params: dict[str, Any] | None = None
|
|
||||||
headers: dict[str, Any] | None = None
|
|
||||||
body: dict[str, Any] | None = None
|
|
|
@ -4,6 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
@ -11,6 +12,11 @@ from pydantic import BaseModel
|
||||||
from llama_stack.schema_utils import json_schema_type
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
|
||||||
|
class Order(Enum):
|
||||||
|
asc = "asc"
|
||||||
|
desc = "desc"
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class PaginatedResponse(BaseModel):
|
class PaginatedResponse(BaseModel):
|
||||||
"""A generic paginated response that follows a simple format.
|
"""A generic paginated response that follows a simple format.
|
||||||
|
|
|
@ -19,6 +19,7 @@ from pydantic import BaseModel, Field, field_validator
|
||||||
from typing_extensions import TypedDict
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent, InterleavedContentItem
|
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent, InterleavedContentItem
|
||||||
|
from llama_stack.apis.common.responses import Order
|
||||||
from llama_stack.apis.models import Model
|
from llama_stack.apis.models import Model
|
||||||
from llama_stack.apis.telemetry.telemetry import MetricResponseMixin
|
from llama_stack.apis.telemetry.telemetry import MetricResponseMixin
|
||||||
from llama_stack.models.llama.datatypes import (
|
from llama_stack.models.llama.datatypes import (
|
||||||
|
@ -833,11 +834,6 @@ class ListOpenAIChatCompletionResponse(BaseModel):
|
||||||
object: Literal["list"] = "list"
|
object: Literal["list"] = "list"
|
||||||
|
|
||||||
|
|
||||||
class Order(Enum):
|
|
||||||
asc = "asc"
|
|
||||||
desc = "desc"
|
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
@trace_protocol
|
@trace_protocol
|
||||||
class InferenceProvider(Protocol):
|
class InferenceProvider(Protocol):
|
||||||
|
|
|
@ -76,6 +76,7 @@ class RAGQueryConfig(BaseModel):
|
||||||
:param chunk_template: Template for formatting each retrieved chunk in the context.
|
:param chunk_template: Template for formatting each retrieved chunk in the context.
|
||||||
Available placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk content string), {metadata} (chunk metadata dict).
|
Available placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk content string), {metadata} (chunk metadata dict).
|
||||||
Default: "Result {index}\\nContent: {chunk.content}\\nMetadata: {metadata}\\n"
|
Default: "Result {index}\\nContent: {chunk.content}\\nMetadata: {metadata}\\n"
|
||||||
|
:param mode: Search mode for retrieval—either "vector" or "keyword". Default "vector".
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# This config defines how a query is generated using the messages
|
# This config defines how a query is generated using the messages
|
||||||
|
@ -84,6 +85,7 @@ class RAGQueryConfig(BaseModel):
|
||||||
max_tokens_in_context: int = 4096
|
max_tokens_in_context: int = 4096
|
||||||
max_chunks: int = 5
|
max_chunks: int = 5
|
||||||
chunk_template: str = "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n"
|
chunk_template: str = "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n"
|
||||||
|
mode: str | None = None
|
||||||
|
|
||||||
@field_validator("chunk_template")
|
@field_validator("chunk_template")
|
||||||
def validate_chunk_template(cls, v: str) -> str:
|
def validate_chunk_template(cls, v: str) -> str:
|
||||||
|
|
|
@ -27,18 +27,10 @@ class ToolParameter(BaseModel):
|
||||||
default: Any | None = None
|
default: Any | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class ToolHost(Enum):
|
|
||||||
distribution = "distribution"
|
|
||||||
client = "client"
|
|
||||||
model_context_protocol = "model_context_protocol"
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class Tool(Resource):
|
class Tool(Resource):
|
||||||
type: Literal[ResourceType.tool] = ResourceType.tool
|
type: Literal[ResourceType.tool] = ResourceType.tool
|
||||||
toolgroup_id: str
|
toolgroup_id: str
|
||||||
tool_host: ToolHost
|
|
||||||
description: str
|
description: str
|
||||||
parameters: list[ToolParameter]
|
parameters: list[ToolParameter]
|
||||||
metadata: dict[str, Any] | None = None
|
metadata: dict[str, Any] | None = None
|
||||||
|
@ -76,8 +68,8 @@ class ToolInvocationResult(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class ToolStore(Protocol):
|
class ToolStore(Protocol):
|
||||||
def get_tool(self, tool_name: str) -> Tool: ...
|
async def get_tool(self, tool_name: str) -> Tool: ...
|
||||||
def get_tool_group(self, toolgroup_id: str) -> ToolGroup: ...
|
async def get_tool_group(self, toolgroup_id: str) -> ToolGroup: ...
|
||||||
|
|
||||||
|
|
||||||
class ListToolGroupsResponse(BaseModel):
|
class ListToolGroupsResponse(BaseModel):
|
||||||
|
|
|
@ -9,6 +9,7 @@ import asyncio
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
|
import sys
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
@ -377,14 +378,15 @@ def _meta_download(
|
||||||
downloader = ParallelDownloader(max_concurrent_downloads=max_concurrent_downloads)
|
downloader = ParallelDownloader(max_concurrent_downloads=max_concurrent_downloads)
|
||||||
asyncio.run(downloader.download_all(tasks))
|
asyncio.run(downloader.download_all(tasks))
|
||||||
|
|
||||||
cprint(f"\nSuccessfully downloaded model to {output_dir}", "green")
|
cprint(f"\nSuccessfully downloaded model to {output_dir}", color="green", file=sys.stderr)
|
||||||
cprint(
|
cprint(
|
||||||
f"\nView MD5 checksum files at: {output_dir / 'checklist.chk'}",
|
f"\nView MD5 checksum files at: {output_dir / 'checklist.chk'}",
|
||||||
"white",
|
file=sys.stderr,
|
||||||
)
|
)
|
||||||
cprint(
|
cprint(
|
||||||
f"\n[Optionally] To run MD5 checksums, use the following command: llama model verify-download --model-id {model_id}",
|
f"\n[Optionally] To run MD5 checksums, use the following command: llama model verify-download --model-id {model_id}",
|
||||||
"yellow",
|
color="yellow",
|
||||||
|
file=sys.stderr,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -79,6 +79,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
||||||
cprint(
|
cprint(
|
||||||
f"Could not find template {args.template}. Please run `llama stack build --list-templates` to check out the available templates",
|
f"Could not find template {args.template}. Please run `llama stack build --list-templates` to check out the available templates",
|
||||||
color="red",
|
color="red",
|
||||||
|
file=sys.stderr,
|
||||||
)
|
)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
build_config = available_templates[args.template]
|
build_config = available_templates[args.template]
|
||||||
|
@ -88,6 +89,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
||||||
cprint(
|
cprint(
|
||||||
f"Please specify a image-type ({' | '.join(e.value for e in ImageType)}) for {args.template}",
|
f"Please specify a image-type ({' | '.join(e.value for e in ImageType)}) for {args.template}",
|
||||||
color="red",
|
color="red",
|
||||||
|
file=sys.stderr,
|
||||||
)
|
)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
elif args.providers:
|
elif args.providers:
|
||||||
|
@ -97,6 +99,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
||||||
cprint(
|
cprint(
|
||||||
"Could not parse `--providers`. Please ensure the list is in the format api1=provider1,api2=provider2",
|
"Could not parse `--providers`. Please ensure the list is in the format api1=provider1,api2=provider2",
|
||||||
color="red",
|
color="red",
|
||||||
|
file=sys.stderr,
|
||||||
)
|
)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
api, provider = api_provider.split("=")
|
api, provider = api_provider.split("=")
|
||||||
|
@ -105,6 +108,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
||||||
cprint(
|
cprint(
|
||||||
f"{api} is not a valid API.",
|
f"{api} is not a valid API.",
|
||||||
color="red",
|
color="red",
|
||||||
|
file=sys.stderr,
|
||||||
)
|
)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
if provider in providers_for_api:
|
if provider in providers_for_api:
|
||||||
|
@ -113,6 +117,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
||||||
cprint(
|
cprint(
|
||||||
f"{provider} is not a valid provider for the {api} API.",
|
f"{provider} is not a valid provider for the {api} API.",
|
||||||
color="red",
|
color="red",
|
||||||
|
file=sys.stderr,
|
||||||
)
|
)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
distribution_spec = DistributionSpec(
|
distribution_spec = DistributionSpec(
|
||||||
|
@ -123,6 +128,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
||||||
cprint(
|
cprint(
|
||||||
f"Please specify a image-type (container | conda | venv) for {args.template}",
|
f"Please specify a image-type (container | conda | venv) for {args.template}",
|
||||||
color="red",
|
color="red",
|
||||||
|
file=sys.stderr,
|
||||||
)
|
)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
@ -151,12 +157,14 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
||||||
cprint(
|
cprint(
|
||||||
f"No current conda environment detected or specified, will create a new conda environment with the name `llamastack-{name}`",
|
f"No current conda environment detected or specified, will create a new conda environment with the name `llamastack-{name}`",
|
||||||
color="yellow",
|
color="yellow",
|
||||||
|
file=sys.stderr,
|
||||||
)
|
)
|
||||||
image_name = f"llamastack-{name}"
|
image_name = f"llamastack-{name}"
|
||||||
else:
|
else:
|
||||||
cprint(
|
cprint(
|
||||||
f"Using conda environment {image_name}",
|
f"Using conda environment {image_name}",
|
||||||
color="green",
|
color="green",
|
||||||
|
file=sys.stderr,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
image_name = f"llamastack-{name}"
|
image_name = f"llamastack-{name}"
|
||||||
|
@ -169,9 +177,10 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
||||||
""",
|
""",
|
||||||
),
|
),
|
||||||
color="green",
|
color="green",
|
||||||
|
file=sys.stderr,
|
||||||
)
|
)
|
||||||
|
|
||||||
print("Tip: use <TAB> to see options for the providers.\n")
|
cprint("Tip: use <TAB> to see options for the providers.\n", color="green", file=sys.stderr)
|
||||||
|
|
||||||
providers = dict()
|
providers = dict()
|
||||||
for api, providers_for_api in get_provider_registry().items():
|
for api, providers_for_api in get_provider_registry().items():
|
||||||
|
@ -213,6 +222,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
||||||
cprint(
|
cprint(
|
||||||
f"Could not parse config file {args.config}: {e}",
|
f"Could not parse config file {args.config}: {e}",
|
||||||
color="red",
|
color="red",
|
||||||
|
file=sys.stderr,
|
||||||
)
|
)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
@ -239,22 +249,25 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
||||||
cprint(
|
cprint(
|
||||||
f"Error building stack: {exc}",
|
f"Error building stack: {exc}",
|
||||||
color="red",
|
color="red",
|
||||||
|
file=sys.stderr,
|
||||||
)
|
)
|
||||||
cprint("Stack trace:", color="red")
|
cprint("Stack trace:", color="red", file=sys.stderr)
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
if run_config is None:
|
if run_config is None:
|
||||||
cprint(
|
cprint(
|
||||||
"Run config path is empty",
|
"Run config path is empty",
|
||||||
color="red",
|
color="red",
|
||||||
|
file=sys.stderr,
|
||||||
)
|
)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
if args.run:
|
if args.run:
|
||||||
config_dict = yaml.safe_load(run_config.read_text())
|
config_dict = yaml.safe_load(run_config.read_text())
|
||||||
config = parse_and_maybe_upgrade_config(config_dict)
|
config = parse_and_maybe_upgrade_config(config_dict)
|
||||||
if not os.path.exists(config.external_providers_dir):
|
if config.external_providers_dir and not config.external_providers_dir.exists():
|
||||||
os.makedirs(config.external_providers_dir, exist_ok=True)
|
config.external_providers_dir.mkdir(exist_ok=True)
|
||||||
run_args = formulate_run_args(args.image_type, args.image_name, config, args.template)
|
run_args = formulate_run_args(args.image_type, args.image_name, config, args.template)
|
||||||
run_args.extend([str(os.getenv("LLAMA_STACK_PORT", 8321)), "--config", run_config])
|
run_args.extend([str(os.getenv("LLAMA_STACK_PORT", 8321)), "--config", run_config])
|
||||||
run_command(run_args)
|
run_command(run_args)
|
||||||
|
@ -304,6 +317,7 @@ def _generate_run_config(
|
||||||
cprint(
|
cprint(
|
||||||
f"Failed to import provider {provider_type} for API {api} - assuming it's external, skipping",
|
f"Failed to import provider {provider_type} for API {api} - assuming it's external, skipping",
|
||||||
color="yellow",
|
color="yellow",
|
||||||
|
file=sys.stderr,
|
||||||
)
|
)
|
||||||
# Set config_type to None to avoid UnboundLocalError
|
# Set config_type to None to avoid UnboundLocalError
|
||||||
config_type = None
|
config_type = None
|
||||||
|
@ -331,10 +345,7 @@ def _generate_run_config(
|
||||||
# For non-container builds, the run.yaml is generated at the very end of the build process so it
|
# For non-container builds, the run.yaml is generated at the very end of the build process so it
|
||||||
# makes sense to display this message
|
# makes sense to display this message
|
||||||
if build_config.image_type != LlamaStackImageType.CONTAINER.value:
|
if build_config.image_type != LlamaStackImageType.CONTAINER.value:
|
||||||
cprint(
|
cprint(f"You can now run your stack with `llama stack run {run_config_file}`", color="green", file=sys.stderr)
|
||||||
f"You can now run your stack with `llama stack run {run_config_file}`",
|
|
||||||
color="green",
|
|
||||||
)
|
|
||||||
return run_config_file
|
return run_config_file
|
||||||
|
|
||||||
|
|
||||||
|
@ -372,7 +383,7 @@ def _run_stack_build_command_from_build_config(
|
||||||
# Generate the run.yaml so it can be included in the container image with the proper entrypoint
|
# Generate the run.yaml so it can be included in the container image with the proper entrypoint
|
||||||
# Only do this if we're building a container image and we're not using a template
|
# Only do this if we're building a container image and we're not using a template
|
||||||
if build_config.image_type == LlamaStackImageType.CONTAINER.value and not template_name and config_path:
|
if build_config.image_type == LlamaStackImageType.CONTAINER.value and not template_name and config_path:
|
||||||
cprint("Generating run.yaml file", color="green")
|
cprint("Generating run.yaml file", color="yellow", file=sys.stderr)
|
||||||
run_config_file = _generate_run_config(build_config, build_dir, image_name)
|
run_config_file = _generate_run_config(build_config, build_dir, image_name)
|
||||||
|
|
||||||
with open(build_file_path, "w") as f:
|
with open(build_file_path, "w") as f:
|
||||||
|
@ -396,11 +407,13 @@ def _run_stack_build_command_from_build_config(
|
||||||
run_config_file = build_dir / f"{template_name}-run.yaml"
|
run_config_file = build_dir / f"{template_name}-run.yaml"
|
||||||
shutil.copy(path, run_config_file)
|
shutil.copy(path, run_config_file)
|
||||||
|
|
||||||
cprint("Build Successful!", color="green")
|
cprint("Build Successful!", color="green", file=sys.stderr)
|
||||||
cprint("You can find the newly-built template here: " + colored(template_path, "light_blue"))
|
cprint(f"You can find the newly-built template here: {template_path}", color="light_blue", file=sys.stderr)
|
||||||
cprint(
|
cprint(
|
||||||
"You can run the new Llama Stack distro via: "
|
"You can run the new Llama Stack distro via: "
|
||||||
+ colored(f"llama stack run {template_path} --image-type {build_config.image_type}", "light_blue")
|
+ colored(f"llama stack run {template_path} --image-type {build_config.image_type}", "light_blue"),
|
||||||
|
color="green",
|
||||||
|
file=sys.stderr,
|
||||||
)
|
)
|
||||||
return template_path
|
return template_path
|
||||||
else:
|
else:
|
||||||
|
|
56
llama_stack/cli/stack/list_stacks.py
Normal file
56
llama_stack/cli/stack/list_stacks.py
Normal file
|
@ -0,0 +1,56 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from llama_stack.cli.subcommand import Subcommand
|
||||||
|
from llama_stack.cli.table import print_table
|
||||||
|
|
||||||
|
|
||||||
|
class StackListBuilds(Subcommand):
|
||||||
|
"""List built stacks in .llama/distributions directory"""
|
||||||
|
|
||||||
|
def __init__(self, subparsers: argparse._SubParsersAction):
|
||||||
|
super().__init__()
|
||||||
|
self.parser = subparsers.add_parser(
|
||||||
|
"list",
|
||||||
|
prog="llama stack list",
|
||||||
|
description="list the build stacks",
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||||
|
)
|
||||||
|
self._add_arguments()
|
||||||
|
self.parser.set_defaults(func=self._list_stack_command)
|
||||||
|
|
||||||
|
def _get_distribution_dirs(self) -> dict[str, Path]:
|
||||||
|
"""Return a dictionary of distribution names and their paths"""
|
||||||
|
distributions = {}
|
||||||
|
dist_dir = Path.home() / ".llama" / "distributions"
|
||||||
|
|
||||||
|
if dist_dir.exists():
|
||||||
|
for stack_dir in dist_dir.iterdir():
|
||||||
|
if stack_dir.is_dir():
|
||||||
|
distributions[stack_dir.name] = stack_dir
|
||||||
|
return distributions
|
||||||
|
|
||||||
|
def _list_stack_command(self, args: argparse.Namespace) -> None:
|
||||||
|
distributions = self._get_distribution_dirs()
|
||||||
|
|
||||||
|
if not distributions:
|
||||||
|
print("No stacks found in ~/.llama/distributions")
|
||||||
|
return
|
||||||
|
|
||||||
|
headers = ["Stack Name", "Path"]
|
||||||
|
headers.extend(["Build Config", "Run Config"])
|
||||||
|
rows = []
|
||||||
|
for name, path in distributions.items():
|
||||||
|
row = [name, str(path)]
|
||||||
|
# Check for build and run config files
|
||||||
|
build_config = "Yes" if (path / f"{name}-build.yaml").exists() else "No"
|
||||||
|
run_config = "Yes" if (path / f"{name}-run.yaml").exists() else "No"
|
||||||
|
row.extend([build_config, run_config])
|
||||||
|
rows.append(row)
|
||||||
|
print_table(rows, headers, separate_rows=True)
|
115
llama_stack/cli/stack/remove.py
Normal file
115
llama_stack/cli/stack/remove.py
Normal file
|
@ -0,0 +1,115 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import shutil
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from termcolor import cprint
|
||||||
|
|
||||||
|
from llama_stack.cli.subcommand import Subcommand
|
||||||
|
from llama_stack.cli.table import print_table
|
||||||
|
|
||||||
|
|
||||||
|
class StackRemove(Subcommand):
|
||||||
|
"""Remove the build stack"""
|
||||||
|
|
||||||
|
def __init__(self, subparsers: argparse._SubParsersAction):
|
||||||
|
super().__init__()
|
||||||
|
self.parser = subparsers.add_parser(
|
||||||
|
"rm",
|
||||||
|
prog="llama stack rm",
|
||||||
|
description="Remove the build stack",
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||||
|
)
|
||||||
|
self._add_arguments()
|
||||||
|
self.parser.set_defaults(func=self._remove_stack_build_command)
|
||||||
|
|
||||||
|
def _add_arguments(self) -> None:
|
||||||
|
self.parser.add_argument(
|
||||||
|
"name",
|
||||||
|
type=str,
|
||||||
|
nargs="?",
|
||||||
|
help="Name of the stack to delete",
|
||||||
|
)
|
||||||
|
self.parser.add_argument(
|
||||||
|
"--all",
|
||||||
|
"-a",
|
||||||
|
action="store_true",
|
||||||
|
help="Delete all stacks (use with caution)",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_distribution_dirs(self) -> dict[str, Path]:
|
||||||
|
"""Return a dictionary of distribution names and their paths"""
|
||||||
|
distributions = {}
|
||||||
|
dist_dir = Path.home() / ".llama" / "distributions"
|
||||||
|
|
||||||
|
if dist_dir.exists():
|
||||||
|
for stack_dir in dist_dir.iterdir():
|
||||||
|
if stack_dir.is_dir():
|
||||||
|
distributions[stack_dir.name] = stack_dir
|
||||||
|
return distributions
|
||||||
|
|
||||||
|
def _list_stacks(self) -> None:
|
||||||
|
"""Display available stacks in a table"""
|
||||||
|
distributions = self._get_distribution_dirs()
|
||||||
|
if not distributions:
|
||||||
|
cprint("No stacks found in ~/.llama/distributions", color="red", file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
headers = ["Stack Name", "Path"]
|
||||||
|
rows = [[name, str(path)] for name, path in distributions.items()]
|
||||||
|
print_table(rows, headers, separate_rows=True)
|
||||||
|
|
||||||
|
def _remove_stack_build_command(self, args: argparse.Namespace) -> None:
|
||||||
|
distributions = self._get_distribution_dirs()
|
||||||
|
|
||||||
|
if args.all:
|
||||||
|
confirm = input("Are you sure you want to delete ALL stacks? [yes-i-really-want/N] ").lower()
|
||||||
|
if confirm != "yes-i-really-want":
|
||||||
|
cprint("Deletion cancelled.", color="green", file=sys.stderr)
|
||||||
|
return
|
||||||
|
|
||||||
|
for name, path in distributions.items():
|
||||||
|
try:
|
||||||
|
shutil.rmtree(path)
|
||||||
|
cprint(f"Deleted stack: {name}", color="green", file=sys.stderr)
|
||||||
|
except Exception as e:
|
||||||
|
cprint(
|
||||||
|
f"Failed to delete stack {name}: {e}",
|
||||||
|
color="red",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
if not args.name:
|
||||||
|
self._list_stacks()
|
||||||
|
if not args.name:
|
||||||
|
return
|
||||||
|
|
||||||
|
if args.name not in distributions:
|
||||||
|
self._list_stacks()
|
||||||
|
cprint(
|
||||||
|
f"Stack not found: {args.name}",
|
||||||
|
color="red",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
stack_path = distributions[args.name]
|
||||||
|
|
||||||
|
confirm = input(f"Are you sure you want to delete stack '{args.name}'? [y/N] ").lower()
|
||||||
|
if confirm != "y":
|
||||||
|
cprint("Deletion cancelled.", color="green", file=sys.stderr)
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
shutil.rmtree(stack_path)
|
||||||
|
cprint(f"Successfully deleted stack: {args.name}", color="green", file=sys.stderr)
|
||||||
|
except Exception as e:
|
||||||
|
cprint(f"Failed to delete stack {args.name}: {e}", color="red", file=sys.stderr)
|
||||||
|
sys.exit(1)
|
|
@ -6,6 +6,7 @@
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
|
import subprocess
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from llama_stack.cli.stack.utils import ImageType
|
from llama_stack.cli.stack.utils import ImageType
|
||||||
|
@ -60,6 +61,11 @@ class StackRun(Subcommand):
|
||||||
help="Image Type used during the build. This can be either conda or container or venv.",
|
help="Image Type used during the build. This can be either conda or container or venv.",
|
||||||
choices=[e.value for e in ImageType],
|
choices=[e.value for e in ImageType],
|
||||||
)
|
)
|
||||||
|
self.parser.add_argument(
|
||||||
|
"--enable-ui",
|
||||||
|
action="store_true",
|
||||||
|
help="Start the UI server",
|
||||||
|
)
|
||||||
|
|
||||||
# If neither image type nor image name is provided, but at the same time
|
# If neither image type nor image name is provided, but at the same time
|
||||||
# the current environment has conda breadcrumbs, then assume what the user
|
# the current environment has conda breadcrumbs, then assume what the user
|
||||||
|
@ -83,6 +89,8 @@ class StackRun(Subcommand):
|
||||||
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
|
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||||
from llama_stack.distribution.utils.exec import formulate_run_args, run_command
|
from llama_stack.distribution.utils.exec import formulate_run_args, run_command
|
||||||
|
|
||||||
|
if args.enable_ui:
|
||||||
|
self._start_ui_development_server(args.port)
|
||||||
image_type, image_name = self._get_image_type_and_name(args)
|
image_type, image_name = self._get_image_type_and_name(args)
|
||||||
|
|
||||||
# Check if config is required based on image type
|
# Check if config is required based on image type
|
||||||
|
@ -170,3 +178,44 @@ class StackRun(Subcommand):
|
||||||
run_args.extend(["--env", f"{key}={value}"])
|
run_args.extend(["--env", f"{key}={value}"])
|
||||||
|
|
||||||
run_command(run_args)
|
run_command(run_args)
|
||||||
|
|
||||||
|
def _start_ui_development_server(self, stack_server_port: int):
|
||||||
|
logger.info("Attempting to start UI development server...")
|
||||||
|
# Check if npm is available
|
||||||
|
npm_check = subprocess.run(["npm", "--version"], capture_output=True, text=True, check=False)
|
||||||
|
if npm_check.returncode != 0:
|
||||||
|
logger.warning(
|
||||||
|
f"'npm' command not found or not executable. UI development server will not be started. Error: {npm_check.stderr}"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
ui_dir = REPO_ROOT / "llama_stack" / "ui"
|
||||||
|
logs_dir = Path("~/.llama/ui/logs").expanduser()
|
||||||
|
try:
|
||||||
|
# Create logs directory if it doesn't exist
|
||||||
|
logs_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
ui_stdout_log_path = logs_dir / "stdout.log"
|
||||||
|
ui_stderr_log_path = logs_dir / "stderr.log"
|
||||||
|
|
||||||
|
# Open log files in append mode
|
||||||
|
stdout_log_file = open(ui_stdout_log_path, "a")
|
||||||
|
stderr_log_file = open(ui_stderr_log_path, "a")
|
||||||
|
|
||||||
|
process = subprocess.Popen(
|
||||||
|
["npm", "run", "dev"],
|
||||||
|
cwd=str(ui_dir),
|
||||||
|
stdout=stdout_log_file,
|
||||||
|
stderr=stderr_log_file,
|
||||||
|
env={**os.environ, "NEXT_PUBLIC_LLAMA_STACK_BASE_URL": f"http://localhost:{stack_server_port}"},
|
||||||
|
)
|
||||||
|
logger.info(f"UI development server process started in {ui_dir} with PID {process.pid}.")
|
||||||
|
logger.info(f"Logs: stdout -> {ui_stdout_log_path}, stderr -> {ui_stderr_log_path}")
|
||||||
|
logger.info(f"UI will be available at http://localhost:{os.getenv('LLAMA_STACK_UI_PORT', 8322)}")
|
||||||
|
|
||||||
|
except FileNotFoundError:
|
||||||
|
logger.error(
|
||||||
|
"Failed to start UI development server: 'npm' command not found. Make sure npm is installed and in your PATH."
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to start UI development server in {ui_dir}: {e}")
|
||||||
|
|
|
@ -7,12 +7,14 @@
|
||||||
import argparse
|
import argparse
|
||||||
from importlib.metadata import version
|
from importlib.metadata import version
|
||||||
|
|
||||||
|
from llama_stack.cli.stack.list_stacks import StackListBuilds
|
||||||
from llama_stack.cli.stack.utils import print_subcommand_description
|
from llama_stack.cli.stack.utils import print_subcommand_description
|
||||||
from llama_stack.cli.subcommand import Subcommand
|
from llama_stack.cli.subcommand import Subcommand
|
||||||
|
|
||||||
from .build import StackBuild
|
from .build import StackBuild
|
||||||
from .list_apis import StackListApis
|
from .list_apis import StackListApis
|
||||||
from .list_providers import StackListProviders
|
from .list_providers import StackListProviders
|
||||||
|
from .remove import StackRemove
|
||||||
from .run import StackRun
|
from .run import StackRun
|
||||||
|
|
||||||
|
|
||||||
|
@ -41,5 +43,6 @@ class StackParser(Subcommand):
|
||||||
StackListApis.create(subparsers)
|
StackListApis.create(subparsers)
|
||||||
StackListProviders.create(subparsers)
|
StackListProviders.create(subparsers)
|
||||||
StackRun.create(subparsers)
|
StackRun.create(subparsers)
|
||||||
|
StackRemove.create(subparsers)
|
||||||
|
StackListBuilds.create(subparsers)
|
||||||
print_subcommand_description(self.parser, subparsers)
|
print_subcommand_description(self.parser, subparsers)
|
||||||
|
|
|
@ -6,6 +6,7 @@
|
||||||
|
|
||||||
import importlib.resources
|
import importlib.resources
|
||||||
import logging
|
import logging
|
||||||
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
@ -43,8 +44,20 @@ def get_provider_dependencies(
|
||||||
# Extract providers based on config type
|
# Extract providers based on config type
|
||||||
if isinstance(config, DistributionTemplate):
|
if isinstance(config, DistributionTemplate):
|
||||||
providers = config.providers
|
providers = config.providers
|
||||||
|
|
||||||
|
# TODO: This is a hack to get the dependencies for internal APIs into build
|
||||||
|
# We should have a better way to do this by formalizing the concept of "internal" APIs
|
||||||
|
# and providers, with a way to specify dependencies for them.
|
||||||
|
run_configs = config.run_configs
|
||||||
|
additional_pip_packages: list[str] = []
|
||||||
|
if run_configs:
|
||||||
|
for run_config in run_configs.values():
|
||||||
|
run_config_ = run_config.run_config(name="", providers={}, container_image=None)
|
||||||
|
if run_config_.inference_store:
|
||||||
|
additional_pip_packages.extend(run_config_.inference_store.pip_packages)
|
||||||
elif isinstance(config, BuildConfig):
|
elif isinstance(config, BuildConfig):
|
||||||
providers = config.distribution_spec.providers
|
providers = config.distribution_spec.providers
|
||||||
|
additional_pip_packages = config.additional_pip_packages
|
||||||
deps = []
|
deps = []
|
||||||
registry = get_provider_registry(config)
|
registry = get_provider_registry(config)
|
||||||
for api_str, provider_or_providers in providers.items():
|
for api_str, provider_or_providers in providers.items():
|
||||||
|
@ -72,6 +85,9 @@ def get_provider_dependencies(
|
||||||
else:
|
else:
|
||||||
normal_deps.append(package)
|
normal_deps.append(package)
|
||||||
|
|
||||||
|
if additional_pip_packages:
|
||||||
|
normal_deps.extend(additional_pip_packages)
|
||||||
|
|
||||||
return list(set(normal_deps)), list(set(special_deps))
|
return list(set(normal_deps)), list(set(special_deps))
|
||||||
|
|
||||||
|
|
||||||
|
@ -80,10 +96,11 @@ def print_pip_install_help(config: BuildConfig):
|
||||||
|
|
||||||
cprint(
|
cprint(
|
||||||
f"Please install needed dependencies using the following commands:\n\nuv pip install {' '.join(normal_deps)}",
|
f"Please install needed dependencies using the following commands:\n\nuv pip install {' '.join(normal_deps)}",
|
||||||
"yellow",
|
color="yellow",
|
||||||
|
file=sys.stderr,
|
||||||
)
|
)
|
||||||
for special_dep in special_deps:
|
for special_dep in special_deps:
|
||||||
cprint(f"uv pip install {special_dep}", "yellow")
|
cprint(f"uv pip install {special_dep}", color="yellow", file=sys.stderr)
|
||||||
print()
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -6,6 +6,7 @@
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
|
import sys
|
||||||
from collections.abc import AsyncIterator
|
from collections.abc import AsyncIterator
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Union, get_args, get_origin
|
from typing import Any, Union, get_args, get_origin
|
||||||
|
@ -96,13 +97,13 @@ def create_api_client_class(protocol) -> type:
|
||||||
try:
|
try:
|
||||||
data = json.loads(data)
|
data = json.loads(data)
|
||||||
if "error" in data:
|
if "error" in data:
|
||||||
cprint(data, "red")
|
cprint(data, color="red", file=sys.stderr)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
yield parse_obj_as(return_type, data)
|
yield parse_obj_as(return_type, data)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error with parsing or validation: {e}")
|
cprint(f"Error with parsing or validation: {e}", color="red", file=sys.stderr)
|
||||||
print(data)
|
cprint(data, color="red", file=sys.stderr)
|
||||||
|
|
||||||
def httpx_request_params(self, method_name: str, *args, **kwargs) -> dict:
|
def httpx_request_params(self, method_name: str, *args, **kwargs) -> dict:
|
||||||
webmethod, sig = self.routes[method_name]
|
webmethod, sig = self.routes[method_name]
|
||||||
|
|
|
@ -25,7 +25,8 @@ from llama_stack.apis.tools import Tool, ToolGroup, ToolGroupInput, ToolRuntime
|
||||||
from llama_stack.apis.vector_dbs import VectorDB, VectorDBInput
|
from llama_stack.apis.vector_dbs import VectorDB, VectorDBInput
|
||||||
from llama_stack.apis.vector_io import VectorIO
|
from llama_stack.apis.vector_io import VectorIO
|
||||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||||
from llama_stack.providers.utils.kvstore.config import KVStoreConfig
|
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
||||||
|
from llama_stack.providers.utils.sqlstore.sqlstore import SqlStoreConfig
|
||||||
|
|
||||||
LLAMA_STACK_BUILD_CONFIG_VERSION = "2"
|
LLAMA_STACK_BUILD_CONFIG_VERSION = "2"
|
||||||
LLAMA_STACK_RUN_CONFIG_VERSION = "2"
|
LLAMA_STACK_RUN_CONFIG_VERSION = "2"
|
||||||
|
@ -220,21 +221,38 @@ class LoggingConfig(BaseModel):
|
||||||
class AuthProviderType(str, Enum):
|
class AuthProviderType(str, Enum):
|
||||||
"""Supported authentication provider types."""
|
"""Supported authentication provider types."""
|
||||||
|
|
||||||
KUBERNETES = "kubernetes"
|
OAUTH2_TOKEN = "oauth2_token"
|
||||||
CUSTOM = "custom"
|
CUSTOM = "custom"
|
||||||
|
|
||||||
|
|
||||||
class AuthenticationConfig(BaseModel):
|
class AuthenticationConfig(BaseModel):
|
||||||
provider_type: AuthProviderType = Field(
|
provider_type: AuthProviderType = Field(
|
||||||
...,
|
...,
|
||||||
description="Type of authentication provider (e.g., 'kubernetes', 'custom')",
|
description="Type of authentication provider",
|
||||||
)
|
)
|
||||||
config: dict[str, str] = Field(
|
config: dict[str, Any] = Field(
|
||||||
...,
|
...,
|
||||||
description="Provider-specific configuration",
|
description="Provider-specific configuration",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AuthenticationRequiredError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class QuotaPeriod(str, Enum):
|
||||||
|
DAY = "day"
|
||||||
|
|
||||||
|
|
||||||
|
class QuotaConfig(BaseModel):
|
||||||
|
kvstore: SqliteKVStoreConfig = Field(description="Config for KV store backend (SQLite only for now)")
|
||||||
|
anonymous_max_requests: int = Field(default=100, description="Max requests for unauthenticated clients per period")
|
||||||
|
authenticated_max_requests: int = Field(
|
||||||
|
default=1000, description="Max requests for authenticated clients per period"
|
||||||
|
)
|
||||||
|
period: QuotaPeriod = Field(default=QuotaPeriod.DAY, description="Quota period to set")
|
||||||
|
|
||||||
|
|
||||||
class ServerConfig(BaseModel):
|
class ServerConfig(BaseModel):
|
||||||
port: int = Field(
|
port: int = Field(
|
||||||
default=8321,
|
default=8321,
|
||||||
|
@ -262,6 +280,10 @@ class ServerConfig(BaseModel):
|
||||||
default=None,
|
default=None,
|
||||||
description="The host the server should listen on",
|
description="The host the server should listen on",
|
||||||
)
|
)
|
||||||
|
quota: QuotaConfig | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Per client quota request configuration",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class StackRunConfig(BaseModel):
|
class StackRunConfig(BaseModel):
|
||||||
|
@ -297,6 +319,13 @@ Configuration for the persistence store used by the distribution registry. If no
|
||||||
a default SQLite store will be used.""",
|
a default SQLite store will be used.""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
inference_store: SqlStoreConfig | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="""
|
||||||
|
Configuration for the persistence store used by the inference API. If not specified,
|
||||||
|
a default SQLite store will be used.""",
|
||||||
|
)
|
||||||
|
|
||||||
# registry of "resources" in the distribution
|
# registry of "resources" in the distribution
|
||||||
models: list[ModelInput] = Field(default_factory=list)
|
models: list[ModelInput] = Field(default_factory=list)
|
||||||
shields: list[ShieldInput] = Field(default_factory=list)
|
shields: list[ShieldInput] = Field(default_factory=list)
|
||||||
|
@ -345,6 +374,10 @@ class BuildConfig(BaseModel):
|
||||||
description="Path to directory containing external provider implementations. The providers packages will be resolved from this directory. "
|
description="Path to directory containing external provider implementations. The providers packages will be resolved from this directory. "
|
||||||
"pip_packages MUST contain the provider package name.",
|
"pip_packages MUST contain the provider package name.",
|
||||||
)
|
)
|
||||||
|
additional_pip_packages: list[str] = Field(
|
||||||
|
default_factory=list,
|
||||||
|
description="Additional pip packages to install in the distribution. These packages will be installed in the distribution environment.",
|
||||||
|
)
|
||||||
|
|
||||||
@field_validator("external_providers_dir")
|
@field_validator("external_providers_dir")
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
|
@ -31,7 +31,7 @@ async def get_provider_impl(config, deps):
|
||||||
|
|
||||||
|
|
||||||
class DistributionInspectImpl(Inspect):
|
class DistributionInspectImpl(Inspect):
|
||||||
def __init__(self, config, deps):
|
def __init__(self, config: DistributionInspectConfig, deps):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.deps = deps
|
self.deps = deps
|
||||||
|
|
||||||
|
@ -39,12 +39,26 @@ class DistributionInspectImpl(Inspect):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def list_routes(self) -> ListRoutesResponse:
|
async def list_routes(self) -> ListRoutesResponse:
|
||||||
run_config = self.config.run_config
|
run_config: StackRunConfig = self.config.run_config
|
||||||
|
|
||||||
ret = []
|
ret = []
|
||||||
all_endpoints = get_all_api_endpoints()
|
all_endpoints = get_all_api_endpoints()
|
||||||
for api, endpoints in all_endpoints.items():
|
for api, endpoints in all_endpoints.items():
|
||||||
|
# Always include provider and inspect APIs, filter others based on run config
|
||||||
|
if api.value in ["providers", "inspect"]:
|
||||||
|
ret.extend(
|
||||||
|
[
|
||||||
|
RouteInfo(
|
||||||
|
route=e.route,
|
||||||
|
method=e.method,
|
||||||
|
provider_types=[], # These APIs don't have "real" providers - they're internal to the stack
|
||||||
|
)
|
||||||
|
for e in endpoints
|
||||||
|
]
|
||||||
|
)
|
||||||
|
else:
|
||||||
providers = run_config.providers.get(api.value, [])
|
providers = run_config.providers.get(api.value, [])
|
||||||
|
if providers: # Only process if there are providers for this API
|
||||||
ret.extend(
|
ret.extend(
|
||||||
[
|
[
|
||||||
RouteInfo(
|
RouteInfo(
|
||||||
|
|
|
@ -9,6 +9,7 @@ import inspect
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
@ -210,10 +211,11 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
self.endpoint_impls = None
|
self.endpoint_impls = None
|
||||||
self.impls = await construct_stack(self.config, self.custom_provider_registry)
|
self.impls = await construct_stack(self.config, self.custom_provider_registry)
|
||||||
except ModuleNotFoundError as _e:
|
except ModuleNotFoundError as _e:
|
||||||
cprint(_e.msg, "red")
|
cprint(_e.msg, color="red", file=sys.stderr)
|
||||||
cprint(
|
cprint(
|
||||||
"Using llama-stack as a library requires installing dependencies depending on the template (providers) you choose.\n",
|
"Using llama-stack as a library requires installing dependencies depending on the template (providers) you choose.\n",
|
||||||
"yellow",
|
color="yellow",
|
||||||
|
file=sys.stderr,
|
||||||
)
|
)
|
||||||
if self.config_path_or_template_name.endswith(".yaml"):
|
if self.config_path_or_template_name.endswith(".yaml"):
|
||||||
# Convert Provider objects to their types
|
# Convert Provider objects to their types
|
||||||
|
@ -234,6 +236,12 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
cprint(
|
cprint(
|
||||||
f"Please run:\n\n{prefix}llama stack build --template {self.config_path_or_template_name} --image-type venv\n\n",
|
f"Please run:\n\n{prefix}llama stack build --template {self.config_path_or_template_name} --image-type venv\n\n",
|
||||||
"yellow",
|
"yellow",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
cprint(
|
||||||
|
"Please check your internet connection and try again.",
|
||||||
|
"red",
|
||||||
|
file=sys.stderr,
|
||||||
)
|
)
|
||||||
raise _e
|
raise _e
|
||||||
|
|
||||||
|
@ -261,8 +269,10 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
raise ValueError("Client not initialized")
|
raise ValueError("Client not initialized")
|
||||||
|
|
||||||
# Create headers with provider data if available
|
# Create headers with provider data if available
|
||||||
headers = {}
|
headers = options.headers or {}
|
||||||
if self.provider_data:
|
if self.provider_data:
|
||||||
|
keys = ["X-LlamaStack-Provider-Data", "x-llamastack-provider-data"]
|
||||||
|
if all(key not in headers for key in keys):
|
||||||
headers["X-LlamaStack-Provider-Data"] = json.dumps(self.provider_data)
|
headers["X-LlamaStack-Provider-Data"] = json.dumps(self.provider_data)
|
||||||
|
|
||||||
# Use context manager for provider data
|
# Use context manager for provider data
|
||||||
|
|
|
@ -47,7 +47,7 @@ from llama_stack.providers.datatypes import (
|
||||||
RemoteProviderSpec,
|
RemoteProviderSpec,
|
||||||
ScoringFunctionsProtocolPrivate,
|
ScoringFunctionsProtocolPrivate,
|
||||||
ShieldsProtocolPrivate,
|
ShieldsProtocolPrivate,
|
||||||
ToolsProtocolPrivate,
|
ToolGroupsProtocolPrivate,
|
||||||
VectorDBsProtocolPrivate,
|
VectorDBsProtocolPrivate,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -93,7 +93,7 @@ def api_protocol_map_for_compliance_check() -> dict[Api, Any]:
|
||||||
def additional_protocols_map() -> dict[Api, Any]:
|
def additional_protocols_map() -> dict[Api, Any]:
|
||||||
return {
|
return {
|
||||||
Api.inference: (ModelsProtocolPrivate, Models, Api.models),
|
Api.inference: (ModelsProtocolPrivate, Models, Api.models),
|
||||||
Api.tool_groups: (ToolsProtocolPrivate, ToolGroups, Api.tool_groups),
|
Api.tool_groups: (ToolGroupsProtocolPrivate, ToolGroups, Api.tool_groups),
|
||||||
Api.vector_io: (VectorDBsProtocolPrivate, VectorDBs, Api.vector_dbs),
|
Api.vector_io: (VectorDBsProtocolPrivate, VectorDBs, Api.vector_dbs),
|
||||||
Api.safety: (ShieldsProtocolPrivate, Shields, Api.shields),
|
Api.safety: (ShieldsProtocolPrivate, Shields, Api.shields),
|
||||||
Api.datasetio: (DatasetsProtocolPrivate, Datasets, Api.datasets),
|
Api.datasetio: (DatasetsProtocolPrivate, Datasets, Api.datasets),
|
||||||
|
@ -140,7 +140,7 @@ async def resolve_impls(
|
||||||
|
|
||||||
sorted_providers = sort_providers_by_deps(providers_with_specs, run_config)
|
sorted_providers = sort_providers_by_deps(providers_with_specs, run_config)
|
||||||
|
|
||||||
return await instantiate_providers(sorted_providers, router_apis, dist_registry)
|
return await instantiate_providers(sorted_providers, router_apis, dist_registry, run_config)
|
||||||
|
|
||||||
|
|
||||||
def specs_for_autorouted_apis(apis_to_serve: list[str] | set[str]) -> dict[str, dict[str, ProviderWithSpec]]:
|
def specs_for_autorouted_apis(apis_to_serve: list[str] | set[str]) -> dict[str, dict[str, ProviderWithSpec]]:
|
||||||
|
@ -243,7 +243,10 @@ def sort_providers_by_deps(
|
||||||
|
|
||||||
|
|
||||||
async def instantiate_providers(
|
async def instantiate_providers(
|
||||||
sorted_providers: list[tuple[str, ProviderWithSpec]], router_apis: set[Api], dist_registry: DistributionRegistry
|
sorted_providers: list[tuple[str, ProviderWithSpec]],
|
||||||
|
router_apis: set[Api],
|
||||||
|
dist_registry: DistributionRegistry,
|
||||||
|
run_config: StackRunConfig,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""Instantiates providers asynchronously while managing dependencies."""
|
"""Instantiates providers asynchronously while managing dependencies."""
|
||||||
impls: dict[Api, Any] = {}
|
impls: dict[Api, Any] = {}
|
||||||
|
@ -258,7 +261,7 @@ async def instantiate_providers(
|
||||||
if isinstance(provider.spec, RoutingTableProviderSpec):
|
if isinstance(provider.spec, RoutingTableProviderSpec):
|
||||||
inner_impls = inner_impls_by_provider_id[f"inner-{provider.spec.router_api.value}"]
|
inner_impls = inner_impls_by_provider_id[f"inner-{provider.spec.router_api.value}"]
|
||||||
|
|
||||||
impl = await instantiate_provider(provider, deps, inner_impls, dist_registry)
|
impl = await instantiate_provider(provider, deps, inner_impls, dist_registry, run_config)
|
||||||
|
|
||||||
if api_str.startswith("inner-"):
|
if api_str.startswith("inner-"):
|
||||||
inner_impls_by_provider_id[api_str][provider.provider_id] = impl
|
inner_impls_by_provider_id[api_str][provider.provider_id] = impl
|
||||||
|
@ -308,6 +311,7 @@ async def instantiate_provider(
|
||||||
deps: dict[Api, Any],
|
deps: dict[Api, Any],
|
||||||
inner_impls: dict[str, Any],
|
inner_impls: dict[str, Any],
|
||||||
dist_registry: DistributionRegistry,
|
dist_registry: DistributionRegistry,
|
||||||
|
run_config: StackRunConfig,
|
||||||
):
|
):
|
||||||
provider_spec = provider.spec
|
provider_spec = provider.spec
|
||||||
if not hasattr(provider_spec, "module"):
|
if not hasattr(provider_spec, "module"):
|
||||||
|
@ -327,7 +331,7 @@ async def instantiate_provider(
|
||||||
method = "get_auto_router_impl"
|
method = "get_auto_router_impl"
|
||||||
|
|
||||||
config = None
|
config = None
|
||||||
args = [provider_spec.api, deps[provider_spec.routing_table_api], deps]
|
args = [provider_spec.api, deps[provider_spec.routing_table_api], deps, run_config]
|
||||||
elif isinstance(provider_spec, RoutingTableProviderSpec):
|
elif isinstance(provider_spec, RoutingTableProviderSpec):
|
||||||
method = "get_routing_table_impl"
|
method = "get_routing_table_impl"
|
||||||
|
|
||||||
|
|
|
@ -7,18 +7,10 @@
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import RoutedProtocol
|
from llama_stack.distribution.datatypes import RoutedProtocol
|
||||||
|
from llama_stack.distribution.stack import StackRunConfig
|
||||||
from llama_stack.distribution.store import DistributionRegistry
|
from llama_stack.distribution.store import DistributionRegistry
|
||||||
from llama_stack.providers.datatypes import Api, RoutingTable
|
from llama_stack.providers.datatypes import Api, RoutingTable
|
||||||
|
from llama_stack.providers.utils.inference.inference_store import InferenceStore
|
||||||
from .routing_tables import (
|
|
||||||
BenchmarksRoutingTable,
|
|
||||||
DatasetsRoutingTable,
|
|
||||||
ModelsRoutingTable,
|
|
||||||
ScoringFunctionsRoutingTable,
|
|
||||||
ShieldsRoutingTable,
|
|
||||||
ToolGroupsRoutingTable,
|
|
||||||
VectorDBsRoutingTable,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def get_routing_table_impl(
|
async def get_routing_table_impl(
|
||||||
|
@ -27,6 +19,14 @@ async def get_routing_table_impl(
|
||||||
_deps,
|
_deps,
|
||||||
dist_registry: DistributionRegistry,
|
dist_registry: DistributionRegistry,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
|
from ..routing_tables.benchmarks import BenchmarksRoutingTable
|
||||||
|
from ..routing_tables.datasets import DatasetsRoutingTable
|
||||||
|
from ..routing_tables.models import ModelsRoutingTable
|
||||||
|
from ..routing_tables.scoring_functions import ScoringFunctionsRoutingTable
|
||||||
|
from ..routing_tables.shields import ShieldsRoutingTable
|
||||||
|
from ..routing_tables.toolgroups import ToolGroupsRoutingTable
|
||||||
|
from ..routing_tables.vector_dbs import VectorDBsRoutingTable
|
||||||
|
|
||||||
api_to_tables = {
|
api_to_tables = {
|
||||||
"vector_dbs": VectorDBsRoutingTable,
|
"vector_dbs": VectorDBsRoutingTable,
|
||||||
"models": ModelsRoutingTable,
|
"models": ModelsRoutingTable,
|
||||||
|
@ -45,16 +45,15 @@ async def get_routing_table_impl(
|
||||||
return impl
|
return impl
|
||||||
|
|
||||||
|
|
||||||
async def get_auto_router_impl(api: Api, routing_table: RoutingTable, deps: dict[str, Any]) -> Any:
|
async def get_auto_router_impl(
|
||||||
from .routers import (
|
api: Api, routing_table: RoutingTable, deps: dict[str, Any], run_config: StackRunConfig
|
||||||
DatasetIORouter,
|
) -> Any:
|
||||||
EvalRouter,
|
from .datasets import DatasetIORouter
|
||||||
InferenceRouter,
|
from .eval_scoring import EvalRouter, ScoringRouter
|
||||||
SafetyRouter,
|
from .inference import InferenceRouter
|
||||||
ScoringRouter,
|
from .safety import SafetyRouter
|
||||||
ToolRuntimeRouter,
|
from .tool_runtime import ToolRuntimeRouter
|
||||||
VectorIORouter,
|
from .vector_io import VectorIORouter
|
||||||
)
|
|
||||||
|
|
||||||
api_to_routers = {
|
api_to_routers = {
|
||||||
"vector_io": VectorIORouter,
|
"vector_io": VectorIORouter,
|
||||||
|
@ -76,6 +75,12 @@ async def get_auto_router_impl(api: Api, routing_table: RoutingTable, deps: dict
|
||||||
if dep_api in deps:
|
if dep_api in deps:
|
||||||
api_to_dep_impl[dep_name] = deps[dep_api]
|
api_to_dep_impl[dep_name] = deps[dep_api]
|
||||||
|
|
||||||
|
# TODO: move pass configs to routers instead
|
||||||
|
if api == Api.inference and run_config.inference_store:
|
||||||
|
inference_store = InferenceStore(run_config.inference_store)
|
||||||
|
await inference_store.initialize()
|
||||||
|
api_to_dep_impl["store"] = inference_store
|
||||||
|
|
||||||
impl = api_to_routers[api.value](routing_table, **api_to_dep_impl)
|
impl = api_to_routers[api.value](routing_table, **api_to_dep_impl)
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
71
llama_stack/distribution/routers/datasets.py
Normal file
71
llama_stack/distribution/routers/datasets.py
Normal file
|
@ -0,0 +1,71 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from llama_stack.apis.common.responses import PaginatedResponse
|
||||||
|
from llama_stack.apis.datasetio import DatasetIO
|
||||||
|
from llama_stack.apis.datasets import DatasetPurpose, DataSource
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
from llama_stack.providers.datatypes import RoutingTable
|
||||||
|
|
||||||
|
logger = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetIORouter(DatasetIO):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
routing_table: RoutingTable,
|
||||||
|
) -> None:
|
||||||
|
logger.debug("Initializing DatasetIORouter")
|
||||||
|
self.routing_table = routing_table
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
logger.debug("DatasetIORouter.initialize")
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
logger.debug("DatasetIORouter.shutdown")
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def register_dataset(
|
||||||
|
self,
|
||||||
|
purpose: DatasetPurpose,
|
||||||
|
source: DataSource,
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
|
dataset_id: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
logger.debug(
|
||||||
|
f"DatasetIORouter.register_dataset: {purpose=} {source=} {metadata=} {dataset_id=}",
|
||||||
|
)
|
||||||
|
await self.routing_table.register_dataset(
|
||||||
|
purpose=purpose,
|
||||||
|
source=source,
|
||||||
|
metadata=metadata,
|
||||||
|
dataset_id=dataset_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def iterrows(
|
||||||
|
self,
|
||||||
|
dataset_id: str,
|
||||||
|
start_index: int | None = None,
|
||||||
|
limit: int | None = None,
|
||||||
|
) -> PaginatedResponse:
|
||||||
|
logger.debug(
|
||||||
|
f"DatasetIORouter.iterrows: {dataset_id}, {start_index=} {limit=}",
|
||||||
|
)
|
||||||
|
return await self.routing_table.get_provider_impl(dataset_id).iterrows(
|
||||||
|
dataset_id=dataset_id,
|
||||||
|
start_index=start_index,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None:
|
||||||
|
logger.debug(f"DatasetIORouter.append_rows: {dataset_id}, {len(rows)} rows")
|
||||||
|
return await self.routing_table.get_provider_impl(dataset_id).append_rows(
|
||||||
|
dataset_id=dataset_id,
|
||||||
|
rows=rows,
|
||||||
|
)
|
148
llama_stack/distribution/routers/eval_scoring.py
Normal file
148
llama_stack/distribution/routers/eval_scoring.py
Normal file
|
@ -0,0 +1,148 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from llama_stack.apis.eval import BenchmarkConfig, Eval, EvaluateResponse, Job
|
||||||
|
from llama_stack.apis.scoring import (
|
||||||
|
ScoreBatchResponse,
|
||||||
|
ScoreResponse,
|
||||||
|
Scoring,
|
||||||
|
ScoringFnParams,
|
||||||
|
)
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
from llama_stack.providers.datatypes import RoutingTable
|
||||||
|
|
||||||
|
logger = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
|
class ScoringRouter(Scoring):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
routing_table: RoutingTable,
|
||||||
|
) -> None:
|
||||||
|
logger.debug("Initializing ScoringRouter")
|
||||||
|
self.routing_table = routing_table
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
logger.debug("ScoringRouter.initialize")
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
logger.debug("ScoringRouter.shutdown")
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def score_batch(
|
||||||
|
self,
|
||||||
|
dataset_id: str,
|
||||||
|
scoring_functions: dict[str, ScoringFnParams | None] = None,
|
||||||
|
save_results_dataset: bool = False,
|
||||||
|
) -> ScoreBatchResponse:
|
||||||
|
logger.debug(f"ScoringRouter.score_batch: {dataset_id}")
|
||||||
|
res = {}
|
||||||
|
for fn_identifier in scoring_functions.keys():
|
||||||
|
score_response = await self.routing_table.get_provider_impl(fn_identifier).score_batch(
|
||||||
|
dataset_id=dataset_id,
|
||||||
|
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
|
||||||
|
)
|
||||||
|
res.update(score_response.results)
|
||||||
|
|
||||||
|
if save_results_dataset:
|
||||||
|
raise NotImplementedError("Save results dataset not implemented yet")
|
||||||
|
|
||||||
|
return ScoreBatchResponse(
|
||||||
|
results=res,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def score(
|
||||||
|
self,
|
||||||
|
input_rows: list[dict[str, Any]],
|
||||||
|
scoring_functions: dict[str, ScoringFnParams | None] = None,
|
||||||
|
) -> ScoreResponse:
|
||||||
|
logger.debug(f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions")
|
||||||
|
res = {}
|
||||||
|
# look up and map each scoring function to its provider impl
|
||||||
|
for fn_identifier in scoring_functions.keys():
|
||||||
|
score_response = await self.routing_table.get_provider_impl(fn_identifier).score(
|
||||||
|
input_rows=input_rows,
|
||||||
|
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
|
||||||
|
)
|
||||||
|
res.update(score_response.results)
|
||||||
|
|
||||||
|
return ScoreResponse(results=res)
|
||||||
|
|
||||||
|
|
||||||
|
class EvalRouter(Eval):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
routing_table: RoutingTable,
|
||||||
|
) -> None:
|
||||||
|
logger.debug("Initializing EvalRouter")
|
||||||
|
self.routing_table = routing_table
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
logger.debug("EvalRouter.initialize")
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
logger.debug("EvalRouter.shutdown")
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def run_eval(
|
||||||
|
self,
|
||||||
|
benchmark_id: str,
|
||||||
|
benchmark_config: BenchmarkConfig,
|
||||||
|
) -> Job:
|
||||||
|
logger.debug(f"EvalRouter.run_eval: {benchmark_id}")
|
||||||
|
return await self.routing_table.get_provider_impl(benchmark_id).run_eval(
|
||||||
|
benchmark_id=benchmark_id,
|
||||||
|
benchmark_config=benchmark_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def evaluate_rows(
|
||||||
|
self,
|
||||||
|
benchmark_id: str,
|
||||||
|
input_rows: list[dict[str, Any]],
|
||||||
|
scoring_functions: list[str],
|
||||||
|
benchmark_config: BenchmarkConfig,
|
||||||
|
) -> EvaluateResponse:
|
||||||
|
logger.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows")
|
||||||
|
return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows(
|
||||||
|
benchmark_id=benchmark_id,
|
||||||
|
input_rows=input_rows,
|
||||||
|
scoring_functions=scoring_functions,
|
||||||
|
benchmark_config=benchmark_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def job_status(
|
||||||
|
self,
|
||||||
|
benchmark_id: str,
|
||||||
|
job_id: str,
|
||||||
|
) -> Job:
|
||||||
|
logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}")
|
||||||
|
return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id)
|
||||||
|
|
||||||
|
async def job_cancel(
|
||||||
|
self,
|
||||||
|
benchmark_id: str,
|
||||||
|
job_id: str,
|
||||||
|
) -> None:
|
||||||
|
logger.debug(f"EvalRouter.job_cancel: {benchmark_id}, {job_id}")
|
||||||
|
await self.routing_table.get_provider_impl(benchmark_id).job_cancel(
|
||||||
|
benchmark_id,
|
||||||
|
job_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def job_result(
|
||||||
|
self,
|
||||||
|
benchmark_id: str,
|
||||||
|
job_id: str,
|
||||||
|
) -> EvaluateResponse:
|
||||||
|
logger.debug(f"EvalRouter.job_result: {benchmark_id}, {job_id}")
|
||||||
|
return await self.routing_table.get_provider_impl(benchmark_id).job_result(
|
||||||
|
benchmark_id,
|
||||||
|
job_id,
|
||||||
|
)
|
|
@ -14,14 +14,9 @@ from openai.types.chat import ChatCompletionToolParam as OpenAIChatCompletionToo
|
||||||
from pydantic import Field, TypeAdapter
|
from pydantic import Field, TypeAdapter
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import (
|
from llama_stack.apis.common.content_types import (
|
||||||
URL,
|
|
||||||
InterleavedContent,
|
InterleavedContent,
|
||||||
InterleavedContentItem,
|
InterleavedContentItem,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.common.responses import PaginatedResponse
|
|
||||||
from llama_stack.apis.datasetio import DatasetIO
|
|
||||||
from llama_stack.apis.datasets import DatasetPurpose, DataSource
|
|
||||||
from llama_stack.apis.eval import BenchmarkConfig, Eval, EvaluateResponse, Job
|
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
BatchChatCompletionResponse,
|
BatchChatCompletionResponse,
|
||||||
BatchCompletionResponse,
|
BatchCompletionResponse,
|
||||||
|
@ -32,8 +27,11 @@ from llama_stack.apis.inference import (
|
||||||
EmbeddingsResponse,
|
EmbeddingsResponse,
|
||||||
EmbeddingTaskType,
|
EmbeddingTaskType,
|
||||||
Inference,
|
Inference,
|
||||||
|
ListOpenAIChatCompletionResponse,
|
||||||
LogProbConfig,
|
LogProbConfig,
|
||||||
Message,
|
Message,
|
||||||
|
OpenAICompletionWithInputMessages,
|
||||||
|
Order,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
StopReason,
|
StopReason,
|
||||||
|
@ -51,89 +49,18 @@ from llama_stack.apis.inference.inference import (
|
||||||
OpenAIResponseFormatParam,
|
OpenAIResponseFormatParam,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.models import Model, ModelType
|
from llama_stack.apis.models import Model, ModelType
|
||||||
from llama_stack.apis.safety import RunShieldResponse, Safety
|
|
||||||
from llama_stack.apis.scoring import (
|
|
||||||
ScoreBatchResponse,
|
|
||||||
ScoreResponse,
|
|
||||||
Scoring,
|
|
||||||
ScoringFnParams,
|
|
||||||
)
|
|
||||||
from llama_stack.apis.shields import Shield
|
|
||||||
from llama_stack.apis.telemetry import MetricEvent, MetricInResponse, Telemetry
|
from llama_stack.apis.telemetry import MetricEvent, MetricInResponse, Telemetry
|
||||||
from llama_stack.apis.tools import (
|
|
||||||
ListToolDefsResponse,
|
|
||||||
RAGDocument,
|
|
||||||
RAGQueryConfig,
|
|
||||||
RAGQueryResult,
|
|
||||||
RAGToolRuntime,
|
|
||||||
ToolRuntime,
|
|
||||||
)
|
|
||||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.models.llama.llama3.chat_format import ChatFormat
|
from llama_stack.models.llama.llama3.chat_format import ChatFormat
|
||||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||||
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
|
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
|
||||||
|
from llama_stack.providers.utils.inference.inference_store import InferenceStore
|
||||||
|
from llama_stack.providers.utils.inference.stream_utils import stream_and_store_openai_completion
|
||||||
from llama_stack.providers.utils.telemetry.tracing import get_current_span
|
from llama_stack.providers.utils.telemetry.tracing import get_current_span
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
logger = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
class VectorIORouter(VectorIO):
|
|
||||||
"""Routes to an provider based on the vector db identifier"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
routing_table: RoutingTable,
|
|
||||||
) -> None:
|
|
||||||
logger.debug("Initializing VectorIORouter")
|
|
||||||
self.routing_table = routing_table
|
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
|
||||||
logger.debug("VectorIORouter.initialize")
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
|
||||||
logger.debug("VectorIORouter.shutdown")
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def register_vector_db(
|
|
||||||
self,
|
|
||||||
vector_db_id: str,
|
|
||||||
embedding_model: str,
|
|
||||||
embedding_dimension: int | None = 384,
|
|
||||||
provider_id: str | None = None,
|
|
||||||
provider_vector_db_id: str | None = None,
|
|
||||||
) -> None:
|
|
||||||
logger.debug(f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}")
|
|
||||||
await self.routing_table.register_vector_db(
|
|
||||||
vector_db_id,
|
|
||||||
embedding_model,
|
|
||||||
embedding_dimension,
|
|
||||||
provider_id,
|
|
||||||
provider_vector_db_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def insert_chunks(
|
|
||||||
self,
|
|
||||||
vector_db_id: str,
|
|
||||||
chunks: list[Chunk],
|
|
||||||
ttl_seconds: int | None = None,
|
|
||||||
) -> None:
|
|
||||||
logger.debug(
|
|
||||||
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}",
|
|
||||||
)
|
|
||||||
return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(vector_db_id, chunks, ttl_seconds)
|
|
||||||
|
|
||||||
async def query_chunks(
|
|
||||||
self,
|
|
||||||
vector_db_id: str,
|
|
||||||
query: InterleavedContent,
|
|
||||||
params: dict[str, Any] | None = None,
|
|
||||||
) -> QueryChunksResponse:
|
|
||||||
logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}")
|
|
||||||
return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params)
|
|
||||||
|
|
||||||
|
|
||||||
class InferenceRouter(Inference):
|
class InferenceRouter(Inference):
|
||||||
"""Routes to an provider based on the model"""
|
"""Routes to an provider based on the model"""
|
||||||
|
|
||||||
|
@ -141,10 +68,12 @@ class InferenceRouter(Inference):
|
||||||
self,
|
self,
|
||||||
routing_table: RoutingTable,
|
routing_table: RoutingTable,
|
||||||
telemetry: Telemetry | None = None,
|
telemetry: Telemetry | None = None,
|
||||||
|
store: InferenceStore | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
logger.debug("Initializing InferenceRouter")
|
logger.debug("Initializing InferenceRouter")
|
||||||
self.routing_table = routing_table
|
self.routing_table = routing_table
|
||||||
self.telemetry = telemetry
|
self.telemetry = telemetry
|
||||||
|
self.store = store
|
||||||
if self.telemetry:
|
if self.telemetry:
|
||||||
self.tokenizer = Tokenizer.get_instance()
|
self.tokenizer = Tokenizer.get_instance()
|
||||||
self.formatter = ChatFormat(self.tokenizer)
|
self.formatter = ChatFormat(self.tokenizer)
|
||||||
|
@ -607,9 +536,31 @@ class InferenceRouter(Inference):
|
||||||
|
|
||||||
provider = self.routing_table.get_provider_impl(model_obj.identifier)
|
provider = self.routing_table.get_provider_impl(model_obj.identifier)
|
||||||
if stream:
|
if stream:
|
||||||
return await provider.openai_chat_completion(**params)
|
response_stream = await provider.openai_chat_completion(**params)
|
||||||
|
if self.store:
|
||||||
|
return stream_and_store_openai_completion(response_stream, model, self.store, messages)
|
||||||
|
return response_stream
|
||||||
else:
|
else:
|
||||||
return await self._nonstream_openai_chat_completion(provider, params)
|
response = await self._nonstream_openai_chat_completion(provider, params)
|
||||||
|
if self.store:
|
||||||
|
await self.store.store_chat_completion(response, messages)
|
||||||
|
return response
|
||||||
|
|
||||||
|
async def list_chat_completions(
|
||||||
|
self,
|
||||||
|
after: str | None = None,
|
||||||
|
limit: int | None = 20,
|
||||||
|
model: str | None = None,
|
||||||
|
order: Order | None = Order.desc,
|
||||||
|
) -> ListOpenAIChatCompletionResponse:
|
||||||
|
if self.store:
|
||||||
|
return await self.store.list_chat_completions(after, limit, model, order)
|
||||||
|
raise NotImplementedError("List chat completions is not supported: inference store is not configured.")
|
||||||
|
|
||||||
|
async def get_chat_completion(self, completion_id: str) -> OpenAICompletionWithInputMessages:
|
||||||
|
if self.store:
|
||||||
|
return await self.store.get_chat_completion(completion_id)
|
||||||
|
raise NotImplementedError("Get chat completion is not supported: inference store is not configured.")
|
||||||
|
|
||||||
async def _nonstream_openai_chat_completion(self, provider: Inference, params: dict) -> OpenAIChatCompletion:
|
async def _nonstream_openai_chat_completion(self, provider: Inference, params: dict) -> OpenAIChatCompletion:
|
||||||
response = await provider.openai_chat_completion(**params)
|
response = await provider.openai_chat_completion(**params)
|
||||||
|
@ -642,295 +593,3 @@ class InferenceRouter(Inference):
|
||||||
status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}"
|
status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}"
|
||||||
)
|
)
|
||||||
return health_statuses
|
return health_statuses
|
||||||
|
|
||||||
|
|
||||||
class SafetyRouter(Safety):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
routing_table: RoutingTable,
|
|
||||||
) -> None:
|
|
||||||
logger.debug("Initializing SafetyRouter")
|
|
||||||
self.routing_table = routing_table
|
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
|
||||||
logger.debug("SafetyRouter.initialize")
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
|
||||||
logger.debug("SafetyRouter.shutdown")
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def register_shield(
|
|
||||||
self,
|
|
||||||
shield_id: str,
|
|
||||||
provider_shield_id: str | None = None,
|
|
||||||
provider_id: str | None = None,
|
|
||||||
params: dict[str, Any] | None = None,
|
|
||||||
) -> Shield:
|
|
||||||
logger.debug(f"SafetyRouter.register_shield: {shield_id}")
|
|
||||||
return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params)
|
|
||||||
|
|
||||||
async def run_shield(
|
|
||||||
self,
|
|
||||||
shield_id: str,
|
|
||||||
messages: list[Message],
|
|
||||||
params: dict[str, Any] = None,
|
|
||||||
) -> RunShieldResponse:
|
|
||||||
logger.debug(f"SafetyRouter.run_shield: {shield_id}")
|
|
||||||
return await self.routing_table.get_provider_impl(shield_id).run_shield(
|
|
||||||
shield_id=shield_id,
|
|
||||||
messages=messages,
|
|
||||||
params=params,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class DatasetIORouter(DatasetIO):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
routing_table: RoutingTable,
|
|
||||||
) -> None:
|
|
||||||
logger.debug("Initializing DatasetIORouter")
|
|
||||||
self.routing_table = routing_table
|
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
|
||||||
logger.debug("DatasetIORouter.initialize")
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
|
||||||
logger.debug("DatasetIORouter.shutdown")
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def register_dataset(
|
|
||||||
self,
|
|
||||||
purpose: DatasetPurpose,
|
|
||||||
source: DataSource,
|
|
||||||
metadata: dict[str, Any] | None = None,
|
|
||||||
dataset_id: str | None = None,
|
|
||||||
) -> None:
|
|
||||||
logger.debug(
|
|
||||||
f"DatasetIORouter.register_dataset: {purpose=} {source=} {metadata=} {dataset_id=}",
|
|
||||||
)
|
|
||||||
await self.routing_table.register_dataset(
|
|
||||||
purpose=purpose,
|
|
||||||
source=source,
|
|
||||||
metadata=metadata,
|
|
||||||
dataset_id=dataset_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def iterrows(
|
|
||||||
self,
|
|
||||||
dataset_id: str,
|
|
||||||
start_index: int | None = None,
|
|
||||||
limit: int | None = None,
|
|
||||||
) -> PaginatedResponse:
|
|
||||||
logger.debug(
|
|
||||||
f"DatasetIORouter.iterrows: {dataset_id}, {start_index=} {limit=}",
|
|
||||||
)
|
|
||||||
return await self.routing_table.get_provider_impl(dataset_id).iterrows(
|
|
||||||
dataset_id=dataset_id,
|
|
||||||
start_index=start_index,
|
|
||||||
limit=limit,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None:
|
|
||||||
logger.debug(f"DatasetIORouter.append_rows: {dataset_id}, {len(rows)} rows")
|
|
||||||
return await self.routing_table.get_provider_impl(dataset_id).append_rows(
|
|
||||||
dataset_id=dataset_id,
|
|
||||||
rows=rows,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ScoringRouter(Scoring):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
routing_table: RoutingTable,
|
|
||||||
) -> None:
|
|
||||||
logger.debug("Initializing ScoringRouter")
|
|
||||||
self.routing_table = routing_table
|
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
|
||||||
logger.debug("ScoringRouter.initialize")
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
|
||||||
logger.debug("ScoringRouter.shutdown")
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def score_batch(
|
|
||||||
self,
|
|
||||||
dataset_id: str,
|
|
||||||
scoring_functions: dict[str, ScoringFnParams | None] = None,
|
|
||||||
save_results_dataset: bool = False,
|
|
||||||
) -> ScoreBatchResponse:
|
|
||||||
logger.debug(f"ScoringRouter.score_batch: {dataset_id}")
|
|
||||||
res = {}
|
|
||||||
for fn_identifier in scoring_functions.keys():
|
|
||||||
score_response = await self.routing_table.get_provider_impl(fn_identifier).score_batch(
|
|
||||||
dataset_id=dataset_id,
|
|
||||||
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
|
|
||||||
)
|
|
||||||
res.update(score_response.results)
|
|
||||||
|
|
||||||
if save_results_dataset:
|
|
||||||
raise NotImplementedError("Save results dataset not implemented yet")
|
|
||||||
|
|
||||||
return ScoreBatchResponse(
|
|
||||||
results=res,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def score(
|
|
||||||
self,
|
|
||||||
input_rows: list[dict[str, Any]],
|
|
||||||
scoring_functions: dict[str, ScoringFnParams | None] = None,
|
|
||||||
) -> ScoreResponse:
|
|
||||||
logger.debug(f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions")
|
|
||||||
res = {}
|
|
||||||
# look up and map each scoring function to its provider impl
|
|
||||||
for fn_identifier in scoring_functions.keys():
|
|
||||||
score_response = await self.routing_table.get_provider_impl(fn_identifier).score(
|
|
||||||
input_rows=input_rows,
|
|
||||||
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
|
|
||||||
)
|
|
||||||
res.update(score_response.results)
|
|
||||||
|
|
||||||
return ScoreResponse(results=res)
|
|
||||||
|
|
||||||
|
|
||||||
class EvalRouter(Eval):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
routing_table: RoutingTable,
|
|
||||||
) -> None:
|
|
||||||
logger.debug("Initializing EvalRouter")
|
|
||||||
self.routing_table = routing_table
|
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
|
||||||
logger.debug("EvalRouter.initialize")
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
|
||||||
logger.debug("EvalRouter.shutdown")
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def run_eval(
|
|
||||||
self,
|
|
||||||
benchmark_id: str,
|
|
||||||
benchmark_config: BenchmarkConfig,
|
|
||||||
) -> Job:
|
|
||||||
logger.debug(f"EvalRouter.run_eval: {benchmark_id}")
|
|
||||||
return await self.routing_table.get_provider_impl(benchmark_id).run_eval(
|
|
||||||
benchmark_id=benchmark_id,
|
|
||||||
benchmark_config=benchmark_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def evaluate_rows(
|
|
||||||
self,
|
|
||||||
benchmark_id: str,
|
|
||||||
input_rows: list[dict[str, Any]],
|
|
||||||
scoring_functions: list[str],
|
|
||||||
benchmark_config: BenchmarkConfig,
|
|
||||||
) -> EvaluateResponse:
|
|
||||||
logger.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows")
|
|
||||||
return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows(
|
|
||||||
benchmark_id=benchmark_id,
|
|
||||||
input_rows=input_rows,
|
|
||||||
scoring_functions=scoring_functions,
|
|
||||||
benchmark_config=benchmark_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def job_status(
|
|
||||||
self,
|
|
||||||
benchmark_id: str,
|
|
||||||
job_id: str,
|
|
||||||
) -> Job:
|
|
||||||
logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}")
|
|
||||||
return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id)
|
|
||||||
|
|
||||||
async def job_cancel(
|
|
||||||
self,
|
|
||||||
benchmark_id: str,
|
|
||||||
job_id: str,
|
|
||||||
) -> None:
|
|
||||||
logger.debug(f"EvalRouter.job_cancel: {benchmark_id}, {job_id}")
|
|
||||||
await self.routing_table.get_provider_impl(benchmark_id).job_cancel(
|
|
||||||
benchmark_id,
|
|
||||||
job_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def job_result(
|
|
||||||
self,
|
|
||||||
benchmark_id: str,
|
|
||||||
job_id: str,
|
|
||||||
) -> EvaluateResponse:
|
|
||||||
logger.debug(f"EvalRouter.job_result: {benchmark_id}, {job_id}")
|
|
||||||
return await self.routing_table.get_provider_impl(benchmark_id).job_result(
|
|
||||||
benchmark_id,
|
|
||||||
job_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ToolRuntimeRouter(ToolRuntime):
|
|
||||||
class RagToolImpl(RAGToolRuntime):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
routing_table: RoutingTable,
|
|
||||||
) -> None:
|
|
||||||
logger.debug("Initializing ToolRuntimeRouter.RagToolImpl")
|
|
||||||
self.routing_table = routing_table
|
|
||||||
|
|
||||||
async def query(
|
|
||||||
self,
|
|
||||||
content: InterleavedContent,
|
|
||||||
vector_db_ids: list[str],
|
|
||||||
query_config: RAGQueryConfig | None = None,
|
|
||||||
) -> RAGQueryResult:
|
|
||||||
logger.debug(f"ToolRuntimeRouter.RagToolImpl.query: {vector_db_ids}")
|
|
||||||
return await self.routing_table.get_provider_impl("knowledge_search").query(
|
|
||||||
content, vector_db_ids, query_config
|
|
||||||
)
|
|
||||||
|
|
||||||
async def insert(
|
|
||||||
self,
|
|
||||||
documents: list[RAGDocument],
|
|
||||||
vector_db_id: str,
|
|
||||||
chunk_size_in_tokens: int = 512,
|
|
||||||
) -> None:
|
|
||||||
logger.debug(
|
|
||||||
f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}"
|
|
||||||
)
|
|
||||||
return await self.routing_table.get_provider_impl("insert_into_memory").insert(
|
|
||||||
documents, vector_db_id, chunk_size_in_tokens
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
routing_table: RoutingTable,
|
|
||||||
) -> None:
|
|
||||||
logger.debug("Initializing ToolRuntimeRouter")
|
|
||||||
self.routing_table = routing_table
|
|
||||||
|
|
||||||
# HACK ALERT this should be in sync with "get_all_api_endpoints()"
|
|
||||||
self.rag_tool = self.RagToolImpl(routing_table)
|
|
||||||
for method in ("query", "insert"):
|
|
||||||
setattr(self, f"rag_tool.{method}", getattr(self.rag_tool, method))
|
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
|
||||||
logger.debug("ToolRuntimeRouter.initialize")
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
|
||||||
logger.debug("ToolRuntimeRouter.shutdown")
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> Any:
|
|
||||||
logger.debug(f"ToolRuntimeRouter.invoke_tool: {tool_name}")
|
|
||||||
return await self.routing_table.get_provider_impl(tool_name).invoke_tool(
|
|
||||||
tool_name=tool_name,
|
|
||||||
kwargs=kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def list_runtime_tools(
|
|
||||||
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
|
|
||||||
) -> ListToolDefsResponse:
|
|
||||||
logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
|
|
||||||
return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint)
|
|
|
@ -1,634 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import time
|
|
||||||
import uuid
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from pydantic import TypeAdapter
|
|
||||||
|
|
||||||
from llama_stack.apis.benchmarks import Benchmark, Benchmarks, ListBenchmarksResponse
|
|
||||||
from llama_stack.apis.common.content_types import URL
|
|
||||||
from llama_stack.apis.common.type_system import ParamType
|
|
||||||
from llama_stack.apis.datasets import (
|
|
||||||
Dataset,
|
|
||||||
DatasetPurpose,
|
|
||||||
Datasets,
|
|
||||||
DatasetType,
|
|
||||||
DataSource,
|
|
||||||
ListDatasetsResponse,
|
|
||||||
RowsDataSource,
|
|
||||||
URIDataSource,
|
|
||||||
)
|
|
||||||
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType, OpenAIListModelsResponse, OpenAIModel
|
|
||||||
from llama_stack.apis.resource import ResourceType
|
|
||||||
from llama_stack.apis.scoring_functions import (
|
|
||||||
ListScoringFunctionsResponse,
|
|
||||||
ScoringFn,
|
|
||||||
ScoringFnParams,
|
|
||||||
ScoringFunctions,
|
|
||||||
)
|
|
||||||
from llama_stack.apis.shields import ListShieldsResponse, Shield, Shields
|
|
||||||
from llama_stack.apis.tools import (
|
|
||||||
ListToolGroupsResponse,
|
|
||||||
ListToolsResponse,
|
|
||||||
Tool,
|
|
||||||
ToolGroup,
|
|
||||||
ToolGroups,
|
|
||||||
ToolHost,
|
|
||||||
)
|
|
||||||
from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs
|
|
||||||
from llama_stack.distribution.access_control import check_access
|
|
||||||
from llama_stack.distribution.datatypes import (
|
|
||||||
AccessAttributes,
|
|
||||||
BenchmarkWithACL,
|
|
||||||
DatasetWithACL,
|
|
||||||
ModelWithACL,
|
|
||||||
RoutableObject,
|
|
||||||
RoutableObjectWithProvider,
|
|
||||||
RoutedProtocol,
|
|
||||||
ScoringFnWithACL,
|
|
||||||
ShieldWithACL,
|
|
||||||
ToolGroupWithACL,
|
|
||||||
ToolWithACL,
|
|
||||||
VectorDBWithACL,
|
|
||||||
)
|
|
||||||
from llama_stack.distribution.request_headers import get_auth_attributes
|
|
||||||
from llama_stack.distribution.store import DistributionRegistry
|
|
||||||
from llama_stack.providers.datatypes import Api, RoutingTable
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def get_impl_api(p: Any) -> Api:
|
|
||||||
return p.__provider_spec__.api
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: this should return the registered object for all APIs
|
|
||||||
async def register_object_with_provider(obj: RoutableObject, p: Any) -> RoutableObject:
|
|
||||||
api = get_impl_api(p)
|
|
||||||
|
|
||||||
assert obj.provider_id != "remote", "Remote provider should not be registered"
|
|
||||||
|
|
||||||
if api == Api.inference:
|
|
||||||
return await p.register_model(obj)
|
|
||||||
elif api == Api.safety:
|
|
||||||
return await p.register_shield(obj)
|
|
||||||
elif api == Api.vector_io:
|
|
||||||
return await p.register_vector_db(obj)
|
|
||||||
elif api == Api.datasetio:
|
|
||||||
return await p.register_dataset(obj)
|
|
||||||
elif api == Api.scoring:
|
|
||||||
return await p.register_scoring_function(obj)
|
|
||||||
elif api == Api.eval:
|
|
||||||
return await p.register_benchmark(obj)
|
|
||||||
elif api == Api.tool_runtime:
|
|
||||||
return await p.register_tool(obj)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown API {api} for registering object with provider")
|
|
||||||
|
|
||||||
|
|
||||||
async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None:
|
|
||||||
api = get_impl_api(p)
|
|
||||||
if api == Api.vector_io:
|
|
||||||
return await p.unregister_vector_db(obj.identifier)
|
|
||||||
elif api == Api.inference:
|
|
||||||
return await p.unregister_model(obj.identifier)
|
|
||||||
elif api == Api.datasetio:
|
|
||||||
return await p.unregister_dataset(obj.identifier)
|
|
||||||
elif api == Api.tool_runtime:
|
|
||||||
return await p.unregister_tool(obj.identifier)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unregister not supported for {api}")
|
|
||||||
|
|
||||||
|
|
||||||
Registry = dict[str, list[RoutableObjectWithProvider]]
|
|
||||||
|
|
||||||
|
|
||||||
class CommonRoutingTableImpl(RoutingTable):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
impls_by_provider_id: dict[str, RoutedProtocol],
|
|
||||||
dist_registry: DistributionRegistry,
|
|
||||||
) -> None:
|
|
||||||
self.impls_by_provider_id = impls_by_provider_id
|
|
||||||
self.dist_registry = dist_registry
|
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
|
||||||
async def add_objects(objs: list[RoutableObjectWithProvider], provider_id: str, cls) -> None:
|
|
||||||
for obj in objs:
|
|
||||||
if cls is None:
|
|
||||||
obj.provider_id = provider_id
|
|
||||||
else:
|
|
||||||
# Create a copy of the model data and explicitly set provider_id
|
|
||||||
model_data = obj.model_dump()
|
|
||||||
model_data["provider_id"] = provider_id
|
|
||||||
obj = cls(**model_data)
|
|
||||||
await self.dist_registry.register(obj)
|
|
||||||
|
|
||||||
# Register all objects from providers
|
|
||||||
for pid, p in self.impls_by_provider_id.items():
|
|
||||||
api = get_impl_api(p)
|
|
||||||
if api == Api.inference:
|
|
||||||
p.model_store = self
|
|
||||||
elif api == Api.safety:
|
|
||||||
p.shield_store = self
|
|
||||||
elif api == Api.vector_io:
|
|
||||||
p.vector_db_store = self
|
|
||||||
elif api == Api.datasetio:
|
|
||||||
p.dataset_store = self
|
|
||||||
elif api == Api.scoring:
|
|
||||||
p.scoring_function_store = self
|
|
||||||
scoring_functions = await p.list_scoring_functions()
|
|
||||||
await add_objects(scoring_functions, pid, ScoringFn)
|
|
||||||
elif api == Api.eval:
|
|
||||||
p.benchmark_store = self
|
|
||||||
elif api == Api.tool_runtime:
|
|
||||||
p.tool_store = self
|
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
|
||||||
for p in self.impls_by_provider_id.values():
|
|
||||||
await p.shutdown()
|
|
||||||
|
|
||||||
def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any:
|
|
||||||
def apiname_object():
|
|
||||||
if isinstance(self, ModelsRoutingTable):
|
|
||||||
return ("Inference", "model")
|
|
||||||
elif isinstance(self, ShieldsRoutingTable):
|
|
||||||
return ("Safety", "shield")
|
|
||||||
elif isinstance(self, VectorDBsRoutingTable):
|
|
||||||
return ("VectorIO", "vector_db")
|
|
||||||
elif isinstance(self, DatasetsRoutingTable):
|
|
||||||
return ("DatasetIO", "dataset")
|
|
||||||
elif isinstance(self, ScoringFunctionsRoutingTable):
|
|
||||||
return ("Scoring", "scoring_function")
|
|
||||||
elif isinstance(self, BenchmarksRoutingTable):
|
|
||||||
return ("Eval", "benchmark")
|
|
||||||
elif isinstance(self, ToolGroupsRoutingTable):
|
|
||||||
return ("Tools", "tool")
|
|
||||||
else:
|
|
||||||
raise ValueError("Unknown routing table type")
|
|
||||||
|
|
||||||
apiname, objtype = apiname_object()
|
|
||||||
|
|
||||||
# Get objects from disk registry
|
|
||||||
obj = self.dist_registry.get_cached(objtype, routing_key)
|
|
||||||
if not obj:
|
|
||||||
provider_ids = list(self.impls_by_provider_id.keys())
|
|
||||||
if len(provider_ids) > 1:
|
|
||||||
provider_ids_str = f"any of the providers: {', '.join(provider_ids)}"
|
|
||||||
else:
|
|
||||||
provider_ids_str = f"provider: `{provider_ids[0]}`"
|
|
||||||
raise ValueError(
|
|
||||||
f"{objtype.capitalize()} `{routing_key}` not served by {provider_ids_str}. Make sure there is an {apiname} provider serving this {objtype}."
|
|
||||||
)
|
|
||||||
|
|
||||||
if not provider_id or provider_id == obj.provider_id:
|
|
||||||
return self.impls_by_provider_id[obj.provider_id]
|
|
||||||
|
|
||||||
raise ValueError(f"Provider not found for `{routing_key}`")
|
|
||||||
|
|
||||||
async def get_object_by_identifier(self, type: str, identifier: str) -> RoutableObjectWithProvider | None:
|
|
||||||
# Get from disk registry
|
|
||||||
obj = await self.dist_registry.get(type, identifier)
|
|
||||||
if not obj:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Check if user has permission to access this object
|
|
||||||
if not check_access(obj.identifier, getattr(obj, "access_attributes", None), get_auth_attributes()):
|
|
||||||
logger.debug(f"Access denied to {type} '{identifier}' based on attribute mismatch")
|
|
||||||
return None
|
|
||||||
|
|
||||||
return obj
|
|
||||||
|
|
||||||
async def unregister_object(self, obj: RoutableObjectWithProvider) -> None:
|
|
||||||
await self.dist_registry.delete(obj.type, obj.identifier)
|
|
||||||
await unregister_object_from_provider(obj, self.impls_by_provider_id[obj.provider_id])
|
|
||||||
|
|
||||||
async def register_object(self, obj: RoutableObjectWithProvider) -> RoutableObjectWithProvider:
|
|
||||||
# if provider_id is not specified, pick an arbitrary one from existing entries
|
|
||||||
if not obj.provider_id and len(self.impls_by_provider_id) > 0:
|
|
||||||
obj.provider_id = list(self.impls_by_provider_id.keys())[0]
|
|
||||||
|
|
||||||
if obj.provider_id not in self.impls_by_provider_id:
|
|
||||||
raise ValueError(f"Provider `{obj.provider_id}` not found")
|
|
||||||
|
|
||||||
p = self.impls_by_provider_id[obj.provider_id]
|
|
||||||
|
|
||||||
# If object supports access control but no attributes set, use creator's attributes
|
|
||||||
if not obj.access_attributes:
|
|
||||||
creator_attributes = get_auth_attributes()
|
|
||||||
if creator_attributes:
|
|
||||||
obj.access_attributes = AccessAttributes(**creator_attributes)
|
|
||||||
logger.info(f"Setting access attributes for {obj.type} '{obj.identifier}' based on creator's identity")
|
|
||||||
|
|
||||||
registered_obj = await register_object_with_provider(obj, p)
|
|
||||||
# TODO: This needs to be fixed for all APIs once they return the registered object
|
|
||||||
if obj.type == ResourceType.model.value:
|
|
||||||
await self.dist_registry.register(registered_obj)
|
|
||||||
return registered_obj
|
|
||||||
|
|
||||||
else:
|
|
||||||
await self.dist_registry.register(obj)
|
|
||||||
return obj
|
|
||||||
|
|
||||||
async def get_all_with_type(self, type: str) -> list[RoutableObjectWithProvider]:
|
|
||||||
objs = await self.dist_registry.get_all()
|
|
||||||
filtered_objs = [obj for obj in objs if obj.type == type]
|
|
||||||
|
|
||||||
# Apply attribute-based access control filtering
|
|
||||||
if filtered_objs:
|
|
||||||
filtered_objs = [
|
|
||||||
obj
|
|
||||||
for obj in filtered_objs
|
|
||||||
if check_access(obj.identifier, getattr(obj, "access_attributes", None), get_auth_attributes())
|
|
||||||
]
|
|
||||||
|
|
||||||
return filtered_objs
|
|
||||||
|
|
||||||
|
|
||||||
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|
||||||
async def list_models(self) -> ListModelsResponse:
|
|
||||||
return ListModelsResponse(data=await self.get_all_with_type("model"))
|
|
||||||
|
|
||||||
async def openai_list_models(self) -> OpenAIListModelsResponse:
|
|
||||||
models = await self.get_all_with_type("model")
|
|
||||||
openai_models = [
|
|
||||||
OpenAIModel(
|
|
||||||
id=model.identifier,
|
|
||||||
object="model",
|
|
||||||
created=int(time.time()),
|
|
||||||
owned_by="llama_stack",
|
|
||||||
)
|
|
||||||
for model in models
|
|
||||||
]
|
|
||||||
return OpenAIListModelsResponse(data=openai_models)
|
|
||||||
|
|
||||||
async def get_model(self, model_id: str) -> Model:
|
|
||||||
model = await self.get_object_by_identifier("model", model_id)
|
|
||||||
if model is None:
|
|
||||||
raise ValueError(f"Model '{model_id}' not found")
|
|
||||||
return model
|
|
||||||
|
|
||||||
async def register_model(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
provider_model_id: str | None = None,
|
|
||||||
provider_id: str | None = None,
|
|
||||||
metadata: dict[str, Any] | None = None,
|
|
||||||
model_type: ModelType | None = None,
|
|
||||||
) -> Model:
|
|
||||||
if provider_model_id is None:
|
|
||||||
provider_model_id = model_id
|
|
||||||
if provider_id is None:
|
|
||||||
# If provider_id not specified, use the only provider if it supports this model
|
|
||||||
if len(self.impls_by_provider_id) == 1:
|
|
||||||
provider_id = list(self.impls_by_provider_id.keys())[0]
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"No provider specified and multiple providers available. Please specify a provider_id. Available providers: {self.impls_by_provider_id.keys()}"
|
|
||||||
)
|
|
||||||
if metadata is None:
|
|
||||||
metadata = {}
|
|
||||||
if model_type is None:
|
|
||||||
model_type = ModelType.llm
|
|
||||||
if "embedding_dimension" not in metadata and model_type == ModelType.embedding:
|
|
||||||
raise ValueError("Embedding model must have an embedding dimension in its metadata")
|
|
||||||
model = ModelWithACL(
|
|
||||||
identifier=model_id,
|
|
||||||
provider_resource_id=provider_model_id,
|
|
||||||
provider_id=provider_id,
|
|
||||||
metadata=metadata,
|
|
||||||
model_type=model_type,
|
|
||||||
)
|
|
||||||
registered_model = await self.register_object(model)
|
|
||||||
return registered_model
|
|
||||||
|
|
||||||
async def unregister_model(self, model_id: str) -> None:
|
|
||||||
existing_model = await self.get_model(model_id)
|
|
||||||
if existing_model is None:
|
|
||||||
raise ValueError(f"Model {model_id} not found")
|
|
||||||
await self.unregister_object(existing_model)
|
|
||||||
|
|
||||||
|
|
||||||
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
|
||||||
async def list_shields(self) -> ListShieldsResponse:
|
|
||||||
return ListShieldsResponse(data=await self.get_all_with_type(ResourceType.shield.value))
|
|
||||||
|
|
||||||
async def get_shield(self, identifier: str) -> Shield:
|
|
||||||
shield = await self.get_object_by_identifier("shield", identifier)
|
|
||||||
if shield is None:
|
|
||||||
raise ValueError(f"Shield '{identifier}' not found")
|
|
||||||
return shield
|
|
||||||
|
|
||||||
async def register_shield(
|
|
||||||
self,
|
|
||||||
shield_id: str,
|
|
||||||
provider_shield_id: str | None = None,
|
|
||||||
provider_id: str | None = None,
|
|
||||||
params: dict[str, Any] | None = None,
|
|
||||||
) -> Shield:
|
|
||||||
if provider_shield_id is None:
|
|
||||||
provider_shield_id = shield_id
|
|
||||||
if provider_id is None:
|
|
||||||
# If provider_id not specified, use the only provider if it supports this shield type
|
|
||||||
if len(self.impls_by_provider_id) == 1:
|
|
||||||
provider_id = list(self.impls_by_provider_id.keys())[0]
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"No provider specified and multiple providers available. Please specify a provider_id."
|
|
||||||
)
|
|
||||||
if params is None:
|
|
||||||
params = {}
|
|
||||||
shield = ShieldWithACL(
|
|
||||||
identifier=shield_id,
|
|
||||||
provider_resource_id=provider_shield_id,
|
|
||||||
provider_id=provider_id,
|
|
||||||
params=params,
|
|
||||||
)
|
|
||||||
await self.register_object(shield)
|
|
||||||
return shield
|
|
||||||
|
|
||||||
|
|
||||||
class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
|
||||||
async def list_vector_dbs(self) -> ListVectorDBsResponse:
|
|
||||||
return ListVectorDBsResponse(data=await self.get_all_with_type("vector_db"))
|
|
||||||
|
|
||||||
async def get_vector_db(self, vector_db_id: str) -> VectorDB:
|
|
||||||
vector_db = await self.get_object_by_identifier("vector_db", vector_db_id)
|
|
||||||
if vector_db is None:
|
|
||||||
raise ValueError(f"Vector DB '{vector_db_id}' not found")
|
|
||||||
return vector_db
|
|
||||||
|
|
||||||
async def register_vector_db(
|
|
||||||
self,
|
|
||||||
vector_db_id: str,
|
|
||||||
embedding_model: str,
|
|
||||||
embedding_dimension: int | None = 384,
|
|
||||||
provider_id: str | None = None,
|
|
||||||
provider_vector_db_id: str | None = None,
|
|
||||||
) -> VectorDB:
|
|
||||||
if provider_vector_db_id is None:
|
|
||||||
provider_vector_db_id = vector_db_id
|
|
||||||
if provider_id is None:
|
|
||||||
if len(self.impls_by_provider_id) > 0:
|
|
||||||
provider_id = list(self.impls_by_provider_id.keys())[0]
|
|
||||||
if len(self.impls_by_provider_id) > 1:
|
|
||||||
logger.warning(
|
|
||||||
f"No provider specified and multiple providers available. Arbitrarily selected the first provider {provider_id}."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError("No provider available. Please configure a vector_io provider.")
|
|
||||||
model = await self.get_object_by_identifier("model", embedding_model)
|
|
||||||
if model is None:
|
|
||||||
raise ValueError(f"Model {embedding_model} not found")
|
|
||||||
if model.model_type != ModelType.embedding:
|
|
||||||
raise ValueError(f"Model {embedding_model} is not an embedding model")
|
|
||||||
if "embedding_dimension" not in model.metadata:
|
|
||||||
raise ValueError(f"Model {embedding_model} does not have an embedding dimension")
|
|
||||||
vector_db_data = {
|
|
||||||
"identifier": vector_db_id,
|
|
||||||
"type": ResourceType.vector_db.value,
|
|
||||||
"provider_id": provider_id,
|
|
||||||
"provider_resource_id": provider_vector_db_id,
|
|
||||||
"embedding_model": embedding_model,
|
|
||||||
"embedding_dimension": model.metadata["embedding_dimension"],
|
|
||||||
}
|
|
||||||
vector_db = TypeAdapter(VectorDBWithACL).validate_python(vector_db_data)
|
|
||||||
await self.register_object(vector_db)
|
|
||||||
return vector_db
|
|
||||||
|
|
||||||
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
|
||||||
existing_vector_db = await self.get_vector_db(vector_db_id)
|
|
||||||
if existing_vector_db is None:
|
|
||||||
raise ValueError(f"Vector DB {vector_db_id} not found")
|
|
||||||
await self.unregister_object(existing_vector_db)
|
|
||||||
|
|
||||||
|
|
||||||
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
|
||||||
async def list_datasets(self) -> ListDatasetsResponse:
|
|
||||||
return ListDatasetsResponse(data=await self.get_all_with_type(ResourceType.dataset.value))
|
|
||||||
|
|
||||||
async def get_dataset(self, dataset_id: str) -> Dataset:
|
|
||||||
dataset = await self.get_object_by_identifier("dataset", dataset_id)
|
|
||||||
if dataset is None:
|
|
||||||
raise ValueError(f"Dataset '{dataset_id}' not found")
|
|
||||||
return dataset
|
|
||||||
|
|
||||||
async def register_dataset(
|
|
||||||
self,
|
|
||||||
purpose: DatasetPurpose,
|
|
||||||
source: DataSource,
|
|
||||||
metadata: dict[str, Any] | None = None,
|
|
||||||
dataset_id: str | None = None,
|
|
||||||
) -> Dataset:
|
|
||||||
if isinstance(source, dict):
|
|
||||||
if source["type"] == "uri":
|
|
||||||
source = URIDataSource.parse_obj(source)
|
|
||||||
elif source["type"] == "rows":
|
|
||||||
source = RowsDataSource.parse_obj(source)
|
|
||||||
|
|
||||||
if not dataset_id:
|
|
||||||
dataset_id = f"dataset-{str(uuid.uuid4())}"
|
|
||||||
|
|
||||||
provider_dataset_id = dataset_id
|
|
||||||
|
|
||||||
# infer provider from source
|
|
||||||
if metadata:
|
|
||||||
if metadata.get("provider_id"):
|
|
||||||
provider_id = metadata.get("provider_id") # pass through from nvidia datasetio
|
|
||||||
elif source.type == DatasetType.rows.value:
|
|
||||||
provider_id = "localfs"
|
|
||||||
elif source.type == DatasetType.uri.value:
|
|
||||||
# infer provider from uri
|
|
||||||
if source.uri.startswith("huggingface"):
|
|
||||||
provider_id = "huggingface"
|
|
||||||
else:
|
|
||||||
provider_id = "localfs"
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown data source type: {source.type}")
|
|
||||||
|
|
||||||
if metadata is None:
|
|
||||||
metadata = {}
|
|
||||||
|
|
||||||
dataset = DatasetWithACL(
|
|
||||||
identifier=dataset_id,
|
|
||||||
provider_resource_id=provider_dataset_id,
|
|
||||||
provider_id=provider_id,
|
|
||||||
purpose=purpose,
|
|
||||||
source=source,
|
|
||||||
metadata=metadata,
|
|
||||||
)
|
|
||||||
|
|
||||||
await self.register_object(dataset)
|
|
||||||
return dataset
|
|
||||||
|
|
||||||
async def unregister_dataset(self, dataset_id: str) -> None:
|
|
||||||
dataset = await self.get_dataset(dataset_id)
|
|
||||||
if dataset is None:
|
|
||||||
raise ValueError(f"Dataset {dataset_id} not found")
|
|
||||||
await self.unregister_object(dataset)
|
|
||||||
|
|
||||||
|
|
||||||
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
|
|
||||||
async def list_scoring_functions(self) -> ListScoringFunctionsResponse:
|
|
||||||
return ListScoringFunctionsResponse(data=await self.get_all_with_type(ResourceType.scoring_function.value))
|
|
||||||
|
|
||||||
async def get_scoring_function(self, scoring_fn_id: str) -> ScoringFn:
|
|
||||||
scoring_fn = await self.get_object_by_identifier("scoring_function", scoring_fn_id)
|
|
||||||
if scoring_fn is None:
|
|
||||||
raise ValueError(f"Scoring function '{scoring_fn_id}' not found")
|
|
||||||
return scoring_fn
|
|
||||||
|
|
||||||
async def register_scoring_function(
|
|
||||||
self,
|
|
||||||
scoring_fn_id: str,
|
|
||||||
description: str,
|
|
||||||
return_type: ParamType,
|
|
||||||
provider_scoring_fn_id: str | None = None,
|
|
||||||
provider_id: str | None = None,
|
|
||||||
params: ScoringFnParams | None = None,
|
|
||||||
) -> None:
|
|
||||||
if provider_scoring_fn_id is None:
|
|
||||||
provider_scoring_fn_id = scoring_fn_id
|
|
||||||
if provider_id is None:
|
|
||||||
if len(self.impls_by_provider_id) == 1:
|
|
||||||
provider_id = list(self.impls_by_provider_id.keys())[0]
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"No provider specified and multiple providers available. Please specify a provider_id."
|
|
||||||
)
|
|
||||||
scoring_fn = ScoringFnWithACL(
|
|
||||||
identifier=scoring_fn_id,
|
|
||||||
description=description,
|
|
||||||
return_type=return_type,
|
|
||||||
provider_resource_id=provider_scoring_fn_id,
|
|
||||||
provider_id=provider_id,
|
|
||||||
params=params,
|
|
||||||
)
|
|
||||||
scoring_fn.provider_id = provider_id
|
|
||||||
await self.register_object(scoring_fn)
|
|
||||||
|
|
||||||
|
|
||||||
class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
|
|
||||||
async def list_benchmarks(self) -> ListBenchmarksResponse:
|
|
||||||
return ListBenchmarksResponse(data=await self.get_all_with_type("benchmark"))
|
|
||||||
|
|
||||||
async def get_benchmark(self, benchmark_id: str) -> Benchmark:
|
|
||||||
benchmark = await self.get_object_by_identifier("benchmark", benchmark_id)
|
|
||||||
if benchmark is None:
|
|
||||||
raise ValueError(f"Benchmark '{benchmark_id}' not found")
|
|
||||||
return benchmark
|
|
||||||
|
|
||||||
async def register_benchmark(
|
|
||||||
self,
|
|
||||||
benchmark_id: str,
|
|
||||||
dataset_id: str,
|
|
||||||
scoring_functions: list[str],
|
|
||||||
metadata: dict[str, Any] | None = None,
|
|
||||||
provider_benchmark_id: str | None = None,
|
|
||||||
provider_id: str | None = None,
|
|
||||||
) -> None:
|
|
||||||
if metadata is None:
|
|
||||||
metadata = {}
|
|
||||||
if provider_id is None:
|
|
||||||
if len(self.impls_by_provider_id) == 1:
|
|
||||||
provider_id = list(self.impls_by_provider_id.keys())[0]
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"No provider specified and multiple providers available. Please specify a provider_id."
|
|
||||||
)
|
|
||||||
if provider_benchmark_id is None:
|
|
||||||
provider_benchmark_id = benchmark_id
|
|
||||||
benchmark = BenchmarkWithACL(
|
|
||||||
identifier=benchmark_id,
|
|
||||||
dataset_id=dataset_id,
|
|
||||||
scoring_functions=scoring_functions,
|
|
||||||
metadata=metadata,
|
|
||||||
provider_id=provider_id,
|
|
||||||
provider_resource_id=provider_benchmark_id,
|
|
||||||
)
|
|
||||||
await self.register_object(benchmark)
|
|
||||||
|
|
||||||
|
|
||||||
class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|
||||||
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse:
|
|
||||||
tools = await self.get_all_with_type("tool")
|
|
||||||
if toolgroup_id:
|
|
||||||
tools = [tool for tool in tools if tool.toolgroup_id == toolgroup_id]
|
|
||||||
return ListToolsResponse(data=tools)
|
|
||||||
|
|
||||||
async def list_tool_groups(self) -> ListToolGroupsResponse:
|
|
||||||
return ListToolGroupsResponse(data=await self.get_all_with_type("tool_group"))
|
|
||||||
|
|
||||||
async def get_tool_group(self, toolgroup_id: str) -> ToolGroup:
|
|
||||||
tool_group = await self.get_object_by_identifier("tool_group", toolgroup_id)
|
|
||||||
if tool_group is None:
|
|
||||||
raise ValueError(f"Tool group '{toolgroup_id}' not found")
|
|
||||||
return tool_group
|
|
||||||
|
|
||||||
async def get_tool(self, tool_name: str) -> Tool:
|
|
||||||
return await self.get_object_by_identifier("tool", tool_name)
|
|
||||||
|
|
||||||
async def register_tool_group(
|
|
||||||
self,
|
|
||||||
toolgroup_id: str,
|
|
||||||
provider_id: str,
|
|
||||||
mcp_endpoint: URL | None = None,
|
|
||||||
args: dict[str, Any] | None = None,
|
|
||||||
) -> None:
|
|
||||||
tools = []
|
|
||||||
tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools(toolgroup_id, mcp_endpoint)
|
|
||||||
tool_host = ToolHost.model_context_protocol if mcp_endpoint else ToolHost.distribution
|
|
||||||
|
|
||||||
for tool_def in tool_defs.data:
|
|
||||||
tools.append(
|
|
||||||
ToolWithACL(
|
|
||||||
identifier=tool_def.name,
|
|
||||||
toolgroup_id=toolgroup_id,
|
|
||||||
description=tool_def.description or "",
|
|
||||||
parameters=tool_def.parameters or [],
|
|
||||||
provider_id=provider_id,
|
|
||||||
provider_resource_id=tool_def.name,
|
|
||||||
metadata=tool_def.metadata,
|
|
||||||
tool_host=tool_host,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
for tool in tools:
|
|
||||||
existing_tool = await self.get_tool(tool.identifier)
|
|
||||||
# Compare existing and new object if one exists
|
|
||||||
if existing_tool:
|
|
||||||
existing_dict = existing_tool.model_dump()
|
|
||||||
new_dict = tool.model_dump()
|
|
||||||
|
|
||||||
if existing_dict != new_dict:
|
|
||||||
raise ValueError(
|
|
||||||
f"Object {tool.identifier} already exists in registry. Please use a different identifier."
|
|
||||||
)
|
|
||||||
await self.register_object(tool)
|
|
||||||
|
|
||||||
await self.dist_registry.register(
|
|
||||||
ToolGroupWithACL(
|
|
||||||
identifier=toolgroup_id,
|
|
||||||
provider_id=provider_id,
|
|
||||||
provider_resource_id=toolgroup_id,
|
|
||||||
mcp_endpoint=mcp_endpoint,
|
|
||||||
args=args,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
|
|
||||||
tool_group = await self.get_tool_group(toolgroup_id)
|
|
||||||
if tool_group is None:
|
|
||||||
raise ValueError(f"Tool group {toolgroup_id} not found")
|
|
||||||
tools = await self.list_tools(toolgroup_id)
|
|
||||||
for tool in getattr(tools, "data", []):
|
|
||||||
await self.unregister_object(tool)
|
|
||||||
await self.unregister_object(tool_group)
|
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
|
||||||
pass
|
|
57
llama_stack/distribution/routers/safety.py
Normal file
57
llama_stack/distribution/routers/safety.py
Normal file
|
@ -0,0 +1,57 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from llama_stack.apis.inference import (
|
||||||
|
Message,
|
||||||
|
)
|
||||||
|
from llama_stack.apis.safety import RunShieldResponse, Safety
|
||||||
|
from llama_stack.apis.shields import Shield
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
from llama_stack.providers.datatypes import RoutingTable
|
||||||
|
|
||||||
|
logger = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
|
class SafetyRouter(Safety):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
routing_table: RoutingTable,
|
||||||
|
) -> None:
|
||||||
|
logger.debug("Initializing SafetyRouter")
|
||||||
|
self.routing_table = routing_table
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
logger.debug("SafetyRouter.initialize")
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
logger.debug("SafetyRouter.shutdown")
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def register_shield(
|
||||||
|
self,
|
||||||
|
shield_id: str,
|
||||||
|
provider_shield_id: str | None = None,
|
||||||
|
provider_id: str | None = None,
|
||||||
|
params: dict[str, Any] | None = None,
|
||||||
|
) -> Shield:
|
||||||
|
logger.debug(f"SafetyRouter.register_shield: {shield_id}")
|
||||||
|
return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params)
|
||||||
|
|
||||||
|
async def run_shield(
|
||||||
|
self,
|
||||||
|
shield_id: str,
|
||||||
|
messages: list[Message],
|
||||||
|
params: dict[str, Any] = None,
|
||||||
|
) -> RunShieldResponse:
|
||||||
|
logger.debug(f"SafetyRouter.run_shield: {shield_id}")
|
||||||
|
return await self.routing_table.get_provider_impl(shield_id).run_shield(
|
||||||
|
shield_id=shield_id,
|
||||||
|
messages=messages,
|
||||||
|
params=params,
|
||||||
|
)
|
92
llama_stack/distribution/routers/tool_runtime.py
Normal file
92
llama_stack/distribution/routers/tool_runtime.py
Normal file
|
@ -0,0 +1,92 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from llama_stack.apis.common.content_types import (
|
||||||
|
URL,
|
||||||
|
InterleavedContent,
|
||||||
|
)
|
||||||
|
from llama_stack.apis.tools import (
|
||||||
|
ListToolsResponse,
|
||||||
|
RAGDocument,
|
||||||
|
RAGQueryConfig,
|
||||||
|
RAGQueryResult,
|
||||||
|
RAGToolRuntime,
|
||||||
|
ToolRuntime,
|
||||||
|
)
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
|
from ..routing_tables.toolgroups import ToolGroupsRoutingTable
|
||||||
|
|
||||||
|
logger = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
|
class ToolRuntimeRouter(ToolRuntime):
|
||||||
|
class RagToolImpl(RAGToolRuntime):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
routing_table: ToolGroupsRoutingTable,
|
||||||
|
) -> None:
|
||||||
|
logger.debug("Initializing ToolRuntimeRouter.RagToolImpl")
|
||||||
|
self.routing_table = routing_table
|
||||||
|
|
||||||
|
async def query(
|
||||||
|
self,
|
||||||
|
content: InterleavedContent,
|
||||||
|
vector_db_ids: list[str],
|
||||||
|
query_config: RAGQueryConfig | None = None,
|
||||||
|
) -> RAGQueryResult:
|
||||||
|
logger.debug(f"ToolRuntimeRouter.RagToolImpl.query: {vector_db_ids}")
|
||||||
|
return await self.routing_table.get_provider_impl("knowledge_search").query(
|
||||||
|
content, vector_db_ids, query_config
|
||||||
|
)
|
||||||
|
|
||||||
|
async def insert(
|
||||||
|
self,
|
||||||
|
documents: list[RAGDocument],
|
||||||
|
vector_db_id: str,
|
||||||
|
chunk_size_in_tokens: int = 512,
|
||||||
|
) -> None:
|
||||||
|
logger.debug(
|
||||||
|
f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}"
|
||||||
|
)
|
||||||
|
return await self.routing_table.get_provider_impl("insert_into_memory").insert(
|
||||||
|
documents, vector_db_id, chunk_size_in_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
routing_table: ToolGroupsRoutingTable,
|
||||||
|
) -> None:
|
||||||
|
logger.debug("Initializing ToolRuntimeRouter")
|
||||||
|
self.routing_table = routing_table
|
||||||
|
|
||||||
|
# HACK ALERT this should be in sync with "get_all_api_endpoints()"
|
||||||
|
self.rag_tool = self.RagToolImpl(routing_table)
|
||||||
|
for method in ("query", "insert"):
|
||||||
|
setattr(self, f"rag_tool.{method}", getattr(self.rag_tool, method))
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
logger.debug("ToolRuntimeRouter.initialize")
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
logger.debug("ToolRuntimeRouter.shutdown")
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> Any:
|
||||||
|
logger.debug(f"ToolRuntimeRouter.invoke_tool: {tool_name}")
|
||||||
|
return await self.routing_table.get_provider_impl(tool_name).invoke_tool(
|
||||||
|
tool_name=tool_name,
|
||||||
|
kwargs=kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def list_runtime_tools(
|
||||||
|
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
|
||||||
|
) -> ListToolsResponse:
|
||||||
|
logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
|
||||||
|
return await self.routing_table.list_tools(tool_group_id)
|
72
llama_stack/distribution/routers/vector_io.py
Normal file
72
llama_stack/distribution/routers/vector_io.py
Normal file
|
@ -0,0 +1,72 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from llama_stack.apis.common.content_types import (
|
||||||
|
InterleavedContent,
|
||||||
|
)
|
||||||
|
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
from llama_stack.providers.datatypes import RoutingTable
|
||||||
|
|
||||||
|
logger = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
|
class VectorIORouter(VectorIO):
|
||||||
|
"""Routes to an provider based on the vector db identifier"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
routing_table: RoutingTable,
|
||||||
|
) -> None:
|
||||||
|
logger.debug("Initializing VectorIORouter")
|
||||||
|
self.routing_table = routing_table
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
logger.debug("VectorIORouter.initialize")
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
logger.debug("VectorIORouter.shutdown")
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def register_vector_db(
|
||||||
|
self,
|
||||||
|
vector_db_id: str,
|
||||||
|
embedding_model: str,
|
||||||
|
embedding_dimension: int | None = 384,
|
||||||
|
provider_id: str | None = None,
|
||||||
|
provider_vector_db_id: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
logger.debug(f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}")
|
||||||
|
await self.routing_table.register_vector_db(
|
||||||
|
vector_db_id,
|
||||||
|
embedding_model,
|
||||||
|
embedding_dimension,
|
||||||
|
provider_id,
|
||||||
|
provider_vector_db_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def insert_chunks(
|
||||||
|
self,
|
||||||
|
vector_db_id: str,
|
||||||
|
chunks: list[Chunk],
|
||||||
|
ttl_seconds: int | None = None,
|
||||||
|
) -> None:
|
||||||
|
logger.debug(
|
||||||
|
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}",
|
||||||
|
)
|
||||||
|
return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(vector_db_id, chunks, ttl_seconds)
|
||||||
|
|
||||||
|
async def query_chunks(
|
||||||
|
self,
|
||||||
|
vector_db_id: str,
|
||||||
|
query: InterleavedContent,
|
||||||
|
params: dict[str, Any] | None = None,
|
||||||
|
) -> QueryChunksResponse:
|
||||||
|
logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}")
|
||||||
|
return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params)
|
5
llama_stack/distribution/routing_tables/__init__.py
Normal file
5
llama_stack/distribution/routing_tables/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
58
llama_stack/distribution/routing_tables/benchmarks.py
Normal file
58
llama_stack/distribution/routing_tables/benchmarks.py
Normal file
|
@ -0,0 +1,58 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from llama_stack.apis.benchmarks import Benchmark, Benchmarks, ListBenchmarksResponse
|
||||||
|
from llama_stack.distribution.datatypes import (
|
||||||
|
BenchmarkWithACL,
|
||||||
|
)
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
|
from .common import CommonRoutingTableImpl
|
||||||
|
|
||||||
|
logger = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
|
class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
|
||||||
|
async def list_benchmarks(self) -> ListBenchmarksResponse:
|
||||||
|
return ListBenchmarksResponse(data=await self.get_all_with_type("benchmark"))
|
||||||
|
|
||||||
|
async def get_benchmark(self, benchmark_id: str) -> Benchmark:
|
||||||
|
benchmark = await self.get_object_by_identifier("benchmark", benchmark_id)
|
||||||
|
if benchmark is None:
|
||||||
|
raise ValueError(f"Benchmark '{benchmark_id}' not found")
|
||||||
|
return benchmark
|
||||||
|
|
||||||
|
async def register_benchmark(
|
||||||
|
self,
|
||||||
|
benchmark_id: str,
|
||||||
|
dataset_id: str,
|
||||||
|
scoring_functions: list[str],
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
|
provider_benchmark_id: str | None = None,
|
||||||
|
provider_id: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
if metadata is None:
|
||||||
|
metadata = {}
|
||||||
|
if provider_id is None:
|
||||||
|
if len(self.impls_by_provider_id) == 1:
|
||||||
|
provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"No provider specified and multiple providers available. Please specify a provider_id."
|
||||||
|
)
|
||||||
|
if provider_benchmark_id is None:
|
||||||
|
provider_benchmark_id = benchmark_id
|
||||||
|
benchmark = BenchmarkWithACL(
|
||||||
|
identifier=benchmark_id,
|
||||||
|
dataset_id=dataset_id,
|
||||||
|
scoring_functions=scoring_functions,
|
||||||
|
metadata=metadata,
|
||||||
|
provider_id=provider_id,
|
||||||
|
provider_resource_id=provider_benchmark_id,
|
||||||
|
)
|
||||||
|
await self.register_object(benchmark)
|
218
llama_stack/distribution/routing_tables/common.py
Normal file
218
llama_stack/distribution/routing_tables/common.py
Normal file
|
@ -0,0 +1,218 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from llama_stack.apis.resource import ResourceType
|
||||||
|
from llama_stack.apis.scoring_functions import ScoringFn
|
||||||
|
from llama_stack.distribution.access_control import check_access
|
||||||
|
from llama_stack.distribution.datatypes import (
|
||||||
|
AccessAttributes,
|
||||||
|
RoutableObject,
|
||||||
|
RoutableObjectWithProvider,
|
||||||
|
RoutedProtocol,
|
||||||
|
)
|
||||||
|
from llama_stack.distribution.request_headers import get_auth_attributes
|
||||||
|
from llama_stack.distribution.store import DistributionRegistry
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
from llama_stack.providers.datatypes import Api, RoutingTable
|
||||||
|
|
||||||
|
logger = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
|
def get_impl_api(p: Any) -> Api:
|
||||||
|
return p.__provider_spec__.api
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: this should return the registered object for all APIs
|
||||||
|
async def register_object_with_provider(obj: RoutableObject, p: Any) -> RoutableObject:
|
||||||
|
api = get_impl_api(p)
|
||||||
|
|
||||||
|
assert obj.provider_id != "remote", "Remote provider should not be registered"
|
||||||
|
|
||||||
|
if api == Api.inference:
|
||||||
|
return await p.register_model(obj)
|
||||||
|
elif api == Api.safety:
|
||||||
|
return await p.register_shield(obj)
|
||||||
|
elif api == Api.vector_io:
|
||||||
|
return await p.register_vector_db(obj)
|
||||||
|
elif api == Api.datasetio:
|
||||||
|
return await p.register_dataset(obj)
|
||||||
|
elif api == Api.scoring:
|
||||||
|
return await p.register_scoring_function(obj)
|
||||||
|
elif api == Api.eval:
|
||||||
|
return await p.register_benchmark(obj)
|
||||||
|
elif api == Api.tool_runtime:
|
||||||
|
return await p.register_toolgroup(obj)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown API {api} for registering object with provider")
|
||||||
|
|
||||||
|
|
||||||
|
async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None:
|
||||||
|
api = get_impl_api(p)
|
||||||
|
if api == Api.vector_io:
|
||||||
|
return await p.unregister_vector_db(obj.identifier)
|
||||||
|
elif api == Api.inference:
|
||||||
|
return await p.unregister_model(obj.identifier)
|
||||||
|
elif api == Api.datasetio:
|
||||||
|
return await p.unregister_dataset(obj.identifier)
|
||||||
|
elif api == Api.tool_runtime:
|
||||||
|
return await p.unregister_toolgroup(obj.identifier)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unregister not supported for {api}")
|
||||||
|
|
||||||
|
|
||||||
|
Registry = dict[str, list[RoutableObjectWithProvider]]
|
||||||
|
|
||||||
|
|
||||||
|
class CommonRoutingTableImpl(RoutingTable):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
impls_by_provider_id: dict[str, RoutedProtocol],
|
||||||
|
dist_registry: DistributionRegistry,
|
||||||
|
) -> None:
|
||||||
|
self.impls_by_provider_id = impls_by_provider_id
|
||||||
|
self.dist_registry = dist_registry
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
async def add_objects(objs: list[RoutableObjectWithProvider], provider_id: str, cls) -> None:
|
||||||
|
for obj in objs:
|
||||||
|
if cls is None:
|
||||||
|
obj.provider_id = provider_id
|
||||||
|
else:
|
||||||
|
# Create a copy of the model data and explicitly set provider_id
|
||||||
|
model_data = obj.model_dump()
|
||||||
|
model_data["provider_id"] = provider_id
|
||||||
|
obj = cls(**model_data)
|
||||||
|
await self.dist_registry.register(obj)
|
||||||
|
|
||||||
|
# Register all objects from providers
|
||||||
|
for pid, p in self.impls_by_provider_id.items():
|
||||||
|
api = get_impl_api(p)
|
||||||
|
if api == Api.inference:
|
||||||
|
p.model_store = self
|
||||||
|
elif api == Api.safety:
|
||||||
|
p.shield_store = self
|
||||||
|
elif api == Api.vector_io:
|
||||||
|
p.vector_db_store = self
|
||||||
|
elif api == Api.datasetio:
|
||||||
|
p.dataset_store = self
|
||||||
|
elif api == Api.scoring:
|
||||||
|
p.scoring_function_store = self
|
||||||
|
scoring_functions = await p.list_scoring_functions()
|
||||||
|
await add_objects(scoring_functions, pid, ScoringFn)
|
||||||
|
elif api == Api.eval:
|
||||||
|
p.benchmark_store = self
|
||||||
|
elif api == Api.tool_runtime:
|
||||||
|
p.tool_store = self
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
for p in self.impls_by_provider_id.values():
|
||||||
|
await p.shutdown()
|
||||||
|
|
||||||
|
def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any:
|
||||||
|
from .benchmarks import BenchmarksRoutingTable
|
||||||
|
from .datasets import DatasetsRoutingTable
|
||||||
|
from .models import ModelsRoutingTable
|
||||||
|
from .scoring_functions import ScoringFunctionsRoutingTable
|
||||||
|
from .shields import ShieldsRoutingTable
|
||||||
|
from .toolgroups import ToolGroupsRoutingTable
|
||||||
|
from .vector_dbs import VectorDBsRoutingTable
|
||||||
|
|
||||||
|
def apiname_object():
|
||||||
|
if isinstance(self, ModelsRoutingTable):
|
||||||
|
return ("Inference", "model")
|
||||||
|
elif isinstance(self, ShieldsRoutingTable):
|
||||||
|
return ("Safety", "shield")
|
||||||
|
elif isinstance(self, VectorDBsRoutingTable):
|
||||||
|
return ("VectorIO", "vector_db")
|
||||||
|
elif isinstance(self, DatasetsRoutingTable):
|
||||||
|
return ("DatasetIO", "dataset")
|
||||||
|
elif isinstance(self, ScoringFunctionsRoutingTable):
|
||||||
|
return ("Scoring", "scoring_function")
|
||||||
|
elif isinstance(self, BenchmarksRoutingTable):
|
||||||
|
return ("Eval", "benchmark")
|
||||||
|
elif isinstance(self, ToolGroupsRoutingTable):
|
||||||
|
return ("ToolGroups", "tool_group")
|
||||||
|
else:
|
||||||
|
raise ValueError("Unknown routing table type")
|
||||||
|
|
||||||
|
apiname, objtype = apiname_object()
|
||||||
|
|
||||||
|
# Get objects from disk registry
|
||||||
|
obj = self.dist_registry.get_cached(objtype, routing_key)
|
||||||
|
if not obj:
|
||||||
|
provider_ids = list(self.impls_by_provider_id.keys())
|
||||||
|
if len(provider_ids) > 1:
|
||||||
|
provider_ids_str = f"any of the providers: {', '.join(provider_ids)}"
|
||||||
|
else:
|
||||||
|
provider_ids_str = f"provider: `{provider_ids[0]}`"
|
||||||
|
raise ValueError(
|
||||||
|
f"{objtype.capitalize()} `{routing_key}` not served by {provider_ids_str}. Make sure there is an {apiname} provider serving this {objtype}."
|
||||||
|
)
|
||||||
|
|
||||||
|
if not provider_id or provider_id == obj.provider_id:
|
||||||
|
return self.impls_by_provider_id[obj.provider_id]
|
||||||
|
|
||||||
|
raise ValueError(f"Provider not found for `{routing_key}`")
|
||||||
|
|
||||||
|
async def get_object_by_identifier(self, type: str, identifier: str) -> RoutableObjectWithProvider | None:
|
||||||
|
# Get from disk registry
|
||||||
|
obj = await self.dist_registry.get(type, identifier)
|
||||||
|
if not obj:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Check if user has permission to access this object
|
||||||
|
if not check_access(obj.identifier, getattr(obj, "access_attributes", None), get_auth_attributes()):
|
||||||
|
logger.debug(f"Access denied to {type} '{identifier}' based on attribute mismatch")
|
||||||
|
return None
|
||||||
|
|
||||||
|
return obj
|
||||||
|
|
||||||
|
async def unregister_object(self, obj: RoutableObjectWithProvider) -> None:
|
||||||
|
await self.dist_registry.delete(obj.type, obj.identifier)
|
||||||
|
await unregister_object_from_provider(obj, self.impls_by_provider_id[obj.provider_id])
|
||||||
|
|
||||||
|
async def register_object(self, obj: RoutableObjectWithProvider) -> RoutableObjectWithProvider:
|
||||||
|
# if provider_id is not specified, pick an arbitrary one from existing entries
|
||||||
|
if not obj.provider_id and len(self.impls_by_provider_id) > 0:
|
||||||
|
obj.provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||||
|
|
||||||
|
if obj.provider_id not in self.impls_by_provider_id:
|
||||||
|
raise ValueError(f"Provider `{obj.provider_id}` not found")
|
||||||
|
|
||||||
|
p = self.impls_by_provider_id[obj.provider_id]
|
||||||
|
|
||||||
|
# If object supports access control but no attributes set, use creator's attributes
|
||||||
|
if not obj.access_attributes:
|
||||||
|
creator_attributes = get_auth_attributes()
|
||||||
|
if creator_attributes:
|
||||||
|
obj.access_attributes = AccessAttributes(**creator_attributes)
|
||||||
|
logger.info(f"Setting access attributes for {obj.type} '{obj.identifier}' based on creator's identity")
|
||||||
|
|
||||||
|
registered_obj = await register_object_with_provider(obj, p)
|
||||||
|
# TODO: This needs to be fixed for all APIs once they return the registered object
|
||||||
|
if obj.type == ResourceType.model.value:
|
||||||
|
await self.dist_registry.register(registered_obj)
|
||||||
|
return registered_obj
|
||||||
|
|
||||||
|
else:
|
||||||
|
await self.dist_registry.register(obj)
|
||||||
|
return obj
|
||||||
|
|
||||||
|
async def get_all_with_type(self, type: str) -> list[RoutableObjectWithProvider]:
|
||||||
|
objs = await self.dist_registry.get_all()
|
||||||
|
filtered_objs = [obj for obj in objs if obj.type == type]
|
||||||
|
|
||||||
|
# Apply attribute-based access control filtering
|
||||||
|
if filtered_objs:
|
||||||
|
filtered_objs = [
|
||||||
|
obj
|
||||||
|
for obj in filtered_objs
|
||||||
|
if check_access(obj.identifier, getattr(obj, "access_attributes", None), get_auth_attributes())
|
||||||
|
]
|
||||||
|
|
||||||
|
return filtered_objs
|
93
llama_stack/distribution/routing_tables/datasets.py
Normal file
93
llama_stack/distribution/routing_tables/datasets.py
Normal file
|
@ -0,0 +1,93 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from llama_stack.apis.datasets import (
|
||||||
|
Dataset,
|
||||||
|
DatasetPurpose,
|
||||||
|
Datasets,
|
||||||
|
DatasetType,
|
||||||
|
DataSource,
|
||||||
|
ListDatasetsResponse,
|
||||||
|
RowsDataSource,
|
||||||
|
URIDataSource,
|
||||||
|
)
|
||||||
|
from llama_stack.apis.resource import ResourceType
|
||||||
|
from llama_stack.distribution.datatypes import (
|
||||||
|
DatasetWithACL,
|
||||||
|
)
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
|
from .common import CommonRoutingTableImpl
|
||||||
|
|
||||||
|
logger = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
||||||
|
async def list_datasets(self) -> ListDatasetsResponse:
|
||||||
|
return ListDatasetsResponse(data=await self.get_all_with_type(ResourceType.dataset.value))
|
||||||
|
|
||||||
|
async def get_dataset(self, dataset_id: str) -> Dataset:
|
||||||
|
dataset = await self.get_object_by_identifier("dataset", dataset_id)
|
||||||
|
if dataset is None:
|
||||||
|
raise ValueError(f"Dataset '{dataset_id}' not found")
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
async def register_dataset(
|
||||||
|
self,
|
||||||
|
purpose: DatasetPurpose,
|
||||||
|
source: DataSource,
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
|
dataset_id: str | None = None,
|
||||||
|
) -> Dataset:
|
||||||
|
if isinstance(source, dict):
|
||||||
|
if source["type"] == "uri":
|
||||||
|
source = URIDataSource.parse_obj(source)
|
||||||
|
elif source["type"] == "rows":
|
||||||
|
source = RowsDataSource.parse_obj(source)
|
||||||
|
|
||||||
|
if not dataset_id:
|
||||||
|
dataset_id = f"dataset-{str(uuid.uuid4())}"
|
||||||
|
|
||||||
|
provider_dataset_id = dataset_id
|
||||||
|
|
||||||
|
# infer provider from source
|
||||||
|
if metadata:
|
||||||
|
if metadata.get("provider_id"):
|
||||||
|
provider_id = metadata.get("provider_id") # pass through from nvidia datasetio
|
||||||
|
elif source.type == DatasetType.rows.value:
|
||||||
|
provider_id = "localfs"
|
||||||
|
elif source.type == DatasetType.uri.value:
|
||||||
|
# infer provider from uri
|
||||||
|
if source.uri.startswith("huggingface"):
|
||||||
|
provider_id = "huggingface"
|
||||||
|
else:
|
||||||
|
provider_id = "localfs"
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown data source type: {source.type}")
|
||||||
|
|
||||||
|
if metadata is None:
|
||||||
|
metadata = {}
|
||||||
|
|
||||||
|
dataset = DatasetWithACL(
|
||||||
|
identifier=dataset_id,
|
||||||
|
provider_resource_id=provider_dataset_id,
|
||||||
|
provider_id=provider_id,
|
||||||
|
purpose=purpose,
|
||||||
|
source=source,
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
await self.register_object(dataset)
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
async def unregister_dataset(self, dataset_id: str) -> None:
|
||||||
|
dataset = await self.get_dataset(dataset_id)
|
||||||
|
if dataset is None:
|
||||||
|
raise ValueError(f"Dataset {dataset_id} not found")
|
||||||
|
await self.unregister_object(dataset)
|
82
llama_stack/distribution/routing_tables/models.py
Normal file
82
llama_stack/distribution/routing_tables/models.py
Normal file
|
@ -0,0 +1,82 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import time
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType, OpenAIListModelsResponse, OpenAIModel
|
||||||
|
from llama_stack.distribution.datatypes import (
|
||||||
|
ModelWithACL,
|
||||||
|
)
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
|
from .common import CommonRoutingTableImpl
|
||||||
|
|
||||||
|
logger = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
|
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
|
async def list_models(self) -> ListModelsResponse:
|
||||||
|
return ListModelsResponse(data=await self.get_all_with_type("model"))
|
||||||
|
|
||||||
|
async def openai_list_models(self) -> OpenAIListModelsResponse:
|
||||||
|
models = await self.get_all_with_type("model")
|
||||||
|
openai_models = [
|
||||||
|
OpenAIModel(
|
||||||
|
id=model.identifier,
|
||||||
|
object="model",
|
||||||
|
created=int(time.time()),
|
||||||
|
owned_by="llama_stack",
|
||||||
|
)
|
||||||
|
for model in models
|
||||||
|
]
|
||||||
|
return OpenAIListModelsResponse(data=openai_models)
|
||||||
|
|
||||||
|
async def get_model(self, model_id: str) -> Model:
|
||||||
|
model = await self.get_object_by_identifier("model", model_id)
|
||||||
|
if model is None:
|
||||||
|
raise ValueError(f"Model '{model_id}' not found")
|
||||||
|
return model
|
||||||
|
|
||||||
|
async def register_model(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
provider_model_id: str | None = None,
|
||||||
|
provider_id: str | None = None,
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
|
model_type: ModelType | None = None,
|
||||||
|
) -> Model:
|
||||||
|
if provider_model_id is None:
|
||||||
|
provider_model_id = model_id
|
||||||
|
if provider_id is None:
|
||||||
|
# If provider_id not specified, use the only provider if it supports this model
|
||||||
|
if len(self.impls_by_provider_id) == 1:
|
||||||
|
provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"No provider specified and multiple providers available. Please specify a provider_id. Available providers: {self.impls_by_provider_id.keys()}"
|
||||||
|
)
|
||||||
|
if metadata is None:
|
||||||
|
metadata = {}
|
||||||
|
if model_type is None:
|
||||||
|
model_type = ModelType.llm
|
||||||
|
if "embedding_dimension" not in metadata and model_type == ModelType.embedding:
|
||||||
|
raise ValueError("Embedding model must have an embedding dimension in its metadata")
|
||||||
|
model = ModelWithACL(
|
||||||
|
identifier=model_id,
|
||||||
|
provider_resource_id=provider_model_id,
|
||||||
|
provider_id=provider_id,
|
||||||
|
metadata=metadata,
|
||||||
|
model_type=model_type,
|
||||||
|
)
|
||||||
|
registered_model = await self.register_object(model)
|
||||||
|
return registered_model
|
||||||
|
|
||||||
|
async def unregister_model(self, model_id: str) -> None:
|
||||||
|
existing_model = await self.get_model(model_id)
|
||||||
|
if existing_model is None:
|
||||||
|
raise ValueError(f"Model {model_id} not found")
|
||||||
|
await self.unregister_object(existing_model)
|
62
llama_stack/distribution/routing_tables/scoring_functions.py
Normal file
62
llama_stack/distribution/routing_tables/scoring_functions.py
Normal file
|
@ -0,0 +1,62 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from llama_stack.apis.common.type_system import ParamType
|
||||||
|
from llama_stack.apis.resource import ResourceType
|
||||||
|
from llama_stack.apis.scoring_functions import (
|
||||||
|
ListScoringFunctionsResponse,
|
||||||
|
ScoringFn,
|
||||||
|
ScoringFnParams,
|
||||||
|
ScoringFunctions,
|
||||||
|
)
|
||||||
|
from llama_stack.distribution.datatypes import (
|
||||||
|
ScoringFnWithACL,
|
||||||
|
)
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
|
from .common import CommonRoutingTableImpl
|
||||||
|
|
||||||
|
logger = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
|
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
|
||||||
|
async def list_scoring_functions(self) -> ListScoringFunctionsResponse:
|
||||||
|
return ListScoringFunctionsResponse(data=await self.get_all_with_type(ResourceType.scoring_function.value))
|
||||||
|
|
||||||
|
async def get_scoring_function(self, scoring_fn_id: str) -> ScoringFn:
|
||||||
|
scoring_fn = await self.get_object_by_identifier("scoring_function", scoring_fn_id)
|
||||||
|
if scoring_fn is None:
|
||||||
|
raise ValueError(f"Scoring function '{scoring_fn_id}' not found")
|
||||||
|
return scoring_fn
|
||||||
|
|
||||||
|
async def register_scoring_function(
|
||||||
|
self,
|
||||||
|
scoring_fn_id: str,
|
||||||
|
description: str,
|
||||||
|
return_type: ParamType,
|
||||||
|
provider_scoring_fn_id: str | None = None,
|
||||||
|
provider_id: str | None = None,
|
||||||
|
params: ScoringFnParams | None = None,
|
||||||
|
) -> None:
|
||||||
|
if provider_scoring_fn_id is None:
|
||||||
|
provider_scoring_fn_id = scoring_fn_id
|
||||||
|
if provider_id is None:
|
||||||
|
if len(self.impls_by_provider_id) == 1:
|
||||||
|
provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"No provider specified and multiple providers available. Please specify a provider_id."
|
||||||
|
)
|
||||||
|
scoring_fn = ScoringFnWithACL(
|
||||||
|
identifier=scoring_fn_id,
|
||||||
|
description=description,
|
||||||
|
return_type=return_type,
|
||||||
|
provider_resource_id=provider_scoring_fn_id,
|
||||||
|
provider_id=provider_id,
|
||||||
|
params=params,
|
||||||
|
)
|
||||||
|
scoring_fn.provider_id = provider_id
|
||||||
|
await self.register_object(scoring_fn)
|
57
llama_stack/distribution/routing_tables/shields.py
Normal file
57
llama_stack/distribution/routing_tables/shields.py
Normal file
|
@ -0,0 +1,57 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from llama_stack.apis.resource import ResourceType
|
||||||
|
from llama_stack.apis.shields import ListShieldsResponse, Shield, Shields
|
||||||
|
from llama_stack.distribution.datatypes import (
|
||||||
|
ShieldWithACL,
|
||||||
|
)
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
|
from .common import CommonRoutingTableImpl
|
||||||
|
|
||||||
|
logger = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
|
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
||||||
|
async def list_shields(self) -> ListShieldsResponse:
|
||||||
|
return ListShieldsResponse(data=await self.get_all_with_type(ResourceType.shield.value))
|
||||||
|
|
||||||
|
async def get_shield(self, identifier: str) -> Shield:
|
||||||
|
shield = await self.get_object_by_identifier("shield", identifier)
|
||||||
|
if shield is None:
|
||||||
|
raise ValueError(f"Shield '{identifier}' not found")
|
||||||
|
return shield
|
||||||
|
|
||||||
|
async def register_shield(
|
||||||
|
self,
|
||||||
|
shield_id: str,
|
||||||
|
provider_shield_id: str | None = None,
|
||||||
|
provider_id: str | None = None,
|
||||||
|
params: dict[str, Any] | None = None,
|
||||||
|
) -> Shield:
|
||||||
|
if provider_shield_id is None:
|
||||||
|
provider_shield_id = shield_id
|
||||||
|
if provider_id is None:
|
||||||
|
# If provider_id not specified, use the only provider if it supports this shield type
|
||||||
|
if len(self.impls_by_provider_id) == 1:
|
||||||
|
provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"No provider specified and multiple providers available. Please specify a provider_id."
|
||||||
|
)
|
||||||
|
if params is None:
|
||||||
|
params = {}
|
||||||
|
shield = ShieldWithACL(
|
||||||
|
identifier=shield_id,
|
||||||
|
provider_resource_id=provider_shield_id,
|
||||||
|
provider_id=provider_id,
|
||||||
|
params=params,
|
||||||
|
)
|
||||||
|
await self.register_object(shield)
|
||||||
|
return shield
|
132
llama_stack/distribution/routing_tables/toolgroups.py
Normal file
132
llama_stack/distribution/routing_tables/toolgroups.py
Normal file
|
@ -0,0 +1,132 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from llama_stack.apis.common.content_types import URL
|
||||||
|
from llama_stack.apis.tools import ListToolGroupsResponse, ListToolsResponse, Tool, ToolGroup, ToolGroups
|
||||||
|
from llama_stack.distribution.datatypes import ToolGroupWithACL
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
|
from .common import CommonRoutingTableImpl
|
||||||
|
|
||||||
|
logger = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
|
def parse_toolgroup_from_toolgroup_name_pair(toolgroup_name_with_maybe_tool_name: str) -> str | None:
|
||||||
|
# handle the funny case like "builtin::rag/knowledge_search"
|
||||||
|
parts = toolgroup_name_with_maybe_tool_name.split("/")
|
||||||
|
if len(parts) == 2:
|
||||||
|
return parts[0]
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
||||||
|
toolgroups_to_tools: dict[str, list[Tool]] = {}
|
||||||
|
tool_to_toolgroup: dict[str, str] = {}
|
||||||
|
|
||||||
|
# overridden
|
||||||
|
def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any:
|
||||||
|
# we don't index tools in the registry anymore, but only keep a cache of them by toolgroup_id
|
||||||
|
# TODO: we may want to invalidate the cache (for a given toolgroup_id) every once in a while?
|
||||||
|
|
||||||
|
toolgroup_id = parse_toolgroup_from_toolgroup_name_pair(routing_key)
|
||||||
|
if toolgroup_id:
|
||||||
|
routing_key = toolgroup_id
|
||||||
|
|
||||||
|
if routing_key in self.tool_to_toolgroup:
|
||||||
|
routing_key = self.tool_to_toolgroup[routing_key]
|
||||||
|
return super().get_provider_impl(routing_key, provider_id)
|
||||||
|
|
||||||
|
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse:
|
||||||
|
if toolgroup_id:
|
||||||
|
if group_id := parse_toolgroup_from_toolgroup_name_pair(toolgroup_id):
|
||||||
|
toolgroup_id = group_id
|
||||||
|
toolgroups = [await self.get_tool_group(toolgroup_id)]
|
||||||
|
else:
|
||||||
|
toolgroups = await self.get_all_with_type("tool_group")
|
||||||
|
|
||||||
|
all_tools = []
|
||||||
|
for toolgroup in toolgroups:
|
||||||
|
if toolgroup.identifier not in self.toolgroups_to_tools:
|
||||||
|
await self._index_tools(toolgroup)
|
||||||
|
all_tools.extend(self.toolgroups_to_tools[toolgroup.identifier])
|
||||||
|
|
||||||
|
return ListToolsResponse(data=all_tools)
|
||||||
|
|
||||||
|
async def _index_tools(self, toolgroup: ToolGroup):
|
||||||
|
provider_impl = super().get_provider_impl(toolgroup.identifier, toolgroup.provider_id)
|
||||||
|
tooldefs_response = await provider_impl.list_runtime_tools(toolgroup.identifier, toolgroup.mcp_endpoint)
|
||||||
|
|
||||||
|
# TODO: kill this Tool vs ToolDef distinction
|
||||||
|
tooldefs = tooldefs_response.data
|
||||||
|
tools = []
|
||||||
|
for t in tooldefs:
|
||||||
|
tools.append(
|
||||||
|
Tool(
|
||||||
|
identifier=t.name,
|
||||||
|
toolgroup_id=toolgroup.identifier,
|
||||||
|
description=t.description or "",
|
||||||
|
parameters=t.parameters or [],
|
||||||
|
metadata=t.metadata,
|
||||||
|
provider_id=toolgroup.provider_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.toolgroups_to_tools[toolgroup.identifier] = tools
|
||||||
|
for tool in tools:
|
||||||
|
self.tool_to_toolgroup[tool.identifier] = toolgroup.identifier
|
||||||
|
|
||||||
|
async def list_tool_groups(self) -> ListToolGroupsResponse:
|
||||||
|
return ListToolGroupsResponse(data=await self.get_all_with_type("tool_group"))
|
||||||
|
|
||||||
|
async def get_tool_group(self, toolgroup_id: str) -> ToolGroup:
|
||||||
|
tool_group = await self.get_object_by_identifier("tool_group", toolgroup_id)
|
||||||
|
if tool_group is None:
|
||||||
|
raise ValueError(f"Tool group '{toolgroup_id}' not found")
|
||||||
|
return tool_group
|
||||||
|
|
||||||
|
async def get_tool(self, tool_name: str) -> Tool:
|
||||||
|
if tool_name in self.tool_to_toolgroup:
|
||||||
|
toolgroup_id = self.tool_to_toolgroup[tool_name]
|
||||||
|
tools = self.toolgroups_to_tools[toolgroup_id]
|
||||||
|
for tool in tools:
|
||||||
|
if tool.identifier == tool_name:
|
||||||
|
return tool
|
||||||
|
raise ValueError(f"Tool '{tool_name}' not found")
|
||||||
|
|
||||||
|
async def register_tool_group(
|
||||||
|
self,
|
||||||
|
toolgroup_id: str,
|
||||||
|
provider_id: str,
|
||||||
|
mcp_endpoint: URL | None = None,
|
||||||
|
args: dict[str, Any] | None = None,
|
||||||
|
) -> None:
|
||||||
|
toolgroup = ToolGroupWithACL(
|
||||||
|
identifier=toolgroup_id,
|
||||||
|
provider_id=provider_id,
|
||||||
|
provider_resource_id=toolgroup_id,
|
||||||
|
mcp_endpoint=mcp_endpoint,
|
||||||
|
args=args,
|
||||||
|
)
|
||||||
|
await self.register_object(toolgroup)
|
||||||
|
|
||||||
|
# ideally, indexing of the tools should not be necessary because anyone using
|
||||||
|
# the tools should first list the tools and then use them. but there are assumptions
|
||||||
|
# baked in some of the code and tests right now.
|
||||||
|
if not toolgroup.mcp_endpoint:
|
||||||
|
await self._index_tools(toolgroup)
|
||||||
|
return toolgroup
|
||||||
|
|
||||||
|
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
|
||||||
|
tool_group = await self.get_tool_group(toolgroup_id)
|
||||||
|
if tool_group is None:
|
||||||
|
raise ValueError(f"Tool group {toolgroup_id} not found")
|
||||||
|
await self.unregister_object(tool_group)
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
pass
|
74
llama_stack/distribution/routing_tables/vector_dbs.py
Normal file
74
llama_stack/distribution/routing_tables/vector_dbs.py
Normal file
|
@ -0,0 +1,74 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from pydantic import TypeAdapter
|
||||||
|
|
||||||
|
from llama_stack.apis.models import ModelType
|
||||||
|
from llama_stack.apis.resource import ResourceType
|
||||||
|
from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs
|
||||||
|
from llama_stack.distribution.datatypes import (
|
||||||
|
VectorDBWithACL,
|
||||||
|
)
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
|
from .common import CommonRoutingTableImpl
|
||||||
|
|
||||||
|
logger = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
|
class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
||||||
|
async def list_vector_dbs(self) -> ListVectorDBsResponse:
|
||||||
|
return ListVectorDBsResponse(data=await self.get_all_with_type("vector_db"))
|
||||||
|
|
||||||
|
async def get_vector_db(self, vector_db_id: str) -> VectorDB:
|
||||||
|
vector_db = await self.get_object_by_identifier("vector_db", vector_db_id)
|
||||||
|
if vector_db is None:
|
||||||
|
raise ValueError(f"Vector DB '{vector_db_id}' not found")
|
||||||
|
return vector_db
|
||||||
|
|
||||||
|
async def register_vector_db(
|
||||||
|
self,
|
||||||
|
vector_db_id: str,
|
||||||
|
embedding_model: str,
|
||||||
|
embedding_dimension: int | None = 384,
|
||||||
|
provider_id: str | None = None,
|
||||||
|
provider_vector_db_id: str | None = None,
|
||||||
|
) -> VectorDB:
|
||||||
|
if provider_vector_db_id is None:
|
||||||
|
provider_vector_db_id = vector_db_id
|
||||||
|
if provider_id is None:
|
||||||
|
if len(self.impls_by_provider_id) > 0:
|
||||||
|
provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||||
|
if len(self.impls_by_provider_id) > 1:
|
||||||
|
logger.warning(
|
||||||
|
f"No provider specified and multiple providers available. Arbitrarily selected the first provider {provider_id}."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("No provider available. Please configure a vector_io provider.")
|
||||||
|
model = await self.get_object_by_identifier("model", embedding_model)
|
||||||
|
if model is None:
|
||||||
|
raise ValueError(f"Model {embedding_model} not found")
|
||||||
|
if model.model_type != ModelType.embedding:
|
||||||
|
raise ValueError(f"Model {embedding_model} is not an embedding model")
|
||||||
|
if "embedding_dimension" not in model.metadata:
|
||||||
|
raise ValueError(f"Model {embedding_model} does not have an embedding dimension")
|
||||||
|
vector_db_data = {
|
||||||
|
"identifier": vector_db_id,
|
||||||
|
"type": ResourceType.vector_db.value,
|
||||||
|
"provider_id": provider_id,
|
||||||
|
"provider_resource_id": provider_vector_db_id,
|
||||||
|
"embedding_model": embedding_model,
|
||||||
|
"embedding_dimension": model.metadata["embedding_dimension"],
|
||||||
|
}
|
||||||
|
vector_db = TypeAdapter(VectorDBWithACL).validate_python(vector_db_data)
|
||||||
|
await self.register_object(vector_db)
|
||||||
|
return vector_db
|
||||||
|
|
||||||
|
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||||
|
existing_vector_db = await self.get_vector_db(vector_db_id)
|
||||||
|
if existing_vector_db is None:
|
||||||
|
raise ValueError(f"Vector DB {vector_db_id} not found")
|
||||||
|
await self.unregister_object(existing_vector_db)
|
|
@ -8,7 +8,8 @@ import json
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from llama_stack.distribution.server.auth_providers import AuthProviderConfig, create_auth_provider
|
from llama_stack.distribution.datatypes import AuthenticationConfig
|
||||||
|
from llama_stack.distribution.server.auth_providers import create_auth_provider
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="auth")
|
logger = get_logger(name=__name__, category="auth")
|
||||||
|
@ -77,7 +78,7 @@ class AuthenticationMiddleware:
|
||||||
access resources that don't have access_attributes defined.
|
access resources that don't have access_attributes defined.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, app, auth_config: AuthProviderConfig):
|
def __init__(self, app, auth_config: AuthenticationConfig):
|
||||||
self.app = app
|
self.app = app
|
||||||
self.auth_provider = create_auth_provider(auth_config)
|
self.auth_provider = create_auth_provider(auth_config)
|
||||||
|
|
||||||
|
@ -113,6 +114,10 @@ class AuthenticationMiddleware:
|
||||||
"roles": [token],
|
"roles": [token],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Store the client ID in the request scope so that downstream middleware (like QuotaMiddleware)
|
||||||
|
# can identify the requester and enforce per-client rate limits.
|
||||||
|
scope["authenticated_client_id"] = token
|
||||||
|
|
||||||
# Store attributes in request scope
|
# Store attributes in request scope
|
||||||
scope["user_attributes"] = user_attributes
|
scope["user_attributes"] = user_attributes
|
||||||
scope["principal"] = validation_result.principal
|
scope["principal"] = validation_result.principal
|
||||||
|
|
|
@ -4,17 +4,19 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import json
|
import ssl
|
||||||
import time
|
import time
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from enum import Enum
|
from asyncio import Lock
|
||||||
|
from pathlib import Path
|
||||||
from urllib.parse import parse_qs
|
from urllib.parse import parse_qs
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from jose import jwt
|
from jose import jwt
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||||
|
from typing_extensions import Self
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import AccessAttributes
|
from llama_stack.distribution.datatypes import AccessAttributes, AuthenticationConfig, AuthProviderType
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="auth")
|
logger = get_logger(name=__name__, category="auth")
|
||||||
|
@ -72,21 +74,6 @@ class AuthRequest(BaseModel):
|
||||||
request: AuthRequestContext = Field(description="Context information about the request being authenticated")
|
request: AuthRequestContext = Field(description="Context information about the request being authenticated")
|
||||||
|
|
||||||
|
|
||||||
class AuthProviderType(str, Enum):
|
|
||||||
"""Supported authentication provider types."""
|
|
||||||
|
|
||||||
KUBERNETES = "kubernetes"
|
|
||||||
CUSTOM = "custom"
|
|
||||||
OAUTH2_TOKEN = "oauth2_token"
|
|
||||||
|
|
||||||
|
|
||||||
class AuthProviderConfig(BaseModel):
|
|
||||||
"""Base configuration for authentication providers."""
|
|
||||||
|
|
||||||
provider_type: AuthProviderType = Field(..., description="Type of authentication provider")
|
|
||||||
config: dict[str, str] = Field(..., description="Provider-specific configuration")
|
|
||||||
|
|
||||||
|
|
||||||
class AuthProvider(ABC):
|
class AuthProvider(ABC):
|
||||||
"""Abstract base class for authentication providers."""
|
"""Abstract base class for authentication providers."""
|
||||||
|
|
||||||
|
@ -101,83 +88,6 @@ class AuthProvider(ABC):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class KubernetesAuthProviderConfig(BaseModel):
|
|
||||||
api_server_url: str
|
|
||||||
ca_cert_path: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class KubernetesAuthProvider(AuthProvider):
|
|
||||||
"""Kubernetes authentication provider that validates tokens against the Kubernetes API server."""
|
|
||||||
|
|
||||||
def __init__(self, config: KubernetesAuthProviderConfig):
|
|
||||||
self.config = config
|
|
||||||
self._client = None
|
|
||||||
|
|
||||||
async def _get_client(self):
|
|
||||||
"""Get or create a Kubernetes client."""
|
|
||||||
if self._client is None:
|
|
||||||
# kubernetes-client has not async support, see:
|
|
||||||
# https://github.com/kubernetes-client/python/issues/323
|
|
||||||
from kubernetes import client
|
|
||||||
from kubernetes.client import ApiClient
|
|
||||||
|
|
||||||
# Configure the client
|
|
||||||
configuration = client.Configuration()
|
|
||||||
configuration.host = self.config.api_server_url
|
|
||||||
if self.config.ca_cert_path:
|
|
||||||
configuration.ssl_ca_cert = self.config.ca_cert_path
|
|
||||||
configuration.verify_ssl = bool(self.config.ca_cert_path)
|
|
||||||
|
|
||||||
# Create API client
|
|
||||||
self._client = ApiClient(configuration)
|
|
||||||
return self._client
|
|
||||||
|
|
||||||
async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
|
|
||||||
"""Validate a Kubernetes token and return access attributes."""
|
|
||||||
try:
|
|
||||||
client = await self._get_client()
|
|
||||||
|
|
||||||
# Set the token in the client
|
|
||||||
client.set_default_header("Authorization", f"Bearer {token}")
|
|
||||||
|
|
||||||
# Make a request to validate the token
|
|
||||||
# We use the /api endpoint which requires authentication
|
|
||||||
from kubernetes.client import CoreV1Api
|
|
||||||
|
|
||||||
api = CoreV1Api(client)
|
|
||||||
api.get_api_resources(_request_timeout=3.0) # Set timeout for this specific request
|
|
||||||
|
|
||||||
# If we get here, the token is valid
|
|
||||||
# Extract user info from the token claims
|
|
||||||
import base64
|
|
||||||
|
|
||||||
# Decode the token (without verification since we've already validated it)
|
|
||||||
token_parts = token.split(".")
|
|
||||||
payload = json.loads(base64.b64decode(token_parts[1] + "=" * (-len(token_parts[1]) % 4)))
|
|
||||||
|
|
||||||
# Extract user information from the token
|
|
||||||
username = payload.get("sub", "")
|
|
||||||
groups = payload.get("groups", [])
|
|
||||||
|
|
||||||
return TokenValidationResult(
|
|
||||||
principal=username,
|
|
||||||
access_attributes=AccessAttributes(
|
|
||||||
roles=[username], # Use username as a role
|
|
||||||
teams=groups, # Use Kubernetes groups as teams
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.exception("Failed to validate Kubernetes token")
|
|
||||||
raise ValueError("Invalid or expired token") from e
|
|
||||||
|
|
||||||
async def close(self):
|
|
||||||
"""Close the HTTP client."""
|
|
||||||
if self._client:
|
|
||||||
self._client.close()
|
|
||||||
self._client = None
|
|
||||||
|
|
||||||
|
|
||||||
def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str]) -> AccessAttributes:
|
def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str]) -> AccessAttributes:
|
||||||
attributes = AccessAttributes()
|
attributes = AccessAttributes()
|
||||||
for claim_key, attribute_key in mapping.items():
|
for claim_key, attribute_key in mapping.items():
|
||||||
|
@ -197,11 +107,24 @@ def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str])
|
||||||
return attributes
|
return attributes
|
||||||
|
|
||||||
|
|
||||||
class OAuth2TokenAuthProviderConfig(BaseModel):
|
class OAuth2JWKSConfig(BaseModel):
|
||||||
# The JWKS URI for collecting public keys
|
# The JWKS URI for collecting public keys
|
||||||
jwks_uri: str
|
uri: str
|
||||||
cache_ttl: int = 3600
|
key_recheck_period: int = Field(default=3600, description="The period to recheck the JWKS URI for key updates")
|
||||||
|
|
||||||
|
|
||||||
|
class OAuth2IntrospectionConfig(BaseModel):
|
||||||
|
url: str
|
||||||
|
client_id: str
|
||||||
|
client_secret: str
|
||||||
|
send_secret_in_body: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class OAuth2TokenAuthProviderConfig(BaseModel):
|
||||||
audience: str = "llama-stack"
|
audience: str = "llama-stack"
|
||||||
|
verify_tls: bool = True
|
||||||
|
tls_cafile: Path | None = None
|
||||||
|
issuer: str | None = Field(default=None, description="The OIDC issuer URL.")
|
||||||
claims_mapping: dict[str, str] = Field(
|
claims_mapping: dict[str, str] = Field(
|
||||||
default_factory=lambda: {
|
default_factory=lambda: {
|
||||||
"sub": "roles",
|
"sub": "roles",
|
||||||
|
@ -213,6 +136,8 @@ class OAuth2TokenAuthProviderConfig(BaseModel):
|
||||||
"namespace": "namespaces",
|
"namespace": "namespaces",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
jwks: OAuth2JWKSConfig | None
|
||||||
|
introspection: OAuth2IntrospectionConfig | None = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@field_validator("claims_mapping")
|
@field_validator("claims_mapping")
|
||||||
|
@ -224,6 +149,14 @@ class OAuth2TokenAuthProviderConfig(BaseModel):
|
||||||
raise ValueError(f"claims_mapping value is not a valid attribute: {value}")
|
raise ValueError(f"claims_mapping value is not a valid attribute: {value}")
|
||||||
return v
|
return v
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def validate_mode(self) -> Self:
|
||||||
|
if not self.jwks and not self.introspection:
|
||||||
|
raise ValueError("One of jwks or introspection must be configured")
|
||||||
|
if self.jwks and self.introspection:
|
||||||
|
raise ValueError("At present only one of jwks or introspection should be configured")
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
class OAuth2TokenAuthProvider(AuthProvider):
|
class OAuth2TokenAuthProvider(AuthProvider):
|
||||||
"""
|
"""
|
||||||
|
@ -236,8 +169,16 @@ class OAuth2TokenAuthProvider(AuthProvider):
|
||||||
self.config = config
|
self.config = config
|
||||||
self._jwks_at: float = 0.0
|
self._jwks_at: float = 0.0
|
||||||
self._jwks: dict[str, str] = {}
|
self._jwks: dict[str, str] = {}
|
||||||
|
self._jwks_lock = Lock()
|
||||||
|
|
||||||
async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
|
async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
|
||||||
|
if self.config.jwks:
|
||||||
|
return await self.validate_jwt_token(token, scope)
|
||||||
|
if self.config.introspection:
|
||||||
|
return await self.introspect_token(token, scope)
|
||||||
|
raise ValueError("One of jwks or introspection must be configured")
|
||||||
|
|
||||||
|
async def validate_jwt_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
|
||||||
"""Validate a token using the JWT token."""
|
"""Validate a token using the JWT token."""
|
||||||
await self._refresh_jwks()
|
await self._refresh_jwks()
|
||||||
|
|
||||||
|
@ -253,7 +194,7 @@ class OAuth2TokenAuthProvider(AuthProvider):
|
||||||
key_data,
|
key_data,
|
||||||
algorithms=[algorithm],
|
algorithms=[algorithm],
|
||||||
audience=self.config.audience,
|
audience=self.config.audience,
|
||||||
options={"verify_exp": True},
|
issuer=self.config.issuer,
|
||||||
)
|
)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
raise ValueError(f"Invalid JWT token: {token}") from exc
|
raise ValueError(f"Invalid JWT token: {token}") from exc
|
||||||
|
@ -267,20 +208,83 @@ class OAuth2TokenAuthProvider(AuthProvider):
|
||||||
access_attributes=access_attributes,
|
access_attributes=access_attributes,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def introspect_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
|
||||||
|
"""Validate a token using token introspection as defined by RFC 7662."""
|
||||||
|
form = {
|
||||||
|
"token": token,
|
||||||
|
}
|
||||||
|
if self.config.introspection is None:
|
||||||
|
raise ValueError("Introspection is not configured")
|
||||||
|
|
||||||
|
if self.config.introspection.send_secret_in_body:
|
||||||
|
form["client_id"] = self.config.introspection.client_id
|
||||||
|
form["client_secret"] = self.config.introspection.client_secret
|
||||||
|
auth = None
|
||||||
|
else:
|
||||||
|
auth = (self.config.introspection.client_id, self.config.introspection.client_secret)
|
||||||
|
ssl_ctxt = None
|
||||||
|
if self.config.tls_cafile:
|
||||||
|
ssl_ctxt = ssl.create_default_context(cafile=self.config.tls_cafile.as_posix())
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient(verify=ssl_ctxt) as client:
|
||||||
|
response = await client.post(
|
||||||
|
self.config.introspection.url,
|
||||||
|
data=form,
|
||||||
|
auth=auth,
|
||||||
|
timeout=10.0, # Add a reasonable timeout
|
||||||
|
)
|
||||||
|
if response.status_code != 200:
|
||||||
|
logger.warning(f"Token introspection failed with status code: {response.status_code}")
|
||||||
|
raise ValueError(f"Token introspection failed: {response.status_code}")
|
||||||
|
|
||||||
|
fields = response.json()
|
||||||
|
if not fields["active"]:
|
||||||
|
raise ValueError("Token not active")
|
||||||
|
principal = fields["sub"] or fields["username"]
|
||||||
|
access_attributes = get_attributes_from_claims(fields, self.config.claims_mapping)
|
||||||
|
return TokenValidationResult(
|
||||||
|
principal=principal,
|
||||||
|
access_attributes=access_attributes,
|
||||||
|
)
|
||||||
|
except httpx.TimeoutException:
|
||||||
|
logger.exception("Token introspection request timed out")
|
||||||
|
raise
|
||||||
|
except ValueError:
|
||||||
|
# Re-raise ValueError exceptions to preserve their message
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Error during token introspection")
|
||||||
|
raise ValueError("Token introspection error") from e
|
||||||
|
|
||||||
async def close(self):
|
async def close(self):
|
||||||
"""Close the HTTP client."""
|
pass
|
||||||
|
|
||||||
async def _refresh_jwks(self) -> None:
|
async def _refresh_jwks(self) -> None:
|
||||||
if time.time() - self._jwks_at > self.config.cache_ttl:
|
"""
|
||||||
async with httpx.AsyncClient() as client:
|
Refresh the JWKS cache.
|
||||||
res = await client.get(self.config.jwks_uri, timeout=5)
|
|
||||||
|
This is a simple cache that expires after a certain amount of time (defined by `key_recheck_period`).
|
||||||
|
If the cache is expired, we refresh the JWKS from the JWKS URI.
|
||||||
|
|
||||||
|
Notes: for Kubernetes which doesn't fully implement the OIDC protocol:
|
||||||
|
* It doesn't have user authentication flows
|
||||||
|
* It doesn't have refresh tokens
|
||||||
|
"""
|
||||||
|
async with self._jwks_lock:
|
||||||
|
if self.config.jwks is None:
|
||||||
|
raise ValueError("JWKS is not configured")
|
||||||
|
if time.time() - self._jwks_at > self.config.jwks.key_recheck_period:
|
||||||
|
verify = self.config.tls_cafile.as_posix() if self.config.tls_cafile else self.config.verify_tls
|
||||||
|
async with httpx.AsyncClient(verify=verify) as client:
|
||||||
|
res = await client.get(self.config.jwks.uri, timeout=5)
|
||||||
res.raise_for_status()
|
res.raise_for_status()
|
||||||
jwks_data = res.json()["keys"]
|
jwks_data = res.json()["keys"]
|
||||||
self._jwks = {}
|
updated = {}
|
||||||
for k in jwks_data:
|
for k in jwks_data:
|
||||||
kid = k["kid"]
|
kid = k["kid"]
|
||||||
# Store the entire key object as it may be needed for different algorithms
|
# Store the entire key object as it may be needed for different algorithms
|
||||||
self._jwks[kid] = k
|
updated[kid] = k
|
||||||
|
self._jwks = updated
|
||||||
self._jwks_at = time.time()
|
self._jwks_at = time.time()
|
||||||
|
|
||||||
|
|
||||||
|
@ -359,13 +363,11 @@ class CustomAuthProvider(AuthProvider):
|
||||||
self._client = None
|
self._client = None
|
||||||
|
|
||||||
|
|
||||||
def create_auth_provider(config: AuthProviderConfig) -> AuthProvider:
|
def create_auth_provider(config: AuthenticationConfig) -> AuthProvider:
|
||||||
"""Factory function to create the appropriate auth provider."""
|
"""Factory function to create the appropriate auth provider."""
|
||||||
provider_type = config.provider_type.lower()
|
provider_type = config.provider_type.lower()
|
||||||
|
|
||||||
if provider_type == "kubernetes":
|
if provider_type == "custom":
|
||||||
return KubernetesAuthProvider(KubernetesAuthProviderConfig.model_validate(config.config))
|
|
||||||
elif provider_type == "custom":
|
|
||||||
return CustomAuthProvider(CustomAuthProviderConfig.model_validate(config.config))
|
return CustomAuthProvider(CustomAuthProviderConfig.model_validate(config.config))
|
||||||
elif provider_type == "oauth2_token":
|
elif provider_type == "oauth2_token":
|
||||||
return OAuth2TokenAuthProvider(OAuth2TokenAuthProviderConfig.model_validate(config.config))
|
return OAuth2TokenAuthProvider(OAuth2TokenAuthProviderConfig.model_validate(config.config))
|
||||||
|
|
110
llama_stack/distribution/server/quota.py
Normal file
110
llama_stack/distribution/server/quota.py
Normal file
|
@ -0,0 +1,110 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
|
||||||
|
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||||
|
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||||
|
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
||||||
|
from llama_stack.providers.utils.kvstore.kvstore import kvstore_impl
|
||||||
|
|
||||||
|
logger = get_logger(name=__name__, category="quota")
|
||||||
|
|
||||||
|
|
||||||
|
class QuotaMiddleware:
|
||||||
|
"""
|
||||||
|
ASGI middleware that enforces separate quotas for authenticated and anonymous clients
|
||||||
|
within a configurable time window.
|
||||||
|
|
||||||
|
- For authenticated requests, it reads the client ID from the
|
||||||
|
`Authorization: Bearer <client_id>` header.
|
||||||
|
- For anonymous requests, it falls back to the IP address of the client.
|
||||||
|
Requests are counted in a KV store (e.g., SQLite), and HTTP 429 is returned
|
||||||
|
once a client exceeds its quota.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
app: ASGIApp,
|
||||||
|
kv_config: KVStoreConfig,
|
||||||
|
anonymous_max_requests: int,
|
||||||
|
authenticated_max_requests: int,
|
||||||
|
window_seconds: int = 86400,
|
||||||
|
):
|
||||||
|
self.app = app
|
||||||
|
self.kv_config = kv_config
|
||||||
|
self.kv: KVStore | None = None
|
||||||
|
self.anonymous_max_requests = anonymous_max_requests
|
||||||
|
self.authenticated_max_requests = authenticated_max_requests
|
||||||
|
self.window_seconds = window_seconds
|
||||||
|
|
||||||
|
if isinstance(self.kv_config, SqliteKVStoreConfig):
|
||||||
|
logger.warning(
|
||||||
|
"QuotaMiddleware: Using SQLite backend. Expiry/TTL is not enforced; cleanup is manual. "
|
||||||
|
f"window_seconds={self.window_seconds}"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _get_kv(self) -> KVStore:
|
||||||
|
if self.kv is None:
|
||||||
|
self.kv = await kvstore_impl(self.kv_config)
|
||||||
|
return self.kv
|
||||||
|
|
||||||
|
async def __call__(self, scope: Scope, receive: Receive, send: Send):
|
||||||
|
if scope["type"] == "http":
|
||||||
|
# pick key & limit based on auth
|
||||||
|
auth_id = scope.get("authenticated_client_id")
|
||||||
|
if auth_id:
|
||||||
|
key_id = auth_id
|
||||||
|
limit = self.authenticated_max_requests
|
||||||
|
else:
|
||||||
|
# fallback to IP
|
||||||
|
client = scope.get("client")
|
||||||
|
key_id = client[0] if client else "anonymous"
|
||||||
|
limit = self.anonymous_max_requests
|
||||||
|
|
||||||
|
current_window = int(time.time() // self.window_seconds)
|
||||||
|
key = f"quota:{key_id}:{current_window}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
kv = await self._get_kv()
|
||||||
|
prev = await kv.get(key) or "0"
|
||||||
|
count = int(prev) + 1
|
||||||
|
|
||||||
|
if int(prev) == 0:
|
||||||
|
# Set with expiration datetime when it is the first request in the window.
|
||||||
|
expiration = datetime.now(timezone.utc) + timedelta(seconds=self.window_seconds)
|
||||||
|
await kv.set(key, str(count), expiration=expiration)
|
||||||
|
else:
|
||||||
|
await kv.set(key, str(count))
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to access KV store for quota")
|
||||||
|
return await self._send_error(send, 500, "Quota service error")
|
||||||
|
|
||||||
|
if count > limit:
|
||||||
|
logger.warning(
|
||||||
|
"Quota exceeded for client %s: %d/%d",
|
||||||
|
key_id,
|
||||||
|
count,
|
||||||
|
limit,
|
||||||
|
)
|
||||||
|
return await self._send_error(send, 429, "Quota exceeded")
|
||||||
|
|
||||||
|
return await self.app(scope, receive, send)
|
||||||
|
|
||||||
|
async def _send_error(self, send: Send, status: int, message: str):
|
||||||
|
await send(
|
||||||
|
{
|
||||||
|
"type": "http.response.start",
|
||||||
|
"status": status,
|
||||||
|
"headers": [[b"content-type", b"application/json"]],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
body = json.dumps({"error": {"message": message}}).encode()
|
||||||
|
await send({"type": "http.response.body", "body": body})
|
|
@ -23,11 +23,12 @@ import yaml
|
||||||
from fastapi import Body, FastAPI, HTTPException, Request
|
from fastapi import Body, FastAPI, HTTPException, Request
|
||||||
from fastapi import Path as FastapiPath
|
from fastapi import Path as FastapiPath
|
||||||
from fastapi.exceptions import RequestValidationError
|
from fastapi.exceptions import RequestValidationError
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.responses import JSONResponse, StreamingResponse
|
from fastapi.responses import JSONResponse, StreamingResponse
|
||||||
from openai import BadRequestError
|
from openai import BadRequestError
|
||||||
from pydantic import BaseModel, ValidationError
|
from pydantic import BaseModel, ValidationError
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import LoggingConfig, StackRunConfig
|
from llama_stack.distribution.datatypes import AuthenticationRequiredError, LoggingConfig, StackRunConfig
|
||||||
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
||||||
from llama_stack.distribution.request_headers import (
|
from llama_stack.distribution.request_headers import (
|
||||||
PROVIDER_DATA_VAR,
|
PROVIDER_DATA_VAR,
|
||||||
|
@ -60,6 +61,7 @@ from llama_stack.providers.utils.telemetry.tracing import (
|
||||||
|
|
||||||
from .auth import AuthenticationMiddleware
|
from .auth import AuthenticationMiddleware
|
||||||
from .endpoints import get_all_api_endpoints
|
from .endpoints import get_all_api_endpoints
|
||||||
|
from .quota import QuotaMiddleware
|
||||||
|
|
||||||
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
||||||
|
|
||||||
|
@ -120,6 +122,8 @@ def translate_exception(exc: Exception) -> HTTPException | RequestValidationErro
|
||||||
return HTTPException(status_code=504, detail=f"Operation timed out: {str(exc)}")
|
return HTTPException(status_code=504, detail=f"Operation timed out: {str(exc)}")
|
||||||
elif isinstance(exc, NotImplementedError):
|
elif isinstance(exc, NotImplementedError):
|
||||||
return HTTPException(status_code=501, detail=f"Not implemented: {str(exc)}")
|
return HTTPException(status_code=501, detail=f"Not implemented: {str(exc)}")
|
||||||
|
elif isinstance(exc, AuthenticationRequiredError):
|
||||||
|
return HTTPException(status_code=401, detail=f"Authentication required: {str(exc)}")
|
||||||
else:
|
else:
|
||||||
return HTTPException(
|
return HTTPException(
|
||||||
status_code=500,
|
status_code=500,
|
||||||
|
@ -280,7 +284,18 @@ class TracingMiddleware:
|
||||||
logger.debug(f"No matching endpoint found for path: {path}, falling back to FastAPI")
|
logger.debug(f"No matching endpoint found for path: {path}, falling back to FastAPI")
|
||||||
return await self.app(scope, receive, send)
|
return await self.app(scope, receive, send)
|
||||||
|
|
||||||
trace_context = await start_trace(trace_path, {"__location__": "server", "raw_path": path})
|
trace_attributes = {"__location__": "server", "raw_path": path}
|
||||||
|
|
||||||
|
# Extract W3C trace context headers and store as trace attributes
|
||||||
|
headers = dict(scope.get("headers", []))
|
||||||
|
traceparent = headers.get(b"traceparent", b"").decode()
|
||||||
|
if traceparent:
|
||||||
|
trace_attributes["traceparent"] = traceparent
|
||||||
|
tracestate = headers.get(b"tracestate", b"").decode()
|
||||||
|
if tracestate:
|
||||||
|
trace_attributes["tracestate"] = tracestate
|
||||||
|
|
||||||
|
trace_context = await start_trace(trace_path, trace_attributes)
|
||||||
|
|
||||||
async def send_with_trace_id(message):
|
async def send_with_trace_id(message):
|
||||||
if message["type"] == "http.response.start":
|
if message["type"] == "http.response.start":
|
||||||
|
@ -370,14 +385,6 @@ def main(args: argparse.Namespace | None = None):
|
||||||
if args is None:
|
if args is None:
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Check for deprecated argument usage
|
|
||||||
if "--config" in sys.argv:
|
|
||||||
warnings.warn(
|
|
||||||
"The '--config' argument is deprecated and will be removed in a future version. Use '--config' instead.",
|
|
||||||
DeprecationWarning,
|
|
||||||
stacklevel=2,
|
|
||||||
)
|
|
||||||
|
|
||||||
log_line = ""
|
log_line = ""
|
||||||
if args.config:
|
if args.config:
|
||||||
# if the user provided a config file, use it, even if template was specified
|
# if the user provided a config file, use it, even if template was specified
|
||||||
|
@ -431,6 +438,46 @@ def main(args: argparse.Namespace | None = None):
|
||||||
if config.server.auth:
|
if config.server.auth:
|
||||||
logger.info(f"Enabling authentication with provider: {config.server.auth.provider_type.value}")
|
logger.info(f"Enabling authentication with provider: {config.server.auth.provider_type.value}")
|
||||||
app.add_middleware(AuthenticationMiddleware, auth_config=config.server.auth)
|
app.add_middleware(AuthenticationMiddleware, auth_config=config.server.auth)
|
||||||
|
else:
|
||||||
|
if config.server.quota:
|
||||||
|
quota = config.server.quota
|
||||||
|
logger.warning(
|
||||||
|
"Configured authenticated_max_requests (%d) but no auth is enabled; "
|
||||||
|
"falling back to anonymous_max_requests (%d) for all the requests",
|
||||||
|
quota.authenticated_max_requests,
|
||||||
|
quota.anonymous_max_requests,
|
||||||
|
)
|
||||||
|
|
||||||
|
if config.server.quota:
|
||||||
|
logger.info("Enabling quota middleware for authenticated and anonymous clients")
|
||||||
|
|
||||||
|
quota = config.server.quota
|
||||||
|
anonymous_max_requests = quota.anonymous_max_requests
|
||||||
|
# if auth is disabled, use the anonymous max requests
|
||||||
|
authenticated_max_requests = quota.authenticated_max_requests if config.server.auth else anonymous_max_requests
|
||||||
|
|
||||||
|
kv_config = quota.kvstore
|
||||||
|
window_map = {"day": 86400}
|
||||||
|
window_seconds = window_map[quota.period.value]
|
||||||
|
|
||||||
|
app.add_middleware(
|
||||||
|
QuotaMiddleware,
|
||||||
|
kv_config=kv_config,
|
||||||
|
anonymous_max_requests=anonymous_max_requests,
|
||||||
|
authenticated_max_requests=authenticated_max_requests,
|
||||||
|
window_seconds=window_seconds,
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- CORS middleware for local development ---
|
||||||
|
# TODO: move to reverse proxy
|
||||||
|
ui_port = os.environ.get("LLAMA_STACK_UI_PORT", 8322)
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=[f"http://localhost:{ui_port}"],
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
impls = asyncio.run(construct_stack(config))
|
impls = asyncio.run(construct_stack(config))
|
||||||
|
|
|
@ -36,7 +36,7 @@ class DistributionRegistry(Protocol):
|
||||||
|
|
||||||
|
|
||||||
REGISTER_PREFIX = "distributions:registry"
|
REGISTER_PREFIX = "distributions:registry"
|
||||||
KEY_VERSION = "v8"
|
KEY_VERSION = "v9"
|
||||||
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"
|
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -8,6 +8,7 @@ import logging
|
||||||
import os
|
import os
|
||||||
import signal
|
import signal
|
||||||
import subprocess
|
import subprocess
|
||||||
|
import sys
|
||||||
|
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
|
@ -33,6 +34,7 @@ def formulate_run_args(image_type, image_name, config, template_name) -> list:
|
||||||
cprint(
|
cprint(
|
||||||
"No current conda environment detected, please specify a conda environment name with --image-name",
|
"No current conda environment detected, please specify a conda environment name with --image-name",
|
||||||
color="red",
|
color="red",
|
||||||
|
file=sys.stderr,
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -49,12 +51,13 @@ def formulate_run_args(image_type, image_name, config, template_name) -> list:
|
||||||
return envpath
|
return envpath
|
||||||
return None
|
return None
|
||||||
|
|
||||||
print(f"Using conda environment: {env_name}")
|
cprint(f"Using conda environment: {env_name}", color="green", file=sys.stderr)
|
||||||
conda_prefix = get_conda_prefix(env_name)
|
conda_prefix = get_conda_prefix(env_name)
|
||||||
if not conda_prefix:
|
if not conda_prefix:
|
||||||
cprint(
|
cprint(
|
||||||
f"Conda environment {env_name} does not exist.",
|
f"Conda environment {env_name} does not exist.",
|
||||||
color="red",
|
color="red",
|
||||||
|
file=sys.stderr,
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -63,6 +66,7 @@ def formulate_run_args(image_type, image_name, config, template_name) -> list:
|
||||||
cprint(
|
cprint(
|
||||||
f"Build file {build_file} does not exist.\n\nPlease run `llama stack build` or specify the correct conda environment name with --image-name",
|
f"Build file {build_file} does not exist.\n\nPlease run `llama stack build` or specify the correct conda environment name with --image-name",
|
||||||
color="red",
|
color="red",
|
||||||
|
file=sys.stderr,
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
|
@ -73,9 +77,10 @@ def formulate_run_args(image_type, image_name, config, template_name) -> list:
|
||||||
cprint(
|
cprint(
|
||||||
"No current virtual environment detected, please specify a virtual environment name with --image-name",
|
"No current virtual environment detected, please specify a virtual environment name with --image-name",
|
||||||
color="red",
|
color="red",
|
||||||
|
file=sys.stderr,
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
print(f"Using virtual environment: {env_name}")
|
cprint(f"Using virtual environment: {env_name}", file=sys.stderr)
|
||||||
|
|
||||||
script = importlib.resources.files("llama_stack") / "distribution/start_stack.sh"
|
script = importlib.resources.files("llama_stack") / "distribution/start_stack.sh"
|
||||||
run_args = [
|
run_args = [
|
||||||
|
|
|
@ -6,6 +6,7 @@
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
from logging.config import dictConfig
|
from logging.config import dictConfig
|
||||||
|
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
|
@ -234,7 +235,7 @@ def get_logger(
|
||||||
|
|
||||||
env_config = os.environ.get("LLAMA_STACK_LOGGING", "")
|
env_config = os.environ.get("LLAMA_STACK_LOGGING", "")
|
||||||
if env_config:
|
if env_config:
|
||||||
cprint(f"Environment variable LLAMA_STACK_LOGGING found: {env_config}", "yellow")
|
cprint(f"Environment variable LLAMA_STACK_LOGGING found: {env_config}", color="yellow", file=sys.stderr)
|
||||||
_category_levels.update(parse_environment_config(env_config))
|
_category_levels.update(parse_environment_config(env_config))
|
||||||
|
|
||||||
log_file = os.environ.get("LLAMA_STACK_LOG_FILE")
|
log_file = os.environ.get("LLAMA_STACK_LOG_FILE")
|
||||||
|
|
|
@ -174,6 +174,7 @@ class Llama3:
|
||||||
cprint(
|
cprint(
|
||||||
"Input to model:\n" + self.tokenizer.decode(tokens_to_print) + "\n",
|
"Input to model:\n" + self.tokenizer.decode(tokens_to_print) + "\n",
|
||||||
"red",
|
"red",
|
||||||
|
file=sys.stderr,
|
||||||
)
|
)
|
||||||
prompt_tokens = [inp.tokens for inp in llm_inputs]
|
prompt_tokens = [inp.tokens for inp in llm_inputs]
|
||||||
|
|
||||||
|
@ -184,7 +185,11 @@ class Llama3:
|
||||||
max_prompt_len = max(len(t) for t in prompt_tokens)
|
max_prompt_len = max(len(t) for t in prompt_tokens)
|
||||||
|
|
||||||
if max_prompt_len >= params.max_seq_len:
|
if max_prompt_len >= params.max_seq_len:
|
||||||
cprint(f"Out of token budget {max_prompt_len} vs {params.max_seq_len}", "red")
|
cprint(
|
||||||
|
f"Out of token budget {max_prompt_len} vs {params.max_seq_len}",
|
||||||
|
color="red",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
total_len = min(max_gen_len + max_prompt_len, params.max_seq_len)
|
total_len = min(max_gen_len + max_prompt_len, params.max_seq_len)
|
||||||
|
|
|
@ -133,9 +133,9 @@ class Llama4:
|
||||||
|
|
||||||
print_model_input = print_model_input or os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1"
|
print_model_input = print_model_input or os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1"
|
||||||
if print_model_input:
|
if print_model_input:
|
||||||
cprint("Input to model:\n", "yellow")
|
cprint("Input to model:\n", color="yellow", file=sys.stderr)
|
||||||
for inp in llm_inputs:
|
for inp in llm_inputs:
|
||||||
cprint(self.tokenizer.decode(inp.tokens), "grey")
|
cprint(self.tokenizer.decode(inp.tokens), color="grey", file=sys.stderr)
|
||||||
prompt_tokens = [inp.tokens for inp in llm_inputs]
|
prompt_tokens = [inp.tokens for inp in llm_inputs]
|
||||||
|
|
||||||
bsz = len(llm_inputs)
|
bsz = len(llm_inputs)
|
||||||
|
@ -145,7 +145,7 @@ class Llama4:
|
||||||
max_prompt_len = max(len(t) for t in prompt_tokens)
|
max_prompt_len = max(len(t) for t in prompt_tokens)
|
||||||
|
|
||||||
if max_prompt_len >= params.max_seq_len:
|
if max_prompt_len >= params.max_seq_len:
|
||||||
cprint(f"Out of token budget {max_prompt_len} vs {params.max_seq_len}", "red")
|
cprint(f"Out of token budget {max_prompt_len} vs {params.max_seq_len}", color="red", file=sys.stderr)
|
||||||
return
|
return
|
||||||
|
|
||||||
total_len = min(max_gen_len + max_prompt_len, params.max_seq_len)
|
total_len = min(max_gen_len + max_prompt_len, params.max_seq_len)
|
||||||
|
|
|
@ -16,7 +16,7 @@ from llama_stack.apis.datatypes import Api
|
||||||
from llama_stack.apis.models import Model
|
from llama_stack.apis.models import Model
|
||||||
from llama_stack.apis.scoring_functions import ScoringFn
|
from llama_stack.apis.scoring_functions import ScoringFn
|
||||||
from llama_stack.apis.shields import Shield
|
from llama_stack.apis.shields import Shield
|
||||||
from llama_stack.apis.tools import Tool
|
from llama_stack.apis.tools import ToolGroup
|
||||||
from llama_stack.apis.vector_dbs import VectorDB
|
from llama_stack.apis.vector_dbs import VectorDB
|
||||||
from llama_stack.schema_utils import json_schema_type
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
@ -74,10 +74,10 @@ class BenchmarksProtocolPrivate(Protocol):
|
||||||
async def register_benchmark(self, benchmark: Benchmark) -> None: ...
|
async def register_benchmark(self, benchmark: Benchmark) -> None: ...
|
||||||
|
|
||||||
|
|
||||||
class ToolsProtocolPrivate(Protocol):
|
class ToolGroupsProtocolPrivate(Protocol):
|
||||||
async def register_tool(self, tool: Tool) -> None: ...
|
async def register_toolgroup(self, toolgroup: ToolGroup) -> None: ...
|
||||||
|
|
||||||
async def unregister_tool(self, tool_id: str) -> None: ...
|
async def unregister_toolgroup(self, toolgroup_id: str) -> None: ...
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
|
|
@ -20,9 +20,12 @@ from llama_stack.apis.agents import (
|
||||||
AgentTurnCreateRequest,
|
AgentTurnCreateRequest,
|
||||||
AgentTurnResumeRequest,
|
AgentTurnResumeRequest,
|
||||||
Document,
|
Document,
|
||||||
|
ListOpenAIResponseInputItem,
|
||||||
|
ListOpenAIResponseObject,
|
||||||
OpenAIResponseInput,
|
OpenAIResponseInput,
|
||||||
OpenAIResponseInputTool,
|
OpenAIResponseInputTool,
|
||||||
OpenAIResponseObject,
|
OpenAIResponseObject,
|
||||||
|
Order,
|
||||||
Session,
|
Session,
|
||||||
Turn,
|
Turn,
|
||||||
)
|
)
|
||||||
|
@ -39,6 +42,7 @@ from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||||
from llama_stack.apis.vector_io import VectorIO
|
from llama_stack.apis.vector_io import VectorIO
|
||||||
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl
|
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl
|
||||||
from llama_stack.providers.utils.pagination import paginate_records
|
from llama_stack.providers.utils.pagination import paginate_records
|
||||||
|
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
|
||||||
|
|
||||||
from .agent_instance import ChatAgent
|
from .agent_instance import ChatAgent
|
||||||
from .config import MetaReferenceAgentsImplConfig
|
from .config import MetaReferenceAgentsImplConfig
|
||||||
|
@ -66,15 +70,17 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
self.tool_groups_api = tool_groups_api
|
self.tool_groups_api = tool_groups_api
|
||||||
|
|
||||||
self.in_memory_store = InmemoryKVStoreImpl()
|
self.in_memory_store = InmemoryKVStoreImpl()
|
||||||
self.openai_responses_impl = None
|
self.openai_responses_impl: OpenAIResponsesImpl | None = None
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
self.persistence_store = await kvstore_impl(self.config.persistence_store)
|
self.persistence_store = await kvstore_impl(self.config.persistence_store)
|
||||||
|
self.responses_store = ResponsesStore(self.config.responses_store)
|
||||||
|
await self.responses_store.initialize()
|
||||||
self.openai_responses_impl = OpenAIResponsesImpl(
|
self.openai_responses_impl = OpenAIResponsesImpl(
|
||||||
self.persistence_store,
|
|
||||||
inference_api=self.inference_api,
|
inference_api=self.inference_api,
|
||||||
tool_groups_api=self.tool_groups_api,
|
tool_groups_api=self.tool_groups_api,
|
||||||
tool_runtime_api=self.tool_runtime_api,
|
tool_runtime_api=self.tool_runtime_api,
|
||||||
|
responses_store=self.responses_store,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def create_agent(
|
async def create_agent(
|
||||||
|
@ -305,14 +311,15 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
# OpenAI responses
|
# OpenAI responses
|
||||||
async def get_openai_response(
|
async def get_openai_response(
|
||||||
self,
|
self,
|
||||||
id: str,
|
response_id: str,
|
||||||
) -> OpenAIResponseObject:
|
) -> OpenAIResponseObject:
|
||||||
return await self.openai_responses_impl.get_openai_response(id)
|
return await self.openai_responses_impl.get_openai_response(response_id)
|
||||||
|
|
||||||
async def create_openai_response(
|
async def create_openai_response(
|
||||||
self,
|
self,
|
||||||
input: str | list[OpenAIResponseInput],
|
input: str | list[OpenAIResponseInput],
|
||||||
model: str,
|
model: str,
|
||||||
|
instructions: str | None = None,
|
||||||
previous_response_id: str | None = None,
|
previous_response_id: str | None = None,
|
||||||
store: bool | None = True,
|
store: bool | None = True,
|
||||||
stream: bool | None = False,
|
stream: bool | None = False,
|
||||||
|
@ -320,5 +327,27 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
tools: list[OpenAIResponseInputTool] | None = None,
|
tools: list[OpenAIResponseInputTool] | None = None,
|
||||||
) -> OpenAIResponseObject:
|
) -> OpenAIResponseObject:
|
||||||
return await self.openai_responses_impl.create_openai_response(
|
return await self.openai_responses_impl.create_openai_response(
|
||||||
input, model, previous_response_id, store, stream, temperature, tools
|
input, model, instructions, previous_response_id, store, stream, temperature, tools
|
||||||
|
)
|
||||||
|
|
||||||
|
async def list_openai_responses(
|
||||||
|
self,
|
||||||
|
after: str | None = None,
|
||||||
|
limit: int | None = 50,
|
||||||
|
model: str | None = None,
|
||||||
|
order: Order | None = Order.desc,
|
||||||
|
) -> ListOpenAIResponseObject:
|
||||||
|
return await self.openai_responses_impl.list_openai_responses(after, limit, model, order)
|
||||||
|
|
||||||
|
async def list_openai_response_input_items(
|
||||||
|
self,
|
||||||
|
response_id: str,
|
||||||
|
after: str | None = None,
|
||||||
|
before: str | None = None,
|
||||||
|
include: list[str] | None = None,
|
||||||
|
limit: int | None = 20,
|
||||||
|
order: Order | None = Order.desc,
|
||||||
|
) -> ListOpenAIResponseInputItem:
|
||||||
|
return await self.openai_responses_impl.list_openai_response_input_items(
|
||||||
|
response_id, after, before, include, limit, order
|
||||||
)
|
)
|
||||||
|
|
|
@ -10,10 +10,12 @@ from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_stack.providers.utils.kvstore import KVStoreConfig
|
from llama_stack.providers.utils.kvstore import KVStoreConfig
|
||||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||||
|
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig
|
||||||
|
|
||||||
|
|
||||||
class MetaReferenceAgentsImplConfig(BaseModel):
|
class MetaReferenceAgentsImplConfig(BaseModel):
|
||||||
persistence_store: KVStoreConfig
|
persistence_store: KVStoreConfig
|
||||||
|
responses_store: SqlStoreConfig
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]:
|
def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]:
|
||||||
|
@ -21,5 +23,9 @@ class MetaReferenceAgentsImplConfig(BaseModel):
|
||||||
"persistence_store": SqliteKVStoreConfig.sample_run_config(
|
"persistence_store": SqliteKVStoreConfig.sample_run_config(
|
||||||
__distro_dir__=__distro_dir__,
|
__distro_dir__=__distro_dir__,
|
||||||
db_name="agents_store.db",
|
db_name="agents_store.db",
|
||||||
)
|
),
|
||||||
|
"responses_store": SqliteSqlStoreConfig.sample_run_config(
|
||||||
|
__distro_dir__=__distro_dir__,
|
||||||
|
db_name="responses_store.db",
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,6 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from collections.abc import AsyncIterator
|
from collections.abc import AsyncIterator
|
||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
|
@ -12,24 +13,29 @@ from typing import Any, cast
|
||||||
from openai.types.chat import ChatCompletionToolParam
|
from openai.types.chat import ChatCompletionToolParam
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from llama_stack.apis.agents import Order
|
||||||
from llama_stack.apis.agents.openai_responses import (
|
from llama_stack.apis.agents.openai_responses import (
|
||||||
|
AllowedToolsFilter,
|
||||||
|
ListOpenAIResponseInputItem,
|
||||||
|
ListOpenAIResponseObject,
|
||||||
OpenAIResponseInput,
|
OpenAIResponseInput,
|
||||||
OpenAIResponseInputFunctionToolCallOutput,
|
OpenAIResponseInputFunctionToolCallOutput,
|
||||||
OpenAIResponseInputItemList,
|
|
||||||
OpenAIResponseInputMessageContent,
|
OpenAIResponseInputMessageContent,
|
||||||
OpenAIResponseInputMessageContentImage,
|
OpenAIResponseInputMessageContentImage,
|
||||||
OpenAIResponseInputMessageContentText,
|
OpenAIResponseInputMessageContentText,
|
||||||
OpenAIResponseInputTool,
|
OpenAIResponseInputTool,
|
||||||
OpenAIResponseInputToolFunction,
|
OpenAIResponseInputToolMCP,
|
||||||
OpenAIResponseMessage,
|
OpenAIResponseMessage,
|
||||||
OpenAIResponseObject,
|
OpenAIResponseObject,
|
||||||
OpenAIResponseObjectStream,
|
OpenAIResponseObjectStream,
|
||||||
OpenAIResponseObjectStreamResponseCompleted,
|
OpenAIResponseObjectStreamResponseCompleted,
|
||||||
OpenAIResponseObjectStreamResponseCreated,
|
OpenAIResponseObjectStreamResponseCreated,
|
||||||
|
OpenAIResponseObjectStreamResponseOutputTextDelta,
|
||||||
OpenAIResponseOutput,
|
OpenAIResponseOutput,
|
||||||
OpenAIResponseOutputMessageContent,
|
OpenAIResponseOutputMessageContent,
|
||||||
OpenAIResponseOutputMessageContentOutputText,
|
OpenAIResponseOutputMessageContentOutputText,
|
||||||
OpenAIResponseOutputMessageFunctionToolCall,
|
OpenAIResponseOutputMessageFunctionToolCall,
|
||||||
|
OpenAIResponseOutputMessageMCPListTools,
|
||||||
OpenAIResponseOutputMessageWebSearchToolCall,
|
OpenAIResponseOutputMessageWebSearchToolCall,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.inference.inference import (
|
from llama_stack.apis.inference.inference import (
|
||||||
|
@ -49,11 +55,12 @@ from llama_stack.apis.inference.inference import (
|
||||||
OpenAIToolMessageParam,
|
OpenAIToolMessageParam,
|
||||||
OpenAIUserMessageParam,
|
OpenAIUserMessageParam,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.tools.tools import ToolGroups, ToolInvocationResult, ToolRuntime
|
from llama_stack.apis.tools.tools import ToolGroups, ToolRuntime
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
|
from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
|
||||||
from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool
|
from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool
|
||||||
from llama_stack.providers.utils.kvstore import KVStore
|
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
|
||||||
|
from llama_stack.providers.utils.tools.mcp import invoke_mcp_tool, list_mcp_tools
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="openai_responses")
|
logger = get_logger(name=__name__, category="openai_responses")
|
||||||
|
|
||||||
|
@ -162,41 +169,43 @@ async def _get_message_type_by_role(role: str):
|
||||||
|
|
||||||
|
|
||||||
class OpenAIResponsePreviousResponseWithInputItems(BaseModel):
|
class OpenAIResponsePreviousResponseWithInputItems(BaseModel):
|
||||||
input_items: OpenAIResponseInputItemList
|
input_items: ListOpenAIResponseInputItem
|
||||||
response: OpenAIResponseObject
|
response: OpenAIResponseObject
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionContext(BaseModel):
|
||||||
|
model: str
|
||||||
|
messages: list[OpenAIMessageParam]
|
||||||
|
tools: list[ChatCompletionToolParam] | None = None
|
||||||
|
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP]
|
||||||
|
stream: bool
|
||||||
|
temperature: float | None
|
||||||
|
|
||||||
|
|
||||||
class OpenAIResponsesImpl:
|
class OpenAIResponsesImpl:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
persistence_store: KVStore,
|
|
||||||
inference_api: Inference,
|
inference_api: Inference,
|
||||||
tool_groups_api: ToolGroups,
|
tool_groups_api: ToolGroups,
|
||||||
tool_runtime_api: ToolRuntime,
|
tool_runtime_api: ToolRuntime,
|
||||||
|
responses_store: ResponsesStore,
|
||||||
):
|
):
|
||||||
self.persistence_store = persistence_store
|
|
||||||
self.inference_api = inference_api
|
self.inference_api = inference_api
|
||||||
self.tool_groups_api = tool_groups_api
|
self.tool_groups_api = tool_groups_api
|
||||||
self.tool_runtime_api = tool_runtime_api
|
self.tool_runtime_api = tool_runtime_api
|
||||||
|
self.responses_store = responses_store
|
||||||
async def _get_previous_response_with_input(self, id: str) -> OpenAIResponsePreviousResponseWithInputItems:
|
|
||||||
key = f"{OPENAI_RESPONSES_PREFIX}{id}"
|
|
||||||
response_json = await self.persistence_store.get(key=key)
|
|
||||||
if response_json is None:
|
|
||||||
raise ValueError(f"OpenAI response with id '{id}' not found")
|
|
||||||
return OpenAIResponsePreviousResponseWithInputItems.model_validate_json(response_json)
|
|
||||||
|
|
||||||
async def _prepend_previous_response(
|
async def _prepend_previous_response(
|
||||||
self, input: str | list[OpenAIResponseInput], previous_response_id: str | None = None
|
self, input: str | list[OpenAIResponseInput], previous_response_id: str | None = None
|
||||||
):
|
):
|
||||||
if previous_response_id:
|
if previous_response_id:
|
||||||
previous_response_with_input = await self._get_previous_response_with_input(previous_response_id)
|
previous_response_with_input = await self.responses_store.get_response_object(previous_response_id)
|
||||||
|
|
||||||
# previous response input items
|
# previous response input items
|
||||||
new_input_items = previous_response_with_input.input_items.data
|
new_input_items = previous_response_with_input.input
|
||||||
|
|
||||||
# previous response output items
|
# previous response output items
|
||||||
new_input_items.extend(previous_response_with_input.response.output)
|
new_input_items.extend(previous_response_with_input.output)
|
||||||
|
|
||||||
# new input items from the current request
|
# new input items from the current request
|
||||||
if isinstance(input, str):
|
if isinstance(input, str):
|
||||||
|
@ -208,17 +217,116 @@ class OpenAIResponsesImpl:
|
||||||
|
|
||||||
return input
|
return input
|
||||||
|
|
||||||
|
async def _prepend_instructions(self, messages, instructions):
|
||||||
|
if instructions:
|
||||||
|
messages.insert(0, OpenAISystemMessageParam(content=instructions))
|
||||||
|
|
||||||
async def get_openai_response(
|
async def get_openai_response(
|
||||||
self,
|
self,
|
||||||
id: str,
|
response_id: str,
|
||||||
) -> OpenAIResponseObject:
|
) -> OpenAIResponseObject:
|
||||||
response_with_input = await self._get_previous_response_with_input(id)
|
response_with_input = await self.responses_store.get_response_object(response_id)
|
||||||
return response_with_input.response
|
return OpenAIResponseObject(**{k: v for k, v in response_with_input.model_dump().items() if k != "input"})
|
||||||
|
|
||||||
|
async def list_openai_responses(
|
||||||
|
self,
|
||||||
|
after: str | None = None,
|
||||||
|
limit: int | None = 50,
|
||||||
|
model: str | None = None,
|
||||||
|
order: Order | None = Order.desc,
|
||||||
|
) -> ListOpenAIResponseObject:
|
||||||
|
return await self.responses_store.list_responses(after, limit, model, order)
|
||||||
|
|
||||||
|
async def list_openai_response_input_items(
|
||||||
|
self,
|
||||||
|
response_id: str,
|
||||||
|
after: str | None = None,
|
||||||
|
before: str | None = None,
|
||||||
|
include: list[str] | None = None,
|
||||||
|
limit: int | None = 20,
|
||||||
|
order: Order | None = Order.desc,
|
||||||
|
) -> ListOpenAIResponseInputItem:
|
||||||
|
"""List input items for a given OpenAI response.
|
||||||
|
|
||||||
|
:param response_id: The ID of the response to retrieve input items for.
|
||||||
|
:param after: An item ID to list items after, used for pagination.
|
||||||
|
:param before: An item ID to list items before, used for pagination.
|
||||||
|
:param include: Additional fields to include in the response.
|
||||||
|
:param limit: A limit on the number of objects to be returned.
|
||||||
|
:param order: The order to return the input items in.
|
||||||
|
:returns: An ListOpenAIResponseInputItem.
|
||||||
|
"""
|
||||||
|
return await self.responses_store.list_response_input_items(response_id, after, before, include, limit, order)
|
||||||
|
|
||||||
|
async def _process_response_choices(
|
||||||
|
self,
|
||||||
|
chat_response: OpenAIChatCompletion,
|
||||||
|
ctx: ChatCompletionContext,
|
||||||
|
tools: list[OpenAIResponseInputTool] | None,
|
||||||
|
) -> list[OpenAIResponseOutput]:
|
||||||
|
"""Handle tool execution and response message creation."""
|
||||||
|
output_messages: list[OpenAIResponseOutput] = []
|
||||||
|
# Execute tool calls if any
|
||||||
|
for choice in chat_response.choices:
|
||||||
|
if choice.message.tool_calls and tools:
|
||||||
|
# Assume if the first tool is a function, all tools are functions
|
||||||
|
if tools[0].type == "function":
|
||||||
|
for tool_call in choice.message.tool_calls:
|
||||||
|
output_messages.append(
|
||||||
|
OpenAIResponseOutputMessageFunctionToolCall(
|
||||||
|
arguments=tool_call.function.arguments or "",
|
||||||
|
call_id=tool_call.id,
|
||||||
|
name=tool_call.function.name or "",
|
||||||
|
id=f"fc_{uuid.uuid4()}",
|
||||||
|
status="completed",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
tool_messages = await self._execute_tool_and_return_final_output(choice, ctx)
|
||||||
|
output_messages.extend(tool_messages)
|
||||||
|
else:
|
||||||
|
output_messages.append(await _convert_chat_choice_to_response_message(choice))
|
||||||
|
|
||||||
|
return output_messages
|
||||||
|
|
||||||
|
async def _store_response(
|
||||||
|
self,
|
||||||
|
response: OpenAIResponseObject,
|
||||||
|
original_input: str | list[OpenAIResponseInput],
|
||||||
|
) -> None:
|
||||||
|
new_input_id = f"msg_{uuid.uuid4()}"
|
||||||
|
if isinstance(original_input, str):
|
||||||
|
# synthesize a message from the input string
|
||||||
|
input_content = OpenAIResponseInputMessageContentText(text=original_input)
|
||||||
|
input_content_item = OpenAIResponseMessage(
|
||||||
|
role="user",
|
||||||
|
content=[input_content],
|
||||||
|
id=new_input_id,
|
||||||
|
)
|
||||||
|
input_items_data = [input_content_item]
|
||||||
|
else:
|
||||||
|
# we already have a list of messages
|
||||||
|
input_items_data = []
|
||||||
|
for input_item in original_input:
|
||||||
|
if isinstance(input_item, OpenAIResponseMessage):
|
||||||
|
# These may or may not already have an id, so dump to dict, check for id, and add if missing
|
||||||
|
input_item_dict = input_item.model_dump()
|
||||||
|
if "id" not in input_item_dict:
|
||||||
|
input_item_dict["id"] = new_input_id
|
||||||
|
input_items_data.append(OpenAIResponseMessage(**input_item_dict))
|
||||||
|
else:
|
||||||
|
input_items_data.append(input_item)
|
||||||
|
|
||||||
|
await self.responses_store.store_response_object(
|
||||||
|
response_object=response,
|
||||||
|
input=input_items_data,
|
||||||
|
)
|
||||||
|
|
||||||
async def create_openai_response(
|
async def create_openai_response(
|
||||||
self,
|
self,
|
||||||
input: str | list[OpenAIResponseInput],
|
input: str | list[OpenAIResponseInput],
|
||||||
model: str,
|
model: str,
|
||||||
|
instructions: str | None = None,
|
||||||
previous_response_id: str | None = None,
|
previous_response_id: str | None = None,
|
||||||
store: bool | None = True,
|
store: bool | None = True,
|
||||||
stream: bool | None = False,
|
stream: bool | None = False,
|
||||||
|
@ -226,11 +334,32 @@ class OpenAIResponsesImpl:
|
||||||
tools: list[OpenAIResponseInputTool] | None = None,
|
tools: list[OpenAIResponseInputTool] | None = None,
|
||||||
):
|
):
|
||||||
stream = False if stream is None else stream
|
stream = False if stream is None else stream
|
||||||
|
original_input = input # Keep reference for storage
|
||||||
|
|
||||||
|
output_messages: list[OpenAIResponseOutput] = []
|
||||||
|
|
||||||
|
# Input preprocessing
|
||||||
input = await self._prepend_previous_response(input, previous_response_id)
|
input = await self._prepend_previous_response(input, previous_response_id)
|
||||||
messages = await _convert_response_input_to_chat_messages(input)
|
messages = await _convert_response_input_to_chat_messages(input)
|
||||||
chat_tools = await self._convert_response_tools_to_chat_tools(tools) if tools else None
|
await self._prepend_instructions(messages, instructions)
|
||||||
chat_response = await self.inference_api.openai_chat_completion(
|
|
||||||
|
# Tool setup
|
||||||
|
chat_tools, mcp_tool_to_server, mcp_list_message = (
|
||||||
|
await self._convert_response_tools_to_chat_tools(tools) if tools else (None, {}, None)
|
||||||
|
)
|
||||||
|
if mcp_list_message:
|
||||||
|
output_messages.append(mcp_list_message)
|
||||||
|
|
||||||
|
ctx = ChatCompletionContext(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
tools=chat_tools,
|
||||||
|
mcp_tool_to_server=mcp_tool_to_server,
|
||||||
|
stream=stream,
|
||||||
|
temperature=temperature,
|
||||||
|
)
|
||||||
|
|
||||||
|
inference_result = await self.inference_api.openai_chat_completion(
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
tools=chat_tools,
|
tools=chat_tools,
|
||||||
|
@ -239,20 +368,122 @@ class OpenAIResponsesImpl:
|
||||||
)
|
)
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
# TODO: refactor this into a separate method that handles streaming
|
return self._create_streaming_response(
|
||||||
|
inference_result=inference_result,
|
||||||
|
ctx=ctx,
|
||||||
|
output_messages=output_messages,
|
||||||
|
original_input=original_input,
|
||||||
|
model=model,
|
||||||
|
store=store,
|
||||||
|
tools=tools,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return await self._create_non_streaming_response(
|
||||||
|
inference_result=inference_result,
|
||||||
|
ctx=ctx,
|
||||||
|
output_messages=output_messages,
|
||||||
|
original_input=original_input,
|
||||||
|
model=model,
|
||||||
|
store=store,
|
||||||
|
tools=tools,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _create_non_streaming_response(
|
||||||
|
self,
|
||||||
|
inference_result: Any,
|
||||||
|
ctx: ChatCompletionContext,
|
||||||
|
output_messages: list[OpenAIResponseOutput],
|
||||||
|
original_input: str | list[OpenAIResponseInput],
|
||||||
|
model: str,
|
||||||
|
store: bool | None,
|
||||||
|
tools: list[OpenAIResponseInputTool] | None,
|
||||||
|
) -> OpenAIResponseObject:
|
||||||
|
chat_response = OpenAIChatCompletion(**inference_result.model_dump())
|
||||||
|
|
||||||
|
# Process response choices (tool execution and message creation)
|
||||||
|
output_messages.extend(
|
||||||
|
await self._process_response_choices(
|
||||||
|
chat_response=chat_response,
|
||||||
|
ctx=ctx,
|
||||||
|
tools=tools,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
response = OpenAIResponseObject(
|
||||||
|
created_at=chat_response.created,
|
||||||
|
id=f"resp-{uuid.uuid4()}",
|
||||||
|
model=model,
|
||||||
|
object="response",
|
||||||
|
status="completed",
|
||||||
|
output=output_messages,
|
||||||
|
)
|
||||||
|
logger.debug(f"OpenAI Responses response: {response}")
|
||||||
|
|
||||||
|
# Store response if requested
|
||||||
|
if store:
|
||||||
|
await self._store_response(
|
||||||
|
response=response,
|
||||||
|
original_input=original_input,
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
async def _create_streaming_response(
|
||||||
|
self,
|
||||||
|
inference_result: Any,
|
||||||
|
ctx: ChatCompletionContext,
|
||||||
|
output_messages: list[OpenAIResponseOutput],
|
||||||
|
original_input: str | list[OpenAIResponseInput],
|
||||||
|
model: str,
|
||||||
|
store: bool | None,
|
||||||
|
tools: list[OpenAIResponseInputTool] | None,
|
||||||
|
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||||
|
# Create initial response and emit response.created immediately
|
||||||
|
response_id = f"resp-{uuid.uuid4()}"
|
||||||
|
created_at = int(time.time())
|
||||||
|
|
||||||
|
initial_response = OpenAIResponseObject(
|
||||||
|
created_at=created_at,
|
||||||
|
id=response_id,
|
||||||
|
model=model,
|
||||||
|
object="response",
|
||||||
|
status="in_progress",
|
||||||
|
output=output_messages.copy(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Emit response.created immediately
|
||||||
|
yield OpenAIResponseObjectStreamResponseCreated(response=initial_response)
|
||||||
|
|
||||||
|
# For streaming, inference_result is an async iterator of chunks
|
||||||
|
# Stream chunks and emit delta events as they arrive
|
||||||
chat_response_id = ""
|
chat_response_id = ""
|
||||||
chat_response_content = []
|
chat_response_content = []
|
||||||
chat_response_tool_calls: dict[int, OpenAIChatCompletionToolCall] = {}
|
chat_response_tool_calls: dict[int, OpenAIChatCompletionToolCall] = {}
|
||||||
# TODO: these chunk_ fields are hacky and only take the last chunk into account
|
|
||||||
chunk_created = 0
|
chunk_created = 0
|
||||||
chunk_model = ""
|
chunk_model = ""
|
||||||
chunk_finish_reason = ""
|
chunk_finish_reason = ""
|
||||||
async for chunk in chat_response:
|
sequence_number = 0
|
||||||
|
|
||||||
|
# Create a placeholder message item for delta events
|
||||||
|
message_item_id = f"msg_{uuid.uuid4()}"
|
||||||
|
|
||||||
|
async for chunk in inference_result:
|
||||||
chat_response_id = chunk.id
|
chat_response_id = chunk.id
|
||||||
chunk_created = chunk.created
|
chunk_created = chunk.created
|
||||||
chunk_model = chunk.model
|
chunk_model = chunk.model
|
||||||
for chunk_choice in chunk.choices:
|
for chunk_choice in chunk.choices:
|
||||||
# TODO: this only works for text content
|
# Emit incremental text content as delta events
|
||||||
|
if chunk_choice.delta.content:
|
||||||
|
sequence_number += 1
|
||||||
|
yield OpenAIResponseObjectStreamResponseOutputTextDelta(
|
||||||
|
content_index=0,
|
||||||
|
delta=chunk_choice.delta.content,
|
||||||
|
item_id=message_item_id,
|
||||||
|
output_index=0,
|
||||||
|
sequence_number=sequence_number,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Collect content for final response
|
||||||
chat_response_content.append(chunk_choice.delta.content or "")
|
chat_response_content.append(chunk_choice.delta.content or "")
|
||||||
if chunk_choice.finish_reason:
|
if chunk_choice.finish_reason:
|
||||||
chunk_finish_reason = chunk_choice.finish_reason
|
chunk_finish_reason = chunk_choice.finish_reason
|
||||||
|
@ -265,13 +496,11 @@ class OpenAIResponsesImpl:
|
||||||
response_tool_call.function.arguments += tool_call.function.arguments
|
response_tool_call.function.arguments += tool_call.function.arguments
|
||||||
else:
|
else:
|
||||||
tool_call_dict: dict[str, Any] = tool_call.model_dump()
|
tool_call_dict: dict[str, Any] = tool_call.model_dump()
|
||||||
# Ensure we don't have any empty type field in the tool call dict.
|
|
||||||
# The OpenAI client used by providers often returns a type=None here.
|
|
||||||
tool_call_dict.pop("type", None)
|
tool_call_dict.pop("type", None)
|
||||||
response_tool_call = OpenAIChatCompletionToolCall(**tool_call_dict)
|
response_tool_call = OpenAIChatCompletionToolCall(**tool_call_dict)
|
||||||
chat_response_tool_calls[tool_call.index] = response_tool_call
|
chat_response_tool_calls[tool_call.index] = response_tool_call
|
||||||
|
|
||||||
# Convert the dict of tool calls by index to a list of tool calls to pass back in our response
|
# Convert collected chunks to complete response
|
||||||
if chat_response_tool_calls:
|
if chat_response_tool_calls:
|
||||||
tool_calls = [chat_response_tool_calls[i] for i in sorted(chat_response_tool_calls.keys())]
|
tool_calls = [chat_response_tool_calls[i] for i in sorted(chat_response_tool_calls.keys())]
|
||||||
else:
|
else:
|
||||||
|
@ -280,7 +509,7 @@ class OpenAIResponsesImpl:
|
||||||
content="".join(chat_response_content),
|
content="".join(chat_response_content),
|
||||||
tool_calls=tool_calls,
|
tool_calls=tool_calls,
|
||||||
)
|
)
|
||||||
chat_response = OpenAIChatCompletion(
|
chat_response_obj = OpenAIChatCompletion(
|
||||||
id=chat_response_id,
|
id=chat_response_id,
|
||||||
choices=[
|
choices=[
|
||||||
OpenAIChoice(
|
OpenAIChoice(
|
||||||
|
@ -292,100 +521,50 @@ class OpenAIResponsesImpl:
|
||||||
created=chunk_created,
|
created=chunk_created,
|
||||||
model=chunk_model,
|
model=chunk_model,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
# dump and reload to map to our pydantic types
|
|
||||||
chat_response = OpenAIChatCompletion(**chat_response.model_dump())
|
|
||||||
|
|
||||||
output_messages: list[OpenAIResponseOutput] = []
|
# Process response choices (tool execution and message creation)
|
||||||
for choice in chat_response.choices:
|
|
||||||
if choice.message.tool_calls and tools:
|
|
||||||
# Assume if the first tool is a function, all tools are functions
|
|
||||||
if isinstance(tools[0], OpenAIResponseInputToolFunction):
|
|
||||||
for tool_call in choice.message.tool_calls:
|
|
||||||
output_messages.append(
|
|
||||||
OpenAIResponseOutputMessageFunctionToolCall(
|
|
||||||
arguments=tool_call.function.arguments or "",
|
|
||||||
call_id=tool_call.id,
|
|
||||||
name=tool_call.function.name or "",
|
|
||||||
id=f"fc_{uuid.uuid4()}",
|
|
||||||
status="completed",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
output_messages.extend(
|
output_messages.extend(
|
||||||
await self._execute_tool_and_return_final_output(model, stream, choice, messages, temperature)
|
await self._process_response_choices(
|
||||||
|
chat_response=chat_response_obj,
|
||||||
|
ctx=ctx,
|
||||||
|
tools=tools,
|
||||||
)
|
)
|
||||||
else:
|
)
|
||||||
output_messages.append(await _convert_chat_choice_to_response_message(choice))
|
|
||||||
response = OpenAIResponseObject(
|
# Create final response
|
||||||
created_at=chat_response.created,
|
final_response = OpenAIResponseObject(
|
||||||
id=f"resp-{uuid.uuid4()}",
|
created_at=created_at,
|
||||||
|
id=response_id,
|
||||||
model=model,
|
model=model,
|
||||||
object="response",
|
object="response",
|
||||||
status="completed",
|
status="completed",
|
||||||
output=output_messages,
|
output=output_messages,
|
||||||
)
|
)
|
||||||
logger.debug(f"OpenAI Responses response: {response}")
|
|
||||||
|
|
||||||
if store:
|
if store:
|
||||||
# Store in kvstore
|
await self._store_response(
|
||||||
|
response=final_response,
|
||||||
new_input_id = f"msg_{uuid.uuid4()}"
|
original_input=original_input,
|
||||||
if isinstance(input, str):
|
|
||||||
# synthesize a message from the input string
|
|
||||||
input_content = OpenAIResponseInputMessageContentText(text=input)
|
|
||||||
input_content_item = OpenAIResponseMessage(
|
|
||||||
role="user",
|
|
||||||
content=[input_content],
|
|
||||||
id=new_input_id,
|
|
||||||
)
|
|
||||||
input_items_data = [input_content_item]
|
|
||||||
else:
|
|
||||||
# we already have a list of messages
|
|
||||||
input_items_data = []
|
|
||||||
for input_item in input:
|
|
||||||
if isinstance(input_item, OpenAIResponseMessage):
|
|
||||||
# These may or may not already have an id, so dump to dict, check for id, and add if missing
|
|
||||||
input_item_dict = input_item.model_dump()
|
|
||||||
if "id" not in input_item_dict:
|
|
||||||
input_item_dict["id"] = new_input_id
|
|
||||||
input_items_data.append(OpenAIResponseMessage(**input_item_dict))
|
|
||||||
else:
|
|
||||||
input_items_data.append(input_item)
|
|
||||||
|
|
||||||
input_items = OpenAIResponseInputItemList(data=input_items_data)
|
|
||||||
prev_response = OpenAIResponsePreviousResponseWithInputItems(
|
|
||||||
input_items=input_items,
|
|
||||||
response=response,
|
|
||||||
)
|
|
||||||
key = f"{OPENAI_RESPONSES_PREFIX}{response.id}"
|
|
||||||
await self.persistence_store.set(
|
|
||||||
key=key,
|
|
||||||
value=prev_response.model_dump_json(),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if stream:
|
# Emit response.completed
|
||||||
|
yield OpenAIResponseObjectStreamResponseCompleted(response=final_response)
|
||||||
async def async_response() -> AsyncIterator[OpenAIResponseObjectStream]:
|
|
||||||
# TODO: response created should actually get emitted much earlier in the process
|
|
||||||
yield OpenAIResponseObjectStreamResponseCreated(response=response)
|
|
||||||
yield OpenAIResponseObjectStreamResponseCompleted(response=response)
|
|
||||||
|
|
||||||
return async_response()
|
|
||||||
|
|
||||||
return response
|
|
||||||
|
|
||||||
async def _convert_response_tools_to_chat_tools(
|
async def _convert_response_tools_to_chat_tools(
|
||||||
self, tools: list[OpenAIResponseInputTool]
|
self, tools: list[OpenAIResponseInputTool]
|
||||||
) -> list[ChatCompletionToolParam]:
|
) -> tuple[
|
||||||
chat_tools: list[ChatCompletionToolParam] = []
|
list[ChatCompletionToolParam],
|
||||||
for input_tool in tools:
|
dict[str, OpenAIResponseInputToolMCP],
|
||||||
# TODO: Handle other tool types
|
OpenAIResponseOutput | None,
|
||||||
if input_tool.type == "function":
|
]:
|
||||||
chat_tools.append(ChatCompletionToolParam(type="function", function=input_tool.model_dump()))
|
from llama_stack.apis.agents.openai_responses import (
|
||||||
elif input_tool.type == "web_search":
|
MCPListToolsTool,
|
||||||
tool_name = "web_search"
|
)
|
||||||
tool = await self.tool_groups_api.get_tool(tool_name)
|
from llama_stack.apis.tools.tools import Tool
|
||||||
|
|
||||||
|
mcp_tool_to_server = {}
|
||||||
|
|
||||||
|
def make_openai_tool(tool_name: str, tool: Tool) -> ChatCompletionToolParam:
|
||||||
tool_def = ToolDefinition(
|
tool_def = ToolDefinition(
|
||||||
tool_name=tool_name,
|
tool_name=tool_name,
|
||||||
description=tool.description,
|
description=tool.description,
|
||||||
|
@ -399,78 +578,106 @@ class OpenAIResponsesImpl:
|
||||||
for param in tool.parameters
|
for param in tool.parameters
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
chat_tool = convert_tooldef_to_openai_tool(tool_def)
|
return convert_tooldef_to_openai_tool(tool_def)
|
||||||
chat_tools.append(chat_tool)
|
|
||||||
|
mcp_list_message = None
|
||||||
|
chat_tools: list[ChatCompletionToolParam] = []
|
||||||
|
for input_tool in tools:
|
||||||
|
# TODO: Handle other tool types
|
||||||
|
if input_tool.type == "function":
|
||||||
|
chat_tools.append(ChatCompletionToolParam(type="function", function=input_tool.model_dump()))
|
||||||
|
elif input_tool.type == "web_search":
|
||||||
|
tool_name = "web_search"
|
||||||
|
tool = await self.tool_groups_api.get_tool(tool_name)
|
||||||
|
if not tool:
|
||||||
|
raise ValueError(f"Tool {tool_name} not found")
|
||||||
|
chat_tools.append(make_openai_tool(tool_name, tool))
|
||||||
|
elif input_tool.type == "mcp":
|
||||||
|
always_allowed = None
|
||||||
|
never_allowed = None
|
||||||
|
if input_tool.allowed_tools:
|
||||||
|
if isinstance(input_tool.allowed_tools, list):
|
||||||
|
always_allowed = input_tool.allowed_tools
|
||||||
|
elif isinstance(input_tool.allowed_tools, AllowedToolsFilter):
|
||||||
|
always_allowed = input_tool.allowed_tools.always
|
||||||
|
never_allowed = input_tool.allowed_tools.never
|
||||||
|
|
||||||
|
tool_defs = await list_mcp_tools(
|
||||||
|
endpoint=input_tool.server_url,
|
||||||
|
headers=input_tool.headers or {},
|
||||||
|
)
|
||||||
|
|
||||||
|
mcp_list_message = OpenAIResponseOutputMessageMCPListTools(
|
||||||
|
id=f"mcp_list_{uuid.uuid4()}",
|
||||||
|
status="completed",
|
||||||
|
server_label=input_tool.server_label,
|
||||||
|
tools=[],
|
||||||
|
)
|
||||||
|
for t in tool_defs.data:
|
||||||
|
if never_allowed and t.name in never_allowed:
|
||||||
|
continue
|
||||||
|
if not always_allowed or t.name in always_allowed:
|
||||||
|
chat_tools.append(make_openai_tool(t.name, t))
|
||||||
|
if t.name in mcp_tool_to_server:
|
||||||
|
raise ValueError(f"Duplicate tool name {t.name} found for server {input_tool.server_label}")
|
||||||
|
mcp_tool_to_server[t.name] = input_tool
|
||||||
|
mcp_list_message.tools.append(
|
||||||
|
MCPListToolsTool(
|
||||||
|
name=t.name,
|
||||||
|
description=t.description,
|
||||||
|
input_schema={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
p.name: {
|
||||||
|
"type": p.parameter_type,
|
||||||
|
"description": p.description,
|
||||||
|
}
|
||||||
|
for p in t.parameters
|
||||||
|
},
|
||||||
|
"required": [p.name for p in t.parameters if p.required],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Llama Stack OpenAI Responses does not yet support tool type: {input_tool.type}")
|
raise ValueError(f"Llama Stack OpenAI Responses does not yet support tool type: {input_tool.type}")
|
||||||
return chat_tools
|
return chat_tools, mcp_tool_to_server, mcp_list_message
|
||||||
|
|
||||||
async def _execute_tool_and_return_final_output(
|
async def _execute_tool_and_return_final_output(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
|
||||||
stream: bool,
|
|
||||||
choice: OpenAIChoice,
|
choice: OpenAIChoice,
|
||||||
messages: list[OpenAIMessageParam],
|
ctx: ChatCompletionContext,
|
||||||
temperature: float,
|
|
||||||
) -> list[OpenAIResponseOutput]:
|
) -> list[OpenAIResponseOutput]:
|
||||||
output_messages: list[OpenAIResponseOutput] = []
|
output_messages: list[OpenAIResponseOutput] = []
|
||||||
|
|
||||||
# If the choice is not an assistant message, we don't need to execute any tools
|
|
||||||
if not isinstance(choice.message, OpenAIAssistantMessageParam):
|
if not isinstance(choice.message, OpenAIAssistantMessageParam):
|
||||||
return output_messages
|
return output_messages
|
||||||
|
|
||||||
# If the assistant message doesn't have any tool calls, we don't need to execute any tools
|
|
||||||
if not choice.message.tool_calls:
|
if not choice.message.tool_calls:
|
||||||
return output_messages
|
return output_messages
|
||||||
|
|
||||||
# Copy the messages list to avoid mutating the original list
|
next_turn_messages = ctx.messages.copy()
|
||||||
messages = messages.copy()
|
|
||||||
|
|
||||||
# Add the assistant message with tool_calls response to the messages list
|
# Add the assistant message with tool_calls response to the messages list
|
||||||
messages.append(choice.message)
|
next_turn_messages.append(choice.message)
|
||||||
|
|
||||||
for tool_call in choice.message.tool_calls:
|
for tool_call in choice.message.tool_calls:
|
||||||
tool_call_id = tool_call.id
|
|
||||||
function = tool_call.function
|
|
||||||
|
|
||||||
# If for some reason the tool call doesn't have a function or id, we can't execute it
|
|
||||||
if not function or not tool_call_id:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# TODO: telemetry spans for tool calls
|
# TODO: telemetry spans for tool calls
|
||||||
result = await self._execute_tool_call(function)
|
tool_call_log, further_input = await self._execute_tool_call(tool_call, ctx)
|
||||||
|
if tool_call_log:
|
||||||
# Handle tool call failure
|
output_messages.append(tool_call_log)
|
||||||
if not result:
|
if further_input:
|
||||||
output_messages.append(
|
next_turn_messages.append(further_input)
|
||||||
OpenAIResponseOutputMessageWebSearchToolCall(
|
|
||||||
id=tool_call_id,
|
|
||||||
status="failed",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
output_messages.append(
|
|
||||||
OpenAIResponseOutputMessageWebSearchToolCall(
|
|
||||||
id=tool_call_id,
|
|
||||||
status="completed",
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
result_content = ""
|
|
||||||
# TODO: handle other result content types and lists
|
|
||||||
if isinstance(result.content, str):
|
|
||||||
result_content = result.content
|
|
||||||
messages.append(OpenAIToolMessageParam(content=result_content, tool_call_id=tool_call_id))
|
|
||||||
|
|
||||||
tool_results_chat_response = await self.inference_api.openai_chat_completion(
|
tool_results_chat_response = await self.inference_api.openai_chat_completion(
|
||||||
model=model_id,
|
model=ctx.model,
|
||||||
messages=messages,
|
messages=next_turn_messages,
|
||||||
stream=stream,
|
stream=ctx.stream,
|
||||||
temperature=temperature,
|
temperature=ctx.temperature,
|
||||||
)
|
)
|
||||||
# type cast to appease mypy
|
# type cast to appease mypy: this is needed because we don't handle streaming properly :)
|
||||||
tool_results_chat_response = cast(OpenAIChatCompletion, tool_results_chat_response)
|
tool_results_chat_response = cast(OpenAIChatCompletion, tool_results_chat_response)
|
||||||
|
|
||||||
|
# Huge TODO: these are NOT the final outputs, we must keep the loop going
|
||||||
tool_final_outputs = [
|
tool_final_outputs = [
|
||||||
await _convert_chat_choice_to_response_message(choice) for choice in tool_results_chat_response.choices
|
await _convert_chat_choice_to_response_message(choice) for choice in tool_results_chat_response.choices
|
||||||
]
|
]
|
||||||
|
@ -480,15 +687,86 @@ class OpenAIResponsesImpl:
|
||||||
|
|
||||||
async def _execute_tool_call(
|
async def _execute_tool_call(
|
||||||
self,
|
self,
|
||||||
function: OpenAIChatCompletionToolCallFunction,
|
tool_call: OpenAIChatCompletionToolCall,
|
||||||
) -> ToolInvocationResult | None:
|
ctx: ChatCompletionContext,
|
||||||
if not function.name:
|
) -> tuple[OpenAIResponseOutput | None, OpenAIMessageParam | None]:
|
||||||
return None
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
function_args = json.loads(function.arguments) if function.arguments else {}
|
interleaved_content_as_str,
|
||||||
logger.info(f"executing tool call: {function.name} with args: {function_args}")
|
)
|
||||||
|
|
||||||
|
tool_call_id = tool_call.id
|
||||||
|
function = tool_call.function
|
||||||
|
|
||||||
|
if not function or not tool_call_id or not function.name:
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
error_exc = None
|
||||||
|
result = None
|
||||||
|
try:
|
||||||
|
if function.name in ctx.mcp_tool_to_server:
|
||||||
|
mcp_tool = ctx.mcp_tool_to_server[function.name]
|
||||||
|
result = await invoke_mcp_tool(
|
||||||
|
endpoint=mcp_tool.server_url,
|
||||||
|
headers=mcp_tool.headers or {},
|
||||||
|
tool_name=function.name,
|
||||||
|
kwargs=json.loads(function.arguments) if function.arguments else {},
|
||||||
|
)
|
||||||
|
else:
|
||||||
result = await self.tool_runtime_api.invoke_tool(
|
result = await self.tool_runtime_api.invoke_tool(
|
||||||
tool_name=function.name,
|
tool_name=function.name,
|
||||||
kwargs=function_args,
|
kwargs=json.loads(function.arguments) if function.arguments else {},
|
||||||
)
|
)
|
||||||
logger.debug(f"tool call {function.name} completed with result: {result}")
|
except Exception as e:
|
||||||
return result
|
error_exc = e
|
||||||
|
|
||||||
|
if function.name in ctx.mcp_tool_to_server:
|
||||||
|
from llama_stack.apis.agents.openai_responses import OpenAIResponseOutputMessageMCPCall
|
||||||
|
|
||||||
|
message = OpenAIResponseOutputMessageMCPCall(
|
||||||
|
id=tool_call_id,
|
||||||
|
arguments=function.arguments,
|
||||||
|
name=function.name,
|
||||||
|
server_label=ctx.mcp_tool_to_server[function.name].server_label,
|
||||||
|
)
|
||||||
|
if error_exc:
|
||||||
|
message.error = str(error_exc)
|
||||||
|
elif (result.error_code and result.error_code > 0) or result.error_message:
|
||||||
|
message.error = f"Error (code {result.error_code}): {result.error_message}"
|
||||||
|
elif result.content:
|
||||||
|
message.output = interleaved_content_as_str(result.content)
|
||||||
|
else:
|
||||||
|
if function.name == "web_search":
|
||||||
|
message = OpenAIResponseOutputMessageWebSearchToolCall(
|
||||||
|
id=tool_call_id,
|
||||||
|
status="completed",
|
||||||
|
)
|
||||||
|
if error_exc or (result.error_code and result.error_code > 0) or result.error_message:
|
||||||
|
message.status = "failed"
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown tool {function.name} called")
|
||||||
|
|
||||||
|
input_message = None
|
||||||
|
if result and result.content:
|
||||||
|
if isinstance(result.content, str):
|
||||||
|
content = result.content
|
||||||
|
elif isinstance(result.content, list):
|
||||||
|
from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem
|
||||||
|
|
||||||
|
content = []
|
||||||
|
for item in result.content:
|
||||||
|
if isinstance(item, TextContentItem):
|
||||||
|
part = OpenAIChatCompletionContentPartTextParam(text=item.text)
|
||||||
|
elif isinstance(item, ImageContentItem):
|
||||||
|
if item.image.data:
|
||||||
|
url = f"data:image;base64,{item.image.data}"
|
||||||
|
else:
|
||||||
|
url = item.image.url
|
||||||
|
part = OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url=url))
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown result content type: {type(item)}")
|
||||||
|
content.append(part)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown result content type: {type(result.content)}")
|
||||||
|
input_message = OpenAIToolMessageParam(content=content, tool_call_id=tool_call_id)
|
||||||
|
|
||||||
|
return message, input_message
|
||||||
|
|
|
@ -6,6 +6,7 @@
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
@ -455,9 +456,9 @@ class MetaReferenceInferenceImpl(
|
||||||
first = token_results[0]
|
first = token_results[0]
|
||||||
if not first.finished and not first.ignore_token:
|
if not first.finished and not first.ignore_token:
|
||||||
if os.environ.get("LLAMA_MODELS_DEBUG", "0") in ("1", "2"):
|
if os.environ.get("LLAMA_MODELS_DEBUG", "0") in ("1", "2"):
|
||||||
cprint(first.text, "cyan", end="")
|
cprint(first.text, color="cyan", end="", file=sys.stderr)
|
||||||
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "2":
|
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "2":
|
||||||
cprint(f"<{first.token}>", "magenta", end="")
|
cprint(f"<{first.token}>", color="magenta", end="", file=sys.stderr)
|
||||||
|
|
||||||
for result in token_results:
|
for result in token_results:
|
||||||
idx = result.batch_idx
|
idx = result.batch_idx
|
||||||
|
@ -519,9 +520,9 @@ class MetaReferenceInferenceImpl(
|
||||||
for token_results in self.generator.chat_completion([request]):
|
for token_results in self.generator.chat_completion([request]):
|
||||||
token_result = token_results[0]
|
token_result = token_results[0]
|
||||||
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1":
|
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1":
|
||||||
cprint(token_result.text, "cyan", end="")
|
cprint(token_result.text, color="cyan", end="", file=sys.stderr)
|
||||||
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "2":
|
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "2":
|
||||||
cprint(f"<{token_result.token}>", "magenta", end="")
|
cprint(f"<{token_result.token}>", color="magenta", end="", file=sys.stderr)
|
||||||
|
|
||||||
if token_result.token == tokenizer.eot_id:
|
if token_result.token == tokenizer.eot_id:
|
||||||
stop_reason = StopReason.end_of_turn
|
stop_reason = StopReason.end_of_turn
|
||||||
|
|
|
@ -16,6 +16,7 @@ from opentelemetry.sdk.resources import Resource
|
||||||
from opentelemetry.sdk.trace import TracerProvider
|
from opentelemetry.sdk.trace import TracerProvider
|
||||||
from opentelemetry.sdk.trace.export import BatchSpanProcessor
|
from opentelemetry.sdk.trace.export import BatchSpanProcessor
|
||||||
from opentelemetry.semconv.resource import ResourceAttributes
|
from opentelemetry.semconv.resource import ResourceAttributes
|
||||||
|
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
|
||||||
|
|
||||||
from llama_stack.apis.telemetry import (
|
from llama_stack.apis.telemetry import (
|
||||||
Event,
|
Event,
|
||||||
|
@ -44,6 +45,7 @@ from llama_stack.providers.inline.telemetry.meta_reference.sqlite_span_processor
|
||||||
)
|
)
|
||||||
from llama_stack.providers.utils.telemetry.dataset_mixin import TelemetryDatasetMixin
|
from llama_stack.providers.utils.telemetry.dataset_mixin import TelemetryDatasetMixin
|
||||||
from llama_stack.providers.utils.telemetry.sqlite_trace_store import SQLiteTraceStore
|
from llama_stack.providers.utils.telemetry.sqlite_trace_store import SQLiteTraceStore
|
||||||
|
from llama_stack.providers.utils.telemetry.tracing import ROOT_SPAN_MARKERS
|
||||||
|
|
||||||
from .config import TelemetryConfig, TelemetrySink
|
from .config import TelemetryConfig, TelemetrySink
|
||||||
|
|
||||||
|
@ -146,7 +148,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
||||||
if span:
|
if span:
|
||||||
timestamp_ns = int(event.timestamp.timestamp() * 1e9)
|
timestamp_ns = int(event.timestamp.timestamp() * 1e9)
|
||||||
span.add_event(
|
span.add_event(
|
||||||
name=event.type,
|
name=event.type.value,
|
||||||
attributes={
|
attributes={
|
||||||
"message": event.message,
|
"message": event.message,
|
||||||
"severity": event.severity.value,
|
"severity": event.severity.value,
|
||||||
|
@ -206,6 +208,15 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
||||||
event.attributes = {}
|
event.attributes = {}
|
||||||
event.attributes["__ttl__"] = ttl_seconds
|
event.attributes["__ttl__"] = ttl_seconds
|
||||||
|
|
||||||
|
# Extract these W3C trace context attributes so they are not written to
|
||||||
|
# underlying storage, as we just need them to propagate the trace context.
|
||||||
|
traceparent = event.attributes.pop("traceparent", None)
|
||||||
|
tracestate = event.attributes.pop("tracestate", None)
|
||||||
|
if traceparent:
|
||||||
|
# If we have a traceparent header value, we're not the root span.
|
||||||
|
for root_attribute in ROOT_SPAN_MARKERS:
|
||||||
|
event.attributes.pop(root_attribute, None)
|
||||||
|
|
||||||
if isinstance(event.payload, SpanStartPayload):
|
if isinstance(event.payload, SpanStartPayload):
|
||||||
# Check if span already exists to prevent duplicates
|
# Check if span already exists to prevent duplicates
|
||||||
if span_id in _GLOBAL_STORAGE["active_spans"]:
|
if span_id in _GLOBAL_STORAGE["active_spans"]:
|
||||||
|
@ -216,8 +227,12 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
||||||
parent_span_id = int(event.payload.parent_span_id, 16)
|
parent_span_id = int(event.payload.parent_span_id, 16)
|
||||||
parent_span = _GLOBAL_STORAGE["active_spans"].get(parent_span_id)
|
parent_span = _GLOBAL_STORAGE["active_spans"].get(parent_span_id)
|
||||||
context = trace.set_span_in_context(parent_span)
|
context = trace.set_span_in_context(parent_span)
|
||||||
else:
|
elif traceparent:
|
||||||
event.attributes["__root_span__"] = "true"
|
carrier = {
|
||||||
|
"traceparent": traceparent,
|
||||||
|
"tracestate": tracestate,
|
||||||
|
}
|
||||||
|
context = TraceContextTextMapPropagator().extract(carrier=carrier)
|
||||||
|
|
||||||
span = tracer.start_span(
|
span = tracer.start_span(
|
||||||
name=event.payload.name,
|
name=event.payload.name,
|
||||||
|
|
|
@ -25,14 +25,14 @@ from llama_stack.apis.tools import (
|
||||||
RAGQueryConfig,
|
RAGQueryConfig,
|
||||||
RAGQueryResult,
|
RAGQueryResult,
|
||||||
RAGToolRuntime,
|
RAGToolRuntime,
|
||||||
Tool,
|
|
||||||
ToolDef,
|
ToolDef,
|
||||||
|
ToolGroup,
|
||||||
ToolInvocationResult,
|
ToolInvocationResult,
|
||||||
ToolParameter,
|
ToolParameter,
|
||||||
ToolRuntime,
|
ToolRuntime,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO
|
from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO
|
||||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
|
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
|
||||||
from llama_stack.providers.utils.memory.vector_store import (
|
from llama_stack.providers.utils.memory.vector_store import (
|
||||||
content_from_doc,
|
content_from_doc,
|
||||||
|
@ -49,7 +49,7 @@ def make_random_string(length: int = 8):
|
||||||
return "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(length))
|
return "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(length))
|
||||||
|
|
||||||
|
|
||||||
class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: RagToolRuntimeConfig,
|
config: RagToolRuntimeConfig,
|
||||||
|
@ -66,10 +66,10 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
||||||
async def shutdown(self):
|
async def shutdown(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def register_tool(self, tool: Tool) -> None:
|
async def register_toolgroup(self, toolgroup: ToolGroup) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def unregister_tool(self, tool_id: str) -> None:
|
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
|
||||||
return
|
return
|
||||||
|
|
||||||
async def insert(
|
async def insert(
|
||||||
|
@ -122,6 +122,7 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
||||||
query=query,
|
query=query,
|
||||||
params={
|
params={
|
||||||
"max_chunks": query_config.max_chunks,
|
"max_chunks": query_config.max_chunks,
|
||||||
|
"mode": query_config.mode,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
for vector_db_id in vector_db_ids
|
for vector_db_id in vector_db_ids
|
||||||
|
@ -146,7 +147,7 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
||||||
for i, chunk in enumerate(chunks):
|
for i, chunk in enumerate(chunks):
|
||||||
metadata = chunk.metadata
|
metadata = chunk.metadata
|
||||||
tokens += metadata["token_count"]
|
tokens += metadata["token_count"]
|
||||||
tokens += metadata["metadata_token_count"]
|
tokens += metadata.get("metadata_token_count", 0)
|
||||||
|
|
||||||
if tokens > query_config.max_tokens_in_context:
|
if tokens > query_config.max_tokens_in_context:
|
||||||
log.error(
|
log.error(
|
||||||
|
|
|
@ -99,9 +99,13 @@ class FaissIndex(EmbeddingIndex):
|
||||||
# Save updated index
|
# Save updated index
|
||||||
await self._save_index()
|
await self._save_index()
|
||||||
|
|
||||||
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
async def query_vector(
|
||||||
|
self,
|
||||||
|
embedding: NDArray,
|
||||||
|
k: int,
|
||||||
|
score_threshold: float,
|
||||||
|
) -> QueryChunksResponse:
|
||||||
distances, indices = await asyncio.to_thread(self.index.search, embedding.reshape(1, -1).astype(np.float32), k)
|
distances, indices = await asyncio.to_thread(self.index.search, embedding.reshape(1, -1).astype(np.float32), k)
|
||||||
|
|
||||||
chunks = []
|
chunks = []
|
||||||
scores = []
|
scores = []
|
||||||
for d, i in zip(distances[0], indices[0], strict=False):
|
for d, i in zip(distances[0], indices[0], strict=False):
|
||||||
|
@ -112,6 +116,14 @@ class FaissIndex(EmbeddingIndex):
|
||||||
|
|
||||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||||
|
|
||||||
|
async def query_keyword(
|
||||||
|
self,
|
||||||
|
query_string: str,
|
||||||
|
k: int,
|
||||||
|
score_threshold: float,
|
||||||
|
) -> QueryChunksResponse:
|
||||||
|
raise NotImplementedError("Keyword search is not supported in FAISS")
|
||||||
|
|
||||||
|
|
||||||
class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
||||||
def __init__(self, config: FaissVectorIOConfig, inference_api: Inference) -> None:
|
def __init__(self, config: FaissVectorIOConfig, inference_api: Inference) -> None:
|
||||||
|
|
|
@ -24,6 +24,11 @@ from llama_stack.providers.utils.memory.vector_store import EmbeddingIndex, Vect
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Specifying search mode is dependent on the VectorIO provider.
|
||||||
|
VECTOR_SEARCH = "vector"
|
||||||
|
KEYWORD_SEARCH = "keyword"
|
||||||
|
SEARCH_MODES = {VECTOR_SEARCH, KEYWORD_SEARCH}
|
||||||
|
|
||||||
|
|
||||||
def serialize_vector(vector: list[float]) -> bytes:
|
def serialize_vector(vector: list[float]) -> bytes:
|
||||||
"""Serialize a list of floats into a compact binary representation."""
|
"""Serialize a list of floats into a compact binary representation."""
|
||||||
|
@ -45,6 +50,7 @@ class SQLiteVecIndex(EmbeddingIndex):
|
||||||
Two tables are used:
|
Two tables are used:
|
||||||
- A metadata table (chunks_{bank_id}) that holds the chunk JSON.
|
- A metadata table (chunks_{bank_id}) that holds the chunk JSON.
|
||||||
- A virtual table (vec_chunks_{bank_id}) that holds the serialized vector.
|
- A virtual table (vec_chunks_{bank_id}) that holds the serialized vector.
|
||||||
|
- An FTS5 table (fts_chunks_{bank_id}) for full-text keyword search.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, dimension: int, db_path: str, bank_id: str):
|
def __init__(self, dimension: int, db_path: str, bank_id: str):
|
||||||
|
@ -53,6 +59,7 @@ class SQLiteVecIndex(EmbeddingIndex):
|
||||||
self.bank_id = bank_id
|
self.bank_id = bank_id
|
||||||
self.metadata_table = f"chunks_{bank_id}".replace("-", "_")
|
self.metadata_table = f"chunks_{bank_id}".replace("-", "_")
|
||||||
self.vector_table = f"vec_chunks_{bank_id}".replace("-", "_")
|
self.vector_table = f"vec_chunks_{bank_id}".replace("-", "_")
|
||||||
|
self.fts_table = f"fts_chunks_{bank_id}".replace("-", "_")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def create(cls, dimension: int, db_path: str, bank_id: str):
|
async def create(cls, dimension: int, db_path: str, bank_id: str):
|
||||||
|
@ -78,6 +85,14 @@ class SQLiteVecIndex(EmbeddingIndex):
|
||||||
USING vec0(embedding FLOAT[{self.dimension}], id TEXT);
|
USING vec0(embedding FLOAT[{self.dimension}], id TEXT);
|
||||||
""")
|
""")
|
||||||
connection.commit()
|
connection.commit()
|
||||||
|
# FTS5 table (for keyword search) - creating both the tables by default. Will use the relevant one
|
||||||
|
# based on query. Implementation of the change on client side will allow passing the search_mode option
|
||||||
|
# during initialization to make it easier to create the table that is required.
|
||||||
|
cur.execute(f"""
|
||||||
|
CREATE VIRTUAL TABLE IF NOT EXISTS {self.fts_table}
|
||||||
|
USING fts5(id, content);
|
||||||
|
""")
|
||||||
|
connection.commit()
|
||||||
finally:
|
finally:
|
||||||
cur.close()
|
cur.close()
|
||||||
connection.close()
|
connection.close()
|
||||||
|
@ -91,6 +106,7 @@ class SQLiteVecIndex(EmbeddingIndex):
|
||||||
try:
|
try:
|
||||||
cur.execute(f"DROP TABLE IF EXISTS {self.metadata_table};")
|
cur.execute(f"DROP TABLE IF EXISTS {self.metadata_table};")
|
||||||
cur.execute(f"DROP TABLE IF EXISTS {self.vector_table};")
|
cur.execute(f"DROP TABLE IF EXISTS {self.vector_table};")
|
||||||
|
cur.execute(f"DROP TABLE IF EXISTS {self.fts_table};")
|
||||||
connection.commit()
|
connection.commit()
|
||||||
finally:
|
finally:
|
||||||
cur.close()
|
cur.close()
|
||||||
|
@ -104,6 +120,7 @@ class SQLiteVecIndex(EmbeddingIndex):
|
||||||
For each chunk, we insert its JSON into the metadata table and then insert its
|
For each chunk, we insert its JSON into the metadata table and then insert its
|
||||||
embedding (serialized to raw bytes) into the virtual table using the assigned rowid.
|
embedding (serialized to raw bytes) into the virtual table using the assigned rowid.
|
||||||
If any insert fails, the transaction is rolled back to maintain consistency.
|
If any insert fails, the transaction is rolled back to maintain consistency.
|
||||||
|
Also inserts chunk content into FTS table for keyword search support.
|
||||||
"""
|
"""
|
||||||
assert all(isinstance(chunk.content, str) for chunk in chunks), "SQLiteVecIndex only supports text chunks"
|
assert all(isinstance(chunk.content, str) for chunk in chunks), "SQLiteVecIndex only supports text chunks"
|
||||||
|
|
||||||
|
@ -112,18 +129,16 @@ class SQLiteVecIndex(EmbeddingIndex):
|
||||||
cur = connection.cursor()
|
cur = connection.cursor()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Start transaction a single transcation for all batches
|
|
||||||
cur.execute("BEGIN TRANSACTION")
|
cur.execute("BEGIN TRANSACTION")
|
||||||
for i in range(0, len(chunks), batch_size):
|
for i in range(0, len(chunks), batch_size):
|
||||||
batch_chunks = chunks[i : i + batch_size]
|
batch_chunks = chunks[i : i + batch_size]
|
||||||
batch_embeddings = embeddings[i : i + batch_size]
|
batch_embeddings = embeddings[i : i + batch_size]
|
||||||
# Prepare metadata inserts
|
|
||||||
|
# Insert metadata
|
||||||
metadata_data = [
|
metadata_data = [
|
||||||
(generate_chunk_id(chunk.metadata["document_id"], chunk.content), chunk.model_dump_json())
|
(generate_chunk_id(chunk.metadata["document_id"], chunk.content), chunk.model_dump_json())
|
||||||
for chunk in batch_chunks
|
for chunk in batch_chunks
|
||||||
if isinstance(chunk.content, str)
|
|
||||||
]
|
]
|
||||||
# Insert metadata (ON CONFLICT to avoid duplicates)
|
|
||||||
cur.executemany(
|
cur.executemany(
|
||||||
f"""
|
f"""
|
||||||
INSERT INTO {self.metadata_table} (id, chunk)
|
INSERT INTO {self.metadata_table} (id, chunk)
|
||||||
|
@ -132,21 +147,43 @@ class SQLiteVecIndex(EmbeddingIndex):
|
||||||
""",
|
""",
|
||||||
metadata_data,
|
metadata_data,
|
||||||
)
|
)
|
||||||
# Prepare embeddings inserts
|
|
||||||
|
# Insert vector embeddings
|
||||||
embedding_data = [
|
embedding_data = [
|
||||||
|
(
|
||||||
(
|
(
|
||||||
generate_chunk_id(chunk.metadata["document_id"], chunk.content),
|
generate_chunk_id(chunk.metadata["document_id"], chunk.content),
|
||||||
serialize_vector(emb.tolist()),
|
serialize_vector(emb.tolist()),
|
||||||
)
|
)
|
||||||
|
)
|
||||||
for chunk, emb in zip(batch_chunks, batch_embeddings, strict=True)
|
for chunk, emb in zip(batch_chunks, batch_embeddings, strict=True)
|
||||||
if isinstance(chunk.content, str)
|
|
||||||
]
|
]
|
||||||
# Insert embeddings in batch
|
cur.executemany(
|
||||||
cur.executemany(f"INSERT INTO {self.vector_table} (id, embedding) VALUES (?, ?);", embedding_data)
|
f"INSERT INTO {self.vector_table} (id, embedding) VALUES (?, ?);",
|
||||||
|
embedding_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Insert FTS content
|
||||||
|
fts_data = [
|
||||||
|
(generate_chunk_id(chunk.metadata["document_id"], chunk.content), chunk.content)
|
||||||
|
for chunk in batch_chunks
|
||||||
|
]
|
||||||
|
# DELETE existing entries with same IDs (FTS5 doesn't support ON CONFLICT)
|
||||||
|
cur.executemany(
|
||||||
|
f"DELETE FROM {self.fts_table} WHERE id = ?;",
|
||||||
|
[(row[0],) for row in fts_data],
|
||||||
|
)
|
||||||
|
|
||||||
|
# INSERT new entries
|
||||||
|
cur.executemany(
|
||||||
|
f"INSERT INTO {self.fts_table} (id, content) VALUES (?, ?);",
|
||||||
|
fts_data,
|
||||||
|
)
|
||||||
|
|
||||||
connection.commit()
|
connection.commit()
|
||||||
|
|
||||||
except sqlite3.Error as e:
|
except sqlite3.Error as e:
|
||||||
connection.rollback() # Rollback on failure
|
connection.rollback()
|
||||||
logger.error(f"Error inserting into {self.vector_table}: {e}")
|
logger.error(f"Error inserting into {self.vector_table}: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
@ -154,22 +191,25 @@ class SQLiteVecIndex(EmbeddingIndex):
|
||||||
cur.close()
|
cur.close()
|
||||||
connection.close()
|
connection.close()
|
||||||
|
|
||||||
# Process all batches in a single thread
|
# Run batch insertion in a background thread
|
||||||
await asyncio.to_thread(_execute_all_batch_inserts)
|
await asyncio.to_thread(_execute_all_batch_inserts)
|
||||||
|
|
||||||
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
async def query_vector(
|
||||||
|
self,
|
||||||
|
embedding: NDArray,
|
||||||
|
k: int,
|
||||||
|
score_threshold: float,
|
||||||
|
) -> QueryChunksResponse:
|
||||||
"""
|
"""
|
||||||
Query for the k most similar chunks. We convert the query embedding to a blob and run a SQL query
|
Performs vector-based search using a virtual table for vector similarity.
|
||||||
against the virtual table. The SQL joins the metadata table to recover the chunk JSON.
|
|
||||||
"""
|
"""
|
||||||
emb_list = embedding.tolist() if isinstance(embedding, np.ndarray) else list(embedding)
|
|
||||||
emb_blob = serialize_vector(emb_list)
|
|
||||||
|
|
||||||
def _execute_query():
|
def _execute_query():
|
||||||
connection = _create_sqlite_connection(self.db_path)
|
connection = _create_sqlite_connection(self.db_path)
|
||||||
cur = connection.cursor()
|
cur = connection.cursor()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
emb_list = embedding.tolist() if isinstance(embedding, np.ndarray) else list(embedding)
|
||||||
|
emb_blob = serialize_vector(emb_list)
|
||||||
query_sql = f"""
|
query_sql = f"""
|
||||||
SELECT m.id, m.chunk, v.distance
|
SELECT m.id, m.chunk, v.distance
|
||||||
FROM {self.vector_table} AS v
|
FROM {self.vector_table} AS v
|
||||||
|
@ -184,17 +224,66 @@ class SQLiteVecIndex(EmbeddingIndex):
|
||||||
connection.close()
|
connection.close()
|
||||||
|
|
||||||
rows = await asyncio.to_thread(_execute_query)
|
rows = await asyncio.to_thread(_execute_query)
|
||||||
|
|
||||||
chunks, scores = [], []
|
chunks, scores = [], []
|
||||||
for _id, chunk_json, distance in rows:
|
for row in rows:
|
||||||
|
_id, chunk_json, distance = row
|
||||||
|
score = 1.0 / distance if distance != 0 else float("inf")
|
||||||
|
if score < score_threshold:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
chunk = Chunk.model_validate_json(chunk_json)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error parsing chunk JSON for id {_id}: {e}")
|
||||||
|
continue
|
||||||
|
chunks.append(chunk)
|
||||||
|
scores.append(score)
|
||||||
|
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||||
|
|
||||||
|
async def query_keyword(
|
||||||
|
self,
|
||||||
|
query_string: str,
|
||||||
|
k: int,
|
||||||
|
score_threshold: float,
|
||||||
|
) -> QueryChunksResponse:
|
||||||
|
"""
|
||||||
|
Performs keyword-based search using SQLite FTS5 for relevance-ranked full-text search.
|
||||||
|
"""
|
||||||
|
if query_string is None:
|
||||||
|
raise ValueError("query_string is required for keyword search.")
|
||||||
|
|
||||||
|
def _execute_query():
|
||||||
|
connection = _create_sqlite_connection(self.db_path)
|
||||||
|
cur = connection.cursor()
|
||||||
|
try:
|
||||||
|
query_sql = f"""
|
||||||
|
SELECT DISTINCT m.id, m.chunk, bm25({self.fts_table}) AS score
|
||||||
|
FROM {self.fts_table} AS f
|
||||||
|
JOIN {self.metadata_table} AS m ON m.id = f.id
|
||||||
|
WHERE f.content MATCH ?
|
||||||
|
ORDER BY score ASC
|
||||||
|
LIMIT ?;
|
||||||
|
"""
|
||||||
|
cur.execute(query_sql, (query_string, k))
|
||||||
|
return cur.fetchall()
|
||||||
|
finally:
|
||||||
|
cur.close()
|
||||||
|
connection.close()
|
||||||
|
|
||||||
|
rows = await asyncio.to_thread(_execute_query)
|
||||||
|
chunks, scores = [], []
|
||||||
|
for row in rows:
|
||||||
|
_id, chunk_json, score = row
|
||||||
|
# BM25 scores returned by sqlite-vec are NEGATED (i.e., more relevant = more negative).
|
||||||
|
# This design is intentional to simplify sorting by ascending score.
|
||||||
|
# Reference: https://alexgarcia.xyz/blog/2024/sqlite-vec-hybrid-search/index.html
|
||||||
|
if score > -score_threshold:
|
||||||
|
continue
|
||||||
try:
|
try:
|
||||||
chunk = Chunk.model_validate_json(chunk_json)
|
chunk = Chunk.model_validate_json(chunk_json)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error parsing chunk JSON for id {_id}: {e}")
|
logger.error(f"Error parsing chunk JSON for id {_id}: {e}")
|
||||||
continue
|
continue
|
||||||
chunks.append(chunk)
|
chunks.append(chunk)
|
||||||
# Mimic the Faiss scoring: score = 1/distance (avoid division by zero)
|
|
||||||
score = 1.0 / distance if distance != 0 else float("inf")
|
|
||||||
scores.append(score)
|
scores.append(score)
|
||||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||||
|
|
||||||
|
|
|
@ -63,4 +63,14 @@ def available_providers() -> list[ProviderSpec]:
|
||||||
config_class="llama_stack.providers.remote.safety.nvidia.NVIDIASafetyConfig",
|
config_class="llama_stack.providers.remote.safety.nvidia.NVIDIASafetyConfig",
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
|
remote_provider_spec(
|
||||||
|
api=Api.safety,
|
||||||
|
adapter=AdapterSpec(
|
||||||
|
adapter_type="sambanova",
|
||||||
|
pip_packages=["litellm"],
|
||||||
|
module="llama_stack.providers.remote.safety.sambanova",
|
||||||
|
config_class="llama_stack.providers.remote.safety.sambanova.SambaNovaSafetyConfig",
|
||||||
|
provider_data_validator="llama_stack.providers.remote.safety.sambanova.config.SambaNovaProviderDataValidator",
|
||||||
|
),
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
|
@ -80,8 +80,9 @@ def available_providers() -> list[ProviderSpec]:
|
||||||
adapter=AdapterSpec(
|
adapter=AdapterSpec(
|
||||||
adapter_type="model-context-protocol",
|
adapter_type="model-context-protocol",
|
||||||
module="llama_stack.providers.remote.tool_runtime.model_context_protocol",
|
module="llama_stack.providers.remote.tool_runtime.model_context_protocol",
|
||||||
config_class="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.ModelContextProtocolConfig",
|
config_class="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.MCPProviderConfig",
|
||||||
pip_packages=["mcp"],
|
pip_packages=["mcp"],
|
||||||
|
provider_data_validator="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.MCPProviderDataValidator",
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
|
@ -92,8 +92,11 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin):
|
||||||
if prompt_logprobs is not None:
|
if prompt_logprobs is not None:
|
||||||
logging.warning("prompt_logprobs is not supported by the OpenAI API. Ignoring.")
|
logging.warning("prompt_logprobs is not supported by the OpenAI API. Ignoring.")
|
||||||
|
|
||||||
|
model_id = (await self.model_store.get_model(model)).provider_resource_id
|
||||||
|
if model_id.startswith("openai/"):
|
||||||
|
model_id = model_id[len("openai/") :]
|
||||||
params = await prepare_openai_completion_params(
|
params = await prepare_openai_completion_params(
|
||||||
model=(await self.model_store.get_model(model)).provider_resource_id,
|
model=model_id,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
best_of=best_of,
|
best_of=best_of,
|
||||||
echo=echo,
|
echo=echo,
|
||||||
|
@ -139,8 +142,11 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin):
|
||||||
top_p: float | None = None,
|
top_p: float | None = None,
|
||||||
user: str | None = None,
|
user: str | None = None,
|
||||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||||
|
model_id = (await self.model_store.get_model(model)).provider_resource_id
|
||||||
|
if model_id.startswith("openai/"):
|
||||||
|
model_id = model_id[len("openai/") :]
|
||||||
params = await prepare_openai_completion_params(
|
params = await prepare_openai_completion_params(
|
||||||
model=(await self.model_store.get_model(model)).provider_resource_id,
|
model=model_id,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
frequency_penalty=frequency_penalty,
|
frequency_penalty=frequency_penalty,
|
||||||
function_call=function_call,
|
function_call=function_call,
|
||||||
|
|
|
@ -4,8 +4,9 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
from llama_stack.schema_utils import json_schema_type
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
@ -24,11 +25,27 @@ class VLLMInferenceAdapterConfig(BaseModel):
|
||||||
default="fake",
|
default="fake",
|
||||||
description="The API token",
|
description="The API token",
|
||||||
)
|
)
|
||||||
tls_verify: bool = Field(
|
tls_verify: bool | str = Field(
|
||||||
default=True,
|
default=True,
|
||||||
description="Whether to verify TLS certificates",
|
description="Whether to verify TLS certificates. Can be a boolean or a path to a CA certificate file.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@field_validator("tls_verify")
|
||||||
|
@classmethod
|
||||||
|
def validate_tls_verify(cls, v):
|
||||||
|
if isinstance(v, str):
|
||||||
|
# Check if it's a boolean string
|
||||||
|
if v.lower() in ("true", "false"):
|
||||||
|
return v.lower() == "true"
|
||||||
|
# Otherwise, treat it as a cert path
|
||||||
|
cert_path = Path(v).expanduser().resolve()
|
||||||
|
if not cert_path.exists():
|
||||||
|
raise ValueError(f"TLS certificate file does not exist: {v}")
|
||||||
|
if not cert_path.is_file():
|
||||||
|
raise ValueError(f"TLS certificate path is not a file: {v}")
|
||||||
|
return v
|
||||||
|
return v
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(
|
def sample_run_config(
|
||||||
cls,
|
cls,
|
||||||
|
|
|
@ -313,7 +313,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
return AsyncOpenAI(
|
return AsyncOpenAI(
|
||||||
base_url=self.config.url,
|
base_url=self.config.url,
|
||||||
api_key=self.config.api_token,
|
api_key=self.config.api_token,
|
||||||
http_client=None if self.config.tls_verify else httpx.AsyncClient(verify=False),
|
http_client=httpx.AsyncClient(verify=self.config.tls_verify),
|
||||||
)
|
)
|
||||||
|
|
||||||
async def completion(
|
async def completion(
|
||||||
|
|
|
@ -224,7 +224,7 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
training_config: TrainingConfig - Configuration for training
|
training_config: TrainingConfig - Configuration for training
|
||||||
model: str - Model identifier
|
model: str - NeMo Customizer configuration name
|
||||||
algorithm_config: Optional[AlgorithmConfig] - Algorithm-specific configuration
|
algorithm_config: Optional[AlgorithmConfig] - Algorithm-specific configuration
|
||||||
checkpoint_dir: Optional[str] - Directory containing model checkpoints, ignored atm
|
checkpoint_dir: Optional[str] - Directory containing model checkpoints, ignored atm
|
||||||
job_uuid: str - Unique identifier for the job, ignored atm
|
job_uuid: str - Unique identifier for the job, ignored atm
|
||||||
|
@ -299,9 +299,6 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
||||||
|
|
||||||
User is informed about unsupported parameters via warnings.
|
User is informed about unsupported parameters via warnings.
|
||||||
"""
|
"""
|
||||||
# Map model to nvidia model name
|
|
||||||
# See `_MODEL_ENTRIES` for supported models
|
|
||||||
nvidia_model = self.get_provider_model_id(model)
|
|
||||||
|
|
||||||
# Check for unsupported method parameters
|
# Check for unsupported method parameters
|
||||||
unsupported_method_params = []
|
unsupported_method_params = []
|
||||||
|
@ -347,7 +344,7 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
||||||
|
|
||||||
# Prepare base job configuration
|
# Prepare base job configuration
|
||||||
job_config = {
|
job_config = {
|
||||||
"config": nvidia_model,
|
"config": model,
|
||||||
"dataset": {
|
"dataset": {
|
||||||
"name": training_config["data_config"]["dataset_id"],
|
"name": training_config["data_config"]["dataset_id"],
|
||||||
"namespace": self.config.dataset_namespace,
|
"namespace": self.config.dataset_namespace,
|
||||||
|
|
18
llama_stack/providers/remote/safety/sambanova/__init__.py
Normal file
18
llama_stack/providers/remote/safety/sambanova/__init__.py
Normal file
|
@ -0,0 +1,18 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from .config import SambaNovaSafetyConfig
|
||||||
|
|
||||||
|
|
||||||
|
async def get_adapter_impl(config: SambaNovaSafetyConfig, _deps) -> Any:
|
||||||
|
from .sambanova import SambaNovaSafetyAdapter
|
||||||
|
|
||||||
|
impl = SambaNovaSafetyAdapter(config)
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
37
llama_stack/providers/remote/safety/sambanova/config.py
Normal file
37
llama_stack/providers/remote/safety/sambanova/config.py
Normal file
|
@ -0,0 +1,37 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, SecretStr
|
||||||
|
|
||||||
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
|
||||||
|
class SambaNovaProviderDataValidator(BaseModel):
|
||||||
|
sambanova_api_key: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Sambanova Cloud API key",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class SambaNovaSafetyConfig(BaseModel):
|
||||||
|
url: str = Field(
|
||||||
|
default="https://api.sambanova.ai/v1",
|
||||||
|
description="The URL for the SambaNova AI server",
|
||||||
|
)
|
||||||
|
api_key: SecretStr | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="The SambaNova cloud API Key",
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def sample_run_config(cls, api_key: str = "${env.SAMBANOVA_API_KEY}", **kwargs) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"url": "https://api.sambanova.ai/v1",
|
||||||
|
"api_key": api_key,
|
||||||
|
}
|
100
llama_stack/providers/remote/safety/sambanova/sambanova.py
Normal file
100
llama_stack/providers/remote/safety/sambanova/sambanova.py
Normal file
|
@ -0,0 +1,100 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from llama_stack.apis.inference import Message
|
||||||
|
from llama_stack.apis.safety import (
|
||||||
|
RunShieldResponse,
|
||||||
|
Safety,
|
||||||
|
SafetyViolation,
|
||||||
|
ViolationLevel,
|
||||||
|
)
|
||||||
|
from llama_stack.apis.shields import Shield
|
||||||
|
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||||
|
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
||||||
|
from llama_stack.providers.utils.inference.openai_compat import convert_message_to_openai_dict_new
|
||||||
|
|
||||||
|
from .config import SambaNovaSafetyConfig
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?"
|
||||||
|
|
||||||
|
|
||||||
|
class SambaNovaSafetyAdapter(Safety, ShieldsProtocolPrivate, NeedsRequestProviderData):
|
||||||
|
def __init__(self, config: SambaNovaSafetyConfig) -> None:
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _get_api_key(self) -> str:
|
||||||
|
config_api_key = self.config.api_key if self.config.api_key else None
|
||||||
|
if config_api_key:
|
||||||
|
return config_api_key.get_secret_value()
|
||||||
|
else:
|
||||||
|
provider_data = self.get_request_provider_data()
|
||||||
|
if provider_data is None or not provider_data.sambanova_api_key:
|
||||||
|
raise ValueError(
|
||||||
|
'Pass Sambanova API Key in the header X-LlamaStack-Provider-Data as { "sambanova_api_key": <your api key> }'
|
||||||
|
)
|
||||||
|
return provider_data.sambanova_api_key
|
||||||
|
|
||||||
|
async def register_shield(self, shield: Shield) -> None:
|
||||||
|
list_models_url = self.config.url + "/models"
|
||||||
|
try:
|
||||||
|
response = requests.get(list_models_url)
|
||||||
|
response.raise_for_status()
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
raise RuntimeError(f"Request to {list_models_url} failed") from e
|
||||||
|
available_models = [model.get("id") for model in response.json().get("data", {})]
|
||||||
|
if (
|
||||||
|
len(available_models) == 0
|
||||||
|
or "guard" not in shield.provider_resource_id.lower()
|
||||||
|
or shield.provider_resource_id.split("sambanova/")[-1] not in available_models
|
||||||
|
):
|
||||||
|
raise ValueError(f"Shield {shield.provider_resource_id} not found in SambaNova")
|
||||||
|
|
||||||
|
async def run_shield(
|
||||||
|
self, shield_id: str, messages: list[Message], params: dict[str, Any] | None = None
|
||||||
|
) -> RunShieldResponse:
|
||||||
|
shield = await self.shield_store.get_shield(shield_id)
|
||||||
|
if not shield:
|
||||||
|
raise ValueError(f"Shield {shield_id} not found")
|
||||||
|
|
||||||
|
shield_params = shield.params
|
||||||
|
logger.debug(f"run_shield::{shield_params}::messages={messages}")
|
||||||
|
content_messages = [await convert_message_to_openai_dict_new(m) for m in messages]
|
||||||
|
logger.debug(f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:")
|
||||||
|
|
||||||
|
response = litellm.completion(
|
||||||
|
model=shield.provider_resource_id, messages=content_messages, api_key=self._get_api_key()
|
||||||
|
)
|
||||||
|
shield_message = response.choices[0].message.content
|
||||||
|
|
||||||
|
if "unsafe" in shield_message.lower():
|
||||||
|
user_message = CANNED_RESPONSE_TEXT
|
||||||
|
violation_type = shield_message.split("\n")[-1]
|
||||||
|
metadata = {"violation_type": violation_type}
|
||||||
|
|
||||||
|
return RunShieldResponse(
|
||||||
|
violation=SafetyViolation(
|
||||||
|
user_message=user_message,
|
||||||
|
violation_level=ViolationLevel.ERROR,
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return RunShieldResponse()
|
|
@ -12,19 +12,19 @@ import httpx
|
||||||
from llama_stack.apis.common.content_types import URL
|
from llama_stack.apis.common.content_types import URL
|
||||||
from llama_stack.apis.tools import (
|
from llama_stack.apis.tools import (
|
||||||
ListToolDefsResponse,
|
ListToolDefsResponse,
|
||||||
Tool,
|
|
||||||
ToolDef,
|
ToolDef,
|
||||||
|
ToolGroup,
|
||||||
ToolInvocationResult,
|
ToolInvocationResult,
|
||||||
ToolParameter,
|
ToolParameter,
|
||||||
ToolRuntime,
|
ToolRuntime,
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
|
||||||
|
|
||||||
from .config import BingSearchToolConfig
|
from .config import BingSearchToolConfig
|
||||||
|
|
||||||
|
|
||||||
class BingSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
|
class BingSearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
|
||||||
def __init__(self, config: BingSearchToolConfig):
|
def __init__(self, config: BingSearchToolConfig):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.url = "https://api.bing.microsoft.com/v7.0/search"
|
self.url = "https://api.bing.microsoft.com/v7.0/search"
|
||||||
|
@ -32,10 +32,10 @@ class BingSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestP
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def register_tool(self, tool: Tool) -> None:
|
async def register_toolgroup(self, toolgroup: ToolGroup) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def unregister_tool(self, tool_id: str) -> None:
|
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
|
||||||
return
|
return
|
||||||
|
|
||||||
def _get_api_key(self) -> str:
|
def _get_api_key(self) -> str:
|
||||||
|
|
|
@ -11,30 +11,30 @@ import httpx
|
||||||
from llama_stack.apis.common.content_types import URL
|
from llama_stack.apis.common.content_types import URL
|
||||||
from llama_stack.apis.tools import (
|
from llama_stack.apis.tools import (
|
||||||
ListToolDefsResponse,
|
ListToolDefsResponse,
|
||||||
Tool,
|
|
||||||
ToolDef,
|
ToolDef,
|
||||||
|
ToolGroup,
|
||||||
ToolInvocationResult,
|
ToolInvocationResult,
|
||||||
ToolParameter,
|
ToolParameter,
|
||||||
ToolRuntime,
|
ToolRuntime,
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||||
from llama_stack.models.llama.datatypes import BuiltinTool
|
from llama_stack.models.llama.datatypes import BuiltinTool
|
||||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
|
||||||
|
|
||||||
from .config import BraveSearchToolConfig
|
from .config import BraveSearchToolConfig
|
||||||
|
|
||||||
|
|
||||||
class BraveSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
|
class BraveSearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
|
||||||
def __init__(self, config: BraveSearchToolConfig):
|
def __init__(self, config: BraveSearchToolConfig):
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def register_tool(self, tool: Tool) -> None:
|
async def register_toolgroup(self, toolgroup: ToolGroup) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def unregister_tool(self, tool_id: str) -> None:
|
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
|
||||||
return
|
return
|
||||||
|
|
||||||
def _get_api_key(self) -> str:
|
def _get_api_key(self) -> str:
|
||||||
|
|
|
@ -4,18 +4,12 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from .config import MCPProviderConfig
|
||||||
|
|
||||||
from .config import ModelContextProtocolConfig
|
|
||||||
|
|
||||||
|
|
||||||
class ModelContextProtocolToolProviderDataValidator(BaseModel):
|
async def get_adapter_impl(config: MCPProviderConfig, _deps):
|
||||||
api_key: str
|
|
||||||
|
|
||||||
|
|
||||||
async def get_adapter_impl(config: ModelContextProtocolConfig, _deps):
|
|
||||||
from .model_context_protocol import ModelContextProtocolToolRuntimeImpl
|
from .model_context_protocol import ModelContextProtocolToolRuntimeImpl
|
||||||
|
|
||||||
impl = ModelContextProtocolToolRuntimeImpl(config)
|
impl = ModelContextProtocolToolRuntimeImpl(config, _deps)
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
|
@ -9,7 +9,12 @@ from typing import Any
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
class ModelContextProtocolConfig(BaseModel):
|
class MCPProviderDataValidator(BaseModel):
|
||||||
|
# mcp_endpoint => dict of headers to send
|
||||||
|
mcp_headers: dict[str, dict[str, str]] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class MCPProviderConfig(BaseModel):
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
||||||
return {}
|
return {}
|
||||||
|
|
|
@ -7,61 +7,45 @@
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
from mcp import ClientSession
|
|
||||||
from mcp.client.sse import sse_client
|
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import URL
|
from llama_stack.apis.common.content_types import URL
|
||||||
|
from llama_stack.apis.datatypes import Api
|
||||||
from llama_stack.apis.tools import (
|
from llama_stack.apis.tools import (
|
||||||
ListToolDefsResponse,
|
ListToolDefsResponse,
|
||||||
ToolDef,
|
ToolGroup,
|
||||||
ToolInvocationResult,
|
ToolInvocationResult,
|
||||||
ToolParameter,
|
|
||||||
ToolRuntime,
|
ToolRuntime,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
|
||||||
|
from llama_stack.providers.utils.tools.mcp import invoke_mcp_tool, list_mcp_tools
|
||||||
|
|
||||||
from .config import ModelContextProtocolConfig
|
from .config import MCPProviderConfig
|
||||||
|
|
||||||
|
logger = get_logger(__name__, category="tools")
|
||||||
|
|
||||||
|
|
||||||
class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
|
||||||
def __init__(self, config: ModelContextProtocolConfig):
|
def __init__(self, config: MCPProviderConfig, _deps: dict[Api, Any]):
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
async def register_toolgroup(self, toolgroup: ToolGroup) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
|
||||||
|
return
|
||||||
|
|
||||||
async def list_runtime_tools(
|
async def list_runtime_tools(
|
||||||
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
|
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
|
||||||
) -> ListToolDefsResponse:
|
) -> ListToolDefsResponse:
|
||||||
|
# this endpoint should be retrieved by getting the tool group right?
|
||||||
if mcp_endpoint is None:
|
if mcp_endpoint is None:
|
||||||
raise ValueError("mcp_endpoint is required")
|
raise ValueError("mcp_endpoint is required")
|
||||||
|
headers = await self.get_headers_from_request(mcp_endpoint.uri)
|
||||||
tools = []
|
return await list_mcp_tools(mcp_endpoint.uri, headers)
|
||||||
async with sse_client(mcp_endpoint.uri) as streams:
|
|
||||||
async with ClientSession(*streams) as session:
|
|
||||||
await session.initialize()
|
|
||||||
tools_result = await session.list_tools()
|
|
||||||
for tool in tools_result.tools:
|
|
||||||
parameters = []
|
|
||||||
for param_name, param_schema in tool.inputSchema.get("properties", {}).items():
|
|
||||||
parameters.append(
|
|
||||||
ToolParameter(
|
|
||||||
name=param_name,
|
|
||||||
parameter_type=param_schema.get("type", "string"),
|
|
||||||
description=param_schema.get("description", ""),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
tools.append(
|
|
||||||
ToolDef(
|
|
||||||
name=tool.name,
|
|
||||||
description=tool.description,
|
|
||||||
parameters=parameters,
|
|
||||||
metadata={
|
|
||||||
"endpoint": mcp_endpoint.uri,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return ListToolDefsResponse(data=tools)
|
|
||||||
|
|
||||||
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
|
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
|
||||||
tool = await self.tool_store.get_tool(tool_name)
|
tool = await self.tool_store.get_tool(tool_name)
|
||||||
|
@ -71,12 +55,19 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
||||||
if urlparse(endpoint).scheme not in ("http", "https"):
|
if urlparse(endpoint).scheme not in ("http", "https"):
|
||||||
raise ValueError(f"Endpoint {endpoint} is not a valid HTTP(S) URL")
|
raise ValueError(f"Endpoint {endpoint} is not a valid HTTP(S) URL")
|
||||||
|
|
||||||
async with sse_client(endpoint) as streams:
|
headers = await self.get_headers_from_request(endpoint)
|
||||||
async with ClientSession(*streams) as session:
|
return await invoke_mcp_tool(endpoint, headers, tool_name, kwargs)
|
||||||
await session.initialize()
|
|
||||||
result = await session.call_tool(tool.identifier, kwargs)
|
|
||||||
|
|
||||||
return ToolInvocationResult(
|
async def get_headers_from_request(self, mcp_endpoint_uri: str) -> dict[str, str]:
|
||||||
content="\n".join([result.model_dump_json() for result in result.content]),
|
def canonicalize_uri(uri: str) -> str:
|
||||||
error_code=1 if result.isError else 0,
|
return f"{urlparse(uri).netloc or ''}/{urlparse(uri).path or ''}"
|
||||||
)
|
|
||||||
|
headers = {}
|
||||||
|
|
||||||
|
provider_data = self.get_request_provider_data()
|
||||||
|
if provider_data and provider_data.mcp_headers:
|
||||||
|
for uri, values in provider_data.mcp_headers.items():
|
||||||
|
if canonicalize_uri(uri) != canonicalize_uri(mcp_endpoint_uri):
|
||||||
|
continue
|
||||||
|
headers.update(values)
|
||||||
|
return headers
|
||||||
|
|
|
@ -12,29 +12,29 @@ import httpx
|
||||||
from llama_stack.apis.common.content_types import URL
|
from llama_stack.apis.common.content_types import URL
|
||||||
from llama_stack.apis.tools import (
|
from llama_stack.apis.tools import (
|
||||||
ListToolDefsResponse,
|
ListToolDefsResponse,
|
||||||
Tool,
|
|
||||||
ToolDef,
|
ToolDef,
|
||||||
|
ToolGroup,
|
||||||
ToolInvocationResult,
|
ToolInvocationResult,
|
||||||
ToolParameter,
|
ToolParameter,
|
||||||
ToolRuntime,
|
ToolRuntime,
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
|
||||||
|
|
||||||
from .config import TavilySearchToolConfig
|
from .config import TavilySearchToolConfig
|
||||||
|
|
||||||
|
|
||||||
class TavilySearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
|
class TavilySearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
|
||||||
def __init__(self, config: TavilySearchToolConfig):
|
def __init__(self, config: TavilySearchToolConfig):
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def register_tool(self, tool: Tool) -> None:
|
async def register_toolgroup(self, toolgroup: ToolGroup) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def unregister_tool(self, tool_id: str) -> None:
|
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
|
||||||
return
|
return
|
||||||
|
|
||||||
def _get_api_key(self) -> str:
|
def _get_api_key(self) -> str:
|
||||||
|
|
|
@ -12,19 +12,19 @@ import httpx
|
||||||
from llama_stack.apis.common.content_types import URL
|
from llama_stack.apis.common.content_types import URL
|
||||||
from llama_stack.apis.tools import (
|
from llama_stack.apis.tools import (
|
||||||
ListToolDefsResponse,
|
ListToolDefsResponse,
|
||||||
Tool,
|
|
||||||
ToolDef,
|
ToolDef,
|
||||||
|
ToolGroup,
|
||||||
ToolInvocationResult,
|
ToolInvocationResult,
|
||||||
ToolParameter,
|
ToolParameter,
|
||||||
ToolRuntime,
|
ToolRuntime,
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
|
||||||
|
|
||||||
from .config import WolframAlphaToolConfig
|
from .config import WolframAlphaToolConfig
|
||||||
|
|
||||||
|
|
||||||
class WolframAlphaToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
|
class WolframAlphaToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
|
||||||
def __init__(self, config: WolframAlphaToolConfig):
|
def __init__(self, config: WolframAlphaToolConfig):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.url = "https://api.wolframalpha.com/v2/query"
|
self.url = "https://api.wolframalpha.com/v2/query"
|
||||||
|
@ -32,10 +32,10 @@ class WolframAlphaToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsReques
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def register_tool(self, tool: Tool) -> None:
|
async def register_toolgroup(self, toolgroup: ToolGroup) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def unregister_tool(self, tool_id: str) -> None:
|
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
|
||||||
return
|
return
|
||||||
|
|
||||||
def _get_api_key(self) -> str:
|
def _get_api_key(self) -> str:
|
||||||
|
|
|
@ -84,6 +84,14 @@ class ChromaIndex(EmbeddingIndex):
|
||||||
async def delete(self):
|
async def delete(self):
|
||||||
await maybe_await(self.client.delete_collection(self.collection.name))
|
await maybe_await(self.client.delete_collection(self.collection.name))
|
||||||
|
|
||||||
|
async def query_keyword(
|
||||||
|
self,
|
||||||
|
query_string: str,
|
||||||
|
k: int,
|
||||||
|
score_threshold: float,
|
||||||
|
) -> QueryChunksResponse:
|
||||||
|
raise NotImplementedError("Keyword search is not supported in Chroma")
|
||||||
|
|
||||||
|
|
||||||
class ChromaVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
class ChromaVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
|
@ -73,7 +73,7 @@ class MilvusIndex(EmbeddingIndex):
|
||||||
logger.error(f"Error inserting chunks into Milvus collection {self.collection_name}: {e}")
|
logger.error(f"Error inserting chunks into Milvus collection {self.collection_name}: {e}")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||||
search_res = await asyncio.to_thread(
|
search_res = await asyncio.to_thread(
|
||||||
self.client.search,
|
self.client.search,
|
||||||
collection_name=self.collection_name,
|
collection_name=self.collection_name,
|
||||||
|
@ -86,6 +86,14 @@ class MilvusIndex(EmbeddingIndex):
|
||||||
scores = [res["distance"] for res in search_res[0]]
|
scores = [res["distance"] for res in search_res[0]]
|
||||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||||
|
|
||||||
|
async def query_keyword(
|
||||||
|
self,
|
||||||
|
query_string: str,
|
||||||
|
k: int,
|
||||||
|
score_threshold: float,
|
||||||
|
) -> QueryChunksResponse:
|
||||||
|
raise NotImplementedError("Keyword search is not supported in Milvus")
|
||||||
|
|
||||||
|
|
||||||
class MilvusVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
class MilvusVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
|
@ -99,7 +99,7 @@ class PGVectorIndex(EmbeddingIndex):
|
||||||
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||||
execute_values(cur, query, values, template="(%s, %s, %s::vector)")
|
execute_values(cur, query, values, template="(%s, %s, %s::vector)")
|
||||||
|
|
||||||
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||||
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||||
cur.execute(
|
cur.execute(
|
||||||
f"""
|
f"""
|
||||||
|
@ -120,6 +120,14 @@ class PGVectorIndex(EmbeddingIndex):
|
||||||
|
|
||||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||||
|
|
||||||
|
async def query_keyword(
|
||||||
|
self,
|
||||||
|
query_string: str,
|
||||||
|
k: int,
|
||||||
|
score_threshold: float,
|
||||||
|
) -> QueryChunksResponse:
|
||||||
|
raise NotImplementedError("Keyword search is not supported in PGVector")
|
||||||
|
|
||||||
async def delete(self):
|
async def delete(self):
|
||||||
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||||
cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")
|
cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")
|
||||||
|
|
|
@ -68,7 +68,7 @@ class QdrantIndex(EmbeddingIndex):
|
||||||
|
|
||||||
await self.client.upsert(collection_name=self.collection_name, points=points)
|
await self.client.upsert(collection_name=self.collection_name, points=points)
|
||||||
|
|
||||||
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||||
results = (
|
results = (
|
||||||
await self.client.query_points(
|
await self.client.query_points(
|
||||||
collection_name=self.collection_name,
|
collection_name=self.collection_name,
|
||||||
|
@ -95,6 +95,14 @@ class QdrantIndex(EmbeddingIndex):
|
||||||
|
|
||||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||||
|
|
||||||
|
async def query_keyword(
|
||||||
|
self,
|
||||||
|
query_string: str,
|
||||||
|
k: int,
|
||||||
|
score_threshold: float,
|
||||||
|
) -> QueryChunksResponse:
|
||||||
|
raise NotImplementedError("Keyword search is not supported in Qdrant")
|
||||||
|
|
||||||
async def delete(self):
|
async def delete(self):
|
||||||
await self.client.delete_collection(collection_name=self.collection_name)
|
await self.client.delete_collection(collection_name=self.collection_name)
|
||||||
|
|
||||||
|
|
|
@ -55,7 +55,7 @@ class WeaviateIndex(EmbeddingIndex):
|
||||||
# TODO: make this async friendly
|
# TODO: make this async friendly
|
||||||
collection.data.insert_many(data_objects)
|
collection.data.insert_many(data_objects)
|
||||||
|
|
||||||
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||||
collection = self.client.collections.get(self.collection_name)
|
collection = self.client.collections.get(self.collection_name)
|
||||||
|
|
||||||
results = collection.query.near_vector(
|
results = collection.query.near_vector(
|
||||||
|
@ -84,6 +84,14 @@ class WeaviateIndex(EmbeddingIndex):
|
||||||
collection = self.client.collections.get(self.collection_name)
|
collection = self.client.collections.get(self.collection_name)
|
||||||
collection.data.delete_many(where=Filter.by_property("id").contains_any(chunk_ids))
|
collection.data.delete_many(where=Filter.by_property("id").contains_any(chunk_ids))
|
||||||
|
|
||||||
|
async def query_keyword(
|
||||||
|
self,
|
||||||
|
query_string: str,
|
||||||
|
k: int,
|
||||||
|
score_threshold: float,
|
||||||
|
) -> QueryChunksResponse:
|
||||||
|
raise NotImplementedError("Keyword search is not supported in Weaviate")
|
||||||
|
|
||||||
|
|
||||||
class WeaviateVectorIOAdapter(
|
class WeaviateVectorIOAdapter(
|
||||||
VectorIO,
|
VectorIO,
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue