forked from phoenix-oss/llama-stack-mirror
Compare commits
220 commits
Author | SHA1 | Date | |
---|---|---|---|
b174effe05 | |||
8943b283e9 | |||
08905fc937 | |||
8b5b1c937b | |||
205fc2cbd1 | |||
4a122bbaca | |||
a77b554bcf | |||
51816af52e | |||
96003b55de | |||
3bde47e562 | |||
ed31462499 | |||
43a7713140 | |||
ad9860c312 | |||
9b70e01c99 | |||
7bba685dee | |||
4603206065 | |||
16abfaeb69 | |||
b2ac7f69cc | |||
00fc43ae96 | |||
65936f7933 | |||
226e443e03 | |||
5b057d60ee | |||
95a56b62a0 | |||
c642ea2dd5 | |||
7e1725f72b | |||
b414fe5566 | |||
cfa38bd13b | |||
|
b21050935e | ||
|
277f8690ef | ||
|
f328436831 | ||
|
31ce208bda | ||
|
ad15276da1 | ||
|
2603f10f95 | ||
|
168c7113df | ||
|
f0d8ceb242 | ||
|
bfdd15d1fa | ||
|
a654467552 | ||
|
63a9f08c9e | ||
|
56e5ddb39f | ||
|
6352078e4b | ||
|
a7ecc92be1 | ||
|
9b7f9db05c | ||
|
0b695538af | ||
|
1d46f3102e | ||
|
4f3f28f718 | ||
|
484abe3116 | ||
|
7105a25b0f | ||
|
5cdb29758a | ||
|
6ee319ae08 | ||
|
a8f75d3897 | ||
|
e7e9ec0379 | ||
|
b2adaa3f60 | ||
|
448f00903d | ||
|
28930cdab6 | ||
|
7504c2f430 | ||
|
51e6f529f3 | ||
|
39b33a3b01 | ||
|
7710b2f43b | ||
|
9623d5d230 | ||
|
ce33d02443 | ||
|
5a422e236c | ||
|
c25bd0ad58 | ||
|
298721c238 | ||
|
eedf21f19c | ||
|
ae7272d8ff | ||
|
a2160dc0af | ||
|
c290999c63 | ||
|
3faf1e4a79 | ||
|
66f09f24ed | ||
|
84751f3e55 | ||
|
a411029d7e | ||
|
15b0a67555 | ||
|
055f48b6a2 | ||
|
ca65617a71 | ||
|
5844c2da68 | ||
|
6463ee7633 | ||
|
558d109ab7 | ||
|
b054023800 | ||
|
51945f1e57 | ||
|
2708312168 | ||
|
d8c6ab9bfc | ||
|
8feb1827c8 | ||
|
549812f51e | ||
|
633bb9c5b3 | ||
|
02e5e8a633 | ||
|
37f1e8a7f7 | ||
|
e92301f2d7 | ||
|
85b5f3172b | ||
|
6a62e783b9 | ||
|
1862de4be5 | ||
|
c25acedbcd | ||
|
2890243107 | ||
|
5a3d777b20 | ||
|
091d8c48f2 | ||
|
87a4b9cb28 | ||
|
3339844fda | ||
|
1a770cf8ac | ||
|
2eae8568e1 | ||
|
3f6368d56c | ||
|
90d7612f5f | ||
|
ed7b4731aa | ||
|
6d20b720b8 | ||
|
82778ecbb0 | ||
|
0cc0731189 | ||
|
047303e339 | ||
|
c7015d3d60 | ||
|
1341916caf | ||
|
f40693e720 | ||
|
f02f7b28c1 | ||
|
8f9964f46b | ||
|
1ae61e8d5f | ||
|
65cf076f13 | ||
|
b8f7e1504d | ||
|
64f8d4c3ad | ||
|
953ccffca2 | ||
|
7f1f21fd6c | ||
|
7aae8fadbf | ||
|
3cc15f7d15 | ||
|
1a6d4af5e9 | ||
|
87e284f1a0 | ||
|
10b1056dea | ||
|
bb5fca9521 | ||
|
e46de23be6 | ||
|
7e25c8df28 | ||
|
c3f27de3ea | ||
|
354faa15ce | ||
|
8e7ab146f8 | ||
|
ff247e35be | ||
|
b42eb1ccbc | ||
|
aa5bef8e05 | ||
|
5052c3cbf3 | ||
|
268725868e | ||
|
a1fbfb51e2 | ||
|
43d4447ff0 | ||
|
1de0dfaab5 | ||
|
dd07c7a5b5 | ||
|
26dffff92a | ||
|
8e316c9b1e | ||
|
e0d10dd0b1 | ||
|
62476a5373 | ||
|
e3ad17ec5e | ||
|
a5d14749a5 | ||
|
23d9f3b1fb | ||
|
c985ea6326 | ||
|
136e6b3cf7 | ||
|
80c349965f | ||
|
53b7f50828 | ||
|
43e623eea6 | ||
|
675f34e79d | ||
|
9a6e91cd93 | ||
|
db21eab713 | ||
|
dd7be274b9 | ||
|
f2b83800cc | ||
|
473a07f624 | ||
|
0f878ad87a | ||
|
fe5f5e530c | ||
|
6371bb1b33 | ||
|
c91e3552a3 | ||
|
40e71758d9 | ||
|
6f1badc934 | ||
|
664161c462 | ||
|
b2b00a216b | ||
|
dd49ef31f1 | ||
|
a57985eeac | ||
|
1a529705da | ||
|
feb9eb8b0d | ||
|
c219a74fa0 | ||
|
7377a5c83e | ||
|
b9b13a3670 | ||
|
2413447467 | ||
|
3022f7b642 | ||
|
65cc971877 | ||
|
18d2312690 | ||
|
2e807b38cc | ||
|
4597145011 | ||
|
a5d151e912 | ||
|
a4247ce0a8 | ||
|
1fbda6bfaa | ||
|
16e163da0e | ||
|
15a1648be6 | ||
|
d27a0f276c | ||
|
6b4c218788 | ||
|
c69f14bfaa | ||
|
9f27578929 | ||
|
f1b103e6c8 | ||
|
272d3359ee | ||
|
9e6561a1ec | ||
|
ffe3d0b2cd | ||
|
88a796ca5a | ||
|
64829947d0 | ||
|
f36f68c590 | ||
|
6378c2a2f3 | ||
|
293d95b955 | ||
|
dc94433072 | ||
|
d897313e0b | ||
|
2c7aba4158 | ||
|
eab550f7d2 | ||
|
4412694018 | ||
|
653e8526ec | ||
|
78ef6a6099 | ||
|
17b5302543 | ||
|
afd7e750d9 | ||
|
5a2bfd6ad5 | ||
|
7532f4cdb2 | ||
|
799286fe52 | ||
|
4d0bfbf984 | ||
|
934446ddb4 | ||
|
2aca7265b3 | ||
|
fe9b5ef08b | ||
|
7807a86358 | ||
|
8dfce2f596 | ||
|
79851d93aa | ||
|
e6bbf8d20b | ||
|
c149cf2e0f | ||
|
1050837622 | ||
|
921ce36480 | ||
|
28687b0e85 | ||
|
6cf6791de1 | ||
|
0266b20535 | ||
|
bb1a85c9a0 |
727 changed files with 48483 additions and 72751 deletions
2
.github/CODEOWNERS
vendored
2
.github/CODEOWNERS
vendored
|
@ -2,4 +2,4 @@
|
||||||
|
|
||||||
# These owners will be the default owners for everything in
|
# These owners will be the default owners for everything in
|
||||||
# the repo. Unless a later match takes precedence,
|
# the repo. Unless a later match takes precedence,
|
||||||
* @ashwinb @yanxi0830 @hardikjshah @dltn @raghotham @dineshyv @vladimirivic @sixianyi0721 @ehhuang @terrytangyuan @SLR722 @leseb
|
* @ashwinb @yanxi0830 @hardikjshah @raghotham @ehhuang @terrytangyuan @leseb @bbrowning
|
||||||
|
|
10
.github/PULL_REQUEST_TEMPLATE.md
vendored
10
.github/PULL_REQUEST_TEMPLATE.md
vendored
|
@ -1,10 +1,8 @@
|
||||||
# What does this PR do?
|
# What does this PR do?
|
||||||
[Provide a short summary of what this PR does and why. Link to relevant issues if applicable.]
|
<!-- Provide a short summary of what this PR does and why. Link to relevant issues if applicable. -->
|
||||||
|
|
||||||
[//]: # (If resolving an issue, uncomment and update the line below)
|
<!-- If resolving an issue, uncomment and update the line below -->
|
||||||
[//]: # (Closes #[issue-number])
|
<!-- Closes #[issue-number] -->
|
||||||
|
|
||||||
## Test Plan
|
## Test Plan
|
||||||
[Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.*]
|
<!-- Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.* -->
|
||||||
|
|
||||||
[//]: # (## Documentation)
|
|
||||||
|
|
2
.github/TRIAGERS.md
vendored
2
.github/TRIAGERS.md
vendored
|
@ -1,2 +1,2 @@
|
||||||
# This file documents Triage members in the Llama Stack community
|
# This file documents Triage members in the Llama Stack community
|
||||||
@franciscojavierarceo @leseb
|
@bbrowning @booxter @franciscojavierarceo @leseb
|
||||||
|
|
26
.github/actions/setup-ollama/action.yml
vendored
Normal file
26
.github/actions/setup-ollama/action.yml
vendored
Normal file
|
@ -0,0 +1,26 @@
|
||||||
|
name: Setup Ollama
|
||||||
|
description: Start Ollama and cache model
|
||||||
|
inputs:
|
||||||
|
models:
|
||||||
|
description: Comma-separated list of models to pull
|
||||||
|
default: "llama3.2:3b-instruct-fp16,all-minilm:latest"
|
||||||
|
runs:
|
||||||
|
using: "composite"
|
||||||
|
steps:
|
||||||
|
- name: Install and start Ollama
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
# the ollama installer also starts the ollama service
|
||||||
|
curl -fsSL https://ollama.com/install.sh | sh
|
||||||
|
|
||||||
|
# Do NOT cache models - pulling the cache is actually slower than just pulling the model.
|
||||||
|
# It takes ~45 seconds to pull the models from the cache and unpack it, but only 30 seconds to
|
||||||
|
# pull them directly.
|
||||||
|
# Maybe this is because the cache is being pulled at the same time by all the matrix jobs?
|
||||||
|
- name: Pull requested models
|
||||||
|
if: inputs.models != ''
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
for model in $(echo "${{ inputs.models }}" | tr ',' ' '); do
|
||||||
|
ollama pull "$model"
|
||||||
|
done
|
22
.github/actions/setup-runner/action.yml
vendored
Normal file
22
.github/actions/setup-runner/action.yml
vendored
Normal file
|
@ -0,0 +1,22 @@
|
||||||
|
name: Setup runner
|
||||||
|
description: Prepare a runner for the tests (install uv, python, project dependencies, etc.)
|
||||||
|
runs:
|
||||||
|
using: "composite"
|
||||||
|
steps:
|
||||||
|
- name: Install uv
|
||||||
|
uses: astral-sh/setup-uv@6b9c6063abd6010835644d4c2e1bef4cf5cd0fca # v6.0.1
|
||||||
|
with:
|
||||||
|
python-version: "3.10"
|
||||||
|
activate-environment: true
|
||||||
|
version: 0.7.6
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
uv sync --all-groups
|
||||||
|
uv pip install ollama faiss-cpu
|
||||||
|
# always test against the latest version of the client
|
||||||
|
# TODO: this is not necessarily a good idea. we need to test against both published and latest
|
||||||
|
# to find out backwards compatibility issues.
|
||||||
|
uv pip install git+https://github.com/meta-llama/llama-stack-client-python.git@main
|
||||||
|
uv pip install -e .
|
1
.github/workflows/Dockerfile
vendored
Normal file
1
.github/workflows/Dockerfile
vendored
Normal file
|
@ -0,0 +1 @@
|
||||||
|
FROM localhost:5000/distribution-kvant:dev
|
73
.github/workflows/ci-playground.yaml
vendored
Normal file
73
.github/workflows/ci-playground.yaml
vendored
Normal file
|
@ -0,0 +1,73 @@
|
||||||
|
name: Build and Push playground container
|
||||||
|
run-name: Build and Push playground container
|
||||||
|
on:
|
||||||
|
workflow_dispatch:
|
||||||
|
#schedule:
|
||||||
|
# - cron: "0 10 * * *"
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
- kvant
|
||||||
|
tags:
|
||||||
|
- 'v*'
|
||||||
|
pull_request:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
- kvant
|
||||||
|
env:
|
||||||
|
IMAGE: git.kvant.cloud/${{github.repository}}-playground
|
||||||
|
jobs:
|
||||||
|
build-playground:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Checkout
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
fetch-depth: 0
|
||||||
|
|
||||||
|
- name: Set current time
|
||||||
|
uses: https://github.com/gerred/actions/current-time@master
|
||||||
|
id: current_time
|
||||||
|
|
||||||
|
- name: Set up Docker Buildx
|
||||||
|
uses: docker/setup-buildx-action@v3
|
||||||
|
|
||||||
|
- name: Login to git.kvant.cloud registry
|
||||||
|
uses: docker/login-action@v3
|
||||||
|
with:
|
||||||
|
registry: git.kvant.cloud
|
||||||
|
username: ${{ vars.ORG_PACKAGE_WRITER_USERNAME }}
|
||||||
|
password: ${{ secrets.ORG_PACKAGE_WRITER_TOKEN }}
|
||||||
|
|
||||||
|
- name: Docker meta
|
||||||
|
id: meta
|
||||||
|
uses: docker/metadata-action@v5
|
||||||
|
with:
|
||||||
|
# list of Docker images to use as base name for tags
|
||||||
|
images: |
|
||||||
|
${{env.IMAGE}}
|
||||||
|
# generate Docker tags based on the following events/attributes
|
||||||
|
tags: |
|
||||||
|
type=schedule
|
||||||
|
type=ref,event=branch
|
||||||
|
type=ref,event=pr
|
||||||
|
type=ref,event=tag
|
||||||
|
type=semver,pattern={{version}}
|
||||||
|
|
||||||
|
- name: Build and push to gitea registry
|
||||||
|
uses: docker/build-push-action@v6
|
||||||
|
with:
|
||||||
|
push: ${{ github.event_name != 'pull_request' }}
|
||||||
|
tags: ${{ steps.meta.outputs.tags }}
|
||||||
|
labels: ${{ steps.meta.outputs.labels }}
|
||||||
|
context: .
|
||||||
|
file: llama_stack/distribution/ui/Containerfile
|
||||||
|
provenance: mode=max
|
||||||
|
sbom: true
|
||||||
|
build-args: |
|
||||||
|
BUILD_DATE=${{ steps.current_time.outputs.time }}
|
||||||
|
cache-from: |
|
||||||
|
type=registry,ref=${{ env.IMAGE }}:buildcache
|
||||||
|
type=registry,ref=${{ env.IMAGE }}:${{ github.ref_name }}
|
||||||
|
type=registry,ref=${{ env.IMAGE }}:main
|
||||||
|
cache-to: type=registry,ref=${{ env.IMAGE }}:buildcache,mode=max,image-manifest=true
|
98
.github/workflows/ci.yaml
vendored
Normal file
98
.github/workflows/ci.yaml
vendored
Normal file
|
@ -0,0 +1,98 @@
|
||||||
|
name: Build and Push container
|
||||||
|
run-name: Build and Push container
|
||||||
|
on:
|
||||||
|
workflow_dispatch:
|
||||||
|
#schedule:
|
||||||
|
# - cron: "0 10 * * *"
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
- kvant
|
||||||
|
tags:
|
||||||
|
- 'v*'
|
||||||
|
pull_request:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
- kvant
|
||||||
|
env:
|
||||||
|
IMAGE: git.kvant.cloud/${{github.repository}}
|
||||||
|
jobs:
|
||||||
|
build:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
services:
|
||||||
|
registry:
|
||||||
|
image: registry:2
|
||||||
|
ports:
|
||||||
|
- 5000:5000
|
||||||
|
steps:
|
||||||
|
- name: Checkout
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
fetch-depth: 0
|
||||||
|
|
||||||
|
- name: Set current time
|
||||||
|
uses: https://github.com/gerred/actions/current-time@master
|
||||||
|
id: current_time
|
||||||
|
|
||||||
|
- name: Set up Docker Buildx
|
||||||
|
uses: docker/setup-buildx-action@v3
|
||||||
|
with:
|
||||||
|
driver-opts: network=host
|
||||||
|
|
||||||
|
- name: Login to git.kvant.cloud registry
|
||||||
|
uses: docker/login-action@v3
|
||||||
|
with:
|
||||||
|
registry: git.kvant.cloud
|
||||||
|
username: ${{ vars.ORG_PACKAGE_WRITER_USERNAME }}
|
||||||
|
password: ${{ secrets.ORG_PACKAGE_WRITER_TOKEN }}
|
||||||
|
|
||||||
|
- name: Docker meta
|
||||||
|
id: meta
|
||||||
|
uses: docker/metadata-action@v5
|
||||||
|
with:
|
||||||
|
# list of Docker images to use as base name for tags
|
||||||
|
images: |
|
||||||
|
${{env.IMAGE}}
|
||||||
|
# generate Docker tags based on the following events/attributes
|
||||||
|
tags: |
|
||||||
|
type=schedule
|
||||||
|
type=ref,event=branch
|
||||||
|
type=ref,event=pr
|
||||||
|
type=ref,event=tag
|
||||||
|
type=semver,pattern={{version}}
|
||||||
|
|
||||||
|
- name: Install uv
|
||||||
|
uses: https://github.com/astral-sh/setup-uv@v5
|
||||||
|
with:
|
||||||
|
# Install a specific version of uv.
|
||||||
|
version: "0.7.8"
|
||||||
|
|
||||||
|
- name: Build
|
||||||
|
env:
|
||||||
|
USE_COPY_NOT_MOUNT: true
|
||||||
|
LLAMA_STACK_DIR: .
|
||||||
|
run: |
|
||||||
|
uvx --from . llama stack build --template kvant --image-type container
|
||||||
|
|
||||||
|
# docker tag distribution-kvant:dev ${{env.IMAGE}}:kvant
|
||||||
|
# docker push ${{env.IMAGE}}:kvant
|
||||||
|
|
||||||
|
docker tag distribution-kvant:dev localhost:5000/distribution-kvant:dev
|
||||||
|
docker push localhost:5000/distribution-kvant:dev
|
||||||
|
|
||||||
|
- name: Build and push to gitea registry
|
||||||
|
uses: docker/build-push-action@v6
|
||||||
|
with:
|
||||||
|
push: ${{ github.event_name != 'pull_request' }}
|
||||||
|
tags: ${{ steps.meta.outputs.tags }}
|
||||||
|
labels: ${{ steps.meta.outputs.labels }}
|
||||||
|
context: .github/workflows
|
||||||
|
provenance: mode=max
|
||||||
|
sbom: true
|
||||||
|
build-args: |
|
||||||
|
BUILD_DATE=${{ steps.current_time.outputs.time }}
|
||||||
|
cache-from: |
|
||||||
|
type=registry,ref=${{ env.IMAGE }}:buildcache
|
||||||
|
type=registry,ref=${{ env.IMAGE }}:${{ github.ref_name }}
|
||||||
|
type=registry,ref=${{ env.IMAGE }}:main
|
||||||
|
cache-to: type=registry,ref=${{ env.IMAGE }}:buildcache,mode=max,image-manifest=true
|
26
.github/workflows_upstream/install-script-ci.yml
vendored
Normal file
26
.github/workflows_upstream/install-script-ci.yml
vendored
Normal file
|
@ -0,0 +1,26 @@
|
||||||
|
name: Installer CI
|
||||||
|
|
||||||
|
on:
|
||||||
|
pull_request:
|
||||||
|
paths:
|
||||||
|
- 'install.sh'
|
||||||
|
push:
|
||||||
|
paths:
|
||||||
|
- 'install.sh'
|
||||||
|
schedule:
|
||||||
|
- cron: '0 2 * * *' # every day at 02:00 UTC
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
lint:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # 4.2.2
|
||||||
|
- name: Run ShellCheck on install.sh
|
||||||
|
run: shellcheck install.sh
|
||||||
|
smoke-test:
|
||||||
|
needs: lint
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # 4.2.2
|
||||||
|
- name: Run installer end-to-end
|
||||||
|
run: ./install.sh
|
132
.github/workflows_upstream/integration-auth-tests.yml
vendored
Normal file
132
.github/workflows_upstream/integration-auth-tests.yml
vendored
Normal file
|
@ -0,0 +1,132 @@
|
||||||
|
name: Integration Auth Tests
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [ main ]
|
||||||
|
pull_request:
|
||||||
|
branches: [ main ]
|
||||||
|
paths:
|
||||||
|
- 'distributions/**'
|
||||||
|
- 'llama_stack/**'
|
||||||
|
- 'tests/integration/**'
|
||||||
|
- 'uv.lock'
|
||||||
|
- 'pyproject.toml'
|
||||||
|
- 'requirements.txt'
|
||||||
|
- '.github/workflows/integration-auth-tests.yml' # This workflow
|
||||||
|
|
||||||
|
concurrency:
|
||||||
|
group: ${{ github.workflow }}-${{ github.ref }}
|
||||||
|
cancel-in-progress: true
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
test-matrix:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
auth-provider: [oauth2_token]
|
||||||
|
fail-fast: false # we want to run all tests regardless of failure
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout repository
|
||||||
|
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
uses: ./.github/actions/setup-runner
|
||||||
|
|
||||||
|
- name: Build Llama Stack
|
||||||
|
run: |
|
||||||
|
llama stack build --template ollama --image-type venv
|
||||||
|
|
||||||
|
- name: Install minikube
|
||||||
|
if: ${{ matrix.auth-provider == 'kubernetes' }}
|
||||||
|
uses: medyagh/setup-minikube@cea33675329b799adccc9526aa5daccc26cd5052 # v0.0.19
|
||||||
|
|
||||||
|
- name: Start minikube
|
||||||
|
if: ${{ matrix.auth-provider == 'oauth2_token' }}
|
||||||
|
run: |
|
||||||
|
minikube start
|
||||||
|
kubectl get pods -A
|
||||||
|
|
||||||
|
- name: Configure Kube Auth
|
||||||
|
if: ${{ matrix.auth-provider == 'oauth2_token' }}
|
||||||
|
run: |
|
||||||
|
kubectl create namespace llama-stack
|
||||||
|
kubectl create serviceaccount llama-stack-auth -n llama-stack
|
||||||
|
kubectl create rolebinding llama-stack-auth-rolebinding --clusterrole=admin --serviceaccount=llama-stack:llama-stack-auth -n llama-stack
|
||||||
|
kubectl create token llama-stack-auth -n llama-stack > llama-stack-auth-token
|
||||||
|
cat <<EOF | kubectl apply -f -
|
||||||
|
apiVersion: rbac.authorization.k8s.io/v1
|
||||||
|
kind: ClusterRole
|
||||||
|
metadata:
|
||||||
|
name: allow-anonymous-openid
|
||||||
|
rules:
|
||||||
|
- nonResourceURLs: ["/openid/v1/jwks"]
|
||||||
|
verbs: ["get"]
|
||||||
|
---
|
||||||
|
apiVersion: rbac.authorization.k8s.io/v1
|
||||||
|
kind: ClusterRoleBinding
|
||||||
|
metadata:
|
||||||
|
name: allow-anonymous-openid
|
||||||
|
roleRef:
|
||||||
|
apiGroup: rbac.authorization.k8s.io
|
||||||
|
kind: ClusterRole
|
||||||
|
name: allow-anonymous-openid
|
||||||
|
subjects:
|
||||||
|
- kind: User
|
||||||
|
name: system:anonymous
|
||||||
|
apiGroup: rbac.authorization.k8s.io
|
||||||
|
EOF
|
||||||
|
|
||||||
|
- name: Set Kubernetes Config
|
||||||
|
if: ${{ matrix.auth-provider == 'oauth2_token' }}
|
||||||
|
run: |
|
||||||
|
echo "KUBERNETES_API_SERVER_URL=$(kubectl get --raw /.well-known/openid-configuration| jq -r .jwks_uri)" >> $GITHUB_ENV
|
||||||
|
echo "KUBERNETES_CA_CERT_PATH=$(kubectl config view --minify -o jsonpath='{.clusters[0].cluster.certificate-authority}')" >> $GITHUB_ENV
|
||||||
|
echo "KUBERNETES_ISSUER=$(kubectl get --raw /.well-known/openid-configuration| jq -r .issuer)" >> $GITHUB_ENV
|
||||||
|
echo "KUBERNETES_AUDIENCE=$(kubectl create token llama-stack-auth -n llama-stack --duration=1h | cut -d. -f2 | base64 -d | jq -r '.aud[0]')" >> $GITHUB_ENV
|
||||||
|
|
||||||
|
- name: Set Kube Auth Config and run server
|
||||||
|
env:
|
||||||
|
INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct"
|
||||||
|
if: ${{ matrix.auth-provider == 'oauth2_token' }}
|
||||||
|
run: |
|
||||||
|
run_dir=$(mktemp -d)
|
||||||
|
cat <<'EOF' > $run_dir/run.yaml
|
||||||
|
version: '2'
|
||||||
|
image_name: kube
|
||||||
|
apis: []
|
||||||
|
providers: {}
|
||||||
|
server:
|
||||||
|
port: 8321
|
||||||
|
EOF
|
||||||
|
yq eval '.server.auth = {"provider_type": "${{ matrix.auth-provider }}"}' -i $run_dir/run.yaml
|
||||||
|
yq eval '.server.auth.config = {"tls_cafile": "${{ env.KUBERNETES_CA_CERT_PATH }}", "issuer": "${{ env.KUBERNETES_ISSUER }}", "audience": "${{ env.KUBERNETES_AUDIENCE }}"}' -i $run_dir/run.yaml
|
||||||
|
yq eval '.server.auth.config.jwks = {"uri": "${{ env.KUBERNETES_API_SERVER_URL }}"}' -i $run_dir/run.yaml
|
||||||
|
cat $run_dir/run.yaml
|
||||||
|
|
||||||
|
nohup uv run llama stack run $run_dir/run.yaml --image-type venv > server.log 2>&1 &
|
||||||
|
|
||||||
|
- name: Wait for Llama Stack server to be ready
|
||||||
|
run: |
|
||||||
|
echo "Waiting for Llama Stack server..."
|
||||||
|
for i in {1..30}; do
|
||||||
|
if curl -s -L -H "Authorization: Bearer $(cat llama-stack-auth-token)" http://localhost:8321/v1/health | grep -q "OK"; then
|
||||||
|
echo "Llama Stack server is up!"
|
||||||
|
if grep -q "Enabling authentication with provider: ${{ matrix.auth-provider }}" server.log; then
|
||||||
|
echo "Llama Stack server is configured to use ${{ matrix.auth-provider }} auth"
|
||||||
|
exit 0
|
||||||
|
else
|
||||||
|
echo "Llama Stack server is not configured to use ${{ matrix.auth-provider }} auth"
|
||||||
|
cat server.log
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
sleep 1
|
||||||
|
done
|
||||||
|
echo "Llama Stack server failed to start"
|
||||||
|
cat server.log
|
||||||
|
exit 1
|
||||||
|
|
||||||
|
- name: Test auth
|
||||||
|
run: |
|
||||||
|
curl -s -L -H "Authorization: Bearer $(cat llama-stack-auth-token)" http://127.0.0.1:8321/v1/providers|jq
|
|
@ -24,7 +24,7 @@ jobs:
|
||||||
matrix:
|
matrix:
|
||||||
# Listing tests manually since some of them currently fail
|
# Listing tests manually since some of them currently fail
|
||||||
# TODO: generate matrix list from tests/integration when fixed
|
# TODO: generate matrix list from tests/integration when fixed
|
||||||
test-type: [agents, inference, datasets, inspect, scoring, post_training, providers]
|
test-type: [agents, inference, datasets, inspect, scoring, post_training, providers, tool_runtime]
|
||||||
client-type: [library, http]
|
client-type: [library, http]
|
||||||
fail-fast: false # we want to run all tests regardless of failure
|
fail-fast: false # we want to run all tests regardless of failure
|
||||||
|
|
||||||
|
@ -32,30 +32,14 @@ jobs:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||||
|
|
||||||
- name: Install uv
|
- name: Install dependencies
|
||||||
uses: astral-sh/setup-uv@0c5e2b8115b80b4c7c5ddf6ffdd634974642d182 # v5.4.1
|
uses: ./.github/actions/setup-runner
|
||||||
with:
|
|
||||||
python-version: "3.10"
|
|
||||||
|
|
||||||
- name: Install and start Ollama
|
- name: Setup ollama
|
||||||
run: |
|
uses: ./.github/actions/setup-ollama
|
||||||
# the ollama installer also starts the ollama service
|
|
||||||
curl -fsSL https://ollama.com/install.sh | sh
|
|
||||||
|
|
||||||
- name: Pull Ollama image
|
- name: Build Llama Stack
|
||||||
run: |
|
run: |
|
||||||
# TODO: cache the model. OLLAMA_MODELS defaults to ~ollama/.ollama/models.
|
|
||||||
ollama pull llama3.2:3b-instruct-fp16
|
|
||||||
|
|
||||||
- name: Set Up Environment and Install Dependencies
|
|
||||||
run: |
|
|
||||||
uv sync --extra dev --extra test
|
|
||||||
uv pip install ollama faiss-cpu
|
|
||||||
# always test against the latest version of the client
|
|
||||||
# TODO: this is not necessarily a good idea. we need to test against both published and latest
|
|
||||||
# to find out backwards compatibility issues.
|
|
||||||
uv pip install git+https://github.com/meta-llama/llama-stack-client-python.git@main
|
|
||||||
uv pip install -e .
|
|
||||||
llama stack build --template ollama --image-type venv
|
llama stack build --template ollama --image-type venv
|
||||||
|
|
||||||
- name: Start Llama Stack server in background
|
- name: Start Llama Stack server in background
|
||||||
|
@ -63,8 +47,7 @@ jobs:
|
||||||
env:
|
env:
|
||||||
INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct"
|
INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct"
|
||||||
run: |
|
run: |
|
||||||
source .venv/bin/activate
|
LLAMA_STACK_LOG_FILE=server.log nohup uv run llama stack run ./llama_stack/templates/ollama/run.yaml --image-type venv &
|
||||||
nohup uv run llama stack run ./llama_stack/templates/ollama/run.yaml --image-type venv > server.log 2>&1 &
|
|
||||||
|
|
||||||
- name: Wait for Llama Stack server to be ready
|
- name: Wait for Llama Stack server to be ready
|
||||||
if: matrix.client-type == 'http'
|
if: matrix.client-type == 'http'
|
||||||
|
@ -92,6 +75,12 @@ jobs:
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
- name: Check Storage and Memory Available Before Tests
|
||||||
|
if: ${{ always() }}
|
||||||
|
run: |
|
||||||
|
free -h
|
||||||
|
df -h
|
||||||
|
|
||||||
- name: Run Integration Tests
|
- name: Run Integration Tests
|
||||||
env:
|
env:
|
||||||
INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct"
|
INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct"
|
||||||
|
@ -101,7 +90,27 @@ jobs:
|
||||||
else
|
else
|
||||||
stack_config="http://localhost:8321"
|
stack_config="http://localhost:8321"
|
||||||
fi
|
fi
|
||||||
uv run pytest -v tests/integration/${{ matrix.test-type }} --stack-config=${stack_config} \
|
uv run pytest -s -v tests/integration/${{ matrix.test-type }} --stack-config=${stack_config} \
|
||||||
-k "not(builtin_tool or safety_with_image or code_interpreter or test_rag)" \
|
-k "not(builtin_tool or safety_with_image or code_interpreter or test_rag)" \
|
||||||
--text-model="meta-llama/Llama-3.2-3B-Instruct" \
|
--text-model="meta-llama/Llama-3.2-3B-Instruct" \
|
||||||
--embedding-model=all-MiniLM-L6-v2
|
--embedding-model=all-MiniLM-L6-v2
|
||||||
|
|
||||||
|
- name: Check Storage and Memory Available After Tests
|
||||||
|
if: ${{ always() }}
|
||||||
|
run: |
|
||||||
|
free -h
|
||||||
|
df -h
|
||||||
|
|
||||||
|
- name: Write ollama logs to file
|
||||||
|
if: ${{ always() }}
|
||||||
|
run: |
|
||||||
|
sudo journalctl -u ollama.service > ollama.log
|
||||||
|
|
||||||
|
- name: Upload all logs to artifacts
|
||||||
|
if: ${{ always() }}
|
||||||
|
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
|
||||||
|
with:
|
||||||
|
name: logs-${{ github.run_id }}-${{ github.run_attempt }}-${{ matrix.client-type }}-${{ matrix.test-type }}
|
||||||
|
path: |
|
||||||
|
*.log
|
||||||
|
retention-days: 1
|
|
@ -18,7 +18,7 @@ jobs:
|
||||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||||
|
|
||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0
|
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0
|
||||||
with:
|
with:
|
||||||
python-version: '3.11'
|
python-version: '3.11'
|
||||||
cache: pip
|
cache: pip
|
||||||
|
@ -27,6 +27,9 @@ jobs:
|
||||||
.pre-commit-config.yaml
|
.pre-commit-config.yaml
|
||||||
|
|
||||||
- uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1
|
- uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1
|
||||||
|
env:
|
||||||
|
SKIP: no-commit-to-branch
|
||||||
|
RUFF_OUTPUT_FORMAT: github
|
||||||
|
|
||||||
- name: Verify if there are any diff files after pre-commit
|
- name: Verify if there are any diff files after pre-commit
|
||||||
run: |
|
run: |
|
|
@ -50,21 +50,8 @@ jobs:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||||
|
|
||||||
- name: Set up Python
|
- name: Install dependencies
|
||||||
uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0
|
uses: ./.github/actions/setup-runner
|
||||||
with:
|
|
||||||
python-version: '3.10'
|
|
||||||
|
|
||||||
- name: Install uv
|
|
||||||
uses: astral-sh/setup-uv@0c5e2b8115b80b4c7c5ddf6ffdd634974642d182 # v5.4.1
|
|
||||||
with:
|
|
||||||
python-version: "3.10"
|
|
||||||
|
|
||||||
- name: Install LlamaStack
|
|
||||||
run: |
|
|
||||||
uv venv
|
|
||||||
source .venv/bin/activate
|
|
||||||
uv pip install -e .
|
|
||||||
|
|
||||||
- name: Print build dependencies
|
- name: Print build dependencies
|
||||||
run: |
|
run: |
|
||||||
|
@ -79,7 +66,6 @@ jobs:
|
||||||
- name: Print dependencies in the image
|
- name: Print dependencies in the image
|
||||||
if: matrix.image-type == 'venv'
|
if: matrix.image-type == 'venv'
|
||||||
run: |
|
run: |
|
||||||
source test/bin/activate
|
|
||||||
uv pip list
|
uv pip list
|
||||||
|
|
||||||
build-single-provider:
|
build-single-provider:
|
||||||
|
@ -88,21 +74,8 @@ jobs:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||||
|
|
||||||
- name: Set up Python
|
- name: Install dependencies
|
||||||
uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0
|
uses: ./.github/actions/setup-runner
|
||||||
with:
|
|
||||||
python-version: '3.10'
|
|
||||||
|
|
||||||
- name: Install uv
|
|
||||||
uses: astral-sh/setup-uv@0c5e2b8115b80b4c7c5ddf6ffdd634974642d182 # v5.4.1
|
|
||||||
with:
|
|
||||||
python-version: "3.10"
|
|
||||||
|
|
||||||
- name: Install LlamaStack
|
|
||||||
run: |
|
|
||||||
uv venv
|
|
||||||
source .venv/bin/activate
|
|
||||||
uv pip install -e .
|
|
||||||
|
|
||||||
- name: Build a single provider
|
- name: Build a single provider
|
||||||
run: |
|
run: |
|
||||||
|
@ -114,27 +87,14 @@ jobs:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||||
|
|
||||||
- name: Set up Python
|
- name: Install dependencies
|
||||||
uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0
|
uses: ./.github/actions/setup-runner
|
||||||
with:
|
|
||||||
python-version: '3.10'
|
|
||||||
|
|
||||||
- name: Install uv
|
|
||||||
uses: astral-sh/setup-uv@0c5e2b8115b80b4c7c5ddf6ffdd634974642d182 # v5.4.1
|
|
||||||
with:
|
|
||||||
python-version: "3.10"
|
|
||||||
|
|
||||||
- name: Install LlamaStack
|
|
||||||
run: |
|
|
||||||
uv venv
|
|
||||||
source .venv/bin/activate
|
|
||||||
uv pip install -e .
|
|
||||||
|
|
||||||
- name: Build a single provider
|
- name: Build a single provider
|
||||||
run: |
|
run: |
|
||||||
yq -i '.image_type = "container"' llama_stack/templates/dev/build.yaml
|
yq -i '.image_type = "container"' llama_stack/templates/starter/build.yaml
|
||||||
yq -i '.image_name = "test"' llama_stack/templates/dev/build.yaml
|
yq -i '.image_name = "test"' llama_stack/templates/starter/build.yaml
|
||||||
USE_COPY_NOT_MOUNT=true LLAMA_STACK_DIR=. uv run llama stack build --config llama_stack/templates/dev/build.yaml
|
USE_COPY_NOT_MOUNT=true LLAMA_STACK_DIR=. uv run llama stack build --config llama_stack/templates/starter/build.yaml
|
||||||
|
|
||||||
- name: Inspect the container image entrypoint
|
- name: Inspect the container image entrypoint
|
||||||
run: |
|
run: |
|
||||||
|
@ -145,3 +105,43 @@ jobs:
|
||||||
echo "Entrypoint is not correct"
|
echo "Entrypoint is not correct"
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
build-ubi9-container-distribution:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Checkout repository
|
||||||
|
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
uses: ./.github/actions/setup-runner
|
||||||
|
|
||||||
|
- name: Pin template to UBI9 base
|
||||||
|
run: |
|
||||||
|
yq -i '
|
||||||
|
.image_type = "container" |
|
||||||
|
.image_name = "ubi9-test" |
|
||||||
|
.distribution_spec.container_image = "registry.access.redhat.com/ubi9:latest"
|
||||||
|
' llama_stack/templates/starter/build.yaml
|
||||||
|
|
||||||
|
- name: Build dev container (UBI9)
|
||||||
|
env:
|
||||||
|
USE_COPY_NOT_MOUNT: "true"
|
||||||
|
LLAMA_STACK_DIR: "."
|
||||||
|
run: |
|
||||||
|
uv run llama stack build --config llama_stack/templates/starter/build.yaml
|
||||||
|
|
||||||
|
- name: Inspect UBI9 image
|
||||||
|
run: |
|
||||||
|
IMAGE_ID=$(docker images --format "{{.Repository}}:{{.Tag}}" | head -n 1)
|
||||||
|
entrypoint=$(docker inspect --format '{{ .Config.Entrypoint }}' $IMAGE_ID)
|
||||||
|
echo "Entrypoint: $entrypoint"
|
||||||
|
if [ "$entrypoint" != "[python -m llama_stack.distribution.server.server --config /app/run.yaml]" ]; then
|
||||||
|
echo "Entrypoint is not correct"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "Checking /etc/os-release in $IMAGE_ID"
|
||||||
|
docker run --rm --entrypoint sh "$IMAGE_ID" -c \
|
||||||
|
'source /etc/os-release && echo "$ID"' \
|
||||||
|
| grep -qE '^(rhel|ubi)$' \
|
||||||
|
|| { echo "Base image is not UBI 9!"; exit 1; }
|
|
@ -23,29 +23,10 @@ jobs:
|
||||||
# container and point 'uv pip install' to the correct path...
|
# container and point 'uv pip install' to the correct path...
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||||
|
|
||||||
- name: Install uv
|
- name: Install dependencies
|
||||||
uses: astral-sh/setup-uv@v5
|
uses: ./.github/actions/setup-runner
|
||||||
with:
|
|
||||||
python-version: "3.10"
|
|
||||||
|
|
||||||
- name: Install Ollama
|
|
||||||
run: |
|
|
||||||
curl -fsSL https://ollama.com/install.sh | sh
|
|
||||||
|
|
||||||
- name: Pull Ollama image
|
|
||||||
run: |
|
|
||||||
ollama pull llama3.2:3b-instruct-fp16
|
|
||||||
|
|
||||||
- name: Start Ollama in background
|
|
||||||
run: |
|
|
||||||
nohup ollama run llama3.2:3b-instruct-fp16 --keepalive=30m > ollama.log 2>&1 &
|
|
||||||
|
|
||||||
- name: Set Up Environment and Install Dependencies
|
|
||||||
run: |
|
|
||||||
uv sync --extra dev --extra test
|
|
||||||
uv pip install -e .
|
|
||||||
|
|
||||||
- name: Apply image type to config file
|
- name: Apply image type to config file
|
||||||
run: |
|
run: |
|
||||||
|
@ -59,57 +40,32 @@ jobs:
|
||||||
|
|
||||||
- name: Create provider configuration
|
- name: Create provider configuration
|
||||||
run: |
|
run: |
|
||||||
mkdir -p /tmp/providers.d/remote/inference
|
mkdir -p /home/runner/.llama/providers.d/remote/inference
|
||||||
cp tests/external-provider/llama-stack-provider-ollama/custom_ollama.yaml /tmp/providers.d/remote/inference/custom_ollama.yaml
|
cp tests/external-provider/llama-stack-provider-ollama/custom_ollama.yaml /home/runner/.llama/providers.d/remote/inference/custom_ollama.yaml
|
||||||
|
|
||||||
- name: Build distro from config file
|
- name: Build distro from config file
|
||||||
run: |
|
run: |
|
||||||
USE_COPY_NOT_MOUNT=true LLAMA_STACK_DIR=. uv run llama stack build --config tests/external-provider/llama-stack-provider-ollama/custom-distro.yaml
|
USE_COPY_NOT_MOUNT=true LLAMA_STACK_DIR=. uv run llama stack build --config tests/external-provider/llama-stack-provider-ollama/custom-distro.yaml
|
||||||
|
|
||||||
- name: Wait for Ollama to start
|
|
||||||
run: |
|
|
||||||
echo "Waiting for Ollama..."
|
|
||||||
for i in {1..30}; do
|
|
||||||
if curl -s http://localhost:11434 | grep -q "Ollama is running"; then
|
|
||||||
echo "Ollama is running!"
|
|
||||||
exit 0
|
|
||||||
fi
|
|
||||||
sleep 1
|
|
||||||
done
|
|
||||||
echo "Ollama failed to start"
|
|
||||||
ollama ps
|
|
||||||
ollama.log
|
|
||||||
exit 1
|
|
||||||
|
|
||||||
- name: Start Llama Stack server in background
|
- name: Start Llama Stack server in background
|
||||||
if: ${{ matrix.image-type }} == 'venv'
|
if: ${{ matrix.image-type }} == 'venv'
|
||||||
env:
|
env:
|
||||||
INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct"
|
INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct"
|
||||||
run: |
|
run: |
|
||||||
source ci-test/bin/activate
|
|
||||||
uv run pip list
|
uv run pip list
|
||||||
nohup uv run --active llama stack run tests/external-provider/llama-stack-provider-ollama/run.yaml --image-type ${{ matrix.image-type }} > server.log 2>&1 &
|
nohup uv run --active llama stack run tests/external-provider/llama-stack-provider-ollama/run.yaml --image-type ${{ matrix.image-type }} > server.log 2>&1 &
|
||||||
|
|
||||||
- name: Wait for Llama Stack server to be ready
|
- name: Wait for Llama Stack server to be ready
|
||||||
run: |
|
run: |
|
||||||
echo "Waiting for Llama Stack server..."
|
|
||||||
for i in {1..30}; do
|
for i in {1..30}; do
|
||||||
if curl -s http://localhost:8321/v1/health | grep -q "OK"; then
|
if ! grep -q "remote::custom_ollama from /home/runner/.llama/providers.d/remote/inference/custom_ollama.yaml" server.log; then
|
||||||
echo "Llama Stack server is up!"
|
echo "Waiting for Llama Stack server to load the provider..."
|
||||||
if grep -q "remote::custom_ollama from /tmp/providers.d/remote/inference/custom_ollama.yaml" server.log; then
|
|
||||||
echo "Llama Stack server is using custom Ollama provider"
|
|
||||||
exit 0
|
|
||||||
else
|
|
||||||
echo "Llama Stack server is not using custom Ollama provider"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
fi
|
|
||||||
sleep 1
|
sleep 1
|
||||||
|
else
|
||||||
|
echo "Provider loaded"
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
done
|
done
|
||||||
echo "Llama Stack server failed to start"
|
echo "Provider failed to load"
|
||||||
cat server.log
|
cat server.log
|
||||||
exit 1
|
exit 1
|
||||||
|
|
||||||
- name: run inference tests
|
|
||||||
run: |
|
|
||||||
uv run pytest -v tests/integration/inference/test_text_inference.py --stack-config="http://localhost:8321" --text-model="meta-llama/Llama-3.2-3B-Instruct" --embedding-model=all-MiniLM-L6-v2
|
|
|
@ -30,17 +30,11 @@ jobs:
|
||||||
- "3.12"
|
- "3.12"
|
||||||
- "3.13"
|
- "3.13"
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
- name: Checkout repository
|
||||||
|
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||||
|
|
||||||
- name: Set up Python ${{ matrix.python }}
|
- name: Install dependencies
|
||||||
uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0
|
uses: ./.github/actions/setup-runner
|
||||||
with:
|
|
||||||
python-version: ${{ matrix.python }}
|
|
||||||
|
|
||||||
- uses: astral-sh/setup-uv@0c5e2b8115b80b4c7c5ddf6ffdd634974642d182 # v5.4.1
|
|
||||||
with:
|
|
||||||
python-version: ${{ matrix.python }}
|
|
||||||
enable-cache: false
|
|
||||||
|
|
||||||
- name: Run unit tests
|
- name: Run unit tests
|
||||||
run: |
|
run: |
|
|
@ -14,6 +14,8 @@ on:
|
||||||
- 'docs/**'
|
- 'docs/**'
|
||||||
- 'pyproject.toml'
|
- 'pyproject.toml'
|
||||||
- '.github/workflows/update-readthedocs.yml'
|
- '.github/workflows/update-readthedocs.yml'
|
||||||
|
tags:
|
||||||
|
- '*'
|
||||||
pull_request:
|
pull_request:
|
||||||
branches:
|
branches:
|
||||||
- main
|
- main
|
||||||
|
@ -35,16 +37,8 @@ jobs:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||||
|
|
||||||
- name: Set up Python
|
- name: Install dependencies
|
||||||
uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0
|
uses: ./.github/actions/setup-runner
|
||||||
with:
|
|
||||||
python-version: '3.11'
|
|
||||||
|
|
||||||
- name: Install the latest version of uv
|
|
||||||
uses: astral-sh/setup-uv@0c5e2b8115b80b4c7c5ddf6ffdd634974642d182 # v5.4.1
|
|
||||||
|
|
||||||
- name: Sync with uv
|
|
||||||
run: uv sync --extra docs
|
|
||||||
|
|
||||||
- name: Build HTML
|
- name: Build HTML
|
||||||
run: |
|
run: |
|
||||||
|
@ -61,7 +55,10 @@ jobs:
|
||||||
|
|
||||||
response=$(curl -X POST \
|
response=$(curl -X POST \
|
||||||
-H "Content-Type: application/json" \
|
-H "Content-Type: application/json" \
|
||||||
-d "{\"token\": \"$TOKEN\"}" \
|
-d "{
|
||||||
|
\"token\": \"$TOKEN\",
|
||||||
|
\"version\": \"$GITHUB_REF_NAME\"
|
||||||
|
}" \
|
||||||
https://readthedocs.org/api/v2/webhook/llama-stack/289768/)
|
https://readthedocs.org/api/v2/webhook/llama-stack/289768/)
|
||||||
|
|
||||||
echo "Response: $response"
|
echo "Response: $response"
|
2
.gitignore
vendored
2
.gitignore
vendored
|
@ -6,6 +6,7 @@ dev_requirements.txt
|
||||||
build
|
build
|
||||||
.DS_Store
|
.DS_Store
|
||||||
llama_stack/configs/*
|
llama_stack/configs/*
|
||||||
|
.cursor/
|
||||||
xcuserdata/
|
xcuserdata/
|
||||||
*.hmap
|
*.hmap
|
||||||
.DS_Store
|
.DS_Store
|
||||||
|
@ -23,3 +24,4 @@ venv/
|
||||||
pytest-report.xml
|
pytest-report.xml
|
||||||
.coverage
|
.coverage
|
||||||
.python-version
|
.python-version
|
||||||
|
data
|
||||||
|
|
|
@ -15,6 +15,18 @@ repos:
|
||||||
args: ['--maxkb=1000']
|
args: ['--maxkb=1000']
|
||||||
- id: end-of-file-fixer
|
- id: end-of-file-fixer
|
||||||
exclude: '^(.*\.svg)$'
|
exclude: '^(.*\.svg)$'
|
||||||
|
- id: no-commit-to-branch
|
||||||
|
- id: check-yaml
|
||||||
|
args: ["--unsafe"]
|
||||||
|
- id: detect-private-key
|
||||||
|
- id: requirements-txt-fixer
|
||||||
|
- id: mixed-line-ending
|
||||||
|
args: [--fix=lf] # Forces to replace line ending by LF (line feed)
|
||||||
|
- id: check-executables-have-shebangs
|
||||||
|
- id: check-json
|
||||||
|
- id: check-shebang-scripts-are-executable
|
||||||
|
- id: check-symlinks
|
||||||
|
- id: check-toml
|
||||||
|
|
||||||
- repo: https://github.com/Lucas-C/pre-commit-hooks
|
- repo: https://github.com/Lucas-C/pre-commit-hooks
|
||||||
rev: v1.5.4
|
rev: v1.5.4
|
||||||
|
@ -41,7 +53,7 @@ repos:
|
||||||
- black==24.3.0
|
- black==24.3.0
|
||||||
|
|
||||||
- repo: https://github.com/astral-sh/uv-pre-commit
|
- repo: https://github.com/astral-sh/uv-pre-commit
|
||||||
rev: 0.6.3
|
rev: 0.7.8
|
||||||
hooks:
|
hooks:
|
||||||
- id: uv-lock
|
- id: uv-lock
|
||||||
- id: uv-export
|
- id: uv-export
|
||||||
|
@ -49,6 +61,7 @@ repos:
|
||||||
"--frozen",
|
"--frozen",
|
||||||
"--no-hashes",
|
"--no-hashes",
|
||||||
"--no-emit-project",
|
"--no-emit-project",
|
||||||
|
"--no-default-groups",
|
||||||
"--output-file=requirements.txt"
|
"--output-file=requirements.txt"
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -76,24 +89,29 @@ repos:
|
||||||
- id: distro-codegen
|
- id: distro-codegen
|
||||||
name: Distribution Template Codegen
|
name: Distribution Template Codegen
|
||||||
additional_dependencies:
|
additional_dependencies:
|
||||||
- uv==0.6.0
|
- uv==0.7.8
|
||||||
entry: uv run --extra codegen ./scripts/distro_codegen.py
|
entry: uv run --group codegen ./scripts/distro_codegen.py
|
||||||
language: python
|
language: python
|
||||||
pass_filenames: false
|
pass_filenames: false
|
||||||
require_serial: true
|
require_serial: true
|
||||||
files: ^llama_stack/templates/.*$|^llama_stack/providers/.*/inference/.*/models\.py$
|
files: ^llama_stack/templates/.*$|^llama_stack/providers/.*/inference/.*/models\.py$
|
||||||
|
|
||||||
- repo: local
|
|
||||||
hooks:
|
|
||||||
- id: openapi-codegen
|
- id: openapi-codegen
|
||||||
name: API Spec Codegen
|
name: API Spec Codegen
|
||||||
additional_dependencies:
|
additional_dependencies:
|
||||||
- uv==0.6.2
|
- uv==0.7.8
|
||||||
entry: sh -c 'uv run --with ".[dev]" ./docs/openapi_generator/run_openapi_generator.sh > /dev/null'
|
entry: sh -c 'uv run ./docs/openapi_generator/run_openapi_generator.sh > /dev/null'
|
||||||
language: python
|
language: python
|
||||||
pass_filenames: false
|
pass_filenames: false
|
||||||
require_serial: true
|
require_serial: true
|
||||||
files: ^llama_stack/apis/|^docs/openapi_generator/
|
files: ^llama_stack/apis/|^docs/openapi_generator/
|
||||||
|
- id: check-workflows-use-hashes
|
||||||
|
name: Check GitHub Actions use SHA-pinned actions
|
||||||
|
entry: ./scripts/check-workflows-use-hashes.sh
|
||||||
|
language: system
|
||||||
|
pass_filenames: false
|
||||||
|
require_serial: true
|
||||||
|
always_run: true
|
||||||
|
files: ^\.github/workflows/.*\.ya?ml$
|
||||||
|
|
||||||
ci:
|
ci:
|
||||||
autofix_commit_msg: 🎨 [pre-commit.ci] Auto format from pre-commit.com hooks
|
autofix_commit_msg: 🎨 [pre-commit.ci] Auto format from pre-commit.com hooks
|
||||||
|
|
|
@ -5,28 +5,21 @@
|
||||||
# Required
|
# Required
|
||||||
version: 2
|
version: 2
|
||||||
|
|
||||||
|
# Build documentation in the "docs/" directory with Sphinx
|
||||||
|
sphinx:
|
||||||
|
configuration: docs/source/conf.py
|
||||||
|
|
||||||
# Set the OS, Python version and other tools you might need
|
# Set the OS, Python version and other tools you might need
|
||||||
build:
|
build:
|
||||||
os: ubuntu-22.04
|
os: ubuntu-22.04
|
||||||
tools:
|
tools:
|
||||||
python: "3.12"
|
python: "3.12"
|
||||||
# You can also specify other tool versions:
|
jobs:
|
||||||
# nodejs: "19"
|
pre_create_environment:
|
||||||
# rust: "1.64"
|
- asdf plugin add uv
|
||||||
# golang: "1.19"
|
- asdf install uv latest
|
||||||
|
- asdf global uv latest
|
||||||
# Build documentation in the "docs/" directory with Sphinx
|
create_environment:
|
||||||
sphinx:
|
- uv venv "${READTHEDOCS_VIRTUALENV_PATH}"
|
||||||
configuration: docs/source/conf.py
|
|
||||||
|
|
||||||
# Optionally build your docs in additional formats such as PDF and ePub
|
|
||||||
# formats:
|
|
||||||
# - pdf
|
|
||||||
# - epub
|
|
||||||
|
|
||||||
# Optional but recommended, declare the Python requirements required
|
|
||||||
# to build your documentation
|
|
||||||
# See https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html
|
|
||||||
python:
|
|
||||||
install:
|
install:
|
||||||
- requirements: docs/requirements.txt
|
- UV_PROJECT_ENVIRONMENT="${READTHEDOCS_VIRTUALENV_PATH}" uv sync --frozen --group docs
|
||||||
|
|
70
CHANGELOG.md
70
CHANGELOG.md
|
@ -1,5 +1,75 @@
|
||||||
# Changelog
|
# Changelog
|
||||||
|
|
||||||
|
# v0.2.7
|
||||||
|
Published on: 2025-05-16T20:38:10Z
|
||||||
|
|
||||||
|
## Highlights
|
||||||
|
|
||||||
|
This is a small update. But a couple highlights:
|
||||||
|
|
||||||
|
* feat: function tools in OpenAI Responses by @bbrowning in https://github.com/meta-llama/llama-stack/pull/2094, getting closer to ready. Streaming is the next missing piece.
|
||||||
|
* feat: Adding support for customizing chunk context in RAG insertion and querying by @franciscojavierarceo in https://github.com/meta-llama/llama-stack/pull/2134
|
||||||
|
* feat: scaffolding for Llama Stack UI by @ehhuang in https://github.com/meta-llama/llama-stack/pull/2149, more to come in the coming releases.
|
||||||
|
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
# v0.2.6
|
||||||
|
Published on: 2025-05-12T18:06:52Z
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
# v0.2.5
|
||||||
|
Published on: 2025-05-04T20:16:49Z
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
# v0.2.4
|
||||||
|
Published on: 2025-04-29T17:26:01Z
|
||||||
|
|
||||||
|
## Highlights
|
||||||
|
|
||||||
|
* One-liner to install and run Llama Stack yay! by @reluctantfuturist in https://github.com/meta-llama/llama-stack/pull/1383
|
||||||
|
* support for NVIDIA NeMo datastore by @raspawar in https://github.com/meta-llama/llama-stack/pull/1852
|
||||||
|
* (yuge!) Kubernetes authentication by @leseb in https://github.com/meta-llama/llama-stack/pull/1778
|
||||||
|
* (yuge!) OpenAI Responses API by @bbrowning in https://github.com/meta-llama/llama-stack/pull/1989
|
||||||
|
* add api.llama provider, llama-guard-4 model by @ashwinb in https://github.com/meta-llama/llama-stack/pull/2058
|
||||||
|
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
# v0.2.3
|
||||||
|
Published on: 2025-04-25T22:46:21Z
|
||||||
|
|
||||||
|
## Highlights
|
||||||
|
|
||||||
|
* OpenAI compatible inference endpoints and client-SDK support. `client.chat.completions.create()` now works.
|
||||||
|
* significant improvements and functionality added to the nVIDIA distribution
|
||||||
|
* many improvements to the test verification suite.
|
||||||
|
* new inference providers: Ramalama, IBM WatsonX
|
||||||
|
* many improvements to the Playground UI
|
||||||
|
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
# v0.2.2
|
||||||
|
Published on: 2025-04-13T01:19:49Z
|
||||||
|
|
||||||
|
## Main changes
|
||||||
|
|
||||||
|
- Bring Your Own Provider (@leseb) - use out-of-tree provider code to execute the distribution server
|
||||||
|
- OpenAI compatible inference API in progress (@bbrowning)
|
||||||
|
- Provider verifications (@ehhuang)
|
||||||
|
- Many updates and fixes to playground
|
||||||
|
- Several llama4 related fixes
|
||||||
|
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
# v0.2.1
|
# v0.2.1
|
||||||
Published on: 2025-04-05T23:13:00Z
|
Published on: 2025-04-05T23:13:00Z
|
||||||
|
|
||||||
|
|
|
@ -110,25 +110,9 @@ uv run pre-commit run --all-files
|
||||||
> [!CAUTION]
|
> [!CAUTION]
|
||||||
> Before pushing your changes, make sure that the pre-commit hooks have passed successfully.
|
> Before pushing your changes, make sure that the pre-commit hooks have passed successfully.
|
||||||
|
|
||||||
## Running unit tests
|
## Running tests
|
||||||
|
|
||||||
You can run the unit tests by running:
|
You can find the Llama Stack testing documentation here [here](tests/README.md).
|
||||||
|
|
||||||
```bash
|
|
||||||
source .venv/bin/activate
|
|
||||||
./scripts/unit-tests.sh
|
|
||||||
```
|
|
||||||
|
|
||||||
If you'd like to run for a non-default version of Python (currently 3.10), pass `PYTHON_VERSION` variable as follows:
|
|
||||||
|
|
||||||
```
|
|
||||||
source .venv/bin/activate
|
|
||||||
PYTHON_VERSION=3.13 ./scripts/unit-tests.sh
|
|
||||||
```
|
|
||||||
|
|
||||||
## Running integration tests
|
|
||||||
|
|
||||||
You can run integration tests following the instructions [here](tests/integration/README.md).
|
|
||||||
|
|
||||||
## Adding a new dependency to the project
|
## Adding a new dependency to the project
|
||||||
|
|
||||||
|
@ -141,11 +125,20 @@ uv sync
|
||||||
|
|
||||||
## Coding Style
|
## Coding Style
|
||||||
|
|
||||||
* Comments should provide meaningful insights into the code. Avoid filler comments that simply describe the next step, as they create unnecessary clutter, same goes for docstrings.
|
* Comments should provide meaningful insights into the code. Avoid filler comments that simply
|
||||||
* Prefer comments to clarify surprising behavior and/or relationships between parts of the code rather than explain what the next line of code does.
|
describe the next step, as they create unnecessary clutter, same goes for docstrings.
|
||||||
* Catching exceptions, prefer using a specific exception type rather than a broad catch-all like `Exception`.
|
* Prefer comments to clarify surprising behavior and/or relationships between parts of the code
|
||||||
|
rather than explain what the next line of code does.
|
||||||
|
* Catching exceptions, prefer using a specific exception type rather than a broad catch-all like
|
||||||
|
`Exception`.
|
||||||
* Error messages should be prefixed with "Failed to ..."
|
* Error messages should be prefixed with "Failed to ..."
|
||||||
* 4 spaces for indentation rather than tabs
|
* 4 spaces for indentation rather than tab
|
||||||
|
* When using `# noqa` to suppress a style or linter warning, include a comment explaining the
|
||||||
|
justification for bypassing the check.
|
||||||
|
* When using `# type: ignore` to suppress a mypy warning, include a comment explaining the
|
||||||
|
justification for bypassing the check.
|
||||||
|
* Don't use unicode characters in the codebase. ASCII-only is preferred for compatibility or
|
||||||
|
readability reasons.
|
||||||
|
|
||||||
## Common Tasks
|
## Common Tasks
|
||||||
|
|
||||||
|
@ -174,14 +167,11 @@ If you have made changes to a provider's configuration in any form (introducing
|
||||||
If you are making changes to the documentation at [https://llama-stack.readthedocs.io/en/latest/](https://llama-stack.readthedocs.io/en/latest/), you can use the following command to build the documentation and preview your changes. You will need [Sphinx](https://www.sphinx-doc.org/en/master/) and the readthedocs theme.
|
If you are making changes to the documentation at [https://llama-stack.readthedocs.io/en/latest/](https://llama-stack.readthedocs.io/en/latest/), you can use the following command to build the documentation and preview your changes. You will need [Sphinx](https://www.sphinx-doc.org/en/master/) and the readthedocs theme.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
cd docs
|
|
||||||
uv sync --extra docs
|
|
||||||
|
|
||||||
# This rebuilds the documentation pages.
|
# This rebuilds the documentation pages.
|
||||||
uv run make html
|
uv run --group docs make -C docs/ html
|
||||||
|
|
||||||
# This will start a local server (usually at http://127.0.0.1:8000) that automatically rebuilds and refreshes when you make changes to the documentation.
|
# This will start a local server (usually at http://127.0.0.1:8000) that automatically rebuilds and refreshes when you make changes to the documentation.
|
||||||
uv run sphinx-autobuild source build/html --write-all
|
uv run --group docs sphinx-autobuild docs/source docs/build/html --write-all
|
||||||
```
|
```
|
||||||
|
|
||||||
### Update API Documentation
|
### Update API Documentation
|
||||||
|
@ -189,7 +179,7 @@ uv run sphinx-autobuild source build/html --write-all
|
||||||
If you modify or add new API endpoints, update the API documentation accordingly. You can do this by running the following command:
|
If you modify or add new API endpoints, update the API documentation accordingly. You can do this by running the following command:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
uv run --with ".[dev]" ./docs/openapi_generator/run_openapi_generator.sh
|
uv run ./docs/openapi_generator/run_openapi_generator.sh
|
||||||
```
|
```
|
||||||
|
|
||||||
The generated API documentation will be available in `docs/_static/`. Make sure to review the changes before committing.
|
The generated API documentation will be available in `docs/_static/`. Make sure to review the changes before committing.
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
include pyproject.toml
|
include pyproject.toml
|
||||||
include llama_stack/templates/dependencies.json
|
|
||||||
include llama_stack/models/llama/llama3/tokenizer.model
|
include llama_stack/models/llama/llama3/tokenizer.model
|
||||||
include llama_stack/models/llama/llama4/tokenizer.model
|
include llama_stack/models/llama/llama4/tokenizer.model
|
||||||
include llama_stack/distribution/*.sh
|
include llama_stack/distribution/*.sh
|
||||||
|
|
52
README.md
52
README.md
|
@ -7,7 +7,7 @@
|
||||||
[](https://github.com/meta-llama/llama-stack/actions/workflows/unit-tests.yml?query=branch%3Amain)
|
[](https://github.com/meta-llama/llama-stack/actions/workflows/unit-tests.yml?query=branch%3Amain)
|
||||||
[](https://github.com/meta-llama/llama-stack/actions/workflows/integration-tests.yml?query=branch%3Amain)
|
[](https://github.com/meta-llama/llama-stack/actions/workflows/integration-tests.yml?query=branch%3Amain)
|
||||||
|
|
||||||
[**Quick Start**](https://llama-stack.readthedocs.io/en/latest/getting_started/index.html) | [**Documentation**](https://llama-stack.readthedocs.io/en/latest/index.html) | [**Colab Notebook**](./docs/getting_started.ipynb)
|
[**Quick Start**](https://llama-stack.readthedocs.io/en/latest/getting_started/index.html) | [**Documentation**](https://llama-stack.readthedocs.io/en/latest/index.html) | [**Colab Notebook**](./docs/getting_started.ipynb) | [**Discord**](https://discord.gg/llama-stack)
|
||||||
|
|
||||||
### ✨🎉 Llama 4 Support 🎉✨
|
### ✨🎉 Llama 4 Support 🎉✨
|
||||||
We released [Version 0.2.0](https://github.com/meta-llama/llama-stack/releases/tag/v0.2.0) with support for the Llama 4 herd of models released by Meta.
|
We released [Version 0.2.0](https://github.com/meta-llama/llama-stack/releases/tag/v0.2.0) with support for the Llama 4 herd of models released by Meta.
|
||||||
|
@ -70,6 +70,13 @@ As more providers start supporting Llama 4, you can use them in Llama Stack as w
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
### 🚀 One-Line Installer 🚀
|
||||||
|
|
||||||
|
To try Llama Stack locally, run:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -LsSf https://github.com/meta-llama/llama-stack/raw/main/install.sh | sh
|
||||||
|
```
|
||||||
|
|
||||||
### Overview
|
### Overview
|
||||||
|
|
||||||
|
@ -100,26 +107,29 @@ By reducing friction and complexity, Llama Stack empowers developers to focus on
|
||||||
### API Providers
|
### API Providers
|
||||||
Here is a list of the various API providers and available distributions that can help developers get started easily with Llama Stack.
|
Here is a list of the various API providers and available distributions that can help developers get started easily with Llama Stack.
|
||||||
|
|
||||||
| **API Provider Builder** | **Environments** | **Agents** | **Inference** | **Memory** | **Safety** | **Telemetry** |
|
| **API Provider Builder** | **Environments** | **Agents** | **Inference** | **Memory** | **Safety** | **Telemetry** | **Post Training** |
|
||||||
|:------------------------:|:----------------------:|:----------:|:-------------:|:----------:|:----------:|:-------------:|
|
|:------------------------:|:----------------------:|:----------:|:-------------:|:----------:|:----------:|:-------------:|:-----------------:|
|
||||||
| Meta Reference | Single Node | ✅ | ✅ | ✅ | ✅ | ✅ |
|
| Meta Reference | Single Node | ✅ | ✅ | ✅ | ✅ | ✅ | |
|
||||||
| SambaNova | Hosted | | ✅ | | | |
|
| SambaNova | Hosted | | ✅ | | ✅ | | |
|
||||||
| Cerebras | Hosted | | ✅ | | | |
|
| Cerebras | Hosted | | ✅ | | | | |
|
||||||
| Fireworks | Hosted | ✅ | ✅ | ✅ | | |
|
| Fireworks | Hosted | ✅ | ✅ | ✅ | | | |
|
||||||
| AWS Bedrock | Hosted | | ✅ | | ✅ | |
|
| AWS Bedrock | Hosted | | ✅ | | ✅ | | |
|
||||||
| Together | Hosted | ✅ | ✅ | | ✅ | |
|
| Together | Hosted | ✅ | ✅ | | ✅ | | |
|
||||||
| Groq | Hosted | | ✅ | | | |
|
| Groq | Hosted | | ✅ | | | | |
|
||||||
| Ollama | Single Node | | ✅ | | | |
|
| Ollama | Single Node | | ✅ | | | | |
|
||||||
| TGI | Hosted and Single Node | | ✅ | | | |
|
| TGI | Hosted and Single Node | | ✅ | | | | |
|
||||||
| NVIDIA NIM | Hosted and Single Node | | ✅ | | | |
|
| NVIDIA NIM | Hosted and Single Node | | ✅ | | | | |
|
||||||
| Chroma | Single Node | | | ✅ | | |
|
| Chroma | Single Node | | | ✅ | | | |
|
||||||
| PG Vector | Single Node | | | ✅ | | |
|
| PG Vector | Single Node | | | ✅ | | | |
|
||||||
| PyTorch ExecuTorch | On-device iOS | ✅ | ✅ | | | |
|
| PyTorch ExecuTorch | On-device iOS | ✅ | ✅ | | | | |
|
||||||
| vLLM | Hosted and Single Node | | ✅ | | | |
|
| vLLM | Hosted and Single Node | | ✅ | | | | |
|
||||||
| OpenAI | Hosted | | ✅ | | | |
|
| OpenAI | Hosted | | ✅ | | | | |
|
||||||
| Anthropic | Hosted | | ✅ | | | |
|
| Anthropic | Hosted | | ✅ | | | | |
|
||||||
| Gemini | Hosted | | ✅ | | | |
|
| Gemini | Hosted | | ✅ | | | | |
|
||||||
| watsonx | Hosted | | ✅ | | | |
|
| watsonx | Hosted | | ✅ | | | | |
|
||||||
|
| HuggingFace | Single Node | | | | | | ✅ |
|
||||||
|
| TorchTune | Single Node | | | | | | ✅ |
|
||||||
|
| NVIDIA NEMO | Hosted | | | | | | ✅ |
|
||||||
|
|
||||||
|
|
||||||
### Distributions
|
### Distributions
|
||||||
|
|
6
docs/_static/css/my_theme.css
vendored
6
docs/_static/css/my_theme.css
vendored
|
@ -27,3 +27,9 @@ pre {
|
||||||
white-space: pre-wrap !important;
|
white-space: pre-wrap !important;
|
||||||
word-break: break-all;
|
word-break: break-all;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
[data-theme="dark"] .mermaid {
|
||||||
|
background-color: #f4f4f6 !important;
|
||||||
|
border-radius: 6px;
|
||||||
|
padding: 0.5em;
|
||||||
|
}
|
||||||
|
|
5001
docs/_static/llama-stack-spec.html
vendored
5001
docs/_static/llama-stack-spec.html
vendored
File diff suppressed because it is too large
Load diff
3671
docs/_static/llama-stack-spec.yaml
vendored
3671
docs/_static/llama-stack-spec.yaml
vendored
File diff suppressed because it is too large
Load diff
|
@ -1050,8 +1050,6 @@
|
||||||
"text/html": [
|
"text/html": [
|
||||||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">ToolGroup</span><span style=\"font-weight: bold\">(</span>\n",
|
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">ToolGroup</span><span style=\"font-weight: bold\">(</span>\n",
|
||||||
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ </span><span style=\"color: #808000; text-decoration-color: #808000\">identifier</span>=<span style=\"color: #008000; text-decoration-color: #008000\">'builtin::code_interpreter'</span>,\n",
|
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ </span><span style=\"color: #808000; text-decoration-color: #808000\">identifier</span>=<span style=\"color: #008000; text-decoration-color: #008000\">'builtin::code_interpreter'</span>,\n",
|
||||||
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ </span><span style=\"color: #808000; text-decoration-color: #808000\">provider_id</span>=<span style=\"color: #008000; text-decoration-color: #008000\">'code-interpreter'</span>,\n",
|
|
||||||
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ </span><span style=\"color: #808000; text-decoration-color: #808000\">provider_resource_id</span>=<span style=\"color: #008000; text-decoration-color: #008000\">'builtin::code_interpreter'</span>,\n",
|
|
||||||
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ </span><span style=\"color: #808000; text-decoration-color: #808000\">type</span>=<span style=\"color: #008000; text-decoration-color: #008000\">'tool_group'</span>,\n",
|
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ </span><span style=\"color: #808000; text-decoration-color: #808000\">type</span>=<span style=\"color: #008000; text-decoration-color: #008000\">'tool_group'</span>,\n",
|
||||||
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ </span><span style=\"color: #808000; text-decoration-color: #808000\">args</span>=<span style=\"color: #800080; text-decoration-color: #800080; font-style: italic\">None</span>,\n",
|
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ </span><span style=\"color: #808000; text-decoration-color: #808000\">args</span>=<span style=\"color: #800080; text-decoration-color: #800080; font-style: italic\">None</span>,\n",
|
||||||
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ </span><span style=\"color: #808000; text-decoration-color: #808000\">mcp_endpoint</span>=<span style=\"color: #800080; text-decoration-color: #800080; font-style: italic\">None</span>\n",
|
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ </span><span style=\"color: #808000; text-decoration-color: #808000\">mcp_endpoint</span>=<span style=\"color: #800080; text-decoration-color: #800080; font-style: italic\">None</span>\n",
|
||||||
|
@ -1061,7 +1059,6 @@
|
||||||
"text/plain": [
|
"text/plain": [
|
||||||
"\u001b[1;35mToolGroup\u001b[0m\u001b[1m(\u001b[0m\n",
|
"\u001b[1;35mToolGroup\u001b[0m\u001b[1m(\u001b[0m\n",
|
||||||
"\u001b[2;32m│ \u001b[0m\u001b[33midentifier\u001b[0m=\u001b[32m'builtin::code_interpreter'\u001b[0m,\n",
|
"\u001b[2;32m│ \u001b[0m\u001b[33midentifier\u001b[0m=\u001b[32m'builtin::code_interpreter'\u001b[0m,\n",
|
||||||
"\u001b[2;32m│ \u001b[0m\u001b[33mprovider_id\u001b[0m=\u001b[32m'code-interpreter'\u001b[0m,\n",
|
|
||||||
"\u001b[2;32m│ \u001b[0m\u001b[33mprovider_resource_id\u001b[0m=\u001b[32m'builtin::code_interpreter'\u001b[0m,\n",
|
"\u001b[2;32m│ \u001b[0m\u001b[33mprovider_resource_id\u001b[0m=\u001b[32m'builtin::code_interpreter'\u001b[0m,\n",
|
||||||
"\u001b[2;32m│ \u001b[0m\u001b[33mtype\u001b[0m=\u001b[32m'tool_group'\u001b[0m,\n",
|
"\u001b[2;32m│ \u001b[0m\u001b[33mtype\u001b[0m=\u001b[32m'tool_group'\u001b[0m,\n",
|
||||||
"\u001b[2;32m│ \u001b[0m\u001b[33margs\u001b[0m=\u001b[3;35mNone\u001b[0m,\n",
|
"\u001b[2;32m│ \u001b[0m\u001b[33margs\u001b[0m=\u001b[3;35mNone\u001b[0m,\n",
|
||||||
|
|
907
docs/getting_started_llama_api.ipynb
Normal file
907
docs/getting_started_llama_api.ipynb
Normal file
File diff suppressed because one or more lines are too long
File diff suppressed because it is too large
Load diff
|
@ -840,7 +840,6 @@
|
||||||
" \"memory_optimizations.rst\",\n",
|
" \"memory_optimizations.rst\",\n",
|
||||||
" \"chat.rst\",\n",
|
" \"chat.rst\",\n",
|
||||||
" \"llama3.rst\",\n",
|
" \"llama3.rst\",\n",
|
||||||
" \"datasets.rst\",\n",
|
|
||||||
" \"qat_finetune.rst\",\n",
|
" \"qat_finetune.rst\",\n",
|
||||||
" \"lora_finetune.rst\",\n",
|
" \"lora_finetune.rst\",\n",
|
||||||
"]\n",
|
"]\n",
|
||||||
|
@ -1586,7 +1585,6 @@
|
||||||
" \"memory_optimizations.rst\",\n",
|
" \"memory_optimizations.rst\",\n",
|
||||||
" \"chat.rst\",\n",
|
" \"chat.rst\",\n",
|
||||||
" \"llama3.rst\",\n",
|
" \"llama3.rst\",\n",
|
||||||
" \"datasets.rst\",\n",
|
|
||||||
" \"qat_finetune.rst\",\n",
|
" \"qat_finetune.rst\",\n",
|
||||||
" \"lora_finetune.rst\",\n",
|
" \"lora_finetune.rst\",\n",
|
||||||
"]\n",
|
"]\n",
|
||||||
|
|
|
@ -44,7 +44,7 @@ def main(output_dir: str):
|
||||||
if return_type_errors:
|
if return_type_errors:
|
||||||
print("\nAPI Method Return Type Validation Errors:\n")
|
print("\nAPI Method Return Type Validation Errors:\n")
|
||||||
for error in return_type_errors:
|
for error in return_type_errors:
|
||||||
print(error)
|
print(error, file=sys.stderr)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
now = str(datetime.now())
|
now = str(datetime.now())
|
||||||
print(
|
print(
|
||||||
|
|
|
@ -6,6 +6,7 @@
|
||||||
|
|
||||||
import hashlib
|
import hashlib
|
||||||
import ipaddress
|
import ipaddress
|
||||||
|
import types
|
||||||
import typing
|
import typing
|
||||||
from dataclasses import make_dataclass
|
from dataclasses import make_dataclass
|
||||||
from typing import Any, Dict, Set, Union
|
from typing import Any, Dict, Set, Union
|
||||||
|
@ -179,7 +180,7 @@ class ContentBuilder:
|
||||||
"Creates the content subtree for a request or response."
|
"Creates the content subtree for a request or response."
|
||||||
|
|
||||||
def is_iterator_type(t):
|
def is_iterator_type(t):
|
||||||
return "StreamChunk" in str(t)
|
return "StreamChunk" in str(t) or "OpenAIResponseObjectStream" in str(t)
|
||||||
|
|
||||||
def get_media_type(t):
|
def get_media_type(t):
|
||||||
if is_generic_list(t):
|
if is_generic_list(t):
|
||||||
|
@ -189,7 +190,7 @@ class ContentBuilder:
|
||||||
else:
|
else:
|
||||||
return "application/json"
|
return "application/json"
|
||||||
|
|
||||||
if typing.get_origin(payload_type) is typing.Union:
|
if typing.get_origin(payload_type) in (typing.Union, types.UnionType):
|
||||||
media_types = []
|
media_types = []
|
||||||
item_types = []
|
item_types = []
|
||||||
for x in typing.get_args(payload_type):
|
for x in typing.get_args(payload_type):
|
||||||
|
@ -758,7 +759,7 @@ class Generator:
|
||||||
)
|
)
|
||||||
|
|
||||||
return Operation(
|
return Operation(
|
||||||
tags=[op.defining_class.__name__],
|
tags=[getattr(op.defining_class, "API_NAMESPACE", op.defining_class.__name__)],
|
||||||
summary=None,
|
summary=None,
|
||||||
# summary=doc_string.short_description,
|
# summary=doc_string.short_description,
|
||||||
description=description,
|
description=description,
|
||||||
|
@ -804,6 +805,8 @@ class Generator:
|
||||||
operation_tags: List[Tag] = []
|
operation_tags: List[Tag] = []
|
||||||
for cls in endpoint_classes:
|
for cls in endpoint_classes:
|
||||||
doc_string = parse_type(cls)
|
doc_string = parse_type(cls)
|
||||||
|
if hasattr(cls, "API_NAMESPACE") and cls.API_NAMESPACE != cls.__name__:
|
||||||
|
continue
|
||||||
operation_tags.append(
|
operation_tags.append(
|
||||||
Tag(
|
Tag(
|
||||||
name=cls.__name__,
|
name=cls.__name__,
|
||||||
|
|
|
@ -174,14 +174,64 @@ def _validate_list_parameters_contain_data(method) -> str | None:
|
||||||
return "does not have a mandatory data attribute containing the list of objects"
|
return "does not have a mandatory data attribute containing the list of objects"
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_has_ellipsis(method) -> str | None:
|
||||||
|
source = inspect.getsource(method)
|
||||||
|
if "..." not in source and not "NotImplementedError" in source:
|
||||||
|
return "does not contain ellipsis (...) in its implementation"
|
||||||
|
|
||||||
|
def _validate_has_return_in_docstring(method) -> str | None:
|
||||||
|
source = inspect.getsource(method)
|
||||||
|
return_type = method.__annotations__.get('return')
|
||||||
|
if return_type is not None and return_type != type(None) and ":returns:" not in source:
|
||||||
|
return "does not have a ':returns:' in its docstring"
|
||||||
|
|
||||||
|
def _validate_has_params_in_docstring(method) -> str | None:
|
||||||
|
source = inspect.getsource(method)
|
||||||
|
sig = inspect.signature(method)
|
||||||
|
# Only check if the method has more than one parameter
|
||||||
|
if len(sig.parameters) > 1 and ":param" not in source:
|
||||||
|
return "does not have a ':param' in its docstring"
|
||||||
|
|
||||||
|
def _validate_has_no_return_none_in_docstring(method) -> str | None:
|
||||||
|
source = inspect.getsource(method)
|
||||||
|
return_type = method.__annotations__.get('return')
|
||||||
|
if return_type is None and ":returns: None" in source:
|
||||||
|
return "has a ':returns: None' in its docstring which is redundant for None-returning functions"
|
||||||
|
|
||||||
|
def _validate_docstring_lines_end_with_dot(method) -> str | None:
|
||||||
|
docstring = inspect.getdoc(method)
|
||||||
|
if docstring is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
lines = docstring.split('\n')
|
||||||
|
for line in lines:
|
||||||
|
line = line.strip()
|
||||||
|
if line and not any(line.endswith(char) for char in '.:{}[]()",'):
|
||||||
|
return f"docstring line '{line}' does not end with a valid character: . : {{ }} [ ] ( ) , \""
|
||||||
|
|
||||||
_VALIDATORS = {
|
_VALIDATORS = {
|
||||||
"GET": [
|
"GET": [
|
||||||
_validate_api_method_return_type,
|
_validate_api_method_return_type,
|
||||||
_validate_list_parameters_contain_data,
|
_validate_list_parameters_contain_data,
|
||||||
_validate_api_method_doesnt_return_list,
|
_validate_api_method_doesnt_return_list,
|
||||||
|
_validate_has_ellipsis,
|
||||||
|
_validate_has_return_in_docstring,
|
||||||
|
_validate_has_params_in_docstring,
|
||||||
|
_validate_docstring_lines_end_with_dot,
|
||||||
],
|
],
|
||||||
"DELETE": [
|
"DELETE": [
|
||||||
_validate_api_delete_method_returns_none,
|
_validate_api_delete_method_returns_none,
|
||||||
|
_validate_has_ellipsis,
|
||||||
|
_validate_has_return_in_docstring,
|
||||||
|
_validate_has_params_in_docstring,
|
||||||
|
_validate_has_no_return_none_in_docstring
|
||||||
|
],
|
||||||
|
"POST": [
|
||||||
|
_validate_has_ellipsis,
|
||||||
|
_validate_has_return_in_docstring,
|
||||||
|
_validate_has_params_in_docstring,
|
||||||
|
_validate_has_no_return_none_in_docstring,
|
||||||
|
_validate_docstring_lines_end_with_dot,
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -3,10 +3,10 @@
|
||||||
Here's a collection of comprehensive guides, examples, and resources for building AI applications with Llama Stack. For the complete documentation, visit our [ReadTheDocs page](https://llama-stack.readthedocs.io/en/latest/index.html).
|
Here's a collection of comprehensive guides, examples, and resources for building AI applications with Llama Stack. For the complete documentation, visit our [ReadTheDocs page](https://llama-stack.readthedocs.io/en/latest/index.html).
|
||||||
|
|
||||||
## Render locally
|
## Render locally
|
||||||
|
|
||||||
|
From the llama-stack root directory, run the following command to render the docs locally:
|
||||||
```bash
|
```bash
|
||||||
pip install -r requirements.txt
|
uv run --group docs sphinx-autobuild docs/source docs/build/html --write-all
|
||||||
cd docs
|
|
||||||
python -m sphinx_autobuild source _build
|
|
||||||
```
|
```
|
||||||
You can open up the docs in your browser at http://localhost:8000
|
You can open up the docs in your browser at http://localhost:8000
|
||||||
|
|
||||||
|
|
|
@ -1,16 +0,0 @@
|
||||||
sphinx==8.1.3
|
|
||||||
myst-parser
|
|
||||||
linkify
|
|
||||||
-e git+https://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme
|
|
||||||
sphinx-rtd-theme>=1.0.0
|
|
||||||
sphinx_autobuild
|
|
||||||
sphinx-copybutton
|
|
||||||
sphinx-design
|
|
||||||
sphinx-pdj-theme
|
|
||||||
sphinx_rtd_dark_mode
|
|
||||||
sphinx-tabs
|
|
||||||
sphinxcontrib-openapi
|
|
||||||
sphinxcontrib-redoc
|
|
||||||
sphinxcontrib-mermaid
|
|
||||||
sphinxcontrib-video
|
|
||||||
tomli
|
|
|
@ -51,11 +51,37 @@ chunks = [
|
||||||
"mime_type": "text/plain",
|
"mime_type": "text/plain",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"document_id": "doc1",
|
"document_id": "doc1",
|
||||||
|
"author": "Jane Doe",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
client.vector_io.insert(vector_db_id=vector_db_id, chunks=chunks)
|
client.vector_io.insert(vector_db_id=vector_db_id, chunks=chunks)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### Using Precomputed Embeddings
|
||||||
|
If you decide to precompute embeddings for your documents, you can insert them directly into the vector database by
|
||||||
|
including the embedding vectors in the chunk data. This is useful if you have a separate embedding service or if you
|
||||||
|
want to customize the ingestion process.
|
||||||
|
```python
|
||||||
|
chunks_with_embeddings = [
|
||||||
|
{
|
||||||
|
"content": "First chunk of text",
|
||||||
|
"mime_type": "text/plain",
|
||||||
|
"embedding": [0.1, 0.2, 0.3, ...], # Your precomputed embedding vector
|
||||||
|
"metadata": {"document_id": "doc1", "section": "introduction"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"content": "Second chunk of text",
|
||||||
|
"mime_type": "text/plain",
|
||||||
|
"embedding": [0.2, 0.3, 0.4, ...], # Your precomputed embedding vector
|
||||||
|
"metadata": {"document_id": "doc1", "section": "methodology"},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
client.vector_io.insert(vector_db_id=vector_db_id, chunks=chunks_with_embeddings)
|
||||||
|
```
|
||||||
|
When providing precomputed embeddings, ensure the embedding dimension matches the embedding_dimension specified when
|
||||||
|
registering the vector database.
|
||||||
|
|
||||||
### Retrieval
|
### Retrieval
|
||||||
You can query the vector database to retrieve documents based on their embeddings.
|
You can query the vector database to retrieve documents based on their embeddings.
|
||||||
```python
|
```python
|
||||||
|
@ -98,6 +124,17 @@ results = client.tool_runtime.rag_tool.query(
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
You can configure how the RAG tool adds metadata to the context if you find it useful for your application. Simply add:
|
||||||
|
```python
|
||||||
|
# Query documents
|
||||||
|
results = client.tool_runtime.rag_tool.query(
|
||||||
|
vector_db_ids=[vector_db_id],
|
||||||
|
content="What do you know about...",
|
||||||
|
query_config={
|
||||||
|
"chunk_template": "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
```
|
||||||
### Building RAG-Enhanced Agents
|
### Building RAG-Enhanced Agents
|
||||||
|
|
||||||
One of the most powerful patterns is combining agents with RAG capabilities. Here's a complete example:
|
One of the most powerful patterns is combining agents with RAG capabilities. Here's a complete example:
|
||||||
|
@ -115,6 +152,12 @@ agent = Agent(
|
||||||
"name": "builtin::rag/knowledge_search",
|
"name": "builtin::rag/knowledge_search",
|
||||||
"args": {
|
"args": {
|
||||||
"vector_db_ids": [vector_db_id],
|
"vector_db_ids": [vector_db_id],
|
||||||
|
# Defaults
|
||||||
|
"query_config": {
|
||||||
|
"chunk_size_in_tokens": 512,
|
||||||
|
"chunk_overlap_in_tokens": 0,
|
||||||
|
"chunk_template": "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
|
|
@ -43,27 +43,6 @@ The tool requires an API key which can be provided either in the configuration o
|
||||||
|
|
||||||
> **NOTE:** When using Tavily Search and Bing Search, the inference output will still display "Brave Search." This is because Llama models have been trained with Brave Search as a built-in tool. Tavily and bing is just being used in lieu of Brave search.
|
> **NOTE:** When using Tavily Search and Bing Search, the inference output will still display "Brave Search." This is because Llama models have been trained with Brave Search as a built-in tool. Tavily and bing is just being used in lieu of Brave search.
|
||||||
|
|
||||||
#### Code Interpreter
|
|
||||||
|
|
||||||
The Code Interpreter allows execution of Python code within a controlled environment.
|
|
||||||
|
|
||||||
```python
|
|
||||||
# Register Code Interpreter tool group
|
|
||||||
client.toolgroups.register(
|
|
||||||
toolgroup_id="builtin::code_interpreter", provider_id="code_interpreter"
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
Features:
|
|
||||||
- Secure execution environment using `bwrap` sandboxing
|
|
||||||
- Matplotlib support for generating plots
|
|
||||||
- Disabled dangerous system operations
|
|
||||||
- Configurable execution timeouts
|
|
||||||
|
|
||||||
> ⚠️ Important: The code interpreter tool can operate in a controlled environment locally or on Podman containers. To ensure proper functionality in containerized environments:
|
|
||||||
> - The container requires privileged access (e.g., --privileged).
|
|
||||||
> - Users without sufficient permissions may encounter permission errors. (`bwrap: Can't mount devpts on /newroot/dev/pts: Permission denied`)
|
|
||||||
> - 🔒 Security Warning: Privileged mode grants elevated access and bypasses security restrictions. Use only in local, isolated, or controlled environments.
|
|
||||||
|
|
||||||
#### WolframAlpha
|
#### WolframAlpha
|
||||||
|
|
||||||
|
@ -102,7 +81,7 @@ Features:
|
||||||
- Context retrieval with token limits
|
- Context retrieval with token limits
|
||||||
|
|
||||||
|
|
||||||
> **Note:** By default, llama stack run.yaml defines toolgroups for web search, code interpreter and rag, that are provided by tavily-search, code-interpreter and rag providers.
|
> **Note:** By default, llama stack run.yaml defines toolgroups for web search, wolfram alpha and rag, that are provided by tavily-search, wolfram-alpha and rag providers.
|
||||||
|
|
||||||
## Model Context Protocol (MCP) Tools
|
## Model Context Protocol (MCP) Tools
|
||||||
|
|
||||||
|
@ -186,34 +165,6 @@ all_tools = client.tools.list_tools()
|
||||||
group_tools = client.tools.list_tools(toolgroup_id="search_tools")
|
group_tools = client.tools.list_tools(toolgroup_id="search_tools")
|
||||||
```
|
```
|
||||||
|
|
||||||
## Simple Example: Using an Agent with the Code-Interpreter Tool
|
|
||||||
|
|
||||||
```python
|
|
||||||
from llama_stack_client import Agent
|
|
||||||
|
|
||||||
# Instantiate the AI agent with the given configuration
|
|
||||||
agent = Agent(
|
|
||||||
client,
|
|
||||||
name="code-interpreter",
|
|
||||||
description="A code interpreter agent for executing Python code snippets",
|
|
||||||
instructions="""
|
|
||||||
You are a highly reliable, concise, and precise assistant.
|
|
||||||
Always show the generated code, never generate your own code, and never anticipate results.
|
|
||||||
""",
|
|
||||||
model="meta-llama/Llama-3.2-3B-Instruct",
|
|
||||||
tools=["builtin::code_interpreter"],
|
|
||||||
max_infer_iters=5,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Start a session
|
|
||||||
session_id = agent.create_session("tool_session")
|
|
||||||
|
|
||||||
# Send a query to the AI agent for code execution
|
|
||||||
response = agent.create_turn(
|
|
||||||
messages=[{"role": "user", "content": "Run this code: print(3 ** 4 - 5 * 2)"}],
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
```
|
|
||||||
## Simple Example 2: Using an Agent with the Web Search Tool
|
## Simple Example 2: Using an Agent with the Web Search Tool
|
||||||
1. Start by registering a Tavily API key at [Tavily](https://tavily.com/).
|
1. Start by registering a Tavily API key at [Tavily](https://tavily.com/).
|
||||||
2. [Optional] Provide the API key directly to the Llama Stack server
|
2. [Optional] Provide the API key directly to the Llama Stack server
|
||||||
|
|
|
@ -22,7 +22,11 @@ from docutils import nodes
|
||||||
# Read version from pyproject.toml
|
# Read version from pyproject.toml
|
||||||
with Path(__file__).parent.parent.parent.joinpath("pyproject.toml").open("rb") as f:
|
with Path(__file__).parent.parent.parent.joinpath("pyproject.toml").open("rb") as f:
|
||||||
pypi_url = "https://pypi.org/pypi/llama-stack/json"
|
pypi_url = "https://pypi.org/pypi/llama-stack/json"
|
||||||
version_tag = json.loads(requests.get(pypi_url).text)["info"]["version"]
|
headers = {
|
||||||
|
'User-Agent': 'pip/23.0.1 (python 3.11)', # Mimic pip's user agent
|
||||||
|
'Accept': 'application/json'
|
||||||
|
}
|
||||||
|
version_tag = json.loads(requests.get(pypi_url, headers=headers).text)["info"]["version"]
|
||||||
print(f"{version_tag=}")
|
print(f"{version_tag=}")
|
||||||
|
|
||||||
# generate the full link including text and url here
|
# generate the full link including text and url here
|
||||||
|
@ -53,14 +57,6 @@ myst_enable_extensions = ["colon_fence"]
|
||||||
|
|
||||||
html_theme = "sphinx_rtd_theme"
|
html_theme = "sphinx_rtd_theme"
|
||||||
html_use_relative_paths = True
|
html_use_relative_paths = True
|
||||||
|
|
||||||
# html_theme = "sphinx_pdj_theme"
|
|
||||||
# html_theme_path = [sphinx_pdj_theme.get_html_theme_path()]
|
|
||||||
|
|
||||||
# html_theme = "pytorch_sphinx_theme"
|
|
||||||
# html_theme_path = [pytorch_sphinx_theme.get_html_theme_path()]
|
|
||||||
|
|
||||||
|
|
||||||
templates_path = ["_templates"]
|
templates_path = ["_templates"]
|
||||||
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
|
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
|
||||||
|
|
||||||
|
@ -110,6 +106,8 @@ html_theme_options = {
|
||||||
"canonical_url": "https://github.com/meta-llama/llama-stack",
|
"canonical_url": "https://github.com/meta-llama/llama-stack",
|
||||||
"collapse_navigation": False,
|
"collapse_navigation": False,
|
||||||
# "style_nav_header_background": "#c3c9d4",
|
# "style_nav_header_background": "#c3c9d4",
|
||||||
|
'display_version': True,
|
||||||
|
'version_selector': True,
|
||||||
}
|
}
|
||||||
|
|
||||||
default_dark_mode = False
|
default_dark_mode = False
|
||||||
|
|
|
@ -6,7 +6,7 @@ This guide will walk you through the process of adding a new API provider to Lla
|
||||||
- Begin by reviewing the [core concepts](../concepts/index.md) of Llama Stack and choose the API your provider belongs to (Inference, Safety, VectorIO, etc.)
|
- Begin by reviewing the [core concepts](../concepts/index.md) of Llama Stack and choose the API your provider belongs to (Inference, Safety, VectorIO, etc.)
|
||||||
- Determine the provider type ({repopath}`Remote::llama_stack/providers/remote` or {repopath}`Inline::llama_stack/providers/inline`). Remote providers make requests to external services, while inline providers execute implementation locally.
|
- Determine the provider type ({repopath}`Remote::llama_stack/providers/remote` or {repopath}`Inline::llama_stack/providers/inline`). Remote providers make requests to external services, while inline providers execute implementation locally.
|
||||||
- Add your provider to the appropriate {repopath}`Registry::llama_stack/providers/registry/`. Specify pip dependencies necessary.
|
- Add your provider to the appropriate {repopath}`Registry::llama_stack/providers/registry/`. Specify pip dependencies necessary.
|
||||||
- Update any distribution {repopath}`Templates::llama_stack/templates/` build.yaml and run.yaml files if they should include your provider by default. Run {repopath}`./scripts/distro_codegen.py` if necessary. Note that `distro_codegen.py` will fail if the new provider causes any distribution template to attempt to import provider-specific dependencies. This usually means the distribution's `get_distribution_template()` code path should only import any necessary Config or model alias definitions from each provider and not the provider's actual implementation.
|
- Update any distribution {repopath}`Templates::llama_stack/templates/` `build.yaml` and `run.yaml` files if they should include your provider by default. Run {repopath}`./scripts/distro_codegen.py` if necessary. Note that `distro_codegen.py` will fail if the new provider causes any distribution template to attempt to import provider-specific dependencies. This usually means the distribution's `get_distribution_template()` code path should only import any necessary Config or model alias definitions from each provider and not the provider's actual implementation.
|
||||||
|
|
||||||
|
|
||||||
Here are some example PRs to help you get started:
|
Here are some example PRs to help you get started:
|
||||||
|
@ -33,6 +33,7 @@ Note that each provider's `sample_run_config()` method (in the configuration cla
|
||||||
|
|
||||||
Unit tests are located in {repopath}`tests/unit`. Provider-specific unit tests are located in {repopath}`tests/unit/providers`. These tests are all run automatically as part of the CI process.
|
Unit tests are located in {repopath}`tests/unit`. Provider-specific unit tests are located in {repopath}`tests/unit/providers`. These tests are all run automatically as part of the CI process.
|
||||||
|
|
||||||
|
Consult {repopath}`tests/unit/README.md` for more details on how to run the tests manually.
|
||||||
|
|
||||||
### 3. Additional end-to-end testing
|
### 3. Additional end-to-end testing
|
||||||
|
|
||||||
|
|
|
@ -178,7 +178,7 @@ image_name: ollama
|
||||||
image_type: conda
|
image_type: conda
|
||||||
|
|
||||||
# If some providers are external, you can specify the path to the implementation
|
# If some providers are external, you can specify the path to the implementation
|
||||||
external_providers_dir: /etc/llama-stack/providers.d
|
external_providers_dir: ~/.llama/providers.d
|
||||||
```
|
```
|
||||||
|
|
||||||
```
|
```
|
||||||
|
@ -206,7 +206,7 @@ distribution_spec:
|
||||||
image_type: container
|
image_type: container
|
||||||
image_name: ci-test
|
image_name: ci-test
|
||||||
# Path to external provider implementations
|
# Path to external provider implementations
|
||||||
external_providers_dir: /etc/llama-stack/providers.d
|
external_providers_dir: ~/.llama/providers.d
|
||||||
```
|
```
|
||||||
|
|
||||||
Here's an example for a custom Ollama provider:
|
Here's an example for a custom Ollama provider:
|
||||||
|
@ -271,7 +271,7 @@ Now, let's start the Llama Stack Distribution Server. You will need the YAML con
|
||||||
|
|
||||||
```
|
```
|
||||||
llama stack run -h
|
llama stack run -h
|
||||||
usage: llama stack run [-h] [--port PORT] [--image-name IMAGE_NAME] [--disable-ipv6] [--env KEY=VALUE] [--tls-keyfile TLS_KEYFILE] [--tls-certfile TLS_CERTFILE]
|
usage: llama stack run [-h] [--port PORT] [--image-name IMAGE_NAME] [--env KEY=VALUE] [--tls-keyfile TLS_KEYFILE] [--tls-certfile TLS_CERTFILE]
|
||||||
[--image-type {conda,container,venv}]
|
[--image-type {conda,container,venv}]
|
||||||
config
|
config
|
||||||
|
|
||||||
|
@ -285,7 +285,6 @@ options:
|
||||||
--port PORT Port to run the server on. It can also be passed via the env var LLAMA_STACK_PORT. (default: 8321)
|
--port PORT Port to run the server on. It can also be passed via the env var LLAMA_STACK_PORT. (default: 8321)
|
||||||
--image-name IMAGE_NAME
|
--image-name IMAGE_NAME
|
||||||
Name of the image to run. Defaults to the current environment (default: None)
|
Name of the image to run. Defaults to the current environment (default: None)
|
||||||
--disable-ipv6 Disable IPv6 support (default: False)
|
|
||||||
--env KEY=VALUE Environment variables to pass to the server in KEY=VALUE format. Can be specified multiple times. (default: [])
|
--env KEY=VALUE Environment variables to pass to the server in KEY=VALUE format. Can be specified multiple times. (default: [])
|
||||||
--tls-keyfile TLS_KEYFILE
|
--tls-keyfile TLS_KEYFILE
|
||||||
Path to TLS key file for HTTPS (default: None)
|
Path to TLS key file for HTTPS (default: None)
|
||||||
|
@ -339,6 +338,48 @@ INFO: Application startup complete.
|
||||||
INFO: Uvicorn running on http://['::', '0.0.0.0']:8321 (Press CTRL+C to quit)
|
INFO: Uvicorn running on http://['::', '0.0.0.0']:8321 (Press CTRL+C to quit)
|
||||||
INFO: 2401:db00:35c:2d2b:face:0:c9:0:54678 - "GET /models/list HTTP/1.1" 200 OK
|
INFO: 2401:db00:35c:2d2b:face:0:c9:0:54678 - "GET /models/list HTTP/1.1" 200 OK
|
||||||
```
|
```
|
||||||
|
### Listing Distributions
|
||||||
|
Using the list command, you can view all existing Llama Stack distributions, including stacks built from templates, from scratch, or using custom configuration files.
|
||||||
|
|
||||||
|
```
|
||||||
|
llama stack list -h
|
||||||
|
usage: llama stack list [-h]
|
||||||
|
|
||||||
|
list the build stacks
|
||||||
|
|
||||||
|
options:
|
||||||
|
-h, --help show this help message and exit
|
||||||
|
```
|
||||||
|
|
||||||
|
Example Usage
|
||||||
|
|
||||||
|
```
|
||||||
|
llama stack list
|
||||||
|
```
|
||||||
|
|
||||||
|
### Removing a Distribution
|
||||||
|
Use the remove command to delete a distribution you've previously built.
|
||||||
|
|
||||||
|
```
|
||||||
|
llama stack rm -h
|
||||||
|
usage: llama stack rm [-h] [--all] [name]
|
||||||
|
|
||||||
|
Remove the build stack
|
||||||
|
|
||||||
|
positional arguments:
|
||||||
|
name Name of the stack to delete (default: None)
|
||||||
|
|
||||||
|
options:
|
||||||
|
-h, --help show this help message and exit
|
||||||
|
--all, -a Delete all stacks (use with caution) (default: False)
|
||||||
|
```
|
||||||
|
|
||||||
|
Example
|
||||||
|
```
|
||||||
|
llama stack rm llamastack-test
|
||||||
|
```
|
||||||
|
|
||||||
|
To keep your environment organized and avoid clutter, consider using `llama stack list` to review old or unused distributions and `llama stack rm <name>` to delete them when they’re no longer needed.
|
||||||
|
|
||||||
### Troubleshooting
|
### Troubleshooting
|
||||||
|
|
||||||
|
|
|
@ -53,6 +53,13 @@ models:
|
||||||
provider_id: ollama
|
provider_id: ollama
|
||||||
provider_model_id: null
|
provider_model_id: null
|
||||||
shields: []
|
shields: []
|
||||||
|
server:
|
||||||
|
port: 8321
|
||||||
|
auth:
|
||||||
|
provider_type: "kubernetes"
|
||||||
|
config:
|
||||||
|
api_server_url: "https://kubernetes.default.svc"
|
||||||
|
ca_cert_path: "/path/to/ca.crt"
|
||||||
```
|
```
|
||||||
|
|
||||||
Let's break this down into the different sections. The first section specifies the set of APIs that the stack server will serve:
|
Let's break this down into the different sections. The first section specifies the set of APIs that the stack server will serve:
|
||||||
|
@ -102,6 +109,227 @@ A Model is an instance of a "Resource" (see [Concepts](../concepts/index)) and i
|
||||||
|
|
||||||
What's with the `provider_model_id` field? This is an identifier for the model inside the provider's model catalog. Contrast it with `model_id` which is the identifier for the same model for Llama Stack's purposes. For example, you may want to name "llama3.2:vision-11b" as "image_captioning_model" when you use it in your Stack interactions. When omitted, the server will set `provider_model_id` to be the same as `model_id`.
|
What's with the `provider_model_id` field? This is an identifier for the model inside the provider's model catalog. Contrast it with `model_id` which is the identifier for the same model for Llama Stack's purposes. For example, you may want to name "llama3.2:vision-11b" as "image_captioning_model" when you use it in your Stack interactions. When omitted, the server will set `provider_model_id` to be the same as `model_id`.
|
||||||
|
|
||||||
|
## Server Configuration
|
||||||
|
|
||||||
|
The `server` section configures the HTTP server that serves the Llama Stack APIs:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
server:
|
||||||
|
port: 8321 # Port to listen on (default: 8321)
|
||||||
|
tls_certfile: "/path/to/cert.pem" # Optional: Path to TLS certificate for HTTPS
|
||||||
|
tls_keyfile: "/path/to/key.pem" # Optional: Path to TLS key for HTTPS
|
||||||
|
```
|
||||||
|
|
||||||
|
### Authentication Configuration
|
||||||
|
|
||||||
|
The `auth` section configures authentication for the server. When configured, all API requests must include a valid Bearer token in the Authorization header:
|
||||||
|
|
||||||
|
```
|
||||||
|
Authorization: Bearer <token>
|
||||||
|
```
|
||||||
|
|
||||||
|
The server supports multiple authentication providers:
|
||||||
|
|
||||||
|
#### OAuth 2.0/OpenID Connect Provider with Kubernetes
|
||||||
|
|
||||||
|
The Kubernetes cluster must be configured to use a service account for authentication.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
kubectl create namespace llama-stack
|
||||||
|
kubectl create serviceaccount llama-stack-auth -n llama-stack
|
||||||
|
kubectl create rolebinding llama-stack-auth-rolebinding --clusterrole=admin --serviceaccount=llama-stack:llama-stack-auth -n llama-stack
|
||||||
|
kubectl create token llama-stack-auth -n llama-stack > llama-stack-auth-token
|
||||||
|
```
|
||||||
|
|
||||||
|
Make sure the `kube-apiserver` runs with `--anonymous-auth=true` to allow unauthenticated requests
|
||||||
|
and that the correct RoleBinding is created to allow the service account to access the necessary
|
||||||
|
resources. If that is not the case, you can create a RoleBinding for the service account to access
|
||||||
|
the necessary resources:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# allow-anonymous-openid.yaml
|
||||||
|
apiVersion: rbac.authorization.k8s.io/v1
|
||||||
|
kind: ClusterRole
|
||||||
|
metadata:
|
||||||
|
name: allow-anonymous-openid
|
||||||
|
rules:
|
||||||
|
- nonResourceURLs: ["/openid/v1/jwks"]
|
||||||
|
verbs: ["get"]
|
||||||
|
---
|
||||||
|
apiVersion: rbac.authorization.k8s.io/v1
|
||||||
|
kind: ClusterRoleBinding
|
||||||
|
metadata:
|
||||||
|
name: allow-anonymous-openid
|
||||||
|
roleRef:
|
||||||
|
apiGroup: rbac.authorization.k8s.io
|
||||||
|
kind: ClusterRole
|
||||||
|
name: allow-anonymous-openid
|
||||||
|
subjects:
|
||||||
|
- kind: User
|
||||||
|
name: system:anonymous
|
||||||
|
apiGroup: rbac.authorization.k8s.io
|
||||||
|
```
|
||||||
|
|
||||||
|
And then apply the configuration:
|
||||||
|
```bash
|
||||||
|
kubectl apply -f allow-anonymous-openid.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
Validates tokens against the Kubernetes API server through the OIDC provider:
|
||||||
|
```yaml
|
||||||
|
server:
|
||||||
|
auth:
|
||||||
|
provider_type: "oauth2_token"
|
||||||
|
config:
|
||||||
|
jwks:
|
||||||
|
uri: "https://kubernetes.default.svc"
|
||||||
|
key_recheck_period: 3600
|
||||||
|
tls_cafile: "/path/to/ca.crt"
|
||||||
|
issuer: "https://kubernetes.default.svc"
|
||||||
|
audience: "https://kubernetes.default.svc"
|
||||||
|
```
|
||||||
|
|
||||||
|
To find your cluster's audience, run:
|
||||||
|
```bash
|
||||||
|
kubectl create token default --duration=1h | cut -d. -f2 | base64 -d | jq .aud
|
||||||
|
```
|
||||||
|
|
||||||
|
For the issuer, you can use the OIDC provider's URL:
|
||||||
|
```bash
|
||||||
|
kubectl get --raw /.well-known/openid-configuration| jq .issuer
|
||||||
|
```
|
||||||
|
|
||||||
|
For the tls_cafile, you can use the CA certificate of the OIDC provider:
|
||||||
|
```bash
|
||||||
|
kubectl config view --minify -o jsonpath='{.clusters[0].cluster.certificate-authority}'
|
||||||
|
```
|
||||||
|
|
||||||
|
The provider extracts user information from the JWT token:
|
||||||
|
- Username from the `sub` claim becomes a role
|
||||||
|
- Kubernetes groups become teams
|
||||||
|
|
||||||
|
You can easily validate a request by running:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -s -L -H "Authorization: Bearer $(cat llama-stack-auth-token)" http://127.0.0.1:8321/v1/providers
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Custom Provider
|
||||||
|
Validates tokens against a custom authentication endpoint:
|
||||||
|
```yaml
|
||||||
|
server:
|
||||||
|
auth:
|
||||||
|
provider_type: "custom"
|
||||||
|
config:
|
||||||
|
endpoint: "https://auth.example.com/validate" # URL of the auth endpoint
|
||||||
|
```
|
||||||
|
|
||||||
|
The custom endpoint receives a POST request with:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"api_key": "<token>",
|
||||||
|
"request": {
|
||||||
|
"path": "/api/v1/endpoint",
|
||||||
|
"headers": {
|
||||||
|
"content-type": "application/json",
|
||||||
|
"user-agent": "curl/7.64.1"
|
||||||
|
},
|
||||||
|
"params": {
|
||||||
|
"key": ["value"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
And must respond with:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"access_attributes": {
|
||||||
|
"roles": ["admin", "user"],
|
||||||
|
"teams": ["ml-team", "nlp-team"],
|
||||||
|
"projects": ["llama-3", "project-x"],
|
||||||
|
"namespaces": ["research"]
|
||||||
|
},
|
||||||
|
"message": "Authentication successful"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
If no access attributes are returned, the token is used as a namespace.
|
||||||
|
|
||||||
|
### Quota Configuration
|
||||||
|
|
||||||
|
The `quota` section allows you to enable server-side request throttling for both
|
||||||
|
authenticated and anonymous clients. This is useful for preventing abuse, enforcing
|
||||||
|
fairness across tenants, and controlling infrastructure costs without requiring
|
||||||
|
client-side rate limiting or external proxies.
|
||||||
|
|
||||||
|
Quotas are disabled by default. When enabled, each client is tracked using either:
|
||||||
|
|
||||||
|
* Their authenticated `client_id` (derived from the Bearer token), or
|
||||||
|
* Their IP address (fallback for anonymous requests)
|
||||||
|
|
||||||
|
Quota state is stored in a SQLite-backed key-value store, and rate limits are applied
|
||||||
|
within a configurable time window (currently only `day` is supported).
|
||||||
|
|
||||||
|
#### Example
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
server:
|
||||||
|
quota:
|
||||||
|
kvstore:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ./quotas.db
|
||||||
|
anonymous_max_requests: 100
|
||||||
|
authenticated_max_requests: 1000
|
||||||
|
period: day
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Configuration Options
|
||||||
|
|
||||||
|
| Field | Description |
|
||||||
|
| ---------------------------- | -------------------------------------------------------------------------- |
|
||||||
|
| `kvstore` | Required. Backend storage config for tracking request counts. |
|
||||||
|
| `kvstore.type` | Must be `"sqlite"` for now. Other backends may be supported in the future. |
|
||||||
|
| `kvstore.db_path` | File path to the SQLite database. |
|
||||||
|
| `anonymous_max_requests` | Max requests per period for unauthenticated clients. |
|
||||||
|
| `authenticated_max_requests` | Max requests per period for authenticated clients. |
|
||||||
|
| `period` | Time window for quota enforcement. Only `"day"` is supported. |
|
||||||
|
|
||||||
|
> Note: if `authenticated_max_requests` is set but no authentication provider is
|
||||||
|
configured, the server will fall back to applying `anonymous_max_requests` to all
|
||||||
|
clients.
|
||||||
|
|
||||||
|
#### Example with Authentication Enabled
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
server:
|
||||||
|
port: 8321
|
||||||
|
auth:
|
||||||
|
provider_type: custom
|
||||||
|
config:
|
||||||
|
endpoint: https://auth.example.com/validate
|
||||||
|
quota:
|
||||||
|
kvstore:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ./quotas.db
|
||||||
|
anonymous_max_requests: 100
|
||||||
|
authenticated_max_requests: 1000
|
||||||
|
period: day
|
||||||
|
```
|
||||||
|
|
||||||
|
If a client exceeds their limit, the server responds with:
|
||||||
|
|
||||||
|
```http
|
||||||
|
HTTP/1.1 429 Too Many Requests
|
||||||
|
Content-Type: application/json
|
||||||
|
|
||||||
|
{
|
||||||
|
"error": {
|
||||||
|
"message": "Quota exceeded"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
## Extending to handle Safety
|
## Extending to handle Safety
|
||||||
|
|
||||||
Configuring Safety can be a little involved so it is instructive to go through an example.
|
Configuring Safety can be a little involved so it is instructive to go through an example.
|
||||||
|
|
|
@ -172,7 +172,7 @@ spec:
|
||||||
- name: llama-stack
|
- name: llama-stack
|
||||||
image: localhost/llama-stack-run-k8s:latest
|
image: localhost/llama-stack-run-k8s:latest
|
||||||
imagePullPolicy: IfNotPresent
|
imagePullPolicy: IfNotPresent
|
||||||
command: ["python", "-m", "llama_stack.distribution.server.server", "--yaml-config", "/app/config.yaml"]
|
command: ["python", "-m", "llama_stack.distribution.server.server", "--config", "/app/config.yaml"]
|
||||||
ports:
|
ports:
|
||||||
- containerPort: 5000
|
- containerPort: 5000
|
||||||
volumeMounts:
|
volumeMounts:
|
||||||
|
|
|
@ -18,11 +18,11 @@ The `llamastack/distribution-watsonx` distribution consists of the following pro
|
||||||
| agents | `inline::meta-reference` |
|
| agents | `inline::meta-reference` |
|
||||||
| datasetio | `remote::huggingface`, `inline::localfs` |
|
| datasetio | `remote::huggingface`, `inline::localfs` |
|
||||||
| eval | `inline::meta-reference` |
|
| eval | `inline::meta-reference` |
|
||||||
| inference | `remote::watsonx` |
|
| inference | `remote::watsonx`, `inline::sentence-transformers` |
|
||||||
| safety | `inline::llama-guard` |
|
| safety | `inline::llama-guard` |
|
||||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||||
| telemetry | `inline::meta-reference` |
|
| telemetry | `inline::meta-reference` |
|
||||||
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime`, `remote::model-context-protocol` |
|
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::rag-runtime`, `remote::model-context-protocol` |
|
||||||
| vector_io | `inline::faiss` |
|
| vector_io | `inline::faiss` |
|
||||||
|
|
||||||
|
|
||||||
|
@ -70,7 +70,7 @@ docker run \
|
||||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||||
-v ./run.yaml:/root/my-run.yaml \
|
-v ./run.yaml:/root/my-run.yaml \
|
||||||
llamastack/distribution-watsonx \
|
llamastack/distribution-watsonx \
|
||||||
--yaml-config /root/my-run.yaml \
|
--config /root/my-run.yaml \
|
||||||
--port $LLAMA_STACK_PORT \
|
--port $LLAMA_STACK_PORT \
|
||||||
--env WATSONX_API_KEY=$WATSONX_API_KEY \
|
--env WATSONX_API_KEY=$WATSONX_API_KEY \
|
||||||
--env WATSONX_PROJECT_ID=$WATSONX_PROJECT_ID \
|
--env WATSONX_PROJECT_ID=$WATSONX_PROJECT_ID \
|
||||||
|
|
|
@ -19,7 +19,7 @@ The `llamastack/distribution-bedrock` distribution consists of the following pro
|
||||||
| safety | `remote::bedrock` |
|
| safety | `remote::bedrock` |
|
||||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||||
| telemetry | `inline::meta-reference` |
|
| telemetry | `inline::meta-reference` |
|
||||||
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime`, `remote::model-context-protocol` |
|
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::rag-runtime`, `remote::model-context-protocol` |
|
||||||
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -12,7 +12,7 @@ The `llamastack/distribution-cerebras` distribution consists of the following pr
|
||||||
| safety | `inline::llama-guard` |
|
| safety | `inline::llama-guard` |
|
||||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||||
| telemetry | `inline::meta-reference` |
|
| telemetry | `inline::meta-reference` |
|
||||||
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime` |
|
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::rag-runtime` |
|
||||||
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
||||||
|
|
||||||
|
|
||||||
|
@ -52,7 +52,7 @@ docker run \
|
||||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||||
-v ./run.yaml:/root/my-run.yaml \
|
-v ./run.yaml:/root/my-run.yaml \
|
||||||
llamastack/distribution-cerebras \
|
llamastack/distribution-cerebras \
|
||||||
--yaml-config /root/my-run.yaml \
|
--config /root/my-run.yaml \
|
||||||
--port $LLAMA_STACK_PORT \
|
--port $LLAMA_STACK_PORT \
|
||||||
--env CEREBRAS_API_KEY=$CEREBRAS_API_KEY
|
--env CEREBRAS_API_KEY=$CEREBRAS_API_KEY
|
||||||
```
|
```
|
||||||
|
|
|
@ -23,7 +23,7 @@ The `llamastack/distribution-dell` distribution consists of the following provid
|
||||||
| safety | `inline::llama-guard` |
|
| safety | `inline::llama-guard` |
|
||||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||||
| telemetry | `inline::meta-reference` |
|
| telemetry | `inline::meta-reference` |
|
||||||
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime` |
|
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::rag-runtime` |
|
||||||
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
||||||
|
|
||||||
|
|
||||||
|
@ -155,7 +155,7 @@ docker run \
|
||||||
-v $HOME/.llama:/root/.llama \
|
-v $HOME/.llama:/root/.llama \
|
||||||
-v ./llama_stack/templates/tgi/run-with-safety.yaml:/root/my-run.yaml \
|
-v ./llama_stack/templates/tgi/run-with-safety.yaml:/root/my-run.yaml \
|
||||||
llamastack/distribution-dell \
|
llamastack/distribution-dell \
|
||||||
--yaml-config /root/my-run.yaml \
|
--config /root/my-run.yaml \
|
||||||
--port $LLAMA_STACK_PORT \
|
--port $LLAMA_STACK_PORT \
|
||||||
--env INFERENCE_MODEL=$INFERENCE_MODEL \
|
--env INFERENCE_MODEL=$INFERENCE_MODEL \
|
||||||
--env DEH_URL=$DEH_URL \
|
--env DEH_URL=$DEH_URL \
|
||||||
|
|
|
@ -22,7 +22,7 @@ The `llamastack/distribution-fireworks` distribution consists of the following p
|
||||||
| safety | `inline::llama-guard` |
|
| safety | `inline::llama-guard` |
|
||||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||||
| telemetry | `inline::meta-reference` |
|
| telemetry | `inline::meta-reference` |
|
||||||
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `remote::wolfram-alpha`, `inline::code-interpreter`, `inline::rag-runtime`, `remote::model-context-protocol` |
|
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `remote::wolfram-alpha`, `inline::rag-runtime`, `remote::model-context-protocol` |
|
||||||
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -22,7 +22,7 @@ The `llamastack/distribution-groq` distribution consists of the following provid
|
||||||
| safety | `inline::llama-guard` |
|
| safety | `inline::llama-guard` |
|
||||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||||
| telemetry | `inline::meta-reference` |
|
| telemetry | `inline::meta-reference` |
|
||||||
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime` |
|
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::rag-runtime` |
|
||||||
| vector_io | `inline::faiss` |
|
| vector_io | `inline::faiss` |
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -22,7 +22,7 @@ The `llamastack/distribution-meta-reference-gpu` distribution consists of the fo
|
||||||
| safety | `inline::llama-guard` |
|
| safety | `inline::llama-guard` |
|
||||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||||
| telemetry | `inline::meta-reference` |
|
| telemetry | `inline::meta-reference` |
|
||||||
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime`, `remote::model-context-protocol` |
|
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::rag-runtime`, `remote::model-context-protocol` |
|
||||||
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -6,7 +6,7 @@ The `llamastack/distribution-nvidia` distribution consists of the following prov
|
||||||
| API | Provider(s) |
|
| API | Provider(s) |
|
||||||
|-----|-------------|
|
|-----|-------------|
|
||||||
| agents | `inline::meta-reference` |
|
| agents | `inline::meta-reference` |
|
||||||
| datasetio | `inline::localfs` |
|
| datasetio | `inline::localfs`, `remote::nvidia` |
|
||||||
| eval | `remote::nvidia` |
|
| eval | `remote::nvidia` |
|
||||||
| inference | `remote::nvidia` |
|
| inference | `remote::nvidia` |
|
||||||
| post_training | `remote::nvidia` |
|
| post_training | `remote::nvidia` |
|
||||||
|
@ -143,7 +143,7 @@ docker run \
|
||||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||||
-v ./run.yaml:/root/my-run.yaml \
|
-v ./run.yaml:/root/my-run.yaml \
|
||||||
llamastack/distribution-nvidia \
|
llamastack/distribution-nvidia \
|
||||||
--yaml-config /root/my-run.yaml \
|
--config /root/my-run.yaml \
|
||||||
--port $LLAMA_STACK_PORT \
|
--port $LLAMA_STACK_PORT \
|
||||||
--env NVIDIA_API_KEY=$NVIDIA_API_KEY
|
--env NVIDIA_API_KEY=$NVIDIA_API_KEY
|
||||||
```
|
```
|
||||||
|
|
|
@ -19,10 +19,11 @@ The `llamastack/distribution-ollama` distribution consists of the following prov
|
||||||
| datasetio | `remote::huggingface`, `inline::localfs` |
|
| datasetio | `remote::huggingface`, `inline::localfs` |
|
||||||
| eval | `inline::meta-reference` |
|
| eval | `inline::meta-reference` |
|
||||||
| inference | `remote::ollama` |
|
| inference | `remote::ollama` |
|
||||||
|
| post_training | `inline::huggingface` |
|
||||||
| safety | `inline::llama-guard` |
|
| safety | `inline::llama-guard` |
|
||||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||||
| telemetry | `inline::meta-reference` |
|
| telemetry | `inline::meta-reference` |
|
||||||
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime`, `remote::model-context-protocol`, `remote::wolfram-alpha` |
|
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::rag-runtime`, `remote::model-context-protocol`, `remote::wolfram-alpha` |
|
||||||
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
||||||
|
|
||||||
|
|
||||||
|
@ -97,7 +98,7 @@ docker run \
|
||||||
-v ~/.llama:/root/.llama \
|
-v ~/.llama:/root/.llama \
|
||||||
-v ./llama_stack/templates/ollama/run-with-safety.yaml:/root/my-run.yaml \
|
-v ./llama_stack/templates/ollama/run-with-safety.yaml:/root/my-run.yaml \
|
||||||
llamastack/distribution-ollama \
|
llamastack/distribution-ollama \
|
||||||
--yaml-config /root/my-run.yaml \
|
--config /root/my-run.yaml \
|
||||||
--port $LLAMA_STACK_PORT \
|
--port $LLAMA_STACK_PORT \
|
||||||
--env INFERENCE_MODEL=$INFERENCE_MODEL \
|
--env INFERENCE_MODEL=$INFERENCE_MODEL \
|
||||||
--env SAFETY_MODEL=$SAFETY_MODEL \
|
--env SAFETY_MODEL=$SAFETY_MODEL \
|
||||||
|
|
|
@ -22,7 +22,7 @@ The `llamastack/distribution-passthrough` distribution consists of the following
|
||||||
| safety | `inline::llama-guard` |
|
| safety | `inline::llama-guard` |
|
||||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||||
| telemetry | `inline::meta-reference` |
|
| telemetry | `inline::meta-reference` |
|
||||||
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `remote::wolfram-alpha`, `inline::code-interpreter`, `inline::rag-runtime`, `remote::model-context-protocol` |
|
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `remote::wolfram-alpha`, `inline::rag-runtime`, `remote::model-context-protocol` |
|
||||||
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -21,7 +21,7 @@ The `llamastack/distribution-remote-vllm` distribution consists of the following
|
||||||
| safety | `inline::llama-guard` |
|
| safety | `inline::llama-guard` |
|
||||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||||
| telemetry | `inline::meta-reference` |
|
| telemetry | `inline::meta-reference` |
|
||||||
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime`, `remote::model-context-protocol`, `remote::wolfram-alpha` |
|
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::rag-runtime`, `remote::model-context-protocol`, `remote::wolfram-alpha` |
|
||||||
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
||||||
|
|
||||||
|
|
||||||
|
@ -233,7 +233,7 @@ docker run \
|
||||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||||
-v ./llama_stack/templates/remote-vllm/run.yaml:/root/my-run.yaml \
|
-v ./llama_stack/templates/remote-vllm/run.yaml:/root/my-run.yaml \
|
||||||
llamastack/distribution-remote-vllm \
|
llamastack/distribution-remote-vllm \
|
||||||
--yaml-config /root/my-run.yaml \
|
--config /root/my-run.yaml \
|
||||||
--port $LLAMA_STACK_PORT \
|
--port $LLAMA_STACK_PORT \
|
||||||
--env INFERENCE_MODEL=$INFERENCE_MODEL \
|
--env INFERENCE_MODEL=$INFERENCE_MODEL \
|
||||||
--env VLLM_URL=http://host.docker.internal:$INFERENCE_PORT/v1
|
--env VLLM_URL=http://host.docker.internal:$INFERENCE_PORT/v1
|
||||||
|
@ -255,7 +255,7 @@ docker run \
|
||||||
-v ~/.llama:/root/.llama \
|
-v ~/.llama:/root/.llama \
|
||||||
-v ./llama_stack/templates/remote-vllm/run-with-safety.yaml:/root/my-run.yaml \
|
-v ./llama_stack/templates/remote-vllm/run-with-safety.yaml:/root/my-run.yaml \
|
||||||
llamastack/distribution-remote-vllm \
|
llamastack/distribution-remote-vllm \
|
||||||
--yaml-config /root/my-run.yaml \
|
--config /root/my-run.yaml \
|
||||||
--port $LLAMA_STACK_PORT \
|
--port $LLAMA_STACK_PORT \
|
||||||
--env INFERENCE_MODEL=$INFERENCE_MODEL \
|
--env INFERENCE_MODEL=$INFERENCE_MODEL \
|
||||||
--env VLLM_URL=http://host.docker.internal:$INFERENCE_PORT/v1 \
|
--env VLLM_URL=http://host.docker.internal:$INFERENCE_PORT/v1 \
|
||||||
|
|
|
@ -16,10 +16,10 @@ The `llamastack/distribution-sambanova` distribution consists of the following p
|
||||||
| API | Provider(s) |
|
| API | Provider(s) |
|
||||||
|-----|-------------|
|
|-----|-------------|
|
||||||
| agents | `inline::meta-reference` |
|
| agents | `inline::meta-reference` |
|
||||||
| inference | `remote::sambanova` |
|
| inference | `remote::sambanova`, `inline::sentence-transformers` |
|
||||||
| safety | `inline::llama-guard` |
|
| safety | `remote::sambanova` |
|
||||||
| telemetry | `inline::meta-reference` |
|
| telemetry | `inline::meta-reference` |
|
||||||
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime` |
|
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::rag-runtime`, `remote::model-context-protocol`, `remote::wolfram-alpha` |
|
||||||
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
||||||
|
|
||||||
|
|
||||||
|
@ -28,53 +28,64 @@ The `llamastack/distribution-sambanova` distribution consists of the following p
|
||||||
The following environment variables can be configured:
|
The following environment variables can be configured:
|
||||||
|
|
||||||
- `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `8321`)
|
- `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `8321`)
|
||||||
- `SAMBANOVA_API_KEY`: SambaNova.AI API Key (default: ``)
|
- `SAMBANOVA_API_KEY`: SambaNova API Key (default: ``)
|
||||||
|
|
||||||
### Models
|
### Models
|
||||||
|
|
||||||
The following models are available by default:
|
The following models are available by default:
|
||||||
|
|
||||||
- `Meta-Llama-3.1-8B-Instruct (aliases: meta-llama/Llama-3.1-8B-Instruct)`
|
- `sambanova/Meta-Llama-3.1-8B-Instruct (aliases: meta-llama/Llama-3.1-8B-Instruct)`
|
||||||
- `Meta-Llama-3.1-70B-Instruct (aliases: meta-llama/Llama-3.1-70B-Instruct)`
|
- `sambanova/Meta-Llama-3.1-405B-Instruct (aliases: meta-llama/Llama-3.1-405B-Instruct-FP8)`
|
||||||
- `Meta-Llama-3.1-405B-Instruct (aliases: meta-llama/Llama-3.1-405B-Instruct-FP8)`
|
- `sambanova/Meta-Llama-3.2-1B-Instruct (aliases: meta-llama/Llama-3.2-1B-Instruct)`
|
||||||
- `Meta-Llama-3.2-1B-Instruct (aliases: meta-llama/Llama-3.2-1B-Instruct)`
|
- `sambanova/Meta-Llama-3.2-3B-Instruct (aliases: meta-llama/Llama-3.2-3B-Instruct)`
|
||||||
- `Meta-Llama-3.2-3B-Instruct (aliases: meta-llama/Llama-3.2-3B-Instruct)`
|
- `sambanova/Meta-Llama-3.3-70B-Instruct (aliases: meta-llama/Llama-3.3-70B-Instruct)`
|
||||||
- `Meta-Llama-3.3-70B-Instruct (aliases: meta-llama/Llama-3.3-70B-Instruct)`
|
- `sambanova/Llama-3.2-11B-Vision-Instruct (aliases: meta-llama/Llama-3.2-11B-Vision-Instruct)`
|
||||||
- `Llama-3.2-11B-Vision-Instruct (aliases: meta-llama/Llama-3.2-11B-Vision-Instruct)`
|
- `sambanova/Llama-3.2-90B-Vision-Instruct (aliases: meta-llama/Llama-3.2-90B-Vision-Instruct)`
|
||||||
- `Llama-3.2-90B-Vision-Instruct (aliases: meta-llama/Llama-3.2-90B-Vision-Instruct)`
|
- `sambanova/Llama-4-Scout-17B-16E-Instruct (aliases: meta-llama/Llama-4-Scout-17B-16E-Instruct)`
|
||||||
- `Meta-Llama-Guard-3-8B (aliases: meta-llama/Llama-Guard-3-8B)`
|
- `sambanova/Llama-4-Maverick-17B-128E-Instruct (aliases: meta-llama/Llama-4-Maverick-17B-128E-Instruct)`
|
||||||
- `Llama-4-Scout-17B-16E-Instruct (aliases: meta-llama/Llama-4-Scout-17B-16E-Instruct)`
|
- `sambanova/Meta-Llama-Guard-3-8B (aliases: meta-llama/Llama-Guard-3-8B)`
|
||||||
|
|
||||||
|
|
||||||
### Prerequisite: API Keys
|
### Prerequisite: API Keys
|
||||||
|
|
||||||
Make sure you have access to a SambaNova API Key. You can get one by visiting [SambaNova.ai](https://sambanova.ai/).
|
Make sure you have access to a SambaNova API Key. You can get one by visiting [SambaNova.ai](http://cloud.sambanova.ai?utm_source=llamastack&utm_medium=external&utm_campaign=cloud_signup).
|
||||||
|
|
||||||
|
|
||||||
## Running Llama Stack with SambaNova
|
## Running Llama Stack with SambaNova
|
||||||
|
|
||||||
You can do this via Conda (build code) or Docker which has a pre-built image.
|
You can do this via Conda (build code) or Docker which has a pre-built image.
|
||||||
|
|
||||||
### Via Docker
|
|
||||||
|
|
||||||
This method allows you to get started quickly without having to build the distribution code.
|
### Via Docker
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
LLAMA_STACK_PORT=8321
|
LLAMA_STACK_PORT=8321
|
||||||
|
llama stack build --template sambanova --image-type container
|
||||||
docker run \
|
docker run \
|
||||||
-it \
|
-it \
|
||||||
--pull always \
|
|
||||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||||
llamastack/distribution-sambanova \
|
-v ~/.llama:/root/.llama \
|
||||||
|
distribution-sambanova \
|
||||||
--port $LLAMA_STACK_PORT \
|
--port $LLAMA_STACK_PORT \
|
||||||
--env SAMBANOVA_API_KEY=$SAMBANOVA_API_KEY
|
--env SAMBANOVA_API_KEY=$SAMBANOVA_API_KEY
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
|
### Via Venv
|
||||||
|
|
||||||
|
```bash
|
||||||
|
llama stack build --template sambanova --image-type venv
|
||||||
|
llama stack run --image-type venv ~/.llama/distributions/sambanova/sambanova-run.yaml \
|
||||||
|
--port $LLAMA_STACK_PORT \
|
||||||
|
--env SAMBANOVA_API_KEY=$SAMBANOVA_API_KEY
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
### Via Conda
|
### Via Conda
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
llama stack build --template sambanova --image-type conda
|
llama stack build --template sambanova --image-type conda
|
||||||
llama stack run ./run.yaml \
|
llama stack run --image-type conda ~/.llama/distributions/sambanova/sambanova-run.yaml \
|
||||||
--port $LLAMA_STACK_PORT \
|
--port $LLAMA_STACK_PORT \
|
||||||
--env SAMBANOVA_API_KEY=$SAMBANOVA_API_KEY
|
--env SAMBANOVA_API_KEY=$SAMBANOVA_API_KEY
|
||||||
```
|
```
|
||||||
|
|
|
@ -23,7 +23,7 @@ The `llamastack/distribution-tgi` distribution consists of the following provide
|
||||||
| safety | `inline::llama-guard` |
|
| safety | `inline::llama-guard` |
|
||||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||||
| telemetry | `inline::meta-reference` |
|
| telemetry | `inline::meta-reference` |
|
||||||
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime`, `remote::model-context-protocol` |
|
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::rag-runtime`, `remote::model-context-protocol` |
|
||||||
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
||||||
|
|
||||||
|
|
||||||
|
@ -117,7 +117,7 @@ docker run \
|
||||||
-v ~/.llama:/root/.llama \
|
-v ~/.llama:/root/.llama \
|
||||||
-v ./llama_stack/templates/tgi/run-with-safety.yaml:/root/my-run.yaml \
|
-v ./llama_stack/templates/tgi/run-with-safety.yaml:/root/my-run.yaml \
|
||||||
llamastack/distribution-tgi \
|
llamastack/distribution-tgi \
|
||||||
--yaml-config /root/my-run.yaml \
|
--config /root/my-run.yaml \
|
||||||
--port $LLAMA_STACK_PORT \
|
--port $LLAMA_STACK_PORT \
|
||||||
--env INFERENCE_MODEL=$INFERENCE_MODEL \
|
--env INFERENCE_MODEL=$INFERENCE_MODEL \
|
||||||
--env TGI_URL=http://host.docker.internal:$INFERENCE_PORT \
|
--env TGI_URL=http://host.docker.internal:$INFERENCE_PORT \
|
||||||
|
|
|
@ -22,7 +22,7 @@ The `llamastack/distribution-together` distribution consists of the following pr
|
||||||
| safety | `inline::llama-guard` |
|
| safety | `inline::llama-guard` |
|
||||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||||
| telemetry | `inline::meta-reference` |
|
| telemetry | `inline::meta-reference` |
|
||||||
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime`, `remote::model-context-protocol`, `remote::wolfram-alpha` |
|
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::rag-runtime`, `remote::model-context-protocol`, `remote::wolfram-alpha` |
|
||||||
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -42,7 +42,7 @@ powershell -ExecutionPolicy ByPass -c "irm https://astral.sh/uv/install.ps1 | ie
|
||||||
Setup your virtual environment.
|
Setup your virtual environment.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
uv venv --python 3.10
|
uv sync --python 3.10
|
||||||
source .venv/bin/activate
|
source .venv/bin/activate
|
||||||
```
|
```
|
||||||
## Step 2: Run Llama Stack
|
## Step 2: Run Llama Stack
|
||||||
|
@ -445,7 +445,6 @@ from llama_stack_client import LlamaStackClient
|
||||||
from llama_stack_client import Agent, AgentEventLogger
|
from llama_stack_client import Agent, AgentEventLogger
|
||||||
from llama_stack_client.types import Document
|
from llama_stack_client.types import Document
|
||||||
import uuid
|
import uuid
|
||||||
from termcolor import cprint
|
|
||||||
|
|
||||||
client = LlamaStackClient(base_url="http://localhost:8321")
|
client = LlamaStackClient(base_url="http://localhost:8321")
|
||||||
|
|
||||||
|
@ -463,7 +462,6 @@ urls = [
|
||||||
"memory_optimizations.rst",
|
"memory_optimizations.rst",
|
||||||
"chat.rst",
|
"chat.rst",
|
||||||
"llama3.rst",
|
"llama3.rst",
|
||||||
"datasets.rst",
|
|
||||||
"qat_finetune.rst",
|
"qat_finetune.rst",
|
||||||
"lora_finetune.rst",
|
"lora_finetune.rst",
|
||||||
]
|
]
|
||||||
|
|
|
@ -10,7 +10,7 @@ Llama Stack supports external providers that live outside of the main codebase.
|
||||||
To enable external providers, you need to configure the `external_providers_dir` in your Llama Stack configuration. This directory should contain your external provider specifications:
|
To enable external providers, you need to configure the `external_providers_dir` in your Llama Stack configuration. This directory should contain your external provider specifications:
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
external_providers_dir: /etc/llama-stack/providers.d/
|
external_providers_dir: ~/.llama/providers.d/
|
||||||
```
|
```
|
||||||
|
|
||||||
## Directory Structure
|
## Directory Structure
|
||||||
|
@ -53,7 +53,9 @@ Here's a list of known external providers that you can use with Llama Stack:
|
||||||
| Name | Description | API | Type | Repository |
|
| Name | Description | API | Type | Repository |
|
||||||
|------|-------------|-----|------|------------|
|
|------|-------------|-----|------|------------|
|
||||||
| KubeFlow Training | Train models with KubeFlow | Post Training | Remote | [llama-stack-provider-kft](https://github.com/opendatahub-io/llama-stack-provider-kft) |
|
| KubeFlow Training | Train models with KubeFlow | Post Training | Remote | [llama-stack-provider-kft](https://github.com/opendatahub-io/llama-stack-provider-kft) |
|
||||||
|
| KubeFlow Pipelines | Train models with KubeFlow Pipelines | Post Training | Inline **and** Remote | [llama-stack-provider-kfp-trainer](https://github.com/opendatahub-io/llama-stack-provider-kfp-trainer) |
|
||||||
| RamaLama | Inference models with RamaLama | Inference | Remote | [ramalama-stack](https://github.com/containers/ramalama-stack) |
|
| RamaLama | Inference models with RamaLama | Inference | Remote | [ramalama-stack](https://github.com/containers/ramalama-stack) |
|
||||||
|
| TrustyAI LM-Eval | Evaluate models with TrustyAI LM-Eval | Eval | Remote | [llama-stack-provider-lmeval](https://github.com/trustyai-explainability/llama-stack-provider-lmeval) |
|
||||||
|
|
||||||
### Remote Provider Specification
|
### Remote Provider Specification
|
||||||
|
|
||||||
|
@ -180,7 +182,7 @@ dependencies = ["llama-stack", "pydantic", "ollama", "aiohttp"]
|
||||||
3. Create the provider specification:
|
3. Create the provider specification:
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
# /etc/llama-stack/providers.d/remote/inference/custom_ollama.yaml
|
# ~/.llama/providers.d/remote/inference/custom_ollama.yaml
|
||||||
adapter:
|
adapter:
|
||||||
adapter_type: custom_ollama
|
adapter_type: custom_ollama
|
||||||
pip_packages: ["ollama", "aiohttp"]
|
pip_packages: ["ollama", "aiohttp"]
|
||||||
|
@ -199,7 +201,7 @@ uv pip install -e .
|
||||||
5. Configure Llama Stack to use external providers:
|
5. Configure Llama Stack to use external providers:
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
external_providers_dir: /etc/llama-stack/providers.d/
|
external_providers_dir: ~/.llama/providers.d/
|
||||||
```
|
```
|
||||||
|
|
||||||
The provider will now be available in Llama Stack with the type `remote::custom_ollama`.
|
The provider will now be available in Llama Stack with the type `remote::custom_ollama`.
|
||||||
|
|
|
@ -30,6 +30,18 @@ Runs inference with an LLM.
|
||||||
## Post Training
|
## Post Training
|
||||||
Fine-tunes a model.
|
Fine-tunes a model.
|
||||||
|
|
||||||
|
#### Post Training Providers
|
||||||
|
The following providers are available for Post Training:
|
||||||
|
|
||||||
|
```{toctree}
|
||||||
|
:maxdepth: 1
|
||||||
|
|
||||||
|
external
|
||||||
|
post_training/huggingface
|
||||||
|
post_training/torchtune
|
||||||
|
post_training/nvidia_nemo
|
||||||
|
```
|
||||||
|
|
||||||
## Safety
|
## Safety
|
||||||
Applies safety policies to the output at a Systems (not only model) level.
|
Applies safety policies to the output at a Systems (not only model) level.
|
||||||
|
|
||||||
|
|
122
docs/source/providers/post_training/huggingface.md
Normal file
122
docs/source/providers/post_training/huggingface.md
Normal file
|
@ -0,0 +1,122 @@
|
||||||
|
---
|
||||||
|
orphan: true
|
||||||
|
---
|
||||||
|
# HuggingFace SFTTrainer
|
||||||
|
|
||||||
|
[HuggingFace SFTTrainer](https://huggingface.co/docs/trl/en/sft_trainer) is an inline post training provider for Llama Stack. It allows you to run supervised fine tuning on a variety of models using many datasets
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
- Simple access through the post_training API
|
||||||
|
- Fully integrated with Llama Stack
|
||||||
|
- GPU support, CPU support, and MPS support (MacOS Metal Performance Shaders)
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
To use the HF SFTTrainer in your Llama Stack project, follow these steps:
|
||||||
|
|
||||||
|
1. Configure your Llama Stack project to use this provider.
|
||||||
|
2. Kick off a SFT job using the Llama Stack post_training API.
|
||||||
|
|
||||||
|
## Setup
|
||||||
|
|
||||||
|
You can access the HuggingFace trainer via the `ollama` distribution:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
llama stack build --template ollama --image-type venv
|
||||||
|
llama stack run --image-type venv ~/.llama/distributions/ollama/ollama-run.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
## Run Training
|
||||||
|
|
||||||
|
You can access the provider and the `supervised_fine_tune` method via the post_training API:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
|
||||||
|
from llama_stack_client.types import (
|
||||||
|
post_training_supervised_fine_tune_params,
|
||||||
|
algorithm_config_param,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_http_client():
|
||||||
|
from llama_stack_client import LlamaStackClient
|
||||||
|
|
||||||
|
return LlamaStackClient(base_url="http://localhost:8321")
|
||||||
|
|
||||||
|
|
||||||
|
client = create_http_client()
|
||||||
|
|
||||||
|
# Example Dataset
|
||||||
|
client.datasets.register(
|
||||||
|
purpose="post-training/messages",
|
||||||
|
source={
|
||||||
|
"type": "uri",
|
||||||
|
"uri": "huggingface://datasets/llamastack/simpleqa?split=train",
|
||||||
|
},
|
||||||
|
dataset_id="simpleqa",
|
||||||
|
)
|
||||||
|
|
||||||
|
training_config = post_training_supervised_fine_tune_params.TrainingConfig(
|
||||||
|
data_config=post_training_supervised_fine_tune_params.TrainingConfigDataConfig(
|
||||||
|
batch_size=32,
|
||||||
|
data_format="instruct",
|
||||||
|
dataset_id="simpleqa",
|
||||||
|
shuffle=True,
|
||||||
|
),
|
||||||
|
gradient_accumulation_steps=1,
|
||||||
|
max_steps_per_epoch=0,
|
||||||
|
max_validation_steps=1,
|
||||||
|
n_epochs=4,
|
||||||
|
)
|
||||||
|
|
||||||
|
algorithm_config = algorithm_config_param.LoraFinetuningConfig( # this config is also currently mandatory but should not be
|
||||||
|
alpha=1,
|
||||||
|
apply_lora_to_mlp=True,
|
||||||
|
apply_lora_to_output=False,
|
||||||
|
lora_attn_modules=["q_proj"],
|
||||||
|
rank=1,
|
||||||
|
type="LoRA",
|
||||||
|
)
|
||||||
|
|
||||||
|
job_uuid = f"test-job{uuid.uuid4()}"
|
||||||
|
|
||||||
|
# Example Model
|
||||||
|
training_model = "ibm-granite/granite-3.3-8b-instruct"
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
response = client.post_training.supervised_fine_tune(
|
||||||
|
job_uuid=job_uuid,
|
||||||
|
logger_config={},
|
||||||
|
model=training_model,
|
||||||
|
hyperparam_search_config={},
|
||||||
|
training_config=training_config,
|
||||||
|
algorithm_config=algorithm_config,
|
||||||
|
checkpoint_dir="output",
|
||||||
|
)
|
||||||
|
print("Job: ", job_uuid)
|
||||||
|
|
||||||
|
|
||||||
|
# Wait for the job to complete!
|
||||||
|
while True:
|
||||||
|
status = client.post_training.job.status(job_uuid=job_uuid)
|
||||||
|
if not status:
|
||||||
|
print("Job not found")
|
||||||
|
break
|
||||||
|
|
||||||
|
print(status)
|
||||||
|
if status.status == "completed":
|
||||||
|
break
|
||||||
|
|
||||||
|
print("Waiting for job to complete...")
|
||||||
|
time.sleep(5)
|
||||||
|
|
||||||
|
end_time = time.time()
|
||||||
|
print("Job completed in", end_time - start_time, "seconds!")
|
||||||
|
|
||||||
|
print("Artifacts:")
|
||||||
|
print(client.post_training.job.artifacts(job_uuid=job_uuid))
|
||||||
|
```
|
163
docs/source/providers/post_training/nvidia_nemo.md
Normal file
163
docs/source/providers/post_training/nvidia_nemo.md
Normal file
|
@ -0,0 +1,163 @@
|
||||||
|
---
|
||||||
|
orphan: true
|
||||||
|
---
|
||||||
|
# NVIDIA NEMO
|
||||||
|
|
||||||
|
[NVIDIA NEMO](https://developer.nvidia.com/nemo-framework) is a remote post training provider for Llama Stack. It provides enterprise-grade fine-tuning capabilities through NVIDIA's NeMo Customizer service.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
- Enterprise-grade fine-tuning capabilities
|
||||||
|
- Support for LoRA and SFT fine-tuning
|
||||||
|
- Integration with NVIDIA's NeMo Customizer service
|
||||||
|
- Support for various NVIDIA-optimized models
|
||||||
|
- Efficient training with NVIDIA hardware acceleration
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
To use NVIDIA NEMO in your Llama Stack project, follow these steps:
|
||||||
|
|
||||||
|
1. Configure your Llama Stack project to use this provider.
|
||||||
|
2. Set up your NVIDIA API credentials.
|
||||||
|
3. Kick off a fine-tuning job using the Llama Stack post_training API.
|
||||||
|
|
||||||
|
## Setup
|
||||||
|
|
||||||
|
You'll need to set the following environment variables:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export NVIDIA_API_KEY="your-api-key"
|
||||||
|
export NVIDIA_DATASET_NAMESPACE="default"
|
||||||
|
export NVIDIA_CUSTOMIZER_URL="your-customizer-url"
|
||||||
|
export NVIDIA_PROJECT_ID="your-project-id"
|
||||||
|
export NVIDIA_OUTPUT_MODEL_DIR="your-output-model-dir"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Run Training
|
||||||
|
|
||||||
|
You can access the provider and the `supervised_fine_tune` method via the post_training API:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from llama_stack_client.types import (
|
||||||
|
post_training_supervised_fine_tune_params,
|
||||||
|
algorithm_config_param,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_http_client():
|
||||||
|
from llama_stack_client import LlamaStackClient
|
||||||
|
|
||||||
|
return LlamaStackClient(base_url="http://localhost:8321")
|
||||||
|
|
||||||
|
|
||||||
|
client = create_http_client()
|
||||||
|
|
||||||
|
# Example Dataset
|
||||||
|
client.datasets.register(
|
||||||
|
purpose="post-training/messages",
|
||||||
|
source={
|
||||||
|
"type": "uri",
|
||||||
|
"uri": "huggingface://datasets/llamastack/simpleqa?split=train",
|
||||||
|
},
|
||||||
|
dataset_id="simpleqa",
|
||||||
|
)
|
||||||
|
|
||||||
|
training_config = post_training_supervised_fine_tune_params.TrainingConfig(
|
||||||
|
data_config=post_training_supervised_fine_tune_params.TrainingConfigDataConfig(
|
||||||
|
batch_size=8, # Default batch size for NEMO
|
||||||
|
data_format="instruct",
|
||||||
|
dataset_id="simpleqa",
|
||||||
|
shuffle=True,
|
||||||
|
),
|
||||||
|
n_epochs=50, # Default epochs for NEMO
|
||||||
|
optimizer_config=post_training_supervised_fine_tune_params.TrainingConfigOptimizerConfig(
|
||||||
|
lr=0.0001, # Default learning rate
|
||||||
|
weight_decay=0.01, # NEMO-specific parameter
|
||||||
|
),
|
||||||
|
# NEMO-specific parameters
|
||||||
|
log_every_n_steps=None,
|
||||||
|
val_check_interval=0.25,
|
||||||
|
sequence_packing_enabled=False,
|
||||||
|
hidden_dropout=None,
|
||||||
|
attention_dropout=None,
|
||||||
|
ffn_dropout=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
algorithm_config = algorithm_config_param.LoraFinetuningConfig(
|
||||||
|
alpha=16, # Default alpha for NEMO
|
||||||
|
type="LoRA",
|
||||||
|
)
|
||||||
|
|
||||||
|
job_uuid = f"test-job{uuid.uuid4()}"
|
||||||
|
|
||||||
|
# Example Model - must be a supported NEMO model
|
||||||
|
training_model = "meta/llama-3.1-8b-instruct"
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
response = client.post_training.supervised_fine_tune(
|
||||||
|
job_uuid=job_uuid,
|
||||||
|
logger_config={},
|
||||||
|
model=training_model,
|
||||||
|
hyperparam_search_config={},
|
||||||
|
training_config=training_config,
|
||||||
|
algorithm_config=algorithm_config,
|
||||||
|
checkpoint_dir="output",
|
||||||
|
)
|
||||||
|
print("Job: ", job_uuid)
|
||||||
|
|
||||||
|
# Wait for the job to complete!
|
||||||
|
while True:
|
||||||
|
status = client.post_training.job.status(job_uuid=job_uuid)
|
||||||
|
if not status:
|
||||||
|
print("Job not found")
|
||||||
|
break
|
||||||
|
|
||||||
|
print(status)
|
||||||
|
if status.status == "completed":
|
||||||
|
break
|
||||||
|
|
||||||
|
print("Waiting for job to complete...")
|
||||||
|
time.sleep(5)
|
||||||
|
|
||||||
|
end_time = time.time()
|
||||||
|
print("Job completed in", end_time - start_time, "seconds!")
|
||||||
|
|
||||||
|
print("Artifacts:")
|
||||||
|
print(client.post_training.job.artifacts(job_uuid=job_uuid))
|
||||||
|
```
|
||||||
|
|
||||||
|
## Supported Models
|
||||||
|
|
||||||
|
Currently supports the following models:
|
||||||
|
- meta/llama-3.1-8b-instruct
|
||||||
|
- meta/llama-3.2-1b-instruct
|
||||||
|
|
||||||
|
## Supported Parameters
|
||||||
|
|
||||||
|
### TrainingConfig
|
||||||
|
- n_epochs (default: 50)
|
||||||
|
- data_config
|
||||||
|
- optimizer_config
|
||||||
|
- log_every_n_steps
|
||||||
|
- val_check_interval (default: 0.25)
|
||||||
|
- sequence_packing_enabled (default: False)
|
||||||
|
- hidden_dropout (0.0-1.0)
|
||||||
|
- attention_dropout (0.0-1.0)
|
||||||
|
- ffn_dropout (0.0-1.0)
|
||||||
|
|
||||||
|
### DataConfig
|
||||||
|
- dataset_id
|
||||||
|
- batch_size (default: 8)
|
||||||
|
|
||||||
|
### OptimizerConfig
|
||||||
|
- lr (default: 0.0001)
|
||||||
|
- weight_decay (default: 0.01)
|
||||||
|
|
||||||
|
### LoRA Config
|
||||||
|
- alpha (default: 16)
|
||||||
|
- type (must be "LoRA")
|
||||||
|
|
||||||
|
Note: Some parameters from the standard Llama Stack API are not supported and will be ignored with a warning.
|
125
docs/source/providers/post_training/torchtune.md
Normal file
125
docs/source/providers/post_training/torchtune.md
Normal file
|
@ -0,0 +1,125 @@
|
||||||
|
---
|
||||||
|
orphan: true
|
||||||
|
---
|
||||||
|
# TorchTune
|
||||||
|
|
||||||
|
[TorchTune](https://github.com/pytorch/torchtune) is an inline post training provider for Llama Stack. It provides a simple and efficient way to fine-tune language models using PyTorch.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
- Simple access through the post_training API
|
||||||
|
- Fully integrated with Llama Stack
|
||||||
|
- GPU support and single device capabilities.
|
||||||
|
- Support for LoRA
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
To use TorchTune in your Llama Stack project, follow these steps:
|
||||||
|
|
||||||
|
1. Configure your Llama Stack project to use this provider.
|
||||||
|
2. Kick off a fine-tuning job using the Llama Stack post_training API.
|
||||||
|
|
||||||
|
## Setup
|
||||||
|
|
||||||
|
You can access the TorchTune trainer by writing your own yaml pointing to the provider:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
post_training:
|
||||||
|
- provider_id: torchtune
|
||||||
|
provider_type: inline::torchtune
|
||||||
|
config: {}
|
||||||
|
```
|
||||||
|
|
||||||
|
you can then build and run your own stack with this provider.
|
||||||
|
|
||||||
|
## Run Training
|
||||||
|
|
||||||
|
You can access the provider and the `supervised_fine_tune` method via the post_training API:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from llama_stack_client.types import (
|
||||||
|
post_training_supervised_fine_tune_params,
|
||||||
|
algorithm_config_param,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_http_client():
|
||||||
|
from llama_stack_client import LlamaStackClient
|
||||||
|
|
||||||
|
return LlamaStackClient(base_url="http://localhost:8321")
|
||||||
|
|
||||||
|
|
||||||
|
client = create_http_client()
|
||||||
|
|
||||||
|
# Example Dataset
|
||||||
|
client.datasets.register(
|
||||||
|
purpose="post-training/messages",
|
||||||
|
source={
|
||||||
|
"type": "uri",
|
||||||
|
"uri": "huggingface://datasets/llamastack/simpleqa?split=train",
|
||||||
|
},
|
||||||
|
dataset_id="simpleqa",
|
||||||
|
)
|
||||||
|
|
||||||
|
training_config = post_training_supervised_fine_tune_params.TrainingConfig(
|
||||||
|
data_config=post_training_supervised_fine_tune_params.TrainingConfigDataConfig(
|
||||||
|
batch_size=32,
|
||||||
|
data_format="instruct",
|
||||||
|
dataset_id="simpleqa",
|
||||||
|
shuffle=True,
|
||||||
|
),
|
||||||
|
gradient_accumulation_steps=1,
|
||||||
|
max_steps_per_epoch=0,
|
||||||
|
max_validation_steps=1,
|
||||||
|
n_epochs=4,
|
||||||
|
)
|
||||||
|
|
||||||
|
algorithm_config = algorithm_config_param.LoraFinetuningConfig(
|
||||||
|
alpha=1,
|
||||||
|
apply_lora_to_mlp=True,
|
||||||
|
apply_lora_to_output=False,
|
||||||
|
lora_attn_modules=["q_proj"],
|
||||||
|
rank=1,
|
||||||
|
type="LoRA",
|
||||||
|
)
|
||||||
|
|
||||||
|
job_uuid = f"test-job{uuid.uuid4()}"
|
||||||
|
|
||||||
|
# Example Model
|
||||||
|
training_model = "meta-llama/Llama-2-7b-hf"
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
response = client.post_training.supervised_fine_tune(
|
||||||
|
job_uuid=job_uuid,
|
||||||
|
logger_config={},
|
||||||
|
model=training_model,
|
||||||
|
hyperparam_search_config={},
|
||||||
|
training_config=training_config,
|
||||||
|
algorithm_config=algorithm_config,
|
||||||
|
checkpoint_dir="output",
|
||||||
|
)
|
||||||
|
print("Job: ", job_uuid)
|
||||||
|
|
||||||
|
# Wait for the job to complete!
|
||||||
|
while True:
|
||||||
|
status = client.post_training.job.status(job_uuid=job_uuid)
|
||||||
|
if not status:
|
||||||
|
print("Job not found")
|
||||||
|
break
|
||||||
|
|
||||||
|
print(status)
|
||||||
|
if status.status == "completed":
|
||||||
|
break
|
||||||
|
|
||||||
|
print("Waiting for job to complete...")
|
||||||
|
time.sleep(5)
|
||||||
|
|
||||||
|
end_time = time.time()
|
||||||
|
print("Job completed in", end_time - start_time, "seconds!")
|
||||||
|
|
||||||
|
print("Artifacts:")
|
||||||
|
print(client.post_training.job.artifacts(job_uuid=job_uuid))
|
||||||
|
```
|
107
docs/source/providers/vector_io/milvus.md
Normal file
107
docs/source/providers/vector_io/milvus.md
Normal file
|
@ -0,0 +1,107 @@
|
||||||
|
---
|
||||||
|
orphan: true
|
||||||
|
---
|
||||||
|
# Milvus
|
||||||
|
|
||||||
|
[Milvus](https://milvus.io/) is an inline and remote vector database provider for Llama Stack. It
|
||||||
|
allows you to store and query vectors directly within a Milvus database.
|
||||||
|
That means you're not limited to storing vectors in memory or in a separate service.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
- Easy to use
|
||||||
|
- Fully integrated with Llama Stack
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
To use Milvus in your Llama Stack project, follow these steps:
|
||||||
|
|
||||||
|
1. Install the necessary dependencies.
|
||||||
|
2. Configure your Llama Stack project to use Milvus.
|
||||||
|
3. Start storing and querying vectors.
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
You can install Milvus using pymilvus:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install pymilvus
|
||||||
|
```
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
In Llama Stack, Milvus can be configured in two ways:
|
||||||
|
- **Inline (Local) Configuration** - Uses Milvus-Lite for local storage
|
||||||
|
- **Remote Configuration** - Connects to a remote Milvus server
|
||||||
|
|
||||||
|
### Inline (Local) Configuration
|
||||||
|
|
||||||
|
The simplest method is local configuration, which requires setting `db_path`, a path for locally storing Milvus-Lite files:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
vector_io:
|
||||||
|
- provider_id: milvus
|
||||||
|
provider_type: inline::milvus
|
||||||
|
config:
|
||||||
|
db_path: ~/.llama/distributions/together/milvus_store.db
|
||||||
|
```
|
||||||
|
|
||||||
|
### Remote Configuration
|
||||||
|
|
||||||
|
Remote configuration is suitable for larger data storage requirements:
|
||||||
|
|
||||||
|
#### Standard Remote Connection
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
vector_io:
|
||||||
|
- provider_id: milvus
|
||||||
|
provider_type: remote::milvus
|
||||||
|
config:
|
||||||
|
uri: "http://<host>:<port>"
|
||||||
|
token: "<user>:<password>"
|
||||||
|
```
|
||||||
|
|
||||||
|
#### TLS-Enabled Remote Connection (One-way TLS)
|
||||||
|
|
||||||
|
For connections to Milvus instances with one-way TLS enabled:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
vector_io:
|
||||||
|
- provider_id: milvus
|
||||||
|
provider_type: remote::milvus
|
||||||
|
config:
|
||||||
|
uri: "https://<host>:<port>"
|
||||||
|
token: "<user>:<password>"
|
||||||
|
secure: True
|
||||||
|
server_pem_path: "/path/to/server.pem"
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Mutual TLS (mTLS) Remote Connection
|
||||||
|
|
||||||
|
For connections to Milvus instances with mutual TLS (mTLS) enabled:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
vector_io:
|
||||||
|
- provider_id: milvus
|
||||||
|
provider_type: remote::milvus
|
||||||
|
config:
|
||||||
|
uri: "https://<host>:<port>"
|
||||||
|
token: "<user>:<password>"
|
||||||
|
secure: True
|
||||||
|
ca_pem_path: "/path/to/ca.pem"
|
||||||
|
client_pem_path: "/path/to/client.pem"
|
||||||
|
client_key_path: "/path/to/client.key"
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Key Parameters for TLS Configuration
|
||||||
|
|
||||||
|
- **`secure`**: Enables TLS encryption when set to `true`. Defaults to `false`.
|
||||||
|
- **`server_pem_path`**: Path to the **server certificate** for verifying the server’s identity (used in one-way TLS).
|
||||||
|
- **`ca_pem_path`**: Path to the **Certificate Authority (CA) certificate** for validating the server certificate (required in mTLS).
|
||||||
|
- **`client_pem_path`**: Path to the **client certificate** file (required for mTLS).
|
||||||
|
- **`client_key_path`**: Path to the **client private key** file (required for mTLS).
|
||||||
|
|
||||||
|
## Documentation
|
||||||
|
See the [Milvus documentation](https://milvus.io/docs/install-overview.md) for more details about Milvus in general.
|
||||||
|
|
||||||
|
For more details on TLS configuration, refer to the [TLS setup guide](https://milvus.io/docs/tls.md).
|
|
@ -1,31 +0,0 @@
|
||||||
---
|
|
||||||
orphan: true
|
|
||||||
---
|
|
||||||
# Milvus
|
|
||||||
|
|
||||||
[Milvus](https://milvus.io/) is an inline and remote vector database provider for Llama Stack. It
|
|
||||||
allows you to store and query vectors directly within a Milvus database.
|
|
||||||
That means you're not limited to storing vectors in memory or in a separate service.
|
|
||||||
|
|
||||||
## Features
|
|
||||||
|
|
||||||
- Easy to use
|
|
||||||
- Fully integrated with Llama Stack
|
|
||||||
|
|
||||||
## Usage
|
|
||||||
|
|
||||||
To use Milvus in your Llama Stack project, follow these steps:
|
|
||||||
|
|
||||||
1. Install the necessary dependencies.
|
|
||||||
2. Configure your Llama Stack project to use Milvus.
|
|
||||||
3. Start storing and querying vectors.
|
|
||||||
|
|
||||||
## Installation
|
|
||||||
|
|
||||||
You can install Milvus using pymilvus:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
pip install pymilvus
|
|
||||||
```
|
|
||||||
## Documentation
|
|
||||||
See the [Milvus documentation](https://milvus.io/docs/install-overview.md) for more details about Milvus in general.
|
|
|
@ -66,6 +66,25 @@ To use sqlite-vec in your Llama Stack project, follow these steps:
|
||||||
2. Configure your Llama Stack project to use SQLite-Vec.
|
2. Configure your Llama Stack project to use SQLite-Vec.
|
||||||
3. Start storing and querying vectors.
|
3. Start storing and querying vectors.
|
||||||
|
|
||||||
|
## Supported Search Modes
|
||||||
|
|
||||||
|
The sqlite-vec provider supports both vector-based and keyword-based (full-text) search modes.
|
||||||
|
|
||||||
|
When using the RAGTool interface, you can specify the desired search behavior via the `mode` parameter in
|
||||||
|
`RAGQueryConfig`. For example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from llama_stack.apis.tool_runtime.rag import RAGQueryConfig
|
||||||
|
|
||||||
|
query_config = RAGQueryConfig(max_chunks=6, mode="vector")
|
||||||
|
|
||||||
|
results = client.tool_runtime.rag_tool.query(
|
||||||
|
vector_db_ids=[vector_db_id],
|
||||||
|
content="what is torchtune",
|
||||||
|
query_config=query_config,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
|
|
||||||
You can install SQLite-Vec using pip:
|
You can install SQLite-Vec using pip:
|
||||||
|
|
|
@ -253,8 +253,6 @@ llama-stack-client toolgroups list
|
||||||
+---------------------------+------------------+------+---------------+
|
+---------------------------+------------------+------+---------------+
|
||||||
| identifier | provider_id | args | mcp_endpoint |
|
| identifier | provider_id | args | mcp_endpoint |
|
||||||
+===========================+==================+======+===============+
|
+===========================+==================+======+===============+
|
||||||
| builtin::code_interpreter | code-interpreter | None | None |
|
|
||||||
+---------------------------+------------------+------+---------------+
|
|
||||||
| builtin::rag | rag-runtime | None | None |
|
| builtin::rag | rag-runtime | None | None |
|
||||||
+---------------------------+------------------+------+---------------+
|
+---------------------------+------------------+------+---------------+
|
||||||
| builtin::websearch | tavily-search | None | None |
|
| builtin::websearch | tavily-search | None | None |
|
||||||
|
|
|
@ -86,11 +86,11 @@ If you're looking for more specific topics, we have a [Zero to Hero Guide](#next
|
||||||
llama stack build --template ollama --image-type conda
|
llama stack build --template ollama --image-type conda
|
||||||
```
|
```
|
||||||
**Expected Output:**
|
**Expected Output:**
|
||||||
```
|
```bash
|
||||||
...
|
...
|
||||||
Build Successful! Next steps:
|
Build Successful!
|
||||||
1. Set the environment variables: LLAMA_STACK_PORT, OLLAMA_URL, INFERENCE_MODEL, SAFETY_MODEL
|
You can find the newly-built template here: ~/.llama/distributions/ollama/ollama-run.yaml
|
||||||
2. `llama stack run /Users/<username>/.llama/distributions/llamastack-ollama/ollama-run.yaml
|
You can run the new Llama Stack Distro via: llama stack run ~/.llama/distributions/ollama/ollama-run.yaml --image-type conda
|
||||||
```
|
```
|
||||||
|
|
||||||
3. **Set the ENV variables by exporting them to the terminal**:
|
3. **Set the ENV variables by exporting them to the terminal**:
|
||||||
|
|
206
install.sh
Executable file
206
install.sh
Executable file
|
@ -0,0 +1,206 @@
|
||||||
|
#!/usr/bin/env bash
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
set -Eeuo pipefail
|
||||||
|
|
||||||
|
PORT=8321
|
||||||
|
OLLAMA_PORT=11434
|
||||||
|
MODEL_ALIAS="llama3.2:3b"
|
||||||
|
SERVER_IMAGE="llamastack/distribution-ollama:0.2.2"
|
||||||
|
WAIT_TIMEOUT=300
|
||||||
|
|
||||||
|
log(){ printf "\e[1;32m%s\e[0m\n" "$*"; }
|
||||||
|
die(){ printf "\e[1;31m❌ %s\e[0m\n" "$*" >&2; exit 1; }
|
||||||
|
|
||||||
|
wait_for_service() {
|
||||||
|
local url="$1"
|
||||||
|
local pattern="$2"
|
||||||
|
local timeout="$3"
|
||||||
|
local name="$4"
|
||||||
|
local start ts
|
||||||
|
log "⏳ Waiting for ${name}…"
|
||||||
|
start=$(date +%s)
|
||||||
|
while true; do
|
||||||
|
if curl --retry 5 --retry-delay 1 --retry-max-time "$timeout" --retry-all-errors --silent --fail "$url" 2>/dev/null | grep -q "$pattern"; then
|
||||||
|
break
|
||||||
|
fi
|
||||||
|
ts=$(date +%s)
|
||||||
|
if (( ts - start >= timeout )); then
|
||||||
|
return 1
|
||||||
|
fi
|
||||||
|
printf '.'
|
||||||
|
sleep 1
|
||||||
|
done
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
usage() {
|
||||||
|
cat << EOF
|
||||||
|
📚 Llama-Stack Deployment Script
|
||||||
|
|
||||||
|
Description:
|
||||||
|
This script sets up and deploys Llama-Stack with Ollama integration in containers.
|
||||||
|
It handles both Docker and Podman runtimes and includes automatic platform detection.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
$(basename "$0") [OPTIONS]
|
||||||
|
|
||||||
|
Options:
|
||||||
|
-p, --port PORT Server port for Llama-Stack (default: ${PORT})
|
||||||
|
-o, --ollama-port PORT Ollama service port (default: ${OLLAMA_PORT})
|
||||||
|
-m, --model MODEL Model alias to use (default: ${MODEL_ALIAS})
|
||||||
|
-i, --image IMAGE Server image (default: ${SERVER_IMAGE})
|
||||||
|
-t, --timeout SECONDS Service wait timeout in seconds (default: ${WAIT_TIMEOUT})
|
||||||
|
-h, --help Show this help message
|
||||||
|
|
||||||
|
For more information:
|
||||||
|
Documentation: https://llama-stack.readthedocs.io/
|
||||||
|
GitHub: https://github.com/meta-llama/llama-stack
|
||||||
|
|
||||||
|
Report issues:
|
||||||
|
https://github.com/meta-llama/llama-stack/issues
|
||||||
|
EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
# Parse command line arguments
|
||||||
|
while [[ $# -gt 0 ]]; do
|
||||||
|
case $1 in
|
||||||
|
-h|--help)
|
||||||
|
usage
|
||||||
|
exit 0
|
||||||
|
;;
|
||||||
|
-p|--port)
|
||||||
|
PORT="$2"
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
-o|--ollama-port)
|
||||||
|
OLLAMA_PORT="$2"
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
-m|--model)
|
||||||
|
MODEL_ALIAS="$2"
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
-i|--image)
|
||||||
|
SERVER_IMAGE="$2"
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
-t|--timeout)
|
||||||
|
WAIT_TIMEOUT="$2"
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
*)
|
||||||
|
die "Unknown option: $1"
|
||||||
|
;;
|
||||||
|
esac
|
||||||
|
done
|
||||||
|
|
||||||
|
if command -v docker &> /dev/null; then
|
||||||
|
ENGINE="docker"
|
||||||
|
elif command -v podman &> /dev/null; then
|
||||||
|
ENGINE="podman"
|
||||||
|
else
|
||||||
|
die "Docker or Podman is required. Install Docker: https://docs.docker.com/get-docker/ or Podman: https://podman.io/getting-started/installation"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Explicitly set the platform for the host architecture
|
||||||
|
HOST_ARCH="$(uname -m)"
|
||||||
|
if [ "$HOST_ARCH" = "arm64" ]; then
|
||||||
|
if [ "$ENGINE" = "docker" ]; then
|
||||||
|
PLATFORM_OPTS=( --platform linux/amd64 )
|
||||||
|
else
|
||||||
|
PLATFORM_OPTS=( --os linux --arch amd64 )
|
||||||
|
fi
|
||||||
|
else
|
||||||
|
PLATFORM_OPTS=()
|
||||||
|
fi
|
||||||
|
|
||||||
|
# macOS + Podman: ensure VM is running before we try to launch containers
|
||||||
|
# If you need GPU passthrough under Podman on macOS, init the VM with libkrun:
|
||||||
|
# CONTAINERS_MACHINE_PROVIDER=libkrun podman machine init
|
||||||
|
if [ "$ENGINE" = "podman" ] && [ "$(uname -s)" = "Darwin" ]; then
|
||||||
|
if ! podman info &>/dev/null; then
|
||||||
|
log "⌛️ Initializing Podman VM…"
|
||||||
|
podman machine init &>/dev/null || true
|
||||||
|
podman machine start &>/dev/null || true
|
||||||
|
|
||||||
|
log "⌛️ Waiting for Podman API…"
|
||||||
|
until podman info &>/dev/null; do
|
||||||
|
sleep 1
|
||||||
|
done
|
||||||
|
log "✅ Podman VM is up"
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Clean up any leftovers from earlier runs
|
||||||
|
for name in ollama-server llama-stack; do
|
||||||
|
ids=$($ENGINE ps -aq --filter "name=^${name}$")
|
||||||
|
if [ -n "$ids" ]; then
|
||||||
|
log "⚠️ Found existing container(s) for '${name}', removing…"
|
||||||
|
$ENGINE rm -f "$ids" > /dev/null 2>&1
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
|
||||||
|
###############################################################################
|
||||||
|
# 0. Create a shared network
|
||||||
|
###############################################################################
|
||||||
|
if ! $ENGINE network inspect llama-net >/dev/null 2>&1; then
|
||||||
|
log "🌐 Creating network…"
|
||||||
|
$ENGINE network create llama-net >/dev/null 2>&1
|
||||||
|
fi
|
||||||
|
|
||||||
|
###############################################################################
|
||||||
|
# 1. Ollama
|
||||||
|
###############################################################################
|
||||||
|
log "🦙 Starting Ollama…"
|
||||||
|
$ENGINE run -d "${PLATFORM_OPTS[@]}" --name ollama-server \
|
||||||
|
--network llama-net \
|
||||||
|
-p "${OLLAMA_PORT}:${OLLAMA_PORT}" \
|
||||||
|
ollama/ollama > /dev/null 2>&1
|
||||||
|
|
||||||
|
if ! wait_for_service "http://localhost:${OLLAMA_PORT}/" "Ollama" "$WAIT_TIMEOUT" "Ollama daemon"; then
|
||||||
|
log "❌ Ollama daemon did not become ready in ${WAIT_TIMEOUT}s; dumping container logs:"
|
||||||
|
$ENGINE logs --tail 200 ollama-server
|
||||||
|
die "Ollama startup failed"
|
||||||
|
fi
|
||||||
|
|
||||||
|
log "📦 Ensuring model is pulled: ${MODEL_ALIAS}…"
|
||||||
|
if ! $ENGINE exec ollama-server ollama pull "${MODEL_ALIAS}" > /dev/null 2>&1; then
|
||||||
|
log "❌ Failed to pull model ${MODEL_ALIAS}; dumping container logs:"
|
||||||
|
$ENGINE logs --tail 200 ollama-server
|
||||||
|
die "Model pull failed"
|
||||||
|
fi
|
||||||
|
|
||||||
|
###############################################################################
|
||||||
|
# 2. Llama‑Stack
|
||||||
|
###############################################################################
|
||||||
|
cmd=( run -d "${PLATFORM_OPTS[@]}" --name llama-stack \
|
||||||
|
--network llama-net \
|
||||||
|
-p "${PORT}:${PORT}" \
|
||||||
|
"${SERVER_IMAGE}" --port "${PORT}" \
|
||||||
|
--env INFERENCE_MODEL="${MODEL_ALIAS}" \
|
||||||
|
--env OLLAMA_URL="http://ollama-server:${OLLAMA_PORT}" )
|
||||||
|
|
||||||
|
log "🦙 Starting Llama‑Stack…"
|
||||||
|
$ENGINE "${cmd[@]}" > /dev/null 2>&1
|
||||||
|
|
||||||
|
if ! wait_for_service "http://127.0.0.1:${PORT}/v1/health" "OK" "$WAIT_TIMEOUT" "Llama-Stack API"; then
|
||||||
|
log "❌ Llama-Stack did not become ready in ${WAIT_TIMEOUT}s; dumping container logs:"
|
||||||
|
$ENGINE logs --tail 200 llama-stack
|
||||||
|
die "Llama-Stack startup failed"
|
||||||
|
fi
|
||||||
|
|
||||||
|
###############################################################################
|
||||||
|
# Done
|
||||||
|
###############################################################################
|
||||||
|
log ""
|
||||||
|
log "🎉 Llama‑Stack is ready!"
|
||||||
|
log "👉 API endpoint: http://localhost:${PORT}"
|
||||||
|
log "📖 Documentation: https://llama-stack.readthedocs.io/en/latest/references/index.html"
|
||||||
|
log "💻 To access the llama‑stack CLI, exec into the container:"
|
||||||
|
log " $ENGINE exec -ti llama-stack bash"
|
||||||
|
log ""
|
6
kvant_build_local.sh
Executable file
6
kvant_build_local.sh
Executable file
|
@ -0,0 +1,6 @@
|
||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
export USE_COPY_NOT_MOUNT=true
|
||||||
|
export LLAMA_STACK_DIR=.
|
||||||
|
|
||||||
|
uvx --from . llama stack build --template kvant --image-type container --image-name kvant
|
17
kvant_start_local.sh
Executable file
17
kvant_start_local.sh
Executable file
|
@ -0,0 +1,17 @@
|
||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
export LLAMA_STACK_PORT=8321
|
||||||
|
# VLLM_API_TOKEN= env file
|
||||||
|
# KEYCLOAK_CLIENT_SECRET= env file
|
||||||
|
|
||||||
|
|
||||||
|
docker run -it \
|
||||||
|
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||||
|
-v $(pwd)/data:/root/.llama \
|
||||||
|
--mount type=bind,source="$(pwd)"/llama_stack/templates/kvant/run.yaml,target=/root/.llama/config.yaml,readonly \
|
||||||
|
--entrypoint python \
|
||||||
|
--env-file ./.env \
|
||||||
|
distribution-kvant:dev \
|
||||||
|
-m llama_stack.distribution.server.server --config /root/.llama/config.yaml \
|
||||||
|
--port $LLAMA_STACK_PORT \
|
||||||
|
|
|
@ -4,24 +4,16 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from collections.abc import AsyncIterator
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import (
|
from typing import Annotated, Any, Literal, Protocol, runtime_checkable
|
||||||
Annotated,
|
|
||||||
Any,
|
|
||||||
AsyncIterator,
|
|
||||||
Dict,
|
|
||||||
List,
|
|
||||||
Literal,
|
|
||||||
Optional,
|
|
||||||
Protocol,
|
|
||||||
Union,
|
|
||||||
runtime_checkable,
|
|
||||||
)
|
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import URL, ContentDelta, InterleavedContent
|
from llama_stack.apis.common.content_types import URL, ContentDelta, InterleavedContent
|
||||||
|
from llama_stack.apis.common.responses import Order, PaginatedResponse
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
CompletionMessage,
|
CompletionMessage,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
|
@ -38,6 +30,23 @@ from llama_stack.apis.safety import SafetyViolation
|
||||||
from llama_stack.apis.tools import ToolDef
|
from llama_stack.apis.tools import ToolDef
|
||||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||||
|
|
||||||
|
from .openai_responses import (
|
||||||
|
ListOpenAIResponseInputItem,
|
||||||
|
ListOpenAIResponseObject,
|
||||||
|
OpenAIResponseInput,
|
||||||
|
OpenAIResponseInputTool,
|
||||||
|
OpenAIResponseObject,
|
||||||
|
OpenAIResponseObjectStream,
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: use enum.StrEnum when we drop support for python 3.10
|
||||||
|
if sys.version_info >= (3, 11):
|
||||||
|
from enum import StrEnum
|
||||||
|
else:
|
||||||
|
|
||||||
|
class StrEnum(str, Enum):
|
||||||
|
"""Backport of StrEnum for Python 3.10 and below."""
|
||||||
|
|
||||||
|
|
||||||
class Attachment(BaseModel):
|
class Attachment(BaseModel):
|
||||||
"""An attachment to an agent turn.
|
"""An attachment to an agent turn.
|
||||||
|
@ -72,11 +81,11 @@ class StepCommon(BaseModel):
|
||||||
|
|
||||||
turn_id: str
|
turn_id: str
|
||||||
step_id: str
|
step_id: str
|
||||||
started_at: Optional[datetime] = None
|
started_at: datetime | None = None
|
||||||
completed_at: Optional[datetime] = None
|
completed_at: datetime | None = None
|
||||||
|
|
||||||
|
|
||||||
class StepType(Enum):
|
class StepType(StrEnum):
|
||||||
"""Type of the step in an agent turn.
|
"""Type of the step in an agent turn.
|
||||||
|
|
||||||
:cvar inference: The step is an inference step that calls an LLM.
|
:cvar inference: The step is an inference step that calls an LLM.
|
||||||
|
@ -100,7 +109,7 @@ class InferenceStep(StepCommon):
|
||||||
|
|
||||||
model_config = ConfigDict(protected_namespaces=())
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
|
||||||
step_type: Literal[StepType.inference.value] = StepType.inference.value
|
step_type: Literal[StepType.inference] = StepType.inference
|
||||||
model_response: CompletionMessage
|
model_response: CompletionMessage
|
||||||
|
|
||||||
|
|
||||||
|
@ -112,9 +121,9 @@ class ToolExecutionStep(StepCommon):
|
||||||
:param tool_responses: The tool responses from the tool calls.
|
:param tool_responses: The tool responses from the tool calls.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
step_type: Literal[StepType.tool_execution.value] = StepType.tool_execution.value
|
step_type: Literal[StepType.tool_execution] = StepType.tool_execution
|
||||||
tool_calls: List[ToolCall]
|
tool_calls: list[ToolCall]
|
||||||
tool_responses: List[ToolResponse]
|
tool_responses: list[ToolResponse]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -124,8 +133,8 @@ class ShieldCallStep(StepCommon):
|
||||||
:param violation: The violation from the shield call.
|
:param violation: The violation from the shield call.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
step_type: Literal[StepType.shield_call.value] = StepType.shield_call.value
|
step_type: Literal[StepType.shield_call] = StepType.shield_call
|
||||||
violation: Optional[SafetyViolation]
|
violation: SafetyViolation | None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -136,19 +145,14 @@ class MemoryRetrievalStep(StepCommon):
|
||||||
:param inserted_context: The context retrieved from the vector databases.
|
:param inserted_context: The context retrieved from the vector databases.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
step_type: Literal[StepType.memory_retrieval.value] = StepType.memory_retrieval.value
|
step_type: Literal[StepType.memory_retrieval] = StepType.memory_retrieval
|
||||||
# TODO: should this be List[str]?
|
# TODO: should this be List[str]?
|
||||||
vector_db_ids: str
|
vector_db_ids: str
|
||||||
inserted_context: InterleavedContent
|
inserted_context: InterleavedContent
|
||||||
|
|
||||||
|
|
||||||
Step = Annotated[
|
Step = Annotated[
|
||||||
Union[
|
InferenceStep | ToolExecutionStep | ShieldCallStep | MemoryRetrievalStep,
|
||||||
InferenceStep,
|
|
||||||
ToolExecutionStep,
|
|
||||||
ShieldCallStep,
|
|
||||||
MemoryRetrievalStep,
|
|
||||||
],
|
|
||||||
Field(discriminator="step_type"),
|
Field(discriminator="step_type"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -159,18 +163,13 @@ class Turn(BaseModel):
|
||||||
|
|
||||||
turn_id: str
|
turn_id: str
|
||||||
session_id: str
|
session_id: str
|
||||||
input_messages: List[
|
input_messages: list[UserMessage | ToolResponseMessage]
|
||||||
Union[
|
steps: list[Step]
|
||||||
UserMessage,
|
|
||||||
ToolResponseMessage,
|
|
||||||
]
|
|
||||||
]
|
|
||||||
steps: List[Step]
|
|
||||||
output_message: CompletionMessage
|
output_message: CompletionMessage
|
||||||
output_attachments: Optional[List[Attachment]] = Field(default_factory=list)
|
output_attachments: list[Attachment] | None = Field(default_factory=lambda: [])
|
||||||
|
|
||||||
started_at: datetime
|
started_at: datetime
|
||||||
completed_at: Optional[datetime] = None
|
completed_at: datetime | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -179,34 +178,31 @@ class Session(BaseModel):
|
||||||
|
|
||||||
session_id: str
|
session_id: str
|
||||||
session_name: str
|
session_name: str
|
||||||
turns: List[Turn]
|
turns: list[Turn]
|
||||||
started_at: datetime
|
started_at: datetime
|
||||||
|
|
||||||
|
|
||||||
class AgentToolGroupWithArgs(BaseModel):
|
class AgentToolGroupWithArgs(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
args: Dict[str, Any]
|
args: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
AgentToolGroup = Union[
|
AgentToolGroup = str | AgentToolGroupWithArgs
|
||||||
str,
|
|
||||||
AgentToolGroupWithArgs,
|
|
||||||
]
|
|
||||||
register_schema(AgentToolGroup, name="AgentTool")
|
register_schema(AgentToolGroup, name="AgentTool")
|
||||||
|
|
||||||
|
|
||||||
class AgentConfigCommon(BaseModel):
|
class AgentConfigCommon(BaseModel):
|
||||||
sampling_params: Optional[SamplingParams] = Field(default_factory=SamplingParams)
|
sampling_params: SamplingParams | None = Field(default_factory=SamplingParams)
|
||||||
|
|
||||||
input_shields: Optional[List[str]] = Field(default_factory=list)
|
input_shields: list[str] | None = Field(default_factory=lambda: [])
|
||||||
output_shields: Optional[List[str]] = Field(default_factory=list)
|
output_shields: list[str] | None = Field(default_factory=lambda: [])
|
||||||
toolgroups: Optional[List[AgentToolGroup]] = Field(default_factory=list)
|
toolgroups: list[AgentToolGroup] | None = Field(default_factory=lambda: [])
|
||||||
client_tools: Optional[List[ToolDef]] = Field(default_factory=list)
|
client_tools: list[ToolDef] | None = Field(default_factory=lambda: [])
|
||||||
tool_choice: Optional[ToolChoice] = Field(default=None, deprecated="use tool_config instead")
|
tool_choice: ToolChoice | None = Field(default=None, deprecated="use tool_config instead")
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None, deprecated="use tool_config instead")
|
tool_prompt_format: ToolPromptFormat | None = Field(default=None, deprecated="use tool_config instead")
|
||||||
tool_config: Optional[ToolConfig] = Field(default=None)
|
tool_config: ToolConfig | None = Field(default=None)
|
||||||
|
|
||||||
max_infer_iters: Optional[int] = 10
|
max_infer_iters: int | None = 10
|
||||||
|
|
||||||
def model_post_init(self, __context):
|
def model_post_init(self, __context):
|
||||||
if self.tool_config:
|
if self.tool_config:
|
||||||
|
@ -236,9 +232,9 @@ class AgentConfig(AgentConfigCommon):
|
||||||
|
|
||||||
model: str
|
model: str
|
||||||
instructions: str
|
instructions: str
|
||||||
name: Optional[str] = None
|
name: str | None = None
|
||||||
enable_session_persistence: Optional[bool] = False
|
enable_session_persistence: bool | None = False
|
||||||
response_format: Optional[ResponseFormat] = None
|
response_format: ResponseFormat | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -248,21 +244,11 @@ class Agent(BaseModel):
|
||||||
created_at: datetime
|
created_at: datetime
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class ListAgentsResponse(BaseModel):
|
|
||||||
data: List[Agent]
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class ListAgentSessionsResponse(BaseModel):
|
|
||||||
data: List[Session]
|
|
||||||
|
|
||||||
|
|
||||||
class AgentConfigOverridablePerTurn(AgentConfigCommon):
|
class AgentConfigOverridablePerTurn(AgentConfigCommon):
|
||||||
instructions: Optional[str] = None
|
instructions: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class AgentTurnResponseEventType(Enum):
|
class AgentTurnResponseEventType(StrEnum):
|
||||||
step_start = "step_start"
|
step_start = "step_start"
|
||||||
step_complete = "step_complete"
|
step_complete = "step_complete"
|
||||||
step_progress = "step_progress"
|
step_progress = "step_progress"
|
||||||
|
@ -274,15 +260,15 @@ class AgentTurnResponseEventType(Enum):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class AgentTurnResponseStepStartPayload(BaseModel):
|
class AgentTurnResponseStepStartPayload(BaseModel):
|
||||||
event_type: Literal[AgentTurnResponseEventType.step_start.value] = AgentTurnResponseEventType.step_start.value
|
event_type: Literal[AgentTurnResponseEventType.step_start] = AgentTurnResponseEventType.step_start
|
||||||
step_type: StepType
|
step_type: StepType
|
||||||
step_id: str
|
step_id: str
|
||||||
metadata: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
metadata: dict[str, Any] | None = Field(default_factory=lambda: {})
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class AgentTurnResponseStepCompletePayload(BaseModel):
|
class AgentTurnResponseStepCompletePayload(BaseModel):
|
||||||
event_type: Literal[AgentTurnResponseEventType.step_complete.value] = AgentTurnResponseEventType.step_complete.value
|
event_type: Literal[AgentTurnResponseEventType.step_complete] = AgentTurnResponseEventType.step_complete
|
||||||
step_type: StepType
|
step_type: StepType
|
||||||
step_id: str
|
step_id: str
|
||||||
step_details: Step
|
step_details: Step
|
||||||
|
@ -292,7 +278,7 @@ class AgentTurnResponseStepCompletePayload(BaseModel):
|
||||||
class AgentTurnResponseStepProgressPayload(BaseModel):
|
class AgentTurnResponseStepProgressPayload(BaseModel):
|
||||||
model_config = ConfigDict(protected_namespaces=())
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
|
||||||
event_type: Literal[AgentTurnResponseEventType.step_progress.value] = AgentTurnResponseEventType.step_progress.value
|
event_type: Literal[AgentTurnResponseEventType.step_progress] = AgentTurnResponseEventType.step_progress
|
||||||
step_type: StepType
|
step_type: StepType
|
||||||
step_id: str
|
step_id: str
|
||||||
|
|
||||||
|
@ -301,33 +287,29 @@ class AgentTurnResponseStepProgressPayload(BaseModel):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class AgentTurnResponseTurnStartPayload(BaseModel):
|
class AgentTurnResponseTurnStartPayload(BaseModel):
|
||||||
event_type: Literal[AgentTurnResponseEventType.turn_start.value] = AgentTurnResponseEventType.turn_start.value
|
event_type: Literal[AgentTurnResponseEventType.turn_start] = AgentTurnResponseEventType.turn_start
|
||||||
turn_id: str
|
turn_id: str
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class AgentTurnResponseTurnCompletePayload(BaseModel):
|
class AgentTurnResponseTurnCompletePayload(BaseModel):
|
||||||
event_type: Literal[AgentTurnResponseEventType.turn_complete.value] = AgentTurnResponseEventType.turn_complete.value
|
event_type: Literal[AgentTurnResponseEventType.turn_complete] = AgentTurnResponseEventType.turn_complete
|
||||||
turn: Turn
|
turn: Turn
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class AgentTurnResponseTurnAwaitingInputPayload(BaseModel):
|
class AgentTurnResponseTurnAwaitingInputPayload(BaseModel):
|
||||||
event_type: Literal[AgentTurnResponseEventType.turn_awaiting_input.value] = (
|
event_type: Literal[AgentTurnResponseEventType.turn_awaiting_input] = AgentTurnResponseEventType.turn_awaiting_input
|
||||||
AgentTurnResponseEventType.turn_awaiting_input.value
|
|
||||||
)
|
|
||||||
turn: Turn
|
turn: Turn
|
||||||
|
|
||||||
|
|
||||||
AgentTurnResponseEventPayload = Annotated[
|
AgentTurnResponseEventPayload = Annotated[
|
||||||
Union[
|
AgentTurnResponseStepStartPayload
|
||||||
AgentTurnResponseStepStartPayload,
|
| AgentTurnResponseStepProgressPayload
|
||||||
AgentTurnResponseStepProgressPayload,
|
| AgentTurnResponseStepCompletePayload
|
||||||
AgentTurnResponseStepCompletePayload,
|
| AgentTurnResponseTurnStartPayload
|
||||||
AgentTurnResponseTurnStartPayload,
|
| AgentTurnResponseTurnCompletePayload
|
||||||
AgentTurnResponseTurnCompletePayload,
|
| AgentTurnResponseTurnAwaitingInputPayload,
|
||||||
AgentTurnResponseTurnAwaitingInputPayload,
|
|
||||||
],
|
|
||||||
Field(discriminator="event_type"),
|
Field(discriminator="event_type"),
|
||||||
]
|
]
|
||||||
register_schema(AgentTurnResponseEventPayload, name="AgentTurnResponseEventPayload")
|
register_schema(AgentTurnResponseEventPayload, name="AgentTurnResponseEventPayload")
|
||||||
|
@ -356,18 +338,13 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
|
||||||
# TODO: figure out how we can simplify this and make why
|
# TODO: figure out how we can simplify this and make why
|
||||||
# ToolResponseMessage needs to be here (it is function call
|
# ToolResponseMessage needs to be here (it is function call
|
||||||
# execution from outside the system)
|
# execution from outside the system)
|
||||||
messages: List[
|
messages: list[UserMessage | ToolResponseMessage]
|
||||||
Union[
|
|
||||||
UserMessage,
|
|
||||||
ToolResponseMessage,
|
|
||||||
]
|
|
||||||
]
|
|
||||||
|
|
||||||
documents: Optional[List[Document]] = None
|
documents: list[Document] | None = None
|
||||||
toolgroups: Optional[List[AgentToolGroup]] = None
|
toolgroups: list[AgentToolGroup] | None = Field(default_factory=lambda: [])
|
||||||
|
|
||||||
stream: Optional[bool] = False
|
stream: bool | None = False
|
||||||
tool_config: Optional[ToolConfig] = None
|
tool_config: ToolConfig | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -375,8 +352,8 @@ class AgentTurnResumeRequest(BaseModel):
|
||||||
agent_id: str
|
agent_id: str
|
||||||
session_id: str
|
session_id: str
|
||||||
turn_id: str
|
turn_id: str
|
||||||
tool_responses: List[ToolResponse]
|
tool_responses: list[ToolResponse]
|
||||||
stream: Optional[bool] = False
|
stream: bool | None = False
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -422,17 +399,12 @@ class Agents(Protocol):
|
||||||
self,
|
self,
|
||||||
agent_id: str,
|
agent_id: str,
|
||||||
session_id: str,
|
session_id: str,
|
||||||
messages: List[
|
messages: list[UserMessage | ToolResponseMessage],
|
||||||
Union[
|
stream: bool | None = False,
|
||||||
UserMessage,
|
documents: list[Document] | None = None,
|
||||||
ToolResponseMessage,
|
toolgroups: list[AgentToolGroup] | None = None,
|
||||||
]
|
tool_config: ToolConfig | None = None,
|
||||||
],
|
) -> Turn | AsyncIterator[AgentTurnResponseStreamChunk]:
|
||||||
stream: Optional[bool] = False,
|
|
||||||
documents: Optional[List[Document]] = None,
|
|
||||||
toolgroups: Optional[List[AgentToolGroup]] = None,
|
|
||||||
tool_config: Optional[ToolConfig] = None,
|
|
||||||
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]:
|
|
||||||
"""Create a new turn for an agent.
|
"""Create a new turn for an agent.
|
||||||
|
|
||||||
:param agent_id: The ID of the agent to create the turn for.
|
:param agent_id: The ID of the agent to create the turn for.
|
||||||
|
@ -443,8 +415,9 @@ class Agents(Protocol):
|
||||||
:param toolgroups: (Optional) List of toolgroups to create the turn with, will be used in addition to the agent's config toolgroups for the request.
|
:param toolgroups: (Optional) List of toolgroups to create the turn with, will be used in addition to the agent's config toolgroups for the request.
|
||||||
:param tool_config: (Optional) The tool configuration to create the turn with, will be used to override the agent's tool_config.
|
:param tool_config: (Optional) The tool configuration to create the turn with, will be used to override the agent's tool_config.
|
||||||
:returns: If stream=False, returns a Turn object.
|
:returns: If stream=False, returns a Turn object.
|
||||||
If stream=True, returns an SSE event stream of AgentTurnResponseStreamChunk
|
If stream=True, returns an SSE event stream of AgentTurnResponseStreamChunk.
|
||||||
"""
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(
|
@webmethod(
|
||||||
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume",
|
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume",
|
||||||
|
@ -456,9 +429,9 @@ class Agents(Protocol):
|
||||||
agent_id: str,
|
agent_id: str,
|
||||||
session_id: str,
|
session_id: str,
|
||||||
turn_id: str,
|
turn_id: str,
|
||||||
tool_responses: List[ToolResponse],
|
tool_responses: list[ToolResponse],
|
||||||
stream: Optional[bool] = False,
|
stream: bool | None = False,
|
||||||
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]:
|
) -> Turn | AsyncIterator[AgentTurnResponseStreamChunk]:
|
||||||
"""Resume an agent turn with executed tool call responses.
|
"""Resume an agent turn with executed tool call responses.
|
||||||
|
|
||||||
When a Turn has the status `awaiting_input` due to pending input from client side tool calls, this endpoint can be used to submit the outputs from the tool calls once they are ready.
|
When a Turn has the status `awaiting_input` due to pending input from client side tool calls, this endpoint can be used to submit the outputs from the tool calls once they are ready.
|
||||||
|
@ -531,13 +504,14 @@ class Agents(Protocol):
|
||||||
self,
|
self,
|
||||||
session_id: str,
|
session_id: str,
|
||||||
agent_id: str,
|
agent_id: str,
|
||||||
turn_ids: Optional[List[str]] = None,
|
turn_ids: list[str] | None = None,
|
||||||
) -> Session:
|
) -> Session:
|
||||||
"""Retrieve an agent session by its ID.
|
"""Retrieve an agent session by its ID.
|
||||||
|
|
||||||
:param session_id: The ID of the session to get.
|
:param session_id: The ID of the session to get.
|
||||||
:param agent_id: The ID of the agent to get the session for.
|
:param agent_id: The ID of the agent to get the session for.
|
||||||
:param turn_ids: (Optional) List of turn IDs to filter the session by.
|
:param turn_ids: (Optional) List of turn IDs to filter the session by.
|
||||||
|
:returns: A Session.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@ -547,7 +521,7 @@ class Agents(Protocol):
|
||||||
session_id: str,
|
session_id: str,
|
||||||
agent_id: str,
|
agent_id: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Delete an agent session by its ID.
|
"""Delete an agent session by its ID and its associated turns.
|
||||||
|
|
||||||
:param session_id: The ID of the session to delete.
|
:param session_id: The ID of the session to delete.
|
||||||
:param agent_id: The ID of the agent to delete the session for.
|
:param agent_id: The ID of the agent to delete the session for.
|
||||||
|
@ -559,17 +533,19 @@ class Agents(Protocol):
|
||||||
self,
|
self,
|
||||||
agent_id: str,
|
agent_id: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Delete an agent by its ID.
|
"""Delete an agent by its ID and its associated sessions and turns.
|
||||||
|
|
||||||
:param agent_id: The ID of the agent to delete.
|
:param agent_id: The ID of the agent to delete.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@webmethod(route="/agents", method="GET")
|
@webmethod(route="/agents", method="GET")
|
||||||
async def list_agents(self) -> ListAgentsResponse:
|
async def list_agents(self, start_index: int | None = None, limit: int | None = None) -> PaginatedResponse:
|
||||||
"""List all agents.
|
"""List all agents.
|
||||||
|
|
||||||
:returns: A ListAgentsResponse.
|
:param start_index: The index to start the pagination from.
|
||||||
|
:param limit: The number of agents to return.
|
||||||
|
:returns: A PaginatedResponse.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@ -586,10 +562,94 @@ class Agents(Protocol):
|
||||||
async def list_agent_sessions(
|
async def list_agent_sessions(
|
||||||
self,
|
self,
|
||||||
agent_id: str,
|
agent_id: str,
|
||||||
) -> ListAgentSessionsResponse:
|
start_index: int | None = None,
|
||||||
|
limit: int | None = None,
|
||||||
|
) -> PaginatedResponse:
|
||||||
"""List all session(s) of a given agent.
|
"""List all session(s) of a given agent.
|
||||||
|
|
||||||
:param agent_id: The ID of the agent to list sessions for.
|
:param agent_id: The ID of the agent to list sessions for.
|
||||||
:returns: A ListAgentSessionsResponse.
|
:param start_index: The index to start the pagination from.
|
||||||
|
:param limit: The number of sessions to return.
|
||||||
|
:returns: A PaginatedResponse.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
# We situate the OpenAI Responses API in the Agents API just like we did things
|
||||||
|
# for Inference. The Responses API, in its intent, serves the same purpose as
|
||||||
|
# the Agents API above -- it is essentially a lightweight "agentic loop" with
|
||||||
|
# integrated tool calling.
|
||||||
|
#
|
||||||
|
# Both of these APIs are inherently stateful.
|
||||||
|
|
||||||
|
@webmethod(route="/openai/v1/responses/{response_id}", method="GET")
|
||||||
|
async def get_openai_response(
|
||||||
|
self,
|
||||||
|
response_id: str,
|
||||||
|
) -> OpenAIResponseObject:
|
||||||
|
"""Retrieve an OpenAI response by its ID.
|
||||||
|
|
||||||
|
:param response_id: The ID of the OpenAI response to retrieve.
|
||||||
|
:returns: An OpenAIResponseObject.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@webmethod(route="/openai/v1/responses", method="POST")
|
||||||
|
async def create_openai_response(
|
||||||
|
self,
|
||||||
|
input: str | list[OpenAIResponseInput],
|
||||||
|
model: str,
|
||||||
|
instructions: str | None = None,
|
||||||
|
previous_response_id: str | None = None,
|
||||||
|
store: bool | None = True,
|
||||||
|
stream: bool | None = False,
|
||||||
|
temperature: float | None = None,
|
||||||
|
tools: list[OpenAIResponseInputTool] | None = None,
|
||||||
|
) -> OpenAIResponseObject | AsyncIterator[OpenAIResponseObjectStream]:
|
||||||
|
"""Create a new OpenAI response.
|
||||||
|
|
||||||
|
:param input: Input message(s) to create the response.
|
||||||
|
:param model: The underlying LLM used for completions.
|
||||||
|
:param previous_response_id: (Optional) if specified, the new response will be a continuation of the previous response. This can be used to easily fork-off new responses from existing responses.
|
||||||
|
:returns: An OpenAIResponseObject.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@webmethod(route="/openai/v1/responses", method="GET")
|
||||||
|
async def list_openai_responses(
|
||||||
|
self,
|
||||||
|
after: str | None = None,
|
||||||
|
limit: int | None = 50,
|
||||||
|
model: str | None = None,
|
||||||
|
order: Order | None = Order.desc,
|
||||||
|
) -> ListOpenAIResponseObject:
|
||||||
|
"""List all OpenAI responses.
|
||||||
|
|
||||||
|
:param after: The ID of the last response to return.
|
||||||
|
:param limit: The number of responses to return.
|
||||||
|
:param model: The model to filter responses by.
|
||||||
|
:param order: The order to sort responses by when sorted by created_at ('asc' or 'desc').
|
||||||
|
:returns: A ListOpenAIResponseObject.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@webmethod(route="/openai/v1/responses/{response_id}/input_items", method="GET")
|
||||||
|
async def list_openai_response_input_items(
|
||||||
|
self,
|
||||||
|
response_id: str,
|
||||||
|
after: str | None = None,
|
||||||
|
before: str | None = None,
|
||||||
|
include: list[str] | None = None,
|
||||||
|
limit: int | None = 20,
|
||||||
|
order: Order | None = Order.desc,
|
||||||
|
) -> ListOpenAIResponseInputItem:
|
||||||
|
"""List input items for a given OpenAI response.
|
||||||
|
|
||||||
|
:param response_id: The ID of the response to retrieve input items for.
|
||||||
|
:param after: An item ID to list items after, used for pagination.
|
||||||
|
:param before: An item ID to list items before, used for pagination.
|
||||||
|
:param include: Additional fields to include in the response.
|
||||||
|
:param limit: A limit on the number of objects to be returned. Limit can range between 1 and 100, and the default is 20.
|
||||||
|
:param order: The order to return the input items in. Default is desc.
|
||||||
|
:returns: An ListOpenAIResponseInputItem.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
279
llama_stack/apis/agents/openai_responses.py
Normal file
279
llama_stack/apis/agents/openai_responses.py
Normal file
|
@ -0,0 +1,279 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Annotated, Any, Literal
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from llama_stack.schema_utils import json_schema_type, register_schema
|
||||||
|
|
||||||
|
# NOTE(ashwin): this file is literally a copy of the OpenAI responses API schema. We should probably
|
||||||
|
# take their YAML and generate this file automatically. Their YAML is available.
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIResponseError(BaseModel):
|
||||||
|
code: str
|
||||||
|
message: str
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIResponseInputMessageContentText(BaseModel):
|
||||||
|
text: str
|
||||||
|
type: Literal["input_text"] = "input_text"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIResponseInputMessageContentImage(BaseModel):
|
||||||
|
detail: Literal["low"] | Literal["high"] | Literal["auto"] = "auto"
|
||||||
|
type: Literal["input_image"] = "input_image"
|
||||||
|
# TODO: handle file_id
|
||||||
|
image_url: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: handle file content types
|
||||||
|
OpenAIResponseInputMessageContent = Annotated[
|
||||||
|
OpenAIResponseInputMessageContentText | OpenAIResponseInputMessageContentImage,
|
||||||
|
Field(discriminator="type"),
|
||||||
|
]
|
||||||
|
register_schema(OpenAIResponseInputMessageContent, name="OpenAIResponseInputMessageContent")
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIResponseOutputMessageContentOutputText(BaseModel):
|
||||||
|
text: str
|
||||||
|
type: Literal["output_text"] = "output_text"
|
||||||
|
|
||||||
|
|
||||||
|
OpenAIResponseOutputMessageContent = Annotated[
|
||||||
|
OpenAIResponseOutputMessageContentOutputText,
|
||||||
|
Field(discriminator="type"),
|
||||||
|
]
|
||||||
|
register_schema(OpenAIResponseOutputMessageContent, name="OpenAIResponseOutputMessageContent")
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIResponseMessage(BaseModel):
|
||||||
|
"""
|
||||||
|
Corresponds to the various Message types in the Responses API.
|
||||||
|
They are all under one type because the Responses API gives them all
|
||||||
|
the same "type" value, and there is no way to tell them apart in certain
|
||||||
|
scenarios.
|
||||||
|
"""
|
||||||
|
|
||||||
|
content: str | list[OpenAIResponseInputMessageContent] | list[OpenAIResponseOutputMessageContent]
|
||||||
|
role: Literal["system"] | Literal["developer"] | Literal["user"] | Literal["assistant"]
|
||||||
|
type: Literal["message"] = "message"
|
||||||
|
|
||||||
|
# The fields below are not used in all scenarios, but are required in others.
|
||||||
|
id: str | None = None
|
||||||
|
status: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIResponseOutputMessageWebSearchToolCall(BaseModel):
|
||||||
|
id: str
|
||||||
|
status: str
|
||||||
|
type: Literal["web_search_call"] = "web_search_call"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIResponseOutputMessageFunctionToolCall(BaseModel):
|
||||||
|
call_id: str
|
||||||
|
name: str
|
||||||
|
arguments: str
|
||||||
|
type: Literal["function_call"] = "function_call"
|
||||||
|
id: str | None = None
|
||||||
|
status: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIResponseOutputMessageMCPCall(BaseModel):
|
||||||
|
id: str
|
||||||
|
type: Literal["mcp_call"] = "mcp_call"
|
||||||
|
arguments: str
|
||||||
|
name: str
|
||||||
|
server_label: str
|
||||||
|
error: str | None = None
|
||||||
|
output: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class MCPListToolsTool(BaseModel):
|
||||||
|
input_schema: dict[str, Any]
|
||||||
|
name: str
|
||||||
|
description: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIResponseOutputMessageMCPListTools(BaseModel):
|
||||||
|
id: str
|
||||||
|
type: Literal["mcp_list_tools"] = "mcp_list_tools"
|
||||||
|
server_label: str
|
||||||
|
tools: list[MCPListToolsTool]
|
||||||
|
|
||||||
|
|
||||||
|
OpenAIResponseOutput = Annotated[
|
||||||
|
OpenAIResponseMessage
|
||||||
|
| OpenAIResponseOutputMessageWebSearchToolCall
|
||||||
|
| OpenAIResponseOutputMessageFunctionToolCall
|
||||||
|
| OpenAIResponseOutputMessageMCPCall
|
||||||
|
| OpenAIResponseOutputMessageMCPListTools,
|
||||||
|
Field(discriminator="type"),
|
||||||
|
]
|
||||||
|
register_schema(OpenAIResponseOutput, name="OpenAIResponseOutput")
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIResponseObject(BaseModel):
|
||||||
|
created_at: int
|
||||||
|
error: OpenAIResponseError | None = None
|
||||||
|
id: str
|
||||||
|
model: str
|
||||||
|
object: Literal["response"] = "response"
|
||||||
|
output: list[OpenAIResponseOutput]
|
||||||
|
parallel_tool_calls: bool = False
|
||||||
|
previous_response_id: str | None = None
|
||||||
|
status: str
|
||||||
|
temperature: float | None = None
|
||||||
|
top_p: float | None = None
|
||||||
|
truncation: str | None = None
|
||||||
|
user: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIResponseObjectStreamResponseCreated(BaseModel):
|
||||||
|
response: OpenAIResponseObject
|
||||||
|
type: Literal["response.created"] = "response.created"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIResponseObjectStreamResponseOutputTextDelta(BaseModel):
|
||||||
|
content_index: int
|
||||||
|
delta: str
|
||||||
|
item_id: str
|
||||||
|
output_index: int
|
||||||
|
sequence_number: int
|
||||||
|
type: Literal["response.output_text.delta"] = "response.output_text.delta"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIResponseObjectStreamResponseCompleted(BaseModel):
|
||||||
|
response: OpenAIResponseObject
|
||||||
|
type: Literal["response.completed"] = "response.completed"
|
||||||
|
|
||||||
|
|
||||||
|
OpenAIResponseObjectStream = Annotated[
|
||||||
|
OpenAIResponseObjectStreamResponseCreated
|
||||||
|
| OpenAIResponseObjectStreamResponseOutputTextDelta
|
||||||
|
| OpenAIResponseObjectStreamResponseCompleted,
|
||||||
|
Field(discriminator="type"),
|
||||||
|
]
|
||||||
|
register_schema(OpenAIResponseObjectStream, name="OpenAIResponseObjectStream")
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIResponseInputFunctionToolCallOutput(BaseModel):
|
||||||
|
"""
|
||||||
|
This represents the output of a function call that gets passed back to the model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
call_id: str
|
||||||
|
output: str
|
||||||
|
type: Literal["function_call_output"] = "function_call_output"
|
||||||
|
id: str | None = None
|
||||||
|
status: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
OpenAIResponseInput = Annotated[
|
||||||
|
# Responses API allows output messages to be passed in as input
|
||||||
|
OpenAIResponseOutputMessageWebSearchToolCall
|
||||||
|
| OpenAIResponseOutputMessageFunctionToolCall
|
||||||
|
| OpenAIResponseInputFunctionToolCallOutput
|
||||||
|
|
|
||||||
|
# Fallback to the generic message type as a last resort
|
||||||
|
OpenAIResponseMessage,
|
||||||
|
Field(union_mode="left_to_right"),
|
||||||
|
]
|
||||||
|
register_schema(OpenAIResponseInput, name="OpenAIResponseInput")
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIResponseInputToolWebSearch(BaseModel):
|
||||||
|
type: Literal["web_search"] | Literal["web_search_preview_2025_03_11"] = "web_search"
|
||||||
|
# TODO: actually use search_context_size somewhere...
|
||||||
|
search_context_size: str | None = Field(default="medium", pattern="^low|medium|high$")
|
||||||
|
# TODO: add user_location
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIResponseInputToolFunction(BaseModel):
|
||||||
|
type: Literal["function"] = "function"
|
||||||
|
name: str
|
||||||
|
description: str | None = None
|
||||||
|
parameters: dict[str, Any] | None
|
||||||
|
strict: bool | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class FileSearchRankingOptions(BaseModel):
|
||||||
|
ranker: str | None = None
|
||||||
|
score_threshold: float | None = Field(default=0.0, ge=0.0, le=1.0)
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIResponseInputToolFileSearch(BaseModel):
|
||||||
|
type: Literal["file_search"] = "file_search"
|
||||||
|
vector_store_id: list[str]
|
||||||
|
ranking_options: FileSearchRankingOptions | None = None
|
||||||
|
# TODO: add filters
|
||||||
|
|
||||||
|
|
||||||
|
class ApprovalFilter(BaseModel):
|
||||||
|
always: list[str] | None = None
|
||||||
|
never: list[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class AllowedToolsFilter(BaseModel):
|
||||||
|
tool_names: list[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIResponseInputToolMCP(BaseModel):
|
||||||
|
type: Literal["mcp"] = "mcp"
|
||||||
|
server_label: str
|
||||||
|
server_url: str
|
||||||
|
headers: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
require_approval: Literal["always"] | Literal["never"] | ApprovalFilter = "never"
|
||||||
|
allowed_tools: list[str] | AllowedToolsFilter | None = None
|
||||||
|
|
||||||
|
|
||||||
|
OpenAIResponseInputTool = Annotated[
|
||||||
|
OpenAIResponseInputToolWebSearch
|
||||||
|
| OpenAIResponseInputToolFileSearch
|
||||||
|
| OpenAIResponseInputToolFunction
|
||||||
|
| OpenAIResponseInputToolMCP,
|
||||||
|
Field(discriminator="type"),
|
||||||
|
]
|
||||||
|
register_schema(OpenAIResponseInputTool, name="OpenAIResponseInputTool")
|
||||||
|
|
||||||
|
|
||||||
|
class ListOpenAIResponseInputItem(BaseModel):
|
||||||
|
data: list[OpenAIResponseInput]
|
||||||
|
object: Literal["list"] = "list"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIResponseObjectWithInput(OpenAIResponseObject):
|
||||||
|
input: list[OpenAIResponseInput]
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ListOpenAIResponseObject(BaseModel):
|
||||||
|
data: list[OpenAIResponseObjectWithInput]
|
||||||
|
has_more: bool
|
||||||
|
first_id: str
|
||||||
|
last_id: str
|
||||||
|
object: Literal["list"] = "list"
|
|
@ -4,7 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import List, Optional, Protocol, runtime_checkable
|
from typing import Protocol, runtime_checkable
|
||||||
|
|
||||||
from llama_stack.apis.common.job_types import Job
|
from llama_stack.apis.common.job_types import Job
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
|
@ -34,22 +34,45 @@ class BatchInference(Protocol):
|
||||||
async def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
content_batch: List[InterleavedContent],
|
content_batch: list[InterleavedContent],
|
||||||
sampling_params: Optional[SamplingParams] = None,
|
sampling_params: SamplingParams | None = None,
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: ResponseFormat | None = None,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: LogProbConfig | None = None,
|
||||||
) -> Job: ...
|
) -> Job:
|
||||||
|
"""Generate completions for a batch of content.
|
||||||
|
|
||||||
|
:param model: The model to use for the completion.
|
||||||
|
:param content_batch: The content to complete.
|
||||||
|
:param sampling_params: The sampling parameters to use for the completion.
|
||||||
|
:param response_format: The response format to use for the completion.
|
||||||
|
:param logprobs: The logprobs to use for the completion.
|
||||||
|
:returns: A job for the completion.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(route="/batch-inference/chat-completion", method="POST")
|
@webmethod(route="/batch-inference/chat-completion", method="POST")
|
||||||
async def chat_completion(
|
async def chat_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
messages_batch: List[List[Message]],
|
messages_batch: list[list[Message]],
|
||||||
sampling_params: Optional[SamplingParams] = None,
|
sampling_params: SamplingParams | None = None,
|
||||||
# zero-shot tool definitions as input to the model
|
# zero-shot tool definitions as input to the model
|
||||||
tools: Optional[List[ToolDefinition]] = None,
|
tools: list[ToolDefinition] | None = None,
|
||||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
tool_choice: ToolChoice | None = ToolChoice.auto,
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
tool_prompt_format: ToolPromptFormat | None = None,
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: ResponseFormat | None = None,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: LogProbConfig | None = None,
|
||||||
) -> Job: ...
|
) -> Job:
|
||||||
|
"""Generate chat completions for a batch of messages.
|
||||||
|
|
||||||
|
:param model: The model to use for the chat completion.
|
||||||
|
:param messages_batch: The messages to complete.
|
||||||
|
:param sampling_params: The sampling parameters to use for the completion.
|
||||||
|
:param tools: The tools to use for the chat completion.
|
||||||
|
:param tool_choice: The tool choice to use for the chat completion.
|
||||||
|
:param tool_prompt_format: The tool prompt format to use for the chat completion.
|
||||||
|
:param response_format: The response format to use for the chat completion.
|
||||||
|
:param logprobs: The logprobs to use for the chat completion.
|
||||||
|
:returns: A job for the chat completion.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
|
@ -3,7 +3,7 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
|
from typing import Any, Literal, Protocol, runtime_checkable
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
@ -13,8 +13,8 @@ from llama_stack.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
class CommonBenchmarkFields(BaseModel):
|
class CommonBenchmarkFields(BaseModel):
|
||||||
dataset_id: str
|
dataset_id: str
|
||||||
scoring_functions: List[str]
|
scoring_functions: list[str]
|
||||||
metadata: Dict[str, Any] = Field(
|
metadata: dict[str, Any] = Field(
|
||||||
default_factory=dict,
|
default_factory=dict,
|
||||||
description="Metadata for this evaluation task",
|
description="Metadata for this evaluation task",
|
||||||
)
|
)
|
||||||
|
@ -22,45 +22,66 @@ class CommonBenchmarkFields(BaseModel):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class Benchmark(CommonBenchmarkFields, Resource):
|
class Benchmark(CommonBenchmarkFields, Resource):
|
||||||
type: Literal[ResourceType.benchmark.value] = ResourceType.benchmark.value
|
type: Literal[ResourceType.benchmark] = ResourceType.benchmark
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def benchmark_id(self) -> str:
|
def benchmark_id(self) -> str:
|
||||||
return self.identifier
|
return self.identifier
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def provider_benchmark_id(self) -> str:
|
def provider_benchmark_id(self) -> str | None:
|
||||||
return self.provider_resource_id
|
return self.provider_resource_id
|
||||||
|
|
||||||
|
|
||||||
class BenchmarkInput(CommonBenchmarkFields, BaseModel):
|
class BenchmarkInput(CommonBenchmarkFields, BaseModel):
|
||||||
benchmark_id: str
|
benchmark_id: str
|
||||||
provider_id: Optional[str] = None
|
provider_id: str | None = None
|
||||||
provider_benchmark_id: Optional[str] = None
|
provider_benchmark_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class ListBenchmarksResponse(BaseModel):
|
class ListBenchmarksResponse(BaseModel):
|
||||||
data: List[Benchmark]
|
data: list[Benchmark]
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
class Benchmarks(Protocol):
|
class Benchmarks(Protocol):
|
||||||
@webmethod(route="/eval/benchmarks", method="GET")
|
@webmethod(route="/eval/benchmarks", method="GET")
|
||||||
async def list_benchmarks(self) -> ListBenchmarksResponse: ...
|
async def list_benchmarks(self) -> ListBenchmarksResponse:
|
||||||
|
"""List all benchmarks.
|
||||||
|
|
||||||
|
:returns: A ListBenchmarksResponse.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(route="/eval/benchmarks/{benchmark_id}", method="GET")
|
@webmethod(route="/eval/benchmarks/{benchmark_id}", method="GET")
|
||||||
async def get_benchmark(
|
async def get_benchmark(
|
||||||
self,
|
self,
|
||||||
benchmark_id: str,
|
benchmark_id: str,
|
||||||
) -> Benchmark: ...
|
) -> Benchmark:
|
||||||
|
"""Get a benchmark by its ID.
|
||||||
|
|
||||||
|
:param benchmark_id: The ID of the benchmark to get.
|
||||||
|
:returns: A Benchmark.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(route="/eval/benchmarks", method="POST")
|
@webmethod(route="/eval/benchmarks", method="POST")
|
||||||
async def register_benchmark(
|
async def register_benchmark(
|
||||||
self,
|
self,
|
||||||
benchmark_id: str,
|
benchmark_id: str,
|
||||||
dataset_id: str,
|
dataset_id: str,
|
||||||
scoring_functions: List[str],
|
scoring_functions: list[str],
|
||||||
provider_benchmark_id: Optional[str] = None,
|
provider_benchmark_id: str | None = None,
|
||||||
provider_id: Optional[str] = None,
|
provider_id: str | None = None,
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: dict[str, Any] | None = None,
|
||||||
) -> None: ...
|
) -> None:
|
||||||
|
"""Register a benchmark.
|
||||||
|
|
||||||
|
:param benchmark_id: The ID of the benchmark to register.
|
||||||
|
:param dataset_id: The ID of the dataset to use for the benchmark.
|
||||||
|
:param scoring_functions: The scoring functions to use for the benchmark.
|
||||||
|
:param provider_benchmark_id: The ID of the provider benchmark to use for the benchmark.
|
||||||
|
:param provider_id: The ID of the provider to use for the benchmark.
|
||||||
|
:param metadata: The metadata to use for the benchmark.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Annotated, List, Literal, Optional, Union
|
from typing import Annotated, Literal
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, model_validator
|
from pydantic import BaseModel, Field, model_validator
|
||||||
|
|
||||||
|
@ -26,9 +26,9 @@ class _URLOrData(BaseModel):
|
||||||
:param data: base64 encoded image data as string
|
:param data: base64 encoded image data as string
|
||||||
"""
|
"""
|
||||||
|
|
||||||
url: Optional[URL] = None
|
url: URL | None = None
|
||||||
# data is a base64 encoded string, hint with contentEncoding=base64
|
# data is a base64 encoded string, hint with contentEncoding=base64
|
||||||
data: Optional[str] = Field(contentEncoding="base64", default=None)
|
data: str | None = Field(default=None, json_schema_extra={"contentEncoding": "base64"})
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -64,13 +64,13 @@ class TextContentItem(BaseModel):
|
||||||
|
|
||||||
# other modalities can be added here
|
# other modalities can be added here
|
||||||
InterleavedContentItem = Annotated[
|
InterleavedContentItem = Annotated[
|
||||||
Union[ImageContentItem, TextContentItem],
|
ImageContentItem | TextContentItem,
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
register_schema(InterleavedContentItem, name="InterleavedContentItem")
|
register_schema(InterleavedContentItem, name="InterleavedContentItem")
|
||||||
|
|
||||||
# accept a single "str" as a special case since it is common
|
# accept a single "str" as a special case since it is common
|
||||||
InterleavedContent = Union[str, InterleavedContentItem, List[InterleavedContentItem]]
|
InterleavedContent = str | InterleavedContentItem | list[InterleavedContentItem]
|
||||||
register_schema(InterleavedContent, name="InterleavedContent")
|
register_schema(InterleavedContent, name="InterleavedContent")
|
||||||
|
|
||||||
|
|
||||||
|
@ -100,13 +100,13 @@ class ToolCallDelta(BaseModel):
|
||||||
# you either send an in-progress tool call so the client can stream a long
|
# you either send an in-progress tool call so the client can stream a long
|
||||||
# code generation or you send the final parsed tool call at the end of the
|
# code generation or you send the final parsed tool call at the end of the
|
||||||
# stream
|
# stream
|
||||||
tool_call: Union[str, ToolCall]
|
tool_call: str | ToolCall
|
||||||
parse_status: ToolCallParseStatus
|
parse_status: ToolCallParseStatus
|
||||||
|
|
||||||
|
|
||||||
# streaming completions send a stream of ContentDeltas
|
# streaming completions send a stream of ContentDeltas
|
||||||
ContentDelta = Annotated[
|
ContentDelta = Annotated[
|
||||||
Union[TextDelta, ImageDelta, ToolCallDelta],
|
TextDelta | ImageDelta | ToolCallDelta,
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
register_schema(ContentDelta, name="ContentDelta")
|
register_schema(ContentDelta, name="ContentDelta")
|
||||||
|
|
|
@ -1,30 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from enum import Enum
|
|
||||||
from typing import Any, Dict, Optional
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import URL
|
|
||||||
from llama_stack.schema_utils import json_schema_type
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class RestAPIMethod(Enum):
|
|
||||||
GET = "GET"
|
|
||||||
POST = "POST"
|
|
||||||
PUT = "PUT"
|
|
||||||
DELETE = "DELETE"
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class RestAPIExecutionConfig(BaseModel):
|
|
||||||
url: URL
|
|
||||||
method: RestAPIMethod
|
|
||||||
params: Optional[Dict[str, Any]] = None
|
|
||||||
headers: Optional[Dict[str, Any]] = None
|
|
||||||
body: Optional[Dict[str, Any]] = None
|
|
|
@ -4,13 +4,19 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Any, Dict, List
|
from enum import Enum
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_stack.schema_utils import json_schema_type
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
|
||||||
|
class Order(Enum):
|
||||||
|
asc = "asc"
|
||||||
|
desc = "desc"
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class PaginatedResponse(BaseModel):
|
class PaginatedResponse(BaseModel):
|
||||||
"""A generic paginated response that follows a simple format.
|
"""A generic paginated response that follows a simple format.
|
||||||
|
@ -19,5 +25,5 @@ class PaginatedResponse(BaseModel):
|
||||||
:param has_more: Whether there are more items available after this set
|
:param has_more: Whether there are more items available after this set
|
||||||
"""
|
"""
|
||||||
|
|
||||||
data: List[Dict[str, Any]]
|
data: list[dict[str, Any]]
|
||||||
has_more: bool
|
has_more: bool
|
||||||
|
|
|
@ -5,7 +5,6 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
@ -27,4 +26,4 @@ class Checkpoint(BaseModel):
|
||||||
epoch: int
|
epoch: int
|
||||||
post_training_job_id: str
|
post_training_job_id: str
|
||||||
path: str
|
path: str
|
||||||
training_metrics: Optional[PostTrainingMetric] = None
|
training_metrics: PostTrainingMetric | None = None
|
||||||
|
|
|
@ -4,10 +4,9 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Literal, Union
|
from typing import Annotated, Literal
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing_extensions import Annotated
|
|
||||||
|
|
||||||
from llama_stack.schema_utils import json_schema_type, register_schema
|
from llama_stack.schema_utils import json_schema_type, register_schema
|
||||||
|
|
||||||
|
@ -73,18 +72,16 @@ class DialogType(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
ParamType = Annotated[
|
ParamType = Annotated[
|
||||||
Union[
|
StringType
|
||||||
StringType,
|
| NumberType
|
||||||
NumberType,
|
| BooleanType
|
||||||
BooleanType,
|
| ArrayType
|
||||||
ArrayType,
|
| ObjectType
|
||||||
ObjectType,
|
| JsonType
|
||||||
JsonType,
|
| UnionType
|
||||||
UnionType,
|
| ChatCompletionInputType
|
||||||
ChatCompletionInputType,
|
| CompletionInputType
|
||||||
CompletionInputType,
|
| AgentTurnInputType,
|
||||||
AgentTurnInputType,
|
|
||||||
],
|
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
register_schema(ParamType, name="ParamType")
|
register_schema(ParamType, name="ParamType")
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
|
from typing import Any, Protocol, runtime_checkable
|
||||||
|
|
||||||
from llama_stack.apis.common.responses import PaginatedResponse
|
from llama_stack.apis.common.responses import PaginatedResponse
|
||||||
from llama_stack.apis.datasets import Dataset
|
from llama_stack.apis.datasets import Dataset
|
||||||
|
@ -24,8 +24,8 @@ class DatasetIO(Protocol):
|
||||||
async def iterrows(
|
async def iterrows(
|
||||||
self,
|
self,
|
||||||
dataset_id: str,
|
dataset_id: str,
|
||||||
start_index: Optional[int] = None,
|
start_index: int | None = None,
|
||||||
limit: Optional[int] = None,
|
limit: int | None = None,
|
||||||
) -> PaginatedResponse:
|
) -> PaginatedResponse:
|
||||||
"""Get a paginated list of rows from a dataset.
|
"""Get a paginated list of rows from a dataset.
|
||||||
|
|
||||||
|
@ -34,14 +34,21 @@ class DatasetIO(Protocol):
|
||||||
- limit: Number of items to return. If None or -1, returns all items.
|
- limit: Number of items to return. If None or -1, returns all items.
|
||||||
|
|
||||||
The response includes:
|
The response includes:
|
||||||
- data: List of items for the current page
|
- data: List of items for the current page.
|
||||||
- has_more: Whether there are more items available after this set
|
- has_more: Whether there are more items available after this set.
|
||||||
|
|
||||||
:param dataset_id: The ID of the dataset to get the rows from.
|
:param dataset_id: The ID of the dataset to get the rows from.
|
||||||
:param start_index: Index into dataset for the first row to get. Get all rows if None.
|
:param start_index: Index into dataset for the first row to get. Get all rows if None.
|
||||||
:param limit: The number of rows to get.
|
:param limit: The number of rows to get.
|
||||||
|
:returns: A PaginatedResponse.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@webmethod(route="/datasetio/append-rows/{dataset_id:path}", method="POST")
|
@webmethod(route="/datasetio/append-rows/{dataset_id:path}", method="POST")
|
||||||
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: ...
|
async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None:
|
||||||
|
"""Append rows to a dataset.
|
||||||
|
|
||||||
|
:param dataset_id: The ID of the dataset to append the rows to.
|
||||||
|
:param rows: The rows to append to the dataset.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Annotated, Any, Dict, List, Literal, Optional, Protocol, Union
|
from typing import Annotated, Any, Literal, Protocol
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
@ -81,11 +81,11 @@ class RowsDataSource(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: Literal["rows"] = "rows"
|
type: Literal["rows"] = "rows"
|
||||||
rows: List[Dict[str, Any]]
|
rows: list[dict[str, Any]]
|
||||||
|
|
||||||
|
|
||||||
DataSource = Annotated[
|
DataSource = Annotated[
|
||||||
Union[URIDataSource, RowsDataSource],
|
URIDataSource | RowsDataSource,
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
register_schema(DataSource, name="DataSource")
|
register_schema(DataSource, name="DataSource")
|
||||||
|
@ -98,7 +98,7 @@ class CommonDatasetFields(BaseModel):
|
||||||
|
|
||||||
purpose: DatasetPurpose
|
purpose: DatasetPurpose
|
||||||
source: DataSource
|
source: DataSource
|
||||||
metadata: Dict[str, Any] = Field(
|
metadata: dict[str, Any] = Field(
|
||||||
default_factory=dict,
|
default_factory=dict,
|
||||||
description="Any additional metadata for this dataset",
|
description="Any additional metadata for this dataset",
|
||||||
)
|
)
|
||||||
|
@ -106,14 +106,14 @@ class CommonDatasetFields(BaseModel):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class Dataset(CommonDatasetFields, Resource):
|
class Dataset(CommonDatasetFields, Resource):
|
||||||
type: Literal[ResourceType.dataset.value] = ResourceType.dataset.value
|
type: Literal[ResourceType.dataset] = ResourceType.dataset
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dataset_id(self) -> str:
|
def dataset_id(self) -> str:
|
||||||
return self.identifier
|
return self.identifier
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def provider_dataset_id(self) -> str:
|
def provider_dataset_id(self) -> str | None:
|
||||||
return self.provider_resource_id
|
return self.provider_resource_id
|
||||||
|
|
||||||
|
|
||||||
|
@ -122,7 +122,7 @@ class DatasetInput(CommonDatasetFields, BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class ListDatasetsResponse(BaseModel):
|
class ListDatasetsResponse(BaseModel):
|
||||||
data: List[Dataset]
|
data: list[Dataset]
|
||||||
|
|
||||||
|
|
||||||
class Datasets(Protocol):
|
class Datasets(Protocol):
|
||||||
|
@ -131,13 +131,14 @@ class Datasets(Protocol):
|
||||||
self,
|
self,
|
||||||
purpose: DatasetPurpose,
|
purpose: DatasetPurpose,
|
||||||
source: DataSource,
|
source: DataSource,
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: dict[str, Any] | None = None,
|
||||||
dataset_id: Optional[str] = None,
|
dataset_id: str | None = None,
|
||||||
) -> Dataset:
|
) -> Dataset:
|
||||||
"""
|
"""
|
||||||
Register a new dataset.
|
Register a new dataset.
|
||||||
|
|
||||||
:param purpose: The purpose of the dataset. One of
|
:param purpose: The purpose of the dataset.
|
||||||
|
One of:
|
||||||
- "post-training/messages": The dataset contains a messages column with list of messages for post-training.
|
- "post-training/messages": The dataset contains a messages column with list of messages for post-training.
|
||||||
{
|
{
|
||||||
"messages": [
|
"messages": [
|
||||||
|
@ -188,8 +189,9 @@ class Datasets(Protocol):
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
:param metadata: The metadata for the dataset.
|
:param metadata: The metadata for the dataset.
|
||||||
- E.g. {"description": "My dataset"}
|
- E.g. {"description": "My dataset"}.
|
||||||
:param dataset_id: The ID of the dataset. If not provided, an ID will be generated.
|
:param dataset_id: The ID of the dataset. If not provided, an ID will be generated.
|
||||||
|
:returns: A Dataset.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@ -197,13 +199,29 @@ class Datasets(Protocol):
|
||||||
async def get_dataset(
|
async def get_dataset(
|
||||||
self,
|
self,
|
||||||
dataset_id: str,
|
dataset_id: str,
|
||||||
) -> Dataset: ...
|
) -> Dataset:
|
||||||
|
"""Get a dataset by its ID.
|
||||||
|
|
||||||
|
:param dataset_id: The ID of the dataset to get.
|
||||||
|
:returns: A Dataset.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(route="/datasets", method="GET")
|
@webmethod(route="/datasets", method="GET")
|
||||||
async def list_datasets(self) -> ListDatasetsResponse: ...
|
async def list_datasets(self) -> ListDatasetsResponse:
|
||||||
|
"""List all datasets.
|
||||||
|
|
||||||
|
:returns: A ListDatasetsResponse.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(route="/datasets/{dataset_id:path}", method="DELETE")
|
@webmethod(route="/datasets/{dataset_id:path}", method="DELETE")
|
||||||
async def unregister_dataset(
|
async def unregister_dataset(
|
||||||
self,
|
self,
|
||||||
dataset_id: str,
|
dataset_id: str,
|
||||||
) -> None: ...
|
) -> None:
|
||||||
|
"""Unregister a dataset by its ID.
|
||||||
|
|
||||||
|
:param dataset_id: The ID of the dataset to unregister.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
|
@ -5,7 +5,6 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
@ -54,4 +53,4 @@ class Error(BaseModel):
|
||||||
status: int
|
status: int
|
||||||
title: str
|
title: str
|
||||||
detail: str
|
detail: str
|
||||||
instance: Optional[str] = None
|
instance: str | None = None
|
||||||
|
|
|
@ -4,10 +4,9 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Any, Dict, List, Literal, Optional, Protocol, Union
|
from typing import Annotated, Any, Literal, Protocol
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing_extensions import Annotated
|
|
||||||
|
|
||||||
from llama_stack.apis.agents import AgentConfig
|
from llama_stack.apis.agents import AgentConfig
|
||||||
from llama_stack.apis.common.job_types import Job
|
from llama_stack.apis.common.job_types import Job
|
||||||
|
@ -29,7 +28,7 @@ class ModelCandidate(BaseModel):
|
||||||
type: Literal["model"] = "model"
|
type: Literal["model"] = "model"
|
||||||
model: str
|
model: str
|
||||||
sampling_params: SamplingParams
|
sampling_params: SamplingParams
|
||||||
system_message: Optional[SystemMessage] = None
|
system_message: SystemMessage | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -43,7 +42,7 @@ class AgentCandidate(BaseModel):
|
||||||
config: AgentConfig
|
config: AgentConfig
|
||||||
|
|
||||||
|
|
||||||
EvalCandidate = Annotated[Union[ModelCandidate, AgentCandidate], Field(discriminator="type")]
|
EvalCandidate = Annotated[ModelCandidate | AgentCandidate, Field(discriminator="type")]
|
||||||
register_schema(EvalCandidate, name="EvalCandidate")
|
register_schema(EvalCandidate, name="EvalCandidate")
|
||||||
|
|
||||||
|
|
||||||
|
@ -57,11 +56,11 @@ class BenchmarkConfig(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
eval_candidate: EvalCandidate
|
eval_candidate: EvalCandidate
|
||||||
scoring_params: Dict[str, ScoringFnParams] = Field(
|
scoring_params: dict[str, ScoringFnParams] = Field(
|
||||||
description="Map between scoring function id and parameters for each scoring function you want to run",
|
description="Map between scoring function id and parameters for each scoring function you want to run",
|
||||||
default_factory=dict,
|
default_factory=dict,
|
||||||
)
|
)
|
||||||
num_examples: Optional[int] = Field(
|
num_examples: int | None = Field(
|
||||||
description="Number of examples to evaluate (useful for testing), if not provided, all examples in the dataset will be evaluated",
|
description="Number of examples to evaluate (useful for testing), if not provided, all examples in the dataset will be evaluated",
|
||||||
default=None,
|
default=None,
|
||||||
)
|
)
|
||||||
|
@ -76,9 +75,9 @@ class EvaluateResponse(BaseModel):
|
||||||
:param scores: The scores from the evaluation.
|
:param scores: The scores from the evaluation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
generations: List[Dict[str, Any]]
|
generations: list[dict[str, Any]]
|
||||||
# each key in the dict is a scoring function name
|
# each key in the dict is a scoring function name
|
||||||
scores: Dict[str, ScoringResult]
|
scores: dict[str, ScoringResult]
|
||||||
|
|
||||||
|
|
||||||
class Eval(Protocol):
|
class Eval(Protocol):
|
||||||
|
@ -94,15 +93,16 @@ class Eval(Protocol):
|
||||||
|
|
||||||
:param benchmark_id: The ID of the benchmark to run the evaluation on.
|
:param benchmark_id: The ID of the benchmark to run the evaluation on.
|
||||||
:param benchmark_config: The configuration for the benchmark.
|
:param benchmark_config: The configuration for the benchmark.
|
||||||
:return: The job that was created to run the evaluation.
|
:returns: The job that was created to run the evaluation.
|
||||||
"""
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(route="/eval/benchmarks/{benchmark_id}/evaluations", method="POST")
|
@webmethod(route="/eval/benchmarks/{benchmark_id}/evaluations", method="POST")
|
||||||
async def evaluate_rows(
|
async def evaluate_rows(
|
||||||
self,
|
self,
|
||||||
benchmark_id: str,
|
benchmark_id: str,
|
||||||
input_rows: List[Dict[str, Any]],
|
input_rows: list[dict[str, Any]],
|
||||||
scoring_functions: List[str],
|
scoring_functions: list[str],
|
||||||
benchmark_config: BenchmarkConfig,
|
benchmark_config: BenchmarkConfig,
|
||||||
) -> EvaluateResponse:
|
) -> EvaluateResponse:
|
||||||
"""Evaluate a list of rows on a benchmark.
|
"""Evaluate a list of rows on a benchmark.
|
||||||
|
@ -111,8 +111,9 @@ class Eval(Protocol):
|
||||||
:param input_rows: The rows to evaluate.
|
:param input_rows: The rows to evaluate.
|
||||||
:param scoring_functions: The scoring functions to use for the evaluation.
|
:param scoring_functions: The scoring functions to use for the evaluation.
|
||||||
:param benchmark_config: The configuration for the benchmark.
|
:param benchmark_config: The configuration for the benchmark.
|
||||||
:return: EvaluateResponse object containing generations and scores
|
:returns: EvaluateResponse object containing generations and scores.
|
||||||
"""
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="GET")
|
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="GET")
|
||||||
async def job_status(self, benchmark_id: str, job_id: str) -> Job:
|
async def job_status(self, benchmark_id: str, job_id: str) -> Job:
|
||||||
|
@ -120,7 +121,7 @@ class Eval(Protocol):
|
||||||
|
|
||||||
:param benchmark_id: The ID of the benchmark to run the evaluation on.
|
:param benchmark_id: The ID of the benchmark to run the evaluation on.
|
||||||
:param job_id: The ID of the job to get the status of.
|
:param job_id: The ID of the job to get the status of.
|
||||||
:return: The status of the evaluationjob.
|
:returns: The status of the evaluation job.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@ -139,5 +140,6 @@ class Eval(Protocol):
|
||||||
|
|
||||||
:param benchmark_id: The ID of the benchmark to run the evaluation on.
|
:param benchmark_id: The ID of the benchmark to run the evaluation on.
|
||||||
:param job_id: The ID of the job to get the result of.
|
:param job_id: The ID of the job to get the result of.
|
||||||
:return: The result of the job.
|
:returns: The result of the job.
|
||||||
"""
|
"""
|
||||||
|
...
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import List, Optional, Protocol, runtime_checkable
|
from typing import Protocol, runtime_checkable
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
@ -42,7 +42,7 @@ class ListBucketResponse(BaseModel):
|
||||||
:param data: List of FileResponse entries
|
:param data: List of FileResponse entries
|
||||||
"""
|
"""
|
||||||
|
|
||||||
data: List[BucketResponse]
|
data: list[BucketResponse]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -74,7 +74,7 @@ class ListFileResponse(BaseModel):
|
||||||
:param data: List of FileResponse entries
|
:param data: List of FileResponse entries
|
||||||
"""
|
"""
|
||||||
|
|
||||||
data: List[FileResponse]
|
data: list[FileResponse]
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
|
@ -91,10 +91,11 @@ class Files(Protocol):
|
||||||
"""
|
"""
|
||||||
Create a new upload session for a file identified by a bucket and key.
|
Create a new upload session for a file identified by a bucket and key.
|
||||||
|
|
||||||
:param bucket: Bucket under which the file is stored (valid chars: a-zA-Z0-9_-)
|
:param bucket: Bucket under which the file is stored (valid chars: a-zA-Z0-9_-).
|
||||||
:param key: Key under which the file is stored (valid chars: a-zA-Z0-9_-/.)
|
:param key: Key under which the file is stored (valid chars: a-zA-Z0-9_-/.).
|
||||||
:param mime_type: MIME type of the file
|
:param mime_type: MIME type of the file.
|
||||||
:param size: File size in bytes
|
:param size: File size in bytes.
|
||||||
|
:returns: A FileUploadResponse.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@ -102,12 +103,13 @@ class Files(Protocol):
|
||||||
async def upload_content_to_session(
|
async def upload_content_to_session(
|
||||||
self,
|
self,
|
||||||
upload_id: str,
|
upload_id: str,
|
||||||
) -> Optional[FileResponse]:
|
) -> FileResponse | None:
|
||||||
"""
|
"""
|
||||||
Upload file content to an existing upload session.
|
Upload file content to an existing upload session.
|
||||||
On the server, request body will have the raw bytes that are uploaded.
|
On the server, request body will have the raw bytes that are uploaded.
|
||||||
|
|
||||||
:param upload_id: ID of the upload session
|
:param upload_id: ID of the upload session.
|
||||||
|
:returns: A FileResponse or None if the upload is not complete.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@ -117,9 +119,10 @@ class Files(Protocol):
|
||||||
upload_id: str,
|
upload_id: str,
|
||||||
) -> FileUploadResponse:
|
) -> FileUploadResponse:
|
||||||
"""
|
"""
|
||||||
Returns information about an existsing upload session
|
Returns information about an existsing upload session.
|
||||||
|
|
||||||
:param upload_id: ID of the upload session
|
:param upload_id: ID of the upload session.
|
||||||
|
:returns: A FileUploadResponse.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@ -130,6 +133,9 @@ class Files(Protocol):
|
||||||
) -> ListBucketResponse:
|
) -> ListBucketResponse:
|
||||||
"""
|
"""
|
||||||
List all buckets.
|
List all buckets.
|
||||||
|
|
||||||
|
:param bucket: Bucket name (valid chars: a-zA-Z0-9_-).
|
||||||
|
:returns: A ListBucketResponse.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@ -141,7 +147,8 @@ class Files(Protocol):
|
||||||
"""
|
"""
|
||||||
List all files in a bucket.
|
List all files in a bucket.
|
||||||
|
|
||||||
:param bucket: Bucket name (valid chars: a-zA-Z0-9_-)
|
:param bucket: Bucket name (valid chars: a-zA-Z0-9_-).
|
||||||
|
:returns: A ListFileResponse.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@ -154,8 +161,9 @@ class Files(Protocol):
|
||||||
"""
|
"""
|
||||||
Get a file info identified by a bucket and key.
|
Get a file info identified by a bucket and key.
|
||||||
|
|
||||||
:param bucket: Bucket name (valid chars: a-zA-Z0-9_-)
|
:param bucket: Bucket name (valid chars: a-zA-Z0-9_-).
|
||||||
:param key: Key under which the file is stored (valid chars: a-zA-Z0-9_-/.)
|
:param key: Key under which the file is stored (valid chars: a-zA-Z0-9_-/.).
|
||||||
|
:returns: A FileResponse.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@ -168,7 +176,7 @@ class Files(Protocol):
|
||||||
"""
|
"""
|
||||||
Delete a file identified by a bucket and key.
|
Delete a file identified by a bucket and key.
|
||||||
|
|
||||||
:param bucket: Bucket name (valid chars: a-zA-Z0-9_-)
|
:param bucket: Bucket name (valid chars: a-zA-Z0-9_-).
|
||||||
:param key: Key under which the file is stored (valid chars: a-zA-Z0-9_-/.)
|
:param key: Key under which the file is stored (valid chars: a-zA-Z0-9_-/.).
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
|
@ -4,23 +4,22 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from collections.abc import AsyncIterator
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import (
|
from typing import (
|
||||||
|
Annotated,
|
||||||
Any,
|
Any,
|
||||||
AsyncIterator,
|
|
||||||
Dict,
|
|
||||||
List,
|
|
||||||
Literal,
|
Literal,
|
||||||
Optional,
|
|
||||||
Protocol,
|
Protocol,
|
||||||
Union,
|
|
||||||
runtime_checkable,
|
runtime_checkable,
|
||||||
)
|
)
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, Field, field_validator
|
||||||
from typing_extensions import Annotated, TypedDict
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent, InterleavedContentItem
|
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent, InterleavedContentItem
|
||||||
|
from llama_stack.apis.common.responses import Order
|
||||||
from llama_stack.apis.models import Model
|
from llama_stack.apis.models import Model
|
||||||
from llama_stack.apis.telemetry.telemetry import MetricResponseMixin
|
from llama_stack.apis.telemetry.telemetry import MetricResponseMixin
|
||||||
from llama_stack.models.llama.datatypes import (
|
from llama_stack.models.llama.datatypes import (
|
||||||
|
@ -38,6 +37,16 @@ register_schema(ToolCall)
|
||||||
register_schema(ToolParamDefinition)
|
register_schema(ToolParamDefinition)
|
||||||
register_schema(ToolDefinition)
|
register_schema(ToolDefinition)
|
||||||
|
|
||||||
|
# TODO: use enum.StrEnum when we drop support for python 3.10
|
||||||
|
if sys.version_info >= (3, 11):
|
||||||
|
from enum import StrEnum
|
||||||
|
else:
|
||||||
|
|
||||||
|
class StrEnum(str, Enum):
|
||||||
|
"""Backport of StrEnum for Python 3.10 and below."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class GreedySamplingStrategy(BaseModel):
|
class GreedySamplingStrategy(BaseModel):
|
||||||
|
@ -47,8 +56,8 @@ class GreedySamplingStrategy(BaseModel):
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class TopPSamplingStrategy(BaseModel):
|
class TopPSamplingStrategy(BaseModel):
|
||||||
type: Literal["top_p"] = "top_p"
|
type: Literal["top_p"] = "top_p"
|
||||||
temperature: Optional[float] = Field(..., gt=0.0)
|
temperature: float | None = Field(..., gt=0.0)
|
||||||
top_p: Optional[float] = 0.95
|
top_p: float | None = 0.95
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -58,7 +67,7 @@ class TopKSamplingStrategy(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
SamplingStrategy = Annotated[
|
SamplingStrategy = Annotated[
|
||||||
Union[GreedySamplingStrategy, TopPSamplingStrategy, TopKSamplingStrategy],
|
GreedySamplingStrategy | TopPSamplingStrategy | TopKSamplingStrategy,
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
register_schema(SamplingStrategy, name="SamplingStrategy")
|
register_schema(SamplingStrategy, name="SamplingStrategy")
|
||||||
|
@ -79,9 +88,9 @@ class SamplingParams(BaseModel):
|
||||||
|
|
||||||
strategy: SamplingStrategy = Field(default_factory=GreedySamplingStrategy)
|
strategy: SamplingStrategy = Field(default_factory=GreedySamplingStrategy)
|
||||||
|
|
||||||
max_tokens: Optional[int] = 0
|
max_tokens: int | None = 0
|
||||||
repetition_penalty: Optional[float] = 1.0
|
repetition_penalty: float | None = 1.0
|
||||||
stop: Optional[List[str]] = None
|
stop: list[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
class LogProbConfig(BaseModel):
|
class LogProbConfig(BaseModel):
|
||||||
|
@ -90,7 +99,7 @@ class LogProbConfig(BaseModel):
|
||||||
:param top_k: How many tokens (for each position) to return log probabilities for.
|
:param top_k: How many tokens (for each position) to return log probabilities for.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
top_k: Optional[int] = 0
|
top_k: int | None = 0
|
||||||
|
|
||||||
|
|
||||||
class QuantizationType(Enum):
|
class QuantizationType(Enum):
|
||||||
|
@ -125,11 +134,11 @@ class Int4QuantizationConfig(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: Literal["int4_mixed"] = "int4_mixed"
|
type: Literal["int4_mixed"] = "int4_mixed"
|
||||||
scheme: Optional[str] = "int4_weight_int8_dynamic_activation"
|
scheme: str | None = "int4_weight_int8_dynamic_activation"
|
||||||
|
|
||||||
|
|
||||||
QuantizationConfig = Annotated[
|
QuantizationConfig = Annotated[
|
||||||
Union[Bf16QuantizationConfig, Fp8QuantizationConfig, Int4QuantizationConfig],
|
Bf16QuantizationConfig | Fp8QuantizationConfig | Int4QuantizationConfig,
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -145,7 +154,7 @@ class UserMessage(BaseModel):
|
||||||
|
|
||||||
role: Literal["user"] = "user"
|
role: Literal["user"] = "user"
|
||||||
content: InterleavedContent
|
content: InterleavedContent
|
||||||
context: Optional[InterleavedContent] = None
|
context: InterleavedContent | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -190,16 +199,11 @@ class CompletionMessage(BaseModel):
|
||||||
role: Literal["assistant"] = "assistant"
|
role: Literal["assistant"] = "assistant"
|
||||||
content: InterleavedContent
|
content: InterleavedContent
|
||||||
stop_reason: StopReason
|
stop_reason: StopReason
|
||||||
tool_calls: Optional[List[ToolCall]] = Field(default_factory=list)
|
tool_calls: list[ToolCall] | None = Field(default_factory=lambda: [])
|
||||||
|
|
||||||
|
|
||||||
Message = Annotated[
|
Message = Annotated[
|
||||||
Union[
|
UserMessage | SystemMessage | ToolResponseMessage | CompletionMessage,
|
||||||
UserMessage,
|
|
||||||
SystemMessage,
|
|
||||||
ToolResponseMessage,
|
|
||||||
CompletionMessage,
|
|
||||||
],
|
|
||||||
Field(discriminator="role"),
|
Field(discriminator="role"),
|
||||||
]
|
]
|
||||||
register_schema(Message, name="Message")
|
register_schema(Message, name="Message")
|
||||||
|
@ -208,9 +212,9 @@ register_schema(Message, name="Message")
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ToolResponse(BaseModel):
|
class ToolResponse(BaseModel):
|
||||||
call_id: str
|
call_id: str
|
||||||
tool_name: Union[BuiltinTool, str]
|
tool_name: BuiltinTool | str
|
||||||
content: InterleavedContent
|
content: InterleavedContent
|
||||||
metadata: Optional[Dict[str, Any]] = None
|
metadata: dict[str, Any] | None = None
|
||||||
|
|
||||||
@field_validator("tool_name", mode="before")
|
@field_validator("tool_name", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -243,7 +247,7 @@ class TokenLogProbs(BaseModel):
|
||||||
:param logprobs_by_token: Dictionary mapping tokens to their log probabilities
|
:param logprobs_by_token: Dictionary mapping tokens to their log probabilities
|
||||||
"""
|
"""
|
||||||
|
|
||||||
logprobs_by_token: Dict[str, float]
|
logprobs_by_token: dict[str, float]
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionResponseEventType(Enum):
|
class ChatCompletionResponseEventType(Enum):
|
||||||
|
@ -271,11 +275,11 @@ class ChatCompletionResponseEvent(BaseModel):
|
||||||
|
|
||||||
event_type: ChatCompletionResponseEventType
|
event_type: ChatCompletionResponseEventType
|
||||||
delta: ContentDelta
|
delta: ContentDelta
|
||||||
logprobs: Optional[List[TokenLogProbs]] = None
|
logprobs: list[TokenLogProbs] | None = None
|
||||||
stop_reason: Optional[StopReason] = None
|
stop_reason: StopReason | None = None
|
||||||
|
|
||||||
|
|
||||||
class ResponseFormatType(Enum):
|
class ResponseFormatType(StrEnum):
|
||||||
"""Types of formats for structured (guided) decoding.
|
"""Types of formats for structured (guided) decoding.
|
||||||
|
|
||||||
:cvar json_schema: Response should conform to a JSON schema. In a Python SDK, this is often a `pydantic` model.
|
:cvar json_schema: Response should conform to a JSON schema. In a Python SDK, this is often a `pydantic` model.
|
||||||
|
@ -294,8 +298,8 @@ class JsonSchemaResponseFormat(BaseModel):
|
||||||
:param json_schema: The JSON schema the response should conform to. In a Python SDK, this is often a `pydantic` model.
|
:param json_schema: The JSON schema the response should conform to. In a Python SDK, this is often a `pydantic` model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: Literal[ResponseFormatType.json_schema.value] = ResponseFormatType.json_schema.value
|
type: Literal[ResponseFormatType.json_schema] = ResponseFormatType.json_schema
|
||||||
json_schema: Dict[str, Any]
|
json_schema: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -306,12 +310,12 @@ class GrammarResponseFormat(BaseModel):
|
||||||
:param bnf: The BNF grammar specification the response should conform to
|
:param bnf: The BNF grammar specification the response should conform to
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: Literal[ResponseFormatType.grammar.value] = ResponseFormatType.grammar.value
|
type: Literal[ResponseFormatType.grammar] = ResponseFormatType.grammar
|
||||||
bnf: Dict[str, Any]
|
bnf: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
ResponseFormat = Annotated[
|
ResponseFormat = Annotated[
|
||||||
Union[JsonSchemaResponseFormat, GrammarResponseFormat],
|
JsonSchemaResponseFormat | GrammarResponseFormat,
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
register_schema(ResponseFormat, name="ResponseFormat")
|
register_schema(ResponseFormat, name="ResponseFormat")
|
||||||
|
@ -321,10 +325,10 @@ register_schema(ResponseFormat, name="ResponseFormat")
|
||||||
class CompletionRequest(BaseModel):
|
class CompletionRequest(BaseModel):
|
||||||
model: str
|
model: str
|
||||||
content: InterleavedContent
|
content: InterleavedContent
|
||||||
sampling_params: Optional[SamplingParams] = Field(default_factory=SamplingParams)
|
sampling_params: SamplingParams | None = Field(default_factory=SamplingParams)
|
||||||
response_format: Optional[ResponseFormat] = None
|
response_format: ResponseFormat | None = None
|
||||||
stream: Optional[bool] = False
|
stream: bool | None = False
|
||||||
logprobs: Optional[LogProbConfig] = None
|
logprobs: LogProbConfig | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -338,7 +342,7 @@ class CompletionResponse(MetricResponseMixin):
|
||||||
|
|
||||||
content: str
|
content: str
|
||||||
stop_reason: StopReason
|
stop_reason: StopReason
|
||||||
logprobs: Optional[List[TokenLogProbs]] = None
|
logprobs: list[TokenLogProbs] | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -351,8 +355,8 @@ class CompletionResponseStreamChunk(MetricResponseMixin):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
delta: str
|
delta: str
|
||||||
stop_reason: Optional[StopReason] = None
|
stop_reason: StopReason | None = None
|
||||||
logprobs: Optional[List[TokenLogProbs]] = None
|
logprobs: list[TokenLogProbs] | None = None
|
||||||
|
|
||||||
|
|
||||||
class SystemMessageBehavior(Enum):
|
class SystemMessageBehavior(Enum):
|
||||||
|
@ -383,9 +387,9 @@ class ToolConfig(BaseModel):
|
||||||
'{{function_definitions}}' to indicate where the function definitions should be inserted.
|
'{{function_definitions}}' to indicate where the function definitions should be inserted.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
tool_choice: Optional[ToolChoice | str] = Field(default=ToolChoice.auto)
|
tool_choice: ToolChoice | str | None = Field(default=ToolChoice.auto)
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None)
|
tool_prompt_format: ToolPromptFormat | None = Field(default=None)
|
||||||
system_message_behavior: Optional[SystemMessageBehavior] = Field(default=SystemMessageBehavior.append)
|
system_message_behavior: SystemMessageBehavior | None = Field(default=SystemMessageBehavior.append)
|
||||||
|
|
||||||
def model_post_init(self, __context: Any) -> None:
|
def model_post_init(self, __context: Any) -> None:
|
||||||
if isinstance(self.tool_choice, str):
|
if isinstance(self.tool_choice, str):
|
||||||
|
@ -399,15 +403,15 @@ class ToolConfig(BaseModel):
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ChatCompletionRequest(BaseModel):
|
class ChatCompletionRequest(BaseModel):
|
||||||
model: str
|
model: str
|
||||||
messages: List[Message]
|
messages: list[Message]
|
||||||
sampling_params: Optional[SamplingParams] = Field(default_factory=SamplingParams)
|
sampling_params: SamplingParams | None = Field(default_factory=SamplingParams)
|
||||||
|
|
||||||
tools: Optional[List[ToolDefinition]] = Field(default_factory=list)
|
tools: list[ToolDefinition] | None = Field(default_factory=lambda: [])
|
||||||
tool_config: Optional[ToolConfig] = Field(default_factory=ToolConfig)
|
tool_config: ToolConfig | None = Field(default_factory=ToolConfig)
|
||||||
|
|
||||||
response_format: Optional[ResponseFormat] = None
|
response_format: ResponseFormat | None = None
|
||||||
stream: Optional[bool] = False
|
stream: bool | None = False
|
||||||
logprobs: Optional[LogProbConfig] = None
|
logprobs: LogProbConfig | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -429,7 +433,7 @@ class ChatCompletionResponse(MetricResponseMixin):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
completion_message: CompletionMessage
|
completion_message: CompletionMessage
|
||||||
logprobs: Optional[List[TokenLogProbs]] = None
|
logprobs: list[TokenLogProbs] | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -439,7 +443,7 @@ class EmbeddingsResponse(BaseModel):
|
||||||
:param embeddings: List of embedding vectors, one per input content. Each embedding is a list of floats. The dimensionality of the embedding is model-specific; you can check model metadata using /models/{model_id}
|
:param embeddings: List of embedding vectors, one per input content. Each embedding is a list of floats. The dimensionality of the embedding is model-specific; you can check model metadata using /models/{model_id}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
embeddings: List[List[float]]
|
embeddings: list[list[float]]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -451,7 +455,7 @@ class OpenAIChatCompletionContentPartTextParam(BaseModel):
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class OpenAIImageURL(BaseModel):
|
class OpenAIImageURL(BaseModel):
|
||||||
url: str
|
url: str
|
||||||
detail: Optional[str] = None
|
detail: str | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -461,16 +465,13 @@ class OpenAIChatCompletionContentPartImageParam(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
OpenAIChatCompletionContentPartParam = Annotated[
|
OpenAIChatCompletionContentPartParam = Annotated[
|
||||||
Union[
|
OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam,
|
||||||
OpenAIChatCompletionContentPartTextParam,
|
|
||||||
OpenAIChatCompletionContentPartImageParam,
|
|
||||||
],
|
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
register_schema(OpenAIChatCompletionContentPartParam, name="OpenAIChatCompletionContentPartParam")
|
register_schema(OpenAIChatCompletionContentPartParam, name="OpenAIChatCompletionContentPartParam")
|
||||||
|
|
||||||
|
|
||||||
OpenAIChatCompletionMessageContent = Union[str, List[OpenAIChatCompletionContentPartParam]]
|
OpenAIChatCompletionMessageContent = str | list[OpenAIChatCompletionContentPartParam]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -484,7 +485,7 @@ class OpenAIUserMessageParam(BaseModel):
|
||||||
|
|
||||||
role: Literal["user"] = "user"
|
role: Literal["user"] = "user"
|
||||||
content: OpenAIChatCompletionMessageContent
|
content: OpenAIChatCompletionMessageContent
|
||||||
name: Optional[str] = None
|
name: str | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -498,21 +499,21 @@ class OpenAISystemMessageParam(BaseModel):
|
||||||
|
|
||||||
role: Literal["system"] = "system"
|
role: Literal["system"] = "system"
|
||||||
content: OpenAIChatCompletionMessageContent
|
content: OpenAIChatCompletionMessageContent
|
||||||
name: Optional[str] = None
|
name: str | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class OpenAIChatCompletionToolCallFunction(BaseModel):
|
class OpenAIChatCompletionToolCallFunction(BaseModel):
|
||||||
name: Optional[str] = None
|
name: str | None = None
|
||||||
arguments: Optional[str] = None
|
arguments: str | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class OpenAIChatCompletionToolCall(BaseModel):
|
class OpenAIChatCompletionToolCall(BaseModel):
|
||||||
index: Optional[int] = None
|
index: int | None = None
|
||||||
id: Optional[str] = None
|
id: str | None = None
|
||||||
type: Literal["function"] = "function"
|
type: Literal["function"] = "function"
|
||||||
function: Optional[OpenAIChatCompletionToolCallFunction] = None
|
function: OpenAIChatCompletionToolCallFunction | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -526,9 +527,9 @@ class OpenAIAssistantMessageParam(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
role: Literal["assistant"] = "assistant"
|
role: Literal["assistant"] = "assistant"
|
||||||
content: Optional[OpenAIChatCompletionMessageContent] = None
|
content: OpenAIChatCompletionMessageContent | None = None
|
||||||
name: Optional[str] = None
|
name: str | None = None
|
||||||
tool_calls: Optional[List[OpenAIChatCompletionToolCall]] = None
|
tool_calls: list[OpenAIChatCompletionToolCall] | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -556,17 +557,15 @@ class OpenAIDeveloperMessageParam(BaseModel):
|
||||||
|
|
||||||
role: Literal["developer"] = "developer"
|
role: Literal["developer"] = "developer"
|
||||||
content: OpenAIChatCompletionMessageContent
|
content: OpenAIChatCompletionMessageContent
|
||||||
name: Optional[str] = None
|
name: str | None = None
|
||||||
|
|
||||||
|
|
||||||
OpenAIMessageParam = Annotated[
|
OpenAIMessageParam = Annotated[
|
||||||
Union[
|
OpenAIUserMessageParam
|
||||||
OpenAIUserMessageParam,
|
| OpenAISystemMessageParam
|
||||||
OpenAISystemMessageParam,
|
| OpenAIAssistantMessageParam
|
||||||
OpenAIAssistantMessageParam,
|
| OpenAIToolMessageParam
|
||||||
OpenAIToolMessageParam,
|
| OpenAIDeveloperMessageParam,
|
||||||
OpenAIDeveloperMessageParam,
|
|
||||||
],
|
|
||||||
Field(discriminator="role"),
|
Field(discriminator="role"),
|
||||||
]
|
]
|
||||||
register_schema(OpenAIMessageParam, name="OpenAIMessageParam")
|
register_schema(OpenAIMessageParam, name="OpenAIMessageParam")
|
||||||
|
@ -580,14 +579,14 @@ class OpenAIResponseFormatText(BaseModel):
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class OpenAIJSONSchema(TypedDict, total=False):
|
class OpenAIJSONSchema(TypedDict, total=False):
|
||||||
name: str
|
name: str
|
||||||
description: Optional[str] = None
|
description: str | None
|
||||||
strict: Optional[bool] = None
|
strict: bool | None
|
||||||
|
|
||||||
# Pydantic BaseModel cannot be used with a schema param, since it already
|
# Pydantic BaseModel cannot be used with a schema param, since it already
|
||||||
# has one. And, we don't want to alias here because then have to handle
|
# has one. And, we don't want to alias here because then have to handle
|
||||||
# that alias when converting to OpenAI params. So, to support schema,
|
# that alias when converting to OpenAI params. So, to support schema,
|
||||||
# we use a TypedDict.
|
# we use a TypedDict.
|
||||||
schema: Optional[Dict[str, Any]] = None
|
schema: dict[str, Any] | None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -602,11 +601,7 @@ class OpenAIResponseFormatJSONObject(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
OpenAIResponseFormatParam = Annotated[
|
OpenAIResponseFormatParam = Annotated[
|
||||||
Union[
|
OpenAIResponseFormatText | OpenAIResponseFormatJSONSchema | OpenAIResponseFormatJSONObject,
|
||||||
OpenAIResponseFormatText,
|
|
||||||
OpenAIResponseFormatJSONSchema,
|
|
||||||
OpenAIResponseFormatJSONObject,
|
|
||||||
],
|
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
register_schema(OpenAIResponseFormatParam, name="OpenAIResponseFormatParam")
|
register_schema(OpenAIResponseFormatParam, name="OpenAIResponseFormatParam")
|
||||||
|
@ -622,7 +617,7 @@ class OpenAITopLogProb(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
token: str
|
token: str
|
||||||
bytes: Optional[List[int]] = None
|
bytes: list[int] | None = None
|
||||||
logprob: float
|
logprob: float
|
||||||
|
|
||||||
|
|
||||||
|
@ -637,9 +632,9 @@ class OpenAITokenLogProb(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
token: str
|
token: str
|
||||||
bytes: Optional[List[int]] = None
|
bytes: list[int] | None = None
|
||||||
logprob: float
|
logprob: float
|
||||||
top_logprobs: List[OpenAITopLogProb]
|
top_logprobs: list[OpenAITopLogProb]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -650,8 +645,8 @@ class OpenAIChoiceLogprobs(BaseModel):
|
||||||
:param refusal: (Optional) The log probabilities for the tokens in the message
|
:param refusal: (Optional) The log probabilities for the tokens in the message
|
||||||
"""
|
"""
|
||||||
|
|
||||||
content: Optional[List[OpenAITokenLogProb]] = None
|
content: list[OpenAITokenLogProb] | None = None
|
||||||
refusal: Optional[List[OpenAITokenLogProb]] = None
|
refusal: list[OpenAITokenLogProb] | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -664,10 +659,10 @@ class OpenAIChoiceDelta(BaseModel):
|
||||||
:param tool_calls: (Optional) The tool calls of the delta
|
:param tool_calls: (Optional) The tool calls of the delta
|
||||||
"""
|
"""
|
||||||
|
|
||||||
content: Optional[str] = None
|
content: str | None = None
|
||||||
refusal: Optional[str] = None
|
refusal: str | None = None
|
||||||
role: Optional[str] = None
|
role: str | None = None
|
||||||
tool_calls: Optional[List[OpenAIChatCompletionToolCall]] = None
|
tool_calls: list[OpenAIChatCompletionToolCall] | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -683,7 +678,7 @@ class OpenAIChunkChoice(BaseModel):
|
||||||
delta: OpenAIChoiceDelta
|
delta: OpenAIChoiceDelta
|
||||||
finish_reason: str
|
finish_reason: str
|
||||||
index: int
|
index: int
|
||||||
logprobs: Optional[OpenAIChoiceLogprobs] = None
|
logprobs: OpenAIChoiceLogprobs | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -699,7 +694,7 @@ class OpenAIChoice(BaseModel):
|
||||||
message: OpenAIMessageParam
|
message: OpenAIMessageParam
|
||||||
finish_reason: str
|
finish_reason: str
|
||||||
index: int
|
index: int
|
||||||
logprobs: Optional[OpenAIChoiceLogprobs] = None
|
logprobs: OpenAIChoiceLogprobs | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -714,7 +709,7 @@ class OpenAIChatCompletion(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
id: str
|
id: str
|
||||||
choices: List[OpenAIChoice]
|
choices: list[OpenAIChoice]
|
||||||
object: Literal["chat.completion"] = "chat.completion"
|
object: Literal["chat.completion"] = "chat.completion"
|
||||||
created: int
|
created: int
|
||||||
model: str
|
model: str
|
||||||
|
@ -732,7 +727,7 @@ class OpenAIChatCompletionChunk(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
id: str
|
id: str
|
||||||
choices: List[OpenAIChunkChoice]
|
choices: list[OpenAIChunkChoice]
|
||||||
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
|
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
|
||||||
created: int
|
created: int
|
||||||
model: str
|
model: str
|
||||||
|
@ -748,10 +743,10 @@ class OpenAICompletionLogprobs(BaseModel):
|
||||||
:top_logprobs: (Optional) The top log probabilities for the tokens
|
:top_logprobs: (Optional) The top log probabilities for the tokens
|
||||||
"""
|
"""
|
||||||
|
|
||||||
text_offset: Optional[List[int]] = None
|
text_offset: list[int] | None = None
|
||||||
token_logprobs: Optional[List[float]] = None
|
token_logprobs: list[float] | None = None
|
||||||
tokens: Optional[List[str]] = None
|
tokens: list[str] | None = None
|
||||||
top_logprobs: Optional[List[Dict[str, float]]] = None
|
top_logprobs: list[dict[str, float]] | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -767,7 +762,7 @@ class OpenAICompletionChoice(BaseModel):
|
||||||
finish_reason: str
|
finish_reason: str
|
||||||
text: str
|
text: str
|
||||||
index: int
|
index: int
|
||||||
logprobs: Optional[OpenAIChoiceLogprobs] = None
|
logprobs: OpenAIChoiceLogprobs | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -782,12 +777,54 @@ class OpenAICompletion(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
id: str
|
id: str
|
||||||
choices: List[OpenAICompletionChoice]
|
choices: list[OpenAICompletionChoice]
|
||||||
created: int
|
created: int
|
||||||
model: str
|
model: str
|
||||||
object: Literal["text_completion"] = "text_completion"
|
object: Literal["text_completion"] = "text_completion"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIEmbeddingData(BaseModel):
|
||||||
|
"""A single embedding data object from an OpenAI-compatible embeddings response.
|
||||||
|
|
||||||
|
:param object: The object type, which will be "embedding"
|
||||||
|
:param embedding: The embedding vector as a list of floats (when encoding_format="float") or as a base64-encoded string (when encoding_format="base64")
|
||||||
|
:param index: The index of the embedding in the input list
|
||||||
|
"""
|
||||||
|
|
||||||
|
object: Literal["embedding"] = "embedding"
|
||||||
|
embedding: list[float] | str
|
||||||
|
index: int
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIEmbeddingUsage(BaseModel):
|
||||||
|
"""Usage information for an OpenAI-compatible embeddings response.
|
||||||
|
|
||||||
|
:param prompt_tokens: The number of tokens in the input
|
||||||
|
:param total_tokens: The total number of tokens used
|
||||||
|
"""
|
||||||
|
|
||||||
|
prompt_tokens: int
|
||||||
|
total_tokens: int
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIEmbeddingsResponse(BaseModel):
|
||||||
|
"""Response from an OpenAI-compatible embeddings request.
|
||||||
|
|
||||||
|
:param object: The object type, which will be "list"
|
||||||
|
:param data: List of embedding data objects
|
||||||
|
:param model: The model that was used to generate the embeddings
|
||||||
|
:param usage: Usage information
|
||||||
|
"""
|
||||||
|
|
||||||
|
object: Literal["list"] = "list"
|
||||||
|
data: list[OpenAIEmbeddingData]
|
||||||
|
model: str
|
||||||
|
usage: OpenAIEmbeddingUsage
|
||||||
|
|
||||||
|
|
||||||
class ModelStore(Protocol):
|
class ModelStore(Protocol):
|
||||||
async def get_model(self, identifier: str) -> Model: ...
|
async def get_model(self, identifier: str) -> Model: ...
|
||||||
|
|
||||||
|
@ -818,23 +855,35 @@ class EmbeddingTaskType(Enum):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class BatchCompletionResponse(BaseModel):
|
class BatchCompletionResponse(BaseModel):
|
||||||
batch: List[CompletionResponse]
|
batch: list[CompletionResponse]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class BatchChatCompletionResponse(BaseModel):
|
class BatchChatCompletionResponse(BaseModel):
|
||||||
batch: List[ChatCompletionResponse]
|
batch: list[ChatCompletionResponse]
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAICompletionWithInputMessages(OpenAIChatCompletion):
|
||||||
|
input_messages: list[OpenAIMessageParam]
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ListOpenAIChatCompletionResponse(BaseModel):
|
||||||
|
data: list[OpenAICompletionWithInputMessages]
|
||||||
|
has_more: bool
|
||||||
|
first_id: str
|
||||||
|
last_id: str
|
||||||
|
object: Literal["list"] = "list"
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
@trace_protocol
|
@trace_protocol
|
||||||
class Inference(Protocol):
|
class InferenceProvider(Protocol):
|
||||||
"""Llama Stack Inference API for generating completions, chat completions, and embeddings.
|
|
||||||
|
|
||||||
This API provides the raw interface to the underlying models. Two kinds of models are supported:
|
|
||||||
- LLM models: these models generate "raw" and "chat" (conversational) completions.
|
|
||||||
- Embedding models: these models generate embeddings to be used for semantic search.
|
|
||||||
"""
|
"""
|
||||||
|
This protocol defines the interface that should be implemented by all inference providers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
API_NAMESPACE: str = "Inference"
|
||||||
|
|
||||||
model_store: ModelStore | None = None
|
model_store: ModelStore | None = None
|
||||||
|
|
||||||
|
@ -843,21 +892,21 @@ class Inference(Protocol):
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
content: InterleavedContent,
|
content: InterleavedContent,
|
||||||
sampling_params: Optional[SamplingParams] = None,
|
sampling_params: SamplingParams | None = None,
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: ResponseFormat | None = None,
|
||||||
stream: Optional[bool] = False,
|
stream: bool | None = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: LogProbConfig | None = None,
|
||||||
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]:
|
) -> CompletionResponse | AsyncIterator[CompletionResponseStreamChunk]:
|
||||||
"""Generate a completion for the given content using the specified model.
|
"""Generate a completion for the given content using the specified model.
|
||||||
|
|
||||||
:param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
:param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
||||||
:param content: The content to generate a completion for
|
:param content: The content to generate a completion for.
|
||||||
:param sampling_params: (Optional) Parameters to control the sampling strategy
|
:param sampling_params: (Optional) Parameters to control the sampling strategy.
|
||||||
:param response_format: (Optional) Grammar specification for guided (structured) decoding
|
:param response_format: (Optional) Grammar specification for guided (structured) decoding.
|
||||||
:param stream: (Optional) If True, generate an SSE event stream of the response. Defaults to False.
|
:param stream: (Optional) If True, generate an SSE event stream of the response. Defaults to False.
|
||||||
:param logprobs: (Optional) If specified, log probabilities for each token position will be returned.
|
:param logprobs: (Optional) If specified, log probabilities for each token position will be returned.
|
||||||
:returns: If stream=False, returns a CompletionResponse with the full completion.
|
:returns: If stream=False, returns a CompletionResponse with the full completion.
|
||||||
If stream=True, returns an SSE event stream of CompletionResponseStreamChunk
|
If stream=True, returns an SSE event stream of CompletionResponseStreamChunk.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@ -865,33 +914,42 @@ class Inference(Protocol):
|
||||||
async def batch_completion(
|
async def batch_completion(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
content_batch: List[InterleavedContent],
|
content_batch: list[InterleavedContent],
|
||||||
sampling_params: Optional[SamplingParams] = None,
|
sampling_params: SamplingParams | None = None,
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: ResponseFormat | None = None,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: LogProbConfig | None = None,
|
||||||
) -> BatchCompletionResponse:
|
) -> BatchCompletionResponse:
|
||||||
|
"""Generate completions for a batch of content using the specified model.
|
||||||
|
|
||||||
|
:param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
||||||
|
:param content_batch: The content to generate completions for.
|
||||||
|
:param sampling_params: (Optional) Parameters to control the sampling strategy.
|
||||||
|
:param response_format: (Optional) Grammar specification for guided (structured) decoding.
|
||||||
|
:param logprobs: (Optional) If specified, log probabilities for each token position will be returned.
|
||||||
|
:returns: A BatchCompletionResponse with the full completions.
|
||||||
|
"""
|
||||||
raise NotImplementedError("Batch completion is not implemented")
|
raise NotImplementedError("Batch completion is not implemented")
|
||||||
|
|
||||||
@webmethod(route="/inference/chat-completion", method="POST")
|
@webmethod(route="/inference/chat-completion", method="POST")
|
||||||
async def chat_completion(
|
async def chat_completion(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
messages: List[Message],
|
messages: list[Message],
|
||||||
sampling_params: Optional[SamplingParams] = None,
|
sampling_params: SamplingParams | None = None,
|
||||||
tools: Optional[List[ToolDefinition]] = None,
|
tools: list[ToolDefinition] | None = None,
|
||||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
tool_choice: ToolChoice | None = ToolChoice.auto,
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
tool_prompt_format: ToolPromptFormat | None = None,
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: ResponseFormat | None = None,
|
||||||
stream: Optional[bool] = False,
|
stream: bool | None = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: LogProbConfig | None = None,
|
||||||
tool_config: Optional[ToolConfig] = None,
|
tool_config: ToolConfig | None = None,
|
||||||
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
|
) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]:
|
||||||
"""Generate a chat completion for the given messages using the specified model.
|
"""Generate a chat completion for the given messages using the specified model.
|
||||||
|
|
||||||
:param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
:param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
||||||
:param messages: List of messages in the conversation
|
:param messages: List of messages in the conversation.
|
||||||
:param sampling_params: Parameters to control the sampling strategy
|
:param sampling_params: Parameters to control the sampling strategy.
|
||||||
:param tools: (Optional) List of tool definitions available to the model
|
:param tools: (Optional) List of tool definitions available to the model.
|
||||||
:param tool_choice: (Optional) Whether tool use is required or automatic. Defaults to ToolChoice.auto.
|
:param tool_choice: (Optional) Whether tool use is required or automatic. Defaults to ToolChoice.auto.
|
||||||
.. deprecated::
|
.. deprecated::
|
||||||
Use tool_config instead.
|
Use tool_config instead.
|
||||||
|
@ -908,7 +966,7 @@ class Inference(Protocol):
|
||||||
:param logprobs: (Optional) If specified, log probabilities for each token position will be returned.
|
:param logprobs: (Optional) If specified, log probabilities for each token position will be returned.
|
||||||
:param tool_config: (Optional) Configuration for tool use.
|
:param tool_config: (Optional) Configuration for tool use.
|
||||||
:returns: If stream=False, returns a ChatCompletionResponse with the full completion.
|
:returns: If stream=False, returns a ChatCompletionResponse with the full completion.
|
||||||
If stream=True, returns an SSE event stream of ChatCompletionResponseStreamChunk
|
If stream=True, returns an SSE event stream of ChatCompletionResponseStreamChunk.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@ -916,23 +974,34 @@ class Inference(Protocol):
|
||||||
async def batch_chat_completion(
|
async def batch_chat_completion(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
messages_batch: List[List[Message]],
|
messages_batch: list[list[Message]],
|
||||||
sampling_params: Optional[SamplingParams] = None,
|
sampling_params: SamplingParams | None = None,
|
||||||
tools: Optional[List[ToolDefinition]] = None,
|
tools: list[ToolDefinition] | None = None,
|
||||||
tool_config: Optional[ToolConfig] = None,
|
tool_config: ToolConfig | None = None,
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: ResponseFormat | None = None,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: LogProbConfig | None = None,
|
||||||
) -> BatchChatCompletionResponse:
|
) -> BatchChatCompletionResponse:
|
||||||
|
"""Generate chat completions for a batch of messages using the specified model.
|
||||||
|
|
||||||
|
:param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
||||||
|
:param messages_batch: The messages to generate completions for.
|
||||||
|
:param sampling_params: (Optional) Parameters to control the sampling strategy.
|
||||||
|
:param tools: (Optional) List of tool definitions available to the model.
|
||||||
|
:param tool_config: (Optional) Configuration for tool use.
|
||||||
|
:param response_format: (Optional) Grammar specification for guided (structured) decoding.
|
||||||
|
:param logprobs: (Optional) If specified, log probabilities for each token position will be returned.
|
||||||
|
:returns: A BatchChatCompletionResponse with the full completions.
|
||||||
|
"""
|
||||||
raise NotImplementedError("Batch chat completion is not implemented")
|
raise NotImplementedError("Batch chat completion is not implemented")
|
||||||
|
|
||||||
@webmethod(route="/inference/embeddings", method="POST")
|
@webmethod(route="/inference/embeddings", method="POST")
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
contents: List[str] | List[InterleavedContentItem],
|
contents: list[str] | list[InterleavedContentItem],
|
||||||
text_truncation: Optional[TextTruncation] = TextTruncation.none,
|
text_truncation: TextTruncation | None = TextTruncation.none,
|
||||||
output_dimension: Optional[int] = None,
|
output_dimension: int | None = None,
|
||||||
task_type: Optional[EmbeddingTaskType] = None,
|
task_type: EmbeddingTaskType | None = None,
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
"""Generate embeddings for content pieces using the specified model.
|
"""Generate embeddings for content pieces using the specified model.
|
||||||
|
|
||||||
|
@ -941,7 +1010,7 @@ class Inference(Protocol):
|
||||||
:param output_dimension: (Optional) Output dimensionality for the embeddings. Only supported by Matryoshka models.
|
:param output_dimension: (Optional) Output dimensionality for the embeddings. Only supported by Matryoshka models.
|
||||||
:param text_truncation: (Optional) Config for how to truncate text for embedding when text is longer than the model's max sequence length.
|
:param text_truncation: (Optional) Config for how to truncate text for embedding when text is longer than the model's max sequence length.
|
||||||
:param task_type: (Optional) How is the embedding being used? This is only supported by asymmetric embedding models.
|
:param task_type: (Optional) How is the embedding being used? This is only supported by asymmetric embedding models.
|
||||||
:returns: An array of embeddings, one for each content. Each embedding is a list of floats. The dimensionality of the embedding is model-specific; you can check model metadata using /models/{model_id}
|
:returns: An array of embeddings, one for each content. Each embedding is a list of floats. The dimensionality of the embedding is model-specific; you can check model metadata using /models/{model_id}.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@ -950,45 +1019,46 @@ class Inference(Protocol):
|
||||||
self,
|
self,
|
||||||
# Standard OpenAI completion parameters
|
# Standard OpenAI completion parameters
|
||||||
model: str,
|
model: str,
|
||||||
prompt: Union[str, List[str], List[int], List[List[int]]],
|
prompt: str | list[str] | list[int] | list[list[int]],
|
||||||
best_of: Optional[int] = None,
|
best_of: int | None = None,
|
||||||
echo: Optional[bool] = None,
|
echo: bool | None = None,
|
||||||
frequency_penalty: Optional[float] = None,
|
frequency_penalty: float | None = None,
|
||||||
logit_bias: Optional[Dict[str, float]] = None,
|
logit_bias: dict[str, float] | None = None,
|
||||||
logprobs: Optional[bool] = None,
|
logprobs: bool | None = None,
|
||||||
max_tokens: Optional[int] = None,
|
max_tokens: int | None = None,
|
||||||
n: Optional[int] = None,
|
n: int | None = None,
|
||||||
presence_penalty: Optional[float] = None,
|
presence_penalty: float | None = None,
|
||||||
seed: Optional[int] = None,
|
seed: int | None = None,
|
||||||
stop: Optional[Union[str, List[str]]] = None,
|
stop: str | list[str] | None = None,
|
||||||
stream: Optional[bool] = None,
|
stream: bool | None = None,
|
||||||
stream_options: Optional[Dict[str, Any]] = None,
|
stream_options: dict[str, Any] | None = None,
|
||||||
temperature: Optional[float] = None,
|
temperature: float | None = None,
|
||||||
top_p: Optional[float] = None,
|
top_p: float | None = None,
|
||||||
user: Optional[str] = None,
|
user: str | None = None,
|
||||||
# vLLM-specific parameters
|
# vLLM-specific parameters
|
||||||
guided_choice: Optional[List[str]] = None,
|
guided_choice: list[str] | None = None,
|
||||||
prompt_logprobs: Optional[int] = None,
|
prompt_logprobs: int | None = None,
|
||||||
) -> OpenAICompletion:
|
) -> OpenAICompletion:
|
||||||
"""Generate an OpenAI-compatible completion for the given prompt using the specified model.
|
"""Generate an OpenAI-compatible completion for the given prompt using the specified model.
|
||||||
|
|
||||||
:param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
:param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
||||||
:param prompt: The prompt to generate a completion for
|
:param prompt: The prompt to generate a completion for.
|
||||||
:param best_of: (Optional) The number of completions to generate
|
:param best_of: (Optional) The number of completions to generate.
|
||||||
:param echo: (Optional) Whether to echo the prompt
|
:param echo: (Optional) Whether to echo the prompt.
|
||||||
:param frequency_penalty: (Optional) The penalty for repeated tokens
|
:param frequency_penalty: (Optional) The penalty for repeated tokens.
|
||||||
:param logit_bias: (Optional) The logit bias to use
|
:param logit_bias: (Optional) The logit bias to use.
|
||||||
:param logprobs: (Optional) The log probabilities to use
|
:param logprobs: (Optional) The log probabilities to use.
|
||||||
:param max_tokens: (Optional) The maximum number of tokens to generate
|
:param max_tokens: (Optional) The maximum number of tokens to generate.
|
||||||
:param n: (Optional) The number of completions to generate
|
:param n: (Optional) The number of completions to generate.
|
||||||
:param presence_penalty: (Optional) The penalty for repeated tokens
|
:param presence_penalty: (Optional) The penalty for repeated tokens.
|
||||||
:param seed: (Optional) The seed to use
|
:param seed: (Optional) The seed to use.
|
||||||
:param stop: (Optional) The stop tokens to use
|
:param stop: (Optional) The stop tokens to use.
|
||||||
:param stream: (Optional) Whether to stream the response
|
:param stream: (Optional) Whether to stream the response.
|
||||||
:param stream_options: (Optional) The stream options to use
|
:param stream_options: (Optional) The stream options to use.
|
||||||
:param temperature: (Optional) The temperature to use
|
:param temperature: (Optional) The temperature to use.
|
||||||
:param top_p: (Optional) The top p to use
|
:param top_p: (Optional) The top p to use.
|
||||||
:param user: (Optional) The user to use
|
:param user: (Optional) The user to use.
|
||||||
|
:returns: An OpenAICompletion.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@ -996,53 +1066,110 @@ class Inference(Protocol):
|
||||||
async def openai_chat_completion(
|
async def openai_chat_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[OpenAIMessageParam],
|
messages: list[OpenAIMessageParam],
|
||||||
frequency_penalty: Optional[float] = None,
|
frequency_penalty: float | None = None,
|
||||||
function_call: Optional[Union[str, Dict[str, Any]]] = None,
|
function_call: str | dict[str, Any] | None = None,
|
||||||
functions: Optional[List[Dict[str, Any]]] = None,
|
functions: list[dict[str, Any]] | None = None,
|
||||||
logit_bias: Optional[Dict[str, float]] = None,
|
logit_bias: dict[str, float] | None = None,
|
||||||
logprobs: Optional[bool] = None,
|
logprobs: bool | None = None,
|
||||||
max_completion_tokens: Optional[int] = None,
|
max_completion_tokens: int | None = None,
|
||||||
max_tokens: Optional[int] = None,
|
max_tokens: int | None = None,
|
||||||
n: Optional[int] = None,
|
n: int | None = None,
|
||||||
parallel_tool_calls: Optional[bool] = None,
|
parallel_tool_calls: bool | None = None,
|
||||||
presence_penalty: Optional[float] = None,
|
presence_penalty: float | None = None,
|
||||||
response_format: Optional[OpenAIResponseFormatParam] = None,
|
response_format: OpenAIResponseFormatParam | None = None,
|
||||||
seed: Optional[int] = None,
|
seed: int | None = None,
|
||||||
stop: Optional[Union[str, List[str]]] = None,
|
stop: str | list[str] | None = None,
|
||||||
stream: Optional[bool] = None,
|
stream: bool | None = None,
|
||||||
stream_options: Optional[Dict[str, Any]] = None,
|
stream_options: dict[str, Any] | None = None,
|
||||||
temperature: Optional[float] = None,
|
temperature: float | None = None,
|
||||||
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
|
tool_choice: str | dict[str, Any] | None = None,
|
||||||
tools: Optional[List[Dict[str, Any]]] = None,
|
tools: list[dict[str, Any]] | None = None,
|
||||||
top_logprobs: Optional[int] = None,
|
top_logprobs: int | None = None,
|
||||||
top_p: Optional[float] = None,
|
top_p: float | None = None,
|
||||||
user: Optional[str] = None,
|
user: str | None = None,
|
||||||
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
|
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||||
"""Generate an OpenAI-compatible chat completion for the given messages using the specified model.
|
"""Generate an OpenAI-compatible chat completion for the given messages using the specified model.
|
||||||
|
|
||||||
:param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
:param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
||||||
:param messages: List of messages in the conversation
|
:param messages: List of messages in the conversation.
|
||||||
:param frequency_penalty: (Optional) The penalty for repeated tokens
|
:param frequency_penalty: (Optional) The penalty for repeated tokens.
|
||||||
:param function_call: (Optional) The function call to use
|
:param function_call: (Optional) The function call to use.
|
||||||
:param functions: (Optional) List of functions to use
|
:param functions: (Optional) List of functions to use.
|
||||||
:param logit_bias: (Optional) The logit bias to use
|
:param logit_bias: (Optional) The logit bias to use.
|
||||||
:param logprobs: (Optional) The log probabilities to use
|
:param logprobs: (Optional) The log probabilities to use.
|
||||||
:param max_completion_tokens: (Optional) The maximum number of tokens to generate
|
:param max_completion_tokens: (Optional) The maximum number of tokens to generate.
|
||||||
:param max_tokens: (Optional) The maximum number of tokens to generate
|
:param max_tokens: (Optional) The maximum number of tokens to generate.
|
||||||
:param n: (Optional) The number of completions to generate
|
:param n: (Optional) The number of completions to generate.
|
||||||
:param parallel_tool_calls: (Optional) Whether to parallelize tool calls
|
:param parallel_tool_calls: (Optional) Whether to parallelize tool calls.
|
||||||
:param presence_penalty: (Optional) The penalty for repeated tokens
|
:param presence_penalty: (Optional) The penalty for repeated tokens.
|
||||||
:param response_format: (Optional) The response format to use
|
:param response_format: (Optional) The response format to use.
|
||||||
:param seed: (Optional) The seed to use
|
:param seed: (Optional) The seed to use.
|
||||||
:param stop: (Optional) The stop tokens to use
|
:param stop: (Optional) The stop tokens to use.
|
||||||
:param stream: (Optional) Whether to stream the response
|
:param stream: (Optional) Whether to stream the response.
|
||||||
:param stream_options: (Optional) The stream options to use
|
:param stream_options: (Optional) The stream options to use.
|
||||||
:param temperature: (Optional) The temperature to use
|
:param temperature: (Optional) The temperature to use.
|
||||||
:param tool_choice: (Optional) The tool choice to use
|
:param tool_choice: (Optional) The tool choice to use.
|
||||||
:param tools: (Optional) The tools to use
|
:param tools: (Optional) The tools to use.
|
||||||
:param top_logprobs: (Optional) The top log probabilities to use
|
:param top_logprobs: (Optional) The top log probabilities to use.
|
||||||
:param top_p: (Optional) The top p to use
|
:param top_p: (Optional) The top p to use.
|
||||||
:param user: (Optional) The user to use
|
:param user: (Optional) The user to use.
|
||||||
|
:returns: An OpenAIChatCompletion.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@webmethod(route="/openai/v1/embeddings", method="POST")
|
||||||
|
async def openai_embeddings(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
input: str | list[str],
|
||||||
|
encoding_format: str | None = "float",
|
||||||
|
dimensions: int | None = None,
|
||||||
|
user: str | None = None,
|
||||||
|
) -> OpenAIEmbeddingsResponse:
|
||||||
|
"""Generate OpenAI-compatible embeddings for the given input using the specified model.
|
||||||
|
|
||||||
|
:param model: The identifier of the model to use. The model must be an embedding model registered with Llama Stack and available via the /models endpoint.
|
||||||
|
:param input: Input text to embed, encoded as a string or array of strings. To embed multiple inputs in a single request, pass an array of strings.
|
||||||
|
:param encoding_format: (Optional) The format to return the embeddings in. Can be either "float" or "base64". Defaults to "float".
|
||||||
|
:param dimensions: (Optional) The number of dimensions the resulting output embeddings should have. Only supported in text-embedding-3 and later models.
|
||||||
|
:param user: (Optional) A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse.
|
||||||
|
:returns: An OpenAIEmbeddingsResponse containing the embeddings.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class Inference(InferenceProvider):
|
||||||
|
"""Llama Stack Inference API for generating completions, chat completions, and embeddings.
|
||||||
|
|
||||||
|
This API provides the raw interface to the underlying models. Two kinds of models are supported:
|
||||||
|
- LLM models: these models generate "raw" and "chat" (conversational) completions.
|
||||||
|
- Embedding models: these models generate embeddings to be used for semantic search.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@webmethod(route="/openai/v1/chat/completions", method="GET")
|
||||||
|
async def list_chat_completions(
|
||||||
|
self,
|
||||||
|
after: str | None = None,
|
||||||
|
limit: int | None = 20,
|
||||||
|
model: str | None = None,
|
||||||
|
order: Order | None = Order.desc,
|
||||||
|
) -> ListOpenAIChatCompletionResponse:
|
||||||
|
"""List all chat completions.
|
||||||
|
|
||||||
|
:param after: The ID of the last chat completion to return.
|
||||||
|
:param limit: The maximum number of chat completions to return.
|
||||||
|
:param model: The model to filter by.
|
||||||
|
:param order: The order to sort the chat completions by: "asc" or "desc". Defaults to "desc".
|
||||||
|
:returns: A ListOpenAIChatCompletionResponse.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("List chat completions is not implemented")
|
||||||
|
|
||||||
|
@webmethod(route="/openai/v1/chat/completions/{completion_id}", method="GET")
|
||||||
|
async def get_chat_completion(self, completion_id: str) -> OpenAICompletionWithInputMessages:
|
||||||
|
"""Describe a chat completion by its ID.
|
||||||
|
|
||||||
|
:param completion_id: ID of the chat completion.
|
||||||
|
:returns: A OpenAICompletionWithInputMessages.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("Get chat completion is not implemented")
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import List, Protocol, runtime_checkable
|
from typing import Protocol, runtime_checkable
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
@ -16,7 +16,7 @@ from llama_stack.schema_utils import json_schema_type, webmethod
|
||||||
class RouteInfo(BaseModel):
|
class RouteInfo(BaseModel):
|
||||||
route: str
|
route: str
|
||||||
method: str
|
method: str
|
||||||
provider_types: List[str]
|
provider_types: list[str]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -30,16 +30,31 @@ class VersionInfo(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class ListRoutesResponse(BaseModel):
|
class ListRoutesResponse(BaseModel):
|
||||||
data: List[RouteInfo]
|
data: list[RouteInfo]
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
class Inspect(Protocol):
|
class Inspect(Protocol):
|
||||||
@webmethod(route="/inspect/routes", method="GET")
|
@webmethod(route="/inspect/routes", method="GET")
|
||||||
async def list_routes(self) -> ListRoutesResponse: ...
|
async def list_routes(self) -> ListRoutesResponse:
|
||||||
|
"""List all routes.
|
||||||
|
|
||||||
|
:returns: A ListRoutesResponse.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(route="/health", method="GET")
|
@webmethod(route="/health", method="GET")
|
||||||
async def health(self) -> HealthInfo: ...
|
async def health(self) -> HealthInfo:
|
||||||
|
"""Get the health of the service.
|
||||||
|
|
||||||
|
:returns: A HealthInfo.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(route="/version", method="GET")
|
@webmethod(route="/version", method="GET")
|
||||||
async def version(self) -> VersionInfo: ...
|
async def version(self) -> VersionInfo:
|
||||||
|
"""Get the version of the service.
|
||||||
|
|
||||||
|
:returns: A VersionInfo.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
|
from typing import Any, Literal, Protocol, runtime_checkable
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
|
@ -15,7 +15,7 @@ from llama_stack.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
|
|
||||||
class CommonModelFields(BaseModel):
|
class CommonModelFields(BaseModel):
|
||||||
metadata: Dict[str, Any] = Field(
|
metadata: dict[str, Any] = Field(
|
||||||
default_factory=dict,
|
default_factory=dict,
|
||||||
description="Any additional metadata for this model",
|
description="Any additional metadata for this model",
|
||||||
)
|
)
|
||||||
|
@ -29,14 +29,14 @@ class ModelType(str, Enum):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class Model(CommonModelFields, Resource):
|
class Model(CommonModelFields, Resource):
|
||||||
type: Literal[ResourceType.model.value] = ResourceType.model.value
|
type: Literal[ResourceType.model] = ResourceType.model
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def model_id(self) -> str:
|
def model_id(self) -> str:
|
||||||
return self.identifier
|
return self.identifier
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def provider_model_id(self) -> str:
|
def provider_model_id(self) -> str | None:
|
||||||
return self.provider_resource_id
|
return self.provider_resource_id
|
||||||
|
|
||||||
model_config = ConfigDict(protected_namespaces=())
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
@ -46,14 +46,14 @@ class Model(CommonModelFields, Resource):
|
||||||
|
|
||||||
class ModelInput(CommonModelFields):
|
class ModelInput(CommonModelFields):
|
||||||
model_id: str
|
model_id: str
|
||||||
provider_id: Optional[str] = None
|
provider_id: str | None = None
|
||||||
provider_model_id: Optional[str] = None
|
provider_model_id: str | None = None
|
||||||
model_type: Optional[ModelType] = ModelType.llm
|
model_type: ModelType | None = ModelType.llm
|
||||||
model_config = ConfigDict(protected_namespaces=())
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
|
||||||
|
|
||||||
class ListModelsResponse(BaseModel):
|
class ListModelsResponse(BaseModel):
|
||||||
data: List[Model]
|
data: list[Model]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -73,36 +73,67 @@ class OpenAIModel(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class OpenAIListModelsResponse(BaseModel):
|
class OpenAIListModelsResponse(BaseModel):
|
||||||
data: List[OpenAIModel]
|
data: list[OpenAIModel]
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
@trace_protocol
|
@trace_protocol
|
||||||
class Models(Protocol):
|
class Models(Protocol):
|
||||||
@webmethod(route="/models", method="GET")
|
@webmethod(route="/models", method="GET")
|
||||||
async def list_models(self) -> ListModelsResponse: ...
|
async def list_models(self) -> ListModelsResponse:
|
||||||
|
"""List all models.
|
||||||
|
|
||||||
|
:returns: A ListModelsResponse.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(route="/openai/v1/models", method="GET")
|
@webmethod(route="/openai/v1/models", method="GET")
|
||||||
async def openai_list_models(self) -> OpenAIListModelsResponse: ...
|
async def openai_list_models(self) -> OpenAIListModelsResponse:
|
||||||
|
"""List models using the OpenAI API.
|
||||||
|
|
||||||
|
:returns: A OpenAIListModelsResponse.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(route="/models/{model_id:path}", method="GET")
|
@webmethod(route="/models/{model_id:path}", method="GET")
|
||||||
async def get_model(
|
async def get_model(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
) -> Model: ...
|
) -> Model:
|
||||||
|
"""Get a model by its identifier.
|
||||||
|
|
||||||
|
:param model_id: The identifier of the model to get.
|
||||||
|
:returns: A Model.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(route="/models", method="POST")
|
@webmethod(route="/models", method="POST")
|
||||||
async def register_model(
|
async def register_model(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
provider_model_id: Optional[str] = None,
|
provider_model_id: str | None = None,
|
||||||
provider_id: Optional[str] = None,
|
provider_id: str | None = None,
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: dict[str, Any] | None = None,
|
||||||
model_type: Optional[ModelType] = None,
|
model_type: ModelType | None = None,
|
||||||
) -> Model: ...
|
) -> Model:
|
||||||
|
"""Register a model.
|
||||||
|
|
||||||
|
:param model_id: The identifier of the model to register.
|
||||||
|
:param provider_model_id: The identifier of the model in the provider.
|
||||||
|
:param provider_id: The identifier of the provider.
|
||||||
|
:param metadata: Any additional metadata for this model.
|
||||||
|
:param model_type: The type of model to register.
|
||||||
|
:returns: A Model.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(route="/models/{model_id:path}", method="DELETE")
|
@webmethod(route="/models/{model_id:path}", method="DELETE")
|
||||||
async def unregister_model(
|
async def unregister_model(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
) -> None: ...
|
) -> None:
|
||||||
|
"""Unregister a model.
|
||||||
|
|
||||||
|
:param model_id: The identifier of the model to unregister.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
|
@ -6,10 +6,9 @@
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, Literal, Optional, Protocol, Union
|
from typing import Annotated, Any, Literal, Protocol
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing_extensions import Annotated
|
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import URL
|
from llama_stack.apis.common.content_types import URL
|
||||||
from llama_stack.apis.common.job_types import JobStatus
|
from llama_stack.apis.common.job_types import JobStatus
|
||||||
|
@ -36,9 +35,9 @@ class DataConfig(BaseModel):
|
||||||
batch_size: int
|
batch_size: int
|
||||||
shuffle: bool
|
shuffle: bool
|
||||||
data_format: DatasetFormat
|
data_format: DatasetFormat
|
||||||
validation_dataset_id: Optional[str] = None
|
validation_dataset_id: str | None = None
|
||||||
packed: Optional[bool] = False
|
packed: bool | None = False
|
||||||
train_on_input: Optional[bool] = False
|
train_on_input: bool | None = False
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -51,10 +50,10 @@ class OptimizerConfig(BaseModel):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class EfficiencyConfig(BaseModel):
|
class EfficiencyConfig(BaseModel):
|
||||||
enable_activation_checkpointing: Optional[bool] = False
|
enable_activation_checkpointing: bool | None = False
|
||||||
enable_activation_offloading: Optional[bool] = False
|
enable_activation_offloading: bool | None = False
|
||||||
memory_efficient_fsdp_wrap: Optional[bool] = False
|
memory_efficient_fsdp_wrap: bool | None = False
|
||||||
fsdp_cpu_offload: Optional[bool] = False
|
fsdp_cpu_offload: bool | None = False
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -62,23 +61,23 @@ class TrainingConfig(BaseModel):
|
||||||
n_epochs: int
|
n_epochs: int
|
||||||
max_steps_per_epoch: int = 1
|
max_steps_per_epoch: int = 1
|
||||||
gradient_accumulation_steps: int = 1
|
gradient_accumulation_steps: int = 1
|
||||||
max_validation_steps: Optional[int] = 1
|
max_validation_steps: int | None = 1
|
||||||
data_config: Optional[DataConfig] = None
|
data_config: DataConfig | None = None
|
||||||
optimizer_config: Optional[OptimizerConfig] = None
|
optimizer_config: OptimizerConfig | None = None
|
||||||
efficiency_config: Optional[EfficiencyConfig] = None
|
efficiency_config: EfficiencyConfig | None = None
|
||||||
dtype: Optional[str] = "bf16"
|
dtype: str | None = "bf16"
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class LoraFinetuningConfig(BaseModel):
|
class LoraFinetuningConfig(BaseModel):
|
||||||
type: Literal["LoRA"] = "LoRA"
|
type: Literal["LoRA"] = "LoRA"
|
||||||
lora_attn_modules: List[str]
|
lora_attn_modules: list[str]
|
||||||
apply_lora_to_mlp: bool
|
apply_lora_to_mlp: bool
|
||||||
apply_lora_to_output: bool
|
apply_lora_to_output: bool
|
||||||
rank: int
|
rank: int
|
||||||
alpha: int
|
alpha: int
|
||||||
use_dora: Optional[bool] = False
|
use_dora: bool | None = False
|
||||||
quantize_base: Optional[bool] = False
|
quantize_base: bool | None = False
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -88,7 +87,7 @@ class QATFinetuningConfig(BaseModel):
|
||||||
group_size: int
|
group_size: int
|
||||||
|
|
||||||
|
|
||||||
AlgorithmConfig = Annotated[Union[LoraFinetuningConfig, QATFinetuningConfig], Field(discriminator="type")]
|
AlgorithmConfig = Annotated[LoraFinetuningConfig | QATFinetuningConfig, Field(discriminator="type")]
|
||||||
register_schema(AlgorithmConfig, name="AlgorithmConfig")
|
register_schema(AlgorithmConfig, name="AlgorithmConfig")
|
||||||
|
|
||||||
|
|
||||||
|
@ -97,7 +96,7 @@ class PostTrainingJobLogStream(BaseModel):
|
||||||
"""Stream of logs from a finetuning job."""
|
"""Stream of logs from a finetuning job."""
|
||||||
|
|
||||||
job_uuid: str
|
job_uuid: str
|
||||||
log_lines: List[str]
|
log_lines: list[str]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -131,8 +130,8 @@ class PostTrainingRLHFRequest(BaseModel):
|
||||||
training_config: TrainingConfig
|
training_config: TrainingConfig
|
||||||
|
|
||||||
# TODO: define these
|
# TODO: define these
|
||||||
hyperparam_search_config: Dict[str, Any]
|
hyperparam_search_config: dict[str, Any]
|
||||||
logger_config: Dict[str, Any]
|
logger_config: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
class PostTrainingJob(BaseModel):
|
class PostTrainingJob(BaseModel):
|
||||||
|
@ -146,17 +145,17 @@ class PostTrainingJobStatusResponse(BaseModel):
|
||||||
job_uuid: str
|
job_uuid: str
|
||||||
status: JobStatus
|
status: JobStatus
|
||||||
|
|
||||||
scheduled_at: Optional[datetime] = None
|
scheduled_at: datetime | None = None
|
||||||
started_at: Optional[datetime] = None
|
started_at: datetime | None = None
|
||||||
completed_at: Optional[datetime] = None
|
completed_at: datetime | None = None
|
||||||
|
|
||||||
resources_allocated: Optional[Dict[str, Any]] = None
|
resources_allocated: dict[str, Any] | None = None
|
||||||
|
|
||||||
checkpoints: List[Checkpoint] = Field(default_factory=list)
|
checkpoints: list[Checkpoint] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
class ListPostTrainingJobsResponse(BaseModel):
|
class ListPostTrainingJobsResponse(BaseModel):
|
||||||
data: List[PostTrainingJob]
|
data: list[PostTrainingJob]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -164,7 +163,7 @@ class PostTrainingJobArtifactsResponse(BaseModel):
|
||||||
"""Artifacts of a finetuning job."""
|
"""Artifacts of a finetuning job."""
|
||||||
|
|
||||||
job_uuid: str
|
job_uuid: str
|
||||||
checkpoints: List[Checkpoint] = Field(default_factory=list)
|
checkpoints: list[Checkpoint] = Field(default_factory=list)
|
||||||
|
|
||||||
# TODO(ashwin): metrics, evals
|
# TODO(ashwin): metrics, evals
|
||||||
|
|
||||||
|
@ -175,15 +174,27 @@ class PostTraining(Protocol):
|
||||||
self,
|
self,
|
||||||
job_uuid: str,
|
job_uuid: str,
|
||||||
training_config: TrainingConfig,
|
training_config: TrainingConfig,
|
||||||
hyperparam_search_config: Dict[str, Any],
|
hyperparam_search_config: dict[str, Any],
|
||||||
logger_config: Dict[str, Any],
|
logger_config: dict[str, Any],
|
||||||
model: Optional[str] = Field(
|
model: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Model descriptor for training if not in provider config`",
|
description="Model descriptor for training if not in provider config`",
|
||||||
),
|
),
|
||||||
checkpoint_dir: Optional[str] = None,
|
checkpoint_dir: str | None = None,
|
||||||
algorithm_config: Optional[AlgorithmConfig] = None,
|
algorithm_config: AlgorithmConfig | None = None,
|
||||||
) -> PostTrainingJob: ...
|
) -> PostTrainingJob:
|
||||||
|
"""Run supervised fine-tuning of a model.
|
||||||
|
|
||||||
|
:param job_uuid: The UUID of the job to create.
|
||||||
|
:param training_config: The training configuration.
|
||||||
|
:param hyperparam_search_config: The hyperparam search configuration.
|
||||||
|
:param logger_config: The logger configuration.
|
||||||
|
:param model: The model to fine-tune.
|
||||||
|
:param checkpoint_dir: The directory to save checkpoint(s) to.
|
||||||
|
:param algorithm_config: The algorithm configuration.
|
||||||
|
:returns: A PostTrainingJob.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(route="/post-training/preference-optimize", method="POST")
|
@webmethod(route="/post-training/preference-optimize", method="POST")
|
||||||
async def preference_optimize(
|
async def preference_optimize(
|
||||||
|
@ -192,18 +203,51 @@ class PostTraining(Protocol):
|
||||||
finetuned_model: str,
|
finetuned_model: str,
|
||||||
algorithm_config: DPOAlignmentConfig,
|
algorithm_config: DPOAlignmentConfig,
|
||||||
training_config: TrainingConfig,
|
training_config: TrainingConfig,
|
||||||
hyperparam_search_config: Dict[str, Any],
|
hyperparam_search_config: dict[str, Any],
|
||||||
logger_config: Dict[str, Any],
|
logger_config: dict[str, Any],
|
||||||
) -> PostTrainingJob: ...
|
) -> PostTrainingJob:
|
||||||
|
"""Run preference optimization of a model.
|
||||||
|
|
||||||
|
:param job_uuid: The UUID of the job to create.
|
||||||
|
:param finetuned_model: The model to fine-tune.
|
||||||
|
:param algorithm_config: The algorithm configuration.
|
||||||
|
:param training_config: The training configuration.
|
||||||
|
:param hyperparam_search_config: The hyperparam search configuration.
|
||||||
|
:param logger_config: The logger configuration.
|
||||||
|
:returns: A PostTrainingJob.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(route="/post-training/jobs", method="GET")
|
@webmethod(route="/post-training/jobs", method="GET")
|
||||||
async def get_training_jobs(self) -> ListPostTrainingJobsResponse: ...
|
async def get_training_jobs(self) -> ListPostTrainingJobsResponse:
|
||||||
|
"""Get all training jobs.
|
||||||
|
|
||||||
|
:returns: A ListPostTrainingJobsResponse.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(route="/post-training/job/status", method="GET")
|
@webmethod(route="/post-training/job/status", method="GET")
|
||||||
async def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse: ...
|
async def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse:
|
||||||
|
"""Get the status of a training job.
|
||||||
|
|
||||||
|
:param job_uuid: The UUID of the job to get the status of.
|
||||||
|
:returns: A PostTrainingJobStatusResponse.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(route="/post-training/job/cancel", method="POST")
|
@webmethod(route="/post-training/job/cancel", method="POST")
|
||||||
async def cancel_training_job(self, job_uuid: str) -> None: ...
|
async def cancel_training_job(self, job_uuid: str) -> None:
|
||||||
|
"""Cancel a training job.
|
||||||
|
|
||||||
|
:param job_uuid: The UUID of the job to cancel.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(route="/post-training/job/artifacts", method="GET")
|
@webmethod(route="/post-training/job/artifacts", method="GET")
|
||||||
async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse: ...
|
async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse:
|
||||||
|
"""Get the artifacts of a training job.
|
||||||
|
|
||||||
|
:param job_uuid: The UUID of the job to get the artifacts of.
|
||||||
|
:returns: A PostTrainingJobArtifactsResponse.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Any, Dict, List, Protocol, runtime_checkable
|
from typing import Any, Protocol, runtime_checkable
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
@ -17,12 +17,12 @@ class ProviderInfo(BaseModel):
|
||||||
api: str
|
api: str
|
||||||
provider_id: str
|
provider_id: str
|
||||||
provider_type: str
|
provider_type: str
|
||||||
config: Dict[str, Any]
|
config: dict[str, Any]
|
||||||
health: HealthResponse
|
health: HealthResponse
|
||||||
|
|
||||||
|
|
||||||
class ListProvidersResponse(BaseModel):
|
class ListProvidersResponse(BaseModel):
|
||||||
data: List[ProviderInfo]
|
data: list[ProviderInfo]
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
|
@ -32,7 +32,18 @@ class Providers(Protocol):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@webmethod(route="/providers", method="GET")
|
@webmethod(route="/providers", method="GET")
|
||||||
async def list_providers(self) -> ListProvidersResponse: ...
|
async def list_providers(self) -> ListProvidersResponse:
|
||||||
|
"""List all available providers.
|
||||||
|
|
||||||
|
:returns: A ListProvidersResponse containing information about all providers.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(route="/providers/{provider_id}", method="GET")
|
@webmethod(route="/providers/{provider_id}", method="GET")
|
||||||
async def inspect_provider(self, provider_id: str) -> ProviderInfo: ...
|
async def inspect_provider(self, provider_id: str) -> ProviderInfo:
|
||||||
|
"""Get detailed information about a specific provider.
|
||||||
|
|
||||||
|
:param provider_id: The ID of the provider to inspect.
|
||||||
|
:returns: A ProviderInfo object containing the provider's details.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
|
@ -4,12 +4,23 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import sys
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
# TODO: use enum.StrEnum when we drop support for python 3.10
|
||||||
|
if sys.version_info >= (3, 11):
|
||||||
|
from enum import StrEnum
|
||||||
|
else:
|
||||||
|
|
||||||
class ResourceType(Enum):
|
class StrEnum(str, Enum):
|
||||||
|
"""Backport of StrEnum for Python 3.10 and below."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ResourceType(StrEnum):
|
||||||
model = "model"
|
model = "model"
|
||||||
shield = "shield"
|
shield = "shield"
|
||||||
vector_db = "vector_db"
|
vector_db = "vector_db"
|
||||||
|
@ -25,9 +36,9 @@ class Resource(BaseModel):
|
||||||
|
|
||||||
identifier: str = Field(description="Unique identifier for this resource in llama stack")
|
identifier: str = Field(description="Unique identifier for this resource in llama stack")
|
||||||
|
|
||||||
provider_resource_id: str = Field(
|
provider_resource_id: str | None = Field(
|
||||||
description="Unique identifier for this resource in the provider",
|
|
||||||
default=None,
|
default=None,
|
||||||
|
description="Unique identifier for this resource in the provider",
|
||||||
)
|
)
|
||||||
|
|
||||||
provider_id: str = Field(description="ID of the provider that owns this resource")
|
provider_id: str = Field(description="ID of the provider that owns this resource")
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
|
from typing import Any, Protocol, runtime_checkable
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
@ -27,16 +27,16 @@ class SafetyViolation(BaseModel):
|
||||||
violation_level: ViolationLevel
|
violation_level: ViolationLevel
|
||||||
|
|
||||||
# what message should you convey to the user
|
# what message should you convey to the user
|
||||||
user_message: Optional[str] = None
|
user_message: str | None = None
|
||||||
|
|
||||||
# additional metadata (including specific violation codes) more for
|
# additional metadata (including specific violation codes) more for
|
||||||
# debugging, telemetry
|
# debugging, telemetry
|
||||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class RunShieldResponse(BaseModel):
|
class RunShieldResponse(BaseModel):
|
||||||
violation: Optional[SafetyViolation] = None
|
violation: SafetyViolation | None = None
|
||||||
|
|
||||||
|
|
||||||
class ShieldStore(Protocol):
|
class ShieldStore(Protocol):
|
||||||
|
@ -52,6 +52,14 @@ class Safety(Protocol):
|
||||||
async def run_shield(
|
async def run_shield(
|
||||||
self,
|
self,
|
||||||
shield_id: str,
|
shield_id: str,
|
||||||
messages: List[Message],
|
messages: list[Message],
|
||||||
params: Dict[str, Any] = None,
|
params: dict[str, Any],
|
||||||
) -> RunShieldResponse: ...
|
) -> RunShieldResponse:
|
||||||
|
"""Run a shield.
|
||||||
|
|
||||||
|
:param shield_id: The identifier of the shield to run.
|
||||||
|
:param messages: The messages to run the shield on.
|
||||||
|
:param params: The parameters of the shield.
|
||||||
|
:returns: A RunShieldResponse.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
|
from typing import Any, Protocol, runtime_checkable
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
@ -12,7 +12,7 @@ from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams
|
||||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
# mapping of metric to value
|
# mapping of metric to value
|
||||||
ScoringResultRow = Dict[str, Any]
|
ScoringResultRow = dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -24,15 +24,15 @@ class ScoringResult(BaseModel):
|
||||||
:param aggregated_results: Map of metric name to aggregated value
|
:param aggregated_results: Map of metric name to aggregated value
|
||||||
"""
|
"""
|
||||||
|
|
||||||
score_rows: List[ScoringResultRow]
|
score_rows: list[ScoringResultRow]
|
||||||
# aggregated metrics to value
|
# aggregated metrics to value
|
||||||
aggregated_results: Dict[str, Any]
|
aggregated_results: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ScoreBatchResponse(BaseModel):
|
class ScoreBatchResponse(BaseModel):
|
||||||
dataset_id: Optional[str] = None
|
dataset_id: str | None = None
|
||||||
results: Dict[str, ScoringResult]
|
results: dict[str, ScoringResult]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -44,7 +44,7 @@ class ScoreResponse(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# each key in the dict is a scoring function name
|
# each key in the dict is a scoring function name
|
||||||
results: Dict[str, ScoringResult]
|
results: dict[str, ScoringResult]
|
||||||
|
|
||||||
|
|
||||||
class ScoringFunctionStore(Protocol):
|
class ScoringFunctionStore(Protocol):
|
||||||
|
@ -59,20 +59,28 @@ class Scoring(Protocol):
|
||||||
async def score_batch(
|
async def score_batch(
|
||||||
self,
|
self,
|
||||||
dataset_id: str,
|
dataset_id: str,
|
||||||
scoring_functions: Dict[str, Optional[ScoringFnParams]],
|
scoring_functions: dict[str, ScoringFnParams | None],
|
||||||
save_results_dataset: bool = False,
|
save_results_dataset: bool = False,
|
||||||
) -> ScoreBatchResponse: ...
|
) -> ScoreBatchResponse:
|
||||||
|
"""Score a batch of rows.
|
||||||
|
|
||||||
|
:param dataset_id: The ID of the dataset to score.
|
||||||
|
:param scoring_functions: The scoring functions to use for the scoring.
|
||||||
|
:param save_results_dataset: Whether to save the results to a dataset.
|
||||||
|
:returns: A ScoreBatchResponse.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(route="/scoring/score", method="POST")
|
@webmethod(route="/scoring/score", method="POST")
|
||||||
async def score(
|
async def score(
|
||||||
self,
|
self,
|
||||||
input_rows: List[Dict[str, Any]],
|
input_rows: list[dict[str, Any]],
|
||||||
scoring_functions: Dict[str, Optional[ScoringFnParams]],
|
scoring_functions: dict[str, ScoringFnParams | None],
|
||||||
) -> ScoreResponse:
|
) -> ScoreResponse:
|
||||||
"""Score a list of rows.
|
"""Score a list of rows.
|
||||||
|
|
||||||
:param input_rows: The rows to score.
|
:param input_rows: The rows to score.
|
||||||
:param scoring_functions: The scoring functions to use for the scoring.
|
:param scoring_functions: The scoring functions to use for the scoring.
|
||||||
:return: ScoreResponse object containing rows and aggregated results
|
:returns: A ScoreResponse object containing rows and aggregated results.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
|
@ -4,37 +4,44 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
# TODO: use enum.StrEnum when we drop support for python 3.10
|
||||||
|
import sys
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import (
|
from typing import (
|
||||||
|
Annotated,
|
||||||
Any,
|
Any,
|
||||||
Dict,
|
|
||||||
List,
|
|
||||||
Literal,
|
Literal,
|
||||||
Optional,
|
|
||||||
Protocol,
|
Protocol,
|
||||||
Union,
|
|
||||||
runtime_checkable,
|
runtime_checkable,
|
||||||
)
|
)
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing_extensions import Annotated
|
|
||||||
|
|
||||||
from llama_stack.apis.common.type_system import ParamType
|
from llama_stack.apis.common.type_system import ParamType
|
||||||
from llama_stack.apis.resource import Resource, ResourceType
|
from llama_stack.apis.resource import Resource, ResourceType
|
||||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||||
|
|
||||||
|
if sys.version_info >= (3, 11):
|
||||||
|
from enum import StrEnum
|
||||||
|
else:
|
||||||
|
|
||||||
|
class StrEnum(str, Enum):
|
||||||
|
"""Backport of StrEnum for Python 3.10 and below."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
# Perhaps more structure can be imposed on these functions. Maybe they could be associated
|
# Perhaps more structure can be imposed on these functions. Maybe they could be associated
|
||||||
# with standard metrics so they can be rolled up?
|
# with standard metrics so they can be rolled up?
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ScoringFnParamsType(Enum):
|
class ScoringFnParamsType(StrEnum):
|
||||||
llm_as_judge = "llm_as_judge"
|
llm_as_judge = "llm_as_judge"
|
||||||
regex_parser = "regex_parser"
|
regex_parser = "regex_parser"
|
||||||
basic = "basic"
|
basic = "basic"
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class AggregationFunctionType(Enum):
|
class AggregationFunctionType(StrEnum):
|
||||||
average = "average"
|
average = "average"
|
||||||
weighted_average = "weighted_average"
|
weighted_average = "weighted_average"
|
||||||
median = "median"
|
median = "median"
|
||||||
|
@ -44,62 +51,58 @@ class AggregationFunctionType(Enum):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class LLMAsJudgeScoringFnParams(BaseModel):
|
class LLMAsJudgeScoringFnParams(BaseModel):
|
||||||
type: Literal[ScoringFnParamsType.llm_as_judge.value] = ScoringFnParamsType.llm_as_judge.value
|
type: Literal[ScoringFnParamsType.llm_as_judge] = ScoringFnParamsType.llm_as_judge
|
||||||
judge_model: str
|
judge_model: str
|
||||||
prompt_template: Optional[str] = None
|
prompt_template: str | None = None
|
||||||
judge_score_regexes: Optional[List[str]] = Field(
|
judge_score_regexes: list[str] = Field(
|
||||||
description="Regexes to extract the answer from generated response",
|
description="Regexes to extract the answer from generated response",
|
||||||
default_factory=list,
|
default_factory=lambda: [],
|
||||||
)
|
)
|
||||||
aggregation_functions: Optional[List[AggregationFunctionType]] = Field(
|
aggregation_functions: list[AggregationFunctionType] = Field(
|
||||||
description="Aggregation functions to apply to the scores of each row",
|
description="Aggregation functions to apply to the scores of each row",
|
||||||
default_factory=list,
|
default_factory=lambda: [],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class RegexParserScoringFnParams(BaseModel):
|
class RegexParserScoringFnParams(BaseModel):
|
||||||
type: Literal[ScoringFnParamsType.regex_parser.value] = ScoringFnParamsType.regex_parser.value
|
type: Literal[ScoringFnParamsType.regex_parser] = ScoringFnParamsType.regex_parser
|
||||||
parsing_regexes: Optional[List[str]] = Field(
|
parsing_regexes: list[str] = Field(
|
||||||
description="Regex to extract the answer from generated response",
|
description="Regex to extract the answer from generated response",
|
||||||
default_factory=list,
|
default_factory=lambda: [],
|
||||||
)
|
)
|
||||||
aggregation_functions: Optional[List[AggregationFunctionType]] = Field(
|
aggregation_functions: list[AggregationFunctionType] = Field(
|
||||||
description="Aggregation functions to apply to the scores of each row",
|
description="Aggregation functions to apply to the scores of each row",
|
||||||
default_factory=list,
|
default_factory=lambda: [],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class BasicScoringFnParams(BaseModel):
|
class BasicScoringFnParams(BaseModel):
|
||||||
type: Literal[ScoringFnParamsType.basic.value] = ScoringFnParamsType.basic.value
|
type: Literal[ScoringFnParamsType.basic] = ScoringFnParamsType.basic
|
||||||
aggregation_functions: Optional[List[AggregationFunctionType]] = Field(
|
aggregation_functions: list[AggregationFunctionType] = Field(
|
||||||
description="Aggregation functions to apply to the scores of each row",
|
description="Aggregation functions to apply to the scores of each row",
|
||||||
default_factory=list,
|
default_factory=list,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
ScoringFnParams = Annotated[
|
ScoringFnParams = Annotated[
|
||||||
Union[
|
LLMAsJudgeScoringFnParams | RegexParserScoringFnParams | BasicScoringFnParams,
|
||||||
LLMAsJudgeScoringFnParams,
|
|
||||||
RegexParserScoringFnParams,
|
|
||||||
BasicScoringFnParams,
|
|
||||||
],
|
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
register_schema(ScoringFnParams, name="ScoringFnParams")
|
register_schema(ScoringFnParams, name="ScoringFnParams")
|
||||||
|
|
||||||
|
|
||||||
class CommonScoringFnFields(BaseModel):
|
class CommonScoringFnFields(BaseModel):
|
||||||
description: Optional[str] = None
|
description: str | None = None
|
||||||
metadata: Dict[str, Any] = Field(
|
metadata: dict[str, Any] = Field(
|
||||||
default_factory=dict,
|
default_factory=dict,
|
||||||
description="Any additional metadata for this definition",
|
description="Any additional metadata for this definition",
|
||||||
)
|
)
|
||||||
return_type: ParamType = Field(
|
return_type: ParamType = Field(
|
||||||
description="The return type of the deterministic function",
|
description="The return type of the deterministic function",
|
||||||
)
|
)
|
||||||
params: Optional[ScoringFnParams] = Field(
|
params: ScoringFnParams | None = Field(
|
||||||
description="The parameters for the scoring function for benchmark eval, these can be overridden for app eval",
|
description="The parameters for the scoring function for benchmark eval, these can be overridden for app eval",
|
||||||
default=None,
|
default=None,
|
||||||
)
|
)
|
||||||
|
@ -107,34 +110,45 @@ class CommonScoringFnFields(BaseModel):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ScoringFn(CommonScoringFnFields, Resource):
|
class ScoringFn(CommonScoringFnFields, Resource):
|
||||||
type: Literal[ResourceType.scoring_function.value] = ResourceType.scoring_function.value
|
type: Literal[ResourceType.scoring_function] = ResourceType.scoring_function
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def scoring_fn_id(self) -> str:
|
def scoring_fn_id(self) -> str:
|
||||||
return self.identifier
|
return self.identifier
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def provider_scoring_fn_id(self) -> str:
|
def provider_scoring_fn_id(self) -> str | None:
|
||||||
return self.provider_resource_id
|
return self.provider_resource_id
|
||||||
|
|
||||||
|
|
||||||
class ScoringFnInput(CommonScoringFnFields, BaseModel):
|
class ScoringFnInput(CommonScoringFnFields, BaseModel):
|
||||||
scoring_fn_id: str
|
scoring_fn_id: str
|
||||||
provider_id: Optional[str] = None
|
provider_id: str | None = None
|
||||||
provider_scoring_fn_id: Optional[str] = None
|
provider_scoring_fn_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class ListScoringFunctionsResponse(BaseModel):
|
class ListScoringFunctionsResponse(BaseModel):
|
||||||
data: List[ScoringFn]
|
data: list[ScoringFn]
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
class ScoringFunctions(Protocol):
|
class ScoringFunctions(Protocol):
|
||||||
@webmethod(route="/scoring-functions", method="GET")
|
@webmethod(route="/scoring-functions", method="GET")
|
||||||
async def list_scoring_functions(self) -> ListScoringFunctionsResponse: ...
|
async def list_scoring_functions(self) -> ListScoringFunctionsResponse:
|
||||||
|
"""List all scoring functions.
|
||||||
|
|
||||||
|
:returns: A ListScoringFunctionsResponse.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(route="/scoring-functions/{scoring_fn_id:path}", method="GET")
|
@webmethod(route="/scoring-functions/{scoring_fn_id:path}", method="GET")
|
||||||
async def get_scoring_function(self, scoring_fn_id: str, /) -> ScoringFn: ...
|
async def get_scoring_function(self, scoring_fn_id: str, /) -> ScoringFn:
|
||||||
|
"""Get a scoring function by its ID.
|
||||||
|
|
||||||
|
:param scoring_fn_id: The ID of the scoring function to get.
|
||||||
|
:returns: A ScoringFn.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(route="/scoring-functions", method="POST")
|
@webmethod(route="/scoring-functions", method="POST")
|
||||||
async def register_scoring_function(
|
async def register_scoring_function(
|
||||||
|
@ -142,7 +156,17 @@ class ScoringFunctions(Protocol):
|
||||||
scoring_fn_id: str,
|
scoring_fn_id: str,
|
||||||
description: str,
|
description: str,
|
||||||
return_type: ParamType,
|
return_type: ParamType,
|
||||||
provider_scoring_fn_id: Optional[str] = None,
|
provider_scoring_fn_id: str | None = None,
|
||||||
provider_id: Optional[str] = None,
|
provider_id: str | None = None,
|
||||||
params: Optional[ScoringFnParams] = None,
|
params: ScoringFnParams | None = None,
|
||||||
) -> None: ...
|
) -> None:
|
||||||
|
"""Register a scoring function.
|
||||||
|
|
||||||
|
:param scoring_fn_id: The ID of the scoring function to register.
|
||||||
|
:param description: The description of the scoring function.
|
||||||
|
:param return_type: The return type of the scoring function.
|
||||||
|
:param provider_scoring_fn_id: The ID of the provider scoring function to use for the scoring function.
|
||||||
|
:param provider_id: The ID of the provider to use for the scoring function.
|
||||||
|
:param params: The parameters for the scoring function for benchmark eval, these can be overridden for app eval.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
|
from typing import Any, Literal, Protocol, runtime_checkable
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
@ -14,48 +14,68 @@ from llama_stack.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
|
|
||||||
class CommonShieldFields(BaseModel):
|
class CommonShieldFields(BaseModel):
|
||||||
params: Optional[Dict[str, Any]] = None
|
params: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class Shield(CommonShieldFields, Resource):
|
class Shield(CommonShieldFields, Resource):
|
||||||
"""A safety shield resource that can be used to check content"""
|
"""A safety shield resource that can be used to check content"""
|
||||||
|
|
||||||
type: Literal[ResourceType.shield.value] = ResourceType.shield.value
|
type: Literal[ResourceType.shield] = ResourceType.shield
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def shield_id(self) -> str:
|
def shield_id(self) -> str:
|
||||||
return self.identifier
|
return self.identifier
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def provider_shield_id(self) -> str:
|
def provider_shield_id(self) -> str | None:
|
||||||
return self.provider_resource_id
|
return self.provider_resource_id
|
||||||
|
|
||||||
|
|
||||||
class ShieldInput(CommonShieldFields):
|
class ShieldInput(CommonShieldFields):
|
||||||
shield_id: str
|
shield_id: str
|
||||||
provider_id: Optional[str] = None
|
provider_id: str | None = None
|
||||||
provider_shield_id: Optional[str] = None
|
provider_shield_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class ListShieldsResponse(BaseModel):
|
class ListShieldsResponse(BaseModel):
|
||||||
data: List[Shield]
|
data: list[Shield]
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
@trace_protocol
|
@trace_protocol
|
||||||
class Shields(Protocol):
|
class Shields(Protocol):
|
||||||
@webmethod(route="/shields", method="GET")
|
@webmethod(route="/shields", method="GET")
|
||||||
async def list_shields(self) -> ListShieldsResponse: ...
|
async def list_shields(self) -> ListShieldsResponse:
|
||||||
|
"""List all shields.
|
||||||
|
|
||||||
|
:returns: A ListShieldsResponse.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(route="/shields/{identifier:path}", method="GET")
|
@webmethod(route="/shields/{identifier:path}", method="GET")
|
||||||
async def get_shield(self, identifier: str) -> Shield: ...
|
async def get_shield(self, identifier: str) -> Shield:
|
||||||
|
"""Get a shield by its identifier.
|
||||||
|
|
||||||
|
:param identifier: The identifier of the shield to get.
|
||||||
|
:returns: A Shield.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(route="/shields", method="POST")
|
@webmethod(route="/shields", method="POST")
|
||||||
async def register_shield(
|
async def register_shield(
|
||||||
self,
|
self,
|
||||||
shield_id: str,
|
shield_id: str,
|
||||||
provider_shield_id: Optional[str] = None,
|
provider_shield_id: str | None = None,
|
||||||
provider_id: Optional[str] = None,
|
provider_id: str | None = None,
|
||||||
params: Optional[Dict[str, Any]] = None,
|
params: dict[str, Any] | None = None,
|
||||||
) -> Shield: ...
|
) -> Shield:
|
||||||
|
"""Register a shield.
|
||||||
|
|
||||||
|
:param shield_id: The identifier of the shield to register.
|
||||||
|
:param provider_shield_id: The identifier of the shield in the provider.
|
||||||
|
:param provider_id: The identifier of the provider.
|
||||||
|
:param params: The parameters of the shield.
|
||||||
|
:returns: A Shield.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, Optional, Protocol, Union
|
from typing import Any, Protocol
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
@ -28,24 +28,24 @@ class FilteringFunction(Enum):
|
||||||
class SyntheticDataGenerationRequest(BaseModel):
|
class SyntheticDataGenerationRequest(BaseModel):
|
||||||
"""Request to generate synthetic data. A small batch of prompts and a filtering function"""
|
"""Request to generate synthetic data. A small batch of prompts and a filtering function"""
|
||||||
|
|
||||||
dialogs: List[Message]
|
dialogs: list[Message]
|
||||||
filtering_function: FilteringFunction = FilteringFunction.none
|
filtering_function: FilteringFunction = FilteringFunction.none
|
||||||
model: Optional[str] = None
|
model: str | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class SyntheticDataGenerationResponse(BaseModel):
|
class SyntheticDataGenerationResponse(BaseModel):
|
||||||
"""Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold."""
|
"""Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold."""
|
||||||
|
|
||||||
synthetic_data: List[Dict[str, Any]]
|
synthetic_data: list[dict[str, Any]]
|
||||||
statistics: Optional[Dict[str, Any]] = None
|
statistics: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
class SyntheticDataGeneration(Protocol):
|
class SyntheticDataGeneration(Protocol):
|
||||||
@webmethod(route="/synthetic-data-generation/generate")
|
@webmethod(route="/synthetic-data-generation/generate")
|
||||||
def synthetic_data_generate(
|
def synthetic_data_generate(
|
||||||
self,
|
self,
|
||||||
dialogs: List[Message],
|
dialogs: list[Message],
|
||||||
filtering_function: FilteringFunction = FilteringFunction.none,
|
filtering_function: FilteringFunction = FilteringFunction.none,
|
||||||
model: Optional[str] = None,
|
model: str | None = None,
|
||||||
) -> Union[SyntheticDataGenerationResponse]: ...
|
) -> SyntheticDataGenerationResponse: ...
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue