mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
Merge branch 'main' into HuggingfacePostTrainingConfig-branch
This commit is contained in:
commit
d0d737680f
193 changed files with 7108 additions and 881 deletions
12
.github/dependabot.yml
vendored
12
.github/dependabot.yml
vendored
|
@ -9,6 +9,7 @@ updates:
|
||||||
day: "saturday"
|
day: "saturday"
|
||||||
commit-message:
|
commit-message:
|
||||||
prefix: chore(github-deps)
|
prefix: chore(github-deps)
|
||||||
|
|
||||||
- package-ecosystem: "uv"
|
- package-ecosystem: "uv"
|
||||||
directory: "/"
|
directory: "/"
|
||||||
schedule:
|
schedule:
|
||||||
|
@ -19,3 +20,14 @@ updates:
|
||||||
- python
|
- python
|
||||||
commit-message:
|
commit-message:
|
||||||
prefix: chore(python-deps)
|
prefix: chore(python-deps)
|
||||||
|
|
||||||
|
- package-ecosystem: npm
|
||||||
|
directory: "/llama_stack/ui"
|
||||||
|
schedule:
|
||||||
|
interval: "weekly"
|
||||||
|
day: "saturday"
|
||||||
|
labels:
|
||||||
|
- type/dependencies
|
||||||
|
- javascript
|
||||||
|
commit-message:
|
||||||
|
prefix: chore(ui-deps)
|
||||||
|
|
2
.github/workflows/changelog.yml
vendored
2
.github/workflows/changelog.yml
vendored
|
@ -17,7 +17,7 @@ jobs:
|
||||||
pull-requests: write # for peter-evans/create-pull-request to create a PR
|
pull-requests: write # for peter-evans/create-pull-request to create a PR
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
- uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||||
with:
|
with:
|
||||||
ref: main
|
ref: main
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
|
|
4
.github/workflows/install-script-ci.yml
vendored
4
.github/workflows/install-script-ci.yml
vendored
|
@ -16,14 +16,14 @@ jobs:
|
||||||
lint:
|
lint:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # 4.2.2
|
- uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # 5.0.0
|
||||||
- name: Run ShellCheck on install.sh
|
- name: Run ShellCheck on install.sh
|
||||||
run: shellcheck scripts/install.sh
|
run: shellcheck scripts/install.sh
|
||||||
smoke-test-on-dev:
|
smoke-test-on-dev:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
uses: ./.github/actions/setup-runner
|
uses: ./.github/actions/setup-runner
|
||||||
|
|
4
.github/workflows/integration-auth-tests.yml
vendored
4
.github/workflows/integration-auth-tests.yml
vendored
|
@ -18,7 +18,7 @@ on:
|
||||||
- '.github/workflows/integration-auth-tests.yml' # This workflow
|
- '.github/workflows/integration-auth-tests.yml' # This workflow
|
||||||
|
|
||||||
concurrency:
|
concurrency:
|
||||||
group: ${{ github.workflow }}-${{ github.ref }}
|
group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }}
|
||||||
cancel-in-progress: true
|
cancel-in-progress: true
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
|
@ -31,7 +31,7 @@ jobs:
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
uses: ./.github/actions/setup-runner
|
uses: ./.github/actions/setup-runner
|
||||||
|
|
|
@ -16,7 +16,7 @@ on:
|
||||||
- '.github/workflows/integration-sql-store-tests.yml' # This workflow
|
- '.github/workflows/integration-sql-store-tests.yml' # This workflow
|
||||||
|
|
||||||
concurrency:
|
concurrency:
|
||||||
group: ${{ github.workflow }}-${{ github.ref }}
|
group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }}
|
||||||
cancel-in-progress: true
|
cancel-in-progress: true
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
|
@ -44,7 +44,7 @@ jobs:
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
uses: ./.github/actions/setup-runner
|
uses: ./.github/actions/setup-runner
|
||||||
|
|
2
.github/workflows/integration-tests.yml
vendored
2
.github/workflows/integration-tests.yml
vendored
|
@ -65,7 +65,7 @@ jobs:
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||||
|
|
||||||
- name: Setup test environment
|
- name: Setup test environment
|
||||||
uses: ./.github/actions/setup-test-environment
|
uses: ./.github/actions/setup-test-environment
|
||||||
|
|
|
@ -33,7 +33,7 @@ jobs:
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
uses: ./.github/actions/setup-runner
|
uses: ./.github/actions/setup-runner
|
||||||
|
|
15
.github/workflows/pre-commit.yml
vendored
15
.github/workflows/pre-commit.yml
vendored
|
@ -8,7 +8,7 @@ on:
|
||||||
branches: [main]
|
branches: [main]
|
||||||
|
|
||||||
concurrency:
|
concurrency:
|
||||||
group: ${{ github.workflow }}-${{ github.ref }}
|
group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }}
|
||||||
cancel-in-progress: true
|
cancel-in-progress: true
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
|
@ -20,7 +20,7 @@ jobs:
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||||
with:
|
with:
|
||||||
# For dependabot PRs, we need to checkout with a token that can push changes
|
# For dependabot PRs, we need to checkout with a token that can push changes
|
||||||
token: ${{ github.actor == 'dependabot[bot]' && secrets.GITHUB_TOKEN || github.token }}
|
token: ${{ github.actor == 'dependabot[bot]' && secrets.GITHUB_TOKEN || github.token }}
|
||||||
|
@ -36,6 +36,17 @@ jobs:
|
||||||
**/requirements*.txt
|
**/requirements*.txt
|
||||||
.pre-commit-config.yaml
|
.pre-commit-config.yaml
|
||||||
|
|
||||||
|
- name: Set up Node.js
|
||||||
|
uses: actions/setup-node@39370e3970a6d050c480ffad4ff0ed4d3fdee5af # v4.1.0
|
||||||
|
with:
|
||||||
|
node-version: '20'
|
||||||
|
cache: 'npm'
|
||||||
|
cache-dependency-path: 'llama_stack/ui/'
|
||||||
|
|
||||||
|
- name: Install npm dependencies
|
||||||
|
run: npm ci
|
||||||
|
working-directory: llama_stack/ui
|
||||||
|
|
||||||
- uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1
|
- uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1
|
||||||
continue-on-error: true
|
continue-on-error: true
|
||||||
env:
|
env:
|
||||||
|
|
20
.github/workflows/providers-build.yml
vendored
20
.github/workflows/providers-build.yml
vendored
|
@ -26,7 +26,7 @@ on:
|
||||||
- 'pyproject.toml'
|
- 'pyproject.toml'
|
||||||
|
|
||||||
concurrency:
|
concurrency:
|
||||||
group: ${{ github.workflow }}-${{ github.ref }}
|
group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }}
|
||||||
cancel-in-progress: true
|
cancel-in-progress: true
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
|
@ -36,7 +36,7 @@ jobs:
|
||||||
distros: ${{ steps.set-matrix.outputs.distros }}
|
distros: ${{ steps.set-matrix.outputs.distros }}
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||||
|
|
||||||
- name: Generate Distribution List
|
- name: Generate Distribution List
|
||||||
id: set-matrix
|
id: set-matrix
|
||||||
|
@ -55,7 +55,7 @@ jobs:
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
uses: ./.github/actions/setup-runner
|
uses: ./.github/actions/setup-runner
|
||||||
|
@ -79,7 +79,7 @@ jobs:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
uses: ./.github/actions/setup-runner
|
uses: ./.github/actions/setup-runner
|
||||||
|
@ -92,7 +92,7 @@ jobs:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
uses: ./.github/actions/setup-runner
|
uses: ./.github/actions/setup-runner
|
||||||
|
@ -106,6 +106,10 @@ jobs:
|
||||||
- name: Inspect the container image entrypoint
|
- name: Inspect the container image entrypoint
|
||||||
run: |
|
run: |
|
||||||
IMAGE_ID=$(docker images --format "{{.Repository}}:{{.Tag}}" | head -n 1)
|
IMAGE_ID=$(docker images --format "{{.Repository}}:{{.Tag}}" | head -n 1)
|
||||||
|
if [ -z "$IMAGE_ID" ]; then
|
||||||
|
echo "No image found"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
entrypoint=$(docker inspect --format '{{ .Config.Entrypoint }}' $IMAGE_ID)
|
entrypoint=$(docker inspect --format '{{ .Config.Entrypoint }}' $IMAGE_ID)
|
||||||
echo "Entrypoint: $entrypoint"
|
echo "Entrypoint: $entrypoint"
|
||||||
if [ "$entrypoint" != "[python -m llama_stack.core.server.server /app/run.yaml]" ]; then
|
if [ "$entrypoint" != "[python -m llama_stack.core.server.server /app/run.yaml]" ]; then
|
||||||
|
@ -117,7 +121,7 @@ jobs:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
uses: ./.github/actions/setup-runner
|
uses: ./.github/actions/setup-runner
|
||||||
|
@ -140,6 +144,10 @@ jobs:
|
||||||
- name: Inspect UBI9 image
|
- name: Inspect UBI9 image
|
||||||
run: |
|
run: |
|
||||||
IMAGE_ID=$(docker images --format "{{.Repository}}:{{.Tag}}" | head -n 1)
|
IMAGE_ID=$(docker images --format "{{.Repository}}:{{.Tag}}" | head -n 1)
|
||||||
|
if [ -z "$IMAGE_ID" ]; then
|
||||||
|
echo "No image found"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
entrypoint=$(docker inspect --format '{{ .Config.Entrypoint }}' $IMAGE_ID)
|
entrypoint=$(docker inspect --format '{{ .Config.Entrypoint }}' $IMAGE_ID)
|
||||||
echo "Entrypoint: $entrypoint"
|
echo "Entrypoint: $entrypoint"
|
||||||
if [ "$entrypoint" != "[python -m llama_stack.core.server.server /app/run.yaml]" ]; then
|
if [ "$entrypoint" != "[python -m llama_stack.core.server.server /app/run.yaml]" ]; then
|
||||||
|
|
4
.github/workflows/python-build-test.yml
vendored
4
.github/workflows/python-build-test.yml
vendored
|
@ -21,10 +21,10 @@ jobs:
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||||
|
|
||||||
- name: Install uv
|
- name: Install uv
|
||||||
uses: astral-sh/setup-uv@e92bafb6253dcd438e0484186d7669ea7a8ca1cc # v6.4.3
|
uses: astral-sh/setup-uv@d9e0f98d3fc6adb07d1e3d37f3043649ddad06a1 # v6.5.0
|
||||||
with:
|
with:
|
||||||
python-version: ${{ matrix.python-version }}
|
python-version: ${{ matrix.python-version }}
|
||||||
activate-environment: true
|
activate-environment: true
|
||||||
|
|
|
@ -46,7 +46,7 @@ jobs:
|
||||||
echo "::endgroup::"
|
echo "::endgroup::"
|
||||||
|
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
|
|
||||||
|
|
2
.github/workflows/semantic-pr.yml
vendored
2
.github/workflows/semantic-pr.yml
vendored
|
@ -22,6 +22,6 @@ jobs:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Check PR Title's semantic conformance
|
- name: Check PR Title's semantic conformance
|
||||||
uses: amannn/action-semantic-pull-request@0723387faaf9b38adef4775cd42cfd5155ed6017 # v5.5.3
|
uses: amannn/action-semantic-pull-request@7f33ba792281b034f64e96f4c0b5496782dd3b37 # v6.1.0
|
||||||
env:
|
env:
|
||||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
|
|
@ -27,7 +27,7 @@ jobs:
|
||||||
# container and point 'uv pip install' to the correct path...
|
# container and point 'uv pip install' to the correct path...
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
uses: ./.github/actions/setup-runner
|
uses: ./.github/actions/setup-runner
|
||||||
|
|
2
.github/workflows/test-external.yml
vendored
2
.github/workflows/test-external.yml
vendored
|
@ -27,7 +27,7 @@ jobs:
|
||||||
# container and point 'uv pip install' to the correct path...
|
# container and point 'uv pip install' to the correct path...
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
uses: ./.github/actions/setup-runner
|
uses: ./.github/actions/setup-runner
|
||||||
|
|
6
.github/workflows/ui-unit-tests.yml
vendored
6
.github/workflows/ui-unit-tests.yml
vendored
|
@ -13,7 +13,7 @@ on:
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
|
|
||||||
concurrency:
|
concurrency:
|
||||||
group: ${{ github.workflow }}-${{ github.ref }}
|
group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }}
|
||||||
cancel-in-progress: true
|
cancel-in-progress: true
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
|
@ -26,10 +26,10 @@ jobs:
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||||
|
|
||||||
- name: Setup Node.js
|
- name: Setup Node.js
|
||||||
uses: actions/setup-node@39370e3970a6d050c480ffad4ff0ed4d3fdee5af # v4.1.0
|
uses: actions/setup-node@49933ea5288caeca8642d1e84afbd3f7d6820020 # v4.4.0
|
||||||
with:
|
with:
|
||||||
node-version: ${{ matrix.node-version }}
|
node-version: ${{ matrix.node-version }}
|
||||||
cache: 'npm'
|
cache: 'npm'
|
||||||
|
|
4
.github/workflows/unit-tests.yml
vendored
4
.github/workflows/unit-tests.yml
vendored
|
@ -18,7 +18,7 @@ on:
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
|
|
||||||
concurrency:
|
concurrency:
|
||||||
group: ${{ github.workflow }}-${{ github.ref }}
|
group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }}
|
||||||
cancel-in-progress: true
|
cancel-in-progress: true
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
|
@ -32,7 +32,7 @@ jobs:
|
||||||
- "3.13"
|
- "3.13"
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
uses: ./.github/actions/setup-runner
|
uses: ./.github/actions/setup-runner
|
||||||
|
|
4
.github/workflows/update-readthedocs.yml
vendored
4
.github/workflows/update-readthedocs.yml
vendored
|
@ -27,7 +27,7 @@ on:
|
||||||
- '.github/workflows/update-readthedocs.yml'
|
- '.github/workflows/update-readthedocs.yml'
|
||||||
|
|
||||||
concurrency:
|
concurrency:
|
||||||
group: ${{ github.workflow }}-${{ github.ref }}
|
group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }}
|
||||||
cancel-in-progress: true
|
cancel-in-progress: true
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
|
@ -37,7 +37,7 @@ jobs:
|
||||||
TOKEN: ${{ secrets.READTHEDOCS_TOKEN }}
|
TOKEN: ${{ secrets.READTHEDOCS_TOKEN }}
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
uses: ./.github/actions/setup-runner
|
uses: ./.github/actions/setup-runner
|
||||||
|
|
|
@ -146,20 +146,32 @@ repos:
|
||||||
pass_filenames: false
|
pass_filenames: false
|
||||||
require_serial: true
|
require_serial: true
|
||||||
files: ^.github/workflows/.*$
|
files: ^.github/workflows/.*$
|
||||||
- id: ui-prettier
|
- id: ui-linter
|
||||||
name: Format UI code with Prettier
|
name: Format & Lint UI
|
||||||
entry: bash -c 'cd llama_stack/ui && npm run format'
|
entry: bash ./scripts/run-ui-linter.sh
|
||||||
language: system
|
language: system
|
||||||
files: ^llama_stack/ui/.*\.(ts|tsx)$
|
files: ^llama_stack/ui/.*\.(ts|tsx)$
|
||||||
pass_filenames: false
|
pass_filenames: false
|
||||||
require_serial: true
|
require_serial: true
|
||||||
- id: ui-eslint
|
|
||||||
name: Lint UI code with ESLint
|
- id: check-log-usage
|
||||||
entry: bash -c 'cd llama_stack/ui && npm run lint -- --fix --quiet'
|
name: Ensure 'llama_stack.log' usage for logging
|
||||||
|
entry: bash
|
||||||
language: system
|
language: system
|
||||||
files: ^llama_stack/ui/.*\.(ts|tsx)$
|
types: [python]
|
||||||
pass_filenames: false
|
pass_filenames: true
|
||||||
require_serial: true
|
args:
|
||||||
|
- -c
|
||||||
|
- |
|
||||||
|
matches=$(grep -EnH '^[^#]*\b(import\s+logging|from\s+logging\b)' "$@" | grep -v -e '#\s*allow-direct-logging' || true)
|
||||||
|
if [ -n "$matches" ]; then
|
||||||
|
# GitHub Actions annotation format
|
||||||
|
while IFS=: read -r file line_num rest; do
|
||||||
|
echo "::error file=$file,line=$line_num::Do not use 'import logging' or 'from logging import' in $file. Use the custom log instead: from llama_stack.log import get_logger; logger = get_logger(). If direct logging is truly needed, add: # allow-direct-logging"
|
||||||
|
done <<< "$matches"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
exit 0
|
||||||
|
|
||||||
ci:
|
ci:
|
||||||
autofix_commit_msg: 🎨 [pre-commit.ci] Auto format from pre-commit.com hooks
|
autofix_commit_msg: 🎨 [pre-commit.ci] Auto format from pre-commit.com hooks
|
||||||
|
|
138
docs/_static/llama-stack-spec.html
vendored
138
docs/_static/llama-stack-spec.html
vendored
|
@ -4605,6 +4605,49 @@
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"/v1/inference/rerank": {
|
||||||
|
"post": {
|
||||||
|
"responses": {
|
||||||
|
"200": {
|
||||||
|
"description": "RerankResponse with indices sorted by relevance score (descending).",
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/RerankResponse"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"400": {
|
||||||
|
"$ref": "#/components/responses/BadRequest400"
|
||||||
|
},
|
||||||
|
"429": {
|
||||||
|
"$ref": "#/components/responses/TooManyRequests429"
|
||||||
|
},
|
||||||
|
"500": {
|
||||||
|
"$ref": "#/components/responses/InternalServerError500"
|
||||||
|
},
|
||||||
|
"default": {
|
||||||
|
"$ref": "#/components/responses/DefaultError"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"tags": [
|
||||||
|
"Inference"
|
||||||
|
],
|
||||||
|
"description": "Rerank a list of documents based on their relevance to a query.",
|
||||||
|
"parameters": [],
|
||||||
|
"requestBody": {
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/RerankRequest"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
"/v1/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume": {
|
"/v1/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume": {
|
||||||
"post": {
|
"post": {
|
||||||
"responses": {
|
"responses": {
|
||||||
|
@ -16024,12 +16067,16 @@
|
||||||
"value": {
|
"value": {
|
||||||
"type": "number",
|
"type": "number",
|
||||||
"description": "The numeric value of the metric at this timestamp"
|
"description": "The numeric value of the metric at this timestamp"
|
||||||
|
},
|
||||||
|
"unit": {
|
||||||
|
"type": "string"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
"required": [
|
"required": [
|
||||||
"timestamp",
|
"timestamp",
|
||||||
"value"
|
"value",
|
||||||
|
"unit"
|
||||||
],
|
],
|
||||||
"title": "MetricDataPoint",
|
"title": "MetricDataPoint",
|
||||||
"description": "A single data point in a metric time series."
|
"description": "A single data point in a metric time series."
|
||||||
|
@ -16587,6 +16634,95 @@
|
||||||
],
|
],
|
||||||
"title": "RegisterVectorDbRequest"
|
"title": "RegisterVectorDbRequest"
|
||||||
},
|
},
|
||||||
|
"RerankRequest": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"model": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The identifier of the reranking model to use."
|
||||||
|
},
|
||||||
|
"query": {
|
||||||
|
"oneOf": [
|
||||||
|
{
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"$ref": "#/components/schemas/OpenAIChatCompletionContentPartTextParam"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"$ref": "#/components/schemas/OpenAIChatCompletionContentPartImageParam"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"description": "The search query to rank items against. Can be a string, text content part, or image content part. The input must not exceed the model's max input token length."
|
||||||
|
},
|
||||||
|
"items": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"oneOf": [
|
||||||
|
{
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"$ref": "#/components/schemas/OpenAIChatCompletionContentPartTextParam"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"$ref": "#/components/schemas/OpenAIChatCompletionContentPartImageParam"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"description": "List of items to rerank. Each item can be a string, text content part, or image content part. Each input must not exceed the model's max input token length."
|
||||||
|
},
|
||||||
|
"max_num_results": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "(Optional) Maximum number of results to return. Default: returns all."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"model",
|
||||||
|
"query",
|
||||||
|
"items"
|
||||||
|
],
|
||||||
|
"title": "RerankRequest"
|
||||||
|
},
|
||||||
|
"RerankData": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"index": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "The original index of the document in the input list"
|
||||||
|
},
|
||||||
|
"relevance_score": {
|
||||||
|
"type": "number",
|
||||||
|
"description": "The relevance score from the model output. Values are inverted when applicable so that higher scores indicate greater relevance."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"index",
|
||||||
|
"relevance_score"
|
||||||
|
],
|
||||||
|
"title": "RerankData",
|
||||||
|
"description": "A single rerank result from a reranking response."
|
||||||
|
},
|
||||||
|
"RerankResponse": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"data": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"$ref": "#/components/schemas/RerankData"
|
||||||
|
},
|
||||||
|
"description": "List of rerank result objects, sorted by relevance score (descending)"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"data"
|
||||||
|
],
|
||||||
|
"title": "RerankResponse",
|
||||||
|
"description": "Response from a reranking request."
|
||||||
|
},
|
||||||
"ResumeAgentTurnRequest": {
|
"ResumeAgentTurnRequest": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
|
|
104
docs/_static/llama-stack-spec.yaml
vendored
104
docs/_static/llama-stack-spec.yaml
vendored
|
@ -3264,6 +3264,37 @@ paths:
|
||||||
schema:
|
schema:
|
||||||
$ref: '#/components/schemas/QueryTracesRequest'
|
$ref: '#/components/schemas/QueryTracesRequest'
|
||||||
required: true
|
required: true
|
||||||
|
/v1/inference/rerank:
|
||||||
|
post:
|
||||||
|
responses:
|
||||||
|
'200':
|
||||||
|
description: >-
|
||||||
|
RerankResponse with indices sorted by relevance score (descending).
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: '#/components/schemas/RerankResponse'
|
||||||
|
'400':
|
||||||
|
$ref: '#/components/responses/BadRequest400'
|
||||||
|
'429':
|
||||||
|
$ref: >-
|
||||||
|
#/components/responses/TooManyRequests429
|
||||||
|
'500':
|
||||||
|
$ref: >-
|
||||||
|
#/components/responses/InternalServerError500
|
||||||
|
default:
|
||||||
|
$ref: '#/components/responses/DefaultError'
|
||||||
|
tags:
|
||||||
|
- Inference
|
||||||
|
description: >-
|
||||||
|
Rerank a list of documents based on their relevance to a query.
|
||||||
|
parameters: []
|
||||||
|
requestBody:
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: '#/components/schemas/RerankRequest'
|
||||||
|
required: true
|
||||||
/v1/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume:
|
/v1/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume:
|
||||||
post:
|
post:
|
||||||
responses:
|
responses:
|
||||||
|
@ -11923,10 +11954,13 @@ components:
|
||||||
type: number
|
type: number
|
||||||
description: >-
|
description: >-
|
||||||
The numeric value of the metric at this timestamp
|
The numeric value of the metric at this timestamp
|
||||||
|
unit:
|
||||||
|
type: string
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
required:
|
required:
|
||||||
- timestamp
|
- timestamp
|
||||||
- value
|
- value
|
||||||
|
- unit
|
||||||
title: MetricDataPoint
|
title: MetricDataPoint
|
||||||
description: >-
|
description: >-
|
||||||
A single data point in a metric time series.
|
A single data point in a metric time series.
|
||||||
|
@ -12337,6 +12371,76 @@ components:
|
||||||
- vector_db_id
|
- vector_db_id
|
||||||
- embedding_model
|
- embedding_model
|
||||||
title: RegisterVectorDbRequest
|
title: RegisterVectorDbRequest
|
||||||
|
RerankRequest:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
model:
|
||||||
|
type: string
|
||||||
|
description: >-
|
||||||
|
The identifier of the reranking model to use.
|
||||||
|
query:
|
||||||
|
oneOf:
|
||||||
|
- type: string
|
||||||
|
- $ref: '#/components/schemas/OpenAIChatCompletionContentPartTextParam'
|
||||||
|
- $ref: '#/components/schemas/OpenAIChatCompletionContentPartImageParam'
|
||||||
|
description: >-
|
||||||
|
The search query to rank items against. Can be a string, text content
|
||||||
|
part, or image content part. The input must not exceed the model's max
|
||||||
|
input token length.
|
||||||
|
items:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
oneOf:
|
||||||
|
- type: string
|
||||||
|
- $ref: '#/components/schemas/OpenAIChatCompletionContentPartTextParam'
|
||||||
|
- $ref: '#/components/schemas/OpenAIChatCompletionContentPartImageParam'
|
||||||
|
description: >-
|
||||||
|
List of items to rerank. Each item can be a string, text content part,
|
||||||
|
or image content part. Each input must not exceed the model's max input
|
||||||
|
token length.
|
||||||
|
max_num_results:
|
||||||
|
type: integer
|
||||||
|
description: >-
|
||||||
|
(Optional) Maximum number of results to return. Default: returns all.
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- model
|
||||||
|
- query
|
||||||
|
- items
|
||||||
|
title: RerankRequest
|
||||||
|
RerankData:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
index:
|
||||||
|
type: integer
|
||||||
|
description: >-
|
||||||
|
The original index of the document in the input list
|
||||||
|
relevance_score:
|
||||||
|
type: number
|
||||||
|
description: >-
|
||||||
|
The relevance score from the model output. Values are inverted when applicable
|
||||||
|
so that higher scores indicate greater relevance.
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- index
|
||||||
|
- relevance_score
|
||||||
|
title: RerankData
|
||||||
|
description: >-
|
||||||
|
A single rerank result from a reranking response.
|
||||||
|
RerankResponse:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
data:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
$ref: '#/components/schemas/RerankData'
|
||||||
|
description: >-
|
||||||
|
List of rerank result objects, sorted by relevance score (descending)
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- data
|
||||||
|
title: RerankResponse
|
||||||
|
description: Response from a reranking request.
|
||||||
ResumeAgentTurnRequest:
|
ResumeAgentTurnRequest:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
|
|
|
@ -225,8 +225,32 @@ server:
|
||||||
port: 8321 # Port to listen on (default: 8321)
|
port: 8321 # Port to listen on (default: 8321)
|
||||||
tls_certfile: "/path/to/cert.pem" # Optional: Path to TLS certificate for HTTPS
|
tls_certfile: "/path/to/cert.pem" # Optional: Path to TLS certificate for HTTPS
|
||||||
tls_keyfile: "/path/to/key.pem" # Optional: Path to TLS key for HTTPS
|
tls_keyfile: "/path/to/key.pem" # Optional: Path to TLS key for HTTPS
|
||||||
|
cors: true # Optional: Enable CORS (dev mode) or full config object
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### CORS Configuration
|
||||||
|
|
||||||
|
CORS (Cross-Origin Resource Sharing) can be configured in two ways:
|
||||||
|
|
||||||
|
**Local development** (allows localhost origins only):
|
||||||
|
```yaml
|
||||||
|
server:
|
||||||
|
cors: true
|
||||||
|
```
|
||||||
|
|
||||||
|
**Explicit configuration** (custom origins and settings):
|
||||||
|
```yaml
|
||||||
|
server:
|
||||||
|
cors:
|
||||||
|
allow_origins: ["https://myapp.com", "https://app.example.com"]
|
||||||
|
allow_methods: ["GET", "POST", "PUT", "DELETE"]
|
||||||
|
allow_headers: ["Content-Type", "Authorization"]
|
||||||
|
allow_credentials: true
|
||||||
|
max_age: 3600
|
||||||
|
```
|
||||||
|
|
||||||
|
When `cors: true`, the server enables secure localhost-only access for local development. For production, specify exact origins to maintain security.
|
||||||
|
|
||||||
### Authentication Configuration
|
### Authentication Configuration
|
||||||
|
|
||||||
> **Breaking Change (v0.2.14)**: The authentication configuration structure has changed. The previous format with `provider_type` and `config` fields has been replaced with a unified `provider_config` field that includes the `type` field. Update your configuration files accordingly.
|
> **Breaking Change (v0.2.14)**: The authentication configuration structure has changed. The previous format with `provider_type` and `config` fields has been replaced with a unified `provider_config` field that includes the `type` field. Update your configuration files accordingly.
|
||||||
|
@ -618,6 +642,54 @@ Content-Type: application/json
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### CORS Configuration
|
||||||
|
|
||||||
|
Configure CORS to allow web browsers to make requests from different domains. Disabled by default.
|
||||||
|
|
||||||
|
#### Quick Setup
|
||||||
|
|
||||||
|
For development, use the simple boolean flag:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
server:
|
||||||
|
cors: true # Auto-enables localhost with any port
|
||||||
|
```
|
||||||
|
|
||||||
|
This automatically allows `http://localhost:*` and `https://localhost:*` with secure defaults.
|
||||||
|
|
||||||
|
#### Custom Configuration
|
||||||
|
|
||||||
|
For specific origins and full control:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
server:
|
||||||
|
cors:
|
||||||
|
allow_origins: ["https://myapp.com", "https://staging.myapp.com"]
|
||||||
|
allow_credentials: true
|
||||||
|
allow_methods: ["GET", "POST", "PUT", "DELETE"]
|
||||||
|
allow_headers: ["Content-Type", "Authorization"]
|
||||||
|
allow_origin_regex: "https://.*\\.example\\.com" # Optional regex pattern
|
||||||
|
expose_headers: ["X-Total-Count"]
|
||||||
|
max_age: 86400
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Configuration Options
|
||||||
|
|
||||||
|
| Field | Description | Default |
|
||||||
|
| -------------------- | ---------------------------------------------- | ------- |
|
||||||
|
| `allow_origins` | List of allowed origins. Use `["*"]` for any. | `["*"]` |
|
||||||
|
| `allow_origin_regex` | Regex pattern for allowed origins (optional). | `None` |
|
||||||
|
| `allow_methods` | Allowed HTTP methods. | `["*"]` |
|
||||||
|
| `allow_headers` | Allowed headers. | `["*"]` |
|
||||||
|
| `allow_credentials` | Allow credentials (cookies, auth headers). | `false` |
|
||||||
|
| `expose_headers` | Headers exposed to browser. | `[]` |
|
||||||
|
| `max_age` | Preflight cache time (seconds). | `600` |
|
||||||
|
|
||||||
|
**Security Notes**:
|
||||||
|
- `allow_credentials: true` requires explicit origins (no wildcards)
|
||||||
|
- `cors: true` enables localhost access only (secure for development)
|
||||||
|
- For public APIs, always specify exact allowed origins
|
||||||
|
|
||||||
## Extending to handle Safety
|
## Extending to handle Safety
|
||||||
|
|
||||||
Configuring Safety can be a little involved so it is instructive to go through an example.
|
Configuring Safety can be a little involved so it is instructive to go through an example.
|
||||||
|
|
|
@ -17,7 +17,6 @@ client = LlamaStackAsLibraryClient(
|
||||||
# provider_data is optional, but if you need to pass in any provider specific data, you can do so here.
|
# provider_data is optional, but if you need to pass in any provider specific data, you can do so here.
|
||||||
provider_data={"tavily_search_api_key": os.environ["TAVILY_SEARCH_API_KEY"]},
|
provider_data={"tavily_search_api_key": os.environ["TAVILY_SEARCH_API_KEY"]},
|
||||||
)
|
)
|
||||||
client.initialize()
|
|
||||||
```
|
```
|
||||||
|
|
||||||
This will parse your config and set up any inline implementations and remote clients needed for your implementation.
|
This will parse your config and set up any inline implementations and remote clients needed for your implementation.
|
||||||
|
@ -32,5 +31,4 @@ If you've created a [custom distribution](https://llama-stack.readthedocs.io/en/
|
||||||
|
|
||||||
```python
|
```python
|
||||||
client = LlamaStackAsLibraryClient(config_path)
|
client = LlamaStackAsLibraryClient(config_path)
|
||||||
client.initialize()
|
|
||||||
```
|
```
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
#!/usr/bin/env python3
|
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
# All rights reserved.
|
# All rights reserved.
|
||||||
#
|
#
|
||||||
|
|
|
@ -2,12 +2,15 @@
|
||||||
|
|
||||||
## Overview
|
## Overview
|
||||||
|
|
||||||
Protocol for batch processing API operations.
|
The Batches API enables efficient processing of multiple requests in a single operation,
|
||||||
|
|
||||||
The Batches API enables efficient processing of multiple requests in a single operation,
|
|
||||||
particularly useful for processing large datasets, batch evaluation workflows, and
|
particularly useful for processing large datasets, batch evaluation workflows, and
|
||||||
cost-effective inference at scale.
|
cost-effective inference at scale.
|
||||||
|
|
||||||
|
The API is designed to allow use of openai client libraries for seamless integration.
|
||||||
|
|
||||||
|
This API provides the following extensions:
|
||||||
|
- idempotent batch creation
|
||||||
|
|
||||||
Note: This API is currently under active development and may undergo changes.
|
Note: This API is currently under active development and may undergo changes.
|
||||||
|
|
||||||
This section contains documentation for all available providers for the **batches** API.
|
This section contains documentation for all available providers for the **batches** API.
|
||||||
|
|
|
@ -10,4 +10,5 @@ This section contains documentation for all available providers for the **files*
|
||||||
:maxdepth: 1
|
:maxdepth: 1
|
||||||
|
|
||||||
inline_localfs
|
inline_localfs
|
||||||
|
remote_s3
|
||||||
```
|
```
|
||||||
|
|
33
docs/source/providers/files/remote_s3.md
Normal file
33
docs/source/providers/files/remote_s3.md
Normal file
|
@ -0,0 +1,33 @@
|
||||||
|
# remote::s3
|
||||||
|
|
||||||
|
## Description
|
||||||
|
|
||||||
|
AWS S3-based file storage provider for scalable cloud file management with metadata persistence.
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
| Field | Type | Required | Default | Description |
|
||||||
|
|-------|------|----------|---------|-------------|
|
||||||
|
| `bucket_name` | `<class 'str'>` | No | | S3 bucket name to store files |
|
||||||
|
| `region` | `<class 'str'>` | No | us-east-1 | AWS region where the bucket is located |
|
||||||
|
| `aws_access_key_id` | `str \| None` | No | | AWS access key ID (optional if using IAM roles) |
|
||||||
|
| `aws_secret_access_key` | `str \| None` | No | | AWS secret access key (optional if using IAM roles) |
|
||||||
|
| `endpoint_url` | `str \| None` | No | | Custom S3 endpoint URL (for MinIO, LocalStack, etc.) |
|
||||||
|
| `auto_create_bucket` | `<class 'bool'>` | No | False | Automatically create the S3 bucket if it doesn't exist |
|
||||||
|
| `metadata_store` | `utils.sqlstore.sqlstore.SqliteSqlStoreConfig \| utils.sqlstore.sqlstore.PostgresSqlStoreConfig` | No | sqlite | SQL store configuration for file metadata |
|
||||||
|
|
||||||
|
## Sample Configuration
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
bucket_name: ${env.S3_BUCKET_NAME}
|
||||||
|
region: ${env.AWS_REGION:=us-east-1}
|
||||||
|
aws_access_key_id: ${env.AWS_ACCESS_KEY_ID:=}
|
||||||
|
aws_secret_access_key: ${env.AWS_SECRET_ACCESS_KEY:=}
|
||||||
|
endpoint_url: ${env.S3_ENDPOINT_URL:=}
|
||||||
|
auto_create_bucket: ${env.S3_AUTO_CREATE_BUCKET:=false}
|
||||||
|
metadata_store:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/s3_files_metadata.db
|
||||||
|
|
||||||
|
```
|
||||||
|
|
|
@ -9,7 +9,9 @@ This section contains documentation for all available providers for the **post_t
|
||||||
```{toctree}
|
```{toctree}
|
||||||
:maxdepth: 1
|
:maxdepth: 1
|
||||||
|
|
||||||
inline_huggingface
|
inline_huggingface-cpu
|
||||||
inline_torchtune
|
inline_huggingface-gpu
|
||||||
|
inline_torchtune-cpu
|
||||||
|
inline_torchtune-gpu
|
||||||
remote_nvidia
|
remote_nvidia
|
||||||
```
|
```
|
||||||
|
|
|
@ -0,0 +1,41 @@
|
||||||
|
# inline::huggingface-cpu
|
||||||
|
|
||||||
|
## Description
|
||||||
|
|
||||||
|
HuggingFace-based post-training provider for fine-tuning models using the HuggingFace ecosystem.
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
| Field | Type | Required | Default | Description |
|
||||||
|
|-------|------|----------|---------|-------------|
|
||||||
|
| `device` | `<class 'str'>` | No | cuda | |
|
||||||
|
| `distributed_backend` | `Literal['fsdp', 'deepspeed'` | No | | |
|
||||||
|
| `checkpoint_format` | `Literal['full_state', 'huggingface'` | No | huggingface | |
|
||||||
|
| `chat_template` | `<class 'str'>` | No | <|user|>
|
||||||
|
{input}
|
||||||
|
<|assistant|>
|
||||||
|
{output} | |
|
||||||
|
| `model_specific_config` | `<class 'dict'>` | No | {'trust_remote_code': True, 'attn_implementation': 'sdpa'} | |
|
||||||
|
| `max_seq_length` | `<class 'int'>` | No | 2048 | |
|
||||||
|
| `gradient_checkpointing` | `<class 'bool'>` | No | False | |
|
||||||
|
| `save_total_limit` | `<class 'int'>` | No | 3 | |
|
||||||
|
| `logging_steps` | `<class 'int'>` | No | 10 | |
|
||||||
|
| `warmup_ratio` | `<class 'float'>` | No | 0.1 | |
|
||||||
|
| `weight_decay` | `<class 'float'>` | No | 0.01 | |
|
||||||
|
| `dataloader_num_workers` | `<class 'int'>` | No | 4 | |
|
||||||
|
| `dataloader_pin_memory` | `<class 'bool'>` | No | True | |
|
||||||
|
| `dpo_beta` | `<class 'float'>` | No | 0.1 | |
|
||||||
|
| `use_reference_model` | `<class 'bool'>` | No | True | |
|
||||||
|
| `dpo_loss_type` | `Literal['sigmoid', 'hinge', 'ipo', 'kto_pair'` | No | sigmoid | |
|
||||||
|
| `dpo_output_dir` | `<class 'str'>` | No | | |
|
||||||
|
|
||||||
|
## Sample Configuration
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
checkpoint_format: huggingface
|
||||||
|
distributed_backend: null
|
||||||
|
device: cpu
|
||||||
|
dpo_output_dir: ~/.llama/dummy/dpo_output
|
||||||
|
|
||||||
|
```
|
||||||
|
|
|
@ -0,0 +1,41 @@
|
||||||
|
# inline::huggingface-gpu
|
||||||
|
|
||||||
|
## Description
|
||||||
|
|
||||||
|
HuggingFace-based post-training provider for fine-tuning models using the HuggingFace ecosystem.
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
| Field | Type | Required | Default | Description |
|
||||||
|
|-------|------|----------|---------|-------------|
|
||||||
|
| `device` | `<class 'str'>` | No | cuda | |
|
||||||
|
| `distributed_backend` | `Literal['fsdp', 'deepspeed'` | No | | |
|
||||||
|
| `checkpoint_format` | `Literal['full_state', 'huggingface'` | No | huggingface | |
|
||||||
|
| `chat_template` | `<class 'str'>` | No | <|user|>
|
||||||
|
{input}
|
||||||
|
<|assistant|>
|
||||||
|
{output} | |
|
||||||
|
| `model_specific_config` | `<class 'dict'>` | No | {'trust_remote_code': True, 'attn_implementation': 'sdpa'} | |
|
||||||
|
| `max_seq_length` | `<class 'int'>` | No | 2048 | |
|
||||||
|
| `gradient_checkpointing` | `<class 'bool'>` | No | False | |
|
||||||
|
| `save_total_limit` | `<class 'int'>` | No | 3 | |
|
||||||
|
| `logging_steps` | `<class 'int'>` | No | 10 | |
|
||||||
|
| `warmup_ratio` | `<class 'float'>` | No | 0.1 | |
|
||||||
|
| `weight_decay` | `<class 'float'>` | No | 0.01 | |
|
||||||
|
| `dataloader_num_workers` | `<class 'int'>` | No | 4 | |
|
||||||
|
| `dataloader_pin_memory` | `<class 'bool'>` | No | True | |
|
||||||
|
| `dpo_beta` | `<class 'float'>` | No | 0.1 | |
|
||||||
|
| `use_reference_model` | `<class 'bool'>` | No | True | |
|
||||||
|
| `dpo_loss_type` | `Literal['sigmoid', 'hinge', 'ipo', 'kto_pair'` | No | sigmoid | |
|
||||||
|
| `dpo_output_dir` | `<class 'str'>` | No | | |
|
||||||
|
|
||||||
|
## Sample Configuration
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
checkpoint_format: huggingface
|
||||||
|
distributed_backend: null
|
||||||
|
device: cpu
|
||||||
|
dpo_output_dir: ~/.llama/dummy/dpo_output
|
||||||
|
|
||||||
|
```
|
||||||
|
|
20
docs/source/providers/post_training/inline_torchtune-cpu.md
Normal file
20
docs/source/providers/post_training/inline_torchtune-cpu.md
Normal file
|
@ -0,0 +1,20 @@
|
||||||
|
# inline::torchtune-cpu
|
||||||
|
|
||||||
|
## Description
|
||||||
|
|
||||||
|
TorchTune-based post-training provider for fine-tuning and optimizing models using Meta's TorchTune framework.
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
| Field | Type | Required | Default | Description |
|
||||||
|
|-------|------|----------|---------|-------------|
|
||||||
|
| `torch_seed` | `int \| None` | No | | |
|
||||||
|
| `checkpoint_format` | `Literal['meta', 'huggingface'` | No | meta | |
|
||||||
|
|
||||||
|
## Sample Configuration
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
checkpoint_format: meta
|
||||||
|
|
||||||
|
```
|
||||||
|
|
20
docs/source/providers/post_training/inline_torchtune-gpu.md
Normal file
20
docs/source/providers/post_training/inline_torchtune-gpu.md
Normal file
|
@ -0,0 +1,20 @@
|
||||||
|
# inline::torchtune-gpu
|
||||||
|
|
||||||
|
## Description
|
||||||
|
|
||||||
|
TorchTune-based post-training provider for fine-tuning and optimizing models using Meta's TorchTune framework.
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
| Field | Type | Required | Default | Description |
|
||||||
|
|-------|------|----------|---------|-------------|
|
||||||
|
| `torch_seed` | `int \| None` | No | | |
|
||||||
|
| `checkpoint_format` | `Literal['meta', 'huggingface'` | No | meta | |
|
||||||
|
|
||||||
|
## Sample Configuration
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
checkpoint_format: meta
|
||||||
|
|
||||||
|
```
|
||||||
|
|
|
@ -29,12 +29,16 @@ class ListBatchesResponse(BaseModel):
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
class Batches(Protocol):
|
class Batches(Protocol):
|
||||||
"""Protocol for batch processing API operations.
|
"""
|
||||||
|
|
||||||
The Batches API enables efficient processing of multiple requests in a single operation,
|
The Batches API enables efficient processing of multiple requests in a single operation,
|
||||||
particularly useful for processing large datasets, batch evaluation workflows, and
|
particularly useful for processing large datasets, batch evaluation workflows, and
|
||||||
cost-effective inference at scale.
|
cost-effective inference at scale.
|
||||||
|
|
||||||
|
The API is designed to allow use of openai client libraries for seamless integration.
|
||||||
|
|
||||||
|
This API provides the following extensions:
|
||||||
|
- idempotent batch creation
|
||||||
|
|
||||||
Note: This API is currently under active development and may undergo changes.
|
Note: This API is currently under active development and may undergo changes.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -45,6 +49,7 @@ class Batches(Protocol):
|
||||||
endpoint: str,
|
endpoint: str,
|
||||||
completion_window: Literal["24h"],
|
completion_window: Literal["24h"],
|
||||||
metadata: dict[str, str] | None = None,
|
metadata: dict[str, str] | None = None,
|
||||||
|
idempotency_key: str | None = None,
|
||||||
) -> BatchObject:
|
) -> BatchObject:
|
||||||
"""Create a new batch for processing multiple API requests.
|
"""Create a new batch for processing multiple API requests.
|
||||||
|
|
||||||
|
@ -52,6 +57,7 @@ class Batches(Protocol):
|
||||||
:param endpoint: The endpoint to be used for all requests in the batch.
|
:param endpoint: The endpoint to be used for all requests in the batch.
|
||||||
:param completion_window: The time window within which the batch should be processed.
|
:param completion_window: The time window within which the batch should be processed.
|
||||||
:param metadata: Optional metadata for the batch.
|
:param metadata: Optional metadata for the batch.
|
||||||
|
:param idempotency_key: Optional idempotency key. When provided, enables idempotent behavior.
|
||||||
:returns: The created batch object.
|
:returns: The created batch object.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
|
@ -473,6 +473,28 @@ class EmbeddingsResponse(BaseModel):
|
||||||
embeddings: list[list[float]]
|
embeddings: list[list[float]]
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class RerankData(BaseModel):
|
||||||
|
"""A single rerank result from a reranking response.
|
||||||
|
|
||||||
|
:param index: The original index of the document in the input list
|
||||||
|
:param relevance_score: The relevance score from the model output. Values are inverted when applicable so that higher scores indicate greater relevance.
|
||||||
|
"""
|
||||||
|
|
||||||
|
index: int
|
||||||
|
relevance_score: float
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class RerankResponse(BaseModel):
|
||||||
|
"""Response from a reranking request.
|
||||||
|
|
||||||
|
:param data: List of rerank result objects, sorted by relevance score (descending)
|
||||||
|
"""
|
||||||
|
|
||||||
|
data: list[RerankData]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class OpenAIChatCompletionContentPartTextParam(BaseModel):
|
class OpenAIChatCompletionContentPartTextParam(BaseModel):
|
||||||
"""Text content part for OpenAI-compatible chat completion messages.
|
"""Text content part for OpenAI-compatible chat completion messages.
|
||||||
|
@ -1046,6 +1068,7 @@ class InferenceProvider(Protocol):
|
||||||
:returns: A BatchCompletionResponse with the full completions.
|
:returns: A BatchCompletionResponse with the full completions.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError("Batch completion is not implemented")
|
raise NotImplementedError("Batch completion is not implemented")
|
||||||
|
return # this is so mypy's safe-super rule will consider the method concrete
|
||||||
|
|
||||||
@webmethod(route="/inference/chat-completion", method="POST")
|
@webmethod(route="/inference/chat-completion", method="POST")
|
||||||
async def chat_completion(
|
async def chat_completion(
|
||||||
|
@ -1110,6 +1133,7 @@ class InferenceProvider(Protocol):
|
||||||
:returns: A BatchChatCompletionResponse with the full completions.
|
:returns: A BatchChatCompletionResponse with the full completions.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError("Batch chat completion is not implemented")
|
raise NotImplementedError("Batch chat completion is not implemented")
|
||||||
|
return # this is so mypy's safe-super rule will consider the method concrete
|
||||||
|
|
||||||
@webmethod(route="/inference/embeddings", method="POST")
|
@webmethod(route="/inference/embeddings", method="POST")
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
|
@ -1131,6 +1155,25 @@ class InferenceProvider(Protocol):
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@webmethod(route="/inference/rerank", method="POST", experimental=True)
|
||||||
|
async def rerank(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
query: str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam,
|
||||||
|
items: list[str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam],
|
||||||
|
max_num_results: int | None = None,
|
||||||
|
) -> RerankResponse:
|
||||||
|
"""Rerank a list of documents based on their relevance to a query.
|
||||||
|
|
||||||
|
:param model: The identifier of the reranking model to use.
|
||||||
|
:param query: The search query to rank items against. Can be a string, text content part, or image content part. The input must not exceed the model's max input token length.
|
||||||
|
:param items: List of items to rerank. Each item can be a string, text content part, or image content part. Each input must not exceed the model's max input token length.
|
||||||
|
:param max_num_results: (Optional) Maximum number of results to return. Default: returns all.
|
||||||
|
:returns: RerankResponse with indices sorted by relevance score (descending).
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("Reranking is not implemented")
|
||||||
|
return # this is so mypy's safe-super rule will consider the method concrete
|
||||||
|
|
||||||
@webmethod(route="/openai/v1/completions", method="POST")
|
@webmethod(route="/openai/v1/completions", method="POST")
|
||||||
async def openai_completion(
|
async def openai_completion(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -386,6 +386,7 @@ class MetricDataPoint(BaseModel):
|
||||||
|
|
||||||
timestamp: int
|
timestamp: int
|
||||||
value: float
|
value: float
|
||||||
|
unit: str
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -518,7 +519,7 @@ class Telemetry(Protocol):
|
||||||
metric_name: str,
|
metric_name: str,
|
||||||
start_time: int,
|
start_time: int,
|
||||||
end_time: int | None = None,
|
end_time: int | None = None,
|
||||||
granularity: str | None = "1d",
|
granularity: str | None = None,
|
||||||
query_type: MetricQueryType = MetricQueryType.RANGE,
|
query_type: MetricQueryType = MetricQueryType.RANGE,
|
||||||
label_matchers: list[MetricLabelMatcher] | None = None,
|
label_matchers: list[MetricLabelMatcher] | None = None,
|
||||||
) -> QueryMetricsResponse:
|
) -> QueryMetricsResponse:
|
||||||
|
|
|
@ -15,7 +15,7 @@ from llama_stack.log import get_logger
|
||||||
|
|
||||||
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="server")
|
logger = get_logger(name=__name__, category="cli")
|
||||||
|
|
||||||
|
|
||||||
class StackRun(Subcommand):
|
class StackRun(Subcommand):
|
||||||
|
|
|
@ -5,7 +5,6 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import importlib.resources
|
import importlib.resources
|
||||||
import logging
|
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
@ -17,9 +16,10 @@ from llama_stack.core.external import load_external_apis
|
||||||
from llama_stack.core.utils.exec import run_command
|
from llama_stack.core.utils.exec import run_command
|
||||||
from llama_stack.core.utils.image_types import LlamaStackImageType
|
from llama_stack.core.utils.image_types import LlamaStackImageType
|
||||||
from llama_stack.distributions.template import DistributionTemplate
|
from llama_stack.distributions.template import DistributionTemplate
|
||||||
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import Api
|
from llama_stack.providers.datatypes import Api
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
# These are the dependencies needed by the distribution server.
|
# These are the dependencies needed by the distribution server.
|
||||||
# `llama-stack` is automatically installed by the installation script.
|
# `llama-stack` is automatically installed by the installation script.
|
||||||
|
|
|
@ -3,7 +3,6 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
import logging
|
|
||||||
import textwrap
|
import textwrap
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
@ -21,9 +20,10 @@ from llama_stack.core.stack import cast_image_name_to_string, replace_env_vars
|
||||||
from llama_stack.core.utils.config_dirs import EXTERNAL_PROVIDERS_DIR
|
from llama_stack.core.utils.config_dirs import EXTERNAL_PROVIDERS_DIR
|
||||||
from llama_stack.core.utils.dynamic import instantiate_class_type
|
from llama_stack.core.utils.dynamic import instantiate_class_type
|
||||||
from llama_stack.core.utils.prompt_for_config import prompt_for_config
|
from llama_stack.core.utils.prompt_for_config import prompt_for_config
|
||||||
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
def configure_single_provider(registry: dict[str, ProviderSpec], provider: Provider) -> Provider:
|
def configure_single_provider(registry: dict[str, ProviderSpec], provider: Provider) -> Provider:
|
||||||
|
|
|
@ -318,6 +318,41 @@ class QuotaConfig(BaseModel):
|
||||||
period: QuotaPeriod = Field(default=QuotaPeriod.DAY, description="Quota period to set")
|
period: QuotaPeriod = Field(default=QuotaPeriod.DAY, description="Quota period to set")
|
||||||
|
|
||||||
|
|
||||||
|
class CORSConfig(BaseModel):
|
||||||
|
allow_origins: list[str] = Field(default_factory=list)
|
||||||
|
allow_origin_regex: str | None = Field(default=None)
|
||||||
|
allow_methods: list[str] = Field(default=["OPTIONS"])
|
||||||
|
allow_headers: list[str] = Field(default_factory=list)
|
||||||
|
allow_credentials: bool = Field(default=False)
|
||||||
|
expose_headers: list[str] = Field(default_factory=list)
|
||||||
|
max_age: int = Field(default=600, ge=0)
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def validate_credentials_config(self) -> Self:
|
||||||
|
if self.allow_credentials and (self.allow_origins == ["*"] or "*" in self.allow_origins):
|
||||||
|
raise ValueError("Cannot use wildcard origins with credentials enabled")
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
def process_cors_config(cors_config: bool | CORSConfig | None) -> CORSConfig | None:
|
||||||
|
if cors_config is False or cors_config is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if cors_config is True:
|
||||||
|
# dev mode: allow localhost on any port
|
||||||
|
return CORSConfig(
|
||||||
|
allow_origins=[],
|
||||||
|
allow_origin_regex=r"https?://localhost:\d+",
|
||||||
|
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
|
||||||
|
allow_headers=["Content-Type", "Authorization", "X-Requested-With"],
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(cors_config, CORSConfig):
|
||||||
|
return cors_config
|
||||||
|
|
||||||
|
raise ValueError(f"Expected bool or CORSConfig, got {type(cors_config).__name__}")
|
||||||
|
|
||||||
|
|
||||||
class ServerConfig(BaseModel):
|
class ServerConfig(BaseModel):
|
||||||
port: int = Field(
|
port: int = Field(
|
||||||
default=8321,
|
default=8321,
|
||||||
|
@ -349,6 +384,12 @@ class ServerConfig(BaseModel):
|
||||||
default=None,
|
default=None,
|
||||||
description="Per client quota request configuration",
|
description="Per client quota request configuration",
|
||||||
)
|
)
|
||||||
|
cors: bool | CORSConfig | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="CORS configuration for cross-origin requests. Can be:\n"
|
||||||
|
"- true: Enable localhost CORS for development\n"
|
||||||
|
"- {allow_origins: [...], allow_methods: [...], ...}: Full configuration",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class StackRunConfig(BaseModel):
|
class StackRunConfig(BaseModel):
|
||||||
|
|
|
@ -7,7 +7,7 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging # allow-direct-logging
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
@ -48,6 +48,7 @@ from llama_stack.core.stack import (
|
||||||
from llama_stack.core.utils.config import redact_sensitive_fields
|
from llama_stack.core.utils.config import redact_sensitive_fields
|
||||||
from llama_stack.core.utils.context import preserve_contexts_async_generator
|
from llama_stack.core.utils.context import preserve_contexts_async_generator
|
||||||
from llama_stack.core.utils.exec import in_notebook
|
from llama_stack.core.utils.exec import in_notebook
|
||||||
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.utils.telemetry.tracing import (
|
from llama_stack.providers.utils.telemetry.tracing import (
|
||||||
CURRENT_TRACE_CONTEXT,
|
CURRENT_TRACE_CONTEXT,
|
||||||
end_trace,
|
end_trace,
|
||||||
|
@ -55,7 +56,7 @@ from llama_stack.providers.utils.telemetry.tracing import (
|
||||||
start_trace,
|
start_trace,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
@ -145,39 +146,26 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.async_client = AsyncLlamaStackAsLibraryClient(
|
self.async_client = AsyncLlamaStackAsLibraryClient(
|
||||||
config_path_or_distro_name, custom_provider_registry, provider_data
|
config_path_or_distro_name, custom_provider_registry, provider_data, skip_logger_removal
|
||||||
)
|
)
|
||||||
self.pool_executor = ThreadPoolExecutor(max_workers=4)
|
self.pool_executor = ThreadPoolExecutor(max_workers=4)
|
||||||
self.skip_logger_removal = skip_logger_removal
|
|
||||||
self.provider_data = provider_data
|
self.provider_data = provider_data
|
||||||
|
|
||||||
self.loop = asyncio.new_event_loop()
|
self.loop = asyncio.new_event_loop()
|
||||||
|
|
||||||
def initialize(self):
|
|
||||||
if in_notebook():
|
|
||||||
import nest_asyncio
|
|
||||||
|
|
||||||
nest_asyncio.apply()
|
|
||||||
if not self.skip_logger_removal:
|
|
||||||
self._remove_root_logger_handlers()
|
|
||||||
|
|
||||||
# use a new event loop to avoid interfering with the main event loop
|
# use a new event loop to avoid interfering with the main event loop
|
||||||
loop = asyncio.new_event_loop()
|
loop = asyncio.new_event_loop()
|
||||||
asyncio.set_event_loop(loop)
|
asyncio.set_event_loop(loop)
|
||||||
try:
|
try:
|
||||||
return loop.run_until_complete(self.async_client.initialize())
|
loop.run_until_complete(self.async_client.initialize())
|
||||||
finally:
|
finally:
|
||||||
asyncio.set_event_loop(None)
|
asyncio.set_event_loop(None)
|
||||||
|
|
||||||
def _remove_root_logger_handlers(self):
|
def initialize(self):
|
||||||
"""
|
"""
|
||||||
Remove all handlers from the root logger. Needed to avoid polluting the console with logs.
|
Deprecated method for backward compatibility.
|
||||||
"""
|
"""
|
||||||
root_logger = logging.getLogger()
|
pass
|
||||||
|
|
||||||
for handler in root_logger.handlers[:]:
|
|
||||||
root_logger.removeHandler(handler)
|
|
||||||
logger.info(f"Removed handler {handler.__class__.__name__} from root logger")
|
|
||||||
|
|
||||||
def request(self, *args, **kwargs):
|
def request(self, *args, **kwargs):
|
||||||
loop = self.loop
|
loop = self.loop
|
||||||
|
@ -215,6 +203,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
config_path_or_distro_name: str,
|
config_path_or_distro_name: str,
|
||||||
custom_provider_registry: ProviderRegistry | None = None,
|
custom_provider_registry: ProviderRegistry | None = None,
|
||||||
provider_data: dict[str, Any] | None = None,
|
provider_data: dict[str, Any] | None = None,
|
||||||
|
skip_logger_removal: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# when using the library client, we should not log to console since many
|
# when using the library client, we should not log to console since many
|
||||||
|
@ -222,6 +211,13 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
current_sinks = os.environ.get("TELEMETRY_SINKS", "sqlite").split(",")
|
current_sinks = os.environ.get("TELEMETRY_SINKS", "sqlite").split(",")
|
||||||
os.environ["TELEMETRY_SINKS"] = ",".join(sink for sink in current_sinks if sink != "console")
|
os.environ["TELEMETRY_SINKS"] = ",".join(sink for sink in current_sinks if sink != "console")
|
||||||
|
|
||||||
|
if in_notebook():
|
||||||
|
import nest_asyncio
|
||||||
|
|
||||||
|
nest_asyncio.apply()
|
||||||
|
if not skip_logger_removal:
|
||||||
|
self._remove_root_logger_handlers()
|
||||||
|
|
||||||
if config_path_or_distro_name.endswith(".yaml"):
|
if config_path_or_distro_name.endswith(".yaml"):
|
||||||
config_path = Path(config_path_or_distro_name)
|
config_path = Path(config_path_or_distro_name)
|
||||||
if not config_path.exists():
|
if not config_path.exists():
|
||||||
|
@ -238,7 +234,24 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
self.provider_data = provider_data
|
self.provider_data = provider_data
|
||||||
self.route_impls: RouteImpls | None = None # Initialize to None to prevent AttributeError
|
self.route_impls: RouteImpls | None = None # Initialize to None to prevent AttributeError
|
||||||
|
|
||||||
|
def _remove_root_logger_handlers(self):
|
||||||
|
"""
|
||||||
|
Remove all handlers from the root logger. Needed to avoid polluting the console with logs.
|
||||||
|
"""
|
||||||
|
root_logger = logging.getLogger()
|
||||||
|
|
||||||
|
for handler in root_logger.handlers[:]:
|
||||||
|
root_logger.removeHandler(handler)
|
||||||
|
logger.info(f"Removed handler {handler.__class__.__name__} from root logger")
|
||||||
|
|
||||||
async def initialize(self) -> bool:
|
async def initialize(self) -> bool:
|
||||||
|
"""
|
||||||
|
Initialize the async client.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if initialization was successful
|
||||||
|
"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.route_impls = None
|
self.route_impls = None
|
||||||
self.impls = await construct_stack(self.config, self.custom_provider_registry)
|
self.impls = await construct_stack(self.config, self.custom_provider_registry)
|
||||||
|
|
|
@ -6,15 +6,15 @@
|
||||||
|
|
||||||
import contextvars
|
import contextvars
|
||||||
import json
|
import json
|
||||||
import logging
|
|
||||||
from contextlib import AbstractContextManager
|
from contextlib import AbstractContextManager
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from llama_stack.core.datatypes import User
|
from llama_stack.core.datatypes import User
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
from .utils.dynamic import instantiate_class_type
|
from .utils.dynamic import instantiate_class_type
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
# Context variable for request provider data and auth attributes
|
# Context variable for request provider data and auth attributes
|
||||||
PROVIDER_DATA_VAR = contextvars.ContextVar("provider_data", default=None)
|
PROVIDER_DATA_VAR = contextvars.ContextVar("provider_data", default=None)
|
||||||
|
|
|
@ -12,7 +12,7 @@ from llama_stack.apis.datasets import DatasetPurpose, DataSource
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import RoutingTable
|
from llama_stack.providers.datatypes import RoutingTable
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
logger = get_logger(name=__name__, category="core::routers")
|
||||||
|
|
||||||
|
|
||||||
class DatasetIORouter(DatasetIO):
|
class DatasetIORouter(DatasetIO):
|
||||||
|
|
|
@ -16,7 +16,7 @@ from llama_stack.apis.scoring import (
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import RoutingTable
|
from llama_stack.providers.datatypes import RoutingTable
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
logger = get_logger(name=__name__, category="core::routers")
|
||||||
|
|
||||||
|
|
||||||
class ScoringRouter(Scoring):
|
class ScoringRouter(Scoring):
|
||||||
|
|
|
@ -65,7 +65,7 @@ from llama_stack.providers.datatypes import HealthResponse, HealthStatus, Routin
|
||||||
from llama_stack.providers.utils.inference.inference_store import InferenceStore
|
from llama_stack.providers.utils.inference.inference_store import InferenceStore
|
||||||
from llama_stack.providers.utils.telemetry.tracing import get_current_span
|
from llama_stack.providers.utils.telemetry.tracing import get_current_span
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="inference")
|
logger = get_logger(name=__name__, category="core::routers")
|
||||||
|
|
||||||
|
|
||||||
class InferenceRouter(Inference):
|
class InferenceRouter(Inference):
|
||||||
|
|
|
@ -13,7 +13,7 @@ from llama_stack.apis.shields import Shield
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import RoutingTable
|
from llama_stack.providers.datatypes import RoutingTable
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
logger = get_logger(name=__name__, category="core::routers")
|
||||||
|
|
||||||
|
|
||||||
class SafetyRouter(Safety):
|
class SafetyRouter(Safety):
|
||||||
|
|
|
@ -22,7 +22,7 @@ from llama_stack.log import get_logger
|
||||||
|
|
||||||
from ..routing_tables.toolgroups import ToolGroupsRoutingTable
|
from ..routing_tables.toolgroups import ToolGroupsRoutingTable
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
logger = get_logger(name=__name__, category="core::routers")
|
||||||
|
|
||||||
|
|
||||||
class ToolRuntimeRouter(ToolRuntime):
|
class ToolRuntimeRouter(ToolRuntime):
|
||||||
|
|
|
@ -30,7 +30,7 @@ from llama_stack.apis.vector_io import (
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
|
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
logger = get_logger(name=__name__, category="core::routers")
|
||||||
|
|
||||||
|
|
||||||
class VectorIORouter(VectorIO):
|
class VectorIORouter(VectorIO):
|
||||||
|
|
|
@ -14,7 +14,7 @@ from llama_stack.log import get_logger
|
||||||
|
|
||||||
from .common import CommonRoutingTableImpl
|
from .common import CommonRoutingTableImpl
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
logger = get_logger(name=__name__, category="core::routing_tables")
|
||||||
|
|
||||||
|
|
||||||
class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
|
class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
|
||||||
|
|
|
@ -23,7 +23,7 @@ from llama_stack.core.store import DistributionRegistry
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import Api, RoutingTable
|
from llama_stack.providers.datatypes import Api, RoutingTable
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
logger = get_logger(name=__name__, category="core::routing_tables")
|
||||||
|
|
||||||
|
|
||||||
def get_impl_api(p: Any) -> Api:
|
def get_impl_api(p: Any) -> Api:
|
||||||
|
|
|
@ -26,7 +26,7 @@ from llama_stack.log import get_logger
|
||||||
|
|
||||||
from .common import CommonRoutingTableImpl
|
from .common import CommonRoutingTableImpl
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
logger = get_logger(name=__name__, category="core::routing_tables")
|
||||||
|
|
||||||
|
|
||||||
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
||||||
|
|
|
@ -17,7 +17,7 @@ from llama_stack.log import get_logger
|
||||||
|
|
||||||
from .common import CommonRoutingTableImpl, lookup_model
|
from .common import CommonRoutingTableImpl, lookup_model
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
logger = get_logger(name=__name__, category="core::routing_tables")
|
||||||
|
|
||||||
|
|
||||||
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
|
|
|
@ -19,7 +19,7 @@ from llama_stack.log import get_logger
|
||||||
|
|
||||||
from .common import CommonRoutingTableImpl
|
from .common import CommonRoutingTableImpl
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
logger = get_logger(name=__name__, category="core::routing_tables")
|
||||||
|
|
||||||
|
|
||||||
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
|
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
|
||||||
|
|
|
@ -15,7 +15,7 @@ from llama_stack.log import get_logger
|
||||||
|
|
||||||
from .common import CommonRoutingTableImpl
|
from .common import CommonRoutingTableImpl
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
logger = get_logger(name=__name__, category="core::routing_tables")
|
||||||
|
|
||||||
|
|
||||||
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
||||||
|
|
|
@ -14,7 +14,7 @@ from llama_stack.log import get_logger
|
||||||
|
|
||||||
from .common import CommonRoutingTableImpl
|
from .common import CommonRoutingTableImpl
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
logger = get_logger(name=__name__, category="core::routing_tables")
|
||||||
|
|
||||||
|
|
||||||
def parse_toolgroup_from_toolgroup_name_pair(toolgroup_name_with_maybe_tool_name: str) -> str | None:
|
def parse_toolgroup_from_toolgroup_name_pair(toolgroup_name_with_maybe_tool_name: str) -> str | None:
|
||||||
|
|
|
@ -30,7 +30,7 @@ from llama_stack.log import get_logger
|
||||||
|
|
||||||
from .common import CommonRoutingTableImpl, lookup_model
|
from .common import CommonRoutingTableImpl, lookup_model
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
logger = get_logger(name=__name__, category="core::routing_tables")
|
||||||
|
|
||||||
|
|
||||||
class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
||||||
|
|
|
@ -15,7 +15,7 @@ from llama_stack.core.server.auth_providers import create_auth_provider
|
||||||
from llama_stack.core.server.routes import find_matching_route, initialize_route_impls
|
from llama_stack.core.server.routes import find_matching_route, initialize_route_impls
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="auth")
|
logger = get_logger(name=__name__, category="core::auth")
|
||||||
|
|
||||||
|
|
||||||
class AuthenticationMiddleware:
|
class AuthenticationMiddleware:
|
||||||
|
|
|
@ -23,7 +23,7 @@ from llama_stack.core.datatypes import (
|
||||||
)
|
)
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="auth")
|
logger = get_logger(name=__name__, category="core::auth")
|
||||||
|
|
||||||
|
|
||||||
class AuthResponse(BaseModel):
|
class AuthResponse(BaseModel):
|
||||||
|
|
|
@ -15,7 +15,7 @@ from llama_stack.providers.utils.kvstore.api import KVStore
|
||||||
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
||||||
from llama_stack.providers.utils.kvstore.kvstore import kvstore_impl
|
from llama_stack.providers.utils.kvstore.kvstore import kvstore_impl
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="quota")
|
logger = get_logger(name=__name__, category="core::server")
|
||||||
|
|
||||||
|
|
||||||
class QuotaMiddleware:
|
class QuotaMiddleware:
|
||||||
|
|
|
@ -9,7 +9,7 @@ import asyncio
|
||||||
import functools
|
import functools
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging # allow-direct-logging
|
||||||
import os
|
import os
|
||||||
import ssl
|
import ssl
|
||||||
import sys
|
import sys
|
||||||
|
@ -28,6 +28,7 @@ from aiohttp import hdrs
|
||||||
from fastapi import Body, FastAPI, HTTPException, Request, Response
|
from fastapi import Body, FastAPI, HTTPException, Request, Response
|
||||||
from fastapi import Path as FastapiPath
|
from fastapi import Path as FastapiPath
|
||||||
from fastapi.exceptions import RequestValidationError
|
from fastapi.exceptions import RequestValidationError
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.responses import JSONResponse, StreamingResponse
|
from fastapi.responses import JSONResponse, StreamingResponse
|
||||||
from openai import BadRequestError
|
from openai import BadRequestError
|
||||||
from pydantic import BaseModel, ValidationError
|
from pydantic import BaseModel, ValidationError
|
||||||
|
@ -40,6 +41,7 @@ from llama_stack.core.datatypes import (
|
||||||
AuthenticationRequiredError,
|
AuthenticationRequiredError,
|
||||||
LoggingConfig,
|
LoggingConfig,
|
||||||
StackRunConfig,
|
StackRunConfig,
|
||||||
|
process_cors_config,
|
||||||
)
|
)
|
||||||
from llama_stack.core.distribution import builtin_automatically_routed_apis
|
from llama_stack.core.distribution import builtin_automatically_routed_apis
|
||||||
from llama_stack.core.external import ExternalApiSpec, load_external_apis
|
from llama_stack.core.external import ExternalApiSpec, load_external_apis
|
||||||
|
@ -82,7 +84,7 @@ from .quota import QuotaMiddleware
|
||||||
|
|
||||||
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="server")
|
logger = get_logger(name=__name__, category="core::server")
|
||||||
|
|
||||||
|
|
||||||
def warn_with_traceback(message, category, filename, lineno, file=None, line=None):
|
def warn_with_traceback(message, category, filename, lineno, file=None, line=None):
|
||||||
|
@ -413,7 +415,7 @@ def main(args: argparse.Namespace | None = None):
|
||||||
config_contents = yaml.safe_load(fp)
|
config_contents = yaml.safe_load(fp)
|
||||||
if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")):
|
if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")):
|
||||||
logger_config = LoggingConfig(**cfg)
|
logger_config = LoggingConfig(**cfg)
|
||||||
logger = get_logger(name=__name__, category="server", config=logger_config)
|
logger = get_logger(name=__name__, category="core::server", config=logger_config)
|
||||||
if args.env:
|
if args.env:
|
||||||
for env_pair in args.env:
|
for env_pair in args.env:
|
||||||
try:
|
try:
|
||||||
|
@ -483,6 +485,12 @@ def main(args: argparse.Namespace | None = None):
|
||||||
window_seconds=window_seconds,
|
window_seconds=window_seconds,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if config.server.cors:
|
||||||
|
logger.info("Enabling CORS")
|
||||||
|
cors_config = process_cors_config(config.server.cors)
|
||||||
|
if cors_config:
|
||||||
|
app.add_middleware(CORSMiddleware, **cors_config.model_dump())
|
||||||
|
|
||||||
if Api.telemetry in impls:
|
if Api.telemetry in impls:
|
||||||
setup_logger(impls[Api.telemetry])
|
setup_logger(impls[Api.telemetry])
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -16,7 +16,7 @@ from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
|
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
|
||||||
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
||||||
|
|
||||||
logger = get_logger(__name__, category="core")
|
logger = get_logger(__name__, category="core::registry")
|
||||||
|
|
||||||
|
|
||||||
class DistributionRegistry(Protocol):
|
class DistributionRegistry(Protocol):
|
||||||
|
|
|
@ -10,7 +10,7 @@ from pathlib import Path
|
||||||
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR
|
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="config_resolution")
|
logger = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
DISTRO_DIR = Path(__file__).parent.parent.parent.parent / "llama_stack" / "distributions"
|
DISTRO_DIR = Path(__file__).parent.parent.parent.parent / "llama_stack" / "distributions"
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import logging
|
import importlib
|
||||||
import os
|
import os
|
||||||
import signal
|
import signal
|
||||||
import subprocess
|
import subprocess
|
||||||
|
@ -12,9 +12,9 @@ import sys
|
||||||
|
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
import importlib
|
log = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
def formulate_run_args(image_type: str, image_name: str) -> list:
|
def formulate_run_args(image_type: str, image_name: str) -> list:
|
||||||
|
|
|
@ -6,7 +6,6 @@
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
import logging
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Annotated, Any, Literal, Union, get_args, get_origin
|
from typing import Annotated, Any, Literal, Union, get_args, get_origin
|
||||||
|
|
||||||
|
@ -14,7 +13,9 @@ from pydantic import BaseModel
|
||||||
from pydantic.fields import FieldInfo
|
from pydantic.fields import FieldInfo
|
||||||
from pydantic_core import PydanticUndefinedType
|
from pydantic_core import PydanticUndefinedType
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
|
log = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
def is_list_of_primitives(field_type):
|
def is_list_of_primitives(field_type):
|
||||||
|
|
|
@ -34,7 +34,7 @@ distribution_spec:
|
||||||
telemetry:
|
telemetry:
|
||||||
- provider_type: inline::meta-reference
|
- provider_type: inline::meta-reference
|
||||||
post_training:
|
post_training:
|
||||||
- provider_type: inline::huggingface
|
- provider_type: inline::huggingface-cpu
|
||||||
eval:
|
eval:
|
||||||
- provider_type: inline::meta-reference
|
- provider_type: inline::meta-reference
|
||||||
datasetio:
|
datasetio:
|
||||||
|
|
|
@ -156,8 +156,8 @@ providers:
|
||||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/trace_store.db
|
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/trace_store.db
|
||||||
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
|
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
|
||||||
post_training:
|
post_training:
|
||||||
- provider_id: huggingface
|
- provider_id: huggingface-cpu
|
||||||
provider_type: inline::huggingface
|
provider_type: inline::huggingface-cpu
|
||||||
config:
|
config:
|
||||||
checkpoint_format: huggingface
|
checkpoint_format: huggingface
|
||||||
distributed_backend: null
|
distributed_backend: null
|
||||||
|
|
7
llama_stack/distributions/starter-gpu/__init__.py
Normal file
7
llama_stack/distributions/starter-gpu/__init__.py
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .starter_gpu import get_distribution_template # noqa: F401
|
59
llama_stack/distributions/starter-gpu/build.yaml
Normal file
59
llama_stack/distributions/starter-gpu/build.yaml
Normal file
|
@ -0,0 +1,59 @@
|
||||||
|
version: 2
|
||||||
|
distribution_spec:
|
||||||
|
description: Quick start template for running Llama Stack with several popular providers.
|
||||||
|
This distribution is intended for GPU-enabled environments.
|
||||||
|
providers:
|
||||||
|
inference:
|
||||||
|
- provider_type: remote::cerebras
|
||||||
|
- provider_type: remote::ollama
|
||||||
|
- provider_type: remote::vllm
|
||||||
|
- provider_type: remote::tgi
|
||||||
|
- provider_type: remote::fireworks
|
||||||
|
- provider_type: remote::together
|
||||||
|
- provider_type: remote::bedrock
|
||||||
|
- provider_type: remote::nvidia
|
||||||
|
- provider_type: remote::openai
|
||||||
|
- provider_type: remote::anthropic
|
||||||
|
- provider_type: remote::gemini
|
||||||
|
- provider_type: remote::vertexai
|
||||||
|
- provider_type: remote::groq
|
||||||
|
- provider_type: remote::sambanova
|
||||||
|
- provider_type: inline::sentence-transformers
|
||||||
|
vector_io:
|
||||||
|
- provider_type: inline::faiss
|
||||||
|
- provider_type: inline::sqlite-vec
|
||||||
|
- provider_type: inline::milvus
|
||||||
|
- provider_type: remote::chromadb
|
||||||
|
- provider_type: remote::pgvector
|
||||||
|
files:
|
||||||
|
- provider_type: inline::localfs
|
||||||
|
safety:
|
||||||
|
- provider_type: inline::llama-guard
|
||||||
|
- provider_type: inline::code-scanner
|
||||||
|
agents:
|
||||||
|
- provider_type: inline::meta-reference
|
||||||
|
telemetry:
|
||||||
|
- provider_type: inline::meta-reference
|
||||||
|
post_training:
|
||||||
|
- provider_type: inline::torchtune-gpu
|
||||||
|
eval:
|
||||||
|
- provider_type: inline::meta-reference
|
||||||
|
datasetio:
|
||||||
|
- provider_type: remote::huggingface
|
||||||
|
- provider_type: inline::localfs
|
||||||
|
scoring:
|
||||||
|
- provider_type: inline::basic
|
||||||
|
- provider_type: inline::llm-as-judge
|
||||||
|
- provider_type: inline::braintrust
|
||||||
|
tool_runtime:
|
||||||
|
- provider_type: remote::brave-search
|
||||||
|
- provider_type: remote::tavily-search
|
||||||
|
- provider_type: inline::rag-runtime
|
||||||
|
- provider_type: remote::model-context-protocol
|
||||||
|
batches:
|
||||||
|
- provider_type: inline::reference
|
||||||
|
image_type: venv
|
||||||
|
additional_pip_packages:
|
||||||
|
- aiosqlite
|
||||||
|
- asyncpg
|
||||||
|
- sqlalchemy[asyncio]
|
238
llama_stack/distributions/starter-gpu/run.yaml
Normal file
238
llama_stack/distributions/starter-gpu/run.yaml
Normal file
|
@ -0,0 +1,238 @@
|
||||||
|
version: 2
|
||||||
|
image_name: starter-gpu
|
||||||
|
apis:
|
||||||
|
- agents
|
||||||
|
- batches
|
||||||
|
- datasetio
|
||||||
|
- eval
|
||||||
|
- files
|
||||||
|
- inference
|
||||||
|
- post_training
|
||||||
|
- safety
|
||||||
|
- scoring
|
||||||
|
- telemetry
|
||||||
|
- tool_runtime
|
||||||
|
- vector_io
|
||||||
|
providers:
|
||||||
|
inference:
|
||||||
|
- provider_id: ${env.CEREBRAS_API_KEY:+cerebras}
|
||||||
|
provider_type: remote::cerebras
|
||||||
|
config:
|
||||||
|
base_url: https://api.cerebras.ai
|
||||||
|
api_key: ${env.CEREBRAS_API_KEY:=}
|
||||||
|
- provider_id: ${env.OLLAMA_URL:+ollama}
|
||||||
|
provider_type: remote::ollama
|
||||||
|
config:
|
||||||
|
url: ${env.OLLAMA_URL:=http://localhost:11434}
|
||||||
|
- provider_id: ${env.VLLM_URL:+vllm}
|
||||||
|
provider_type: remote::vllm
|
||||||
|
config:
|
||||||
|
url: ${env.VLLM_URL:=}
|
||||||
|
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
|
||||||
|
api_token: ${env.VLLM_API_TOKEN:=fake}
|
||||||
|
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
|
||||||
|
- provider_id: ${env.TGI_URL:+tgi}
|
||||||
|
provider_type: remote::tgi
|
||||||
|
config:
|
||||||
|
url: ${env.TGI_URL:=}
|
||||||
|
- provider_id: fireworks
|
||||||
|
provider_type: remote::fireworks
|
||||||
|
config:
|
||||||
|
url: https://api.fireworks.ai/inference/v1
|
||||||
|
api_key: ${env.FIREWORKS_API_KEY:=}
|
||||||
|
- provider_id: together
|
||||||
|
provider_type: remote::together
|
||||||
|
config:
|
||||||
|
url: https://api.together.xyz/v1
|
||||||
|
api_key: ${env.TOGETHER_API_KEY:=}
|
||||||
|
- provider_id: bedrock
|
||||||
|
provider_type: remote::bedrock
|
||||||
|
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
|
||||||
|
provider_type: remote::nvidia
|
||||||
|
config:
|
||||||
|
url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com}
|
||||||
|
api_key: ${env.NVIDIA_API_KEY:=}
|
||||||
|
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:=True}
|
||||||
|
- provider_id: openai
|
||||||
|
provider_type: remote::openai
|
||||||
|
config:
|
||||||
|
api_key: ${env.OPENAI_API_KEY:=}
|
||||||
|
base_url: ${env.OPENAI_BASE_URL:=https://api.openai.com/v1}
|
||||||
|
- provider_id: anthropic
|
||||||
|
provider_type: remote::anthropic
|
||||||
|
config:
|
||||||
|
api_key: ${env.ANTHROPIC_API_KEY:=}
|
||||||
|
- provider_id: gemini
|
||||||
|
provider_type: remote::gemini
|
||||||
|
config:
|
||||||
|
api_key: ${env.GEMINI_API_KEY:=}
|
||||||
|
- provider_id: ${env.VERTEX_AI_PROJECT:+vertexai}
|
||||||
|
provider_type: remote::vertexai
|
||||||
|
config:
|
||||||
|
project: ${env.VERTEX_AI_PROJECT:=}
|
||||||
|
location: ${env.VERTEX_AI_LOCATION:=us-central1}
|
||||||
|
- provider_id: groq
|
||||||
|
provider_type: remote::groq
|
||||||
|
config:
|
||||||
|
url: https://api.groq.com
|
||||||
|
api_key: ${env.GROQ_API_KEY:=}
|
||||||
|
- provider_id: sambanova
|
||||||
|
provider_type: remote::sambanova
|
||||||
|
config:
|
||||||
|
url: https://api.sambanova.ai/v1
|
||||||
|
api_key: ${env.SAMBANOVA_API_KEY:=}
|
||||||
|
- provider_id: sentence-transformers
|
||||||
|
provider_type: inline::sentence-transformers
|
||||||
|
vector_io:
|
||||||
|
- provider_id: faiss
|
||||||
|
provider_type: inline::faiss
|
||||||
|
config:
|
||||||
|
kvstore:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/faiss_store.db
|
||||||
|
- provider_id: sqlite-vec
|
||||||
|
provider_type: inline::sqlite-vec
|
||||||
|
config:
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/sqlite_vec.db
|
||||||
|
kvstore:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/sqlite_vec_registry.db
|
||||||
|
- provider_id: ${env.MILVUS_URL:+milvus}
|
||||||
|
provider_type: inline::milvus
|
||||||
|
config:
|
||||||
|
db_path: ${env.MILVUS_DB_PATH:=~/.llama/distributions/starter}/milvus.db
|
||||||
|
kvstore:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/milvus_registry.db
|
||||||
|
- provider_id: ${env.CHROMADB_URL:+chromadb}
|
||||||
|
provider_type: remote::chromadb
|
||||||
|
config:
|
||||||
|
url: ${env.CHROMADB_URL:=}
|
||||||
|
kvstore:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter/}/chroma_remote_registry.db
|
||||||
|
- provider_id: ${env.PGVECTOR_DB:+pgvector}
|
||||||
|
provider_type: remote::pgvector
|
||||||
|
config:
|
||||||
|
host: ${env.PGVECTOR_HOST:=localhost}
|
||||||
|
port: ${env.PGVECTOR_PORT:=5432}
|
||||||
|
db: ${env.PGVECTOR_DB:=}
|
||||||
|
user: ${env.PGVECTOR_USER:=}
|
||||||
|
password: ${env.PGVECTOR_PASSWORD:=}
|
||||||
|
kvstore:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/pgvector_registry.db
|
||||||
|
files:
|
||||||
|
- provider_id: meta-reference-files
|
||||||
|
provider_type: inline::localfs
|
||||||
|
config:
|
||||||
|
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/starter/files}
|
||||||
|
metadata_store:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/files_metadata.db
|
||||||
|
safety:
|
||||||
|
- provider_id: llama-guard
|
||||||
|
provider_type: inline::llama-guard
|
||||||
|
config:
|
||||||
|
excluded_categories: []
|
||||||
|
- provider_id: code-scanner
|
||||||
|
provider_type: inline::code-scanner
|
||||||
|
agents:
|
||||||
|
- provider_id: meta-reference
|
||||||
|
provider_type: inline::meta-reference
|
||||||
|
config:
|
||||||
|
persistence_store:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/agents_store.db
|
||||||
|
responses_store:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/responses_store.db
|
||||||
|
telemetry:
|
||||||
|
- provider_id: meta-reference
|
||||||
|
provider_type: inline::meta-reference
|
||||||
|
config:
|
||||||
|
service_name: "${env.OTEL_SERVICE_NAME:=\u200B}"
|
||||||
|
sinks: ${env.TELEMETRY_SINKS:=console,sqlite}
|
||||||
|
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/trace_store.db
|
||||||
|
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
|
||||||
|
post_training:
|
||||||
|
- provider_id: torchtune-gpu
|
||||||
|
provider_type: inline::torchtune-gpu
|
||||||
|
config:
|
||||||
|
checkpoint_format: meta
|
||||||
|
eval:
|
||||||
|
- provider_id: meta-reference
|
||||||
|
provider_type: inline::meta-reference
|
||||||
|
config:
|
||||||
|
kvstore:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/meta_reference_eval.db
|
||||||
|
datasetio:
|
||||||
|
- provider_id: huggingface
|
||||||
|
provider_type: remote::huggingface
|
||||||
|
config:
|
||||||
|
kvstore:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/huggingface_datasetio.db
|
||||||
|
- provider_id: localfs
|
||||||
|
provider_type: inline::localfs
|
||||||
|
config:
|
||||||
|
kvstore:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/localfs_datasetio.db
|
||||||
|
scoring:
|
||||||
|
- provider_id: basic
|
||||||
|
provider_type: inline::basic
|
||||||
|
- provider_id: llm-as-judge
|
||||||
|
provider_type: inline::llm-as-judge
|
||||||
|
- provider_id: braintrust
|
||||||
|
provider_type: inline::braintrust
|
||||||
|
config:
|
||||||
|
openai_api_key: ${env.OPENAI_API_KEY:=}
|
||||||
|
tool_runtime:
|
||||||
|
- provider_id: brave-search
|
||||||
|
provider_type: remote::brave-search
|
||||||
|
config:
|
||||||
|
api_key: ${env.BRAVE_SEARCH_API_KEY:=}
|
||||||
|
max_results: 3
|
||||||
|
- provider_id: tavily-search
|
||||||
|
provider_type: remote::tavily-search
|
||||||
|
config:
|
||||||
|
api_key: ${env.TAVILY_SEARCH_API_KEY:=}
|
||||||
|
max_results: 3
|
||||||
|
- provider_id: rag-runtime
|
||||||
|
provider_type: inline::rag-runtime
|
||||||
|
- provider_id: model-context-protocol
|
||||||
|
provider_type: remote::model-context-protocol
|
||||||
|
batches:
|
||||||
|
- provider_id: reference
|
||||||
|
provider_type: inline::reference
|
||||||
|
config:
|
||||||
|
kvstore:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/batches.db
|
||||||
|
metadata_store:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/registry.db
|
||||||
|
inference_store:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/inference_store.db
|
||||||
|
models: []
|
||||||
|
shields:
|
||||||
|
- shield_id: llama-guard
|
||||||
|
provider_id: ${env.SAFETY_MODEL:+llama-guard}
|
||||||
|
provider_shield_id: ${env.SAFETY_MODEL:=}
|
||||||
|
- shield_id: code-scanner
|
||||||
|
provider_id: ${env.CODE_SCANNER_MODEL:+code-scanner}
|
||||||
|
provider_shield_id: ${env.CODE_SCANNER_MODEL:=}
|
||||||
|
vector_dbs: []
|
||||||
|
datasets: []
|
||||||
|
scoring_fns: []
|
||||||
|
benchmarks: []
|
||||||
|
tool_groups:
|
||||||
|
- toolgroup_id: builtin::websearch
|
||||||
|
provider_id: tavily-search
|
||||||
|
- toolgroup_id: builtin::rag
|
||||||
|
provider_id: rag-runtime
|
||||||
|
server:
|
||||||
|
port: 8321
|
22
llama_stack/distributions/starter-gpu/starter_gpu.py
Normal file
22
llama_stack/distributions/starter-gpu/starter_gpu.py
Normal file
|
@ -0,0 +1,22 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
|
from llama_stack.distributions.template import BuildProvider, DistributionTemplate
|
||||||
|
|
||||||
|
from ..starter.starter import get_distribution_template as get_starter_distribution_template
|
||||||
|
|
||||||
|
|
||||||
|
def get_distribution_template() -> DistributionTemplate:
|
||||||
|
template = get_starter_distribution_template()
|
||||||
|
name = "starter-gpu"
|
||||||
|
template.name = name
|
||||||
|
template.description = "Quick start template for running Llama Stack with several popular providers. This distribution is intended for GPU-enabled environments."
|
||||||
|
|
||||||
|
template.providers["post_training"] = [
|
||||||
|
BuildProvider(provider_type="inline::torchtune-gpu"),
|
||||||
|
]
|
||||||
|
return template
|
|
@ -1,6 +1,7 @@
|
||||||
version: 2
|
version: 2
|
||||||
distribution_spec:
|
distribution_spec:
|
||||||
description: Quick start template for running Llama Stack with several popular providers
|
description: Quick start template for running Llama Stack with several popular providers.
|
||||||
|
This distribution is intended for CPU-only environments.
|
||||||
providers:
|
providers:
|
||||||
inference:
|
inference:
|
||||||
- provider_type: remote::cerebras
|
- provider_type: remote::cerebras
|
||||||
|
@ -34,7 +35,7 @@ distribution_spec:
|
||||||
telemetry:
|
telemetry:
|
||||||
- provider_type: inline::meta-reference
|
- provider_type: inline::meta-reference
|
||||||
post_training:
|
post_training:
|
||||||
- provider_type: inline::huggingface
|
- provider_type: inline::huggingface-cpu
|
||||||
eval:
|
eval:
|
||||||
- provider_type: inline::meta-reference
|
- provider_type: inline::meta-reference
|
||||||
datasetio:
|
datasetio:
|
||||||
|
|
|
@ -156,8 +156,8 @@ providers:
|
||||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/trace_store.db
|
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/trace_store.db
|
||||||
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
|
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
|
||||||
post_training:
|
post_training:
|
||||||
- provider_id: huggingface
|
- provider_id: huggingface-cpu
|
||||||
provider_type: inline::huggingface
|
provider_type: inline::huggingface-cpu
|
||||||
config:
|
config:
|
||||||
checkpoint_format: huggingface
|
checkpoint_format: huggingface
|
||||||
distributed_backend: null
|
distributed_backend: null
|
||||||
|
|
|
@ -120,7 +120,7 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
],
|
],
|
||||||
"agents": [BuildProvider(provider_type="inline::meta-reference")],
|
"agents": [BuildProvider(provider_type="inline::meta-reference")],
|
||||||
"telemetry": [BuildProvider(provider_type="inline::meta-reference")],
|
"telemetry": [BuildProvider(provider_type="inline::meta-reference")],
|
||||||
"post_training": [BuildProvider(provider_type="inline::huggingface")],
|
"post_training": [BuildProvider(provider_type="inline::huggingface-cpu")],
|
||||||
"eval": [BuildProvider(provider_type="inline::meta-reference")],
|
"eval": [BuildProvider(provider_type="inline::meta-reference")],
|
||||||
"datasetio": [
|
"datasetio": [
|
||||||
BuildProvider(provider_type="remote::huggingface"),
|
BuildProvider(provider_type="remote::huggingface"),
|
||||||
|
@ -178,7 +178,7 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
return DistributionTemplate(
|
return DistributionTemplate(
|
||||||
name=name,
|
name=name,
|
||||||
distro_type="self_hosted",
|
distro_type="self_hosted",
|
||||||
description="Quick start template for running Llama Stack with several popular providers",
|
description="Quick start template for running Llama Stack with several popular providers. This distribution is intended for CPU-only environments.",
|
||||||
container_image=None,
|
container_image=None,
|
||||||
template_path=None,
|
template_path=None,
|
||||||
providers=providers,
|
providers=providers,
|
||||||
|
|
|
@ -4,10 +4,10 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import logging
|
import logging # allow-direct-logging
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from logging.config import dictConfig
|
from logging.config import dictConfig # allow-direct-logging
|
||||||
|
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
from rich.errors import MarkupError
|
from rich.errors import MarkupError
|
||||||
|
|
|
@ -13,14 +13,15 @@
|
||||||
|
|
||||||
# Copyright (c) Meta Platforms, Inc. and its affiliates.
|
# Copyright (c) Meta Platforms, Inc. and its affiliates.
|
||||||
import math
|
import math
|
||||||
from logging import getLogger
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
from .utils import get_negative_inf_value, to_2tuple
|
from .utils import get_negative_inf_value, to_2tuple
|
||||||
|
|
||||||
logger = getLogger()
|
logger = get_logger(name=__name__, category="models::llama")
|
||||||
|
|
||||||
|
|
||||||
def resize_local_position_embedding(orig_pos_embed, grid_size):
|
def resize_local_position_embedding(orig_pos_embed, grid_size):
|
||||||
|
|
|
@ -13,7 +13,6 @@
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from logging import getLogger
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
@ -21,9 +20,11 @@ import torchvision.transforms as tv
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from torchvision.transforms import functional as F
|
from torchvision.transforms import functional as F
|
||||||
|
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
IMAGE_RES = 224
|
IMAGE_RES = 224
|
||||||
|
|
||||||
logger = getLogger()
|
logger = get_logger(name=__name__, category="models::llama")
|
||||||
|
|
||||||
|
|
||||||
class VariableSizeImageTransform:
|
class VariableSizeImageTransform:
|
||||||
|
|
|
@ -3,8 +3,6 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import logging
|
|
||||||
import math
|
import math
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
@ -22,6 +20,8 @@ from PIL import Image as PIL_Image
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
from torch.distributed import _functional_collectives as funcol
|
from torch.distributed import _functional_collectives as funcol
|
||||||
|
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
from ..model import ModelArgs, RMSNorm, apply_rotary_emb, precompute_freqs_cis
|
from ..model import ModelArgs, RMSNorm, apply_rotary_emb, precompute_freqs_cis
|
||||||
from .encoder_utils import (
|
from .encoder_utils import (
|
||||||
build_encoder_attention_mask,
|
build_encoder_attention_mask,
|
||||||
|
@ -34,9 +34,10 @@ from .encoder_utils import (
|
||||||
from .image_transform import VariableSizeImageTransform
|
from .image_transform import VariableSizeImageTransform
|
||||||
from .utils import get_negative_inf_value, to_2tuple
|
from .utils import get_negative_inf_value, to_2tuple
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
MP_SCALE = 8
|
MP_SCALE = 8
|
||||||
|
|
||||||
|
logger = get_logger(name=__name__, category="models::llama")
|
||||||
|
|
||||||
|
|
||||||
def reduce_from_tensor_model_parallel_region(input_):
|
def reduce_from_tensor_model_parallel_region(input_):
|
||||||
"""All-reduce the input tensor across model parallel group."""
|
"""All-reduce the input tensor across model parallel group."""
|
||||||
|
@ -771,7 +772,7 @@ class TilePositionEmbedding(nn.Module):
|
||||||
if embed is not None:
|
if embed is not None:
|
||||||
# reshape the weights to the correct shape
|
# reshape the weights to the correct shape
|
||||||
nt_old, nt_old, _, w = embed.shape
|
nt_old, nt_old, _, w = embed.shape
|
||||||
logging.info(f"Resizing tile embedding from {nt_old}x{nt_old} to {self.num_tiles}x{self.num_tiles}")
|
logger.info(f"Resizing tile embedding from {nt_old}x{nt_old} to {self.num_tiles}x{self.num_tiles}")
|
||||||
embed_new = TilePositionEmbedding._dynamic_resize(embed, self.num_tiles)
|
embed_new = TilePositionEmbedding._dynamic_resize(embed, self.num_tiles)
|
||||||
# assign the weights to the module
|
# assign the weights to the module
|
||||||
state_dict[prefix + "embedding"] = embed_new
|
state_dict[prefix + "embedding"] = embed_new
|
||||||
|
|
|
@ -4,8 +4,8 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
from collections.abc import Collection, Iterator, Sequence, Set
|
from collections.abc import Collection, Iterator, Sequence, Set
|
||||||
from logging import getLogger
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import (
|
from typing import (
|
||||||
Literal,
|
Literal,
|
||||||
|
@ -14,11 +14,9 @@ from typing import (
|
||||||
|
|
||||||
import tiktoken
|
import tiktoken
|
||||||
|
|
||||||
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.models.llama.tokenizer_utils import load_bpe_file
|
from llama_stack.models.llama.tokenizer_utils import load_bpe_file
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
# The tiktoken tokenizer can handle <=400k chars without
|
# The tiktoken tokenizer can handle <=400k chars without
|
||||||
# pyo3_runtime.PanicException.
|
# pyo3_runtime.PanicException.
|
||||||
TIKTOKEN_MAX_ENCODE_CHARS = 400_000
|
TIKTOKEN_MAX_ENCODE_CHARS = 400_000
|
||||||
|
@ -31,6 +29,8 @@ MAX_NO_WHITESPACES_CHARS = 25_000
|
||||||
|
|
||||||
_INSTANCE = None
|
_INSTANCE = None
|
||||||
|
|
||||||
|
logger = get_logger(name=__name__, category="models::llama")
|
||||||
|
|
||||||
|
|
||||||
class Tokenizer:
|
class Tokenizer:
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -11,7 +11,7 @@ from llama_stack.log import get_logger
|
||||||
|
|
||||||
from ..datatypes import BuiltinTool, RecursiveType, ToolCall, ToolPromptFormat
|
from ..datatypes import BuiltinTool, RecursiveType, ToolCall, ToolPromptFormat
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="inference")
|
logger = get_logger(name=__name__, category="models::llama")
|
||||||
|
|
||||||
BUILTIN_TOOL_PATTERN = r'\b(?P<tool_name>\w+)\.call\(query="(?P<query>[^"]*)"\)'
|
BUILTIN_TOOL_PATTERN = r'\b(?P<tool_name>\w+)\.call\(query="(?P<query>[^"]*)"\)'
|
||||||
CUSTOM_TOOL_CALL_PATTERN = re.compile(r"<function=(?P<function_name>[^}]+)>(?P<args>{.*?})")
|
CUSTOM_TOOL_CALL_PATTERN = re.compile(r"<function=(?P<function_name>[^}]+)>(?P<args>{.*?})")
|
||||||
|
|
|
@ -4,7 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
|
|
||||||
|
@ -13,11 +12,13 @@ from fairscale.nn.model_parallel.initialize import get_model_parallel_rank
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
from ...datatypes import QuantizationMode
|
from ...datatypes import QuantizationMode
|
||||||
from ..model import Transformer, TransformerBlock
|
from ..model import Transformer, TransformerBlock
|
||||||
from ..moe import MoE
|
from ..moe import MoE
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = get_logger(name=__name__, category="models::llama")
|
||||||
|
|
||||||
|
|
||||||
def swiglu_wrapper_no_reduce(
|
def swiglu_wrapper_no_reduce(
|
||||||
|
|
|
@ -5,7 +5,6 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from collections.abc import Collection, Iterator, Sequence, Set
|
from collections.abc import Collection, Iterator, Sequence, Set
|
||||||
from logging import getLogger
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import (
|
from typing import (
|
||||||
Literal,
|
Literal,
|
||||||
|
@ -14,11 +13,9 @@ from typing import (
|
||||||
|
|
||||||
import tiktoken
|
import tiktoken
|
||||||
|
|
||||||
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.models.llama.tokenizer_utils import load_bpe_file
|
from llama_stack.models.llama.tokenizer_utils import load_bpe_file
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
# The tiktoken tokenizer can handle <=400k chars without
|
# The tiktoken tokenizer can handle <=400k chars without
|
||||||
# pyo3_runtime.PanicException.
|
# pyo3_runtime.PanicException.
|
||||||
TIKTOKEN_MAX_ENCODE_CHARS = 400_000
|
TIKTOKEN_MAX_ENCODE_CHARS = 400_000
|
||||||
|
@ -101,6 +98,8 @@ BASIC_SPECIAL_TOKENS = [
|
||||||
"<|fim_suffix|>",
|
"<|fim_suffix|>",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
logger = get_logger(name=__name__, category="models::llama")
|
||||||
|
|
||||||
|
|
||||||
class Tokenizer:
|
class Tokenizer:
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -6,9 +6,10 @@
|
||||||
|
|
||||||
# type: ignore
|
# type: ignore
|
||||||
import collections
|
import collections
|
||||||
import logging
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
|
log = get_logger(name=__name__, category="models::llama")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import fbgemm_gpu.experimental.gen_ai # noqa: F401
|
import fbgemm_gpu.experimental.gen_ai # noqa: F401
|
||||||
|
|
|
@ -84,7 +84,7 @@ MEMORY_QUERY_TOOL = "knowledge_search"
|
||||||
WEB_SEARCH_TOOL = "web_search"
|
WEB_SEARCH_TOOL = "web_search"
|
||||||
RAG_TOOL_GROUP = "builtin::rag"
|
RAG_TOOL_GROUP = "builtin::rag"
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="agents")
|
logger = get_logger(name=__name__, category="agents::meta_reference")
|
||||||
|
|
||||||
|
|
||||||
class ChatAgent(ShieldRunnerMixin):
|
class ChatAgent(ShieldRunnerMixin):
|
||||||
|
|
|
@ -4,7 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import logging
|
|
||||||
import uuid
|
import uuid
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
|
@ -42,6 +41,7 @@ from llama_stack.apis.safety import Safety
|
||||||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||||
from llama_stack.apis.vector_io import VectorIO
|
from llama_stack.apis.vector_io import VectorIO
|
||||||
from llama_stack.core.datatypes import AccessRule
|
from llama_stack.core.datatypes import AccessRule
|
||||||
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl
|
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl
|
||||||
from llama_stack.providers.utils.pagination import paginate_records
|
from llama_stack.providers.utils.pagination import paginate_records
|
||||||
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
|
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
|
||||||
|
@ -51,7 +51,7 @@ from .config import MetaReferenceAgentsImplConfig
|
||||||
from .persistence import AgentInfo
|
from .persistence import AgentInfo
|
||||||
from .responses.openai_responses import OpenAIResponsesImpl
|
from .responses.openai_responses import OpenAIResponsesImpl
|
||||||
|
|
||||||
logger = logging.getLogger()
|
logger = get_logger(name=__name__, category="agents::meta_reference")
|
||||||
|
|
||||||
|
|
||||||
class MetaReferenceAgentsImpl(Agents):
|
class MetaReferenceAgentsImpl(Agents):
|
||||||
|
|
|
@ -5,7 +5,6 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
|
|
||||||
|
@ -15,9 +14,10 @@ from llama_stack.core.access_control.access_control import AccessDeniedError, is
|
||||||
from llama_stack.core.access_control.datatypes import AccessRule
|
from llama_stack.core.access_control.datatypes import AccessRule
|
||||||
from llama_stack.core.datatypes import User
|
from llama_stack.core.datatypes import User
|
||||||
from llama_stack.core.request_headers import get_authenticated_user
|
from llama_stack.core.request_headers import get_authenticated_user
|
||||||
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.utils.kvstore import KVStore
|
from llama_stack.providers.utils.kvstore import KVStore
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = get_logger(name=__name__, category="agents::meta_reference")
|
||||||
|
|
||||||
|
|
||||||
class AgentSessionInfo(Session):
|
class AgentSessionInfo(Session):
|
||||||
|
|
|
@ -41,7 +41,7 @@ from .utils import (
|
||||||
convert_response_text_to_chat_response_format,
|
convert_response_text_to_chat_response_format,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="responses")
|
logger = get_logger(name=__name__, category="openai::responses")
|
||||||
|
|
||||||
|
|
||||||
class OpenAIResponsePreviousResponseWithInputItems(BaseModel):
|
class OpenAIResponsePreviousResponseWithInputItems(BaseModel):
|
||||||
|
|
|
@ -47,7 +47,7 @@ from llama_stack.log import get_logger
|
||||||
from .types import ChatCompletionContext, ChatCompletionResult
|
from .types import ChatCompletionContext, ChatCompletionResult
|
||||||
from .utils import convert_chat_choice_to_response_message, is_function_tool_call
|
from .utils import convert_chat_choice_to_response_message, is_function_tool_call
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="responses")
|
logger = get_logger(name=__name__, category="agents::meta_reference")
|
||||||
|
|
||||||
|
|
||||||
class StreamingResponseOrchestrator:
|
class StreamingResponseOrchestrator:
|
||||||
|
|
|
@ -38,7 +38,7 @@ from llama_stack.log import get_logger
|
||||||
|
|
||||||
from .types import ChatCompletionContext, ToolExecutionResult
|
from .types import ChatCompletionContext, ToolExecutionResult
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="responses")
|
logger = get_logger(name=__name__, category="agents::meta_reference")
|
||||||
|
|
||||||
|
|
||||||
class ToolExecutor:
|
class ToolExecutor:
|
||||||
|
|
|
@ -17,6 +17,8 @@ from llama_stack.apis.agents.openai_responses import (
|
||||||
OpenAIResponseOutputMessageContent,
|
OpenAIResponseOutputMessageContent,
|
||||||
OpenAIResponseOutputMessageContentOutputText,
|
OpenAIResponseOutputMessageContentOutputText,
|
||||||
OpenAIResponseOutputMessageFunctionToolCall,
|
OpenAIResponseOutputMessageFunctionToolCall,
|
||||||
|
OpenAIResponseOutputMessageMCPCall,
|
||||||
|
OpenAIResponseOutputMessageMCPListTools,
|
||||||
OpenAIResponseText,
|
OpenAIResponseText,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
|
@ -99,14 +101,22 @@ async def convert_response_input_to_chat_messages(
|
||||||
"""
|
"""
|
||||||
messages: list[OpenAIMessageParam] = []
|
messages: list[OpenAIMessageParam] = []
|
||||||
if isinstance(input, list):
|
if isinstance(input, list):
|
||||||
|
# extract all OpenAIResponseInputFunctionToolCallOutput items
|
||||||
|
# so their corresponding OpenAIToolMessageParam instances can
|
||||||
|
# be added immediately following the corresponding
|
||||||
|
# OpenAIAssistantMessageParam
|
||||||
|
tool_call_results = {}
|
||||||
for input_item in input:
|
for input_item in input:
|
||||||
if isinstance(input_item, OpenAIResponseInputFunctionToolCallOutput):
|
if isinstance(input_item, OpenAIResponseInputFunctionToolCallOutput):
|
||||||
messages.append(
|
tool_call_results[input_item.call_id] = OpenAIToolMessageParam(
|
||||||
OpenAIToolMessageParam(
|
content=input_item.output,
|
||||||
content=input_item.output,
|
tool_call_id=input_item.call_id,
|
||||||
tool_call_id=input_item.call_id,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
for input_item in input:
|
||||||
|
if isinstance(input_item, OpenAIResponseInputFunctionToolCallOutput):
|
||||||
|
# skip as these have been extracted and inserted in order
|
||||||
|
pass
|
||||||
elif isinstance(input_item, OpenAIResponseOutputMessageFunctionToolCall):
|
elif isinstance(input_item, OpenAIResponseOutputMessageFunctionToolCall):
|
||||||
tool_call = OpenAIChatCompletionToolCall(
|
tool_call = OpenAIChatCompletionToolCall(
|
||||||
index=0,
|
index=0,
|
||||||
|
@ -117,6 +127,28 @@ async def convert_response_input_to_chat_messages(
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
messages.append(OpenAIAssistantMessageParam(tool_calls=[tool_call]))
|
messages.append(OpenAIAssistantMessageParam(tool_calls=[tool_call]))
|
||||||
|
if input_item.call_id in tool_call_results:
|
||||||
|
messages.append(tool_call_results[input_item.call_id])
|
||||||
|
del tool_call_results[input_item.call_id]
|
||||||
|
elif isinstance(input_item, OpenAIResponseOutputMessageMCPCall):
|
||||||
|
tool_call = OpenAIChatCompletionToolCall(
|
||||||
|
index=0,
|
||||||
|
id=input_item.id,
|
||||||
|
function=OpenAIChatCompletionToolCallFunction(
|
||||||
|
name=input_item.name,
|
||||||
|
arguments=input_item.arguments,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
messages.append(OpenAIAssistantMessageParam(tool_calls=[tool_call]))
|
||||||
|
messages.append(
|
||||||
|
OpenAIToolMessageParam(
|
||||||
|
content=input_item.output,
|
||||||
|
tool_call_id=input_item.id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif isinstance(input_item, OpenAIResponseOutputMessageMCPListTools):
|
||||||
|
# the tool list will be handled separately
|
||||||
|
pass
|
||||||
else:
|
else:
|
||||||
content = await convert_response_content_to_chat_content(input_item.content)
|
content = await convert_response_content_to_chat_content(input_item.content)
|
||||||
message_type = await get_message_type_by_role(input_item.role)
|
message_type = await get_message_type_by_role(input_item.role)
|
||||||
|
@ -125,6 +157,10 @@ async def convert_response_input_to_chat_messages(
|
||||||
f"Llama Stack OpenAI Responses does not yet support message role '{input_item.role}' in this context"
|
f"Llama Stack OpenAI Responses does not yet support message role '{input_item.role}' in this context"
|
||||||
)
|
)
|
||||||
messages.append(message_type(content=content))
|
messages.append(message_type(content=content))
|
||||||
|
if len(tool_call_results):
|
||||||
|
raise ValueError(
|
||||||
|
f"Received function_call_output(s) with call_id(s) {tool_call_results.keys()}, but no corresponding function_call"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
messages.append(OpenAIUserMessageParam(content=input))
|
messages.append(OpenAIUserMessageParam(content=input))
|
||||||
return messages
|
return messages
|
||||||
|
|
|
@ -5,13 +5,13 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
|
||||||
|
|
||||||
from llama_stack.apis.inference import Message
|
from llama_stack.apis.inference import Message
|
||||||
from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel
|
from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel
|
||||||
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.utils.telemetry import tracing
|
from llama_stack.providers.utils.telemetry import tracing
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = get_logger(name=__name__, category="agents::meta_reference")
|
||||||
|
|
||||||
|
|
||||||
class SafetyException(Exception): # noqa: N818
|
class SafetyException(Exception): # noqa: N818
|
||||||
|
|
|
@ -5,6 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import hashlib
|
||||||
import itertools
|
import itertools
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
|
@ -136,28 +137,45 @@ class ReferenceBatchesImpl(Batches):
|
||||||
endpoint: str,
|
endpoint: str,
|
||||||
completion_window: Literal["24h"],
|
completion_window: Literal["24h"],
|
||||||
metadata: dict[str, str] | None = None,
|
metadata: dict[str, str] | None = None,
|
||||||
|
idempotency_key: str | None = None,
|
||||||
) -> BatchObject:
|
) -> BatchObject:
|
||||||
"""
|
"""
|
||||||
Create a new batch for processing multiple API requests.
|
Create a new batch for processing multiple API requests.
|
||||||
|
|
||||||
Error handling by levels -
|
This implementation provides optional idempotency: when an idempotency key
|
||||||
0. Input param handling, results in 40x errors before processing, e.g.
|
(idempotency_key) is provided, a deterministic ID is generated based on the input
|
||||||
- Wrong completion_window
|
parameters. If a batch with the same parameters already exists, it will be
|
||||||
- Invalid metadata types
|
returned instead of creating a duplicate. Without an idempotency key,
|
||||||
- Unknown endpoint
|
each request creates a new batch with a unique ID.
|
||||||
-> no batch created
|
|
||||||
1. Errors preventing processing, result in BatchErrors aggregated in process_batch, e.g.
|
Args:
|
||||||
- input_file_id missing
|
input_file_id: The ID of an uploaded file containing requests for the batch.
|
||||||
- invalid json in file
|
endpoint: The endpoint to be used for all requests in the batch.
|
||||||
- missing custom_id, method, url, body
|
completion_window: The time window within which the batch should be processed.
|
||||||
- invalid model
|
metadata: Optional metadata for the batch.
|
||||||
- streaming
|
idempotency_key: Optional idempotency key for enabling idempotent behavior.
|
||||||
-> batch created, validation sends to failed status
|
|
||||||
2. Processing errors, result in error_file_id entries, e.g.
|
Returns:
|
||||||
- Any error returned from inference endpoint
|
The created or existing batch object.
|
||||||
-> batch created, goes to completed status
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Error handling by levels -
|
||||||
|
# 0. Input param handling, results in 40x errors before processing, e.g.
|
||||||
|
# - Wrong completion_window
|
||||||
|
# - Invalid metadata types
|
||||||
|
# - Unknown endpoint
|
||||||
|
# -> no batch created
|
||||||
|
# 1. Errors preventing processing, result in BatchErrors aggregated in process_batch, e.g.
|
||||||
|
# - input_file_id missing
|
||||||
|
# - invalid json in file
|
||||||
|
# - missing custom_id, method, url, body
|
||||||
|
# - invalid model
|
||||||
|
# - streaming
|
||||||
|
# -> batch created, validation sends to failed status
|
||||||
|
# 2. Processing errors, result in error_file_id entries, e.g.
|
||||||
|
# - Any error returned from inference endpoint
|
||||||
|
# -> batch created, goes to completed status
|
||||||
|
|
||||||
# TODO: set expiration time for garbage collection
|
# TODO: set expiration time for garbage collection
|
||||||
|
|
||||||
if endpoint not in ["/v1/chat/completions"]:
|
if endpoint not in ["/v1/chat/completions"]:
|
||||||
|
@ -171,6 +189,35 @@ class ReferenceBatchesImpl(Batches):
|
||||||
)
|
)
|
||||||
|
|
||||||
batch_id = f"batch_{uuid.uuid4().hex[:16]}"
|
batch_id = f"batch_{uuid.uuid4().hex[:16]}"
|
||||||
|
|
||||||
|
# For idempotent requests, use the idempotency key for the batch ID
|
||||||
|
# This ensures the same key always maps to the same batch ID,
|
||||||
|
# allowing us to detect parameter conflicts
|
||||||
|
if idempotency_key is not None:
|
||||||
|
hash_input = idempotency_key.encode("utf-8")
|
||||||
|
hash_digest = hashlib.sha256(hash_input).hexdigest()[:24]
|
||||||
|
batch_id = f"batch_{hash_digest}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
existing_batch = await self.retrieve_batch(batch_id)
|
||||||
|
|
||||||
|
if (
|
||||||
|
existing_batch.input_file_id != input_file_id
|
||||||
|
or existing_batch.endpoint != endpoint
|
||||||
|
or existing_batch.completion_window != completion_window
|
||||||
|
or existing_batch.metadata != metadata
|
||||||
|
):
|
||||||
|
raise ConflictError(
|
||||||
|
f"Idempotency key '{idempotency_key}' was previously used with different parameters. "
|
||||||
|
"Either use a new idempotency key or ensure all parameters match the original request."
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Returning existing batch with ID: {batch_id}")
|
||||||
|
return existing_batch
|
||||||
|
except ResourceNotFoundError:
|
||||||
|
# Batch doesn't exist, continue with creation
|
||||||
|
pass
|
||||||
|
|
||||||
current_time = int(time.time())
|
current_time = int(time.time())
|
||||||
|
|
||||||
batch = BatchObject(
|
batch = BatchObject(
|
||||||
|
@ -185,6 +232,7 @@ class ReferenceBatchesImpl(Batches):
|
||||||
)
|
)
|
||||||
|
|
||||||
await self.kvstore.set(f"batch:{batch_id}", batch.to_json())
|
await self.kvstore.set(f"batch:{batch_id}", batch.to_json())
|
||||||
|
logger.info(f"Created new batch with ID: {batch_id}")
|
||||||
|
|
||||||
if self.process_batches:
|
if self.process_batches:
|
||||||
task = asyncio.create_task(self._process_batch(batch_id))
|
task = asyncio.create_task(self._process_batch(batch_id))
|
||||||
|
|
|
@ -11,6 +11,7 @@ from typing import Annotated
|
||||||
|
|
||||||
from fastapi import File, Form, Response, UploadFile
|
from fastapi import File, Form, Response, UploadFile
|
||||||
|
|
||||||
|
from llama_stack.apis.common.errors import ResourceNotFoundError
|
||||||
from llama_stack.apis.common.responses import Order
|
from llama_stack.apis.common.responses import Order
|
||||||
from llama_stack.apis.files import (
|
from llama_stack.apis.files import (
|
||||||
Files,
|
Files,
|
||||||
|
@ -20,12 +21,15 @@ from llama_stack.apis.files import (
|
||||||
OpenAIFilePurpose,
|
OpenAIFilePurpose,
|
||||||
)
|
)
|
||||||
from llama_stack.core.datatypes import AccessRule
|
from llama_stack.core.datatypes import AccessRule
|
||||||
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
|
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
|
||||||
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
||||||
from llama_stack.providers.utils.sqlstore.sqlstore import sqlstore_impl
|
from llama_stack.providers.utils.sqlstore.sqlstore import sqlstore_impl
|
||||||
|
|
||||||
from .config import LocalfsFilesImplConfig
|
from .config import LocalfsFilesImplConfig
|
||||||
|
|
||||||
|
logger = get_logger(name=__name__, category="files")
|
||||||
|
|
||||||
|
|
||||||
class LocalfsFilesImpl(Files):
|
class LocalfsFilesImpl(Files):
|
||||||
def __init__(self, config: LocalfsFilesImplConfig, policy: list[AccessRule]) -> None:
|
def __init__(self, config: LocalfsFilesImplConfig, policy: list[AccessRule]) -> None:
|
||||||
|
@ -65,6 +69,18 @@ class LocalfsFilesImpl(Files):
|
||||||
"""Get the filesystem path for a file ID."""
|
"""Get the filesystem path for a file ID."""
|
||||||
return Path(self.config.storage_dir) / file_id
|
return Path(self.config.storage_dir) / file_id
|
||||||
|
|
||||||
|
async def _lookup_file_id(self, file_id: str) -> tuple[OpenAIFileObject, Path]:
|
||||||
|
"""Look up a OpenAIFileObject and filesystem path from its ID."""
|
||||||
|
if not self.sql_store:
|
||||||
|
raise RuntimeError("Files provider not initialized")
|
||||||
|
|
||||||
|
row = await self.sql_store.fetch_one("openai_files", policy=self.policy, where={"id": file_id})
|
||||||
|
if not row:
|
||||||
|
raise ResourceNotFoundError(file_id, "File", "client.files.list()")
|
||||||
|
|
||||||
|
file_path = Path(row.pop("file_path"))
|
||||||
|
return OpenAIFileObject(**row), file_path
|
||||||
|
|
||||||
# OpenAI Files API Implementation
|
# OpenAI Files API Implementation
|
||||||
async def openai_upload_file(
|
async def openai_upload_file(
|
||||||
self,
|
self,
|
||||||
|
@ -157,37 +173,19 @@ class LocalfsFilesImpl(Files):
|
||||||
|
|
||||||
async def openai_retrieve_file(self, file_id: str) -> OpenAIFileObject:
|
async def openai_retrieve_file(self, file_id: str) -> OpenAIFileObject:
|
||||||
"""Returns information about a specific file."""
|
"""Returns information about a specific file."""
|
||||||
if not self.sql_store:
|
file_obj, _ = await self._lookup_file_id(file_id)
|
||||||
raise RuntimeError("Files provider not initialized")
|
|
||||||
|
|
||||||
row = await self.sql_store.fetch_one("openai_files", policy=self.policy, where={"id": file_id})
|
return file_obj
|
||||||
if not row:
|
|
||||||
raise ValueError(f"File with id {file_id} not found")
|
|
||||||
|
|
||||||
return OpenAIFileObject(
|
|
||||||
id=row["id"],
|
|
||||||
filename=row["filename"],
|
|
||||||
purpose=OpenAIFilePurpose(row["purpose"]),
|
|
||||||
bytes=row["bytes"],
|
|
||||||
created_at=row["created_at"],
|
|
||||||
expires_at=row["expires_at"],
|
|
||||||
)
|
|
||||||
|
|
||||||
async def openai_delete_file(self, file_id: str) -> OpenAIFileDeleteResponse:
|
async def openai_delete_file(self, file_id: str) -> OpenAIFileDeleteResponse:
|
||||||
"""Delete a file."""
|
"""Delete a file."""
|
||||||
if not self.sql_store:
|
|
||||||
raise RuntimeError("Files provider not initialized")
|
|
||||||
|
|
||||||
row = await self.sql_store.fetch_one("openai_files", policy=self.policy, where={"id": file_id})
|
|
||||||
if not row:
|
|
||||||
raise ValueError(f"File with id {file_id} not found")
|
|
||||||
|
|
||||||
# Delete physical file
|
# Delete physical file
|
||||||
file_path = Path(row["file_path"])
|
_, file_path = await self._lookup_file_id(file_id)
|
||||||
if file_path.exists():
|
if file_path.exists():
|
||||||
file_path.unlink()
|
file_path.unlink()
|
||||||
|
|
||||||
# Delete metadata from database
|
# Delete metadata from database
|
||||||
|
assert self.sql_store is not None, "Files provider not initialized"
|
||||||
await self.sql_store.delete("openai_files", where={"id": file_id})
|
await self.sql_store.delete("openai_files", where={"id": file_id})
|
||||||
|
|
||||||
return OpenAIFileDeleteResponse(
|
return OpenAIFileDeleteResponse(
|
||||||
|
@ -197,25 +195,17 @@ class LocalfsFilesImpl(Files):
|
||||||
|
|
||||||
async def openai_retrieve_file_content(self, file_id: str) -> Response:
|
async def openai_retrieve_file_content(self, file_id: str) -> Response:
|
||||||
"""Returns the contents of the specified file."""
|
"""Returns the contents of the specified file."""
|
||||||
if not self.sql_store:
|
|
||||||
raise RuntimeError("Files provider not initialized")
|
|
||||||
|
|
||||||
# Get file metadata
|
|
||||||
row = await self.sql_store.fetch_one("openai_files", policy=self.policy, where={"id": file_id})
|
|
||||||
if not row:
|
|
||||||
raise ValueError(f"File with id {file_id} not found")
|
|
||||||
|
|
||||||
# Read file content
|
# Read file content
|
||||||
file_path = Path(row["file_path"])
|
file_obj, file_path = await self._lookup_file_id(file_id)
|
||||||
if not file_path.exists():
|
|
||||||
raise ValueError(f"File content not found on disk: {file_path}")
|
|
||||||
|
|
||||||
with open(file_path, "rb") as f:
|
if not file_path.exists():
|
||||||
content = f.read()
|
logger.warning(f"File '{file_id}'s underlying '{file_path}' is missing, deleting metadata.")
|
||||||
|
await self.openai_delete_file(file_id)
|
||||||
|
raise ResourceNotFoundError(file_id, "File", "client.files.list()")
|
||||||
|
|
||||||
# Return as binary response with appropriate content type
|
# Return as binary response with appropriate content type
|
||||||
return Response(
|
return Response(
|
||||||
content=content,
|
content=file_path.read_bytes(),
|
||||||
media_type="application/octet-stream",
|
media_type="application/octet-stream",
|
||||||
headers={"Content-Disposition": f'attachment; filename="{row["filename"]}"'},
|
headers={"Content-Disposition": f'attachment; filename="{file_obj.filename}"'},
|
||||||
)
|
)
|
||||||
|
|
|
@ -12,7 +12,6 @@
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
import json
|
import json
|
||||||
import logging
|
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
|
@ -32,13 +31,14 @@ from fairscale.nn.model_parallel.initialize import (
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from torch.distributed.launcher.api import LaunchConfig, elastic_launch
|
from torch.distributed.launcher.api import LaunchConfig, elastic_launch
|
||||||
|
|
||||||
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.models.llama.datatypes import GenerationResult
|
from llama_stack.models.llama.datatypes import GenerationResult
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
ChatCompletionRequestWithRawContent,
|
ChatCompletionRequestWithRawContent,
|
||||||
CompletionRequestWithRawContent,
|
CompletionRequestWithRawContent,
|
||||||
)
|
)
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = get_logger(name=__name__, category="inference")
|
||||||
|
|
||||||
|
|
||||||
class ProcessingMessageName(str, Enum):
|
class ProcessingMessageName(str, Enum):
|
||||||
|
|
|
@ -4,13 +4,11 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import logging
|
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
|
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
CompletionResponse,
|
CompletionResponse,
|
||||||
InferenceProvider,
|
InferenceProvider,
|
||||||
InterleavedContent,
|
|
||||||
LogProbConfig,
|
LogProbConfig,
|
||||||
Message,
|
Message,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
|
@ -21,6 +19,7 @@ from llama_stack.apis.inference import (
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.models import ModelType
|
from llama_stack.apis.models import ModelType
|
||||||
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
|
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
|
||||||
from llama_stack.providers.utils.inference.embedding_mixin import (
|
from llama_stack.providers.utils.inference.embedding_mixin import (
|
||||||
SentenceTransformerEmbeddingMixin,
|
SentenceTransformerEmbeddingMixin,
|
||||||
|
@ -32,7 +31,7 @@ from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
|
|
||||||
from .config import SentenceTransformersInferenceConfig
|
from .config import SentenceTransformersInferenceConfig
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = get_logger(name=__name__, category="inference")
|
||||||
|
|
||||||
|
|
||||||
class SentenceTransformersInferenceImpl(
|
class SentenceTransformersInferenceImpl(
|
||||||
|
@ -100,25 +99,3 @@ class SentenceTransformersInferenceImpl(
|
||||||
tool_config: ToolConfig | None = None,
|
tool_config: ToolConfig | None = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
raise ValueError("Sentence transformers don't support chat completion")
|
raise ValueError("Sentence transformers don't support chat completion")
|
||||||
|
|
||||||
async def batch_completion(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
content_batch: list[InterleavedContent],
|
|
||||||
sampling_params: SamplingParams | None = None,
|
|
||||||
response_format: ResponseFormat | None = None,
|
|
||||||
logprobs: LogProbConfig | None = None,
|
|
||||||
):
|
|
||||||
raise NotImplementedError("Batch completion is not supported for Sentence Transformers")
|
|
||||||
|
|
||||||
async def batch_chat_completion(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
messages_batch: list[list[Message]],
|
|
||||||
sampling_params: SamplingParams | None = None,
|
|
||||||
tools: list[ToolDefinition] | None = None,
|
|
||||||
tool_config: ToolConfig | None = None,
|
|
||||||
response_format: ResponseFormat | None = None,
|
|
||||||
logprobs: LogProbConfig | None = None,
|
|
||||||
):
|
|
||||||
raise NotImplementedError("Batch chat completion is not supported for Sentence Transformers")
|
|
||||||
|
|
|
@ -6,7 +6,6 @@
|
||||||
|
|
||||||
import gc
|
import gc
|
||||||
import json
|
import json
|
||||||
import logging
|
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
@ -28,6 +27,7 @@ from llama_stack.apis.post_training import (
|
||||||
LoraFinetuningConfig,
|
LoraFinetuningConfig,
|
||||||
TrainingConfig,
|
TrainingConfig,
|
||||||
)
|
)
|
||||||
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
|
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
|
||||||
|
|
||||||
from ..config import HuggingFacePostTrainingConfig
|
from ..config import HuggingFacePostTrainingConfig
|
||||||
|
@ -44,7 +44,7 @@ from ..utils import (
|
||||||
split_dataset,
|
split_dataset,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = get_logger(name=__name__, category="post_training")
|
||||||
|
|
||||||
|
|
||||||
class HFFinetuningSingleDevice:
|
class HFFinetuningSingleDevice:
|
||||||
|
|
|
@ -5,7 +5,6 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import gc
|
import gc
|
||||||
import logging
|
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
@ -24,6 +23,7 @@ from llama_stack.apis.post_training import (
|
||||||
DPOAlignmentConfig,
|
DPOAlignmentConfig,
|
||||||
TrainingConfig,
|
TrainingConfig,
|
||||||
)
|
)
|
||||||
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
|
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
|
||||||
|
|
||||||
from ..config import HuggingFacePostTrainingConfig
|
from ..config import HuggingFacePostTrainingConfig
|
||||||
|
@ -40,7 +40,7 @@ from ..utils import (
|
||||||
split_dataset,
|
split_dataset,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = get_logger(name=__name__, category="post_training")
|
||||||
|
|
||||||
|
|
||||||
class HFDPOAlignmentSingleDevice:
|
class HFDPOAlignmentSingleDevice:
|
||||||
|
|
|
@ -4,7 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
import signal
|
import signal
|
||||||
import sys
|
import sys
|
||||||
|
@ -19,10 +18,11 @@ from transformers import AutoConfig, AutoModelForCausalLM
|
||||||
|
|
||||||
from llama_stack.apis.datasetio import DatasetIO
|
from llama_stack.apis.datasetio import DatasetIO
|
||||||
from llama_stack.apis.post_training import Checkpoint, TrainingConfig
|
from llama_stack.apis.post_training import Checkpoint, TrainingConfig
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
from .config import HuggingFacePostTrainingConfig
|
from .config import HuggingFacePostTrainingConfig
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = get_logger(name=__name__, category="post_training")
|
||||||
|
|
||||||
|
|
||||||
def setup_environment():
|
def setup_environment():
|
||||||
|
|
|
@ -4,7 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
|
@ -19,6 +18,7 @@ from torch.utils.data import DataLoader, DistributedSampler
|
||||||
from torchtune import modules, training
|
from torchtune import modules, training
|
||||||
from torchtune import utils as torchtune_utils
|
from torchtune import utils as torchtune_utils
|
||||||
from torchtune.data import padded_collate_sft
|
from torchtune.data import padded_collate_sft
|
||||||
|
from torchtune.models.llama3._tokenizer import Llama3Tokenizer
|
||||||
from torchtune.modules.loss import CEWithChunkedOutputLoss
|
from torchtune.modules.loss import CEWithChunkedOutputLoss
|
||||||
from torchtune.modules.peft import (
|
from torchtune.modules.peft import (
|
||||||
get_adapter_params,
|
get_adapter_params,
|
||||||
|
@ -45,6 +45,7 @@ from llama_stack.apis.post_training import (
|
||||||
)
|
)
|
||||||
from llama_stack.core.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
|
from llama_stack.core.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
|
||||||
from llama_stack.core.utils.model_utils import model_local_dir
|
from llama_stack.core.utils.model_utils import model_local_dir
|
||||||
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.models.llama.sku_list import resolve_model
|
from llama_stack.models.llama.sku_list import resolve_model
|
||||||
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
|
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
|
||||||
from llama_stack.providers.inline.post_training.torchtune.common import utils
|
from llama_stack.providers.inline.post_training.torchtune.common import utils
|
||||||
|
@ -56,9 +57,7 @@ from llama_stack.providers.inline.post_training.torchtune.config import (
|
||||||
)
|
)
|
||||||
from llama_stack.providers.inline.post_training.torchtune.datasets.sft import SFTDataset
|
from llama_stack.providers.inline.post_training.torchtune.datasets.sft import SFTDataset
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = get_logger(name=__name__, category="post_training")
|
||||||
|
|
||||||
from torchtune.models.llama3._tokenizer import Llama3Tokenizer
|
|
||||||
|
|
||||||
|
|
||||||
class LoraFinetuningSingleDevice:
|
class LoraFinetuningSingleDevice:
|
||||||
|
|
|
@ -4,7 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import logging
|
|
||||||
import uuid
|
import uuid
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
@ -20,13 +19,14 @@ from llama_stack.apis.safety import (
|
||||||
)
|
)
|
||||||
from llama_stack.apis.safety.safety import ModerationObject, ModerationObjectResults
|
from llama_stack.apis.safety.safety import ModerationObject, ModerationObjectResults
|
||||||
from llama_stack.apis.shields import Shield
|
from llama_stack.apis.shields import Shield
|
||||||
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
interleaved_content_as_str,
|
interleaved_content_as_str,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .config import CodeScannerConfig
|
from .config import CodeScannerConfig
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = get_logger(name=__name__, category="safety")
|
||||||
|
|
||||||
ALLOWED_CODE_SCANNER_MODEL_IDS = [
|
ALLOWED_CODE_SCANNER_MODEL_IDS = [
|
||||||
"code-scanner",
|
"code-scanner",
|
||||||
|
|
|
@ -4,7 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import logging
|
|
||||||
import re
|
import re
|
||||||
import uuid
|
import uuid
|
||||||
from string import Template
|
from string import Template
|
||||||
|
@ -21,6 +20,7 @@ from llama_stack.apis.safety import (
|
||||||
from llama_stack.apis.safety.safety import ModerationObject, ModerationObjectResults
|
from llama_stack.apis.safety.safety import ModerationObject, ModerationObjectResults
|
||||||
from llama_stack.apis.shields import Shield
|
from llama_stack.apis.shields import Shield
|
||||||
from llama_stack.core.datatypes import Api
|
from llama_stack.core.datatypes import Api
|
||||||
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.models.llama.datatypes import Role
|
from llama_stack.models.llama.datatypes import Role
|
||||||
from llama_stack.models.llama.sku_types import CoreModelId
|
from llama_stack.models.llama.sku_types import CoreModelId
|
||||||
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
||||||
|
@ -132,6 +132,8 @@ Provide your safety assessment for ONLY THE LAST $agent_type message in the abov
|
||||||
|
|
||||||
PROMPT_TEMPLATE = Template(f"{PROMPT_TASK}{SAFETY_CATEGORIES}{PROMPT_CONVERSATION}{PROMPT_INSTRUCTIONS}")
|
PROMPT_TEMPLATE = Template(f"{PROMPT_TASK}{SAFETY_CATEGORIES}{PROMPT_CONVERSATION}{PROMPT_INSTRUCTIONS}")
|
||||||
|
|
||||||
|
logger = get_logger(name=__name__, category="safety")
|
||||||
|
|
||||||
|
|
||||||
class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
||||||
def __init__(self, config: LlamaGuardConfig, deps) -> None:
|
def __init__(self, config: LlamaGuardConfig, deps) -> None:
|
||||||
|
@ -407,7 +409,7 @@ class LlamaGuardShield:
|
||||||
unsafe_code_list = [code.strip() for code in unsafe_code.split(",")]
|
unsafe_code_list = [code.strip() for code in unsafe_code.split(",")]
|
||||||
invalid_codes = [code for code in unsafe_code_list if code not in SAFETY_CODE_TO_CATEGORIES_MAP]
|
invalid_codes = [code for code in unsafe_code_list if code not in SAFETY_CODE_TO_CATEGORIES_MAP]
|
||||||
if invalid_codes:
|
if invalid_codes:
|
||||||
logging.warning(f"Invalid safety codes returned: {invalid_codes}")
|
logger.warning(f"Invalid safety codes returned: {invalid_codes}")
|
||||||
# just returning safe object, as we don't know what the invalid codes can map to
|
# just returning safe object, as we don't know what the invalid codes can map to
|
||||||
return ModerationObject(
|
return ModerationObject(
|
||||||
id=f"modr-{uuid.uuid4()}",
|
id=f"modr-{uuid.uuid4()}",
|
||||||
|
|
|
@ -4,7 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
@ -21,6 +20,7 @@ from llama_stack.apis.safety import (
|
||||||
from llama_stack.apis.safety.safety import ModerationObject
|
from llama_stack.apis.safety.safety import ModerationObject
|
||||||
from llama_stack.apis.shields import Shield
|
from llama_stack.apis.shields import Shield
|
||||||
from llama_stack.core.utils.model_utils import model_local_dir
|
from llama_stack.core.utils.model_utils import model_local_dir
|
||||||
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
interleaved_content_as_str,
|
interleaved_content_as_str,
|
||||||
|
@ -28,7 +28,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
|
|
||||||
from .config import PromptGuardConfig, PromptGuardType
|
from .config import PromptGuardConfig, PromptGuardType
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = get_logger(name=__name__, category="safety")
|
||||||
|
|
||||||
PROMPT_GUARD_MODEL = "Prompt-Guard-86M"
|
PROMPT_GUARD_MODEL = "Prompt-Guard-86M"
|
||||||
|
|
||||||
|
|
|
@ -7,7 +7,6 @@
|
||||||
import collections
|
import collections
|
||||||
import functools
|
import functools
|
||||||
import json
|
import json
|
||||||
import logging
|
|
||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
import string
|
import string
|
||||||
|
@ -20,7 +19,9 @@ import nltk
|
||||||
from pythainlp.tokenize import sent_tokenize as sent_tokenize_thai
|
from pythainlp.tokenize import sent_tokenize as sent_tokenize_thai
|
||||||
from pythainlp.tokenize import word_tokenize as word_tokenize_thai
|
from pythainlp.tokenize import word_tokenize as word_tokenize_thai
|
||||||
|
|
||||||
logger = logging.getLogger()
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(name=__name__, category="scoring")
|
||||||
|
|
||||||
WORD_LIST = [
|
WORD_LIST = [
|
||||||
"western",
|
"western",
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue