mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
Merge branch 'main' into chroma
This commit is contained in:
commit
11c71c958e
308 changed files with 26415 additions and 11807 deletions
60
.github/actions/run-and-record-tests/action.yml
vendored
60
.github/actions/run-and-record-tests/action.yml
vendored
|
@ -2,26 +2,28 @@ name: 'Run and Record Tests'
|
|||
description: 'Run integration tests and handle recording/artifact upload'
|
||||
|
||||
inputs:
|
||||
test-subdirs:
|
||||
description: 'Comma-separated list of test subdirectories to run'
|
||||
required: true
|
||||
test-pattern:
|
||||
description: 'Regex pattern to pass to pytest -k'
|
||||
required: false
|
||||
default: ''
|
||||
stack-config:
|
||||
description: 'Stack configuration to use'
|
||||
required: true
|
||||
provider:
|
||||
description: 'Provider to use for tests'
|
||||
required: true
|
||||
setup:
|
||||
description: 'Setup to use for tests (e.g., ollama, gpt, vllm)'
|
||||
required: false
|
||||
default: ''
|
||||
inference-mode:
|
||||
description: 'Inference mode (record or replay)'
|
||||
required: true
|
||||
run-vision-tests:
|
||||
description: 'Whether to run vision tests'
|
||||
suite:
|
||||
description: 'Test suite to use: base, responses, vision, etc.'
|
||||
required: false
|
||||
default: 'false'
|
||||
default: ''
|
||||
subdirs:
|
||||
description: 'Comma-separated list of test subdirectories to run; overrides suite'
|
||||
required: false
|
||||
default: ''
|
||||
pattern:
|
||||
description: 'Regex pattern to pass to pytest -k'
|
||||
required: false
|
||||
default: ''
|
||||
|
||||
runs:
|
||||
using: 'composite'
|
||||
|
@ -36,14 +38,23 @@ runs:
|
|||
- name: Run Integration Tests
|
||||
shell: bash
|
||||
run: |
|
||||
uv run --no-sync ./scripts/integration-tests.sh \
|
||||
--stack-config '${{ inputs.stack-config }}' \
|
||||
--provider '${{ inputs.provider }}' \
|
||||
--test-subdirs '${{ inputs.test-subdirs }}' \
|
||||
--test-pattern '${{ inputs.test-pattern }}' \
|
||||
--inference-mode '${{ inputs.inference-mode }}' \
|
||||
${{ inputs.run-vision-tests == 'true' && '--run-vision-tests' || '' }} \
|
||||
| tee pytest-${{ inputs.inference-mode }}.log
|
||||
SCRIPT_ARGS="--stack-config ${{ inputs.stack-config }} --inference-mode ${{ inputs.inference-mode }}"
|
||||
|
||||
# Add optional arguments only if they are provided
|
||||
if [ -n '${{ inputs.setup }}' ]; then
|
||||
SCRIPT_ARGS="$SCRIPT_ARGS --setup ${{ inputs.setup }}"
|
||||
fi
|
||||
if [ -n '${{ inputs.suite }}' ]; then
|
||||
SCRIPT_ARGS="$SCRIPT_ARGS --suite ${{ inputs.suite }}"
|
||||
fi
|
||||
if [ -n '${{ inputs.subdirs }}' ]; then
|
||||
SCRIPT_ARGS="$SCRIPT_ARGS --subdirs ${{ inputs.subdirs }}"
|
||||
fi
|
||||
if [ -n '${{ inputs.pattern }}' ]; then
|
||||
SCRIPT_ARGS="$SCRIPT_ARGS --pattern ${{ inputs.pattern }}"
|
||||
fi
|
||||
|
||||
uv run --no-sync ./scripts/integration-tests.sh $SCRIPT_ARGS | tee pytest-${{ inputs.inference-mode }}.log
|
||||
|
||||
|
||||
- name: Commit and push recordings
|
||||
|
@ -57,12 +68,7 @@ runs:
|
|||
echo "New recordings detected, committing and pushing"
|
||||
git add tests/integration/recordings/
|
||||
|
||||
if [ "${{ inputs.run-vision-tests }}" == "true" ]; then
|
||||
git commit -m "Recordings update from CI (vision)"
|
||||
else
|
||||
git commit -m "Recordings update from CI"
|
||||
fi
|
||||
|
||||
git commit -m "Recordings update from CI (suite: ${{ inputs.suite }})"
|
||||
git fetch origin ${{ github.ref_name }}
|
||||
git rebase origin/${{ github.ref_name }}
|
||||
echo "Rebased successfully"
|
||||
|
|
8
.github/actions/setup-ollama/action.yml
vendored
8
.github/actions/setup-ollama/action.yml
vendored
|
@ -1,17 +1,17 @@
|
|||
name: Setup Ollama
|
||||
description: Start Ollama
|
||||
inputs:
|
||||
run-vision-tests:
|
||||
description: 'Run vision tests: "true" or "false"'
|
||||
suite:
|
||||
description: 'Test suite to use: base, responses, vision, etc.'
|
||||
required: false
|
||||
default: 'false'
|
||||
default: ''
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Start Ollama
|
||||
shell: bash
|
||||
run: |
|
||||
if [ "${{ inputs.run-vision-tests }}" == "true" ]; then
|
||||
if [ "${{ inputs.suite }}" == "vision" ]; then
|
||||
image="ollama-with-vision-model"
|
||||
else
|
||||
image="ollama-with-models"
|
||||
|
|
|
@ -8,14 +8,14 @@ inputs:
|
|||
client-version:
|
||||
description: 'Client version (latest or published)'
|
||||
required: true
|
||||
provider:
|
||||
description: 'Provider to setup (ollama or vllm)'
|
||||
required: true
|
||||
default: 'ollama'
|
||||
run-vision-tests:
|
||||
description: 'Whether to setup provider for vision tests'
|
||||
setup:
|
||||
description: 'Setup to configure (ollama, vllm, gpt, etc.)'
|
||||
required: false
|
||||
default: 'false'
|
||||
default: 'ollama'
|
||||
suite:
|
||||
description: 'Test suite to use: base, responses, vision, etc.'
|
||||
required: false
|
||||
default: ''
|
||||
inference-mode:
|
||||
description: 'Inference mode (record or replay)'
|
||||
required: true
|
||||
|
@ -30,13 +30,13 @@ runs:
|
|||
client-version: ${{ inputs.client-version }}
|
||||
|
||||
- name: Setup ollama
|
||||
if: ${{ inputs.provider == 'ollama' && inputs.inference-mode == 'record' }}
|
||||
if: ${{ (inputs.setup == 'ollama' || inputs.setup == 'ollama-vision') && inputs.inference-mode == 'record' }}
|
||||
uses: ./.github/actions/setup-ollama
|
||||
with:
|
||||
run-vision-tests: ${{ inputs.run-vision-tests }}
|
||||
suite: ${{ inputs.suite }}
|
||||
|
||||
- name: Setup vllm
|
||||
if: ${{ inputs.provider == 'vllm' && inputs.inference-mode == 'record' }}
|
||||
if: ${{ inputs.setup == 'vllm' && inputs.inference-mode == 'record' }}
|
||||
uses: ./.github/actions/setup-vllm
|
||||
|
||||
- name: Build Llama Stack
|
||||
|
|
3
.github/workflows/README.md
vendored
3
.github/workflows/README.md
vendored
|
@ -5,10 +5,11 @@ Llama Stack uses GitHub Actions for Continuous Integration (CI). Below is a tabl
|
|||
| Name | File | Purpose |
|
||||
| ---- | ---- | ------- |
|
||||
| Update Changelog | [changelog.yml](changelog.yml) | Creates PR for updating the CHANGELOG.md |
|
||||
| API Conformance Tests | [conformance.yml](conformance.yml) | Run the API Conformance test suite on the changes. |
|
||||
| Installer CI | [install-script-ci.yml](install-script-ci.yml) | Test the installation script |
|
||||
| Integration Auth Tests | [integration-auth-tests.yml](integration-auth-tests.yml) | Run the integration test suite with Kubernetes authentication |
|
||||
| SqlStore Integration Tests | [integration-sql-store-tests.yml](integration-sql-store-tests.yml) | Run the integration test suite with SqlStore |
|
||||
| Integration Tests (Replay) | [integration-tests.yml](integration-tests.yml) | Run the integration test suite from tests/integration in replay mode |
|
||||
| Integration Tests (Replay) | [integration-tests.yml](integration-tests.yml) | Run the integration test suites from tests/integration in replay mode |
|
||||
| Vector IO Integration Tests | [integration-vector-io-tests.yml](integration-vector-io-tests.yml) | Run the integration test suite with various VectorIO providers |
|
||||
| Pre-commit | [pre-commit.yml](pre-commit.yml) | Run pre-commit checks |
|
||||
| Test Llama Stack Build | [providers-build.yml](providers-build.yml) | Test llama stack build |
|
||||
|
|
57
.github/workflows/conformance.yml
vendored
Normal file
57
.github/workflows/conformance.yml
vendored
Normal file
|
@ -0,0 +1,57 @@
|
|||
# API Conformance Tests
|
||||
# This workflow ensures that API changes maintain backward compatibility and don't break existing integrations
|
||||
# It runs schema validation and OpenAPI diff checks to catch breaking changes early
|
||||
|
||||
name: API Conformance Tests
|
||||
|
||||
run-name: Run the API Conformance test suite on the changes.
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
types: [opened, synchronize, reopened]
|
||||
paths:
|
||||
- 'llama_stack/**'
|
||||
- '!llama_stack/ui/**'
|
||||
- 'tests/**'
|
||||
- 'uv.lock'
|
||||
- 'pyproject.toml'
|
||||
- '.github/workflows/conformance.yml' # This workflow itself
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }}
|
||||
# Cancel in-progress runs when new commits are pushed to avoid wasting CI resources
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
# Job to check if API schema changes maintain backward compatibility
|
||||
check-schema-compatibility:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
# Using specific version 4.1.7 because 5.0.0 fails when trying to run this locally using `act`
|
||||
# This ensures consistent behavior between local testing and CI
|
||||
- name: Checkout PR Code
|
||||
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||
|
||||
# Checkout the base branch to compare against (usually main)
|
||||
# This allows us to diff the current changes against the previous state
|
||||
- name: Checkout Base Branch
|
||||
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||
with:
|
||||
ref: ${{ github.event.pull_request.base.ref }}
|
||||
path: 'base'
|
||||
|
||||
# Install oasdiff: https://github.com/oasdiff/oasdiff, a tool for detecting breaking changes in OpenAPI specs.
|
||||
- name: Install oasdiff
|
||||
run: |
|
||||
curl -fsSL https://raw.githubusercontent.com/oasdiff/oasdiff/main/install.sh | sh
|
||||
|
||||
# Run oasdiff to detect breaking changes in the API specification
|
||||
# This step will fail if incompatible changes are detected, preventing breaking changes from being merged
|
||||
- name: Run OpenAPI Breaking Change Diff
|
||||
run: |
|
||||
oasdiff breaking --fail-on ERR base/docs/_static/llama-stack-spec.yaml docs/_static/llama-stack-spec.yaml --match-path '^/v1/openai/v1' \
|
||||
--match-path '^/v1/vector-io' \
|
||||
--match-path '^/v1/vector-dbs'
|
32
.github/workflows/integration-tests.yml
vendored
32
.github/workflows/integration-tests.yml
vendored
|
@ -1,6 +1,6 @@
|
|||
name: Integration Tests (Replay)
|
||||
|
||||
run-name: Run the integration test suite from tests/integration in replay mode
|
||||
run-name: Run the integration test suites from tests/integration in replay mode
|
||||
|
||||
on:
|
||||
push:
|
||||
|
@ -28,18 +28,10 @@ on:
|
|||
description: 'Test against both the latest and published versions'
|
||||
type: boolean
|
||||
default: false
|
||||
test-provider:
|
||||
description: 'Test against a specific provider'
|
||||
test-setup:
|
||||
description: 'Test against a specific setup'
|
||||
type: string
|
||||
default: 'ollama'
|
||||
test-subdirs:
|
||||
description: 'Comma-separated list of test subdirectories to run'
|
||||
type: string
|
||||
default: ''
|
||||
test-pattern:
|
||||
description: 'Regex pattern to pass to pytest -k'
|
||||
type: string
|
||||
default: ''
|
||||
|
||||
concurrency:
|
||||
# Skip concurrency for pushes to main - each commit should be tested independently
|
||||
|
@ -50,18 +42,18 @@ jobs:
|
|||
|
||||
run-replay-mode-tests:
|
||||
runs-on: ubuntu-latest
|
||||
name: ${{ format('Integration Tests ({0}, {1}, {2}, client={3}, vision={4})', matrix.client-type, matrix.provider, matrix.python-version, matrix.client-version, matrix.run-vision-tests) }}
|
||||
name: ${{ format('Integration Tests ({0}, {1}, {2}, client={3}, {4})', matrix.client-type, matrix.setup, matrix.python-version, matrix.client-version, matrix.suite) }}
|
||||
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
client-type: [library, server]
|
||||
# Use vllm on weekly schedule, otherwise use test-provider input (defaults to ollama)
|
||||
provider: ${{ (github.event.schedule == '1 0 * * 0') && fromJSON('["vllm"]') || fromJSON(format('["{0}"]', github.event.inputs.test-provider || 'ollama')) }}
|
||||
# Use vllm on weekly schedule, otherwise use test-setup input (defaults to ollama)
|
||||
setup: ${{ (github.event.schedule == '1 0 * * 0') && fromJSON('["vllm"]') || fromJSON(format('["{0}"]', github.event.inputs.test-setup || 'ollama')) }}
|
||||
# Use Python 3.13 only on nightly schedule (daily latest client test), otherwise use 3.12
|
||||
python-version: ${{ github.event.schedule == '0 0 * * *' && fromJSON('["3.12", "3.13"]') || fromJSON('["3.12"]') }}
|
||||
client-version: ${{ (github.event.schedule == '0 0 * * *' || github.event.inputs.test-all-client-versions == 'true') && fromJSON('["published", "latest"]') || fromJSON('["latest"]') }}
|
||||
run-vision-tests: [true, false]
|
||||
suite: [base, vision]
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
|
@ -72,16 +64,14 @@ jobs:
|
|||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
client-version: ${{ matrix.client-version }}
|
||||
provider: ${{ matrix.provider }}
|
||||
run-vision-tests: ${{ matrix.run-vision-tests }}
|
||||
setup: ${{ matrix.setup }}
|
||||
suite: ${{ matrix.suite }}
|
||||
inference-mode: 'replay'
|
||||
|
||||
- name: Run tests
|
||||
uses: ./.github/actions/run-and-record-tests
|
||||
with:
|
||||
test-subdirs: ${{ inputs.test-subdirs }}
|
||||
test-pattern: ${{ inputs.test-pattern }}
|
||||
stack-config: ${{ matrix.client-type == 'library' && 'ci-tests' || 'server:ci-tests' }}
|
||||
provider: ${{ matrix.provider }}
|
||||
setup: ${{ matrix.setup }}
|
||||
inference-mode: 'replay'
|
||||
run-vision-tests: ${{ matrix.run-vision-tests }}
|
||||
suite: ${{ matrix.suite }}
|
||||
|
|
5
.github/workflows/pre-commit.yml
vendored
5
.github/workflows/pre-commit.yml
vendored
|
@ -28,7 +28,7 @@ jobs:
|
|||
fetch-depth: ${{ github.actor == 'dependabot[bot]' && 0 || 1 }}
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0
|
||||
uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0
|
||||
with:
|
||||
python-version: '3.12'
|
||||
cache: pip
|
||||
|
@ -37,7 +37,7 @@ jobs:
|
|||
.pre-commit-config.yaml
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@49933ea5288caeca8642d1e84afbd3f7d6820020 # v4.4.0
|
||||
uses: actions/setup-node@a0853c24544627f65ddf259abe73b1d18a591444 # v5.0.0
|
||||
with:
|
||||
node-version: '20'
|
||||
cache: 'npm'
|
||||
|
@ -48,7 +48,6 @@ jobs:
|
|||
working-directory: llama_stack/ui
|
||||
|
||||
- uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1
|
||||
continue-on-error: true
|
||||
env:
|
||||
SKIP: no-commit-to-branch
|
||||
RUFF_OUTPUT_FORMAT: github
|
||||
|
|
2
.github/workflows/python-build-test.yml
vendored
2
.github/workflows/python-build-test.yml
vendored
|
@ -24,7 +24,7 @@ jobs:
|
|||
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@4959332f0f014c5280e7eac8b70c90cb574c9f9b # v6.6.0
|
||||
uses: astral-sh/setup-uv@557e51de59eb14aaaba2ed9621916900a91d50c6 # v6.6.1
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
activate-environment: true
|
||||
|
|
42
.github/workflows/record-integration-tests.yml
vendored
42
.github/workflows/record-integration-tests.yml
vendored
|
@ -10,19 +10,19 @@ run-name: Run the integration test suite from tests/integration
|
|||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
test-subdirs:
|
||||
description: 'Comma-separated list of test subdirectories to run'
|
||||
type: string
|
||||
default: ''
|
||||
test-provider:
|
||||
description: 'Test against a specific provider'
|
||||
test-setup:
|
||||
description: 'Test against a specific setup'
|
||||
type: string
|
||||
default: 'ollama'
|
||||
run-vision-tests:
|
||||
description: 'Whether to run vision tests'
|
||||
type: boolean
|
||||
default: false
|
||||
test-pattern:
|
||||
suite:
|
||||
description: 'Test suite to use: base, responses, vision, etc.'
|
||||
type: string
|
||||
default: ''
|
||||
subdirs:
|
||||
description: 'Comma-separated list of test subdirectories to run; overrides suite'
|
||||
type: string
|
||||
default: ''
|
||||
pattern:
|
||||
description: 'Regex pattern to pass to pytest -k'
|
||||
type: string
|
||||
default: ''
|
||||
|
@ -38,11 +38,11 @@ jobs:
|
|||
- name: Echo workflow inputs
|
||||
run: |
|
||||
echo "::group::Workflow Inputs"
|
||||
echo "test-subdirs: ${{ inputs.test-subdirs }}"
|
||||
echo "test-provider: ${{ inputs.test-provider }}"
|
||||
echo "run-vision-tests: ${{ inputs.run-vision-tests }}"
|
||||
echo "test-pattern: ${{ inputs.test-pattern }}"
|
||||
echo "branch: ${{ github.ref_name }}"
|
||||
echo "test-setup: ${{ inputs.test-setup }}"
|
||||
echo "suite: ${{ inputs.suite }}"
|
||||
echo "subdirs: ${{ inputs.subdirs }}"
|
||||
echo "pattern: ${{ inputs.pattern }}"
|
||||
echo "::endgroup::"
|
||||
|
||||
- name: Checkout repository
|
||||
|
@ -55,16 +55,16 @@ jobs:
|
|||
with:
|
||||
python-version: "3.12" # Use single Python version for recording
|
||||
client-version: "latest"
|
||||
provider: ${{ inputs.test-provider || 'ollama' }}
|
||||
run-vision-tests: ${{ inputs.run-vision-tests }}
|
||||
setup: ${{ inputs.test-setup || 'ollama' }}
|
||||
suite: ${{ inputs.suite }}
|
||||
inference-mode: 'record'
|
||||
|
||||
- name: Run and record tests
|
||||
uses: ./.github/actions/run-and-record-tests
|
||||
with:
|
||||
test-pattern: ${{ inputs.test-pattern }}
|
||||
test-subdirs: ${{ inputs.test-subdirs }}
|
||||
stack-config: 'server:ci-tests' # recording must be done with server since more tests are run
|
||||
provider: ${{ inputs.test-provider || 'ollama' }}
|
||||
setup: ${{ inputs.test-setup || 'ollama' }}
|
||||
inference-mode: 'record'
|
||||
run-vision-tests: ${{ inputs.run-vision-tests }}
|
||||
suite: ${{ inputs.suite }}
|
||||
subdirs: ${{ inputs.subdirs }}
|
||||
pattern: ${{ inputs.pattern }}
|
||||
|
|
2
.github/workflows/stale_bot.yml
vendored
2
.github/workflows/stale_bot.yml
vendored
|
@ -24,7 +24,7 @@ jobs:
|
|||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Stale Action
|
||||
uses: actions/stale@5bef64f19d7facfb25b37b414482c7164d639639 # v9.1.0
|
||||
uses: actions/stale@3a9db7e6a41a89f618792c92c0e97cc736e1b13f # v10.0.0
|
||||
with:
|
||||
stale-issue-label: 'stale'
|
||||
stale-issue-message: >
|
||||
|
|
2
.github/workflows/ui-unit-tests.yml
vendored
2
.github/workflows/ui-unit-tests.yml
vendored
|
@ -29,7 +29,7 @@ jobs:
|
|||
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@49933ea5288caeca8642d1e84afbd3f7d6820020 # v4.4.0
|
||||
uses: actions/setup-node@a0853c24544627f65ddf259abe73b1d18a591444 # v5.0.0
|
||||
with:
|
||||
node-version: ${{ matrix.node-version }}
|
||||
cache: 'npm'
|
||||
|
|
2
.gitignore
vendored
2
.gitignore
vendored
|
@ -26,5 +26,7 @@ venv/
|
|||
pytest-report.xml
|
||||
.coverage
|
||||
.python-version
|
||||
AGENTS.md
|
||||
server.log
|
||||
CLAUDE.md
|
||||
.claude/
|
||||
|
|
|
@ -86,7 +86,7 @@ repos:
|
|||
language: python
|
||||
pass_filenames: false
|
||||
require_serial: true
|
||||
files: ^llama_stack/templates/.*$|^llama_stack/providers/.*/inference/.*/models\.py$
|
||||
files: ^llama_stack/distributions/.*$|^llama_stack/providers/.*/inference/.*/models\.py$
|
||||
- id: provider-codegen
|
||||
name: Provider Codegen
|
||||
additional_dependencies:
|
||||
|
|
98
CHANGELOG.md
98
CHANGELOG.md
|
@ -1,5 +1,103 @@
|
|||
# Changelog
|
||||
|
||||
# v0.2.20
|
||||
Published on: 2025-08-29T22:25:32Z
|
||||
|
||||
Here are some key changes that are coming as part of this release.
|
||||
|
||||
### Build and Environment
|
||||
|
||||
- Environment improvements: fixed env var replacement to preserve types.
|
||||
- Docker stability: fixed container startup failures for Fireworks AI provider.
|
||||
- Removed absolute paths in build for better portability.
|
||||
|
||||
### Features
|
||||
|
||||
- UI Enhancements: Implemented file upload and VectorDB creation/configuration directly in UI.
|
||||
- Vector Store Improvements: Added keyword, vector, and hybrid search inside vector store.
|
||||
- Added S3 authorization support for file providers.
|
||||
- SQL Store: Added inequality support to where clause.
|
||||
|
||||
### Documentation
|
||||
|
||||
- Fixed post-training docs.
|
||||
- Added Contributor Guidelines for creating Internal vs. External providers.
|
||||
|
||||
### Fixes
|
||||
|
||||
- Removed unsupported bfcl scoring function.
|
||||
- Multiple reliability and configuration fixes for providers and environment handling.
|
||||
|
||||
### Engineering / Chores
|
||||
|
||||
- Cleaner internal development setup with consistent paths.
|
||||
- Incremental improvements to provider integration and vector store behavior.
|
||||
|
||||
|
||||
### New Contributors
|
||||
- @omertuc made their first contribution in #3270
|
||||
- @r3v5 made their first contribution in vector store hybrid search
|
||||
|
||||
---
|
||||
|
||||
# v0.2.19
|
||||
Published on: 2025-08-26T22:06:55Z
|
||||
|
||||
## Highlights
|
||||
* feat: Add CORS configuration support for server by @skamenan7 in https://github.com/llamastack/llama-stack/pull/3201
|
||||
* feat(api): introduce /rerank by @ehhuang in https://github.com/llamastack/llama-stack/pull/2940
|
||||
* feat: Add S3 Files Provider by @mattf in https://github.com/llamastack/llama-stack/pull/3202
|
||||
|
||||
|
||||
---
|
||||
|
||||
# v0.2.18
|
||||
Published on: 2025-08-20T01:09:27Z
|
||||
|
||||
## Highlights
|
||||
* Add moderations create API
|
||||
* Hybrid search in Milvus
|
||||
* Numerous Responses API improvements
|
||||
* Documentation updates
|
||||
|
||||
|
||||
---
|
||||
|
||||
# v0.2.17
|
||||
Published on: 2025-08-05T01:51:14Z
|
||||
|
||||
## Highlights
|
||||
|
||||
* feat(tests): introduce inference record/replay to increase test reliability by @ashwinb in https://github.com/meta-llama/llama-stack/pull/2941
|
||||
* fix(library_client): improve initialization error handling and prevent AttributeError by @mattf in https://github.com/meta-llama/llama-stack/pull/2944
|
||||
* fix: use OLLAMA_URL to activate Ollama provider in starter by @ashwinb in https://github.com/meta-llama/llama-stack/pull/2963
|
||||
* feat(UI): adding MVP playground UI by @franciscojavierarceo in https://github.com/meta-llama/llama-stack/pull/2828
|
||||
* Standardization of errors (@nathan-weinberg)
|
||||
* feat: Enable DPO training with HuggingFace inline provider by @Nehanth in https://github.com/meta-llama/llama-stack/pull/2825
|
||||
* chore: rename templates to distributions by @ashwinb in https://github.com/meta-llama/llama-stack/pull/3035
|
||||
|
||||
|
||||
---
|
||||
|
||||
# v0.2.16
|
||||
Published on: 2025-07-28T23:35:23Z
|
||||
|
||||
## Highlights
|
||||
|
||||
* Automatic model registration for self-hosted providers (ollama and vllm currently). No need for `INFERENCE_MODEL` environment variables which need to be updated, etc.
|
||||
* Much simplified starter distribution. Most `ENABLE_` env variables are now gone. When you set `VLLM_URL`, the `vllm` provider is auto-enabled. Similar for `MILVUS_URL`, `PGVECTOR_DB`, etc. Check the [run.yaml](https://github.com/meta-llama/llama-stack/blob/main/llama_stack/templates/starter/run.yaml) for more details.
|
||||
* All tests migrated to pytest now (thanks @Elbehery)
|
||||
* DPO implementation in the post-training provider (thanks @Nehanth)
|
||||
* (Huge!) Support for external APIs and providers thereof (thanks @leseb, @cdoern and others). This is a really big deal -- you can now add more APIs completely out of tree and experiment with them before (optionally) wanting to contribute back.
|
||||
* `inline::vllm` provider is gone thank you very much
|
||||
* several improvements to OpenAI inference implementations and LiteLLM backend (thanks @mattf)
|
||||
* Chroma now supports Vector Store API (thanks @franciscojavierarceo).
|
||||
* Authorization improvements: Vector Store/File APIs now supports access control (thanks @franciscojavierarceo); Telemetry read APIs are gated according to logged-in user's roles.
|
||||
|
||||
|
||||
|
||||
---
|
||||
|
||||
# v0.2.15
|
||||
Published on: 2025-07-16T03:30:01Z
|
||||
|
||||
|
|
|
@ -34,13 +34,12 @@ This data enables data-driven architectural decisions and performance optimizati
|
|||
|
||||
**1. Deploy base k8s infrastructure:**
|
||||
```bash
|
||||
cd ../k8s
|
||||
cd ../../docs/source/distributions/k8s
|
||||
./apply.sh
|
||||
```
|
||||
|
||||
**2. Deploy benchmark components:**
|
||||
```bash
|
||||
cd ../k8s-benchmark
|
||||
./apply.sh
|
||||
```
|
||||
|
||||
|
@ -56,7 +55,6 @@ kubectl get pods
|
|||
|
||||
**Benchmark Llama Stack (default):**
|
||||
```bash
|
||||
cd docs/source/distributions/k8s-benchmark/
|
||||
./run-benchmark.sh
|
||||
```
|
||||
|
|
@ -14,7 +14,7 @@ import os
|
|||
import random
|
||||
import statistics
|
||||
import time
|
||||
from typing import Tuple
|
||||
|
||||
import aiohttp
|
||||
|
||||
|
||||
|
@ -57,17 +57,9 @@ class BenchmarkStats:
|
|||
success_rate = (self.success_count / self.total_requests) * 100
|
||||
|
||||
print(f"\n{'=' * 60}")
|
||||
print(f"BENCHMARK RESULTS")
|
||||
print(f"{'='*60}")
|
||||
print(f"Total time: {total_time:.2f}s")
|
||||
print(f"Concurrent users: {self.concurrent_users}")
|
||||
print(f"Total requests: {self.total_requests}")
|
||||
print(f"Successful requests: {self.success_count}")
|
||||
print(f"Failed requests: {len(self.errors)}")
|
||||
print(f"Success rate: {success_rate:.1f}%")
|
||||
print(f"Requests per second: {self.success_count / total_time:.2f}")
|
||||
print("BENCHMARK RESULTS")
|
||||
|
||||
print(f"\nResponse Time Statistics:")
|
||||
print("\nResponse Time Statistics:")
|
||||
print(f" Mean: {statistics.mean(self.response_times):.3f}s")
|
||||
print(f" Median: {statistics.median(self.response_times):.3f}s")
|
||||
print(f" Min: {min(self.response_times):.3f}s")
|
||||
|
@ -78,14 +70,14 @@ class BenchmarkStats:
|
|||
|
||||
percentiles = [50, 90, 95, 99]
|
||||
sorted_times = sorted(self.response_times)
|
||||
print(f"\nPercentiles:")
|
||||
print("\nPercentiles:")
|
||||
for p in percentiles:
|
||||
idx = int(len(sorted_times) * p / 100) - 1
|
||||
idx = max(0, min(idx, len(sorted_times) - 1))
|
||||
print(f" P{p}: {sorted_times[idx]:.3f}s")
|
||||
|
||||
if self.ttft_times:
|
||||
print(f"\nTime to First Token (TTFT) Statistics:")
|
||||
print("\nTime to First Token (TTFT) Statistics:")
|
||||
print(f" Mean: {statistics.mean(self.ttft_times):.3f}s")
|
||||
print(f" Median: {statistics.median(self.ttft_times):.3f}s")
|
||||
print(f" Min: {min(self.ttft_times):.3f}s")
|
||||
|
@ -95,26 +87,35 @@ class BenchmarkStats:
|
|||
print(f" Std Dev: {statistics.stdev(self.ttft_times):.3f}s")
|
||||
|
||||
sorted_ttft = sorted(self.ttft_times)
|
||||
print(f"\nTTFT Percentiles:")
|
||||
print("\nTTFT Percentiles:")
|
||||
for p in percentiles:
|
||||
idx = int(len(sorted_ttft) * p / 100) - 1
|
||||
idx = max(0, min(idx, len(sorted_ttft) - 1))
|
||||
print(f" P{p}: {sorted_ttft[idx]:.3f}s")
|
||||
|
||||
if self.chunks_received:
|
||||
print(f"\nStreaming Statistics:")
|
||||
print("\nStreaming Statistics:")
|
||||
print(f" Mean chunks per response: {statistics.mean(self.chunks_received):.1f}")
|
||||
print(f" Total chunks received: {sum(self.chunks_received)}")
|
||||
|
||||
print(f"{'=' * 60}")
|
||||
print(f"Total time: {total_time:.2f}s")
|
||||
print(f"Concurrent users: {self.concurrent_users}")
|
||||
print(f"Total requests: {self.total_requests}")
|
||||
print(f"Successful requests: {self.success_count}")
|
||||
print(f"Failed requests: {len(self.errors)}")
|
||||
print(f"Success rate: {success_rate:.1f}%")
|
||||
print(f"Requests per second: {self.success_count / total_time:.2f}")
|
||||
|
||||
if self.errors:
|
||||
print(f"\nErrors (showing first 5):")
|
||||
print("\nErrors (showing first 5):")
|
||||
for error in self.errors[:5]:
|
||||
print(f" {error}")
|
||||
|
||||
|
||||
class LlamaStackBenchmark:
|
||||
def __init__(self, base_url: str, model_id: str):
|
||||
self.base_url = base_url.rstrip('/')
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.model_id = model_id
|
||||
self.headers = {"Content-Type": "application/json"}
|
||||
self.test_messages = [
|
||||
|
@ -125,20 +126,14 @@ class LlamaStackBenchmark:
|
|||
[
|
||||
{"role": "user", "content": "What is machine learning?"},
|
||||
{"role": "assistant", "content": "Machine learning is a subset of AI..."},
|
||||
{"role": "user", "content": "Can you give me a practical example?"}
|
||||
]
|
||||
{"role": "user", "content": "Can you give me a practical example?"},
|
||||
],
|
||||
]
|
||||
|
||||
|
||||
async def make_async_streaming_request(self) -> Tuple[float, int, float | None, str | None]:
|
||||
async def make_async_streaming_request(self) -> tuple[float, int, float | None, str | None]:
|
||||
"""Make a single async streaming chat completion request."""
|
||||
messages = random.choice(self.test_messages)
|
||||
payload = {
|
||||
"model": self.model_id,
|
||||
"messages": messages,
|
||||
"stream": True,
|
||||
"max_tokens": 100
|
||||
}
|
||||
payload = {"model": self.model_id, "messages": messages, "stream": True, "max_tokens": 100}
|
||||
|
||||
start_time = time.time()
|
||||
chunks_received = 0
|
||||
|
@ -152,17 +147,17 @@ class LlamaStackBenchmark:
|
|||
f"{self.base_url}/chat/completions",
|
||||
headers=self.headers,
|
||||
json=payload,
|
||||
timeout=aiohttp.ClientTimeout(total=30)
|
||||
timeout=aiohttp.ClientTimeout(total=30),
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
async for line in response.content:
|
||||
if line:
|
||||
line_str = line.decode('utf-8').strip()
|
||||
if line_str.startswith('data: '):
|
||||
line_str = line.decode("utf-8").strip()
|
||||
if line_str.startswith("data: "):
|
||||
chunks_received += 1
|
||||
if ttft is None:
|
||||
ttft = time.time() - start_time
|
||||
if line_str == 'data: [DONE]':
|
||||
if line_str == "data: [DONE]":
|
||||
break
|
||||
|
||||
if chunks_received == 0:
|
||||
|
@ -179,7 +174,6 @@ class LlamaStackBenchmark:
|
|||
response_time = time.time() - start_time
|
||||
return response_time, chunks_received, ttft, error
|
||||
|
||||
|
||||
async def run_benchmark(self, duration: int, concurrent_users: int) -> BenchmarkStats:
|
||||
"""Run benchmark using async requests for specified duration."""
|
||||
stats = BenchmarkStats()
|
||||
|
@ -191,7 +185,7 @@ class LlamaStackBenchmark:
|
|||
print(f"Model: {self.model_id}")
|
||||
|
||||
connector = aiohttp.TCPConnector(limit=concurrent_users)
|
||||
async with aiohttp.ClientSession(connector=connector) as session:
|
||||
async with aiohttp.ClientSession(connector=connector):
|
||||
|
||||
async def worker(worker_id: int):
|
||||
"""Worker that sends requests sequentially until canceled."""
|
||||
|
@ -215,7 +209,9 @@ class LlamaStackBenchmark:
|
|||
await asyncio.sleep(1) # Report every second
|
||||
if time.time() >= last_report_time + 10: # Report every 10 seconds
|
||||
elapsed = time.time() - stats.start_time
|
||||
print(f"Completed: {stats.total_requests} requests in {elapsed:.1f}s")
|
||||
print(
|
||||
f"Completed: {stats.total_requests} requests in {elapsed:.1f}s, RPS: {stats.total_requests / elapsed:.1f}"
|
||||
)
|
||||
last_report_time = time.time()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
|
@ -240,14 +236,16 @@ class LlamaStackBenchmark:
|
|||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Llama Stack Benchmark Tool")
|
||||
parser.add_argument("--base-url", default=os.getenv("BENCHMARK_BASE_URL", "http://localhost:8000/v1/openai/v1"),
|
||||
help="Base URL for the API (default: http://localhost:8000/v1/openai/v1)")
|
||||
parser.add_argument("--model", default=os.getenv("INFERENCE_MODEL", "test-model"),
|
||||
help="Model ID to use for requests")
|
||||
parser.add_argument("--duration", type=int, default=60,
|
||||
help="Duration in seconds to run benchmark (default: 60)")
|
||||
parser.add_argument("--concurrent", type=int, default=10,
|
||||
help="Number of concurrent users (default: 10)")
|
||||
parser.add_argument(
|
||||
"--base-url",
|
||||
default=os.getenv("BENCHMARK_BASE_URL", "http://localhost:8000/v1/openai/v1"),
|
||||
help="Base URL for the API (default: http://localhost:8000/v1/openai/v1)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model", default=os.getenv("INFERENCE_MODEL", "test-model"), help="Model ID to use for requests"
|
||||
)
|
||||
parser.add_argument("--duration", type=int, default=60, help="Duration in seconds to run benchmark (default: 60)")
|
||||
parser.add_argument("--concurrent", type=int, default=10, help="Number of concurrent users (default: 10)")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
|
@ -11,16 +11,18 @@ OpenAI-compatible mock server that returns:
|
|||
- Valid OpenAI-formatted chat completion responses with dynamic content
|
||||
"""
|
||||
|
||||
from flask import Flask, request, jsonify, Response
|
||||
import time
|
||||
import random
|
||||
import uuid
|
||||
import json
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
import uuid
|
||||
|
||||
from flask import Flask, Response, jsonify, request
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
|
||||
# Models from environment variables
|
||||
def get_models():
|
||||
models_str = os.getenv("MOCK_MODELS", "meta-llama/Llama-3.2-3B-Instruct")
|
||||
|
@ -29,40 +31,72 @@ def get_models():
|
|||
return {
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"id": model_id,
|
||||
"object": "model",
|
||||
"created": 1234567890,
|
||||
"owned_by": "vllm"
|
||||
}
|
||||
for model_id in model_ids
|
||||
]
|
||||
{"id": model_id, "object": "model", "created": 1234567890, "owned_by": "vllm"} for model_id in model_ids
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def generate_random_text(length=50):
|
||||
"""Generate random but coherent text for responses."""
|
||||
words = [
|
||||
"Hello", "there", "I'm", "an", "AI", "assistant", "ready", "to", "help", "you",
|
||||
"with", "your", "questions", "and", "tasks", "today", "Let", "me","know", "what",
|
||||
"you'd", "like", "to", "discuss", "or", "explore", "together", "I", "can", "assist",
|
||||
"with", "various", "topics", "including", "coding", "writing", "analysis", "and", "more"
|
||||
"Hello",
|
||||
"there",
|
||||
"I'm",
|
||||
"an",
|
||||
"AI",
|
||||
"assistant",
|
||||
"ready",
|
||||
"to",
|
||||
"help",
|
||||
"you",
|
||||
"with",
|
||||
"your",
|
||||
"questions",
|
||||
"and",
|
||||
"tasks",
|
||||
"today",
|
||||
"Let",
|
||||
"me",
|
||||
"know",
|
||||
"what",
|
||||
"you'd",
|
||||
"like",
|
||||
"to",
|
||||
"discuss",
|
||||
"or",
|
||||
"explore",
|
||||
"together",
|
||||
"I",
|
||||
"can",
|
||||
"assist",
|
||||
"with",
|
||||
"various",
|
||||
"topics",
|
||||
"including",
|
||||
"coding",
|
||||
"writing",
|
||||
"analysis",
|
||||
"and",
|
||||
"more",
|
||||
]
|
||||
return " ".join(random.choices(words, k=length))
|
||||
|
||||
@app.route('/v1/models', methods=['GET'])
|
||||
|
||||
@app.route("/v1/models", methods=["GET"])
|
||||
def list_models():
|
||||
models = get_models()
|
||||
print(f"[MOCK] Returning models: {[m['id'] for m in models['data']]}")
|
||||
return jsonify(models)
|
||||
|
||||
@app.route('/v1/chat/completions', methods=['POST'])
|
||||
|
||||
@app.route("/v1/chat/completions", methods=["POST"])
|
||||
def chat_completions():
|
||||
"""Return OpenAI-formatted chat completion responses."""
|
||||
data = request.get_json()
|
||||
default_model = get_models()['data'][0]['id']
|
||||
model = data.get('model', default_model)
|
||||
messages = data.get('messages', [])
|
||||
stream = data.get('stream', False)
|
||||
default_model = get_models()["data"][0]["id"]
|
||||
model = data.get("model", default_model)
|
||||
messages = data.get("messages", [])
|
||||
stream = data.get("stream", False)
|
||||
|
||||
print(f"[MOCK] Chat completion request - model: {model}, stream: {stream}")
|
||||
|
||||
|
@ -71,11 +105,12 @@ def chat_completions():
|
|||
else:
|
||||
return handle_non_streaming_completion(model, messages)
|
||||
|
||||
|
||||
def handle_non_streaming_completion(model, messages):
|
||||
response_text = generate_random_text(random.randint(20, 80))
|
||||
|
||||
# Calculate realistic token counts
|
||||
prompt_tokens = sum(len(str(msg.get('content', '')).split()) for msg in messages)
|
||||
prompt_tokens = sum(len(str(msg.get("content", "")).split()) for msg in messages)
|
||||
completion_tokens = len(response_text.split())
|
||||
|
||||
response = {
|
||||
|
@ -83,25 +118,17 @@ def handle_non_streaming_completion(model, messages):
|
|||
"object": "chat.completion",
|
||||
"created": int(time.time()),
|
||||
"model": model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": response_text
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
],
|
||||
"choices": [{"index": 0, "message": {"role": "assistant", "content": response_text}, "finish_reason": "stop"}],
|
||||
"usage": {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"total_tokens": prompt_tokens + completion_tokens
|
||||
}
|
||||
"total_tokens": prompt_tokens + completion_tokens,
|
||||
},
|
||||
}
|
||||
|
||||
return jsonify(response)
|
||||
|
||||
|
||||
def handle_streaming_completion(model, messages):
|
||||
def generate_stream():
|
||||
# Generate response text
|
||||
|
@ -114,12 +141,7 @@ def handle_streaming_completion(model, messages):
|
|||
"object": "chat.completion.chunk",
|
||||
"created": int(time.time()),
|
||||
"model": model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {"role": "assistant", "content": ""}
|
||||
}
|
||||
]
|
||||
"choices": [{"index": 0, "delta": {"role": "assistant", "content": ""}}],
|
||||
}
|
||||
yield f"data: {json.dumps(initial_chunk)}\n\n"
|
||||
|
||||
|
@ -130,12 +152,7 @@ def handle_streaming_completion(model, messages):
|
|||
"object": "chat.completion.chunk",
|
||||
"created": int(time.time()),
|
||||
"model": model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {"content": f"{word} " if i < len(words) - 1 else word}
|
||||
}
|
||||
]
|
||||
"choices": [{"index": 0, "delta": {"content": f"{word} " if i < len(words) - 1 else word}}],
|
||||
}
|
||||
yield f"data: {json.dumps(chunk)}\n\n"
|
||||
# Configurable delay to simulate realistic streaming
|
||||
|
@ -148,35 +165,30 @@ def handle_streaming_completion(model, messages):
|
|||
"object": "chat.completion.chunk",
|
||||
"created": int(time.time()),
|
||||
"model": model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {"content": ""},
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
]
|
||||
"choices": [{"index": 0, "delta": {"content": ""}, "finish_reason": "stop"}],
|
||||
}
|
||||
yield f"data: {json.dumps(final_chunk)}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return Response(
|
||||
generate_stream(),
|
||||
mimetype='text/event-stream',
|
||||
mimetype="text/event-stream",
|
||||
headers={
|
||||
'Cache-Control': 'no-cache',
|
||||
'Connection': 'keep-alive',
|
||||
'Access-Control-Allow-Origin': '*',
|
||||
}
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"Access-Control-Allow-Origin": "*",
|
||||
},
|
||||
)
|
||||
|
||||
@app.route('/health', methods=['GET'])
|
||||
|
||||
@app.route("/health", methods=["GET"])
|
||||
def health():
|
||||
return jsonify({"status": "healthy", "type": "openai-mock"})
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='OpenAI-compatible mock server')
|
||||
parser.add_argument('--port', type=int, default=8081,
|
||||
help='Port to run the server on (default: 8081)')
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="OpenAI-compatible mock server")
|
||||
parser.add_argument("--port", type=int, default=8081, help="Port to run the server on (default: 8081)")
|
||||
args = parser.parse_args()
|
||||
|
||||
port = args.port
|
||||
|
@ -187,4 +199,4 @@ if __name__ == '__main__':
|
|||
print("- OpenAI-formatted chat/completion responses with dynamic content")
|
||||
print("- Streaming support with valid SSE format")
|
||||
print(f"- Listening on: http://0.0.0.0:{port}")
|
||||
app.run(host='0.0.0.0', port=port, debug=False)
|
||||
app.run(host="0.0.0.0", port=port, debug=False)
|
|
@ -6,6 +6,7 @@ data:
|
|||
apis:
|
||||
- agents
|
||||
- inference
|
||||
- files
|
||||
- safety
|
||||
- telemetry
|
||||
- tool_runtime
|
||||
|
@ -19,13 +20,6 @@ data:
|
|||
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
|
||||
api_token: ${env.VLLM_API_TOKEN:=fake}
|
||||
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
|
||||
- provider_id: vllm-safety
|
||||
provider_type: remote::vllm
|
||||
config:
|
||||
url: ${env.VLLM_SAFETY_URL:=http://localhost:8000/v1}
|
||||
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
|
||||
api_token: ${env.VLLM_API_TOKEN:=fake}
|
||||
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
|
||||
- provider_id: sentence-transformers
|
||||
provider_type: inline::sentence-transformers
|
||||
config: {}
|
||||
|
@ -41,6 +35,14 @@ data:
|
|||
db: ${env.POSTGRES_DB:=llamastack}
|
||||
user: ${env.POSTGRES_USER:=llamastack}
|
||||
password: ${env.POSTGRES_PASSWORD:=llamastack}
|
||||
files:
|
||||
- provider_id: meta-reference-files
|
||||
provider_type: inline::localfs
|
||||
config:
|
||||
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/starter/files}
|
||||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/files_metadata.db
|
||||
safety:
|
||||
- provider_id: llama-guard
|
||||
provider_type: inline::llama-guard
|
||||
|
@ -111,9 +113,6 @@ data:
|
|||
- model_id: ${env.INFERENCE_MODEL}
|
||||
provider_id: vllm-inference
|
||||
model_type: llm
|
||||
- model_id: ${env.SAFETY_MODEL}
|
||||
provider_id: vllm-safety
|
||||
model_type: llm
|
||||
shields:
|
||||
- shield_id: ${env.SAFETY_MODEL:=meta-llama/Llama-Guard-3-1B}
|
||||
vector_dbs: []
|
|
@ -2,7 +2,10 @@ version: '2'
|
|||
image_name: kubernetes-benchmark-demo
|
||||
apis:
|
||||
- agents
|
||||
- files
|
||||
- inference
|
||||
- files
|
||||
- safety
|
||||
- telemetry
|
||||
- tool_runtime
|
||||
- vector_io
|
||||
|
@ -18,6 +21,14 @@ providers:
|
|||
- provider_id: sentence-transformers
|
||||
provider_type: inline::sentence-transformers
|
||||
config: {}
|
||||
files:
|
||||
- provider_id: meta-reference-files
|
||||
provider_type: inline::localfs
|
||||
config:
|
||||
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/starter/files}
|
||||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/files_metadata.db
|
||||
vector_io:
|
||||
- provider_id: ${env.ENABLE_CHROMADB:+chromadb}
|
||||
provider_type: remote::chromadb
|
||||
|
@ -30,6 +41,19 @@ providers:
|
|||
db: ${env.POSTGRES_DB:=llamastack}
|
||||
user: ${env.POSTGRES_USER:=llamastack}
|
||||
password: ${env.POSTGRES_PASSWORD:=llamastack}
|
||||
files:
|
||||
- provider_id: meta-reference-files
|
||||
provider_type: inline::localfs
|
||||
config:
|
||||
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/starter/files}
|
||||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/files_metadata.db
|
||||
safety:
|
||||
- provider_id: llama-guard
|
||||
provider_type: inline::llama-guard
|
||||
config:
|
||||
excluded_categories: []
|
||||
agents:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
@ -95,6 +119,8 @@ models:
|
|||
- model_id: ${env.INFERENCE_MODEL}
|
||||
provider_id: vllm-inference
|
||||
model_type: llm
|
||||
shields:
|
||||
- shield_id: ${env.SAFETY_MODEL:=meta-llama/Llama-Guard-3-1B}
|
||||
vector_dbs: []
|
||||
datasets: []
|
||||
scoring_fns: []
|
101
docs/_static/css/my_theme.css
vendored
101
docs/_static/css/my_theme.css
vendored
|
@ -1,5 +1,106 @@
|
|||
@import url("theme.css");
|
||||
|
||||
/* Horizontal Navigation Bar */
|
||||
.horizontal-nav {
|
||||
background-color: #ffffff;
|
||||
border-bottom: 1px solid #e5e5e5;
|
||||
padding: 0;
|
||||
position: fixed;
|
||||
top: 0;
|
||||
left: 0;
|
||||
right: 0;
|
||||
z-index: 1050;
|
||||
height: 50px;
|
||||
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
|
||||
}
|
||||
|
||||
[data-theme="dark"] .horizontal-nav {
|
||||
background-color: #1a1a1a;
|
||||
border-bottom: 1px solid #333;
|
||||
}
|
||||
|
||||
.horizontal-nav .nav-container {
|
||||
max-width: 1200px;
|
||||
margin: 0 auto;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: space-between;
|
||||
padding: 0 20px;
|
||||
height: 100%;
|
||||
}
|
||||
|
||||
.horizontal-nav .nav-brand {
|
||||
font-size: 18px;
|
||||
font-weight: 600;
|
||||
color: #333;
|
||||
text-decoration: none;
|
||||
}
|
||||
|
||||
[data-theme="dark"] .horizontal-nav .nav-brand {
|
||||
color: #fff;
|
||||
}
|
||||
|
||||
.horizontal-nav .nav-links {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 30px;
|
||||
list-style: none;
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
}
|
||||
|
||||
.horizontal-nav .nav-links a {
|
||||
color: #666;
|
||||
text-decoration: none;
|
||||
font-size: 14px;
|
||||
font-weight: 500;
|
||||
padding: 8px 12px;
|
||||
border-radius: 6px;
|
||||
transition: all 0.2s ease;
|
||||
}
|
||||
|
||||
.horizontal-nav .nav-links a:hover,
|
||||
.horizontal-nav .nav-links a.active {
|
||||
color: #333;
|
||||
background-color: #f5f5f5;
|
||||
}
|
||||
|
||||
.horizontal-nav .nav-links a.active {
|
||||
font-weight: 600;
|
||||
}
|
||||
|
||||
[data-theme="dark"] .horizontal-nav .nav-links a {
|
||||
color: #ccc;
|
||||
}
|
||||
|
||||
[data-theme="dark"] .horizontal-nav .nav-links a:hover,
|
||||
[data-theme="dark"] .horizontal-nav .nav-links a.active {
|
||||
color: #fff;
|
||||
background-color: #333;
|
||||
}
|
||||
|
||||
.horizontal-nav .nav-links .github-link {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 6px;
|
||||
}
|
||||
|
||||
.horizontal-nav .nav-links .github-icon {
|
||||
width: 16px;
|
||||
height: 16px;
|
||||
fill: currentColor;
|
||||
}
|
||||
|
||||
/* Adjust main content to account for fixed nav */
|
||||
.wy-nav-side {
|
||||
top: 50px;
|
||||
height: calc(100vh - 50px);
|
||||
}
|
||||
|
||||
.wy-nav-content-wrap {
|
||||
margin-top: 50px;
|
||||
}
|
||||
|
||||
.wy-nav-content {
|
||||
max-width: 90%;
|
||||
}
|
||||
|
|
44
docs/_static/js/horizontal_nav.js
vendored
Normal file
44
docs/_static/js/horizontal_nav.js
vendored
Normal file
|
@ -0,0 +1,44 @@
|
|||
// Horizontal Navigation Bar for Llama Stack Documentation
|
||||
document.addEventListener('DOMContentLoaded', function() {
|
||||
// Create the horizontal navigation HTML
|
||||
const navHTML = `
|
||||
<nav class="horizontal-nav">
|
||||
<div class="nav-container">
|
||||
<a href="/" class="nav-brand">Llama Stack</a>
|
||||
<ul class="nav-links">
|
||||
<li><a href="/">Docs</a></li>
|
||||
<li><a href="/references/api_reference/">API Reference</a></li>
|
||||
<li><a href="https://github.com/meta-llama/llama-stack" target="_blank" class="github-link">
|
||||
<svg class="github-icon" viewBox="0 0 16 16" aria-hidden="true">
|
||||
<path d="M8 0C3.58 0 0 3.58 0 8c0 3.54 2.29 6.53 5.47 7.59.4.07.55-.17.55-.38 0-.19-.01-.82-.01-1.49-2.01.37-2.53-.49-2.69-.94-.09-.23-.48-.94-.82-1.13-.28-.15-.68-.52-.01-.53.63-.01 1.08.58 1.23.82.72 1.21 1.87.87 2.33.66.07-.52.28-.87.51-1.07-1.78-.2-3.64-.89-3.64-3.95 0-.87.31-1.59.82-2.15-.08-.2-.36-1.02.08-2.12 0 0 .67-.21 2.2.82.64-.18 1.32-.27 2-.27.68 0 1.36.09 2 .27 1.53-1.04 2.2-.82 2.2-.82.44 1.1.16 1.92.08 2.12.51.56.82 1.27.82 2.15 0 3.07-1.87 3.75-3.65 3.95.29.25.54.73.54 1.48 0 1.07-.01 1.93-.01 2.2 0 .21.15.46.55.38A8.013 8.013 0 0016 8c0-4.42-3.58-8-8-8z"/>
|
||||
</svg>
|
||||
GitHub
|
||||
</a></li>
|
||||
</ul>
|
||||
</div>
|
||||
</nav>
|
||||
`;
|
||||
|
||||
// Insert the navigation at the beginning of the body
|
||||
document.body.insertAdjacentHTML('afterbegin', navHTML);
|
||||
|
||||
// Update navigation links based on current page
|
||||
updateActiveNav();
|
||||
});
|
||||
|
||||
function updateActiveNav() {
|
||||
const currentPath = window.location.pathname;
|
||||
const navLinks = document.querySelectorAll('.horizontal-nav .nav-links a');
|
||||
|
||||
navLinks.forEach(link => {
|
||||
// Remove any existing active classes
|
||||
link.classList.remove('active');
|
||||
|
||||
// Add active class based on current path
|
||||
if (currentPath === '/' && link.getAttribute('href') === '/') {
|
||||
link.classList.add('active');
|
||||
} else if (currentPath.includes('/references/api_reference/') && link.getAttribute('href').includes('api_reference')) {
|
||||
link.classList.add('active');
|
||||
}
|
||||
});
|
||||
}
|
483
docs/_static/llama-stack-spec.html
vendored
483
docs/_static/llama-stack-spec.html
vendored
|
@ -633,6 +633,80 @@
|
|||
}
|
||||
}
|
||||
},
|
||||
"/v1/prompts": {
|
||||
"get": {
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "A ListPromptsResponse containing all prompts.",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/ListPromptsResponse"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"400": {
|
||||
"$ref": "#/components/responses/BadRequest400"
|
||||
},
|
||||
"429": {
|
||||
"$ref": "#/components/responses/TooManyRequests429"
|
||||
},
|
||||
"500": {
|
||||
"$ref": "#/components/responses/InternalServerError500"
|
||||
},
|
||||
"default": {
|
||||
"$ref": "#/components/responses/DefaultError"
|
||||
}
|
||||
},
|
||||
"tags": [
|
||||
"Prompts"
|
||||
],
|
||||
"description": "List all prompts.",
|
||||
"parameters": []
|
||||
},
|
||||
"post": {
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "The created Prompt resource.",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/Prompt"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"400": {
|
||||
"$ref": "#/components/responses/BadRequest400"
|
||||
},
|
||||
"429": {
|
||||
"$ref": "#/components/responses/TooManyRequests429"
|
||||
},
|
||||
"500": {
|
||||
"$ref": "#/components/responses/InternalServerError500"
|
||||
},
|
||||
"default": {
|
||||
"$ref": "#/components/responses/DefaultError"
|
||||
}
|
||||
},
|
||||
"tags": [
|
||||
"Prompts"
|
||||
],
|
||||
"description": "Create a new prompt.",
|
||||
"parameters": [],
|
||||
"requestBody": {
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/CreatePromptRequest"
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"/v1/agents/{agent_id}": {
|
||||
"get": {
|
||||
"responses": {
|
||||
|
@ -901,6 +975,143 @@
|
|||
]
|
||||
}
|
||||
},
|
||||
"/v1/prompts/{prompt_id}": {
|
||||
"get": {
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "A Prompt resource.",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/Prompt"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"400": {
|
||||
"$ref": "#/components/responses/BadRequest400"
|
||||
},
|
||||
"429": {
|
||||
"$ref": "#/components/responses/TooManyRequests429"
|
||||
},
|
||||
"500": {
|
||||
"$ref": "#/components/responses/InternalServerError500"
|
||||
},
|
||||
"default": {
|
||||
"$ref": "#/components/responses/DefaultError"
|
||||
}
|
||||
},
|
||||
"tags": [
|
||||
"Prompts"
|
||||
],
|
||||
"description": "Get a prompt by its identifier and optional version.",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "prompt_id",
|
||||
"in": "path",
|
||||
"description": "The identifier of the prompt to get.",
|
||||
"required": true,
|
||||
"schema": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "version",
|
||||
"in": "query",
|
||||
"description": "The version of the prompt to get (defaults to latest).",
|
||||
"required": false,
|
||||
"schema": {
|
||||
"type": "integer"
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
"post": {
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "The updated Prompt resource with incremented version.",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/Prompt"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"400": {
|
||||
"$ref": "#/components/responses/BadRequest400"
|
||||
},
|
||||
"429": {
|
||||
"$ref": "#/components/responses/TooManyRequests429"
|
||||
},
|
||||
"500": {
|
||||
"$ref": "#/components/responses/InternalServerError500"
|
||||
},
|
||||
"default": {
|
||||
"$ref": "#/components/responses/DefaultError"
|
||||
}
|
||||
},
|
||||
"tags": [
|
||||
"Prompts"
|
||||
],
|
||||
"description": "Update an existing prompt (increments version).",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "prompt_id",
|
||||
"in": "path",
|
||||
"description": "The identifier of the prompt to update.",
|
||||
"required": true,
|
||||
"schema": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
],
|
||||
"requestBody": {
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/UpdatePromptRequest"
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": true
|
||||
}
|
||||
},
|
||||
"delete": {
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK"
|
||||
},
|
||||
"400": {
|
||||
"$ref": "#/components/responses/BadRequest400"
|
||||
},
|
||||
"429": {
|
||||
"$ref": "#/components/responses/TooManyRequests429"
|
||||
},
|
||||
"500": {
|
||||
"$ref": "#/components/responses/InternalServerError500"
|
||||
},
|
||||
"default": {
|
||||
"$ref": "#/components/responses/DefaultError"
|
||||
}
|
||||
},
|
||||
"tags": [
|
||||
"Prompts"
|
||||
],
|
||||
"description": "Delete a prompt.",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "prompt_id",
|
||||
"in": "path",
|
||||
"description": "The identifier of the prompt to delete.",
|
||||
"required": true,
|
||||
"schema": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"/v1/inference/embeddings": {
|
||||
"post": {
|
||||
"responses": {
|
||||
|
@ -2836,6 +3047,49 @@
|
|||
]
|
||||
}
|
||||
},
|
||||
"/v1/prompts/{prompt_id}/versions": {
|
||||
"get": {
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "A ListPromptsResponse containing all versions of the prompt.",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/ListPromptsResponse"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"400": {
|
||||
"$ref": "#/components/responses/BadRequest400"
|
||||
},
|
||||
"429": {
|
||||
"$ref": "#/components/responses/TooManyRequests429"
|
||||
},
|
||||
"500": {
|
||||
"$ref": "#/components/responses/InternalServerError500"
|
||||
},
|
||||
"default": {
|
||||
"$ref": "#/components/responses/DefaultError"
|
||||
}
|
||||
},
|
||||
"tags": [
|
||||
"Prompts"
|
||||
],
|
||||
"description": "List all versions of a specific prompt.",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "prompt_id",
|
||||
"in": "path",
|
||||
"description": "The identifier of the prompt to list versions for.",
|
||||
"required": true,
|
||||
"schema": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"/v1/providers": {
|
||||
"get": {
|
||||
"responses": {
|
||||
|
@ -4129,7 +4383,7 @@
|
|||
"tags": [
|
||||
"Files"
|
||||
],
|
||||
"description": "Upload a file that can be used across various endpoints.\nThe file upload should be a multipart form request with:\n- file: The File object (not file name) to be uploaded.\n- purpose: The intended purpose of the uploaded file.",
|
||||
"description": "Upload a file that can be used across various endpoints.\nThe file upload should be a multipart form request with:\n- file: The File object (not file name) to be uploaded.\n- purpose: The intended purpose of the uploaded file.\n- expires_after: Optional form values describing expiration for the file. Expected expires_after[anchor] = \"created_at\", expires_after[seconds] = <int>. Seconds must be between 3600 and 2592000 (1 hour to 30 days).",
|
||||
"parameters": [],
|
||||
"requestBody": {
|
||||
"content": {
|
||||
|
@ -4143,11 +4397,33 @@
|
|||
},
|
||||
"purpose": {
|
||||
"$ref": "#/components/schemas/OpenAIFilePurpose"
|
||||
},
|
||||
"expires_after_anchor": {
|
||||
"oneOf": [
|
||||
{
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
]
|
||||
},
|
||||
"expires_after_seconds": {
|
||||
"oneOf": [
|
||||
{
|
||||
"type": "integer"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"file",
|
||||
"purpose"
|
||||
"purpose",
|
||||
"expires_after_anchor",
|
||||
"expires_after_seconds"
|
||||
]
|
||||
}
|
||||
}
|
||||
|
@ -4985,6 +5261,59 @@
|
|||
}
|
||||
}
|
||||
},
|
||||
"/v1/prompts/{prompt_id}/set-default-version": {
|
||||
"post": {
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "The prompt with the specified version now set as default.",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/Prompt"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"400": {
|
||||
"$ref": "#/components/responses/BadRequest400"
|
||||
},
|
||||
"429": {
|
||||
"$ref": "#/components/responses/TooManyRequests429"
|
||||
},
|
||||
"500": {
|
||||
"$ref": "#/components/responses/InternalServerError500"
|
||||
},
|
||||
"default": {
|
||||
"$ref": "#/components/responses/DefaultError"
|
||||
}
|
||||
},
|
||||
"tags": [
|
||||
"Prompts"
|
||||
],
|
||||
"description": "Set which version of a prompt should be the default in get_prompt (latest).",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "prompt_id",
|
||||
"in": "path",
|
||||
"description": "The identifier of the prompt.",
|
||||
"required": true,
|
||||
"schema": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
],
|
||||
"requestBody": {
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/SetDefaultVersionRequest"
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"/v1/post-training/supervised-fine-tune": {
|
||||
"post": {
|
||||
"responses": {
|
||||
|
@ -9648,6 +9977,65 @@
|
|||
],
|
||||
"title": "OpenAIResponseObjectStreamResponseWebSearchCallSearching"
|
||||
},
|
||||
"CreatePromptRequest": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"prompt": {
|
||||
"type": "string",
|
||||
"description": "The prompt text content with variable placeholders."
|
||||
},
|
||||
"variables": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
},
|
||||
"description": "List of variable names that can be used in the prompt template."
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"prompt"
|
||||
],
|
||||
"title": "CreatePromptRequest"
|
||||
},
|
||||
"Prompt": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"prompt": {
|
||||
"type": "string",
|
||||
"description": "The system prompt text with variable placeholders. Variables are only supported when using the Responses API."
|
||||
},
|
||||
"version": {
|
||||
"type": "integer",
|
||||
"description": "Version (integer starting at 1, incremented on save)"
|
||||
},
|
||||
"prompt_id": {
|
||||
"type": "string",
|
||||
"description": "Unique identifier formatted as 'pmpt_<48-digit-hash>'"
|
||||
},
|
||||
"variables": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
},
|
||||
"description": "List of prompt variable names that can be used in the prompt template"
|
||||
},
|
||||
"is_default": {
|
||||
"type": "boolean",
|
||||
"default": false,
|
||||
"description": "Boolean indicating whether this version is the default version for this prompt"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"version",
|
||||
"prompt_id",
|
||||
"variables",
|
||||
"is_default"
|
||||
],
|
||||
"title": "Prompt",
|
||||
"description": "A prompt resource representing a stored OpenAI Compatible prompt template in Llama Stack."
|
||||
},
|
||||
"OpenAIDeleteResponseObject": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
@ -10274,7 +10662,8 @@
|
|||
"scoring_function",
|
||||
"benchmark",
|
||||
"tool",
|
||||
"tool_group"
|
||||
"tool_group",
|
||||
"prompt"
|
||||
],
|
||||
"const": "benchmark",
|
||||
"default": "benchmark",
|
||||
|
@ -10901,7 +11290,8 @@
|
|||
"scoring_function",
|
||||
"benchmark",
|
||||
"tool",
|
||||
"tool_group"
|
||||
"tool_group",
|
||||
"prompt"
|
||||
],
|
||||
"const": "dataset",
|
||||
"default": "dataset",
|
||||
|
@ -11051,7 +11441,8 @@
|
|||
"scoring_function",
|
||||
"benchmark",
|
||||
"tool",
|
||||
"tool_group"
|
||||
"tool_group",
|
||||
"prompt"
|
||||
],
|
||||
"const": "model",
|
||||
"default": "model",
|
||||
|
@ -11316,7 +11707,8 @@
|
|||
"scoring_function",
|
||||
"benchmark",
|
||||
"tool",
|
||||
"tool_group"
|
||||
"tool_group",
|
||||
"prompt"
|
||||
],
|
||||
"const": "scoring_function",
|
||||
"default": "scoring_function",
|
||||
|
@ -11424,7 +11816,8 @@
|
|||
"scoring_function",
|
||||
"benchmark",
|
||||
"tool",
|
||||
"tool_group"
|
||||
"tool_group",
|
||||
"prompt"
|
||||
],
|
||||
"const": "shield",
|
||||
"default": "shield",
|
||||
|
@ -11669,7 +12062,8 @@
|
|||
"scoring_function",
|
||||
"benchmark",
|
||||
"tool",
|
||||
"tool_group"
|
||||
"tool_group",
|
||||
"prompt"
|
||||
],
|
||||
"const": "tool",
|
||||
"default": "tool",
|
||||
|
@ -11751,7 +12145,8 @@
|
|||
"scoring_function",
|
||||
"benchmark",
|
||||
"tool",
|
||||
"tool_group"
|
||||
"tool_group",
|
||||
"prompt"
|
||||
],
|
||||
"const": "tool_group",
|
||||
"default": "tool_group",
|
||||
|
@ -12045,7 +12440,8 @@
|
|||
"scoring_function",
|
||||
"benchmark",
|
||||
"tool",
|
||||
"tool_group"
|
||||
"tool_group",
|
||||
"prompt"
|
||||
],
|
||||
"const": "vector_db",
|
||||
"default": "vector_db",
|
||||
|
@ -12860,6 +13256,23 @@
|
|||
"title": "OpenAIResponseObjectWithInput",
|
||||
"description": "OpenAI response object extended with input context information."
|
||||
},
|
||||
"ListPromptsResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"data": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/components/schemas/Prompt"
|
||||
}
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"data"
|
||||
],
|
||||
"title": "ListPromptsResponse",
|
||||
"description": "Response model to list prompts."
|
||||
},
|
||||
"ListProvidersResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
@ -17106,6 +17519,20 @@
|
|||
"title": "ScoreBatchResponse",
|
||||
"description": "Response from batch scoring operations on datasets."
|
||||
},
|
||||
"SetDefaultVersionRequest": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"version": {
|
||||
"type": "integer",
|
||||
"description": "The version to set as default."
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"version"
|
||||
],
|
||||
"title": "SetDefaultVersionRequest"
|
||||
},
|
||||
"AlgorithmConfig": {
|
||||
"oneOf": [
|
||||
{
|
||||
|
@ -17390,6 +17817,37 @@
|
|||
"title": "SyntheticDataGenerationResponse",
|
||||
"description": "Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold."
|
||||
},
|
||||
"UpdatePromptRequest": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"prompt": {
|
||||
"type": "string",
|
||||
"description": "The updated prompt text content."
|
||||
},
|
||||
"version": {
|
||||
"type": "integer",
|
||||
"description": "The current version of the prompt being updated."
|
||||
},
|
||||
"variables": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
},
|
||||
"description": "Updated list of variable names that can be used in the prompt template."
|
||||
},
|
||||
"set_as_default": {
|
||||
"type": "boolean",
|
||||
"description": "Set the new version as the default (default=True)."
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"prompt",
|
||||
"version",
|
||||
"set_as_default"
|
||||
],
|
||||
"title": "UpdatePromptRequest"
|
||||
},
|
||||
"VersionInfo": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
@ -17515,6 +17973,10 @@
|
|||
{
|
||||
"name": "PostTraining (Coming Soon)"
|
||||
},
|
||||
{
|
||||
"name": "Prompts",
|
||||
"x-displayName": "Protocol for prompt management operations."
|
||||
},
|
||||
{
|
||||
"name": "Providers",
|
||||
"x-displayName": "Providers API for inspecting, listing, and modifying providers and their configurations."
|
||||
|
@ -17565,6 +18027,7 @@
|
|||
"Inspect",
|
||||
"Models",
|
||||
"PostTraining (Coming Soon)",
|
||||
"Prompts",
|
||||
"Providers",
|
||||
"Safety",
|
||||
"Scoring",
|
||||
|
|
346
docs/_static/llama-stack-spec.yaml
vendored
346
docs/_static/llama-stack-spec.yaml
vendored
|
@ -427,6 +427,58 @@ paths:
|
|||
schema:
|
||||
$ref: '#/components/schemas/CreateOpenaiResponseRequest'
|
||||
required: true
|
||||
/v1/prompts:
|
||||
get:
|
||||
responses:
|
||||
'200':
|
||||
description: >-
|
||||
A ListPromptsResponse containing all prompts.
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ListPromptsResponse'
|
||||
'400':
|
||||
$ref: '#/components/responses/BadRequest400'
|
||||
'429':
|
||||
$ref: >-
|
||||
#/components/responses/TooManyRequests429
|
||||
'500':
|
||||
$ref: >-
|
||||
#/components/responses/InternalServerError500
|
||||
default:
|
||||
$ref: '#/components/responses/DefaultError'
|
||||
tags:
|
||||
- Prompts
|
||||
description: List all prompts.
|
||||
parameters: []
|
||||
post:
|
||||
responses:
|
||||
'200':
|
||||
description: The created Prompt resource.
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/Prompt'
|
||||
'400':
|
||||
$ref: '#/components/responses/BadRequest400'
|
||||
'429':
|
||||
$ref: >-
|
||||
#/components/responses/TooManyRequests429
|
||||
'500':
|
||||
$ref: >-
|
||||
#/components/responses/InternalServerError500
|
||||
default:
|
||||
$ref: '#/components/responses/DefaultError'
|
||||
tags:
|
||||
- Prompts
|
||||
description: Create a new prompt.
|
||||
parameters: []
|
||||
requestBody:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/CreatePromptRequest'
|
||||
required: true
|
||||
/v1/agents/{agent_id}:
|
||||
get:
|
||||
responses:
|
||||
|
@ -616,6 +668,103 @@ paths:
|
|||
required: true
|
||||
schema:
|
||||
type: string
|
||||
/v1/prompts/{prompt_id}:
|
||||
get:
|
||||
responses:
|
||||
'200':
|
||||
description: A Prompt resource.
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/Prompt'
|
||||
'400':
|
||||
$ref: '#/components/responses/BadRequest400'
|
||||
'429':
|
||||
$ref: >-
|
||||
#/components/responses/TooManyRequests429
|
||||
'500':
|
||||
$ref: >-
|
||||
#/components/responses/InternalServerError500
|
||||
default:
|
||||
$ref: '#/components/responses/DefaultError'
|
||||
tags:
|
||||
- Prompts
|
||||
description: >-
|
||||
Get a prompt by its identifier and optional version.
|
||||
parameters:
|
||||
- name: prompt_id
|
||||
in: path
|
||||
description: The identifier of the prompt to get.
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
- name: version
|
||||
in: query
|
||||
description: >-
|
||||
The version of the prompt to get (defaults to latest).
|
||||
required: false
|
||||
schema:
|
||||
type: integer
|
||||
post:
|
||||
responses:
|
||||
'200':
|
||||
description: >-
|
||||
The updated Prompt resource with incremented version.
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/Prompt'
|
||||
'400':
|
||||
$ref: '#/components/responses/BadRequest400'
|
||||
'429':
|
||||
$ref: >-
|
||||
#/components/responses/TooManyRequests429
|
||||
'500':
|
||||
$ref: >-
|
||||
#/components/responses/InternalServerError500
|
||||
default:
|
||||
$ref: '#/components/responses/DefaultError'
|
||||
tags:
|
||||
- Prompts
|
||||
description: >-
|
||||
Update an existing prompt (increments version).
|
||||
parameters:
|
||||
- name: prompt_id
|
||||
in: path
|
||||
description: The identifier of the prompt to update.
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
requestBody:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/UpdatePromptRequest'
|
||||
required: true
|
||||
delete:
|
||||
responses:
|
||||
'200':
|
||||
description: OK
|
||||
'400':
|
||||
$ref: '#/components/responses/BadRequest400'
|
||||
'429':
|
||||
$ref: >-
|
||||
#/components/responses/TooManyRequests429
|
||||
'500':
|
||||
$ref: >-
|
||||
#/components/responses/InternalServerError500
|
||||
default:
|
||||
$ref: '#/components/responses/DefaultError'
|
||||
tags:
|
||||
- Prompts
|
||||
description: Delete a prompt.
|
||||
parameters:
|
||||
- name: prompt_id
|
||||
in: path
|
||||
description: The identifier of the prompt to delete.
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
/v1/inference/embeddings:
|
||||
post:
|
||||
responses:
|
||||
|
@ -1983,6 +2132,37 @@ paths:
|
|||
required: false
|
||||
schema:
|
||||
$ref: '#/components/schemas/Order'
|
||||
/v1/prompts/{prompt_id}/versions:
|
||||
get:
|
||||
responses:
|
||||
'200':
|
||||
description: >-
|
||||
A ListPromptsResponse containing all versions of the prompt.
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ListPromptsResponse'
|
||||
'400':
|
||||
$ref: '#/components/responses/BadRequest400'
|
||||
'429':
|
||||
$ref: >-
|
||||
#/components/responses/TooManyRequests429
|
||||
'500':
|
||||
$ref: >-
|
||||
#/components/responses/InternalServerError500
|
||||
default:
|
||||
$ref: '#/components/responses/DefaultError'
|
||||
tags:
|
||||
- Prompts
|
||||
description: List all versions of a specific prompt.
|
||||
parameters:
|
||||
- name: prompt_id
|
||||
in: path
|
||||
description: >-
|
||||
The identifier of the prompt to list versions for.
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
/v1/providers:
|
||||
get:
|
||||
responses:
|
||||
|
@ -2933,6 +3113,10 @@ paths:
|
|||
- file: The File object (not file name) to be uploaded.
|
||||
|
||||
- purpose: The intended purpose of the uploaded file.
|
||||
|
||||
- expires_after: Optional form values describing expiration for the file.
|
||||
Expected expires_after[anchor] = "created_at", expires_after[seconds] = <int>.
|
||||
Seconds must be between 3600 and 2592000 (1 hour to 30 days).
|
||||
parameters: []
|
||||
requestBody:
|
||||
content:
|
||||
|
@ -2945,9 +3129,19 @@ paths:
|
|||
format: binary
|
||||
purpose:
|
||||
$ref: '#/components/schemas/OpenAIFilePurpose'
|
||||
expires_after_anchor:
|
||||
oneOf:
|
||||
- type: string
|
||||
- type: 'null'
|
||||
expires_after_seconds:
|
||||
oneOf:
|
||||
- type: integer
|
||||
- type: 'null'
|
||||
required:
|
||||
- file
|
||||
- purpose
|
||||
- expires_after_anchor
|
||||
- expires_after_seconds
|
||||
required: true
|
||||
/v1/openai/v1/models:
|
||||
get:
|
||||
|
@ -3532,6 +3726,43 @@ paths:
|
|||
schema:
|
||||
$ref: '#/components/schemas/ScoreBatchRequest'
|
||||
required: true
|
||||
/v1/prompts/{prompt_id}/set-default-version:
|
||||
post:
|
||||
responses:
|
||||
'200':
|
||||
description: >-
|
||||
The prompt with the specified version now set as default.
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/Prompt'
|
||||
'400':
|
||||
$ref: '#/components/responses/BadRequest400'
|
||||
'429':
|
||||
$ref: >-
|
||||
#/components/responses/TooManyRequests429
|
||||
'500':
|
||||
$ref: >-
|
||||
#/components/responses/InternalServerError500
|
||||
default:
|
||||
$ref: '#/components/responses/DefaultError'
|
||||
tags:
|
||||
- Prompts
|
||||
description: >-
|
||||
Set which version of a prompt should be the default in get_prompt (latest).
|
||||
parameters:
|
||||
- name: prompt_id
|
||||
in: path
|
||||
description: The identifier of the prompt.
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
requestBody:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/SetDefaultVersionRequest'
|
||||
required: true
|
||||
/v1/post-training/supervised-fine-tune:
|
||||
post:
|
||||
responses:
|
||||
|
@ -7134,6 +7365,61 @@ components:
|
|||
- type
|
||||
title: >-
|
||||
OpenAIResponseObjectStreamResponseWebSearchCallSearching
|
||||
CreatePromptRequest:
|
||||
type: object
|
||||
properties:
|
||||
prompt:
|
||||
type: string
|
||||
description: >-
|
||||
The prompt text content with variable placeholders.
|
||||
variables:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
description: >-
|
||||
List of variable names that can be used in the prompt template.
|
||||
additionalProperties: false
|
||||
required:
|
||||
- prompt
|
||||
title: CreatePromptRequest
|
||||
Prompt:
|
||||
type: object
|
||||
properties:
|
||||
prompt:
|
||||
type: string
|
||||
description: >-
|
||||
The system prompt text with variable placeholders. Variables are only
|
||||
supported when using the Responses API.
|
||||
version:
|
||||
type: integer
|
||||
description: >-
|
||||
Version (integer starting at 1, incremented on save)
|
||||
prompt_id:
|
||||
type: string
|
||||
description: >-
|
||||
Unique identifier formatted as 'pmpt_<48-digit-hash>'
|
||||
variables:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
description: >-
|
||||
List of prompt variable names that can be used in the prompt template
|
||||
is_default:
|
||||
type: boolean
|
||||
default: false
|
||||
description: >-
|
||||
Boolean indicating whether this version is the default version for this
|
||||
prompt
|
||||
additionalProperties: false
|
||||
required:
|
||||
- version
|
||||
- prompt_id
|
||||
- variables
|
||||
- is_default
|
||||
title: Prompt
|
||||
description: >-
|
||||
A prompt resource representing a stored OpenAI Compatible prompt template
|
||||
in Llama Stack.
|
||||
OpenAIDeleteResponseObject:
|
||||
type: object
|
||||
properties:
|
||||
|
@ -7607,6 +7893,7 @@ components:
|
|||
- benchmark
|
||||
- tool
|
||||
- tool_group
|
||||
- prompt
|
||||
const: benchmark
|
||||
default: benchmark
|
||||
description: The resource type, always benchmark
|
||||
|
@ -8093,6 +8380,7 @@ components:
|
|||
- benchmark
|
||||
- tool
|
||||
- tool_group
|
||||
- prompt
|
||||
const: dataset
|
||||
default: dataset
|
||||
description: >-
|
||||
|
@ -8205,6 +8493,7 @@ components:
|
|||
- benchmark
|
||||
- tool
|
||||
- tool_group
|
||||
- prompt
|
||||
const: model
|
||||
default: model
|
||||
description: >-
|
||||
|
@ -8396,6 +8685,7 @@ components:
|
|||
- benchmark
|
||||
- tool
|
||||
- tool_group
|
||||
- prompt
|
||||
const: scoring_function
|
||||
default: scoring_function
|
||||
description: >-
|
||||
|
@ -8472,6 +8762,7 @@ components:
|
|||
- benchmark
|
||||
- tool
|
||||
- tool_group
|
||||
- prompt
|
||||
const: shield
|
||||
default: shield
|
||||
description: The resource type, always shield
|
||||
|
@ -8651,6 +8942,7 @@ components:
|
|||
- benchmark
|
||||
- tool
|
||||
- tool_group
|
||||
- prompt
|
||||
const: tool
|
||||
default: tool
|
||||
description: Type of resource, always 'tool'
|
||||
|
@ -8709,6 +9001,7 @@ components:
|
|||
- benchmark
|
||||
- tool
|
||||
- tool_group
|
||||
- prompt
|
||||
const: tool_group
|
||||
default: tool_group
|
||||
description: Type of resource, always 'tool_group'
|
||||
|
@ -8937,6 +9230,7 @@ components:
|
|||
- benchmark
|
||||
- tool
|
||||
- tool_group
|
||||
- prompt
|
||||
const: vector_db
|
||||
default: vector_db
|
||||
description: >-
|
||||
|
@ -9563,6 +9857,18 @@ components:
|
|||
title: OpenAIResponseObjectWithInput
|
||||
description: >-
|
||||
OpenAI response object extended with input context information.
|
||||
ListPromptsResponse:
|
||||
type: object
|
||||
properties:
|
||||
data:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/Prompt'
|
||||
additionalProperties: false
|
||||
required:
|
||||
- data
|
||||
title: ListPromptsResponse
|
||||
description: Response model to list prompts.
|
||||
ListProvidersResponse:
|
||||
type: object
|
||||
properties:
|
||||
|
@ -12708,6 +13014,16 @@ components:
|
|||
title: ScoreBatchResponse
|
||||
description: >-
|
||||
Response from batch scoring operations on datasets.
|
||||
SetDefaultVersionRequest:
|
||||
type: object
|
||||
properties:
|
||||
version:
|
||||
type: integer
|
||||
description: The version to set as default.
|
||||
additionalProperties: false
|
||||
required:
|
||||
- version
|
||||
title: SetDefaultVersionRequest
|
||||
AlgorithmConfig:
|
||||
oneOf:
|
||||
- $ref: '#/components/schemas/LoraFinetuningConfig'
|
||||
|
@ -12904,6 +13220,32 @@ components:
|
|||
description: >-
|
||||
Response from the synthetic data generation. Batch of (prompt, response, score)
|
||||
tuples that pass the threshold.
|
||||
UpdatePromptRequest:
|
||||
type: object
|
||||
properties:
|
||||
prompt:
|
||||
type: string
|
||||
description: The updated prompt text content.
|
||||
version:
|
||||
type: integer
|
||||
description: >-
|
||||
The current version of the prompt being updated.
|
||||
variables:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
description: >-
|
||||
Updated list of variable names that can be used in the prompt template.
|
||||
set_as_default:
|
||||
type: boolean
|
||||
description: >-
|
||||
Set the new version as the default (default=True).
|
||||
additionalProperties: false
|
||||
required:
|
||||
- prompt
|
||||
- version
|
||||
- set_as_default
|
||||
title: UpdatePromptRequest
|
||||
VersionInfo:
|
||||
type: object
|
||||
properties:
|
||||
|
@ -13015,6 +13357,9 @@ tags:
|
|||
- name: Inspect
|
||||
- name: Models
|
||||
- name: PostTraining (Coming Soon)
|
||||
- name: Prompts
|
||||
x-displayName: >-
|
||||
Protocol for prompt management operations.
|
||||
- name: Providers
|
||||
x-displayName: >-
|
||||
Providers API for inspecting, listing, and modifying providers and their configurations.
|
||||
|
@ -13042,6 +13387,7 @@ x-tagGroups:
|
|||
- Inspect
|
||||
- Models
|
||||
- PostTraining (Coming Soon)
|
||||
- Prompts
|
||||
- Providers
|
||||
- Safety
|
||||
- Scoring
|
||||
|
|
|
@ -33,7 +33,7 @@ The list of open-benchmarks we currently support:
|
|||
- [MMMU](https://arxiv.org/abs/2311.16502) (A Massive Multi-discipline Multimodal Understanding and Reasoning Benchmark for Expert AGI)]: Benchmark designed to evaluate multimodal models.
|
||||
|
||||
|
||||
You can follow this [contributing guide](https://llama-stack.readthedocs.io/en/latest/references/evals_reference/index.html#open-benchmark-contributing-guide) to add more open-benchmarks to Llama Stack
|
||||
You can follow this [contributing guide](../references/evals_reference/index.md#open-benchmark-contributing-guide) to add more open-benchmarks to Llama Stack
|
||||
|
||||
#### Run evaluation on open-benchmarks via CLI
|
||||
|
||||
|
|
|
@ -35,3 +35,6 @@ device: cpu
|
|||
|
||||
```
|
||||
|
||||
[Find more detailed information here!](huggingface.md)
|
||||
|
||||
|
||||
|
|
|
@ -22,3 +22,4 @@ checkpoint_format: meta
|
|||
|
||||
```
|
||||
|
||||
[Find more detailed information here!](torchtune.md)
|
||||
|
|
|
@ -88,7 +88,7 @@ Interactive pages for users to play with and explore Llama Stack API capabilitie
|
|||
- **API Resources**: Inspect Llama Stack API resources
|
||||
- This page allows you to inspect Llama Stack API resources (`models`, `datasets`, `memory_banks`, `benchmarks`, `shields`).
|
||||
- Under the hood, it uses Llama Stack's `/<resources>/list` API to get information about each resources.
|
||||
- Please visit [Core Concepts](https://llama-stack.readthedocs.io/en/latest/concepts/index.html) for more details about the resources.
|
||||
- Please visit [Core Concepts](../../concepts/index.md) for more details about the resources.
|
||||
|
||||
### Starting the Llama Stack Playground
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
Llama Stack (LLS) provides two different APIs for building AI applications with tool calling capabilities: the **Agents API** and the **OpenAI Responses API**. While both enable AI systems to use tools, and maintain full conversation history, they serve different use cases and have distinct characteristics.
|
||||
|
||||
```{note}
|
||||
For simple and basic inferencing, you may want to use the [Chat Completions API](https://llama-stack.readthedocs.io/en/latest/providers/index.html#chat-completions) directly, before progressing to Agents or Responses API.
|
||||
**Note:** For simple and basic inferencing, you may want to use the [Chat Completions API](../providers/openai.md#chat-completions) directly, before progressing to Agents or Responses API.
|
||||
```
|
||||
|
||||
## Overview
|
||||
|
@ -173,7 +173,7 @@ Both APIs demonstrate distinct strengths that make them valuable on their own fo
|
|||
|
||||
## For More Information
|
||||
|
||||
- **LLS Agents API**: For detailed information on creating and managing agents, see the [Agents documentation](https://llama-stack.readthedocs.io/en/latest/building_applications/agent.html)
|
||||
- **LLS Agents API**: For detailed information on creating and managing agents, see the [Agents documentation](agent.md)
|
||||
- **OpenAI Responses API**: For information on using the OpenAI-compatible responses API, see the [OpenAI API documentation](https://platform.openai.com/docs/api-reference/responses)
|
||||
- **Chat Completions API**: For the default backend API used by Agents, see the [Chat Completions providers documentation](https://llama-stack.readthedocs.io/en/latest/providers/index.html#chat-completions)
|
||||
- **Agent Execution Loop**: For understanding how agents process turns and steps in their execution, see the [Agent Execution Loop documentation](https://llama-stack.readthedocs.io/en/latest/building_applications/agent_execution_loop.html)
|
||||
- **Chat Completions API**: For the default backend API used by Agents, see the [Chat Completions providers documentation](../providers/openai.md#chat-completions)
|
||||
- **Agent Execution Loop**: For understanding how agents process turns and steps in their execution, see the [Agent Execution Loop documentation](agent_execution_loop.md)
|
||||
|
|
|
@ -6,4 +6,4 @@ While there is a lot of flexibility to mix-and-match providers, often users will
|
|||
|
||||
**Locally Hosted Distro**: You may want to run Llama Stack on your own hardware. Typically though, you still need to use Inference via an external service. You can use providers like HuggingFace TGI, Fireworks, Together, etc. for this purpose. Or you may have access to GPUs and can run a [vLLM](https://github.com/vllm-project/vllm) or [NVIDIA NIM](https://build.nvidia.com/nim?filters=nimType%3Anim_type_run_anywhere&q=llama) instance. If you "just" have a regular desktop machine, you can use [Ollama](https://ollama.com/) for inference. To provide convenient quick access to these options, we provide a number of such pre-configured locally-hosted Distros.
|
||||
|
||||
**On-device Distro**: To run Llama Stack directly on an edge device (mobile phone or a tablet), we provide Distros for [iOS](https://llama-stack.readthedocs.io/en/latest/distributions/ondevice_distro/ios_sdk.html) and [Android](https://llama-stack.readthedocs.io/en/latest/distributions/ondevice_distro/android_sdk.html)
|
||||
**On-device Distro**: To run Llama Stack directly on an edge device (mobile phone or a tablet), we provide Distros for [iOS](../distributions/ondevice_distro/ios_sdk.md) and [Android](../distributions/ondevice_distro/android_sdk.md)
|
||||
|
|
|
@ -131,6 +131,7 @@ html_static_path = ["../_static"]
|
|||
def setup(app):
|
||||
app.add_css_file("css/my_theme.css")
|
||||
app.add_js_file("js/detect_theme.js")
|
||||
app.add_js_file("js/horizontal_nav.js")
|
||||
app.add_js_file("js/keyboard_shortcuts.js")
|
||||
|
||||
def dockerhub_role(name, rawtext, text, lineno, inliner, options={}, content=[]):
|
||||
|
|
|
@ -35,5 +35,5 @@ testing/record-replay
|
|||
|
||||
### Benchmarking
|
||||
|
||||
```{include} ../../../docs/source/distributions/k8s-benchmark/README.md
|
||||
```{include} ../../../benchmarking/k8s-benchmark/README.md
|
||||
```
|
||||
|
|
|
@ -14,6 +14,13 @@ Here are some example PRs to help you get started:
|
|||
- [Nvidia Inference Implementation](https://github.com/meta-llama/llama-stack/pull/355)
|
||||
- [Model context protocol Tool Runtime](https://github.com/meta-llama/llama-stack/pull/665)
|
||||
|
||||
## Guidelines for creating Internal or External Providers
|
||||
|
||||
|**Type** |Internal (In-tree) |External (out-of-tree)
|
||||
|---------|-------------------|---------------------|
|
||||
|**Description** |A provider that is directly in the Llama Stack code|A provider that is outside of the Llama stack core codebase but is still accessible and usable by Llama Stack.
|
||||
|**Benefits** |Ability to interact with the provider with minimal additional configurations or installations| Contributors do not have to add directly to the code to create providers accessible on Llama Stack. Keep provider-specific code separate from the core Llama Stack code.
|
||||
|
||||
## Inference Provider Patterns
|
||||
|
||||
When implementing Inference providers for OpenAI-compatible APIs, Llama Stack provides several mixin classes to simplify development and ensure consistent behavior across providers.
|
||||
|
|
|
@ -40,18 +40,15 @@ The system patches OpenAI and Ollama client methods to intercept calls before th
|
|||
|
||||
### Storage Architecture
|
||||
|
||||
Recordings use a two-tier storage system optimized for both speed and debuggability:
|
||||
Recordings are stored as JSON files in the recording directory. They are looked up by their request hash.
|
||||
|
||||
```
|
||||
recordings/
|
||||
├── index.sqlite # Fast lookup by request hash
|
||||
└── responses/
|
||||
├── abc123def456.json # Individual response files
|
||||
└── def789ghi012.json
|
||||
```
|
||||
|
||||
**SQLite index** enables O(log n) hash lookups and metadata queries without loading response bodies.
|
||||
|
||||
**JSON files** store complete request/response pairs in human-readable format for debugging.
|
||||
|
||||
## Recording Modes
|
||||
|
@ -166,8 +163,8 @@ This preserves type safety - when replayed, you get the same Pydantic objects wi
|
|||
Control recording behavior globally:
|
||||
|
||||
```bash
|
||||
export LLAMA_STACK_TEST_INFERENCE_MODE=replay
|
||||
export LLAMA_STACK_TEST_RECORDING_DIR=/path/to/recordings
|
||||
export LLAMA_STACK_TEST_INFERENCE_MODE=replay # this is the default
|
||||
export LLAMA_STACK_TEST_RECORDING_DIR=/path/to/recordings # default is tests/integration/recordings
|
||||
pytest tests/integration/
|
||||
```
|
||||
|
||||
|
|
|
@ -354,6 +354,47 @@ You can easily validate a request by running:
|
|||
curl -s -L -H "Authorization: Bearer $(cat llama-stack-auth-token)" http://127.0.0.1:8321/v1/providers
|
||||
```
|
||||
|
||||
#### Kubernetes Authentication Provider
|
||||
|
||||
The server can be configured to use Kubernetes SelfSubjectReview API to validate tokens directly against the Kubernetes API server:
|
||||
|
||||
```yaml
|
||||
server:
|
||||
auth:
|
||||
provider_config:
|
||||
type: "kubernetes"
|
||||
api_server_url: "https://kubernetes.default.svc"
|
||||
claims_mapping:
|
||||
username: "roles"
|
||||
groups: "roles"
|
||||
uid: "uid_attr"
|
||||
verify_tls: true
|
||||
tls_cafile: "/path/to/ca.crt"
|
||||
```
|
||||
|
||||
Configuration options:
|
||||
- `api_server_url`: The Kubernetes API server URL (e.g., https://kubernetes.default.svc:6443)
|
||||
- `verify_tls`: Whether to verify TLS certificates (default: true)
|
||||
- `tls_cafile`: Path to CA certificate file for TLS verification
|
||||
- `claims_mapping`: Mapping of Kubernetes user claims to access attributes
|
||||
|
||||
The provider validates tokens by sending a SelfSubjectReview request to the Kubernetes API server at `/apis/authentication.k8s.io/v1/selfsubjectreviews`. The provider extracts user information from the response:
|
||||
- Username from the `userInfo.username` field
|
||||
- Groups from the `userInfo.groups` field
|
||||
- UID from the `userInfo.uid` field
|
||||
|
||||
To obtain a token for testing:
|
||||
```bash
|
||||
kubectl create namespace llama-stack
|
||||
kubectl create serviceaccount llama-stack-auth -n llama-stack
|
||||
kubectl create token llama-stack-auth -n llama-stack > llama-stack-auth-token
|
||||
```
|
||||
|
||||
You can validate a request by running:
|
||||
```bash
|
||||
curl -s -L -H "Authorization: Bearer $(cat llama-stack-auth-token)" http://127.0.0.1:8321/v1/providers
|
||||
```
|
||||
|
||||
#### GitHub Token Provider
|
||||
Validates GitHub personal access tokens or OAuth tokens directly:
|
||||
```yaml
|
||||
|
|
|
@ -27,7 +27,7 @@ Then, you can access the APIs like `models` and `inference` on the client and ca
|
|||
response = client.models.list()
|
||||
```
|
||||
|
||||
If you've created a [custom distribution](https://llama-stack.readthedocs.io/en/latest/distributions/building_distro.html), you can also use the run.yaml configuration file directly:
|
||||
If you've created a [custom distribution](building_distro.md), you can also use the run.yaml configuration file directly:
|
||||
|
||||
```python
|
||||
client = LlamaStackAsLibraryClient(config_path)
|
||||
|
|
|
@ -22,17 +22,17 @@ else
|
|||
fi
|
||||
|
||||
if [ -z "${GITHUB_CLIENT_ID:-}" ]; then
|
||||
echo "ERROR: GITHUB_CLIENT_ID not set. You need it for Github login to work. Refer to https://llama-stack.readthedocs.io/en/latest/deploying/index.html#kubernetes-deployment-guide"
|
||||
echo "ERROR: GITHUB_CLIENT_ID not set. You need it for Github login to work. See the Kubernetes Deployment Guide in the Llama Stack documentation."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ -z "${GITHUB_CLIENT_SECRET:-}" ]; then
|
||||
echo "ERROR: GITHUB_CLIENT_SECRET not set. You need it for Github login to work. Refer to https://llama-stack.readthedocs.io/en/latest/deploying/index.html#kubernetes-deployment-guide"
|
||||
echo "ERROR: GITHUB_CLIENT_SECRET not set. You need it for Github login to work. See the Kubernetes Deployment Guide in the Llama Stack documentation."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ -z "${LLAMA_STACK_UI_URL:-}" ]; then
|
||||
echo "ERROR: LLAMA_STACK_UI_URL not set. Should be set to the external URL of the UI (excluding port). You need it for Github login to work. Refer to https://llama-stack.readthedocs.io/en/latest/deploying/index.html#kubernetes-deployment-guide"
|
||||
echo "ERROR: LLAMA_STACK_UI_URL not set. Should be set to the external URL of the UI (excluding port). You need it for Github login to work. See the Kubernetes Deployment Guide in the Llama Stack documentation."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
|
|
@ -1,137 +1,55 @@
|
|||
apiVersion: v1
|
||||
data:
|
||||
stack_run_config.yaml: |
|
||||
version: '2'
|
||||
image_name: kubernetes-demo
|
||||
apis:
|
||||
- agents
|
||||
- inference
|
||||
- safety
|
||||
- telemetry
|
||||
- tool_runtime
|
||||
- vector_io
|
||||
providers:
|
||||
inference:
|
||||
- provider_id: vllm-inference
|
||||
provider_type: remote::vllm
|
||||
config:
|
||||
url: ${env.VLLM_URL:=http://localhost:8000/v1}
|
||||
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
|
||||
api_token: ${env.VLLM_API_TOKEN:=fake}
|
||||
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
|
||||
- provider_id: vllm-safety
|
||||
provider_type: remote::vllm
|
||||
config:
|
||||
url: ${env.VLLM_SAFETY_URL:=http://localhost:8000/v1}
|
||||
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
|
||||
api_token: ${env.VLLM_API_TOKEN:=fake}
|
||||
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
|
||||
- provider_id: sentence-transformers
|
||||
provider_type: inline::sentence-transformers
|
||||
config: {}
|
||||
vector_io:
|
||||
- provider_id: ${env.ENABLE_CHROMADB:+chromadb}
|
||||
provider_type: remote::chromadb
|
||||
config:
|
||||
url: ${env.CHROMADB_URL:=}
|
||||
kvstore:
|
||||
type: postgres
|
||||
host: ${env.POSTGRES_HOST:=localhost}
|
||||
port: ${env.POSTGRES_PORT:=5432}
|
||||
db: ${env.POSTGRES_DB:=llamastack}
|
||||
user: ${env.POSTGRES_USER:=llamastack}
|
||||
password: ${env.POSTGRES_PASSWORD:=llamastack}
|
||||
safety:
|
||||
- provider_id: llama-guard
|
||||
provider_type: inline::llama-guard
|
||||
config:
|
||||
excluded_categories: []
|
||||
agents:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
persistence_store:
|
||||
type: postgres
|
||||
host: ${env.POSTGRES_HOST:=localhost}
|
||||
port: ${env.POSTGRES_PORT:=5432}
|
||||
db: ${env.POSTGRES_DB:=llamastack}
|
||||
user: ${env.POSTGRES_USER:=llamastack}
|
||||
password: ${env.POSTGRES_PASSWORD:=llamastack}
|
||||
responses_store:
|
||||
type: postgres
|
||||
host: ${env.POSTGRES_HOST:=localhost}
|
||||
port: ${env.POSTGRES_PORT:=5432}
|
||||
db: ${env.POSTGRES_DB:=llamastack}
|
||||
user: ${env.POSTGRES_USER:=llamastack}
|
||||
password: ${env.POSTGRES_PASSWORD:=llamastack}
|
||||
telemetry:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: "${env.OTEL_SERVICE_NAME:=\u200B}"
|
||||
sinks: ${env.TELEMETRY_SINKS:=console}
|
||||
tool_runtime:
|
||||
- provider_id: brave-search
|
||||
provider_type: remote::brave-search
|
||||
config:
|
||||
api_key: ${env.BRAVE_SEARCH_API_KEY:+}
|
||||
max_results: 3
|
||||
- provider_id: tavily-search
|
||||
provider_type: remote::tavily-search
|
||||
config:
|
||||
api_key: ${env.TAVILY_SEARCH_API_KEY:+}
|
||||
max_results: 3
|
||||
- provider_id: rag-runtime
|
||||
provider_type: inline::rag-runtime
|
||||
config: {}
|
||||
- provider_id: model-context-protocol
|
||||
provider_type: remote::model-context-protocol
|
||||
config: {}
|
||||
metadata_store:
|
||||
type: postgres
|
||||
host: ${env.POSTGRES_HOST:=localhost}
|
||||
port: ${env.POSTGRES_PORT:=5432}
|
||||
db: ${env.POSTGRES_DB:=llamastack}
|
||||
user: ${env.POSTGRES_USER:=llamastack}
|
||||
password: ${env.POSTGRES_PASSWORD:=llamastack}
|
||||
table_name: llamastack_kvstore
|
||||
inference_store:
|
||||
type: postgres
|
||||
host: ${env.POSTGRES_HOST:=localhost}
|
||||
port: ${env.POSTGRES_PORT:=5432}
|
||||
db: ${env.POSTGRES_DB:=llamastack}
|
||||
user: ${env.POSTGRES_USER:=llamastack}
|
||||
password: ${env.POSTGRES_PASSWORD:=llamastack}
|
||||
models:
|
||||
- metadata:
|
||||
embedding_dimension: 384
|
||||
model_id: all-MiniLM-L6-v2
|
||||
provider_id: sentence-transformers
|
||||
model_type: embedding
|
||||
- metadata: {}
|
||||
model_id: ${env.INFERENCE_MODEL}
|
||||
provider_id: vllm-inference
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: ${env.SAFETY_MODEL:=meta-llama/Llama-Guard-3-1B}
|
||||
provider_id: vllm-safety
|
||||
model_type: llm
|
||||
shields:
|
||||
- shield_id: ${env.SAFETY_MODEL:=meta-llama/Llama-Guard-3-1B}
|
||||
vector_dbs: []
|
||||
datasets: []
|
||||
scoring_fns: []
|
||||
benchmarks: []
|
||||
tool_groups:
|
||||
- toolgroup_id: builtin::websearch
|
||||
provider_id: tavily-search
|
||||
- toolgroup_id: builtin::rag
|
||||
provider_id: rag-runtime
|
||||
server:
|
||||
port: 8321
|
||||
auth:
|
||||
provider_config:
|
||||
type: github_token
|
||||
stack_run_config.yaml: "version: '2'\nimage_name: kubernetes-demo\napis:\n- agents\n-
|
||||
inference\n- files\n- safety\n- telemetry\n- tool_runtime\n- vector_io\nproviders:\n
|
||||
\ inference:\n - provider_id: vllm-inference\n provider_type: remote::vllm\n
|
||||
\ config:\n url: ${env.VLLM_URL:=http://localhost:8000/v1}\n max_tokens:
|
||||
${env.VLLM_MAX_TOKENS:=4096}\n api_token: ${env.VLLM_API_TOKEN:=fake}\n tls_verify:
|
||||
${env.VLLM_TLS_VERIFY:=true}\n - provider_id: vllm-safety\n provider_type:
|
||||
remote::vllm\n config:\n url: ${env.VLLM_SAFETY_URL:=http://localhost:8000/v1}\n
|
||||
\ max_tokens: ${env.VLLM_MAX_TOKENS:=4096}\n api_token: ${env.VLLM_API_TOKEN:=fake}\n
|
||||
\ tls_verify: ${env.VLLM_TLS_VERIFY:=true}\n - provider_id: sentence-transformers\n
|
||||
\ provider_type: inline::sentence-transformers\n config: {}\n vector_io:\n
|
||||
\ - provider_id: ${env.ENABLE_CHROMADB:+chromadb}\n provider_type: remote::chromadb\n
|
||||
\ config:\n url: ${env.CHROMADB_URL:=}\n kvstore:\n type: postgres\n
|
||||
\ host: ${env.POSTGRES_HOST:=localhost}\n port: ${env.POSTGRES_PORT:=5432}\n
|
||||
\ db: ${env.POSTGRES_DB:=llamastack}\n user: ${env.POSTGRES_USER:=llamastack}\n
|
||||
\ password: ${env.POSTGRES_PASSWORD:=llamastack}\n files:\n - provider_id:
|
||||
meta-reference-files\n provider_type: inline::localfs\n config:\n storage_dir:
|
||||
${env.FILES_STORAGE_DIR:=~/.llama/distributions/starter/files}\n metadata_store:\n
|
||||
\ type: sqlite\n db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/files_metadata.db
|
||||
\ \n safety:\n - provider_id: llama-guard\n provider_type: inline::llama-guard\n
|
||||
\ config:\n excluded_categories: []\n agents:\n - provider_id: meta-reference\n
|
||||
\ provider_type: inline::meta-reference\n config:\n persistence_store:\n
|
||||
\ type: postgres\n host: ${env.POSTGRES_HOST:=localhost}\n port:
|
||||
${env.POSTGRES_PORT:=5432}\n db: ${env.POSTGRES_DB:=llamastack}\n user:
|
||||
${env.POSTGRES_USER:=llamastack}\n password: ${env.POSTGRES_PASSWORD:=llamastack}\n
|
||||
\ responses_store:\n type: postgres\n host: ${env.POSTGRES_HOST:=localhost}\n
|
||||
\ port: ${env.POSTGRES_PORT:=5432}\n db: ${env.POSTGRES_DB:=llamastack}\n
|
||||
\ user: ${env.POSTGRES_USER:=llamastack}\n password: ${env.POSTGRES_PASSWORD:=llamastack}\n
|
||||
\ telemetry:\n - provider_id: meta-reference\n provider_type: inline::meta-reference\n
|
||||
\ config:\n service_name: \"${env.OTEL_SERVICE_NAME:=\\u200B}\"\n sinks:
|
||||
${env.TELEMETRY_SINKS:=console}\n tool_runtime:\n - provider_id: brave-search\n
|
||||
\ provider_type: remote::brave-search\n config:\n api_key: ${env.BRAVE_SEARCH_API_KEY:+}\n
|
||||
\ max_results: 3\n - provider_id: tavily-search\n provider_type: remote::tavily-search\n
|
||||
\ config:\n api_key: ${env.TAVILY_SEARCH_API_KEY:+}\n max_results:
|
||||
3\n - provider_id: rag-runtime\n provider_type: inline::rag-runtime\n config:
|
||||
{}\n - provider_id: model-context-protocol\n provider_type: remote::model-context-protocol\n
|
||||
\ config: {}\nmetadata_store:\n type: postgres\n host: ${env.POSTGRES_HOST:=localhost}\n
|
||||
\ port: ${env.POSTGRES_PORT:=5432}\n db: ${env.POSTGRES_DB:=llamastack}\n user:
|
||||
${env.POSTGRES_USER:=llamastack}\n password: ${env.POSTGRES_PASSWORD:=llamastack}\n
|
||||
\ table_name: llamastack_kvstore\ninference_store:\n type: postgres\n host:
|
||||
${env.POSTGRES_HOST:=localhost}\n port: ${env.POSTGRES_PORT:=5432}\n db: ${env.POSTGRES_DB:=llamastack}\n
|
||||
\ user: ${env.POSTGRES_USER:=llamastack}\n password: ${env.POSTGRES_PASSWORD:=llamastack}\nmodels:\n-
|
||||
metadata:\n embedding_dimension: 384\n model_id: all-MiniLM-L6-v2\n provider_id:
|
||||
sentence-transformers\n model_type: embedding\n- metadata: {}\n model_id: ${env.INFERENCE_MODEL}\n
|
||||
\ provider_id: vllm-inference\n model_type: llm\n- metadata: {}\n model_id:
|
||||
${env.SAFETY_MODEL:=meta-llama/Llama-Guard-3-1B}\n provider_id: vllm-safety\n
|
||||
\ model_type: llm\nshields:\n- shield_id: ${env.SAFETY_MODEL:=meta-llama/Llama-Guard-3-1B}\nvector_dbs:
|
||||
[]\ndatasets: []\nscoring_fns: []\nbenchmarks: []\ntool_groups:\n- toolgroup_id:
|
||||
builtin::websearch\n provider_id: tavily-search\n- toolgroup_id: builtin::rag\n
|
||||
\ provider_id: rag-runtime\nserver:\n port: 8321\n auth:\n provider_config:\n
|
||||
\ type: github_token\n"
|
||||
kind: ConfigMap
|
||||
metadata:
|
||||
creationTimestamp: null
|
||||
|
|
|
@ -3,6 +3,7 @@ image_name: kubernetes-demo
|
|||
apis:
|
||||
- agents
|
||||
- inference
|
||||
- files
|
||||
- safety
|
||||
- telemetry
|
||||
- tool_runtime
|
||||
|
@ -38,6 +39,14 @@ providers:
|
|||
db: ${env.POSTGRES_DB:=llamastack}
|
||||
user: ${env.POSTGRES_USER:=llamastack}
|
||||
password: ${env.POSTGRES_PASSWORD:=llamastack}
|
||||
files:
|
||||
- provider_id: meta-reference-files
|
||||
provider_type: inline::localfs
|
||||
config:
|
||||
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/starter/files}
|
||||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/files_metadata.db
|
||||
safety:
|
||||
- provider_id: llama-guard
|
||||
provider_type: inline::llama-guard
|
||||
|
|
|
@ -66,7 +66,7 @@ llama stack run starter --port 5050
|
|||
|
||||
Ensure the Llama Stack server version is the same as the Kotlin SDK Library for maximum compatibility.
|
||||
|
||||
Other inference providers: [Table](https://llama-stack.readthedocs.io/en/latest/index.html#supported-llama-stack-implementations)
|
||||
Other inference providers: [Table](../../index.md#supported-llama-stack-implementations)
|
||||
|
||||
How to set remote localhost in Demo App: [Settings](https://github.com/meta-llama/llama-stack-client-kotlin/tree/latest-release/examples/android_app#settings)
|
||||
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
orphan: true
|
||||
---
|
||||
<!-- This file was auto-generated by distro_codegen.py, please edit source -->
|
||||
# Meta Reference Distribution
|
||||
# Meta Reference GPU Distribution
|
||||
|
||||
```{toctree}
|
||||
:maxdepth: 2
|
||||
|
@ -41,7 +41,7 @@ The following environment variables can be configured:
|
|||
|
||||
## Prerequisite: Downloading Models
|
||||
|
||||
Please use `llama model list --downloaded` to check that you have llama model checkpoints downloaded in `~/.llama` before proceeding. See [installation guide](https://llama-stack.readthedocs.io/en/latest/references/llama_cli_reference/download_models.html) here to download the models. Run `llama model list` to see the available models to download, and `llama model download` to download the checkpoints.
|
||||
Please use `llama model list --downloaded` to check that you have llama model checkpoints downloaded in `~/.llama` before proceeding. See [installation guide](../../references/llama_cli_reference/download_models.md) here to download the models. Run `llama model list` to see the available models to download, and `llama model download` to download the checkpoints.
|
||||
|
||||
```
|
||||
$ llama model list --downloaded
|
||||
|
|
|
@ -50,6 +50,7 @@ The following models are available by default:
|
|||
- `meta/llama-3.2-11b-vision-instruct `
|
||||
- `meta/llama-3.2-90b-vision-instruct `
|
||||
- `meta/llama-3.3-70b-instruct `
|
||||
- `nvidia/vila `
|
||||
- `nvidia/llama-3.2-nv-embedqa-1b-v2 `
|
||||
- `nvidia/nv-embedqa-e5-v5 `
|
||||
- `nvidia/nv-embedqa-mistral-7b-v2 `
|
||||
|
|
|
@ -18,12 +18,13 @@ embedding_model_id = (
|
|||
).identifier
|
||||
embedding_dimension = em.metadata["embedding_dimension"]
|
||||
|
||||
_ = client.vector_dbs.register(
|
||||
vector_db = client.vector_dbs.register(
|
||||
vector_db_id=vector_db_id,
|
||||
embedding_model=embedding_model_id,
|
||||
embedding_dimension=embedding_dimension,
|
||||
provider_id="faiss",
|
||||
)
|
||||
vector_db_id = vector_db.identifier
|
||||
source = "https://www.paulgraham.com/greatwork.html"
|
||||
print("rag_tool> Ingesting document:", source)
|
||||
document = RAGDocument(
|
||||
|
@ -35,7 +36,7 @@ document = RAGDocument(
|
|||
client.tool_runtime.rag_tool.insert(
|
||||
documents=[document],
|
||||
vector_db_id=vector_db_id,
|
||||
chunk_size_in_tokens=50,
|
||||
chunk_size_in_tokens=100,
|
||||
)
|
||||
agent = Agent(
|
||||
client,
|
||||
|
|
|
@ -8,3 +8,4 @@ Here's a list of known external providers that you can use with Llama Stack:
|
|||
| KubeFlow Pipelines | Train models with KubeFlow Pipelines | Post Training | Inline **and** Remote | [llama-stack-provider-kfp-trainer](https://github.com/opendatahub-io/llama-stack-provider-kfp-trainer) |
|
||||
| RamaLama | Inference models with RamaLama | Inference | Remote | [ramalama-stack](https://github.com/containers/ramalama-stack) |
|
||||
| TrustyAI LM-Eval | Evaluate models with TrustyAI LM-Eval | Eval | Remote | [llama-stack-provider-lmeval](https://github.com/trustyai-explainability/llama-stack-provider-lmeval) |
|
||||
| MongoDB | VectorIO with MongoDB | Vector_IO | Remote | [mongodb-llama-stack](https://github.com/mongodb-partners/mongodb-llama-stack) |
|
||||
|
|
|
@ -15,8 +15,8 @@ AWS Bedrock inference provider for accessing various AI models through AWS's man
|
|||
| `profile_name` | `str \| None` | No | | The profile name that contains credentials to use.Default use environment variable: AWS_PROFILE |
|
||||
| `total_max_attempts` | `int \| None` | No | | An integer representing the maximum number of attempts that will be made for a single request, including the initial attempt. Default use environment variable: AWS_MAX_ATTEMPTS |
|
||||
| `retry_mode` | `str \| None` | No | | A string representing the type of retries Boto3 will perform.Default use environment variable: AWS_RETRY_MODE |
|
||||
| `connect_timeout` | `float \| None` | No | 60 | The time in seconds till a timeout exception is thrown when attempting to make a connection. The default is 60 seconds. |
|
||||
| `read_timeout` | `float \| None` | No | 60 | The time in seconds till a timeout exception is thrown when attempting to read from a connection.The default is 60 seconds. |
|
||||
| `connect_timeout` | `float \| None` | No | 60.0 | The time in seconds till a timeout exception is thrown when attempting to make a connection. The default is 60 seconds. |
|
||||
| `read_timeout` | `float \| None` | No | 60.0 | The time in seconds till a timeout exception is thrown when attempting to read from a connection.The default is 60 seconds. |
|
||||
| `session_ttl` | `int \| None` | No | 3600 | The time in seconds till a session expires. The default is 3600 seconds (1 hour). |
|
||||
|
||||
## Sample Configuration
|
||||
|
|
|
@ -9,7 +9,6 @@ This section contains documentation for all available providers for the **post_t
|
|||
```{toctree}
|
||||
:maxdepth: 1
|
||||
|
||||
inline_huggingface-cpu
|
||||
inline_huggingface-gpu
|
||||
inline_torchtune-cpu
|
||||
inline_torchtune-gpu
|
||||
|
|
|
@ -15,8 +15,8 @@ AWS Bedrock safety provider for content moderation using AWS's safety services.
|
|||
| `profile_name` | `str \| None` | No | | The profile name that contains credentials to use.Default use environment variable: AWS_PROFILE |
|
||||
| `total_max_attempts` | `int \| None` | No | | An integer representing the maximum number of attempts that will be made for a single request, including the initial attempt. Default use environment variable: AWS_MAX_ATTEMPTS |
|
||||
| `retry_mode` | `str \| None` | No | | A string representing the type of retries Boto3 will perform.Default use environment variable: AWS_RETRY_MODE |
|
||||
| `connect_timeout` | `float \| None` | No | 60 | The time in seconds till a timeout exception is thrown when attempting to make a connection. The default is 60 seconds. |
|
||||
| `read_timeout` | `float \| None` | No | 60 | The time in seconds till a timeout exception is thrown when attempting to read from a connection.The default is 60 seconds. |
|
||||
| `connect_timeout` | `float \| None` | No | 60.0 | The time in seconds till a timeout exception is thrown when attempting to make a connection. The default is 60 seconds. |
|
||||
| `read_timeout` | `float \| None` | No | 60.0 | The time in seconds till a timeout exception is thrown when attempting to read from a connection.The default is 60 seconds. |
|
||||
| `session_ttl` | `int \| None` | No | 3600 | The time in seconds till a session expires. The default is 3600 seconds (1 hour). |
|
||||
|
||||
## Sample Configuration
|
||||
|
|
|
@ -12,6 +12,60 @@ That means you'll get fast and efficient vector retrieval.
|
|||
- Easy to use
|
||||
- Fully integrated with Llama Stack
|
||||
|
||||
There are three implementations of search for PGVectoIndex available:
|
||||
|
||||
1. Vector Search:
|
||||
- How it works:
|
||||
- Uses PostgreSQL's vector extension (pgvector) to perform similarity search
|
||||
- Compares query embeddings against stored embeddings using Cosine distance or other distance metrics
|
||||
- Eg. SQL query: SELECT document, embedding <=> %s::vector AS distance FROM table ORDER BY distance
|
||||
|
||||
-Characteristics:
|
||||
- Semantic understanding - finds documents similar in meaning even if they don't share keywords
|
||||
- Works with high-dimensional vector embeddings (typically 768, 1024, or higher dimensions)
|
||||
- Best for: Finding conceptually related content, handling synonyms, cross-language search
|
||||
|
||||
2. Keyword Search
|
||||
- How it works:
|
||||
- Uses PostgreSQL's full-text search capabilities with tsvector and ts_rank
|
||||
- Converts text to searchable tokens using to_tsvector('english', text). Default language is English.
|
||||
- Eg. SQL query: SELECT document, ts_rank(tokenized_content, plainto_tsquery('english', %s)) AS score
|
||||
|
||||
- Characteristics:
|
||||
- Lexical matching - finds exact keyword matches and variations
|
||||
- Uses GIN (Generalized Inverted Index) for fast text search performance
|
||||
- Scoring: Uses PostgreSQL's ts_rank function for relevance scoring
|
||||
- Best for: Exact term matching, proper names, technical terms, Boolean-style queries
|
||||
|
||||
3. Hybrid Search
|
||||
- How it works:
|
||||
- Combines both vector and keyword search results
|
||||
- Runs both searches independently, then merges results using configurable reranking
|
||||
|
||||
- Two reranking strategies available:
|
||||
- Reciprocal Rank Fusion (RRF) - (default: 60.0)
|
||||
- Weighted Average - (default: 0.5)
|
||||
|
||||
- Characteristics:
|
||||
- Best of both worlds: semantic understanding + exact matching
|
||||
- Documents appearing in both searches get boosted scores
|
||||
- Configurable balance between semantic and lexical matching
|
||||
- Best for: General-purpose search where you want both precision and recall
|
||||
|
||||
4. Database Schema
|
||||
The PGVector implementation stores data optimized for all three search types:
|
||||
CREATE TABLE vector_store_xxx (
|
||||
id TEXT PRIMARY KEY,
|
||||
document JSONB, -- Original document
|
||||
embedding vector(dimension), -- For vector search
|
||||
content_text TEXT, -- Raw text content
|
||||
tokenized_content TSVECTOR -- For keyword search
|
||||
);
|
||||
|
||||
-- Indexes for performance
|
||||
CREATE INDEX content_gin_idx ON table USING GIN(tokenized_content); -- Keyword search
|
||||
-- Vector index created automatically by pgvector
|
||||
|
||||
## Usage
|
||||
|
||||
To use PGVector in your Llama Stack project, follow these steps:
|
||||
|
@ -20,6 +74,25 @@ To use PGVector in your Llama Stack project, follow these steps:
|
|||
2. Configure your Llama Stack project to use pgvector. (e.g. remote::pgvector).
|
||||
3. Start storing and querying vectors.
|
||||
|
||||
## This is an example how you can set up your environment for using PGVector
|
||||
|
||||
1. Export env vars:
|
||||
```bash
|
||||
export ENABLE_PGVECTOR=true
|
||||
export PGVECTOR_HOST=localhost
|
||||
export PGVECTOR_PORT=5432
|
||||
export PGVECTOR_DB=llamastack
|
||||
export PGVECTOR_USER=llamastack
|
||||
export PGVECTOR_PASSWORD=llamastack
|
||||
```
|
||||
|
||||
2. Create DB:
|
||||
```bash
|
||||
psql -h localhost -U postgres -c "CREATE ROLE llamastack LOGIN PASSWORD 'llamastack';"
|
||||
psql -h localhost -U postgres -c "CREATE DATABASE llamastack OWNER llamastack;"
|
||||
psql -h localhost -U llamastack -d llamastack -c "CREATE EXTENSION IF NOT EXISTS vector;"
|
||||
```
|
||||
|
||||
## Installation
|
||||
|
||||
You can install PGVector using docker:
|
||||
|
|
|
@ -17,6 +17,7 @@ Weaviate supports:
|
|||
- Metadata filtering
|
||||
- Multi-modal retrieval
|
||||
|
||||
|
||||
## Usage
|
||||
|
||||
To use Weaviate in your Llama Stack project, follow these steps:
|
||||
|
|
|
@ -202,7 +202,7 @@ pprint(response)
|
|||
|
||||
Llama Stack offers a library of scoring functions and the `/scoring` API, allowing you to run evaluations on your pre-annotated AI application datasets.
|
||||
|
||||
In this example, we will work with an example RAG dataset you have built previously, label with an annotation, and use LLM-As-Judge with custom judge prompt for scoring. Please checkout our [Llama Stack Playground](https://llama-stack.readthedocs.io/en/latest/playground/index.html) for an interactive interface to upload datasets and run scorings.
|
||||
In this example, we will work with an example RAG dataset you have built previously, label with an annotation, and use LLM-As-Judge with custom judge prompt for scoring. Please checkout our [Llama Stack Playground](../../building_applications/playground/index.md) for an interactive interface to upload datasets and run scorings.
|
||||
|
||||
```python
|
||||
judge_model_id = "meta-llama/Llama-3.1-405B-Instruct-FP8"
|
||||
|
|
|
@ -478,7 +478,6 @@ llama-stack-client scoring_functions list
|
|||
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━┓
|
||||
┃ identifier ┃ provider_id ┃ description ┃ type ┃
|
||||
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━┩
|
||||
│ basic::bfcl │ basic │ BFCL complex scoring │ scoring_function │
|
||||
│ basic::docvqa │ basic │ DocVQA Visual Question & Answer scoring function │ scoring_function │
|
||||
│ basic::equality │ basic │ Returns 1.0 if the input is equal to the target, 0.0 │ scoring_function │
|
||||
│ │ │ otherwise. │ │
|
||||
|
|
|
@ -79,3 +79,10 @@ class ConflictError(ValueError):
|
|||
|
||||
def __init__(self, message: str) -> None:
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class TokenValidationError(ValueError):
|
||||
"""raised when token validation fails during authentication"""
|
||||
|
||||
def __init__(self, message: str) -> None:
|
||||
super().__init__(message)
|
||||
|
|
|
@ -102,6 +102,7 @@ class Api(Enum, metaclass=DynamicApiMeta):
|
|||
:cvar benchmarks: Benchmark suite management
|
||||
:cvar tool_groups: Tool group organization
|
||||
:cvar files: File storage and management
|
||||
:cvar prompts: Prompt versions and management
|
||||
:cvar inspect: Built-in system inspection and introspection
|
||||
"""
|
||||
|
||||
|
@ -127,6 +128,7 @@ class Api(Enum, metaclass=DynamicApiMeta):
|
|||
benchmarks = "benchmarks"
|
||||
tool_groups = "tool_groups"
|
||||
files = "files"
|
||||
prompts = "prompts"
|
||||
|
||||
# built-in API
|
||||
inspect = "inspect"
|
||||
|
|
|
@ -5,10 +5,10 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from enum import StrEnum
|
||||
from typing import Annotated, Literal, Protocol, runtime_checkable
|
||||
from typing import Annotated, ClassVar, Literal, Protocol, runtime_checkable
|
||||
|
||||
from fastapi import File, Form, Response, UploadFile
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.common.responses import Order
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
|
@ -49,6 +49,23 @@ class OpenAIFileObject(BaseModel):
|
|||
purpose: OpenAIFilePurpose
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ExpiresAfter(BaseModel):
|
||||
"""
|
||||
Control expiration of uploaded files.
|
||||
|
||||
Params:
|
||||
- anchor, must be "created_at"
|
||||
- seconds, must be int between 3600 and 2592000 (1 hour to 30 days)
|
||||
"""
|
||||
|
||||
MIN: ClassVar[int] = 3600 # 1 hour
|
||||
MAX: ClassVar[int] = 2592000 # 30 days
|
||||
|
||||
anchor: Literal["created_at"]
|
||||
seconds: int = Field(..., ge=3600, le=2592000)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ListOpenAIFileResponse(BaseModel):
|
||||
"""
|
||||
|
@ -92,6 +109,9 @@ class Files(Protocol):
|
|||
self,
|
||||
file: Annotated[UploadFile, File()],
|
||||
purpose: Annotated[OpenAIFilePurpose, Form()],
|
||||
expires_after_anchor: Annotated[str | None, Form(alias="expires_after[anchor]")] = None,
|
||||
expires_after_seconds: Annotated[int | None, Form(alias="expires_after[seconds]")] = None,
|
||||
# TODO: expires_after is producing strange openapi spec, params are showing up as a required w/ oneOf being null
|
||||
) -> OpenAIFileObject:
|
||||
"""
|
||||
Upload a file that can be used across various endpoints.
|
||||
|
@ -99,6 +119,7 @@ class Files(Protocol):
|
|||
The file upload should be a multipart form request with:
|
||||
- file: The File object (not file name) to be uploaded.
|
||||
- purpose: The intended purpose of the uploaded file.
|
||||
- expires_after: Optional form values describing expiration for the file. Expected expires_after[anchor] = "created_at", expires_after[seconds] = <int>. Seconds must be between 3600 and 2592000 (1 hour to 30 days).
|
||||
|
||||
:param file: The uploaded file object containing content and metadata (filename, content_type, etc.).
|
||||
:param purpose: The intended purpose of the uploaded file (e.g., "assistants", "fine-tune").
|
||||
|
|
9
llama_stack/apis/prompts/__init__.py
Normal file
9
llama_stack/apis/prompts/__init__.py
Normal file
|
@ -0,0 +1,9 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .prompts import ListPromptsResponse, Prompt, Prompts
|
||||
|
||||
__all__ = ["Prompt", "Prompts", "ListPromptsResponse"]
|
189
llama_stack/apis/prompts/prompts.py
Normal file
189
llama_stack/apis/prompts/prompts.py
Normal file
|
@ -0,0 +1,189 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import re
|
||||
import secrets
|
||||
from typing import Protocol, runtime_checkable
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Prompt(BaseModel):
|
||||
"""A prompt resource representing a stored OpenAI Compatible prompt template in Llama Stack.
|
||||
|
||||
:param prompt: The system prompt text with variable placeholders. Variables are only supported when using the Responses API.
|
||||
:param version: Version (integer starting at 1, incremented on save)
|
||||
:param prompt_id: Unique identifier formatted as 'pmpt_<48-digit-hash>'
|
||||
:param variables: List of prompt variable names that can be used in the prompt template
|
||||
:param is_default: Boolean indicating whether this version is the default version for this prompt
|
||||
"""
|
||||
|
||||
prompt: str | None = Field(default=None, description="The system prompt with variable placeholders")
|
||||
version: int = Field(description="Version (integer starting at 1, incremented on save)", ge=1)
|
||||
prompt_id: str = Field(description="Unique identifier in format 'pmpt_<48-digit-hash>'")
|
||||
variables: list[str] = Field(
|
||||
default_factory=list, description="List of variable names that can be used in the prompt template"
|
||||
)
|
||||
is_default: bool = Field(
|
||||
default=False, description="Boolean indicating whether this version is the default version"
|
||||
)
|
||||
|
||||
@field_validator("prompt_id")
|
||||
@classmethod
|
||||
def validate_prompt_id(cls, prompt_id: str) -> str:
|
||||
if not isinstance(prompt_id, str):
|
||||
raise TypeError("prompt_id must be a string in format 'pmpt_<48-digit-hash>'")
|
||||
|
||||
if not prompt_id.startswith("pmpt_"):
|
||||
raise ValueError("prompt_id must start with 'pmpt_' prefix")
|
||||
|
||||
hex_part = prompt_id[5:]
|
||||
if len(hex_part) != 48:
|
||||
raise ValueError("prompt_id must be in format 'pmpt_<48-digit-hash>' (48 lowercase hex chars)")
|
||||
|
||||
for char in hex_part:
|
||||
if char not in "0123456789abcdef":
|
||||
raise ValueError("prompt_id hex part must contain only lowercase hex characters [0-9a-f]")
|
||||
|
||||
return prompt_id
|
||||
|
||||
@field_validator("version")
|
||||
@classmethod
|
||||
def validate_version(cls, prompt_version: int) -> int:
|
||||
if prompt_version < 1:
|
||||
raise ValueError("version must be >= 1")
|
||||
return prompt_version
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_prompt_variables(self):
|
||||
"""Validate that all variables used in the prompt are declared in the variables list."""
|
||||
if not self.prompt:
|
||||
return self
|
||||
|
||||
prompt_variables = set(re.findall(r"{{\s*(\w+)\s*}}", self.prompt))
|
||||
declared_variables = set(self.variables)
|
||||
|
||||
undeclared = prompt_variables - declared_variables
|
||||
if undeclared:
|
||||
raise ValueError(f"Prompt contains undeclared variables: {sorted(undeclared)}")
|
||||
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def generate_prompt_id(cls) -> str:
|
||||
# Generate 48 hex characters (24 bytes)
|
||||
random_bytes = secrets.token_bytes(24)
|
||||
hex_string = random_bytes.hex()
|
||||
return f"pmpt_{hex_string}"
|
||||
|
||||
|
||||
class ListPromptsResponse(BaseModel):
|
||||
"""Response model to list prompts."""
|
||||
|
||||
data: list[Prompt]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class Prompts(Protocol):
|
||||
"""Protocol for prompt management operations."""
|
||||
|
||||
@webmethod(route="/prompts", method="GET")
|
||||
async def list_prompts(self) -> ListPromptsResponse:
|
||||
"""List all prompts.
|
||||
|
||||
:returns: A ListPromptsResponse containing all prompts.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/prompts/{prompt_id}/versions", method="GET")
|
||||
async def list_prompt_versions(
|
||||
self,
|
||||
prompt_id: str,
|
||||
) -> ListPromptsResponse:
|
||||
"""List all versions of a specific prompt.
|
||||
|
||||
:param prompt_id: The identifier of the prompt to list versions for.
|
||||
:returns: A ListPromptsResponse containing all versions of the prompt.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/prompts/{prompt_id}", method="GET")
|
||||
async def get_prompt(
|
||||
self,
|
||||
prompt_id: str,
|
||||
version: int | None = None,
|
||||
) -> Prompt:
|
||||
"""Get a prompt by its identifier and optional version.
|
||||
|
||||
:param prompt_id: The identifier of the prompt to get.
|
||||
:param version: The version of the prompt to get (defaults to latest).
|
||||
:returns: A Prompt resource.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/prompts", method="POST")
|
||||
async def create_prompt(
|
||||
self,
|
||||
prompt: str,
|
||||
variables: list[str] | None = None,
|
||||
) -> Prompt:
|
||||
"""Create a new prompt.
|
||||
|
||||
:param prompt: The prompt text content with variable placeholders.
|
||||
:param variables: List of variable names that can be used in the prompt template.
|
||||
:returns: The created Prompt resource.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/prompts/{prompt_id}", method="PUT")
|
||||
async def update_prompt(
|
||||
self,
|
||||
prompt_id: str,
|
||||
prompt: str,
|
||||
version: int,
|
||||
variables: list[str] | None = None,
|
||||
set_as_default: bool = True,
|
||||
) -> Prompt:
|
||||
"""Update an existing prompt (increments version).
|
||||
|
||||
:param prompt_id: The identifier of the prompt to update.
|
||||
:param prompt: The updated prompt text content.
|
||||
:param version: The current version of the prompt being updated.
|
||||
:param variables: Updated list of variable names that can be used in the prompt template.
|
||||
:param set_as_default: Set the new version as the default (default=True).
|
||||
:returns: The updated Prompt resource with incremented version.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/prompts/{prompt_id}", method="DELETE")
|
||||
async def delete_prompt(
|
||||
self,
|
||||
prompt_id: str,
|
||||
) -> None:
|
||||
"""Delete a prompt.
|
||||
|
||||
:param prompt_id: The identifier of the prompt to delete.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/prompts/{prompt_id}/set-default-version", method="PUT")
|
||||
async def set_default_version(
|
||||
self,
|
||||
prompt_id: str,
|
||||
version: int,
|
||||
) -> Prompt:
|
||||
"""Set which version of a prompt should be the default in get_prompt (latest).
|
||||
|
||||
:param prompt_id: The identifier of the prompt.
|
||||
:param version: The version to set as default.
|
||||
:returns: The prompt with the specified version now set as default.
|
||||
"""
|
||||
...
|
|
@ -19,6 +19,7 @@ class ResourceType(StrEnum):
|
|||
benchmark = "benchmark"
|
||||
tool = "tool"
|
||||
tool_group = "tool_group"
|
||||
prompt = "prompt"
|
||||
|
||||
|
||||
class Resource(BaseModel):
|
||||
|
|
|
@ -45,6 +45,7 @@ from llama_stack.core.utils.dynamic import instantiate_class_type
|
|||
from llama_stack.core.utils.exec import formulate_run_args, run_command
|
||||
from llama_stack.core.utils.image_types import LlamaStackImageType
|
||||
from llama_stack.providers.datatypes import Api
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
|
||||
|
||||
DISTRIBS_PATH = Path(__file__).parent.parent.parent / "distributions"
|
||||
|
||||
|
@ -294,6 +295,12 @@ def _generate_run_config(
|
|||
if build_config.external_providers_dir
|
||||
else EXTERNAL_PROVIDERS_DIR,
|
||||
)
|
||||
if not run_config.inference_store:
|
||||
run_config.inference_store = SqliteSqlStoreConfig(
|
||||
**SqliteSqlStoreConfig.sample_run_config(
|
||||
__distro_dir__=(DISTRIBS_BASE_DIR / image_name).as_posix(), db_name="inference_store.db"
|
||||
)
|
||||
)
|
||||
# build providers dict
|
||||
provider_registry = get_provider_registry(build_config)
|
||||
for api in apis:
|
||||
|
|
|
@ -80,7 +80,7 @@ def get_provider_dependencies(
|
|||
normal_deps = []
|
||||
special_deps = []
|
||||
for package in deps:
|
||||
if "--no-deps" in package or "--index-url" in package:
|
||||
if any(f in package for f in ["--no-deps", "--index-url", "--extra-index-url"]):
|
||||
special_deps.append(package)
|
||||
else:
|
||||
normal_deps.append(package)
|
||||
|
|
|
@ -7,6 +7,7 @@
|
|||
from enum import StrEnum
|
||||
from pathlib import Path
|
||||
from typing import Annotated, Any, Literal, Self
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
|
||||
|
@ -212,6 +213,7 @@ class AuthProviderType(StrEnum):
|
|||
OAUTH2_TOKEN = "oauth2_token"
|
||||
GITHUB_TOKEN = "github_token"
|
||||
CUSTOM = "custom"
|
||||
KUBERNETES = "kubernetes"
|
||||
|
||||
|
||||
class OAuth2TokenAuthConfig(BaseModel):
|
||||
|
@ -282,8 +284,45 @@ class GitHubTokenAuthConfig(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
class KubernetesAuthProviderConfig(BaseModel):
|
||||
"""Configuration for Kubernetes authentication provider."""
|
||||
|
||||
type: Literal[AuthProviderType.KUBERNETES] = AuthProviderType.KUBERNETES
|
||||
api_server_url: str = Field(
|
||||
default="https://kubernetes.default.svc",
|
||||
description="Kubernetes API server URL (e.g., https://api.cluster.domain:6443)",
|
||||
)
|
||||
verify_tls: bool = Field(default=True, description="Whether to verify TLS certificates")
|
||||
tls_cafile: Path | None = Field(default=None, description="Path to CA certificate file for TLS verification")
|
||||
claims_mapping: dict[str, str] = Field(
|
||||
default_factory=lambda: {
|
||||
"username": "roles",
|
||||
"groups": "roles",
|
||||
},
|
||||
description="Mapping of Kubernetes user claims to access attributes",
|
||||
)
|
||||
|
||||
@field_validator("api_server_url")
|
||||
@classmethod
|
||||
def validate_api_server_url(cls, v):
|
||||
parsed = urlparse(v)
|
||||
if not parsed.scheme or not parsed.netloc:
|
||||
raise ValueError(f"api_server_url must be a valid URL with scheme and host: {v}")
|
||||
if parsed.scheme not in ["http", "https"]:
|
||||
raise ValueError(f"api_server_url scheme must be http or https: {v}")
|
||||
return v
|
||||
|
||||
@field_validator("claims_mapping")
|
||||
@classmethod
|
||||
def validate_claims_mapping(cls, v):
|
||||
for key, value in v.items():
|
||||
if not value:
|
||||
raise ValueError(f"claims_mapping value cannot be empty: {key}")
|
||||
return v
|
||||
|
||||
|
||||
AuthProviderConfig = Annotated[
|
||||
OAuth2TokenAuthConfig | GitHubTokenAuthConfig | CustomAuthConfig,
|
||||
OAuth2TokenAuthConfig | GitHubTokenAuthConfig | CustomAuthConfig | KubernetesAuthProviderConfig,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
@ -392,6 +431,12 @@ class ServerConfig(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
class InferenceStoreConfig(BaseModel):
|
||||
sql_store_config: SqlStoreConfig
|
||||
max_write_queue_size: int = Field(default=10000, description="Max queued writes for inference store")
|
||||
num_writers: int = Field(default=4, description="Number of concurrent background writers")
|
||||
|
||||
|
||||
class StackRunConfig(BaseModel):
|
||||
version: int = LLAMA_STACK_RUN_CONFIG_VERSION
|
||||
|
||||
|
@ -425,11 +470,12 @@ Configuration for the persistence store used by the distribution registry. If no
|
|||
a default SQLite store will be used.""",
|
||||
)
|
||||
|
||||
inference_store: SqlStoreConfig | None = Field(
|
||||
inference_store: InferenceStoreConfig | SqlStoreConfig | None = Field(
|
||||
default=None,
|
||||
description="""
|
||||
Configuration for the persistence store used by the inference API. If not specified,
|
||||
a default SQLite store will be used.""",
|
||||
Configuration for the persistence store used by the inference API. Can be either a
|
||||
InferenceStoreConfig (with queue tuning parameters) or a SqlStoreConfig (deprecated).
|
||||
If not specified, a default SQLite store will be used.""",
|
||||
)
|
||||
|
||||
# registry of "resources" in the distribution
|
||||
|
|
|
@ -10,7 +10,6 @@ import json
|
|||
import logging # allow-direct-logging
|
||||
import os
|
||||
import sys
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from enum import Enum
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
|
@ -148,7 +147,6 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
|
|||
self.async_client = AsyncLlamaStackAsLibraryClient(
|
||||
config_path_or_distro_name, custom_provider_registry, provider_data, skip_logger_removal
|
||||
)
|
||||
self.pool_executor = ThreadPoolExecutor(max_workers=4)
|
||||
self.provider_data = provider_data
|
||||
|
||||
self.loop = asyncio.new_event_loop()
|
||||
|
|
233
llama_stack/core/prompts/prompts.py
Normal file
233
llama_stack/core/prompts/prompts.py
Normal file
|
@ -0,0 +1,233 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.prompts import ListPromptsResponse, Prompt, Prompts
|
||||
from llama_stack.core.datatypes import StackRunConfig
|
||||
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
|
||||
|
||||
class PromptServiceConfig(BaseModel):
|
||||
"""Configuration for the built-in prompt service.
|
||||
|
||||
:param run_config: Stack run configuration containing distribution info
|
||||
"""
|
||||
|
||||
run_config: StackRunConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: PromptServiceConfig, deps: dict[Any, Any]):
|
||||
"""Get the prompt service implementation."""
|
||||
impl = PromptServiceImpl(config, deps)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
||||
|
||||
class PromptServiceImpl(Prompts):
|
||||
"""Built-in prompt service implementation using KVStore."""
|
||||
|
||||
def __init__(self, config: PromptServiceConfig, deps: dict[Any, Any]):
|
||||
self.config = config
|
||||
self.deps = deps
|
||||
self.kvstore: KVStore
|
||||
|
||||
async def initialize(self) -> None:
|
||||
kvstore_config = SqliteKVStoreConfig(
|
||||
db_path=(DISTRIBS_BASE_DIR / self.config.run_config.image_name / "prompts.db").as_posix()
|
||||
)
|
||||
self.kvstore = await kvstore_impl(kvstore_config)
|
||||
|
||||
def _get_default_key(self, prompt_id: str) -> str:
|
||||
"""Get the KVStore key that stores the default version number."""
|
||||
return f"prompts:v1:{prompt_id}:default"
|
||||
|
||||
async def _get_prompt_key(self, prompt_id: str, version: int | None = None) -> str:
|
||||
"""Get the KVStore key for prompt data, returning default version if applicable."""
|
||||
if version:
|
||||
return self._get_version_key(prompt_id, str(version))
|
||||
|
||||
default_key = self._get_default_key(prompt_id)
|
||||
resolved_version = await self.kvstore.get(default_key)
|
||||
if resolved_version is None:
|
||||
raise ValueError(f"Prompt {prompt_id}:default not found")
|
||||
return self._get_version_key(prompt_id, resolved_version)
|
||||
|
||||
def _get_version_key(self, prompt_id: str, version: str) -> str:
|
||||
"""Get the KVStore key for a specific prompt version."""
|
||||
return f"prompts:v1:{prompt_id}:{version}"
|
||||
|
||||
def _get_list_key_prefix(self) -> str:
|
||||
"""Get the key prefix for listing prompts."""
|
||||
return "prompts:v1:"
|
||||
|
||||
def _serialize_prompt(self, prompt: Prompt) -> str:
|
||||
"""Serialize a prompt to JSON string for storage."""
|
||||
return json.dumps(
|
||||
{
|
||||
"prompt_id": prompt.prompt_id,
|
||||
"prompt": prompt.prompt,
|
||||
"version": prompt.version,
|
||||
"variables": prompt.variables or [],
|
||||
"is_default": prompt.is_default,
|
||||
}
|
||||
)
|
||||
|
||||
def _deserialize_prompt(self, data: str) -> Prompt:
|
||||
"""Deserialize a prompt from JSON string."""
|
||||
obj = json.loads(data)
|
||||
return Prompt(
|
||||
prompt_id=obj["prompt_id"],
|
||||
prompt=obj["prompt"],
|
||||
version=obj["version"],
|
||||
variables=obj.get("variables", []),
|
||||
is_default=obj.get("is_default", False),
|
||||
)
|
||||
|
||||
async def list_prompts(self) -> ListPromptsResponse:
|
||||
"""List all prompts (default versions only)."""
|
||||
prefix = self._get_list_key_prefix()
|
||||
keys = await self.kvstore.keys_in_range(prefix, prefix + "\xff")
|
||||
|
||||
prompts = []
|
||||
for key in keys:
|
||||
if key.endswith(":default"):
|
||||
try:
|
||||
default_version = await self.kvstore.get(key)
|
||||
if default_version:
|
||||
prompt_id = key.replace(prefix, "").replace(":default", "")
|
||||
version_key = self._get_version_key(prompt_id, default_version)
|
||||
data = await self.kvstore.get(version_key)
|
||||
if data:
|
||||
prompt = self._deserialize_prompt(data)
|
||||
prompts.append(prompt)
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
continue
|
||||
|
||||
prompts.sort(key=lambda p: p.prompt_id or "", reverse=True)
|
||||
return ListPromptsResponse(data=prompts)
|
||||
|
||||
async def get_prompt(self, prompt_id: str, version: int | None = None) -> Prompt:
|
||||
"""Get a prompt by its identifier and optional version."""
|
||||
key = await self._get_prompt_key(prompt_id, version)
|
||||
data = await self.kvstore.get(key)
|
||||
if data is None:
|
||||
raise ValueError(f"Prompt {prompt_id}:{version if version else 'default'} not found")
|
||||
return self._deserialize_prompt(data)
|
||||
|
||||
async def create_prompt(
|
||||
self,
|
||||
prompt: str,
|
||||
variables: list[str] | None = None,
|
||||
) -> Prompt:
|
||||
"""Create a new prompt."""
|
||||
if variables is None:
|
||||
variables = []
|
||||
|
||||
prompt_obj = Prompt(
|
||||
prompt_id=Prompt.generate_prompt_id(),
|
||||
prompt=prompt,
|
||||
version=1,
|
||||
variables=variables,
|
||||
)
|
||||
|
||||
version_key = self._get_version_key(prompt_obj.prompt_id, str(prompt_obj.version))
|
||||
data = self._serialize_prompt(prompt_obj)
|
||||
await self.kvstore.set(version_key, data)
|
||||
|
||||
default_key = self._get_default_key(prompt_obj.prompt_id)
|
||||
await self.kvstore.set(default_key, str(prompt_obj.version))
|
||||
|
||||
return prompt_obj
|
||||
|
||||
async def update_prompt(
|
||||
self,
|
||||
prompt_id: str,
|
||||
prompt: str,
|
||||
version: int,
|
||||
variables: list[str] | None = None,
|
||||
set_as_default: bool = True,
|
||||
) -> Prompt:
|
||||
"""Update an existing prompt (increments version)."""
|
||||
if version < 1:
|
||||
raise ValueError("Version must be >= 1")
|
||||
if variables is None:
|
||||
variables = []
|
||||
|
||||
prompt_versions = await self.list_prompt_versions(prompt_id)
|
||||
latest_prompt = max(prompt_versions.data, key=lambda x: int(x.version))
|
||||
|
||||
if version and latest_prompt.version != version:
|
||||
raise ValueError(
|
||||
f"'{version}' is not the latest prompt version for prompt_id='{prompt_id}'. Use the latest version '{latest_prompt.version}' in request."
|
||||
)
|
||||
|
||||
current_version = latest_prompt.version if version is None else version
|
||||
new_version = current_version + 1
|
||||
|
||||
updated_prompt = Prompt(prompt_id=prompt_id, prompt=prompt, version=new_version, variables=variables)
|
||||
|
||||
version_key = self._get_version_key(prompt_id, str(new_version))
|
||||
data = self._serialize_prompt(updated_prompt)
|
||||
await self.kvstore.set(version_key, data)
|
||||
|
||||
if set_as_default:
|
||||
await self.set_default_version(prompt_id, new_version)
|
||||
|
||||
return updated_prompt
|
||||
|
||||
async def delete_prompt(self, prompt_id: str) -> None:
|
||||
"""Delete a prompt and all its versions."""
|
||||
await self.get_prompt(prompt_id)
|
||||
|
||||
prefix = f"prompts:v1:{prompt_id}:"
|
||||
keys = await self.kvstore.keys_in_range(prefix, prefix + "\xff")
|
||||
|
||||
for key in keys:
|
||||
await self.kvstore.delete(key)
|
||||
|
||||
async def list_prompt_versions(self, prompt_id: str) -> ListPromptsResponse:
|
||||
"""List all versions of a specific prompt."""
|
||||
prefix = f"prompts:v1:{prompt_id}:"
|
||||
keys = await self.kvstore.keys_in_range(prefix, prefix + "\xff")
|
||||
|
||||
default_version = None
|
||||
prompts = []
|
||||
|
||||
for key in keys:
|
||||
data = await self.kvstore.get(key)
|
||||
if key.endswith(":default"):
|
||||
default_version = data
|
||||
else:
|
||||
if data:
|
||||
prompt_obj = self._deserialize_prompt(data)
|
||||
prompts.append(prompt_obj)
|
||||
|
||||
if not prompts:
|
||||
raise ValueError(f"Prompt {prompt_id} not found")
|
||||
|
||||
for prompt in prompts:
|
||||
prompt.is_default = str(prompt.version) == default_version
|
||||
|
||||
prompts.sort(key=lambda x: x.version)
|
||||
return ListPromptsResponse(data=prompts)
|
||||
|
||||
async def set_default_version(self, prompt_id: str, version: int) -> Prompt:
|
||||
"""Set which version of a prompt should be the default, If not set. the default is the latest."""
|
||||
version_key = self._get_version_key(prompt_id, str(version))
|
||||
data = await self.kvstore.get(version_key)
|
||||
if data is None:
|
||||
raise ValueError(f"Prompt {prompt_id} version {version} not found")
|
||||
|
||||
default_key = self._get_default_key(prompt_id)
|
||||
await self.kvstore.set(default_key, str(version))
|
||||
|
||||
return self._deserialize_prompt(data)
|
|
@ -19,6 +19,7 @@ from llama_stack.apis.inference import Inference, InferenceProvider
|
|||
from llama_stack.apis.inspect import Inspect
|
||||
from llama_stack.apis.models import Models
|
||||
from llama_stack.apis.post_training import PostTraining
|
||||
from llama_stack.apis.prompts import Prompts
|
||||
from llama_stack.apis.providers import Providers as ProvidersAPI
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.scoring import Scoring
|
||||
|
@ -93,6 +94,7 @@ def api_protocol_map(external_apis: dict[Api, ExternalApiSpec] | None = None) ->
|
|||
Api.tool_groups: ToolGroups,
|
||||
Api.tool_runtime: ToolRuntime,
|
||||
Api.files: Files,
|
||||
Api.prompts: Prompts,
|
||||
}
|
||||
|
||||
if external_apis:
|
||||
|
@ -284,7 +286,15 @@ async def instantiate_providers(
|
|||
if provider.provider_id is None:
|
||||
continue
|
||||
|
||||
try:
|
||||
deps = {a: impls[a] for a in provider.spec.api_dependencies}
|
||||
except KeyError as e:
|
||||
missing_api = e.args[0]
|
||||
raise RuntimeError(
|
||||
f"Failed to resolve '{provider.spec.api.value}' provider '{provider.provider_id}' of type '{provider.spec.provider_type}': "
|
||||
f"required dependency '{missing_api.value}' is not available. "
|
||||
f"Please add a '{missing_api.value}' provider to your configuration or check if the provider is properly configured."
|
||||
) from e
|
||||
for a in provider.spec.optional_api_dependencies:
|
||||
if a in impls:
|
||||
deps[a] = impls[a]
|
||||
|
|
|
@ -78,7 +78,10 @@ async def get_auto_router_impl(
|
|||
|
||||
# TODO: move pass configs to routers instead
|
||||
if api == Api.inference and run_config.inference_store:
|
||||
inference_store = InferenceStore(run_config.inference_store, policy)
|
||||
inference_store = InferenceStore(
|
||||
config=run_config.inference_store,
|
||||
policy=policy,
|
||||
)
|
||||
await inference_store.initialize()
|
||||
api_to_dep_impl["store"] = inference_store
|
||||
|
||||
|
|
|
@ -63,7 +63,7 @@ from llama_stack.models.llama.llama3.chat_format import ChatFormat
|
|||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
|
||||
from llama_stack.providers.utils.inference.inference_store import InferenceStore
|
||||
from llama_stack.providers.utils.telemetry.tracing import get_current_span
|
||||
from llama_stack.providers.utils.telemetry.tracing import enqueue_event, get_current_span
|
||||
|
||||
logger = get_logger(name=__name__, category="core::routers")
|
||||
|
||||
|
@ -90,6 +90,11 @@ class InferenceRouter(Inference):
|
|||
|
||||
async def shutdown(self) -> None:
|
||||
logger.debug("InferenceRouter.shutdown")
|
||||
if self.store:
|
||||
try:
|
||||
await self.store.shutdown()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error during InferenceStore shutdown: {e}")
|
||||
|
||||
async def register_model(
|
||||
self,
|
||||
|
@ -160,7 +165,7 @@ class InferenceRouter(Inference):
|
|||
metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model)
|
||||
if self.telemetry:
|
||||
for metric in metrics:
|
||||
await self.telemetry.log_event(metric)
|
||||
enqueue_event(metric)
|
||||
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics]
|
||||
|
||||
async def _count_tokens(
|
||||
|
@ -431,7 +436,7 @@ class InferenceRouter(Inference):
|
|||
model=model_obj,
|
||||
)
|
||||
for metric in metrics:
|
||||
await self.telemetry.log_event(metric)
|
||||
enqueue_event(metric)
|
||||
|
||||
# these metrics will show up in the client response.
|
||||
response.metrics = (
|
||||
|
@ -527,7 +532,7 @@ class InferenceRouter(Inference):
|
|||
|
||||
# Store the response with the ID that will be returned to the client
|
||||
if self.store:
|
||||
await self.store.store_chat_completion(response, messages)
|
||||
asyncio.create_task(self.store.store_chat_completion(response, messages))
|
||||
|
||||
if self.telemetry:
|
||||
metrics = self._construct_metrics(
|
||||
|
@ -537,7 +542,7 @@ class InferenceRouter(Inference):
|
|||
model=model_obj,
|
||||
)
|
||||
for metric in metrics:
|
||||
await self.telemetry.log_event(metric)
|
||||
enqueue_event(metric)
|
||||
# these metrics will show up in the client response.
|
||||
response.metrics = (
|
||||
metrics if not hasattr(response, "metrics") or response.metrics is None else response.metrics + metrics
|
||||
|
@ -664,7 +669,7 @@ class InferenceRouter(Inference):
|
|||
"completion_tokens",
|
||||
"total_tokens",
|
||||
]: # Only log completion and total tokens
|
||||
await self.telemetry.log_event(metric)
|
||||
enqueue_event(metric)
|
||||
|
||||
# Return metrics in response
|
||||
async_metrics = [
|
||||
|
@ -710,7 +715,7 @@ class InferenceRouter(Inference):
|
|||
)
|
||||
for metric in completion_metrics:
|
||||
if metric.metric in ["completion_tokens", "total_tokens"]: # Only log completion and total tokens
|
||||
await self.telemetry.log_event(metric)
|
||||
enqueue_event(metric)
|
||||
|
||||
# Return metrics in response
|
||||
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics]
|
||||
|
@ -755,7 +760,7 @@ class InferenceRouter(Inference):
|
|||
choices_data[idx] = {
|
||||
"content_parts": [],
|
||||
"tool_calls_builder": {},
|
||||
"finish_reason": None,
|
||||
"finish_reason": "stop",
|
||||
"logprobs_content_parts": [],
|
||||
}
|
||||
current_choice_data = choices_data[idx]
|
||||
|
@ -806,7 +811,7 @@ class InferenceRouter(Inference):
|
|||
model=model,
|
||||
)
|
||||
for metric in metrics:
|
||||
await self.telemetry.log_event(metric)
|
||||
enqueue_event(metric)
|
||||
|
||||
yield chunk
|
||||
finally:
|
||||
|
@ -855,4 +860,4 @@ class InferenceRouter(Inference):
|
|||
object="chat.completion",
|
||||
)
|
||||
logger.debug(f"InferenceRouter.completion_response: {final_response}")
|
||||
await self.store.store_chat_completion(final_response, messages)
|
||||
asyncio.create_task(self.store.store_chat_completion(final_response, messages))
|
||||
|
|
|
@ -52,7 +52,6 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
|||
provider_vector_db_id: str | None = None,
|
||||
vector_db_name: str | None = None,
|
||||
) -> VectorDB:
|
||||
provider_vector_db_id = provider_vector_db_id or vector_db_id
|
||||
if provider_id is None:
|
||||
if len(self.impls_by_provider_id) > 0:
|
||||
provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||
|
@ -69,14 +68,33 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
|||
raise ModelTypeError(embedding_model, model.model_type, ModelType.embedding)
|
||||
if "embedding_dimension" not in model.metadata:
|
||||
raise ValueError(f"Model {embedding_model} does not have an embedding dimension")
|
||||
|
||||
provider = self.impls_by_provider_id[provider_id]
|
||||
logger.warning(
|
||||
"VectorDB is being deprecated in future releases in favor of VectorStore. Please migrate your usage accordingly."
|
||||
)
|
||||
vector_store = await provider.openai_create_vector_store(
|
||||
name=vector_db_name or vector_db_id,
|
||||
embedding_model=embedding_model,
|
||||
embedding_dimension=model.metadata["embedding_dimension"],
|
||||
provider_id=provider_id,
|
||||
provider_vector_db_id=provider_vector_db_id,
|
||||
)
|
||||
|
||||
vector_store_id = vector_store.id
|
||||
actual_provider_vector_db_id = provider_vector_db_id or vector_store_id
|
||||
logger.warning(
|
||||
f"Ignoring vector_db_id {vector_db_id} and using vector_store_id {vector_store_id} instead. Setting VectorDB {vector_db_id} to VectorDB.vector_db_name"
|
||||
)
|
||||
|
||||
vector_db_data = {
|
||||
"identifier": vector_db_id,
|
||||
"identifier": vector_store_id,
|
||||
"type": ResourceType.vector_db.value,
|
||||
"provider_id": provider_id,
|
||||
"provider_resource_id": provider_vector_db_id,
|
||||
"provider_resource_id": actual_provider_vector_db_id,
|
||||
"embedding_model": embedding_model,
|
||||
"embedding_dimension": model.metadata["embedding_dimension"],
|
||||
"vector_db_name": vector_db_name,
|
||||
"vector_db_name": vector_store.name,
|
||||
}
|
||||
vector_db = TypeAdapter(VectorDBWithOwner).validate_python(vector_db_data)
|
||||
await self.register_object(vector_db)
|
||||
|
|
|
@ -8,16 +8,18 @@ import ssl
|
|||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from asyncio import Lock
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
from urllib.parse import parse_qs, urljoin, urlparse
|
||||
|
||||
import httpx
|
||||
from jose import jwt
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.common.errors import TokenValidationError
|
||||
from llama_stack.core.datatypes import (
|
||||
AuthenticationConfig,
|
||||
CustomAuthConfig,
|
||||
GitHubTokenAuthConfig,
|
||||
KubernetesAuthProviderConfig,
|
||||
OAuth2TokenAuthConfig,
|
||||
User,
|
||||
)
|
||||
|
@ -162,7 +164,7 @@ class OAuth2TokenAuthProvider(AuthProvider):
|
|||
auth=auth,
|
||||
timeout=10.0, # Add a reasonable timeout
|
||||
)
|
||||
if response.status_code != 200:
|
||||
if response.status_code != httpx.codes.OK:
|
||||
logger.warning(f"Token introspection failed with status code: {response.status_code}")
|
||||
raise ValueError(f"Token introspection failed: {response.status_code}")
|
||||
|
||||
|
@ -272,7 +274,7 @@ class CustomAuthProvider(AuthProvider):
|
|||
json=auth_request.model_dump(),
|
||||
timeout=10.0, # Add a reasonable timeout
|
||||
)
|
||||
if response.status_code != 200:
|
||||
if response.status_code != httpx.codes.OK:
|
||||
logger.warning(f"Authentication failed with status code: {response.status_code}")
|
||||
raise ValueError(f"Authentication failed: {response.status_code}")
|
||||
|
||||
|
@ -374,6 +376,89 @@ async def _get_github_user_info(access_token: str, github_api_base_url: str) ->
|
|||
}
|
||||
|
||||
|
||||
class KubernetesAuthProvider(AuthProvider):
|
||||
"""
|
||||
Kubernetes authentication provider that validates tokens using the Kubernetes SelfSubjectReview API.
|
||||
This provider integrates with Kubernetes API server by using the
|
||||
/apis/authentication.k8s.io/v1/selfsubjectreviews endpoint to validate tokens and extract user information.
|
||||
"""
|
||||
|
||||
def __init__(self, config: KubernetesAuthProviderConfig):
|
||||
self.config = config
|
||||
|
||||
def _httpx_verify_value(self) -> bool | str:
|
||||
"""
|
||||
Build the value for httpx's `verify` parameter.
|
||||
- False disables verification.
|
||||
- Path string points to a CA bundle.
|
||||
- True uses system defaults.
|
||||
"""
|
||||
if not self.config.verify_tls:
|
||||
return False
|
||||
if self.config.tls_cafile:
|
||||
return self.config.tls_cafile.as_posix()
|
||||
return True
|
||||
|
||||
async def validate_token(self, token: str, scope: dict | None = None) -> User:
|
||||
"""Validate a token using Kubernetes SelfSubjectReview API endpoint."""
|
||||
# Build the Kubernetes SelfSubjectReview API endpoint URL
|
||||
review_api_url = urljoin(self.config.api_server_url, "/apis/authentication.k8s.io/v1/selfsubjectreviews")
|
||||
|
||||
# Create SelfSubjectReview request body
|
||||
review_request = {"apiVersion": "authentication.k8s.io/v1", "kind": "SelfSubjectReview"}
|
||||
verify = self._httpx_verify_value()
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(verify=verify, timeout=10.0) as client:
|
||||
response = await client.post(
|
||||
review_api_url,
|
||||
json=review_request,
|
||||
headers={
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
if response.status_code == httpx.codes.UNAUTHORIZED:
|
||||
raise TokenValidationError("Invalid token")
|
||||
if response.status_code != httpx.codes.CREATED:
|
||||
logger.warning(f"Kubernetes SelfSubjectReview API failed with status code: {response.status_code}")
|
||||
raise TokenValidationError(f"Token validation failed: {response.status_code}")
|
||||
|
||||
review_response = response.json()
|
||||
# Extract user information from SelfSubjectReview response
|
||||
status = review_response.get("status", {})
|
||||
if not status:
|
||||
raise ValueError("No status found in SelfSubjectReview response")
|
||||
|
||||
user_info = status.get("userInfo", {})
|
||||
if not user_info:
|
||||
raise ValueError("No userInfo found in SelfSubjectReview response")
|
||||
|
||||
username = user_info.get("username")
|
||||
if not username:
|
||||
raise ValueError("No username found in SelfSubjectReview response")
|
||||
|
||||
# Build user attributes from Kubernetes user info
|
||||
user_attributes = get_attributes_from_claims(user_info, self.config.claims_mapping)
|
||||
|
||||
return User(
|
||||
principal=username,
|
||||
attributes=user_attributes,
|
||||
)
|
||||
|
||||
except httpx.TimeoutException:
|
||||
logger.warning("Kubernetes SelfSubjectReview API request timed out")
|
||||
raise ValueError("Token validation timeout") from None
|
||||
except Exception as e:
|
||||
logger.warning(f"Error during token validation: {str(e)}")
|
||||
raise ValueError(f"Token validation error: {str(e)}") from e
|
||||
|
||||
async def close(self):
|
||||
"""Close any resources."""
|
||||
pass
|
||||
|
||||
|
||||
def create_auth_provider(config: AuthenticationConfig) -> AuthProvider:
|
||||
"""Factory function to create the appropriate auth provider."""
|
||||
provider_config = config.provider_config
|
||||
|
@ -384,5 +469,7 @@ def create_auth_provider(config: AuthenticationConfig) -> AuthProvider:
|
|||
return OAuth2TokenAuthProvider(provider_config)
|
||||
elif isinstance(provider_config, GitHubTokenAuthConfig):
|
||||
return GitHubTokenAuthProvider(provider_config)
|
||||
elif isinstance(provider_config, KubernetesAuthProviderConfig):
|
||||
return KubernetesAuthProvider(provider_config)
|
||||
else:
|
||||
raise ValueError(f"Unknown authentication provider config type: {type(provider_config)}")
|
||||
|
|
|
@ -132,15 +132,17 @@ def translate_exception(exc: Exception) -> HTTPException | RequestValidationErro
|
|||
},
|
||||
)
|
||||
elif isinstance(exc, ConflictError):
|
||||
return HTTPException(status_code=409, detail=str(exc))
|
||||
return HTTPException(status_code=httpx.codes.CONFLICT, detail=str(exc))
|
||||
elif isinstance(exc, ResourceNotFoundError):
|
||||
return HTTPException(status_code=404, detail=str(exc))
|
||||
return HTTPException(status_code=httpx.codes.NOT_FOUND, detail=str(exc))
|
||||
elif isinstance(exc, ValueError):
|
||||
return HTTPException(status_code=httpx.codes.BAD_REQUEST, detail=f"Invalid value: {str(exc)}")
|
||||
elif isinstance(exc, BadRequestError):
|
||||
return HTTPException(status_code=httpx.codes.BAD_REQUEST, detail=str(exc))
|
||||
elif isinstance(exc, PermissionError | AccessDeniedError):
|
||||
return HTTPException(status_code=httpx.codes.FORBIDDEN, detail=f"Permission denied: {str(exc)}")
|
||||
elif isinstance(exc, ConnectionError | httpx.ConnectError):
|
||||
return HTTPException(status_code=httpx.codes.BAD_GATEWAY, detail=str(exc))
|
||||
elif isinstance(exc, asyncio.TimeoutError | TimeoutError):
|
||||
return HTTPException(status_code=httpx.codes.GATEWAY_TIMEOUT, detail=f"Operation timed out: {str(exc)}")
|
||||
elif isinstance(exc, NotImplementedError):
|
||||
|
@ -513,6 +515,7 @@ def main(args: argparse.Namespace | None = None):
|
|||
|
||||
apis_to_serve.add("inspect")
|
||||
apis_to_serve.add("providers")
|
||||
apis_to_serve.add("prompts")
|
||||
for api_str in apis_to_serve:
|
||||
api = Api(api_str)
|
||||
|
||||
|
|
|
@ -24,6 +24,7 @@ from llama_stack.apis.inference import Inference
|
|||
from llama_stack.apis.inspect import Inspect
|
||||
from llama_stack.apis.models import Models
|
||||
from llama_stack.apis.post_training import PostTraining
|
||||
from llama_stack.apis.prompts import Prompts
|
||||
from llama_stack.apis.providers import Providers
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.scoring import Scoring
|
||||
|
@ -37,6 +38,7 @@ from llama_stack.apis.vector_io import VectorIO
|
|||
from llama_stack.core.datatypes import Provider, StackRunConfig
|
||||
from llama_stack.core.distribution import get_provider_registry
|
||||
from llama_stack.core.inspect import DistributionInspectConfig, DistributionInspectImpl
|
||||
from llama_stack.core.prompts.prompts import PromptServiceConfig, PromptServiceImpl
|
||||
from llama_stack.core.providers import ProviderImpl, ProviderImplConfig
|
||||
from llama_stack.core.resolver import ProviderRegistry, resolve_impls
|
||||
from llama_stack.core.routing_tables.common import CommonRoutingTableImpl
|
||||
|
@ -72,6 +74,7 @@ class LlamaStack(
|
|||
ToolRuntime,
|
||||
RAGToolRuntime,
|
||||
Files,
|
||||
Prompts,
|
||||
):
|
||||
pass
|
||||
|
||||
|
@ -105,12 +108,12 @@ async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
|
|||
|
||||
method = getattr(impls[api], register_method)
|
||||
for obj in objects:
|
||||
logger.debug(f"registering {rsrc.capitalize()} {obj} for provider {obj.provider_id}")
|
||||
|
||||
if hasattr(obj, "provider_id"):
|
||||
# Do not register models on disabled providers
|
||||
if hasattr(obj, "provider_id") and (not obj.provider_id or obj.provider_id == "__disabled__"):
|
||||
if not obj.provider_id or obj.provider_id == "__disabled__":
|
||||
logger.debug(f"Skipping {rsrc.capitalize()} registration for disabled provider.")
|
||||
continue
|
||||
logger.debug(f"registering {rsrc.capitalize()} {obj} for provider {obj.provider_id}")
|
||||
|
||||
# we want to maintain the type information in arguments to method.
|
||||
# instead of method(**obj.model_dump()), which may convert a typed attr to a dict,
|
||||
|
@ -225,7 +228,10 @@ def replace_env_vars(config: Any, path: str = "") -> Any:
|
|||
|
||||
try:
|
||||
result = re.sub(pattern, get_env_var, config)
|
||||
# Only apply type conversion if substitution actually happened
|
||||
if result != config:
|
||||
return _convert_string_to_proper_type(result)
|
||||
return result
|
||||
except EnvVarError as e:
|
||||
raise EnvVarError(e.var_name, e.path) from None
|
||||
|
||||
|
@ -302,6 +308,12 @@ def add_internal_implementations(impls: dict[Api, Any], run_config: StackRunConf
|
|||
)
|
||||
impls[Api.providers] = providers_impl
|
||||
|
||||
prompts_impl = PromptServiceImpl(
|
||||
PromptServiceConfig(run_config=run_config),
|
||||
deps=impls,
|
||||
)
|
||||
impls[Api.prompts] = prompts_impl
|
||||
|
||||
|
||||
# Produces a stack of providers for the given run config. Not all APIs may be
|
||||
# asked for in the run config.
|
||||
|
@ -326,6 +338,9 @@ async def construct_stack(
|
|||
# Add internal implementations after all other providers are resolved
|
||||
add_internal_implementations(impls, run_config)
|
||||
|
||||
if Api.prompts in impls:
|
||||
await impls[Api.prompts].initialize()
|
||||
|
||||
await register_resources(run_config, impls)
|
||||
|
||||
await refresh_registry_once(impls)
|
||||
|
|
|
@ -34,7 +34,7 @@ distribution_spec:
|
|||
telemetry:
|
||||
- provider_type: inline::meta-reference
|
||||
post_training:
|
||||
- provider_type: inline::huggingface-cpu
|
||||
- provider_type: inline::torchtune-cpu
|
||||
eval:
|
||||
- provider_type: inline::meta-reference
|
||||
datasetio:
|
||||
|
|
|
@ -11,9 +11,7 @@ from ..starter.starter import get_distribution_template as get_starter_distribut
|
|||
|
||||
|
||||
def get_distribution_template() -> DistributionTemplate:
|
||||
template = get_starter_distribution_template()
|
||||
name = "ci-tests"
|
||||
template.name = name
|
||||
template = get_starter_distribution_template(name="ci-tests")
|
||||
template.description = "CI tests for Llama Stack"
|
||||
|
||||
return template
|
||||
|
|
|
@ -89,28 +89,28 @@ providers:
|
|||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/faiss_store.db
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/faiss_store.db
|
||||
- provider_id: sqlite-vec
|
||||
provider_type: inline::sqlite-vec
|
||||
config:
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/sqlite_vec.db
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/sqlite_vec.db
|
||||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/sqlite_vec_registry.db
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/sqlite_vec_registry.db
|
||||
- provider_id: ${env.MILVUS_URL:+milvus}
|
||||
provider_type: inline::milvus
|
||||
config:
|
||||
db_path: ${env.MILVUS_DB_PATH:=~/.llama/distributions/starter}/milvus.db
|
||||
db_path: ${env.MILVUS_DB_PATH:=~/.llama/distributions/ci-tests}/milvus.db
|
||||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/milvus_registry.db
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/milvus_registry.db
|
||||
- provider_id: ${env.CHROMADB_URL:+chromadb}
|
||||
provider_type: remote::chromadb
|
||||
config:
|
||||
url: ${env.CHROMADB_URL:=}
|
||||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter/}/chroma_remote_registry.db
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests/}/chroma_remote_registry.db
|
||||
- provider_id: ${env.PGVECTOR_DB:+pgvector}
|
||||
provider_type: remote::pgvector
|
||||
config:
|
||||
|
@ -121,15 +121,15 @@ providers:
|
|||
password: ${env.PGVECTOR_PASSWORD:=}
|
||||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/pgvector_registry.db
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/pgvector_registry.db
|
||||
files:
|
||||
- provider_id: meta-reference-files
|
||||
provider_type: inline::localfs
|
||||
config:
|
||||
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/starter/files}
|
||||
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/ci-tests/files}
|
||||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/files_metadata.db
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/files_metadata.db
|
||||
safety:
|
||||
- provider_id: llama-guard
|
||||
provider_type: inline::llama-guard
|
||||
|
@ -156,13 +156,10 @@ providers:
|
|||
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/trace_store.db
|
||||
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
|
||||
post_training:
|
||||
- provider_id: huggingface-cpu
|
||||
provider_type: inline::huggingface-cpu
|
||||
- provider_id: torchtune-cpu
|
||||
provider_type: inline::torchtune-cpu
|
||||
config:
|
||||
checkpoint_format: huggingface
|
||||
distributed_backend: null
|
||||
device: cpu
|
||||
dpo_output_dir: ~/.llama/distributions/ci-tests/dpo_output
|
||||
checkpoint_format: meta
|
||||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
---
|
||||
orphan: true
|
||||
---
|
||||
# Meta Reference Distribution
|
||||
# Meta Reference GPU Distribution
|
||||
|
||||
```{toctree}
|
||||
:maxdepth: 2
|
||||
|
@ -29,7 +29,7 @@ The following environment variables can be configured:
|
|||
|
||||
## Prerequisite: Downloading Models
|
||||
|
||||
Please use `llama model list --downloaded` to check that you have llama model checkpoints downloaded in `~/.llama` before proceeding. See [installation guide](https://llama-stack.readthedocs.io/en/latest/references/llama_cli_reference/download_models.html) here to download the models. Run `llama model list` to see the available models to download, and `llama model download` to download the checkpoints.
|
||||
Please use `llama model list --downloaded` to check that you have llama model checkpoints downloaded in `~/.llama` before proceeding. See [installation guide](../../references/llama_cli_reference/download_models.md) here to download the models. Run `llama model list` to see the available models to download, and `llama model download` to download the checkpoints.
|
||||
|
||||
```
|
||||
$ llama model list --downloaded
|
||||
|
|
|
@ -134,6 +134,11 @@ models:
|
|||
provider_id: nvidia
|
||||
provider_model_id: meta/llama-3.3-70b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: nvidia/vila
|
||||
provider_id: nvidia
|
||||
provider_model_id: nvidia/vila
|
||||
model_type: llm
|
||||
- metadata:
|
||||
embedding_dimension: 2048
|
||||
context_length: 8192
|
||||
|
|
|
@ -43,7 +43,7 @@ def get_inference_providers() -> tuple[list[Provider], dict[str, list[ProviderMo
|
|||
"openai",
|
||||
[
|
||||
ProviderModelEntry(
|
||||
provider_model_id="openai/gpt-4o",
|
||||
provider_model_id="gpt-4o",
|
||||
model_type=ModelType.llm,
|
||||
)
|
||||
],
|
||||
|
@ -53,7 +53,7 @@ def get_inference_providers() -> tuple[list[Provider], dict[str, list[ProviderMo
|
|||
"anthropic",
|
||||
[
|
||||
ProviderModelEntry(
|
||||
provider_model_id="anthropic/claude-3-5-sonnet-latest",
|
||||
provider_model_id="claude-3-5-sonnet-latest",
|
||||
model_type=ModelType.llm,
|
||||
)
|
||||
],
|
||||
|
@ -206,13 +206,6 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
uri="huggingface://datasets/llamastack/math_500?split=test",
|
||||
),
|
||||
),
|
||||
DatasetInput(
|
||||
dataset_id="bfcl",
|
||||
purpose=DatasetPurpose.eval_messages_answer,
|
||||
source=URIDataSource(
|
||||
uri="huggingface://datasets/llamastack/bfcl_v3?split=train",
|
||||
),
|
||||
),
|
||||
DatasetInput(
|
||||
dataset_id="ifeval",
|
||||
purpose=DatasetPurpose.eval_messages_answer,
|
||||
|
@ -250,11 +243,6 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
dataset_id="math_500",
|
||||
scoring_functions=["basic::regex_parser_math_response"],
|
||||
),
|
||||
BenchmarkInput(
|
||||
benchmark_id="meta-reference-bfcl",
|
||||
dataset_id="bfcl",
|
||||
scoring_functions=["basic::bfcl"],
|
||||
),
|
||||
BenchmarkInput(
|
||||
benchmark_id="meta-reference-ifeval",
|
||||
dataset_id="ifeval",
|
||||
|
|
|
@ -136,14 +136,14 @@ inference_store:
|
|||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/open-benchmark}/inference_store.db
|
||||
models:
|
||||
- metadata: {}
|
||||
model_id: openai/gpt-4o
|
||||
model_id: gpt-4o
|
||||
provider_id: openai
|
||||
provider_model_id: openai/gpt-4o
|
||||
provider_model_id: gpt-4o
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: anthropic/claude-3-5-sonnet-latest
|
||||
model_id: claude-3-5-sonnet-latest
|
||||
provider_id: anthropic
|
||||
provider_model_id: anthropic/claude-3-5-sonnet-latest
|
||||
provider_model_id: claude-3-5-sonnet-latest
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: gemini/gemini-1.5-flash
|
||||
|
@ -188,12 +188,6 @@ datasets:
|
|||
uri: huggingface://datasets/llamastack/math_500?split=test
|
||||
metadata: {}
|
||||
dataset_id: math_500
|
||||
- purpose: eval/messages-answer
|
||||
source:
|
||||
type: uri
|
||||
uri: huggingface://datasets/llamastack/bfcl_v3?split=train
|
||||
metadata: {}
|
||||
dataset_id: bfcl
|
||||
- purpose: eval/messages-answer
|
||||
source:
|
||||
type: uri
|
||||
|
@ -228,11 +222,6 @@ benchmarks:
|
|||
- basic::regex_parser_math_response
|
||||
metadata: {}
|
||||
benchmark_id: meta-reference-math-500
|
||||
- dataset_id: bfcl
|
||||
scoring_functions:
|
||||
- basic::bfcl
|
||||
metadata: {}
|
||||
benchmark_id: meta-reference-bfcl
|
||||
- dataset_id: ifeval
|
||||
scoring_functions:
|
||||
- basic::ifeval
|
||||
|
|
|
@ -35,7 +35,7 @@ distribution_spec:
|
|||
telemetry:
|
||||
- provider_type: inline::meta-reference
|
||||
post_training:
|
||||
- provider_type: inline::torchtune-gpu
|
||||
- provider_type: inline::huggingface-gpu
|
||||
eval:
|
||||
- provider_type: inline::meta-reference
|
||||
datasetio:
|
||||
|
|
|
@ -89,28 +89,28 @@ providers:
|
|||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/faiss_store.db
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/faiss_store.db
|
||||
- provider_id: sqlite-vec
|
||||
provider_type: inline::sqlite-vec
|
||||
config:
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/sqlite_vec.db
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/sqlite_vec.db
|
||||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/sqlite_vec_registry.db
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/sqlite_vec_registry.db
|
||||
- provider_id: ${env.MILVUS_URL:+milvus}
|
||||
provider_type: inline::milvus
|
||||
config:
|
||||
db_path: ${env.MILVUS_DB_PATH:=~/.llama/distributions/starter}/milvus.db
|
||||
db_path: ${env.MILVUS_DB_PATH:=~/.llama/distributions/starter-gpu}/milvus.db
|
||||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/milvus_registry.db
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/milvus_registry.db
|
||||
- provider_id: ${env.CHROMADB_URL:+chromadb}
|
||||
provider_type: remote::chromadb
|
||||
config:
|
||||
url: ${env.CHROMADB_URL:=}
|
||||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter/}/chroma_remote_registry.db
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu/}/chroma_remote_registry.db
|
||||
- provider_id: ${env.PGVECTOR_DB:+pgvector}
|
||||
provider_type: remote::pgvector
|
||||
config:
|
||||
|
@ -121,15 +121,15 @@ providers:
|
|||
password: ${env.PGVECTOR_PASSWORD:=}
|
||||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/pgvector_registry.db
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/pgvector_registry.db
|
||||
files:
|
||||
- provider_id: meta-reference-files
|
||||
provider_type: inline::localfs
|
||||
config:
|
||||
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/starter/files}
|
||||
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/starter-gpu/files}
|
||||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/files_metadata.db
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/files_metadata.db
|
||||
safety:
|
||||
- provider_id: llama-guard
|
||||
provider_type: inline::llama-guard
|
||||
|
@ -156,10 +156,13 @@ providers:
|
|||
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/trace_store.db
|
||||
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
|
||||
post_training:
|
||||
- provider_id: torchtune-gpu
|
||||
provider_type: inline::torchtune-gpu
|
||||
- provider_id: huggingface-gpu
|
||||
provider_type: inline::huggingface-gpu
|
||||
config:
|
||||
checkpoint_format: meta
|
||||
checkpoint_format: huggingface
|
||||
distributed_backend: null
|
||||
device: cpu
|
||||
dpo_output_dir: ~/.llama/distributions/starter-gpu/dpo_output
|
||||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
|
|
@ -11,12 +11,10 @@ from ..starter.starter import get_distribution_template as get_starter_distribut
|
|||
|
||||
|
||||
def get_distribution_template() -> DistributionTemplate:
|
||||
template = get_starter_distribution_template()
|
||||
name = "starter-gpu"
|
||||
template.name = name
|
||||
template = get_starter_distribution_template(name="starter-gpu")
|
||||
template.description = "Quick start template for running Llama Stack with several popular providers. This distribution is intended for GPU-enabled environments."
|
||||
|
||||
template.providers["post_training"] = [
|
||||
BuildProvider(provider_type="inline::torchtune-gpu"),
|
||||
BuildProvider(provider_type="inline::huggingface-gpu"),
|
||||
]
|
||||
return template
|
||||
|
|
|
@ -35,7 +35,7 @@ distribution_spec:
|
|||
telemetry:
|
||||
- provider_type: inline::meta-reference
|
||||
post_training:
|
||||
- provider_type: inline::huggingface-cpu
|
||||
- provider_type: inline::torchtune-cpu
|
||||
eval:
|
||||
- provider_type: inline::meta-reference
|
||||
datasetio:
|
||||
|
|
|
@ -156,13 +156,10 @@ providers:
|
|||
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/trace_store.db
|
||||
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
|
||||
post_training:
|
||||
- provider_id: huggingface-cpu
|
||||
provider_type: inline::huggingface-cpu
|
||||
- provider_id: torchtune-cpu
|
||||
provider_type: inline::torchtune-cpu
|
||||
config:
|
||||
checkpoint_format: huggingface
|
||||
distributed_backend: null
|
||||
device: cpu
|
||||
dpo_output_dir: ~/.llama/distributions/starter/dpo_output
|
||||
checkpoint_format: meta
|
||||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
|
|
@ -99,9 +99,8 @@ def get_remote_inference_providers() -> list[Provider]:
|
|||
return inference_providers
|
||||
|
||||
|
||||
def get_distribution_template() -> DistributionTemplate:
|
||||
def get_distribution_template(name: str = "starter") -> DistributionTemplate:
|
||||
remote_inference_providers = get_remote_inference_providers()
|
||||
name = "starter"
|
||||
|
||||
providers = {
|
||||
"inference": [BuildProvider(provider_type=p.provider_type, module=p.module) for p in remote_inference_providers]
|
||||
|
@ -120,7 +119,7 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
],
|
||||
"agents": [BuildProvider(provider_type="inline::meta-reference")],
|
||||
"telemetry": [BuildProvider(provider_type="inline::meta-reference")],
|
||||
"post_training": [BuildProvider(provider_type="inline::huggingface-cpu")],
|
||||
"post_training": [BuildProvider(provider_type="inline::torchtune-cpu")],
|
||||
"eval": [BuildProvider(provider_type="inline::meta-reference")],
|
||||
"datasetio": [
|
||||
BuildProvider(provider_type="remote::huggingface"),
|
||||
|
|
|
@ -178,9 +178,9 @@ class ReferenceBatchesImpl(Batches):
|
|||
|
||||
# TODO: set expiration time for garbage collection
|
||||
|
||||
if endpoint not in ["/v1/chat/completions"]:
|
||||
if endpoint not in ["/v1/chat/completions", "/v1/completions"]:
|
||||
raise ValueError(
|
||||
f"Invalid endpoint: {endpoint}. Supported values: /v1/chat/completions. Code: invalid_value. Param: endpoint",
|
||||
f"Invalid endpoint: {endpoint}. Supported values: /v1/chat/completions, /v1/completions. Code: invalid_value. Param: endpoint",
|
||||
)
|
||||
|
||||
if completion_window != "24h":
|
||||
|
@ -424,13 +424,21 @@ class ReferenceBatchesImpl(Batches):
|
|||
)
|
||||
valid = False
|
||||
|
||||
for param, expected_type, type_string in [
|
||||
if batch.endpoint == "/v1/chat/completions":
|
||||
required_params = [
|
||||
("model", str, "a string"),
|
||||
# messages is specific to /v1/chat/completions
|
||||
# we could skip validating messages here and let inference fail. however,
|
||||
# that would be a very expensive way to find out messages is wrong.
|
||||
("messages", list, "an array"), # TODO: allow messages to be a string?
|
||||
]:
|
||||
]
|
||||
else: # /v1/completions
|
||||
required_params = [
|
||||
("model", str, "a string"),
|
||||
("prompt", str, "a string"), # TODO: allow prompt to be a list of strings??
|
||||
]
|
||||
|
||||
for param, expected_type, type_string in required_params:
|
||||
if param not in body:
|
||||
errors.append(
|
||||
BatchError(
|
||||
|
@ -591,6 +599,7 @@ class ReferenceBatchesImpl(Batches):
|
|||
|
||||
try:
|
||||
# TODO(SECURITY): review body for security issues
|
||||
if request.url == "/v1/chat/completions":
|
||||
request.body["messages"] = [convert_to_openai_message_param(msg) for msg in request.body["messages"]]
|
||||
chat_response = await self.inference_api.openai_chat_completion(**request.body)
|
||||
|
||||
|
@ -605,6 +614,22 @@ class ReferenceBatchesImpl(Batches):
|
|||
"body": chat_response.model_dump_json(),
|
||||
},
|
||||
}
|
||||
else: # /v1/completions
|
||||
completion_response = await self.inference_api.openai_completion(**request.body)
|
||||
|
||||
# this is for mypy, we don't allow streaming so we'll get the right type
|
||||
assert hasattr(completion_response, "model_dump_json"), (
|
||||
"Completion response must have model_dump_json method"
|
||||
)
|
||||
return {
|
||||
"id": request_id,
|
||||
"custom_id": request.custom_id,
|
||||
"response": {
|
||||
"status_code": 200,
|
||||
"request_id": request_id,
|
||||
"body": completion_response.model_dump_json(),
|
||||
},
|
||||
}
|
||||
except Exception as e:
|
||||
logger.info(f"Error processing request {request.custom_id} in batch {batch_id}: {e}")
|
||||
return {
|
||||
|
|
|
@ -86,11 +86,16 @@ class LocalfsFilesImpl(Files):
|
|||
self,
|
||||
file: Annotated[UploadFile, File()],
|
||||
purpose: Annotated[OpenAIFilePurpose, Form()],
|
||||
expires_after_anchor: Annotated[str | None, Form(alias="expires_after[anchor]")] = None,
|
||||
expires_after_seconds: Annotated[int | None, Form(alias="expires_after[seconds]")] = None,
|
||||
) -> OpenAIFileObject:
|
||||
"""Upload a file that can be used across various endpoints."""
|
||||
if not self.sql_store:
|
||||
raise RuntimeError("Files provider not initialized")
|
||||
|
||||
if expires_after_anchor is not None or expires_after_seconds is not None:
|
||||
raise NotImplementedError("File expiration is not supported by this provider")
|
||||
|
||||
file_id = self._generate_file_id()
|
||||
file_path = self._get_file_path(file_id)
|
||||
|
||||
|
|
|
@ -22,7 +22,6 @@ from llama_stack.providers.utils.common.data_schema_validator import (
|
|||
)
|
||||
|
||||
from .config import BasicScoringConfig
|
||||
from .scoring_fn.bfcl_scoring_fn import BFCLScoringFn
|
||||
from .scoring_fn.docvqa_scoring_fn import DocVQAScoringFn
|
||||
from .scoring_fn.equality_scoring_fn import EqualityScoringFn
|
||||
from .scoring_fn.ifeval_scoring_fn import IfEvalScoringFn
|
||||
|
@ -37,7 +36,6 @@ FIXED_FNS = [
|
|||
SubsetOfScoringFn,
|
||||
RegexParserScoringFn,
|
||||
RegexParserMathResponseScoringFn,
|
||||
BFCLScoringFn,
|
||||
IfEvalScoringFn,
|
||||
DocVQAScoringFn,
|
||||
]
|
||||
|
|
|
@ -1,93 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.scoring import ScoringResultRow
|
||||
from llama_stack.apis.scoring_functions import ScoringFnParams
|
||||
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
|
||||
|
||||
from ..utils.bfcl.ast_parser import decode_ast
|
||||
from ..utils.bfcl.checker import ast_checker, is_empty_output
|
||||
from .fn_defs.bfcl import bfcl
|
||||
|
||||
|
||||
def postprocess(x: dict[str, Any], test_category: str) -> dict[str, Any]:
|
||||
contain_func_call = False
|
||||
error = None
|
||||
error_type = None
|
||||
checker_result = {}
|
||||
try:
|
||||
prediction = decode_ast(x["generated_answer"], x["language"]) or ""
|
||||
contain_func_call = True
|
||||
# if not is_function_calling_format_output(prediction):
|
||||
if is_empty_output(prediction):
|
||||
contain_func_call = False
|
||||
error = "Did not output in the specified format. Note: the model_result is wrapped in a string to ensure json serializability."
|
||||
error_type = "ast_decoder:decoder_wrong_output_format"
|
||||
else:
|
||||
checker_result = ast_checker(
|
||||
json.loads(x["function"]),
|
||||
prediction,
|
||||
json.loads(x["ground_truth"]),
|
||||
x["language"],
|
||||
test_category=test_category,
|
||||
model_name="",
|
||||
)
|
||||
except Exception as e:
|
||||
prediction = ""
|
||||
error = f"Invalid syntax. Failed to decode AST. {str(e)}"
|
||||
error_type = "ast_decoder:decoder_failed"
|
||||
return {
|
||||
"prediction": prediction,
|
||||
"contain_func_call": contain_func_call,
|
||||
"valid": checker_result.get("valid", False),
|
||||
"error": error or checker_result.get("error", ""),
|
||||
"error_type": error_type or checker_result.get("error_type", ""),
|
||||
}
|
||||
|
||||
|
||||
def gen_valid(x: dict[str, Any]) -> dict[str, float]:
|
||||
return {"valid": x["valid"]}
|
||||
|
||||
|
||||
def gen_relevance_acc(x: dict[str, Any]) -> dict[str, float]:
|
||||
# This function serves for both relevance and irrelevance tests, which share the exact opposite logic.
|
||||
# If `test_category` is "irrelevance", the model is expected to output no function call.
|
||||
# No function call means either the AST decoding fails (a error message is generated) or the decoded AST does not contain any function call (such as a empty list, `[]`).
|
||||
# If `test_category` is "relevance", the model is expected to output to a function call, and empty list doesn't count as a function call.
|
||||
acc = not x["contain_func_call"] if "irrelevance" in x["id"] else x["contain_func_call"]
|
||||
return {"valid": float(acc)}
|
||||
|
||||
|
||||
class BFCLScoringFn(RegisteredBaseScoringFn):
|
||||
"""
|
||||
A scoring_fn for BFCL
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.supported_fn_defs_registry = {
|
||||
bfcl.identifier: bfcl,
|
||||
}
|
||||
|
||||
async def score_row(
|
||||
self,
|
||||
input_row: dict[str, Any],
|
||||
scoring_fn_identifier: str | None = "bfcl",
|
||||
scoring_params: ScoringFnParams | None = None,
|
||||
) -> ScoringResultRow:
|
||||
test_category = re.sub(r"_[0-9_-]+$", "", input_row["id"])
|
||||
score_result = postprocess(input_row, test_category)
|
||||
if test_category in {"irrelevance", "live_relevance", "live_irrelevance"}:
|
||||
score = gen_relevance_acc(score_result)["valid"]
|
||||
else:
|
||||
score = gen_valid(score_result)["valid"]
|
||||
return {
|
||||
"score": float(score),
|
||||
}
|
|
@ -1,21 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.common.type_system import NumberType
|
||||
from llama_stack.apis.scoring_functions import (
|
||||
AggregationFunctionType,
|
||||
BasicScoringFnParams,
|
||||
ScoringFn,
|
||||
)
|
||||
|
||||
bfcl = ScoringFn(
|
||||
identifier="basic::bfcl",
|
||||
description="BFCL complex scoring",
|
||||
return_type=NumberType(),
|
||||
provider_id="basic",
|
||||
provider_resource_id="bfcl",
|
||||
params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.accuracy]),
|
||||
)
|
|
@ -1,296 +0,0 @@
|
|||
# ruff: noqa
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import ast
|
||||
|
||||
from .tree_sitter import get_parser
|
||||
|
||||
|
||||
def parse_java_function_call(source_code):
|
||||
if not source_code.endswith(";"):
|
||||
source_code += ";" # Necessary for the parser not to register an error
|
||||
parser = get_parser("java")
|
||||
tree = parser.parse(bytes(source_code, "utf8"))
|
||||
root_node = tree.root_node
|
||||
|
||||
if root_node.has_error:
|
||||
raise Exception("Error parsing java the source code.")
|
||||
|
||||
def get_text(node):
|
||||
"""Returns the text represented by the node."""
|
||||
return source_code[node.start_byte : node.end_byte]
|
||||
|
||||
def traverse_node(node, nested=False):
|
||||
if node.type == "string_literal":
|
||||
if nested:
|
||||
return get_text(node)
|
||||
# Strip surrounding quotes from string literals
|
||||
return get_text(node)[1:-1]
|
||||
elif node.type == "character_literal":
|
||||
if nested:
|
||||
return get_text(node)
|
||||
# Strip surrounding single quotes from character literals
|
||||
return get_text(node)[1:-1]
|
||||
"""Traverse the node to collect texts for complex structures."""
|
||||
if node.type in [
|
||||
"identifier",
|
||||
"class_literal",
|
||||
"type_identifier",
|
||||
"method_invocation",
|
||||
]:
|
||||
return get_text(node)
|
||||
elif node.type == "array_creation_expression":
|
||||
# Handle array creation expression specifically
|
||||
type_node = node.child_by_field_name("type")
|
||||
value_node = node.child_by_field_name("value")
|
||||
type_text = traverse_node(type_node, True)
|
||||
value_text = traverse_node(value_node, True)
|
||||
return f"new {type_text}[]{value_text}"
|
||||
elif node.type == "object_creation_expression":
|
||||
# Handle object creation expression specifically
|
||||
type_node = node.child_by_field_name("type")
|
||||
arguments_node = node.child_by_field_name("arguments")
|
||||
type_text = traverse_node(type_node, True)
|
||||
if arguments_node:
|
||||
# Process each argument carefully, avoiding unnecessary punctuation
|
||||
argument_texts = []
|
||||
for child in arguments_node.children:
|
||||
if child.type not in [
|
||||
",",
|
||||
"(",
|
||||
")",
|
||||
]: # Exclude commas and parentheses
|
||||
argument_text = traverse_node(child, True)
|
||||
argument_texts.append(argument_text)
|
||||
arguments_text = ", ".join(argument_texts)
|
||||
return f"new {type_text}({arguments_text})"
|
||||
else:
|
||||
return f"new {type_text}()"
|
||||
elif node.type == "set":
|
||||
# Handling sets specifically
|
||||
items = [traverse_node(n, True) for n in node.children if n.type not in [",", "set"]]
|
||||
return "{" + ", ".join(items) + "}"
|
||||
|
||||
elif node.child_count > 0:
|
||||
return "".join(traverse_node(child, True) for child in node.children)
|
||||
else:
|
||||
return get_text(node)
|
||||
|
||||
def extract_arguments(args_node):
|
||||
arguments = {}
|
||||
for child in args_node.children:
|
||||
if child.type == "assignment_expression":
|
||||
# For named parameters
|
||||
name_node, value_node = child.children[0], child.children[2]
|
||||
name = get_text(name_node)
|
||||
value = traverse_node(value_node)
|
||||
if name in arguments:
|
||||
if not isinstance(arguments[name], list):
|
||||
arguments[name] = [arguments[name]]
|
||||
arguments[name].append(value)
|
||||
else:
|
||||
arguments[name] = value
|
||||
# arguments.append({'name': name, 'value': value})
|
||||
elif child.type in ["identifier", "class_literal", "set"]:
|
||||
# For unnamed parameters and handling sets
|
||||
value = traverse_node(child)
|
||||
if None in arguments:
|
||||
if not isinstance(arguments[None], list):
|
||||
arguments[None] = [arguments[None]]
|
||||
arguments[None].append(value)
|
||||
else:
|
||||
arguments[None] = value
|
||||
return arguments
|
||||
|
||||
def traverse(node):
|
||||
if node.type == "method_invocation":
|
||||
# Extract the function name and its arguments
|
||||
method_name = get_text(node.child_by_field_name("name"))
|
||||
class_name_node = node.child_by_field_name("object")
|
||||
if class_name_node:
|
||||
class_name = get_text(class_name_node)
|
||||
function_name = f"{class_name}.{method_name}"
|
||||
else:
|
||||
function_name = method_name
|
||||
arguments_node = node.child_by_field_name("arguments")
|
||||
if arguments_node:
|
||||
arguments = extract_arguments(arguments_node)
|
||||
for key, value in arguments.items():
|
||||
if isinstance(value, list):
|
||||
raise Exception("Error: Multiple arguments with the same name are not supported.")
|
||||
return [{function_name: arguments}]
|
||||
|
||||
else:
|
||||
for child in node.children:
|
||||
result = traverse(child)
|
||||
if result:
|
||||
return result
|
||||
|
||||
result = traverse(root_node)
|
||||
return result if result else {}
|
||||
|
||||
|
||||
def parse_javascript_function_call(source_code):
|
||||
if not source_code.endswith(";"):
|
||||
source_code += ";" # Necessary for the parser not to register an error
|
||||
parser = get_parser("javascript")
|
||||
# Parse the source code
|
||||
tree = parser.parse(bytes(source_code, "utf8"))
|
||||
root_node = tree.root_node
|
||||
if root_node.has_error:
|
||||
raise Exception("Error js parsing the source code.")
|
||||
|
||||
# Function to recursively extract argument details
|
||||
def extract_arguments(node):
|
||||
args = {}
|
||||
for child in node.children:
|
||||
if child.type == "assignment_expression":
|
||||
# Extract left (name) and right (value) parts of the assignment
|
||||
name = child.children[0].text.decode("utf-8")
|
||||
value = child.children[2].text.decode("utf-8")
|
||||
if (value.startswith('"') and value.endswith('"')) or (value.startswith("'") and value.endswith("'")):
|
||||
value = value[1:-1] # Trim the quotation marks
|
||||
if name in args:
|
||||
if not isinstance(args[name], list):
|
||||
args[name] = [args[name]]
|
||||
args[name].append(value)
|
||||
else:
|
||||
args[name] = value
|
||||
|
||||
elif child.type == "identifier" or child.type == "true":
|
||||
# Handle non-named arguments and boolean values
|
||||
value = child.text.decode("utf-8")
|
||||
if None in args:
|
||||
if not isinstance(args[None], list):
|
||||
args[None] = [args[None]]
|
||||
args[None].append(value)
|
||||
else:
|
||||
args[None] = value
|
||||
return args
|
||||
|
||||
# Find the function call and extract its name and arguments
|
||||
if root_node.type == "program":
|
||||
for child in root_node.children:
|
||||
if child.type == "expression_statement":
|
||||
for sub_child in child.children:
|
||||
if sub_child.type == "call_expression":
|
||||
function_name = sub_child.children[0].text.decode("utf8")
|
||||
arguments_node = sub_child.children[1]
|
||||
parameters = extract_arguments(arguments_node)
|
||||
for key, value in parameters.items():
|
||||
if isinstance(value, list):
|
||||
raise Exception("Error: Multiple arguments with the same name are not supported.")
|
||||
result = [{function_name: parameters}]
|
||||
return result
|
||||
|
||||
|
||||
def ast_parse(input_str, language="Python"):
|
||||
if language == "Python":
|
||||
cleaned_input = input_str.strip("[]'")
|
||||
parsed = ast.parse(cleaned_input, mode="eval")
|
||||
extracted = []
|
||||
if isinstance(parsed.body, ast.Call):
|
||||
extracted.append(resolve_ast_call(parsed.body))
|
||||
else:
|
||||
for elem in parsed.body.elts:
|
||||
extracted.append(resolve_ast_call(elem))
|
||||
return extracted
|
||||
elif language == "Java":
|
||||
return parse_java_function_call(input_str[1:-1]) # Remove the [ and ] from the string
|
||||
elif language == "JavaScript":
|
||||
return parse_javascript_function_call(input_str[1:-1])
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported language: {language}")
|
||||
|
||||
|
||||
def resolve_ast_call(elem):
|
||||
# Handle nested attributes for deeply nested module paths
|
||||
func_parts = []
|
||||
func_part = elem.func
|
||||
while isinstance(func_part, ast.Attribute):
|
||||
func_parts.append(func_part.attr)
|
||||
func_part = func_part.value
|
||||
if isinstance(func_part, ast.Name):
|
||||
func_parts.append(func_part.id)
|
||||
func_name = ".".join(reversed(func_parts))
|
||||
args_dict = {}
|
||||
# Parse when args are simply passed as an unnamed dictionary arg
|
||||
for arg in elem.args:
|
||||
if isinstance(arg, ast.Dict):
|
||||
for key, value in zip(arg.keys, arg.values):
|
||||
if isinstance(key, ast.Constant):
|
||||
arg_name = key.value
|
||||
output = resolve_ast_by_type(value)
|
||||
args_dict[arg_name] = output
|
||||
for arg in elem.keywords:
|
||||
output = resolve_ast_by_type(arg.value)
|
||||
args_dict[arg.arg] = output
|
||||
return {func_name: args_dict}
|
||||
|
||||
|
||||
def resolve_ast_by_type(value):
|
||||
if isinstance(value, ast.Constant):
|
||||
if value.value is Ellipsis:
|
||||
output = "..."
|
||||
else:
|
||||
output = value.value
|
||||
elif isinstance(value, ast.UnaryOp):
|
||||
output = -value.operand.value
|
||||
elif isinstance(value, ast.List):
|
||||
output = [resolve_ast_by_type(v) for v in value.elts]
|
||||
elif isinstance(value, ast.Dict):
|
||||
output = {resolve_ast_by_type(k): resolve_ast_by_type(v) for k, v in zip(value.keys, value.values)}
|
||||
elif isinstance(value, ast.NameConstant): # Added this condition to handle boolean values
|
||||
output = value.value
|
||||
elif isinstance(value, ast.BinOp): # Added this condition to handle function calls as arguments
|
||||
output = eval(ast.unparse(value))
|
||||
elif isinstance(value, ast.Name):
|
||||
output = value.id
|
||||
elif isinstance(value, ast.Call):
|
||||
if len(value.keywords) == 0:
|
||||
output = ast.unparse(value)
|
||||
else:
|
||||
output = resolve_ast_call(value)
|
||||
elif isinstance(value, ast.Tuple):
|
||||
output = tuple(resolve_ast_by_type(v) for v in value.elts)
|
||||
elif isinstance(value, ast.Lambda):
|
||||
output = eval(ast.unparse(value.body[0].value))
|
||||
elif isinstance(value, ast.Ellipsis):
|
||||
output = "..."
|
||||
elif isinstance(value, ast.Subscript):
|
||||
try:
|
||||
output = ast.unparse(value.body[0].value)
|
||||
except:
|
||||
output = ast.unparse(value.value) + "[" + ast.unparse(value.slice) + "]"
|
||||
else:
|
||||
raise Exception(f"Unsupported AST type: {type(value)}")
|
||||
return output
|
||||
|
||||
|
||||
def decode_ast(result, language="Python"):
|
||||
func = result
|
||||
func = func.replace("\n", "") # remove new line characters
|
||||
if not func.startswith("["):
|
||||
func = "[" + func
|
||||
if not func.endswith("]"):
|
||||
func = func + "]"
|
||||
decoded_output = ast_parse(func, language)
|
||||
return decoded_output
|
||||
|
||||
|
||||
def decode_execute(result):
|
||||
func = result
|
||||
func = func.replace("\n", "") # remove new line characters
|
||||
if not func.startswith("["):
|
||||
func = "[" + func
|
||||
if not func.endswith("]"):
|
||||
func = func + "]"
|
||||
decode_output = ast_parse(func)
|
||||
execution_list = []
|
||||
for function_call in decode_output:
|
||||
for key, value in function_call.items():
|
||||
execution_list.append(f"{key}({','.join([f'{k}={repr(v)}' for k, v in value.items()])})")
|
||||
return execution_list
|
|
@ -1,989 +0,0 @@
|
|||
# ruff: noqa
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
# Comment out for now until we actually use the rest checker in evals
|
||||
# import requests # Do not remove this import even though it seems to be unused. It's used in the executable_checker_rest function.
|
||||
|
||||
|
||||
class NoAPIKeyError(Exception):
|
||||
def __init__(self):
|
||||
self.message = "❗️Please fill in the API keys in the function_credential_config.json file. If you do not provide the API keys, the executable test category results will be inaccurate."
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
REAL_TIME_MATCH_ALLOWED_DIFFERENCE = 0.2
|
||||
|
||||
|
||||
JAVA_TYPE_CONVERSION = {
|
||||
"byte": int,
|
||||
"short": int,
|
||||
"integer": int,
|
||||
"float": float,
|
||||
"double": float,
|
||||
"long": int,
|
||||
"boolean": bool,
|
||||
"char": str,
|
||||
"Array": list,
|
||||
"ArrayList": list,
|
||||
"Set": set,
|
||||
"HashMap": dict,
|
||||
"Hashtable": dict,
|
||||
"Queue": list, # this can be `queue.Queue` as well, for simplicity we check with list
|
||||
"Stack": list,
|
||||
"String": str,
|
||||
"any": str,
|
||||
}
|
||||
|
||||
JS_TYPE_CONVERSION = {
|
||||
"String": str,
|
||||
"integer": int,
|
||||
"float": float,
|
||||
"Bigint": int,
|
||||
"Boolean": bool,
|
||||
"dict": dict,
|
||||
"array": list,
|
||||
"any": str,
|
||||
}
|
||||
|
||||
# We switch to conditional import for the following two imports to avoid unnecessary installations.
|
||||
# User doesn't need to setup the tree-sitter packages if they are not running the test for that language.
|
||||
# from js_type_converter import js_type_converter
|
||||
# from java_type_converter import java_type_converter
|
||||
|
||||
PYTHON_TYPE_MAPPING = {
|
||||
"string": str,
|
||||
"integer": int,
|
||||
"float": float,
|
||||
"boolean": bool,
|
||||
"array": list,
|
||||
"tuple": list,
|
||||
"dict": dict,
|
||||
"any": str,
|
||||
}
|
||||
|
||||
# This is the list of types that we need to recursively check its values
|
||||
PYTHON_NESTED_TYPE_CHECK_LIST = ["array", "tuple"]
|
||||
|
||||
|
||||
NESTED_CONVERSION_TYPE_LIST = ["Array", "ArrayList", "array"]
|
||||
|
||||
|
||||
#### Helper functions for AST ####
|
||||
def find_description(func_descriptions, name):
|
||||
if type(func_descriptions) == list:
|
||||
for func_description in func_descriptions:
|
||||
if func_description["name"] == name:
|
||||
return func_description
|
||||
return None
|
||||
else:
|
||||
# it is a dict, there is only one function
|
||||
return func_descriptions
|
||||
|
||||
|
||||
def get_possible_answer_type(possible_answer: list):
|
||||
for answer in possible_answer:
|
||||
if answer != "": # Optional parameter
|
||||
return type(answer)
|
||||
return None
|
||||
|
||||
|
||||
def type_checker(
|
||||
param: str,
|
||||
value,
|
||||
possible_answer: list,
|
||||
expected_type_description: str,
|
||||
expected_type_converted,
|
||||
nested_type_converted,
|
||||
):
|
||||
# NOTE: This type checker only supports nested type checking for one level deep.
|
||||
# We didn't implement recursive type checking for nested types, as it's not needed for the current use case and it's very complex.
|
||||
|
||||
result: Any = {
|
||||
"valid": True,
|
||||
"error": [],
|
||||
"is_variable": False,
|
||||
"error_type": "type_error:simple",
|
||||
}
|
||||
|
||||
is_variable = False
|
||||
# check for the case where a variable is used instead of a actual value.
|
||||
# use the type in possible_answer as the expected type
|
||||
possible_answer_type = get_possible_answer_type(possible_answer)
|
||||
# if possible_answer only contains optional parameters, we can't determine the type
|
||||
if possible_answer_type != None:
|
||||
# we are being precise here.
|
||||
# in fact, possible_answer_type should always be string, as that's how we treat varibale in possible_answer
|
||||
if possible_answer_type != expected_type_converted:
|
||||
is_variable = True
|
||||
|
||||
# value is the same type as in function description
|
||||
if type(value) == expected_type_converted:
|
||||
# We don't need to do recursive check for simple types
|
||||
if nested_type_converted == None:
|
||||
result["is_variable"] = is_variable
|
||||
return result
|
||||
else:
|
||||
for possible_answer_item in possible_answer:
|
||||
flag = True # Each parameter should match to at least one possible answer type.
|
||||
# Here, we assume that each item should be the same type. We could also relax it.
|
||||
if type(possible_answer_item) == list:
|
||||
for value_item in value:
|
||||
checker_result = type_checker(
|
||||
param,
|
||||
value_item,
|
||||
possible_answer_item,
|
||||
str(nested_type_converted),
|
||||
nested_type_converted,
|
||||
None,
|
||||
)
|
||||
if not checker_result["valid"]:
|
||||
flag = False
|
||||
break
|
||||
|
||||
if flag:
|
||||
return {"valid": True, "error": [], "is_variable": is_variable}
|
||||
|
||||
result["valid"] = False
|
||||
result["error"] = [
|
||||
f"Nested type checking failed for parameter {repr(param)}. Expected outer type {expected_type_description} with inner type {str(nested_type_converted)}. Parameter value: {repr(value)}."
|
||||
]
|
||||
result["error_type"] = "type_error:nested"
|
||||
|
||||
# value is not as expected, check for the case where a variable is used instead of a actual value
|
||||
# use the type in possible_answer as the expected type
|
||||
possible_answer_type = get_possible_answer_type(possible_answer)
|
||||
# if possible_answer only contains optional parameters, we can't determine the type
|
||||
if possible_answer_type != None:
|
||||
# we are being precise here.
|
||||
# in fact, possible_answer_type should always be string, as that's how we treat varibale in possible_answer
|
||||
if type(value) == possible_answer_type:
|
||||
result["is_variable"] = True
|
||||
return result
|
||||
|
||||
result["valid"] = False
|
||||
result["error"].append(
|
||||
f"Incorrect type for parameter {repr(param)}. Expected type {expected_type_description}, got {type(value).__name__}. Parameter value: {repr(value)}."
|
||||
)
|
||||
result["error_type"] = "type_error:simple"
|
||||
return result
|
||||
|
||||
|
||||
def standardize_string(input_string: str):
|
||||
# This function standardizes the string by removing all the spaces, ",./-_*^" punctuation, and converting it to lowercase
|
||||
# It will also convert all the single quotes to double quotes
|
||||
# This is used to compare the model output with the possible answers
|
||||
# We don't want to punish model for answer like April 1, 2024 vs April 1,2024, vs April 1 2024
|
||||
regex_string = r"[ \,\.\/\-\_\*\^]"
|
||||
return re.sub(regex_string, "", input_string).lower().replace("'", '"')
|
||||
|
||||
|
||||
def string_checker(param: str, model_output: str, possible_answer: list):
|
||||
standardize_possible_answer = []
|
||||
standardize_model_output = standardize_string(model_output)
|
||||
for i in range(len(possible_answer)):
|
||||
if type(possible_answer[i]) == str:
|
||||
standardize_possible_answer.append(standardize_string(possible_answer[i]))
|
||||
|
||||
if standardize_model_output not in standardize_possible_answer:
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [
|
||||
f"Invalid value for parameter {repr(param)}: {repr(model_output)}. Expected one of {possible_answer}. Case insensitive."
|
||||
],
|
||||
"error_type": "value_error:string",
|
||||
}
|
||||
|
||||
return {"valid": True, "error": []}
|
||||
|
||||
|
||||
def list_checker(param: str, model_output: list, possible_answer: list):
|
||||
# Convert the tuple to a list
|
||||
|
||||
standardize_model_output = list(model_output)
|
||||
|
||||
# If the element in the list is a string, we need to standardize it
|
||||
for i in range(len(standardize_model_output)):
|
||||
if type(standardize_model_output[i]) == str:
|
||||
standardize_model_output[i] = standardize_string(model_output[i])
|
||||
|
||||
standardize_possible_answer: Any = []
|
||||
# We also need to standardize the possible answers
|
||||
for i in range(len(possible_answer)):
|
||||
standardize_possible_answer.append([])
|
||||
for j in range(len(possible_answer[i])):
|
||||
if type(possible_answer[i][j]) == str:
|
||||
standardize_possible_answer[i].append(standardize_string(possible_answer[i][j]))
|
||||
else:
|
||||
standardize_possible_answer[i].append(possible_answer[i][j])
|
||||
|
||||
if standardize_model_output not in standardize_possible_answer:
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [
|
||||
f"Invalid value for parameter {repr(param)}: {repr(model_output)}. Expected one of {possible_answer}."
|
||||
],
|
||||
"error_type": "value_error:list/tuple",
|
||||
}
|
||||
|
||||
return {"valid": True, "error": []}
|
||||
|
||||
|
||||
def dict_checker(param: str, model_output: dict, possible_answers: list):
|
||||
# This function works for simple dictionaries, but not dictionaries with nested dictionaries.
|
||||
# The current dataset only contains simple dictionaries, so this is sufficient.
|
||||
|
||||
result = {"valid": False, "error": [], "error_type": "dict_checker:unclear"}
|
||||
for i in range(len(possible_answers)):
|
||||
if possible_answers[i] == "":
|
||||
continue
|
||||
|
||||
result = {"valid": False, "error": [], "error_type": "dict_checker:unclear"}
|
||||
|
||||
flag = True
|
||||
|
||||
possible_answer = possible_answers[i]
|
||||
# possible_anwer is a single dictionary
|
||||
|
||||
for key, value in model_output.items():
|
||||
if key not in possible_answer:
|
||||
result["valid"] = False
|
||||
result["error"].append(f"Unexpected dict key parameter: '{key}'.") # type: ignore[attr-defined]
|
||||
result["error_type"] = "value_error:dict_key"
|
||||
flag = False
|
||||
break
|
||||
|
||||
standardize_value = value
|
||||
# If the value is a string, we need to standardize it
|
||||
if type(value) == str:
|
||||
standardize_value = standardize_string(value)
|
||||
|
||||
# We also need to standardize the possible answers if they are string
|
||||
standardize_possible_answer = []
|
||||
for i in range(len(possible_answer[key])):
|
||||
if type(possible_answer[key][i]) == str:
|
||||
standardize_possible_answer.append(standardize_string(possible_answer[key][i]))
|
||||
else:
|
||||
standardize_possible_answer.append(possible_answer[key][i])
|
||||
|
||||
if standardize_value not in standardize_possible_answer:
|
||||
result["valid"] = False
|
||||
result["error"].append( # type: ignore[attr-defined]
|
||||
f"Invalid value for parameter {repr(key)}: {repr(value)}. Expected one of {standardize_possible_answer}."
|
||||
)
|
||||
result["error_type"] = "value_error:dict_value"
|
||||
flag = False
|
||||
break
|
||||
|
||||
for key, value in possible_answer.items():
|
||||
if key not in model_output and "" not in value:
|
||||
result["valid"] = False
|
||||
result["error"].append(f"Missing dict key parameter: '{key}'.") # type: ignore[attr-defined]
|
||||
result["error_type"] = "value_error:dict_key"
|
||||
flag = False
|
||||
break
|
||||
|
||||
if flag:
|
||||
return {"valid": True, "error": []}
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def list_dict_checker(param: str, model_output: list, possible_answers: list):
|
||||
# This function takes in a list of dictionaries and checks if each dictionary is valid
|
||||
# The order of the dictionaries in the list must match the order of the possible answers
|
||||
|
||||
result = {"valid": False, "error": [], "error_type": "list_dict_checker:unclear"}
|
||||
|
||||
for answer_index in range(len(possible_answers)):
|
||||
flag = True # True means so far, all dictionaries are valid
|
||||
|
||||
# Only proceed if the number of dictionaries in the list matches the number of dictionaries in the possible answers
|
||||
if len(model_output) != len(possible_answers[answer_index]):
|
||||
result["valid"] = False
|
||||
result["error"] = ["Wrong number of dictionaries in the list."]
|
||||
result["error_type"] = "value_error:list_dict_count"
|
||||
flag = False
|
||||
continue
|
||||
|
||||
for dict_index in range(len(model_output)):
|
||||
result = dict_checker(
|
||||
param,
|
||||
model_output[dict_index],
|
||||
[possible_answers[answer_index][dict_index]],
|
||||
)
|
||||
if not result["valid"]:
|
||||
flag = False
|
||||
break
|
||||
if flag:
|
||||
return {"valid": True, "error": []}
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def simple_function_checker(
|
||||
func_description: dict,
|
||||
model_output: dict,
|
||||
possible_answer: dict,
|
||||
language: str,
|
||||
model_name: str,
|
||||
):
|
||||
possible_answer = list(possible_answer.values())[0]
|
||||
# Extract function name and parameters details
|
||||
func_name = func_description["name"]
|
||||
param_details = func_description["parameters"]["properties"]
|
||||
required_params = func_description["parameters"]["required"]
|
||||
|
||||
# Initialize a result dictionary
|
||||
result = {
|
||||
"valid": True,
|
||||
"error": [],
|
||||
"error_type": "simple_function_checker:unclear",
|
||||
}
|
||||
|
||||
# Check if function name matches
|
||||
if func_name not in model_output:
|
||||
result["valid"] = False
|
||||
result["error"].append( # type: ignore[attr-defined]
|
||||
f"Function name {repr(func_name)} not found in model output."
|
||||
)
|
||||
result["error_type"] = "simple_function_checker:wrong_func_name"
|
||||
return result
|
||||
|
||||
model_params = model_output[func_name]
|
||||
|
||||
# Check for required parameters in model output
|
||||
for param in required_params:
|
||||
if param not in model_params:
|
||||
result["valid"] = False
|
||||
result["error"].append(f"Missing required parameter: {repr(param)}.") # type: ignore[attr-defined]
|
||||
result["error_type"] = "simple_function_checker:missing_required"
|
||||
return result
|
||||
|
||||
# Validate types and values for each parameter in model output
|
||||
for param, value in model_params.items():
|
||||
if param not in param_details or param not in possible_answer:
|
||||
result["valid"] = False
|
||||
result["error"].append(f"Unexpected parameter: {repr(param)}.") # type: ignore[attr-defined]
|
||||
result["error_type"] = "simple_function_checker:unexpected_param"
|
||||
return result
|
||||
|
||||
full_param_details = param_details[param]
|
||||
expected_type_description = full_param_details["type"] # This is a string
|
||||
is_variable = False
|
||||
nested_type_converted = None
|
||||
|
||||
if language == "Java":
|
||||
from evals.utils.bfcl.java_type_converter import java_type_converter
|
||||
|
||||
expected_type_converted = JAVA_TYPE_CONVERSION[expected_type_description]
|
||||
|
||||
if expected_type_description in JAVA_TYPE_CONVERSION:
|
||||
if type(value) != str:
|
||||
result["valid"] = False
|
||||
result["error"].append( # type: ignore[attr-defined]
|
||||
f"Incorrect type for parameter {repr(param)}. Expected type String, got {type(value).__name__}. Parameter value: {repr(value)}."
|
||||
)
|
||||
result["error_type"] = "type_error:java"
|
||||
return result
|
||||
|
||||
if expected_type_description in NESTED_CONVERSION_TYPE_LIST:
|
||||
nested_type = param_details[param]["items"]["type"]
|
||||
nested_type_converted = JAVA_TYPE_CONVERSION[nested_type]
|
||||
value = java_type_converter(value, expected_type_description, nested_type)
|
||||
else:
|
||||
value = java_type_converter(value, expected_type_description)
|
||||
|
||||
elif language == "JavaScript":
|
||||
from evals.utils.bfcl.js_type_converter import js_type_converter
|
||||
|
||||
expected_type_converted = JS_TYPE_CONVERSION[expected_type_description]
|
||||
|
||||
if expected_type_description in JS_TYPE_CONVERSION:
|
||||
if type(value) != str:
|
||||
result["valid"] = False
|
||||
result["error"].append( # type: ignore[attr-defined]
|
||||
f"Incorrect type for parameter {repr(param)}. Expected type String, got {type(value).__name__}. Parameter value: {repr(value)}."
|
||||
)
|
||||
result["error_type"] = "type_error:js"
|
||||
return result
|
||||
|
||||
if expected_type_description in NESTED_CONVERSION_TYPE_LIST:
|
||||
nested_type = param_details[param]["items"]["type"]
|
||||
nested_type_converted = JS_TYPE_CONVERSION[nested_type]
|
||||
value = js_type_converter(value, expected_type_description, nested_type)
|
||||
else:
|
||||
value = js_type_converter(value, expected_type_description)
|
||||
|
||||
elif language == "Python":
|
||||
expected_type_converted = PYTHON_TYPE_MAPPING[expected_type_description]
|
||||
if expected_type_description in PYTHON_NESTED_TYPE_CHECK_LIST:
|
||||
nested_type = param_details[param]["items"]["type"]
|
||||
nested_type_converted = PYTHON_TYPE_MAPPING[nested_type]
|
||||
|
||||
# We convert all tuple value to list when the expected type is tuple.
|
||||
# The conversion is necessary because any tuple in the possible answer would become a list after being processed through json.dump() and json.load().
|
||||
# This does introduce some false positive (eg, when the model provides a list value instead of tuple). We hope to find a better solution in the future.
|
||||
if expected_type_description == "tuple" and type(value) == tuple:
|
||||
value = list(value)
|
||||
|
||||
# Allow python auto conversion from int to float
|
||||
if language == "Python" and expected_type_description == "float" and type(value) == int:
|
||||
value = float(value)
|
||||
|
||||
# Type checking
|
||||
# In fact, we only check for Python here.
|
||||
# Type check for other languages are handled by the type converter, and so their value (after conversion) is always correct.
|
||||
type_check_result = type_checker(
|
||||
param,
|
||||
value,
|
||||
possible_answer[param],
|
||||
expected_type_description,
|
||||
expected_type_converted,
|
||||
nested_type_converted,
|
||||
)
|
||||
is_variable = type_check_result["is_variable"]
|
||||
if not type_check_result["valid"]:
|
||||
return type_check_result
|
||||
|
||||
# It doesn't make sense to special handle dictionaries and list of dictionaries if the value is a variable.
|
||||
# We can just treat the variable as a string and use the normal flow.
|
||||
if not is_variable:
|
||||
# Special handle for dictionaries
|
||||
if expected_type_converted == dict:
|
||||
result = dict_checker(param, value, possible_answer[param])
|
||||
if not result["valid"]:
|
||||
return result
|
||||
continue
|
||||
|
||||
# Special handle for list of dictionaries
|
||||
elif expected_type_converted == list and nested_type_converted == dict:
|
||||
result = list_dict_checker(param, value, possible_answer[param])
|
||||
if not result["valid"]:
|
||||
return result
|
||||
continue
|
||||
|
||||
# Special handle for strings
|
||||
elif expected_type_converted == str:
|
||||
# We don't check for case sensitivity for string, as long as it's not a variable
|
||||
result = string_checker(param, value, possible_answer[param])
|
||||
if not result["valid"]:
|
||||
return result
|
||||
continue
|
||||
|
||||
elif expected_type_converted == list:
|
||||
result = list_checker(param, value, possible_answer[param])
|
||||
if not result["valid"]:
|
||||
return result
|
||||
continue
|
||||
|
||||
# Check if the value is within the possible answers
|
||||
if value not in possible_answer[param]:
|
||||
result["valid"] = False
|
||||
result["error"].append( # type: ignore[attr-defined]
|
||||
f"Invalid value for parameter {repr(param)}: {repr(value)}. Expected one of {possible_answer[param]}."
|
||||
)
|
||||
result["error_type"] = "value_error:others"
|
||||
return result
|
||||
|
||||
# Check for optional parameters not provided but allowed
|
||||
for param in possible_answer:
|
||||
if param not in model_params and "" not in possible_answer[param]:
|
||||
result["valid"] = False
|
||||
result["error"].append( # type: ignore[attr-defined]
|
||||
f"Optional parameter {repr(param)} not provided and not marked as optional."
|
||||
)
|
||||
result["error_type"] = "simple_function_checker:missing_optional"
|
||||
return result
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def parallel_function_checker_enforce_order(
|
||||
func_descriptions: list,
|
||||
model_output: list,
|
||||
possible_answers: dict,
|
||||
language: str,
|
||||
model_name: str,
|
||||
):
|
||||
if len(model_output) != len(possible_answers):
|
||||
return {
|
||||
"valid": False,
|
||||
"error": ["Wrong number of functions."],
|
||||
"error_type": "parallel_function_checker_enforce_order:wrong_count",
|
||||
}
|
||||
|
||||
func_name_list = list(possible_answers.keys())
|
||||
possible_answers_list = []
|
||||
|
||||
for key, value in possible_answers.items():
|
||||
possible_answers_list.append({key: value})
|
||||
|
||||
for i in range(len(possible_answers_list)):
|
||||
func_description = find_description(func_descriptions, func_name_list[i])
|
||||
|
||||
result = simple_function_checker(
|
||||
func_description,
|
||||
model_output[i],
|
||||
possible_answers_list[i],
|
||||
language,
|
||||
model_name,
|
||||
)
|
||||
if not result["valid"]:
|
||||
return result
|
||||
|
||||
return {"valid": True, "error": []}
|
||||
|
||||
|
||||
def parallel_function_checker_no_order(
|
||||
func_descriptions: list,
|
||||
model_output: list,
|
||||
possible_answers: list,
|
||||
language: str,
|
||||
model_name: str,
|
||||
):
|
||||
if len(model_output) != len(possible_answers):
|
||||
return {
|
||||
"valid": False,
|
||||
"error": ["Wrong number of functions."],
|
||||
"error_type": "parallel_function_checker_no_order:wrong_count",
|
||||
}
|
||||
|
||||
matched_indices = []
|
||||
|
||||
# We go throught the possible answers one by one, and eliminate the model output that matches the possible answer
|
||||
# It must be this way because we need ground truth to fetch the correct function description
|
||||
for i in range(len(possible_answers)):
|
||||
# possible_answers[i] is a dictionary with only one key
|
||||
func_name_expected = list(possible_answers[i].keys())[0]
|
||||
func_description = find_description(func_descriptions, func_name_expected)
|
||||
|
||||
all_errors = []
|
||||
|
||||
for index in range(len(model_output)):
|
||||
if index in matched_indices:
|
||||
continue
|
||||
|
||||
result = simple_function_checker(
|
||||
func_description,
|
||||
model_output[index],
|
||||
possible_answers[i],
|
||||
language,
|
||||
model_name,
|
||||
)
|
||||
|
||||
if result["valid"]:
|
||||
matched_indices.append(index)
|
||||
break
|
||||
else:
|
||||
all_errors.append(
|
||||
{
|
||||
f"Model Result Index {index}": {
|
||||
"sub_error": result["error"],
|
||||
"sub_error_type": result["error_type"],
|
||||
"model_output_item": model_output[index],
|
||||
"possible_answer_item": possible_answers[i],
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
if not result["valid"]:
|
||||
considered_indices = [i for i in range(len(model_output)) if i not in matched_indices]
|
||||
all_errors.insert(
|
||||
0,
|
||||
f"Could not find a matching function among index {considered_indices} of model output for index {i} of possible answers.", # type: ignore[arg-type]
|
||||
)
|
||||
return {
|
||||
"valid": False,
|
||||
"error": all_errors,
|
||||
"error_type": "parallel_function_checker_no_order:cannot_find_match",
|
||||
}
|
||||
|
||||
return {"valid": True, "error": []}
|
||||
|
||||
|
||||
def multiple_function_checker(
|
||||
func_descriptions: list,
|
||||
model_output: list,
|
||||
possible_answers: list,
|
||||
language: str,
|
||||
model_name: str,
|
||||
):
|
||||
if len(model_output) != len(possible_answers):
|
||||
return {
|
||||
"valid": False,
|
||||
"error": ["Wrong number of functions."],
|
||||
"error_type": "multiple_function_checker:wrong_count",
|
||||
}
|
||||
|
||||
# possible_answers is a list of only one dictionary with only one key
|
||||
func_name_expected = list(possible_answers[0].keys())[0]
|
||||
func_description = find_description(func_descriptions, func_name_expected)
|
||||
return simple_function_checker(
|
||||
func_description,
|
||||
model_output[0],
|
||||
possible_answers[0],
|
||||
language,
|
||||
model_name,
|
||||
)
|
||||
|
||||
|
||||
def patten_matcher(exec_output, expected_result, function_call, is_sanity_check):
|
||||
result = {"valid": True, "error": [], "error_type": "executable_checker:unclear"}
|
||||
|
||||
if type(exec_output) != type(expected_result):
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [
|
||||
f"Wrong execution result type for {repr(function_call)}. Expected type: {type(expected_result)}, but got: {type(exec_output)}."
|
||||
],
|
||||
"error_type": "executable_checker:wrong_result_type",
|
||||
"model_executed_output": exec_output,
|
||||
}
|
||||
if type(exec_output) == dict:
|
||||
# We loose the requirement for the sanity check as the expected result used in the sanity check might not be the most up-to-date one.
|
||||
# This happens when the key is a timestamp or a random number.
|
||||
if is_sanity_check:
|
||||
if len(exec_output) != len(expected_result):
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [
|
||||
f"Wrong execution result pattern for {repr(function_call)}. Expect type Dict, but wrong number of elements in the output. Expected length: {len(expected_result)}, but got: {len(exec_output)}."
|
||||
],
|
||||
"error_type": "executable_checker:wrong_result_type:dict_length",
|
||||
"model_executed_output": exec_output,
|
||||
}
|
||||
else:
|
||||
return result
|
||||
|
||||
for key, value in expected_result.items():
|
||||
if key not in exec_output:
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [
|
||||
f"Wrong execution result pattern for {repr(function_call)}. Expect type Dict, but key {repr(key)} not found in the model output."
|
||||
],
|
||||
"error_type": "executable_checker:wrong_result_type:dict_key_not_found",
|
||||
"model_executed_output": exec_output,
|
||||
}
|
||||
for key, value in exec_output.items():
|
||||
if key not in expected_result:
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [
|
||||
f"Wrong execution result pattern for {repr(function_call)}. Expect type Dict, but key {repr(key)} not expected in the model output."
|
||||
],
|
||||
"error_type": "executable_checker:wrong_result_type:dict_extra_key",
|
||||
"model_executed_output": exec_output,
|
||||
}
|
||||
if type(exec_output) == list:
|
||||
if len(exec_output) != len(expected_result):
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [
|
||||
f"Wrong execution result pattern for {repr(function_call)}. Expect type list, but wrong number of elements in the output. Expected length: {len(expected_result)}, but got: {len(exec_output)}."
|
||||
],
|
||||
"error_type": "executable_checker:wrong_result_type:list_length",
|
||||
"model_executed_output": exec_output,
|
||||
}
|
||||
return result
|
||||
|
||||
|
||||
#### Helper functions for Exec ####
|
||||
def executable_checker_simple(
|
||||
function_call: str,
|
||||
expected_result,
|
||||
expected_result_type: str,
|
||||
is_sanity_check=False,
|
||||
):
|
||||
result = {"valid": True, "error": [], "error_type": "executable_checker:unclear"}
|
||||
|
||||
exec_dict: Any = {}
|
||||
|
||||
try:
|
||||
exec(
|
||||
"from executable_python_function import *" + "\nresult=" + function_call,
|
||||
exec_dict,
|
||||
)
|
||||
exec_output = exec_dict["result"]
|
||||
except NoAPIKeyError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
result["valid"] = False
|
||||
result["error"].append( # type: ignore[attr-defined]
|
||||
f"Error in execution: {repr(function_call)}. Error: {str(e)}"
|
||||
)
|
||||
result["error_type"] = "executable_checker:execution_error"
|
||||
return result
|
||||
|
||||
# We need to special handle the case where the execution result is a tuple and convert it to a list
|
||||
# Because when json is stored, the tuple is converted to a list, and so the expected result is a list when loaded from json
|
||||
if isinstance(exec_output, tuple):
|
||||
exec_output = list(exec_output)
|
||||
|
||||
if expected_result_type == "exact_match":
|
||||
if exec_output != expected_result:
|
||||
result["valid"] = False
|
||||
result["error"].append( # type: ignore[attr-defined]
|
||||
f"Wrong execution result for {repr(function_call)}. Expected: {expected_result}, but got: {exec_output}."
|
||||
)
|
||||
result["error_type"] = "executable_checker:wrong_result"
|
||||
result["model_executed_output"] = exec_output
|
||||
return result
|
||||
|
||||
elif expected_result_type == "real_time_match":
|
||||
# Allow for 5% difference
|
||||
if (type(expected_result) == float or type(expected_result) == int) and (
|
||||
type(exec_output) == float or type(exec_output) == int
|
||||
):
|
||||
if not (
|
||||
expected_result * (1 - REAL_TIME_MATCH_ALLOWED_DIFFERENCE)
|
||||
<= exec_output
|
||||
<= expected_result * (1 + REAL_TIME_MATCH_ALLOWED_DIFFERENCE)
|
||||
):
|
||||
result["valid"] = False
|
||||
result["error"].append( # type: ignore[attr-defined]
|
||||
f"Wrong execution result for {repr(function_call)}. Expected: {expected_result}, but got: {exec_output}. {REAL_TIME_MATCH_ALLOWED_DIFFERENCE * 100}% difference allowed."
|
||||
)
|
||||
result["error_type"] = "executable_checker:wrong_result_real_time"
|
||||
result["model_executed_output"] = exec_output
|
||||
return result
|
||||
else:
|
||||
result["valid"] = False
|
||||
result["error"].append( # type: ignore[attr-defined]
|
||||
f"Wrong execution result for {repr(function_call)}. Expected: {expected_result}, but got: {exec_output}. Type needs to be float or int for real time match criteria."
|
||||
)
|
||||
result["error_type"] = "executable_checker:wrong_result_real_time"
|
||||
result["model_executed_output"] = exec_output
|
||||
return result
|
||||
|
||||
else:
|
||||
# structural match
|
||||
pattern_match_result = patten_matcher(exec_output, expected_result, function_call, is_sanity_check)
|
||||
if not pattern_match_result["valid"]:
|
||||
return pattern_match_result
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def executable_checker_parallel_no_order(
|
||||
decoded_result: list, expected_exec_result: list, expected_exec_result_type: list
|
||||
):
|
||||
if len(decoded_result) != len(expected_exec_result):
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [
|
||||
f"Wrong number of functions provided. Expected {len(expected_exec_result)}, but got {len(decoded_result)}."
|
||||
],
|
||||
"error_type": "value_error:exec_result_count",
|
||||
}
|
||||
|
||||
matched_indices = []
|
||||
for i in range(len(expected_exec_result)):
|
||||
all_errors = []
|
||||
for index in range(len(decoded_result)):
|
||||
if index in matched_indices:
|
||||
continue
|
||||
|
||||
result = executable_checker_simple(
|
||||
decoded_result[index],
|
||||
expected_exec_result[i],
|
||||
expected_exec_result_type[i],
|
||||
False,
|
||||
)
|
||||
|
||||
if result["valid"]:
|
||||
matched_indices.append(index)
|
||||
break
|
||||
else:
|
||||
all_errors.append(
|
||||
{
|
||||
f"Model Result Index {index}": {
|
||||
"sub_error": result["error"],
|
||||
"sub_error_type": result["error_type"],
|
||||
"model_executed_output": (
|
||||
result["model_executed_output"] if "model_executed_output" in result else None
|
||||
),
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
if not result["valid"]:
|
||||
considered_indices = [i for i in range(len(decoded_result)) if i not in matched_indices]
|
||||
all_errors.insert(
|
||||
0,
|
||||
f"Could not find a matching function among index {considered_indices} of model output for index {i} of possible answers.", # type: ignore[arg-type]
|
||||
)
|
||||
return {
|
||||
"valid": False,
|
||||
"error": all_errors,
|
||||
"error_type": "executable_checker:cannot_find_match",
|
||||
}
|
||||
|
||||
return {"valid": True, "error": [], "error_type": "executable_checker:unclear"}
|
||||
|
||||
|
||||
#### Main function ####
|
||||
def executable_checker_rest(func_call, idx):
|
||||
# Move this here for now to avoid needing to read this file / fix paths to be relative to dataset_dir. Fix when it's actually needed / used.
|
||||
EVAL_GROUND_TRUTH_PATH = "/mnt/wsfuse/fair_llm_v2/datasets/eval/bfcl/rest-eval-response_v5.jsonl" # Ground truth file for v5 for rest execution
|
||||
with open(EVAL_GROUND_TRUTH_PATH, "r") as f:
|
||||
EVAL_GROUND_TRUTH = f.readlines()
|
||||
if "https://geocode.maps.co" in func_call:
|
||||
time.sleep(2)
|
||||
if "requests_get" in func_call:
|
||||
func_call = func_call.replace("requests_get", "requests.get")
|
||||
try:
|
||||
response = eval(func_call)
|
||||
except Exception as e:
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [f"Execution failed. {str(e)}"],
|
||||
"error_type": "executable_checker_rest:execution_error",
|
||||
}
|
||||
|
||||
try:
|
||||
if response.status_code == 200:
|
||||
eval_GT_json = json.loads(EVAL_GROUND_TRUTH[idx])
|
||||
try:
|
||||
if isinstance(eval_GT_json, dict):
|
||||
if isinstance(response.json(), dict):
|
||||
if set(eval_GT_json.keys()) == set(response.json().keys()):
|
||||
return {"valid": True, "error": [], "error_type": ""}
|
||||
return {
|
||||
"valid": False,
|
||||
"error": ["Key inconsistency"],
|
||||
"error_type": "executable_checker_rest:wrong_key",
|
||||
}
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [f"Expected dictionary, but got {type(response.json())}"],
|
||||
"error_type": "executable_checker_rest:wrong_type",
|
||||
}
|
||||
|
||||
elif isinstance(eval_GT_json, list):
|
||||
if isinstance(response.json(), list):
|
||||
if len(eval_GT_json) != len(response.json()):
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [f"Response list length inconsistency."],
|
||||
"error_type": "value_error:exec_result_rest_count",
|
||||
}
|
||||
|
||||
else:
|
||||
for i in range(len(eval_GT_json)):
|
||||
if set(eval_GT_json[i].keys()) != set(response.json()[i].keys()):
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [f"Key inconsistency"],
|
||||
"error_type": "executable_checker_rest:wrong_key",
|
||||
}
|
||||
|
||||
return {"valid": True, "error": []}
|
||||
else:
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [f"Expected list, but got {type(response.json())}"],
|
||||
"error_type": "executable_checker_rest:wrong_type",
|
||||
}
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [f"Expected dict or list, but got {type(response.json())}"],
|
||||
"error_type": "executable_checker_rest:wrong_type",
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [
|
||||
f"Error in execution and type checking. Status code: {response.status_code}. Error: {str(e)}"
|
||||
],
|
||||
"error_type": "executable_checker_rest:response_format_error",
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [f"Execution result status code is not 200, got {response.status_code}"],
|
||||
"error_type": "executable_checker_rest:wrong_status_code",
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [f"Cannot get status code of the response. Error: {str(e)}"],
|
||||
"error_type": "executable_checker_rest:cannot_get_status_code",
|
||||
}
|
||||
|
||||
|
||||
def ast_checker(func_description, model_output, possible_answer, language, test_category, model_name):
|
||||
if "parallel" in test_category:
|
||||
return parallel_function_checker_no_order(func_description, model_output, possible_answer, language, model_name)
|
||||
|
||||
elif "multiple" in test_category:
|
||||
return multiple_function_checker(func_description, model_output, possible_answer, language, model_name)
|
||||
|
||||
else:
|
||||
if len(model_output) != 1:
|
||||
return {
|
||||
"valid": False,
|
||||
"error": ["Wrong number of functions."],
|
||||
"error_type": "simple_function_checker:wrong_count",
|
||||
}
|
||||
|
||||
return simple_function_checker(
|
||||
func_description[0],
|
||||
model_output[0],
|
||||
possible_answer[0],
|
||||
language,
|
||||
model_name,
|
||||
)
|
||||
|
||||
|
||||
def exec_checker(decoded_result: list, func_description: dict, test_category: str):
|
||||
if "multiple" in test_category or "parallel" in test_category:
|
||||
return executable_checker_parallel_no_order(
|
||||
decoded_result,
|
||||
func_description["execution_result"],
|
||||
func_description["execution_result_type"],
|
||||
)
|
||||
|
||||
else:
|
||||
if len(decoded_result) != 1:
|
||||
return {
|
||||
"valid": False,
|
||||
"error": ["Wrong number of functions."],
|
||||
"error_type": "simple_exec_checker:wrong_count",
|
||||
}
|
||||
return executable_checker_simple(
|
||||
decoded_result[0],
|
||||
func_description["execution_result"][0],
|
||||
func_description["execution_result_type"][0],
|
||||
False,
|
||||
)
|
||||
|
||||
|
||||
def is_empty_output(decoded_output):
|
||||
# This function is a patch to the ast decoder for relevance detection
|
||||
# Sometimes the ast decoder will parse successfully, but the input doens't really have a function call
|
||||
# [], [{}], and anything that is not in function calling format is considered empty (and thus should be marked as correct)
|
||||
if not is_function_calling_format_output(decoded_output):
|
||||
return True
|
||||
if len(decoded_output) == 0:
|
||||
return True
|
||||
if len(decoded_output) == 1 and len(decoded_output[0]) == 0:
|
||||
return True
|
||||
|
||||
|
||||
def is_function_calling_format_output(decoded_output):
|
||||
# Ensure the output is a list of dictionaries
|
||||
if type(decoded_output) == list:
|
||||
for item in decoded_output:
|
||||
if type(item) != dict:
|
||||
return False
|
||||
return True
|
||||
return False
|
|
@ -1,40 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Tree-sitter changes its API with unfortunate frequency. Modules that need it should
|
||||
import it from here so that we can centrally manage things as necessary.
|
||||
"""
|
||||
|
||||
# These currently work with tree-sitter 0.23.0
|
||||
# NOTE: Don't import tree-sitter or any of the language modules in the main module
|
||||
# because not all environments have them. Import lazily inside functions where needed.
|
||||
|
||||
import importlib
|
||||
import typing
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
import tree_sitter
|
||||
|
||||
|
||||
def get_language(language: str) -> "tree_sitter.Language":
|
||||
import tree_sitter
|
||||
|
||||
language_module_name = f"tree_sitter_{language}"
|
||||
try:
|
||||
language_module = importlib.import_module(language_module_name)
|
||||
except ModuleNotFoundError as exc:
|
||||
raise ValueError(
|
||||
f"Language {language} is not found. Please install the tree-sitter-{language} package."
|
||||
) from exc
|
||||
return tree_sitter.Language(language_module.language())
|
||||
|
||||
|
||||
def get_parser(language: str, **kwargs) -> "tree_sitter.Parser":
|
||||
import tree_sitter
|
||||
|
||||
lang = get_language(language)
|
||||
return tree_sitter.Parser(lang, **kwargs)
|
|
@ -14,6 +14,6 @@ from .config import RagToolRuntimeConfig
|
|||
async def get_provider_impl(config: RagToolRuntimeConfig, deps: dict[Api, Any]):
|
||||
from .memory import MemoryToolRuntimeImpl
|
||||
|
||||
impl = MemoryToolRuntimeImpl(config, deps[Api.vector_io], deps[Api.inference])
|
||||
impl = MemoryToolRuntimeImpl(config, deps[Api.vector_io], deps[Api.inference], deps[Api.files])
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
@ -5,10 +5,15 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import mimetypes
|
||||
import secrets
|
||||
import string
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from fastapi import UploadFile
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
|
@ -17,6 +22,7 @@ from llama_stack.apis.common.content_types import (
|
|||
InterleavedContentItem,
|
||||
TextContentItem,
|
||||
)
|
||||
from llama_stack.apis.files import Files, OpenAIFilePurpose
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.tools import (
|
||||
ListToolDefsResponse,
|
||||
|
@ -30,13 +36,18 @@ from llama_stack.apis.tools import (
|
|||
ToolParameter,
|
||||
ToolRuntime,
|
||||
)
|
||||
from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO
|
||||
from llama_stack.apis.vector_io import (
|
||||
QueryChunksResponse,
|
||||
VectorIO,
|
||||
VectorStoreChunkingStrategyStatic,
|
||||
VectorStoreChunkingStrategyStaticConfig,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
content_from_doc,
|
||||
make_overlapped_chunks,
|
||||
parse_data_url,
|
||||
)
|
||||
|
||||
from .config import RagToolRuntimeConfig
|
||||
|
@ -55,10 +66,12 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
|
|||
config: RagToolRuntimeConfig,
|
||||
vector_io_api: VectorIO,
|
||||
inference_api: Inference,
|
||||
files_api: Files,
|
||||
):
|
||||
self.config = config
|
||||
self.vector_io_api = vector_io_api
|
||||
self.inference_api = inference_api
|
||||
self.files_api = files_api
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
@ -78,26 +91,49 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
|
|||
vector_db_id: str,
|
||||
chunk_size_in_tokens: int = 512,
|
||||
) -> None:
|
||||
chunks = []
|
||||
for doc in documents:
|
||||
content = await content_from_doc(doc)
|
||||
# TODO: we should add enrichment here as URLs won't be added to the metadata by default
|
||||
chunks.extend(
|
||||
make_overlapped_chunks(
|
||||
doc.document_id,
|
||||
content,
|
||||
chunk_size_in_tokens,
|
||||
chunk_size_in_tokens // 4,
|
||||
doc.metadata,
|
||||
)
|
||||
)
|
||||
|
||||
if not chunks:
|
||||
if not documents:
|
||||
return
|
||||
|
||||
await self.vector_io_api.insert_chunks(
|
||||
chunks=chunks,
|
||||
vector_db_id=vector_db_id,
|
||||
for doc in documents:
|
||||
if isinstance(doc.content, URL):
|
||||
if doc.content.uri.startswith("data:"):
|
||||
parts = parse_data_url(doc.content.uri)
|
||||
file_data = base64.b64decode(parts["data"]) if parts["is_base64"] else parts["data"].encode()
|
||||
mime_type = parts["mimetype"]
|
||||
else:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(doc.content.uri)
|
||||
file_data = response.content
|
||||
mime_type = doc.mime_type or response.headers.get("content-type", "application/octet-stream")
|
||||
else:
|
||||
content_str = await content_from_doc(doc)
|
||||
file_data = content_str.encode("utf-8")
|
||||
mime_type = doc.mime_type or "text/plain"
|
||||
|
||||
file_extension = mimetypes.guess_extension(mime_type) or ".txt"
|
||||
filename = doc.metadata.get("filename", f"{doc.document_id}{file_extension}")
|
||||
|
||||
file_obj = io.BytesIO(file_data)
|
||||
file_obj.name = filename
|
||||
|
||||
upload_file = UploadFile(file=file_obj, filename=filename)
|
||||
|
||||
created_file = await self.files_api.openai_upload_file(
|
||||
file=upload_file, purpose=OpenAIFilePurpose.ASSISTANTS
|
||||
)
|
||||
|
||||
chunking_strategy = VectorStoreChunkingStrategyStatic(
|
||||
static=VectorStoreChunkingStrategyStaticConfig(
|
||||
max_chunk_size_tokens=chunk_size_in_tokens,
|
||||
chunk_overlap_tokens=chunk_size_in_tokens // 4,
|
||||
)
|
||||
)
|
||||
|
||||
await self.vector_io_api.openai_attach_file_to_vector_store(
|
||||
vector_store_id=vector_db_id,
|
||||
file_id=created_file.id,
|
||||
attributes=doc.metadata,
|
||||
chunking_strategy=chunking_strategy,
|
||||
)
|
||||
|
||||
async def query(
|
||||
|
@ -131,8 +167,18 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
|
|||
for vector_db_id in vector_db_ids
|
||||
]
|
||||
results: list[QueryChunksResponse] = await asyncio.gather(*tasks)
|
||||
chunks = [c for r in results for c in r.chunks]
|
||||
scores = [s for r in results for s in r.scores]
|
||||
|
||||
chunks = []
|
||||
scores = []
|
||||
|
||||
for vector_db_id, result in zip(vector_db_ids, results, strict=False):
|
||||
for chunk, score in zip(result.chunks, result.scores, strict=False):
|
||||
if not hasattr(chunk, "metadata") or chunk.metadata is None:
|
||||
chunk.metadata = {}
|
||||
chunk.metadata["vector_db_id"] = vector_db_id
|
||||
|
||||
chunks.append(chunk)
|
||||
scores.append(score)
|
||||
|
||||
if not chunks:
|
||||
return RAGQueryResult(content=None)
|
||||
|
@ -167,6 +213,7 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
|
|||
metadata_keys_to_exclude_from_context = [
|
||||
"token_count",
|
||||
"metadata_token_count",
|
||||
"vector_db_id",
|
||||
]
|
||||
metadata_for_context = {}
|
||||
for k in chunk_metadata_keys_to_include_from_context:
|
||||
|
@ -191,6 +238,7 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
|
|||
"document_ids": [c.metadata["document_id"] for c in chunks[: len(picked)]],
|
||||
"chunks": [c.content for c in chunks[: len(picked)]],
|
||||
"scores": scores[: len(picked)],
|
||||
"vector_db_ids": [c.metadata["vector_db_id"] for c in chunks[: len(picked)]],
|
||||
},
|
||||
)
|
||||
|
||||
|
|
|
@ -30,11 +30,11 @@ from llama_stack.providers.utils.kvstore.api import KVStore
|
|||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
RERANKER_TYPE_RRF,
|
||||
RERANKER_TYPE_WEIGHTED,
|
||||
ChunkForDeletion,
|
||||
EmbeddingIndex,
|
||||
VectorDBWithIndex,
|
||||
)
|
||||
from llama_stack.providers.utils.vector_io.vector_utils import WeightedInMemoryAggregator
|
||||
|
||||
logger = get_logger(name=__name__, category="vector_io")
|
||||
|
||||
|
@ -66,59 +66,6 @@ def _create_sqlite_connection(db_path):
|
|||
return connection
|
||||
|
||||
|
||||
def _normalize_scores(scores: dict[str, float]) -> dict[str, float]:
|
||||
"""Normalize scores to [0,1] range using min-max normalization."""
|
||||
if not scores:
|
||||
return {}
|
||||
min_score = min(scores.values())
|
||||
max_score = max(scores.values())
|
||||
score_range = max_score - min_score
|
||||
if score_range > 0:
|
||||
return {doc_id: (score - min_score) / score_range for doc_id, score in scores.items()}
|
||||
return dict.fromkeys(scores, 1.0)
|
||||
|
||||
|
||||
def _weighted_rerank(
|
||||
vector_scores: dict[str, float],
|
||||
keyword_scores: dict[str, float],
|
||||
alpha: float = 0.5,
|
||||
) -> dict[str, float]:
|
||||
"""ReRanker that uses weighted average of scores."""
|
||||
all_ids = set(vector_scores.keys()) | set(keyword_scores.keys())
|
||||
normalized_vector_scores = _normalize_scores(vector_scores)
|
||||
normalized_keyword_scores = _normalize_scores(keyword_scores)
|
||||
|
||||
return {
|
||||
doc_id: (alpha * normalized_keyword_scores.get(doc_id, 0.0))
|
||||
+ ((1 - alpha) * normalized_vector_scores.get(doc_id, 0.0))
|
||||
for doc_id in all_ids
|
||||
}
|
||||
|
||||
|
||||
def _rrf_rerank(
|
||||
vector_scores: dict[str, float],
|
||||
keyword_scores: dict[str, float],
|
||||
impact_factor: float = 60.0,
|
||||
) -> dict[str, float]:
|
||||
"""ReRanker that uses Reciprocal Rank Fusion."""
|
||||
# Convert scores to ranks
|
||||
vector_ranks = {
|
||||
doc_id: i + 1 for i, (doc_id, _) in enumerate(sorted(vector_scores.items(), key=lambda x: x[1], reverse=True))
|
||||
}
|
||||
keyword_ranks = {
|
||||
doc_id: i + 1 for i, (doc_id, _) in enumerate(sorted(keyword_scores.items(), key=lambda x: x[1], reverse=True))
|
||||
}
|
||||
|
||||
all_ids = set(vector_scores.keys()) | set(keyword_scores.keys())
|
||||
rrf_scores = {}
|
||||
for doc_id in all_ids:
|
||||
vector_rank = vector_ranks.get(doc_id, float("inf"))
|
||||
keyword_rank = keyword_ranks.get(doc_id, float("inf"))
|
||||
# RRF formula: score = 1/(k + r) where k is impact_factor and r is the rank
|
||||
rrf_scores[doc_id] = (1.0 / (impact_factor + vector_rank)) + (1.0 / (impact_factor + keyword_rank))
|
||||
return rrf_scores
|
||||
|
||||
|
||||
def _make_sql_identifier(name: str) -> str:
|
||||
return re.sub(r"[^a-zA-Z0-9_]", "_", name)
|
||||
|
||||
|
@ -398,14 +345,10 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
for chunk, score in zip(keyword_response.chunks, keyword_response.scores, strict=False)
|
||||
}
|
||||
|
||||
# Combine scores using the specified reranker
|
||||
if reranker_type == RERANKER_TYPE_WEIGHTED:
|
||||
alpha = reranker_params.get("alpha", 0.5)
|
||||
combined_scores = _weighted_rerank(vector_scores, keyword_scores, alpha)
|
||||
else:
|
||||
# Default to RRF for None, RRF, or any unknown types
|
||||
impact_factor = reranker_params.get("impact_factor", 60.0)
|
||||
combined_scores = _rrf_rerank(vector_scores, keyword_scores, impact_factor)
|
||||
# Combine scores using the reranking utility
|
||||
combined_scores = WeightedInMemoryAggregator.combine_search_results(
|
||||
vector_scores, keyword_scores, reranker_type, reranker_params
|
||||
)
|
||||
|
||||
# Sort by combined score and get top k results
|
||||
sorted_items = sorted(combined_scores.items(), key=lambda x: x[1], reverse=True)
|
||||
|
|
|
@ -13,7 +13,7 @@ def available_providers() -> list[ProviderSpec]:
|
|||
InlineProviderSpec(
|
||||
api=Api.batches,
|
||||
provider_type="inline::reference",
|
||||
pip_packages=["openai"],
|
||||
pip_packages=[],
|
||||
module="llama_stack.providers.inline.batches.reference",
|
||||
config_class="llama_stack.providers.inline.batches.reference.config.ReferenceBatchesImplConfig",
|
||||
api_dependencies=[
|
||||
|
|
|
@ -30,7 +30,7 @@ def available_providers() -> list[ProviderSpec]:
|
|||
adapter=AdapterSpec(
|
||||
adapter_type="huggingface",
|
||||
pip_packages=[
|
||||
"datasets",
|
||||
"datasets>=4.0.0",
|
||||
],
|
||||
module="llama_stack.providers.remote.datasetio.huggingface",
|
||||
config_class="llama_stack.providers.remote.datasetio.huggingface.HuggingfaceDatasetIOConfig",
|
||||
|
@ -42,7 +42,7 @@ def available_providers() -> list[ProviderSpec]:
|
|||
adapter=AdapterSpec(
|
||||
adapter_type="nvidia",
|
||||
pip_packages=[
|
||||
"datasets",
|
||||
"datasets>=4.0.0",
|
||||
],
|
||||
module="llama_stack.providers.remote.datasetio.nvidia",
|
||||
config_class="llama_stack.providers.remote.datasetio.nvidia.NvidiaDatasetIOConfig",
|
||||
|
|
|
@ -40,8 +40,9 @@ def available_providers() -> list[ProviderSpec]:
|
|||
InlineProviderSpec(
|
||||
api=Api.inference,
|
||||
provider_type="inline::sentence-transformers",
|
||||
# CrossEncoder depends on torchao.quantization
|
||||
pip_packages=[
|
||||
"torch torchvision --index-url https://download.pytorch.org/whl/cpu",
|
||||
"torch torchvision torchao>=0.12.0 --extra-index-url https://download.pytorch.org/whl/cpu",
|
||||
"sentence-transformers --no-deps",
|
||||
],
|
||||
module="llama_stack.providers.inline.inference.sentence_transformers",
|
||||
|
@ -74,7 +75,7 @@ def available_providers() -> list[ProviderSpec]:
|
|||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="vllm",
|
||||
pip_packages=["openai"],
|
||||
pip_packages=[],
|
||||
module="llama_stack.providers.remote.inference.vllm",
|
||||
config_class="llama_stack.providers.remote.inference.vllm.VLLMInferenceAdapterConfig",
|
||||
description="Remote vLLM inference provider for connecting to vLLM servers.",
|
||||
|
@ -115,7 +116,7 @@ def available_providers() -> list[ProviderSpec]:
|
|||
adapter=AdapterSpec(
|
||||
adapter_type="fireworks",
|
||||
pip_packages=[
|
||||
"fireworks-ai",
|
||||
"fireworks-ai<=0.17.16",
|
||||
],
|
||||
module="llama_stack.providers.remote.inference.fireworks",
|
||||
config_class="llama_stack.providers.remote.inference.fireworks.FireworksImplConfig",
|
||||
|
@ -150,9 +151,7 @@ def available_providers() -> list[ProviderSpec]:
|
|||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="databricks",
|
||||
pip_packages=[
|
||||
"openai",
|
||||
],
|
||||
pip_packages=[],
|
||||
module="llama_stack.providers.remote.inference.databricks",
|
||||
config_class="llama_stack.providers.remote.inference.databricks.DatabricksImplConfig",
|
||||
description="Databricks inference provider for running models on Databricks' unified analytics platform.",
|
||||
|
@ -162,9 +161,7 @@ def available_providers() -> list[ProviderSpec]:
|
|||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="nvidia",
|
||||
pip_packages=[
|
||||
"openai",
|
||||
],
|
||||
pip_packages=[],
|
||||
module="llama_stack.providers.remote.inference.nvidia",
|
||||
config_class="llama_stack.providers.remote.inference.nvidia.NVIDIAConfig",
|
||||
description="NVIDIA inference provider for accessing NVIDIA NIM models and AI services.",
|
||||
|
@ -174,7 +171,7 @@ def available_providers() -> list[ProviderSpec]:
|
|||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="runpod",
|
||||
pip_packages=["openai"],
|
||||
pip_packages=[],
|
||||
module="llama_stack.providers.remote.inference.runpod",
|
||||
config_class="llama_stack.providers.remote.inference.runpod.RunpodImplConfig",
|
||||
description="RunPod inference provider for running models on RunPod's cloud GPU platform.",
|
||||
|
@ -291,7 +288,7 @@ Available Models:
|
|||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="watsonx",
|
||||
pip_packages=["ibm_watson_machine_learning"],
|
||||
pip_packages=["ibm_watsonx_ai"],
|
||||
module="llama_stack.providers.remote.inference.watsonx",
|
||||
config_class="llama_stack.providers.remote.inference.watsonx.WatsonXConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.watsonx.WatsonXProviderDataValidator",
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue