mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 01:03:59 +00:00
Merge branch 'main' into feat/litellm_sambanova_usage
This commit is contained in:
commit
b7f16ac7a6
535 changed files with 23539 additions and 8112 deletions
6
.coveragerc
Normal file
6
.coveragerc
Normal file
|
@ -0,0 +1,6 @@
|
||||||
|
[run]
|
||||||
|
omit =
|
||||||
|
*/tests/*
|
||||||
|
*/llama_stack/providers/*
|
||||||
|
*/llama_stack/templates/*
|
||||||
|
.venv/*
|
26
.github/workflows/install-script-ci.yml
vendored
Normal file
26
.github/workflows/install-script-ci.yml
vendored
Normal file
|
@ -0,0 +1,26 @@
|
||||||
|
name: Installer CI
|
||||||
|
|
||||||
|
on:
|
||||||
|
pull_request:
|
||||||
|
paths:
|
||||||
|
- 'install.sh'
|
||||||
|
push:
|
||||||
|
paths:
|
||||||
|
- 'install.sh'
|
||||||
|
schedule:
|
||||||
|
- cron: '0 2 * * *' # every day at 02:00 UTC
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
lint:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # 4.2.2
|
||||||
|
- name: Run ShellCheck on install.sh
|
||||||
|
run: shellcheck install.sh
|
||||||
|
smoke-test:
|
||||||
|
needs: lint
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # 4.2.2
|
||||||
|
- name: Run installer end-to-end
|
||||||
|
run: ./install.sh
|
136
.github/workflows/integration-auth-tests.yml
vendored
Normal file
136
.github/workflows/integration-auth-tests.yml
vendored
Normal file
|
@ -0,0 +1,136 @@
|
||||||
|
name: Integration Auth Tests
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [ main ]
|
||||||
|
pull_request:
|
||||||
|
branches: [ main ]
|
||||||
|
paths:
|
||||||
|
- 'distributions/**'
|
||||||
|
- 'llama_stack/**'
|
||||||
|
- 'tests/integration/**'
|
||||||
|
- 'uv.lock'
|
||||||
|
- 'pyproject.toml'
|
||||||
|
- 'requirements.txt'
|
||||||
|
- '.github/workflows/integration-auth-tests.yml' # This workflow
|
||||||
|
|
||||||
|
concurrency:
|
||||||
|
group: ${{ github.workflow }}-${{ github.ref }}
|
||||||
|
cancel-in-progress: true
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
test-matrix:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
auth-provider: [kubernetes]
|
||||||
|
fail-fast: false # we want to run all tests regardless of failure
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout repository
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Install uv
|
||||||
|
uses: astral-sh/setup-uv@v5
|
||||||
|
with:
|
||||||
|
python-version: "3.10"
|
||||||
|
|
||||||
|
- name: Set Up Environment and Install Dependencies
|
||||||
|
run: |
|
||||||
|
uv sync --extra dev --extra test
|
||||||
|
uv pip install -e .
|
||||||
|
llama stack build --template ollama --image-type venv
|
||||||
|
|
||||||
|
- name: Install minikube
|
||||||
|
if: ${{ matrix.auth-provider == 'kubernetes' }}
|
||||||
|
uses: medyagh/setup-minikube@latest
|
||||||
|
|
||||||
|
- name: Start minikube
|
||||||
|
if: ${{ matrix.auth-provider == 'kubernetes' }}
|
||||||
|
run: |
|
||||||
|
minikube start
|
||||||
|
kubectl get pods -A
|
||||||
|
|
||||||
|
- name: Configure Kube Auth
|
||||||
|
if: ${{ matrix.auth-provider == 'kubernetes' }}
|
||||||
|
run: |
|
||||||
|
kubectl create namespace llama-stack
|
||||||
|
kubectl create serviceaccount llama-stack-auth -n llama-stack
|
||||||
|
kubectl create rolebinding llama-stack-auth-rolebinding --clusterrole=admin --serviceaccount=llama-stack:llama-stack-auth -n llama-stack
|
||||||
|
kubectl create token llama-stack-auth -n llama-stack > llama-stack-auth-token
|
||||||
|
|
||||||
|
- name: Set Kubernetes Config
|
||||||
|
if: ${{ matrix.auth-provider == 'kubernetes' }}
|
||||||
|
run: |
|
||||||
|
echo "KUBERNETES_API_SERVER_URL=$(kubectl config view --minify -o jsonpath='{.clusters[0].cluster.server}')" >> $GITHUB_ENV
|
||||||
|
echo "KUBERNETES_CA_CERT_PATH=$(kubectl config view --minify -o jsonpath='{.clusters[0].cluster.certificate-authority}')" >> $GITHUB_ENV
|
||||||
|
|
||||||
|
- name: Set Kube Auth Config and run server
|
||||||
|
env:
|
||||||
|
INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct"
|
||||||
|
if: ${{ matrix.auth-provider == 'kubernetes' }}
|
||||||
|
run: |
|
||||||
|
run_dir=$(mktemp -d)
|
||||||
|
cat <<'EOF' > $run_dir/run.yaml
|
||||||
|
version: '2'
|
||||||
|
image_name: kube
|
||||||
|
apis:
|
||||||
|
- agents
|
||||||
|
- datasetio
|
||||||
|
- eval
|
||||||
|
- inference
|
||||||
|
- safety
|
||||||
|
- scoring
|
||||||
|
- telemetry
|
||||||
|
- tool_runtime
|
||||||
|
- vector_io
|
||||||
|
providers:
|
||||||
|
agents:
|
||||||
|
- provider_id: meta-reference
|
||||||
|
provider_type: inline::meta-reference
|
||||||
|
config:
|
||||||
|
persistence_store:
|
||||||
|
type: sqlite
|
||||||
|
namespace: null
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/agents_store.db
|
||||||
|
telemetry:
|
||||||
|
- provider_id: meta-reference
|
||||||
|
provider_type: inline::meta-reference
|
||||||
|
config:
|
||||||
|
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||||
|
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||||
|
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/ollama/trace_store.db}
|
||||||
|
server:
|
||||||
|
port: 8321
|
||||||
|
EOF
|
||||||
|
yq eval '.server.auth = {"provider_type": "${{ matrix.auth-provider }}"}' -i $run_dir/run.yaml
|
||||||
|
yq eval '.server.auth.config = {"api_server_url": "${{ env.KUBERNETES_API_SERVER_URL }}", "ca_cert_path": "${{ env.KUBERNETES_CA_CERT_PATH }}"}' -i $run_dir/run.yaml
|
||||||
|
cat $run_dir/run.yaml
|
||||||
|
|
||||||
|
source .venv/bin/activate
|
||||||
|
nohup uv run llama stack run $run_dir/run.yaml --image-type venv > server.log 2>&1 &
|
||||||
|
|
||||||
|
- name: Wait for Llama Stack server to be ready
|
||||||
|
run: |
|
||||||
|
echo "Waiting for Llama Stack server..."
|
||||||
|
for i in {1..30}; do
|
||||||
|
if curl -s -L -H "Authorization: Bearer $(cat llama-stack-auth-token)" http://localhost:8321/v1/health | grep -q "OK"; then
|
||||||
|
echo "Llama Stack server is up!"
|
||||||
|
if grep -q "Enabling authentication with provider: ${{ matrix.auth-provider }}" server.log; then
|
||||||
|
echo "Llama Stack server is configured to use ${{ matrix.auth-provider }} auth"
|
||||||
|
exit 0
|
||||||
|
else
|
||||||
|
echo "Llama Stack server is not configured to use ${{ matrix.auth-provider }} auth"
|
||||||
|
cat server.log
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
sleep 1
|
||||||
|
done
|
||||||
|
echo "Llama Stack server failed to start"
|
||||||
|
cat server.log
|
||||||
|
exit 1
|
||||||
|
|
||||||
|
- name: Test auth
|
||||||
|
run: |
|
||||||
|
curl -s -L -H "Authorization: Bearer $(cat llama-stack-auth-token)" http://127.0.0.1:8321/v1/providers|jq
|
25
.github/workflows/integration-tests.yml
vendored
25
.github/workflows/integration-tests.yml
vendored
|
@ -6,7 +6,6 @@ on:
|
||||||
pull_request:
|
pull_request:
|
||||||
branches: [ main ]
|
branches: [ main ]
|
||||||
paths:
|
paths:
|
||||||
- 'distributions/**'
|
|
||||||
- 'llama_stack/**'
|
- 'llama_stack/**'
|
||||||
- 'tests/integration/**'
|
- 'tests/integration/**'
|
||||||
- 'uv.lock'
|
- 'uv.lock'
|
||||||
|
@ -34,19 +33,24 @@ jobs:
|
||||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||||
|
|
||||||
- name: Install uv
|
- name: Install uv
|
||||||
uses: astral-sh/setup-uv@0c5e2b8115b80b4c7c5ddf6ffdd634974642d182 # v5.4.1
|
uses: astral-sh/setup-uv@c7f87aa956e4c323abf06d5dec078e358f6b4d04 # v6.0.0
|
||||||
with:
|
with:
|
||||||
python-version: "3.10"
|
python-version: "3.10"
|
||||||
|
activate-environment: true
|
||||||
|
|
||||||
- name: Install and start Ollama
|
- name: Install and start Ollama
|
||||||
run: |
|
run: |
|
||||||
# the ollama installer also starts the ollama service
|
# the ollama installer also starts the ollama service
|
||||||
curl -fsSL https://ollama.com/install.sh | sh
|
curl -fsSL https://ollama.com/install.sh | sh
|
||||||
|
|
||||||
- name: Pull Ollama image
|
# Do NOT cache models - pulling the cache is actually slower than just pulling the model.
|
||||||
|
# It takes ~45 seconds to pull the models from the cache and unpack it, but only 30 seconds to
|
||||||
|
# pull them directly.
|
||||||
|
# Maybe this is because the cache is being pulled at the same time by all the matrix jobs?
|
||||||
|
- name: Pull Ollama models (instruct and embed)
|
||||||
run: |
|
run: |
|
||||||
# TODO: cache the model. OLLAMA_MODELS defaults to ~ollama/.ollama/models.
|
|
||||||
ollama pull llama3.2:3b-instruct-fp16
|
ollama pull llama3.2:3b-instruct-fp16
|
||||||
|
ollama pull all-minilm:latest
|
||||||
|
|
||||||
- name: Set Up Environment and Install Dependencies
|
- name: Set Up Environment and Install Dependencies
|
||||||
run: |
|
run: |
|
||||||
|
@ -106,3 +110,16 @@ jobs:
|
||||||
-k "not(builtin_tool or safety_with_image or code_interpreter or test_rag)" \
|
-k "not(builtin_tool or safety_with_image or code_interpreter or test_rag)" \
|
||||||
--text-model="meta-llama/Llama-3.2-3B-Instruct" \
|
--text-model="meta-llama/Llama-3.2-3B-Instruct" \
|
||||||
--embedding-model=all-MiniLM-L6-v2
|
--embedding-model=all-MiniLM-L6-v2
|
||||||
|
|
||||||
|
- name: Write ollama logs to file
|
||||||
|
run: |
|
||||||
|
sudo journalctl -u ollama.service > ollama.log
|
||||||
|
|
||||||
|
- name: Upload all logs to artifacts
|
||||||
|
if: always()
|
||||||
|
uses: actions/upload-artifact@v4
|
||||||
|
with:
|
||||||
|
name: logs-${{ github.run_id }}-${{ github.run_attempt }}-${{ matrix.client-type }}-${{ matrix.test-type }}
|
||||||
|
path: |
|
||||||
|
*.log
|
||||||
|
retention-days: 1
|
||||||
|
|
4
.github/workflows/pre-commit.yml
vendored
4
.github/workflows/pre-commit.yml
vendored
|
@ -18,7 +18,7 @@ jobs:
|
||||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||||
|
|
||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0
|
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0
|
||||||
with:
|
with:
|
||||||
python-version: '3.11'
|
python-version: '3.11'
|
||||||
cache: pip
|
cache: pip
|
||||||
|
@ -27,6 +27,8 @@ jobs:
|
||||||
.pre-commit-config.yaml
|
.pre-commit-config.yaml
|
||||||
|
|
||||||
- uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1
|
- uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1
|
||||||
|
env:
|
||||||
|
SKIP: no-commit-to-branch
|
||||||
|
|
||||||
- name: Verify if there are any diff files after pre-commit
|
- name: Verify if there are any diff files after pre-commit
|
||||||
run: |
|
run: |
|
||||||
|
|
121
.github/workflows/providers-build.yml
vendored
121
.github/workflows/providers-build.yml
vendored
|
@ -51,12 +51,12 @@ jobs:
|
||||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||||
|
|
||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0
|
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0
|
||||||
with:
|
with:
|
||||||
python-version: '3.10'
|
python-version: '3.10'
|
||||||
|
|
||||||
- name: Install uv
|
- name: Install uv
|
||||||
uses: astral-sh/setup-uv@0c5e2b8115b80b4c7c5ddf6ffdd634974642d182 # v5.4.1
|
uses: astral-sh/setup-uv@c7f87aa956e4c323abf06d5dec078e358f6b4d04 # v6.0.0
|
||||||
with:
|
with:
|
||||||
python-version: "3.10"
|
python-version: "3.10"
|
||||||
|
|
||||||
|
@ -81,3 +81,120 @@ jobs:
|
||||||
run: |
|
run: |
|
||||||
source test/bin/activate
|
source test/bin/activate
|
||||||
uv pip list
|
uv pip list
|
||||||
|
|
||||||
|
build-single-provider:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Checkout repository
|
||||||
|
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||||
|
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0
|
||||||
|
with:
|
||||||
|
python-version: '3.10'
|
||||||
|
|
||||||
|
- name: Install uv
|
||||||
|
uses: astral-sh/setup-uv@c7f87aa956e4c323abf06d5dec078e358f6b4d04 # v6.0.0
|
||||||
|
with:
|
||||||
|
python-version: "3.10"
|
||||||
|
|
||||||
|
- name: Install LlamaStack
|
||||||
|
run: |
|
||||||
|
uv venv
|
||||||
|
source .venv/bin/activate
|
||||||
|
uv pip install -e .
|
||||||
|
|
||||||
|
- name: Build a single provider
|
||||||
|
run: |
|
||||||
|
USE_COPY_NOT_MOUNT=true LLAMA_STACK_DIR=. uv run llama stack build --image-type venv --image-name test --providers inference=remote::ollama
|
||||||
|
|
||||||
|
build-custom-container-distribution:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Checkout repository
|
||||||
|
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||||
|
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0
|
||||||
|
with:
|
||||||
|
python-version: '3.10'
|
||||||
|
|
||||||
|
- name: Install uv
|
||||||
|
uses: astral-sh/setup-uv@c7f87aa956e4c323abf06d5dec078e358f6b4d04 # v6.0.0
|
||||||
|
with:
|
||||||
|
python-version: "3.10"
|
||||||
|
|
||||||
|
- name: Install LlamaStack
|
||||||
|
run: |
|
||||||
|
uv venv
|
||||||
|
source .venv/bin/activate
|
||||||
|
uv pip install -e .
|
||||||
|
|
||||||
|
- name: Build a single provider
|
||||||
|
run: |
|
||||||
|
yq -i '.image_type = "container"' llama_stack/templates/dev/build.yaml
|
||||||
|
yq -i '.image_name = "test"' llama_stack/templates/dev/build.yaml
|
||||||
|
USE_COPY_NOT_MOUNT=true LLAMA_STACK_DIR=. uv run llama stack build --config llama_stack/templates/dev/build.yaml
|
||||||
|
|
||||||
|
- name: Inspect the container image entrypoint
|
||||||
|
run: |
|
||||||
|
IMAGE_ID=$(docker images --format "{{.Repository}}:{{.Tag}}" | head -n 1)
|
||||||
|
entrypoint=$(docker inspect --format '{{ .Config.Entrypoint }}' $IMAGE_ID)
|
||||||
|
echo "Entrypoint: $entrypoint"
|
||||||
|
if [ "$entrypoint" != "[python -m llama_stack.distribution.server.server --config /app/run.yaml]" ]; then
|
||||||
|
echo "Entrypoint is not correct"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
build-ubi9-container-distribution:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Checkout repository
|
||||||
|
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||||
|
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0
|
||||||
|
with:
|
||||||
|
python-version: '3.10'
|
||||||
|
|
||||||
|
- name: Install uv
|
||||||
|
uses: astral-sh/setup-uv@0c5e2b8115b80b4c7c5ddf6ffdd634974642d182 # v5.4.1
|
||||||
|
with:
|
||||||
|
python-version: "3.10"
|
||||||
|
|
||||||
|
- name: Install LlamaStack
|
||||||
|
run: |
|
||||||
|
uv venv
|
||||||
|
source .venv/bin/activate
|
||||||
|
uv pip install -e .
|
||||||
|
|
||||||
|
- name: Pin template to UBI9 base
|
||||||
|
run: |
|
||||||
|
yq -i '
|
||||||
|
.image_type = "container" |
|
||||||
|
.image_name = "ubi9-test" |
|
||||||
|
.distribution_spec.container_image = "registry.access.redhat.com/ubi9:latest"
|
||||||
|
' llama_stack/templates/dev/build.yaml
|
||||||
|
|
||||||
|
- name: Build dev container (UBI9)
|
||||||
|
env:
|
||||||
|
USE_COPY_NOT_MOUNT: "true"
|
||||||
|
LLAMA_STACK_DIR: "."
|
||||||
|
run: |
|
||||||
|
uv run llama stack build --config llama_stack/templates/dev/build.yaml
|
||||||
|
|
||||||
|
- name: Inspect UBI9 image
|
||||||
|
run: |
|
||||||
|
IMAGE_ID=$(docker images --format "{{.Repository}}:{{.Tag}}" | head -n 1)
|
||||||
|
entrypoint=$(docker inspect --format '{{ .Config.Entrypoint }}' $IMAGE_ID)
|
||||||
|
echo "Entrypoint: $entrypoint"
|
||||||
|
if [ "$entrypoint" != "[python -m llama_stack.distribution.server.server --config /app/run.yaml]" ]; then
|
||||||
|
echo "Entrypoint is not correct"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "Checking /etc/os-release in $IMAGE_ID"
|
||||||
|
docker run --rm --entrypoint sh "$IMAGE_ID" -c \
|
||||||
|
'source /etc/os-release && echo "$ID"' \
|
||||||
|
| grep -qE '^(rhel|ubi)$' \
|
||||||
|
|| { echo "Base image is not UBI 9!"; exit 1; }
|
||||||
|
|
79
.github/workflows/test-external-providers.yml
vendored
79
.github/workflows/test-external-providers.yml
vendored
|
@ -5,89 +5,74 @@ on:
|
||||||
branches: [ main ]
|
branches: [ main ]
|
||||||
pull_request:
|
pull_request:
|
||||||
branches: [ main ]
|
branches: [ main ]
|
||||||
|
paths:
|
||||||
|
- 'llama_stack/**'
|
||||||
|
- 'tests/integration/**'
|
||||||
|
- 'uv.lock'
|
||||||
|
- 'pyproject.toml'
|
||||||
|
- 'requirements.txt'
|
||||||
|
- '.github/workflows/test-external-providers.yml' # This workflow
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
test-external-providers:
|
test-external-providers:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
image-type: [venv]
|
||||||
|
# We don't do container yet, it's tricky to install a package from the host into the
|
||||||
|
# container and point 'uv pip install' to the correct path...
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Install uv
|
- name: Install uv
|
||||||
uses: astral-sh/setup-uv@v5
|
uses: astral-sh/setup-uv@v6
|
||||||
with:
|
with:
|
||||||
python-version: "3.10"
|
python-version: "3.10"
|
||||||
|
|
||||||
- name: Install Ollama
|
|
||||||
run: |
|
|
||||||
curl -fsSL https://ollama.com/install.sh | sh
|
|
||||||
|
|
||||||
- name: Pull Ollama image
|
|
||||||
run: |
|
|
||||||
ollama pull llama3.2:3b-instruct-fp16
|
|
||||||
|
|
||||||
- name: Start Ollama in background
|
|
||||||
run: |
|
|
||||||
nohup ollama run llama3.2:3b-instruct-fp16 --keepalive=30m > ollama.log 2>&1 &
|
|
||||||
|
|
||||||
- name: Set Up Environment and Install Dependencies
|
- name: Set Up Environment and Install Dependencies
|
||||||
run: |
|
run: |
|
||||||
uv sync --extra dev --extra test
|
uv sync --extra dev --extra test
|
||||||
uv pip install -e .
|
uv pip install -e .
|
||||||
|
|
||||||
- name: Install Ollama custom provider
|
- name: Apply image type to config file
|
||||||
|
run: |
|
||||||
|
yq -i '.image_type = "${{ matrix.image-type }}"' tests/external-provider/llama-stack-provider-ollama/custom-distro.yaml
|
||||||
|
cat tests/external-provider/llama-stack-provider-ollama/custom-distro.yaml
|
||||||
|
|
||||||
|
- name: Setup directory for Ollama custom provider
|
||||||
run: |
|
run: |
|
||||||
mkdir -p tests/external-provider/llama-stack-provider-ollama/src/
|
mkdir -p tests/external-provider/llama-stack-provider-ollama/src/
|
||||||
cp -a llama_stack/providers/remote/inference/ollama/ tests/external-provider/llama-stack-provider-ollama/src/llama_stack_provider_ollama
|
cp -a llama_stack/providers/remote/inference/ollama/ tests/external-provider/llama-stack-provider-ollama/src/llama_stack_provider_ollama
|
||||||
uv pip install tests/external-provider/llama-stack-provider-ollama
|
|
||||||
|
|
||||||
- name: Create provider configuration
|
- name: Create provider configuration
|
||||||
run: |
|
run: |
|
||||||
mkdir -p /tmp/providers.d/remote/inference
|
mkdir -p /tmp/providers.d/remote/inference
|
||||||
cp tests/external-provider/llama-stack-provider-ollama/custom_ollama.yaml /tmp/providers.d/remote/inference/custom_ollama.yaml
|
cp tests/external-provider/llama-stack-provider-ollama/custom_ollama.yaml /tmp/providers.d/remote/inference/custom_ollama.yaml
|
||||||
|
|
||||||
- name: Wait for Ollama to start
|
- name: Build distro from config file
|
||||||
run: |
|
run: |
|
||||||
echo "Waiting for Ollama..."
|
USE_COPY_NOT_MOUNT=true LLAMA_STACK_DIR=. uv run llama stack build --config tests/external-provider/llama-stack-provider-ollama/custom-distro.yaml
|
||||||
for i in {1..30}; do
|
|
||||||
if curl -s http://localhost:11434 | grep -q "Ollama is running"; then
|
|
||||||
echo "Ollama is running!"
|
|
||||||
exit 0
|
|
||||||
fi
|
|
||||||
sleep 1
|
|
||||||
done
|
|
||||||
echo "Ollama failed to start"
|
|
||||||
ollama ps
|
|
||||||
ollama.log
|
|
||||||
exit 1
|
|
||||||
|
|
||||||
- name: Start Llama Stack server in background
|
- name: Start Llama Stack server in background
|
||||||
|
if: ${{ matrix.image-type }} == 'venv'
|
||||||
env:
|
env:
|
||||||
INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct"
|
INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct"
|
||||||
run: |
|
run: |
|
||||||
source .venv/bin/activate
|
source ci-test/bin/activate
|
||||||
nohup uv run llama stack run tests/external-provider/llama-stack-provider-ollama/run.yaml --image-type venv > server.log 2>&1 &
|
uv run pip list
|
||||||
|
nohup uv run --active llama stack run tests/external-provider/llama-stack-provider-ollama/run.yaml --image-type ${{ matrix.image-type }} > server.log 2>&1 &
|
||||||
|
|
||||||
- name: Wait for Llama Stack server to be ready
|
- name: Wait for Llama Stack server to be ready
|
||||||
run: |
|
run: |
|
||||||
echo "Waiting for Llama Stack server..."
|
|
||||||
for i in {1..30}; do
|
for i in {1..30}; do
|
||||||
if curl -s http://localhost:8321/v1/health | grep -q "OK"; then
|
if ! grep -q "remote::custom_ollama from /tmp/providers.d/remote/inference/custom_ollama.yaml" server.log; then
|
||||||
echo "Llama Stack server is up!"
|
echo "Waiting for Llama Stack server to load the provider..."
|
||||||
if grep -q "remote::custom_ollama from /tmp/providers.d/remote/inference/custom_ollama.yaml" server.log; then
|
sleep 1
|
||||||
echo "Llama Stack server is using custom Ollama provider"
|
else
|
||||||
exit 0
|
echo "Provider loaded"
|
||||||
else
|
exit 0
|
||||||
echo "Llama Stack server is not using custom Ollama provider"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
fi
|
fi
|
||||||
sleep 1
|
|
||||||
done
|
done
|
||||||
echo "Llama Stack server failed to start"
|
echo "Provider failed to load"
|
||||||
cat server.log
|
|
||||||
exit 1
|
exit 1
|
||||||
|
|
||||||
- name: run inference tests
|
|
||||||
run: |
|
|
||||||
uv run pytest -v tests/integration/inference/test_text_inference.py --stack-config="http://localhost:8321" --text-model="meta-llama/Llama-3.2-3B-Instruct" --embedding-model=all-MiniLM-L6-v2
|
|
||||||
|
|
5
.github/workflows/unit-tests.yml
vendored
5
.github/workflows/unit-tests.yml
vendored
|
@ -6,7 +6,6 @@ on:
|
||||||
pull_request:
|
pull_request:
|
||||||
branches: [ main ]
|
branches: [ main ]
|
||||||
paths:
|
paths:
|
||||||
- 'distributions/**'
|
|
||||||
- 'llama_stack/**'
|
- 'llama_stack/**'
|
||||||
- 'tests/unit/**'
|
- 'tests/unit/**'
|
||||||
- 'uv.lock'
|
- 'uv.lock'
|
||||||
|
@ -34,11 +33,11 @@ jobs:
|
||||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||||
|
|
||||||
- name: Set up Python ${{ matrix.python }}
|
- name: Set up Python ${{ matrix.python }}
|
||||||
uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0
|
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0
|
||||||
with:
|
with:
|
||||||
python-version: ${{ matrix.python }}
|
python-version: ${{ matrix.python }}
|
||||||
|
|
||||||
- uses: astral-sh/setup-uv@0c5e2b8115b80b4c7c5ddf6ffdd634974642d182 # v5.4.1
|
- uses: astral-sh/setup-uv@c7f87aa956e4c323abf06d5dec078e358f6b4d04 # v6.0.0
|
||||||
with:
|
with:
|
||||||
python-version: ${{ matrix.python }}
|
python-version: ${{ matrix.python }}
|
||||||
enable-cache: false
|
enable-cache: false
|
||||||
|
|
4
.github/workflows/update-readthedocs.yml
vendored
4
.github/workflows/update-readthedocs.yml
vendored
|
@ -36,12 +36,12 @@ jobs:
|
||||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||||
|
|
||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0
|
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0
|
||||||
with:
|
with:
|
||||||
python-version: '3.11'
|
python-version: '3.11'
|
||||||
|
|
||||||
- name: Install the latest version of uv
|
- name: Install the latest version of uv
|
||||||
uses: astral-sh/setup-uv@0c5e2b8115b80b4c7c5ddf6ffdd634974642d182 # v5.4.1
|
uses: astral-sh/setup-uv@c7f87aa956e4c323abf06d5dec078e358f6b4d04 # v6.0.0
|
||||||
|
|
||||||
- name: Sync with uv
|
- name: Sync with uv
|
||||||
run: uv sync --extra docs
|
run: uv sync --extra docs
|
||||||
|
|
|
@ -15,6 +15,18 @@ repos:
|
||||||
args: ['--maxkb=1000']
|
args: ['--maxkb=1000']
|
||||||
- id: end-of-file-fixer
|
- id: end-of-file-fixer
|
||||||
exclude: '^(.*\.svg)$'
|
exclude: '^(.*\.svg)$'
|
||||||
|
- id: no-commit-to-branch
|
||||||
|
- id: check-yaml
|
||||||
|
args: ["--unsafe"]
|
||||||
|
- id: detect-private-key
|
||||||
|
- id: requirements-txt-fixer
|
||||||
|
- id: mixed-line-ending
|
||||||
|
args: [--fix=lf] # Forces to replace line ending by LF (line feed)
|
||||||
|
- id: check-executables-have-shebangs
|
||||||
|
- id: check-json
|
||||||
|
- id: check-shebang-scripts-are-executable
|
||||||
|
- id: check-symlinks
|
||||||
|
- id: check-toml
|
||||||
|
|
||||||
- repo: https://github.com/Lucas-C/pre-commit-hooks
|
- repo: https://github.com/Lucas-C/pre-commit-hooks
|
||||||
rev: v1.5.4
|
rev: v1.5.4
|
||||||
|
|
28
CHANGELOG.md
28
CHANGELOG.md
|
@ -1,5 +1,33 @@
|
||||||
# Changelog
|
# Changelog
|
||||||
|
|
||||||
|
# v0.2.3
|
||||||
|
Published on: 2025-04-25T22:46:21Z
|
||||||
|
|
||||||
|
## Highlights
|
||||||
|
|
||||||
|
* OpenAI compatible inference endpoints and client-SDK support. `client.chat.completions.create()` now works.
|
||||||
|
* significant improvements and functionality added to the nVIDIA distribution
|
||||||
|
* many improvements to the test verification suite.
|
||||||
|
* new inference providers: Ramalama, IBM WatsonX
|
||||||
|
* many improvements to the Playground UI
|
||||||
|
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
# v0.2.2
|
||||||
|
Published on: 2025-04-13T01:19:49Z
|
||||||
|
|
||||||
|
## Main changes
|
||||||
|
|
||||||
|
- Bring Your Own Provider (@leseb) - use out-of-tree provider code to execute the distribution server
|
||||||
|
- OpenAI compatible inference API in progress (@bbrowning)
|
||||||
|
- Provider verifications (@ehhuang)
|
||||||
|
- Many updates and fixes to playground
|
||||||
|
- Several llama4 related fixes
|
||||||
|
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
# v0.2.1
|
# v0.2.1
|
||||||
Published on: 2025-04-05T23:13:00Z
|
Published on: 2025-04-05T23:13:00Z
|
||||||
|
|
||||||
|
|
|
@ -141,11 +141,18 @@ uv sync
|
||||||
|
|
||||||
## Coding Style
|
## Coding Style
|
||||||
|
|
||||||
* Comments should provide meaningful insights into the code. Avoid filler comments that simply describe the next step, as they create unnecessary clutter, same goes for docstrings.
|
* Comments should provide meaningful insights into the code. Avoid filler comments that simply
|
||||||
* Prefer comments to clarify surprising behavior and/or relationships between parts of the code rather than explain what the next line of code does.
|
describe the next step, as they create unnecessary clutter, same goes for docstrings.
|
||||||
* Catching exceptions, prefer using a specific exception type rather than a broad catch-all like `Exception`.
|
* Prefer comments to clarify surprising behavior and/or relationships between parts of the code
|
||||||
|
rather than explain what the next line of code does.
|
||||||
|
* Catching exceptions, prefer using a specific exception type rather than a broad catch-all like
|
||||||
|
`Exception`.
|
||||||
* Error messages should be prefixed with "Failed to ..."
|
* Error messages should be prefixed with "Failed to ..."
|
||||||
* 4 spaces for indentation rather than tabs
|
* 4 spaces for indentation rather than tab
|
||||||
|
* When using `# noqa` to suppress a style or linter warning, include a comment explaining the
|
||||||
|
justification for bypassing the check.
|
||||||
|
* When using `# type: ignore` to suppress a mypy warning, include a comment explaining the
|
||||||
|
justification for bypassing the check.
|
||||||
|
|
||||||
## Common Tasks
|
## Common Tasks
|
||||||
|
|
||||||
|
|
|
@ -70,6 +70,13 @@ As more providers start supporting Llama 4, you can use them in Llama Stack as w
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
### 🚀 One-Line Installer 🚀
|
||||||
|
|
||||||
|
To try Llama Stack locally, run:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -LsSf https://github.com/meta-llama/llama-stack/raw/main/install.sh | sh
|
||||||
|
```
|
||||||
|
|
||||||
### Overview
|
### Overview
|
||||||
|
|
||||||
|
@ -119,6 +126,7 @@ Here is a list of the various API providers and available distributions that can
|
||||||
| OpenAI | Hosted | | ✅ | | | |
|
| OpenAI | Hosted | | ✅ | | | |
|
||||||
| Anthropic | Hosted | | ✅ | | | |
|
| Anthropic | Hosted | | ✅ | | | |
|
||||||
| Gemini | Hosted | | ✅ | | | |
|
| Gemini | Hosted | | ✅ | | | |
|
||||||
|
| watsonx | Hosted | | ✅ | | | |
|
||||||
|
|
||||||
|
|
||||||
### Distributions
|
### Distributions
|
||||||
|
@ -128,7 +136,6 @@ A Llama Stack Distribution (or "distro") is a pre-configured bundle of provider
|
||||||
| **Distribution** | **Llama Stack Docker** | Start This Distribution |
|
| **Distribution** | **Llama Stack Docker** | Start This Distribution |
|
||||||
|:---------------------------------------------:|:-------------------------------------------------------------------------------------------------------------------------------------------------------------:|:------------------------------------------------------------------------------------------------------------------------:|
|
|:---------------------------------------------:|:-------------------------------------------------------------------------------------------------------------------------------------------------------------:|:------------------------------------------------------------------------------------------------------------------------:|
|
||||||
| Meta Reference | [llamastack/distribution-meta-reference-gpu](https://hub.docker.com/repository/docker/llamastack/distribution-meta-reference-gpu/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/meta-reference-gpu.html) |
|
| Meta Reference | [llamastack/distribution-meta-reference-gpu](https://hub.docker.com/repository/docker/llamastack/distribution-meta-reference-gpu/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/meta-reference-gpu.html) |
|
||||||
| Meta Reference Quantized | [llamastack/distribution-meta-reference-quantized-gpu](https://hub.docker.com/repository/docker/llamastack/distribution-meta-reference-quantized-gpu/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/meta-reference-quantized-gpu.html) |
|
|
||||||
| SambaNova | [llamastack/distribution-sambanova](https://hub.docker.com/repository/docker/llamastack/distribution-sambanova/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/sambanova.html) |
|
| SambaNova | [llamastack/distribution-sambanova](https://hub.docker.com/repository/docker/llamastack/distribution-sambanova/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/sambanova.html) |
|
||||||
| Cerebras | [llamastack/distribution-cerebras](https://hub.docker.com/repository/docker/llamastack/distribution-cerebras/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/cerebras.html) |
|
| Cerebras | [llamastack/distribution-cerebras](https://hub.docker.com/repository/docker/llamastack/distribution-cerebras/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/cerebras.html) |
|
||||||
| Ollama | [llamastack/distribution-ollama](https://hub.docker.com/repository/docker/llamastack/distribution-ollama/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/ollama.html) |
|
| Ollama | [llamastack/distribution-ollama](https://hub.docker.com/repository/docker/llamastack/distribution-ollama/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/ollama.html) |
|
||||||
|
|
6
docs/_static/css/my_theme.css
vendored
6
docs/_static/css/my_theme.css
vendored
|
@ -27,3 +27,9 @@ pre {
|
||||||
white-space: pre-wrap !important;
|
white-space: pre-wrap !important;
|
||||||
word-break: break-all;
|
word-break: break-all;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
[data-theme="dark"] .mermaid {
|
||||||
|
background-color: #f4f4f6 !important;
|
||||||
|
border-radius: 6px;
|
||||||
|
padding: 0.5em;
|
||||||
|
}
|
||||||
|
|
29
docs/_static/js/detect_theme.js
vendored
29
docs/_static/js/detect_theme.js
vendored
|
@ -1,9 +1,32 @@
|
||||||
document.addEventListener("DOMContentLoaded", function () {
|
document.addEventListener("DOMContentLoaded", function () {
|
||||||
const prefersDark = window.matchMedia("(prefers-color-scheme: dark)").matches;
|
const prefersDark = window.matchMedia("(prefers-color-scheme: dark)").matches;
|
||||||
const htmlElement = document.documentElement;
|
const htmlElement = document.documentElement;
|
||||||
if (prefersDark) {
|
|
||||||
htmlElement.setAttribute("data-theme", "dark");
|
// Check if theme is saved in localStorage
|
||||||
|
const savedTheme = localStorage.getItem("sphinx-rtd-theme");
|
||||||
|
|
||||||
|
if (savedTheme) {
|
||||||
|
// Use the saved theme preference
|
||||||
|
htmlElement.setAttribute("data-theme", savedTheme);
|
||||||
|
document.body.classList.toggle("dark", savedTheme === "dark");
|
||||||
} else {
|
} else {
|
||||||
htmlElement.setAttribute("data-theme", "light");
|
// Fall back to system preference
|
||||||
|
const theme = prefersDark ? "dark" : "light";
|
||||||
|
htmlElement.setAttribute("data-theme", theme);
|
||||||
|
document.body.classList.toggle("dark", theme === "dark");
|
||||||
|
// Save initial preference
|
||||||
|
localStorage.setItem("sphinx-rtd-theme", theme);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Listen for theme changes from the existing toggle
|
||||||
|
const observer = new MutationObserver(function(mutations) {
|
||||||
|
mutations.forEach(function(mutation) {
|
||||||
|
if (mutation.attributeName === "data-theme") {
|
||||||
|
const currentTheme = htmlElement.getAttribute("data-theme");
|
||||||
|
localStorage.setItem("sphinx-rtd-theme", currentTheme);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
observer.observe(htmlElement, { attributes: true });
|
||||||
});
|
});
|
||||||
|
|
537
docs/_static/llama-stack-spec.html
vendored
537
docs/_static/llama-stack-spec.html
vendored
|
@ -497,6 +497,54 @@
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"/v1/openai/v1/responses": {
|
||||||
|
"post": {
|
||||||
|
"responses": {
|
||||||
|
"200": {
|
||||||
|
"description": "Runtime representation of an annotated type.",
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/OpenAIResponseObject"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"text/event-stream": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/OpenAIResponseObjectStream"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"400": {
|
||||||
|
"$ref": "#/components/responses/BadRequest400"
|
||||||
|
},
|
||||||
|
"429": {
|
||||||
|
"$ref": "#/components/responses/TooManyRequests429"
|
||||||
|
},
|
||||||
|
"500": {
|
||||||
|
"$ref": "#/components/responses/InternalServerError500"
|
||||||
|
},
|
||||||
|
"default": {
|
||||||
|
"$ref": "#/components/responses/DefaultError"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"tags": [
|
||||||
|
"Agents"
|
||||||
|
],
|
||||||
|
"description": "Create a new OpenAI response.",
|
||||||
|
"parameters": [],
|
||||||
|
"requestBody": {
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/CreateOpenaiResponseRequest"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
"/v1/files": {
|
"/v1/files": {
|
||||||
"get": {
|
"get": {
|
||||||
"responses": {
|
"responses": {
|
||||||
|
@ -1278,6 +1326,49 @@
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"/v1/openai/v1/responses/{id}": {
|
||||||
|
"get": {
|
||||||
|
"responses": {
|
||||||
|
"200": {
|
||||||
|
"description": "An OpenAIResponseObject.",
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/OpenAIResponseObject"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"400": {
|
||||||
|
"$ref": "#/components/responses/BadRequest400"
|
||||||
|
},
|
||||||
|
"429": {
|
||||||
|
"$ref": "#/components/responses/TooManyRequests429"
|
||||||
|
},
|
||||||
|
"500": {
|
||||||
|
"$ref": "#/components/responses/InternalServerError500"
|
||||||
|
},
|
||||||
|
"default": {
|
||||||
|
"$ref": "#/components/responses/DefaultError"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"tags": [
|
||||||
|
"Agents"
|
||||||
|
],
|
||||||
|
"description": "Retrieve an OpenAI response by its ID.",
|
||||||
|
"parameters": [
|
||||||
|
{
|
||||||
|
"name": "id",
|
||||||
|
"in": "path",
|
||||||
|
"description": "The ID of the OpenAI response to retrieve.",
|
||||||
|
"required": true,
|
||||||
|
"schema": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
"/v1/scoring-functions/{scoring_fn_id}": {
|
"/v1/scoring-functions/{scoring_fn_id}": {
|
||||||
"get": {
|
"get": {
|
||||||
"responses": {
|
"responses": {
|
||||||
|
@ -5221,17 +5312,25 @@
|
||||||
"default": 10
|
"default": 10
|
||||||
},
|
},
|
||||||
"model": {
|
"model": {
|
||||||
"type": "string"
|
"type": "string",
|
||||||
|
"description": "The model identifier to use for the agent"
|
||||||
},
|
},
|
||||||
"instructions": {
|
"instructions": {
|
||||||
"type": "string"
|
"type": "string",
|
||||||
|
"description": "The system instructions for the agent"
|
||||||
|
},
|
||||||
|
"name": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Optional name for the agent, used in telemetry and identification"
|
||||||
},
|
},
|
||||||
"enable_session_persistence": {
|
"enable_session_persistence": {
|
||||||
"type": "boolean",
|
"type": "boolean",
|
||||||
"default": false
|
"default": false,
|
||||||
|
"description": "Optional flag indicating whether session data has to be persisted"
|
||||||
},
|
},
|
||||||
"response_format": {
|
"response_format": {
|
||||||
"$ref": "#/components/schemas/ResponseFormat"
|
"$ref": "#/components/schemas/ResponseFormat",
|
||||||
|
"description": "Optional response format configuration"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
|
@ -5239,7 +5338,8 @@
|
||||||
"model",
|
"model",
|
||||||
"instructions"
|
"instructions"
|
||||||
],
|
],
|
||||||
"title": "AgentConfig"
|
"title": "AgentConfig",
|
||||||
|
"description": "Configuration for an agent."
|
||||||
},
|
},
|
||||||
"AgentTool": {
|
"AgentTool": {
|
||||||
"oneOf": [
|
"oneOf": [
|
||||||
|
@ -6183,6 +6283,430 @@
|
||||||
],
|
],
|
||||||
"title": "AgentTurnResponseTurnStartPayload"
|
"title": "AgentTurnResponseTurnStartPayload"
|
||||||
},
|
},
|
||||||
|
"OpenAIResponseInputMessage": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"content": {
|
||||||
|
"oneOf": [
|
||||||
|
{
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"$ref": "#/components/schemas/OpenAIResponseInputMessageContent"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"role": {
|
||||||
|
"oneOf": [
|
||||||
|
{
|
||||||
|
"type": "string",
|
||||||
|
"const": "system"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "string",
|
||||||
|
"const": "developer"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "string",
|
||||||
|
"const": "user"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "string",
|
||||||
|
"const": "assistant"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"type": "string",
|
||||||
|
"const": "message",
|
||||||
|
"default": "message"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"content",
|
||||||
|
"role"
|
||||||
|
],
|
||||||
|
"title": "OpenAIResponseInputMessage"
|
||||||
|
},
|
||||||
|
"OpenAIResponseInputMessageContent": {
|
||||||
|
"oneOf": [
|
||||||
|
{
|
||||||
|
"$ref": "#/components/schemas/OpenAIResponseInputMessageContentText"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"$ref": "#/components/schemas/OpenAIResponseInputMessageContentImage"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"discriminator": {
|
||||||
|
"propertyName": "type",
|
||||||
|
"mapping": {
|
||||||
|
"input_text": "#/components/schemas/OpenAIResponseInputMessageContentText",
|
||||||
|
"input_image": "#/components/schemas/OpenAIResponseInputMessageContentImage"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"OpenAIResponseInputMessageContentImage": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"detail": {
|
||||||
|
"oneOf": [
|
||||||
|
{
|
||||||
|
"type": "string",
|
||||||
|
"const": "low"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "string",
|
||||||
|
"const": "high"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "string",
|
||||||
|
"const": "auto"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"default": "auto"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"type": "string",
|
||||||
|
"const": "input_image",
|
||||||
|
"default": "input_image"
|
||||||
|
},
|
||||||
|
"image_url": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"detail",
|
||||||
|
"type"
|
||||||
|
],
|
||||||
|
"title": "OpenAIResponseInputMessageContentImage"
|
||||||
|
},
|
||||||
|
"OpenAIResponseInputMessageContentText": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"text": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"type": "string",
|
||||||
|
"const": "input_text",
|
||||||
|
"default": "input_text"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"text",
|
||||||
|
"type"
|
||||||
|
],
|
||||||
|
"title": "OpenAIResponseInputMessageContentText"
|
||||||
|
},
|
||||||
|
"OpenAIResponseInputTool": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"type": {
|
||||||
|
"oneOf": [
|
||||||
|
{
|
||||||
|
"type": "string",
|
||||||
|
"const": "web_search"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "string",
|
||||||
|
"const": "web_search_preview_2025_03_11"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"default": "web_search"
|
||||||
|
},
|
||||||
|
"search_context_size": {
|
||||||
|
"type": "string",
|
||||||
|
"default": "medium"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"type"
|
||||||
|
],
|
||||||
|
"title": "OpenAIResponseInputToolWebSearch"
|
||||||
|
},
|
||||||
|
"CreateOpenaiResponseRequest": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"input": {
|
||||||
|
"oneOf": [
|
||||||
|
{
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"$ref": "#/components/schemas/OpenAIResponseInputMessage"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"description": "Input message(s) to create the response."
|
||||||
|
},
|
||||||
|
"model": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The underlying LLM used for completions."
|
||||||
|
},
|
||||||
|
"previous_response_id": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "(Optional) if specified, the new response will be a continuation of the previous response. This can be used to easily fork-off new responses from existing responses."
|
||||||
|
},
|
||||||
|
"store": {
|
||||||
|
"type": "boolean"
|
||||||
|
},
|
||||||
|
"stream": {
|
||||||
|
"type": "boolean"
|
||||||
|
},
|
||||||
|
"temperature": {
|
||||||
|
"type": "number"
|
||||||
|
},
|
||||||
|
"tools": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"$ref": "#/components/schemas/OpenAIResponseInputTool"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"input",
|
||||||
|
"model"
|
||||||
|
],
|
||||||
|
"title": "CreateOpenaiResponseRequest"
|
||||||
|
},
|
||||||
|
"OpenAIResponseError": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"code": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"message": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"code",
|
||||||
|
"message"
|
||||||
|
],
|
||||||
|
"title": "OpenAIResponseError"
|
||||||
|
},
|
||||||
|
"OpenAIResponseObject": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"created_at": {
|
||||||
|
"type": "integer"
|
||||||
|
},
|
||||||
|
"error": {
|
||||||
|
"$ref": "#/components/schemas/OpenAIResponseError"
|
||||||
|
},
|
||||||
|
"id": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"model": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"object": {
|
||||||
|
"type": "string",
|
||||||
|
"const": "response",
|
||||||
|
"default": "response"
|
||||||
|
},
|
||||||
|
"output": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"$ref": "#/components/schemas/OpenAIResponseOutput"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"parallel_tool_calls": {
|
||||||
|
"type": "boolean",
|
||||||
|
"default": false
|
||||||
|
},
|
||||||
|
"previous_response_id": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"status": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"temperature": {
|
||||||
|
"type": "number"
|
||||||
|
},
|
||||||
|
"top_p": {
|
||||||
|
"type": "number"
|
||||||
|
},
|
||||||
|
"truncation": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"user": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"created_at",
|
||||||
|
"id",
|
||||||
|
"model",
|
||||||
|
"object",
|
||||||
|
"output",
|
||||||
|
"parallel_tool_calls",
|
||||||
|
"status"
|
||||||
|
],
|
||||||
|
"title": "OpenAIResponseObject"
|
||||||
|
},
|
||||||
|
"OpenAIResponseOutput": {
|
||||||
|
"oneOf": [
|
||||||
|
{
|
||||||
|
"$ref": "#/components/schemas/OpenAIResponseOutputMessage"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"$ref": "#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"discriminator": {
|
||||||
|
"propertyName": "type",
|
||||||
|
"mapping": {
|
||||||
|
"message": "#/components/schemas/OpenAIResponseOutputMessage",
|
||||||
|
"web_search_call": "#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"OpenAIResponseOutputMessage": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"id": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"content": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"$ref": "#/components/schemas/OpenAIResponseOutputMessageContent"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"role": {
|
||||||
|
"type": "string",
|
||||||
|
"const": "assistant",
|
||||||
|
"default": "assistant"
|
||||||
|
},
|
||||||
|
"status": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"type": "string",
|
||||||
|
"const": "message",
|
||||||
|
"default": "message"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"id",
|
||||||
|
"content",
|
||||||
|
"role",
|
||||||
|
"status",
|
||||||
|
"type"
|
||||||
|
],
|
||||||
|
"title": "OpenAIResponseOutputMessage"
|
||||||
|
},
|
||||||
|
"OpenAIResponseOutputMessageContent": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"text": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"type": "string",
|
||||||
|
"const": "output_text",
|
||||||
|
"default": "output_text"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"text",
|
||||||
|
"type"
|
||||||
|
],
|
||||||
|
"title": "OpenAIResponseOutputMessageContentOutputText"
|
||||||
|
},
|
||||||
|
"OpenAIResponseOutputMessageWebSearchToolCall": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"id": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"status": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"type": "string",
|
||||||
|
"const": "web_search_call",
|
||||||
|
"default": "web_search_call"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"id",
|
||||||
|
"status",
|
||||||
|
"type"
|
||||||
|
],
|
||||||
|
"title": "OpenAIResponseOutputMessageWebSearchToolCall"
|
||||||
|
},
|
||||||
|
"OpenAIResponseObjectStream": {
|
||||||
|
"oneOf": [
|
||||||
|
{
|
||||||
|
"$ref": "#/components/schemas/OpenAIResponseObjectStreamResponseCreated"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"$ref": "#/components/schemas/OpenAIResponseObjectStreamResponseCompleted"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"discriminator": {
|
||||||
|
"propertyName": "type",
|
||||||
|
"mapping": {
|
||||||
|
"response.created": "#/components/schemas/OpenAIResponseObjectStreamResponseCreated",
|
||||||
|
"response.completed": "#/components/schemas/OpenAIResponseObjectStreamResponseCompleted"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"OpenAIResponseObjectStreamResponseCompleted": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"response": {
|
||||||
|
"$ref": "#/components/schemas/OpenAIResponseObject"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"type": "string",
|
||||||
|
"const": "response.completed",
|
||||||
|
"default": "response.completed"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"response",
|
||||||
|
"type"
|
||||||
|
],
|
||||||
|
"title": "OpenAIResponseObjectStreamResponseCompleted"
|
||||||
|
},
|
||||||
|
"OpenAIResponseObjectStreamResponseCreated": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"response": {
|
||||||
|
"$ref": "#/components/schemas/OpenAIResponseObject"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"type": "string",
|
||||||
|
"const": "response.created",
|
||||||
|
"default": "response.created"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"response",
|
||||||
|
"type"
|
||||||
|
],
|
||||||
|
"title": "OpenAIResponseObjectStreamResponseCreated"
|
||||||
|
},
|
||||||
"CreateUploadSessionRequest": {
|
"CreateUploadSessionRequest": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
|
@ -8891,8 +9415,7 @@
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
"required": [
|
"required": [
|
||||||
"role",
|
"role"
|
||||||
"content"
|
|
||||||
],
|
],
|
||||||
"title": "OpenAIAssistantMessageParam",
|
"title": "OpenAIAssistantMessageParam",
|
||||||
"description": "A message containing the model's (assistant) response in an OpenAI-compatible chat completion request."
|
"description": "A message containing the model's (assistant) response in an OpenAI-compatible chat completion request."
|
||||||
|
|
364
docs/_static/llama-stack-spec.yaml
vendored
364
docs/_static/llama-stack-spec.yaml
vendored
|
@ -330,6 +330,39 @@ paths:
|
||||||
schema:
|
schema:
|
||||||
$ref: '#/components/schemas/CreateAgentTurnRequest'
|
$ref: '#/components/schemas/CreateAgentTurnRequest'
|
||||||
required: true
|
required: true
|
||||||
|
/v1/openai/v1/responses:
|
||||||
|
post:
|
||||||
|
responses:
|
||||||
|
'200':
|
||||||
|
description: >-
|
||||||
|
Runtime representation of an annotated type.
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: '#/components/schemas/OpenAIResponseObject'
|
||||||
|
text/event-stream:
|
||||||
|
schema:
|
||||||
|
$ref: '#/components/schemas/OpenAIResponseObjectStream'
|
||||||
|
'400':
|
||||||
|
$ref: '#/components/responses/BadRequest400'
|
||||||
|
'429':
|
||||||
|
$ref: >-
|
||||||
|
#/components/responses/TooManyRequests429
|
||||||
|
'500':
|
||||||
|
$ref: >-
|
||||||
|
#/components/responses/InternalServerError500
|
||||||
|
default:
|
||||||
|
$ref: '#/components/responses/DefaultError'
|
||||||
|
tags:
|
||||||
|
- Agents
|
||||||
|
description: Create a new OpenAI response.
|
||||||
|
parameters: []
|
||||||
|
requestBody:
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: '#/components/schemas/CreateOpenaiResponseRequest'
|
||||||
|
required: true
|
||||||
/v1/files:
|
/v1/files:
|
||||||
get:
|
get:
|
||||||
responses:
|
responses:
|
||||||
|
@ -875,6 +908,36 @@ paths:
|
||||||
required: true
|
required: true
|
||||||
schema:
|
schema:
|
||||||
type: string
|
type: string
|
||||||
|
/v1/openai/v1/responses/{id}:
|
||||||
|
get:
|
||||||
|
responses:
|
||||||
|
'200':
|
||||||
|
description: An OpenAIResponseObject.
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: '#/components/schemas/OpenAIResponseObject'
|
||||||
|
'400':
|
||||||
|
$ref: '#/components/responses/BadRequest400'
|
||||||
|
'429':
|
||||||
|
$ref: >-
|
||||||
|
#/components/responses/TooManyRequests429
|
||||||
|
'500':
|
||||||
|
$ref: >-
|
||||||
|
#/components/responses/InternalServerError500
|
||||||
|
default:
|
||||||
|
$ref: '#/components/responses/DefaultError'
|
||||||
|
tags:
|
||||||
|
- Agents
|
||||||
|
description: Retrieve an OpenAI response by its ID.
|
||||||
|
parameters:
|
||||||
|
- name: id
|
||||||
|
in: path
|
||||||
|
description: >-
|
||||||
|
The ID of the OpenAI response to retrieve.
|
||||||
|
required: true
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
/v1/scoring-functions/{scoring_fn_id}:
|
/v1/scoring-functions/{scoring_fn_id}:
|
||||||
get:
|
get:
|
||||||
responses:
|
responses:
|
||||||
|
@ -3686,18 +3749,29 @@ components:
|
||||||
default: 10
|
default: 10
|
||||||
model:
|
model:
|
||||||
type: string
|
type: string
|
||||||
|
description: >-
|
||||||
|
The model identifier to use for the agent
|
||||||
instructions:
|
instructions:
|
||||||
type: string
|
type: string
|
||||||
|
description: The system instructions for the agent
|
||||||
|
name:
|
||||||
|
type: string
|
||||||
|
description: >-
|
||||||
|
Optional name for the agent, used in telemetry and identification
|
||||||
enable_session_persistence:
|
enable_session_persistence:
|
||||||
type: boolean
|
type: boolean
|
||||||
default: false
|
default: false
|
||||||
|
description: >-
|
||||||
|
Optional flag indicating whether session data has to be persisted
|
||||||
response_format:
|
response_format:
|
||||||
$ref: '#/components/schemas/ResponseFormat'
|
$ref: '#/components/schemas/ResponseFormat'
|
||||||
|
description: Optional response format configuration
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
required:
|
required:
|
||||||
- model
|
- model
|
||||||
- instructions
|
- instructions
|
||||||
title: AgentConfig
|
title: AgentConfig
|
||||||
|
description: Configuration for an agent.
|
||||||
AgentTool:
|
AgentTool:
|
||||||
oneOf:
|
oneOf:
|
||||||
- type: string
|
- type: string
|
||||||
|
@ -4318,6 +4392,295 @@ components:
|
||||||
- event_type
|
- event_type
|
||||||
- turn_id
|
- turn_id
|
||||||
title: AgentTurnResponseTurnStartPayload
|
title: AgentTurnResponseTurnStartPayload
|
||||||
|
OpenAIResponseInputMessage:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
content:
|
||||||
|
oneOf:
|
||||||
|
- type: string
|
||||||
|
- type: array
|
||||||
|
items:
|
||||||
|
$ref: '#/components/schemas/OpenAIResponseInputMessageContent'
|
||||||
|
role:
|
||||||
|
oneOf:
|
||||||
|
- type: string
|
||||||
|
const: system
|
||||||
|
- type: string
|
||||||
|
const: developer
|
||||||
|
- type: string
|
||||||
|
const: user
|
||||||
|
- type: string
|
||||||
|
const: assistant
|
||||||
|
type:
|
||||||
|
type: string
|
||||||
|
const: message
|
||||||
|
default: message
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- content
|
||||||
|
- role
|
||||||
|
title: OpenAIResponseInputMessage
|
||||||
|
OpenAIResponseInputMessageContent:
|
||||||
|
oneOf:
|
||||||
|
- $ref: '#/components/schemas/OpenAIResponseInputMessageContentText'
|
||||||
|
- $ref: '#/components/schemas/OpenAIResponseInputMessageContentImage'
|
||||||
|
discriminator:
|
||||||
|
propertyName: type
|
||||||
|
mapping:
|
||||||
|
input_text: '#/components/schemas/OpenAIResponseInputMessageContentText'
|
||||||
|
input_image: '#/components/schemas/OpenAIResponseInputMessageContentImage'
|
||||||
|
OpenAIResponseInputMessageContentImage:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
detail:
|
||||||
|
oneOf:
|
||||||
|
- type: string
|
||||||
|
const: low
|
||||||
|
- type: string
|
||||||
|
const: high
|
||||||
|
- type: string
|
||||||
|
const: auto
|
||||||
|
default: auto
|
||||||
|
type:
|
||||||
|
type: string
|
||||||
|
const: input_image
|
||||||
|
default: input_image
|
||||||
|
image_url:
|
||||||
|
type: string
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- detail
|
||||||
|
- type
|
||||||
|
title: OpenAIResponseInputMessageContentImage
|
||||||
|
OpenAIResponseInputMessageContentText:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
text:
|
||||||
|
type: string
|
||||||
|
type:
|
||||||
|
type: string
|
||||||
|
const: input_text
|
||||||
|
default: input_text
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- text
|
||||||
|
- type
|
||||||
|
title: OpenAIResponseInputMessageContentText
|
||||||
|
OpenAIResponseInputTool:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
type:
|
||||||
|
oneOf:
|
||||||
|
- type: string
|
||||||
|
const: web_search
|
||||||
|
- type: string
|
||||||
|
const: web_search_preview_2025_03_11
|
||||||
|
default: web_search
|
||||||
|
search_context_size:
|
||||||
|
type: string
|
||||||
|
default: medium
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- type
|
||||||
|
title: OpenAIResponseInputToolWebSearch
|
||||||
|
CreateOpenaiResponseRequest:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
input:
|
||||||
|
oneOf:
|
||||||
|
- type: string
|
||||||
|
- type: array
|
||||||
|
items:
|
||||||
|
$ref: '#/components/schemas/OpenAIResponseInputMessage'
|
||||||
|
description: Input message(s) to create the response.
|
||||||
|
model:
|
||||||
|
type: string
|
||||||
|
description: The underlying LLM used for completions.
|
||||||
|
previous_response_id:
|
||||||
|
type: string
|
||||||
|
description: >-
|
||||||
|
(Optional) if specified, the new response will be a continuation of the
|
||||||
|
previous response. This can be used to easily fork-off new responses from
|
||||||
|
existing responses.
|
||||||
|
store:
|
||||||
|
type: boolean
|
||||||
|
stream:
|
||||||
|
type: boolean
|
||||||
|
temperature:
|
||||||
|
type: number
|
||||||
|
tools:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
$ref: '#/components/schemas/OpenAIResponseInputTool'
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- input
|
||||||
|
- model
|
||||||
|
title: CreateOpenaiResponseRequest
|
||||||
|
OpenAIResponseError:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
code:
|
||||||
|
type: string
|
||||||
|
message:
|
||||||
|
type: string
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- code
|
||||||
|
- message
|
||||||
|
title: OpenAIResponseError
|
||||||
|
OpenAIResponseObject:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
created_at:
|
||||||
|
type: integer
|
||||||
|
error:
|
||||||
|
$ref: '#/components/schemas/OpenAIResponseError'
|
||||||
|
id:
|
||||||
|
type: string
|
||||||
|
model:
|
||||||
|
type: string
|
||||||
|
object:
|
||||||
|
type: string
|
||||||
|
const: response
|
||||||
|
default: response
|
||||||
|
output:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
$ref: '#/components/schemas/OpenAIResponseOutput'
|
||||||
|
parallel_tool_calls:
|
||||||
|
type: boolean
|
||||||
|
default: false
|
||||||
|
previous_response_id:
|
||||||
|
type: string
|
||||||
|
status:
|
||||||
|
type: string
|
||||||
|
temperature:
|
||||||
|
type: number
|
||||||
|
top_p:
|
||||||
|
type: number
|
||||||
|
truncation:
|
||||||
|
type: string
|
||||||
|
user:
|
||||||
|
type: string
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- created_at
|
||||||
|
- id
|
||||||
|
- model
|
||||||
|
- object
|
||||||
|
- output
|
||||||
|
- parallel_tool_calls
|
||||||
|
- status
|
||||||
|
title: OpenAIResponseObject
|
||||||
|
OpenAIResponseOutput:
|
||||||
|
oneOf:
|
||||||
|
- $ref: '#/components/schemas/OpenAIResponseOutputMessage'
|
||||||
|
- $ref: '#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall'
|
||||||
|
discriminator:
|
||||||
|
propertyName: type
|
||||||
|
mapping:
|
||||||
|
message: '#/components/schemas/OpenAIResponseOutputMessage'
|
||||||
|
web_search_call: '#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall'
|
||||||
|
OpenAIResponseOutputMessage:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
id:
|
||||||
|
type: string
|
||||||
|
content:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
$ref: '#/components/schemas/OpenAIResponseOutputMessageContent'
|
||||||
|
role:
|
||||||
|
type: string
|
||||||
|
const: assistant
|
||||||
|
default: assistant
|
||||||
|
status:
|
||||||
|
type: string
|
||||||
|
type:
|
||||||
|
type: string
|
||||||
|
const: message
|
||||||
|
default: message
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- id
|
||||||
|
- content
|
||||||
|
- role
|
||||||
|
- status
|
||||||
|
- type
|
||||||
|
title: OpenAIResponseOutputMessage
|
||||||
|
OpenAIResponseOutputMessageContent:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
text:
|
||||||
|
type: string
|
||||||
|
type:
|
||||||
|
type: string
|
||||||
|
const: output_text
|
||||||
|
default: output_text
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- text
|
||||||
|
- type
|
||||||
|
title: >-
|
||||||
|
OpenAIResponseOutputMessageContentOutputText
|
||||||
|
"OpenAIResponseOutputMessageWebSearchToolCall":
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
id:
|
||||||
|
type: string
|
||||||
|
status:
|
||||||
|
type: string
|
||||||
|
type:
|
||||||
|
type: string
|
||||||
|
const: web_search_call
|
||||||
|
default: web_search_call
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- id
|
||||||
|
- status
|
||||||
|
- type
|
||||||
|
title: >-
|
||||||
|
OpenAIResponseOutputMessageWebSearchToolCall
|
||||||
|
OpenAIResponseObjectStream:
|
||||||
|
oneOf:
|
||||||
|
- $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseCreated'
|
||||||
|
- $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseCompleted'
|
||||||
|
discriminator:
|
||||||
|
propertyName: type
|
||||||
|
mapping:
|
||||||
|
response.created: '#/components/schemas/OpenAIResponseObjectStreamResponseCreated'
|
||||||
|
response.completed: '#/components/schemas/OpenAIResponseObjectStreamResponseCompleted'
|
||||||
|
"OpenAIResponseObjectStreamResponseCompleted":
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
response:
|
||||||
|
$ref: '#/components/schemas/OpenAIResponseObject'
|
||||||
|
type:
|
||||||
|
type: string
|
||||||
|
const: response.completed
|
||||||
|
default: response.completed
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- response
|
||||||
|
- type
|
||||||
|
title: >-
|
||||||
|
OpenAIResponseObjectStreamResponseCompleted
|
||||||
|
"OpenAIResponseObjectStreamResponseCreated":
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
response:
|
||||||
|
$ref: '#/components/schemas/OpenAIResponseObject'
|
||||||
|
type:
|
||||||
|
type: string
|
||||||
|
const: response.created
|
||||||
|
default: response.created
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- response
|
||||||
|
- type
|
||||||
|
title: >-
|
||||||
|
OpenAIResponseObjectStreamResponseCreated
|
||||||
CreateUploadSessionRequest:
|
CreateUploadSessionRequest:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
|
@ -6097,7 +6460,6 @@ components:
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
required:
|
required:
|
||||||
- role
|
- role
|
||||||
- content
|
|
||||||
title: OpenAIAssistantMessageParam
|
title: OpenAIAssistantMessageParam
|
||||||
description: >-
|
description: >-
|
||||||
A message containing the model's (assistant) response in an OpenAI-compatible
|
A message containing the model's (assistant) response in an OpenAI-compatible
|
||||||
|
|
907
docs/getting_started_llama_api.ipynb
Normal file
907
docs/getting_started_llama_api.ipynb
Normal file
File diff suppressed because one or more lines are too long
|
@ -6,6 +6,7 @@
|
||||||
|
|
||||||
import hashlib
|
import hashlib
|
||||||
import ipaddress
|
import ipaddress
|
||||||
|
import types
|
||||||
import typing
|
import typing
|
||||||
from dataclasses import make_dataclass
|
from dataclasses import make_dataclass
|
||||||
from typing import Any, Dict, Set, Union
|
from typing import Any, Dict, Set, Union
|
||||||
|
@ -179,7 +180,7 @@ class ContentBuilder:
|
||||||
"Creates the content subtree for a request or response."
|
"Creates the content subtree for a request or response."
|
||||||
|
|
||||||
def is_iterator_type(t):
|
def is_iterator_type(t):
|
||||||
return "StreamChunk" in str(t)
|
return "StreamChunk" in str(t) or "OpenAIResponseObjectStream" in str(t)
|
||||||
|
|
||||||
def get_media_type(t):
|
def get_media_type(t):
|
||||||
if is_generic_list(t):
|
if is_generic_list(t):
|
||||||
|
@ -189,7 +190,7 @@ class ContentBuilder:
|
||||||
else:
|
else:
|
||||||
return "application/json"
|
return "application/json"
|
||||||
|
|
||||||
if typing.get_origin(payload_type) is typing.Union:
|
if typing.get_origin(payload_type) in (typing.Union, types.UnionType):
|
||||||
media_types = []
|
media_types = []
|
||||||
item_types = []
|
item_types = []
|
||||||
for x in typing.get_args(payload_type):
|
for x in typing.get_args(payload_type):
|
||||||
|
|
|
@ -1,16 +1,16 @@
|
||||||
sphinx==8.1.3
|
|
||||||
myst-parser
|
|
||||||
linkify
|
linkify
|
||||||
|
myst-parser
|
||||||
-e git+https://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme
|
-e git+https://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme
|
||||||
sphinx-rtd-theme>=1.0.0
|
sphinx==8.1.3
|
||||||
sphinx_autobuild
|
|
||||||
sphinx-copybutton
|
sphinx-copybutton
|
||||||
sphinx-design
|
sphinx-design
|
||||||
sphinx-pdj-theme
|
sphinx-pdj-theme
|
||||||
sphinx_rtd_dark_mode
|
sphinx-rtd-theme>=1.0.0
|
||||||
sphinx-tabs
|
sphinx-tabs
|
||||||
|
sphinx_autobuild
|
||||||
|
sphinx_rtd_dark_mode
|
||||||
|
sphinxcontrib-mermaid
|
||||||
sphinxcontrib-openapi
|
sphinxcontrib-openapi
|
||||||
sphinxcontrib-redoc
|
sphinxcontrib-redoc
|
||||||
sphinxcontrib-mermaid
|
|
||||||
sphinxcontrib-video
|
sphinxcontrib-video
|
||||||
tomli
|
tomli
|
||||||
|
|
|
@ -68,7 +68,8 @@ chunks_response = client.vector_io.query(
|
||||||
### Using the RAG Tool
|
### Using the RAG Tool
|
||||||
|
|
||||||
A better way to ingest documents is to use the RAG Tool. This tool allows you to ingest documents from URLs, files, etc.
|
A better way to ingest documents is to use the RAG Tool. This tool allows you to ingest documents from URLs, files, etc.
|
||||||
and automatically chunks them into smaller pieces.
|
and automatically chunks them into smaller pieces. More examples for how to format a RAGDocument can be found in the
|
||||||
|
[appendix](#more-ragdocument-examples).
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from llama_stack_client import RAGDocument
|
from llama_stack_client import RAGDocument
|
||||||
|
@ -178,3 +179,38 @@ for vector_db_id in client.vector_dbs.list():
|
||||||
print(f"Unregistering vector database: {vector_db_id.identifier}")
|
print(f"Unregistering vector database: {vector_db_id.identifier}")
|
||||||
client.vector_dbs.unregister(vector_db_id=vector_db_id.identifier)
|
client.vector_dbs.unregister(vector_db_id=vector_db_id.identifier)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Appendix
|
||||||
|
|
||||||
|
#### More RAGDocument Examples
|
||||||
|
```python
|
||||||
|
from llama_stack_client import RAGDocument
|
||||||
|
import base64
|
||||||
|
|
||||||
|
RAGDocument(document_id="num-0", content={"uri": "file://path/to/file"})
|
||||||
|
RAGDocument(document_id="num-1", content="plain text")
|
||||||
|
RAGDocument(
|
||||||
|
document_id="num-2",
|
||||||
|
content={
|
||||||
|
"type": "text",
|
||||||
|
"text": "plain text input",
|
||||||
|
}, # for inputs that should be treated as text explicitly
|
||||||
|
)
|
||||||
|
RAGDocument(
|
||||||
|
document_id="num-3",
|
||||||
|
content={
|
||||||
|
"type": "image",
|
||||||
|
"image": {"url": {"uri": "https://mywebsite.com/image.jpg"}},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
B64_ENCODED_IMAGE = base64.b64encode(
|
||||||
|
requests.get(
|
||||||
|
"https://raw.githubusercontent.com/meta-llama/llama-stack/refs/heads/main/docs/_static/llama-stack.png"
|
||||||
|
).content
|
||||||
|
)
|
||||||
|
RAGDocuemnt(
|
||||||
|
document_id="num-4",
|
||||||
|
content={"type": "image", "image": {"data": B64_ENCODED_IMAGE}},
|
||||||
|
)
|
||||||
|
```
|
||||||
|
for more strongly typed interaction use the typed dicts found [here](https://github.com/meta-llama/llama-stack-client-python/blob/38cd91c9e396f2be0bec1ee96a19771582ba6f17/src/llama_stack_client/types/shared_params/document.py).
|
||||||
|
|
|
@ -41,30 +41,9 @@ client.toolgroups.register(
|
||||||
|
|
||||||
The tool requires an API key which can be provided either in the configuration or through the request header `X-LlamaStack-Provider-Data`. The format of the header is `{"<provider_name>_api_key": <your api key>}`.
|
The tool requires an API key which can be provided either in the configuration or through the request header `X-LlamaStack-Provider-Data`. The format of the header is `{"<provider_name>_api_key": <your api key>}`.
|
||||||
|
|
||||||
|
> **NOTE:** When using Tavily Search and Bing Search, the inference output will still display "Brave Search." This is because Llama models have been trained with Brave Search as a built-in tool. Tavily and bing is just being used in lieu of Brave search.
|
||||||
|
|
||||||
|
|
||||||
#### Code Interpreter
|
|
||||||
|
|
||||||
The Code Interpreter allows execution of Python code within a controlled environment.
|
|
||||||
|
|
||||||
```python
|
|
||||||
# Register Code Interpreter tool group
|
|
||||||
client.toolgroups.register(
|
|
||||||
toolgroup_id="builtin::code_interpreter", provider_id="code_interpreter"
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
Features:
|
|
||||||
- Secure execution environment using `bwrap` sandboxing
|
|
||||||
- Matplotlib support for generating plots
|
|
||||||
- Disabled dangerous system operations
|
|
||||||
- Configurable execution timeouts
|
|
||||||
|
|
||||||
> ⚠️ Important: The code interpreter tool can operate in a controlled environment locally or on Podman containers. To ensure proper functionality in containerized environments:
|
|
||||||
> - The container requires privileged access (e.g., --privileged).
|
|
||||||
> - Users without sufficient permissions may encounter permission errors. (`bwrap: Can't mount devpts on /newroot/dev/pts: Permission denied`)
|
|
||||||
> - 🔒 Security Warning: Privileged mode grants elevated access and bypasses security restrictions. Use only in local, isolated, or controlled environments.
|
|
||||||
|
|
||||||
#### WolframAlpha
|
#### WolframAlpha
|
||||||
|
|
||||||
The WolframAlpha tool provides access to computational knowledge through the WolframAlpha API.
|
The WolframAlpha tool provides access to computational knowledge through the WolframAlpha API.
|
||||||
|
@ -102,7 +81,7 @@ Features:
|
||||||
- Context retrieval with token limits
|
- Context retrieval with token limits
|
||||||
|
|
||||||
|
|
||||||
> **Note:** By default, llama stack run.yaml defines toolgroups for web search, code interpreter and rag, that are provided by tavily-search, code-interpreter and rag providers.
|
> **Note:** By default, llama stack run.yaml defines toolgroups for web search, wolfram alpha and rag, that are provided by tavily-search, wolfram-alpha and rag providers.
|
||||||
|
|
||||||
## Model Context Protocol (MCP) Tools
|
## Model Context Protocol (MCP) Tools
|
||||||
|
|
||||||
|
@ -214,3 +193,69 @@ response = agent.create_turn(
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
## Simple Example 2: Using an Agent with the Web Search Tool
|
||||||
|
1. Start by registering a Tavily API key at [Tavily](https://tavily.com/).
|
||||||
|
2. [Optional] Provide the API key directly to the Llama Stack server
|
||||||
|
```bash
|
||||||
|
export TAVILY_SEARCH_API_KEY="your key"
|
||||||
|
```
|
||||||
|
```bash
|
||||||
|
--env TAVILY_SEARCH_API_KEY=${TAVILY_SEARCH_API_KEY}
|
||||||
|
```
|
||||||
|
3. Run the following script.
|
||||||
|
```python
|
||||||
|
from llama_stack_client.lib.agents.agent import Agent
|
||||||
|
from llama_stack_client.types.agent_create_params import AgentConfig
|
||||||
|
from llama_stack_client.lib.agents.event_logger import EventLogger
|
||||||
|
from llama_stack_client import LlamaStackClient
|
||||||
|
|
||||||
|
client = LlamaStackClient(
|
||||||
|
base_url=f"http://localhost:8321",
|
||||||
|
provider_data={
|
||||||
|
"tavily_search_api_key": "your_TAVILY_SEARCH_API_KEY"
|
||||||
|
}, # Set this from the client side. No need to provide it if it has already been configured on the Llama Stack server.
|
||||||
|
)
|
||||||
|
|
||||||
|
agent = Agent(
|
||||||
|
client,
|
||||||
|
model="meta-llama/Llama-3.2-3B-Instruct",
|
||||||
|
instructions=(
|
||||||
|
"You are a web search assistant, must use websearch tool to look up the most current and precise information available. "
|
||||||
|
),
|
||||||
|
tools=["builtin::websearch"],
|
||||||
|
)
|
||||||
|
|
||||||
|
session_id = agent.create_session("websearch-session")
|
||||||
|
|
||||||
|
response = agent.create_turn(
|
||||||
|
messages=[
|
||||||
|
{"role": "user", "content": "How did the USA perform in the last Olympics?"}
|
||||||
|
],
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
for log in EventLogger().log(response):
|
||||||
|
log.print()
|
||||||
|
```
|
||||||
|
|
||||||
|
## Simple Example3: Using an Agent with the WolframAlpha Tool
|
||||||
|
1. Start by registering for a WolframAlpha API key at [WolframAlpha Developer Portal](https://developer.wolframalpha.com/access).
|
||||||
|
2. Provide the API key either when starting the Llama Stack server:
|
||||||
|
```bash
|
||||||
|
--env WOLFRAM_ALPHA_API_KEY=${WOLFRAM_ALPHA_API_KEY}
|
||||||
|
```
|
||||||
|
or from the client side:
|
||||||
|
```python
|
||||||
|
client = LlamaStackClient(
|
||||||
|
base_url="http://localhost:8321",
|
||||||
|
provider_data={"wolfram_alpha_api_key": wolfram_api_key},
|
||||||
|
)
|
||||||
|
```
|
||||||
|
3. Configure the tools in the Agent by setting `tools=["builtin::wolfram_alpha"]`.
|
||||||
|
4. Example user query:
|
||||||
|
```python
|
||||||
|
response = agent.create_turn(
|
||||||
|
messages=[{"role": "user", "content": "Solve x^2 + 2x + 1 = 0 using WolframAlpha"}],
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
```
|
||||||
|
|
|
@ -109,8 +109,6 @@ llama stack build --list-templates
|
||||||
+------------------------------+-----------------------------------------------------------------------------+
|
+------------------------------+-----------------------------------------------------------------------------+
|
||||||
| nvidia | Use NVIDIA NIM for running LLM inference |
|
| nvidia | Use NVIDIA NIM for running LLM inference |
|
||||||
+------------------------------+-----------------------------------------------------------------------------+
|
+------------------------------+-----------------------------------------------------------------------------+
|
||||||
| meta-reference-quantized-gpu | Use Meta Reference with fp8, int4 quantization for running LLM inference |
|
|
||||||
+------------------------------+-----------------------------------------------------------------------------+
|
|
||||||
| cerebras | Use Cerebras for running LLM inference |
|
| cerebras | Use Cerebras for running LLM inference |
|
||||||
+------------------------------+-----------------------------------------------------------------------------+
|
+------------------------------+-----------------------------------------------------------------------------+
|
||||||
| ollama | Use (an external) Ollama server for running LLM inference |
|
| ollama | Use (an external) Ollama server for running LLM inference |
|
||||||
|
@ -176,7 +174,11 @@ distribution_spec:
|
||||||
safety: inline::llama-guard
|
safety: inline::llama-guard
|
||||||
agents: inline::meta-reference
|
agents: inline::meta-reference
|
||||||
telemetry: inline::meta-reference
|
telemetry: inline::meta-reference
|
||||||
|
image_name: ollama
|
||||||
image_type: conda
|
image_type: conda
|
||||||
|
|
||||||
|
# If some providers are external, you can specify the path to the implementation
|
||||||
|
external_providers_dir: /etc/llama-stack/providers.d
|
||||||
```
|
```
|
||||||
|
|
||||||
```
|
```
|
||||||
|
@ -184,6 +186,57 @@ llama stack build --config llama_stack/templates/ollama/build.yaml
|
||||||
```
|
```
|
||||||
:::
|
:::
|
||||||
|
|
||||||
|
:::{tab-item} Building with External Providers
|
||||||
|
|
||||||
|
Llama Stack supports external providers that live outside of the main codebase. This allows you to create and maintain your own providers independently or use community-provided providers.
|
||||||
|
|
||||||
|
To build a distribution with external providers, you need to:
|
||||||
|
|
||||||
|
1. Configure the `external_providers_dir` in your build configuration file:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# Example my-external-stack.yaml with external providers
|
||||||
|
version: '2'
|
||||||
|
distribution_spec:
|
||||||
|
description: Custom distro for CI tests
|
||||||
|
providers:
|
||||||
|
inference:
|
||||||
|
- remote::custom_ollama
|
||||||
|
# Add more providers as needed
|
||||||
|
image_type: container
|
||||||
|
image_name: ci-test
|
||||||
|
# Path to external provider implementations
|
||||||
|
external_providers_dir: /etc/llama-stack/providers.d
|
||||||
|
```
|
||||||
|
|
||||||
|
Here's an example for a custom Ollama provider:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
adapter:
|
||||||
|
adapter_type: custom_ollama
|
||||||
|
pip_packages:
|
||||||
|
- ollama
|
||||||
|
- aiohttp
|
||||||
|
- llama-stack-provider-ollama # This is the provider package
|
||||||
|
config_class: llama_stack_ollama_provider.config.OllamaImplConfig
|
||||||
|
module: llama_stack_ollama_provider
|
||||||
|
api_dependencies: []
|
||||||
|
optional_api_dependencies: []
|
||||||
|
```
|
||||||
|
|
||||||
|
The `pip_packages` section lists the Python packages required by the provider, as well as the
|
||||||
|
provider package itself. The package must be available on PyPI or can be provided from a local
|
||||||
|
directory or a git repository (git must be installed on the build environment).
|
||||||
|
|
||||||
|
2. Build your distribution using the config file:
|
||||||
|
|
||||||
|
```
|
||||||
|
llama stack build --config my-external-stack.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
For more information on external providers, including directory structure, provider types, and implementation requirements, see the [External Providers documentation](../providers/external.md).
|
||||||
|
:::
|
||||||
|
|
||||||
:::{tab-item} Building Container
|
:::{tab-item} Building Container
|
||||||
|
|
||||||
```{admonition} Podman Alternative
|
```{admonition} Podman Alternative
|
||||||
|
|
|
@ -53,6 +53,13 @@ models:
|
||||||
provider_id: ollama
|
provider_id: ollama
|
||||||
provider_model_id: null
|
provider_model_id: null
|
||||||
shields: []
|
shields: []
|
||||||
|
server:
|
||||||
|
port: 8321
|
||||||
|
auth:
|
||||||
|
provider_type: "kubernetes"
|
||||||
|
config:
|
||||||
|
api_server_url: "https://kubernetes.default.svc"
|
||||||
|
ca_cert_path: "/path/to/ca.crt"
|
||||||
```
|
```
|
||||||
|
|
||||||
Let's break this down into the different sections. The first section specifies the set of APIs that the stack server will serve:
|
Let's break this down into the different sections. The first section specifies the set of APIs that the stack server will serve:
|
||||||
|
@ -102,6 +109,105 @@ A Model is an instance of a "Resource" (see [Concepts](../concepts/index)) and i
|
||||||
|
|
||||||
What's with the `provider_model_id` field? This is an identifier for the model inside the provider's model catalog. Contrast it with `model_id` which is the identifier for the same model for Llama Stack's purposes. For example, you may want to name "llama3.2:vision-11b" as "image_captioning_model" when you use it in your Stack interactions. When omitted, the server will set `provider_model_id` to be the same as `model_id`.
|
What's with the `provider_model_id` field? This is an identifier for the model inside the provider's model catalog. Contrast it with `model_id` which is the identifier for the same model for Llama Stack's purposes. For example, you may want to name "llama3.2:vision-11b" as "image_captioning_model" when you use it in your Stack interactions. When omitted, the server will set `provider_model_id` to be the same as `model_id`.
|
||||||
|
|
||||||
|
## Server Configuration
|
||||||
|
|
||||||
|
The `server` section configures the HTTP server that serves the Llama Stack APIs:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
server:
|
||||||
|
port: 8321 # Port to listen on (default: 8321)
|
||||||
|
tls_certfile: "/path/to/cert.pem" # Optional: Path to TLS certificate for HTTPS
|
||||||
|
tls_keyfile: "/path/to/key.pem" # Optional: Path to TLS key for HTTPS
|
||||||
|
auth: # Optional: Authentication configuration
|
||||||
|
provider_type: "kubernetes" # Type of auth provider
|
||||||
|
config: # Provider-specific configuration
|
||||||
|
api_server_url: "https://kubernetes.default.svc"
|
||||||
|
ca_cert_path: "/path/to/ca.crt" # Optional: Path to CA certificate
|
||||||
|
```
|
||||||
|
|
||||||
|
### Authentication Configuration
|
||||||
|
|
||||||
|
The `auth` section configures authentication for the server. When configured, all API requests must include a valid Bearer token in the Authorization header:
|
||||||
|
|
||||||
|
```
|
||||||
|
Authorization: Bearer <token>
|
||||||
|
```
|
||||||
|
|
||||||
|
The server supports multiple authentication providers:
|
||||||
|
|
||||||
|
#### Kubernetes Provider
|
||||||
|
|
||||||
|
The Kubernetes cluster must be configured to use a service account for authentication.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
kubectl create namespace llama-stack
|
||||||
|
kubectl create serviceaccount llama-stack-auth -n llama-stack
|
||||||
|
kubectl create rolebinding llama-stack-auth-rolebinding --clusterrole=admin --serviceaccount=llama-stack:llama-stack-auth -n llama-stack
|
||||||
|
kubectl create token llama-stack-auth -n llama-stack > llama-stack-auth-token
|
||||||
|
```
|
||||||
|
|
||||||
|
Validates tokens against the Kubernetes API server:
|
||||||
|
```yaml
|
||||||
|
server:
|
||||||
|
auth:
|
||||||
|
provider_type: "kubernetes"
|
||||||
|
config:
|
||||||
|
api_server_url: "https://kubernetes.default.svc" # URL of the Kubernetes API server
|
||||||
|
ca_cert_path: "/path/to/ca.crt" # Optional: Path to CA certificate
|
||||||
|
```
|
||||||
|
|
||||||
|
The provider extracts user information from the JWT token:
|
||||||
|
- Username from the `sub` claim becomes a role
|
||||||
|
- Kubernetes groups become teams
|
||||||
|
|
||||||
|
You can easily validate a request by running:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -s -L -H "Authorization: Bearer $(cat llama-stack-auth-token)" http://127.0.0.1:8321/v1/providers
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Custom Provider
|
||||||
|
Validates tokens against a custom authentication endpoint:
|
||||||
|
```yaml
|
||||||
|
server:
|
||||||
|
auth:
|
||||||
|
provider_type: "custom"
|
||||||
|
config:
|
||||||
|
endpoint: "https://auth.example.com/validate" # URL of the auth endpoint
|
||||||
|
```
|
||||||
|
|
||||||
|
The custom endpoint receives a POST request with:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"api_key": "<token>",
|
||||||
|
"request": {
|
||||||
|
"path": "/api/v1/endpoint",
|
||||||
|
"headers": {
|
||||||
|
"content-type": "application/json",
|
||||||
|
"user-agent": "curl/7.64.1"
|
||||||
|
},
|
||||||
|
"params": {
|
||||||
|
"key": ["value"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
And must respond with:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"access_attributes": {
|
||||||
|
"roles": ["admin", "user"],
|
||||||
|
"teams": ["ml-team", "nlp-team"],
|
||||||
|
"projects": ["llama-3", "project-x"],
|
||||||
|
"namespaces": ["research"]
|
||||||
|
},
|
||||||
|
"message": "Authentication successful"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
If no access attributes are returned, the token is used as a namespace.
|
||||||
|
|
||||||
## Extending to handle Safety
|
## Extending to handle Safety
|
||||||
|
|
||||||
Configuring Safety can be a little involved so it is instructive to go through an example.
|
Configuring Safety can be a little involved so it is instructive to go through an example.
|
||||||
|
|
|
@ -24,7 +24,7 @@ The key files in the app are `ExampleLlamaStackLocalInference.kt`, `ExampleLlama
|
||||||
Add the following dependency in your `build.gradle.kts` file:
|
Add the following dependency in your `build.gradle.kts` file:
|
||||||
```
|
```
|
||||||
dependencies {
|
dependencies {
|
||||||
implementation("com.llama.llamastack:llama-stack-client-kotlin:0.1.4.2")
|
implementation("com.llama.llamastack:llama-stack-client-kotlin:0.2.2")
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
This will download jar files in your gradle cache in a directory like `~/.gradle/caches/modules-2/files-2.1/com.llama.llamastack/`
|
This will download jar files in your gradle cache in a directory like `~/.gradle/caches/modules-2/files-2.1/com.llama.llamastack/`
|
||||||
|
@ -37,11 +37,7 @@ For local inferencing, it is required to include the ExecuTorch library into you
|
||||||
|
|
||||||
Include the ExecuTorch library by:
|
Include the ExecuTorch library by:
|
||||||
1. Download the `download-prebuilt-et-lib.sh` script file from the [llama-stack-client-kotlin-client-local](https://github.com/meta-llama/llama-stack-client-kotlin/tree/latest-release/llama-stack-client-kotlin-client-local/download-prebuilt-et-lib.sh) directory to your local machine.
|
1. Download the `download-prebuilt-et-lib.sh` script file from the [llama-stack-client-kotlin-client-local](https://github.com/meta-llama/llama-stack-client-kotlin/tree/latest-release/llama-stack-client-kotlin-client-local/download-prebuilt-et-lib.sh) directory to your local machine.
|
||||||
2. Move the script to the top level of your Android app where the app directory resides:
|
2. Move the script to the top level of your Android app where the `app` directory resides.
|
||||||
<p align="center">
|
|
||||||
<img src="https://github.com/meta-llama/llama-stack-client-kotlin/blob/latest-release/doc/img/example_android_app_directory.png" style="width:300px">
|
|
||||||
</p>
|
|
||||||
|
|
||||||
3. Run `sh download-prebuilt-et-lib.sh` to create an `app/libs` directory and download the `executorch.aar` in that path. This generates an ExecuTorch library for the XNNPACK delegate.
|
3. Run `sh download-prebuilt-et-lib.sh` to create an `app/libs` directory and download the `executorch.aar` in that path. This generates an ExecuTorch library for the XNNPACK delegate.
|
||||||
4. Add the `executorch.aar` dependency in your `build.gradle.kts` file:
|
4. Add the `executorch.aar` dependency in your `build.gradle.kts` file:
|
||||||
```
|
```
|
||||||
|
@ -52,6 +48,8 @@ dependencies {
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
See other dependencies for the local RAG in Android app [README](https://github.com/meta-llama/llama-stack-client-kotlin/tree/latest-release/examples/android_app#quick-start).
|
||||||
|
|
||||||
## Llama Stack APIs in Your Android App
|
## Llama Stack APIs in Your Android App
|
||||||
Breaking down the demo app, this section will show the core pieces that are used to initialize and run inference with Llama Stack using the Kotlin library.
|
Breaking down the demo app, this section will show the core pieces that are used to initialize and run inference with Llama Stack using the Kotlin library.
|
||||||
|
|
||||||
|
@ -60,7 +58,7 @@ Start a Llama Stack server on localhost. Here is an example of how you can do th
|
||||||
```
|
```
|
||||||
conda create -n stack-fireworks python=3.10
|
conda create -n stack-fireworks python=3.10
|
||||||
conda activate stack-fireworks
|
conda activate stack-fireworks
|
||||||
pip install --no-cache llama-stack==0.1.4
|
pip install --no-cache llama-stack==0.2.2
|
||||||
llama stack build --template fireworks --image-type conda
|
llama stack build --template fireworks --image-type conda
|
||||||
export FIREWORKS_API_KEY=<SOME_KEY>
|
export FIREWORKS_API_KEY=<SOME_KEY>
|
||||||
llama stack run fireworks --port 5050
|
llama stack run fireworks --port 5050
|
||||||
|
|
|
@ -1,88 +0,0 @@
|
||||||
<!-- This file was auto-generated by distro_codegen.py, please edit source -->
|
|
||||||
# NVIDIA Distribution
|
|
||||||
|
|
||||||
The `llamastack/distribution-nvidia` distribution consists of the following provider configurations.
|
|
||||||
|
|
||||||
| API | Provider(s) |
|
|
||||||
|-----|-------------|
|
|
||||||
| agents | `inline::meta-reference` |
|
|
||||||
| datasetio | `inline::localfs` |
|
|
||||||
| eval | `inline::meta-reference` |
|
|
||||||
| inference | `remote::nvidia` |
|
|
||||||
| post_training | `remote::nvidia` |
|
|
||||||
| safety | `remote::nvidia` |
|
|
||||||
| scoring | `inline::basic` |
|
|
||||||
| telemetry | `inline::meta-reference` |
|
|
||||||
| tool_runtime | `inline::rag-runtime` |
|
|
||||||
| vector_io | `inline::faiss` |
|
|
||||||
|
|
||||||
|
|
||||||
### Environment Variables
|
|
||||||
|
|
||||||
The following environment variables can be configured:
|
|
||||||
|
|
||||||
- `NVIDIA_API_KEY`: NVIDIA API Key (default: ``)
|
|
||||||
- `NVIDIA_USER_ID`: NVIDIA User ID (default: `llama-stack-user`)
|
|
||||||
- `NVIDIA_DATASET_NAMESPACE`: NVIDIA Dataset Namespace (default: `default`)
|
|
||||||
- `NVIDIA_ACCESS_POLICIES`: NVIDIA Access Policies (default: `{}`)
|
|
||||||
- `NVIDIA_PROJECT_ID`: NVIDIA Project ID (default: `test-project`)
|
|
||||||
- `NVIDIA_CUSTOMIZER_URL`: NVIDIA Customizer URL (default: `https://customizer.api.nvidia.com`)
|
|
||||||
- `NVIDIA_OUTPUT_MODEL_DIR`: NVIDIA Output Model Directory (default: `test-example-model@v1`)
|
|
||||||
- `GUARDRAILS_SERVICE_URL`: URL for the NeMo Guardrails Service (default: `http://0.0.0.0:7331`)
|
|
||||||
- `INFERENCE_MODEL`: Inference model (default: `Llama3.1-8B-Instruct`)
|
|
||||||
- `SAFETY_MODEL`: Name of the model to use for safety (default: `meta/llama-3.1-8b-instruct`)
|
|
||||||
|
|
||||||
### Models
|
|
||||||
|
|
||||||
The following models are available by default:
|
|
||||||
|
|
||||||
- `meta/llama3-8b-instruct (aliases: meta-llama/Llama-3-8B-Instruct)`
|
|
||||||
- `meta/llama3-70b-instruct (aliases: meta-llama/Llama-3-70B-Instruct)`
|
|
||||||
- `meta/llama-3.1-8b-instruct (aliases: meta-llama/Llama-3.1-8B-Instruct)`
|
|
||||||
- `meta/llama-3.1-70b-instruct (aliases: meta-llama/Llama-3.1-70B-Instruct)`
|
|
||||||
- `meta/llama-3.1-405b-instruct (aliases: meta-llama/Llama-3.1-405B-Instruct-FP8)`
|
|
||||||
- `meta/llama-3.2-1b-instruct (aliases: meta-llama/Llama-3.2-1B-Instruct)`
|
|
||||||
- `meta/llama-3.2-3b-instruct (aliases: meta-llama/Llama-3.2-3B-Instruct)`
|
|
||||||
- `meta/llama-3.2-11b-vision-instruct (aliases: meta-llama/Llama-3.2-11B-Vision-Instruct)`
|
|
||||||
- `meta/llama-3.2-90b-vision-instruct (aliases: meta-llama/Llama-3.2-90B-Vision-Instruct)`
|
|
||||||
- `nvidia/llama-3.2-nv-embedqa-1b-v2 `
|
|
||||||
- `nvidia/nv-embedqa-e5-v5 `
|
|
||||||
- `nvidia/nv-embedqa-mistral-7b-v2 `
|
|
||||||
- `snowflake/arctic-embed-l `
|
|
||||||
|
|
||||||
|
|
||||||
### Prerequisite: API Keys
|
|
||||||
|
|
||||||
Make sure you have access to a NVIDIA API Key. You can get one by visiting [https://build.nvidia.com/](https://build.nvidia.com/).
|
|
||||||
|
|
||||||
|
|
||||||
## Running Llama Stack with NVIDIA
|
|
||||||
|
|
||||||
You can do this via Conda (build code) or Docker which has a pre-built image.
|
|
||||||
|
|
||||||
### Via Docker
|
|
||||||
|
|
||||||
This method allows you to get started quickly without having to build the distribution code.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
LLAMA_STACK_PORT=8321
|
|
||||||
docker run \
|
|
||||||
-it \
|
|
||||||
--pull always \
|
|
||||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
|
||||||
-v ./run.yaml:/root/my-run.yaml \
|
|
||||||
llamastack/distribution-nvidia \
|
|
||||||
--yaml-config /root/my-run.yaml \
|
|
||||||
--port $LLAMA_STACK_PORT \
|
|
||||||
--env NVIDIA_API_KEY=$NVIDIA_API_KEY
|
|
||||||
```
|
|
||||||
|
|
||||||
### Via Conda
|
|
||||||
|
|
||||||
```bash
|
|
||||||
llama stack build --template nvidia --image-type conda
|
|
||||||
llama stack run ./run.yaml \
|
|
||||||
--port 8321 \
|
|
||||||
--env NVIDIA_API_KEY=$NVIDIA_API_KEY
|
|
||||||
--env INFERENCE_MODEL=$INFERENCE_MODEL
|
|
||||||
```
|
|
88
docs/source/distributions/remote_hosted_distro/watsonx.md
Normal file
88
docs/source/distributions/remote_hosted_distro/watsonx.md
Normal file
|
@ -0,0 +1,88 @@
|
||||||
|
---
|
||||||
|
orphan: true
|
||||||
|
---
|
||||||
|
<!-- This file was auto-generated by distro_codegen.py, please edit source -->
|
||||||
|
# watsonx Distribution
|
||||||
|
|
||||||
|
```{toctree}
|
||||||
|
:maxdepth: 2
|
||||||
|
:hidden:
|
||||||
|
|
||||||
|
self
|
||||||
|
```
|
||||||
|
|
||||||
|
The `llamastack/distribution-watsonx` distribution consists of the following provider configurations.
|
||||||
|
|
||||||
|
| API | Provider(s) |
|
||||||
|
|-----|-------------|
|
||||||
|
| agents | `inline::meta-reference` |
|
||||||
|
| datasetio | `remote::huggingface`, `inline::localfs` |
|
||||||
|
| eval | `inline::meta-reference` |
|
||||||
|
| inference | `remote::watsonx` |
|
||||||
|
| safety | `inline::llama-guard` |
|
||||||
|
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||||
|
| telemetry | `inline::meta-reference` |
|
||||||
|
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::rag-runtime`, `remote::model-context-protocol` |
|
||||||
|
| vector_io | `inline::faiss` |
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
### Environment Variables
|
||||||
|
|
||||||
|
The following environment variables can be configured:
|
||||||
|
|
||||||
|
- `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
|
||||||
|
- `WATSONX_API_KEY`: watsonx API Key (default: ``)
|
||||||
|
- `WATSONX_PROJECT_ID`: watsonx Project ID (default: ``)
|
||||||
|
|
||||||
|
### Models
|
||||||
|
|
||||||
|
The following models are available by default:
|
||||||
|
|
||||||
|
- `meta-llama/llama-3-3-70b-instruct (aliases: meta-llama/Llama-3.3-70B-Instruct)`
|
||||||
|
- `meta-llama/llama-2-13b-chat (aliases: meta-llama/Llama-2-13b)`
|
||||||
|
- `meta-llama/llama-3-1-70b-instruct (aliases: meta-llama/Llama-3.1-70B-Instruct)`
|
||||||
|
- `meta-llama/llama-3-1-8b-instruct (aliases: meta-llama/Llama-3.1-8B-Instruct)`
|
||||||
|
- `meta-llama/llama-3-2-11b-vision-instruct (aliases: meta-llama/Llama-3.2-11B-Vision-Instruct)`
|
||||||
|
- `meta-llama/llama-3-2-1b-instruct (aliases: meta-llama/Llama-3.2-1B-Instruct)`
|
||||||
|
- `meta-llama/llama-3-2-3b-instruct (aliases: meta-llama/Llama-3.2-3B-Instruct)`
|
||||||
|
- `meta-llama/llama-3-2-90b-vision-instruct (aliases: meta-llama/Llama-3.2-90B-Vision-Instruct)`
|
||||||
|
- `meta-llama/llama-guard-3-11b-vision (aliases: meta-llama/Llama-Guard-3-11B-Vision)`
|
||||||
|
|
||||||
|
|
||||||
|
### Prerequisite: API Keys
|
||||||
|
|
||||||
|
Make sure you have access to a watsonx API Key. You can get one by referring [watsonx.ai](https://www.ibm.com/docs/en/masv-and-l/maximo-manage/continuous-delivery?topic=setup-create-watsonx-api-key).
|
||||||
|
|
||||||
|
|
||||||
|
## Running Llama Stack with watsonx
|
||||||
|
|
||||||
|
You can do this via Conda (build code), venv or Docker which has a pre-built image.
|
||||||
|
|
||||||
|
### Via Docker
|
||||||
|
|
||||||
|
This method allows you to get started quickly without having to build the distribution code.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
LLAMA_STACK_PORT=5001
|
||||||
|
docker run \
|
||||||
|
-it \
|
||||||
|
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||||
|
-v ./run.yaml:/root/my-run.yaml \
|
||||||
|
llamastack/distribution-watsonx \
|
||||||
|
--yaml-config /root/my-run.yaml \
|
||||||
|
--port $LLAMA_STACK_PORT \
|
||||||
|
--env WATSONX_API_KEY=$WATSONX_API_KEY \
|
||||||
|
--env WATSONX_PROJECT_ID=$WATSONX_PROJECT_ID \
|
||||||
|
--env WATSONX_BASE_URL=$WATSONX_BASE_URL
|
||||||
|
```
|
||||||
|
|
||||||
|
### Via Conda
|
||||||
|
|
||||||
|
```bash
|
||||||
|
llama stack build --template watsonx --image-type conda
|
||||||
|
llama stack run ./run.yaml \
|
||||||
|
--port $LLAMA_STACK_PORT \
|
||||||
|
--env WATSONX_API_KEY=$WATSONX_API_KEY \
|
||||||
|
--env WATSONX_PROJECT_ID=$WATSONX_PROJECT_ID
|
||||||
|
```
|
|
@ -19,7 +19,7 @@ The `llamastack/distribution-bedrock` distribution consists of the following pro
|
||||||
| safety | `remote::bedrock` |
|
| safety | `remote::bedrock` |
|
||||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||||
| telemetry | `inline::meta-reference` |
|
| telemetry | `inline::meta-reference` |
|
||||||
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime`, `remote::model-context-protocol` |
|
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::rag-runtime`, `remote::model-context-protocol` |
|
||||||
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -12,7 +12,7 @@ The `llamastack/distribution-cerebras` distribution consists of the following pr
|
||||||
| safety | `inline::llama-guard` |
|
| safety | `inline::llama-guard` |
|
||||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||||
| telemetry | `inline::meta-reference` |
|
| telemetry | `inline::meta-reference` |
|
||||||
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime` |
|
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::rag-runtime` |
|
||||||
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -22,7 +22,7 @@ The `llamastack/distribution-fireworks` distribution consists of the following p
|
||||||
| safety | `inline::llama-guard` |
|
| safety | `inline::llama-guard` |
|
||||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||||
| telemetry | `inline::meta-reference` |
|
| telemetry | `inline::meta-reference` |
|
||||||
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `remote::wolfram-alpha`, `inline::code-interpreter`, `inline::rag-runtime`, `remote::model-context-protocol` |
|
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `remote::wolfram-alpha`, `inline::rag-runtime`, `remote::model-context-protocol` |
|
||||||
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -22,7 +22,7 @@ The `llamastack/distribution-groq` distribution consists of the following provid
|
||||||
| safety | `inline::llama-guard` |
|
| safety | `inline::llama-guard` |
|
||||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||||
| telemetry | `inline::meta-reference` |
|
| telemetry | `inline::meta-reference` |
|
||||||
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime` |
|
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::rag-runtime` |
|
||||||
| vector_io | `inline::faiss` |
|
| vector_io | `inline::faiss` |
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -22,7 +22,7 @@ The `llamastack/distribution-meta-reference-gpu` distribution consists of the fo
|
||||||
| safety | `inline::llama-guard` |
|
| safety | `inline::llama-guard` |
|
||||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||||
| telemetry | `inline::meta-reference` |
|
| telemetry | `inline::meta-reference` |
|
||||||
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime`, `remote::model-context-protocol` |
|
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::rag-runtime`, `remote::model-context-protocol` |
|
||||||
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
||||||
|
|
||||||
|
|
||||||
|
@ -81,6 +81,7 @@ LLAMA_STACK_PORT=8321
|
||||||
docker run \
|
docker run \
|
||||||
-it \
|
-it \
|
||||||
--pull always \
|
--pull always \
|
||||||
|
--gpu all \
|
||||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||||
-v ~/.llama:/root/.llama \
|
-v ~/.llama:/root/.llama \
|
||||||
llamastack/distribution-meta-reference-gpu \
|
llamastack/distribution-meta-reference-gpu \
|
||||||
|
@ -94,6 +95,7 @@ If you are using Llama Stack Safety / Shield APIs, use:
|
||||||
docker run \
|
docker run \
|
||||||
-it \
|
-it \
|
||||||
--pull always \
|
--pull always \
|
||||||
|
--gpu all \
|
||||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||||
-v ~/.llama:/root/.llama \
|
-v ~/.llama:/root/.llama \
|
||||||
llamastack/distribution-meta-reference-gpu \
|
llamastack/distribution-meta-reference-gpu \
|
||||||
|
|
|
@ -1,123 +0,0 @@
|
||||||
---
|
|
||||||
orphan: true
|
|
||||||
---
|
|
||||||
<!-- This file was auto-generated by distro_codegen.py, please edit source -->
|
|
||||||
# Meta Reference Quantized Distribution
|
|
||||||
|
|
||||||
```{toctree}
|
|
||||||
:maxdepth: 2
|
|
||||||
:hidden:
|
|
||||||
|
|
||||||
self
|
|
||||||
```
|
|
||||||
|
|
||||||
The `llamastack/distribution-meta-reference-quantized-gpu` distribution consists of the following provider configurations:
|
|
||||||
|
|
||||||
| API | Provider(s) |
|
|
||||||
|-----|-------------|
|
|
||||||
| agents | `inline::meta-reference` |
|
|
||||||
| datasetio | `remote::huggingface`, `inline::localfs` |
|
|
||||||
| eval | `inline::meta-reference` |
|
|
||||||
| inference | `inline::meta-reference-quantized` |
|
|
||||||
| safety | `inline::llama-guard` |
|
|
||||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
|
||||||
| telemetry | `inline::meta-reference` |
|
|
||||||
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime`, `remote::model-context-protocol` |
|
|
||||||
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
|
||||||
|
|
||||||
|
|
||||||
The only difference vs. the `meta-reference-gpu` distribution is that it has support for more efficient inference -- with fp8, int4 quantization, etc.
|
|
||||||
|
|
||||||
Note that you need access to nvidia GPUs to run this distribution. This distribution is not compatible with CPU-only machines or machines with AMD GPUs.
|
|
||||||
|
|
||||||
### Environment Variables
|
|
||||||
|
|
||||||
The following environment variables can be configured:
|
|
||||||
|
|
||||||
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `8321`)
|
|
||||||
- `INFERENCE_MODEL`: Inference model loaded into the Meta Reference server (default: `meta-llama/Llama-3.2-3B-Instruct`)
|
|
||||||
- `INFERENCE_CHECKPOINT_DIR`: Directory containing the Meta Reference model checkpoint (default: `null`)
|
|
||||||
|
|
||||||
|
|
||||||
## Prerequisite: Downloading Models
|
|
||||||
|
|
||||||
Please use `llama model list --downloaded` to check that you have llama model checkpoints downloaded in `~/.llama` before proceeding. See [installation guide](https://llama-stack.readthedocs.io/en/latest/references/llama_cli_reference/download_models.html) here to download the models. Run `llama model list` to see the available models to download, and `llama model download` to download the checkpoints.
|
|
||||||
|
|
||||||
```
|
|
||||||
$ llama model list --downloaded
|
|
||||||
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┓
|
|
||||||
┃ Model ┃ Size ┃ Modified Time ┃
|
|
||||||
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━┩
|
|
||||||
│ Llama3.2-1B-Instruct:int4-qlora-eo8 │ 1.53 GB │ 2025-02-26 11:22:28 │
|
|
||||||
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
|
||||||
│ Llama3.2-1B │ 2.31 GB │ 2025-02-18 21:48:52 │
|
|
||||||
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
|
||||||
│ Prompt-Guard-86M │ 0.02 GB │ 2025-02-26 11:29:28 │
|
|
||||||
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
|
||||||
│ Llama3.2-3B-Instruct:int4-spinquant-eo8 │ 3.69 GB │ 2025-02-26 11:37:41 │
|
|
||||||
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
|
||||||
│ Llama3.2-3B │ 5.99 GB │ 2025-02-18 21:51:26 │
|
|
||||||
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
|
||||||
│ Llama3.1-8B │ 14.97 GB │ 2025-02-16 10:36:37 │
|
|
||||||
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
|
||||||
│ Llama3.2-1B-Instruct:int4-spinquant-eo8 │ 1.51 GB │ 2025-02-26 11:35:02 │
|
|
||||||
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
|
||||||
│ Llama-Guard-3-1B │ 2.80 GB │ 2025-02-26 11:20:46 │
|
|
||||||
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
|
||||||
│ Llama-Guard-3-1B:int4 │ 0.43 GB │ 2025-02-26 11:33:33 │
|
|
||||||
└─────────────────────────────────────────┴──────────┴─────────────────────┘
|
|
||||||
```
|
|
||||||
|
|
||||||
## Running the Distribution
|
|
||||||
|
|
||||||
You can do this via Conda (build code) or Docker which has a pre-built image.
|
|
||||||
|
|
||||||
### Via Docker
|
|
||||||
|
|
||||||
This method allows you to get started quickly without having to build the distribution code.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
LLAMA_STACK_PORT=8321
|
|
||||||
docker run \
|
|
||||||
-it \
|
|
||||||
--pull always \
|
|
||||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
|
||||||
-v ~/.llama:/root/.llama \
|
|
||||||
llamastack/distribution-meta-reference-quantized-gpu \
|
|
||||||
--port $LLAMA_STACK_PORT \
|
|
||||||
--env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
|
|
||||||
```
|
|
||||||
|
|
||||||
If you are using Llama Stack Safety / Shield APIs, use:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
docker run \
|
|
||||||
-it \
|
|
||||||
--pull always \
|
|
||||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
|
||||||
-v ~/.llama:/root/.llama \
|
|
||||||
llamastack/distribution-meta-reference-quantized-gpu \
|
|
||||||
--port $LLAMA_STACK_PORT \
|
|
||||||
--env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \
|
|
||||||
--env SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
|
|
||||||
```
|
|
||||||
|
|
||||||
### Via Conda
|
|
||||||
|
|
||||||
Make sure you have done `uv pip install llama-stack` and have the Llama Stack CLI available.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
llama stack build --template meta-reference-quantized-gpu --image-type conda
|
|
||||||
llama stack run distributions/meta-reference-quantized-gpu/run.yaml \
|
|
||||||
--port $LLAMA_STACK_PORT \
|
|
||||||
--env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
|
|
||||||
```
|
|
||||||
|
|
||||||
If you are using Llama Stack Safety / Shield APIs, use:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
llama stack run distributions/meta-reference-quantized-gpu/run-with-safety.yaml \
|
|
||||||
--port $LLAMA_STACK_PORT \
|
|
||||||
--env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \
|
|
||||||
--env SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
|
|
||||||
```
|
|
|
@ -6,8 +6,8 @@ The `llamastack/distribution-nvidia` distribution consists of the following prov
|
||||||
| API | Provider(s) |
|
| API | Provider(s) |
|
||||||
|-----|-------------|
|
|-----|-------------|
|
||||||
| agents | `inline::meta-reference` |
|
| agents | `inline::meta-reference` |
|
||||||
| datasetio | `inline::localfs` |
|
| datasetio | `inline::localfs`, `remote::nvidia` |
|
||||||
| eval | `inline::meta-reference` |
|
| eval | `remote::nvidia` |
|
||||||
| inference | `remote::nvidia` |
|
| inference | `remote::nvidia` |
|
||||||
| post_training | `remote::nvidia` |
|
| post_training | `remote::nvidia` |
|
||||||
| safety | `remote::nvidia` |
|
| safety | `remote::nvidia` |
|
||||||
|
@ -22,13 +22,13 @@ The `llamastack/distribution-nvidia` distribution consists of the following prov
|
||||||
The following environment variables can be configured:
|
The following environment variables can be configured:
|
||||||
|
|
||||||
- `NVIDIA_API_KEY`: NVIDIA API Key (default: ``)
|
- `NVIDIA_API_KEY`: NVIDIA API Key (default: ``)
|
||||||
- `NVIDIA_USER_ID`: NVIDIA User ID (default: `llama-stack-user`)
|
- `NVIDIA_APPEND_API_VERSION`: Whether to append the API version to the base_url (default: `True`)
|
||||||
- `NVIDIA_DATASET_NAMESPACE`: NVIDIA Dataset Namespace (default: `default`)
|
- `NVIDIA_DATASET_NAMESPACE`: NVIDIA Dataset Namespace (default: `default`)
|
||||||
- `NVIDIA_ACCESS_POLICIES`: NVIDIA Access Policies (default: `{}`)
|
|
||||||
- `NVIDIA_PROJECT_ID`: NVIDIA Project ID (default: `test-project`)
|
- `NVIDIA_PROJECT_ID`: NVIDIA Project ID (default: `test-project`)
|
||||||
- `NVIDIA_CUSTOMIZER_URL`: NVIDIA Customizer URL (default: `https://customizer.api.nvidia.com`)
|
- `NVIDIA_CUSTOMIZER_URL`: NVIDIA Customizer URL (default: `https://customizer.api.nvidia.com`)
|
||||||
- `NVIDIA_OUTPUT_MODEL_DIR`: NVIDIA Output Model Directory (default: `test-example-model@v1`)
|
- `NVIDIA_OUTPUT_MODEL_DIR`: NVIDIA Output Model Directory (default: `test-example-model@v1`)
|
||||||
- `GUARDRAILS_SERVICE_URL`: URL for the NeMo Guardrails Service (default: `http://0.0.0.0:7331`)
|
- `GUARDRAILS_SERVICE_URL`: URL for the NeMo Guardrails Service (default: `http://0.0.0.0:7331`)
|
||||||
|
- `NVIDIA_EVALUATOR_URL`: URL for the NeMo Evaluator Service (default: `http://0.0.0.0:7331`)
|
||||||
- `INFERENCE_MODEL`: Inference model (default: `Llama3.1-8B-Instruct`)
|
- `INFERENCE_MODEL`: Inference model (default: `Llama3.1-8B-Instruct`)
|
||||||
- `SAFETY_MODEL`: Name of the model to use for safety (default: `meta/llama-3.1-8b-instruct`)
|
- `SAFETY_MODEL`: Name of the model to use for safety (default: `meta/llama-3.1-8b-instruct`)
|
||||||
|
|
||||||
|
@ -45,20 +45,91 @@ The following models are available by default:
|
||||||
- `meta/llama-3.2-3b-instruct (aliases: meta-llama/Llama-3.2-3B-Instruct)`
|
- `meta/llama-3.2-3b-instruct (aliases: meta-llama/Llama-3.2-3B-Instruct)`
|
||||||
- `meta/llama-3.2-11b-vision-instruct (aliases: meta-llama/Llama-3.2-11B-Vision-Instruct)`
|
- `meta/llama-3.2-11b-vision-instruct (aliases: meta-llama/Llama-3.2-11B-Vision-Instruct)`
|
||||||
- `meta/llama-3.2-90b-vision-instruct (aliases: meta-llama/Llama-3.2-90B-Vision-Instruct)`
|
- `meta/llama-3.2-90b-vision-instruct (aliases: meta-llama/Llama-3.2-90B-Vision-Instruct)`
|
||||||
|
- `meta/llama-3.3-70b-instruct (aliases: meta-llama/Llama-3.3-70B-Instruct)`
|
||||||
- `nvidia/llama-3.2-nv-embedqa-1b-v2 `
|
- `nvidia/llama-3.2-nv-embedqa-1b-v2 `
|
||||||
- `nvidia/nv-embedqa-e5-v5 `
|
- `nvidia/nv-embedqa-e5-v5 `
|
||||||
- `nvidia/nv-embedqa-mistral-7b-v2 `
|
- `nvidia/nv-embedqa-mistral-7b-v2 `
|
||||||
- `snowflake/arctic-embed-l `
|
- `snowflake/arctic-embed-l `
|
||||||
|
|
||||||
|
|
||||||
### Prerequisite: API Keys
|
## Prerequisites
|
||||||
|
### NVIDIA API Keys
|
||||||
|
|
||||||
Make sure you have access to a NVIDIA API Key. You can get one by visiting [https://build.nvidia.com/](https://build.nvidia.com/).
|
Make sure you have access to a NVIDIA API Key. You can get one by visiting [https://build.nvidia.com/](https://build.nvidia.com/). Use this key for the `NVIDIA_API_KEY` environment variable.
|
||||||
|
|
||||||
|
### Deploy NeMo Microservices Platform
|
||||||
|
The NVIDIA NeMo microservices platform supports end-to-end microservice deployment of a complete AI flywheel on your Kubernetes cluster through the NeMo Microservices Helm Chart. Please reference the [NVIDIA NeMo Microservices documentation](https://docs.nvidia.com/nemo/microservices/latest/about/index.html) for platform prerequisites and instructions to install and deploy the platform.
|
||||||
|
|
||||||
|
## Supported Services
|
||||||
|
Each Llama Stack API corresponds to a specific NeMo microservice. The core microservices (Customizer, Evaluator, Guardrails) are exposed by the same endpoint. The platform components (Data Store) are each exposed by separate endpoints.
|
||||||
|
|
||||||
|
### Inference: NVIDIA NIM
|
||||||
|
NVIDIA NIM is used for running inference with registered models. There are two ways to access NVIDIA NIMs:
|
||||||
|
1. Hosted (default): Preview APIs hosted at https://integrate.api.nvidia.com (Requires an API key)
|
||||||
|
2. Self-hosted: NVIDIA NIMs that run on your own infrastructure.
|
||||||
|
|
||||||
|
The deployed platform includes the NIM Proxy microservice, which is the service that provides to access your NIMs (for example, to run inference on a model). Set the `NVIDIA_BASE_URL` environment variable to use your NVIDIA NIM Proxy deployment.
|
||||||
|
|
||||||
|
### Datasetio API: NeMo Data Store
|
||||||
|
The NeMo Data Store microservice serves as the default file storage solution for the NeMo microservices platform. It exposts APIs compatible with the Hugging Face Hub client (`HfApi`), so you can use the client to interact with Data Store. The `NVIDIA_DATASETS_URL` environment variable should point to your NeMo Data Store endpoint.
|
||||||
|
|
||||||
|
See the [NVIDIA Datasetio docs](/llama_stack/providers/remote/datasetio/nvidia/README.md) for supported features and example usage.
|
||||||
|
|
||||||
|
### Eval API: NeMo Evaluator
|
||||||
|
The NeMo Evaluator microservice supports evaluation of LLMs. Launching an Evaluation job with NeMo Evaluator requires an Evaluation Config (an object that contains metadata needed by the job). A Llama Stack Benchmark maps to an Evaluation Config, so registering a Benchmark creates an Evaluation Config in NeMo Evaluator. The `NVIDIA_EVALUATOR_URL` environment variable should point to your NeMo Microservices endpoint.
|
||||||
|
|
||||||
|
See the [NVIDIA Eval docs](/llama_stack/providers/remote/eval/nvidia/README.md) for supported features and example usage.
|
||||||
|
|
||||||
|
### Post-Training API: NeMo Customizer
|
||||||
|
The NeMo Customizer microservice supports fine-tuning models. You can reference [this list of supported models](/llama_stack/providers/remote/post_training/nvidia/models.py) that can be fine-tuned using Llama Stack. The `NVIDIA_CUSTOMIZER_URL` environment variable should point to your NeMo Microservices endpoint.
|
||||||
|
|
||||||
|
See the [NVIDIA Post-Training docs](/llama_stack/providers/remote/post_training/nvidia/README.md) for supported features and example usage.
|
||||||
|
|
||||||
|
### Safety API: NeMo Guardrails
|
||||||
|
The NeMo Guardrails microservice sits between your application and the LLM, and adds checks and content moderation to a model. The `GUARDRAILS_SERVICE_URL` environment variable should point to your NeMo Microservices endpoint.
|
||||||
|
|
||||||
|
See the NVIDIA Safety docs for supported features and example usage.
|
||||||
|
|
||||||
|
## Deploying models
|
||||||
|
In order to use a registered model with the Llama Stack APIs, ensure the corresponding NIM is deployed to your environment. For example, you can use the NIM Proxy microservice to deploy `meta/llama-3.2-1b-instruct`.
|
||||||
|
|
||||||
|
Note: For improved inference speeds, we need to use NIM with `fast_outlines` guided decoding system (specified in the request body). This is the default if you deployed the platform with the NeMo Microservices Helm Chart.
|
||||||
|
```sh
|
||||||
|
# URL to NeMo NIM Proxy service
|
||||||
|
export NEMO_URL="http://nemo.test"
|
||||||
|
|
||||||
|
curl --location "$NEMO_URL/v1/deployment/model-deployments" \
|
||||||
|
-H 'accept: application/json' \
|
||||||
|
-H 'Content-Type: application/json' \
|
||||||
|
-d '{
|
||||||
|
"name": "llama-3.2-1b-instruct",
|
||||||
|
"namespace": "meta",
|
||||||
|
"config": {
|
||||||
|
"model": "meta/llama-3.2-1b-instruct",
|
||||||
|
"nim_deployment": {
|
||||||
|
"image_name": "nvcr.io/nim/meta/llama-3.2-1b-instruct",
|
||||||
|
"image_tag": "1.8.3",
|
||||||
|
"pvc_size": "25Gi",
|
||||||
|
"gpu": 1,
|
||||||
|
"additional_envs": {
|
||||||
|
"NIM_GUIDED_DECODING_BACKEND": "fast_outlines"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
This NIM deployment should take approximately 10 minutes to go live. [See the docs](https://docs.nvidia.com/nemo/microservices/latest/get-started/tutorials/deploy-nims.html) for more information on how to deploy a NIM and verify it's available for inference.
|
||||||
|
|
||||||
|
You can also remove a deployed NIM to free up GPU resources, if needed.
|
||||||
|
```sh
|
||||||
|
export NEMO_URL="http://nemo.test"
|
||||||
|
|
||||||
|
curl -X DELETE "$NEMO_URL/v1/deployment/model-deployments/meta/llama-3.1-8b-instruct"
|
||||||
|
```
|
||||||
|
|
||||||
## Running Llama Stack with NVIDIA
|
## Running Llama Stack with NVIDIA
|
||||||
|
|
||||||
You can do this via Conda (build code) or Docker which has a pre-built image.
|
You can do this via Conda or venv (build code), or Docker which has a pre-built image.
|
||||||
|
|
||||||
### Via Docker
|
### Via Docker
|
||||||
|
|
||||||
|
@ -80,9 +151,23 @@ docker run \
|
||||||
### Via Conda
|
### Via Conda
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
INFERENCE_MODEL=meta-llama/Llama-3.1-8b-Instruct
|
||||||
llama stack build --template nvidia --image-type conda
|
llama stack build --template nvidia --image-type conda
|
||||||
llama stack run ./run.yaml \
|
llama stack run ./run.yaml \
|
||||||
--port 8321 \
|
--port 8321 \
|
||||||
--env NVIDIA_API_KEY=$NVIDIA_API_KEY
|
--env NVIDIA_API_KEY=$NVIDIA_API_KEY \
|
||||||
|
--env INFERENCE_MODEL=$INFERENCE_MODEL
|
||||||
|
```
|
||||||
|
|
||||||
|
### Via venv
|
||||||
|
|
||||||
|
If you've set up your local development environment, you can also build the image using your local virtual environment.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
INFERENCE_MODEL=meta-llama/Llama-3.1-8b-Instruct
|
||||||
|
llama stack build --template nvidia --image-type venv
|
||||||
|
llama stack run ./run.yaml \
|
||||||
|
--port 8321 \
|
||||||
|
--env NVIDIA_API_KEY=$NVIDIA_API_KEY \
|
||||||
--env INFERENCE_MODEL=$INFERENCE_MODEL
|
--env INFERENCE_MODEL=$INFERENCE_MODEL
|
||||||
```
|
```
|
||||||
|
|
|
@ -22,7 +22,7 @@ The `llamastack/distribution-ollama` distribution consists of the following prov
|
||||||
| safety | `inline::llama-guard` |
|
| safety | `inline::llama-guard` |
|
||||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||||
| telemetry | `inline::meta-reference` |
|
| telemetry | `inline::meta-reference` |
|
||||||
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime`, `remote::model-context-protocol`, `remote::wolfram-alpha` |
|
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::rag-runtime`, `remote::model-context-protocol`, `remote::wolfram-alpha` |
|
||||||
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -22,7 +22,7 @@ The `llamastack/distribution-passthrough` distribution consists of the following
|
||||||
| safety | `inline::llama-guard` |
|
| safety | `inline::llama-guard` |
|
||||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||||
| telemetry | `inline::meta-reference` |
|
| telemetry | `inline::meta-reference` |
|
||||||
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `remote::wolfram-alpha`, `inline::code-interpreter`, `inline::rag-runtime`, `remote::model-context-protocol` |
|
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `remote::wolfram-alpha`, `inline::rag-runtime`, `remote::model-context-protocol` |
|
||||||
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -21,7 +21,7 @@ The `llamastack/distribution-remote-vllm` distribution consists of the following
|
||||||
| safety | `inline::llama-guard` |
|
| safety | `inline::llama-guard` |
|
||||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||||
| telemetry | `inline::meta-reference` |
|
| telemetry | `inline::meta-reference` |
|
||||||
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime`, `remote::model-context-protocol`, `remote::wolfram-alpha` |
|
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::rag-runtime`, `remote::model-context-protocol`, `remote::wolfram-alpha` |
|
||||||
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
||||||
|
|
||||||
|
|
||||||
|
@ -41,10 +41,10 @@ The following environment variables can be configured:
|
||||||
|
|
||||||
## Setting up vLLM server
|
## Setting up vLLM server
|
||||||
|
|
||||||
In the following sections, we'll use either AMD and NVIDIA GPUs to serve as hardware accelerators for the vLLM
|
In the following sections, we'll use AMD, NVIDIA or Intel GPUs to serve as hardware accelerators for the vLLM
|
||||||
server, which acts as both the LLM inference provider and the safety provider. Note that vLLM also
|
server, which acts as both the LLM inference provider and the safety provider. Note that vLLM also
|
||||||
[supports many other hardware accelerators](https://docs.vllm.ai/en/latest/getting_started/installation.html) and
|
[supports many other hardware accelerators](https://docs.vllm.ai/en/latest/getting_started/installation.html) and
|
||||||
that we only use GPUs here for demonstration purposes.
|
that we only use GPUs here for demonstration purposes. Note that if you run into issues, you can include the environment variable `--env VLLM_DEBUG_LOG_API_SERVER_RESPONSE=true` (available in vLLM v0.8.3 and above) in the `docker run` command to enable log response from API server for debugging.
|
||||||
|
|
||||||
### Setting up vLLM server on AMD GPU
|
### Setting up vLLM server on AMD GPU
|
||||||
|
|
||||||
|
@ -162,6 +162,55 @@ docker run \
|
||||||
--port $SAFETY_PORT
|
--port $SAFETY_PORT
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Setting up vLLM server on Intel GPU
|
||||||
|
|
||||||
|
Refer to [vLLM Documentation for XPU](https://docs.vllm.ai/en/v0.8.2/getting_started/installation/gpu.html?device=xpu) to get a vLLM endpoint. In addition to vLLM side setup which guides towards installing vLLM from sources orself-building vLLM Docker container, Intel provides prebuilt vLLM container to use on systems with Intel GPUs supported by PyTorch XPU backend:
|
||||||
|
- [intel/vllm](https://hub.docker.com/r/intel/vllm)
|
||||||
|
|
||||||
|
Here is a sample script to start a vLLM server locally via Docker using Intel provided container:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export INFERENCE_PORT=8000
|
||||||
|
export INFERENCE_MODEL=meta-llama/Llama-3.2-1B-Instruct
|
||||||
|
export ZE_AFFINITY_MASK=0
|
||||||
|
|
||||||
|
docker run \
|
||||||
|
--pull always \
|
||||||
|
--device /dev/dri \
|
||||||
|
-v /dev/dri/by-path:/dev/dri/by-path \
|
||||||
|
-v ~/.cache/huggingface:/root/.cache/huggingface \
|
||||||
|
--env "HUGGING_FACE_HUB_TOKEN=$HF_TOKEN" \
|
||||||
|
--env ZE_AFFINITY_MASK=$ZE_AFFINITY_MASK \
|
||||||
|
-p $INFERENCE_PORT:$INFERENCE_PORT \
|
||||||
|
--ipc=host \
|
||||||
|
intel/vllm:xpu \
|
||||||
|
--gpu-memory-utilization 0.7 \
|
||||||
|
--model $INFERENCE_MODEL \
|
||||||
|
--port $INFERENCE_PORT
|
||||||
|
```
|
||||||
|
|
||||||
|
If you are using Llama Stack Safety / Shield APIs, then you will need to also run another instance of a vLLM with a corresponding safety model like `meta-llama/Llama-Guard-3-1B` using a script like:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export SAFETY_PORT=8081
|
||||||
|
export SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
|
||||||
|
export ZE_AFFINITY_MASK=1
|
||||||
|
|
||||||
|
docker run \
|
||||||
|
--pull always \
|
||||||
|
--device /dev/dri \
|
||||||
|
-v /dev/dri/by-path:/dev/dri/by-path \
|
||||||
|
-v ~/.cache/huggingface:/root/.cache/huggingface \
|
||||||
|
--env "HUGGING_FACE_HUB_TOKEN=$HF_TOKEN" \
|
||||||
|
--env ZE_AFFINITY_MASK=$ZE_AFFINITY_MASK \
|
||||||
|
-p $SAFETY_PORT:$SAFETY_PORT \
|
||||||
|
--ipc=host \
|
||||||
|
intel/vllm:xpu \
|
||||||
|
--gpu-memory-utilization 0.7 \
|
||||||
|
--model $SAFETY_MODEL \
|
||||||
|
--port $SAFETY_PORT
|
||||||
|
```
|
||||||
|
|
||||||
## Running Llama Stack
|
## Running Llama Stack
|
||||||
|
|
||||||
Now you are ready to run Llama Stack with vLLM as the inference provider. You can do this via Conda (build code) or Docker which has a pre-built image.
|
Now you are ready to run Llama Stack with vLLM as the inference provider. You can do this via Conda (build code) or Docker which has a pre-built image.
|
||||||
|
|
|
@ -19,7 +19,7 @@ The `llamastack/distribution-sambanova` distribution consists of the following p
|
||||||
| inference | `remote::sambanova`, `inline::sentence-transformers` |
|
| inference | `remote::sambanova`, `inline::sentence-transformers` |
|
||||||
| safety | `inline::llama-guard` |
|
| safety | `inline::llama-guard` |
|
||||||
| telemetry | `inline::meta-reference` |
|
| telemetry | `inline::meta-reference` |
|
||||||
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime`, `remote::model-context-protocol`, `remote::wolfram-alpha` |
|
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::rag-runtime`, `remote::model-context-protocol`, `remote::wolfram-alpha` |
|
||||||
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -23,7 +23,7 @@ The `llamastack/distribution-tgi` distribution consists of the following provide
|
||||||
| safety | `inline::llama-guard` |
|
| safety | `inline::llama-guard` |
|
||||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||||
| telemetry | `inline::meta-reference` |
|
| telemetry | `inline::meta-reference` |
|
||||||
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime`, `remote::model-context-protocol` |
|
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::rag-runtime`, `remote::model-context-protocol` |
|
||||||
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -22,7 +22,7 @@ The `llamastack/distribution-together` distribution consists of the following pr
|
||||||
| safety | `inline::llama-guard` |
|
| safety | `inline::llama-guard` |
|
||||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||||
| telemetry | `inline::meta-reference` |
|
| telemetry | `inline::meta-reference` |
|
||||||
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime`, `remote::model-context-protocol`, `remote::wolfram-alpha` |
|
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::rag-runtime`, `remote::model-context-protocol`, `remote::wolfram-alpha` |
|
||||||
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -50,9 +50,11 @@ Llama Stack supports two types of external providers:
|
||||||
|
|
||||||
Here's a list of known external providers that you can use with Llama Stack:
|
Here's a list of known external providers that you can use with Llama Stack:
|
||||||
|
|
||||||
| Type | Name | Description | Repository |
|
| Name | Description | API | Type | Repository |
|
||||||
|------|------|-------------|------------|
|
|------|-------------|-----|------|------------|
|
||||||
| Remote | KubeFlow Training | Train models with KubeFlow | [llama-stack-provider-kft](https://github.com/opendatahub-io/llama-stack-provider-kft) |
|
| KubeFlow Training | Train models with KubeFlow | Post Training | Remote | [llama-stack-provider-kft](https://github.com/opendatahub-io/llama-stack-provider-kft) |
|
||||||
|
| KubeFlow Pipelines | Train models with KubeFlow Pipelines | Post Training | Remote | [llama-stack-provider-kfp-trainer](https://github.com/opendatahub-io/llama-stack-provider-kfp-trainer) |
|
||||||
|
| RamaLama | Inference models with RamaLama | Inference | Remote | [ramalama-stack](https://github.com/containers/ramalama-stack) |
|
||||||
|
|
||||||
### Remote Provider Specification
|
### Remote Provider Specification
|
||||||
|
|
||||||
|
|
|
@ -389,5 +389,7 @@
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.10.15"
|
"version": "3.10.15"
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
}
|
}
|
||||||
|
|
|
@ -256,5 +256,7 @@
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.10.15"
|
"version": "3.10.15"
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
}
|
}
|
||||||
|
|
|
@ -301,5 +301,7 @@
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.12.2"
|
"version": "3.12.2"
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
}
|
}
|
||||||
|
|
|
@ -200,5 +200,7 @@
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.12.2"
|
"version": "3.12.2"
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
}
|
}
|
||||||
|
|
|
@ -355,5 +355,7 @@
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.10.15"
|
"version": "3.10.15"
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
}
|
}
|
||||||
|
|
|
@ -398,5 +398,7 @@
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.10.15"
|
"version": "3.10.15"
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
}
|
}
|
||||||
|
|
|
@ -132,5 +132,7 @@
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.11.10"
|
"version": "3.11.10"
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
}
|
}
|
||||||
|
|
|
@ -188,5 +188,7 @@
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.10.15"
|
"version": "3.10.15"
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
}
|
}
|
||||||
|
|
|
@ -86,11 +86,11 @@ If you're looking for more specific topics, we have a [Zero to Hero Guide](#next
|
||||||
llama stack build --template ollama --image-type conda
|
llama stack build --template ollama --image-type conda
|
||||||
```
|
```
|
||||||
**Expected Output:**
|
**Expected Output:**
|
||||||
```
|
```bash
|
||||||
...
|
...
|
||||||
Build Successful! Next steps:
|
Build Successful!
|
||||||
1. Set the environment variables: LLAMA_STACK_PORT, OLLAMA_URL, INFERENCE_MODEL, SAFETY_MODEL
|
You can find the newly-built template here: ~/.llama/distributions/ollama/ollama-run.yaml
|
||||||
2. `llama stack run /Users/<username>/.llama/distributions/llamastack-ollama/ollama-run.yaml
|
You can run the new Llama Stack Distro via: llama stack run ~/.llama/distributions/ollama/ollama-run.yaml --image-type conda
|
||||||
```
|
```
|
||||||
|
|
||||||
3. **Set the ENV variables by exporting them to the terminal**:
|
3. **Set the ENV variables by exporting them to the terminal**:
|
||||||
|
|
145
install.sh
Executable file
145
install.sh
Executable file
|
@ -0,0 +1,145 @@
|
||||||
|
#!/usr/bin/env bash
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
set -Eeuo pipefail
|
||||||
|
|
||||||
|
PORT=8321
|
||||||
|
OLLAMA_PORT=11434
|
||||||
|
MODEL_ALIAS="llama3.2:3b"
|
||||||
|
SERVER_IMAGE="llamastack/distribution-ollama:0.2.2"
|
||||||
|
WAIT_TIMEOUT=300
|
||||||
|
|
||||||
|
log(){ printf "\e[1;32m%s\e[0m\n" "$*"; }
|
||||||
|
die(){ printf "\e[1;31m❌ %s\e[0m\n" "$*" >&2; exit 1; }
|
||||||
|
|
||||||
|
wait_for_service() {
|
||||||
|
local url="$1"
|
||||||
|
local pattern="$2"
|
||||||
|
local timeout="$3"
|
||||||
|
local name="$4"
|
||||||
|
local start ts
|
||||||
|
log "⏳ Waiting for ${name}…"
|
||||||
|
start=$(date +%s)
|
||||||
|
while true; do
|
||||||
|
if curl --retry 5 --retry-delay 1 --retry-max-time "$timeout" --retry-all-errors --silent --fail "$url" 2>/dev/null | grep -q "$pattern"; then
|
||||||
|
break
|
||||||
|
fi
|
||||||
|
ts=$(date +%s)
|
||||||
|
if (( ts - start >= timeout )); then
|
||||||
|
return 1
|
||||||
|
fi
|
||||||
|
printf '.'
|
||||||
|
sleep 1
|
||||||
|
done
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
if command -v docker &> /dev/null; then
|
||||||
|
ENGINE="docker"
|
||||||
|
elif command -v podman &> /dev/null; then
|
||||||
|
ENGINE="podman"
|
||||||
|
else
|
||||||
|
die "Docker or Podman is required. Install Docker: https://docs.docker.com/get-docker/ or Podman: https://podman.io/getting-started/installation"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Explicitly set the platform for the host architecture
|
||||||
|
HOST_ARCH="$(uname -m)"
|
||||||
|
if [ "$HOST_ARCH" = "arm64" ]; then
|
||||||
|
if [ "$ENGINE" = "docker" ]; then
|
||||||
|
PLATFORM_OPTS=( --platform linux/amd64 )
|
||||||
|
else
|
||||||
|
PLATFORM_OPTS=( --os linux --arch amd64 )
|
||||||
|
fi
|
||||||
|
else
|
||||||
|
PLATFORM_OPTS=()
|
||||||
|
fi
|
||||||
|
|
||||||
|
# macOS + Podman: ensure VM is running before we try to launch containers
|
||||||
|
# If you need GPU passthrough under Podman on macOS, init the VM with libkrun:
|
||||||
|
# CONTAINERS_MACHINE_PROVIDER=libkrun podman machine init
|
||||||
|
if [ "$ENGINE" = "podman" ] && [ "$(uname -s)" = "Darwin" ]; then
|
||||||
|
if ! podman info &>/dev/null; then
|
||||||
|
log "⌛️ Initializing Podman VM…"
|
||||||
|
podman machine init &>/dev/null || true
|
||||||
|
podman machine start &>/dev/null || true
|
||||||
|
|
||||||
|
log "⌛️ Waiting for Podman API…"
|
||||||
|
until podman info &>/dev/null; do
|
||||||
|
sleep 1
|
||||||
|
done
|
||||||
|
log "✅ Podman VM is up"
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Clean up any leftovers from earlier runs
|
||||||
|
for name in ollama-server llama-stack; do
|
||||||
|
ids=$($ENGINE ps -aq --filter "name=^${name}$")
|
||||||
|
if [ -n "$ids" ]; then
|
||||||
|
log "⚠️ Found existing container(s) for '${name}', removing…"
|
||||||
|
$ENGINE rm -f "$ids" > /dev/null 2>&1
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
|
||||||
|
###############################################################################
|
||||||
|
# 0. Create a shared network
|
||||||
|
###############################################################################
|
||||||
|
if ! $ENGINE network inspect llama-net >/dev/null 2>&1; then
|
||||||
|
log "🌐 Creating network…"
|
||||||
|
$ENGINE network create llama-net >/dev/null 2>&1
|
||||||
|
fi
|
||||||
|
|
||||||
|
###############################################################################
|
||||||
|
# 1. Ollama
|
||||||
|
###############################################################################
|
||||||
|
log "🦙 Starting Ollama…"
|
||||||
|
$ENGINE run -d "${PLATFORM_OPTS[@]}" --name ollama-server \
|
||||||
|
--network llama-net \
|
||||||
|
-p "${OLLAMA_PORT}:${OLLAMA_PORT}" \
|
||||||
|
ollama/ollama > /dev/null 2>&1
|
||||||
|
|
||||||
|
if ! wait_for_service "http://localhost:${OLLAMA_PORT}/" "Ollama" "$WAIT_TIMEOUT" "Ollama daemon"; then
|
||||||
|
log "❌ Ollama daemon did not become ready in ${WAIT_TIMEOUT}s; dumping container logs:"
|
||||||
|
$ENGINE logs --tail 200 ollama-server
|
||||||
|
die "Ollama startup failed"
|
||||||
|
fi
|
||||||
|
|
||||||
|
log "📦 Ensuring model is pulled: ${MODEL_ALIAS}…"
|
||||||
|
if ! $ENGINE exec ollama-server ollama pull "${MODEL_ALIAS}" > /dev/null 2>&1; then
|
||||||
|
log "❌ Failed to pull model ${MODEL_ALIAS}; dumping container logs:"
|
||||||
|
$ENGINE logs --tail 200 ollama-server
|
||||||
|
die "Model pull failed"
|
||||||
|
fi
|
||||||
|
|
||||||
|
###############################################################################
|
||||||
|
# 2. Llama‑Stack
|
||||||
|
###############################################################################
|
||||||
|
cmd=( run -d "${PLATFORM_OPTS[@]}" --name llama-stack \
|
||||||
|
--network llama-net \
|
||||||
|
-p "${PORT}:${PORT}" \
|
||||||
|
"${SERVER_IMAGE}" --port "${PORT}" \
|
||||||
|
--env INFERENCE_MODEL="${MODEL_ALIAS}" \
|
||||||
|
--env OLLAMA_URL="http://ollama-server:${OLLAMA_PORT}" )
|
||||||
|
|
||||||
|
log "🦙 Starting Llama‑Stack…"
|
||||||
|
$ENGINE "${cmd[@]}" > /dev/null 2>&1
|
||||||
|
|
||||||
|
if ! wait_for_service "http://127.0.0.1:${PORT}/v1/health" "OK" "$WAIT_TIMEOUT" "Llama-Stack API"; then
|
||||||
|
log "❌ Llama-Stack did not become ready in ${WAIT_TIMEOUT}s; dumping container logs:"
|
||||||
|
$ENGINE logs --tail 200 llama-stack
|
||||||
|
die "Llama-Stack startup failed"
|
||||||
|
fi
|
||||||
|
|
||||||
|
###############################################################################
|
||||||
|
# Done
|
||||||
|
###############################################################################
|
||||||
|
log ""
|
||||||
|
log "🎉 Llama‑Stack is ready!"
|
||||||
|
log "👉 API endpoint: http://localhost:${PORT}"
|
||||||
|
log "📖 Documentation: https://llama-stack.readthedocs.io/en/latest/references/index.html"
|
||||||
|
log "💻 To access the llama‑stack CLI, exec into the container:"
|
||||||
|
log " $ENGINE exec -ti llama-stack bash"
|
||||||
|
log ""
|
|
@ -4,20 +4,10 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from collections.abc import AsyncIterator
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import (
|
from typing import Annotated, Any, Literal, Protocol, runtime_checkable
|
||||||
Annotated,
|
|
||||||
Any,
|
|
||||||
AsyncIterator,
|
|
||||||
Dict,
|
|
||||||
List,
|
|
||||||
Literal,
|
|
||||||
Optional,
|
|
||||||
Protocol,
|
|
||||||
Union,
|
|
||||||
runtime_checkable,
|
|
||||||
)
|
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
|
@ -38,6 +28,13 @@ from llama_stack.apis.safety import SafetyViolation
|
||||||
from llama_stack.apis.tools import ToolDef
|
from llama_stack.apis.tools import ToolDef
|
||||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||||
|
|
||||||
|
from .openai_responses import (
|
||||||
|
OpenAIResponseInputMessage,
|
||||||
|
OpenAIResponseInputTool,
|
||||||
|
OpenAIResponseObject,
|
||||||
|
OpenAIResponseObjectStream,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Attachment(BaseModel):
|
class Attachment(BaseModel):
|
||||||
"""An attachment to an agent turn.
|
"""An attachment to an agent turn.
|
||||||
|
@ -72,8 +69,8 @@ class StepCommon(BaseModel):
|
||||||
|
|
||||||
turn_id: str
|
turn_id: str
|
||||||
step_id: str
|
step_id: str
|
||||||
started_at: Optional[datetime] = None
|
started_at: datetime | None = None
|
||||||
completed_at: Optional[datetime] = None
|
completed_at: datetime | None = None
|
||||||
|
|
||||||
|
|
||||||
class StepType(Enum):
|
class StepType(Enum):
|
||||||
|
@ -113,8 +110,8 @@ class ToolExecutionStep(StepCommon):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
step_type: Literal[StepType.tool_execution.value] = StepType.tool_execution.value
|
step_type: Literal[StepType.tool_execution.value] = StepType.tool_execution.value
|
||||||
tool_calls: List[ToolCall]
|
tool_calls: list[ToolCall]
|
||||||
tool_responses: List[ToolResponse]
|
tool_responses: list[ToolResponse]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -125,7 +122,7 @@ class ShieldCallStep(StepCommon):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
step_type: Literal[StepType.shield_call.value] = StepType.shield_call.value
|
step_type: Literal[StepType.shield_call.value] = StepType.shield_call.value
|
||||||
violation: Optional[SafetyViolation]
|
violation: SafetyViolation | None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -143,12 +140,7 @@ class MemoryRetrievalStep(StepCommon):
|
||||||
|
|
||||||
|
|
||||||
Step = Annotated[
|
Step = Annotated[
|
||||||
Union[
|
InferenceStep | ToolExecutionStep | ShieldCallStep | MemoryRetrievalStep,
|
||||||
InferenceStep,
|
|
||||||
ToolExecutionStep,
|
|
||||||
ShieldCallStep,
|
|
||||||
MemoryRetrievalStep,
|
|
||||||
],
|
|
||||||
Field(discriminator="step_type"),
|
Field(discriminator="step_type"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -159,18 +151,13 @@ class Turn(BaseModel):
|
||||||
|
|
||||||
turn_id: str
|
turn_id: str
|
||||||
session_id: str
|
session_id: str
|
||||||
input_messages: List[
|
input_messages: list[UserMessage | ToolResponseMessage]
|
||||||
Union[
|
steps: list[Step]
|
||||||
UserMessage,
|
|
||||||
ToolResponseMessage,
|
|
||||||
]
|
|
||||||
]
|
|
||||||
steps: List[Step]
|
|
||||||
output_message: CompletionMessage
|
output_message: CompletionMessage
|
||||||
output_attachments: Optional[List[Attachment]] = Field(default_factory=list)
|
output_attachments: list[Attachment] | None = Field(default_factory=list)
|
||||||
|
|
||||||
started_at: datetime
|
started_at: datetime
|
||||||
completed_at: Optional[datetime] = None
|
completed_at: datetime | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -179,34 +166,31 @@ class Session(BaseModel):
|
||||||
|
|
||||||
session_id: str
|
session_id: str
|
||||||
session_name: str
|
session_name: str
|
||||||
turns: List[Turn]
|
turns: list[Turn]
|
||||||
started_at: datetime
|
started_at: datetime
|
||||||
|
|
||||||
|
|
||||||
class AgentToolGroupWithArgs(BaseModel):
|
class AgentToolGroupWithArgs(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
args: Dict[str, Any]
|
args: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
AgentToolGroup = Union[
|
AgentToolGroup = str | AgentToolGroupWithArgs
|
||||||
str,
|
|
||||||
AgentToolGroupWithArgs,
|
|
||||||
]
|
|
||||||
register_schema(AgentToolGroup, name="AgentTool")
|
register_schema(AgentToolGroup, name="AgentTool")
|
||||||
|
|
||||||
|
|
||||||
class AgentConfigCommon(BaseModel):
|
class AgentConfigCommon(BaseModel):
|
||||||
sampling_params: Optional[SamplingParams] = Field(default_factory=SamplingParams)
|
sampling_params: SamplingParams | None = Field(default_factory=SamplingParams)
|
||||||
|
|
||||||
input_shields: Optional[List[str]] = Field(default_factory=list)
|
input_shields: list[str] | None = Field(default_factory=list)
|
||||||
output_shields: Optional[List[str]] = Field(default_factory=list)
|
output_shields: list[str] | None = Field(default_factory=list)
|
||||||
toolgroups: Optional[List[AgentToolGroup]] = Field(default_factory=list)
|
toolgroups: list[AgentToolGroup] | None = Field(default_factory=list)
|
||||||
client_tools: Optional[List[ToolDef]] = Field(default_factory=list)
|
client_tools: list[ToolDef] | None = Field(default_factory=list)
|
||||||
tool_choice: Optional[ToolChoice] = Field(default=None, deprecated="use tool_config instead")
|
tool_choice: ToolChoice | None = Field(default=None, deprecated="use tool_config instead")
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None, deprecated="use tool_config instead")
|
tool_prompt_format: ToolPromptFormat | None = Field(default=None, deprecated="use tool_config instead")
|
||||||
tool_config: Optional[ToolConfig] = Field(default=None)
|
tool_config: ToolConfig | None = Field(default=None)
|
||||||
|
|
||||||
max_infer_iters: Optional[int] = 10
|
max_infer_iters: int | None = 10
|
||||||
|
|
||||||
def model_post_init(self, __context):
|
def model_post_init(self, __context):
|
||||||
if self.tool_config:
|
if self.tool_config:
|
||||||
|
@ -225,10 +209,20 @@ class AgentConfigCommon(BaseModel):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class AgentConfig(AgentConfigCommon):
|
class AgentConfig(AgentConfigCommon):
|
||||||
|
"""Configuration for an agent.
|
||||||
|
|
||||||
|
:param model: The model identifier to use for the agent
|
||||||
|
:param instructions: The system instructions for the agent
|
||||||
|
:param name: Optional name for the agent, used in telemetry and identification
|
||||||
|
:param enable_session_persistence: Optional flag indicating whether session data has to be persisted
|
||||||
|
:param response_format: Optional response format configuration
|
||||||
|
"""
|
||||||
|
|
||||||
model: str
|
model: str
|
||||||
instructions: str
|
instructions: str
|
||||||
enable_session_persistence: Optional[bool] = False
|
name: str | None = None
|
||||||
response_format: Optional[ResponseFormat] = None
|
enable_session_persistence: bool | None = False
|
||||||
|
response_format: ResponseFormat | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -240,16 +234,16 @@ class Agent(BaseModel):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ListAgentsResponse(BaseModel):
|
class ListAgentsResponse(BaseModel):
|
||||||
data: List[Agent]
|
data: list[Agent]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ListAgentSessionsResponse(BaseModel):
|
class ListAgentSessionsResponse(BaseModel):
|
||||||
data: List[Session]
|
data: list[Session]
|
||||||
|
|
||||||
|
|
||||||
class AgentConfigOverridablePerTurn(AgentConfigCommon):
|
class AgentConfigOverridablePerTurn(AgentConfigCommon):
|
||||||
instructions: Optional[str] = None
|
instructions: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class AgentTurnResponseEventType(Enum):
|
class AgentTurnResponseEventType(Enum):
|
||||||
|
@ -267,7 +261,7 @@ class AgentTurnResponseStepStartPayload(BaseModel):
|
||||||
event_type: Literal[AgentTurnResponseEventType.step_start.value] = AgentTurnResponseEventType.step_start.value
|
event_type: Literal[AgentTurnResponseEventType.step_start.value] = AgentTurnResponseEventType.step_start.value
|
||||||
step_type: StepType
|
step_type: StepType
|
||||||
step_id: str
|
step_id: str
|
||||||
metadata: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
metadata: dict[str, Any] | None = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -310,14 +304,12 @@ class AgentTurnResponseTurnAwaitingInputPayload(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
AgentTurnResponseEventPayload = Annotated[
|
AgentTurnResponseEventPayload = Annotated[
|
||||||
Union[
|
AgentTurnResponseStepStartPayload
|
||||||
AgentTurnResponseStepStartPayload,
|
| AgentTurnResponseStepProgressPayload
|
||||||
AgentTurnResponseStepProgressPayload,
|
| AgentTurnResponseStepCompletePayload
|
||||||
AgentTurnResponseStepCompletePayload,
|
| AgentTurnResponseTurnStartPayload
|
||||||
AgentTurnResponseTurnStartPayload,
|
| AgentTurnResponseTurnCompletePayload
|
||||||
AgentTurnResponseTurnCompletePayload,
|
| AgentTurnResponseTurnAwaitingInputPayload,
|
||||||
AgentTurnResponseTurnAwaitingInputPayload,
|
|
||||||
],
|
|
||||||
Field(discriminator="event_type"),
|
Field(discriminator="event_type"),
|
||||||
]
|
]
|
||||||
register_schema(AgentTurnResponseEventPayload, name="AgentTurnResponseEventPayload")
|
register_schema(AgentTurnResponseEventPayload, name="AgentTurnResponseEventPayload")
|
||||||
|
@ -346,18 +338,13 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
|
||||||
# TODO: figure out how we can simplify this and make why
|
# TODO: figure out how we can simplify this and make why
|
||||||
# ToolResponseMessage needs to be here (it is function call
|
# ToolResponseMessage needs to be here (it is function call
|
||||||
# execution from outside the system)
|
# execution from outside the system)
|
||||||
messages: List[
|
messages: list[UserMessage | ToolResponseMessage]
|
||||||
Union[
|
|
||||||
UserMessage,
|
|
||||||
ToolResponseMessage,
|
|
||||||
]
|
|
||||||
]
|
|
||||||
|
|
||||||
documents: Optional[List[Document]] = None
|
documents: list[Document] | None = None
|
||||||
toolgroups: Optional[List[AgentToolGroup]] = None
|
toolgroups: list[AgentToolGroup] | None = None
|
||||||
|
|
||||||
stream: Optional[bool] = False
|
stream: bool | None = False
|
||||||
tool_config: Optional[ToolConfig] = None
|
tool_config: ToolConfig | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -365,8 +352,8 @@ class AgentTurnResumeRequest(BaseModel):
|
||||||
agent_id: str
|
agent_id: str
|
||||||
session_id: str
|
session_id: str
|
||||||
turn_id: str
|
turn_id: str
|
||||||
tool_responses: List[ToolResponse]
|
tool_responses: list[ToolResponse]
|
||||||
stream: Optional[bool] = False
|
stream: bool | None = False
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -412,17 +399,12 @@ class Agents(Protocol):
|
||||||
self,
|
self,
|
||||||
agent_id: str,
|
agent_id: str,
|
||||||
session_id: str,
|
session_id: str,
|
||||||
messages: List[
|
messages: list[UserMessage | ToolResponseMessage],
|
||||||
Union[
|
stream: bool | None = False,
|
||||||
UserMessage,
|
documents: list[Document] | None = None,
|
||||||
ToolResponseMessage,
|
toolgroups: list[AgentToolGroup] | None = None,
|
||||||
]
|
tool_config: ToolConfig | None = None,
|
||||||
],
|
) -> Turn | AsyncIterator[AgentTurnResponseStreamChunk]:
|
||||||
stream: Optional[bool] = False,
|
|
||||||
documents: Optional[List[Document]] = None,
|
|
||||||
toolgroups: Optional[List[AgentToolGroup]] = None,
|
|
||||||
tool_config: Optional[ToolConfig] = None,
|
|
||||||
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]:
|
|
||||||
"""Create a new turn for an agent.
|
"""Create a new turn for an agent.
|
||||||
|
|
||||||
:param agent_id: The ID of the agent to create the turn for.
|
:param agent_id: The ID of the agent to create the turn for.
|
||||||
|
@ -446,9 +428,9 @@ class Agents(Protocol):
|
||||||
agent_id: str,
|
agent_id: str,
|
||||||
session_id: str,
|
session_id: str,
|
||||||
turn_id: str,
|
turn_id: str,
|
||||||
tool_responses: List[ToolResponse],
|
tool_responses: list[ToolResponse],
|
||||||
stream: Optional[bool] = False,
|
stream: bool | None = False,
|
||||||
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]:
|
) -> Turn | AsyncIterator[AgentTurnResponseStreamChunk]:
|
||||||
"""Resume an agent turn with executed tool call responses.
|
"""Resume an agent turn with executed tool call responses.
|
||||||
|
|
||||||
When a Turn has the status `awaiting_input` due to pending input from client side tool calls, this endpoint can be used to submit the outputs from the tool calls once they are ready.
|
When a Turn has the status `awaiting_input` due to pending input from client side tool calls, this endpoint can be used to submit the outputs from the tool calls once they are ready.
|
||||||
|
@ -521,7 +503,7 @@ class Agents(Protocol):
|
||||||
self,
|
self,
|
||||||
session_id: str,
|
session_id: str,
|
||||||
agent_id: str,
|
agent_id: str,
|
||||||
turn_ids: Optional[List[str]] = None,
|
turn_ids: list[str] | None = None,
|
||||||
) -> Session:
|
) -> Session:
|
||||||
"""Retrieve an agent session by its ID.
|
"""Retrieve an agent session by its ID.
|
||||||
|
|
||||||
|
@ -583,3 +565,40 @@ class Agents(Protocol):
|
||||||
:returns: A ListAgentSessionsResponse.
|
:returns: A ListAgentSessionsResponse.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
# We situate the OpenAI Responses API in the Agents API just like we did things
|
||||||
|
# for Inference. The Responses API, in its intent, serves the same purpose as
|
||||||
|
# the Agents API above -- it is essentially a lightweight "agentic loop" with
|
||||||
|
# integrated tool calling.
|
||||||
|
#
|
||||||
|
# Both of these APIs are inherently stateful.
|
||||||
|
|
||||||
|
@webmethod(route="/openai/v1/responses/{id}", method="GET")
|
||||||
|
async def get_openai_response(
|
||||||
|
self,
|
||||||
|
id: str,
|
||||||
|
) -> OpenAIResponseObject:
|
||||||
|
"""Retrieve an OpenAI response by its ID.
|
||||||
|
|
||||||
|
:param id: The ID of the OpenAI response to retrieve.
|
||||||
|
:returns: An OpenAIResponseObject.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@webmethod(route="/openai/v1/responses", method="POST")
|
||||||
|
async def create_openai_response(
|
||||||
|
self,
|
||||||
|
input: str | list[OpenAIResponseInputMessage],
|
||||||
|
model: str,
|
||||||
|
previous_response_id: str | None = None,
|
||||||
|
store: bool | None = True,
|
||||||
|
stream: bool | None = False,
|
||||||
|
temperature: float | None = None,
|
||||||
|
tools: list[OpenAIResponseInputTool] | None = None,
|
||||||
|
) -> OpenAIResponseObject | AsyncIterator[OpenAIResponseObjectStream]:
|
||||||
|
"""Create a new OpenAI response.
|
||||||
|
|
||||||
|
:param input: Input message(s) to create the response.
|
||||||
|
:param model: The underlying LLM used for completions.
|
||||||
|
:param previous_response_id: (Optional) if specified, the new response will be a continuation of the previous response. This can be used to easily fork-off new responses from existing responses.
|
||||||
|
"""
|
||||||
|
|
133
llama_stack/apis/agents/openai_responses.py
Normal file
133
llama_stack/apis/agents/openai_responses.py
Normal file
|
@ -0,0 +1,133 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Annotated, Literal
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from llama_stack.schema_utils import json_schema_type, register_schema
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIResponseError(BaseModel):
|
||||||
|
code: str
|
||||||
|
message: str
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIResponseOutputMessageContentOutputText(BaseModel):
|
||||||
|
text: str
|
||||||
|
type: Literal["output_text"] = "output_text"
|
||||||
|
|
||||||
|
|
||||||
|
OpenAIResponseOutputMessageContent = Annotated[
|
||||||
|
OpenAIResponseOutputMessageContentOutputText,
|
||||||
|
Field(discriminator="type"),
|
||||||
|
]
|
||||||
|
register_schema(OpenAIResponseOutputMessageContent, name="OpenAIResponseOutputMessageContent")
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIResponseOutputMessage(BaseModel):
|
||||||
|
id: str
|
||||||
|
content: list[OpenAIResponseOutputMessageContent]
|
||||||
|
role: Literal["assistant"] = "assistant"
|
||||||
|
status: str
|
||||||
|
type: Literal["message"] = "message"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIResponseOutputMessageWebSearchToolCall(BaseModel):
|
||||||
|
id: str
|
||||||
|
status: str
|
||||||
|
type: Literal["web_search_call"] = "web_search_call"
|
||||||
|
|
||||||
|
|
||||||
|
OpenAIResponseOutput = Annotated[
|
||||||
|
OpenAIResponseOutputMessage | OpenAIResponseOutputMessageWebSearchToolCall,
|
||||||
|
Field(discriminator="type"),
|
||||||
|
]
|
||||||
|
register_schema(OpenAIResponseOutput, name="OpenAIResponseOutput")
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIResponseObject(BaseModel):
|
||||||
|
created_at: int
|
||||||
|
error: OpenAIResponseError | None = None
|
||||||
|
id: str
|
||||||
|
model: str
|
||||||
|
object: Literal["response"] = "response"
|
||||||
|
output: list[OpenAIResponseOutput]
|
||||||
|
parallel_tool_calls: bool = False
|
||||||
|
previous_response_id: str | None = None
|
||||||
|
status: str
|
||||||
|
temperature: float | None = None
|
||||||
|
top_p: float | None = None
|
||||||
|
truncation: str | None = None
|
||||||
|
user: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIResponseObjectStreamResponseCreated(BaseModel):
|
||||||
|
response: OpenAIResponseObject
|
||||||
|
type: Literal["response.created"] = "response.created"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIResponseObjectStreamResponseCompleted(BaseModel):
|
||||||
|
response: OpenAIResponseObject
|
||||||
|
type: Literal["response.completed"] = "response.completed"
|
||||||
|
|
||||||
|
|
||||||
|
OpenAIResponseObjectStream = Annotated[
|
||||||
|
OpenAIResponseObjectStreamResponseCreated | OpenAIResponseObjectStreamResponseCompleted,
|
||||||
|
Field(discriminator="type"),
|
||||||
|
]
|
||||||
|
register_schema(OpenAIResponseObjectStream, name="OpenAIResponseObjectStream")
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIResponseInputMessageContentText(BaseModel):
|
||||||
|
text: str
|
||||||
|
type: Literal["input_text"] = "input_text"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIResponseInputMessageContentImage(BaseModel):
|
||||||
|
detail: Literal["low"] | Literal["high"] | Literal["auto"] = "auto"
|
||||||
|
type: Literal["input_image"] = "input_image"
|
||||||
|
# TODO: handle file_id
|
||||||
|
image_url: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: handle file content types
|
||||||
|
OpenAIResponseInputMessageContent = Annotated[
|
||||||
|
OpenAIResponseInputMessageContentText | OpenAIResponseInputMessageContentImage,
|
||||||
|
Field(discriminator="type"),
|
||||||
|
]
|
||||||
|
register_schema(OpenAIResponseInputMessageContent, name="OpenAIResponseInputMessageContent")
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIResponseInputMessage(BaseModel):
|
||||||
|
content: str | list[OpenAIResponseInputMessageContent]
|
||||||
|
role: Literal["system"] | Literal["developer"] | Literal["user"] | Literal["assistant"]
|
||||||
|
type: Literal["message"] | None = "message"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIResponseInputToolWebSearch(BaseModel):
|
||||||
|
type: Literal["web_search"] | Literal["web_search_preview_2025_03_11"] = "web_search"
|
||||||
|
# TODO: actually use search_context_size somewhere...
|
||||||
|
search_context_size: str | None = Field(default="medium", pattern="^low|medium|high$")
|
||||||
|
# TODO: add user_location
|
||||||
|
|
||||||
|
|
||||||
|
OpenAIResponseInputTool = Annotated[
|
||||||
|
OpenAIResponseInputToolWebSearch,
|
||||||
|
Field(discriminator="type"),
|
||||||
|
]
|
||||||
|
register_schema(OpenAIResponseInputTool, name="OpenAIResponseInputTool")
|
|
@ -4,7 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import List, Optional, Protocol, runtime_checkable
|
from typing import Protocol, runtime_checkable
|
||||||
|
|
||||||
from llama_stack.apis.common.job_types import Job
|
from llama_stack.apis.common.job_types import Job
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
|
@ -34,22 +34,22 @@ class BatchInference(Protocol):
|
||||||
async def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
content_batch: List[InterleavedContent],
|
content_batch: list[InterleavedContent],
|
||||||
sampling_params: Optional[SamplingParams] = None,
|
sampling_params: SamplingParams | None = None,
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: ResponseFormat | None = None,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: LogProbConfig | None = None,
|
||||||
) -> Job: ...
|
) -> Job: ...
|
||||||
|
|
||||||
@webmethod(route="/batch-inference/chat-completion", method="POST")
|
@webmethod(route="/batch-inference/chat-completion", method="POST")
|
||||||
async def chat_completion(
|
async def chat_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
messages_batch: List[List[Message]],
|
messages_batch: list[list[Message]],
|
||||||
sampling_params: Optional[SamplingParams] = None,
|
sampling_params: SamplingParams | None = None,
|
||||||
# zero-shot tool definitions as input to the model
|
# zero-shot tool definitions as input to the model
|
||||||
tools: Optional[List[ToolDefinition]] = None,
|
tools: list[ToolDefinition] | None = None,
|
||||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
tool_choice: ToolChoice | None = ToolChoice.auto,
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
tool_prompt_format: ToolPromptFormat | None = None,
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: ResponseFormat | None = None,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: LogProbConfig | None = None,
|
||||||
) -> Job: ...
|
) -> Job: ...
|
||||||
|
|
|
@ -3,7 +3,7 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
|
from typing import Any, Literal, Protocol, runtime_checkable
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
@ -13,8 +13,8 @@ from llama_stack.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
class CommonBenchmarkFields(BaseModel):
|
class CommonBenchmarkFields(BaseModel):
|
||||||
dataset_id: str
|
dataset_id: str
|
||||||
scoring_functions: List[str]
|
scoring_functions: list[str]
|
||||||
metadata: Dict[str, Any] = Field(
|
metadata: dict[str, Any] = Field(
|
||||||
default_factory=dict,
|
default_factory=dict,
|
||||||
description="Metadata for this evaluation task",
|
description="Metadata for this evaluation task",
|
||||||
)
|
)
|
||||||
|
@ -35,12 +35,12 @@ class Benchmark(CommonBenchmarkFields, Resource):
|
||||||
|
|
||||||
class BenchmarkInput(CommonBenchmarkFields, BaseModel):
|
class BenchmarkInput(CommonBenchmarkFields, BaseModel):
|
||||||
benchmark_id: str
|
benchmark_id: str
|
||||||
provider_id: Optional[str] = None
|
provider_id: str | None = None
|
||||||
provider_benchmark_id: Optional[str] = None
|
provider_benchmark_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class ListBenchmarksResponse(BaseModel):
|
class ListBenchmarksResponse(BaseModel):
|
||||||
data: List[Benchmark]
|
data: list[Benchmark]
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
|
@ -59,8 +59,8 @@ class Benchmarks(Protocol):
|
||||||
self,
|
self,
|
||||||
benchmark_id: str,
|
benchmark_id: str,
|
||||||
dataset_id: str,
|
dataset_id: str,
|
||||||
scoring_functions: List[str],
|
scoring_functions: list[str],
|
||||||
provider_benchmark_id: Optional[str] = None,
|
provider_benchmark_id: str | None = None,
|
||||||
provider_id: Optional[str] = None,
|
provider_id: str | None = None,
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: dict[str, Any] | None = None,
|
||||||
) -> None: ...
|
) -> None: ...
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Annotated, List, Literal, Optional, Union
|
from typing import Annotated, Literal
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, model_validator
|
from pydantic import BaseModel, Field, model_validator
|
||||||
|
|
||||||
|
@ -26,9 +26,9 @@ class _URLOrData(BaseModel):
|
||||||
:param data: base64 encoded image data as string
|
:param data: base64 encoded image data as string
|
||||||
"""
|
"""
|
||||||
|
|
||||||
url: Optional[URL] = None
|
url: URL | None = None
|
||||||
# data is a base64 encoded string, hint with contentEncoding=base64
|
# data is a base64 encoded string, hint with contentEncoding=base64
|
||||||
data: Optional[str] = Field(contentEncoding="base64", default=None)
|
data: str | None = Field(contentEncoding="base64", default=None)
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -64,13 +64,13 @@ class TextContentItem(BaseModel):
|
||||||
|
|
||||||
# other modalities can be added here
|
# other modalities can be added here
|
||||||
InterleavedContentItem = Annotated[
|
InterleavedContentItem = Annotated[
|
||||||
Union[ImageContentItem, TextContentItem],
|
ImageContentItem | TextContentItem,
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
register_schema(InterleavedContentItem, name="InterleavedContentItem")
|
register_schema(InterleavedContentItem, name="InterleavedContentItem")
|
||||||
|
|
||||||
# accept a single "str" as a special case since it is common
|
# accept a single "str" as a special case since it is common
|
||||||
InterleavedContent = Union[str, InterleavedContentItem, List[InterleavedContentItem]]
|
InterleavedContent = str | InterleavedContentItem | list[InterleavedContentItem]
|
||||||
register_schema(InterleavedContent, name="InterleavedContent")
|
register_schema(InterleavedContent, name="InterleavedContent")
|
||||||
|
|
||||||
|
|
||||||
|
@ -100,13 +100,13 @@ class ToolCallDelta(BaseModel):
|
||||||
# you either send an in-progress tool call so the client can stream a long
|
# you either send an in-progress tool call so the client can stream a long
|
||||||
# code generation or you send the final parsed tool call at the end of the
|
# code generation or you send the final parsed tool call at the end of the
|
||||||
# stream
|
# stream
|
||||||
tool_call: Union[str, ToolCall]
|
tool_call: str | ToolCall
|
||||||
parse_status: ToolCallParseStatus
|
parse_status: ToolCallParseStatus
|
||||||
|
|
||||||
|
|
||||||
# streaming completions send a stream of ContentDeltas
|
# streaming completions send a stream of ContentDeltas
|
||||||
ContentDelta = Annotated[
|
ContentDelta = Annotated[
|
||||||
Union[TextDelta, ImageDelta, ToolCallDelta],
|
TextDelta | ImageDelta | ToolCallDelta,
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
register_schema(ContentDelta, name="ContentDelta")
|
register_schema(ContentDelta, name="ContentDelta")
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
@ -25,6 +25,6 @@ class RestAPIMethod(Enum):
|
||||||
class RestAPIExecutionConfig(BaseModel):
|
class RestAPIExecutionConfig(BaseModel):
|
||||||
url: URL
|
url: URL
|
||||||
method: RestAPIMethod
|
method: RestAPIMethod
|
||||||
params: Optional[Dict[str, Any]] = None
|
params: dict[str, Any] | None = None
|
||||||
headers: Optional[Dict[str, Any]] = None
|
headers: dict[str, Any] | None = None
|
||||||
body: Optional[Dict[str, Any]] = None
|
body: dict[str, Any] | None = None
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Any, Dict, List
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
@ -19,5 +19,5 @@ class PaginatedResponse(BaseModel):
|
||||||
:param has_more: Whether there are more items available after this set
|
:param has_more: Whether there are more items available after this set
|
||||||
"""
|
"""
|
||||||
|
|
||||||
data: List[Dict[str, Any]]
|
data: list[dict[str, Any]]
|
||||||
has_more: bool
|
has_more: bool
|
||||||
|
|
|
@ -5,7 +5,6 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
@ -27,4 +26,4 @@ class Checkpoint(BaseModel):
|
||||||
epoch: int
|
epoch: int
|
||||||
post_training_job_id: str
|
post_training_job_id: str
|
||||||
path: str
|
path: str
|
||||||
training_metrics: Optional[PostTrainingMetric] = None
|
training_metrics: PostTrainingMetric | None = None
|
||||||
|
|
|
@ -4,10 +4,9 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Literal, Union
|
from typing import Annotated, Literal
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing_extensions import Annotated
|
|
||||||
|
|
||||||
from llama_stack.schema_utils import json_schema_type, register_schema
|
from llama_stack.schema_utils import json_schema_type, register_schema
|
||||||
|
|
||||||
|
@ -73,18 +72,16 @@ class DialogType(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
ParamType = Annotated[
|
ParamType = Annotated[
|
||||||
Union[
|
StringType
|
||||||
StringType,
|
| NumberType
|
||||||
NumberType,
|
| BooleanType
|
||||||
BooleanType,
|
| ArrayType
|
||||||
ArrayType,
|
| ObjectType
|
||||||
ObjectType,
|
| JsonType
|
||||||
JsonType,
|
| UnionType
|
||||||
UnionType,
|
| ChatCompletionInputType
|
||||||
ChatCompletionInputType,
|
| CompletionInputType
|
||||||
CompletionInputType,
|
| AgentTurnInputType,
|
||||||
AgentTurnInputType,
|
|
||||||
],
|
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
register_schema(ParamType, name="ParamType")
|
register_schema(ParamType, name="ParamType")
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
|
from typing import Any, Protocol, runtime_checkable
|
||||||
|
|
||||||
from llama_stack.apis.common.responses import PaginatedResponse
|
from llama_stack.apis.common.responses import PaginatedResponse
|
||||||
from llama_stack.apis.datasets import Dataset
|
from llama_stack.apis.datasets import Dataset
|
||||||
|
@ -24,8 +24,8 @@ class DatasetIO(Protocol):
|
||||||
async def iterrows(
|
async def iterrows(
|
||||||
self,
|
self,
|
||||||
dataset_id: str,
|
dataset_id: str,
|
||||||
start_index: Optional[int] = None,
|
start_index: int | None = None,
|
||||||
limit: Optional[int] = None,
|
limit: int | None = None,
|
||||||
) -> PaginatedResponse:
|
) -> PaginatedResponse:
|
||||||
"""Get a paginated list of rows from a dataset.
|
"""Get a paginated list of rows from a dataset.
|
||||||
|
|
||||||
|
@ -44,4 +44,4 @@ class DatasetIO(Protocol):
|
||||||
...
|
...
|
||||||
|
|
||||||
@webmethod(route="/datasetio/append-rows/{dataset_id:path}", method="POST")
|
@webmethod(route="/datasetio/append-rows/{dataset_id:path}", method="POST")
|
||||||
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: ...
|
async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None: ...
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Annotated, Any, Dict, List, Literal, Optional, Protocol, Union
|
from typing import Annotated, Any, Literal, Protocol
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
@ -81,11 +81,11 @@ class RowsDataSource(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: Literal["rows"] = "rows"
|
type: Literal["rows"] = "rows"
|
||||||
rows: List[Dict[str, Any]]
|
rows: list[dict[str, Any]]
|
||||||
|
|
||||||
|
|
||||||
DataSource = Annotated[
|
DataSource = Annotated[
|
||||||
Union[URIDataSource, RowsDataSource],
|
URIDataSource | RowsDataSource,
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
register_schema(DataSource, name="DataSource")
|
register_schema(DataSource, name="DataSource")
|
||||||
|
@ -98,7 +98,7 @@ class CommonDatasetFields(BaseModel):
|
||||||
|
|
||||||
purpose: DatasetPurpose
|
purpose: DatasetPurpose
|
||||||
source: DataSource
|
source: DataSource
|
||||||
metadata: Dict[str, Any] = Field(
|
metadata: dict[str, Any] = Field(
|
||||||
default_factory=dict,
|
default_factory=dict,
|
||||||
description="Any additional metadata for this dataset",
|
description="Any additional metadata for this dataset",
|
||||||
)
|
)
|
||||||
|
@ -122,7 +122,7 @@ class DatasetInput(CommonDatasetFields, BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class ListDatasetsResponse(BaseModel):
|
class ListDatasetsResponse(BaseModel):
|
||||||
data: List[Dataset]
|
data: list[Dataset]
|
||||||
|
|
||||||
|
|
||||||
class Datasets(Protocol):
|
class Datasets(Protocol):
|
||||||
|
@ -131,8 +131,8 @@ class Datasets(Protocol):
|
||||||
self,
|
self,
|
||||||
purpose: DatasetPurpose,
|
purpose: DatasetPurpose,
|
||||||
source: DataSource,
|
source: DataSource,
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: dict[str, Any] | None = None,
|
||||||
dataset_id: Optional[str] = None,
|
dataset_id: str | None = None,
|
||||||
) -> Dataset:
|
) -> Dataset:
|
||||||
"""
|
"""
|
||||||
Register a new dataset.
|
Register a new dataset.
|
||||||
|
|
|
@ -5,7 +5,6 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
@ -54,4 +53,4 @@ class Error(BaseModel):
|
||||||
status: int
|
status: int
|
||||||
title: str
|
title: str
|
||||||
detail: str
|
detail: str
|
||||||
instance: Optional[str] = None
|
instance: str | None = None
|
||||||
|
|
|
@ -4,10 +4,9 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Any, Dict, List, Literal, Optional, Protocol, Union
|
from typing import Annotated, Any, Literal, Protocol
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing_extensions import Annotated
|
|
||||||
|
|
||||||
from llama_stack.apis.agents import AgentConfig
|
from llama_stack.apis.agents import AgentConfig
|
||||||
from llama_stack.apis.common.job_types import Job
|
from llama_stack.apis.common.job_types import Job
|
||||||
|
@ -29,7 +28,7 @@ class ModelCandidate(BaseModel):
|
||||||
type: Literal["model"] = "model"
|
type: Literal["model"] = "model"
|
||||||
model: str
|
model: str
|
||||||
sampling_params: SamplingParams
|
sampling_params: SamplingParams
|
||||||
system_message: Optional[SystemMessage] = None
|
system_message: SystemMessage | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -43,7 +42,7 @@ class AgentCandidate(BaseModel):
|
||||||
config: AgentConfig
|
config: AgentConfig
|
||||||
|
|
||||||
|
|
||||||
EvalCandidate = Annotated[Union[ModelCandidate, AgentCandidate], Field(discriminator="type")]
|
EvalCandidate = Annotated[ModelCandidate | AgentCandidate, Field(discriminator="type")]
|
||||||
register_schema(EvalCandidate, name="EvalCandidate")
|
register_schema(EvalCandidate, name="EvalCandidate")
|
||||||
|
|
||||||
|
|
||||||
|
@ -57,11 +56,11 @@ class BenchmarkConfig(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
eval_candidate: EvalCandidate
|
eval_candidate: EvalCandidate
|
||||||
scoring_params: Dict[str, ScoringFnParams] = Field(
|
scoring_params: dict[str, ScoringFnParams] = Field(
|
||||||
description="Map between scoring function id and parameters for each scoring function you want to run",
|
description="Map between scoring function id and parameters for each scoring function you want to run",
|
||||||
default_factory=dict,
|
default_factory=dict,
|
||||||
)
|
)
|
||||||
num_examples: Optional[int] = Field(
|
num_examples: int | None = Field(
|
||||||
description="Number of examples to evaluate (useful for testing), if not provided, all examples in the dataset will be evaluated",
|
description="Number of examples to evaluate (useful for testing), if not provided, all examples in the dataset will be evaluated",
|
||||||
default=None,
|
default=None,
|
||||||
)
|
)
|
||||||
|
@ -76,9 +75,9 @@ class EvaluateResponse(BaseModel):
|
||||||
:param scores: The scores from the evaluation.
|
:param scores: The scores from the evaluation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
generations: List[Dict[str, Any]]
|
generations: list[dict[str, Any]]
|
||||||
# each key in the dict is a scoring function name
|
# each key in the dict is a scoring function name
|
||||||
scores: Dict[str, ScoringResult]
|
scores: dict[str, ScoringResult]
|
||||||
|
|
||||||
|
|
||||||
class Eval(Protocol):
|
class Eval(Protocol):
|
||||||
|
@ -101,8 +100,8 @@ class Eval(Protocol):
|
||||||
async def evaluate_rows(
|
async def evaluate_rows(
|
||||||
self,
|
self,
|
||||||
benchmark_id: str,
|
benchmark_id: str,
|
||||||
input_rows: List[Dict[str, Any]],
|
input_rows: list[dict[str, Any]],
|
||||||
scoring_functions: List[str],
|
scoring_functions: list[str],
|
||||||
benchmark_config: BenchmarkConfig,
|
benchmark_config: BenchmarkConfig,
|
||||||
) -> EvaluateResponse:
|
) -> EvaluateResponse:
|
||||||
"""Evaluate a list of rows on a benchmark.
|
"""Evaluate a list of rows on a benchmark.
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import List, Optional, Protocol, runtime_checkable
|
from typing import Protocol, runtime_checkable
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
@ -42,7 +42,7 @@ class ListBucketResponse(BaseModel):
|
||||||
:param data: List of FileResponse entries
|
:param data: List of FileResponse entries
|
||||||
"""
|
"""
|
||||||
|
|
||||||
data: List[BucketResponse]
|
data: list[BucketResponse]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -74,7 +74,7 @@ class ListFileResponse(BaseModel):
|
||||||
:param data: List of FileResponse entries
|
:param data: List of FileResponse entries
|
||||||
"""
|
"""
|
||||||
|
|
||||||
data: List[FileResponse]
|
data: list[FileResponse]
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
|
@ -102,7 +102,7 @@ class Files(Protocol):
|
||||||
async def upload_content_to_session(
|
async def upload_content_to_session(
|
||||||
self,
|
self,
|
||||||
upload_id: str,
|
upload_id: str,
|
||||||
) -> Optional[FileResponse]:
|
) -> FileResponse | None:
|
||||||
"""
|
"""
|
||||||
Upload file content to an existing upload session.
|
Upload file content to an existing upload session.
|
||||||
On the server, request body will have the raw bytes that are uploaded.
|
On the server, request body will have the raw bytes that are uploaded.
|
||||||
|
|
|
@ -4,21 +4,18 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from collections.abc import AsyncIterator
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import (
|
from typing import (
|
||||||
|
Annotated,
|
||||||
Any,
|
Any,
|
||||||
AsyncIterator,
|
|
||||||
Dict,
|
|
||||||
List,
|
|
||||||
Literal,
|
Literal,
|
||||||
Optional,
|
|
||||||
Protocol,
|
Protocol,
|
||||||
Union,
|
|
||||||
runtime_checkable,
|
runtime_checkable,
|
||||||
)
|
)
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, Field, field_validator
|
||||||
from typing_extensions import Annotated, TypedDict
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent, InterleavedContentItem
|
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent, InterleavedContentItem
|
||||||
from llama_stack.apis.models import Model
|
from llama_stack.apis.models import Model
|
||||||
|
@ -47,8 +44,8 @@ class GreedySamplingStrategy(BaseModel):
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class TopPSamplingStrategy(BaseModel):
|
class TopPSamplingStrategy(BaseModel):
|
||||||
type: Literal["top_p"] = "top_p"
|
type: Literal["top_p"] = "top_p"
|
||||||
temperature: Optional[float] = Field(..., gt=0.0)
|
temperature: float | None = Field(..., gt=0.0)
|
||||||
top_p: Optional[float] = 0.95
|
top_p: float | None = 0.95
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -58,7 +55,7 @@ class TopKSamplingStrategy(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
SamplingStrategy = Annotated[
|
SamplingStrategy = Annotated[
|
||||||
Union[GreedySamplingStrategy, TopPSamplingStrategy, TopKSamplingStrategy],
|
GreedySamplingStrategy | TopPSamplingStrategy | TopKSamplingStrategy,
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
register_schema(SamplingStrategy, name="SamplingStrategy")
|
register_schema(SamplingStrategy, name="SamplingStrategy")
|
||||||
|
@ -79,9 +76,9 @@ class SamplingParams(BaseModel):
|
||||||
|
|
||||||
strategy: SamplingStrategy = Field(default_factory=GreedySamplingStrategy)
|
strategy: SamplingStrategy = Field(default_factory=GreedySamplingStrategy)
|
||||||
|
|
||||||
max_tokens: Optional[int] = 0
|
max_tokens: int | None = 0
|
||||||
repetition_penalty: Optional[float] = 1.0
|
repetition_penalty: float | None = 1.0
|
||||||
stop: Optional[List[str]] = None
|
stop: list[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
class LogProbConfig(BaseModel):
|
class LogProbConfig(BaseModel):
|
||||||
|
@ -90,7 +87,7 @@ class LogProbConfig(BaseModel):
|
||||||
:param top_k: How many tokens (for each position) to return log probabilities for.
|
:param top_k: How many tokens (for each position) to return log probabilities for.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
top_k: Optional[int] = 0
|
top_k: int | None = 0
|
||||||
|
|
||||||
|
|
||||||
class QuantizationType(Enum):
|
class QuantizationType(Enum):
|
||||||
|
@ -125,11 +122,11 @@ class Int4QuantizationConfig(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: Literal["int4_mixed"] = "int4_mixed"
|
type: Literal["int4_mixed"] = "int4_mixed"
|
||||||
scheme: Optional[str] = "int4_weight_int8_dynamic_activation"
|
scheme: str | None = "int4_weight_int8_dynamic_activation"
|
||||||
|
|
||||||
|
|
||||||
QuantizationConfig = Annotated[
|
QuantizationConfig = Annotated[
|
||||||
Union[Bf16QuantizationConfig, Fp8QuantizationConfig, Int4QuantizationConfig],
|
Bf16QuantizationConfig | Fp8QuantizationConfig | Int4QuantizationConfig,
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -145,7 +142,7 @@ class UserMessage(BaseModel):
|
||||||
|
|
||||||
role: Literal["user"] = "user"
|
role: Literal["user"] = "user"
|
||||||
content: InterleavedContent
|
content: InterleavedContent
|
||||||
context: Optional[InterleavedContent] = None
|
context: InterleavedContent | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -190,16 +187,11 @@ class CompletionMessage(BaseModel):
|
||||||
role: Literal["assistant"] = "assistant"
|
role: Literal["assistant"] = "assistant"
|
||||||
content: InterleavedContent
|
content: InterleavedContent
|
||||||
stop_reason: StopReason
|
stop_reason: StopReason
|
||||||
tool_calls: Optional[List[ToolCall]] = Field(default_factory=list)
|
tool_calls: list[ToolCall] | None = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
Message = Annotated[
|
Message = Annotated[
|
||||||
Union[
|
UserMessage | SystemMessage | ToolResponseMessage | CompletionMessage,
|
||||||
UserMessage,
|
|
||||||
SystemMessage,
|
|
||||||
ToolResponseMessage,
|
|
||||||
CompletionMessage,
|
|
||||||
],
|
|
||||||
Field(discriminator="role"),
|
Field(discriminator="role"),
|
||||||
]
|
]
|
||||||
register_schema(Message, name="Message")
|
register_schema(Message, name="Message")
|
||||||
|
@ -208,9 +200,9 @@ register_schema(Message, name="Message")
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ToolResponse(BaseModel):
|
class ToolResponse(BaseModel):
|
||||||
call_id: str
|
call_id: str
|
||||||
tool_name: Union[BuiltinTool, str]
|
tool_name: BuiltinTool | str
|
||||||
content: InterleavedContent
|
content: InterleavedContent
|
||||||
metadata: Optional[Dict[str, Any]] = None
|
metadata: dict[str, Any] | None = None
|
||||||
|
|
||||||
@field_validator("tool_name", mode="before")
|
@field_validator("tool_name", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -243,7 +235,7 @@ class TokenLogProbs(BaseModel):
|
||||||
:param logprobs_by_token: Dictionary mapping tokens to their log probabilities
|
:param logprobs_by_token: Dictionary mapping tokens to their log probabilities
|
||||||
"""
|
"""
|
||||||
|
|
||||||
logprobs_by_token: Dict[str, float]
|
logprobs_by_token: dict[str, float]
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionResponseEventType(Enum):
|
class ChatCompletionResponseEventType(Enum):
|
||||||
|
@ -271,8 +263,8 @@ class ChatCompletionResponseEvent(BaseModel):
|
||||||
|
|
||||||
event_type: ChatCompletionResponseEventType
|
event_type: ChatCompletionResponseEventType
|
||||||
delta: ContentDelta
|
delta: ContentDelta
|
||||||
logprobs: Optional[List[TokenLogProbs]] = None
|
logprobs: list[TokenLogProbs] | None = None
|
||||||
stop_reason: Optional[StopReason] = None
|
stop_reason: StopReason | None = None
|
||||||
|
|
||||||
|
|
||||||
class ResponseFormatType(Enum):
|
class ResponseFormatType(Enum):
|
||||||
|
@ -295,7 +287,7 @@ class JsonSchemaResponseFormat(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: Literal[ResponseFormatType.json_schema.value] = ResponseFormatType.json_schema.value
|
type: Literal[ResponseFormatType.json_schema.value] = ResponseFormatType.json_schema.value
|
||||||
json_schema: Dict[str, Any]
|
json_schema: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -307,11 +299,11 @@ class GrammarResponseFormat(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: Literal[ResponseFormatType.grammar.value] = ResponseFormatType.grammar.value
|
type: Literal[ResponseFormatType.grammar.value] = ResponseFormatType.grammar.value
|
||||||
bnf: Dict[str, Any]
|
bnf: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
ResponseFormat = Annotated[
|
ResponseFormat = Annotated[
|
||||||
Union[JsonSchemaResponseFormat, GrammarResponseFormat],
|
JsonSchemaResponseFormat | GrammarResponseFormat,
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
register_schema(ResponseFormat, name="ResponseFormat")
|
register_schema(ResponseFormat, name="ResponseFormat")
|
||||||
|
@ -321,10 +313,10 @@ register_schema(ResponseFormat, name="ResponseFormat")
|
||||||
class CompletionRequest(BaseModel):
|
class CompletionRequest(BaseModel):
|
||||||
model: str
|
model: str
|
||||||
content: InterleavedContent
|
content: InterleavedContent
|
||||||
sampling_params: Optional[SamplingParams] = Field(default_factory=SamplingParams)
|
sampling_params: SamplingParams | None = Field(default_factory=SamplingParams)
|
||||||
response_format: Optional[ResponseFormat] = None
|
response_format: ResponseFormat | None = None
|
||||||
stream: Optional[bool] = False
|
stream: bool | None = False
|
||||||
logprobs: Optional[LogProbConfig] = None
|
logprobs: LogProbConfig | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -338,7 +330,7 @@ class CompletionResponse(MetricResponseMixin):
|
||||||
|
|
||||||
content: str
|
content: str
|
||||||
stop_reason: StopReason
|
stop_reason: StopReason
|
||||||
logprobs: Optional[List[TokenLogProbs]] = None
|
logprobs: list[TokenLogProbs] | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -351,8 +343,8 @@ class CompletionResponseStreamChunk(MetricResponseMixin):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
delta: str
|
delta: str
|
||||||
stop_reason: Optional[StopReason] = None
|
stop_reason: StopReason | None = None
|
||||||
logprobs: Optional[List[TokenLogProbs]] = None
|
logprobs: list[TokenLogProbs] | None = None
|
||||||
|
|
||||||
|
|
||||||
class SystemMessageBehavior(Enum):
|
class SystemMessageBehavior(Enum):
|
||||||
|
@ -383,9 +375,9 @@ class ToolConfig(BaseModel):
|
||||||
'{{function_definitions}}' to indicate where the function definitions should be inserted.
|
'{{function_definitions}}' to indicate where the function definitions should be inserted.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
tool_choice: Optional[ToolChoice | str] = Field(default=ToolChoice.auto)
|
tool_choice: ToolChoice | str | None = Field(default=ToolChoice.auto)
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None)
|
tool_prompt_format: ToolPromptFormat | None = Field(default=None)
|
||||||
system_message_behavior: Optional[SystemMessageBehavior] = Field(default=SystemMessageBehavior.append)
|
system_message_behavior: SystemMessageBehavior | None = Field(default=SystemMessageBehavior.append)
|
||||||
|
|
||||||
def model_post_init(self, __context: Any) -> None:
|
def model_post_init(self, __context: Any) -> None:
|
||||||
if isinstance(self.tool_choice, str):
|
if isinstance(self.tool_choice, str):
|
||||||
|
@ -399,15 +391,15 @@ class ToolConfig(BaseModel):
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ChatCompletionRequest(BaseModel):
|
class ChatCompletionRequest(BaseModel):
|
||||||
model: str
|
model: str
|
||||||
messages: List[Message]
|
messages: list[Message]
|
||||||
sampling_params: Optional[SamplingParams] = Field(default_factory=SamplingParams)
|
sampling_params: SamplingParams | None = Field(default_factory=SamplingParams)
|
||||||
|
|
||||||
tools: Optional[List[ToolDefinition]] = Field(default_factory=list)
|
tools: list[ToolDefinition] | None = Field(default_factory=list)
|
||||||
tool_config: Optional[ToolConfig] = Field(default_factory=ToolConfig)
|
tool_config: ToolConfig | None = Field(default_factory=ToolConfig)
|
||||||
|
|
||||||
response_format: Optional[ResponseFormat] = None
|
response_format: ResponseFormat | None = None
|
||||||
stream: Optional[bool] = False
|
stream: bool | None = False
|
||||||
logprobs: Optional[LogProbConfig] = None
|
logprobs: LogProbConfig | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -429,7 +421,7 @@ class ChatCompletionResponse(MetricResponseMixin):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
completion_message: CompletionMessage
|
completion_message: CompletionMessage
|
||||||
logprobs: Optional[List[TokenLogProbs]] = None
|
logprobs: list[TokenLogProbs] | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -439,7 +431,7 @@ class EmbeddingsResponse(BaseModel):
|
||||||
:param embeddings: List of embedding vectors, one per input content. Each embedding is a list of floats. The dimensionality of the embedding is model-specific; you can check model metadata using /models/{model_id}
|
:param embeddings: List of embedding vectors, one per input content. Each embedding is a list of floats. The dimensionality of the embedding is model-specific; you can check model metadata using /models/{model_id}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
embeddings: List[List[float]]
|
embeddings: list[list[float]]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -451,7 +443,7 @@ class OpenAIChatCompletionContentPartTextParam(BaseModel):
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class OpenAIImageURL(BaseModel):
|
class OpenAIImageURL(BaseModel):
|
||||||
url: str
|
url: str
|
||||||
detail: Optional[str] = None
|
detail: str | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -461,16 +453,13 @@ class OpenAIChatCompletionContentPartImageParam(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
OpenAIChatCompletionContentPartParam = Annotated[
|
OpenAIChatCompletionContentPartParam = Annotated[
|
||||||
Union[
|
OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam,
|
||||||
OpenAIChatCompletionContentPartTextParam,
|
|
||||||
OpenAIChatCompletionContentPartImageParam,
|
|
||||||
],
|
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
register_schema(OpenAIChatCompletionContentPartParam, name="OpenAIChatCompletionContentPartParam")
|
register_schema(OpenAIChatCompletionContentPartParam, name="OpenAIChatCompletionContentPartParam")
|
||||||
|
|
||||||
|
|
||||||
OpenAIChatCompletionMessageContent = Union[str, List[OpenAIChatCompletionContentPartParam]]
|
OpenAIChatCompletionMessageContent = str | list[OpenAIChatCompletionContentPartParam]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -484,7 +473,7 @@ class OpenAIUserMessageParam(BaseModel):
|
||||||
|
|
||||||
role: Literal["user"] = "user"
|
role: Literal["user"] = "user"
|
||||||
content: OpenAIChatCompletionMessageContent
|
content: OpenAIChatCompletionMessageContent
|
||||||
name: Optional[str] = None
|
name: str | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -498,21 +487,21 @@ class OpenAISystemMessageParam(BaseModel):
|
||||||
|
|
||||||
role: Literal["system"] = "system"
|
role: Literal["system"] = "system"
|
||||||
content: OpenAIChatCompletionMessageContent
|
content: OpenAIChatCompletionMessageContent
|
||||||
name: Optional[str] = None
|
name: str | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class OpenAIChatCompletionToolCallFunction(BaseModel):
|
class OpenAIChatCompletionToolCallFunction(BaseModel):
|
||||||
name: Optional[str] = None
|
name: str | None = None
|
||||||
arguments: Optional[str] = None
|
arguments: str | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class OpenAIChatCompletionToolCall(BaseModel):
|
class OpenAIChatCompletionToolCall(BaseModel):
|
||||||
index: Optional[int] = None
|
index: int | None = None
|
||||||
id: Optional[str] = None
|
id: str | None = None
|
||||||
type: Literal["function"] = "function"
|
type: Literal["function"] = "function"
|
||||||
function: Optional[OpenAIChatCompletionToolCallFunction] = None
|
function: OpenAIChatCompletionToolCallFunction | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -526,9 +515,9 @@ class OpenAIAssistantMessageParam(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
role: Literal["assistant"] = "assistant"
|
role: Literal["assistant"] = "assistant"
|
||||||
content: OpenAIChatCompletionMessageContent
|
content: OpenAIChatCompletionMessageContent | None = None
|
||||||
name: Optional[str] = None
|
name: str | None = None
|
||||||
tool_calls: Optional[List[OpenAIChatCompletionToolCall]] = Field(default_factory=list)
|
tool_calls: list[OpenAIChatCompletionToolCall] | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -556,17 +545,15 @@ class OpenAIDeveloperMessageParam(BaseModel):
|
||||||
|
|
||||||
role: Literal["developer"] = "developer"
|
role: Literal["developer"] = "developer"
|
||||||
content: OpenAIChatCompletionMessageContent
|
content: OpenAIChatCompletionMessageContent
|
||||||
name: Optional[str] = None
|
name: str | None = None
|
||||||
|
|
||||||
|
|
||||||
OpenAIMessageParam = Annotated[
|
OpenAIMessageParam = Annotated[
|
||||||
Union[
|
OpenAIUserMessageParam
|
||||||
OpenAIUserMessageParam,
|
| OpenAISystemMessageParam
|
||||||
OpenAISystemMessageParam,
|
| OpenAIAssistantMessageParam
|
||||||
OpenAIAssistantMessageParam,
|
| OpenAIToolMessageParam
|
||||||
OpenAIToolMessageParam,
|
| OpenAIDeveloperMessageParam,
|
||||||
OpenAIDeveloperMessageParam,
|
|
||||||
],
|
|
||||||
Field(discriminator="role"),
|
Field(discriminator="role"),
|
||||||
]
|
]
|
||||||
register_schema(OpenAIMessageParam, name="OpenAIMessageParam")
|
register_schema(OpenAIMessageParam, name="OpenAIMessageParam")
|
||||||
|
@ -580,14 +567,14 @@ class OpenAIResponseFormatText(BaseModel):
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class OpenAIJSONSchema(TypedDict, total=False):
|
class OpenAIJSONSchema(TypedDict, total=False):
|
||||||
name: str
|
name: str
|
||||||
description: Optional[str] = None
|
description: str | None = None
|
||||||
strict: Optional[bool] = None
|
strict: bool | None = None
|
||||||
|
|
||||||
# Pydantic BaseModel cannot be used with a schema param, since it already
|
# Pydantic BaseModel cannot be used with a schema param, since it already
|
||||||
# has one. And, we don't want to alias here because then have to handle
|
# has one. And, we don't want to alias here because then have to handle
|
||||||
# that alias when converting to OpenAI params. So, to support schema,
|
# that alias when converting to OpenAI params. So, to support schema,
|
||||||
# we use a TypedDict.
|
# we use a TypedDict.
|
||||||
schema: Optional[Dict[str, Any]] = None
|
schema: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -602,11 +589,7 @@ class OpenAIResponseFormatJSONObject(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
OpenAIResponseFormatParam = Annotated[
|
OpenAIResponseFormatParam = Annotated[
|
||||||
Union[
|
OpenAIResponseFormatText | OpenAIResponseFormatJSONSchema | OpenAIResponseFormatJSONObject,
|
||||||
OpenAIResponseFormatText,
|
|
||||||
OpenAIResponseFormatJSONSchema,
|
|
||||||
OpenAIResponseFormatJSONObject,
|
|
||||||
],
|
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
register_schema(OpenAIResponseFormatParam, name="OpenAIResponseFormatParam")
|
register_schema(OpenAIResponseFormatParam, name="OpenAIResponseFormatParam")
|
||||||
|
@ -622,7 +605,7 @@ class OpenAITopLogProb(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
token: str
|
token: str
|
||||||
bytes: Optional[List[int]] = None
|
bytes: list[int] | None = None
|
||||||
logprob: float
|
logprob: float
|
||||||
|
|
||||||
|
|
||||||
|
@ -637,9 +620,9 @@ class OpenAITokenLogProb(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
token: str
|
token: str
|
||||||
bytes: Optional[List[int]] = None
|
bytes: list[int] | None = None
|
||||||
logprob: float
|
logprob: float
|
||||||
top_logprobs: List[OpenAITopLogProb]
|
top_logprobs: list[OpenAITopLogProb]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -650,8 +633,8 @@ class OpenAIChoiceLogprobs(BaseModel):
|
||||||
:param refusal: (Optional) The log probabilities for the tokens in the message
|
:param refusal: (Optional) The log probabilities for the tokens in the message
|
||||||
"""
|
"""
|
||||||
|
|
||||||
content: Optional[List[OpenAITokenLogProb]] = None
|
content: list[OpenAITokenLogProb] | None = None
|
||||||
refusal: Optional[List[OpenAITokenLogProb]] = None
|
refusal: list[OpenAITokenLogProb] | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -664,10 +647,10 @@ class OpenAIChoiceDelta(BaseModel):
|
||||||
:param tool_calls: (Optional) The tool calls of the delta
|
:param tool_calls: (Optional) The tool calls of the delta
|
||||||
"""
|
"""
|
||||||
|
|
||||||
content: Optional[str] = None
|
content: str | None = None
|
||||||
refusal: Optional[str] = None
|
refusal: str | None = None
|
||||||
role: Optional[str] = None
|
role: str | None = None
|
||||||
tool_calls: Optional[List[OpenAIChatCompletionToolCall]] = None
|
tool_calls: list[OpenAIChatCompletionToolCall] | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -683,7 +666,7 @@ class OpenAIChunkChoice(BaseModel):
|
||||||
delta: OpenAIChoiceDelta
|
delta: OpenAIChoiceDelta
|
||||||
finish_reason: str
|
finish_reason: str
|
||||||
index: int
|
index: int
|
||||||
logprobs: Optional[OpenAIChoiceLogprobs] = None
|
logprobs: OpenAIChoiceLogprobs | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -699,7 +682,7 @@ class OpenAIChoice(BaseModel):
|
||||||
message: OpenAIMessageParam
|
message: OpenAIMessageParam
|
||||||
finish_reason: str
|
finish_reason: str
|
||||||
index: int
|
index: int
|
||||||
logprobs: Optional[OpenAIChoiceLogprobs] = None
|
logprobs: OpenAIChoiceLogprobs | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -714,7 +697,7 @@ class OpenAIChatCompletion(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
id: str
|
id: str
|
||||||
choices: List[OpenAIChoice]
|
choices: list[OpenAIChoice]
|
||||||
object: Literal["chat.completion"] = "chat.completion"
|
object: Literal["chat.completion"] = "chat.completion"
|
||||||
created: int
|
created: int
|
||||||
model: str
|
model: str
|
||||||
|
@ -732,7 +715,7 @@ class OpenAIChatCompletionChunk(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
id: str
|
id: str
|
||||||
choices: List[OpenAIChunkChoice]
|
choices: list[OpenAIChunkChoice]
|
||||||
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
|
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
|
||||||
created: int
|
created: int
|
||||||
model: str
|
model: str
|
||||||
|
@ -748,10 +731,10 @@ class OpenAICompletionLogprobs(BaseModel):
|
||||||
:top_logprobs: (Optional) The top log probabilities for the tokens
|
:top_logprobs: (Optional) The top log probabilities for the tokens
|
||||||
"""
|
"""
|
||||||
|
|
||||||
text_offset: Optional[List[int]] = None
|
text_offset: list[int] | None = None
|
||||||
token_logprobs: Optional[List[float]] = None
|
token_logprobs: list[float] | None = None
|
||||||
tokens: Optional[List[str]] = None
|
tokens: list[str] | None = None
|
||||||
top_logprobs: Optional[List[Dict[str, float]]] = None
|
top_logprobs: list[dict[str, float]] | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -767,7 +750,7 @@ class OpenAICompletionChoice(BaseModel):
|
||||||
finish_reason: str
|
finish_reason: str
|
||||||
text: str
|
text: str
|
||||||
index: int
|
index: int
|
||||||
logprobs: Optional[OpenAIChoiceLogprobs] = None
|
logprobs: OpenAIChoiceLogprobs | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -782,7 +765,7 @@ class OpenAICompletion(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
id: str
|
id: str
|
||||||
choices: List[OpenAICompletionChoice]
|
choices: list[OpenAICompletionChoice]
|
||||||
created: int
|
created: int
|
||||||
model: str
|
model: str
|
||||||
object: Literal["text_completion"] = "text_completion"
|
object: Literal["text_completion"] = "text_completion"
|
||||||
|
@ -818,12 +801,12 @@ class EmbeddingTaskType(Enum):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class BatchCompletionResponse(BaseModel):
|
class BatchCompletionResponse(BaseModel):
|
||||||
batch: List[CompletionResponse]
|
batch: list[CompletionResponse]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class BatchChatCompletionResponse(BaseModel):
|
class BatchChatCompletionResponse(BaseModel):
|
||||||
batch: List[ChatCompletionResponse]
|
batch: list[ChatCompletionResponse]
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
|
@ -843,11 +826,11 @@ class Inference(Protocol):
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
content: InterleavedContent,
|
content: InterleavedContent,
|
||||||
sampling_params: Optional[SamplingParams] = None,
|
sampling_params: SamplingParams | None = None,
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: ResponseFormat | None = None,
|
||||||
stream: Optional[bool] = False,
|
stream: bool | None = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: LogProbConfig | None = None,
|
||||||
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]:
|
) -> CompletionResponse | AsyncIterator[CompletionResponseStreamChunk]:
|
||||||
"""Generate a completion for the given content using the specified model.
|
"""Generate a completion for the given content using the specified model.
|
||||||
|
|
||||||
:param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
:param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
||||||
|
@ -865,10 +848,10 @@ class Inference(Protocol):
|
||||||
async def batch_completion(
|
async def batch_completion(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
content_batch: List[InterleavedContent],
|
content_batch: list[InterleavedContent],
|
||||||
sampling_params: Optional[SamplingParams] = None,
|
sampling_params: SamplingParams | None = None,
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: ResponseFormat | None = None,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: LogProbConfig | None = None,
|
||||||
) -> BatchCompletionResponse:
|
) -> BatchCompletionResponse:
|
||||||
raise NotImplementedError("Batch completion is not implemented")
|
raise NotImplementedError("Batch completion is not implemented")
|
||||||
|
|
||||||
|
@ -876,16 +859,16 @@ class Inference(Protocol):
|
||||||
async def chat_completion(
|
async def chat_completion(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
messages: List[Message],
|
messages: list[Message],
|
||||||
sampling_params: Optional[SamplingParams] = None,
|
sampling_params: SamplingParams | None = None,
|
||||||
tools: Optional[List[ToolDefinition]] = None,
|
tools: list[ToolDefinition] | None = None,
|
||||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
tool_choice: ToolChoice | None = ToolChoice.auto,
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
tool_prompt_format: ToolPromptFormat | None = None,
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: ResponseFormat | None = None,
|
||||||
stream: Optional[bool] = False,
|
stream: bool | None = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: LogProbConfig | None = None,
|
||||||
tool_config: Optional[ToolConfig] = None,
|
tool_config: ToolConfig | None = None,
|
||||||
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
|
) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]:
|
||||||
"""Generate a chat completion for the given messages using the specified model.
|
"""Generate a chat completion for the given messages using the specified model.
|
||||||
|
|
||||||
:param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
:param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
||||||
|
@ -916,12 +899,12 @@ class Inference(Protocol):
|
||||||
async def batch_chat_completion(
|
async def batch_chat_completion(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
messages_batch: List[List[Message]],
|
messages_batch: list[list[Message]],
|
||||||
sampling_params: Optional[SamplingParams] = None,
|
sampling_params: SamplingParams | None = None,
|
||||||
tools: Optional[List[ToolDefinition]] = None,
|
tools: list[ToolDefinition] | None = None,
|
||||||
tool_config: Optional[ToolConfig] = None,
|
tool_config: ToolConfig | None = None,
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: ResponseFormat | None = None,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: LogProbConfig | None = None,
|
||||||
) -> BatchChatCompletionResponse:
|
) -> BatchChatCompletionResponse:
|
||||||
raise NotImplementedError("Batch chat completion is not implemented")
|
raise NotImplementedError("Batch chat completion is not implemented")
|
||||||
|
|
||||||
|
@ -929,10 +912,10 @@ class Inference(Protocol):
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
contents: List[str] | List[InterleavedContentItem],
|
contents: list[str] | list[InterleavedContentItem],
|
||||||
text_truncation: Optional[TextTruncation] = TextTruncation.none,
|
text_truncation: TextTruncation | None = TextTruncation.none,
|
||||||
output_dimension: Optional[int] = None,
|
output_dimension: int | None = None,
|
||||||
task_type: Optional[EmbeddingTaskType] = None,
|
task_type: EmbeddingTaskType | None = None,
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
"""Generate embeddings for content pieces using the specified model.
|
"""Generate embeddings for content pieces using the specified model.
|
||||||
|
|
||||||
|
@ -950,25 +933,25 @@ class Inference(Protocol):
|
||||||
self,
|
self,
|
||||||
# Standard OpenAI completion parameters
|
# Standard OpenAI completion parameters
|
||||||
model: str,
|
model: str,
|
||||||
prompt: Union[str, List[str], List[int], List[List[int]]],
|
prompt: str | list[str] | list[int] | list[list[int]],
|
||||||
best_of: Optional[int] = None,
|
best_of: int | None = None,
|
||||||
echo: Optional[bool] = None,
|
echo: bool | None = None,
|
||||||
frequency_penalty: Optional[float] = None,
|
frequency_penalty: float | None = None,
|
||||||
logit_bias: Optional[Dict[str, float]] = None,
|
logit_bias: dict[str, float] | None = None,
|
||||||
logprobs: Optional[bool] = None,
|
logprobs: bool | None = None,
|
||||||
max_tokens: Optional[int] = None,
|
max_tokens: int | None = None,
|
||||||
n: Optional[int] = None,
|
n: int | None = None,
|
||||||
presence_penalty: Optional[float] = None,
|
presence_penalty: float | None = None,
|
||||||
seed: Optional[int] = None,
|
seed: int | None = None,
|
||||||
stop: Optional[Union[str, List[str]]] = None,
|
stop: str | list[str] | None = None,
|
||||||
stream: Optional[bool] = None,
|
stream: bool | None = None,
|
||||||
stream_options: Optional[Dict[str, Any]] = None,
|
stream_options: dict[str, Any] | None = None,
|
||||||
temperature: Optional[float] = None,
|
temperature: float | None = None,
|
||||||
top_p: Optional[float] = None,
|
top_p: float | None = None,
|
||||||
user: Optional[str] = None,
|
user: str | None = None,
|
||||||
# vLLM-specific parameters
|
# vLLM-specific parameters
|
||||||
guided_choice: Optional[List[str]] = None,
|
guided_choice: list[str] | None = None,
|
||||||
prompt_logprobs: Optional[int] = None,
|
prompt_logprobs: int | None = None,
|
||||||
) -> OpenAICompletion:
|
) -> OpenAICompletion:
|
||||||
"""Generate an OpenAI-compatible completion for the given prompt using the specified model.
|
"""Generate an OpenAI-compatible completion for the given prompt using the specified model.
|
||||||
|
|
||||||
|
@ -996,29 +979,29 @@ class Inference(Protocol):
|
||||||
async def openai_chat_completion(
|
async def openai_chat_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[OpenAIMessageParam],
|
messages: list[OpenAIMessageParam],
|
||||||
frequency_penalty: Optional[float] = None,
|
frequency_penalty: float | None = None,
|
||||||
function_call: Optional[Union[str, Dict[str, Any]]] = None,
|
function_call: str | dict[str, Any] | None = None,
|
||||||
functions: Optional[List[Dict[str, Any]]] = None,
|
functions: list[dict[str, Any]] | None = None,
|
||||||
logit_bias: Optional[Dict[str, float]] = None,
|
logit_bias: dict[str, float] | None = None,
|
||||||
logprobs: Optional[bool] = None,
|
logprobs: bool | None = None,
|
||||||
max_completion_tokens: Optional[int] = None,
|
max_completion_tokens: int | None = None,
|
||||||
max_tokens: Optional[int] = None,
|
max_tokens: int | None = None,
|
||||||
n: Optional[int] = None,
|
n: int | None = None,
|
||||||
parallel_tool_calls: Optional[bool] = None,
|
parallel_tool_calls: bool | None = None,
|
||||||
presence_penalty: Optional[float] = None,
|
presence_penalty: float | None = None,
|
||||||
response_format: Optional[OpenAIResponseFormatParam] = None,
|
response_format: OpenAIResponseFormatParam | None = None,
|
||||||
seed: Optional[int] = None,
|
seed: int | None = None,
|
||||||
stop: Optional[Union[str, List[str]]] = None,
|
stop: str | list[str] | None = None,
|
||||||
stream: Optional[bool] = None,
|
stream: bool | None = None,
|
||||||
stream_options: Optional[Dict[str, Any]] = None,
|
stream_options: dict[str, Any] | None = None,
|
||||||
temperature: Optional[float] = None,
|
temperature: float | None = None,
|
||||||
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
|
tool_choice: str | dict[str, Any] | None = None,
|
||||||
tools: Optional[List[Dict[str, Any]]] = None,
|
tools: list[dict[str, Any]] | None = None,
|
||||||
top_logprobs: Optional[int] = None,
|
top_logprobs: int | None = None,
|
||||||
top_p: Optional[float] = None,
|
top_p: float | None = None,
|
||||||
user: Optional[str] = None,
|
user: str | None = None,
|
||||||
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
|
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||||
"""Generate an OpenAI-compatible chat completion for the given messages using the specified model.
|
"""Generate an OpenAI-compatible chat completion for the given messages using the specified model.
|
||||||
|
|
||||||
:param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
:param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import List, Protocol, runtime_checkable
|
from typing import Protocol, runtime_checkable
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
@ -16,7 +16,7 @@ from llama_stack.schema_utils import json_schema_type, webmethod
|
||||||
class RouteInfo(BaseModel):
|
class RouteInfo(BaseModel):
|
||||||
route: str
|
route: str
|
||||||
method: str
|
method: str
|
||||||
provider_types: List[str]
|
provider_types: list[str]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -30,7 +30,7 @@ class VersionInfo(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class ListRoutesResponse(BaseModel):
|
class ListRoutesResponse(BaseModel):
|
||||||
data: List[RouteInfo]
|
data: list[RouteInfo]
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
|
from typing import Any, Literal, Protocol, runtime_checkable
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
|
@ -15,7 +15,7 @@ from llama_stack.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
|
|
||||||
class CommonModelFields(BaseModel):
|
class CommonModelFields(BaseModel):
|
||||||
metadata: Dict[str, Any] = Field(
|
metadata: dict[str, Any] = Field(
|
||||||
default_factory=dict,
|
default_factory=dict,
|
||||||
description="Any additional metadata for this model",
|
description="Any additional metadata for this model",
|
||||||
)
|
)
|
||||||
|
@ -46,14 +46,14 @@ class Model(CommonModelFields, Resource):
|
||||||
|
|
||||||
class ModelInput(CommonModelFields):
|
class ModelInput(CommonModelFields):
|
||||||
model_id: str
|
model_id: str
|
||||||
provider_id: Optional[str] = None
|
provider_id: str | None = None
|
||||||
provider_model_id: Optional[str] = None
|
provider_model_id: str | None = None
|
||||||
model_type: Optional[ModelType] = ModelType.llm
|
model_type: ModelType | None = ModelType.llm
|
||||||
model_config = ConfigDict(protected_namespaces=())
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
|
||||||
|
|
||||||
class ListModelsResponse(BaseModel):
|
class ListModelsResponse(BaseModel):
|
||||||
data: List[Model]
|
data: list[Model]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -73,7 +73,7 @@ class OpenAIModel(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class OpenAIListModelsResponse(BaseModel):
|
class OpenAIListModelsResponse(BaseModel):
|
||||||
data: List[OpenAIModel]
|
data: list[OpenAIModel]
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
|
@ -95,10 +95,10 @@ class Models(Protocol):
|
||||||
async def register_model(
|
async def register_model(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
provider_model_id: Optional[str] = None,
|
provider_model_id: str | None = None,
|
||||||
provider_id: Optional[str] = None,
|
provider_id: str | None = None,
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: dict[str, Any] | None = None,
|
||||||
model_type: Optional[ModelType] = None,
|
model_type: ModelType | None = None,
|
||||||
) -> Model: ...
|
) -> Model: ...
|
||||||
|
|
||||||
@webmethod(route="/models/{model_id:path}", method="DELETE")
|
@webmethod(route="/models/{model_id:path}", method="DELETE")
|
||||||
|
|
|
@ -6,10 +6,9 @@
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, Literal, Optional, Protocol, Union
|
from typing import Annotated, Any, Literal, Protocol
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing_extensions import Annotated
|
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import URL
|
from llama_stack.apis.common.content_types import URL
|
||||||
from llama_stack.apis.common.job_types import JobStatus
|
from llama_stack.apis.common.job_types import JobStatus
|
||||||
|
@ -36,9 +35,9 @@ class DataConfig(BaseModel):
|
||||||
batch_size: int
|
batch_size: int
|
||||||
shuffle: bool
|
shuffle: bool
|
||||||
data_format: DatasetFormat
|
data_format: DatasetFormat
|
||||||
validation_dataset_id: Optional[str] = None
|
validation_dataset_id: str | None = None
|
||||||
packed: Optional[bool] = False
|
packed: bool | None = False
|
||||||
train_on_input: Optional[bool] = False
|
train_on_input: bool | None = False
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -51,10 +50,10 @@ class OptimizerConfig(BaseModel):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class EfficiencyConfig(BaseModel):
|
class EfficiencyConfig(BaseModel):
|
||||||
enable_activation_checkpointing: Optional[bool] = False
|
enable_activation_checkpointing: bool | None = False
|
||||||
enable_activation_offloading: Optional[bool] = False
|
enable_activation_offloading: bool | None = False
|
||||||
memory_efficient_fsdp_wrap: Optional[bool] = False
|
memory_efficient_fsdp_wrap: bool | None = False
|
||||||
fsdp_cpu_offload: Optional[bool] = False
|
fsdp_cpu_offload: bool | None = False
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -62,23 +61,23 @@ class TrainingConfig(BaseModel):
|
||||||
n_epochs: int
|
n_epochs: int
|
||||||
max_steps_per_epoch: int = 1
|
max_steps_per_epoch: int = 1
|
||||||
gradient_accumulation_steps: int = 1
|
gradient_accumulation_steps: int = 1
|
||||||
max_validation_steps: Optional[int] = 1
|
max_validation_steps: int | None = 1
|
||||||
data_config: Optional[DataConfig] = None
|
data_config: DataConfig | None = None
|
||||||
optimizer_config: Optional[OptimizerConfig] = None
|
optimizer_config: OptimizerConfig | None = None
|
||||||
efficiency_config: Optional[EfficiencyConfig] = None
|
efficiency_config: EfficiencyConfig | None = None
|
||||||
dtype: Optional[str] = "bf16"
|
dtype: str | None = "bf16"
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class LoraFinetuningConfig(BaseModel):
|
class LoraFinetuningConfig(BaseModel):
|
||||||
type: Literal["LoRA"] = "LoRA"
|
type: Literal["LoRA"] = "LoRA"
|
||||||
lora_attn_modules: List[str]
|
lora_attn_modules: list[str]
|
||||||
apply_lora_to_mlp: bool
|
apply_lora_to_mlp: bool
|
||||||
apply_lora_to_output: bool
|
apply_lora_to_output: bool
|
||||||
rank: int
|
rank: int
|
||||||
alpha: int
|
alpha: int
|
||||||
use_dora: Optional[bool] = False
|
use_dora: bool | None = False
|
||||||
quantize_base: Optional[bool] = False
|
quantize_base: bool | None = False
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -88,7 +87,7 @@ class QATFinetuningConfig(BaseModel):
|
||||||
group_size: int
|
group_size: int
|
||||||
|
|
||||||
|
|
||||||
AlgorithmConfig = Annotated[Union[LoraFinetuningConfig, QATFinetuningConfig], Field(discriminator="type")]
|
AlgorithmConfig = Annotated[LoraFinetuningConfig | QATFinetuningConfig, Field(discriminator="type")]
|
||||||
register_schema(AlgorithmConfig, name="AlgorithmConfig")
|
register_schema(AlgorithmConfig, name="AlgorithmConfig")
|
||||||
|
|
||||||
|
|
||||||
|
@ -97,7 +96,7 @@ class PostTrainingJobLogStream(BaseModel):
|
||||||
"""Stream of logs from a finetuning job."""
|
"""Stream of logs from a finetuning job."""
|
||||||
|
|
||||||
job_uuid: str
|
job_uuid: str
|
||||||
log_lines: List[str]
|
log_lines: list[str]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -131,8 +130,8 @@ class PostTrainingRLHFRequest(BaseModel):
|
||||||
training_config: TrainingConfig
|
training_config: TrainingConfig
|
||||||
|
|
||||||
# TODO: define these
|
# TODO: define these
|
||||||
hyperparam_search_config: Dict[str, Any]
|
hyperparam_search_config: dict[str, Any]
|
||||||
logger_config: Dict[str, Any]
|
logger_config: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
class PostTrainingJob(BaseModel):
|
class PostTrainingJob(BaseModel):
|
||||||
|
@ -146,17 +145,17 @@ class PostTrainingJobStatusResponse(BaseModel):
|
||||||
job_uuid: str
|
job_uuid: str
|
||||||
status: JobStatus
|
status: JobStatus
|
||||||
|
|
||||||
scheduled_at: Optional[datetime] = None
|
scheduled_at: datetime | None = None
|
||||||
started_at: Optional[datetime] = None
|
started_at: datetime | None = None
|
||||||
completed_at: Optional[datetime] = None
|
completed_at: datetime | None = None
|
||||||
|
|
||||||
resources_allocated: Optional[Dict[str, Any]] = None
|
resources_allocated: dict[str, Any] | None = None
|
||||||
|
|
||||||
checkpoints: List[Checkpoint] = Field(default_factory=list)
|
checkpoints: list[Checkpoint] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
class ListPostTrainingJobsResponse(BaseModel):
|
class ListPostTrainingJobsResponse(BaseModel):
|
||||||
data: List[PostTrainingJob]
|
data: list[PostTrainingJob]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -164,7 +163,7 @@ class PostTrainingJobArtifactsResponse(BaseModel):
|
||||||
"""Artifacts of a finetuning job."""
|
"""Artifacts of a finetuning job."""
|
||||||
|
|
||||||
job_uuid: str
|
job_uuid: str
|
||||||
checkpoints: List[Checkpoint] = Field(default_factory=list)
|
checkpoints: list[Checkpoint] = Field(default_factory=list)
|
||||||
|
|
||||||
# TODO(ashwin): metrics, evals
|
# TODO(ashwin): metrics, evals
|
||||||
|
|
||||||
|
@ -175,14 +174,14 @@ class PostTraining(Protocol):
|
||||||
self,
|
self,
|
||||||
job_uuid: str,
|
job_uuid: str,
|
||||||
training_config: TrainingConfig,
|
training_config: TrainingConfig,
|
||||||
hyperparam_search_config: Dict[str, Any],
|
hyperparam_search_config: dict[str, Any],
|
||||||
logger_config: Dict[str, Any],
|
logger_config: dict[str, Any],
|
||||||
model: Optional[str] = Field(
|
model: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Model descriptor for training if not in provider config`",
|
description="Model descriptor for training if not in provider config`",
|
||||||
),
|
),
|
||||||
checkpoint_dir: Optional[str] = None,
|
checkpoint_dir: str | None = None,
|
||||||
algorithm_config: Optional[AlgorithmConfig] = None,
|
algorithm_config: AlgorithmConfig | None = None,
|
||||||
) -> PostTrainingJob: ...
|
) -> PostTrainingJob: ...
|
||||||
|
|
||||||
@webmethod(route="/post-training/preference-optimize", method="POST")
|
@webmethod(route="/post-training/preference-optimize", method="POST")
|
||||||
|
@ -192,8 +191,8 @@ class PostTraining(Protocol):
|
||||||
finetuned_model: str,
|
finetuned_model: str,
|
||||||
algorithm_config: DPOAlignmentConfig,
|
algorithm_config: DPOAlignmentConfig,
|
||||||
training_config: TrainingConfig,
|
training_config: TrainingConfig,
|
||||||
hyperparam_search_config: Dict[str, Any],
|
hyperparam_search_config: dict[str, Any],
|
||||||
logger_config: Dict[str, Any],
|
logger_config: dict[str, Any],
|
||||||
) -> PostTrainingJob: ...
|
) -> PostTrainingJob: ...
|
||||||
|
|
||||||
@webmethod(route="/post-training/jobs", method="GET")
|
@webmethod(route="/post-training/jobs", method="GET")
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Any, Dict, List, Protocol, runtime_checkable
|
from typing import Any, Protocol, runtime_checkable
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
@ -17,12 +17,12 @@ class ProviderInfo(BaseModel):
|
||||||
api: str
|
api: str
|
||||||
provider_id: str
|
provider_id: str
|
||||||
provider_type: str
|
provider_type: str
|
||||||
config: Dict[str, Any]
|
config: dict[str, Any]
|
||||||
health: HealthResponse
|
health: HealthResponse
|
||||||
|
|
||||||
|
|
||||||
class ListProvidersResponse(BaseModel):
|
class ListProvidersResponse(BaseModel):
|
||||||
data: List[ProviderInfo]
|
data: list[ProviderInfo]
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
|
from typing import Any, Protocol, runtime_checkable
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
@ -27,16 +27,16 @@ class SafetyViolation(BaseModel):
|
||||||
violation_level: ViolationLevel
|
violation_level: ViolationLevel
|
||||||
|
|
||||||
# what message should you convey to the user
|
# what message should you convey to the user
|
||||||
user_message: Optional[str] = None
|
user_message: str | None = None
|
||||||
|
|
||||||
# additional metadata (including specific violation codes) more for
|
# additional metadata (including specific violation codes) more for
|
||||||
# debugging, telemetry
|
# debugging, telemetry
|
||||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class RunShieldResponse(BaseModel):
|
class RunShieldResponse(BaseModel):
|
||||||
violation: Optional[SafetyViolation] = None
|
violation: SafetyViolation | None = None
|
||||||
|
|
||||||
|
|
||||||
class ShieldStore(Protocol):
|
class ShieldStore(Protocol):
|
||||||
|
@ -52,6 +52,6 @@ class Safety(Protocol):
|
||||||
async def run_shield(
|
async def run_shield(
|
||||||
self,
|
self,
|
||||||
shield_id: str,
|
shield_id: str,
|
||||||
messages: List[Message],
|
messages: list[Message],
|
||||||
params: Dict[str, Any] = None,
|
params: dict[str, Any] = None,
|
||||||
) -> RunShieldResponse: ...
|
) -> RunShieldResponse: ...
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
|
from typing import Any, Protocol, runtime_checkable
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
@ -12,7 +12,7 @@ from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams
|
||||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
# mapping of metric to value
|
# mapping of metric to value
|
||||||
ScoringResultRow = Dict[str, Any]
|
ScoringResultRow = dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -24,15 +24,15 @@ class ScoringResult(BaseModel):
|
||||||
:param aggregated_results: Map of metric name to aggregated value
|
:param aggregated_results: Map of metric name to aggregated value
|
||||||
"""
|
"""
|
||||||
|
|
||||||
score_rows: List[ScoringResultRow]
|
score_rows: list[ScoringResultRow]
|
||||||
# aggregated metrics to value
|
# aggregated metrics to value
|
||||||
aggregated_results: Dict[str, Any]
|
aggregated_results: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ScoreBatchResponse(BaseModel):
|
class ScoreBatchResponse(BaseModel):
|
||||||
dataset_id: Optional[str] = None
|
dataset_id: str | None = None
|
||||||
results: Dict[str, ScoringResult]
|
results: dict[str, ScoringResult]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -44,7 +44,7 @@ class ScoreResponse(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# each key in the dict is a scoring function name
|
# each key in the dict is a scoring function name
|
||||||
results: Dict[str, ScoringResult]
|
results: dict[str, ScoringResult]
|
||||||
|
|
||||||
|
|
||||||
class ScoringFunctionStore(Protocol):
|
class ScoringFunctionStore(Protocol):
|
||||||
|
@ -59,15 +59,15 @@ class Scoring(Protocol):
|
||||||
async def score_batch(
|
async def score_batch(
|
||||||
self,
|
self,
|
||||||
dataset_id: str,
|
dataset_id: str,
|
||||||
scoring_functions: Dict[str, Optional[ScoringFnParams]],
|
scoring_functions: dict[str, ScoringFnParams | None],
|
||||||
save_results_dataset: bool = False,
|
save_results_dataset: bool = False,
|
||||||
) -> ScoreBatchResponse: ...
|
) -> ScoreBatchResponse: ...
|
||||||
|
|
||||||
@webmethod(route="/scoring/score", method="POST")
|
@webmethod(route="/scoring/score", method="POST")
|
||||||
async def score(
|
async def score(
|
||||||
self,
|
self,
|
||||||
input_rows: List[Dict[str, Any]],
|
input_rows: list[dict[str, Any]],
|
||||||
scoring_functions: Dict[str, Optional[ScoringFnParams]],
|
scoring_functions: dict[str, ScoringFnParams | None],
|
||||||
) -> ScoreResponse:
|
) -> ScoreResponse:
|
||||||
"""Score a list of rows.
|
"""Score a list of rows.
|
||||||
|
|
||||||
|
|
|
@ -6,18 +6,14 @@
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import (
|
from typing import (
|
||||||
|
Annotated,
|
||||||
Any,
|
Any,
|
||||||
Dict,
|
|
||||||
List,
|
|
||||||
Literal,
|
Literal,
|
||||||
Optional,
|
|
||||||
Protocol,
|
Protocol,
|
||||||
Union,
|
|
||||||
runtime_checkable,
|
runtime_checkable,
|
||||||
)
|
)
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing_extensions import Annotated
|
|
||||||
|
|
||||||
from llama_stack.apis.common.type_system import ParamType
|
from llama_stack.apis.common.type_system import ParamType
|
||||||
from llama_stack.apis.resource import Resource, ResourceType
|
from llama_stack.apis.resource import Resource, ResourceType
|
||||||
|
@ -46,12 +42,12 @@ class AggregationFunctionType(Enum):
|
||||||
class LLMAsJudgeScoringFnParams(BaseModel):
|
class LLMAsJudgeScoringFnParams(BaseModel):
|
||||||
type: Literal[ScoringFnParamsType.llm_as_judge.value] = ScoringFnParamsType.llm_as_judge.value
|
type: Literal[ScoringFnParamsType.llm_as_judge.value] = ScoringFnParamsType.llm_as_judge.value
|
||||||
judge_model: str
|
judge_model: str
|
||||||
prompt_template: Optional[str] = None
|
prompt_template: str | None = None
|
||||||
judge_score_regexes: Optional[List[str]] = Field(
|
judge_score_regexes: list[str] | None = Field(
|
||||||
description="Regexes to extract the answer from generated response",
|
description="Regexes to extract the answer from generated response",
|
||||||
default_factory=list,
|
default_factory=list,
|
||||||
)
|
)
|
||||||
aggregation_functions: Optional[List[AggregationFunctionType]] = Field(
|
aggregation_functions: list[AggregationFunctionType] | None = Field(
|
||||||
description="Aggregation functions to apply to the scores of each row",
|
description="Aggregation functions to apply to the scores of each row",
|
||||||
default_factory=list,
|
default_factory=list,
|
||||||
)
|
)
|
||||||
|
@ -60,11 +56,11 @@ class LLMAsJudgeScoringFnParams(BaseModel):
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class RegexParserScoringFnParams(BaseModel):
|
class RegexParserScoringFnParams(BaseModel):
|
||||||
type: Literal[ScoringFnParamsType.regex_parser.value] = ScoringFnParamsType.regex_parser.value
|
type: Literal[ScoringFnParamsType.regex_parser.value] = ScoringFnParamsType.regex_parser.value
|
||||||
parsing_regexes: Optional[List[str]] = Field(
|
parsing_regexes: list[str] | None = Field(
|
||||||
description="Regex to extract the answer from generated response",
|
description="Regex to extract the answer from generated response",
|
||||||
default_factory=list,
|
default_factory=list,
|
||||||
)
|
)
|
||||||
aggregation_functions: Optional[List[AggregationFunctionType]] = Field(
|
aggregation_functions: list[AggregationFunctionType] | None = Field(
|
||||||
description="Aggregation functions to apply to the scores of each row",
|
description="Aggregation functions to apply to the scores of each row",
|
||||||
default_factory=list,
|
default_factory=list,
|
||||||
)
|
)
|
||||||
|
@ -73,33 +69,29 @@ class RegexParserScoringFnParams(BaseModel):
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class BasicScoringFnParams(BaseModel):
|
class BasicScoringFnParams(BaseModel):
|
||||||
type: Literal[ScoringFnParamsType.basic.value] = ScoringFnParamsType.basic.value
|
type: Literal[ScoringFnParamsType.basic.value] = ScoringFnParamsType.basic.value
|
||||||
aggregation_functions: Optional[List[AggregationFunctionType]] = Field(
|
aggregation_functions: list[AggregationFunctionType] | None = Field(
|
||||||
description="Aggregation functions to apply to the scores of each row",
|
description="Aggregation functions to apply to the scores of each row",
|
||||||
default_factory=list,
|
default_factory=list,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
ScoringFnParams = Annotated[
|
ScoringFnParams = Annotated[
|
||||||
Union[
|
LLMAsJudgeScoringFnParams | RegexParserScoringFnParams | BasicScoringFnParams,
|
||||||
LLMAsJudgeScoringFnParams,
|
|
||||||
RegexParserScoringFnParams,
|
|
||||||
BasicScoringFnParams,
|
|
||||||
],
|
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
register_schema(ScoringFnParams, name="ScoringFnParams")
|
register_schema(ScoringFnParams, name="ScoringFnParams")
|
||||||
|
|
||||||
|
|
||||||
class CommonScoringFnFields(BaseModel):
|
class CommonScoringFnFields(BaseModel):
|
||||||
description: Optional[str] = None
|
description: str | None = None
|
||||||
metadata: Dict[str, Any] = Field(
|
metadata: dict[str, Any] = Field(
|
||||||
default_factory=dict,
|
default_factory=dict,
|
||||||
description="Any additional metadata for this definition",
|
description="Any additional metadata for this definition",
|
||||||
)
|
)
|
||||||
return_type: ParamType = Field(
|
return_type: ParamType = Field(
|
||||||
description="The return type of the deterministic function",
|
description="The return type of the deterministic function",
|
||||||
)
|
)
|
||||||
params: Optional[ScoringFnParams] = Field(
|
params: ScoringFnParams | None = Field(
|
||||||
description="The parameters for the scoring function for benchmark eval, these can be overridden for app eval",
|
description="The parameters for the scoring function for benchmark eval, these can be overridden for app eval",
|
||||||
default=None,
|
default=None,
|
||||||
)
|
)
|
||||||
|
@ -120,12 +112,12 @@ class ScoringFn(CommonScoringFnFields, Resource):
|
||||||
|
|
||||||
class ScoringFnInput(CommonScoringFnFields, BaseModel):
|
class ScoringFnInput(CommonScoringFnFields, BaseModel):
|
||||||
scoring_fn_id: str
|
scoring_fn_id: str
|
||||||
provider_id: Optional[str] = None
|
provider_id: str | None = None
|
||||||
provider_scoring_fn_id: Optional[str] = None
|
provider_scoring_fn_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class ListScoringFunctionsResponse(BaseModel):
|
class ListScoringFunctionsResponse(BaseModel):
|
||||||
data: List[ScoringFn]
|
data: list[ScoringFn]
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
|
@ -142,7 +134,7 @@ class ScoringFunctions(Protocol):
|
||||||
scoring_fn_id: str,
|
scoring_fn_id: str,
|
||||||
description: str,
|
description: str,
|
||||||
return_type: ParamType,
|
return_type: ParamType,
|
||||||
provider_scoring_fn_id: Optional[str] = None,
|
provider_scoring_fn_id: str | None = None,
|
||||||
provider_id: Optional[str] = None,
|
provider_id: str | None = None,
|
||||||
params: Optional[ScoringFnParams] = None,
|
params: ScoringFnParams | None = None,
|
||||||
) -> None: ...
|
) -> None: ...
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
|
from typing import Any, Literal, Protocol, runtime_checkable
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
@ -14,7 +14,7 @@ from llama_stack.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
|
|
||||||
class CommonShieldFields(BaseModel):
|
class CommonShieldFields(BaseModel):
|
||||||
params: Optional[Dict[str, Any]] = None
|
params: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -34,12 +34,12 @@ class Shield(CommonShieldFields, Resource):
|
||||||
|
|
||||||
class ShieldInput(CommonShieldFields):
|
class ShieldInput(CommonShieldFields):
|
||||||
shield_id: str
|
shield_id: str
|
||||||
provider_id: Optional[str] = None
|
provider_id: str | None = None
|
||||||
provider_shield_id: Optional[str] = None
|
provider_shield_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class ListShieldsResponse(BaseModel):
|
class ListShieldsResponse(BaseModel):
|
||||||
data: List[Shield]
|
data: list[Shield]
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
|
@ -55,7 +55,7 @@ class Shields(Protocol):
|
||||||
async def register_shield(
|
async def register_shield(
|
||||||
self,
|
self,
|
||||||
shield_id: str,
|
shield_id: str,
|
||||||
provider_shield_id: Optional[str] = None,
|
provider_shield_id: str | None = None,
|
||||||
provider_id: Optional[str] = None,
|
provider_id: str | None = None,
|
||||||
params: Optional[Dict[str, Any]] = None,
|
params: dict[str, Any] | None = None,
|
||||||
) -> Shield: ...
|
) -> Shield: ...
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, Optional, Protocol, Union
|
from typing import Any, Protocol
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
@ -28,24 +28,24 @@ class FilteringFunction(Enum):
|
||||||
class SyntheticDataGenerationRequest(BaseModel):
|
class SyntheticDataGenerationRequest(BaseModel):
|
||||||
"""Request to generate synthetic data. A small batch of prompts and a filtering function"""
|
"""Request to generate synthetic data. A small batch of prompts and a filtering function"""
|
||||||
|
|
||||||
dialogs: List[Message]
|
dialogs: list[Message]
|
||||||
filtering_function: FilteringFunction = FilteringFunction.none
|
filtering_function: FilteringFunction = FilteringFunction.none
|
||||||
model: Optional[str] = None
|
model: str | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class SyntheticDataGenerationResponse(BaseModel):
|
class SyntheticDataGenerationResponse(BaseModel):
|
||||||
"""Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold."""
|
"""Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold."""
|
||||||
|
|
||||||
synthetic_data: List[Dict[str, Any]]
|
synthetic_data: list[dict[str, Any]]
|
||||||
statistics: Optional[Dict[str, Any]] = None
|
statistics: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
class SyntheticDataGeneration(Protocol):
|
class SyntheticDataGeneration(Protocol):
|
||||||
@webmethod(route="/synthetic-data-generation/generate")
|
@webmethod(route="/synthetic-data-generation/generate")
|
||||||
def synthetic_data_generate(
|
def synthetic_data_generate(
|
||||||
self,
|
self,
|
||||||
dialogs: List[Message],
|
dialogs: list[Message],
|
||||||
filtering_function: FilteringFunction = FilteringFunction.none,
|
filtering_function: FilteringFunction = FilteringFunction.none,
|
||||||
model: Optional[str] = None,
|
model: str | None = None,
|
||||||
) -> Union[SyntheticDataGenerationResponse]: ...
|
) -> SyntheticDataGenerationResponse: ...
|
||||||
|
|
|
@ -7,18 +7,14 @@
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import (
|
from typing import (
|
||||||
|
Annotated,
|
||||||
Any,
|
Any,
|
||||||
Dict,
|
|
||||||
List,
|
|
||||||
Literal,
|
Literal,
|
||||||
Optional,
|
|
||||||
Protocol,
|
Protocol,
|
||||||
Union,
|
|
||||||
runtime_checkable,
|
runtime_checkable,
|
||||||
)
|
)
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing_extensions import Annotated
|
|
||||||
|
|
||||||
from llama_stack.models.llama.datatypes import Primitive
|
from llama_stack.models.llama.datatypes import Primitive
|
||||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||||
|
@ -37,11 +33,11 @@ class SpanStatus(Enum):
|
||||||
class Span(BaseModel):
|
class Span(BaseModel):
|
||||||
span_id: str
|
span_id: str
|
||||||
trace_id: str
|
trace_id: str
|
||||||
parent_span_id: Optional[str] = None
|
parent_span_id: str | None = None
|
||||||
name: str
|
name: str
|
||||||
start_time: datetime
|
start_time: datetime
|
||||||
end_time: Optional[datetime] = None
|
end_time: datetime | None = None
|
||||||
attributes: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
attributes: dict[str, Any] | None = Field(default_factory=dict)
|
||||||
|
|
||||||
def set_attribute(self, key: str, value: Any):
|
def set_attribute(self, key: str, value: Any):
|
||||||
if self.attributes is None:
|
if self.attributes is None:
|
||||||
|
@ -54,7 +50,7 @@ class Trace(BaseModel):
|
||||||
trace_id: str
|
trace_id: str
|
||||||
root_span_id: str
|
root_span_id: str
|
||||||
start_time: datetime
|
start_time: datetime
|
||||||
end_time: Optional[datetime] = None
|
end_time: datetime | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -78,7 +74,7 @@ class EventCommon(BaseModel):
|
||||||
trace_id: str
|
trace_id: str
|
||||||
span_id: str
|
span_id: str
|
||||||
timestamp: datetime
|
timestamp: datetime
|
||||||
attributes: Optional[Dict[str, Primitive]] = Field(default_factory=dict)
|
attributes: dict[str, Primitive] | None = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -92,15 +88,15 @@ class UnstructuredLogEvent(EventCommon):
|
||||||
class MetricEvent(EventCommon):
|
class MetricEvent(EventCommon):
|
||||||
type: Literal[EventType.METRIC.value] = EventType.METRIC.value
|
type: Literal[EventType.METRIC.value] = EventType.METRIC.value
|
||||||
metric: str # this would be an enum
|
metric: str # this would be an enum
|
||||||
value: Union[int, float]
|
value: int | float
|
||||||
unit: str
|
unit: str
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class MetricInResponse(BaseModel):
|
class MetricInResponse(BaseModel):
|
||||||
metric: str
|
metric: str
|
||||||
value: Union[int, float]
|
value: int | float
|
||||||
unit: Optional[str] = None
|
unit: str | None = None
|
||||||
|
|
||||||
|
|
||||||
# This is a short term solution to allow inference API to return metrics
|
# This is a short term solution to allow inference API to return metrics
|
||||||
|
@ -124,7 +120,7 @@ class MetricInResponse(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class MetricResponseMixin(BaseModel):
|
class MetricResponseMixin(BaseModel):
|
||||||
metrics: Optional[List[MetricInResponse]] = None
|
metrics: list[MetricInResponse] | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -137,7 +133,7 @@ class StructuredLogType(Enum):
|
||||||
class SpanStartPayload(BaseModel):
|
class SpanStartPayload(BaseModel):
|
||||||
type: Literal[StructuredLogType.SPAN_START.value] = StructuredLogType.SPAN_START.value
|
type: Literal[StructuredLogType.SPAN_START.value] = StructuredLogType.SPAN_START.value
|
||||||
name: str
|
name: str
|
||||||
parent_span_id: Optional[str] = None
|
parent_span_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -147,10 +143,7 @@ class SpanEndPayload(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
StructuredLogPayload = Annotated[
|
StructuredLogPayload = Annotated[
|
||||||
Union[
|
SpanStartPayload | SpanEndPayload,
|
||||||
SpanStartPayload,
|
|
||||||
SpanEndPayload,
|
|
||||||
],
|
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
register_schema(StructuredLogPayload, name="StructuredLogPayload")
|
register_schema(StructuredLogPayload, name="StructuredLogPayload")
|
||||||
|
@ -163,11 +156,7 @@ class StructuredLogEvent(EventCommon):
|
||||||
|
|
||||||
|
|
||||||
Event = Annotated[
|
Event = Annotated[
|
||||||
Union[
|
UnstructuredLogEvent | MetricEvent | StructuredLogEvent,
|
||||||
UnstructuredLogEvent,
|
|
||||||
MetricEvent,
|
|
||||||
StructuredLogEvent,
|
|
||||||
],
|
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
register_schema(Event, name="Event")
|
register_schema(Event, name="Event")
|
||||||
|
@ -184,7 +173,7 @@ class EvalTrace(BaseModel):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class SpanWithStatus(Span):
|
class SpanWithStatus(Span):
|
||||||
status: Optional[SpanStatus] = None
|
status: SpanStatus | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -203,15 +192,15 @@ class QueryCondition(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class QueryTracesResponse(BaseModel):
|
class QueryTracesResponse(BaseModel):
|
||||||
data: List[Trace]
|
data: list[Trace]
|
||||||
|
|
||||||
|
|
||||||
class QuerySpansResponse(BaseModel):
|
class QuerySpansResponse(BaseModel):
|
||||||
data: List[Span]
|
data: list[Span]
|
||||||
|
|
||||||
|
|
||||||
class QuerySpanTreeResponse(BaseModel):
|
class QuerySpanTreeResponse(BaseModel):
|
||||||
data: Dict[str, SpanWithStatus]
|
data: dict[str, SpanWithStatus]
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
|
@ -222,10 +211,10 @@ class Telemetry(Protocol):
|
||||||
@webmethod(route="/telemetry/traces", method="POST")
|
@webmethod(route="/telemetry/traces", method="POST")
|
||||||
async def query_traces(
|
async def query_traces(
|
||||||
self,
|
self,
|
||||||
attribute_filters: Optional[List[QueryCondition]] = None,
|
attribute_filters: list[QueryCondition] | None = None,
|
||||||
limit: Optional[int] = 100,
|
limit: int | None = 100,
|
||||||
offset: Optional[int] = 0,
|
offset: int | None = 0,
|
||||||
order_by: Optional[List[str]] = None,
|
order_by: list[str] | None = None,
|
||||||
) -> QueryTracesResponse: ...
|
) -> QueryTracesResponse: ...
|
||||||
|
|
||||||
@webmethod(route="/telemetry/traces/{trace_id:path}", method="GET")
|
@webmethod(route="/telemetry/traces/{trace_id:path}", method="GET")
|
||||||
|
@ -238,23 +227,23 @@ class Telemetry(Protocol):
|
||||||
async def get_span_tree(
|
async def get_span_tree(
|
||||||
self,
|
self,
|
||||||
span_id: str,
|
span_id: str,
|
||||||
attributes_to_return: Optional[List[str]] = None,
|
attributes_to_return: list[str] | None = None,
|
||||||
max_depth: Optional[int] = None,
|
max_depth: int | None = None,
|
||||||
) -> QuerySpanTreeResponse: ...
|
) -> QuerySpanTreeResponse: ...
|
||||||
|
|
||||||
@webmethod(route="/telemetry/spans", method="POST")
|
@webmethod(route="/telemetry/spans", method="POST")
|
||||||
async def query_spans(
|
async def query_spans(
|
||||||
self,
|
self,
|
||||||
attribute_filters: List[QueryCondition],
|
attribute_filters: list[QueryCondition],
|
||||||
attributes_to_return: List[str],
|
attributes_to_return: list[str],
|
||||||
max_depth: Optional[int] = None,
|
max_depth: int | None = None,
|
||||||
) -> QuerySpansResponse: ...
|
) -> QuerySpansResponse: ...
|
||||||
|
|
||||||
@webmethod(route="/telemetry/spans/export", method="POST")
|
@webmethod(route="/telemetry/spans/export", method="POST")
|
||||||
async def save_spans_to_dataset(
|
async def save_spans_to_dataset(
|
||||||
self,
|
self,
|
||||||
attribute_filters: List[QueryCondition],
|
attribute_filters: list[QueryCondition],
|
||||||
attributes_to_save: List[str],
|
attributes_to_save: list[str],
|
||||||
dataset_id: str,
|
dataset_id: str,
|
||||||
max_depth: Optional[int] = None,
|
max_depth: int | None = None,
|
||||||
) -> None: ...
|
) -> None: ...
|
||||||
|
|
|
@ -5,10 +5,10 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, Literal, Optional, Union
|
from typing import Annotated, Any, Literal
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing_extensions import Annotated, Protocol, runtime_checkable
|
from typing_extensions import Protocol, runtime_checkable
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import URL, InterleavedContent
|
from llama_stack.apis.common.content_types import URL, InterleavedContent
|
||||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||||
|
@ -29,13 +29,13 @@ class RAGDocument(BaseModel):
|
||||||
document_id: str
|
document_id: str
|
||||||
content: InterleavedContent | URL
|
content: InterleavedContent | URL
|
||||||
mime_type: str | None = None
|
mime_type: str | None = None
|
||||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class RAGQueryResult(BaseModel):
|
class RAGQueryResult(BaseModel):
|
||||||
content: Optional[InterleavedContent] = None
|
content: InterleavedContent | None = None
|
||||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -59,10 +59,7 @@ class LLMRAGQueryGeneratorConfig(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
RAGQueryGeneratorConfig = Annotated[
|
RAGQueryGeneratorConfig = Annotated[
|
||||||
Union[
|
DefaultRAGQueryGeneratorConfig | LLMRAGQueryGeneratorConfig,
|
||||||
DefaultRAGQueryGeneratorConfig,
|
|
||||||
LLMRAGQueryGeneratorConfig,
|
|
||||||
],
|
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
register_schema(RAGQueryGeneratorConfig, name="RAGQueryGeneratorConfig")
|
register_schema(RAGQueryGeneratorConfig, name="RAGQueryGeneratorConfig")
|
||||||
|
@ -83,7 +80,7 @@ class RAGToolRuntime(Protocol):
|
||||||
@webmethod(route="/tool-runtime/rag-tool/insert", method="POST")
|
@webmethod(route="/tool-runtime/rag-tool/insert", method="POST")
|
||||||
async def insert(
|
async def insert(
|
||||||
self,
|
self,
|
||||||
documents: List[RAGDocument],
|
documents: list[RAGDocument],
|
||||||
vector_db_id: str,
|
vector_db_id: str,
|
||||||
chunk_size_in_tokens: int = 512,
|
chunk_size_in_tokens: int = 512,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -94,8 +91,8 @@ class RAGToolRuntime(Protocol):
|
||||||
async def query(
|
async def query(
|
||||||
self,
|
self,
|
||||||
content: InterleavedContent,
|
content: InterleavedContent,
|
||||||
vector_db_ids: List[str],
|
vector_db_ids: list[str],
|
||||||
query_config: Optional[RAGQueryConfig] = None,
|
query_config: RAGQueryConfig | None = None,
|
||||||
) -> RAGQueryResult:
|
) -> RAGQueryResult:
|
||||||
"""Query the RAG system for context; typically invoked by the agent"""
|
"""Query the RAG system for context; typically invoked by the agent"""
|
||||||
...
|
...
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, Literal, Optional
|
from typing import Any, Literal
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing_extensions import Protocol, runtime_checkable
|
from typing_extensions import Protocol, runtime_checkable
|
||||||
|
@ -24,7 +24,7 @@ class ToolParameter(BaseModel):
|
||||||
parameter_type: str
|
parameter_type: str
|
||||||
description: str
|
description: str
|
||||||
required: bool = Field(default=True)
|
required: bool = Field(default=True)
|
||||||
default: Optional[Any] = None
|
default: Any | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -40,39 +40,39 @@ class Tool(Resource):
|
||||||
toolgroup_id: str
|
toolgroup_id: str
|
||||||
tool_host: ToolHost
|
tool_host: ToolHost
|
||||||
description: str
|
description: str
|
||||||
parameters: List[ToolParameter]
|
parameters: list[ToolParameter]
|
||||||
metadata: Optional[Dict[str, Any]] = None
|
metadata: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ToolDef(BaseModel):
|
class ToolDef(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
description: Optional[str] = None
|
description: str | None = None
|
||||||
parameters: Optional[List[ToolParameter]] = None
|
parameters: list[ToolParameter] | None = None
|
||||||
metadata: Optional[Dict[str, Any]] = None
|
metadata: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ToolGroupInput(BaseModel):
|
class ToolGroupInput(BaseModel):
|
||||||
toolgroup_id: str
|
toolgroup_id: str
|
||||||
provider_id: str
|
provider_id: str
|
||||||
args: Optional[Dict[str, Any]] = None
|
args: dict[str, Any] | None = None
|
||||||
mcp_endpoint: Optional[URL] = None
|
mcp_endpoint: URL | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ToolGroup(Resource):
|
class ToolGroup(Resource):
|
||||||
type: Literal[ResourceType.tool_group.value] = ResourceType.tool_group.value
|
type: Literal[ResourceType.tool_group.value] = ResourceType.tool_group.value
|
||||||
mcp_endpoint: Optional[URL] = None
|
mcp_endpoint: URL | None = None
|
||||||
args: Optional[Dict[str, Any]] = None
|
args: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ToolInvocationResult(BaseModel):
|
class ToolInvocationResult(BaseModel):
|
||||||
content: Optional[InterleavedContent] = None
|
content: InterleavedContent | None = None
|
||||||
error_message: Optional[str] = None
|
error_message: str | None = None
|
||||||
error_code: Optional[int] = None
|
error_code: int | None = None
|
||||||
metadata: Optional[Dict[str, Any]] = None
|
metadata: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
class ToolStore(Protocol):
|
class ToolStore(Protocol):
|
||||||
|
@ -81,11 +81,11 @@ class ToolStore(Protocol):
|
||||||
|
|
||||||
|
|
||||||
class ListToolGroupsResponse(BaseModel):
|
class ListToolGroupsResponse(BaseModel):
|
||||||
data: List[ToolGroup]
|
data: list[ToolGroup]
|
||||||
|
|
||||||
|
|
||||||
class ListToolsResponse(BaseModel):
|
class ListToolsResponse(BaseModel):
|
||||||
data: List[Tool]
|
data: list[Tool]
|
||||||
|
|
||||||
|
|
||||||
class ListToolDefsResponse(BaseModel):
|
class ListToolDefsResponse(BaseModel):
|
||||||
|
@ -100,8 +100,8 @@ class ToolGroups(Protocol):
|
||||||
self,
|
self,
|
||||||
toolgroup_id: str,
|
toolgroup_id: str,
|
||||||
provider_id: str,
|
provider_id: str,
|
||||||
mcp_endpoint: Optional[URL] = None,
|
mcp_endpoint: URL | None = None,
|
||||||
args: Optional[Dict[str, Any]] = None,
|
args: dict[str, Any] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Register a tool group"""
|
"""Register a tool group"""
|
||||||
...
|
...
|
||||||
|
@ -118,7 +118,7 @@ class ToolGroups(Protocol):
|
||||||
...
|
...
|
||||||
|
|
||||||
@webmethod(route="/tools", method="GET")
|
@webmethod(route="/tools", method="GET")
|
||||||
async def list_tools(self, toolgroup_id: Optional[str] = None) -> ListToolsResponse:
|
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse:
|
||||||
"""List tools with optional tool group"""
|
"""List tools with optional tool group"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@ -151,10 +151,10 @@ class ToolRuntime(Protocol):
|
||||||
# TODO: This needs to be renamed once OPEN API generator name conflict issue is fixed.
|
# TODO: This needs to be renamed once OPEN API generator name conflict issue is fixed.
|
||||||
@webmethod(route="/tool-runtime/list-tools", method="GET")
|
@webmethod(route="/tool-runtime/list-tools", method="GET")
|
||||||
async def list_runtime_tools(
|
async def list_runtime_tools(
|
||||||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
|
||||||
) -> ListToolDefsResponse: ...
|
) -> ListToolDefsResponse: ...
|
||||||
|
|
||||||
@webmethod(route="/tool-runtime/invoke", method="POST")
|
@webmethod(route="/tool-runtime/invoke", method="POST")
|
||||||
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult:
|
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
|
||||||
"""Run a tool with the given arguments"""
|
"""Run a tool with the given arguments"""
|
||||||
...
|
...
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import List, Literal, Optional, Protocol, runtime_checkable
|
from typing import Literal, Protocol, runtime_checkable
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
@ -33,11 +33,11 @@ class VectorDBInput(BaseModel):
|
||||||
vector_db_id: str
|
vector_db_id: str
|
||||||
embedding_model: str
|
embedding_model: str
|
||||||
embedding_dimension: int
|
embedding_dimension: int
|
||||||
provider_vector_db_id: Optional[str] = None
|
provider_vector_db_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class ListVectorDBsResponse(BaseModel):
|
class ListVectorDBsResponse(BaseModel):
|
||||||
data: List[VectorDB]
|
data: list[VectorDB]
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
|
@ -57,9 +57,9 @@ class VectorDBs(Protocol):
|
||||||
self,
|
self,
|
||||||
vector_db_id: str,
|
vector_db_id: str,
|
||||||
embedding_model: str,
|
embedding_model: str,
|
||||||
embedding_dimension: Optional[int] = 384,
|
embedding_dimension: int | None = 384,
|
||||||
provider_id: Optional[str] = None,
|
provider_id: str | None = None,
|
||||||
provider_vector_db_id: Optional[str] = None,
|
provider_vector_db_id: str | None = None,
|
||||||
) -> VectorDB: ...
|
) -> VectorDB: ...
|
||||||
|
|
||||||
@webmethod(route="/vector-dbs/{vector_db_id:path}", method="DELETE")
|
@webmethod(route="/vector-dbs/{vector_db_id:path}", method="DELETE")
|
||||||
|
|
|
@ -8,7 +8,7 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
|
from typing import Any, Protocol, runtime_checkable
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
@ -20,17 +20,17 @@ from llama_stack.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
class Chunk(BaseModel):
|
class Chunk(BaseModel):
|
||||||
content: InterleavedContent
|
content: InterleavedContent
|
||||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class QueryChunksResponse(BaseModel):
|
class QueryChunksResponse(BaseModel):
|
||||||
chunks: List[Chunk]
|
chunks: list[Chunk]
|
||||||
scores: List[float]
|
scores: list[float]
|
||||||
|
|
||||||
|
|
||||||
class VectorDBStore(Protocol):
|
class VectorDBStore(Protocol):
|
||||||
def get_vector_db(self, vector_db_id: str) -> Optional[VectorDB]: ...
|
def get_vector_db(self, vector_db_id: str) -> VectorDB | None: ...
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
|
@ -44,8 +44,8 @@ class VectorIO(Protocol):
|
||||||
async def insert_chunks(
|
async def insert_chunks(
|
||||||
self,
|
self,
|
||||||
vector_db_id: str,
|
vector_db_id: str,
|
||||||
chunks: List[Chunk],
|
chunks: list[Chunk],
|
||||||
ttl_seconds: Optional[int] = None,
|
ttl_seconds: int | None = None,
|
||||||
) -> None: ...
|
) -> None: ...
|
||||||
|
|
||||||
@webmethod(route="/vector-io/query", method="POST")
|
@webmethod(route="/vector-io/query", method="POST")
|
||||||
|
@ -53,5 +53,5 @@ class VectorIO(Protocol):
|
||||||
self,
|
self,
|
||||||
vector_db_id: str,
|
vector_db_id: str,
|
||||||
query: InterleavedContent,
|
query: InterleavedContent,
|
||||||
params: Optional[Dict[str, Any]] = None,
|
params: dict[str, Any] | None = None,
|
||||||
) -> QueryChunksResponse: ...
|
) -> QueryChunksResponse: ...
|
||||||
|
|
|
@ -13,7 +13,6 @@ from dataclasses import dataclass
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional
|
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
@ -102,7 +101,7 @@ class DownloadTask:
|
||||||
output_file: str
|
output_file: str
|
||||||
total_size: int = 0
|
total_size: int = 0
|
||||||
downloaded_size: int = 0
|
downloaded_size: int = 0
|
||||||
task_id: Optional[int] = None
|
task_id: int | None = None
|
||||||
retries: int = 0
|
retries: int = 0
|
||||||
max_retries: int = 3
|
max_retries: int = 3
|
||||||
|
|
||||||
|
@ -262,7 +261,7 @@ class ParallelDownloader:
|
||||||
self.progress.update(task.task_id, description=f"[red]Failed: {task.output_file}[/red]")
|
self.progress.update(task.task_id, description=f"[red]Failed: {task.output_file}[/red]")
|
||||||
raise DownloadError(f"Download failed for {task.output_file}: {str(e)}") from e
|
raise DownloadError(f"Download failed for {task.output_file}: {str(e)}") from e
|
||||||
|
|
||||||
def has_disk_space(self, tasks: List[DownloadTask]) -> bool:
|
def has_disk_space(self, tasks: list[DownloadTask]) -> bool:
|
||||||
try:
|
try:
|
||||||
total_remaining_size = sum(task.total_size - task.downloaded_size for task in tasks)
|
total_remaining_size = sum(task.total_size - task.downloaded_size for task in tasks)
|
||||||
dir_path = os.path.dirname(os.path.abspath(tasks[0].output_file))
|
dir_path = os.path.dirname(os.path.abspath(tasks[0].output_file))
|
||||||
|
@ -282,7 +281,7 @@ class ParallelDownloader:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise DownloadError(f"Failed to check disk space: {str(e)}") from e
|
raise DownloadError(f"Failed to check disk space: {str(e)}") from e
|
||||||
|
|
||||||
async def download_all(self, tasks: List[DownloadTask]) -> None:
|
async def download_all(self, tasks: list[DownloadTask]) -> None:
|
||||||
if not tasks:
|
if not tasks:
|
||||||
raise ValueError("No download tasks provided")
|
raise ValueError("No download tasks provided")
|
||||||
|
|
||||||
|
@ -391,20 +390,20 @@ def _meta_download(
|
||||||
|
|
||||||
class ModelEntry(BaseModel):
|
class ModelEntry(BaseModel):
|
||||||
model_id: str
|
model_id: str
|
||||||
files: Dict[str, str]
|
files: dict[str, str]
|
||||||
|
|
||||||
model_config = ConfigDict(protected_namespaces=())
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
|
||||||
|
|
||||||
class Manifest(BaseModel):
|
class Manifest(BaseModel):
|
||||||
models: List[ModelEntry]
|
models: list[ModelEntry]
|
||||||
expires_on: datetime
|
expires_on: datetime
|
||||||
|
|
||||||
|
|
||||||
def _download_from_manifest(manifest_file: str, max_concurrent_downloads: int):
|
def _download_from_manifest(manifest_file: str, max_concurrent_downloads: int):
|
||||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||||
|
|
||||||
with open(manifest_file, "r") as f:
|
with open(manifest_file) as f:
|
||||||
d = json.load(f)
|
d = json.load(f)
|
||||||
manifest = Manifest(**d)
|
manifest = Manifest(**d)
|
||||||
|
|
||||||
|
@ -460,15 +459,17 @@ def run_download_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser):
|
||||||
from llama_stack.models.llama.sku_list import llama_meta_net_info, resolve_model
|
from llama_stack.models.llama.sku_list import llama_meta_net_info, resolve_model
|
||||||
|
|
||||||
from .model.safety_models import (
|
from .model.safety_models import (
|
||||||
prompt_guard_download_info,
|
prompt_guard_download_info_map,
|
||||||
prompt_guard_model_sku,
|
prompt_guard_model_sku_map,
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt_guard = prompt_guard_model_sku()
|
prompt_guard_model_sku_map = prompt_guard_model_sku_map()
|
||||||
|
prompt_guard_download_info_map = prompt_guard_download_info_map()
|
||||||
|
|
||||||
for model_id in model_ids:
|
for model_id in model_ids:
|
||||||
if model_id == prompt_guard.model_id:
|
if model_id in prompt_guard_model_sku_map.keys():
|
||||||
model = prompt_guard
|
model = prompt_guard_model_sku_map[model_id]
|
||||||
info = prompt_guard_download_info()
|
info = prompt_guard_download_info_map[model_id]
|
||||||
else:
|
else:
|
||||||
model = resolve_model(model_id)
|
model = resolve_model(model_id)
|
||||||
if model is None:
|
if model is None:
|
||||||
|
|
|
@ -36,11 +36,11 @@ class ModelDescribe(Subcommand):
|
||||||
)
|
)
|
||||||
|
|
||||||
def _run_model_describe_cmd(self, args: argparse.Namespace) -> None:
|
def _run_model_describe_cmd(self, args: argparse.Namespace) -> None:
|
||||||
from .safety_models import prompt_guard_model_sku
|
from .safety_models import prompt_guard_model_sku_map
|
||||||
|
|
||||||
prompt_guard = prompt_guard_model_sku()
|
prompt_guard_model_map = prompt_guard_model_sku_map()
|
||||||
if args.model_id == prompt_guard.model_id:
|
if args.model_id in prompt_guard_model_map.keys():
|
||||||
model = prompt_guard
|
model = prompt_guard_model_map[args.model_id]
|
||||||
else:
|
else:
|
||||||
model = resolve_model(args.model_id)
|
model = resolve_model(args.model_id)
|
||||||
|
|
||||||
|
|
|
@ -84,7 +84,7 @@ class ModelList(Subcommand):
|
||||||
)
|
)
|
||||||
|
|
||||||
def _run_model_list_cmd(self, args: argparse.Namespace) -> None:
|
def _run_model_list_cmd(self, args: argparse.Namespace) -> None:
|
||||||
from .safety_models import prompt_guard_model_sku
|
from .safety_models import prompt_guard_model_skus
|
||||||
|
|
||||||
if args.downloaded:
|
if args.downloaded:
|
||||||
return _run_model_list_downloaded_cmd()
|
return _run_model_list_downloaded_cmd()
|
||||||
|
@ -96,7 +96,7 @@ class ModelList(Subcommand):
|
||||||
]
|
]
|
||||||
|
|
||||||
rows = []
|
rows = []
|
||||||
for model in all_registered_models() + [prompt_guard_model_sku()]:
|
for model in all_registered_models() + prompt_guard_model_skus():
|
||||||
if not args.show_all and not model.is_featured:
|
if not args.show_all and not model.is_featured:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
|
@ -42,11 +42,12 @@ class ModelRemove(Subcommand):
|
||||||
)
|
)
|
||||||
|
|
||||||
def _run_model_remove_cmd(self, args: argparse.Namespace) -> None:
|
def _run_model_remove_cmd(self, args: argparse.Namespace) -> None:
|
||||||
from .safety_models import prompt_guard_model_sku
|
from .safety_models import prompt_guard_model_sku_map
|
||||||
|
|
||||||
prompt_guard = prompt_guard_model_sku()
|
prompt_guard_model_map = prompt_guard_model_sku_map()
|
||||||
if args.model == prompt_guard.model_id:
|
|
||||||
model = prompt_guard
|
if args.model in prompt_guard_model_map.keys():
|
||||||
|
model = prompt_guard_model_map[args.model]
|
||||||
else:
|
else:
|
||||||
model = resolve_model(args.model)
|
model = resolve_model(args.model)
|
||||||
|
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Any, Dict
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
|
@ -15,14 +15,14 @@ from llama_stack.models.llama.sku_types import CheckpointQuantizationFormat
|
||||||
class PromptGuardModel(BaseModel):
|
class PromptGuardModel(BaseModel):
|
||||||
"""Make a 'fake' Model-like object for Prompt Guard. Eventually this will be removed."""
|
"""Make a 'fake' Model-like object for Prompt Guard. Eventually this will be removed."""
|
||||||
|
|
||||||
model_id: str = "Prompt-Guard-86M"
|
model_id: str
|
||||||
|
huggingface_repo: str
|
||||||
description: str = "Prompt Guard. NOTE: this model will not be provided via `llama` CLI soon."
|
description: str = "Prompt Guard. NOTE: this model will not be provided via `llama` CLI soon."
|
||||||
is_featured: bool = False
|
is_featured: bool = False
|
||||||
huggingface_repo: str = "meta-llama/Prompt-Guard-86M"
|
max_seq_length: int = 512
|
||||||
max_seq_length: int = 2048
|
|
||||||
is_instruct_model: bool = False
|
is_instruct_model: bool = False
|
||||||
quantization_format: CheckpointQuantizationFormat = CheckpointQuantizationFormat.bf16
|
quantization_format: CheckpointQuantizationFormat = CheckpointQuantizationFormat.bf16
|
||||||
arch_args: Dict[str, Any] = Field(default_factory=dict)
|
arch_args: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
def descriptor(self) -> str:
|
def descriptor(self) -> str:
|
||||||
return self.model_id
|
return self.model_id
|
||||||
|
@ -30,18 +30,35 @@ class PromptGuardModel(BaseModel):
|
||||||
model_config = ConfigDict(protected_namespaces=())
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
|
||||||
|
|
||||||
def prompt_guard_model_sku():
|
def prompt_guard_model_skus():
|
||||||
return PromptGuardModel()
|
return [
|
||||||
|
PromptGuardModel(model_id="Prompt-Guard-86M", huggingface_repo="meta-llama/Prompt-Guard-86M"),
|
||||||
|
PromptGuardModel(
|
||||||
|
model_id="Llama-Prompt-Guard-2-86M",
|
||||||
|
huggingface_repo="meta-llama/Llama-Prompt-Guard-2-86M",
|
||||||
|
),
|
||||||
|
PromptGuardModel(
|
||||||
|
model_id="Llama-Prompt-Guard-2-22M",
|
||||||
|
huggingface_repo="meta-llama/Llama-Prompt-Guard-2-22M",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def prompt_guard_download_info():
|
def prompt_guard_model_sku_map() -> dict[str, Any]:
|
||||||
return LlamaDownloadInfo(
|
return {model.model_id: model for model in prompt_guard_model_skus()}
|
||||||
folder="Prompt-Guard",
|
|
||||||
files=[
|
|
||||||
"model.safetensors",
|
def prompt_guard_download_info_map() -> dict[str, LlamaDownloadInfo]:
|
||||||
"special_tokens_map.json",
|
return {
|
||||||
"tokenizer.json",
|
model.model_id: LlamaDownloadInfo(
|
||||||
"tokenizer_config.json",
|
folder="Prompt-Guard" if model.model_id == "Prompt-Guard-86M" else model.model_id,
|
||||||
],
|
files=[
|
||||||
pth_size=1,
|
"model.safetensors",
|
||||||
)
|
"special_tokens_map.json",
|
||||||
|
"tokenizer.json",
|
||||||
|
"tokenizer_config.json",
|
||||||
|
],
|
||||||
|
pth_size=1,
|
||||||
|
)
|
||||||
|
for model in prompt_guard_model_skus()
|
||||||
|
}
|
||||||
|
|
|
@ -13,13 +13,12 @@ import sys
|
||||||
import textwrap
|
import textwrap
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Optional
|
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
from prompt_toolkit import prompt
|
from prompt_toolkit import prompt
|
||||||
from prompt_toolkit.completion import WordCompleter
|
from prompt_toolkit.completion import WordCompleter
|
||||||
from prompt_toolkit.validation import Validator
|
from prompt_toolkit.validation import Validator
|
||||||
from termcolor import cprint
|
from termcolor import colored, cprint
|
||||||
|
|
||||||
from llama_stack.cli.stack.utils import ImageType
|
from llama_stack.cli.stack.utils import ImageType
|
||||||
from llama_stack.cli.table import print_table
|
from llama_stack.cli.table import print_table
|
||||||
|
@ -46,14 +45,14 @@ from llama_stack.providers.datatypes import Api
|
||||||
TEMPLATES_PATH = Path(__file__).parent.parent.parent / "templates"
|
TEMPLATES_PATH = Path(__file__).parent.parent.parent / "templates"
|
||||||
|
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache
|
||||||
def available_templates_specs() -> Dict[str, BuildConfig]:
|
def available_templates_specs() -> dict[str, BuildConfig]:
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
template_specs = {}
|
template_specs = {}
|
||||||
for p in TEMPLATES_PATH.rglob("*build.yaml"):
|
for p in TEMPLATES_PATH.rglob("*build.yaml"):
|
||||||
template_name = p.parent.name
|
template_name = p.parent.name
|
||||||
with open(p, "r") as f:
|
with open(p) as f:
|
||||||
build_config = BuildConfig(**yaml.safe_load(f))
|
build_config = BuildConfig(**yaml.safe_load(f))
|
||||||
template_specs[template_name] = build_config
|
template_specs[template_name] = build_config
|
||||||
return template_specs
|
return template_specs
|
||||||
|
@ -89,6 +88,43 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
||||||
color="red",
|
color="red",
|
||||||
)
|
)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
elif args.providers:
|
||||||
|
providers = dict()
|
||||||
|
for api_provider in args.providers.split(","):
|
||||||
|
if "=" not in api_provider:
|
||||||
|
cprint(
|
||||||
|
"Could not parse `--providers`. Please ensure the list is in the format api1=provider1,api2=provider2",
|
||||||
|
color="red",
|
||||||
|
)
|
||||||
|
sys.exit(1)
|
||||||
|
api, provider = api_provider.split("=")
|
||||||
|
providers_for_api = get_provider_registry().get(Api(api), None)
|
||||||
|
if providers_for_api is None:
|
||||||
|
cprint(
|
||||||
|
f"{api} is not a valid API.",
|
||||||
|
color="red",
|
||||||
|
)
|
||||||
|
sys.exit(1)
|
||||||
|
if provider in providers_for_api:
|
||||||
|
providers.setdefault(api, []).append(provider)
|
||||||
|
else:
|
||||||
|
cprint(
|
||||||
|
f"{provider} is not a valid provider for the {api} API.",
|
||||||
|
color="red",
|
||||||
|
)
|
||||||
|
sys.exit(1)
|
||||||
|
distribution_spec = DistributionSpec(
|
||||||
|
providers=providers,
|
||||||
|
description=",".join(args.providers),
|
||||||
|
)
|
||||||
|
if not args.image_type:
|
||||||
|
cprint(
|
||||||
|
f"Please specify a image-type (container | conda | venv) for {args.template}",
|
||||||
|
color="red",
|
||||||
|
)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
build_config = BuildConfig(image_type=args.image_type, distribution_spec=distribution_spec)
|
||||||
elif not args.config and not args.template:
|
elif not args.config and not args.template:
|
||||||
name = prompt(
|
name = prompt(
|
||||||
"> Enter a name for your Llama Stack (e.g. my-local-stack): ",
|
"> Enter a name for your Llama Stack (e.g. my-local-stack): ",
|
||||||
|
@ -99,12 +135,13 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
||||||
)
|
)
|
||||||
|
|
||||||
image_type = prompt(
|
image_type = prompt(
|
||||||
f"> Enter the image type you want your Llama Stack to be built as ({' or '.join(e.value for e in ImageType)}): ",
|
"> Enter the image type you want your Llama Stack to be built as (use <TAB> to see options): ",
|
||||||
|
completer=WordCompleter([e.value for e in ImageType]),
|
||||||
|
complete_while_typing=True,
|
||||||
validator=Validator.from_callable(
|
validator=Validator.from_callable(
|
||||||
lambda x: x in [e.value for e in ImageType],
|
lambda x: x in [e.value for e in ImageType],
|
||||||
error_message=f"Invalid image type, please enter {' or '.join(e.value for e in ImageType)}",
|
error_message="Invalid image type. Use <TAB> to see options",
|
||||||
),
|
),
|
||||||
default=ImageType.CONDA.value,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if image_type == ImageType.CONDA.value:
|
if image_type == ImageType.CONDA.value:
|
||||||
|
@ -140,7 +177,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
||||||
if not available_providers:
|
if not available_providers:
|
||||||
continue
|
continue
|
||||||
api_provider = prompt(
|
api_provider = prompt(
|
||||||
"> Enter provider for API {}: ".format(api.value),
|
f"> Enter provider for API {api.value}: ",
|
||||||
completer=WordCompleter(available_providers),
|
completer=WordCompleter(available_providers),
|
||||||
complete_while_typing=True,
|
complete_while_typing=True,
|
||||||
validator=Validator.from_callable(
|
validator=Validator.from_callable(
|
||||||
|
@ -163,7 +200,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
||||||
|
|
||||||
build_config = BuildConfig(image_type=image_type, distribution_spec=distribution_spec)
|
build_config = BuildConfig(image_type=image_type, distribution_spec=distribution_spec)
|
||||||
else:
|
else:
|
||||||
with open(args.config, "r") as f:
|
with open(args.config) as f:
|
||||||
try:
|
try:
|
||||||
build_config = BuildConfig(**yaml.safe_load(f))
|
build_config = BuildConfig(**yaml.safe_load(f))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -173,16 +210,9 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
||||||
)
|
)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
if build_config.image_type == LlamaStackImageType.CONTAINER.value and not args.image_name:
|
|
||||||
cprint(
|
|
||||||
"Please specify --image-name when building a container from a config file",
|
|
||||||
color="red",
|
|
||||||
)
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
if args.print_deps_only:
|
if args.print_deps_only:
|
||||||
print(f"# Dependencies for {args.template or args.config or image_name}")
|
print(f"# Dependencies for {args.template or args.config or image_name}")
|
||||||
normal_deps, special_deps = get_provider_dependencies(build_config.distribution_spec.providers)
|
normal_deps, special_deps = get_provider_dependencies(build_config)
|
||||||
normal_deps += SERVER_DEPENDENCIES
|
normal_deps += SERVER_DEPENDENCIES
|
||||||
print(f"uv pip install {' '.join(normal_deps)}")
|
print(f"uv pip install {' '.join(normal_deps)}")
|
||||||
for special_dep in special_deps:
|
for special_dep in special_deps:
|
||||||
|
@ -198,10 +228,14 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
||||||
)
|
)
|
||||||
|
|
||||||
except (Exception, RuntimeError) as exc:
|
except (Exception, RuntimeError) as exc:
|
||||||
|
import traceback
|
||||||
|
|
||||||
cprint(
|
cprint(
|
||||||
f"Error building stack: {exc}",
|
f"Error building stack: {exc}",
|
||||||
color="red",
|
color="red",
|
||||||
)
|
)
|
||||||
|
cprint("Stack trace:", color="red")
|
||||||
|
traceback.print_exc()
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
if run_config is None:
|
if run_config is None:
|
||||||
cprint(
|
cprint(
|
||||||
|
@ -233,9 +267,10 @@ def _generate_run_config(
|
||||||
image_name=image_name,
|
image_name=image_name,
|
||||||
apis=apis,
|
apis=apis,
|
||||||
providers={},
|
providers={},
|
||||||
|
external_providers_dir=build_config.external_providers_dir if build_config.external_providers_dir else None,
|
||||||
)
|
)
|
||||||
# build providers dict
|
# build providers dict
|
||||||
provider_registry = get_provider_registry()
|
provider_registry = get_provider_registry(build_config)
|
||||||
for api in apis:
|
for api in apis:
|
||||||
run_config.providers[api] = []
|
run_config.providers[api] = []
|
||||||
provider_types = build_config.distribution_spec.providers[api]
|
provider_types = build_config.distribution_spec.providers[api]
|
||||||
|
@ -249,8 +284,22 @@ def _generate_run_config(
|
||||||
if p.deprecation_error:
|
if p.deprecation_error:
|
||||||
raise InvalidProviderError(p.deprecation_error)
|
raise InvalidProviderError(p.deprecation_error)
|
||||||
|
|
||||||
config_type = instantiate_class_type(provider_registry[Api(api)][provider_type].config_class)
|
try:
|
||||||
if hasattr(config_type, "sample_run_config"):
|
config_type = instantiate_class_type(provider_registry[Api(api)][provider_type].config_class)
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
# HACK ALERT:
|
||||||
|
# This code executes after building is done, the import cannot work since the
|
||||||
|
# package is either available in the venv or container - not available on the host.
|
||||||
|
# TODO: use a "is_external" flag in ProviderSpec to check if the provider is
|
||||||
|
# external
|
||||||
|
cprint(
|
||||||
|
f"Failed to import provider {provider_type} for API {api} - assuming it's external, skipping",
|
||||||
|
color="yellow",
|
||||||
|
)
|
||||||
|
# Set config_type to None to avoid UnboundLocalError
|
||||||
|
config_type = None
|
||||||
|
|
||||||
|
if config_type is not None and hasattr(config_type, "sample_run_config"):
|
||||||
config = config_type.sample_run_config(__distro_dir__=f"~/.llama/distributions/{image_name}")
|
config = config_type.sample_run_config(__distro_dir__=f"~/.llama/distributions/{image_name}")
|
||||||
else:
|
else:
|
||||||
config = {}
|
config = {}
|
||||||
|
@ -268,20 +317,25 @@ def _generate_run_config(
|
||||||
to_write = json.loads(run_config.model_dump_json())
|
to_write = json.loads(run_config.model_dump_json())
|
||||||
f.write(yaml.dump(to_write, sort_keys=False))
|
f.write(yaml.dump(to_write, sort_keys=False))
|
||||||
|
|
||||||
# this path is only invoked when no template is provided
|
# Only print this message for non-container builds since it will be displayed before the
|
||||||
cprint(
|
# container is built
|
||||||
f"You can now run your stack with `llama stack run {run_config_file}`",
|
# For non-container builds, the run.yaml is generated at the very end of the build process so it
|
||||||
color="green",
|
# makes sense to display this message
|
||||||
)
|
if build_config.image_type != LlamaStackImageType.CONTAINER.value:
|
||||||
|
cprint(
|
||||||
|
f"You can now run your stack with `llama stack run {run_config_file}`",
|
||||||
|
color="green",
|
||||||
|
)
|
||||||
return run_config_file
|
return run_config_file
|
||||||
|
|
||||||
|
|
||||||
def _run_stack_build_command_from_build_config(
|
def _run_stack_build_command_from_build_config(
|
||||||
build_config: BuildConfig,
|
build_config: BuildConfig,
|
||||||
image_name: Optional[str] = None,
|
image_name: str | None = None,
|
||||||
template_name: Optional[str] = None,
|
template_name: str | None = None,
|
||||||
config_path: Optional[str] = None,
|
config_path: str | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
|
image_name = image_name or build_config.image_name
|
||||||
if build_config.image_type == LlamaStackImageType.CONTAINER.value:
|
if build_config.image_type == LlamaStackImageType.CONTAINER.value:
|
||||||
if template_name:
|
if template_name:
|
||||||
image_name = f"distribution-{template_name}"
|
image_name = f"distribution-{template_name}"
|
||||||
|
@ -305,6 +359,13 @@ def _run_stack_build_command_from_build_config(
|
||||||
build_file_path = build_dir / f"{image_name}-build.yaml"
|
build_file_path = build_dir / f"{image_name}-build.yaml"
|
||||||
|
|
||||||
os.makedirs(build_dir, exist_ok=True)
|
os.makedirs(build_dir, exist_ok=True)
|
||||||
|
run_config_file = None
|
||||||
|
# Generate the run.yaml so it can be included in the container image with the proper entrypoint
|
||||||
|
# Only do this if we're building a container image and we're not using a template
|
||||||
|
if build_config.image_type == LlamaStackImageType.CONTAINER.value and not template_name and config_path:
|
||||||
|
cprint("Generating run.yaml file", color="green")
|
||||||
|
run_config_file = _generate_run_config(build_config, build_dir, image_name)
|
||||||
|
|
||||||
with open(build_file_path, "w") as f:
|
with open(build_file_path, "w") as f:
|
||||||
to_write = json.loads(build_config.model_dump_json())
|
to_write = json.loads(build_config.model_dump_json())
|
||||||
f.write(yaml.dump(to_write, sort_keys=False))
|
f.write(yaml.dump(to_write, sort_keys=False))
|
||||||
|
@ -313,7 +374,8 @@ def _run_stack_build_command_from_build_config(
|
||||||
build_config,
|
build_config,
|
||||||
build_file_path,
|
build_file_path,
|
||||||
image_name,
|
image_name,
|
||||||
template_or_config=template_name or config_path,
|
template_or_config=template_name or config_path or str(build_file_path),
|
||||||
|
run_config=run_config_file,
|
||||||
)
|
)
|
||||||
if return_code != 0:
|
if return_code != 0:
|
||||||
raise RuntimeError(f"Failed to build image {image_name}")
|
raise RuntimeError(f"Failed to build image {image_name}")
|
||||||
|
@ -326,6 +388,11 @@ def _run_stack_build_command_from_build_config(
|
||||||
shutil.copy(path, run_config_file)
|
shutil.copy(path, run_config_file)
|
||||||
|
|
||||||
cprint("Build Successful!", color="green")
|
cprint("Build Successful!", color="green")
|
||||||
|
cprint("You can find the newly-built template here: " + colored(template_path, "light_blue"))
|
||||||
|
cprint(
|
||||||
|
"You can run the new Llama Stack distro via: "
|
||||||
|
+ colored(f"llama stack run {template_path} --image-type {build_config.image_type}", "light_blue")
|
||||||
|
)
|
||||||
return template_path
|
return template_path
|
||||||
else:
|
else:
|
||||||
return _generate_run_config(build_config, build_dir, image_name)
|
return _generate_run_config(build_config, build_dir, image_name)
|
||||||
|
|
|
@ -75,6 +75,12 @@ the build. If not specified, currently active environment will be used if found.
|
||||||
default=False,
|
default=False,
|
||||||
help="Run the stack after building using the same image type, name, and other applicable arguments",
|
help="Run the stack after building using the same image type, name, and other applicable arguments",
|
||||||
)
|
)
|
||||||
|
self.parser.add_argument(
|
||||||
|
"--providers",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Build a config for a list of providers and only those providers. This list is formatted like: api1=provider1,api2=provider2. Where there can be multiple providers per API.",
|
||||||
|
)
|
||||||
|
|
||||||
def _run_stack_build_command(self, args: argparse.Namespace) -> None:
|
def _run_stack_build_command(self, args: argparse.Namespace) -> None:
|
||||||
# always keep implementation completely silo-ed away from CLI so CLI
|
# always keep implementation completely silo-ed away from CLI so CLI
|
||||||
|
|
|
@ -119,7 +119,7 @@ class StackRun(Subcommand):
|
||||||
|
|
||||||
if not config_file.is_file():
|
if not config_file.is_file():
|
||||||
self.parser.error(
|
self.parser.error(
|
||||||
f"Config file must be a valid file path, '{config_file}’ is not a file: type={type(config_file)}"
|
f"Config file must be a valid file path, '{config_file}' is not a file: type={type(config_file)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"Using run configuration: {config_file}")
|
logger.info(f"Using run configuration: {config_file}")
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Iterable
|
from collections.abc import Iterable
|
||||||
|
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
from rich.table import Table
|
from rich.table import Table
|
||||||
|
|
|
@ -9,7 +9,6 @@ import hashlib
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional
|
|
||||||
|
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
from rich.progress import Progress, SpinnerColumn, TextColumn
|
from rich.progress import Progress, SpinnerColumn, TextColumn
|
||||||
|
@ -21,7 +20,7 @@ from llama_stack.cli.subcommand import Subcommand
|
||||||
class VerificationResult:
|
class VerificationResult:
|
||||||
filename: str
|
filename: str
|
||||||
expected_hash: str
|
expected_hash: str
|
||||||
actual_hash: Optional[str]
|
actual_hash: str | None
|
||||||
exists: bool
|
exists: bool
|
||||||
matches: bool
|
matches: bool
|
||||||
|
|
||||||
|
@ -60,9 +59,9 @@ def calculate_md5(filepath: Path, chunk_size: int = 8192) -> str:
|
||||||
return md5_hash.hexdigest()
|
return md5_hash.hexdigest()
|
||||||
|
|
||||||
|
|
||||||
def load_checksums(checklist_path: Path) -> Dict[str, str]:
|
def load_checksums(checklist_path: Path) -> dict[str, str]:
|
||||||
checksums = {}
|
checksums = {}
|
||||||
with open(checklist_path, "r") as f:
|
with open(checklist_path) as f:
|
||||||
for line in f:
|
for line in f:
|
||||||
if line.strip():
|
if line.strip():
|
||||||
md5sum, filepath = line.strip().split(" ", 1)
|
md5sum, filepath = line.strip().split(" ", 1)
|
||||||
|
@ -72,7 +71,7 @@ def load_checksums(checklist_path: Path) -> Dict[str, str]:
|
||||||
return checksums
|
return checksums
|
||||||
|
|
||||||
|
|
||||||
def verify_files(model_dir: Path, checksums: Dict[str, str], console: Console) -> List[VerificationResult]:
|
def verify_files(model_dir: Path, checksums: dict[str, str], console: Console) -> list[VerificationResult]:
|
||||||
results = []
|
results = []
|
||||||
|
|
||||||
with Progress(
|
with Progress(
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import AccessAttributes
|
from llama_stack.distribution.datatypes import AccessAttributes
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
@ -14,8 +14,8 @@ logger = get_logger(__name__, category="core")
|
||||||
|
|
||||||
def check_access(
|
def check_access(
|
||||||
obj_identifier: str,
|
obj_identifier: str,
|
||||||
obj_attributes: Optional[AccessAttributes],
|
obj_attributes: AccessAttributes | None,
|
||||||
user_attributes: Optional[Dict[str, Any]] = None,
|
user_attributes: dict[str, Any] | None = None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Check if the current user has access to the given object, based on access attributes.
|
"""Check if the current user has access to the given object, based on access attributes.
|
||||||
|
|
||||||
|
|
|
@ -7,16 +7,16 @@
|
||||||
import importlib.resources
|
import importlib.resources
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import BuildConfig, Provider
|
from llama_stack.distribution.datatypes import BuildConfig
|
||||||
from llama_stack.distribution.distribution import get_provider_registry
|
from llama_stack.distribution.distribution import get_provider_registry
|
||||||
from llama_stack.distribution.utils.exec import run_command
|
from llama_stack.distribution.utils.exec import run_command
|
||||||
from llama_stack.distribution.utils.image_types import LlamaStackImageType
|
from llama_stack.distribution.utils.image_types import LlamaStackImageType
|
||||||
from llama_stack.providers.datatypes import Api
|
from llama_stack.providers.datatypes import Api
|
||||||
|
from llama_stack.templates.template import DistributionTemplate
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -37,19 +37,23 @@ class ApiInput(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
def get_provider_dependencies(
|
def get_provider_dependencies(
|
||||||
config_providers: Dict[str, List[Provider]],
|
config: BuildConfig | DistributionTemplate,
|
||||||
) -> tuple[list[str], list[str]]:
|
) -> tuple[list[str], list[str]]:
|
||||||
"""Get normal and special dependencies from provider configuration."""
|
"""Get normal and special dependencies from provider configuration."""
|
||||||
all_providers = get_provider_registry()
|
# Extract providers based on config type
|
||||||
|
if isinstance(config, DistributionTemplate):
|
||||||
|
providers = config.providers
|
||||||
|
elif isinstance(config, BuildConfig):
|
||||||
|
providers = config.distribution_spec.providers
|
||||||
deps = []
|
deps = []
|
||||||
|
registry = get_provider_registry(config)
|
||||||
for api_str, provider_or_providers in config_providers.items():
|
for api_str, provider_or_providers in providers.items():
|
||||||
providers_for_api = all_providers[Api(api_str)]
|
providers_for_api = registry[Api(api_str)]
|
||||||
|
|
||||||
providers = provider_or_providers if isinstance(provider_or_providers, list) else [provider_or_providers]
|
providers = provider_or_providers if isinstance(provider_or_providers, list) else [provider_or_providers]
|
||||||
|
|
||||||
for provider in providers:
|
for provider in providers:
|
||||||
# Providers from BuildConfig and RunConfig are subtly different – not great
|
# Providers from BuildConfig and RunConfig are subtly different - not great
|
||||||
provider_type = provider if isinstance(provider, str) else provider.provider_type
|
provider_type = provider if isinstance(provider, str) else provider.provider_type
|
||||||
|
|
||||||
if provider_type not in providers_for_api:
|
if provider_type not in providers_for_api:
|
||||||
|
@ -71,8 +75,8 @@ def get_provider_dependencies(
|
||||||
return list(set(normal_deps)), list(set(special_deps))
|
return list(set(normal_deps)), list(set(special_deps))
|
||||||
|
|
||||||
|
|
||||||
def print_pip_install_help(providers: Dict[str, List[Provider]]):
|
def print_pip_install_help(config: BuildConfig):
|
||||||
normal_deps, special_deps = get_provider_dependencies(providers)
|
normal_deps, special_deps = get_provider_dependencies(config)
|
||||||
|
|
||||||
cprint(
|
cprint(
|
||||||
f"Please install needed dependencies using the following commands:\n\nuv pip install {' '.join(normal_deps)}",
|
f"Please install needed dependencies using the following commands:\n\nuv pip install {' '.join(normal_deps)}",
|
||||||
|
@ -88,10 +92,11 @@ def build_image(
|
||||||
build_file_path: Path,
|
build_file_path: Path,
|
||||||
image_name: str,
|
image_name: str,
|
||||||
template_or_config: str,
|
template_or_config: str,
|
||||||
|
run_config: str | None = None,
|
||||||
):
|
):
|
||||||
container_base = build_config.distribution_spec.container_image or "python:3.10-slim"
|
container_base = build_config.distribution_spec.container_image or "python:3.10-slim"
|
||||||
|
|
||||||
normal_deps, special_deps = get_provider_dependencies(build_config.distribution_spec.providers)
|
normal_deps, special_deps = get_provider_dependencies(build_config)
|
||||||
normal_deps += SERVER_DEPENDENCIES
|
normal_deps += SERVER_DEPENDENCIES
|
||||||
|
|
||||||
if build_config.image_type == LlamaStackImageType.CONTAINER.value:
|
if build_config.image_type == LlamaStackImageType.CONTAINER.value:
|
||||||
|
@ -103,6 +108,11 @@ def build_image(
|
||||||
container_base,
|
container_base,
|
||||||
" ".join(normal_deps),
|
" ".join(normal_deps),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# When building from a config file (not a template), include the run config path in the
|
||||||
|
# build arguments
|
||||||
|
if run_config is not None:
|
||||||
|
args.append(run_config)
|
||||||
elif build_config.image_type == LlamaStackImageType.CONDA.value:
|
elif build_config.image_type == LlamaStackImageType.CONDA.value:
|
||||||
script = str(importlib.resources.files("llama_stack") / "distribution/build_conda_env.sh")
|
script = str(importlib.resources.files("llama_stack") / "distribution/build_conda_env.sh")
|
||||||
args = [
|
args = [
|
||||||
|
|
|
@ -19,12 +19,16 @@ UV_HTTP_TIMEOUT=${UV_HTTP_TIMEOUT:-500}
|
||||||
# mounting is not supported by docker buildx, so we use COPY instead
|
# mounting is not supported by docker buildx, so we use COPY instead
|
||||||
USE_COPY_NOT_MOUNT=${USE_COPY_NOT_MOUNT:-}
|
USE_COPY_NOT_MOUNT=${USE_COPY_NOT_MOUNT:-}
|
||||||
|
|
||||||
|
# Path to the run.yaml file in the container
|
||||||
|
RUN_CONFIG_PATH=/app/run.yaml
|
||||||
|
|
||||||
|
BUILD_CONTEXT_DIR=$(pwd)
|
||||||
|
|
||||||
if [ "$#" -lt 4 ]; then
|
if [ "$#" -lt 4 ]; then
|
||||||
# This only works for templates
|
# This only works for templates
|
||||||
echo "Usage: $0 <template_or_config> <image_name> <container_base> <pip_dependencies> [<special_pip_deps>]" >&2
|
echo "Usage: $0 <template_or_config> <image_name> <container_base> <pip_dependencies> [<run_config>] [<special_pip_deps>]" >&2
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
set -euo pipefail
|
set -euo pipefail
|
||||||
|
|
||||||
template_or_config="$1"
|
template_or_config="$1"
|
||||||
|
@ -35,8 +39,27 @@ container_base="$1"
|
||||||
shift
|
shift
|
||||||
pip_dependencies="$1"
|
pip_dependencies="$1"
|
||||||
shift
|
shift
|
||||||
special_pip_deps="${1:-}"
|
|
||||||
|
|
||||||
|
# Handle optional arguments
|
||||||
|
run_config=""
|
||||||
|
special_pip_deps=""
|
||||||
|
|
||||||
|
# Check if there are more arguments
|
||||||
|
# The logics is becoming cumbersom, we should refactor it if we can do better
|
||||||
|
if [ $# -gt 0 ]; then
|
||||||
|
# Check if the argument ends with .yaml
|
||||||
|
if [[ "$1" == *.yaml ]]; then
|
||||||
|
run_config="$1"
|
||||||
|
shift
|
||||||
|
# If there's another argument after .yaml, it must be special_pip_deps
|
||||||
|
if [ $# -gt 0 ]; then
|
||||||
|
special_pip_deps="$1"
|
||||||
|
fi
|
||||||
|
else
|
||||||
|
# If it's not .yaml, it must be special_pip_deps
|
||||||
|
special_pip_deps="$1"
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
# Define color codes
|
# Define color codes
|
||||||
RED='\033[0;31m'
|
RED='\033[0;31m'
|
||||||
|
@ -72,9 +95,13 @@ if [[ $container_base == *"registry.access.redhat.com/ubi9"* ]]; then
|
||||||
FROM $container_base
|
FROM $container_base
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
|
|
||||||
RUN dnf -y update && dnf install -y iputils net-tools wget \
|
# We install the Python 3.11 dev headers and build tools so that any
|
||||||
|
# C‑extension wheels (e.g. polyleven, faiss‑cpu) can compile successfully.
|
||||||
|
|
||||||
|
RUN dnf -y update && dnf install -y iputils git net-tools wget \
|
||||||
vim-minimal python3.11 python3.11-pip python3.11-wheel \
|
vim-minimal python3.11 python3.11-pip python3.11-wheel \
|
||||||
python3.11-setuptools && ln -s /bin/pip3.11 /bin/pip && ln -s /bin/python3.11 /bin/python && dnf clean all
|
python3.11-setuptools python3.11-devel gcc make && \
|
||||||
|
ln -s /bin/pip3.11 /bin/pip && ln -s /bin/python3.11 /bin/python && dnf clean all
|
||||||
|
|
||||||
ENV UV_SYSTEM_PYTHON=1
|
ENV UV_SYSTEM_PYTHON=1
|
||||||
RUN pip install uv
|
RUN pip install uv
|
||||||
|
@ -86,7 +113,7 @@ WORKDIR /app
|
||||||
|
|
||||||
RUN apt-get update && apt-get install -y \
|
RUN apt-get update && apt-get install -y \
|
||||||
iputils-ping net-tools iproute2 dnsutils telnet \
|
iputils-ping net-tools iproute2 dnsutils telnet \
|
||||||
curl wget telnet \
|
curl wget telnet git\
|
||||||
procps psmisc lsof \
|
procps psmisc lsof \
|
||||||
traceroute \
|
traceroute \
|
||||||
bubblewrap \
|
bubblewrap \
|
||||||
|
@ -115,6 +142,45 @@ EOF
|
||||||
done
|
done
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
# Function to get Python command
|
||||||
|
get_python_cmd() {
|
||||||
|
if is_command_available python; then
|
||||||
|
echo "python"
|
||||||
|
elif is_command_available python3; then
|
||||||
|
echo "python3"
|
||||||
|
else
|
||||||
|
echo "Error: Neither python nor python3 is installed. Please install Python to continue." >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
if [ -n "$run_config" ]; then
|
||||||
|
# Copy the run config to the build context since it's an absolute path
|
||||||
|
cp "$run_config" "$BUILD_CONTEXT_DIR/run.yaml"
|
||||||
|
add_to_container << EOF
|
||||||
|
COPY run.yaml $RUN_CONFIG_PATH
|
||||||
|
EOF
|
||||||
|
|
||||||
|
# Parse the run.yaml configuration to identify external provider directories
|
||||||
|
# If external providers are specified, copy their directory to the container
|
||||||
|
# and update the configuration to reference the new container path
|
||||||
|
python_cmd=$(get_python_cmd)
|
||||||
|
external_providers_dir=$($python_cmd -c "import yaml; config = yaml.safe_load(open('$run_config')); print(config.get('external_providers_dir') or '')")
|
||||||
|
if [ -n "$external_providers_dir" ]; then
|
||||||
|
echo "Copying external providers directory: $external_providers_dir"
|
||||||
|
add_to_container << EOF
|
||||||
|
COPY $external_providers_dir /app/providers.d
|
||||||
|
EOF
|
||||||
|
# Edit the run.yaml file to change the external_providers_dir to /app/providers.d
|
||||||
|
if [ "$(uname)" = "Darwin" ]; then
|
||||||
|
sed -i.bak -e 's|external_providers_dir:.*|external_providers_dir: /app/providers.d|' "$BUILD_CONTEXT_DIR/run.yaml"
|
||||||
|
rm -f "$BUILD_CONTEXT_DIR/run.yaml.bak"
|
||||||
|
else
|
||||||
|
sed -i 's|external_providers_dir:.*|external_providers_dir: /app/providers.d|' "$BUILD_CONTEXT_DIR/run.yaml"
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
stack_mount="/app/llama-stack-source"
|
stack_mount="/app/llama-stack-source"
|
||||||
client_mount="/app/llama-stack-client-source"
|
client_mount="/app/llama-stack-client-source"
|
||||||
|
|
||||||
|
@ -174,15 +240,16 @@ fi
|
||||||
RUN pip uninstall -y uv
|
RUN pip uninstall -y uv
|
||||||
EOF
|
EOF
|
||||||
|
|
||||||
# if template_or_config ends with .yaml, it is not a template and we should not use the --template flag
|
# If a run config is provided, we use the --config flag
|
||||||
if [[ "$template_or_config" != *.yaml ]]; then
|
if [[ -n "$run_config" ]]; then
|
||||||
|
add_to_container << EOF
|
||||||
|
ENTRYPOINT ["python", "-m", "llama_stack.distribution.server.server", "--config", "$RUN_CONFIG_PATH"]
|
||||||
|
EOF
|
||||||
|
# If a template is provided (not a yaml file), we use the --template flag
|
||||||
|
elif [[ "$template_or_config" != *.yaml ]]; then
|
||||||
add_to_container << EOF
|
add_to_container << EOF
|
||||||
ENTRYPOINT ["python", "-m", "llama_stack.distribution.server.server", "--template", "$template_or_config"]
|
ENTRYPOINT ["python", "-m", "llama_stack.distribution.server.server", "--template", "$template_or_config"]
|
||||||
EOF
|
EOF
|
||||||
else
|
|
||||||
add_to_container << EOF
|
|
||||||
ENTRYPOINT ["python", "-m", "llama_stack.distribution.server.server"]
|
|
||||||
EOF
|
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Add other require item commands genearic to all containers
|
# Add other require item commands genearic to all containers
|
||||||
|
@ -254,9 +321,10 @@ $CONTAINER_BINARY build \
|
||||||
"${CLI_ARGS[@]}" \
|
"${CLI_ARGS[@]}" \
|
||||||
-t "$image_tag" \
|
-t "$image_tag" \
|
||||||
-f "$TEMP_DIR/Containerfile" \
|
-f "$TEMP_DIR/Containerfile" \
|
||||||
"."
|
"$BUILD_CONTEXT_DIR"
|
||||||
|
|
||||||
# clean up tmp/configs
|
# clean up tmp/configs
|
||||||
|
rm -f "$BUILD_CONTEXT_DIR/run.yaml"
|
||||||
set +x
|
set +x
|
||||||
|
|
||||||
echo "Success!"
|
echo "Success!"
|
||||||
|
|
|
@ -8,7 +8,7 @@ import inspect
|
||||||
import json
|
import json
|
||||||
from collections.abc import AsyncIterator
|
from collections.abc import AsyncIterator
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Type, Union, get_args, get_origin
|
from typing import Any, Union, get_args, get_origin
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from pydantic import BaseModel, parse_obj_as
|
from pydantic import BaseModel, parse_obj_as
|
||||||
|
@ -27,7 +27,7 @@ async def get_client_impl(protocol, config: RemoteProviderConfig, _deps: Any):
|
||||||
return impl
|
return impl
|
||||||
|
|
||||||
|
|
||||||
def create_api_client_class(protocol) -> Type:
|
def create_api_client_class(protocol) -> type:
|
||||||
if protocol in _CLIENT_CLASSES:
|
if protocol in _CLIENT_CLASSES:
|
||||||
return _CLIENT_CLASSES[protocol]
|
return _CLIENT_CLASSES[protocol]
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
# All rights reserved.
|
# All rights reserved.
|
||||||
#
|
#
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
import logging
|
import logging
|
||||||
import textwrap
|
import textwrap
|
||||||
from typing import Any, Dict
|
from typing import Any
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import (
|
from llama_stack.distribution.datatypes import (
|
||||||
LLAMA_STACK_RUN_CONFIG_VERSION,
|
LLAMA_STACK_RUN_CONFIG_VERSION,
|
||||||
|
@ -24,7 +24,7 @@ from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def configure_single_provider(registry: Dict[str, ProviderSpec], provider: Provider) -> Provider:
|
def configure_single_provider(registry: dict[str, ProviderSpec], provider: Provider) -> Provider:
|
||||||
provider_spec = registry[provider.provider_type]
|
provider_spec = registry[provider.provider_type]
|
||||||
config_type = instantiate_class_type(provider_spec.config_class)
|
config_type = instantiate_class_type(provider_spec.config_class)
|
||||||
try:
|
try:
|
||||||
|
@ -120,8 +120,8 @@ def configure_api_providers(config: StackRunConfig, build_spec: DistributionSpec
|
||||||
|
|
||||||
|
|
||||||
def upgrade_from_routing_table(
|
def upgrade_from_routing_table(
|
||||||
config_dict: Dict[str, Any],
|
config_dict: dict[str, Any],
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
def get_providers(entries):
|
def get_providers(entries):
|
||||||
return [
|
return [
|
||||||
Provider(
|
Provider(
|
||||||
|
@ -163,7 +163,7 @@ def upgrade_from_routing_table(
|
||||||
return config_dict
|
return config_dict
|
||||||
|
|
||||||
|
|
||||||
def parse_and_maybe_upgrade_config(config_dict: Dict[str, Any]) -> StackRunConfig:
|
def parse_and_maybe_upgrade_config(config_dict: dict[str, Any]) -> StackRunConfig:
|
||||||
version = config_dict.get("version", None)
|
version = config_dict.get("version", None)
|
||||||
if version == LLAMA_STACK_RUN_CONFIG_VERSION:
|
if version == LLAMA_STACK_RUN_CONFIG_VERSION:
|
||||||
return StackRunConfig(**config_dict)
|
return StackRunConfig(**config_dict)
|
||||||
|
|
|
@ -4,7 +4,8 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Annotated, Any, Dict, List, Optional, Union
|
from enum import Enum
|
||||||
|
from typing import Annotated, Any
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
@ -29,7 +30,7 @@ LLAMA_STACK_BUILD_CONFIG_VERSION = "2"
|
||||||
LLAMA_STACK_RUN_CONFIG_VERSION = "2"
|
LLAMA_STACK_RUN_CONFIG_VERSION = "2"
|
||||||
|
|
||||||
|
|
||||||
RoutingKey = Union[str, List[str]]
|
RoutingKey = str | list[str]
|
||||||
|
|
||||||
|
|
||||||
class AccessAttributes(BaseModel):
|
class AccessAttributes(BaseModel):
|
||||||
|
@ -46,17 +47,17 @@ class AccessAttributes(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Standard attribute categories - the minimal set we need now
|
# Standard attribute categories - the minimal set we need now
|
||||||
roles: Optional[List[str]] = Field(
|
roles: list[str] | None = Field(
|
||||||
default=None, description="Role-based attributes (e.g., 'admin', 'data-scientist', 'user')"
|
default=None, description="Role-based attributes (e.g., 'admin', 'data-scientist', 'user')"
|
||||||
)
|
)
|
||||||
|
|
||||||
teams: Optional[List[str]] = Field(default=None, description="Team-based attributes (e.g., 'ml-team', 'nlp-team')")
|
teams: list[str] | None = Field(default=None, description="Team-based attributes (e.g., 'ml-team', 'nlp-team')")
|
||||||
|
|
||||||
projects: Optional[List[str]] = Field(
|
projects: list[str] | None = Field(
|
||||||
default=None, description="Project-based access attributes (e.g., 'llama-3', 'customer-insights')"
|
default=None, description="Project-based access attributes (e.g., 'llama-3', 'customer-insights')"
|
||||||
)
|
)
|
||||||
|
|
||||||
namespaces: Optional[List[str]] = Field(
|
namespaces: list[str] | None = Field(
|
||||||
default=None, description="Namespace-based access control for resource isolation"
|
default=None, description="Namespace-based access control for resource isolation"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -105,7 +106,7 @@ class ResourceWithACL(Resource):
|
||||||
# ^ User must have access to the customer-insights project AND have confidential namespace
|
# ^ User must have access to the customer-insights project AND have confidential namespace
|
||||||
"""
|
"""
|
||||||
|
|
||||||
access_attributes: Optional[AccessAttributes] = None
|
access_attributes: AccessAttributes | None = None
|
||||||
|
|
||||||
|
|
||||||
# Use the extended Resource for all routable objects
|
# Use the extended Resource for all routable objects
|
||||||
|
@ -141,41 +142,21 @@ class ToolGroupWithACL(ToolGroup, ResourceWithACL):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
RoutableObject = Union[
|
RoutableObject = Model | Shield | VectorDB | Dataset | ScoringFn | Benchmark | Tool | ToolGroup
|
||||||
Model,
|
|
||||||
Shield,
|
|
||||||
VectorDB,
|
|
||||||
Dataset,
|
|
||||||
ScoringFn,
|
|
||||||
Benchmark,
|
|
||||||
Tool,
|
|
||||||
ToolGroup,
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
RoutableObjectWithProvider = Annotated[
|
RoutableObjectWithProvider = Annotated[
|
||||||
Union[
|
ModelWithACL
|
||||||
ModelWithACL,
|
| ShieldWithACL
|
||||||
ShieldWithACL,
|
| VectorDBWithACL
|
||||||
VectorDBWithACL,
|
| DatasetWithACL
|
||||||
DatasetWithACL,
|
| ScoringFnWithACL
|
||||||
ScoringFnWithACL,
|
| BenchmarkWithACL
|
||||||
BenchmarkWithACL,
|
| ToolWithACL
|
||||||
ToolWithACL,
|
| ToolGroupWithACL,
|
||||||
ToolGroupWithACL,
|
|
||||||
],
|
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
|
|
||||||
RoutedProtocol = Union[
|
RoutedProtocol = Inference | Safety | VectorIO | DatasetIO | Scoring | Eval | ToolRuntime
|
||||||
Inference,
|
|
||||||
Safety,
|
|
||||||
VectorIO,
|
|
||||||
DatasetIO,
|
|
||||||
Scoring,
|
|
||||||
Eval,
|
|
||||||
ToolRuntime,
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
# Example: /inference, /safety
|
# Example: /inference, /safety
|
||||||
|
@ -183,15 +164,15 @@ class AutoRoutedProviderSpec(ProviderSpec):
|
||||||
provider_type: str = "router"
|
provider_type: str = "router"
|
||||||
config_class: str = ""
|
config_class: str = ""
|
||||||
|
|
||||||
container_image: Optional[str] = None
|
container_image: str | None = None
|
||||||
routing_table_api: Api
|
routing_table_api: Api
|
||||||
module: str
|
module: str
|
||||||
provider_data_validator: Optional[str] = Field(
|
provider_data_validator: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def pip_packages(self) -> List[str]:
|
def pip_packages(self) -> list[str]:
|
||||||
raise AssertionError("Should not be called on AutoRoutedProviderSpec")
|
raise AssertionError("Should not be called on AutoRoutedProviderSpec")
|
||||||
|
|
||||||
|
|
||||||
|
@ -199,20 +180,20 @@ class AutoRoutedProviderSpec(ProviderSpec):
|
||||||
class RoutingTableProviderSpec(ProviderSpec):
|
class RoutingTableProviderSpec(ProviderSpec):
|
||||||
provider_type: str = "routing_table"
|
provider_type: str = "routing_table"
|
||||||
config_class: str = ""
|
config_class: str = ""
|
||||||
container_image: Optional[str] = None
|
container_image: str | None = None
|
||||||
|
|
||||||
router_api: Api
|
router_api: Api
|
||||||
module: str
|
module: str
|
||||||
pip_packages: List[str] = Field(default_factory=list)
|
pip_packages: list[str] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
class DistributionSpec(BaseModel):
|
class DistributionSpec(BaseModel):
|
||||||
description: Optional[str] = Field(
|
description: str | None = Field(
|
||||||
default="",
|
default="",
|
||||||
description="Description of the distribution",
|
description="Description of the distribution",
|
||||||
)
|
)
|
||||||
container_image: Optional[str] = None
|
container_image: str | None = None
|
||||||
providers: Dict[str, Union[str, List[str]]] = Field(
|
providers: dict[str, str | list[str]] = Field(
|
||||||
default_factory=dict,
|
default_factory=dict,
|
||||||
description="""
|
description="""
|
||||||
Provider Types for each of the APIs provided by this distribution. If you
|
Provider Types for each of the APIs provided by this distribution. If you
|
||||||
|
@ -224,21 +205,32 @@ in the runtime configuration to help route to the correct provider.""",
|
||||||
class Provider(BaseModel):
|
class Provider(BaseModel):
|
||||||
provider_id: str
|
provider_id: str
|
||||||
provider_type: str
|
provider_type: str
|
||||||
config: Dict[str, Any]
|
config: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
class LoggingConfig(BaseModel):
|
class LoggingConfig(BaseModel):
|
||||||
category_levels: Dict[str, str] = Field(
|
category_levels: dict[str, str] = Field(
|
||||||
default_factory=Dict,
|
default_factory=dict,
|
||||||
description="""
|
description="""
|
||||||
Dictionary of different logging configurations for different portions (ex: core, server) of llama stack""",
|
Dictionary of different logging configurations for different portions (ex: core, server) of llama stack""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AuthProviderType(str, Enum):
|
||||||
|
"""Supported authentication provider types."""
|
||||||
|
|
||||||
|
KUBERNETES = "kubernetes"
|
||||||
|
CUSTOM = "custom"
|
||||||
|
|
||||||
|
|
||||||
class AuthenticationConfig(BaseModel):
|
class AuthenticationConfig(BaseModel):
|
||||||
endpoint: str = Field(
|
provider_type: AuthProviderType = Field(
|
||||||
...,
|
...,
|
||||||
description="Endpoint URL to validate authentication tokens",
|
description="Type of authentication provider (e.g., 'kubernetes', 'custom')",
|
||||||
|
)
|
||||||
|
config: dict[str, str] = Field(
|
||||||
|
...,
|
||||||
|
description="Provider-specific configuration",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -249,15 +241,15 @@ class ServerConfig(BaseModel):
|
||||||
ge=1024,
|
ge=1024,
|
||||||
le=65535,
|
le=65535,
|
||||||
)
|
)
|
||||||
tls_certfile: Optional[str] = Field(
|
tls_certfile: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Path to TLS certificate file for HTTPS",
|
description="Path to TLS certificate file for HTTPS",
|
||||||
)
|
)
|
||||||
tls_keyfile: Optional[str] = Field(
|
tls_keyfile: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Path to TLS key file for HTTPS",
|
description="Path to TLS key file for HTTPS",
|
||||||
)
|
)
|
||||||
auth: Optional[AuthenticationConfig] = Field(
|
auth: AuthenticationConfig | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Authentication configuration for the server",
|
description="Authentication configuration for the server",
|
||||||
)
|
)
|
||||||
|
@ -273,23 +265,23 @@ Reference to the distribution this package refers to. For unregistered (adhoc) p
|
||||||
this could be just a hash
|
this could be just a hash
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
container_image: Optional[str] = Field(
|
container_image: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Reference to the container image if this package refers to a container",
|
description="Reference to the container image if this package refers to a container",
|
||||||
)
|
)
|
||||||
apis: List[str] = Field(
|
apis: list[str] = Field(
|
||||||
default_factory=list,
|
default_factory=list,
|
||||||
description="""
|
description="""
|
||||||
The list of APIs to serve. If not specified, all APIs specified in the provider_map will be served""",
|
The list of APIs to serve. If not specified, all APIs specified in the provider_map will be served""",
|
||||||
)
|
)
|
||||||
|
|
||||||
providers: Dict[str, List[Provider]] = Field(
|
providers: dict[str, list[Provider]] = Field(
|
||||||
description="""
|
description="""
|
||||||
One or more providers to use for each API. The same provider_type (e.g., meta-reference)
|
One or more providers to use for each API. The same provider_type (e.g., meta-reference)
|
||||||
can be instantiated multiple times (with different configs) if necessary.
|
can be instantiated multiple times (with different configs) if necessary.
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
metadata_store: Optional[KVStoreConfig] = Field(
|
metadata_store: KVStoreConfig | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="""
|
description="""
|
||||||
Configuration for the persistence store used by the distribution registry. If not specified,
|
Configuration for the persistence store used by the distribution registry. If not specified,
|
||||||
|
@ -297,22 +289,22 @@ a default SQLite store will be used.""",
|
||||||
)
|
)
|
||||||
|
|
||||||
# registry of "resources" in the distribution
|
# registry of "resources" in the distribution
|
||||||
models: List[ModelInput] = Field(default_factory=list)
|
models: list[ModelInput] = Field(default_factory=list)
|
||||||
shields: List[ShieldInput] = Field(default_factory=list)
|
shields: list[ShieldInput] = Field(default_factory=list)
|
||||||
vector_dbs: List[VectorDBInput] = Field(default_factory=list)
|
vector_dbs: list[VectorDBInput] = Field(default_factory=list)
|
||||||
datasets: List[DatasetInput] = Field(default_factory=list)
|
datasets: list[DatasetInput] = Field(default_factory=list)
|
||||||
scoring_fns: List[ScoringFnInput] = Field(default_factory=list)
|
scoring_fns: list[ScoringFnInput] = Field(default_factory=list)
|
||||||
benchmarks: List[BenchmarkInput] = Field(default_factory=list)
|
benchmarks: list[BenchmarkInput] = Field(default_factory=list)
|
||||||
tool_groups: List[ToolGroupInput] = Field(default_factory=list)
|
tool_groups: list[ToolGroupInput] = Field(default_factory=list)
|
||||||
|
|
||||||
logging: Optional[LoggingConfig] = Field(default=None, description="Configuration for Llama Stack Logging")
|
logging: LoggingConfig | None = Field(default=None, description="Configuration for Llama Stack Logging")
|
||||||
|
|
||||||
server: ServerConfig = Field(
|
server: ServerConfig = Field(
|
||||||
default_factory=ServerConfig,
|
default_factory=ServerConfig,
|
||||||
description="Configuration for the HTTP(S) server",
|
description="Configuration for the HTTP(S) server",
|
||||||
)
|
)
|
||||||
|
|
||||||
external_providers_dir: Optional[str] = Field(
|
external_providers_dir: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Path to directory containing external provider implementations. The providers code and dependencies must be installed on the system.",
|
description="Path to directory containing external provider implementations. The providers code and dependencies must be installed on the system.",
|
||||||
)
|
)
|
||||||
|
@ -326,3 +318,12 @@ class BuildConfig(BaseModel):
|
||||||
default="conda",
|
default="conda",
|
||||||
description="Type of package to build (conda | container | venv)",
|
description="Type of package to build (conda | container | venv)",
|
||||||
)
|
)
|
||||||
|
image_name: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Name of the distribution to build",
|
||||||
|
)
|
||||||
|
external_providers_dir: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Path to directory containing external provider implementations. The providers packages will be resolved from this directory. "
|
||||||
|
"pip_packages MUST contain the provider package name.",
|
||||||
|
)
|
||||||
|
|
|
@ -7,12 +7,11 @@
|
||||||
import glob
|
import glob
|
||||||
import importlib
|
import importlib
|
||||||
import os
|
import os
|
||||||
from typing import Any, Dict, List
|
from typing import Any
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import StackRunConfig
|
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import (
|
from llama_stack.providers.datatypes import (
|
||||||
AdapterSpec,
|
AdapterSpec,
|
||||||
|
@ -25,7 +24,7 @@ from llama_stack.providers.datatypes import (
|
||||||
logger = get_logger(name=__name__, category="core")
|
logger = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
def stack_apis() -> List[Api]:
|
def stack_apis() -> list[Api]:
|
||||||
return list(Api)
|
return list(Api)
|
||||||
|
|
||||||
|
|
||||||
|
@ -34,7 +33,7 @@ class AutoRoutedApiInfo(BaseModel):
|
||||||
router_api: Api
|
router_api: Api
|
||||||
|
|
||||||
|
|
||||||
def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]:
|
def builtin_automatically_routed_apis() -> list[AutoRoutedApiInfo]:
|
||||||
return [
|
return [
|
||||||
AutoRoutedApiInfo(
|
AutoRoutedApiInfo(
|
||||||
routing_table_api=Api.models,
|
routing_table_api=Api.models,
|
||||||
|
@ -67,12 +66,12 @@ def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]:
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def providable_apis() -> List[Api]:
|
def providable_apis() -> list[Api]:
|
||||||
routing_table_apis = {x.routing_table_api for x in builtin_automatically_routed_apis()}
|
routing_table_apis = {x.routing_table_api for x in builtin_automatically_routed_apis()}
|
||||||
return [api for api in Api if api not in routing_table_apis and api != Api.inspect and api != Api.providers]
|
return [api for api in Api if api not in routing_table_apis and api != Api.inspect and api != Api.providers]
|
||||||
|
|
||||||
|
|
||||||
def _load_remote_provider_spec(spec_data: Dict[str, Any], api: Api) -> ProviderSpec:
|
def _load_remote_provider_spec(spec_data: dict[str, Any], api: Api) -> ProviderSpec:
|
||||||
adapter = AdapterSpec(**spec_data["adapter"])
|
adapter = AdapterSpec(**spec_data["adapter"])
|
||||||
spec = remote_provider_spec(
|
spec = remote_provider_spec(
|
||||||
api=api,
|
api=api,
|
||||||
|
@ -82,7 +81,7 @@ def _load_remote_provider_spec(spec_data: Dict[str, Any], api: Api) -> ProviderS
|
||||||
return spec
|
return spec
|
||||||
|
|
||||||
|
|
||||||
def _load_inline_provider_spec(spec_data: Dict[str, Any], api: Api, provider_name: str) -> ProviderSpec:
|
def _load_inline_provider_spec(spec_data: dict[str, Any], api: Api, provider_name: str) -> ProviderSpec:
|
||||||
spec = InlineProviderSpec(
|
spec = InlineProviderSpec(
|
||||||
api=api,
|
api=api,
|
||||||
provider_type=f"inline::{provider_name}",
|
provider_type=f"inline::{provider_name}",
|
||||||
|
@ -97,7 +96,9 @@ def _load_inline_provider_spec(spec_data: Dict[str, Any], api: Api, provider_nam
|
||||||
return spec
|
return spec
|
||||||
|
|
||||||
|
|
||||||
def get_provider_registry(config: StackRunConfig | None = None) -> Dict[Api, Dict[str, ProviderSpec]]:
|
def get_provider_registry(
|
||||||
|
config=None,
|
||||||
|
) -> dict[Api, dict[str, ProviderSpec]]:
|
||||||
"""Get the provider registry, optionally including external providers.
|
"""Get the provider registry, optionally including external providers.
|
||||||
|
|
||||||
This function loads both built-in providers and external providers from YAML files.
|
This function loads both built-in providers and external providers from YAML files.
|
||||||
|
@ -122,7 +123,7 @@ def get_provider_registry(config: StackRunConfig | None = None) -> Dict[Api, Dic
|
||||||
llama-guard.yaml
|
llama-guard.yaml
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config: Optional StackRunConfig containing the external providers directory path
|
config: Optional object containing the external providers directory path
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A dictionary mapping APIs to their available providers
|
A dictionary mapping APIs to their available providers
|
||||||
|
@ -132,7 +133,7 @@ def get_provider_registry(config: StackRunConfig | None = None) -> Dict[Api, Dic
|
||||||
ValueError: If any provider spec is invalid
|
ValueError: If any provider spec is invalid
|
||||||
"""
|
"""
|
||||||
|
|
||||||
ret: Dict[Api, Dict[str, ProviderSpec]] = {}
|
ret: dict[Api, dict[str, ProviderSpec]] = {}
|
||||||
for api in providable_apis():
|
for api in providable_apis():
|
||||||
name = api.name.lower()
|
name = api.name.lower()
|
||||||
logger.debug(f"Importing module {name}")
|
logger.debug(f"Importing module {name}")
|
||||||
|
@ -142,7 +143,8 @@ def get_provider_registry(config: StackRunConfig | None = None) -> Dict[Api, Dic
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
logger.warning(f"Failed to import module {name}: {e}")
|
logger.warning(f"Failed to import module {name}: {e}")
|
||||||
|
|
||||||
if config and config.external_providers_dir:
|
# Check if config has the external_providers_dir attribute
|
||||||
|
if config and hasattr(config, "external_providers_dir") and config.external_providers_dir:
|
||||||
external_providers_dir = os.path.abspath(config.external_providers_dir)
|
external_providers_dir = os.path.abspath(config.external_providers_dir)
|
||||||
if not os.path.exists(external_providers_dir):
|
if not os.path.exists(external_providers_dir):
|
||||||
raise FileNotFoundError(f"External providers directory not found: {external_providers_dir}")
|
raise FileNotFoundError(f"External providers directory not found: {external_providers_dir}")
|
||||||
|
|
|
@ -12,7 +12,7 @@ import os
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Optional, TypeVar, Union, get_args, get_origin
|
from typing import Any, TypeVar, Union, get_args, get_origin
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
import yaml
|
import yaml
|
||||||
|
@ -119,8 +119,8 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
|
||||||
self,
|
self,
|
||||||
config_path_or_template_name: str,
|
config_path_or_template_name: str,
|
||||||
skip_logger_removal: bool = False,
|
skip_logger_removal: bool = False,
|
||||||
custom_provider_registry: Optional[ProviderRegistry] = None,
|
custom_provider_registry: ProviderRegistry | None = None,
|
||||||
provider_data: Optional[dict[str, Any]] = None,
|
provider_data: dict[str, Any] | None = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.async_client = AsyncLlamaStackAsLibraryClient(
|
self.async_client = AsyncLlamaStackAsLibraryClient(
|
||||||
|
@ -181,8 +181,8 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config_path_or_template_name: str,
|
config_path_or_template_name: str,
|
||||||
custom_provider_registry: Optional[ProviderRegistry] = None,
|
custom_provider_registry: ProviderRegistry | None = None,
|
||||||
provider_data: Optional[dict[str, Any]] = None,
|
provider_data: dict[str, Any] | None = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# when using the library client, we should not log to console since many
|
# when using the library client, we should not log to console since many
|
||||||
|
@ -371,7 +371,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
)
|
)
|
||||||
return await response.parse()
|
return await response.parse()
|
||||||
|
|
||||||
def _convert_body(self, path: str, method: str, body: Optional[dict] = None) -> dict:
|
def _convert_body(self, path: str, method: str, body: dict | None = None) -> dict:
|
||||||
if not body:
|
if not body:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import Any, Dict
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
@ -73,14 +73,14 @@ class ProviderImpl(Providers):
|
||||||
|
|
||||||
raise ValueError(f"Provider {provider_id} not found")
|
raise ValueError(f"Provider {provider_id} not found")
|
||||||
|
|
||||||
async def get_providers_health(self) -> Dict[str, Dict[str, HealthResponse]]:
|
async def get_providers_health(self) -> dict[str, dict[str, HealthResponse]]:
|
||||||
"""Get health status for all providers.
|
"""Get health status for all providers.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict[str, Dict[str, HealthResponse]]: A dictionary mapping API names to provider health statuses.
|
Dict[str, Dict[str, HealthResponse]]: A dictionary mapping API names to provider health statuses.
|
||||||
Each API maps to a dictionary of provider IDs to their health responses.
|
Each API maps to a dictionary of provider IDs to their health responses.
|
||||||
"""
|
"""
|
||||||
providers_health: Dict[str, Dict[str, HealthResponse]] = {}
|
providers_health: dict[str, dict[str, HealthResponse]] = {}
|
||||||
timeout = 1.0
|
timeout = 1.0
|
||||||
|
|
||||||
async def check_provider_health(impl: Any) -> tuple[str, HealthResponse] | None:
|
async def check_provider_health(impl: Any) -> tuple[str, HealthResponse] | None:
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue