diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 263828e1c..af2058b9a 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -1,8 +1,10 @@ # What does this PR do? - +[Provide a short summary of what this PR does and why. Link to relevant issues if applicable.] - - +[//]: # (If resolving an issue, uncomment and update the line below) +[//]: # (Closes #[issue-number]) ## Test Plan - +[Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.*] + +[//]: # (## Documentation) diff --git a/.github/actions/setup-runner/action.yml b/.github/actions/setup-runner/action.yml deleted file mode 100644 index 6cba4fdc3..000000000 --- a/.github/actions/setup-runner/action.yml +++ /dev/null @@ -1,22 +0,0 @@ -name: Setup runner -description: Prepare a runner for the tests (install uv, python, project dependencies, etc.) -runs: - using: "composite" - steps: - - name: Install uv - uses: astral-sh/setup-uv@6b9c6063abd6010835644d4c2e1bef4cf5cd0fca # v6.0.1 - with: - python-version: "3.10" - activate-environment: true - version: 0.7.6 - - - name: Install dependencies - shell: bash - run: | - uv sync --all-groups - uv pip install ollama faiss-cpu - # always test against the latest version of the client - # TODO: this is not necessarily a good idea. we need to test against both published and latest - # to find out backwards compatibility issues. - uv pip install git+https://github.com/meta-llama/llama-stack-client-python.git@main - uv pip install -e . diff --git a/.github/workflows/Dockerfile b/.github/workflows/Dockerfile deleted file mode 100644 index 9261bd174..000000000 --- a/.github/workflows/Dockerfile +++ /dev/null @@ -1 +0,0 @@ -FROM localhost:5000/distribution-kvant:dev \ No newline at end of file diff --git a/.github/workflows_upstream/changelog.yml b/.github/workflows/changelog.yml similarity index 100% rename from .github/workflows_upstream/changelog.yml rename to .github/workflows/changelog.yml diff --git a/.github/workflows/ci-playground.yaml b/.github/workflows/ci-playground.yaml deleted file mode 100644 index 251782855..000000000 --- a/.github/workflows/ci-playground.yaml +++ /dev/null @@ -1,73 +0,0 @@ -name: Build and Push playground container -run-name: Build and Push playground container -on: - workflow_dispatch: - #schedule: - # - cron: "0 10 * * *" - push: - branches: - - main - - kvant - tags: - - 'v*' - pull_request: - branches: - - main - - kvant -env: - IMAGE: git.kvant.cloud/${{github.repository}}-playground -jobs: - build-playground: - runs-on: ubuntu-latest - steps: - - name: Checkout - uses: actions/checkout@v4 - with: - fetch-depth: 0 - - - name: Set current time - uses: https://github.com/gerred/actions/current-time@master - id: current_time - - - name: Set up Docker Buildx - uses: docker/setup-buildx-action@v3 - - - name: Login to git.kvant.cloud registry - uses: docker/login-action@v3 - with: - registry: git.kvant.cloud - username: ${{ vars.ORG_PACKAGE_WRITER_USERNAME }} - password: ${{ secrets.ORG_PACKAGE_WRITER_TOKEN }} - - - name: Docker meta - id: meta - uses: docker/metadata-action@v5 - with: - # list of Docker images to use as base name for tags - images: | - ${{env.IMAGE}} - # generate Docker tags based on the following events/attributes - tags: | - type=schedule - type=ref,event=branch - type=ref,event=pr - type=ref,event=tag - type=semver,pattern={{version}} - - - name: Build and push to gitea registry - uses: docker/build-push-action@v6 - with: - push: ${{ github.event_name != 'pull_request' }} - tags: ${{ steps.meta.outputs.tags }} - labels: ${{ steps.meta.outputs.labels }} - context: . - file: llama_stack/distribution/ui/Containerfile - provenance: mode=max - sbom: true - build-args: | - BUILD_DATE=${{ steps.current_time.outputs.time }} - cache-from: | - type=registry,ref=${{ env.IMAGE }}:buildcache - type=registry,ref=${{ env.IMAGE }}:${{ github.ref_name }} - type=registry,ref=${{ env.IMAGE }}:main - cache-to: type=registry,ref=${{ env.IMAGE }}:buildcache,mode=max,image-manifest=true diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml deleted file mode 100644 index 87f196cc2..000000000 --- a/.github/workflows/ci.yaml +++ /dev/null @@ -1,98 +0,0 @@ -name: Build and Push container -run-name: Build and Push container -on: - workflow_dispatch: - #schedule: - # - cron: "0 10 * * *" - push: - branches: - - main - - kvant - tags: - - 'v*' - pull_request: - branches: - - main - - kvant -env: - IMAGE: git.kvant.cloud/${{github.repository}} -jobs: - build: - runs-on: ubuntu-latest - services: - registry: - image: registry:2 - ports: - - 5000:5000 - steps: - - name: Checkout - uses: actions/checkout@v4 - with: - fetch-depth: 0 - - - name: Set current time - uses: https://github.com/gerred/actions/current-time@master - id: current_time - - - name: Set up Docker Buildx - uses: docker/setup-buildx-action@v3 - with: - driver-opts: network=host - - - name: Login to git.kvant.cloud registry - uses: docker/login-action@v3 - with: - registry: git.kvant.cloud - username: ${{ vars.ORG_PACKAGE_WRITER_USERNAME }} - password: ${{ secrets.ORG_PACKAGE_WRITER_TOKEN }} - - - name: Docker meta - id: meta - uses: docker/metadata-action@v5 - with: - # list of Docker images to use as base name for tags - images: | - ${{env.IMAGE}} - # generate Docker tags based on the following events/attributes - tags: | - type=schedule - type=ref,event=branch - type=ref,event=pr - type=ref,event=tag - type=semver,pattern={{version}} - - - name: Install uv - uses: https://github.com/astral-sh/setup-uv@v5 - with: - # Install a specific version of uv. - version: "0.7.8" - - - name: Build - env: - USE_COPY_NOT_MOUNT: true - LLAMA_STACK_DIR: . - run: | - uvx --from . llama stack build --template kvant --image-type container - - # docker tag distribution-kvant:dev ${{env.IMAGE}}:kvant - # docker push ${{env.IMAGE}}:kvant - - docker tag distribution-kvant:dev localhost:5000/distribution-kvant:dev - docker push localhost:5000/distribution-kvant:dev - - - name: Build and push to gitea registry - uses: docker/build-push-action@v6 - with: - push: ${{ github.event_name != 'pull_request' }} - tags: ${{ steps.meta.outputs.tags }} - labels: ${{ steps.meta.outputs.labels }} - context: .github/workflows - provenance: mode=max - sbom: true - build-args: | - BUILD_DATE=${{ steps.current_time.outputs.time }} - cache-from: | - type=registry,ref=${{ env.IMAGE }}:buildcache - type=registry,ref=${{ env.IMAGE }}:${{ github.ref_name }} - type=registry,ref=${{ env.IMAGE }}:main - cache-to: type=registry,ref=${{ env.IMAGE }}:buildcache,mode=max,image-manifest=true diff --git a/.github/workflows_upstream/gha_workflow_llama_stack_tests.yml b/.github/workflows/gha_workflow_llama_stack_tests.yml similarity index 100% rename from .github/workflows_upstream/gha_workflow_llama_stack_tests.yml rename to .github/workflows/gha_workflow_llama_stack_tests.yml diff --git a/.github/workflows_upstream/install-script-ci.yml b/.github/workflows/install-script-ci.yml similarity index 100% rename from .github/workflows_upstream/install-script-ci.yml rename to .github/workflows/install-script-ci.yml diff --git a/.github/workflows_upstream/integration-auth-tests.yml b/.github/workflows/integration-auth-tests.yml similarity index 64% rename from .github/workflows_upstream/integration-auth-tests.yml rename to .github/workflows/integration-auth-tests.yml index a3a746246..33fb4e802 100644 --- a/.github/workflows_upstream/integration-auth-tests.yml +++ b/.github/workflows/integration-auth-tests.yml @@ -23,18 +23,23 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - auth-provider: [oauth2_token] + auth-provider: [kubernetes] fail-fast: false # we want to run all tests regardless of failure steps: - name: Checkout repository uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - name: Install dependencies - uses: ./.github/actions/setup-runner + - name: Install uv + uses: astral-sh/setup-uv@c7f87aa956e4c323abf06d5dec078e358f6b4d04 # v6.0.0 + with: + python-version: "3.10" + activate-environment: true - - name: Build Llama Stack + - name: Set Up Environment and Install Dependencies run: | + uv sync --extra dev --extra test + uv pip install -e . llama stack build --template ollama --image-type venv - name: Install minikube @@ -42,53 +47,29 @@ jobs: uses: medyagh/setup-minikube@cea33675329b799adccc9526aa5daccc26cd5052 # v0.0.19 - name: Start minikube - if: ${{ matrix.auth-provider == 'oauth2_token' }} + if: ${{ matrix.auth-provider == 'kubernetes' }} run: | minikube start kubectl get pods -A - name: Configure Kube Auth - if: ${{ matrix.auth-provider == 'oauth2_token' }} + if: ${{ matrix.auth-provider == 'kubernetes' }} run: | kubectl create namespace llama-stack kubectl create serviceaccount llama-stack-auth -n llama-stack kubectl create rolebinding llama-stack-auth-rolebinding --clusterrole=admin --serviceaccount=llama-stack:llama-stack-auth -n llama-stack kubectl create token llama-stack-auth -n llama-stack > llama-stack-auth-token - cat <> $GITHUB_ENV + echo "KUBERNETES_API_SERVER_URL=$(kubectl config view --minify -o jsonpath='{.clusters[0].cluster.server}')" >> $GITHUB_ENV echo "KUBERNETES_CA_CERT_PATH=$(kubectl config view --minify -o jsonpath='{.clusters[0].cluster.certificate-authority}')" >> $GITHUB_ENV - echo "KUBERNETES_ISSUER=$(kubectl get --raw /.well-known/openid-configuration| jq -r .issuer)" >> $GITHUB_ENV - echo "KUBERNETES_AUDIENCE=$(kubectl create token llama-stack-auth -n llama-stack --duration=1h | cut -d. -f2 | base64 -d | jq -r '.aud[0]')" >> $GITHUB_ENV - name: Set Kube Auth Config and run server env: INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct" - if: ${{ matrix.auth-provider == 'oauth2_token' }} + if: ${{ matrix.auth-provider == 'kubernetes' }} run: | run_dir=$(mktemp -d) cat <<'EOF' > $run_dir/run.yaml @@ -100,10 +81,10 @@ jobs: port: 8321 EOF yq eval '.server.auth = {"provider_type": "${{ matrix.auth-provider }}"}' -i $run_dir/run.yaml - yq eval '.server.auth.config = {"tls_cafile": "${{ env.KUBERNETES_CA_CERT_PATH }}", "issuer": "${{ env.KUBERNETES_ISSUER }}", "audience": "${{ env.KUBERNETES_AUDIENCE }}"}' -i $run_dir/run.yaml - yq eval '.server.auth.config.jwks = {"uri": "${{ env.KUBERNETES_API_SERVER_URL }}"}' -i $run_dir/run.yaml + yq eval '.server.auth.config = {"api_server_url": "${{ env.KUBERNETES_API_SERVER_URL }}", "ca_cert_path": "${{ env.KUBERNETES_CA_CERT_PATH }}"}' -i $run_dir/run.yaml cat $run_dir/run.yaml + source .venv/bin/activate nohup uv run llama stack run $run_dir/run.yaml --image-type venv > server.log 2>&1 & - name: Wait for Llama Stack server to be ready diff --git a/.github/workflows_upstream/integration-tests.yml b/.github/workflows/integration-tests.yml similarity index 75% rename from .github/workflows_upstream/integration-tests.yml rename to .github/workflows/integration-tests.yml index d78e82c9d..d755ff0ae 100644 --- a/.github/workflows_upstream/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -24,7 +24,7 @@ jobs: matrix: # Listing tests manually since some of them currently fail # TODO: generate matrix list from tests/integration when fixed - test-type: [agents, inference, datasets, inspect, scoring, post_training, providers, tool_runtime] + test-type: [agents, inference, datasets, inspect, scoring, post_training, providers] client-type: [library, http] fail-fast: false # we want to run all tests regardless of failure @@ -32,14 +32,24 @@ jobs: - name: Checkout repository uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - name: Install dependencies - uses: ./.github/actions/setup-runner + - name: Install uv + uses: astral-sh/setup-uv@c7f87aa956e4c323abf06d5dec078e358f6b4d04 # v6.0.0 + with: + python-version: "3.10" + activate-environment: true - name: Setup ollama uses: ./.github/actions/setup-ollama - - name: Build Llama Stack + - name: Set Up Environment and Install Dependencies run: | + uv sync --extra dev --extra test + uv pip install ollama faiss-cpu + # always test against the latest version of the client + # TODO: this is not necessarily a good idea. we need to test against both published and latest + # to find out backwards compatibility issues. + uv pip install git+https://github.com/meta-llama/llama-stack-client-python.git@main + uv pip install -e . llama stack build --template ollama --image-type venv - name: Start Llama Stack server in background @@ -47,7 +57,8 @@ jobs: env: INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct" run: | - LLAMA_STACK_LOG_FILE=server.log nohup uv run llama stack run ./llama_stack/templates/ollama/run.yaml --image-type venv & + source .venv/bin/activate + nohup uv run llama stack run ./llama_stack/templates/ollama/run.yaml --image-type venv > server.log 2>&1 & - name: Wait for Llama Stack server to be ready if: matrix.client-type == 'http' @@ -75,12 +86,6 @@ jobs: exit 1 fi - - name: Check Storage and Memory Available Before Tests - if: ${{ always() }} - run: | - free -h - df -h - - name: Run Integration Tests env: INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct" @@ -90,24 +95,17 @@ jobs: else stack_config="http://localhost:8321" fi - uv run pytest -s -v tests/integration/${{ matrix.test-type }} --stack-config=${stack_config} \ + uv run pytest -v tests/integration/${{ matrix.test-type }} --stack-config=${stack_config} \ -k "not(builtin_tool or safety_with_image or code_interpreter or test_rag)" \ --text-model="meta-llama/Llama-3.2-3B-Instruct" \ --embedding-model=all-MiniLM-L6-v2 - - name: Check Storage and Memory Available After Tests - if: ${{ always() }} - run: | - free -h - df -h - - name: Write ollama logs to file - if: ${{ always() }} run: | sudo journalctl -u ollama.service > ollama.log - name: Upload all logs to artifacts - if: ${{ always() }} + if: always() uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2 with: name: logs-${{ github.run_id }}-${{ github.run_attempt }}-${{ matrix.client-type }}-${{ matrix.test-type }} diff --git a/.github/workflows_upstream/pre-commit.yml b/.github/workflows/pre-commit.yml similarity index 97% rename from .github/workflows_upstream/pre-commit.yml rename to .github/workflows/pre-commit.yml index 2bbd52c53..4df04fbb0 100644 --- a/.github/workflows_upstream/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -29,7 +29,6 @@ jobs: - uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1 env: SKIP: no-commit-to-branch - RUFF_OUTPUT_FORMAT: github - name: Verify if there are any diff files after pre-commit run: | diff --git a/.github/workflows_upstream/providers-build.yml b/.github/workflows/providers-build.yml similarity index 73% rename from .github/workflows_upstream/providers-build.yml rename to .github/workflows/providers-build.yml index cf53459b9..0fd7904d4 100644 --- a/.github/workflows_upstream/providers-build.yml +++ b/.github/workflows/providers-build.yml @@ -50,8 +50,21 @@ jobs: - name: Checkout repository uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - name: Install dependencies - uses: ./.github/actions/setup-runner + - name: Set up Python + uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0 + with: + python-version: '3.10' + + - name: Install uv + uses: astral-sh/setup-uv@c7f87aa956e4c323abf06d5dec078e358f6b4d04 # v6.0.0 + with: + python-version: "3.10" + + - name: Install LlamaStack + run: | + uv venv + source .venv/bin/activate + uv pip install -e . - name: Print build dependencies run: | @@ -66,6 +79,7 @@ jobs: - name: Print dependencies in the image if: matrix.image-type == 'venv' run: | + source test/bin/activate uv pip list build-single-provider: @@ -74,8 +88,21 @@ jobs: - name: Checkout repository uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - name: Install dependencies - uses: ./.github/actions/setup-runner + - name: Set up Python + uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0 + with: + python-version: '3.10' + + - name: Install uv + uses: astral-sh/setup-uv@c7f87aa956e4c323abf06d5dec078e358f6b4d04 # v6.0.0 + with: + python-version: "3.10" + + - name: Install LlamaStack + run: | + uv venv + source .venv/bin/activate + uv pip install -e . - name: Build a single provider run: | @@ -87,8 +114,21 @@ jobs: - name: Checkout repository uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - name: Install dependencies - uses: ./.github/actions/setup-runner + - name: Set up Python + uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0 + with: + python-version: '3.10' + + - name: Install uv + uses: astral-sh/setup-uv@c7f87aa956e4c323abf06d5dec078e358f6b4d04 # v6.0.0 + with: + python-version: "3.10" + + - name: Install LlamaStack + run: | + uv venv + source .venv/bin/activate + uv pip install -e . - name: Build a single provider run: | @@ -112,8 +152,21 @@ jobs: - name: Checkout repository uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - name: Install dependencies - uses: ./.github/actions/setup-runner + - name: Set up Python + uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0 + with: + python-version: '3.10' + + - name: Install uv + uses: astral-sh/setup-uv@0c5e2b8115b80b4c7c5ddf6ffdd634974642d182 # v5.4.1 + with: + python-version: "3.10" + + - name: Install LlamaStack + run: | + uv venv + source .venv/bin/activate + uv pip install -e . - name: Pin template to UBI9 base run: | diff --git a/.github/workflows_upstream/semantic-pr.yml b/.github/workflows/semantic-pr.yml similarity index 100% rename from .github/workflows_upstream/semantic-pr.yml rename to .github/workflows/semantic-pr.yml diff --git a/.github/workflows_upstream/stale_bot.yml b/.github/workflows/stale_bot.yml similarity index 100% rename from .github/workflows_upstream/stale_bot.yml rename to .github/workflows/stale_bot.yml diff --git a/.github/workflows_upstream/test-external-providers.yml b/.github/workflows/test-external-providers.yml similarity index 88% rename from .github/workflows_upstream/test-external-providers.yml rename to .github/workflows/test-external-providers.yml index 06ab7cf3c..8c75dde25 100644 --- a/.github/workflows_upstream/test-external-providers.yml +++ b/.github/workflows/test-external-providers.yml @@ -25,8 +25,15 @@ jobs: - name: Checkout repository uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - name: Install dependencies - uses: ./.github/actions/setup-runner + - name: Install uv + uses: astral-sh/setup-uv@c7f87aa956e4c323abf06d5dec078e358f6b4d04 # v6.0.0 + with: + python-version: "3.10" + + - name: Set Up Environment and Install Dependencies + run: | + uv sync --extra dev --extra test + uv pip install -e . - name: Apply image type to config file run: | @@ -52,6 +59,7 @@ jobs: env: INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct" run: | + source ci-test/bin/activate uv run pip list nohup uv run --active llama stack run tests/external-provider/llama-stack-provider-ollama/run.yaml --image-type ${{ matrix.image-type }} > server.log 2>&1 & diff --git a/.github/workflows_upstream/tests.yml b/.github/workflows/tests.yml similarity index 100% rename from .github/workflows_upstream/tests.yml rename to .github/workflows/tests.yml diff --git a/.github/workflows_upstream/unit-tests.yml b/.github/workflows/unit-tests.yml similarity index 73% rename from .github/workflows_upstream/unit-tests.yml rename to .github/workflows/unit-tests.yml index fc0459f0f..64a5bba37 100644 --- a/.github/workflows_upstream/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -30,11 +30,17 @@ jobs: - "3.12" - "3.13" steps: - - name: Checkout repository - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - name: Install dependencies - uses: ./.github/actions/setup-runner + - name: Set up Python ${{ matrix.python }} + uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0 + with: + python-version: ${{ matrix.python }} + + - uses: astral-sh/setup-uv@c7f87aa956e4c323abf06d5dec078e358f6b4d04 # v6.0.0 + with: + python-version: ${{ matrix.python }} + enable-cache: false - name: Run unit tests run: | diff --git a/.github/workflows_upstream/update-readthedocs.yml b/.github/workflows/update-readthedocs.yml similarity index 81% rename from .github/workflows_upstream/update-readthedocs.yml rename to .github/workflows/update-readthedocs.yml index 981332a77..094942368 100644 --- a/.github/workflows_upstream/update-readthedocs.yml +++ b/.github/workflows/update-readthedocs.yml @@ -37,8 +37,16 @@ jobs: - name: Checkout repository uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - name: Install dependencies - uses: ./.github/actions/setup-runner + - name: Set up Python + uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0 + with: + python-version: '3.11' + + - name: Install the latest version of uv + uses: astral-sh/setup-uv@c7f87aa956e4c323abf06d5dec078e358f6b4d04 # v6.0.0 + + - name: Sync with uv + run: uv sync --extra docs - name: Build HTML run: | diff --git a/.gitignore b/.gitignore index 747acdc7b..0ef25cdf1 100644 --- a/.gitignore +++ b/.gitignore @@ -6,7 +6,6 @@ dev_requirements.txt build .DS_Store llama_stack/configs/* -.cursor/ xcuserdata/ *.hmap .DS_Store @@ -24,4 +23,3 @@ venv/ pytest-report.xml .coverage .python-version -data diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index aaec469e4..e78fcd158 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -53,7 +53,7 @@ repos: - black==24.3.0 - repo: https://github.com/astral-sh/uv-pre-commit - rev: 0.7.8 + rev: 0.6.3 hooks: - id: uv-lock - id: uv-export @@ -61,7 +61,6 @@ repos: "--frozen", "--no-hashes", "--no-emit-project", - "--no-default-groups", "--output-file=requirements.txt" ] @@ -89,17 +88,20 @@ repos: - id: distro-codegen name: Distribution Template Codegen additional_dependencies: - - uv==0.7.8 - entry: uv run --group codegen ./scripts/distro_codegen.py + - uv==0.6.0 + entry: uv run --extra codegen ./scripts/distro_codegen.py language: python pass_filenames: false require_serial: true files: ^llama_stack/templates/.*$|^llama_stack/providers/.*/inference/.*/models\.py$ + +- repo: local + hooks: - id: openapi-codegen name: API Spec Codegen additional_dependencies: - - uv==0.7.8 - entry: sh -c 'uv run ./docs/openapi_generator/run_openapi_generator.sh > /dev/null' + - uv==0.6.2 + entry: sh -c 'uv run --with ".[dev]" ./docs/openapi_generator/run_openapi_generator.sh > /dev/null' language: python pass_filenames: false require_serial: true diff --git a/.readthedocs.yaml b/.readthedocs.yaml index 461977a6c..f114dbf9b 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -5,21 +5,28 @@ # Required version: 2 -# Build documentation in the "docs/" directory with Sphinx -sphinx: - configuration: docs/source/conf.py - # Set the OS, Python version and other tools you might need build: os: ubuntu-22.04 tools: python: "3.12" - jobs: - pre_create_environment: - - asdf plugin add uv - - asdf install uv latest - - asdf global uv latest - create_environment: - - uv venv "${READTHEDOCS_VIRTUALENV_PATH}" - install: - - UV_PROJECT_ENVIRONMENT="${READTHEDOCS_VIRTUALENV_PATH}" uv sync --frozen --group docs + # You can also specify other tool versions: + # nodejs: "19" + # rust: "1.64" + # golang: "1.19" + +# Build documentation in the "docs/" directory with Sphinx +sphinx: + configuration: docs/source/conf.py + +# Optionally build your docs in additional formats such as PDF and ePub +# formats: +# - pdf +# - epub + +# Optional but recommended, declare the Python requirements required +# to build your documentation +# See https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html +python: + install: + - requirements: docs/requirements.txt diff --git a/CHANGELOG.md b/CHANGELOG.md index f7644a5af..ec2468b46 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,14 +3,14 @@ # v0.2.7 Published on: 2025-05-16T20:38:10Z -## Highlights - -This is a small update. But a couple highlights: - -* feat: function tools in OpenAI Responses by @bbrowning in https://github.com/meta-llama/llama-stack/pull/2094, getting closer to ready. Streaming is the next missing piece. -* feat: Adding support for customizing chunk context in RAG insertion and querying by @franciscojavierarceo in https://github.com/meta-llama/llama-stack/pull/2134 -* feat: scaffolding for Llama Stack UI by @ehhuang in https://github.com/meta-llama/llama-stack/pull/2149, more to come in the coming releases. - +## Highlights + +This is a small update. But a couple highlights: + +* feat: function tools in OpenAI Responses by @bbrowning in https://github.com/meta-llama/llama-stack/pull/2094, getting closer to ready. Streaming is the next missing piece. +* feat: Adding support for customizing chunk context in RAG insertion and querying by @franciscojavierarceo in https://github.com/meta-llama/llama-stack/pull/2134 +* feat: scaffolding for Llama Stack UI by @ehhuang in https://github.com/meta-llama/llama-stack/pull/2149, more to come in the coming releases. + --- @@ -31,42 +31,42 @@ Published on: 2025-05-04T20:16:49Z # v0.2.4 Published on: 2025-04-29T17:26:01Z -## Highlights - -* One-liner to install and run Llama Stack yay! by @reluctantfuturist in https://github.com/meta-llama/llama-stack/pull/1383 -* support for NVIDIA NeMo datastore by @raspawar in https://github.com/meta-llama/llama-stack/pull/1852 -* (yuge!) Kubernetes authentication by @leseb in https://github.com/meta-llama/llama-stack/pull/1778 -* (yuge!) OpenAI Responses API by @bbrowning in https://github.com/meta-llama/llama-stack/pull/1989 -* add api.llama provider, llama-guard-4 model by @ashwinb in https://github.com/meta-llama/llama-stack/pull/2058 - +## Highlights + +* One-liner to install and run Llama Stack yay! by @reluctantfuturist in https://github.com/meta-llama/llama-stack/pull/1383 +* support for NVIDIA NeMo datastore by @raspawar in https://github.com/meta-llama/llama-stack/pull/1852 +* (yuge!) Kubernetes authentication by @leseb in https://github.com/meta-llama/llama-stack/pull/1778 +* (yuge!) OpenAI Responses API by @bbrowning in https://github.com/meta-llama/llama-stack/pull/1989 +* add api.llama provider, llama-guard-4 model by @ashwinb in https://github.com/meta-llama/llama-stack/pull/2058 + --- # v0.2.3 Published on: 2025-04-25T22:46:21Z -## Highlights - -* OpenAI compatible inference endpoints and client-SDK support. `client.chat.completions.create()` now works. -* significant improvements and functionality added to the nVIDIA distribution -* many improvements to the test verification suite. -* new inference providers: Ramalama, IBM WatsonX -* many improvements to the Playground UI - +## Highlights + +* OpenAI compatible inference endpoints and client-SDK support. `client.chat.completions.create()` now works. +* significant improvements and functionality added to the nVIDIA distribution +* many improvements to the test verification suite. +* new inference providers: Ramalama, IBM WatsonX +* many improvements to the Playground UI + --- # v0.2.2 Published on: 2025-04-13T01:19:49Z -## Main changes - -- Bring Your Own Provider (@leseb) - use out-of-tree provider code to execute the distribution server -- OpenAI compatible inference API in progress (@bbrowning) -- Provider verifications (@ehhuang) -- Many updates and fixes to playground -- Several llama4 related fixes - +## Main changes + +- Bring Your Own Provider (@leseb) - use out-of-tree provider code to execute the distribution server +- OpenAI compatible inference API in progress (@bbrowning) +- Provider verifications (@ehhuang) +- Many updates and fixes to playground +- Several llama4 related fixes + --- @@ -80,10 +80,10 @@ Published on: 2025-04-05T23:13:00Z # v0.2.0 Published on: 2025-04-05T19:04:29Z -## Llama 4 Support - -Checkout more at https://www.llama.com - +## Llama 4 Support + +Checkout more at https://www.llama.com + --- @@ -91,58 +91,58 @@ Checkout more at https://www.llama.com # v0.1.9 Published on: 2025-03-29T00:52:23Z -### Build and Test Agents -* Agents: Entire document context with attachments -* RAG: Documentation with sqlite-vec faiss comparison -* Getting started: Fixes to getting started notebook. - -### Agent Evals and Model Customization -* (**New**) Post-training: Add nemo customizer - -### Better Engineering -* Moved sqlite-vec to non-blocking calls -* Don't return a payload on file delete - - +### Build and Test Agents +* Agents: Entire document context with attachments +* RAG: Documentation with sqlite-vec faiss comparison +* Getting started: Fixes to getting started notebook. + +### Agent Evals and Model Customization +* (**New**) Post-training: Add nemo customizer + +### Better Engineering +* Moved sqlite-vec to non-blocking calls +* Don't return a payload on file delete + + --- # v0.1.8 Published on: 2025-03-24T01:28:50Z -# v0.1.8 Release Notes - -### Build and Test Agents -* Safety: Integrated NVIDIA as a safety provider. -* VectorDB: Added Qdrant as an inline provider. -* Agents: Added support for multiple tool groups in agents. -* Agents: Simplified imports for Agents in client package - - -### Agent Evals and Model Customization -* Introduced DocVQA and IfEval benchmarks. - -### Deploying and Monitoring Agents -* Introduced a Containerfile and image workflow for the Playground. -* Implemented support for Bearer (API Key) authentication. -* Added attribute-based access control for resources. -* Fixes on docker deployments: use --pull always and standardized the default port to 8321 -* Deprecated: /v1/inspect/providers use /v1/providers/ instead - -### Better Engineering -* Consolidated scripts under the ./scripts directory. -* Addressed mypy violations in various modules. -* Added Dependabot scans for Python dependencies. -* Implemented a scheduled workflow to update the changelog automatically. -* Enforced concurrency to reduce CI loads. - - -### New Contributors -* @cmodi-meta made their first contribution in https://github.com/meta-llama/llama-stack/pull/1650 -* @jeffmaury made their first contribution in https://github.com/meta-llama/llama-stack/pull/1671 -* @derekhiggins made their first contribution in https://github.com/meta-llama/llama-stack/pull/1698 -* @Bobbins228 made their first contribution in https://github.com/meta-llama/llama-stack/pull/1745 - +# v0.1.8 Release Notes + +### Build and Test Agents +* Safety: Integrated NVIDIA as a safety provider. +* VectorDB: Added Qdrant as an inline provider. +* Agents: Added support for multiple tool groups in agents. +* Agents: Simplified imports for Agents in client package + + +### Agent Evals and Model Customization +* Introduced DocVQA and IfEval benchmarks. + +### Deploying and Monitoring Agents +* Introduced a Containerfile and image workflow for the Playground. +* Implemented support for Bearer (API Key) authentication. +* Added attribute-based access control for resources. +* Fixes on docker deployments: use --pull always and standardized the default port to 8321 +* Deprecated: /v1/inspect/providers use /v1/providers/ instead + +### Better Engineering +* Consolidated scripts under the ./scripts directory. +* Addressed mypy violations in various modules. +* Added Dependabot scans for Python dependencies. +* Implemented a scheduled workflow to update the changelog automatically. +* Enforced concurrency to reduce CI loads. + + +### New Contributors +* @cmodi-meta made their first contribution in https://github.com/meta-llama/llama-stack/pull/1650 +* @jeffmaury made their first contribution in https://github.com/meta-llama/llama-stack/pull/1671 +* @derekhiggins made their first contribution in https://github.com/meta-llama/llama-stack/pull/1698 +* @Bobbins228 made their first contribution in https://github.com/meta-llama/llama-stack/pull/1745 + **Full Changelog**: https://github.com/meta-llama/llama-stack/compare/v0.1.7...v0.1.8 --- @@ -150,73 +150,73 @@ Published on: 2025-03-24T01:28:50Z # v0.1.7 Published on: 2025-03-14T22:30:51Z -## 0.1.7 Release Notes - -### Build and Test Agents -* Inference: ImageType is now refactored to LlamaStackImageType -* Inference: Added tests to measure TTFT -* Inference: Bring back usage metrics -* Agents: Added endpoint for get agent, list agents and list sessions -* Agents: Automated conversion of type hints in client tool for lite llm format -* Agents: Deprecated ToolResponseMessage in agent.resume API -* Added Provider API for listing and inspecting provider info - -### Agent Evals and Model Customization -* Eval: Added new eval benchmarks Math 500 and BFCL v3 -* Deploy and Monitoring of Agents -* Telemetry: Fix tracing to work across coroutines - -### Better Engineering -* Display code coverage for unit tests -* Updated call sites (inference, tool calls, agents) to move to async non blocking calls -* Unit tests also run on Python 3.11, 3.12, and 3.13 -* Added ollama inference to Integration tests CI -* Improved documentation across examples, testing, CLI, updated providers table ) - - - +## 0.1.7 Release Notes + +### Build and Test Agents +* Inference: ImageType is now refactored to LlamaStackImageType +* Inference: Added tests to measure TTFT +* Inference: Bring back usage metrics +* Agents: Added endpoint for get agent, list agents and list sessions +* Agents: Automated conversion of type hints in client tool for lite llm format +* Agents: Deprecated ToolResponseMessage in agent.resume API +* Added Provider API for listing and inspecting provider info + +### Agent Evals and Model Customization +* Eval: Added new eval benchmarks Math 500 and BFCL v3 +* Deploy and Monitoring of Agents +* Telemetry: Fix tracing to work across coroutines + +### Better Engineering +* Display code coverage for unit tests +* Updated call sites (inference, tool calls, agents) to move to async non blocking calls +* Unit tests also run on Python 3.11, 3.12, and 3.13 +* Added ollama inference to Integration tests CI +* Improved documentation across examples, testing, CLI, updated providers table ) + + + --- # v0.1.6 Published on: 2025-03-08T04:35:08Z -## 0.1.6 Release Notes - -### Build and Test Agents -* Inference: Fixed support for inline vllm provider -* (**New**) Agent: Build & Monitor Agent Workflows with Llama Stack + Anthropic's Best Practice [Notebook](https://github.com/meta-llama/llama-stack/blob/main/docs/notebooks/Llama_Stack_Agent_Workflows.ipynb) -* (**New**) Agent: Revamped agent [documentation](https://llama-stack.readthedocs.io/en/latest/building_applications/agent.html) with more details and examples -* Agent: Unify tools and Python SDK Agents API -* Agent: AsyncAgent Python SDK wrapper supporting async client tool calls -* Agent: Support python functions without @client_tool decorator as client tools -* Agent: deprecation for allow_resume_turn flag, and remove need to specify tool_prompt_format -* VectorIO: MilvusDB support added - -### Agent Evals and Model Customization -* (**New**) Agent: Llama Stack RAG Lifecycle [Notebook](https://github.com/meta-llama/llama-stack/blob/main/docs/notebooks/Llama_Stack_RAG_Lifecycle.ipynb) -* Eval: Documentation for eval, scoring, adding new benchmarks -* Eval: Distribution template to run benchmarks on llama & non-llama models -* Eval: Ability to register new custom LLM-as-judge scoring functions -* (**New**) Looking for contributors for open benchmarks. See [documentation](https://llama-stack.readthedocs.io/en/latest/references/evals_reference/index.html#open-benchmark-contributing-guide) for details. - -### Deploy and Monitoring of Agents -* Better support for different log levels across all components for better monitoring - -### Better Engineering -* Enhance OpenAPI spec to include Error types across all APIs -* Moved all tests to /tests and created unit tests to run on each PR -* Removed all dependencies on llama-models repo - +## 0.1.6 Release Notes + +### Build and Test Agents +* Inference: Fixed support for inline vllm provider +* (**New**) Agent: Build & Monitor Agent Workflows with Llama Stack + Anthropic's Best Practice [Notebook](https://github.com/meta-llama/llama-stack/blob/main/docs/notebooks/Llama_Stack_Agent_Workflows.ipynb) +* (**New**) Agent: Revamped agent [documentation](https://llama-stack.readthedocs.io/en/latest/building_applications/agent.html) with more details and examples +* Agent: Unify tools and Python SDK Agents API +* Agent: AsyncAgent Python SDK wrapper supporting async client tool calls +* Agent: Support python functions without @client_tool decorator as client tools +* Agent: deprecation for allow_resume_turn flag, and remove need to specify tool_prompt_format +* VectorIO: MilvusDB support added + +### Agent Evals and Model Customization +* (**New**) Agent: Llama Stack RAG Lifecycle [Notebook](https://github.com/meta-llama/llama-stack/blob/main/docs/notebooks/Llama_Stack_RAG_Lifecycle.ipynb) +* Eval: Documentation for eval, scoring, adding new benchmarks +* Eval: Distribution template to run benchmarks on llama & non-llama models +* Eval: Ability to register new custom LLM-as-judge scoring functions +* (**New**) Looking for contributors for open benchmarks. See [documentation](https://llama-stack.readthedocs.io/en/latest/references/evals_reference/index.html#open-benchmark-contributing-guide) for details. + +### Deploy and Monitoring of Agents +* Better support for different log levels across all components for better monitoring + +### Better Engineering +* Enhance OpenAPI spec to include Error types across all APIs +* Moved all tests to /tests and created unit tests to run on each PR +* Removed all dependencies on llama-models repo + --- # v0.1.5.1 Published on: 2025-02-28T22:37:44Z -## 0.1.5.1 Release Notes -* Fixes for security risk in https://github.com/meta-llama/llama-stack/pull/1327 and https://github.com/meta-llama/llama-stack/pull/1328 - +## 0.1.5.1 Release Notes +* Fixes for security risk in https://github.com/meta-llama/llama-stack/pull/1327 and https://github.com/meta-llama/llama-stack/pull/1328 + **Full Changelog**: https://github.com/meta-llama/llama-stack/compare/v0.1.5...v0.1.5.1 --- @@ -224,176 +224,176 @@ Published on: 2025-02-28T22:37:44Z # v0.1.5 Published on: 2025-02-28T18:14:01Z -## 0.1.5 Release Notes -### Build Agents -* Inference: Support more non-llama models (openai, anthropic, gemini) -* Inference: Can use the provider's model name in addition to the HF alias -* Inference: Fixed issues with calling tools that weren't specified in the prompt -* RAG: Improved system prompt for RAG and no more need for hard-coded rag-tool calling -* Embeddings: Added support for Nemo retriever embedding models -* Tools: Added support for MCP tools in Ollama Distribution -* Distributions: Added new Groq distribution - -### Customize Models -* Save post-trained checkpoint in SafeTensor format to allow Ollama inference provider to use the post-trained model - -### Monitor agents -* More comprehensive logging of agent steps including client tools -* Telemetry inputs/outputs are now structured and queryable -* Ability to retrieve agents session, turn, step by ids - -### Better Engineering -* Moved executorch Swift code out of this repo into the llama-stack-client-swift repo, similar to kotlin -* Move most logging to use logger instead of prints -* Completed text /chat-completion and /completion tests - +## 0.1.5 Release Notes +### Build Agents +* Inference: Support more non-llama models (openai, anthropic, gemini) +* Inference: Can use the provider's model name in addition to the HF alias +* Inference: Fixed issues with calling tools that weren't specified in the prompt +* RAG: Improved system prompt for RAG and no more need for hard-coded rag-tool calling +* Embeddings: Added support for Nemo retriever embedding models +* Tools: Added support for MCP tools in Ollama Distribution +* Distributions: Added new Groq distribution + +### Customize Models +* Save post-trained checkpoint in SafeTensor format to allow Ollama inference provider to use the post-trained model + +### Monitor agents +* More comprehensive logging of agent steps including client tools +* Telemetry inputs/outputs are now structured and queryable +* Ability to retrieve agents session, turn, step by ids + +### Better Engineering +* Moved executorch Swift code out of this repo into the llama-stack-client-swift repo, similar to kotlin +* Move most logging to use logger instead of prints +* Completed text /chat-completion and /completion tests + --- # v0.1.4 Published on: 2025-02-25T00:02:43Z -## v0.1.4 Release Notes -Here are the key changes coming as part of this release: - -### Build and Test Agents -* Inference: Added support for non-llama models -* Inference: Added option to list all downloaded models and remove models -* Agent: Introduce new api agents.resume_turn to include client side tool execution in the same turn -* Agent: AgentConfig introduces new variable “tool_config” that allows for better tool configuration and system prompt overrides -* Agent: Added logging for agent step start and completion times -* Agent: Added support for logging for tool execution metadata -* Embedding: Updated /inference/embeddings to support asymmetric models, truncation and variable sized outputs -* Embedding: Updated embedding models for Ollama, Together, and Fireworks with available defaults -* VectorIO: Improved performance of sqlite-vec using chunked writes -### Agent Evals and Model Customization -* Deprecated api /eval-tasks. Use /eval/benchmark instead -* Added CPU training support for TorchTune -### Deploy and Monitoring of Agents -* Consistent view of client and server tool calls in telemetry -### Better Engineering -* Made tests more data-driven for consistent evaluation -* Fixed documentation links and improved API reference generation -* Various small fixes for build scripts and system reliability - - +## v0.1.4 Release Notes +Here are the key changes coming as part of this release: + +### Build and Test Agents +* Inference: Added support for non-llama models +* Inference: Added option to list all downloaded models and remove models +* Agent: Introduce new api agents.resume_turn to include client side tool execution in the same turn +* Agent: AgentConfig introduces new variable “tool_config” that allows for better tool configuration and system prompt overrides +* Agent: Added logging for agent step start and completion times +* Agent: Added support for logging for tool execution metadata +* Embedding: Updated /inference/embeddings to support asymmetric models, truncation and variable sized outputs +* Embedding: Updated embedding models for Ollama, Together, and Fireworks with available defaults +* VectorIO: Improved performance of sqlite-vec using chunked writes +### Agent Evals and Model Customization +* Deprecated api /eval-tasks. Use /eval/benchmark instead +* Added CPU training support for TorchTune +### Deploy and Monitoring of Agents +* Consistent view of client and server tool calls in telemetry +### Better Engineering +* Made tests more data-driven for consistent evaluation +* Fixed documentation links and improved API reference generation +* Various small fixes for build scripts and system reliability + + --- # v0.1.3 Published on: 2025-02-14T20:24:32Z -## v0.1.3 Release - -Here are some key changes that are coming as part of this release. - -### Build and Test Agents -Streamlined the initial development experience -- Added support for llama stack run --image-type venv -- Enhanced vector store options with new sqlite-vec provider and improved Qdrant integration -- vLLM improvements for tool calling and logprobs -- Better handling of sporadic code_interpreter tool calls - -### Agent Evals -Better benchmarking and Agent performance assessment -- Renamed eval API /eval-task to /benchmarks -- Improved documentation and notebooks for RAG and evals - -### Deploy and Monitoring of Agents -Improved production readiness -- Added usage metrics collection for chat completions -- CLI improvements for provider information -- Improved error handling and system reliability -- Better model endpoint handling and accessibility -- Improved signal handling on distro server - -### Better Engineering -Infrastructure and code quality improvements -- Faster text-based chat completion tests -- Improved testing for non-streaming agent apis -- Standardized import formatting with ruff linter -- Added conventional commits standard -- Fixed documentation parsing issues - +## v0.1.3 Release + +Here are some key changes that are coming as part of this release. + +### Build and Test Agents +Streamlined the initial development experience +- Added support for llama stack run --image-type venv +- Enhanced vector store options with new sqlite-vec provider and improved Qdrant integration +- vLLM improvements for tool calling and logprobs +- Better handling of sporadic code_interpreter tool calls + +### Agent Evals +Better benchmarking and Agent performance assessment +- Renamed eval API /eval-task to /benchmarks +- Improved documentation and notebooks for RAG and evals + +### Deploy and Monitoring of Agents +Improved production readiness +- Added usage metrics collection for chat completions +- CLI improvements for provider information +- Improved error handling and system reliability +- Better model endpoint handling and accessibility +- Improved signal handling on distro server + +### Better Engineering +Infrastructure and code quality improvements +- Faster text-based chat completion tests +- Improved testing for non-streaming agent apis +- Standardized import formatting with ruff linter +- Added conventional commits standard +- Fixed documentation parsing issues + --- # v0.1.2 Published on: 2025-02-07T22:06:49Z -# TL;DR -- Several stabilizations to development flows after the switch to `uv` -- Migrated CI workflows to new OSS repo - [llama-stack-ops](https://github.com/meta-llama/llama-stack-ops) -- Added automated rebuilds for ReadTheDocs -- Llama Stack server supports HTTPS -- Added system prompt overrides support -- Several bug fixes and improvements to documentation (check out Kubernetes deployment guide by @terrytangyuan ) - +# TL;DR +- Several stabilizations to development flows after the switch to `uv` +- Migrated CI workflows to new OSS repo - [llama-stack-ops](https://github.com/meta-llama/llama-stack-ops) +- Added automated rebuilds for ReadTheDocs +- Llama Stack server supports HTTPS +- Added system prompt overrides support +- Several bug fixes and improvements to documentation (check out Kubernetes deployment guide by @terrytangyuan ) + --- # v0.1.1 Published on: 2025-02-02T02:29:24Z -A bunch of small / big improvements everywhere including support for Windows, switching to `uv` and many provider improvements. - +A bunch of small / big improvements everywhere including support for Windows, switching to `uv` and many provider improvements. + --- # v0.1.0 Published on: 2025-01-24T17:47:47Z -We are excited to announce a stable API release of Llama Stack, which enables developers to build RAG applications and Agents using tools and safety shields, monitor and those agents with telemetry, and evaluate the agent with scoring functions. - -## Context -GenAI application developers need more than just an LLM - they need to integrate tools, connect with their data sources, establish guardrails, and ground the LLM responses effectively. Currently, developers must piece together various tools and APIs, complicating the development lifecycle and increasing costs. The result is that developers are spending more time on these integrations rather than focusing on the application logic itself. The bespoke coupling of components also makes it challenging to adopt state-of-the-art solutions in the rapidly evolving GenAI space. This is particularly difficult for open models like Llama, as best practices are not widely established in the open. - -Llama Stack was created to provide developers with a comprehensive and coherent interface that simplifies AI application development and codifies best practices across the Llama ecosystem. Since our launch in September 2024, we have seen a huge uptick in interest in Llama Stack APIs by both AI developers and from partners building AI services with Llama models. Partners like Nvidia, Fireworks, and Ollama have collaborated with us to develop implementations across various APIs, including inference, memory, and safety. - -With Llama Stack, you can easily build a RAG agent which can also search the web, do complex math, and custom tool calling. You can use telemetry to inspect those traces, and convert telemetry into evals datasets. And with Llama Stack’s plugin architecture and prepackage distributions, you choose to run your agent anywhere - in the cloud with our partners, deploy your own environment using virtualenv, conda, or Docker, operate locally with Ollama, or even run on mobile devices with our SDKs. Llama Stack offers unprecedented flexibility while also simplifying the developer experience. - -## Release -After iterating on the APIs for the last 3 months, today we’re launching a stable release (V1) of the Llama Stack APIs and the corresponding llama-stack server and client packages(v0.1.0). We now have automated tests for providers. These tests make sure that all provider implementations are verified. Developers can now easily and reliably select distributions or providers based on their specific requirements. - -There are example standalone apps in llama-stack-apps. - - -## Key Features of this release - -- **Unified API Layer** - - Inference: Run LLM models - - RAG: Store and retrieve knowledge for RAG - - Agents: Build multi-step agentic workflows - - Tools: Register tools that can be called by the agent - - Safety: Apply content filtering and safety policies - - Evaluation: Test model and agent quality - - Telemetry: Collect and analyze usage data and complex agentic traces - - Post Training ( Coming Soon ): Fine tune models for specific use cases - -- **Rich Provider Ecosystem** - - Local Development: Meta's Reference, Ollama - - Cloud: Fireworks, Together, Nvidia, AWS Bedrock, Groq, Cerebras - - On-premises: Nvidia NIM, vLLM, TGI, Dell-TGI - - On-device: iOS and Android support - -- **Built for Production** - - Pre-packaged distributions for common deployment scenarios - - Backwards compatibility across model versions - - Comprehensive evaluation capabilities - - Full observability and monitoring - -- **Multiple developer interfaces** - - CLI: Command line interface - - Python SDK - - Swift iOS SDK - - Kotlin Android SDK - -- **Sample llama stack applications** - - Python - - iOS - - Android - - +We are excited to announce a stable API release of Llama Stack, which enables developers to build RAG applications and Agents using tools and safety shields, monitor and those agents with telemetry, and evaluate the agent with scoring functions. + +## Context +GenAI application developers need more than just an LLM - they need to integrate tools, connect with their data sources, establish guardrails, and ground the LLM responses effectively. Currently, developers must piece together various tools and APIs, complicating the development lifecycle and increasing costs. The result is that developers are spending more time on these integrations rather than focusing on the application logic itself. The bespoke coupling of components also makes it challenging to adopt state-of-the-art solutions in the rapidly evolving GenAI space. This is particularly difficult for open models like Llama, as best practices are not widely established in the open. + +Llama Stack was created to provide developers with a comprehensive and coherent interface that simplifies AI application development and codifies best practices across the Llama ecosystem. Since our launch in September 2024, we have seen a huge uptick in interest in Llama Stack APIs by both AI developers and from partners building AI services with Llama models. Partners like Nvidia, Fireworks, and Ollama have collaborated with us to develop implementations across various APIs, including inference, memory, and safety. + +With Llama Stack, you can easily build a RAG agent which can also search the web, do complex math, and custom tool calling. You can use telemetry to inspect those traces, and convert telemetry into evals datasets. And with Llama Stack’s plugin architecture and prepackage distributions, you choose to run your agent anywhere - in the cloud with our partners, deploy your own environment using virtualenv, conda, or Docker, operate locally with Ollama, or even run on mobile devices with our SDKs. Llama Stack offers unprecedented flexibility while also simplifying the developer experience. + +## Release +After iterating on the APIs for the last 3 months, today we’re launching a stable release (V1) of the Llama Stack APIs and the corresponding llama-stack server and client packages(v0.1.0). We now have automated tests for providers. These tests make sure that all provider implementations are verified. Developers can now easily and reliably select distributions or providers based on their specific requirements. + +There are example standalone apps in llama-stack-apps. + + +## Key Features of this release + +- **Unified API Layer** + - Inference: Run LLM models + - RAG: Store and retrieve knowledge for RAG + - Agents: Build multi-step agentic workflows + - Tools: Register tools that can be called by the agent + - Safety: Apply content filtering and safety policies + - Evaluation: Test model and agent quality + - Telemetry: Collect and analyze usage data and complex agentic traces + - Post Training ( Coming Soon ): Fine tune models for specific use cases + +- **Rich Provider Ecosystem** + - Local Development: Meta's Reference, Ollama + - Cloud: Fireworks, Together, Nvidia, AWS Bedrock, Groq, Cerebras + - On-premises: Nvidia NIM, vLLM, TGI, Dell-TGI + - On-device: iOS and Android support + +- **Built for Production** + - Pre-packaged distributions for common deployment scenarios + - Backwards compatibility across model versions + - Comprehensive evaluation capabilities + - Full observability and monitoring + +- **Multiple developer interfaces** + - CLI: Command line interface + - Python SDK + - Swift iOS SDK + - Kotlin Android SDK + +- **Sample llama stack applications** + - Python + - iOS + - Android + + --- @@ -407,8 +407,8 @@ Published on: 2025-01-22T22:24:01Z # v0.0.63 Published on: 2024-12-18T07:17:43Z -A small but important bug-fix release to update the URL datatype for the client-SDKs. The issue affected multimodal agentic turns especially. - +A small but important bug-fix release to update the URL datatype for the client-SDKs. The issue affected multimodal agentic turns especially. + **Full Changelog**: https://github.com/meta-llama/llama-stack/compare/v0.0.62...v0.0.63 --- @@ -444,39 +444,40 @@ Published on: 2024-11-22T00:36:09Z # v0.0.53 Published on: 2024-11-20T22:18:00Z -🚀 Initial Release Notes for Llama Stack! - -### Added -- Resource-oriented design for models, shields, memory banks, datasets and eval tasks -- Persistence for registered objects with distribution -- Ability to persist memory banks created for FAISS -- PostgreSQL KVStore implementation -- Environment variable placeholder support in run.yaml files -- Comprehensive Zero-to-Hero notebooks and quickstart guides -- Support for quantized models in Ollama -- Vision models support for Together, Fireworks, Meta-Reference, and Ollama, and vLLM -- Bedrock distribution with safety shields support -- Evals API with task registration and scoring functions -- MMLU and SimpleQA benchmark scoring functions -- Huggingface dataset provider integration for benchmarks -- Support for custom dataset registration from local paths -- Benchmark evaluation CLI tools with visualization tables -- RAG evaluation scoring functions and metrics -- Local persistence for datasets and eval tasks - -### Changed -- Split safety into distinct providers (llama-guard, prompt-guard, code-scanner) -- Changed provider naming convention (`impls` → `inline`, `adapters` → `remote`) -- Updated API signatures for dataset and eval task registration -- Restructured folder organization for providers -- Enhanced Docker build configuration -- Added version prefixing for REST API routes -- Enhanced evaluation task registration workflow -- Improved benchmark evaluation output formatting -- Restructured evals folder organization for better modularity - -### Removed -- `llama stack configure` command - +🚀 Initial Release Notes for Llama Stack! + +### Added +- Resource-oriented design for models, shields, memory banks, datasets and eval tasks +- Persistence for registered objects with distribution +- Ability to persist memory banks created for FAISS +- PostgreSQL KVStore implementation +- Environment variable placeholder support in run.yaml files +- Comprehensive Zero-to-Hero notebooks and quickstart guides +- Support for quantized models in Ollama +- Vision models support for Together, Fireworks, Meta-Reference, and Ollama, and vLLM +- Bedrock distribution with safety shields support +- Evals API with task registration and scoring functions +- MMLU and SimpleQA benchmark scoring functions +- Huggingface dataset provider integration for benchmarks +- Support for custom dataset registration from local paths +- Benchmark evaluation CLI tools with visualization tables +- RAG evaluation scoring functions and metrics +- Local persistence for datasets and eval tasks + +### Changed +- Split safety into distinct providers (llama-guard, prompt-guard, code-scanner) +- Changed provider naming convention (`impls` → `inline`, `adapters` → `remote`) +- Updated API signatures for dataset and eval task registration +- Restructured folder organization for providers +- Enhanced Docker build configuration +- Added version prefixing for REST API routes +- Enhanced evaluation task registration workflow +- Improved benchmark evaluation output formatting +- Restructured evals folder organization for better modularity + +### Removed +- `llama stack configure` command + --- + diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 10e3f6cee..d7c3e3e2f 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -167,11 +167,14 @@ If you have made changes to a provider's configuration in any form (introducing If you are making changes to the documentation at [https://llama-stack.readthedocs.io/en/latest/](https://llama-stack.readthedocs.io/en/latest/), you can use the following command to build the documentation and preview your changes. You will need [Sphinx](https://www.sphinx-doc.org/en/master/) and the readthedocs theme. ```bash +cd docs +uv sync --extra docs + # This rebuilds the documentation pages. -uv run --group docs make -C docs/ html +uv run make html # This will start a local server (usually at http://127.0.0.1:8000) that automatically rebuilds and refreshes when you make changes to the documentation. -uv run --group docs sphinx-autobuild docs/source docs/build/html --write-all +uv run sphinx-autobuild source build/html --write-all ``` ### Update API Documentation @@ -179,7 +182,7 @@ uv run --group docs sphinx-autobuild docs/source docs/build/html --write-all If you modify or add new API endpoints, update the API documentation accordingly. You can do this by running the following command: ```bash -uv run ./docs/openapi_generator/run_openapi_generator.sh +uv run --with ".[dev]" ./docs/openapi_generator/run_openapi_generator.sh ``` The generated API documentation will be available in `docs/_static/`. Make sure to review the changes before committing. diff --git a/MANIFEST.in b/MANIFEST.in index 88bd11767..879a9cbd4 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,4 +1,5 @@ include pyproject.toml +include llama_stack/templates/dependencies.json include llama_stack/models/llama/llama3/tokenizer.model include llama_stack/models/llama/llama4/tokenizer.model include llama_stack/distribution/*.sh diff --git a/README.md b/README.md index 37f1aa0f3..5dfe3577a 100644 --- a/README.md +++ b/README.md @@ -107,29 +107,26 @@ By reducing friction and complexity, Llama Stack empowers developers to focus on ### API Providers Here is a list of the various API providers and available distributions that can help developers get started easily with Llama Stack. -| **API Provider Builder** | **Environments** | **Agents** | **Inference** | **Memory** | **Safety** | **Telemetry** | **Post Training** | -|:------------------------:|:----------------------:|:----------:|:-------------:|:----------:|:----------:|:-------------:|:-----------------:| -| Meta Reference | Single Node | ✅ | ✅ | ✅ | ✅ | ✅ | | -| SambaNova | Hosted | | ✅ | | ✅ | | | -| Cerebras | Hosted | | ✅ | | | | | -| Fireworks | Hosted | ✅ | ✅ | ✅ | | | | -| AWS Bedrock | Hosted | | ✅ | | ✅ | | | -| Together | Hosted | ✅ | ✅ | | ✅ | | | -| Groq | Hosted | | ✅ | | | | | -| Ollama | Single Node | | ✅ | | | | | -| TGI | Hosted and Single Node | | ✅ | | | | | -| NVIDIA NIM | Hosted and Single Node | | ✅ | | | | | -| Chroma | Single Node | | | ✅ | | | | -| PG Vector | Single Node | | | ✅ | | | | -| PyTorch ExecuTorch | On-device iOS | ✅ | ✅ | | | | | -| vLLM | Hosted and Single Node | | ✅ | | | | | -| OpenAI | Hosted | | ✅ | | | | | -| Anthropic | Hosted | | ✅ | | | | | -| Gemini | Hosted | | ✅ | | | | | -| watsonx | Hosted | | ✅ | | | | | -| HuggingFace | Single Node | | | | | | ✅ | -| TorchTune | Single Node | | | | | | ✅ | -| NVIDIA NEMO | Hosted | | | | | | ✅ | +| **API Provider Builder** | **Environments** | **Agents** | **Inference** | **Memory** | **Safety** | **Telemetry** | +|:------------------------:|:----------------------:|:----------:|:-------------:|:----------:|:----------:|:-------------:| +| Meta Reference | Single Node | ✅ | ✅ | ✅ | ✅ | ✅ | +| SambaNova | Hosted | | ✅ | | | | +| Cerebras | Hosted | | ✅ | | | | +| Fireworks | Hosted | ✅ | ✅ | ✅ | | | +| AWS Bedrock | Hosted | | ✅ | | ✅ | | +| Together | Hosted | ✅ | ✅ | | ✅ | | +| Groq | Hosted | | ✅ | | | | +| Ollama | Single Node | | ✅ | | | | +| TGI | Hosted and Single Node | | ✅ | | | | +| NVIDIA NIM | Hosted and Single Node | | ✅ | | | | +| Chroma | Single Node | | | ✅ | | | +| PG Vector | Single Node | | | ✅ | | | +| PyTorch ExecuTorch | On-device iOS | ✅ | ✅ | | | | +| vLLM | Hosted and Single Node | | ✅ | | | | +| OpenAI | Hosted | | ✅ | | | | +| Anthropic | Hosted | | ✅ | | | | +| Gemini | Hosted | | ✅ | | | | +| watsonx | Hosted | | ✅ | | | | ### Distributions diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index d88462909..9032e5968 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -518,74 +518,6 @@ } }, "/v1/openai/v1/responses": { - "get": { - "responses": { - "200": { - "description": "A ListOpenAIResponseObject.", - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/ListOpenAIResponseObject" - } - } - } - }, - "400": { - "$ref": "#/components/responses/BadRequest400" - }, - "429": { - "$ref": "#/components/responses/TooManyRequests429" - }, - "500": { - "$ref": "#/components/responses/InternalServerError500" - }, - "default": { - "$ref": "#/components/responses/DefaultError" - } - }, - "tags": [ - "Agents" - ], - "description": "List all OpenAI responses.", - "parameters": [ - { - "name": "after", - "in": "query", - "description": "The ID of the last response to return.", - "required": false, - "schema": { - "type": "string" - } - }, - { - "name": "limit", - "in": "query", - "description": "The number of responses to return.", - "required": false, - "schema": { - "type": "integer" - } - }, - { - "name": "model", - "in": "query", - "description": "The model to filter responses by.", - "required": false, - "schema": { - "type": "string" - } - }, - { - "name": "order", - "in": "query", - "description": "The order to sort responses by when sorted by created_at ('asc' or 'desc').", - "required": false, - "schema": { - "$ref": "#/components/schemas/Order" - } - } - ] - }, "post": { "responses": { "200": { @@ -1266,49 +1198,6 @@ ] } }, - "/v1/openai/v1/chat/completions/{completion_id}": { - "get": { - "responses": { - "200": { - "description": "A OpenAICompletionWithInputMessages.", - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/OpenAICompletionWithInputMessages" - } - } - } - }, - "400": { - "$ref": "#/components/responses/BadRequest400" - }, - "429": { - "$ref": "#/components/responses/TooManyRequests429" - }, - "500": { - "$ref": "#/components/responses/InternalServerError500" - }, - "default": { - "$ref": "#/components/responses/DefaultError" - } - }, - "tags": [ - "Inference" - ], - "description": "Describe a chat completion by its ID.", - "parameters": [ - { - "name": "completion_id", - "in": "path", - "description": "ID of the chat completion.", - "required": true, - "schema": { - "type": "string" - } - } - ] - } - }, "/v1/datasets/{dataset_id}": { "get": { "responses": { @@ -1463,7 +1352,7 @@ ] } }, - "/v1/openai/v1/responses/{response_id}": { + "/v1/openai/v1/responses/{id}": { "get": { "responses": { "200": { @@ -1495,7 +1384,7 @@ "description": "Retrieve an OpenAI response by its ID.", "parameters": [ { - "name": "response_id", + "name": "id", "in": "path", "description": "The ID of the OpenAI response to retrieve.", "required": true, @@ -2685,124 +2574,6 @@ } } }, - "/v1/openai/v1/chat/completions": { - "get": { - "responses": { - "200": { - "description": "A ListOpenAIChatCompletionResponse.", - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/ListOpenAIChatCompletionResponse" - } - } - } - }, - "400": { - "$ref": "#/components/responses/BadRequest400" - }, - "429": { - "$ref": "#/components/responses/TooManyRequests429" - }, - "500": { - "$ref": "#/components/responses/InternalServerError500" - }, - "default": { - "$ref": "#/components/responses/DefaultError" - } - }, - "tags": [ - "Inference" - ], - "description": "List all chat completions.", - "parameters": [ - { - "name": "after", - "in": "query", - "description": "The ID of the last chat completion to return.", - "required": false, - "schema": { - "type": "string" - } - }, - { - "name": "limit", - "in": "query", - "description": "The maximum number of chat completions to return.", - "required": false, - "schema": { - "type": "integer" - } - }, - { - "name": "model", - "in": "query", - "description": "The model to filter by.", - "required": false, - "schema": { - "type": "string" - } - }, - { - "name": "order", - "in": "query", - "description": "The order to sort the chat completions by: \"asc\" or \"desc\". Defaults to \"desc\".", - "required": false, - "schema": { - "$ref": "#/components/schemas/Order" - } - } - ] - }, - "post": { - "responses": { - "200": { - "description": "An OpenAIChatCompletion.", - "content": { - "application/json": { - "schema": { - "oneOf": [ - { - "$ref": "#/components/schemas/OpenAIChatCompletion" - }, - { - "$ref": "#/components/schemas/OpenAIChatCompletionChunk" - } - ] - } - } - } - }, - "400": { - "$ref": "#/components/responses/BadRequest400" - }, - "429": { - "$ref": "#/components/responses/TooManyRequests429" - }, - "500": { - "$ref": "#/components/responses/InternalServerError500" - }, - "default": { - "$ref": "#/components/responses/DefaultError" - } - }, - "tags": [ - "Inference" - ], - "description": "Generate an OpenAI-compatible chat completion for the given messages using the specified model.", - "parameters": [], - "requestBody": { - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/OpenaiChatCompletionRequest" - } - } - }, - "required": true - } - } - }, "/v1/datasets": { "get": { "responses": { @@ -2994,97 +2765,6 @@ } } }, - "/v1/openai/v1/responses/{response_id}/input_items": { - "get": { - "responses": { - "200": { - "description": "An ListOpenAIResponseInputItem.", - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/ListOpenAIResponseInputItem" - } - } - } - }, - "400": { - "$ref": "#/components/responses/BadRequest400" - }, - "429": { - "$ref": "#/components/responses/TooManyRequests429" - }, - "500": { - "$ref": "#/components/responses/InternalServerError500" - }, - "default": { - "$ref": "#/components/responses/DefaultError" - } - }, - "tags": [ - "Agents" - ], - "description": "List input items for a given OpenAI response.", - "parameters": [ - { - "name": "response_id", - "in": "path", - "description": "The ID of the response to retrieve input items for.", - "required": true, - "schema": { - "type": "string" - } - }, - { - "name": "after", - "in": "query", - "description": "An item ID to list items after, used for pagination.", - "required": false, - "schema": { - "type": "string" - } - }, - { - "name": "before", - "in": "query", - "description": "An item ID to list items before, used for pagination.", - "required": false, - "schema": { - "type": "string" - } - }, - { - "name": "include", - "in": "query", - "description": "Additional fields to include in the response.", - "required": false, - "schema": { - "type": "array", - "items": { - "type": "string" - } - } - }, - { - "name": "limit", - "in": "query", - "description": "A limit on the number of objects to be returned. Limit can range between 1 and 100, and the default is 20.", - "required": false, - "schema": { - "type": "integer" - } - }, - { - "name": "order", - "in": "query", - "description": "The order to return the input items in. Default is desc.", - "required": false, - "schema": { - "$ref": "#/components/schemas/Order" - } - } - ] - } - }, "/v1/providers": { "get": { "responses": { @@ -3564,6 +3244,56 @@ } } }, + "/v1/openai/v1/chat/completions": { + "post": { + "responses": { + "200": { + "description": "An OpenAIChatCompletion.", + "content": { + "application/json": { + "schema": { + "oneOf": [ + { + "$ref": "#/components/schemas/OpenAIChatCompletion" + }, + { + "$ref": "#/components/schemas/OpenAIChatCompletionChunk" + } + ] + } + } + } + }, + "400": { + "$ref": "#/components/responses/BadRequest400" + }, + "429": { + "$ref": "#/components/responses/TooManyRequests429" + }, + "500": { + "$ref": "#/components/responses/InternalServerError500" + }, + "default": { + "$ref": "#/components/responses/DefaultError" + } + }, + "tags": [ + "Inference" + ], + "description": "Generate an OpenAI-compatible chat completion for the given messages using the specified model.", + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/OpenaiChatCompletionRequest" + } + } + }, + "required": true + } + } + }, "/v1/openai/v1/completions": { "post": { "responses": { @@ -3607,49 +3337,6 @@ } } }, - "/v1/openai/v1/embeddings": { - "post": { - "responses": { - "200": { - "description": "An OpenAIEmbeddingsResponse containing the embeddings.", - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/OpenAIEmbeddingsResponse" - } - } - } - }, - "400": { - "$ref": "#/components/responses/BadRequest400" - }, - "429": { - "$ref": "#/components/responses/TooManyRequests429" - }, - "500": { - "$ref": "#/components/responses/InternalServerError500" - }, - "default": { - "$ref": "#/components/responses/DefaultError" - } - }, - "tags": [ - "Inference" - ], - "description": "Generate OpenAI-compatible embeddings for the given input using the specified model.", - "parameters": [], - "requestBody": { - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/OpenaiEmbeddingsRequest" - } - } - }, - "required": true - } - } - }, "/v1/openai/v1/models": { "get": { "responses": { @@ -6944,9 +6631,6 @@ }, { "$ref": "#/components/schemas/OpenAIResponseInputToolFunction" - }, - { - "$ref": "#/components/schemas/OpenAIResponseInputToolMCP" } ], "discriminator": { @@ -6954,8 +6638,7 @@ "mapping": { "web_search": "#/components/schemas/OpenAIResponseInputToolWebSearch", "file_search": "#/components/schemas/OpenAIResponseInputToolFileSearch", - "function": "#/components/schemas/OpenAIResponseInputToolFunction", - "mcp": "#/components/schemas/OpenAIResponseInputToolMCP" + "function": "#/components/schemas/OpenAIResponseInputToolFunction" } } }, @@ -7045,110 +6728,6 @@ ], "title": "OpenAIResponseInputToolFunction" }, - "OpenAIResponseInputToolMCP": { - "type": "object", - "properties": { - "type": { - "type": "string", - "const": "mcp", - "default": "mcp" - }, - "server_label": { - "type": "string" - }, - "server_url": { - "type": "string" - }, - "headers": { - "type": "object", - "additionalProperties": { - "oneOf": [ - { - "type": "null" - }, - { - "type": "boolean" - }, - { - "type": "number" - }, - { - "type": "string" - }, - { - "type": "array" - }, - { - "type": "object" - } - ] - } - }, - "require_approval": { - "oneOf": [ - { - "type": "string", - "const": "always" - }, - { - "type": "string", - "const": "never" - }, - { - "type": "object", - "properties": { - "always": { - "type": "array", - "items": { - "type": "string" - } - }, - "never": { - "type": "array", - "items": { - "type": "string" - } - } - }, - "additionalProperties": false, - "title": "ApprovalFilter" - } - ], - "default": "never" - }, - "allowed_tools": { - "oneOf": [ - { - "type": "array", - "items": { - "type": "string" - } - }, - { - "type": "object", - "properties": { - "tool_names": { - "type": "array", - "items": { - "type": "string" - } - } - }, - "additionalProperties": false, - "title": "AllowedToolsFilter" - } - ] - } - }, - "additionalProperties": false, - "required": [ - "type", - "server_label", - "server_url", - "require_approval" - ], - "title": "OpenAIResponseInputToolMCP" - }, "OpenAIResponseInputToolWebSearch": { "type": "object", "properties": { @@ -7261,15 +6840,15 @@ "OpenAIResponseOutputMessageFunctionToolCall": { "type": "object", "properties": { + "arguments": { + "type": "string" + }, "call_id": { "type": "string" }, "name": { "type": "string" }, - "arguments": { - "type": "string" - }, "type": { "type": "string", "const": "function_call", @@ -7284,10 +6863,12 @@ }, "additionalProperties": false, "required": [ + "arguments", "call_id", "name", - "arguments", - "type" + "type", + "id", + "status" ], "title": "OpenAIResponseOutputMessageFunctionToolCall" }, @@ -7335,9 +6916,6 @@ "type": "string", "description": "The underlying LLM used for completions." }, - "instructions": { - "type": "string" - }, "previous_response_id": { "type": "string", "description": "(Optional) if specified, the new response will be a continuation of the previous response. This can be used to easily fork-off new responses from existing responses." @@ -7453,12 +7031,6 @@ }, { "$ref": "#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall" - }, - { - "$ref": "#/components/schemas/OpenAIResponseOutputMessageMCPCall" - }, - { - "$ref": "#/components/schemas/OpenAIResponseOutputMessageMCPListTools" } ], "discriminator": { @@ -7466,126 +7038,15 @@ "mapping": { "message": "#/components/schemas/OpenAIResponseMessage", "web_search_call": "#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall", - "function_call": "#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall", - "mcp_call": "#/components/schemas/OpenAIResponseOutputMessageMCPCall", - "mcp_list_tools": "#/components/schemas/OpenAIResponseOutputMessageMCPListTools" + "function_call": "#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall" } } }, - "OpenAIResponseOutputMessageMCPCall": { - "type": "object", - "properties": { - "id": { - "type": "string" - }, - "type": { - "type": "string", - "const": "mcp_call", - "default": "mcp_call" - }, - "arguments": { - "type": "string" - }, - "name": { - "type": "string" - }, - "server_label": { - "type": "string" - }, - "error": { - "type": "string" - }, - "output": { - "type": "string" - } - }, - "additionalProperties": false, - "required": [ - "id", - "type", - "arguments", - "name", - "server_label" - ], - "title": "OpenAIResponseOutputMessageMCPCall" - }, - "OpenAIResponseOutputMessageMCPListTools": { - "type": "object", - "properties": { - "id": { - "type": "string" - }, - "type": { - "type": "string", - "const": "mcp_list_tools", - "default": "mcp_list_tools" - }, - "server_label": { - "type": "string" - }, - "tools": { - "type": "array", - "items": { - "type": "object", - "properties": { - "input_schema": { - "type": "object", - "additionalProperties": { - "oneOf": [ - { - "type": "null" - }, - { - "type": "boolean" - }, - { - "type": "number" - }, - { - "type": "string" - }, - { - "type": "array" - }, - { - "type": "object" - } - ] - } - }, - "name": { - "type": "string" - }, - "description": { - "type": "string" - } - }, - "additionalProperties": false, - "required": [ - "input_schema", - "name" - ], - "title": "MCPListToolsTool" - } - } - }, - "additionalProperties": false, - "required": [ - "id", - "type", - "server_label", - "tools" - ], - "title": "OpenAIResponseOutputMessageMCPListTools" - }, "OpenAIResponseObjectStream": { "oneOf": [ { "$ref": "#/components/schemas/OpenAIResponseObjectStreamResponseCreated" }, - { - "$ref": "#/components/schemas/OpenAIResponseObjectStreamResponseOutputTextDelta" - }, { "$ref": "#/components/schemas/OpenAIResponseObjectStreamResponseCompleted" } @@ -7594,7 +7055,6 @@ "propertyName": "type", "mapping": { "response.created": "#/components/schemas/OpenAIResponseObjectStreamResponseCreated", - "response.output_text.delta": "#/components/schemas/OpenAIResponseObjectStreamResponseOutputTextDelta", "response.completed": "#/components/schemas/OpenAIResponseObjectStreamResponseCompleted" } } @@ -7637,41 +7097,6 @@ ], "title": "OpenAIResponseObjectStreamResponseCreated" }, - "OpenAIResponseObjectStreamResponseOutputTextDelta": { - "type": "object", - "properties": { - "content_index": { - "type": "integer" - }, - "delta": { - "type": "string" - }, - "item_id": { - "type": "string" - }, - "output_index": { - "type": "integer" - }, - "sequence_number": { - "type": "integer" - }, - "type": { - "type": "string", - "const": "response.output_text.delta", - "default": "response.output_text.delta" - } - }, - "additionalProperties": false, - "required": [ - "content_index", - "delta", - "item_id", - "output_index", - "sequence_number", - "type" - ], - "title": "OpenAIResponseObjectStreamResponseOutputTextDelta" - }, "CreateUploadSessionRequest": { "type": "object", "properties": { @@ -8356,482 +7781,6 @@ ], "title": "Benchmark" }, - "OpenAIAssistantMessageParam": { - "type": "object", - "properties": { - "role": { - "type": "string", - "const": "assistant", - "default": "assistant", - "description": "Must be \"assistant\" to identify this as the model's response" - }, - "content": { - "oneOf": [ - { - "type": "string" - }, - { - "type": "array", - "items": { - "$ref": "#/components/schemas/OpenAIChatCompletionContentPartParam" - } - } - ], - "description": "The content of the model's response" - }, - "name": { - "type": "string", - "description": "(Optional) The name of the assistant message participant." - }, - "tool_calls": { - "type": "array", - "items": { - "$ref": "#/components/schemas/OpenAIChatCompletionToolCall" - }, - "description": "List of tool calls. Each tool call is an OpenAIChatCompletionToolCall object." - } - }, - "additionalProperties": false, - "required": [ - "role" - ], - "title": "OpenAIAssistantMessageParam", - "description": "A message containing the model's (assistant) response in an OpenAI-compatible chat completion request." - }, - "OpenAIChatCompletionContentPartImageParam": { - "type": "object", - "properties": { - "type": { - "type": "string", - "const": "image_url", - "default": "image_url" - }, - "image_url": { - "$ref": "#/components/schemas/OpenAIImageURL" - } - }, - "additionalProperties": false, - "required": [ - "type", - "image_url" - ], - "title": "OpenAIChatCompletionContentPartImageParam" - }, - "OpenAIChatCompletionContentPartParam": { - "oneOf": [ - { - "$ref": "#/components/schemas/OpenAIChatCompletionContentPartTextParam" - }, - { - "$ref": "#/components/schemas/OpenAIChatCompletionContentPartImageParam" - } - ], - "discriminator": { - "propertyName": "type", - "mapping": { - "text": "#/components/schemas/OpenAIChatCompletionContentPartTextParam", - "image_url": "#/components/schemas/OpenAIChatCompletionContentPartImageParam" - } - } - }, - "OpenAIChatCompletionContentPartTextParam": { - "type": "object", - "properties": { - "type": { - "type": "string", - "const": "text", - "default": "text" - }, - "text": { - "type": "string" - } - }, - "additionalProperties": false, - "required": [ - "type", - "text" - ], - "title": "OpenAIChatCompletionContentPartTextParam" - }, - "OpenAIChatCompletionToolCall": { - "type": "object", - "properties": { - "index": { - "type": "integer" - }, - "id": { - "type": "string" - }, - "type": { - "type": "string", - "const": "function", - "default": "function" - }, - "function": { - "$ref": "#/components/schemas/OpenAIChatCompletionToolCallFunction" - } - }, - "additionalProperties": false, - "required": [ - "type" - ], - "title": "OpenAIChatCompletionToolCall" - }, - "OpenAIChatCompletionToolCallFunction": { - "type": "object", - "properties": { - "name": { - "type": "string" - }, - "arguments": { - "type": "string" - } - }, - "additionalProperties": false, - "title": "OpenAIChatCompletionToolCallFunction" - }, - "OpenAIChoice": { - "type": "object", - "properties": { - "message": { - "$ref": "#/components/schemas/OpenAIMessageParam", - "description": "The message from the model" - }, - "finish_reason": { - "type": "string", - "description": "The reason the model stopped generating" - }, - "index": { - "type": "integer", - "description": "The index of the choice" - }, - "logprobs": { - "$ref": "#/components/schemas/OpenAIChoiceLogprobs", - "description": "(Optional) The log probabilities for the tokens in the message" - } - }, - "additionalProperties": false, - "required": [ - "message", - "finish_reason", - "index" - ], - "title": "OpenAIChoice", - "description": "A choice from an OpenAI-compatible chat completion response." - }, - "OpenAIChoiceLogprobs": { - "type": "object", - "properties": { - "content": { - "type": "array", - "items": { - "$ref": "#/components/schemas/OpenAITokenLogProb" - }, - "description": "(Optional) The log probabilities for the tokens in the message" - }, - "refusal": { - "type": "array", - "items": { - "$ref": "#/components/schemas/OpenAITokenLogProb" - }, - "description": "(Optional) The log probabilities for the tokens in the message" - } - }, - "additionalProperties": false, - "title": "OpenAIChoiceLogprobs", - "description": "The log probabilities for the tokens in the message from an OpenAI-compatible chat completion response." - }, - "OpenAIDeveloperMessageParam": { - "type": "object", - "properties": { - "role": { - "type": "string", - "const": "developer", - "default": "developer", - "description": "Must be \"developer\" to identify this as a developer message" - }, - "content": { - "oneOf": [ - { - "type": "string" - }, - { - "type": "array", - "items": { - "$ref": "#/components/schemas/OpenAIChatCompletionContentPartParam" - } - } - ], - "description": "The content of the developer message" - }, - "name": { - "type": "string", - "description": "(Optional) The name of the developer message participant." - } - }, - "additionalProperties": false, - "required": [ - "role", - "content" - ], - "title": "OpenAIDeveloperMessageParam", - "description": "A message from the developer in an OpenAI-compatible chat completion request." - }, - "OpenAIImageURL": { - "type": "object", - "properties": { - "url": { - "type": "string" - }, - "detail": { - "type": "string" - } - }, - "additionalProperties": false, - "required": [ - "url" - ], - "title": "OpenAIImageURL" - }, - "OpenAIMessageParam": { - "oneOf": [ - { - "$ref": "#/components/schemas/OpenAIUserMessageParam" - }, - { - "$ref": "#/components/schemas/OpenAISystemMessageParam" - }, - { - "$ref": "#/components/schemas/OpenAIAssistantMessageParam" - }, - { - "$ref": "#/components/schemas/OpenAIToolMessageParam" - }, - { - "$ref": "#/components/schemas/OpenAIDeveloperMessageParam" - } - ], - "discriminator": { - "propertyName": "role", - "mapping": { - "user": "#/components/schemas/OpenAIUserMessageParam", - "system": "#/components/schemas/OpenAISystemMessageParam", - "assistant": "#/components/schemas/OpenAIAssistantMessageParam", - "tool": "#/components/schemas/OpenAIToolMessageParam", - "developer": "#/components/schemas/OpenAIDeveloperMessageParam" - } - } - }, - "OpenAISystemMessageParam": { - "type": "object", - "properties": { - "role": { - "type": "string", - "const": "system", - "default": "system", - "description": "Must be \"system\" to identify this as a system message" - }, - "content": { - "oneOf": [ - { - "type": "string" - }, - { - "type": "array", - "items": { - "$ref": "#/components/schemas/OpenAIChatCompletionContentPartParam" - } - } - ], - "description": "The content of the \"system prompt\". If multiple system messages are provided, they are concatenated. The underlying Llama Stack code may also add other system messages (for example, for formatting tool definitions)." - }, - "name": { - "type": "string", - "description": "(Optional) The name of the system message participant." - } - }, - "additionalProperties": false, - "required": [ - "role", - "content" - ], - "title": "OpenAISystemMessageParam", - "description": "A system message providing instructions or context to the model." - }, - "OpenAITokenLogProb": { - "type": "object", - "properties": { - "token": { - "type": "string" - }, - "bytes": { - "type": "array", - "items": { - "type": "integer" - } - }, - "logprob": { - "type": "number" - }, - "top_logprobs": { - "type": "array", - "items": { - "$ref": "#/components/schemas/OpenAITopLogProb" - } - } - }, - "additionalProperties": false, - "required": [ - "token", - "logprob", - "top_logprobs" - ], - "title": "OpenAITokenLogProb", - "description": "The log probability for a token from an OpenAI-compatible chat completion response." - }, - "OpenAIToolMessageParam": { - "type": "object", - "properties": { - "role": { - "type": "string", - "const": "tool", - "default": "tool", - "description": "Must be \"tool\" to identify this as a tool response" - }, - "tool_call_id": { - "type": "string", - "description": "Unique identifier for the tool call this response is for" - }, - "content": { - "oneOf": [ - { - "type": "string" - }, - { - "type": "array", - "items": { - "$ref": "#/components/schemas/OpenAIChatCompletionContentPartParam" - } - } - ], - "description": "The response content from the tool" - } - }, - "additionalProperties": false, - "required": [ - "role", - "tool_call_id", - "content" - ], - "title": "OpenAIToolMessageParam", - "description": "A message representing the result of a tool invocation in an OpenAI-compatible chat completion request." - }, - "OpenAITopLogProb": { - "type": "object", - "properties": { - "token": { - "type": "string" - }, - "bytes": { - "type": "array", - "items": { - "type": "integer" - } - }, - "logprob": { - "type": "number" - } - }, - "additionalProperties": false, - "required": [ - "token", - "logprob" - ], - "title": "OpenAITopLogProb", - "description": "The top log probability for a token from an OpenAI-compatible chat completion response." - }, - "OpenAIUserMessageParam": { - "type": "object", - "properties": { - "role": { - "type": "string", - "const": "user", - "default": "user", - "description": "Must be \"user\" to identify this as a user message" - }, - "content": { - "oneOf": [ - { - "type": "string" - }, - { - "type": "array", - "items": { - "$ref": "#/components/schemas/OpenAIChatCompletionContentPartParam" - } - } - ], - "description": "The content of the message, which can include text and other media" - }, - "name": { - "type": "string", - "description": "(Optional) The name of the user message participant." - } - }, - "additionalProperties": false, - "required": [ - "role", - "content" - ], - "title": "OpenAIUserMessageParam", - "description": "A message from the user in an OpenAI-compatible chat completion request." - }, - "OpenAICompletionWithInputMessages": { - "type": "object", - "properties": { - "id": { - "type": "string", - "description": "The ID of the chat completion" - }, - "choices": { - "type": "array", - "items": { - "$ref": "#/components/schemas/OpenAIChoice" - }, - "description": "List of choices" - }, - "object": { - "type": "string", - "const": "chat.completion", - "default": "chat.completion", - "description": "The object type, which will be \"chat.completion\"" - }, - "created": { - "type": "integer", - "description": "The Unix timestamp in seconds when the chat completion was created" - }, - "model": { - "type": "string", - "description": "The model that was used to generate the chat completion" - }, - "input_messages": { - "type": "array", - "items": { - "$ref": "#/components/schemas/OpenAIMessageParam" - } - } - }, - "additionalProperties": false, - "required": [ - "id", - "choices", - "object", - "created", - "model", - "input_messages" - ], - "title": "OpenAICompletionWithInputMessages" - }, "DataSource": { "oneOf": [ { @@ -9637,6 +8586,9 @@ "toolgroup_id": { "type": "string" }, + "tool_host": { + "$ref": "#/components/schemas/ToolHost" + }, "description": { "type": "string" }, @@ -9678,11 +8630,21 @@ "provider_id", "type", "toolgroup_id", + "tool_host", "description", "parameters" ], "title": "Tool" }, + "ToolHost": { + "type": "string", + "enum": [ + "distribution", + "client", + "model_context_protocol" + ], + "title": "ToolHost" + }, "ToolGroup": { "type": "object", "properties": { @@ -10063,8 +9025,7 @@ "type": "object", "properties": { "content": { - "$ref": "#/components/schemas/InterleavedContent", - "description": "The content of the chunk, which can be interleaved text, images, or other types." + "$ref": "#/components/schemas/InterleavedContent" }, "metadata": { "type": "object", @@ -10089,15 +9050,7 @@ "type": "object" } ] - }, - "description": "Metadata associated with the chunk, such as document ID, source, or other relevant information." - }, - "embedding": { - "type": "array", - "items": { - "type": "number" - }, - "description": "Optional embedding for the chunk. If not provided, it will be computed later." + } } }, "additionalProperties": false, @@ -10105,10 +9058,9 @@ "content", "metadata" ], - "title": "Chunk", - "description": "A chunk of content that can be inserted into a vector database." + "title": "Chunk" }, - "description": "The chunks to insert. Each `Chunk` should contain content which can be interleaved text, images, or other types. `metadata`: `dict[str, Any]` and `embedding`: `List[float]` are optional. If `metadata` is provided, you configure how Llama Stack formats the chunk during generation. If `embedding` is not provided, it will be computed later." + "description": "The chunks to insert." }, "ttl_seconds": { "type": "integer", @@ -10394,91 +9346,6 @@ ], "title": "ListBenchmarksResponse" }, - "Order": { - "type": "string", - "enum": [ - "asc", - "desc" - ], - "title": "Order" - }, - "ListOpenAIChatCompletionResponse": { - "type": "object", - "properties": { - "data": { - "type": "array", - "items": { - "type": "object", - "properties": { - "id": { - "type": "string", - "description": "The ID of the chat completion" - }, - "choices": { - "type": "array", - "items": { - "$ref": "#/components/schemas/OpenAIChoice" - }, - "description": "List of choices" - }, - "object": { - "type": "string", - "const": "chat.completion", - "default": "chat.completion", - "description": "The object type, which will be \"chat.completion\"" - }, - "created": { - "type": "integer", - "description": "The Unix timestamp in seconds when the chat completion was created" - }, - "model": { - "type": "string", - "description": "The model that was used to generate the chat completion" - }, - "input_messages": { - "type": "array", - "items": { - "$ref": "#/components/schemas/OpenAIMessageParam" - } - } - }, - "additionalProperties": false, - "required": [ - "id", - "choices", - "object", - "created", - "model", - "input_messages" - ], - "title": "OpenAICompletionWithInputMessages" - } - }, - "has_more": { - "type": "boolean" - }, - "first_id": { - "type": "string" - }, - "last_id": { - "type": "string" - }, - "object": { - "type": "string", - "const": "list", - "default": "list" - } - }, - "additionalProperties": false, - "required": [ - "data", - "has_more", - "first_id", - "last_id", - "object" - ], - "title": "ListOpenAIChatCompletionResponse" - }, "ListDatasetsResponse": { "type": "object", "properties": { @@ -10529,130 +9396,6 @@ ], "title": "ListModelsResponse" }, - "ListOpenAIResponseInputItem": { - "type": "object", - "properties": { - "data": { - "type": "array", - "items": { - "$ref": "#/components/schemas/OpenAIResponseInput" - } - }, - "object": { - "type": "string", - "const": "list", - "default": "list" - } - }, - "additionalProperties": false, - "required": [ - "data", - "object" - ], - "title": "ListOpenAIResponseInputItem" - }, - "ListOpenAIResponseObject": { - "type": "object", - "properties": { - "data": { - "type": "array", - "items": { - "$ref": "#/components/schemas/OpenAIResponseObjectWithInput" - } - }, - "has_more": { - "type": "boolean" - }, - "first_id": { - "type": "string" - }, - "last_id": { - "type": "string" - }, - "object": { - "type": "string", - "const": "list", - "default": "list" - } - }, - "additionalProperties": false, - "required": [ - "data", - "has_more", - "first_id", - "last_id", - "object" - ], - "title": "ListOpenAIResponseObject" - }, - "OpenAIResponseObjectWithInput": { - "type": "object", - "properties": { - "created_at": { - "type": "integer" - }, - "error": { - "$ref": "#/components/schemas/OpenAIResponseError" - }, - "id": { - "type": "string" - }, - "model": { - "type": "string" - }, - "object": { - "type": "string", - "const": "response", - "default": "response" - }, - "output": { - "type": "array", - "items": { - "$ref": "#/components/schemas/OpenAIResponseOutput" - } - }, - "parallel_tool_calls": { - "type": "boolean", - "default": false - }, - "previous_response_id": { - "type": "string" - }, - "status": { - "type": "string" - }, - "temperature": { - "type": "number" - }, - "top_p": { - "type": "number" - }, - "truncation": { - "type": "string" - }, - "user": { - "type": "string" - }, - "input": { - "type": "array", - "items": { - "$ref": "#/components/schemas/OpenAIResponseInput" - } - } - }, - "additionalProperties": false, - "required": [ - "created_at", - "id", - "model", - "object", - "output", - "parallel_tool_calls", - "status", - "input" - ], - "title": "OpenAIResponseObjectWithInput" - }, "ListProvidersResponse": { "type": "object", "properties": { @@ -11113,6 +9856,192 @@ ], "title": "LogEventRequest" }, + "OpenAIAssistantMessageParam": { + "type": "object", + "properties": { + "role": { + "type": "string", + "const": "assistant", + "default": "assistant", + "description": "Must be \"assistant\" to identify this as the model's response" + }, + "content": { + "oneOf": [ + { + "type": "string" + }, + { + "type": "array", + "items": { + "$ref": "#/components/schemas/OpenAIChatCompletionContentPartParam" + } + } + ], + "description": "The content of the model's response" + }, + "name": { + "type": "string", + "description": "(Optional) The name of the assistant message participant." + }, + "tool_calls": { + "type": "array", + "items": { + "$ref": "#/components/schemas/OpenAIChatCompletionToolCall" + }, + "description": "List of tool calls. Each tool call is an OpenAIChatCompletionToolCall object." + } + }, + "additionalProperties": false, + "required": [ + "role" + ], + "title": "OpenAIAssistantMessageParam", + "description": "A message containing the model's (assistant) response in an OpenAI-compatible chat completion request." + }, + "OpenAIChatCompletionContentPartImageParam": { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "image_url", + "default": "image_url" + }, + "image_url": { + "$ref": "#/components/schemas/OpenAIImageURL" + } + }, + "additionalProperties": false, + "required": [ + "type", + "image_url" + ], + "title": "OpenAIChatCompletionContentPartImageParam" + }, + "OpenAIChatCompletionContentPartParam": { + "oneOf": [ + { + "$ref": "#/components/schemas/OpenAIChatCompletionContentPartTextParam" + }, + { + "$ref": "#/components/schemas/OpenAIChatCompletionContentPartImageParam" + } + ], + "discriminator": { + "propertyName": "type", + "mapping": { + "text": "#/components/schemas/OpenAIChatCompletionContentPartTextParam", + "image_url": "#/components/schemas/OpenAIChatCompletionContentPartImageParam" + } + } + }, + "OpenAIChatCompletionContentPartTextParam": { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "text", + "default": "text" + }, + "text": { + "type": "string" + } + }, + "additionalProperties": false, + "required": [ + "type", + "text" + ], + "title": "OpenAIChatCompletionContentPartTextParam" + }, + "OpenAIChatCompletionToolCall": { + "type": "object", + "properties": { + "index": { + "type": "integer" + }, + "id": { + "type": "string" + }, + "type": { + "type": "string", + "const": "function", + "default": "function" + }, + "function": { + "$ref": "#/components/schemas/OpenAIChatCompletionToolCallFunction" + } + }, + "additionalProperties": false, + "required": [ + "type" + ], + "title": "OpenAIChatCompletionToolCall" + }, + "OpenAIChatCompletionToolCallFunction": { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "arguments": { + "type": "string" + } + }, + "additionalProperties": false, + "title": "OpenAIChatCompletionToolCallFunction" + }, + "OpenAIDeveloperMessageParam": { + "type": "object", + "properties": { + "role": { + "type": "string", + "const": "developer", + "default": "developer", + "description": "Must be \"developer\" to identify this as a developer message" + }, + "content": { + "oneOf": [ + { + "type": "string" + }, + { + "type": "array", + "items": { + "$ref": "#/components/schemas/OpenAIChatCompletionContentPartParam" + } + } + ], + "description": "The content of the developer message" + }, + "name": { + "type": "string", + "description": "(Optional) The name of the developer message participant." + } + }, + "additionalProperties": false, + "required": [ + "role", + "content" + ], + "title": "OpenAIDeveloperMessageParam", + "description": "A message from the developer in an OpenAI-compatible chat completion request." + }, + "OpenAIImageURL": { + "type": "object", + "properties": { + "url": { + "type": "string" + }, + "detail": { + "type": "string" + } + }, + "additionalProperties": false, + "required": [ + "url" + ], + "title": "OpenAIImageURL" + }, "OpenAIJSONSchema": { "type": "object", "properties": { @@ -11157,6 +10086,35 @@ ], "title": "OpenAIJSONSchema" }, + "OpenAIMessageParam": { + "oneOf": [ + { + "$ref": "#/components/schemas/OpenAIUserMessageParam" + }, + { + "$ref": "#/components/schemas/OpenAISystemMessageParam" + }, + { + "$ref": "#/components/schemas/OpenAIAssistantMessageParam" + }, + { + "$ref": "#/components/schemas/OpenAIToolMessageParam" + }, + { + "$ref": "#/components/schemas/OpenAIDeveloperMessageParam" + } + ], + "discriminator": { + "propertyName": "role", + "mapping": { + "user": "#/components/schemas/OpenAIUserMessageParam", + "system": "#/components/schemas/OpenAISystemMessageParam", + "assistant": "#/components/schemas/OpenAIAssistantMessageParam", + "tool": "#/components/schemas/OpenAIToolMessageParam", + "developer": "#/components/schemas/OpenAIDeveloperMessageParam" + } + } + }, "OpenAIResponseFormatJSONObject": { "type": "object", "properties": { @@ -11227,6 +10185,115 @@ ], "title": "OpenAIResponseFormatText" }, + "OpenAISystemMessageParam": { + "type": "object", + "properties": { + "role": { + "type": "string", + "const": "system", + "default": "system", + "description": "Must be \"system\" to identify this as a system message" + }, + "content": { + "oneOf": [ + { + "type": "string" + }, + { + "type": "array", + "items": { + "$ref": "#/components/schemas/OpenAIChatCompletionContentPartParam" + } + } + ], + "description": "The content of the \"system prompt\". If multiple system messages are provided, they are concatenated. The underlying Llama Stack code may also add other system messages (for example, for formatting tool definitions)." + }, + "name": { + "type": "string", + "description": "(Optional) The name of the system message participant." + } + }, + "additionalProperties": false, + "required": [ + "role", + "content" + ], + "title": "OpenAISystemMessageParam", + "description": "A system message providing instructions or context to the model." + }, + "OpenAIToolMessageParam": { + "type": "object", + "properties": { + "role": { + "type": "string", + "const": "tool", + "default": "tool", + "description": "Must be \"tool\" to identify this as a tool response" + }, + "tool_call_id": { + "type": "string", + "description": "Unique identifier for the tool call this response is for" + }, + "content": { + "oneOf": [ + { + "type": "string" + }, + { + "type": "array", + "items": { + "$ref": "#/components/schemas/OpenAIChatCompletionContentPartParam" + } + } + ], + "description": "The response content from the tool" + } + }, + "additionalProperties": false, + "required": [ + "role", + "tool_call_id", + "content" + ], + "title": "OpenAIToolMessageParam", + "description": "A message representing the result of a tool invocation in an OpenAI-compatible chat completion request." + }, + "OpenAIUserMessageParam": { + "type": "object", + "properties": { + "role": { + "type": "string", + "const": "user", + "default": "user", + "description": "Must be \"user\" to identify this as a user message" + }, + "content": { + "oneOf": [ + { + "type": "string" + }, + { + "type": "array", + "items": { + "$ref": "#/components/schemas/OpenAIChatCompletionContentPartParam" + } + } + ], + "description": "The content of the message, which can include text and other media" + }, + "name": { + "type": "string", + "description": "(Optional) The name of the user message participant." + } + }, + "additionalProperties": false, + "required": [ + "role", + "content" + ], + "title": "OpenAIUserMessageParam", + "description": "A message from the user in an OpenAI-compatible chat completion request." + }, "OpenaiChatCompletionRequest": { "type": "object", "properties": { @@ -11556,6 +10623,35 @@ "title": "OpenAIChatCompletionChunk", "description": "Chunk from a streaming response to an OpenAI-compatible chat completion request." }, + "OpenAIChoice": { + "type": "object", + "properties": { + "message": { + "$ref": "#/components/schemas/OpenAIMessageParam", + "description": "The message from the model" + }, + "finish_reason": { + "type": "string", + "description": "The reason the model stopped generating" + }, + "index": { + "type": "integer", + "description": "The index of the choice" + }, + "logprobs": { + "$ref": "#/components/schemas/OpenAIChoiceLogprobs", + "description": "(Optional) The log probabilities for the tokens in the message" + } + }, + "additionalProperties": false, + "required": [ + "message", + "finish_reason", + "index" + ], + "title": "OpenAIChoice", + "description": "A choice from an OpenAI-compatible chat completion response." + }, "OpenAIChoiceDelta": { "type": "object", "properties": { @@ -11583,6 +10679,28 @@ "title": "OpenAIChoiceDelta", "description": "A delta from an OpenAI-compatible chat completion streaming response." }, + "OpenAIChoiceLogprobs": { + "type": "object", + "properties": { + "content": { + "type": "array", + "items": { + "$ref": "#/components/schemas/OpenAITokenLogProb" + }, + "description": "(Optional) The log probabilities for the tokens in the message" + }, + "refusal": { + "type": "array", + "items": { + "$ref": "#/components/schemas/OpenAITokenLogProb" + }, + "description": "(Optional) The log probabilities for the tokens in the message" + } + }, + "additionalProperties": false, + "title": "OpenAIChoiceLogprobs", + "description": "The log probabilities for the tokens in the message from an OpenAI-compatible chat completion response." + }, "OpenAIChunkChoice": { "type": "object", "properties": { @@ -11612,6 +10730,61 @@ "title": "OpenAIChunkChoice", "description": "A chunk choice from an OpenAI-compatible chat completion streaming response." }, + "OpenAITokenLogProb": { + "type": "object", + "properties": { + "token": { + "type": "string" + }, + "bytes": { + "type": "array", + "items": { + "type": "integer" + } + }, + "logprob": { + "type": "number" + }, + "top_logprobs": { + "type": "array", + "items": { + "$ref": "#/components/schemas/OpenAITopLogProb" + } + } + }, + "additionalProperties": false, + "required": [ + "token", + "logprob", + "top_logprobs" + ], + "title": "OpenAITokenLogProb", + "description": "The log probability for a token from an OpenAI-compatible chat completion response." + }, + "OpenAITopLogProb": { + "type": "object", + "properties": { + "token": { + "type": "string" + }, + "bytes": { + "type": "array", + "items": { + "type": "integer" + } + }, + "logprob": { + "type": "number" + } + }, + "additionalProperties": false, + "required": [ + "token", + "logprob" + ], + "title": "OpenAITopLogProb", + "description": "The top log probability for a token from an OpenAI-compatible chat completion response." + }, "OpenaiCompletionRequest": { "type": "object", "properties": { @@ -11820,139 +10993,6 @@ "title": "OpenAICompletionChoice", "description": "A choice from an OpenAI-compatible completion response." }, - "OpenaiEmbeddingsRequest": { - "type": "object", - "properties": { - "model": { - "type": "string", - "description": "The identifier of the model to use. The model must be an embedding model registered with Llama Stack and available via the /models endpoint." - }, - "input": { - "oneOf": [ - { - "type": "string" - }, - { - "type": "array", - "items": { - "type": "string" - } - } - ], - "description": "Input text to embed, encoded as a string or array of strings. To embed multiple inputs in a single request, pass an array of strings." - }, - "encoding_format": { - "type": "string", - "description": "(Optional) The format to return the embeddings in. Can be either \"float\" or \"base64\". Defaults to \"float\"." - }, - "dimensions": { - "type": "integer", - "description": "(Optional) The number of dimensions the resulting output embeddings should have. Only supported in text-embedding-3 and later models." - }, - "user": { - "type": "string", - "description": "(Optional) A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse." - } - }, - "additionalProperties": false, - "required": [ - "model", - "input" - ], - "title": "OpenaiEmbeddingsRequest" - }, - "OpenAIEmbeddingData": { - "type": "object", - "properties": { - "object": { - "type": "string", - "const": "embedding", - "default": "embedding", - "description": "The object type, which will be \"embedding\"" - }, - "embedding": { - "oneOf": [ - { - "type": "array", - "items": { - "type": "number" - } - }, - { - "type": "string" - } - ], - "description": "The embedding vector as a list of floats (when encoding_format=\"float\") or as a base64-encoded string (when encoding_format=\"base64\")" - }, - "index": { - "type": "integer", - "description": "The index of the embedding in the input list" - } - }, - "additionalProperties": false, - "required": [ - "object", - "embedding", - "index" - ], - "title": "OpenAIEmbeddingData", - "description": "A single embedding data object from an OpenAI-compatible embeddings response." - }, - "OpenAIEmbeddingUsage": { - "type": "object", - "properties": { - "prompt_tokens": { - "type": "integer", - "description": "The number of tokens in the input" - }, - "total_tokens": { - "type": "integer", - "description": "The total number of tokens used" - } - }, - "additionalProperties": false, - "required": [ - "prompt_tokens", - "total_tokens" - ], - "title": "OpenAIEmbeddingUsage", - "description": "Usage information for an OpenAI-compatible embeddings response." - }, - "OpenAIEmbeddingsResponse": { - "type": "object", - "properties": { - "object": { - "type": "string", - "const": "list", - "default": "list", - "description": "The object type, which will be \"list\"" - }, - "data": { - "type": "array", - "items": { - "$ref": "#/components/schemas/OpenAIEmbeddingData" - }, - "description": "List of embedding data objects" - }, - "model": { - "type": "string", - "description": "The model that was used to generate the embeddings" - }, - "usage": { - "$ref": "#/components/schemas/OpenAIEmbeddingUsage", - "description": "Usage information" - } - }, - "additionalProperties": false, - "required": [ - "object", - "data", - "model", - "usage" - ], - "title": "OpenAIEmbeddingsResponse", - "description": "Response from an OpenAI-compatible embeddings request." - }, "OpenAIModel": { "type": "object", "properties": { @@ -12323,10 +11363,6 @@ "type": "string", "default": "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n", "description": "Template for formatting each retrieved chunk in the context. Available placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk content string), {metadata} (chunk metadata dict). Default: \"Result {index}\\nContent: {chunk.content}\\nMetadata: {metadata}\\n\"" - }, - "mode": { - "type": "string", - "description": "Search mode for retrieval—either \"vector\" or \"keyword\". Default \"vector\"." } }, "additionalProperties": false, @@ -12471,8 +11507,7 @@ "type": "object", "properties": { "content": { - "$ref": "#/components/schemas/InterleavedContent", - "description": "The content of the chunk, which can be interleaved text, images, or other types." + "$ref": "#/components/schemas/InterleavedContent" }, "metadata": { "type": "object", @@ -12497,15 +11532,7 @@ "type": "object" } ] - }, - "description": "Metadata associated with the chunk, such as document ID, source, or other relevant information." - }, - "embedding": { - "type": "array", - "items": { - "type": "number" - }, - "description": "Optional embedding for the chunk. If not provided, it will be computed later." + } } }, "additionalProperties": false, @@ -12513,8 +11540,7 @@ "content", "metadata" ], - "title": "Chunk", - "description": "A chunk of content that can be inserted into a vector database." + "title": "Chunk" } }, "scores": { diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index 7638c3cbd..a988e0eab 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -349,53 +349,6 @@ paths: $ref: '#/components/schemas/CreateAgentTurnRequest' required: true /v1/openai/v1/responses: - get: - responses: - '200': - description: A ListOpenAIResponseObject. - content: - application/json: - schema: - $ref: '#/components/schemas/ListOpenAIResponseObject' - '400': - $ref: '#/components/responses/BadRequest400' - '429': - $ref: >- - #/components/responses/TooManyRequests429 - '500': - $ref: >- - #/components/responses/InternalServerError500 - default: - $ref: '#/components/responses/DefaultError' - tags: - - Agents - description: List all OpenAI responses. - parameters: - - name: after - in: query - description: The ID of the last response to return. - required: false - schema: - type: string - - name: limit - in: query - description: The number of responses to return. - required: false - schema: - type: integer - - name: model - in: query - description: The model to filter responses by. - required: false - schema: - type: string - - name: order - in: query - description: >- - The order to sort responses by when sorted by created_at ('asc' or 'desc'). - required: false - schema: - $ref: '#/components/schemas/Order' post: responses: '200': @@ -874,35 +827,6 @@ paths: required: true schema: type: string - /v1/openai/v1/chat/completions/{completion_id}: - get: - responses: - '200': - description: A OpenAICompletionWithInputMessages. - content: - application/json: - schema: - $ref: '#/components/schemas/OpenAICompletionWithInputMessages' - '400': - $ref: '#/components/responses/BadRequest400' - '429': - $ref: >- - #/components/responses/TooManyRequests429 - '500': - $ref: >- - #/components/responses/InternalServerError500 - default: - $ref: '#/components/responses/DefaultError' - tags: - - Inference - description: Describe a chat completion by its ID. - parameters: - - name: completion_id - in: path - description: ID of the chat completion. - required: true - schema: - type: string /v1/datasets/{dataset_id}: get: responses: @@ -1010,7 +934,7 @@ paths: required: true schema: type: string - /v1/openai/v1/responses/{response_id}: + /v1/openai/v1/responses/{id}: get: responses: '200': @@ -1033,7 +957,7 @@ paths: - Agents description: Retrieve an OpenAI response by its ID. parameters: - - name: response_id + - name: id in: path description: >- The ID of the OpenAI response to retrieve. @@ -1871,89 +1795,6 @@ paths: schema: $ref: '#/components/schemas/RegisterBenchmarkRequest' required: true - /v1/openai/v1/chat/completions: - get: - responses: - '200': - description: A ListOpenAIChatCompletionResponse. - content: - application/json: - schema: - $ref: '#/components/schemas/ListOpenAIChatCompletionResponse' - '400': - $ref: '#/components/responses/BadRequest400' - '429': - $ref: >- - #/components/responses/TooManyRequests429 - '500': - $ref: >- - #/components/responses/InternalServerError500 - default: - $ref: '#/components/responses/DefaultError' - tags: - - Inference - description: List all chat completions. - parameters: - - name: after - in: query - description: >- - The ID of the last chat completion to return. - required: false - schema: - type: string - - name: limit - in: query - description: >- - The maximum number of chat completions to return. - required: false - schema: - type: integer - - name: model - in: query - description: The model to filter by. - required: false - schema: - type: string - - name: order - in: query - description: >- - The order to sort the chat completions by: "asc" or "desc". Defaults to - "desc". - required: false - schema: - $ref: '#/components/schemas/Order' - post: - responses: - '200': - description: An OpenAIChatCompletion. - content: - application/json: - schema: - oneOf: - - $ref: '#/components/schemas/OpenAIChatCompletion' - - $ref: '#/components/schemas/OpenAIChatCompletionChunk' - '400': - $ref: '#/components/responses/BadRequest400' - '429': - $ref: >- - #/components/responses/TooManyRequests429 - '500': - $ref: >- - #/components/responses/InternalServerError500 - default: - $ref: '#/components/responses/DefaultError' - tags: - - Inference - description: >- - Generate an OpenAI-compatible chat completion for the given messages using - the specified model. - parameters: [] - requestBody: - content: - application/json: - schema: - $ref: '#/components/schemas/OpenaiChatCompletionRequest' - required: true /v1/datasets: get: responses: @@ -2085,75 +1926,6 @@ paths: schema: $ref: '#/components/schemas/RegisterModelRequest' required: true - /v1/openai/v1/responses/{response_id}/input_items: - get: - responses: - '200': - description: An ListOpenAIResponseInputItem. - content: - application/json: - schema: - $ref: '#/components/schemas/ListOpenAIResponseInputItem' - '400': - $ref: '#/components/responses/BadRequest400' - '429': - $ref: >- - #/components/responses/TooManyRequests429 - '500': - $ref: >- - #/components/responses/InternalServerError500 - default: - $ref: '#/components/responses/DefaultError' - tags: - - Agents - description: >- - List input items for a given OpenAI response. - parameters: - - name: response_id - in: path - description: >- - The ID of the response to retrieve input items for. - required: true - schema: - type: string - - name: after - in: query - description: >- - An item ID to list items after, used for pagination. - required: false - schema: - type: string - - name: before - in: query - description: >- - An item ID to list items before, used for pagination. - required: false - schema: - type: string - - name: include - in: query - description: >- - Additional fields to include in the response. - required: false - schema: - type: array - items: - type: string - - name: limit - in: query - description: >- - A limit on the number of objects to be returned. Limit can range between - 1 and 100, and the default is 20. - required: false - schema: - type: integer - - name: order - in: query - description: >- - The order to return the input items in. Default is desc. - required: false - schema: - $ref: '#/components/schemas/Order' /v1/providers: get: responses: @@ -2489,6 +2261,39 @@ paths: schema: $ref: '#/components/schemas/LogEventRequest' required: true + /v1/openai/v1/chat/completions: + post: + responses: + '200': + description: An OpenAIChatCompletion. + content: + application/json: + schema: + oneOf: + - $ref: '#/components/schemas/OpenAIChatCompletion' + - $ref: '#/components/schemas/OpenAIChatCompletionChunk' + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - Inference + description: >- + Generate an OpenAI-compatible chat completion for the given messages using + the specified model. + parameters: [] + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/OpenaiChatCompletionRequest' + required: true /v1/openai/v1/completions: post: responses: @@ -2520,38 +2325,6 @@ paths: schema: $ref: '#/components/schemas/OpenaiCompletionRequest' required: true - /v1/openai/v1/embeddings: - post: - responses: - '200': - description: >- - An OpenAIEmbeddingsResponse containing the embeddings. - content: - application/json: - schema: - $ref: '#/components/schemas/OpenAIEmbeddingsResponse' - '400': - $ref: '#/components/responses/BadRequest400' - '429': - $ref: >- - #/components/responses/TooManyRequests429 - '500': - $ref: >- - #/components/responses/InternalServerError500 - default: - $ref: '#/components/responses/DefaultError' - tags: - - Inference - description: >- - Generate OpenAI-compatible embeddings for the given input using the specified - model. - parameters: [] - requestBody: - content: - application/json: - schema: - $ref: '#/components/schemas/OpenaiEmbeddingsRequest' - required: true /v1/openai/v1/models: get: responses: @@ -4910,14 +4683,12 @@ components: - $ref: '#/components/schemas/OpenAIResponseInputToolWebSearch' - $ref: '#/components/schemas/OpenAIResponseInputToolFileSearch' - $ref: '#/components/schemas/OpenAIResponseInputToolFunction' - - $ref: '#/components/schemas/OpenAIResponseInputToolMCP' discriminator: propertyName: type mapping: web_search: '#/components/schemas/OpenAIResponseInputToolWebSearch' file_search: '#/components/schemas/OpenAIResponseInputToolFileSearch' function: '#/components/schemas/OpenAIResponseInputToolFunction' - mcp: '#/components/schemas/OpenAIResponseInputToolMCP' OpenAIResponseInputToolFileSearch: type: object properties: @@ -4972,66 +4743,6 @@ components: - type - name title: OpenAIResponseInputToolFunction - OpenAIResponseInputToolMCP: - type: object - properties: - type: - type: string - const: mcp - default: mcp - server_label: - type: string - server_url: - type: string - headers: - type: object - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - require_approval: - oneOf: - - type: string - const: always - - type: string - const: never - - type: object - properties: - always: - type: array - items: - type: string - never: - type: array - items: - type: string - additionalProperties: false - title: ApprovalFilter - default: never - allowed_tools: - oneOf: - - type: array - items: - type: string - - type: object - properties: - tool_names: - type: array - items: - type: string - additionalProperties: false - title: AllowedToolsFilter - additionalProperties: false - required: - - type - - server_label - - server_url - - require_approval - title: OpenAIResponseInputToolMCP OpenAIResponseInputToolWebSearch: type: object properties: @@ -5107,12 +4818,12 @@ components: "OpenAIResponseOutputMessageFunctionToolCall": type: object properties: + arguments: + type: string call_id: type: string name: type: string - arguments: - type: string type: type: string const: function_call @@ -5123,10 +4834,12 @@ components: type: string additionalProperties: false required: + - arguments - call_id - name - - arguments - type + - id + - status title: >- OpenAIResponseOutputMessageFunctionToolCall "OpenAIResponseOutputMessageWebSearchToolCall": @@ -5160,8 +4873,6 @@ components: model: type: string description: The underlying LLM used for completions. - instructions: - type: string previous_response_id: type: string description: >- @@ -5244,95 +4955,20 @@ components: - $ref: '#/components/schemas/OpenAIResponseMessage' - $ref: '#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall' - $ref: '#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall' - - $ref: '#/components/schemas/OpenAIResponseOutputMessageMCPCall' - - $ref: '#/components/schemas/OpenAIResponseOutputMessageMCPListTools' discriminator: propertyName: type mapping: message: '#/components/schemas/OpenAIResponseMessage' web_search_call: '#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall' function_call: '#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall' - mcp_call: '#/components/schemas/OpenAIResponseOutputMessageMCPCall' - mcp_list_tools: '#/components/schemas/OpenAIResponseOutputMessageMCPListTools' - OpenAIResponseOutputMessageMCPCall: - type: object - properties: - id: - type: string - type: - type: string - const: mcp_call - default: mcp_call - arguments: - type: string - name: - type: string - server_label: - type: string - error: - type: string - output: - type: string - additionalProperties: false - required: - - id - - type - - arguments - - name - - server_label - title: OpenAIResponseOutputMessageMCPCall - OpenAIResponseOutputMessageMCPListTools: - type: object - properties: - id: - type: string - type: - type: string - const: mcp_list_tools - default: mcp_list_tools - server_label: - type: string - tools: - type: array - items: - type: object - properties: - input_schema: - type: object - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - name: - type: string - description: - type: string - additionalProperties: false - required: - - input_schema - - name - title: MCPListToolsTool - additionalProperties: false - required: - - id - - type - - server_label - - tools - title: OpenAIResponseOutputMessageMCPListTools OpenAIResponseObjectStream: oneOf: - $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseCreated' - - $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseOutputTextDelta' - $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseCompleted' discriminator: propertyName: type mapping: response.created: '#/components/schemas/OpenAIResponseObjectStreamResponseCreated' - response.output_text.delta: '#/components/schemas/OpenAIResponseObjectStreamResponseOutputTextDelta' response.completed: '#/components/schemas/OpenAIResponseObjectStreamResponseCompleted' "OpenAIResponseObjectStreamResponseCompleted": type: object @@ -5364,33 +5000,6 @@ components: - type title: >- OpenAIResponseObjectStreamResponseCreated - "OpenAIResponseObjectStreamResponseOutputTextDelta": - type: object - properties: - content_index: - type: integer - delta: - type: string - item_id: - type: string - output_index: - type: integer - sequence_number: - type: integer - type: - type: string - const: response.output_text.delta - default: response.output_text.delta - additionalProperties: false - required: - - content_index - - delta - - item_id - - output_index - - sequence_number - - type - title: >- - OpenAIResponseObjectStreamResponseOutputTextDelta CreateUploadSessionRequest: type: object properties: @@ -5870,369 +5479,6 @@ components: - scoring_functions - metadata title: Benchmark - OpenAIAssistantMessageParam: - type: object - properties: - role: - type: string - const: assistant - default: assistant - description: >- - Must be "assistant" to identify this as the model's response - content: - oneOf: - - type: string - - type: array - items: - $ref: '#/components/schemas/OpenAIChatCompletionContentPartParam' - description: The content of the model's response - name: - type: string - description: >- - (Optional) The name of the assistant message participant. - tool_calls: - type: array - items: - $ref: '#/components/schemas/OpenAIChatCompletionToolCall' - description: >- - List of tool calls. Each tool call is an OpenAIChatCompletionToolCall - object. - additionalProperties: false - required: - - role - title: OpenAIAssistantMessageParam - description: >- - A message containing the model's (assistant) response in an OpenAI-compatible - chat completion request. - "OpenAIChatCompletionContentPartImageParam": - type: object - properties: - type: - type: string - const: image_url - default: image_url - image_url: - $ref: '#/components/schemas/OpenAIImageURL' - additionalProperties: false - required: - - type - - image_url - title: >- - OpenAIChatCompletionContentPartImageParam - OpenAIChatCompletionContentPartParam: - oneOf: - - $ref: '#/components/schemas/OpenAIChatCompletionContentPartTextParam' - - $ref: '#/components/schemas/OpenAIChatCompletionContentPartImageParam' - discriminator: - propertyName: type - mapping: - text: '#/components/schemas/OpenAIChatCompletionContentPartTextParam' - image_url: '#/components/schemas/OpenAIChatCompletionContentPartImageParam' - OpenAIChatCompletionContentPartTextParam: - type: object - properties: - type: - type: string - const: text - default: text - text: - type: string - additionalProperties: false - required: - - type - - text - title: OpenAIChatCompletionContentPartTextParam - OpenAIChatCompletionToolCall: - type: object - properties: - index: - type: integer - id: - type: string - type: - type: string - const: function - default: function - function: - $ref: '#/components/schemas/OpenAIChatCompletionToolCallFunction' - additionalProperties: false - required: - - type - title: OpenAIChatCompletionToolCall - OpenAIChatCompletionToolCallFunction: - type: object - properties: - name: - type: string - arguments: - type: string - additionalProperties: false - title: OpenAIChatCompletionToolCallFunction - OpenAIChoice: - type: object - properties: - message: - $ref: '#/components/schemas/OpenAIMessageParam' - description: The message from the model - finish_reason: - type: string - description: The reason the model stopped generating - index: - type: integer - description: The index of the choice - logprobs: - $ref: '#/components/schemas/OpenAIChoiceLogprobs' - description: >- - (Optional) The log probabilities for the tokens in the message - additionalProperties: false - required: - - message - - finish_reason - - index - title: OpenAIChoice - description: >- - A choice from an OpenAI-compatible chat completion response. - OpenAIChoiceLogprobs: - type: object - properties: - content: - type: array - items: - $ref: '#/components/schemas/OpenAITokenLogProb' - description: >- - (Optional) The log probabilities for the tokens in the message - refusal: - type: array - items: - $ref: '#/components/schemas/OpenAITokenLogProb' - description: >- - (Optional) The log probabilities for the tokens in the message - additionalProperties: false - title: OpenAIChoiceLogprobs - description: >- - The log probabilities for the tokens in the message from an OpenAI-compatible - chat completion response. - OpenAIDeveloperMessageParam: - type: object - properties: - role: - type: string - const: developer - default: developer - description: >- - Must be "developer" to identify this as a developer message - content: - oneOf: - - type: string - - type: array - items: - $ref: '#/components/schemas/OpenAIChatCompletionContentPartParam' - description: The content of the developer message - name: - type: string - description: >- - (Optional) The name of the developer message participant. - additionalProperties: false - required: - - role - - content - title: OpenAIDeveloperMessageParam - description: >- - A message from the developer in an OpenAI-compatible chat completion request. - OpenAIImageURL: - type: object - properties: - url: - type: string - detail: - type: string - additionalProperties: false - required: - - url - title: OpenAIImageURL - OpenAIMessageParam: - oneOf: - - $ref: '#/components/schemas/OpenAIUserMessageParam' - - $ref: '#/components/schemas/OpenAISystemMessageParam' - - $ref: '#/components/schemas/OpenAIAssistantMessageParam' - - $ref: '#/components/schemas/OpenAIToolMessageParam' - - $ref: '#/components/schemas/OpenAIDeveloperMessageParam' - discriminator: - propertyName: role - mapping: - user: '#/components/schemas/OpenAIUserMessageParam' - system: '#/components/schemas/OpenAISystemMessageParam' - assistant: '#/components/schemas/OpenAIAssistantMessageParam' - tool: '#/components/schemas/OpenAIToolMessageParam' - developer: '#/components/schemas/OpenAIDeveloperMessageParam' - OpenAISystemMessageParam: - type: object - properties: - role: - type: string - const: system - default: system - description: >- - Must be "system" to identify this as a system message - content: - oneOf: - - type: string - - type: array - items: - $ref: '#/components/schemas/OpenAIChatCompletionContentPartParam' - description: >- - The content of the "system prompt". If multiple system messages are provided, - they are concatenated. The underlying Llama Stack code may also add other - system messages (for example, for formatting tool definitions). - name: - type: string - description: >- - (Optional) The name of the system message participant. - additionalProperties: false - required: - - role - - content - title: OpenAISystemMessageParam - description: >- - A system message providing instructions or context to the model. - OpenAITokenLogProb: - type: object - properties: - token: - type: string - bytes: - type: array - items: - type: integer - logprob: - type: number - top_logprobs: - type: array - items: - $ref: '#/components/schemas/OpenAITopLogProb' - additionalProperties: false - required: - - token - - logprob - - top_logprobs - title: OpenAITokenLogProb - description: >- - The log probability for a token from an OpenAI-compatible chat completion - response. - OpenAIToolMessageParam: - type: object - properties: - role: - type: string - const: tool - default: tool - description: >- - Must be "tool" to identify this as a tool response - tool_call_id: - type: string - description: >- - Unique identifier for the tool call this response is for - content: - oneOf: - - type: string - - type: array - items: - $ref: '#/components/schemas/OpenAIChatCompletionContentPartParam' - description: The response content from the tool - additionalProperties: false - required: - - role - - tool_call_id - - content - title: OpenAIToolMessageParam - description: >- - A message representing the result of a tool invocation in an OpenAI-compatible - chat completion request. - OpenAITopLogProb: - type: object - properties: - token: - type: string - bytes: - type: array - items: - type: integer - logprob: - type: number - additionalProperties: false - required: - - token - - logprob - title: OpenAITopLogProb - description: >- - The top log probability for a token from an OpenAI-compatible chat completion - response. - OpenAIUserMessageParam: - type: object - properties: - role: - type: string - const: user - default: user - description: >- - Must be "user" to identify this as a user message - content: - oneOf: - - type: string - - type: array - items: - $ref: '#/components/schemas/OpenAIChatCompletionContentPartParam' - description: >- - The content of the message, which can include text and other media - name: - type: string - description: >- - (Optional) The name of the user message participant. - additionalProperties: false - required: - - role - - content - title: OpenAIUserMessageParam - description: >- - A message from the user in an OpenAI-compatible chat completion request. - OpenAICompletionWithInputMessages: - type: object - properties: - id: - type: string - description: The ID of the chat completion - choices: - type: array - items: - $ref: '#/components/schemas/OpenAIChoice' - description: List of choices - object: - type: string - const: chat.completion - default: chat.completion - description: >- - The object type, which will be "chat.completion" - created: - type: integer - description: >- - The Unix timestamp in seconds when the chat completion was created - model: - type: string - description: >- - The model that was used to generate the chat completion - input_messages: - type: array - items: - $ref: '#/components/schemas/OpenAIMessageParam' - additionalProperties: false - required: - - id - - choices - - object - - created - - model - - input_messages - title: OpenAICompletionWithInputMessages DataSource: oneOf: - $ref: '#/components/schemas/URIDataSource' @@ -6774,6 +6020,8 @@ components: default: tool toolgroup_id: type: string + tool_host: + $ref: '#/components/schemas/ToolHost' description: type: string parameters: @@ -6796,9 +6044,17 @@ components: - provider_id - type - toolgroup_id + - tool_host - description - parameters title: Tool + ToolHost: + type: string + enum: + - distribution + - client + - model_context_protocol + title: ToolHost ToolGroup: type: object properties: @@ -7056,9 +6312,6 @@ components: properties: content: $ref: '#/components/schemas/InterleavedContent' - description: >- - The content of the chunk, which can be interleaved text, images, - or other types. metadata: type: object additionalProperties: @@ -7069,29 +6322,12 @@ components: - type: string - type: array - type: object - description: >- - Metadata associated with the chunk, such as document ID, source, - or other relevant information. - embedding: - type: array - items: - type: number - description: >- - Optional embedding for the chunk. If not provided, it will be computed - later. additionalProperties: false required: - content - metadata title: Chunk - description: >- - A chunk of content that can be inserted into a vector database. - description: >- - The chunks to insert. Each `Chunk` should contain content which can be - interleaved text, images, or other types. `metadata`: `dict[str, Any]` - and `embedding`: `List[float]` are optional. If `metadata` is provided, - you configure how Llama Stack formats the chunk during generation. If - `embedding` is not provided, it will be computed later. + description: The chunks to insert. ttl_seconds: type: integer description: The time to live of the chunks. @@ -7261,73 +6497,6 @@ components: required: - data title: ListBenchmarksResponse - Order: - type: string - enum: - - asc - - desc - title: Order - ListOpenAIChatCompletionResponse: - type: object - properties: - data: - type: array - items: - type: object - properties: - id: - type: string - description: The ID of the chat completion - choices: - type: array - items: - $ref: '#/components/schemas/OpenAIChoice' - description: List of choices - object: - type: string - const: chat.completion - default: chat.completion - description: >- - The object type, which will be "chat.completion" - created: - type: integer - description: >- - The Unix timestamp in seconds when the chat completion was created - model: - type: string - description: >- - The model that was used to generate the chat completion - input_messages: - type: array - items: - $ref: '#/components/schemas/OpenAIMessageParam' - additionalProperties: false - required: - - id - - choices - - object - - created - - model - - input_messages - title: OpenAICompletionWithInputMessages - has_more: - type: boolean - first_id: - type: string - last_id: - type: string - object: - type: string - const: list - default: list - additionalProperties: false - required: - - data - - has_more - - first_id - - last_id - - object - title: ListOpenAIChatCompletionResponse ListDatasetsResponse: type: object properties: @@ -7364,96 +6533,6 @@ components: required: - data title: ListModelsResponse - ListOpenAIResponseInputItem: - type: object - properties: - data: - type: array - items: - $ref: '#/components/schemas/OpenAIResponseInput' - object: - type: string - const: list - default: list - additionalProperties: false - required: - - data - - object - title: ListOpenAIResponseInputItem - ListOpenAIResponseObject: - type: object - properties: - data: - type: array - items: - $ref: '#/components/schemas/OpenAIResponseObjectWithInput' - has_more: - type: boolean - first_id: - type: string - last_id: - type: string - object: - type: string - const: list - default: list - additionalProperties: false - required: - - data - - has_more - - first_id - - last_id - - object - title: ListOpenAIResponseObject - OpenAIResponseObjectWithInput: - type: object - properties: - created_at: - type: integer - error: - $ref: '#/components/schemas/OpenAIResponseError' - id: - type: string - model: - type: string - object: - type: string - const: response - default: response - output: - type: array - items: - $ref: '#/components/schemas/OpenAIResponseOutput' - parallel_tool_calls: - type: boolean - default: false - previous_response_id: - type: string - status: - type: string - temperature: - type: number - top_p: - type: number - truncation: - type: string - user: - type: string - input: - type: array - items: - $ref: '#/components/schemas/OpenAIResponseInput' - additionalProperties: false - required: - - created_at - - id - - model - - object - - output - - parallel_tool_calls - - status - - input - title: OpenAIResponseObjectWithInput ListProvidersResponse: type: object properties: @@ -7756,6 +6835,142 @@ components: - event - ttl_seconds title: LogEventRequest + OpenAIAssistantMessageParam: + type: object + properties: + role: + type: string + const: assistant + default: assistant + description: >- + Must be "assistant" to identify this as the model's response + content: + oneOf: + - type: string + - type: array + items: + $ref: '#/components/schemas/OpenAIChatCompletionContentPartParam' + description: The content of the model's response + name: + type: string + description: >- + (Optional) The name of the assistant message participant. + tool_calls: + type: array + items: + $ref: '#/components/schemas/OpenAIChatCompletionToolCall' + description: >- + List of tool calls. Each tool call is an OpenAIChatCompletionToolCall + object. + additionalProperties: false + required: + - role + title: OpenAIAssistantMessageParam + description: >- + A message containing the model's (assistant) response in an OpenAI-compatible + chat completion request. + "OpenAIChatCompletionContentPartImageParam": + type: object + properties: + type: + type: string + const: image_url + default: image_url + image_url: + $ref: '#/components/schemas/OpenAIImageURL' + additionalProperties: false + required: + - type + - image_url + title: >- + OpenAIChatCompletionContentPartImageParam + OpenAIChatCompletionContentPartParam: + oneOf: + - $ref: '#/components/schemas/OpenAIChatCompletionContentPartTextParam' + - $ref: '#/components/schemas/OpenAIChatCompletionContentPartImageParam' + discriminator: + propertyName: type + mapping: + text: '#/components/schemas/OpenAIChatCompletionContentPartTextParam' + image_url: '#/components/schemas/OpenAIChatCompletionContentPartImageParam' + OpenAIChatCompletionContentPartTextParam: + type: object + properties: + type: + type: string + const: text + default: text + text: + type: string + additionalProperties: false + required: + - type + - text + title: OpenAIChatCompletionContentPartTextParam + OpenAIChatCompletionToolCall: + type: object + properties: + index: + type: integer + id: + type: string + type: + type: string + const: function + default: function + function: + $ref: '#/components/schemas/OpenAIChatCompletionToolCallFunction' + additionalProperties: false + required: + - type + title: OpenAIChatCompletionToolCall + OpenAIChatCompletionToolCallFunction: + type: object + properties: + name: + type: string + arguments: + type: string + additionalProperties: false + title: OpenAIChatCompletionToolCallFunction + OpenAIDeveloperMessageParam: + type: object + properties: + role: + type: string + const: developer + default: developer + description: >- + Must be "developer" to identify this as a developer message + content: + oneOf: + - type: string + - type: array + items: + $ref: '#/components/schemas/OpenAIChatCompletionContentPartParam' + description: The content of the developer message + name: + type: string + description: >- + (Optional) The name of the developer message participant. + additionalProperties: false + required: + - role + - content + title: OpenAIDeveloperMessageParam + description: >- + A message from the developer in an OpenAI-compatible chat completion request. + OpenAIImageURL: + type: object + properties: + url: + type: string + detail: + type: string + additionalProperties: false + required: + - url + title: OpenAIImageURL OpenAIJSONSchema: type: object properties: @@ -7779,6 +6994,21 @@ components: required: - name title: OpenAIJSONSchema + OpenAIMessageParam: + oneOf: + - $ref: '#/components/schemas/OpenAIUserMessageParam' + - $ref: '#/components/schemas/OpenAISystemMessageParam' + - $ref: '#/components/schemas/OpenAIAssistantMessageParam' + - $ref: '#/components/schemas/OpenAIToolMessageParam' + - $ref: '#/components/schemas/OpenAIDeveloperMessageParam' + discriminator: + propertyName: role + mapping: + user: '#/components/schemas/OpenAIUserMessageParam' + system: '#/components/schemas/OpenAISystemMessageParam' + assistant: '#/components/schemas/OpenAIAssistantMessageParam' + tool: '#/components/schemas/OpenAIToolMessageParam' + developer: '#/components/schemas/OpenAIDeveloperMessageParam' OpenAIResponseFormatJSONObject: type: object properties: @@ -7826,6 +7056,93 @@ components: required: - type title: OpenAIResponseFormatText + OpenAISystemMessageParam: + type: object + properties: + role: + type: string + const: system + default: system + description: >- + Must be "system" to identify this as a system message + content: + oneOf: + - type: string + - type: array + items: + $ref: '#/components/schemas/OpenAIChatCompletionContentPartParam' + description: >- + The content of the "system prompt". If multiple system messages are provided, + they are concatenated. The underlying Llama Stack code may also add other + system messages (for example, for formatting tool definitions). + name: + type: string + description: >- + (Optional) The name of the system message participant. + additionalProperties: false + required: + - role + - content + title: OpenAISystemMessageParam + description: >- + A system message providing instructions or context to the model. + OpenAIToolMessageParam: + type: object + properties: + role: + type: string + const: tool + default: tool + description: >- + Must be "tool" to identify this as a tool response + tool_call_id: + type: string + description: >- + Unique identifier for the tool call this response is for + content: + oneOf: + - type: string + - type: array + items: + $ref: '#/components/schemas/OpenAIChatCompletionContentPartParam' + description: The response content from the tool + additionalProperties: false + required: + - role + - tool_call_id + - content + title: OpenAIToolMessageParam + description: >- + A message representing the result of a tool invocation in an OpenAI-compatible + chat completion request. + OpenAIUserMessageParam: + type: object + properties: + role: + type: string + const: user + default: user + description: >- + Must be "user" to identify this as a user message + content: + oneOf: + - type: string + - type: array + items: + $ref: '#/components/schemas/OpenAIChatCompletionContentPartParam' + description: >- + The content of the message, which can include text and other media + name: + type: string + description: >- + (Optional) The name of the user message participant. + additionalProperties: false + required: + - role + - content + title: OpenAIUserMessageParam + description: >- + A message from the user in an OpenAI-compatible chat completion request. OpenaiChatCompletionRequest: type: object properties: @@ -8039,6 +7356,30 @@ components: title: OpenAIChatCompletionChunk description: >- Chunk from a streaming response to an OpenAI-compatible chat completion request. + OpenAIChoice: + type: object + properties: + message: + $ref: '#/components/schemas/OpenAIMessageParam' + description: The message from the model + finish_reason: + type: string + description: The reason the model stopped generating + index: + type: integer + description: The index of the choice + logprobs: + $ref: '#/components/schemas/OpenAIChoiceLogprobs' + description: >- + (Optional) The log probabilities for the tokens in the message + additionalProperties: false + required: + - message + - finish_reason + - index + title: OpenAIChoice + description: >- + A choice from an OpenAI-compatible chat completion response. OpenAIChoiceDelta: type: object properties: @@ -8060,6 +7401,26 @@ components: title: OpenAIChoiceDelta description: >- A delta from an OpenAI-compatible chat completion streaming response. + OpenAIChoiceLogprobs: + type: object + properties: + content: + type: array + items: + $ref: '#/components/schemas/OpenAITokenLogProb' + description: >- + (Optional) The log probabilities for the tokens in the message + refusal: + type: array + items: + $ref: '#/components/schemas/OpenAITokenLogProb' + description: >- + (Optional) The log probabilities for the tokens in the message + additionalProperties: false + title: OpenAIChoiceLogprobs + description: >- + The log probabilities for the tokens in the message from an OpenAI-compatible + chat completion response. OpenAIChunkChoice: type: object properties: @@ -8084,6 +7445,49 @@ components: title: OpenAIChunkChoice description: >- A chunk choice from an OpenAI-compatible chat completion streaming response. + OpenAITokenLogProb: + type: object + properties: + token: + type: string + bytes: + type: array + items: + type: integer + logprob: + type: number + top_logprobs: + type: array + items: + $ref: '#/components/schemas/OpenAITopLogProb' + additionalProperties: false + required: + - token + - logprob + - top_logprobs + title: OpenAITokenLogProb + description: >- + The log probability for a token from an OpenAI-compatible chat completion + response. + OpenAITopLogProb: + type: object + properties: + token: + type: string + bytes: + type: array + items: + type: integer + logprob: + type: number + additionalProperties: false + required: + - token + - logprob + title: OpenAITopLogProb + description: >- + The top log probability for a token from an OpenAI-compatible chat completion + response. OpenaiCompletionRequest: type: object properties: @@ -8229,118 +7633,6 @@ components: title: OpenAICompletionChoice description: >- A choice from an OpenAI-compatible completion response. - OpenaiEmbeddingsRequest: - type: object - properties: - model: - type: string - description: >- - The identifier of the model to use. The model must be an embedding model - registered with Llama Stack and available via the /models endpoint. - input: - oneOf: - - type: string - - type: array - items: - type: string - description: >- - Input text to embed, encoded as a string or array of strings. To embed - multiple inputs in a single request, pass an array of strings. - encoding_format: - type: string - description: >- - (Optional) The format to return the embeddings in. Can be either "float" - or "base64". Defaults to "float". - dimensions: - type: integer - description: >- - (Optional) The number of dimensions the resulting output embeddings should - have. Only supported in text-embedding-3 and later models. - user: - type: string - description: >- - (Optional) A unique identifier representing your end-user, which can help - OpenAI to monitor and detect abuse. - additionalProperties: false - required: - - model - - input - title: OpenaiEmbeddingsRequest - OpenAIEmbeddingData: - type: object - properties: - object: - type: string - const: embedding - default: embedding - description: >- - The object type, which will be "embedding" - embedding: - oneOf: - - type: array - items: - type: number - - type: string - description: >- - The embedding vector as a list of floats (when encoding_format="float") - or as a base64-encoded string (when encoding_format="base64") - index: - type: integer - description: >- - The index of the embedding in the input list - additionalProperties: false - required: - - object - - embedding - - index - title: OpenAIEmbeddingData - description: >- - A single embedding data object from an OpenAI-compatible embeddings response. - OpenAIEmbeddingUsage: - type: object - properties: - prompt_tokens: - type: integer - description: The number of tokens in the input - total_tokens: - type: integer - description: The total number of tokens used - additionalProperties: false - required: - - prompt_tokens - - total_tokens - title: OpenAIEmbeddingUsage - description: >- - Usage information for an OpenAI-compatible embeddings response. - OpenAIEmbeddingsResponse: - type: object - properties: - object: - type: string - const: list - default: list - description: The object type, which will be "list" - data: - type: array - items: - $ref: '#/components/schemas/OpenAIEmbeddingData' - description: List of embedding data objects - model: - type: string - description: >- - The model that was used to generate the embeddings - usage: - $ref: '#/components/schemas/OpenAIEmbeddingUsage' - description: Usage information - additionalProperties: false - required: - - object - - data - - model - - usage - title: OpenAIEmbeddingsResponse - description: >- - Response from an OpenAI-compatible embeddings request. OpenAIModel: type: object properties: @@ -8608,10 +7900,6 @@ components: placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk content string), {metadata} (chunk metadata dict). Default: "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n" - mode: - type: string - description: >- - Search mode for retrieval—either "vector" or "keyword". Default "vector". additionalProperties: false required: - query_generator_config @@ -8701,9 +7989,6 @@ components: properties: content: $ref: '#/components/schemas/InterleavedContent' - description: >- - The content of the chunk, which can be interleaved text, images, - or other types. metadata: type: object additionalProperties: @@ -8714,23 +7999,11 @@ components: - type: string - type: array - type: object - description: >- - Metadata associated with the chunk, such as document ID, source, - or other relevant information. - embedding: - type: array - items: - type: number - description: >- - Optional embedding for the chunk. If not provided, it will be computed - later. additionalProperties: false required: - content - metadata title: Chunk - description: >- - A chunk of content that can be inserted into a vector database. scores: type: array items: diff --git a/docs/notebooks/Llama_Stack_Benchmark_Evals.ipynb b/docs/notebooks/Llama_Stack_Benchmark_Evals.ipynb index 93f78d268..413b693d1 100644 --- a/docs/notebooks/Llama_Stack_Benchmark_Evals.ipynb +++ b/docs/notebooks/Llama_Stack_Benchmark_Evals.ipynb @@ -38,8 +38,12 @@ "cell_type": "code", "execution_count": null, "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, "collapsed": true, - "id": "O9pGVlPIjpix" + "id": "O9pGVlPIjpix", + "outputId": "e1fbe723-ae31-4630-eb80-4c4f6476d56f" }, "outputs": [], "source": [ @@ -51,8 +55,12 @@ "cell_type": "code", "execution_count": null, "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, "collapsed": true, - "id": "JQpLUSNjlGAM" + "id": "JQpLUSNjlGAM", + "outputId": "2f7fec97-5511-4cae-d51e-6d262fbca19c" }, "outputs": [], "source": [ @@ -62,7 +70,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -693,7 +701,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": { "id": "TC_IwIAQo4q-" }, @@ -706,10 +714,116 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 305, + "referenced_widgets": [ + "feb82e061ee44283b4a46be858ef4cd7", + "78a2d2d4ee3f42f3be42ef4baa298561", + "ba5e6ca09f174ef3a348453cf5cfc24a", + "74b58e4647644c9daf9af488942fdaf4", + "d56e218958a041e286e80f24e400ab0b", + "cab80632b7564a9eb59583e09573c1ee", + "10c0d50d7c204de0b4c8e8f4d3ec0af5", + "626ef2f811ae4e119a0e85cebe92b91d", + "aef4172d916f40b0ab4ed09104e10f24", + "25529e7fd57049d2816d31f696eab1fd", + "093bdcb608cf4b4fa37b0032a3915187", + "c788d4e9e1e24dca9b6503689df9b631", + "d1587e2144bf46299c1bdec3ea96e4e7", + "500a072c09da41759cb2c942a16d8429", + "9785009392934e3bbb229e8781667cbc", + "84570fe2c2a54a068fb9b8cbc8b041a1", + "f9e579c58e3f4ae0bbb721dffa33bf0a", + "737116977f474ec0b68d88a40fd1086c", + "e6d6e516cd03452297d80c36376855dd", + "6ae0fadb3aeb4be18a9ab3279fb23145", + "fa4800a506ac480984d58933580df086", + "117468099dbc42fdaafc08207eaac7ab", + "44f585990aa244d8ba61f892dc1ccc1c", + "4fc59928a0544f95a4438b37d19ca437", + "fb644d47049f495397d0e60597c86ea3", + "78632694ff694442bc3fefc2cac2cbf5", + "083fd2549abd4b03bd41d8b92ec28f42", + "611d6472a58d419583acc416767a4c90", + "98c5ce434cff454eaaa3f0fd3498183a", + "3d0344a9cc744e369da1b6b7ea1b3be8", + "c452ccbf47a44073aee710175f707a7d", + "0218397c573e4b28bfb4ffa66464d50f", + "9b01bcd6e5174be2af19f457047017c8", + "4fed5720f30b4b3cbbc606a4f25e223b", + "6fa866b9971542739b0ed26d90ceac80", + "fe7553b513954cc68c427b5d9d260b33", + "4bc266d49a6741a88350e029d101425b", + "da57445f98e7427589962836c2b4287e", + "ad1fb86cc1f94fd9911eda03cf4a3783", + "fdefb51ad4c4418b98c5826126558011", + "179d41b80dc841e8a440482516b8bca5", + "22b1ecd2eff14770bcfb0c62d3d4213f", + "47f876cf41484d55b645e1e99337423a", + "340fbbb4982c460992c88885e79b47db", + "9659140487ca4d3ea799196d2c1ecf61", + "52150fd494d24eea89b5232077509355", + "04acde771d0a46699e1de07d9733d1a3", + "7b98103300814f3caea84266263b95a2", + "75f06408071c494f934bb909b84110d1", + "b09b2690894749339a9172e5ad0a9b75", + "cbed38801163438d891879b756f5baab", + "399a6417b23e4593bb244ec3abb6b46d", + "53a321f36b0d4e08a74a5bcfbd04434b", + "b8c0c8aaac0d4032bf5c673a43d084ab", + "d1f32499fa3f4795b92361637e23a9bb", + "c06f9a090fb54c74b947634bf6d11fa8", + "82991dcc80f14af9bd2e95f705980676", + "cd832e3842b945aabbb327856053f261", + "93ee645d54f34acdb0d15092d4a6f0d1", + "b77fe05bbcf84cdc8ef85b264ccd35f6", + "e17d286a965a49cfb8d5bf885865cb1e", + "ca015c1a0c1449e68edb282462435a3f", + "2932b06afde9468a976eb6bfb072b80e", + "d027c807ddc04f89bec41dc05fde7718", + "4ff3a6aaf706460bbba01b248b93000e", + "bfd75a39f0154c30adbaad1e2ca0f1e2", + "4f788a7920c346f3b42900825bd6711a", + "8e9358ec7d474808bb96c13e13489c67", + "f0dfeee2a8d64dedbc8ef55ad4e69932", + "9437b707bf1a4847a50aafeb4252dab5", + "f255707788704a76bd1651f26a22402d", + "3b70fa4e43ef4951862e119378c3c501", + "6c0a6a7fa8ca4e1c961a36305f0e7638", + "201bd914f9884e46b8e6df9d9900a6e8", + "f53b7ada01084e73bba6e14a95e2a534", + "d2029292327b488db02fd123ee2b75af", + "3e26bc24a3e44b4582f57913bdf98de4", + "9d2b6eabf7e14436b72bbf374b4a2a0a", + "b5d7cb5a6157449a850ef0e12e3d3eb7", + "c245d316bf9e44dabe5bfd1e47fc8d2e", + "963cf422ca894d82b0dd94c6165d41bf", + "78d0e2aa93674bbeb42bff87a23cce9b", + "12c6f1180eeb4e9eb9037ea5dd24ec8e", + "017a81d7160240a398947545963856f5", + "1cf8eeb8d81c4e8a8e95dd43296a78b9", + "5b0b5a3f79e94c51aae48fe0dd34ba0e", + "f5b34a743ce54fb591f25b04a2651d65", + "dec6399e2c5341aead66e1674d3e6c72", + "24e48376a72940679989a39a40bbe7f6", + "484df732051540859bc7ac9cecadc83c", + "4b33b1db50c34a2fa957d81a71a2a47f", + "e51d501e2f994baba40345ad632eabee", + "631a85e420b64e8cb6915af59c5ce08a", + "70af9cb2838c4a92bd67f8cb5c98d97f", + "158115266c284c4f8dbce3586151cbf1", + "ce5019b36cde44c58c5f596dbb59a2f8", + "b90d660ca8584ba1815a3c66b420c079", + "7c4d1de626784a59a7e0a33c24086186", + "21cf0e35ecd845a8b5e7c5ce241cf177" + ] + }, "collapsed": true, - "id": "DJkmoG2kq1_P" + "id": "DJkmoG2kq1_P", + "outputId": "8493ee59-c6ff-4bb6-d787-f295944db1cf" }, "outputs": [], "source": [ @@ -734,7 +848,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": { "colab": { "base_uri": "https://localhost:8080/", @@ -835,7 +949,7 @@ "\n", "client.benchmarks.register(\n", " benchmark_id=\"meta-reference::mmmu\",\n", - " # Note: we can use any value as `dataset_id` because we'll be using the `evaluate_rows` API which accepts the\n", + " # Note: we can use any value as `dataset_id` because we'll be using the `evaluate_rows` API which accepts the \n", " # `input_rows` argument and does not fetch data from the dataset.\n", " dataset_id=f\"mmmu-{subset}-{split}\",\n", " # Note: for the same reason as above, we can use any value as `scoring_functions`.\n", @@ -880,7 +994,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "metadata": { "id": "HXmZf3Ymw-aX" }, @@ -900,7 +1014,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "metadata": { "id": "Gc8azb4Rxr5J" }, @@ -914,7 +1028,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "metadata": { "colab": { "base_uri": "https://localhost:8080/", @@ -1054,7 +1168,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 27, "metadata": { "colab": { "base_uri": "https://localhost:8080/", @@ -1179,9 +1293,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "lxc9-eXYK5Av" - }, + "metadata": {}, "outputs": [], "source": [] } @@ -1210,6 +1322,3088 @@ "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.16" + }, + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "017a81d7160240a398947545963856f5": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "0218397c573e4b28bfb4ffa66464d50f": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "04acde771d0a46699e1de07d9733d1a3": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_399a6417b23e4593bb244ec3abb6b46d", + "max": 453677660, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_53a321f36b0d4e08a74a5bcfbd04434b", + "value": 453677660 + } + }, + "083fd2549abd4b03bd41d8b92ec28f42": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "093bdcb608cf4b4fa37b0032a3915187": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "10c0d50d7c204de0b4c8e8f4d3ec0af5": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "117468099dbc42fdaafc08207eaac7ab": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "12c6f1180eeb4e9eb9037ea5dd24ec8e": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "158115266c284c4f8dbce3586151cbf1": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "179d41b80dc841e8a440482516b8bca5": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "1cf8eeb8d81c4e8a8e95dd43296a78b9": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "201bd914f9884e46b8e6df9d9900a6e8": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "21cf0e35ecd845a8b5e7c5ce241cf177": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "22b1ecd2eff14770bcfb0c62d3d4213f": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "24e48376a72940679989a39a40bbe7f6": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_484df732051540859bc7ac9cecadc83c", + "IPY_MODEL_4b33b1db50c34a2fa957d81a71a2a47f", + "IPY_MODEL_e51d501e2f994baba40345ad632eabee" + ], + "layout": "IPY_MODEL_631a85e420b64e8cb6915af59c5ce08a" + } + }, + "25529e7fd57049d2816d31f696eab1fd": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "2932b06afde9468a976eb6bfb072b80e": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "340fbbb4982c460992c88885e79b47db": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "399a6417b23e4593bb244ec3abb6b46d": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "3b70fa4e43ef4951862e119378c3c501": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "3d0344a9cc744e369da1b6b7ea1b3be8": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "3e26bc24a3e44b4582f57913bdf98de4": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "44f585990aa244d8ba61f892dc1ccc1c": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_4fc59928a0544f95a4438b37d19ca437", + "IPY_MODEL_fb644d47049f495397d0e60597c86ea3", + "IPY_MODEL_78632694ff694442bc3fefc2cac2cbf5" + ], + "layout": "IPY_MODEL_083fd2549abd4b03bd41d8b92ec28f42" + } + }, + "47f876cf41484d55b645e1e99337423a": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "484df732051540859bc7ac9cecadc83c": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_70af9cb2838c4a92bd67f8cb5c98d97f", + "placeholder": "​", + "style": "IPY_MODEL_158115266c284c4f8dbce3586151cbf1", + "value": "Generating test split: 100%" + } + }, + "4b33b1db50c34a2fa957d81a71a2a47f": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_ce5019b36cde44c58c5f596dbb59a2f8", + "max": 287, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_b90d660ca8584ba1815a3c66b420c079", + "value": 287 + } + }, + "4bc266d49a6741a88350e029d101425b": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_47f876cf41484d55b645e1e99337423a", + "placeholder": "​", + "style": "IPY_MODEL_340fbbb4982c460992c88885e79b47db", + "value": " 461M/461M [00:11<00:00, 31.2MB/s]" + } + }, + "4f788a7920c346f3b42900825bd6711a": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_8e9358ec7d474808bb96c13e13489c67", + "IPY_MODEL_f0dfeee2a8d64dedbc8ef55ad4e69932", + "IPY_MODEL_9437b707bf1a4847a50aafeb4252dab5" + ], + "layout": "IPY_MODEL_f255707788704a76bd1651f26a22402d" + } + }, + "4fc59928a0544f95a4438b37d19ca437": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_611d6472a58d419583acc416767a4c90", + "placeholder": "​", + "style": "IPY_MODEL_98c5ce434cff454eaaa3f0fd3498183a", + "value": "validation-00000-of-00001.parquet: 100%" + } + }, + "4fed5720f30b4b3cbbc606a4f25e223b": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_6fa866b9971542739b0ed26d90ceac80", + "IPY_MODEL_fe7553b513954cc68c427b5d9d260b33", + "IPY_MODEL_4bc266d49a6741a88350e029d101425b" + ], + "layout": "IPY_MODEL_da57445f98e7427589962836c2b4287e" + } + }, + "4ff3a6aaf706460bbba01b248b93000e": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "500a072c09da41759cb2c942a16d8429": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_e6d6e516cd03452297d80c36376855dd", + "max": 29453850, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_6ae0fadb3aeb4be18a9ab3279fb23145", + "value": 29453850 + } + }, + "52150fd494d24eea89b5232077509355": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_b09b2690894749339a9172e5ad0a9b75", + "placeholder": "​", + "style": "IPY_MODEL_cbed38801163438d891879b756f5baab", + "value": "test-00001-of-00003.parquet: 100%" + } + }, + "53a321f36b0d4e08a74a5bcfbd04434b": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "5b0b5a3f79e94c51aae48fe0dd34ba0e": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "611d6472a58d419583acc416767a4c90": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "626ef2f811ae4e119a0e85cebe92b91d": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "631a85e420b64e8cb6915af59c5ce08a": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "6ae0fadb3aeb4be18a9ab3279fb23145": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "6c0a6a7fa8ca4e1c961a36305f0e7638": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "6fa866b9971542739b0ed26d90ceac80": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_ad1fb86cc1f94fd9911eda03cf4a3783", + "placeholder": "​", + "style": "IPY_MODEL_fdefb51ad4c4418b98c5826126558011", + "value": "test-00000-of-00003.parquet: 100%" + } + }, + "70af9cb2838c4a92bd67f8cb5c98d97f": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "737116977f474ec0b68d88a40fd1086c": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "74b58e4647644c9daf9af488942fdaf4": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_25529e7fd57049d2816d31f696eab1fd", + "placeholder": "​", + "style": "IPY_MODEL_093bdcb608cf4b4fa37b0032a3915187", + "value": " 36.0k/36.0k [00:00<00:00, 1.29MB/s]" + } + }, + "75f06408071c494f934bb909b84110d1": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "78632694ff694442bc3fefc2cac2cbf5": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_0218397c573e4b28bfb4ffa66464d50f", + "placeholder": "​", + "style": "IPY_MODEL_9b01bcd6e5174be2af19f457047017c8", + "value": " 165M/165M [00:03<00:00, 42.9MB/s]" + } + }, + "78a2d2d4ee3f42f3be42ef4baa298561": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_cab80632b7564a9eb59583e09573c1ee", + "placeholder": "​", + "style": "IPY_MODEL_10c0d50d7c204de0b4c8e8f4d3ec0af5", + "value": "README.md: 100%" + } + }, + "78d0e2aa93674bbeb42bff87a23cce9b": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "7b98103300814f3caea84266263b95a2": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_b8c0c8aaac0d4032bf5c673a43d084ab", + "placeholder": "​", + "style": "IPY_MODEL_d1f32499fa3f4795b92361637e23a9bb", + "value": " 454M/454M [00:11<00:00, 40.4MB/s]" + } + }, + "7c4d1de626784a59a7e0a33c24086186": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "82991dcc80f14af9bd2e95f705980676": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_e17d286a965a49cfb8d5bf885865cb1e", + "placeholder": "​", + "style": "IPY_MODEL_ca015c1a0c1449e68edb282462435a3f", + "value": "test-00002-of-00003.parquet: 100%" + } + }, + "84570fe2c2a54a068fb9b8cbc8b041a1": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "8e9358ec7d474808bb96c13e13489c67": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_3b70fa4e43ef4951862e119378c3c501", + "placeholder": "​", + "style": "IPY_MODEL_6c0a6a7fa8ca4e1c961a36305f0e7638", + "value": "Generating dev split: 100%" + } + }, + "93ee645d54f34acdb0d15092d4a6f0d1": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_4ff3a6aaf706460bbba01b248b93000e", + "placeholder": "​", + "style": "IPY_MODEL_bfd75a39f0154c30adbaad1e2ca0f1e2", + "value": " 471M/471M [00:11<00:00, 41.5MB/s]" + } + }, + "9437b707bf1a4847a50aafeb4252dab5": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_d2029292327b488db02fd123ee2b75af", + "placeholder": "​", + "style": "IPY_MODEL_3e26bc24a3e44b4582f57913bdf98de4", + "value": " 5/5 [00:00<00:00,  8.03 examples/s]" + } + }, + "963cf422ca894d82b0dd94c6165d41bf": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_f5b34a743ce54fb591f25b04a2651d65", + "placeholder": "​", + "style": "IPY_MODEL_dec6399e2c5341aead66e1674d3e6c72", + "value": " 30/30 [00:03<00:00,  8.23 examples/s]" + } + }, + "9659140487ca4d3ea799196d2c1ecf61": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_52150fd494d24eea89b5232077509355", + "IPY_MODEL_04acde771d0a46699e1de07d9733d1a3", + "IPY_MODEL_7b98103300814f3caea84266263b95a2" + ], + "layout": "IPY_MODEL_75f06408071c494f934bb909b84110d1" + } + }, + "9785009392934e3bbb229e8781667cbc": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_fa4800a506ac480984d58933580df086", + "placeholder": "​", + "style": "IPY_MODEL_117468099dbc42fdaafc08207eaac7ab", + "value": " 29.5M/29.5M [00:00<00:00, 36.5MB/s]" + } + }, + "98c5ce434cff454eaaa3f0fd3498183a": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "9b01bcd6e5174be2af19f457047017c8": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "9d2b6eabf7e14436b72bbf374b4a2a0a": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_b5d7cb5a6157449a850ef0e12e3d3eb7", + "IPY_MODEL_c245d316bf9e44dabe5bfd1e47fc8d2e", + "IPY_MODEL_963cf422ca894d82b0dd94c6165d41bf" + ], + "layout": "IPY_MODEL_78d0e2aa93674bbeb42bff87a23cce9b" + } + }, + "ad1fb86cc1f94fd9911eda03cf4a3783": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "aef4172d916f40b0ab4ed09104e10f24": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "b09b2690894749339a9172e5ad0a9b75": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "b5d7cb5a6157449a850ef0e12e3d3eb7": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_12c6f1180eeb4e9eb9037ea5dd24ec8e", + "placeholder": "​", + "style": "IPY_MODEL_017a81d7160240a398947545963856f5", + "value": "Generating validation split: 100%" + } + }, + "b77fe05bbcf84cdc8ef85b264ccd35f6": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "b8c0c8aaac0d4032bf5c673a43d084ab": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "b90d660ca8584ba1815a3c66b420c079": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "ba5e6ca09f174ef3a348453cf5cfc24a": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_626ef2f811ae4e119a0e85cebe92b91d", + "max": 36030, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_aef4172d916f40b0ab4ed09104e10f24", + "value": 36030 + } + }, + "bfd75a39f0154c30adbaad1e2ca0f1e2": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "c06f9a090fb54c74b947634bf6d11fa8": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_82991dcc80f14af9bd2e95f705980676", + "IPY_MODEL_cd832e3842b945aabbb327856053f261", + "IPY_MODEL_93ee645d54f34acdb0d15092d4a6f0d1" + ], + "layout": "IPY_MODEL_b77fe05bbcf84cdc8ef85b264ccd35f6" + } + }, + "c245d316bf9e44dabe5bfd1e47fc8d2e": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_1cf8eeb8d81c4e8a8e95dd43296a78b9", + "max": 30, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_5b0b5a3f79e94c51aae48fe0dd34ba0e", + "value": 30 + } + }, + "c452ccbf47a44073aee710175f707a7d": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "c788d4e9e1e24dca9b6503689df9b631": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_d1587e2144bf46299c1bdec3ea96e4e7", + "IPY_MODEL_500a072c09da41759cb2c942a16d8429", + "IPY_MODEL_9785009392934e3bbb229e8781667cbc" + ], + "layout": "IPY_MODEL_84570fe2c2a54a068fb9b8cbc8b041a1" + } + }, + "ca015c1a0c1449e68edb282462435a3f": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "cab80632b7564a9eb59583e09573c1ee": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "cbed38801163438d891879b756f5baab": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "cd832e3842b945aabbb327856053f261": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_2932b06afde9468a976eb6bfb072b80e", + "max": 470745176, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_d027c807ddc04f89bec41dc05fde7718", + "value": 470745176 + } + }, + "ce5019b36cde44c58c5f596dbb59a2f8": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "d027c807ddc04f89bec41dc05fde7718": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "d1587e2144bf46299c1bdec3ea96e4e7": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_f9e579c58e3f4ae0bbb721dffa33bf0a", + "placeholder": "​", + "style": "IPY_MODEL_737116977f474ec0b68d88a40fd1086c", + "value": "dev-00000-of-00001.parquet: 100%" + } + }, + "d1f32499fa3f4795b92361637e23a9bb": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "d2029292327b488db02fd123ee2b75af": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "d56e218958a041e286e80f24e400ab0b": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "da57445f98e7427589962836c2b4287e": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "dec6399e2c5341aead66e1674d3e6c72": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "e17d286a965a49cfb8d5bf885865cb1e": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "e51d501e2f994baba40345ad632eabee": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_7c4d1de626784a59a7e0a33c24086186", + "placeholder": "​", + "style": "IPY_MODEL_21cf0e35ecd845a8b5e7c5ce241cf177", + "value": " 287/287 [00:23<00:00, 12.48 examples/s]" + } + }, + "e6d6e516cd03452297d80c36376855dd": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "f0dfeee2a8d64dedbc8ef55ad4e69932": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_201bd914f9884e46b8e6df9d9900a6e8", + "max": 5, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_f53b7ada01084e73bba6e14a95e2a534", + "value": 5 + } + }, + "f255707788704a76bd1651f26a22402d": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "f53b7ada01084e73bba6e14a95e2a534": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "f5b34a743ce54fb591f25b04a2651d65": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "f9e579c58e3f4ae0bbb721dffa33bf0a": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "fa4800a506ac480984d58933580df086": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "fb644d47049f495397d0e60597c86ea3": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_3d0344a9cc744e369da1b6b7ea1b3be8", + "max": 165333397, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_c452ccbf47a44073aee710175f707a7d", + "value": 165333397 + } + }, + "fdefb51ad4c4418b98c5826126558011": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "fe7553b513954cc68c427b5d9d260b33": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_179d41b80dc841e8a440482516b8bca5", + "max": 461411018, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_22b1ecd2eff14770bcfb0c62d3d4213f", + "value": 461411018 + } + }, + "feb82e061ee44283b4a46be858ef4cd7": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_78a2d2d4ee3f42f3be42ef4baa298561", + "IPY_MODEL_ba5e6ca09f174ef3a348453cf5cfc24a", + "IPY_MODEL_74b58e4647644c9daf9af488942fdaf4" + ], + "layout": "IPY_MODEL_d56e218958a041e286e80f24e400ab0b" + } + } + } } }, "nbformat": 4, diff --git a/docs/openapi_generator/pyopenapi/generator.py b/docs/openapi_generator/pyopenapi/generator.py index 5b7a685c1..cc594d8d7 100644 --- a/docs/openapi_generator/pyopenapi/generator.py +++ b/docs/openapi_generator/pyopenapi/generator.py @@ -759,7 +759,7 @@ class Generator: ) return Operation( - tags=[getattr(op.defining_class, "API_NAMESPACE", op.defining_class.__name__)], + tags=[op.defining_class.__name__], summary=None, # summary=doc_string.short_description, description=description, @@ -805,8 +805,6 @@ class Generator: operation_tags: List[Tag] = [] for cls in endpoint_classes: doc_string = parse_type(cls) - if hasattr(cls, "API_NAMESPACE") and cls.API_NAMESPACE != cls.__name__: - continue operation_tags.append( Tag( name=cls.__name__, diff --git a/docs/readme.md b/docs/readme.md index c238c4720..b88a4738d 100644 --- a/docs/readme.md +++ b/docs/readme.md @@ -3,10 +3,10 @@ Here's a collection of comprehensive guides, examples, and resources for building AI applications with Llama Stack. For the complete documentation, visit our [ReadTheDocs page](https://llama-stack.readthedocs.io/en/latest/index.html). ## Render locally - -From the llama-stack root directory, run the following command to render the docs locally: ```bash -uv run --group docs sphinx-autobuild docs/source docs/build/html --write-all +pip install -r requirements.txt +cd docs +python -m sphinx_autobuild source _build ``` You can open up the docs in your browser at http://localhost:8000 diff --git a/docs/requirements.txt b/docs/requirements.txt new file mode 100644 index 000000000..6cd45c33b --- /dev/null +++ b/docs/requirements.txt @@ -0,0 +1,16 @@ +linkify +myst-parser +-e git+https://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme +sphinx==8.1.3 +sphinx-copybutton +sphinx-design +sphinx-pdj-theme +sphinx-rtd-theme>=1.0.0 +sphinx-tabs +sphinx_autobuild +sphinx_rtd_dark_mode +sphinxcontrib-mermaid +sphinxcontrib-openapi +sphinxcontrib-redoc +sphinxcontrib-video +tomli diff --git a/docs/source/building_applications/rag.md b/docs/source/building_applications/rag.md index 289c38991..dbe90a7fc 100644 --- a/docs/source/building_applications/rag.md +++ b/docs/source/building_applications/rag.md @@ -57,31 +57,6 @@ chunks = [ ] client.vector_io.insert(vector_db_id=vector_db_id, chunks=chunks) ``` - -#### Using Precomputed Embeddings -If you decide to precompute embeddings for your documents, you can insert them directly into the vector database by -including the embedding vectors in the chunk data. This is useful if you have a separate embedding service or if you -want to customize the ingestion process. -```python -chunks_with_embeddings = [ - { - "content": "First chunk of text", - "mime_type": "text/plain", - "embedding": [0.1, 0.2, 0.3, ...], # Your precomputed embedding vector - "metadata": {"document_id": "doc1", "section": "introduction"}, - }, - { - "content": "Second chunk of text", - "mime_type": "text/plain", - "embedding": [0.2, 0.3, 0.4, ...], # Your precomputed embedding vector - "metadata": {"document_id": "doc1", "section": "methodology"}, - }, -] -client.vector_io.insert(vector_db_id=vector_db_id, chunks=chunks_with_embeddings) -``` -When providing precomputed embeddings, ensure the embedding dimension matches the embedding_dimension specified when -registering the vector database. - ### Retrieval You can query the vector database to retrieve documents based on their embeddings. ```python diff --git a/docs/source/conf.py b/docs/source/conf.py index 6e59dbdfb..501a923dd 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -22,11 +22,7 @@ from docutils import nodes # Read version from pyproject.toml with Path(__file__).parent.parent.parent.joinpath("pyproject.toml").open("rb") as f: pypi_url = "https://pypi.org/pypi/llama-stack/json" - headers = { - 'User-Agent': 'pip/23.0.1 (python 3.11)', # Mimic pip's user agent - 'Accept': 'application/json' - } - version_tag = json.loads(requests.get(pypi_url, headers=headers).text)["info"]["version"] + version_tag = json.loads(requests.get(pypi_url).text)["info"]["version"] print(f"{version_tag=}") # generate the full link including text and url here @@ -57,6 +53,14 @@ myst_enable_extensions = ["colon_fence"] html_theme = "sphinx_rtd_theme" html_use_relative_paths = True + +# html_theme = "sphinx_pdj_theme" +# html_theme_path = [sphinx_pdj_theme.get_html_theme_path()] + +# html_theme = "pytorch_sphinx_theme" +# html_theme_path = [pytorch_sphinx_theme.get_html_theme_path()] + + templates_path = ["_templates"] exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] diff --git a/docs/source/distributions/building_distro.md b/docs/source/distributions/building_distro.md index 0dbabf8aa..d9b73c910 100644 --- a/docs/source/distributions/building_distro.md +++ b/docs/source/distributions/building_distro.md @@ -338,48 +338,6 @@ INFO: Application startup complete. INFO: Uvicorn running on http://['::', '0.0.0.0']:8321 (Press CTRL+C to quit) INFO: 2401:db00:35c:2d2b:face:0:c9:0:54678 - "GET /models/list HTTP/1.1" 200 OK ``` -### Listing Distributions -Using the list command, you can view all existing Llama Stack distributions, including stacks built from templates, from scratch, or using custom configuration files. - -``` -llama stack list -h -usage: llama stack list [-h] - -list the build stacks - -options: - -h, --help show this help message and exit -``` - -Example Usage - -``` -llama stack list -``` - -### Removing a Distribution -Use the remove command to delete a distribution you've previously built. - -``` -llama stack rm -h -usage: llama stack rm [-h] [--all] [name] - -Remove the build stack - -positional arguments: - name Name of the stack to delete (default: None) - -options: - -h, --help show this help message and exit - --all, -a Delete all stacks (use with caution) (default: False) -``` - -Example -``` -llama stack rm llamastack-test -``` - -To keep your environment organized and avoid clutter, consider using `llama stack list` to review old or unused distributions and `llama stack rm ` to delete them when they’re no longer needed. ### Troubleshooting diff --git a/docs/source/distributions/configuration.md b/docs/source/distributions/configuration.md index de99b6576..b62227a84 100644 --- a/docs/source/distributions/configuration.md +++ b/docs/source/distributions/configuration.md @@ -118,6 +118,11 @@ server: port: 8321 # Port to listen on (default: 8321) tls_certfile: "/path/to/cert.pem" # Optional: Path to TLS certificate for HTTPS tls_keyfile: "/path/to/key.pem" # Optional: Path to TLS key for HTTPS + auth: # Optional: Authentication configuration + provider_type: "kubernetes" # Type of auth provider + config: # Provider-specific configuration + api_server_url: "https://kubernetes.default.svc" + ca_cert_path: "/path/to/ca.crt" # Optional: Path to CA certificate ``` ### Authentication Configuration @@ -130,7 +135,7 @@ Authorization: Bearer The server supports multiple authentication providers: -#### OAuth 2.0/OpenID Connect Provider with Kubernetes +#### Kubernetes Provider The Kubernetes cluster must be configured to use a service account for authentication. @@ -141,67 +146,14 @@ kubectl create rolebinding llama-stack-auth-rolebinding --clusterrole=admin --se kubectl create token llama-stack-auth -n llama-stack > llama-stack-auth-token ``` -Make sure the `kube-apiserver` runs with `--anonymous-auth=true` to allow unauthenticated requests -and that the correct RoleBinding is created to allow the service account to access the necessary -resources. If that is not the case, you can create a RoleBinding for the service account to access -the necessary resources: - -```yaml -# allow-anonymous-openid.yaml -apiVersion: rbac.authorization.k8s.io/v1 -kind: ClusterRole -metadata: - name: allow-anonymous-openid -rules: -- nonResourceURLs: ["/openid/v1/jwks"] - verbs: ["get"] ---- -apiVersion: rbac.authorization.k8s.io/v1 -kind: ClusterRoleBinding -metadata: - name: allow-anonymous-openid -roleRef: - apiGroup: rbac.authorization.k8s.io - kind: ClusterRole - name: allow-anonymous-openid -subjects: -- kind: User - name: system:anonymous - apiGroup: rbac.authorization.k8s.io -``` - -And then apply the configuration: -```bash -kubectl apply -f allow-anonymous-openid.yaml -``` - -Validates tokens against the Kubernetes API server through the OIDC provider: +Validates tokens against the Kubernetes API server: ```yaml server: auth: - provider_type: "oauth2_token" + provider_type: "kubernetes" config: - jwks: - uri: "https://kubernetes.default.svc" - key_recheck_period: 3600 - tls_cafile: "/path/to/ca.crt" - issuer: "https://kubernetes.default.svc" - audience: "https://kubernetes.default.svc" -``` - -To find your cluster's audience, run: -```bash -kubectl create token default --duration=1h | cut -d. -f2 | base64 -d | jq .aud -``` - -For the issuer, you can use the OIDC provider's URL: -```bash -kubectl get --raw /.well-known/openid-configuration| jq .issuer -``` - -For the tls_cafile, you can use the CA certificate of the OIDC provider: -```bash -kubectl config view --minify -o jsonpath='{.clusters[0].cluster.certificate-authority}' + api_server_url: "https://kubernetes.default.svc" # URL of the Kubernetes API server + ca_cert_path: "/path/to/ca.crt" # Optional: Path to CA certificate ``` The provider extracts user information from the JWT token: @@ -256,80 +208,6 @@ And must respond with: If no access attributes are returned, the token is used as a namespace. -### Quota Configuration - -The `quota` section allows you to enable server-side request throttling for both -authenticated and anonymous clients. This is useful for preventing abuse, enforcing -fairness across tenants, and controlling infrastructure costs without requiring -client-side rate limiting or external proxies. - -Quotas are disabled by default. When enabled, each client is tracked using either: - -* Their authenticated `client_id` (derived from the Bearer token), or -* Their IP address (fallback for anonymous requests) - -Quota state is stored in a SQLite-backed key-value store, and rate limits are applied -within a configurable time window (currently only `day` is supported). - -#### Example - -```yaml -server: - quota: - kvstore: - type: sqlite - db_path: ./quotas.db - anonymous_max_requests: 100 - authenticated_max_requests: 1000 - period: day -``` - -#### Configuration Options - -| Field | Description | -| ---------------------------- | -------------------------------------------------------------------------- | -| `kvstore` | Required. Backend storage config for tracking request counts. | -| `kvstore.type` | Must be `"sqlite"` for now. Other backends may be supported in the future. | -| `kvstore.db_path` | File path to the SQLite database. | -| `anonymous_max_requests` | Max requests per period for unauthenticated clients. | -| `authenticated_max_requests` | Max requests per period for authenticated clients. | -| `period` | Time window for quota enforcement. Only `"day"` is supported. | - -> Note: if `authenticated_max_requests` is set but no authentication provider is -configured, the server will fall back to applying `anonymous_max_requests` to all -clients. - -#### Example with Authentication Enabled - -```yaml -server: - port: 8321 - auth: - provider_type: custom - config: - endpoint: https://auth.example.com/validate - quota: - kvstore: - type: sqlite - db_path: ./quotas.db - anonymous_max_requests: 100 - authenticated_max_requests: 1000 - period: day -``` - -If a client exceeds their limit, the server responds with: - -```http -HTTP/1.1 429 Too Many Requests -Content-Type: application/json - -{ - "error": { - "message": "Quota exceeded" - } -} -``` - ## Extending to handle Safety Configuring Safety can be a little involved so it is instructive to go through an example. diff --git a/docs/source/distributions/kubernetes_deployment.md b/docs/source/distributions/kubernetes_deployment.md index f43039824..21ec02012 100644 --- a/docs/source/distributions/kubernetes_deployment.md +++ b/docs/source/distributions/kubernetes_deployment.md @@ -172,7 +172,7 @@ spec: - name: llama-stack image: localhost/llama-stack-run-k8s:latest imagePullPolicy: IfNotPresent - command: ["python", "-m", "llama_stack.distribution.server.server", "--config", "/app/config.yaml"] + command: ["python", "-m", "llama_stack.distribution.server.server", "--yaml-config", "/app/config.yaml"] ports: - containerPort: 5000 volumeMounts: diff --git a/docs/source/distributions/remote_hosted_distro/watsonx.md b/docs/source/distributions/remote_hosted_distro/watsonx.md index ec1b98059..d8d327bb5 100644 --- a/docs/source/distributions/remote_hosted_distro/watsonx.md +++ b/docs/source/distributions/remote_hosted_distro/watsonx.md @@ -70,7 +70,7 @@ docker run \ -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ -v ./run.yaml:/root/my-run.yaml \ llamastack/distribution-watsonx \ - --config /root/my-run.yaml \ + --yaml-config /root/my-run.yaml \ --port $LLAMA_STACK_PORT \ --env WATSONX_API_KEY=$WATSONX_API_KEY \ --env WATSONX_PROJECT_ID=$WATSONX_PROJECT_ID \ diff --git a/docs/source/distributions/self_hosted_distro/cerebras.md b/docs/source/distributions/self_hosted_distro/cerebras.md index 3c4db1b75..329c9b802 100644 --- a/docs/source/distributions/self_hosted_distro/cerebras.md +++ b/docs/source/distributions/self_hosted_distro/cerebras.md @@ -52,7 +52,7 @@ docker run \ -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ -v ./run.yaml:/root/my-run.yaml \ llamastack/distribution-cerebras \ - --config /root/my-run.yaml \ + --yaml-config /root/my-run.yaml \ --port $LLAMA_STACK_PORT \ --env CEREBRAS_API_KEY=$CEREBRAS_API_KEY ``` diff --git a/docs/source/distributions/self_hosted_distro/dell.md b/docs/source/distributions/self_hosted_distro/dell.md index eded3bdc4..2e987985c 100644 --- a/docs/source/distributions/self_hosted_distro/dell.md +++ b/docs/source/distributions/self_hosted_distro/dell.md @@ -155,7 +155,7 @@ docker run \ -v $HOME/.llama:/root/.llama \ -v ./llama_stack/templates/tgi/run-with-safety.yaml:/root/my-run.yaml \ llamastack/distribution-dell \ - --config /root/my-run.yaml \ + --yaml-config /root/my-run.yaml \ --port $LLAMA_STACK_PORT \ --env INFERENCE_MODEL=$INFERENCE_MODEL \ --env DEH_URL=$DEH_URL \ diff --git a/docs/source/distributions/self_hosted_distro/nvidia.md b/docs/source/distributions/self_hosted_distro/nvidia.md index e84b5c525..a5bbbfdee 100644 --- a/docs/source/distributions/self_hosted_distro/nvidia.md +++ b/docs/source/distributions/self_hosted_distro/nvidia.md @@ -143,7 +143,7 @@ docker run \ -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ -v ./run.yaml:/root/my-run.yaml \ llamastack/distribution-nvidia \ - --config /root/my-run.yaml \ + --yaml-config /root/my-run.yaml \ --port $LLAMA_STACK_PORT \ --env NVIDIA_API_KEY=$NVIDIA_API_KEY ``` diff --git a/docs/source/distributions/self_hosted_distro/ollama.md b/docs/source/distributions/self_hosted_distro/ollama.md index 4d148feda..5d8935fe2 100644 --- a/docs/source/distributions/self_hosted_distro/ollama.md +++ b/docs/source/distributions/self_hosted_distro/ollama.md @@ -19,7 +19,6 @@ The `llamastack/distribution-ollama` distribution consists of the following prov | datasetio | `remote::huggingface`, `inline::localfs` | | eval | `inline::meta-reference` | | inference | `remote::ollama` | -| post_training | `inline::huggingface` | | safety | `inline::llama-guard` | | scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` | | telemetry | `inline::meta-reference` | @@ -98,7 +97,7 @@ docker run \ -v ~/.llama:/root/.llama \ -v ./llama_stack/templates/ollama/run-with-safety.yaml:/root/my-run.yaml \ llamastack/distribution-ollama \ - --config /root/my-run.yaml \ + --yaml-config /root/my-run.yaml \ --port $LLAMA_STACK_PORT \ --env INFERENCE_MODEL=$INFERENCE_MODEL \ --env SAFETY_MODEL=$SAFETY_MODEL \ diff --git a/docs/source/distributions/self_hosted_distro/remote-vllm.md b/docs/source/distributions/self_hosted_distro/remote-vllm.md index 6e7cf410d..2ff4bad5b 100644 --- a/docs/source/distributions/self_hosted_distro/remote-vllm.md +++ b/docs/source/distributions/self_hosted_distro/remote-vllm.md @@ -233,7 +233,7 @@ docker run \ -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ -v ./llama_stack/templates/remote-vllm/run.yaml:/root/my-run.yaml \ llamastack/distribution-remote-vllm \ - --config /root/my-run.yaml \ + --yaml-config /root/my-run.yaml \ --port $LLAMA_STACK_PORT \ --env INFERENCE_MODEL=$INFERENCE_MODEL \ --env VLLM_URL=http://host.docker.internal:$INFERENCE_PORT/v1 @@ -255,7 +255,7 @@ docker run \ -v ~/.llama:/root/.llama \ -v ./llama_stack/templates/remote-vllm/run-with-safety.yaml:/root/my-run.yaml \ llamastack/distribution-remote-vllm \ - --config /root/my-run.yaml \ + --yaml-config /root/my-run.yaml \ --port $LLAMA_STACK_PORT \ --env INFERENCE_MODEL=$INFERENCE_MODEL \ --env VLLM_URL=http://host.docker.internal:$INFERENCE_PORT/v1 \ diff --git a/docs/source/distributions/self_hosted_distro/sambanova.md b/docs/source/distributions/self_hosted_distro/sambanova.md index bb4842362..aaa8fd3cc 100644 --- a/docs/source/distributions/self_hosted_distro/sambanova.md +++ b/docs/source/distributions/self_hosted_distro/sambanova.md @@ -17,7 +17,7 @@ The `llamastack/distribution-sambanova` distribution consists of the following p |-----|-------------| | agents | `inline::meta-reference` | | inference | `remote::sambanova`, `inline::sentence-transformers` | -| safety | `remote::sambanova` | +| safety | `inline::llama-guard` | | telemetry | `inline::meta-reference` | | tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::rag-runtime`, `remote::model-context-protocol`, `remote::wolfram-alpha` | | vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` | @@ -48,44 +48,33 @@ The following models are available by default: ### Prerequisite: API Keys -Make sure you have access to a SambaNova API Key. You can get one by visiting [SambaNova.ai](http://cloud.sambanova.ai?utm_source=llamastack&utm_medium=external&utm_campaign=cloud_signup). +Make sure you have access to a SambaNova API Key. You can get one by visiting [SambaNova.ai](https://sambanova.ai/). ## Running Llama Stack with SambaNova You can do this via Conda (build code) or Docker which has a pre-built image. - ### Via Docker +This method allows you to get started quickly without having to build the distribution code. + ```bash LLAMA_STACK_PORT=8321 -llama stack build --template sambanova --image-type container docker run \ -it \ + --pull always \ -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ - -v ~/.llama:/root/.llama \ - distribution-sambanova \ + llamastack/distribution-sambanova \ --port $LLAMA_STACK_PORT \ --env SAMBANOVA_API_KEY=$SAMBANOVA_API_KEY ``` - -### Via Venv - -```bash -llama stack build --template sambanova --image-type venv -llama stack run --image-type venv ~/.llama/distributions/sambanova/sambanova-run.yaml \ - --port $LLAMA_STACK_PORT \ - --env SAMBANOVA_API_KEY=$SAMBANOVA_API_KEY -``` - - ### Via Conda ```bash llama stack build --template sambanova --image-type conda -llama stack run --image-type conda ~/.llama/distributions/sambanova/sambanova-run.yaml \ +llama stack run ./run.yaml \ --port $LLAMA_STACK_PORT \ --env SAMBANOVA_API_KEY=$SAMBANOVA_API_KEY ``` diff --git a/docs/source/distributions/self_hosted_distro/tgi.md b/docs/source/distributions/self_hosted_distro/tgi.md index 24f9d03ec..7a75aa559 100644 --- a/docs/source/distributions/self_hosted_distro/tgi.md +++ b/docs/source/distributions/self_hosted_distro/tgi.md @@ -117,7 +117,7 @@ docker run \ -v ~/.llama:/root/.llama \ -v ./llama_stack/templates/tgi/run-with-safety.yaml:/root/my-run.yaml \ llamastack/distribution-tgi \ - --config /root/my-run.yaml \ + --yaml-config /root/my-run.yaml \ --port $LLAMA_STACK_PORT \ --env INFERENCE_MODEL=$INFERENCE_MODEL \ --env TGI_URL=http://host.docker.internal:$INFERENCE_PORT \ diff --git a/docs/source/providers/index.md b/docs/source/providers/index.md index 1f5026479..1d1a6e081 100644 --- a/docs/source/providers/index.md +++ b/docs/source/providers/index.md @@ -30,18 +30,6 @@ Runs inference with an LLM. ## Post Training Fine-tunes a model. -#### Post Training Providers -The following providers are available for Post Training: - -```{toctree} -:maxdepth: 1 - -external -post_training/huggingface -post_training/torchtune -post_training/nvidia_nemo -``` - ## Safety Applies safety policies to the output at a Systems (not only model) level. diff --git a/docs/source/providers/post_training/huggingface.md b/docs/source/providers/post_training/huggingface.md deleted file mode 100644 index c342203a8..000000000 --- a/docs/source/providers/post_training/huggingface.md +++ /dev/null @@ -1,122 +0,0 @@ ---- -orphan: true ---- -# HuggingFace SFTTrainer - -[HuggingFace SFTTrainer](https://huggingface.co/docs/trl/en/sft_trainer) is an inline post training provider for Llama Stack. It allows you to run supervised fine tuning on a variety of models using many datasets - -## Features - -- Simple access through the post_training API -- Fully integrated with Llama Stack -- GPU support, CPU support, and MPS support (MacOS Metal Performance Shaders) - -## Usage - -To use the HF SFTTrainer in your Llama Stack project, follow these steps: - -1. Configure your Llama Stack project to use this provider. -2. Kick off a SFT job using the Llama Stack post_training API. - -## Setup - -You can access the HuggingFace trainer via the `ollama` distribution: - -```bash -llama stack build --template ollama --image-type venv -llama stack run --image-type venv ~/.llama/distributions/ollama/ollama-run.yaml -``` - -## Run Training - -You can access the provider and the `supervised_fine_tune` method via the post_training API: - -```python -import time -import uuid - - -from llama_stack_client.types import ( - post_training_supervised_fine_tune_params, - algorithm_config_param, -) - - -def create_http_client(): - from llama_stack_client import LlamaStackClient - - return LlamaStackClient(base_url="http://localhost:8321") - - -client = create_http_client() - -# Example Dataset -client.datasets.register( - purpose="post-training/messages", - source={ - "type": "uri", - "uri": "huggingface://datasets/llamastack/simpleqa?split=train", - }, - dataset_id="simpleqa", -) - -training_config = post_training_supervised_fine_tune_params.TrainingConfig( - data_config=post_training_supervised_fine_tune_params.TrainingConfigDataConfig( - batch_size=32, - data_format="instruct", - dataset_id="simpleqa", - shuffle=True, - ), - gradient_accumulation_steps=1, - max_steps_per_epoch=0, - max_validation_steps=1, - n_epochs=4, -) - -algorithm_config = algorithm_config_param.LoraFinetuningConfig( # this config is also currently mandatory but should not be - alpha=1, - apply_lora_to_mlp=True, - apply_lora_to_output=False, - lora_attn_modules=["q_proj"], - rank=1, - type="LoRA", -) - -job_uuid = f"test-job{uuid.uuid4()}" - -# Example Model -training_model = "ibm-granite/granite-3.3-8b-instruct" - -start_time = time.time() -response = client.post_training.supervised_fine_tune( - job_uuid=job_uuid, - logger_config={}, - model=training_model, - hyperparam_search_config={}, - training_config=training_config, - algorithm_config=algorithm_config, - checkpoint_dir="output", -) -print("Job: ", job_uuid) - - -# Wait for the job to complete! -while True: - status = client.post_training.job.status(job_uuid=job_uuid) - if not status: - print("Job not found") - break - - print(status) - if status.status == "completed": - break - - print("Waiting for job to complete...") - time.sleep(5) - -end_time = time.time() -print("Job completed in", end_time - start_time, "seconds!") - -print("Artifacts:") -print(client.post_training.job.artifacts(job_uuid=job_uuid)) -``` diff --git a/docs/source/providers/post_training/nvidia_nemo.md b/docs/source/providers/post_training/nvidia_nemo.md deleted file mode 100644 index 1a7adbe16..000000000 --- a/docs/source/providers/post_training/nvidia_nemo.md +++ /dev/null @@ -1,163 +0,0 @@ ---- -orphan: true ---- -# NVIDIA NEMO - -[NVIDIA NEMO](https://developer.nvidia.com/nemo-framework) is a remote post training provider for Llama Stack. It provides enterprise-grade fine-tuning capabilities through NVIDIA's NeMo Customizer service. - -## Features - -- Enterprise-grade fine-tuning capabilities -- Support for LoRA and SFT fine-tuning -- Integration with NVIDIA's NeMo Customizer service -- Support for various NVIDIA-optimized models -- Efficient training with NVIDIA hardware acceleration - -## Usage - -To use NVIDIA NEMO in your Llama Stack project, follow these steps: - -1. Configure your Llama Stack project to use this provider. -2. Set up your NVIDIA API credentials. -3. Kick off a fine-tuning job using the Llama Stack post_training API. - -## Setup - -You'll need to set the following environment variables: - -```bash -export NVIDIA_API_KEY="your-api-key" -export NVIDIA_DATASET_NAMESPACE="default" -export NVIDIA_CUSTOMIZER_URL="your-customizer-url" -export NVIDIA_PROJECT_ID="your-project-id" -export NVIDIA_OUTPUT_MODEL_DIR="your-output-model-dir" -``` - -## Run Training - -You can access the provider and the `supervised_fine_tune` method via the post_training API: - -```python -import time -import uuid - -from llama_stack_client.types import ( - post_training_supervised_fine_tune_params, - algorithm_config_param, -) - - -def create_http_client(): - from llama_stack_client import LlamaStackClient - - return LlamaStackClient(base_url="http://localhost:8321") - - -client = create_http_client() - -# Example Dataset -client.datasets.register( - purpose="post-training/messages", - source={ - "type": "uri", - "uri": "huggingface://datasets/llamastack/simpleqa?split=train", - }, - dataset_id="simpleqa", -) - -training_config = post_training_supervised_fine_tune_params.TrainingConfig( - data_config=post_training_supervised_fine_tune_params.TrainingConfigDataConfig( - batch_size=8, # Default batch size for NEMO - data_format="instruct", - dataset_id="simpleqa", - shuffle=True, - ), - n_epochs=50, # Default epochs for NEMO - optimizer_config=post_training_supervised_fine_tune_params.TrainingConfigOptimizerConfig( - lr=0.0001, # Default learning rate - weight_decay=0.01, # NEMO-specific parameter - ), - # NEMO-specific parameters - log_every_n_steps=None, - val_check_interval=0.25, - sequence_packing_enabled=False, - hidden_dropout=None, - attention_dropout=None, - ffn_dropout=None, -) - -algorithm_config = algorithm_config_param.LoraFinetuningConfig( - alpha=16, # Default alpha for NEMO - type="LoRA", -) - -job_uuid = f"test-job{uuid.uuid4()}" - -# Example Model - must be a supported NEMO model -training_model = "meta/llama-3.1-8b-instruct" - -start_time = time.time() -response = client.post_training.supervised_fine_tune( - job_uuid=job_uuid, - logger_config={}, - model=training_model, - hyperparam_search_config={}, - training_config=training_config, - algorithm_config=algorithm_config, - checkpoint_dir="output", -) -print("Job: ", job_uuid) - -# Wait for the job to complete! -while True: - status = client.post_training.job.status(job_uuid=job_uuid) - if not status: - print("Job not found") - break - - print(status) - if status.status == "completed": - break - - print("Waiting for job to complete...") - time.sleep(5) - -end_time = time.time() -print("Job completed in", end_time - start_time, "seconds!") - -print("Artifacts:") -print(client.post_training.job.artifacts(job_uuid=job_uuid)) -``` - -## Supported Models - -Currently supports the following models: -- meta/llama-3.1-8b-instruct -- meta/llama-3.2-1b-instruct - -## Supported Parameters - -### TrainingConfig -- n_epochs (default: 50) -- data_config -- optimizer_config -- log_every_n_steps -- val_check_interval (default: 0.25) -- sequence_packing_enabled (default: False) -- hidden_dropout (0.0-1.0) -- attention_dropout (0.0-1.0) -- ffn_dropout (0.0-1.0) - -### DataConfig -- dataset_id -- batch_size (default: 8) - -### OptimizerConfig -- lr (default: 0.0001) -- weight_decay (default: 0.01) - -### LoRA Config -- alpha (default: 16) -- type (must be "LoRA") - -Note: Some parameters from the standard Llama Stack API are not supported and will be ignored with a warning. diff --git a/docs/source/providers/post_training/torchtune.md b/docs/source/providers/post_training/torchtune.md deleted file mode 100644 index ef72505b1..000000000 --- a/docs/source/providers/post_training/torchtune.md +++ /dev/null @@ -1,125 +0,0 @@ ---- -orphan: true ---- -# TorchTune - -[TorchTune](https://github.com/pytorch/torchtune) is an inline post training provider for Llama Stack. It provides a simple and efficient way to fine-tune language models using PyTorch. - -## Features - -- Simple access through the post_training API -- Fully integrated with Llama Stack -- GPU support and single device capabilities. -- Support for LoRA - -## Usage - -To use TorchTune in your Llama Stack project, follow these steps: - -1. Configure your Llama Stack project to use this provider. -2. Kick off a fine-tuning job using the Llama Stack post_training API. - -## Setup - -You can access the TorchTune trainer by writing your own yaml pointing to the provider: - -```yaml -post_training: - - provider_id: torchtune - provider_type: inline::torchtune - config: {} -``` - -you can then build and run your own stack with this provider. - -## Run Training - -You can access the provider and the `supervised_fine_tune` method via the post_training API: - -```python -import time -import uuid - -from llama_stack_client.types import ( - post_training_supervised_fine_tune_params, - algorithm_config_param, -) - - -def create_http_client(): - from llama_stack_client import LlamaStackClient - - return LlamaStackClient(base_url="http://localhost:8321") - - -client = create_http_client() - -# Example Dataset -client.datasets.register( - purpose="post-training/messages", - source={ - "type": "uri", - "uri": "huggingface://datasets/llamastack/simpleqa?split=train", - }, - dataset_id="simpleqa", -) - -training_config = post_training_supervised_fine_tune_params.TrainingConfig( - data_config=post_training_supervised_fine_tune_params.TrainingConfigDataConfig( - batch_size=32, - data_format="instruct", - dataset_id="simpleqa", - shuffle=True, - ), - gradient_accumulation_steps=1, - max_steps_per_epoch=0, - max_validation_steps=1, - n_epochs=4, -) - -algorithm_config = algorithm_config_param.LoraFinetuningConfig( - alpha=1, - apply_lora_to_mlp=True, - apply_lora_to_output=False, - lora_attn_modules=["q_proj"], - rank=1, - type="LoRA", -) - -job_uuid = f"test-job{uuid.uuid4()}" - -# Example Model -training_model = "meta-llama/Llama-2-7b-hf" - -start_time = time.time() -response = client.post_training.supervised_fine_tune( - job_uuid=job_uuid, - logger_config={}, - model=training_model, - hyperparam_search_config={}, - training_config=training_config, - algorithm_config=algorithm_config, - checkpoint_dir="output", -) -print("Job: ", job_uuid) - -# Wait for the job to complete! -while True: - status = client.post_training.job.status(job_uuid=job_uuid) - if not status: - print("Job not found") - break - - print(status) - if status.status == "completed": - break - - print("Waiting for job to complete...") - time.sleep(5) - -end_time = time.time() -print("Job completed in", end_time - start_time, "seconds!") - -print("Artifacts:") -print(client.post_training.job.artifacts(job_uuid=job_uuid)) -``` diff --git a/docs/source/providers/vector_io/sqlite-vec.md b/docs/source/providers/vector_io/sqlite-vec.md index 49ba659f7..43d10c751 100644 --- a/docs/source/providers/vector_io/sqlite-vec.md +++ b/docs/source/providers/vector_io/sqlite-vec.md @@ -66,25 +66,6 @@ To use sqlite-vec in your Llama Stack project, follow these steps: 2. Configure your Llama Stack project to use SQLite-Vec. 3. Start storing and querying vectors. -## Supported Search Modes - -The sqlite-vec provider supports both vector-based and keyword-based (full-text) search modes. - -When using the RAGTool interface, you can specify the desired search behavior via the `mode` parameter in -`RAGQueryConfig`. For example: - -```python -from llama_stack.apis.tool_runtime.rag import RAGQueryConfig - -query_config = RAGQueryConfig(max_chunks=6, mode="vector") - -results = client.tool_runtime.rag_tool.query( - vector_db_ids=[vector_db_id], - content="what is torchtune", - query_config=query_config, -) -``` - ## Installation You can install SQLite-Vec using pip: diff --git a/kvant_build_local.sh b/kvant_build_local.sh deleted file mode 100755 index 9701c57dc..000000000 --- a/kvant_build_local.sh +++ /dev/null @@ -1,6 +0,0 @@ -#!/usr/bin/env bash - -export USE_COPY_NOT_MOUNT=true -export LLAMA_STACK_DIR=. - -uvx --from . llama stack build --template kvant --image-type container --image-name kvant diff --git a/kvant_start_local.sh b/kvant_start_local.sh deleted file mode 100755 index db5bff84a..000000000 --- a/kvant_start_local.sh +++ /dev/null @@ -1,17 +0,0 @@ -#!/usr/bin/env bash - -export LLAMA_STACK_PORT=8321 -# VLLM_API_TOKEN= env file -# KEYCLOAK_CLIENT_SECRET= env file - - -docker run -it \ - -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ - -v $(pwd)/data:/root/.llama \ - --mount type=bind,source="$(pwd)"/llama_stack/templates/kvant/run.yaml,target=/root/.llama/config.yaml,readonly \ - --entrypoint python \ - --env-file ./.env \ - distribution-kvant:dev \ - -m llama_stack.distribution.server.server --config /root/.llama/config.yaml \ - --port $LLAMA_STACK_PORT \ - diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index b79c512b8..b2f85336c 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -13,7 +13,7 @@ from typing import Annotated, Any, Literal, Protocol, runtime_checkable from pydantic import BaseModel, ConfigDict, Field from llama_stack.apis.common.content_types import URL, ContentDelta, InterleavedContent -from llama_stack.apis.common.responses import Order, PaginatedResponse +from llama_stack.apis.common.responses import PaginatedResponse from llama_stack.apis.inference import ( CompletionMessage, ResponseFormat, @@ -31,8 +31,6 @@ from llama_stack.apis.tools import ToolDef from llama_stack.schema_utils import json_schema_type, register_schema, webmethod from .openai_responses import ( - ListOpenAIResponseInputItem, - ListOpenAIResponseObject, OpenAIResponseInput, OpenAIResponseInputTool, OpenAIResponseObject, @@ -581,14 +579,14 @@ class Agents(Protocol): # # Both of these APIs are inherently stateful. - @webmethod(route="/openai/v1/responses/{response_id}", method="GET") + @webmethod(route="/openai/v1/responses/{id}", method="GET") async def get_openai_response( self, - response_id: str, + id: str, ) -> OpenAIResponseObject: """Retrieve an OpenAI response by its ID. - :param response_id: The ID of the OpenAI response to retrieve. + :param id: The ID of the OpenAI response to retrieve. :returns: An OpenAIResponseObject. """ ... @@ -598,7 +596,6 @@ class Agents(Protocol): self, input: str | list[OpenAIResponseInput], model: str, - instructions: str | None = None, previous_response_id: str | None = None, store: bool | None = True, stream: bool | None = False, @@ -613,43 +610,3 @@ class Agents(Protocol): :returns: An OpenAIResponseObject. """ ... - - @webmethod(route="/openai/v1/responses", method="GET") - async def list_openai_responses( - self, - after: str | None = None, - limit: int | None = 50, - model: str | None = None, - order: Order | None = Order.desc, - ) -> ListOpenAIResponseObject: - """List all OpenAI responses. - - :param after: The ID of the last response to return. - :param limit: The number of responses to return. - :param model: The model to filter responses by. - :param order: The order to sort responses by when sorted by created_at ('asc' or 'desc'). - :returns: A ListOpenAIResponseObject. - """ - ... - - @webmethod(route="/openai/v1/responses/{response_id}/input_items", method="GET") - async def list_openai_response_input_items( - self, - response_id: str, - after: str | None = None, - before: str | None = None, - include: list[str] | None = None, - limit: int | None = 20, - order: Order | None = Order.desc, - ) -> ListOpenAIResponseInputItem: - """List input items for a given OpenAI response. - - :param response_id: The ID of the response to retrieve input items for. - :param after: An item ID to list items after, used for pagination. - :param before: An item ID to list items before, used for pagination. - :param include: Additional fields to include in the response. - :param limit: A limit on the number of objects to be returned. Limit can range between 1 and 100, and the default is 20. - :param order: The order to return the input items in. Default is desc. - :returns: An ListOpenAIResponseInputItem. - """ - ... diff --git a/llama_stack/apis/agents/openai_responses.py b/llama_stack/apis/agents/openai_responses.py index 6806e1d3f..dcf0c7f9c 100644 --- a/llama_stack/apis/agents/openai_responses.py +++ b/llama_stack/apis/agents/openai_responses.py @@ -10,9 +10,6 @@ from pydantic import BaseModel, Field from llama_stack.schema_utils import json_schema_type, register_schema -# NOTE(ashwin): this file is literally a copy of the OpenAI responses API schema. We should probably -# take their YAML and generate this file automatically. Their YAML is available. - @json_schema_type class OpenAIResponseError(BaseModel): @@ -82,45 +79,16 @@ class OpenAIResponseOutputMessageWebSearchToolCall(BaseModel): @json_schema_type class OpenAIResponseOutputMessageFunctionToolCall(BaseModel): + arguments: str call_id: str name: str - arguments: str type: Literal["function_call"] = "function_call" - id: str | None = None - status: str | None = None - - -@json_schema_type -class OpenAIResponseOutputMessageMCPCall(BaseModel): id: str - type: Literal["mcp_call"] = "mcp_call" - arguments: str - name: str - server_label: str - error: str | None = None - output: str | None = None - - -class MCPListToolsTool(BaseModel): - input_schema: dict[str, Any] - name: str - description: str | None = None - - -@json_schema_type -class OpenAIResponseOutputMessageMCPListTools(BaseModel): - id: str - type: Literal["mcp_list_tools"] = "mcp_list_tools" - server_label: str - tools: list[MCPListToolsTool] + status: str OpenAIResponseOutput = Annotated[ - OpenAIResponseMessage - | OpenAIResponseOutputMessageWebSearchToolCall - | OpenAIResponseOutputMessageFunctionToolCall - | OpenAIResponseOutputMessageMCPCall - | OpenAIResponseOutputMessageMCPListTools, + OpenAIResponseMessage | OpenAIResponseOutputMessageWebSearchToolCall | OpenAIResponseOutputMessageFunctionToolCall, Field(discriminator="type"), ] register_schema(OpenAIResponseOutput, name="OpenAIResponseOutput") @@ -149,16 +117,6 @@ class OpenAIResponseObjectStreamResponseCreated(BaseModel): type: Literal["response.created"] = "response.created" -@json_schema_type -class OpenAIResponseObjectStreamResponseOutputTextDelta(BaseModel): - content_index: int - delta: str - item_id: str - output_index: int - sequence_number: int - type: Literal["response.output_text.delta"] = "response.output_text.delta" - - @json_schema_type class OpenAIResponseObjectStreamResponseCompleted(BaseModel): response: OpenAIResponseObject @@ -166,9 +124,7 @@ class OpenAIResponseObjectStreamResponseCompleted(BaseModel): OpenAIResponseObjectStream = Annotated[ - OpenAIResponseObjectStreamResponseCreated - | OpenAIResponseObjectStreamResponseOutputTextDelta - | OpenAIResponseObjectStreamResponseCompleted, + OpenAIResponseObjectStreamResponseCreated | OpenAIResponseObjectStreamResponseCompleted, Field(discriminator="type"), ] register_schema(OpenAIResponseObjectStream, name="OpenAIResponseObjectStream") @@ -230,50 +186,13 @@ class OpenAIResponseInputToolFileSearch(BaseModel): # TODO: add filters -class ApprovalFilter(BaseModel): - always: list[str] | None = None - never: list[str] | None = None - - -class AllowedToolsFilter(BaseModel): - tool_names: list[str] | None = None - - -@json_schema_type -class OpenAIResponseInputToolMCP(BaseModel): - type: Literal["mcp"] = "mcp" - server_label: str - server_url: str - headers: dict[str, Any] | None = None - - require_approval: Literal["always"] | Literal["never"] | ApprovalFilter = "never" - allowed_tools: list[str] | AllowedToolsFilter | None = None - - OpenAIResponseInputTool = Annotated[ - OpenAIResponseInputToolWebSearch - | OpenAIResponseInputToolFileSearch - | OpenAIResponseInputToolFunction - | OpenAIResponseInputToolMCP, + OpenAIResponseInputToolWebSearch | OpenAIResponseInputToolFileSearch | OpenAIResponseInputToolFunction, Field(discriminator="type"), ] register_schema(OpenAIResponseInputTool, name="OpenAIResponseInputTool") -class ListOpenAIResponseInputItem(BaseModel): +class OpenAIResponseInputItemList(BaseModel): data: list[OpenAIResponseInput] object: Literal["list"] = "list" - - -@json_schema_type -class OpenAIResponseObjectWithInput(OpenAIResponseObject): - input: list[OpenAIResponseInput] - - -@json_schema_type -class ListOpenAIResponseObject(BaseModel): - data: list[OpenAIResponseObjectWithInput] - has_more: bool - first_id: str - last_id: str - object: Literal["list"] = "list" diff --git a/llama_stack/apis/common/deployment_types.py b/llama_stack/apis/common/deployment_types.py new file mode 100644 index 000000000..4d01d7ad1 --- /dev/null +++ b/llama_stack/apis/common/deployment_types.py @@ -0,0 +1,30 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from enum import Enum +from typing import Any + +from pydantic import BaseModel + +from llama_stack.apis.common.content_types import URL +from llama_stack.schema_utils import json_schema_type + + +@json_schema_type +class RestAPIMethod(Enum): + GET = "GET" + POST = "POST" + PUT = "PUT" + DELETE = "DELETE" + + +@json_schema_type +class RestAPIExecutionConfig(BaseModel): + url: URL + method: RestAPIMethod + params: dict[str, Any] | None = None + headers: dict[str, Any] | None = None + body: dict[str, Any] | None = None diff --git a/llama_stack/apis/common/responses.py b/llama_stack/apis/common/responses.py index 5cb41e23d..b3bb5cb6b 100644 --- a/llama_stack/apis/common/responses.py +++ b/llama_stack/apis/common/responses.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from enum import Enum from typing import Any from pydantic import BaseModel @@ -12,11 +11,6 @@ from pydantic import BaseModel from llama_stack.schema_utils import json_schema_type -class Order(Enum): - asc = "asc" - desc = "desc" - - @json_schema_type class PaginatedResponse(BaseModel): """A generic paginated response that follows a simple format. diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index 74697dd18..3c91b5a6e 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -19,7 +19,6 @@ from pydantic import BaseModel, Field, field_validator from typing_extensions import TypedDict from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent, InterleavedContentItem -from llama_stack.apis.common.responses import Order from llama_stack.apis.models import Model from llama_stack.apis.telemetry.telemetry import MetricResponseMixin from llama_stack.models.llama.datatypes import ( @@ -783,48 +782,6 @@ class OpenAICompletion(BaseModel): object: Literal["text_completion"] = "text_completion" -@json_schema_type -class OpenAIEmbeddingData(BaseModel): - """A single embedding data object from an OpenAI-compatible embeddings response. - - :param object: The object type, which will be "embedding" - :param embedding: The embedding vector as a list of floats (when encoding_format="float") or as a base64-encoded string (when encoding_format="base64") - :param index: The index of the embedding in the input list - """ - - object: Literal["embedding"] = "embedding" - embedding: list[float] | str - index: int - - -@json_schema_type -class OpenAIEmbeddingUsage(BaseModel): - """Usage information for an OpenAI-compatible embeddings response. - - :param prompt_tokens: The number of tokens in the input - :param total_tokens: The total number of tokens used - """ - - prompt_tokens: int - total_tokens: int - - -@json_schema_type -class OpenAIEmbeddingsResponse(BaseModel): - """Response from an OpenAI-compatible embeddings request. - - :param object: The object type, which will be "list" - :param data: List of embedding data objects - :param model: The model that was used to generate the embeddings - :param usage: Usage information - """ - - object: Literal["list"] = "list" - data: list[OpenAIEmbeddingData] - model: str - usage: OpenAIEmbeddingUsage - - class ModelStore(Protocol): async def get_model(self, identifier: str) -> Model: ... @@ -863,27 +820,15 @@ class BatchChatCompletionResponse(BaseModel): batch: list[ChatCompletionResponse] -class OpenAICompletionWithInputMessages(OpenAIChatCompletion): - input_messages: list[OpenAIMessageParam] - - -@json_schema_type -class ListOpenAIChatCompletionResponse(BaseModel): - data: list[OpenAICompletionWithInputMessages] - has_more: bool - first_id: str - last_id: str - object: Literal["list"] = "list" - - @runtime_checkable @trace_protocol -class InferenceProvider(Protocol): - """ - This protocol defines the interface that should be implemented by all inference providers. - """ +class Inference(Protocol): + """Llama Stack Inference API for generating completions, chat completions, and embeddings. - API_NAMESPACE: str = "Inference" + This API provides the raw interface to the underlying models. Two kinds of models are supported: + - LLM models: these models generate "raw" and "chat" (conversational) completions. + - Embedding models: these models generate embeddings to be used for semantic search. + """ model_store: ModelStore | None = None @@ -1117,59 +1062,3 @@ class InferenceProvider(Protocol): :returns: An OpenAIChatCompletion. """ ... - - @webmethod(route="/openai/v1/embeddings", method="POST") - async def openai_embeddings( - self, - model: str, - input: str | list[str], - encoding_format: str | None = "float", - dimensions: int | None = None, - user: str | None = None, - ) -> OpenAIEmbeddingsResponse: - """Generate OpenAI-compatible embeddings for the given input using the specified model. - - :param model: The identifier of the model to use. The model must be an embedding model registered with Llama Stack and available via the /models endpoint. - :param input: Input text to embed, encoded as a string or array of strings. To embed multiple inputs in a single request, pass an array of strings. - :param encoding_format: (Optional) The format to return the embeddings in. Can be either "float" or "base64". Defaults to "float". - :param dimensions: (Optional) The number of dimensions the resulting output embeddings should have. Only supported in text-embedding-3 and later models. - :param user: (Optional) A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. - :returns: An OpenAIEmbeddingsResponse containing the embeddings. - """ - ... - - -class Inference(InferenceProvider): - """Llama Stack Inference API for generating completions, chat completions, and embeddings. - - This API provides the raw interface to the underlying models. Two kinds of models are supported: - - LLM models: these models generate "raw" and "chat" (conversational) completions. - - Embedding models: these models generate embeddings to be used for semantic search. - """ - - @webmethod(route="/openai/v1/chat/completions", method="GET") - async def list_chat_completions( - self, - after: str | None = None, - limit: int | None = 20, - model: str | None = None, - order: Order | None = Order.desc, - ) -> ListOpenAIChatCompletionResponse: - """List all chat completions. - - :param after: The ID of the last chat completion to return. - :param limit: The maximum number of chat completions to return. - :param model: The model to filter by. - :param order: The order to sort the chat completions by: "asc" or "desc". Defaults to "desc". - :returns: A ListOpenAIChatCompletionResponse. - """ - raise NotImplementedError("List chat completions is not implemented") - - @webmethod(route="/openai/v1/chat/completions/{completion_id}", method="GET") - async def get_chat_completion(self, completion_id: str) -> OpenAICompletionWithInputMessages: - """Describe a chat completion by its ID. - - :param completion_id: ID of the chat completion. - :returns: A OpenAICompletionWithInputMessages. - """ - raise NotImplementedError("Get chat completion is not implemented") diff --git a/llama_stack/apis/tools/rag_tool.py b/llama_stack/apis/tools/rag_tool.py index 1e3542f74..de3e4c62c 100644 --- a/llama_stack/apis/tools/rag_tool.py +++ b/llama_stack/apis/tools/rag_tool.py @@ -76,7 +76,6 @@ class RAGQueryConfig(BaseModel): :param chunk_template: Template for formatting each retrieved chunk in the context. Available placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk content string), {metadata} (chunk metadata dict). Default: "Result {index}\\nContent: {chunk.content}\\nMetadata: {metadata}\\n" - :param mode: Search mode for retrieval—either "vector" or "keyword". Default "vector". """ # This config defines how a query is generated using the messages @@ -85,7 +84,6 @@ class RAGQueryConfig(BaseModel): max_tokens_in_context: int = 4096 max_chunks: int = 5 chunk_template: str = "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n" - mode: str | None = None @field_validator("chunk_template") def validate_chunk_template(cls, v: str) -> str: diff --git a/llama_stack/apis/tools/tools.py b/llama_stack/apis/tools/tools.py index 0c8d47edf..2f62b0ba1 100644 --- a/llama_stack/apis/tools/tools.py +++ b/llama_stack/apis/tools/tools.py @@ -27,10 +27,18 @@ class ToolParameter(BaseModel): default: Any | None = None +@json_schema_type +class ToolHost(Enum): + distribution = "distribution" + client = "client" + model_context_protocol = "model_context_protocol" + + @json_schema_type class Tool(Resource): type: Literal[ResourceType.tool] = ResourceType.tool toolgroup_id: str + tool_host: ToolHost description: str parameters: list[ToolParameter] metadata: dict[str, Any] | None = None @@ -68,8 +76,8 @@ class ToolInvocationResult(BaseModel): class ToolStore(Protocol): - async def get_tool(self, tool_name: str) -> Tool: ... - async def get_tool_group(self, toolgroup_id: str) -> ToolGroup: ... + def get_tool(self, tool_name: str) -> Tool: ... + def get_tool_group(self, toolgroup_id: str) -> ToolGroup: ... class ListToolGroupsResponse(BaseModel): diff --git a/llama_stack/apis/vector_io/vector_io.py b/llama_stack/apis/vector_io/vector_io.py index 44cc8f904..3ac62d42c 100644 --- a/llama_stack/apis/vector_io/vector_io.py +++ b/llama_stack/apis/vector_io/vector_io.py @@ -19,16 +19,8 @@ from llama_stack.schema_utils import json_schema_type, webmethod class Chunk(BaseModel): - """ - A chunk of content that can be inserted into a vector database. - :param content: The content of the chunk, which can be interleaved text, images, or other types. - :param embedding: Optional embedding for the chunk. If not provided, it will be computed later. - :param metadata: Metadata associated with the chunk, such as document ID, source, or other relevant information. - """ - content: InterleavedContent metadata: dict[str, Any] = Field(default_factory=dict) - embedding: list[float] | None = None @json_schema_type @@ -58,10 +50,7 @@ class VectorIO(Protocol): """Insert chunks into a vector database. :param vector_db_id: The identifier of the vector database to insert the chunks into. - :param chunks: The chunks to insert. Each `Chunk` should contain content which can be interleaved text, images, or other types. - `metadata`: `dict[str, Any]` and `embedding`: `List[float]` are optional. - If `metadata` is provided, you configure how Llama Stack formats the chunk during generation. - If `embedding` is not provided, it will be computed later. + :param chunks: The chunks to insert. :param ttl_seconds: The time to live of the chunks. """ ... diff --git a/llama_stack/cli/download.py b/llama_stack/cli/download.py index b96842119..09c753776 100644 --- a/llama_stack/cli/download.py +++ b/llama_stack/cli/download.py @@ -9,7 +9,6 @@ import asyncio import json import os import shutil -import sys from dataclasses import dataclass from datetime import datetime, timezone from functools import partial @@ -378,15 +377,14 @@ def _meta_download( downloader = ParallelDownloader(max_concurrent_downloads=max_concurrent_downloads) asyncio.run(downloader.download_all(tasks)) - cprint(f"\nSuccessfully downloaded model to {output_dir}", color="green", file=sys.stderr) + cprint(f"\nSuccessfully downloaded model to {output_dir}", "green") cprint( f"\nView MD5 checksum files at: {output_dir / 'checklist.chk'}", - file=sys.stderr, + "white", ) cprint( f"\n[Optionally] To run MD5 checksums, use the following command: llama model verify-download --model-id {model_id}", - color="yellow", - file=sys.stderr, + "yellow", ) diff --git a/llama_stack/cli/stack/_build.py b/llama_stack/cli/stack/_build.py index f6f72946a..37147e905 100644 --- a/llama_stack/cli/stack/_build.py +++ b/llama_stack/cli/stack/_build.py @@ -12,7 +12,6 @@ import shutil import sys import textwrap from functools import lru_cache -from importlib.abc import Traversable from pathlib import Path import yaml @@ -79,7 +78,6 @@ def run_stack_build_command(args: argparse.Namespace) -> None: cprint( f"Could not find template {args.template}. Please run `llama stack build --list-templates` to check out the available templates", color="red", - file=sys.stderr, ) sys.exit(1) build_config = available_templates[args.template] @@ -89,7 +87,6 @@ def run_stack_build_command(args: argparse.Namespace) -> None: cprint( f"Please specify a image-type ({' | '.join(e.value for e in ImageType)}) for {args.template}", color="red", - file=sys.stderr, ) sys.exit(1) elif args.providers: @@ -99,7 +96,6 @@ def run_stack_build_command(args: argparse.Namespace) -> None: cprint( "Could not parse `--providers`. Please ensure the list is in the format api1=provider1,api2=provider2", color="red", - file=sys.stderr, ) sys.exit(1) api, provider = api_provider.split("=") @@ -108,7 +104,6 @@ def run_stack_build_command(args: argparse.Namespace) -> None: cprint( f"{api} is not a valid API.", color="red", - file=sys.stderr, ) sys.exit(1) if provider in providers_for_api: @@ -117,7 +112,6 @@ def run_stack_build_command(args: argparse.Namespace) -> None: cprint( f"{provider} is not a valid provider for the {api} API.", color="red", - file=sys.stderr, ) sys.exit(1) distribution_spec = DistributionSpec( @@ -128,7 +122,6 @@ def run_stack_build_command(args: argparse.Namespace) -> None: cprint( f"Please specify a image-type (container | conda | venv) for {args.template}", color="red", - file=sys.stderr, ) sys.exit(1) @@ -157,14 +150,12 @@ def run_stack_build_command(args: argparse.Namespace) -> None: cprint( f"No current conda environment detected or specified, will create a new conda environment with the name `llamastack-{name}`", color="yellow", - file=sys.stderr, ) image_name = f"llamastack-{name}" else: cprint( f"Using conda environment {image_name}", color="green", - file=sys.stderr, ) else: image_name = f"llamastack-{name}" @@ -177,10 +168,9 @@ def run_stack_build_command(args: argparse.Namespace) -> None: """, ), color="green", - file=sys.stderr, ) - cprint("Tip: use to see options for the providers.\n", color="green", file=sys.stderr) + print("Tip: use to see options for the providers.\n") providers = dict() for api, providers_for_api in get_provider_registry().items(): @@ -216,13 +206,10 @@ def run_stack_build_command(args: argparse.Namespace) -> None: contents = yaml.safe_load(f) contents = replace_env_vars(contents) build_config = BuildConfig(**contents) - if args.image_type: - build_config.image_type = args.image_type except Exception as e: cprint( f"Could not parse config file {args.config}: {e}", color="red", - file=sys.stderr, ) sys.exit(1) @@ -249,27 +236,25 @@ def run_stack_build_command(args: argparse.Namespace) -> None: cprint( f"Error building stack: {exc}", color="red", - file=sys.stderr, ) - cprint("Stack trace:", color="red", file=sys.stderr) + cprint("Stack trace:", color="red") traceback.print_exc() sys.exit(1) - if run_config is None: cprint( "Run config path is empty", color="red", - file=sys.stderr, ) sys.exit(1) if args.run: + run_config = Path(run_config) config_dict = yaml.safe_load(run_config.read_text()) config = parse_and_maybe_upgrade_config(config_dict) - if config.external_providers_dir and not config.external_providers_dir.exists(): - config.external_providers_dir.mkdir(exist_ok=True) + if not os.path.exists(str(config.external_providers_dir)): + os.makedirs(str(config.external_providers_dir), exist_ok=True) run_args = formulate_run_args(args.image_type, args.image_name, config, args.template) - run_args.extend([str(os.getenv("LLAMA_STACK_PORT", 8321)), "--config", run_config]) + run_args.extend([run_config, str(os.getenv("LLAMA_STACK_PORT", 8321))]) run_command(run_args) @@ -277,7 +262,7 @@ def _generate_run_config( build_config: BuildConfig, build_dir: Path, image_name: str, -) -> Path: +) -> str: """ Generate a run.yaml template file for user to edit from a build.yaml file """ @@ -317,7 +302,6 @@ def _generate_run_config( cprint( f"Failed to import provider {provider_type} for API {api} - assuming it's external, skipping", color="yellow", - file=sys.stderr, ) # Set config_type to None to avoid UnboundLocalError config_type = None @@ -345,7 +329,10 @@ def _generate_run_config( # For non-container builds, the run.yaml is generated at the very end of the build process so it # makes sense to display this message if build_config.image_type != LlamaStackImageType.CONTAINER.value: - cprint(f"You can now run your stack with `llama stack run {run_config_file}`", color="green", file=sys.stderr) + cprint( + f"You can now run your stack with `llama stack run {run_config_file}`", + color="green", + ) return run_config_file @@ -354,7 +341,7 @@ def _run_stack_build_command_from_build_config( image_name: str | None = None, template_name: str | None = None, config_path: str | None = None, -) -> Path | Traversable: +) -> str: image_name = image_name or build_config.image_name if build_config.image_type == LlamaStackImageType.CONTAINER.value: if template_name: @@ -383,7 +370,7 @@ def _run_stack_build_command_from_build_config( # Generate the run.yaml so it can be included in the container image with the proper entrypoint # Only do this if we're building a container image and we're not using a template if build_config.image_type == LlamaStackImageType.CONTAINER.value and not template_name and config_path: - cprint("Generating run.yaml file", color="yellow", file=sys.stderr) + cprint("Generating run.yaml file", color="green") run_config_file = _generate_run_config(build_config, build_dir, image_name) with open(build_file_path, "w") as f: @@ -407,13 +394,11 @@ def _run_stack_build_command_from_build_config( run_config_file = build_dir / f"{template_name}-run.yaml" shutil.copy(path, run_config_file) - cprint("Build Successful!", color="green", file=sys.stderr) - cprint(f"You can find the newly-built template here: {template_path}", color="light_blue", file=sys.stderr) + cprint("Build Successful!", color="green") + cprint("You can find the newly-built template here: " + colored(template_path, "light_blue")) cprint( "You can run the new Llama Stack distro via: " - + colored(f"llama stack run {template_path} --image-type {build_config.image_type}", "light_blue"), - color="green", - file=sys.stderr, + + colored(f"llama stack run {template_path} --image-type {build_config.image_type}", "light_blue") ) return template_path else: diff --git a/llama_stack/cli/stack/build.py b/llama_stack/cli/stack/build.py index 2c402beeb..93e7d9b22 100644 --- a/llama_stack/cli/stack/build.py +++ b/llama_stack/cli/stack/build.py @@ -49,7 +49,7 @@ class StackBuild(Subcommand): type=str, help="Image Type to use for the build. If not specified, will use the image type from the template config.", choices=[e.value for e in ImageType], - default=None, # no default so we can detect if a user specified --image-type and override image_type in the config + default=ImageType.CONDA.value, ) self.parser.add_argument( diff --git a/llama_stack/cli/stack/list_stacks.py b/llama_stack/cli/stack/list_stacks.py deleted file mode 100644 index 2ea0fdeea..000000000 --- a/llama_stack/cli/stack/list_stacks.py +++ /dev/null @@ -1,56 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import argparse -from pathlib import Path - -from llama_stack.cli.subcommand import Subcommand -from llama_stack.cli.table import print_table - - -class StackListBuilds(Subcommand): - """List built stacks in .llama/distributions directory""" - - def __init__(self, subparsers: argparse._SubParsersAction): - super().__init__() - self.parser = subparsers.add_parser( - "list", - prog="llama stack list", - description="list the build stacks", - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - self._add_arguments() - self.parser.set_defaults(func=self._list_stack_command) - - def _get_distribution_dirs(self) -> dict[str, Path]: - """Return a dictionary of distribution names and their paths""" - distributions = {} - dist_dir = Path.home() / ".llama" / "distributions" - - if dist_dir.exists(): - for stack_dir in dist_dir.iterdir(): - if stack_dir.is_dir(): - distributions[stack_dir.name] = stack_dir - return distributions - - def _list_stack_command(self, args: argparse.Namespace) -> None: - distributions = self._get_distribution_dirs() - - if not distributions: - print("No stacks found in ~/.llama/distributions") - return - - headers = ["Stack Name", "Path"] - headers.extend(["Build Config", "Run Config"]) - rows = [] - for name, path in distributions.items(): - row = [name, str(path)] - # Check for build and run config files - build_config = "Yes" if (path / f"{name}-build.yaml").exists() else "No" - run_config = "Yes" if (path / f"{name}-run.yaml").exists() else "No" - row.extend([build_config, run_config]) - rows.append(row) - print_table(rows, headers, separate_rows=True) diff --git a/llama_stack/cli/stack/remove.py b/llama_stack/cli/stack/remove.py deleted file mode 100644 index a1796941e..000000000 --- a/llama_stack/cli/stack/remove.py +++ /dev/null @@ -1,115 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import argparse -import shutil -import sys -from pathlib import Path - -from termcolor import cprint - -from llama_stack.cli.subcommand import Subcommand -from llama_stack.cli.table import print_table - - -class StackRemove(Subcommand): - """Remove the build stack""" - - def __init__(self, subparsers: argparse._SubParsersAction): - super().__init__() - self.parser = subparsers.add_parser( - "rm", - prog="llama stack rm", - description="Remove the build stack", - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - self._add_arguments() - self.parser.set_defaults(func=self._remove_stack_build_command) - - def _add_arguments(self) -> None: - self.parser.add_argument( - "name", - type=str, - nargs="?", - help="Name of the stack to delete", - ) - self.parser.add_argument( - "--all", - "-a", - action="store_true", - help="Delete all stacks (use with caution)", - ) - - def _get_distribution_dirs(self) -> dict[str, Path]: - """Return a dictionary of distribution names and their paths""" - distributions = {} - dist_dir = Path.home() / ".llama" / "distributions" - - if dist_dir.exists(): - for stack_dir in dist_dir.iterdir(): - if stack_dir.is_dir(): - distributions[stack_dir.name] = stack_dir - return distributions - - def _list_stacks(self) -> None: - """Display available stacks in a table""" - distributions = self._get_distribution_dirs() - if not distributions: - cprint("No stacks found in ~/.llama/distributions", color="red", file=sys.stderr) - sys.exit(1) - - headers = ["Stack Name", "Path"] - rows = [[name, str(path)] for name, path in distributions.items()] - print_table(rows, headers, separate_rows=True) - - def _remove_stack_build_command(self, args: argparse.Namespace) -> None: - distributions = self._get_distribution_dirs() - - if args.all: - confirm = input("Are you sure you want to delete ALL stacks? [yes-i-really-want/N] ").lower() - if confirm != "yes-i-really-want": - cprint("Deletion cancelled.", color="green", file=sys.stderr) - return - - for name, path in distributions.items(): - try: - shutil.rmtree(path) - cprint(f"Deleted stack: {name}", color="green", file=sys.stderr) - except Exception as e: - cprint( - f"Failed to delete stack {name}: {e}", - color="red", - file=sys.stderr, - ) - sys.exit(1) - - if not args.name: - self._list_stacks() - if not args.name: - return - - if args.name not in distributions: - self._list_stacks() - cprint( - f"Stack not found: {args.name}", - color="red", - file=sys.stderr, - ) - sys.exit(1) - - stack_path = distributions[args.name] - - confirm = input(f"Are you sure you want to delete stack '{args.name}'? [y/N] ").lower() - if confirm != "y": - cprint("Deletion cancelled.", color="green", file=sys.stderr) - return - - try: - shutil.rmtree(stack_path) - cprint(f"Successfully deleted stack: {args.name}", color="green", file=sys.stderr) - except Exception as e: - cprint(f"Failed to delete stack {args.name}: {e}", color="red", file=sys.stderr) - sys.exit(1) diff --git a/llama_stack/cli/stack/run.py b/llama_stack/cli/stack/run.py index 27745edac..4a44e0366 100644 --- a/llama_stack/cli/stack/run.py +++ b/llama_stack/cli/stack/run.py @@ -6,7 +6,6 @@ import argparse import os -import subprocess from pathlib import Path from llama_stack.cli.stack.utils import ImageType @@ -61,11 +60,6 @@ class StackRun(Subcommand): help="Image Type used during the build. This can be either conda or container or venv.", choices=[e.value for e in ImageType], ) - self.parser.add_argument( - "--enable-ui", - action="store_true", - help="Start the UI server", - ) # If neither image type nor image name is provided, but at the same time # the current environment has conda breadcrumbs, then assume what the user @@ -89,8 +83,6 @@ class StackRun(Subcommand): from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR from llama_stack.distribution.utils.exec import formulate_run_args, run_command - if args.enable_ui: - self._start_ui_development_server(args.port) image_type, image_name = self._get_image_type_and_name(args) # Check if config is required based on image type @@ -178,44 +170,3 @@ class StackRun(Subcommand): run_args.extend(["--env", f"{key}={value}"]) run_command(run_args) - - def _start_ui_development_server(self, stack_server_port: int): - logger.info("Attempting to start UI development server...") - # Check if npm is available - npm_check = subprocess.run(["npm", "--version"], capture_output=True, text=True, check=False) - if npm_check.returncode != 0: - logger.warning( - f"'npm' command not found or not executable. UI development server will not be started. Error: {npm_check.stderr}" - ) - return - - ui_dir = REPO_ROOT / "llama_stack" / "ui" - logs_dir = Path("~/.llama/ui/logs").expanduser() - try: - # Create logs directory if it doesn't exist - logs_dir.mkdir(parents=True, exist_ok=True) - - ui_stdout_log_path = logs_dir / "stdout.log" - ui_stderr_log_path = logs_dir / "stderr.log" - - # Open log files in append mode - stdout_log_file = open(ui_stdout_log_path, "a") - stderr_log_file = open(ui_stderr_log_path, "a") - - process = subprocess.Popen( - ["npm", "run", "dev"], - cwd=str(ui_dir), - stdout=stdout_log_file, - stderr=stderr_log_file, - env={**os.environ, "NEXT_PUBLIC_LLAMA_STACK_BASE_URL": f"http://localhost:{stack_server_port}"}, - ) - logger.info(f"UI development server process started in {ui_dir} with PID {process.pid}.") - logger.info(f"Logs: stdout -> {ui_stdout_log_path}, stderr -> {ui_stderr_log_path}") - logger.info(f"UI will be available at http://localhost:{os.getenv('LLAMA_STACK_UI_PORT', 8322)}") - - except FileNotFoundError: - logger.error( - "Failed to start UI development server: 'npm' command not found. Make sure npm is installed and in your PATH." - ) - except Exception as e: - logger.error(f"Failed to start UI development server in {ui_dir}: {e}") diff --git a/llama_stack/cli/stack/stack.py b/llama_stack/cli/stack/stack.py index 3aff78e23..ccf1a5ffc 100644 --- a/llama_stack/cli/stack/stack.py +++ b/llama_stack/cli/stack/stack.py @@ -7,14 +7,12 @@ import argparse from importlib.metadata import version -from llama_stack.cli.stack.list_stacks import StackListBuilds from llama_stack.cli.stack.utils import print_subcommand_description from llama_stack.cli.subcommand import Subcommand from .build import StackBuild from .list_apis import StackListApis from .list_providers import StackListProviders -from .remove import StackRemove from .run import StackRun @@ -43,6 +41,5 @@ class StackParser(Subcommand): StackListApis.create(subparsers) StackListProviders.create(subparsers) StackRun.create(subparsers) - StackRemove.create(subparsers) - StackListBuilds.create(subparsers) + print_subcommand_description(self.parser, subparsers) diff --git a/llama_stack/distribution/build.py b/llama_stack/distribution/build.py index 072f9c425..1d39063f0 100644 --- a/llama_stack/distribution/build.py +++ b/llama_stack/distribution/build.py @@ -6,7 +6,6 @@ import importlib.resources import logging -import sys from pathlib import Path from pydantic import BaseModel @@ -44,20 +43,8 @@ def get_provider_dependencies( # Extract providers based on config type if isinstance(config, DistributionTemplate): providers = config.providers - - # TODO: This is a hack to get the dependencies for internal APIs into build - # We should have a better way to do this by formalizing the concept of "internal" APIs - # and providers, with a way to specify dependencies for them. - run_configs = config.run_configs - additional_pip_packages: list[str] = [] - if run_configs: - for run_config in run_configs.values(): - run_config_ = run_config.run_config(name="", providers={}, container_image=None) - if run_config_.inference_store: - additional_pip_packages.extend(run_config_.inference_store.pip_packages) elif isinstance(config, BuildConfig): providers = config.distribution_spec.providers - additional_pip_packages = config.additional_pip_packages deps = [] registry = get_provider_registry(config) for api_str, provider_or_providers in providers.items(): @@ -85,9 +72,6 @@ def get_provider_dependencies( else: normal_deps.append(package) - if additional_pip_packages: - normal_deps.extend(additional_pip_packages) - return list(set(normal_deps)), list(set(special_deps)) @@ -96,11 +80,10 @@ def print_pip_install_help(config: BuildConfig): cprint( f"Please install needed dependencies using the following commands:\n\nuv pip install {' '.join(normal_deps)}", - color="yellow", - file=sys.stderr, + "yellow", ) for special_dep in special_deps: - cprint(f"uv pip install {special_dep}", color="yellow", file=sys.stderr) + cprint(f"uv pip install {special_dep}", "yellow") print() diff --git a/llama_stack/distribution/client.py b/llama_stack/distribution/client.py index 03e4fb051..9fde8a157 100644 --- a/llama_stack/distribution/client.py +++ b/llama_stack/distribution/client.py @@ -6,7 +6,6 @@ import inspect import json -import sys from collections.abc import AsyncIterator from enum import Enum from typing import Any, Union, get_args, get_origin @@ -97,13 +96,13 @@ def create_api_client_class(protocol) -> type: try: data = json.loads(data) if "error" in data: - cprint(data, color="red", file=sys.stderr) + cprint(data, "red") continue yield parse_obj_as(return_type, data) except Exception as e: - cprint(f"Error with parsing or validation: {e}", color="red", file=sys.stderr) - cprint(data, color="red", file=sys.stderr) + print(f"Error with parsing or validation: {e}") + print(data) def httpx_request_params(self, method_name: str, *args, **kwargs) -> dict: webmethod, sig = self.routes[method_name] diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py index def7048c0..783a48de3 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -25,8 +25,7 @@ from llama_stack.apis.tools import Tool, ToolGroup, ToolGroupInput, ToolRuntime from llama_stack.apis.vector_dbs import VectorDB, VectorDBInput from llama_stack.apis.vector_io import VectorIO from llama_stack.providers.datatypes import Api, ProviderSpec -from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig -from llama_stack.providers.utils.sqlstore.sqlstore import SqlStoreConfig +from llama_stack.providers.utils.kvstore.config import KVStoreConfig LLAMA_STACK_BUILD_CONFIG_VERSION = "2" LLAMA_STACK_RUN_CONFIG_VERSION = "2" @@ -221,38 +220,21 @@ class LoggingConfig(BaseModel): class AuthProviderType(str, Enum): """Supported authentication provider types.""" - OAUTH2_TOKEN = "oauth2_token" + KUBERNETES = "kubernetes" CUSTOM = "custom" class AuthenticationConfig(BaseModel): provider_type: AuthProviderType = Field( ..., - description="Type of authentication provider", + description="Type of authentication provider (e.g., 'kubernetes', 'custom')", ) - config: dict[str, Any] = Field( + config: dict[str, str] = Field( ..., description="Provider-specific configuration", ) -class AuthenticationRequiredError(Exception): - pass - - -class QuotaPeriod(str, Enum): - DAY = "day" - - -class QuotaConfig(BaseModel): - kvstore: SqliteKVStoreConfig = Field(description="Config for KV store backend (SQLite only for now)") - anonymous_max_requests: int = Field(default=100, description="Max requests for unauthenticated clients per period") - authenticated_max_requests: int = Field( - default=1000, description="Max requests for authenticated clients per period" - ) - period: QuotaPeriod = Field(default=QuotaPeriod.DAY, description="Quota period to set") - - class ServerConfig(BaseModel): port: int = Field( default=8321, @@ -280,10 +262,6 @@ class ServerConfig(BaseModel): default=None, description="The host the server should listen on", ) - quota: QuotaConfig | None = Field( - default=None, - description="Per client quota request configuration", - ) class StackRunConfig(BaseModel): @@ -319,13 +297,6 @@ Configuration for the persistence store used by the distribution registry. If no a default SQLite store will be used.""", ) - inference_store: SqlStoreConfig | None = Field( - default=None, - description=""" -Configuration for the persistence store used by the inference API. If not specified, -a default SQLite store will be used.""", - ) - # registry of "resources" in the distribution models: list[ModelInput] = Field(default_factory=list) shields: list[ShieldInput] = Field(default_factory=list) @@ -369,21 +340,8 @@ class BuildConfig(BaseModel): default=None, description="Name of the distribution to build", ) - external_providers_dir: Path | None = Field( + external_providers_dir: str | None = Field( default=None, description="Path to directory containing external provider implementations. The providers packages will be resolved from this directory. " "pip_packages MUST contain the provider package name.", ) - additional_pip_packages: list[str] = Field( - default_factory=list, - description="Additional pip packages to install in the distribution. These packages will be installed in the distribution environment.", - ) - - @field_validator("external_providers_dir") - @classmethod - def validate_external_providers_dir(cls, v): - if v is None: - return None - if isinstance(v, str): - return Path(v) - return v diff --git a/llama_stack/distribution/inspect.py b/llama_stack/distribution/inspect.py index 5822070ad..23f644ec6 100644 --- a/llama_stack/distribution/inspect.py +++ b/llama_stack/distribution/inspect.py @@ -16,7 +16,7 @@ from llama_stack.apis.inspect import ( VersionInfo, ) from llama_stack.distribution.datatypes import StackRunConfig -from llama_stack.distribution.server.routes import get_all_api_routes +from llama_stack.distribution.server.endpoints import get_all_api_endpoints from llama_stack.providers.datatypes import HealthStatus @@ -31,7 +31,7 @@ async def get_provider_impl(config, deps): class DistributionInspectImpl(Inspect): - def __init__(self, config: DistributionInspectConfig, deps): + def __init__(self, config, deps): self.config = config self.deps = deps @@ -39,36 +39,22 @@ class DistributionInspectImpl(Inspect): pass async def list_routes(self) -> ListRoutesResponse: - run_config: StackRunConfig = self.config.run_config + run_config = self.config.run_config ret = [] - all_endpoints = get_all_api_routes() + all_endpoints = get_all_api_endpoints() for api, endpoints in all_endpoints.items(): - # Always include provider and inspect APIs, filter others based on run config - if api.value in ["providers", "inspect"]: - ret.extend( - [ - RouteInfo( - route=e.path, - method=next(iter([m for m in e.methods if m != "HEAD"])), - provider_types=[], # These APIs don't have "real" providers - they're internal to the stack - ) - for e in endpoints - ] - ) - else: - providers = run_config.providers.get(api.value, []) - if providers: # Only process if there are providers for this API - ret.extend( - [ - RouteInfo( - route=e.path, - method=next(iter([m for m in e.methods if m != "HEAD"])), - provider_types=[p.provider_type for p in providers], - ) - for e in endpoints - ] + providers = run_config.providers.get(api.value, []) + ret.extend( + [ + RouteInfo( + route=e.route, + method=e.method, + provider_types=[p.provider_type for p in providers], ) + for e in endpoints + ] + ) return ListRoutesResponse(data=ret) diff --git a/llama_stack/distribution/library_client.py b/llama_stack/distribution/library_client.py index f32130cf9..8e5445874 100644 --- a/llama_stack/distribution/library_client.py +++ b/llama_stack/distribution/library_client.py @@ -9,7 +9,6 @@ import inspect import json import logging import os -import sys from concurrent.futures import ThreadPoolExecutor from enum import Enum from pathlib import Path @@ -37,7 +36,10 @@ from llama_stack.distribution.request_headers import ( request_provider_data_context, ) from llama_stack.distribution.resolver import ProviderRegistry -from llama_stack.distribution.server.routes import find_matching_route, initialize_route_impls +from llama_stack.distribution.server.endpoints import ( + find_matching_endpoint, + initialize_endpoint_impls, +) from llama_stack.distribution.stack import ( construct_stack, get_stack_run_config_from_template, @@ -205,14 +207,13 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): async def initialize(self) -> bool: try: - self.route_impls = None + self.endpoint_impls = None self.impls = await construct_stack(self.config, self.custom_provider_registry) except ModuleNotFoundError as _e: - cprint(_e.msg, color="red", file=sys.stderr) + cprint(_e.msg, "red") cprint( "Using llama-stack as a library requires installing dependencies depending on the template (providers) you choose.\n", - color="yellow", - file=sys.stderr, + "yellow", ) if self.config_path_or_template_name.endswith(".yaml"): # Convert Provider objects to their types @@ -225,7 +226,6 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): distribution_spec=DistributionSpec( providers=provider_types, ), - external_providers_dir=self.config.external_providers_dir, ) print_pip_install_help(build_config) else: @@ -233,13 +233,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): cprint( f"Please run:\n\n{prefix}llama stack build --template {self.config_path_or_template_name} --image-type venv\n\n", "yellow", - file=sys.stderr, ) - cprint( - "Please check your internet connection and try again.", - "red", - file=sys.stderr, - ) raise _e if Api.telemetry in self.impls: @@ -251,7 +245,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): safe_config = redact_sensitive_fields(self.config.model_dump()) console.print(yaml.dump(safe_config, indent=2)) - self.route_impls = initialize_route_impls(self.impls) + self.endpoint_impls = initialize_endpoint_impls(self.impls) return True async def request( @@ -262,15 +256,13 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): stream=False, stream_cls=None, ): - if not self.route_impls: + if not self.endpoint_impls: raise ValueError("Client not initialized") # Create headers with provider data if available - headers = options.headers or {} + headers = {} if self.provider_data: - keys = ["X-LlamaStack-Provider-Data", "x-llamastack-provider-data"] - if all(key not in headers for key in keys): - headers["X-LlamaStack-Provider-Data"] = json.dumps(self.provider_data) + headers["X-LlamaStack-Provider-Data"] = json.dumps(self.provider_data) # Use context manager for provider data with request_provider_data_context(headers): @@ -293,14 +285,11 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): cast_to: Any, options: Any, ): - if self.route_impls is None: - raise ValueError("Client not initialized") - path = options.url body = options.params or {} body |= options.json_data or {} - matched_func, path_params, route = find_matching_route(options.method, path, self.route_impls) + matched_func, path_params, route = find_matching_endpoint(options.method, path, self.endpoint_impls) body |= path_params body = self._convert_body(path, options.method, body) await start_trace(route, {"__location__": "library_client"}) @@ -342,13 +331,10 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): options: Any, stream_cls: Any, ): - if self.route_impls is None: - raise ValueError("Client not initialized") - path = options.url body = options.params or {} body |= options.json_data or {} - func, path_params, route = find_matching_route(options.method, path, self.route_impls) + func, path_params, route = find_matching_endpoint(options.method, path, self.endpoint_impls) body |= path_params body = self._convert_body(path, options.method, body) @@ -400,10 +386,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): if not body: return {} - if self.route_impls is None: - raise ValueError("Client not initialized") - - func, _, _ = find_matching_route(method, path, self.route_impls) + func, _, _ = find_matching_endpoint(method, path, self.endpoint_impls) sig = inspect.signature(func) # Strip NOT_GIVENs to use the defaults in signature diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index b7c7cb87f..37588ea64 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -13,7 +13,7 @@ from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasets import Datasets from llama_stack.apis.eval import Eval from llama_stack.apis.files import Files -from llama_stack.apis.inference import Inference, InferenceProvider +from llama_stack.apis.inference import Inference from llama_stack.apis.inspect import Inspect from llama_stack.apis.models import Models from llama_stack.apis.post_training import PostTraining @@ -47,7 +47,7 @@ from llama_stack.providers.datatypes import ( RemoteProviderSpec, ScoringFunctionsProtocolPrivate, ShieldsProtocolPrivate, - ToolGroupsProtocolPrivate, + ToolsProtocolPrivate, VectorDBsProtocolPrivate, ) @@ -83,17 +83,10 @@ def api_protocol_map() -> dict[Api, Any]: } -def api_protocol_map_for_compliance_check() -> dict[Api, Any]: - return { - **api_protocol_map(), - Api.inference: InferenceProvider, - } - - def additional_protocols_map() -> dict[Api, Any]: return { Api.inference: (ModelsProtocolPrivate, Models, Api.models), - Api.tool_groups: (ToolGroupsProtocolPrivate, ToolGroups, Api.tool_groups), + Api.tool_groups: (ToolsProtocolPrivate, ToolGroups, Api.tool_groups), Api.vector_io: (VectorDBsProtocolPrivate, VectorDBs, Api.vector_dbs), Api.safety: (ShieldsProtocolPrivate, Shields, Api.shields), Api.datasetio: (DatasetsProtocolPrivate, Datasets, Api.datasets), @@ -140,7 +133,7 @@ async def resolve_impls( sorted_providers = sort_providers_by_deps(providers_with_specs, run_config) - return await instantiate_providers(sorted_providers, router_apis, dist_registry, run_config) + return await instantiate_providers(sorted_providers, router_apis, dist_registry) def specs_for_autorouted_apis(apis_to_serve: list[str] | set[str]) -> dict[str, dict[str, ProviderWithSpec]]: @@ -243,10 +236,7 @@ def sort_providers_by_deps( async def instantiate_providers( - sorted_providers: list[tuple[str, ProviderWithSpec]], - router_apis: set[Api], - dist_registry: DistributionRegistry, - run_config: StackRunConfig, + sorted_providers: list[tuple[str, ProviderWithSpec]], router_apis: set[Api], dist_registry: DistributionRegistry ) -> dict: """Instantiates providers asynchronously while managing dependencies.""" impls: dict[Api, Any] = {} @@ -261,7 +251,7 @@ async def instantiate_providers( if isinstance(provider.spec, RoutingTableProviderSpec): inner_impls = inner_impls_by_provider_id[f"inner-{provider.spec.router_api.value}"] - impl = await instantiate_provider(provider, deps, inner_impls, dist_registry, run_config) + impl = await instantiate_provider(provider, deps, inner_impls, dist_registry) if api_str.startswith("inner-"): inner_impls_by_provider_id[api_str][provider.provider_id] = impl @@ -311,8 +301,10 @@ async def instantiate_provider( deps: dict[Api, Any], inner_impls: dict[str, Any], dist_registry: DistributionRegistry, - run_config: StackRunConfig, ): + protocols = api_protocol_map() + additional_protocols = additional_protocols_map() + provider_spec = provider.spec if not hasattr(provider_spec, "module"): raise AttributeError(f"ProviderSpec of type {type(provider_spec)} does not have a 'module' attribute") @@ -331,7 +323,7 @@ async def instantiate_provider( method = "get_auto_router_impl" config = None - args = [provider_spec.api, deps[provider_spec.routing_table_api], deps, run_config] + args = [provider_spec.api, deps[provider_spec.routing_table_api], deps] elif isinstance(provider_spec, RoutingTableProviderSpec): method = "get_routing_table_impl" @@ -350,8 +342,6 @@ async def instantiate_provider( impl.__provider_spec__ = provider_spec impl.__provider_config__ = config - protocols = api_protocol_map_for_compliance_check() - additional_protocols = additional_protocols_map() # TODO: check compliance for special tool groups # the impl should be for Api.tool_runtime, the name should be the special tool group, the protocol should be the special tool group protocol check_protocol_compliance(impl, protocols[provider_spec.api]) diff --git a/llama_stack/distribution/routers/__init__.py b/llama_stack/distribution/routers/__init__.py index 1358d5812..cd2a296f2 100644 --- a/llama_stack/distribution/routers/__init__.py +++ b/llama_stack/distribution/routers/__init__.py @@ -7,10 +7,18 @@ from typing import Any from llama_stack.distribution.datatypes import RoutedProtocol -from llama_stack.distribution.stack import StackRunConfig from llama_stack.distribution.store import DistributionRegistry from llama_stack.providers.datatypes import Api, RoutingTable -from llama_stack.providers.utils.inference.inference_store import InferenceStore + +from .routing_tables import ( + BenchmarksRoutingTable, + DatasetsRoutingTable, + ModelsRoutingTable, + ScoringFunctionsRoutingTable, + ShieldsRoutingTable, + ToolGroupsRoutingTable, + VectorDBsRoutingTable, +) async def get_routing_table_impl( @@ -19,14 +27,6 @@ async def get_routing_table_impl( _deps, dist_registry: DistributionRegistry, ) -> Any: - from ..routing_tables.benchmarks import BenchmarksRoutingTable - from ..routing_tables.datasets import DatasetsRoutingTable - from ..routing_tables.models import ModelsRoutingTable - from ..routing_tables.scoring_functions import ScoringFunctionsRoutingTable - from ..routing_tables.shields import ShieldsRoutingTable - from ..routing_tables.toolgroups import ToolGroupsRoutingTable - from ..routing_tables.vector_dbs import VectorDBsRoutingTable - api_to_tables = { "vector_dbs": VectorDBsRoutingTable, "models": ModelsRoutingTable, @@ -45,15 +45,16 @@ async def get_routing_table_impl( return impl -async def get_auto_router_impl( - api: Api, routing_table: RoutingTable, deps: dict[str, Any], run_config: StackRunConfig -) -> Any: - from .datasets import DatasetIORouter - from .eval_scoring import EvalRouter, ScoringRouter - from .inference import InferenceRouter - from .safety import SafetyRouter - from .tool_runtime import ToolRuntimeRouter - from .vector_io import VectorIORouter +async def get_auto_router_impl(api: Api, routing_table: RoutingTable, deps: dict[str, Any]) -> Any: + from .routers import ( + DatasetIORouter, + EvalRouter, + InferenceRouter, + SafetyRouter, + ScoringRouter, + ToolRuntimeRouter, + VectorIORouter, + ) api_to_routers = { "vector_io": VectorIORouter, @@ -75,12 +76,6 @@ async def get_auto_router_impl( if dep_api in deps: api_to_dep_impl[dep_name] = deps[dep_api] - # TODO: move pass configs to routers instead - if api == Api.inference and run_config.inference_store: - inference_store = InferenceStore(run_config.inference_store) - await inference_store.initialize() - api_to_dep_impl["store"] = inference_store - impl = api_to_routers[api.value](routing_table, **api_to_dep_impl) await impl.initialize() return impl diff --git a/llama_stack/distribution/routers/datasets.py b/llama_stack/distribution/routers/datasets.py deleted file mode 100644 index 6f28756c9..000000000 --- a/llama_stack/distribution/routers/datasets.py +++ /dev/null @@ -1,71 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from typing import Any - -from llama_stack.apis.common.responses import PaginatedResponse -from llama_stack.apis.datasetio import DatasetIO -from llama_stack.apis.datasets import DatasetPurpose, DataSource -from llama_stack.log import get_logger -from llama_stack.providers.datatypes import RoutingTable - -logger = get_logger(name=__name__, category="core") - - -class DatasetIORouter(DatasetIO): - def __init__( - self, - routing_table: RoutingTable, - ) -> None: - logger.debug("Initializing DatasetIORouter") - self.routing_table = routing_table - - async def initialize(self) -> None: - logger.debug("DatasetIORouter.initialize") - pass - - async def shutdown(self) -> None: - logger.debug("DatasetIORouter.shutdown") - pass - - async def register_dataset( - self, - purpose: DatasetPurpose, - source: DataSource, - metadata: dict[str, Any] | None = None, - dataset_id: str | None = None, - ) -> None: - logger.debug( - f"DatasetIORouter.register_dataset: {purpose=} {source=} {metadata=} {dataset_id=}", - ) - await self.routing_table.register_dataset( - purpose=purpose, - source=source, - metadata=metadata, - dataset_id=dataset_id, - ) - - async def iterrows( - self, - dataset_id: str, - start_index: int | None = None, - limit: int | None = None, - ) -> PaginatedResponse: - logger.debug( - f"DatasetIORouter.iterrows: {dataset_id}, {start_index=} {limit=}", - ) - return await self.routing_table.get_provider_impl(dataset_id).iterrows( - dataset_id=dataset_id, - start_index=start_index, - limit=limit, - ) - - async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None: - logger.debug(f"DatasetIORouter.append_rows: {dataset_id}, {len(rows)} rows") - return await self.routing_table.get_provider_impl(dataset_id).append_rows( - dataset_id=dataset_id, - rows=rows, - ) diff --git a/llama_stack/distribution/routers/eval_scoring.py b/llama_stack/distribution/routers/eval_scoring.py deleted file mode 100644 index fd0bb90a7..000000000 --- a/llama_stack/distribution/routers/eval_scoring.py +++ /dev/null @@ -1,148 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from typing import Any - -from llama_stack.apis.eval import BenchmarkConfig, Eval, EvaluateResponse, Job -from llama_stack.apis.scoring import ( - ScoreBatchResponse, - ScoreResponse, - Scoring, - ScoringFnParams, -) -from llama_stack.log import get_logger -from llama_stack.providers.datatypes import RoutingTable - -logger = get_logger(name=__name__, category="core") - - -class ScoringRouter(Scoring): - def __init__( - self, - routing_table: RoutingTable, - ) -> None: - logger.debug("Initializing ScoringRouter") - self.routing_table = routing_table - - async def initialize(self) -> None: - logger.debug("ScoringRouter.initialize") - pass - - async def shutdown(self) -> None: - logger.debug("ScoringRouter.shutdown") - pass - - async def score_batch( - self, - dataset_id: str, - scoring_functions: dict[str, ScoringFnParams | None] = None, - save_results_dataset: bool = False, - ) -> ScoreBatchResponse: - logger.debug(f"ScoringRouter.score_batch: {dataset_id}") - res = {} - for fn_identifier in scoring_functions.keys(): - score_response = await self.routing_table.get_provider_impl(fn_identifier).score_batch( - dataset_id=dataset_id, - scoring_functions={fn_identifier: scoring_functions[fn_identifier]}, - ) - res.update(score_response.results) - - if save_results_dataset: - raise NotImplementedError("Save results dataset not implemented yet") - - return ScoreBatchResponse( - results=res, - ) - - async def score( - self, - input_rows: list[dict[str, Any]], - scoring_functions: dict[str, ScoringFnParams | None] = None, - ) -> ScoreResponse: - logger.debug(f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions") - res = {} - # look up and map each scoring function to its provider impl - for fn_identifier in scoring_functions.keys(): - score_response = await self.routing_table.get_provider_impl(fn_identifier).score( - input_rows=input_rows, - scoring_functions={fn_identifier: scoring_functions[fn_identifier]}, - ) - res.update(score_response.results) - - return ScoreResponse(results=res) - - -class EvalRouter(Eval): - def __init__( - self, - routing_table: RoutingTable, - ) -> None: - logger.debug("Initializing EvalRouter") - self.routing_table = routing_table - - async def initialize(self) -> None: - logger.debug("EvalRouter.initialize") - pass - - async def shutdown(self) -> None: - logger.debug("EvalRouter.shutdown") - pass - - async def run_eval( - self, - benchmark_id: str, - benchmark_config: BenchmarkConfig, - ) -> Job: - logger.debug(f"EvalRouter.run_eval: {benchmark_id}") - return await self.routing_table.get_provider_impl(benchmark_id).run_eval( - benchmark_id=benchmark_id, - benchmark_config=benchmark_config, - ) - - async def evaluate_rows( - self, - benchmark_id: str, - input_rows: list[dict[str, Any]], - scoring_functions: list[str], - benchmark_config: BenchmarkConfig, - ) -> EvaluateResponse: - logger.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows") - return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows( - benchmark_id=benchmark_id, - input_rows=input_rows, - scoring_functions=scoring_functions, - benchmark_config=benchmark_config, - ) - - async def job_status( - self, - benchmark_id: str, - job_id: str, - ) -> Job: - logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}") - return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id) - - async def job_cancel( - self, - benchmark_id: str, - job_id: str, - ) -> None: - logger.debug(f"EvalRouter.job_cancel: {benchmark_id}, {job_id}") - await self.routing_table.get_provider_impl(benchmark_id).job_cancel( - benchmark_id, - job_id, - ) - - async def job_result( - self, - benchmark_id: str, - job_id: str, - ) -> EvaluateResponse: - logger.debug(f"EvalRouter.job_result: {benchmark_id}, {job_id}") - return await self.routing_table.get_provider_impl(benchmark_id).job_result( - benchmark_id, - job_id, - ) diff --git a/llama_stack/distribution/routers/inference.py b/llama_stack/distribution/routers/routers.py similarity index 65% rename from llama_stack/distribution/routers/inference.py rename to llama_stack/distribution/routers/routers.py index 763bd9105..371d34904 100644 --- a/llama_stack/distribution/routers/inference.py +++ b/llama_stack/distribution/routers/routers.py @@ -14,9 +14,14 @@ from openai.types.chat import ChatCompletionToolParam as OpenAIChatCompletionToo from pydantic import Field, TypeAdapter from llama_stack.apis.common.content_types import ( + URL, InterleavedContent, InterleavedContentItem, ) +from llama_stack.apis.common.responses import PaginatedResponse +from llama_stack.apis.datasetio import DatasetIO +from llama_stack.apis.datasets import DatasetPurpose, DataSource +from llama_stack.apis.eval import BenchmarkConfig, Eval, EvaluateResponse, Job from llama_stack.apis.inference import ( BatchChatCompletionResponse, BatchCompletionResponse, @@ -27,11 +32,8 @@ from llama_stack.apis.inference import ( EmbeddingsResponse, EmbeddingTaskType, Inference, - ListOpenAIChatCompletionResponse, LogProbConfig, Message, - OpenAICompletionWithInputMessages, - Order, ResponseFormat, SamplingParams, StopReason, @@ -45,23 +47,93 @@ from llama_stack.apis.inference.inference import ( OpenAIChatCompletion, OpenAIChatCompletionChunk, OpenAICompletion, - OpenAIEmbeddingsResponse, OpenAIMessageParam, OpenAIResponseFormatParam, ) from llama_stack.apis.models import Model, ModelType +from llama_stack.apis.safety import RunShieldResponse, Safety +from llama_stack.apis.scoring import ( + ScoreBatchResponse, + ScoreResponse, + Scoring, + ScoringFnParams, +) +from llama_stack.apis.shields import Shield from llama_stack.apis.telemetry import MetricEvent, MetricInResponse, Telemetry +from llama_stack.apis.tools import ( + ListToolDefsResponse, + RAGDocument, + RAGQueryConfig, + RAGQueryResult, + RAGToolRuntime, + ToolRuntime, +) +from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO from llama_stack.log import get_logger from llama_stack.models.llama.llama3.chat_format import ChatFormat from llama_stack.models.llama.llama3.tokenizer import Tokenizer from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable -from llama_stack.providers.utils.inference.inference_store import InferenceStore -from llama_stack.providers.utils.inference.stream_utils import stream_and_store_openai_completion from llama_stack.providers.utils.telemetry.tracing import get_current_span logger = get_logger(name=__name__, category="core") +class VectorIORouter(VectorIO): + """Routes to an provider based on the vector db identifier""" + + def __init__( + self, + routing_table: RoutingTable, + ) -> None: + logger.debug("Initializing VectorIORouter") + self.routing_table = routing_table + + async def initialize(self) -> None: + logger.debug("VectorIORouter.initialize") + pass + + async def shutdown(self) -> None: + logger.debug("VectorIORouter.shutdown") + pass + + async def register_vector_db( + self, + vector_db_id: str, + embedding_model: str, + embedding_dimension: int | None = 384, + provider_id: str | None = None, + provider_vector_db_id: str | None = None, + ) -> None: + logger.debug(f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}") + await self.routing_table.register_vector_db( + vector_db_id, + embedding_model, + embedding_dimension, + provider_id, + provider_vector_db_id, + ) + + async def insert_chunks( + self, + vector_db_id: str, + chunks: list[Chunk], + ttl_seconds: int | None = None, + ) -> None: + logger.debug( + f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}", + ) + return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(vector_db_id, chunks, ttl_seconds) + + async def query_chunks( + self, + vector_db_id: str, + query: InterleavedContent, + params: dict[str, Any] | None = None, + ) -> QueryChunksResponse: + logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}") + return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params) + + class InferenceRouter(Inference): """Routes to an provider based on the model""" @@ -69,12 +141,10 @@ class InferenceRouter(Inference): self, routing_table: RoutingTable, telemetry: Telemetry | None = None, - store: InferenceStore | None = None, ) -> None: logger.debug("Initializing InferenceRouter") self.routing_table = routing_table self.telemetry = telemetry - self.store = store if self.telemetry: self.tokenizer = Tokenizer.get_instance() self.formatter = ChatFormat(self.tokenizer) @@ -537,59 +607,9 @@ class InferenceRouter(Inference): provider = self.routing_table.get_provider_impl(model_obj.identifier) if stream: - response_stream = await provider.openai_chat_completion(**params) - if self.store: - return stream_and_store_openai_completion(response_stream, model, self.store, messages) - return response_stream + return await provider.openai_chat_completion(**params) else: - response = await self._nonstream_openai_chat_completion(provider, params) - if self.store: - await self.store.store_chat_completion(response, messages) - return response - - async def openai_embeddings( - self, - model: str, - input: str | list[str], - encoding_format: str | None = "float", - dimensions: int | None = None, - user: str | None = None, - ) -> OpenAIEmbeddingsResponse: - logger.debug( - f"InferenceRouter.openai_embeddings: {model=}, input_type={type(input)}, {encoding_format=}, {dimensions=}", - ) - model_obj = await self.routing_table.get_model(model) - if model_obj is None: - raise ValueError(f"Model '{model}' not found") - if model_obj.model_type != ModelType.embedding: - raise ValueError(f"Model '{model}' is not an embedding model") - - params = dict( - model=model_obj.identifier, - input=input, - encoding_format=encoding_format, - dimensions=dimensions, - user=user, - ) - - provider = self.routing_table.get_provider_impl(model_obj.identifier) - return await provider.openai_embeddings(**params) - - async def list_chat_completions( - self, - after: str | None = None, - limit: int | None = 20, - model: str | None = None, - order: Order | None = Order.desc, - ) -> ListOpenAIChatCompletionResponse: - if self.store: - return await self.store.list_chat_completions(after, limit, model, order) - raise NotImplementedError("List chat completions is not supported: inference store is not configured.") - - async def get_chat_completion(self, completion_id: str) -> OpenAICompletionWithInputMessages: - if self.store: - return await self.store.get_chat_completion(completion_id) - raise NotImplementedError("Get chat completion is not supported: inference store is not configured.") + return await self._nonstream_openai_chat_completion(provider, params) async def _nonstream_openai_chat_completion(self, provider: Inference, params: dict) -> OpenAIChatCompletion: response = await provider.openai_chat_completion(**params) @@ -622,3 +642,295 @@ class InferenceRouter(Inference): status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}" ) return health_statuses + + +class SafetyRouter(Safety): + def __init__( + self, + routing_table: RoutingTable, + ) -> None: + logger.debug("Initializing SafetyRouter") + self.routing_table = routing_table + + async def initialize(self) -> None: + logger.debug("SafetyRouter.initialize") + pass + + async def shutdown(self) -> None: + logger.debug("SafetyRouter.shutdown") + pass + + async def register_shield( + self, + shield_id: str, + provider_shield_id: str | None = None, + provider_id: str | None = None, + params: dict[str, Any] | None = None, + ) -> Shield: + logger.debug(f"SafetyRouter.register_shield: {shield_id}") + return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params) + + async def run_shield( + self, + shield_id: str, + messages: list[Message], + params: dict[str, Any] = None, + ) -> RunShieldResponse: + logger.debug(f"SafetyRouter.run_shield: {shield_id}") + return await self.routing_table.get_provider_impl(shield_id).run_shield( + shield_id=shield_id, + messages=messages, + params=params, + ) + + +class DatasetIORouter(DatasetIO): + def __init__( + self, + routing_table: RoutingTable, + ) -> None: + logger.debug("Initializing DatasetIORouter") + self.routing_table = routing_table + + async def initialize(self) -> None: + logger.debug("DatasetIORouter.initialize") + pass + + async def shutdown(self) -> None: + logger.debug("DatasetIORouter.shutdown") + pass + + async def register_dataset( + self, + purpose: DatasetPurpose, + source: DataSource, + metadata: dict[str, Any] | None = None, + dataset_id: str | None = None, + ) -> None: + logger.debug( + f"DatasetIORouter.register_dataset: {purpose=} {source=} {metadata=} {dataset_id=}", + ) + await self.routing_table.register_dataset( + purpose=purpose, + source=source, + metadata=metadata, + dataset_id=dataset_id, + ) + + async def iterrows( + self, + dataset_id: str, + start_index: int | None = None, + limit: int | None = None, + ) -> PaginatedResponse: + logger.debug( + f"DatasetIORouter.iterrows: {dataset_id}, {start_index=} {limit=}", + ) + return await self.routing_table.get_provider_impl(dataset_id).iterrows( + dataset_id=dataset_id, + start_index=start_index, + limit=limit, + ) + + async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None: + logger.debug(f"DatasetIORouter.append_rows: {dataset_id}, {len(rows)} rows") + return await self.routing_table.get_provider_impl(dataset_id).append_rows( + dataset_id=dataset_id, + rows=rows, + ) + + +class ScoringRouter(Scoring): + def __init__( + self, + routing_table: RoutingTable, + ) -> None: + logger.debug("Initializing ScoringRouter") + self.routing_table = routing_table + + async def initialize(self) -> None: + logger.debug("ScoringRouter.initialize") + pass + + async def shutdown(self) -> None: + logger.debug("ScoringRouter.shutdown") + pass + + async def score_batch( + self, + dataset_id: str, + scoring_functions: dict[str, ScoringFnParams | None] = None, + save_results_dataset: bool = False, + ) -> ScoreBatchResponse: + logger.debug(f"ScoringRouter.score_batch: {dataset_id}") + res = {} + for fn_identifier in scoring_functions.keys(): + score_response = await self.routing_table.get_provider_impl(fn_identifier).score_batch( + dataset_id=dataset_id, + scoring_functions={fn_identifier: scoring_functions[fn_identifier]}, + ) + res.update(score_response.results) + + if save_results_dataset: + raise NotImplementedError("Save results dataset not implemented yet") + + return ScoreBatchResponse( + results=res, + ) + + async def score( + self, + input_rows: list[dict[str, Any]], + scoring_functions: dict[str, ScoringFnParams | None] = None, + ) -> ScoreResponse: + logger.debug(f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions") + res = {} + # look up and map each scoring function to its provider impl + for fn_identifier in scoring_functions.keys(): + score_response = await self.routing_table.get_provider_impl(fn_identifier).score( + input_rows=input_rows, + scoring_functions={fn_identifier: scoring_functions[fn_identifier]}, + ) + res.update(score_response.results) + + return ScoreResponse(results=res) + + +class EvalRouter(Eval): + def __init__( + self, + routing_table: RoutingTable, + ) -> None: + logger.debug("Initializing EvalRouter") + self.routing_table = routing_table + + async def initialize(self) -> None: + logger.debug("EvalRouter.initialize") + pass + + async def shutdown(self) -> None: + logger.debug("EvalRouter.shutdown") + pass + + async def run_eval( + self, + benchmark_id: str, + benchmark_config: BenchmarkConfig, + ) -> Job: + logger.debug(f"EvalRouter.run_eval: {benchmark_id}") + return await self.routing_table.get_provider_impl(benchmark_id).run_eval( + benchmark_id=benchmark_id, + benchmark_config=benchmark_config, + ) + + async def evaluate_rows( + self, + benchmark_id: str, + input_rows: list[dict[str, Any]], + scoring_functions: list[str], + benchmark_config: BenchmarkConfig, + ) -> EvaluateResponse: + logger.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows") + return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows( + benchmark_id=benchmark_id, + input_rows=input_rows, + scoring_functions=scoring_functions, + benchmark_config=benchmark_config, + ) + + async def job_status( + self, + benchmark_id: str, + job_id: str, + ) -> Job: + logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}") + return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id) + + async def job_cancel( + self, + benchmark_id: str, + job_id: str, + ) -> None: + logger.debug(f"EvalRouter.job_cancel: {benchmark_id}, {job_id}") + await self.routing_table.get_provider_impl(benchmark_id).job_cancel( + benchmark_id, + job_id, + ) + + async def job_result( + self, + benchmark_id: str, + job_id: str, + ) -> EvaluateResponse: + logger.debug(f"EvalRouter.job_result: {benchmark_id}, {job_id}") + return await self.routing_table.get_provider_impl(benchmark_id).job_result( + benchmark_id, + job_id, + ) + + +class ToolRuntimeRouter(ToolRuntime): + class RagToolImpl(RAGToolRuntime): + def __init__( + self, + routing_table: RoutingTable, + ) -> None: + logger.debug("Initializing ToolRuntimeRouter.RagToolImpl") + self.routing_table = routing_table + + async def query( + self, + content: InterleavedContent, + vector_db_ids: list[str], + query_config: RAGQueryConfig | None = None, + ) -> RAGQueryResult: + logger.debug(f"ToolRuntimeRouter.RagToolImpl.query: {vector_db_ids}") + return await self.routing_table.get_provider_impl("knowledge_search").query( + content, vector_db_ids, query_config + ) + + async def insert( + self, + documents: list[RAGDocument], + vector_db_id: str, + chunk_size_in_tokens: int = 512, + ) -> None: + logger.debug( + f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}" + ) + return await self.routing_table.get_provider_impl("insert_into_memory").insert( + documents, vector_db_id, chunk_size_in_tokens + ) + + def __init__( + self, + routing_table: RoutingTable, + ) -> None: + logger.debug("Initializing ToolRuntimeRouter") + self.routing_table = routing_table + + # HACK ALERT this should be in sync with "get_all_api_endpoints()" + self.rag_tool = self.RagToolImpl(routing_table) + for method in ("query", "insert"): + setattr(self, f"rag_tool.{method}", getattr(self.rag_tool, method)) + + async def initialize(self) -> None: + logger.debug("ToolRuntimeRouter.initialize") + pass + + async def shutdown(self) -> None: + logger.debug("ToolRuntimeRouter.shutdown") + pass + + async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> Any: + logger.debug(f"ToolRuntimeRouter.invoke_tool: {tool_name}") + return await self.routing_table.get_provider_impl(tool_name).invoke_tool( + tool_name=tool_name, + kwargs=kwargs, + ) + + async def list_runtime_tools( + self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None + ) -> ListToolDefsResponse: + logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}") + return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint) diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py new file mode 100644 index 000000000..c04562197 --- /dev/null +++ b/llama_stack/distribution/routers/routing_tables.py @@ -0,0 +1,634 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import logging +import time +import uuid +from typing import Any + +from pydantic import TypeAdapter + +from llama_stack.apis.benchmarks import Benchmark, Benchmarks, ListBenchmarksResponse +from llama_stack.apis.common.content_types import URL +from llama_stack.apis.common.type_system import ParamType +from llama_stack.apis.datasets import ( + Dataset, + DatasetPurpose, + Datasets, + DatasetType, + DataSource, + ListDatasetsResponse, + RowsDataSource, + URIDataSource, +) +from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType, OpenAIListModelsResponse, OpenAIModel +from llama_stack.apis.resource import ResourceType +from llama_stack.apis.scoring_functions import ( + ListScoringFunctionsResponse, + ScoringFn, + ScoringFnParams, + ScoringFunctions, +) +from llama_stack.apis.shields import ListShieldsResponse, Shield, Shields +from llama_stack.apis.tools import ( + ListToolGroupsResponse, + ListToolsResponse, + Tool, + ToolGroup, + ToolGroups, + ToolHost, +) +from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs +from llama_stack.distribution.access_control import check_access +from llama_stack.distribution.datatypes import ( + AccessAttributes, + BenchmarkWithACL, + DatasetWithACL, + ModelWithACL, + RoutableObject, + RoutableObjectWithProvider, + RoutedProtocol, + ScoringFnWithACL, + ShieldWithACL, + ToolGroupWithACL, + ToolWithACL, + VectorDBWithACL, +) +from llama_stack.distribution.request_headers import get_auth_attributes +from llama_stack.distribution.store import DistributionRegistry +from llama_stack.providers.datatypes import Api, RoutingTable + +logger = logging.getLogger(__name__) + + +def get_impl_api(p: Any) -> Api: + return p.__provider_spec__.api + + +# TODO: this should return the registered object for all APIs +async def register_object_with_provider(obj: RoutableObject, p: Any) -> RoutableObject: + api = get_impl_api(p) + + assert obj.provider_id != "remote", "Remote provider should not be registered" + + if api == Api.inference: + return await p.register_model(obj) + elif api == Api.safety: + return await p.register_shield(obj) + elif api == Api.vector_io: + return await p.register_vector_db(obj) + elif api == Api.datasetio: + return await p.register_dataset(obj) + elif api == Api.scoring: + return await p.register_scoring_function(obj) + elif api == Api.eval: + return await p.register_benchmark(obj) + elif api == Api.tool_runtime: + return await p.register_tool(obj) + else: + raise ValueError(f"Unknown API {api} for registering object with provider") + + +async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None: + api = get_impl_api(p) + if api == Api.vector_io: + return await p.unregister_vector_db(obj.identifier) + elif api == Api.inference: + return await p.unregister_model(obj.identifier) + elif api == Api.datasetio: + return await p.unregister_dataset(obj.identifier) + elif api == Api.tool_runtime: + return await p.unregister_tool(obj.identifier) + else: + raise ValueError(f"Unregister not supported for {api}") + + +Registry = dict[str, list[RoutableObjectWithProvider]] + + +class CommonRoutingTableImpl(RoutingTable): + def __init__( + self, + impls_by_provider_id: dict[str, RoutedProtocol], + dist_registry: DistributionRegistry, + ) -> None: + self.impls_by_provider_id = impls_by_provider_id + self.dist_registry = dist_registry + + async def initialize(self) -> None: + async def add_objects(objs: list[RoutableObjectWithProvider], provider_id: str, cls) -> None: + for obj in objs: + if cls is None: + obj.provider_id = provider_id + else: + # Create a copy of the model data and explicitly set provider_id + model_data = obj.model_dump() + model_data["provider_id"] = provider_id + obj = cls(**model_data) + await self.dist_registry.register(obj) + + # Register all objects from providers + for pid, p in self.impls_by_provider_id.items(): + api = get_impl_api(p) + if api == Api.inference: + p.model_store = self + elif api == Api.safety: + p.shield_store = self + elif api == Api.vector_io: + p.vector_db_store = self + elif api == Api.datasetio: + p.dataset_store = self + elif api == Api.scoring: + p.scoring_function_store = self + scoring_functions = await p.list_scoring_functions() + await add_objects(scoring_functions, pid, ScoringFn) + elif api == Api.eval: + p.benchmark_store = self + elif api == Api.tool_runtime: + p.tool_store = self + + async def shutdown(self) -> None: + for p in self.impls_by_provider_id.values(): + await p.shutdown() + + def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any: + def apiname_object(): + if isinstance(self, ModelsRoutingTable): + return ("Inference", "model") + elif isinstance(self, ShieldsRoutingTable): + return ("Safety", "shield") + elif isinstance(self, VectorDBsRoutingTable): + return ("VectorIO", "vector_db") + elif isinstance(self, DatasetsRoutingTable): + return ("DatasetIO", "dataset") + elif isinstance(self, ScoringFunctionsRoutingTable): + return ("Scoring", "scoring_function") + elif isinstance(self, BenchmarksRoutingTable): + return ("Eval", "benchmark") + elif isinstance(self, ToolGroupsRoutingTable): + return ("Tools", "tool") + else: + raise ValueError("Unknown routing table type") + + apiname, objtype = apiname_object() + + # Get objects from disk registry + obj = self.dist_registry.get_cached(objtype, routing_key) + if not obj: + provider_ids = list(self.impls_by_provider_id.keys()) + if len(provider_ids) > 1: + provider_ids_str = f"any of the providers: {', '.join(provider_ids)}" + else: + provider_ids_str = f"provider: `{provider_ids[0]}`" + raise ValueError( + f"{objtype.capitalize()} `{routing_key}` not served by {provider_ids_str}. Make sure there is an {apiname} provider serving this {objtype}." + ) + + if not provider_id or provider_id == obj.provider_id: + return self.impls_by_provider_id[obj.provider_id] + + raise ValueError(f"Provider not found for `{routing_key}`") + + async def get_object_by_identifier(self, type: str, identifier: str) -> RoutableObjectWithProvider | None: + # Get from disk registry + obj = await self.dist_registry.get(type, identifier) + if not obj: + return None + + # Check if user has permission to access this object + if not check_access(obj.identifier, getattr(obj, "access_attributes", None), get_auth_attributes()): + logger.debug(f"Access denied to {type} '{identifier}' based on attribute mismatch") + return None + + return obj + + async def unregister_object(self, obj: RoutableObjectWithProvider) -> None: + await self.dist_registry.delete(obj.type, obj.identifier) + await unregister_object_from_provider(obj, self.impls_by_provider_id[obj.provider_id]) + + async def register_object(self, obj: RoutableObjectWithProvider) -> RoutableObjectWithProvider: + # if provider_id is not specified, pick an arbitrary one from existing entries + if not obj.provider_id and len(self.impls_by_provider_id) > 0: + obj.provider_id = list(self.impls_by_provider_id.keys())[0] + + if obj.provider_id not in self.impls_by_provider_id: + raise ValueError(f"Provider `{obj.provider_id}` not found") + + p = self.impls_by_provider_id[obj.provider_id] + + # If object supports access control but no attributes set, use creator's attributes + if not obj.access_attributes: + creator_attributes = get_auth_attributes() + if creator_attributes: + obj.access_attributes = AccessAttributes(**creator_attributes) + logger.info(f"Setting access attributes for {obj.type} '{obj.identifier}' based on creator's identity") + + registered_obj = await register_object_with_provider(obj, p) + # TODO: This needs to be fixed for all APIs once they return the registered object + if obj.type == ResourceType.model.value: + await self.dist_registry.register(registered_obj) + return registered_obj + + else: + await self.dist_registry.register(obj) + return obj + + async def get_all_with_type(self, type: str) -> list[RoutableObjectWithProvider]: + objs = await self.dist_registry.get_all() + filtered_objs = [obj for obj in objs if obj.type == type] + + # Apply attribute-based access control filtering + if filtered_objs: + filtered_objs = [ + obj + for obj in filtered_objs + if check_access(obj.identifier, getattr(obj, "access_attributes", None), get_auth_attributes()) + ] + + return filtered_objs + + +class ModelsRoutingTable(CommonRoutingTableImpl, Models): + async def list_models(self) -> ListModelsResponse: + return ListModelsResponse(data=await self.get_all_with_type("model")) + + async def openai_list_models(self) -> OpenAIListModelsResponse: + models = await self.get_all_with_type("model") + openai_models = [ + OpenAIModel( + id=model.identifier, + object="model", + created=int(time.time()), + owned_by="llama_stack", + ) + for model in models + ] + return OpenAIListModelsResponse(data=openai_models) + + async def get_model(self, model_id: str) -> Model: + model = await self.get_object_by_identifier("model", model_id) + if model is None: + raise ValueError(f"Model '{model_id}' not found") + return model + + async def register_model( + self, + model_id: str, + provider_model_id: str | None = None, + provider_id: str | None = None, + metadata: dict[str, Any] | None = None, + model_type: ModelType | None = None, + ) -> Model: + if provider_model_id is None: + provider_model_id = model_id + if provider_id is None: + # If provider_id not specified, use the only provider if it supports this model + if len(self.impls_by_provider_id) == 1: + provider_id = list(self.impls_by_provider_id.keys())[0] + else: + raise ValueError( + f"No provider specified and multiple providers available. Please specify a provider_id. Available providers: {self.impls_by_provider_id.keys()}" + ) + if metadata is None: + metadata = {} + if model_type is None: + model_type = ModelType.llm + if "embedding_dimension" not in metadata and model_type == ModelType.embedding: + raise ValueError("Embedding model must have an embedding dimension in its metadata") + model = ModelWithACL( + identifier=model_id, + provider_resource_id=provider_model_id, + provider_id=provider_id, + metadata=metadata, + model_type=model_type, + ) + registered_model = await self.register_object(model) + return registered_model + + async def unregister_model(self, model_id: str) -> None: + existing_model = await self.get_model(model_id) + if existing_model is None: + raise ValueError(f"Model {model_id} not found") + await self.unregister_object(existing_model) + + +class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): + async def list_shields(self) -> ListShieldsResponse: + return ListShieldsResponse(data=await self.get_all_with_type(ResourceType.shield.value)) + + async def get_shield(self, identifier: str) -> Shield: + shield = await self.get_object_by_identifier("shield", identifier) + if shield is None: + raise ValueError(f"Shield '{identifier}' not found") + return shield + + async def register_shield( + self, + shield_id: str, + provider_shield_id: str | None = None, + provider_id: str | None = None, + params: dict[str, Any] | None = None, + ) -> Shield: + if provider_shield_id is None: + provider_shield_id = shield_id + if provider_id is None: + # If provider_id not specified, use the only provider if it supports this shield type + if len(self.impls_by_provider_id) == 1: + provider_id = list(self.impls_by_provider_id.keys())[0] + else: + raise ValueError( + "No provider specified and multiple providers available. Please specify a provider_id." + ) + if params is None: + params = {} + shield = ShieldWithACL( + identifier=shield_id, + provider_resource_id=provider_shield_id, + provider_id=provider_id, + params=params, + ) + await self.register_object(shield) + return shield + + +class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): + async def list_vector_dbs(self) -> ListVectorDBsResponse: + return ListVectorDBsResponse(data=await self.get_all_with_type("vector_db")) + + async def get_vector_db(self, vector_db_id: str) -> VectorDB: + vector_db = await self.get_object_by_identifier("vector_db", vector_db_id) + if vector_db is None: + raise ValueError(f"Vector DB '{vector_db_id}' not found") + return vector_db + + async def register_vector_db( + self, + vector_db_id: str, + embedding_model: str, + embedding_dimension: int | None = 384, + provider_id: str | None = None, + provider_vector_db_id: str | None = None, + ) -> VectorDB: + if provider_vector_db_id is None: + provider_vector_db_id = vector_db_id + if provider_id is None: + if len(self.impls_by_provider_id) > 0: + provider_id = list(self.impls_by_provider_id.keys())[0] + if len(self.impls_by_provider_id) > 1: + logger.warning( + f"No provider specified and multiple providers available. Arbitrarily selected the first provider {provider_id}." + ) + else: + raise ValueError("No provider available. Please configure a vector_io provider.") + model = await self.get_object_by_identifier("model", embedding_model) + if model is None: + raise ValueError(f"Model {embedding_model} not found") + if model.model_type != ModelType.embedding: + raise ValueError(f"Model {embedding_model} is not an embedding model") + if "embedding_dimension" not in model.metadata: + raise ValueError(f"Model {embedding_model} does not have an embedding dimension") + vector_db_data = { + "identifier": vector_db_id, + "type": ResourceType.vector_db.value, + "provider_id": provider_id, + "provider_resource_id": provider_vector_db_id, + "embedding_model": embedding_model, + "embedding_dimension": model.metadata["embedding_dimension"], + } + vector_db = TypeAdapter(VectorDBWithACL).validate_python(vector_db_data) + await self.register_object(vector_db) + return vector_db + + async def unregister_vector_db(self, vector_db_id: str) -> None: + existing_vector_db = await self.get_vector_db(vector_db_id) + if existing_vector_db is None: + raise ValueError(f"Vector DB {vector_db_id} not found") + await self.unregister_object(existing_vector_db) + + +class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): + async def list_datasets(self) -> ListDatasetsResponse: + return ListDatasetsResponse(data=await self.get_all_with_type(ResourceType.dataset.value)) + + async def get_dataset(self, dataset_id: str) -> Dataset: + dataset = await self.get_object_by_identifier("dataset", dataset_id) + if dataset is None: + raise ValueError(f"Dataset '{dataset_id}' not found") + return dataset + + async def register_dataset( + self, + purpose: DatasetPurpose, + source: DataSource, + metadata: dict[str, Any] | None = None, + dataset_id: str | None = None, + ) -> Dataset: + if isinstance(source, dict): + if source["type"] == "uri": + source = URIDataSource.parse_obj(source) + elif source["type"] == "rows": + source = RowsDataSource.parse_obj(source) + + if not dataset_id: + dataset_id = f"dataset-{str(uuid.uuid4())}" + + provider_dataset_id = dataset_id + + # infer provider from source + if metadata: + if metadata.get("provider_id"): + provider_id = metadata.get("provider_id") # pass through from nvidia datasetio + elif source.type == DatasetType.rows.value: + provider_id = "localfs" + elif source.type == DatasetType.uri.value: + # infer provider from uri + if source.uri.startswith("huggingface"): + provider_id = "huggingface" + else: + provider_id = "localfs" + else: + raise ValueError(f"Unknown data source type: {source.type}") + + if metadata is None: + metadata = {} + + dataset = DatasetWithACL( + identifier=dataset_id, + provider_resource_id=provider_dataset_id, + provider_id=provider_id, + purpose=purpose, + source=source, + metadata=metadata, + ) + + await self.register_object(dataset) + return dataset + + async def unregister_dataset(self, dataset_id: str) -> None: + dataset = await self.get_dataset(dataset_id) + if dataset is None: + raise ValueError(f"Dataset {dataset_id} not found") + await self.unregister_object(dataset) + + +class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions): + async def list_scoring_functions(self) -> ListScoringFunctionsResponse: + return ListScoringFunctionsResponse(data=await self.get_all_with_type(ResourceType.scoring_function.value)) + + async def get_scoring_function(self, scoring_fn_id: str) -> ScoringFn: + scoring_fn = await self.get_object_by_identifier("scoring_function", scoring_fn_id) + if scoring_fn is None: + raise ValueError(f"Scoring function '{scoring_fn_id}' not found") + return scoring_fn + + async def register_scoring_function( + self, + scoring_fn_id: str, + description: str, + return_type: ParamType, + provider_scoring_fn_id: str | None = None, + provider_id: str | None = None, + params: ScoringFnParams | None = None, + ) -> None: + if provider_scoring_fn_id is None: + provider_scoring_fn_id = scoring_fn_id + if provider_id is None: + if len(self.impls_by_provider_id) == 1: + provider_id = list(self.impls_by_provider_id.keys())[0] + else: + raise ValueError( + "No provider specified and multiple providers available. Please specify a provider_id." + ) + scoring_fn = ScoringFnWithACL( + identifier=scoring_fn_id, + description=description, + return_type=return_type, + provider_resource_id=provider_scoring_fn_id, + provider_id=provider_id, + params=params, + ) + scoring_fn.provider_id = provider_id + await self.register_object(scoring_fn) + + +class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks): + async def list_benchmarks(self) -> ListBenchmarksResponse: + return ListBenchmarksResponse(data=await self.get_all_with_type("benchmark")) + + async def get_benchmark(self, benchmark_id: str) -> Benchmark: + benchmark = await self.get_object_by_identifier("benchmark", benchmark_id) + if benchmark is None: + raise ValueError(f"Benchmark '{benchmark_id}' not found") + return benchmark + + async def register_benchmark( + self, + benchmark_id: str, + dataset_id: str, + scoring_functions: list[str], + metadata: dict[str, Any] | None = None, + provider_benchmark_id: str | None = None, + provider_id: str | None = None, + ) -> None: + if metadata is None: + metadata = {} + if provider_id is None: + if len(self.impls_by_provider_id) == 1: + provider_id = list(self.impls_by_provider_id.keys())[0] + else: + raise ValueError( + "No provider specified and multiple providers available. Please specify a provider_id." + ) + if provider_benchmark_id is None: + provider_benchmark_id = benchmark_id + benchmark = BenchmarkWithACL( + identifier=benchmark_id, + dataset_id=dataset_id, + scoring_functions=scoring_functions, + metadata=metadata, + provider_id=provider_id, + provider_resource_id=provider_benchmark_id, + ) + await self.register_object(benchmark) + + +class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups): + async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse: + tools = await self.get_all_with_type("tool") + if toolgroup_id: + tools = [tool for tool in tools if tool.toolgroup_id == toolgroup_id] + return ListToolsResponse(data=tools) + + async def list_tool_groups(self) -> ListToolGroupsResponse: + return ListToolGroupsResponse(data=await self.get_all_with_type("tool_group")) + + async def get_tool_group(self, toolgroup_id: str) -> ToolGroup: + tool_group = await self.get_object_by_identifier("tool_group", toolgroup_id) + if tool_group is None: + raise ValueError(f"Tool group '{toolgroup_id}' not found") + return tool_group + + async def get_tool(self, tool_name: str) -> Tool: + return await self.get_object_by_identifier("tool", tool_name) + + async def register_tool_group( + self, + toolgroup_id: str, + provider_id: str, + mcp_endpoint: URL | None = None, + args: dict[str, Any] | None = None, + ) -> None: + tools = [] + tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools(toolgroup_id, mcp_endpoint) + tool_host = ToolHost.model_context_protocol if mcp_endpoint else ToolHost.distribution + + for tool_def in tool_defs.data: + tools.append( + ToolWithACL( + identifier=tool_def.name, + toolgroup_id=toolgroup_id, + description=tool_def.description or "", + parameters=tool_def.parameters or [], + provider_id=provider_id, + provider_resource_id=tool_def.name, + metadata=tool_def.metadata, + tool_host=tool_host, + ) + ) + for tool in tools: + existing_tool = await self.get_tool(tool.identifier) + # Compare existing and new object if one exists + if existing_tool: + existing_dict = existing_tool.model_dump() + new_dict = tool.model_dump() + + if existing_dict != new_dict: + raise ValueError( + f"Object {tool.identifier} already exists in registry. Please use a different identifier." + ) + await self.register_object(tool) + + await self.dist_registry.register( + ToolGroupWithACL( + identifier=toolgroup_id, + provider_id=provider_id, + provider_resource_id=toolgroup_id, + mcp_endpoint=mcp_endpoint, + args=args, + ) + ) + + async def unregister_toolgroup(self, toolgroup_id: str) -> None: + tool_group = await self.get_tool_group(toolgroup_id) + if tool_group is None: + raise ValueError(f"Tool group {toolgroup_id} not found") + tools = await self.list_tools(toolgroup_id) + for tool in getattr(tools, "data", []): + await self.unregister_object(tool) + await self.unregister_object(tool_group) + + async def shutdown(self) -> None: + pass diff --git a/llama_stack/distribution/routers/safety.py b/llama_stack/distribution/routers/safety.py deleted file mode 100644 index 9761d2db0..000000000 --- a/llama_stack/distribution/routers/safety.py +++ /dev/null @@ -1,57 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from typing import Any - -from llama_stack.apis.inference import ( - Message, -) -from llama_stack.apis.safety import RunShieldResponse, Safety -from llama_stack.apis.shields import Shield -from llama_stack.log import get_logger -from llama_stack.providers.datatypes import RoutingTable - -logger = get_logger(name=__name__, category="core") - - -class SafetyRouter(Safety): - def __init__( - self, - routing_table: RoutingTable, - ) -> None: - logger.debug("Initializing SafetyRouter") - self.routing_table = routing_table - - async def initialize(self) -> None: - logger.debug("SafetyRouter.initialize") - pass - - async def shutdown(self) -> None: - logger.debug("SafetyRouter.shutdown") - pass - - async def register_shield( - self, - shield_id: str, - provider_shield_id: str | None = None, - provider_id: str | None = None, - params: dict[str, Any] | None = None, - ) -> Shield: - logger.debug(f"SafetyRouter.register_shield: {shield_id}") - return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params) - - async def run_shield( - self, - shield_id: str, - messages: list[Message], - params: dict[str, Any] = None, - ) -> RunShieldResponse: - logger.debug(f"SafetyRouter.run_shield: {shield_id}") - return await self.routing_table.get_provider_impl(shield_id).run_shield( - shield_id=shield_id, - messages=messages, - params=params, - ) diff --git a/llama_stack/distribution/routers/tool_runtime.py b/llama_stack/distribution/routers/tool_runtime.py deleted file mode 100644 index 285843dbc..000000000 --- a/llama_stack/distribution/routers/tool_runtime.py +++ /dev/null @@ -1,92 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from typing import Any - -from llama_stack.apis.common.content_types import ( - URL, - InterleavedContent, -) -from llama_stack.apis.tools import ( - ListToolsResponse, - RAGDocument, - RAGQueryConfig, - RAGQueryResult, - RAGToolRuntime, - ToolRuntime, -) -from llama_stack.log import get_logger - -from ..routing_tables.toolgroups import ToolGroupsRoutingTable - -logger = get_logger(name=__name__, category="core") - - -class ToolRuntimeRouter(ToolRuntime): - class RagToolImpl(RAGToolRuntime): - def __init__( - self, - routing_table: ToolGroupsRoutingTable, - ) -> None: - logger.debug("Initializing ToolRuntimeRouter.RagToolImpl") - self.routing_table = routing_table - - async def query( - self, - content: InterleavedContent, - vector_db_ids: list[str], - query_config: RAGQueryConfig | None = None, - ) -> RAGQueryResult: - logger.debug(f"ToolRuntimeRouter.RagToolImpl.query: {vector_db_ids}") - return await self.routing_table.get_provider_impl("knowledge_search").query( - content, vector_db_ids, query_config - ) - - async def insert( - self, - documents: list[RAGDocument], - vector_db_id: str, - chunk_size_in_tokens: int = 512, - ) -> None: - logger.debug( - f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}" - ) - return await self.routing_table.get_provider_impl("insert_into_memory").insert( - documents, vector_db_id, chunk_size_in_tokens - ) - - def __init__( - self, - routing_table: ToolGroupsRoutingTable, - ) -> None: - logger.debug("Initializing ToolRuntimeRouter") - self.routing_table = routing_table - - # HACK ALERT this should be in sync with "get_all_api_endpoints()" - self.rag_tool = self.RagToolImpl(routing_table) - for method in ("query", "insert"): - setattr(self, f"rag_tool.{method}", getattr(self.rag_tool, method)) - - async def initialize(self) -> None: - logger.debug("ToolRuntimeRouter.initialize") - pass - - async def shutdown(self) -> None: - logger.debug("ToolRuntimeRouter.shutdown") - pass - - async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> Any: - logger.debug(f"ToolRuntimeRouter.invoke_tool: {tool_name}") - return await self.routing_table.get_provider_impl(tool_name).invoke_tool( - tool_name=tool_name, - kwargs=kwargs, - ) - - async def list_runtime_tools( - self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None - ) -> ListToolsResponse: - logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}") - return await self.routing_table.list_tools(tool_group_id) diff --git a/llama_stack/distribution/routers/vector_io.py b/llama_stack/distribution/routers/vector_io.py deleted file mode 100644 index 8c17aa890..000000000 --- a/llama_stack/distribution/routers/vector_io.py +++ /dev/null @@ -1,72 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from typing import Any - -from llama_stack.apis.common.content_types import ( - InterleavedContent, -) -from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO -from llama_stack.log import get_logger -from llama_stack.providers.datatypes import RoutingTable - -logger = get_logger(name=__name__, category="core") - - -class VectorIORouter(VectorIO): - """Routes to an provider based on the vector db identifier""" - - def __init__( - self, - routing_table: RoutingTable, - ) -> None: - logger.debug("Initializing VectorIORouter") - self.routing_table = routing_table - - async def initialize(self) -> None: - logger.debug("VectorIORouter.initialize") - pass - - async def shutdown(self) -> None: - logger.debug("VectorIORouter.shutdown") - pass - - async def register_vector_db( - self, - vector_db_id: str, - embedding_model: str, - embedding_dimension: int | None = 384, - provider_id: str | None = None, - provider_vector_db_id: str | None = None, - ) -> None: - logger.debug(f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}") - await self.routing_table.register_vector_db( - vector_db_id, - embedding_model, - embedding_dimension, - provider_id, - provider_vector_db_id, - ) - - async def insert_chunks( - self, - vector_db_id: str, - chunks: list[Chunk], - ttl_seconds: int | None = None, - ) -> None: - logger.debug( - f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}", - ) - return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(vector_db_id, chunks, ttl_seconds) - - async def query_chunks( - self, - vector_db_id: str, - query: InterleavedContent, - params: dict[str, Any] | None = None, - ) -> QueryChunksResponse: - logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}") - return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params) diff --git a/llama_stack/distribution/routing_tables/__init__.py b/llama_stack/distribution/routing_tables/__init__.py deleted file mode 100644 index 756f351d8..000000000 --- a/llama_stack/distribution/routing_tables/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. diff --git a/llama_stack/distribution/routing_tables/benchmarks.py b/llama_stack/distribution/routing_tables/benchmarks.py deleted file mode 100644 index 589a00c02..000000000 --- a/llama_stack/distribution/routing_tables/benchmarks.py +++ /dev/null @@ -1,58 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from typing import Any - -from llama_stack.apis.benchmarks import Benchmark, Benchmarks, ListBenchmarksResponse -from llama_stack.distribution.datatypes import ( - BenchmarkWithACL, -) -from llama_stack.log import get_logger - -from .common import CommonRoutingTableImpl - -logger = get_logger(name=__name__, category="core") - - -class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks): - async def list_benchmarks(self) -> ListBenchmarksResponse: - return ListBenchmarksResponse(data=await self.get_all_with_type("benchmark")) - - async def get_benchmark(self, benchmark_id: str) -> Benchmark: - benchmark = await self.get_object_by_identifier("benchmark", benchmark_id) - if benchmark is None: - raise ValueError(f"Benchmark '{benchmark_id}' not found") - return benchmark - - async def register_benchmark( - self, - benchmark_id: str, - dataset_id: str, - scoring_functions: list[str], - metadata: dict[str, Any] | None = None, - provider_benchmark_id: str | None = None, - provider_id: str | None = None, - ) -> None: - if metadata is None: - metadata = {} - if provider_id is None: - if len(self.impls_by_provider_id) == 1: - provider_id = list(self.impls_by_provider_id.keys())[0] - else: - raise ValueError( - "No provider specified and multiple providers available. Please specify a provider_id." - ) - if provider_benchmark_id is None: - provider_benchmark_id = benchmark_id - benchmark = BenchmarkWithACL( - identifier=benchmark_id, - dataset_id=dataset_id, - scoring_functions=scoring_functions, - metadata=metadata, - provider_id=provider_id, - provider_resource_id=provider_benchmark_id, - ) - await self.register_object(benchmark) diff --git a/llama_stack/distribution/routing_tables/common.py b/llama_stack/distribution/routing_tables/common.py deleted file mode 100644 index 8ec87ca50..000000000 --- a/llama_stack/distribution/routing_tables/common.py +++ /dev/null @@ -1,218 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from typing import Any - -from llama_stack.apis.resource import ResourceType -from llama_stack.apis.scoring_functions import ScoringFn -from llama_stack.distribution.access_control import check_access -from llama_stack.distribution.datatypes import ( - AccessAttributes, - RoutableObject, - RoutableObjectWithProvider, - RoutedProtocol, -) -from llama_stack.distribution.request_headers import get_auth_attributes -from llama_stack.distribution.store import DistributionRegistry -from llama_stack.log import get_logger -from llama_stack.providers.datatypes import Api, RoutingTable - -logger = get_logger(name=__name__, category="core") - - -def get_impl_api(p: Any) -> Api: - return p.__provider_spec__.api - - -# TODO: this should return the registered object for all APIs -async def register_object_with_provider(obj: RoutableObject, p: Any) -> RoutableObject: - api = get_impl_api(p) - - assert obj.provider_id != "remote", "Remote provider should not be registered" - - if api == Api.inference: - return await p.register_model(obj) - elif api == Api.safety: - return await p.register_shield(obj) - elif api == Api.vector_io: - return await p.register_vector_db(obj) - elif api == Api.datasetio: - return await p.register_dataset(obj) - elif api == Api.scoring: - return await p.register_scoring_function(obj) - elif api == Api.eval: - return await p.register_benchmark(obj) - elif api == Api.tool_runtime: - return await p.register_toolgroup(obj) - else: - raise ValueError(f"Unknown API {api} for registering object with provider") - - -async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None: - api = get_impl_api(p) - if api == Api.vector_io: - return await p.unregister_vector_db(obj.identifier) - elif api == Api.inference: - return await p.unregister_model(obj.identifier) - elif api == Api.datasetio: - return await p.unregister_dataset(obj.identifier) - elif api == Api.tool_runtime: - return await p.unregister_toolgroup(obj.identifier) - else: - raise ValueError(f"Unregister not supported for {api}") - - -Registry = dict[str, list[RoutableObjectWithProvider]] - - -class CommonRoutingTableImpl(RoutingTable): - def __init__( - self, - impls_by_provider_id: dict[str, RoutedProtocol], - dist_registry: DistributionRegistry, - ) -> None: - self.impls_by_provider_id = impls_by_provider_id - self.dist_registry = dist_registry - - async def initialize(self) -> None: - async def add_objects(objs: list[RoutableObjectWithProvider], provider_id: str, cls) -> None: - for obj in objs: - if cls is None: - obj.provider_id = provider_id - else: - # Create a copy of the model data and explicitly set provider_id - model_data = obj.model_dump() - model_data["provider_id"] = provider_id - obj = cls(**model_data) - await self.dist_registry.register(obj) - - # Register all objects from providers - for pid, p in self.impls_by_provider_id.items(): - api = get_impl_api(p) - if api == Api.inference: - p.model_store = self - elif api == Api.safety: - p.shield_store = self - elif api == Api.vector_io: - p.vector_db_store = self - elif api == Api.datasetio: - p.dataset_store = self - elif api == Api.scoring: - p.scoring_function_store = self - scoring_functions = await p.list_scoring_functions() - await add_objects(scoring_functions, pid, ScoringFn) - elif api == Api.eval: - p.benchmark_store = self - elif api == Api.tool_runtime: - p.tool_store = self - - async def shutdown(self) -> None: - for p in self.impls_by_provider_id.values(): - await p.shutdown() - - def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any: - from .benchmarks import BenchmarksRoutingTable - from .datasets import DatasetsRoutingTable - from .models import ModelsRoutingTable - from .scoring_functions import ScoringFunctionsRoutingTable - from .shields import ShieldsRoutingTable - from .toolgroups import ToolGroupsRoutingTable - from .vector_dbs import VectorDBsRoutingTable - - def apiname_object(): - if isinstance(self, ModelsRoutingTable): - return ("Inference", "model") - elif isinstance(self, ShieldsRoutingTable): - return ("Safety", "shield") - elif isinstance(self, VectorDBsRoutingTable): - return ("VectorIO", "vector_db") - elif isinstance(self, DatasetsRoutingTable): - return ("DatasetIO", "dataset") - elif isinstance(self, ScoringFunctionsRoutingTable): - return ("Scoring", "scoring_function") - elif isinstance(self, BenchmarksRoutingTable): - return ("Eval", "benchmark") - elif isinstance(self, ToolGroupsRoutingTable): - return ("ToolGroups", "tool_group") - else: - raise ValueError("Unknown routing table type") - - apiname, objtype = apiname_object() - - # Get objects from disk registry - obj = self.dist_registry.get_cached(objtype, routing_key) - if not obj: - provider_ids = list(self.impls_by_provider_id.keys()) - if len(provider_ids) > 1: - provider_ids_str = f"any of the providers: {', '.join(provider_ids)}" - else: - provider_ids_str = f"provider: `{provider_ids[0]}`" - raise ValueError( - f"{objtype.capitalize()} `{routing_key}` not served by {provider_ids_str}. Make sure there is an {apiname} provider serving this {objtype}." - ) - - if not provider_id or provider_id == obj.provider_id: - return self.impls_by_provider_id[obj.provider_id] - - raise ValueError(f"Provider not found for `{routing_key}`") - - async def get_object_by_identifier(self, type: str, identifier: str) -> RoutableObjectWithProvider | None: - # Get from disk registry - obj = await self.dist_registry.get(type, identifier) - if not obj: - return None - - # Check if user has permission to access this object - if not check_access(obj.identifier, getattr(obj, "access_attributes", None), get_auth_attributes()): - logger.debug(f"Access denied to {type} '{identifier}' based on attribute mismatch") - return None - - return obj - - async def unregister_object(self, obj: RoutableObjectWithProvider) -> None: - await self.dist_registry.delete(obj.type, obj.identifier) - await unregister_object_from_provider(obj, self.impls_by_provider_id[obj.provider_id]) - - async def register_object(self, obj: RoutableObjectWithProvider) -> RoutableObjectWithProvider: - # if provider_id is not specified, pick an arbitrary one from existing entries - if not obj.provider_id and len(self.impls_by_provider_id) > 0: - obj.provider_id = list(self.impls_by_provider_id.keys())[0] - - if obj.provider_id not in self.impls_by_provider_id: - raise ValueError(f"Provider `{obj.provider_id}` not found") - - p = self.impls_by_provider_id[obj.provider_id] - - # If object supports access control but no attributes set, use creator's attributes - if not obj.access_attributes: - creator_attributes = get_auth_attributes() - if creator_attributes: - obj.access_attributes = AccessAttributes(**creator_attributes) - logger.info(f"Setting access attributes for {obj.type} '{obj.identifier}' based on creator's identity") - - registered_obj = await register_object_with_provider(obj, p) - # TODO: This needs to be fixed for all APIs once they return the registered object - if obj.type == ResourceType.model.value: - await self.dist_registry.register(registered_obj) - return registered_obj - - else: - await self.dist_registry.register(obj) - return obj - - async def get_all_with_type(self, type: str) -> list[RoutableObjectWithProvider]: - objs = await self.dist_registry.get_all() - filtered_objs = [obj for obj in objs if obj.type == type] - - # Apply attribute-based access control filtering - if filtered_objs: - filtered_objs = [ - obj - for obj in filtered_objs - if check_access(obj.identifier, getattr(obj, "access_attributes", None), get_auth_attributes()) - ] - - return filtered_objs diff --git a/llama_stack/distribution/routing_tables/datasets.py b/llama_stack/distribution/routing_tables/datasets.py deleted file mode 100644 index 4401ad47e..000000000 --- a/llama_stack/distribution/routing_tables/datasets.py +++ /dev/null @@ -1,93 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import uuid -from typing import Any - -from llama_stack.apis.datasets import ( - Dataset, - DatasetPurpose, - Datasets, - DatasetType, - DataSource, - ListDatasetsResponse, - RowsDataSource, - URIDataSource, -) -from llama_stack.apis.resource import ResourceType -from llama_stack.distribution.datatypes import ( - DatasetWithACL, -) -from llama_stack.log import get_logger - -from .common import CommonRoutingTableImpl - -logger = get_logger(name=__name__, category="core") - - -class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): - async def list_datasets(self) -> ListDatasetsResponse: - return ListDatasetsResponse(data=await self.get_all_with_type(ResourceType.dataset.value)) - - async def get_dataset(self, dataset_id: str) -> Dataset: - dataset = await self.get_object_by_identifier("dataset", dataset_id) - if dataset is None: - raise ValueError(f"Dataset '{dataset_id}' not found") - return dataset - - async def register_dataset( - self, - purpose: DatasetPurpose, - source: DataSource, - metadata: dict[str, Any] | None = None, - dataset_id: str | None = None, - ) -> Dataset: - if isinstance(source, dict): - if source["type"] == "uri": - source = URIDataSource.parse_obj(source) - elif source["type"] == "rows": - source = RowsDataSource.parse_obj(source) - - if not dataset_id: - dataset_id = f"dataset-{str(uuid.uuid4())}" - - provider_dataset_id = dataset_id - - # infer provider from source - if metadata: - if metadata.get("provider_id"): - provider_id = metadata.get("provider_id") # pass through from nvidia datasetio - elif source.type == DatasetType.rows.value: - provider_id = "localfs" - elif source.type == DatasetType.uri.value: - # infer provider from uri - if source.uri.startswith("huggingface"): - provider_id = "huggingface" - else: - provider_id = "localfs" - else: - raise ValueError(f"Unknown data source type: {source.type}") - - if metadata is None: - metadata = {} - - dataset = DatasetWithACL( - identifier=dataset_id, - provider_resource_id=provider_dataset_id, - provider_id=provider_id, - purpose=purpose, - source=source, - metadata=metadata, - ) - - await self.register_object(dataset) - return dataset - - async def unregister_dataset(self, dataset_id: str) -> None: - dataset = await self.get_dataset(dataset_id) - if dataset is None: - raise ValueError(f"Dataset {dataset_id} not found") - await self.unregister_object(dataset) diff --git a/llama_stack/distribution/routing_tables/models.py b/llama_stack/distribution/routing_tables/models.py deleted file mode 100644 index 7216d9935..000000000 --- a/llama_stack/distribution/routing_tables/models.py +++ /dev/null @@ -1,82 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import time -from typing import Any - -from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType, OpenAIListModelsResponse, OpenAIModel -from llama_stack.distribution.datatypes import ( - ModelWithACL, -) -from llama_stack.log import get_logger - -from .common import CommonRoutingTableImpl - -logger = get_logger(name=__name__, category="core") - - -class ModelsRoutingTable(CommonRoutingTableImpl, Models): - async def list_models(self) -> ListModelsResponse: - return ListModelsResponse(data=await self.get_all_with_type("model")) - - async def openai_list_models(self) -> OpenAIListModelsResponse: - models = await self.get_all_with_type("model") - openai_models = [ - OpenAIModel( - id=model.identifier, - object="model", - created=int(time.time()), - owned_by="llama_stack", - ) - for model in models - ] - return OpenAIListModelsResponse(data=openai_models) - - async def get_model(self, model_id: str) -> Model: - model = await self.get_object_by_identifier("model", model_id) - if model is None: - raise ValueError(f"Model '{model_id}' not found") - return model - - async def register_model( - self, - model_id: str, - provider_model_id: str | None = None, - provider_id: str | None = None, - metadata: dict[str, Any] | None = None, - model_type: ModelType | None = None, - ) -> Model: - if provider_model_id is None: - provider_model_id = model_id - if provider_id is None: - # If provider_id not specified, use the only provider if it supports this model - if len(self.impls_by_provider_id) == 1: - provider_id = list(self.impls_by_provider_id.keys())[0] - else: - raise ValueError( - f"No provider specified and multiple providers available. Please specify a provider_id. Available providers: {self.impls_by_provider_id.keys()}" - ) - if metadata is None: - metadata = {} - if model_type is None: - model_type = ModelType.llm - if "embedding_dimension" not in metadata and model_type == ModelType.embedding: - raise ValueError("Embedding model must have an embedding dimension in its metadata") - model = ModelWithACL( - identifier=model_id, - provider_resource_id=provider_model_id, - provider_id=provider_id, - metadata=metadata, - model_type=model_type, - ) - registered_model = await self.register_object(model) - return registered_model - - async def unregister_model(self, model_id: str) -> None: - existing_model = await self.get_model(model_id) - if existing_model is None: - raise ValueError(f"Model {model_id} not found") - await self.unregister_object(existing_model) diff --git a/llama_stack/distribution/routing_tables/scoring_functions.py b/llama_stack/distribution/routing_tables/scoring_functions.py deleted file mode 100644 index d85f64b57..000000000 --- a/llama_stack/distribution/routing_tables/scoring_functions.py +++ /dev/null @@ -1,62 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from llama_stack.apis.common.type_system import ParamType -from llama_stack.apis.resource import ResourceType -from llama_stack.apis.scoring_functions import ( - ListScoringFunctionsResponse, - ScoringFn, - ScoringFnParams, - ScoringFunctions, -) -from llama_stack.distribution.datatypes import ( - ScoringFnWithACL, -) -from llama_stack.log import get_logger - -from .common import CommonRoutingTableImpl - -logger = get_logger(name=__name__, category="core") - - -class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions): - async def list_scoring_functions(self) -> ListScoringFunctionsResponse: - return ListScoringFunctionsResponse(data=await self.get_all_with_type(ResourceType.scoring_function.value)) - - async def get_scoring_function(self, scoring_fn_id: str) -> ScoringFn: - scoring_fn = await self.get_object_by_identifier("scoring_function", scoring_fn_id) - if scoring_fn is None: - raise ValueError(f"Scoring function '{scoring_fn_id}' not found") - return scoring_fn - - async def register_scoring_function( - self, - scoring_fn_id: str, - description: str, - return_type: ParamType, - provider_scoring_fn_id: str | None = None, - provider_id: str | None = None, - params: ScoringFnParams | None = None, - ) -> None: - if provider_scoring_fn_id is None: - provider_scoring_fn_id = scoring_fn_id - if provider_id is None: - if len(self.impls_by_provider_id) == 1: - provider_id = list(self.impls_by_provider_id.keys())[0] - else: - raise ValueError( - "No provider specified and multiple providers available. Please specify a provider_id." - ) - scoring_fn = ScoringFnWithACL( - identifier=scoring_fn_id, - description=description, - return_type=return_type, - provider_resource_id=provider_scoring_fn_id, - provider_id=provider_id, - params=params, - ) - scoring_fn.provider_id = provider_id - await self.register_object(scoring_fn) diff --git a/llama_stack/distribution/routing_tables/shields.py b/llama_stack/distribution/routing_tables/shields.py deleted file mode 100644 index 7f62596c9..000000000 --- a/llama_stack/distribution/routing_tables/shields.py +++ /dev/null @@ -1,57 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from typing import Any - -from llama_stack.apis.resource import ResourceType -from llama_stack.apis.shields import ListShieldsResponse, Shield, Shields -from llama_stack.distribution.datatypes import ( - ShieldWithACL, -) -from llama_stack.log import get_logger - -from .common import CommonRoutingTableImpl - -logger = get_logger(name=__name__, category="core") - - -class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): - async def list_shields(self) -> ListShieldsResponse: - return ListShieldsResponse(data=await self.get_all_with_type(ResourceType.shield.value)) - - async def get_shield(self, identifier: str) -> Shield: - shield = await self.get_object_by_identifier("shield", identifier) - if shield is None: - raise ValueError(f"Shield '{identifier}' not found") - return shield - - async def register_shield( - self, - shield_id: str, - provider_shield_id: str | None = None, - provider_id: str | None = None, - params: dict[str, Any] | None = None, - ) -> Shield: - if provider_shield_id is None: - provider_shield_id = shield_id - if provider_id is None: - # If provider_id not specified, use the only provider if it supports this shield type - if len(self.impls_by_provider_id) == 1: - provider_id = list(self.impls_by_provider_id.keys())[0] - else: - raise ValueError( - "No provider specified and multiple providers available. Please specify a provider_id." - ) - if params is None: - params = {} - shield = ShieldWithACL( - identifier=shield_id, - provider_resource_id=provider_shield_id, - provider_id=provider_id, - params=params, - ) - await self.register_object(shield) - return shield diff --git a/llama_stack/distribution/routing_tables/toolgroups.py b/llama_stack/distribution/routing_tables/toolgroups.py deleted file mode 100644 index 2f7dc3e06..000000000 --- a/llama_stack/distribution/routing_tables/toolgroups.py +++ /dev/null @@ -1,132 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from typing import Any - -from llama_stack.apis.common.content_types import URL -from llama_stack.apis.tools import ListToolGroupsResponse, ListToolsResponse, Tool, ToolGroup, ToolGroups -from llama_stack.distribution.datatypes import ToolGroupWithACL -from llama_stack.log import get_logger - -from .common import CommonRoutingTableImpl - -logger = get_logger(name=__name__, category="core") - - -def parse_toolgroup_from_toolgroup_name_pair(toolgroup_name_with_maybe_tool_name: str) -> str | None: - # handle the funny case like "builtin::rag/knowledge_search" - parts = toolgroup_name_with_maybe_tool_name.split("/") - if len(parts) == 2: - return parts[0] - else: - return None - - -class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups): - toolgroups_to_tools: dict[str, list[Tool]] = {} - tool_to_toolgroup: dict[str, str] = {} - - # overridden - def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any: - # we don't index tools in the registry anymore, but only keep a cache of them by toolgroup_id - # TODO: we may want to invalidate the cache (for a given toolgroup_id) every once in a while? - - toolgroup_id = parse_toolgroup_from_toolgroup_name_pair(routing_key) - if toolgroup_id: - routing_key = toolgroup_id - - if routing_key in self.tool_to_toolgroup: - routing_key = self.tool_to_toolgroup[routing_key] - return super().get_provider_impl(routing_key, provider_id) - - async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse: - if toolgroup_id: - if group_id := parse_toolgroup_from_toolgroup_name_pair(toolgroup_id): - toolgroup_id = group_id - toolgroups = [await self.get_tool_group(toolgroup_id)] - else: - toolgroups = await self.get_all_with_type("tool_group") - - all_tools = [] - for toolgroup in toolgroups: - if toolgroup.identifier not in self.toolgroups_to_tools: - await self._index_tools(toolgroup) - all_tools.extend(self.toolgroups_to_tools[toolgroup.identifier]) - - return ListToolsResponse(data=all_tools) - - async def _index_tools(self, toolgroup: ToolGroup): - provider_impl = super().get_provider_impl(toolgroup.identifier, toolgroup.provider_id) - tooldefs_response = await provider_impl.list_runtime_tools(toolgroup.identifier, toolgroup.mcp_endpoint) - - # TODO: kill this Tool vs ToolDef distinction - tooldefs = tooldefs_response.data - tools = [] - for t in tooldefs: - tools.append( - Tool( - identifier=t.name, - toolgroup_id=toolgroup.identifier, - description=t.description or "", - parameters=t.parameters or [], - metadata=t.metadata, - provider_id=toolgroup.provider_id, - ) - ) - - self.toolgroups_to_tools[toolgroup.identifier] = tools - for tool in tools: - self.tool_to_toolgroup[tool.identifier] = toolgroup.identifier - - async def list_tool_groups(self) -> ListToolGroupsResponse: - return ListToolGroupsResponse(data=await self.get_all_with_type("tool_group")) - - async def get_tool_group(self, toolgroup_id: str) -> ToolGroup: - tool_group = await self.get_object_by_identifier("tool_group", toolgroup_id) - if tool_group is None: - raise ValueError(f"Tool group '{toolgroup_id}' not found") - return tool_group - - async def get_tool(self, tool_name: str) -> Tool: - if tool_name in self.tool_to_toolgroup: - toolgroup_id = self.tool_to_toolgroup[tool_name] - tools = self.toolgroups_to_tools[toolgroup_id] - for tool in tools: - if tool.identifier == tool_name: - return tool - raise ValueError(f"Tool '{tool_name}' not found") - - async def register_tool_group( - self, - toolgroup_id: str, - provider_id: str, - mcp_endpoint: URL | None = None, - args: dict[str, Any] | None = None, - ) -> None: - toolgroup = ToolGroupWithACL( - identifier=toolgroup_id, - provider_id=provider_id, - provider_resource_id=toolgroup_id, - mcp_endpoint=mcp_endpoint, - args=args, - ) - await self.register_object(toolgroup) - - # ideally, indexing of the tools should not be necessary because anyone using - # the tools should first list the tools and then use them. but there are assumptions - # baked in some of the code and tests right now. - if not toolgroup.mcp_endpoint: - await self._index_tools(toolgroup) - return toolgroup - - async def unregister_toolgroup(self, toolgroup_id: str) -> None: - tool_group = await self.get_tool_group(toolgroup_id) - if tool_group is None: - raise ValueError(f"Tool group {toolgroup_id} not found") - await self.unregister_object(tool_group) - - async def shutdown(self) -> None: - pass diff --git a/llama_stack/distribution/routing_tables/vector_dbs.py b/llama_stack/distribution/routing_tables/vector_dbs.py deleted file mode 100644 index dc6c0d0ef..000000000 --- a/llama_stack/distribution/routing_tables/vector_dbs.py +++ /dev/null @@ -1,74 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from pydantic import TypeAdapter - -from llama_stack.apis.models import ModelType -from llama_stack.apis.resource import ResourceType -from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs -from llama_stack.distribution.datatypes import ( - VectorDBWithACL, -) -from llama_stack.log import get_logger - -from .common import CommonRoutingTableImpl - -logger = get_logger(name=__name__, category="core") - - -class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): - async def list_vector_dbs(self) -> ListVectorDBsResponse: - return ListVectorDBsResponse(data=await self.get_all_with_type("vector_db")) - - async def get_vector_db(self, vector_db_id: str) -> VectorDB: - vector_db = await self.get_object_by_identifier("vector_db", vector_db_id) - if vector_db is None: - raise ValueError(f"Vector DB '{vector_db_id}' not found") - return vector_db - - async def register_vector_db( - self, - vector_db_id: str, - embedding_model: str, - embedding_dimension: int | None = 384, - provider_id: str | None = None, - provider_vector_db_id: str | None = None, - ) -> VectorDB: - if provider_vector_db_id is None: - provider_vector_db_id = vector_db_id - if provider_id is None: - if len(self.impls_by_provider_id) > 0: - provider_id = list(self.impls_by_provider_id.keys())[0] - if len(self.impls_by_provider_id) > 1: - logger.warning( - f"No provider specified and multiple providers available. Arbitrarily selected the first provider {provider_id}." - ) - else: - raise ValueError("No provider available. Please configure a vector_io provider.") - model = await self.get_object_by_identifier("model", embedding_model) - if model is None: - raise ValueError(f"Model {embedding_model} not found") - if model.model_type != ModelType.embedding: - raise ValueError(f"Model {embedding_model} is not an embedding model") - if "embedding_dimension" not in model.metadata: - raise ValueError(f"Model {embedding_model} does not have an embedding dimension") - vector_db_data = { - "identifier": vector_db_id, - "type": ResourceType.vector_db.value, - "provider_id": provider_id, - "provider_resource_id": provider_vector_db_id, - "embedding_model": embedding_model, - "embedding_dimension": model.metadata["embedding_dimension"], - } - vector_db = TypeAdapter(VectorDBWithACL).validate_python(vector_db_data) - await self.register_object(vector_db) - return vector_db - - async def unregister_vector_db(self, vector_db_id: str) -> None: - existing_vector_db = await self.get_vector_db(vector_db_id) - if existing_vector_db is None: - raise ValueError(f"Vector DB {vector_db_id} not found") - await self.unregister_object(existing_vector_db) diff --git a/llama_stack/distribution/server/auth.py b/llama_stack/distribution/server/auth.py index fb26b49a7..429232ece 100644 --- a/llama_stack/distribution/server/auth.py +++ b/llama_stack/distribution/server/auth.py @@ -8,8 +8,7 @@ import json import httpx -from llama_stack.distribution.datatypes import AuthenticationConfig -from llama_stack.distribution.server.auth_providers import create_auth_provider +from llama_stack.distribution.server.auth_providers import AuthProviderConfig, create_auth_provider from llama_stack.log import get_logger logger = get_logger(name=__name__, category="auth") @@ -78,7 +77,7 @@ class AuthenticationMiddleware: access resources that don't have access_attributes defined. """ - def __init__(self, app, auth_config: AuthenticationConfig): + def __init__(self, app, auth_config: AuthProviderConfig): self.app = app self.auth_provider = create_auth_provider(auth_config) @@ -94,7 +93,7 @@ class AuthenticationMiddleware: # Validate token and get access attributes try: - validation_result = await self.auth_provider.validate_token(token, scope) + access_attributes = await self.auth_provider.validate_token(token, scope) except httpx.TimeoutException: logger.exception("Authentication request timed out") return await self._send_auth_error(send, "Authentication service timeout") @@ -106,24 +105,17 @@ class AuthenticationMiddleware: return await self._send_auth_error(send, "Authentication service error") # Store attributes in request scope for access control - if validation_result.access_attributes: - user_attributes = validation_result.access_attributes.model_dump(exclude_none=True) + if access_attributes: + user_attributes = access_attributes.model_dump(exclude_none=True) else: logger.warning("No access attributes, setting namespace to token by default") user_attributes = { - "roles": [token], + "namespaces": [token], } - # Store the client ID in the request scope so that downstream middleware (like QuotaMiddleware) - # can identify the requester and enforce per-client rate limits. - scope["authenticated_client_id"] = token - # Store attributes in request scope scope["user_attributes"] = user_attributes - scope["principal"] = validation_result.principal - logger.debug( - f"Authentication successful: {validation_result.principal} with {len(scope['user_attributes'])} attributes" - ) + logger.debug(f"Authentication successful: {len(scope['user_attributes'])} attributes") return await self.app(scope, receive, send) diff --git a/llama_stack/distribution/server/auth_providers.py b/llama_stack/distribution/server/auth_providers.py index 723a65b77..1b19f8923 100644 --- a/llama_stack/distribution/server/auth_providers.py +++ b/llama_stack/distribution/server/auth_providers.py @@ -4,29 +4,23 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import ssl -import time +import json from abc import ABC, abstractmethod -from asyncio import Lock -from pathlib import Path +from enum import Enum from urllib.parse import parse_qs import httpx -from jose import jwt -from pydantic import BaseModel, Field, field_validator, model_validator -from typing_extensions import Self +from pydantic import BaseModel, Field -from llama_stack.distribution.datatypes import AccessAttributes, AuthenticationConfig, AuthProviderType +from llama_stack.distribution.datatypes import AccessAttributes from llama_stack.log import get_logger logger = get_logger(name=__name__, category="auth") -class TokenValidationResult(BaseModel): - principal: str | None = Field( - default=None, - description="The principal (username or persistent identifier) of the authenticated user", - ) +class AuthResponse(BaseModel): + """The format of the authentication response from the auth endpoint.""" + access_attributes: AccessAttributes | None = Field( default=None, description=""" @@ -49,10 +43,6 @@ class TokenValidationResult(BaseModel): """, ) - -class AuthResponse(TokenValidationResult): - """The format of the authentication response from the auth endpoint.""" - message: str | None = Field( default=None, description="Optional message providing additional context about the authentication result." ) @@ -74,11 +64,25 @@ class AuthRequest(BaseModel): request: AuthRequestContext = Field(description="Context information about the request being authenticated") +class AuthProviderType(str, Enum): + """Supported authentication provider types.""" + + KUBERNETES = "kubernetes" + CUSTOM = "custom" + + +class AuthProviderConfig(BaseModel): + """Base configuration for authentication providers.""" + + provider_type: AuthProviderType = Field(..., description="Type of authentication provider") + config: dict[str, str] = Field(..., description="Provider-specific configuration") + + class AuthProvider(ABC): """Abstract base class for authentication providers.""" @abstractmethod - async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult: + async def validate_token(self, token: str, scope: dict | None = None) -> AccessAttributes | None: """Validate a token and return access attributes.""" pass @@ -88,219 +92,88 @@ class AuthProvider(ABC): pass -def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str]) -> AccessAttributes: - attributes = AccessAttributes() - for claim_key, attribute_key in mapping.items(): - if claim_key not in claims or not hasattr(attributes, attribute_key): - continue - claim = claims[claim_key] - if isinstance(claim, list): - values = claim - else: - values = claim.split() +class KubernetesAuthProvider(AuthProvider): + """Kubernetes authentication provider that validates tokens against the Kubernetes API server.""" - current = getattr(attributes, attribute_key) - if current: - current.extend(values) - else: - setattr(attributes, attribute_key, values) - return attributes + def __init__(self, config: dict[str, str]): + self.api_server_url = config["api_server_url"] + self.ca_cert_path = config.get("ca_cert_path") + self._client = None + async def _get_client(self): + """Get or create a Kubernetes client.""" + if self._client is None: + # kubernetes-client has not async support, see: + # https://github.com/kubernetes-client/python/issues/323 + from kubernetes import client + from kubernetes.client import ApiClient -class OAuth2JWKSConfig(BaseModel): - # The JWKS URI for collecting public keys - uri: str - key_recheck_period: int = Field(default=3600, description="The period to recheck the JWKS URI for key updates") + # Configure the client + configuration = client.Configuration() + configuration.host = self.api_server_url + if self.ca_cert_path: + configuration.ssl_ca_cert = self.ca_cert_path + configuration.verify_ssl = bool(self.ca_cert_path) + # Create API client + self._client = ApiClient(configuration) + return self._client -class OAuth2IntrospectionConfig(BaseModel): - url: str - client_id: str - client_secret: str - send_secret_in_body: bool = False - - -class OAuth2TokenAuthProviderConfig(BaseModel): - audience: str = "llama-stack" - verify_tls: bool = True - tls_cafile: Path | None = None - issuer: str | None = Field(default=None, description="The OIDC issuer URL.") - claims_mapping: dict[str, str] = Field( - default_factory=lambda: { - "sub": "roles", - "username": "roles", - "groups": "teams", - "team": "teams", - "project": "projects", - "tenant": "namespaces", - "namespace": "namespaces", - }, - ) - jwks: OAuth2JWKSConfig | None - introspection: OAuth2IntrospectionConfig | None = None - - @classmethod - @field_validator("claims_mapping") - def validate_claims_mapping(cls, v): - for key, value in v.items(): - if not value: - raise ValueError(f"claims_mapping value cannot be empty: {key}") - if value not in AccessAttributes.model_fields: - raise ValueError(f"claims_mapping value is not a valid attribute: {value}") - return v - - @model_validator(mode="after") - def validate_mode(self) -> Self: - if not self.jwks and not self.introspection: - raise ValueError("One of jwks or introspection must be configured") - if self.jwks and self.introspection: - raise ValueError("At present only one of jwks or introspection should be configured") - return self - - -class OAuth2TokenAuthProvider(AuthProvider): - """ - JWT token authentication provider that validates a JWT token and extracts access attributes. - - This should be the standard authentication provider for most use cases. - """ - - def __init__(self, config: OAuth2TokenAuthProviderConfig): - self.config = config - self._jwks_at: float = 0.0 - self._jwks: dict[str, str] = {} - self._jwks_lock = Lock() - - async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult: - if self.config.jwks: - return await self.validate_jwt_token(token, scope) - if self.config.introspection: - return await self.introspect_token(token, scope) - raise ValueError("One of jwks or introspection must be configured") - - async def validate_jwt_token(self, token: str, scope: dict | None = None) -> TokenValidationResult: - """Validate a token using the JWT token.""" - await self._refresh_jwks() - + async def validate_token(self, token: str, scope: dict | None = None) -> AccessAttributes | None: + """Validate a Kubernetes token and return access attributes.""" try: - header = jwt.get_unverified_header(token) - kid = header["kid"] - if kid not in self._jwks: - raise ValueError(f"Unknown key ID: {kid}") - key_data = self._jwks[kid] - algorithm = header.get("alg", "RS256") - claims = jwt.decode( - token, - key_data, - algorithms=[algorithm], - audience=self.config.audience, - issuer=self.config.issuer, + client = await self._get_client() + + # Set the token in the client + client.set_default_header("Authorization", f"Bearer {token}") + + # Make a request to validate the token + # We use the /api endpoint which requires authentication + from kubernetes.client import CoreV1Api + + api = CoreV1Api(client) + api.get_api_resources(_request_timeout=3.0) # Set timeout for this specific request + + # If we get here, the token is valid + # Extract user info from the token claims + import base64 + + # Decode the token (without verification since we've already validated it) + token_parts = token.split(".") + payload = json.loads(base64.b64decode(token_parts[1] + "=" * (-len(token_parts[1]) % 4))) + + # Extract user information from the token + username = payload.get("sub", "") + groups = payload.get("groups", []) + + return AccessAttributes( + roles=[username], # Use username as a role + teams=groups, # Use Kubernetes groups as teams ) - except Exception as exc: - raise ValueError(f"Invalid JWT token: {token}") from exc - # There are other standard claims, the most relevant of which is `scope`. - # We should incorporate these into the access attributes. - principal = claims["sub"] - access_attributes = get_attributes_from_claims(claims, self.config.claims_mapping) - return TokenValidationResult( - principal=principal, - access_attributes=access_attributes, - ) - - async def introspect_token(self, token: str, scope: dict | None = None) -> TokenValidationResult: - """Validate a token using token introspection as defined by RFC 7662.""" - form = { - "token": token, - } - if self.config.introspection is None: - raise ValueError("Introspection is not configured") - - if self.config.introspection.send_secret_in_body: - form["client_id"] = self.config.introspection.client_id - form["client_secret"] = self.config.introspection.client_secret - auth = None - else: - auth = (self.config.introspection.client_id, self.config.introspection.client_secret) - ssl_ctxt = None - if self.config.tls_cafile: - ssl_ctxt = ssl.create_default_context(cafile=self.config.tls_cafile.as_posix()) - try: - async with httpx.AsyncClient(verify=ssl_ctxt) as client: - response = await client.post( - self.config.introspection.url, - data=form, - auth=auth, - timeout=10.0, # Add a reasonable timeout - ) - if response.status_code != 200: - logger.warning(f"Token introspection failed with status code: {response.status_code}") - raise ValueError(f"Token introspection failed: {response.status_code}") - - fields = response.json() - if not fields["active"]: - raise ValueError("Token not active") - principal = fields["sub"] or fields["username"] - access_attributes = get_attributes_from_claims(fields, self.config.claims_mapping) - return TokenValidationResult( - principal=principal, - access_attributes=access_attributes, - ) - except httpx.TimeoutException: - logger.exception("Token introspection request timed out") - raise - except ValueError: - # Re-raise ValueError exceptions to preserve their message - raise except Exception as e: - logger.exception("Error during token introspection") - raise ValueError("Token introspection error") from e + logger.exception("Failed to validate Kubernetes token") + raise ValueError("Invalid or expired token") from e async def close(self): - pass - - async def _refresh_jwks(self) -> None: - """ - Refresh the JWKS cache. - - This is a simple cache that expires after a certain amount of time (defined by `key_recheck_period`). - If the cache is expired, we refresh the JWKS from the JWKS URI. - - Notes: for Kubernetes which doesn't fully implement the OIDC protocol: - * It doesn't have user authentication flows - * It doesn't have refresh tokens - """ - async with self._jwks_lock: - if self.config.jwks is None: - raise ValueError("JWKS is not configured") - if time.time() - self._jwks_at > self.config.jwks.key_recheck_period: - verify = self.config.tls_cafile.as_posix() if self.config.tls_cafile else self.config.verify_tls - async with httpx.AsyncClient(verify=verify) as client: - res = await client.get(self.config.jwks.uri, timeout=5) - res.raise_for_status() - jwks_data = res.json()["keys"] - updated = {} - for k in jwks_data: - kid = k["kid"] - # Store the entire key object as it may be needed for different algorithms - updated[kid] = k - self._jwks = updated - self._jwks_at = time.time() - - -class CustomAuthProviderConfig(BaseModel): - endpoint: str + """Close the HTTP client.""" + if self._client: + self._client.close() + self._client = None class CustomAuthProvider(AuthProvider): """Custom authentication provider that uses an external endpoint.""" - def __init__(self, config: CustomAuthProviderConfig): - self.config = config + def __init__(self, config: dict[str, str]): + self.endpoint = config["endpoint"] self._client = None - async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult: + async def validate_token(self, token: str, scope: dict | None = None) -> AccessAttributes | None: """Validate a token using the custom authentication endpoint.""" + if not self.endpoint: + raise ValueError("Authentication endpoint not configured") + if scope is None: scope = {} @@ -329,7 +202,7 @@ class CustomAuthProvider(AuthProvider): try: async with httpx.AsyncClient() as client: response = await client.post( - self.config.endpoint, + self.endpoint, json=auth_request.model_dump(), timeout=10.0, # Add a reasonable timeout ) @@ -341,7 +214,19 @@ class CustomAuthProvider(AuthProvider): try: response_data = response.json() auth_response = AuthResponse(**response_data) - return auth_response + + # Store attributes in request scope for access control + if auth_response.access_attributes: + return auth_response.access_attributes + else: + logger.warning("No access attributes, setting namespace to api_key by default") + user_attributes = { + "namespaces": [token], + } + + scope["user_attributes"] = user_attributes + logger.debug(f"Authentication successful: {len(user_attributes)} attributes") + return auth_response.access_attributes except Exception as e: logger.exception("Error parsing authentication response") raise ValueError("Invalid authentication response format") from e @@ -363,14 +248,14 @@ class CustomAuthProvider(AuthProvider): self._client = None -def create_auth_provider(config: AuthenticationConfig) -> AuthProvider: +def create_auth_provider(config: AuthProviderConfig) -> AuthProvider: """Factory function to create the appropriate auth provider.""" provider_type = config.provider_type.lower() - if provider_type == "custom": - return CustomAuthProvider(CustomAuthProviderConfig.model_validate(config.config)) - elif provider_type == "oauth2_token": - return OAuth2TokenAuthProvider(OAuth2TokenAuthProviderConfig.model_validate(config.config)) + if provider_type == "kubernetes": + return KubernetesAuthProvider(config.config) + elif provider_type == "custom": + return CustomAuthProvider(config.config) else: supported_providers = ", ".join([t.value for t in AuthProviderType]) raise ValueError(f"Unsupported auth provider type: {provider_type}. Supported types are: {supported_providers}") diff --git a/llama_stack/distribution/server/routes.py b/llama_stack/distribution/server/endpoints.py similarity index 55% rename from llama_stack/distribution/server/routes.py rename to llama_stack/distribution/server/endpoints.py index ea66fec5a..ec1f7e083 100644 --- a/llama_stack/distribution/server/routes.py +++ b/llama_stack/distribution/server/endpoints.py @@ -6,23 +6,20 @@ import inspect import re -from collections.abc import Callable -from typing import Any -from aiohttp import hdrs -from starlette.routing import Route +from pydantic import BaseModel from llama_stack.apis.tools import RAGToolRuntime, SpecialToolGroup from llama_stack.apis.version import LLAMA_STACK_API_VERSION from llama_stack.distribution.resolver import api_protocol_map from llama_stack.providers.datatypes import Api -EndpointFunc = Callable[..., Any] -PathParams = dict[str, str] -RouteInfo = tuple[EndpointFunc, str] -PathImpl = dict[str, RouteInfo] -RouteImpls = dict[str, PathImpl] -RouteMatch = tuple[EndpointFunc, PathParams, str] + +class ApiEndpoint(BaseModel): + route: str + method: str + name: str + descriptive_name: str | None = None def toolgroup_protocol_map(): @@ -31,13 +28,13 @@ def toolgroup_protocol_map(): } -def get_all_api_routes() -> dict[Api, list[Route]]: +def get_all_api_endpoints() -> dict[Api, list[ApiEndpoint]]: apis = {} protocols = api_protocol_map() toolgroup_protocols = toolgroup_protocol_map() for api, protocol in protocols.items(): - routes = [] + endpoints = [] protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction) # HACK ALERT @@ -54,28 +51,26 @@ def get_all_api_routes() -> dict[Api, list[Route]]: if not hasattr(method, "__webmethod__"): continue - # The __webmethod__ attribute is dynamically added by the @webmethod decorator - # mypy doesn't know about this dynamic attribute, so we ignore the attr-defined error - webmethod = method.__webmethod__ # type: ignore[attr-defined] - path = f"/{LLAMA_STACK_API_VERSION}/{webmethod.route.lstrip('/')}" - if webmethod.method == hdrs.METH_GET: - http_method = hdrs.METH_GET - elif webmethod.method == hdrs.METH_DELETE: - http_method = hdrs.METH_DELETE + webmethod = method.__webmethod__ + route = f"/{LLAMA_STACK_API_VERSION}/{webmethod.route.lstrip('/')}" + if webmethod.method == "GET": + method = "get" + elif webmethod.method == "DELETE": + method = "delete" else: - http_method = hdrs.METH_POST - routes.append( - Route(path=path, methods=[http_method], name=name, endpoint=None) - ) # setting endpoint to None since don't use a Router object + method = "post" + endpoints.append( + ApiEndpoint(route=route, method=method, name=name, descriptive_name=webmethod.descriptive_name) + ) - apis[api] = routes + apis[api] = endpoints return apis -def initialize_route_impls(impls: dict[Api, Any]) -> RouteImpls: - routes = get_all_api_routes() - route_impls: RouteImpls = {} +def initialize_endpoint_impls(impls): + endpoints = get_all_api_endpoints() + endpoint_impls = {} def _convert_path_to_regex(path: str) -> str: # Convert {param} to named capture groups @@ -88,34 +83,29 @@ def initialize_route_impls(impls: dict[Api, Any]) -> RouteImpls: return f"^{pattern}$" - for api, api_routes in routes.items(): + for api, api_endpoints in endpoints.items(): if api not in impls: continue - for route in api_routes: + for endpoint in api_endpoints: impl = impls[api] - func = getattr(impl, route.name) - # Get the first (and typically only) method from the set, filtering out HEAD - available_methods = [m for m in route.methods if m != "HEAD"] - if not available_methods: - continue # Skip if only HEAD method is available - method = available_methods[0].lower() - if method not in route_impls: - route_impls[method] = {} - route_impls[method][_convert_path_to_regex(route.path)] = ( + func = getattr(impl, endpoint.name) + if endpoint.method not in endpoint_impls: + endpoint_impls[endpoint.method] = {} + endpoint_impls[endpoint.method][_convert_path_to_regex(endpoint.route)] = ( func, - route.path, + endpoint.descriptive_name or endpoint.route, ) - return route_impls + return endpoint_impls -def find_matching_route(method: str, path: str, route_impls: RouteImpls) -> RouteMatch: +def find_matching_endpoint(method, path, endpoint_impls): """Find the matching endpoint implementation for a given method and path. Args: method: HTTP method (GET, POST, etc.) path: URL path to match against - route_impls: A dictionary of endpoint implementations + endpoint_impls: A dictionary of endpoint implementations Returns: A tuple of (endpoint_function, path_params, descriptive_name) @@ -123,7 +113,7 @@ def find_matching_route(method: str, path: str, route_impls: RouteImpls) -> Rout Raises: ValueError: If no matching endpoint is found """ - impls = route_impls.get(method.lower()) + impls = endpoint_impls.get(method.lower()) if not impls: raise ValueError(f"No endpoint found for {path}") diff --git a/llama_stack/distribution/server/quota.py b/llama_stack/distribution/server/quota.py deleted file mode 100644 index ddbffae64..000000000 --- a/llama_stack/distribution/server/quota.py +++ /dev/null @@ -1,110 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import json -import time -from datetime import datetime, timedelta, timezone - -from starlette.types import ASGIApp, Receive, Scope, Send - -from llama_stack.log import get_logger -from llama_stack.providers.utils.kvstore.api import KVStore -from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig -from llama_stack.providers.utils.kvstore.kvstore import kvstore_impl - -logger = get_logger(name=__name__, category="quota") - - -class QuotaMiddleware: - """ - ASGI middleware that enforces separate quotas for authenticated and anonymous clients - within a configurable time window. - - - For authenticated requests, it reads the client ID from the - `Authorization: Bearer ` header. - - For anonymous requests, it falls back to the IP address of the client. - Requests are counted in a KV store (e.g., SQLite), and HTTP 429 is returned - once a client exceeds its quota. - """ - - def __init__( - self, - app: ASGIApp, - kv_config: KVStoreConfig, - anonymous_max_requests: int, - authenticated_max_requests: int, - window_seconds: int = 86400, - ): - self.app = app - self.kv_config = kv_config - self.kv: KVStore | None = None - self.anonymous_max_requests = anonymous_max_requests - self.authenticated_max_requests = authenticated_max_requests - self.window_seconds = window_seconds - - if isinstance(self.kv_config, SqliteKVStoreConfig): - logger.warning( - "QuotaMiddleware: Using SQLite backend. Expiry/TTL is not enforced; cleanup is manual. " - f"window_seconds={self.window_seconds}" - ) - - async def _get_kv(self) -> KVStore: - if self.kv is None: - self.kv = await kvstore_impl(self.kv_config) - return self.kv - - async def __call__(self, scope: Scope, receive: Receive, send: Send): - if scope["type"] == "http": - # pick key & limit based on auth - auth_id = scope.get("authenticated_client_id") - if auth_id: - key_id = auth_id - limit = self.authenticated_max_requests - else: - # fallback to IP - client = scope.get("client") - key_id = client[0] if client else "anonymous" - limit = self.anonymous_max_requests - - current_window = int(time.time() // self.window_seconds) - key = f"quota:{key_id}:{current_window}" - - try: - kv = await self._get_kv() - prev = await kv.get(key) or "0" - count = int(prev) + 1 - - if int(prev) == 0: - # Set with expiration datetime when it is the first request in the window. - expiration = datetime.now(timezone.utc) + timedelta(seconds=self.window_seconds) - await kv.set(key, str(count), expiration=expiration) - else: - await kv.set(key, str(count)) - except Exception: - logger.exception("Failed to access KV store for quota") - return await self._send_error(send, 500, "Quota service error") - - if count > limit: - logger.warning( - "Quota exceeded for client %s: %d/%d", - key_id, - count, - limit, - ) - return await self._send_error(send, 429, "Quota exceeded") - - return await self.app(scope, receive, send) - - async def _send_error(self, send: Send, status: int, message: str): - await send( - { - "type": "http.response.start", - "status": status, - "headers": [[b"content-type", b"application/json"]], - } - ) - body = json.dumps({"error": {"message": message}}).encode() - await send({"type": "http.response.body", "body": body}) diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 6c88bbfe9..ff0775dd6 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -6,7 +6,6 @@ import argparse import asyncio -import functools import inspect import json import os @@ -14,7 +13,6 @@ import ssl import sys import traceback import warnings -from collections.abc import Callable from contextlib import asynccontextmanager from importlib.metadata import version as parse_version from pathlib import Path @@ -22,26 +20,23 @@ from typing import Annotated, Any import rich.pretty import yaml -from aiohttp import hdrs from fastapi import Body, FastAPI, HTTPException, Request from fastapi import Path as FastapiPath from fastapi.exceptions import RequestValidationError -from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, StreamingResponse from openai import BadRequestError from pydantic import BaseModel, ValidationError -from llama_stack.distribution.datatypes import AuthenticationRequiredError, LoggingConfig, StackRunConfig +from llama_stack.distribution.datatypes import LoggingConfig, StackRunConfig from llama_stack.distribution.distribution import builtin_automatically_routed_apis from llama_stack.distribution.request_headers import ( PROVIDER_DATA_VAR, request_provider_data_context, ) from llama_stack.distribution.resolver import InvalidProviderError -from llama_stack.distribution.server.routes import ( - find_matching_route, - get_all_api_routes, - initialize_route_impls, +from llama_stack.distribution.server.endpoints import ( + find_matching_endpoint, + initialize_endpoint_impls, ) from llama_stack.distribution.stack import ( construct_stack, @@ -64,7 +59,7 @@ from llama_stack.providers.utils.telemetry.tracing import ( ) from .auth import AuthenticationMiddleware -from .quota import QuotaMiddleware +from .endpoints import get_all_api_endpoints REPO_ROOT = Path(__file__).parent.parent.parent.parent @@ -125,8 +120,6 @@ def translate_exception(exc: Exception) -> HTTPException | RequestValidationErro return HTTPException(status_code=504, detail=f"Operation timed out: {str(exc)}") elif isinstance(exc, NotImplementedError): return HTTPException(status_code=501, detail=f"Not implemented: {str(exc)}") - elif isinstance(exc, AuthenticationRequiredError): - return HTTPException(status_code=401, detail=f"Authentication required: {str(exc)}") else: return HTTPException( status_code=500, @@ -212,9 +205,8 @@ async def log_request_pre_validation(request: Request): logger.warning(f"Could not read or log request body for {request.method} {request.url.path}: {e}") -def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable: - @functools.wraps(func) - async def route_handler(request: Request, **kwargs): +def create_dynamic_typed_route(func: Any, method: str, route: str): + async def endpoint(request: Request, **kwargs): # Get auth attributes from the request scope user_attributes = request.scope.get("user_attributes", {}) @@ -254,9 +246,9 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable: for param in new_params[1:] ] - route_handler.__signature__ = sig.replace(parameters=new_params) + endpoint.__signature__ = sig.replace(parameters=new_params) - return route_handler + return endpoint class TracingMiddleware: @@ -278,28 +270,17 @@ class TracingMiddleware: logger.debug(f"Bypassing custom routing for FastAPI built-in path: {path}") return await self.app(scope, receive, send) - if not hasattr(self, "route_impls"): - self.route_impls = initialize_route_impls(self.impls) + if not hasattr(self, "endpoint_impls"): + self.endpoint_impls = initialize_endpoint_impls(self.impls) try: - _, _, trace_path = find_matching_route(scope.get("method", hdrs.METH_GET), path, self.route_impls) + _, _, trace_path = find_matching_endpoint(scope.get("method", "GET"), path, self.endpoint_impls) except ValueError: # If no matching endpoint is found, pass through to FastAPI - logger.debug(f"No matching route found for path: {path}, falling back to FastAPI") + logger.debug(f"No matching endpoint found for path: {path}, falling back to FastAPI") return await self.app(scope, receive, send) - trace_attributes = {"__location__": "server", "raw_path": path} - - # Extract W3C trace context headers and store as trace attributes - headers = dict(scope.get("headers", [])) - traceparent = headers.get(b"traceparent", b"").decode() - if traceparent: - trace_attributes["traceparent"] = traceparent - tracestate = headers.get(b"tracestate", b"").decode() - if tracestate: - trace_attributes["tracestate"] = tracestate - - trace_context = await start_trace(trace_path, trace_attributes) + trace_context = await start_trace(trace_path, {"__location__": "server", "raw_path": path}) async def send_with_trace_id(message): if message["type"] == "http.response.start": @@ -389,6 +370,14 @@ def main(args: argparse.Namespace | None = None): if args is None: args = parser.parse_args() + # Check for deprecated argument usage + if "--yaml-config" in sys.argv: + warnings.warn( + "The '--yaml-config' argument is deprecated and will be removed in a future version. Use '--config' instead.", + DeprecationWarning, + stacklevel=2, + ) + log_line = "" if args.config: # if the user provided a config file, use it, even if template was specified @@ -402,7 +391,7 @@ def main(args: argparse.Namespace | None = None): raise ValueError(f"Template {args.template} does not exist") log_line = f"Using template {args.template} config file: {config_file}" else: - raise ValueError("Either --config or --template must be provided") + raise ValueError("Either --yaml-config or --template must be provided") logger_config = None with open(config_file) as fp: @@ -442,46 +431,6 @@ def main(args: argparse.Namespace | None = None): if config.server.auth: logger.info(f"Enabling authentication with provider: {config.server.auth.provider_type.value}") app.add_middleware(AuthenticationMiddleware, auth_config=config.server.auth) - else: - if config.server.quota: - quota = config.server.quota - logger.warning( - "Configured authenticated_max_requests (%d) but no auth is enabled; " - "falling back to anonymous_max_requests (%d) for all the requests", - quota.authenticated_max_requests, - quota.anonymous_max_requests, - ) - - if config.server.quota: - logger.info("Enabling quota middleware for authenticated and anonymous clients") - - quota = config.server.quota - anonymous_max_requests = quota.anonymous_max_requests - # if auth is disabled, use the anonymous max requests - authenticated_max_requests = quota.authenticated_max_requests if config.server.auth else anonymous_max_requests - - kv_config = quota.kvstore - window_map = {"day": 86400} - window_seconds = window_map[quota.period.value] - - app.add_middleware( - QuotaMiddleware, - kv_config=kv_config, - anonymous_max_requests=anonymous_max_requests, - authenticated_max_requests=authenticated_max_requests, - window_seconds=window_seconds, - ) - - # --- CORS middleware for local development --- - # TODO: move to reverse proxy - ui_port = os.environ.get("LLAMA_STACK_UI_PORT", 8322) - app.add_middleware( - CORSMiddleware, - allow_origins=[f"http://localhost:{ui_port}"], - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], - ) try: impls = asyncio.run(construct_stack(config)) @@ -494,7 +443,7 @@ def main(args: argparse.Namespace | None = None): else: setup_logger(TelemetryAdapter(TelemetryConfig(), {})) - all_routes = get_all_api_routes() + all_endpoints = get_all_api_endpoints() if config.apis: apis_to_serve = set(config.apis) @@ -512,29 +461,24 @@ def main(args: argparse.Namespace | None = None): for api_str in apis_to_serve: api = Api(api_str) - routes = all_routes[api] + endpoints = all_endpoints[api] impl = impls[api] - for route in routes: - if not hasattr(impl, route.name): + for endpoint in endpoints: + if not hasattr(impl, endpoint.name): # ideally this should be a typing violation already - raise ValueError(f"Could not find method {route.name} on {impl}!") + raise ValueError(f"Could not find method {endpoint.name} on {impl}!!") - impl_method = getattr(impl, route.name) - # Filter out HEAD method since it's automatically handled by FastAPI for GET routes - available_methods = [m for m in route.methods if m != "HEAD"] - if not available_methods: - raise ValueError(f"No methods found for {route.name} on {impl}") - method = available_methods[0] - logger.debug(f"{method} {route.path}") + impl_method = getattr(impl, endpoint.name) + logger.debug(f"{endpoint.method.upper()} {endpoint.route}") with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=UserWarning, module="pydantic._internal._fields") - getattr(app, method.lower())(route.path, response_model=None)( + getattr(app, endpoint.method)(endpoint.route, response_model=None)( create_dynamic_typed_route( impl_method, - method.lower(), - route.path, + endpoint.method, + endpoint.route, ) ) diff --git a/llama_stack/distribution/start_stack.sh b/llama_stack/distribution/start_stack.sh index 996935a5e..bf49e1619 100755 --- a/llama_stack/distribution/start_stack.sh +++ b/llama_stack/distribution/start_stack.sh @@ -54,7 +54,7 @@ other_args="" # Process remaining arguments while [[ $# -gt 0 ]]; do case "$1" in - --config) + --config|--yaml-config) if [[ -n "$2" ]]; then yaml_config="$2" shift 2 @@ -121,7 +121,7 @@ if [[ "$env_type" == "venv" || "$env_type" == "conda" ]]; then set -x if [ -n "$yaml_config" ]; then - yaml_config_arg="--config $yaml_config" + yaml_config_arg="--yaml-config $yaml_config" else yaml_config_arg="" fi @@ -181,9 +181,9 @@ elif [[ "$env_type" == "container" ]]; then # Add yaml config if provided, otherwise use default if [ -n "$yaml_config" ]; then - cmd="$cmd -v $yaml_config:/app/run.yaml --config /app/run.yaml" + cmd="$cmd -v $yaml_config:/app/run.yaml --yaml-config /app/run.yaml" else - cmd="$cmd --config /app/run.yaml" + cmd="$cmd --yaml-config /app/run.yaml" fi # Add any other args diff --git a/llama_stack/distribution/store/registry.py b/llama_stack/distribution/store/registry.py index 0e84854c2..a6b400136 100644 --- a/llama_stack/distribution/store/registry.py +++ b/llama_stack/distribution/store/registry.py @@ -36,7 +36,7 @@ class DistributionRegistry(Protocol): REGISTER_PREFIX = "distributions:registry" -KEY_VERSION = "v9" +KEY_VERSION = "v8" KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}" diff --git a/llama_stack/distribution/ui/Containerfile b/llama_stack/distribution/ui/Containerfile index 5d2dc933b..0126d1867 100644 --- a/llama_stack/distribution/ui/Containerfile +++ b/llama_stack/distribution/ui/Containerfile @@ -5,8 +5,7 @@ FROM python:3.12-slim WORKDIR /app COPY . /app/ RUN /usr/local/bin/python -m pip install --upgrade pip && \ - /usr/local/bin/pip3 install -r requirements.txt && \ - /usr/local/bin/pip3 install -r llama_stack/distribution/ui/requirements.txt + /usr/local/bin/pip3 install -r requirements.txt EXPOSE 8501 -ENTRYPOINT ["streamlit", "run", "llama_stack/distribution/ui/app.py", "--server.port=8501", "--server.address=0.0.0.0"] +ENTRYPOINT ["streamlit", "run", "app.py", "--server.port=8501", "--server.address=0.0.0.0"] diff --git a/llama_stack/distribution/ui/README.md b/llama_stack/distribution/ui/README.md index 0e96690ec..51c2d2bc2 100644 --- a/llama_stack/distribution/ui/README.md +++ b/llama_stack/distribution/ui/README.md @@ -48,6 +48,3 @@ uv run --with ".[ui]" streamlit run llama_stack/distribution/ui/app.py | TOGETHER_API_KEY | API key for Together provider | (empty string) | | SAMBANOVA_API_KEY | API key for SambaNova provider | (empty string) | | OPENAI_API_KEY | API key for OpenAI provider | (empty string) | -| KEYCLOAK_URL | URL for keycloak authentication | (empty string) | -| KEYCLOAK_REALM | Keycloak realm | default | -| KEYCLOAK_CLIENT_ID | Client ID for keycloak auth | (empty string) | \ No newline at end of file diff --git a/llama_stack/distribution/ui/app.py b/llama_stack/distribution/ui/app.py index a9a28b445..441f65d20 100644 --- a/llama_stack/distribution/ui/app.py +++ b/llama_stack/distribution/ui/app.py @@ -50,42 +50,6 @@ def main(): ) pg.run() -def main2(): - from dataclasses import asdict - st.subheader(f"Welcome {keycloak.user_info['preferred_username']}!") - st.write(f"Here is your user information:") - st.write(asdict(keycloak)) - -def get_access_token() -> str|None: - return st.session_state.get('access_token') if __name__ == "__main__": - - from streamlit_keycloak import login - import os - - keycloak_url = os.environ.get("KEYCLOAK_URL") - keycloak_realm = os.environ.get("KEYCLOAK_REALM", "default") - keycloak_client_id = os.environ.get("KEYCLOAK_CLIENT_ID") - - if keycloak_url and keycloak_client_id: - keycloak = login( - url=keycloak_url, - realm=keycloak_realm, - client_id=keycloak_client_id, - custom_labels={ - "labelButton": "Sign in to kvant", - "labelLogin": "Please sign in to your kvant account.", - "errorNoPopup": "Unable to open the authentication popup. Allow popups and refresh the page to proceed.", - "errorPopupClosed": "Authentication popup was closed manually.", - "errorFatal": "Unable to connect to Keycloak using the current configuration." - }, - auto_refresh=True, - ) - - if keycloak.authenticated: - st.session_state['access_token'] = keycloak.access_token - main() - # TBD - add other authentications - else: - main() + main() diff --git a/llama_stack/distribution/ui/modules/api.py b/llama_stack/distribution/ui/modules/api.py index a426e59ba..11455ed46 100644 --- a/llama_stack/distribution/ui/modules/api.py +++ b/llama_stack/distribution/ui/modules/api.py @@ -7,13 +7,11 @@ import os from llama_stack_client import LlamaStackClient -from llama_stack.distribution.ui.app import get_access_token class LlamaStackApi: def __init__(self): self.client = LlamaStackClient( - api_key=get_access_token(), base_url=os.environ.get("LLAMA_STACK_ENDPOINT", "http://localhost:8321"), provider_data={ "fireworks_api_key": os.environ.get("FIREWORKS_API_KEY", ""), @@ -30,3 +28,5 @@ class LlamaStackApi: scoring_params = {fn_id: None for fn_id in scoring_function_ids} return self.client.scoring.score(input_rows=[row], scoring_functions=scoring_params) + +llama_stack_api = LlamaStackApi() diff --git a/llama_stack/distribution/ui/page/distribution/datasets.py b/llama_stack/distribution/ui/page/distribution/datasets.py index 89f645ca8..6842b29a7 100644 --- a/llama_stack/distribution/ui/page/distribution/datasets.py +++ b/llama_stack/distribution/ui/page/distribution/datasets.py @@ -6,13 +6,13 @@ import streamlit as st -from llama_stack.distribution.ui.modules.api import LlamaStackApi +from llama_stack.distribution.ui.modules.api import llama_stack_api def datasets(): st.header("Datasets") - datasets_info = {d.identifier: d.to_dict() for d in LlamaStackApi().client.datasets.list()} + datasets_info = {d.identifier: d.to_dict() for d in llama_stack_api.client.datasets.list()} if len(datasets_info) > 0: selected_dataset = st.selectbox("Select a dataset", list(datasets_info.keys())) st.json(datasets_info[selected_dataset], expanded=True) diff --git a/llama_stack/distribution/ui/page/distribution/eval_tasks.py b/llama_stack/distribution/ui/page/distribution/eval_tasks.py index 2b70f9202..492be4700 100644 --- a/llama_stack/distribution/ui/page/distribution/eval_tasks.py +++ b/llama_stack/distribution/ui/page/distribution/eval_tasks.py @@ -6,14 +6,14 @@ import streamlit as st -from llama_stack.distribution.ui.modules.api import LlamaStackApi +from llama_stack.distribution.ui.modules.api import llama_stack_api def benchmarks(): # Benchmarks Section st.header("Benchmarks") - benchmarks_info = {d.identifier: d.to_dict() for d in LlamaStackApi().client.benchmarks.list()} + benchmarks_info = {d.identifier: d.to_dict() for d in llama_stack_api.client.benchmarks.list()} if len(benchmarks_info) > 0: selected_benchmark = st.selectbox("Select an eval task", list(benchmarks_info.keys()), key="benchmark_inspect") diff --git a/llama_stack/distribution/ui/page/distribution/models.py b/llama_stack/distribution/ui/page/distribution/models.py index 3b96f179f..f29459098 100644 --- a/llama_stack/distribution/ui/page/distribution/models.py +++ b/llama_stack/distribution/ui/page/distribution/models.py @@ -6,13 +6,13 @@ import streamlit as st -from llama_stack.distribution.ui.modules.api import LlamaStackApi +from llama_stack.distribution.ui.modules.api import llama_stack_api def models(): # Models Section st.header("Models") - models_info = {m.identifier: m.to_dict() for m in LlamaStackApi().client.models.list()} + models_info = {m.identifier: m.to_dict() for m in llama_stack_api.client.models.list()} selected_model = st.selectbox("Select a model", list(models_info.keys())) st.json(models_info[selected_model]) diff --git a/llama_stack/distribution/ui/page/distribution/providers.py b/llama_stack/distribution/ui/page/distribution/providers.py index 116237b13..c660cb986 100644 --- a/llama_stack/distribution/ui/page/distribution/providers.py +++ b/llama_stack/distribution/ui/page/distribution/providers.py @@ -6,12 +6,12 @@ import streamlit as st -from llama_stack.distribution.ui.modules.api import LlamaStackApi +from llama_stack.distribution.ui.modules.api import llama_stack_api def providers(): st.header("🔍 API Providers") - apis_providers_lst = LlamaStackApi().client.providers.list() + apis_providers_lst = llama_stack_api.client.providers.list() api_to_providers = {} for api_provider in apis_providers_lst: if api_provider.api in api_to_providers: diff --git a/llama_stack/distribution/ui/page/distribution/scoring_functions.py b/llama_stack/distribution/ui/page/distribution/scoring_functions.py index 3c3428f44..193146356 100644 --- a/llama_stack/distribution/ui/page/distribution/scoring_functions.py +++ b/llama_stack/distribution/ui/page/distribution/scoring_functions.py @@ -6,13 +6,13 @@ import streamlit as st -from llama_stack.distribution.ui.modules.api import LlamaStackApi +from llama_stack.distribution.ui.modules.api import llama_stack_api def scoring_functions(): st.header("Scoring Functions") - scoring_functions_info = {s.identifier: s.to_dict() for s in LlamaStackApi().client.scoring_functions.list()} + scoring_functions_info = {s.identifier: s.to_dict() for s in llama_stack_api.client.scoring_functions.list()} selected_scoring_function = st.selectbox("Select a scoring function", list(scoring_functions_info.keys())) st.json(scoring_functions_info[selected_scoring_function], expanded=True) diff --git a/llama_stack/distribution/ui/page/distribution/shields.py b/llama_stack/distribution/ui/page/distribution/shields.py index 84b583980..67d66d64f 100644 --- a/llama_stack/distribution/ui/page/distribution/shields.py +++ b/llama_stack/distribution/ui/page/distribution/shields.py @@ -6,14 +6,14 @@ import streamlit as st -from llama_stack.distribution.ui.modules.api import LlamaStackApi +from llama_stack.distribution.ui.modules.api import llama_stack_api def shields(): # Shields Section st.header("Shields") - shields_info = {s.identifier: s.to_dict() for s in LlamaStackApi().client.shields.list()} + shields_info = {s.identifier: s.to_dict() for s in llama_stack_api.client.shields.list()} selected_shield = st.selectbox("Select a shield", list(shields_info.keys())) st.json(shields_info[selected_shield]) diff --git a/llama_stack/distribution/ui/page/distribution/vector_dbs.py b/llama_stack/distribution/ui/page/distribution/vector_dbs.py index e7eb7b13b..49a4f25bb 100644 --- a/llama_stack/distribution/ui/page/distribution/vector_dbs.py +++ b/llama_stack/distribution/ui/page/distribution/vector_dbs.py @@ -6,12 +6,12 @@ import streamlit as st -from llama_stack.distribution.ui.modules.api import LlamaStackApi +from llama_stack.distribution.ui.modules.api import llama_stack_api def vector_dbs(): st.header("Vector Databases") - vector_dbs_info = {v.identifier: v.to_dict() for v in LlamaStackApi().client.vector_dbs.list()} + vector_dbs_info = {v.identifier: v.to_dict() for v in llama_stack_api.client.vector_dbs.list()} if len(vector_dbs_info) > 0: selected_vector_db = st.selectbox("Select a vector database", list(vector_dbs_info.keys())) diff --git a/llama_stack/distribution/ui/page/evaluations/app_eval.py b/llama_stack/distribution/ui/page/evaluations/app_eval.py index 13da6071e..d7bc6388c 100644 --- a/llama_stack/distribution/ui/page/evaluations/app_eval.py +++ b/llama_stack/distribution/ui/page/evaluations/app_eval.py @@ -9,7 +9,7 @@ import json import pandas as pd import streamlit as st -from llama_stack.distribution.ui.modules.api import LlamaStackApi +from llama_stack.distribution.ui.modules.api import llama_stack_api from llama_stack.distribution.ui.modules.utils import process_dataset @@ -39,7 +39,7 @@ def application_evaluation_page(): # Select Scoring Functions to Run Evaluation On st.subheader("Select Scoring Functions") - scoring_functions = LlamaStackApi().client.scoring_functions.list() + scoring_functions = llama_stack_api.client.scoring_functions.list() scoring_functions = {sf.identifier: sf for sf in scoring_functions} scoring_functions_names = list(scoring_functions.keys()) selected_scoring_functions = st.multiselect( @@ -48,7 +48,7 @@ def application_evaluation_page(): help="Choose one or more scoring functions.", ) - available_models = LlamaStackApi().client.models.list() + available_models = llama_stack_api.client.models.list() available_models = [m.identifier for m in available_models] scoring_params = {} @@ -108,7 +108,7 @@ def application_evaluation_page(): progress_bar.progress(progress, text=progress_text) # Run evaluation for current row - score_res = LlamaStackApi().run_scoring( + score_res = llama_stack_api.run_scoring( r, scoring_function_ids=selected_scoring_functions, scoring_params=scoring_params, diff --git a/llama_stack/distribution/ui/page/evaluations/native_eval.py b/llama_stack/distribution/ui/page/evaluations/native_eval.py index 133c3b151..97f875e17 100644 --- a/llama_stack/distribution/ui/page/evaluations/native_eval.py +++ b/llama_stack/distribution/ui/page/evaluations/native_eval.py @@ -9,13 +9,13 @@ import json import pandas as pd import streamlit as st -from llama_stack.distribution.ui.modules.api import LlamaStackApi +from llama_stack.distribution.ui.modules.api import llama_stack_api def select_benchmark_1(): # Select Benchmarks st.subheader("1. Choose An Eval Task") - benchmarks = LlamaStackApi().client.benchmarks.list() + benchmarks = llama_stack_api.client.benchmarks.list() benchmarks = {et.identifier: et for et in benchmarks} benchmarks_names = list(benchmarks.keys()) selected_benchmark = st.selectbox( @@ -47,7 +47,7 @@ def define_eval_candidate_2(): # Define Eval Candidate candidate_type = st.radio("Candidate Type", ["model", "agent"]) - available_models = LlamaStackApi().client.models.list() + available_models = llama_stack_api.client.models.list() available_models = [model.identifier for model in available_models] selected_model = st.selectbox( "Choose a model", @@ -167,7 +167,7 @@ def run_evaluation_3(): eval_candidate = st.session_state["eval_candidate"] dataset_id = benchmarks[selected_benchmark].dataset_id - rows = LlamaStackApi().client.datasets.iterrows( + rows = llama_stack_api.client.datasets.iterrows( dataset_id=dataset_id, ) total_rows = len(rows.data) @@ -208,7 +208,7 @@ def run_evaluation_3(): progress = i / len(rows) progress_bar.progress(progress, text=progress_text) # Run evaluation for current row - eval_res = LlamaStackApi().client.eval.evaluate_rows( + eval_res = llama_stack_api.client.eval.evaluate_rows( benchmark_id=selected_benchmark, input_rows=[r], scoring_functions=benchmarks[selected_benchmark].scoring_functions, diff --git a/llama_stack/distribution/ui/page/playground/chat.py b/llama_stack/distribution/ui/page/playground/chat.py index 053ae42de..fcaf08795 100644 --- a/llama_stack/distribution/ui/page/playground/chat.py +++ b/llama_stack/distribution/ui/page/playground/chat.py @@ -6,12 +6,12 @@ import streamlit as st -from llama_stack.distribution.ui.modules.api import LlamaStackApi +from llama_stack.distribution.ui.modules.api import llama_stack_api # Sidebar configurations with st.sidebar: st.header("Configuration") - available_models = LlamaStackApi().client.models.list() + available_models = llama_stack_api.client.models.list() available_models = [model.identifier for model in available_models if model.model_type == "llm"] selected_model = st.selectbox( "Choose a model", @@ -103,7 +103,7 @@ if prompt := st.chat_input("Example: What is Llama Stack?"): else: strategy = {"type": "greedy"} - response = LlamaStackApi().client.inference.chat_completion( + response = llama_stack_api.client.inference.chat_completion( messages=[ {"role": "system", "content": system_prompt}, {"role": "user", "content": prompt}, diff --git a/llama_stack/distribution/ui/page/playground/rag.py b/llama_stack/distribution/ui/page/playground/rag.py index 94e27a255..696d89bc2 100644 --- a/llama_stack/distribution/ui/page/playground/rag.py +++ b/llama_stack/distribution/ui/page/playground/rag.py @@ -10,7 +10,7 @@ import streamlit as st from llama_stack_client import Agent, AgentEventLogger, RAGDocument from llama_stack.apis.common.content_types import ToolCallDelta -from llama_stack.distribution.ui.modules.api import LlamaStackApi +from llama_stack.distribution.ui.modules.api import llama_stack_api from llama_stack.distribution.ui.modules.utils import data_url_from_file @@ -57,14 +57,14 @@ def rag_chat_page(): for i, uploaded_file in enumerate(uploaded_files) ] - providers = LlamaStackApi().client.providers.list() + providers = llama_stack_api.client.providers.list() vector_io_provider = None for x in providers: if x.api == "vector_io": vector_io_provider = x.provider_id - LlamaStackApi().client.vector_dbs.register( + llama_stack_api.client.vector_dbs.register( vector_db_id=vector_db_name, # Use the user-provided name embedding_dimension=384, embedding_model="all-MiniLM-L6-v2", @@ -72,7 +72,7 @@ def rag_chat_page(): ) # insert documents using the custom vector db name - LlamaStackApi().client.tool_runtime.rag_tool.insert( + llama_stack_api.client.tool_runtime.rag_tool.insert( vector_db_id=vector_db_name, # Use the user-provided name documents=documents, chunk_size_in_tokens=512, @@ -93,7 +93,7 @@ def rag_chat_page(): ) # select memory banks - vector_dbs = LlamaStackApi().client.vector_dbs.list() + vector_dbs = llama_stack_api.client.vector_dbs.list() vector_dbs = [vector_db.identifier for vector_db in vector_dbs] selected_vector_dbs = st.multiselect( label="Select Document Collections to use in RAG queries", @@ -103,7 +103,7 @@ def rag_chat_page(): ) st.subheader("Inference Parameters", divider=True) - available_models = LlamaStackApi().client.models.list() + available_models = llama_stack_api.client.models.list() available_models = [model.identifier for model in available_models if model.model_type == "llm"] selected_model = st.selectbox( label="Choose a model", @@ -167,7 +167,7 @@ def rag_chat_page(): @st.cache_resource def create_agent(): return Agent( - LlamaStackApi().client, + llama_stack_api.client, model=selected_model, instructions=system_prompt, sampling_params={ @@ -232,7 +232,7 @@ def rag_chat_page(): st.session_state.messages.append({"role": "system", "content": system_prompt}) # Query the vector DB - rag_response = LlamaStackApi().client.tool_runtime.rag_tool.query( + rag_response = llama_stack_api.client.tool_runtime.rag_tool.query( content=prompt, vector_db_ids=list(selected_vector_dbs) ) prompt_context = rag_response.content @@ -251,7 +251,7 @@ def rag_chat_page(): # Run inference directly st.session_state.messages.append({"role": "user", "content": extended_prompt}) - response = LlamaStackApi().client.inference.chat_completion( + response = llama_stack_api.client.inference.chat_completion( messages=st.session_state.messages, model_id=selected_model, sampling_params={ diff --git a/llama_stack/distribution/ui/page/playground/tools.py b/llama_stack/distribution/ui/page/playground/tools.py index 570bfb366..149d8cce9 100644 --- a/llama_stack/distribution/ui/page/playground/tools.py +++ b/llama_stack/distribution/ui/page/playground/tools.py @@ -13,7 +13,7 @@ from llama_stack_client import Agent from llama_stack_client.lib.agents.react.agent import ReActAgent from llama_stack_client.lib.agents.react.tool_parser import ReActOutput -from llama_stack.distribution.ui.modules.api import LlamaStackApi +from llama_stack.distribution.ui.modules.api import llama_stack_api class AgentType(enum.Enum): @@ -24,7 +24,7 @@ class AgentType(enum.Enum): def tool_chat_page(): st.title("🛠 Tools") - client = LlamaStackApi().client + client = llama_stack_api.client models = client.models.list() model_list = [model.identifier for model in models if model.api_model_type == "llm"] @@ -55,7 +55,7 @@ def tool_chat_page(): ) if "builtin::rag" in toolgroup_selection: - vector_dbs = LlamaStackApi().client.vector_dbs.list() or [] + vector_dbs = llama_stack_api.client.vector_dbs.list() or [] if not vector_dbs: st.info("No vector databases available for selection.") vector_dbs = [vector_db.identifier for vector_db in vector_dbs] diff --git a/llama_stack/distribution/ui/requirements.txt b/llama_stack/distribution/ui/requirements.txt index 862f969d6..53a1e7bf3 100644 --- a/llama_stack/distribution/ui/requirements.txt +++ b/llama_stack/distribution/ui/requirements.txt @@ -1,5 +1,5 @@ -llama-stack-client>=0.2.9 +llama-stack>=0.2.1 +llama-stack-client>=0.2.1 pandas streamlit streamlit-option-menu -streamlit-keycloak diff --git a/llama_stack/distribution/utils/exec.py b/llama_stack/distribution/utils/exec.py index 7c2e00524..4acce4f5b 100644 --- a/llama_stack/distribution/utils/exec.py +++ b/llama_stack/distribution/utils/exec.py @@ -8,7 +8,6 @@ import logging import os import signal import subprocess -import sys from termcolor import cprint @@ -34,7 +33,6 @@ def formulate_run_args(image_type, image_name, config, template_name) -> list: cprint( "No current conda environment detected, please specify a conda environment name with --image-name", color="red", - file=sys.stderr, ) return @@ -51,13 +49,12 @@ def formulate_run_args(image_type, image_name, config, template_name) -> list: return envpath return None - cprint(f"Using conda environment: {env_name}", color="green", file=sys.stderr) + print(f"Using conda environment: {env_name}") conda_prefix = get_conda_prefix(env_name) if not conda_prefix: cprint( f"Conda environment {env_name} does not exist.", color="red", - file=sys.stderr, ) return @@ -66,7 +63,6 @@ def formulate_run_args(image_type, image_name, config, template_name) -> list: cprint( f"Build file {build_file} does not exist.\n\nPlease run `llama stack build` or specify the correct conda environment name with --image-name", color="red", - file=sys.stderr, ) return else: @@ -77,10 +73,9 @@ def formulate_run_args(image_type, image_name, config, template_name) -> list: cprint( "No current virtual environment detected, please specify a virtual environment name with --image-name", color="red", - file=sys.stderr, ) return - cprint(f"Using virtual environment: {env_name}", file=sys.stderr) + print(f"Using virtual environment: {env_name}") script = importlib.resources.files("llama_stack") / "distribution/start_stack.sh" run_args = [ diff --git a/llama_stack/log.py b/llama_stack/log.py index f4184710a..98858d208 100644 --- a/llama_stack/log.py +++ b/llama_stack/log.py @@ -6,7 +6,6 @@ import logging import os -import sys from logging.config import dictConfig from rich.console import Console @@ -235,7 +234,7 @@ def get_logger( env_config = os.environ.get("LLAMA_STACK_LOGGING", "") if env_config: - cprint(f"Environment variable LLAMA_STACK_LOGGING found: {env_config}", color="yellow", file=sys.stderr) + cprint(f"Environment variable LLAMA_STACK_LOGGING found: {env_config}", "yellow") _category_levels.update(parse_environment_config(env_config)) log_file = os.environ.get("LLAMA_STACK_LOG_FILE") diff --git a/llama_stack/models/llama/llama3/generation.py b/llama_stack/models/llama/llama3/generation.py index fe7be5ea9..c6d618818 100644 --- a/llama_stack/models/llama/llama3/generation.py +++ b/llama_stack/models/llama/llama3/generation.py @@ -174,7 +174,6 @@ class Llama3: cprint( "Input to model:\n" + self.tokenizer.decode(tokens_to_print) + "\n", "red", - file=sys.stderr, ) prompt_tokens = [inp.tokens for inp in llm_inputs] @@ -185,11 +184,7 @@ class Llama3: max_prompt_len = max(len(t) for t in prompt_tokens) if max_prompt_len >= params.max_seq_len: - cprint( - f"Out of token budget {max_prompt_len} vs {params.max_seq_len}", - color="red", - file=sys.stderr, - ) + cprint(f"Out of token budget {max_prompt_len} vs {params.max_seq_len}", "red") return total_len = min(max_gen_len + max_prompt_len, params.max_seq_len) diff --git a/llama_stack/models/llama/llama4/generation.py b/llama_stack/models/llama/llama4/generation.py index 6132d25d4..476761209 100644 --- a/llama_stack/models/llama/llama4/generation.py +++ b/llama_stack/models/llama/llama4/generation.py @@ -133,9 +133,9 @@ class Llama4: print_model_input = print_model_input or os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1" if print_model_input: - cprint("Input to model:\n", color="yellow", file=sys.stderr) + cprint("Input to model:\n", "yellow") for inp in llm_inputs: - cprint(self.tokenizer.decode(inp.tokens), color="grey", file=sys.stderr) + cprint(self.tokenizer.decode(inp.tokens), "grey") prompt_tokens = [inp.tokens for inp in llm_inputs] bsz = len(llm_inputs) @@ -145,7 +145,7 @@ class Llama4: max_prompt_len = max(len(t) for t in prompt_tokens) if max_prompt_len >= params.max_seq_len: - cprint(f"Out of token budget {max_prompt_len} vs {params.max_seq_len}", color="red", file=sys.stderr) + cprint(f"Out of token budget {max_prompt_len} vs {params.max_seq_len}", "red") return total_len = min(max_gen_len + max_prompt_len, params.max_seq_len) diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index 60b05545b..3e9806f23 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -16,7 +16,7 @@ from llama_stack.apis.datatypes import Api from llama_stack.apis.models import Model from llama_stack.apis.scoring_functions import ScoringFn from llama_stack.apis.shields import Shield -from llama_stack.apis.tools import ToolGroup +from llama_stack.apis.tools import Tool from llama_stack.apis.vector_dbs import VectorDB from llama_stack.schema_utils import json_schema_type @@ -74,10 +74,10 @@ class BenchmarksProtocolPrivate(Protocol): async def register_benchmark(self, benchmark: Benchmark) -> None: ... -class ToolGroupsProtocolPrivate(Protocol): - async def register_toolgroup(self, toolgroup: ToolGroup) -> None: ... +class ToolsProtocolPrivate(Protocol): + async def register_tool(self, tool: Tool) -> None: ... - async def unregister_toolgroup(self, toolgroup_id: str) -> None: ... + async def unregister_tool(self, tool_id: str) -> None: ... @json_schema_type diff --git a/llama_stack/providers/inline/agents/meta_reference/agents.py b/llama_stack/providers/inline/agents/meta_reference/agents.py index bcbfcbe31..86780fd61 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agents.py +++ b/llama_stack/providers/inline/agents/meta_reference/agents.py @@ -20,12 +20,9 @@ from llama_stack.apis.agents import ( AgentTurnCreateRequest, AgentTurnResumeRequest, Document, - ListOpenAIResponseInputItem, - ListOpenAIResponseObject, OpenAIResponseInput, OpenAIResponseInputTool, OpenAIResponseObject, - Order, Session, Turn, ) @@ -42,7 +39,6 @@ from llama_stack.apis.tools import ToolGroups, ToolRuntime from llama_stack.apis.vector_io import VectorIO from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl from llama_stack.providers.utils.pagination import paginate_records -from llama_stack.providers.utils.responses.responses_store import ResponsesStore from .agent_instance import ChatAgent from .config import MetaReferenceAgentsImplConfig @@ -70,17 +66,15 @@ class MetaReferenceAgentsImpl(Agents): self.tool_groups_api = tool_groups_api self.in_memory_store = InmemoryKVStoreImpl() - self.openai_responses_impl: OpenAIResponsesImpl | None = None + self.openai_responses_impl = None async def initialize(self) -> None: self.persistence_store = await kvstore_impl(self.config.persistence_store) - self.responses_store = ResponsesStore(self.config.responses_store) - await self.responses_store.initialize() self.openai_responses_impl = OpenAIResponsesImpl( + self.persistence_store, inference_api=self.inference_api, tool_groups_api=self.tool_groups_api, tool_runtime_api=self.tool_runtime_api, - responses_store=self.responses_store, ) async def create_agent( @@ -311,15 +305,14 @@ class MetaReferenceAgentsImpl(Agents): # OpenAI responses async def get_openai_response( self, - response_id: str, + id: str, ) -> OpenAIResponseObject: - return await self.openai_responses_impl.get_openai_response(response_id) + return await self.openai_responses_impl.get_openai_response(id) async def create_openai_response( self, input: str | list[OpenAIResponseInput], model: str, - instructions: str | None = None, previous_response_id: str | None = None, store: bool | None = True, stream: bool | None = False, @@ -327,27 +320,5 @@ class MetaReferenceAgentsImpl(Agents): tools: list[OpenAIResponseInputTool] | None = None, ) -> OpenAIResponseObject: return await self.openai_responses_impl.create_openai_response( - input, model, instructions, previous_response_id, store, stream, temperature, tools - ) - - async def list_openai_responses( - self, - after: str | None = None, - limit: int | None = 50, - model: str | None = None, - order: Order | None = Order.desc, - ) -> ListOpenAIResponseObject: - return await self.openai_responses_impl.list_openai_responses(after, limit, model, order) - - async def list_openai_response_input_items( - self, - response_id: str, - after: str | None = None, - before: str | None = None, - include: list[str] | None = None, - limit: int | None = 20, - order: Order | None = Order.desc, - ) -> ListOpenAIResponseInputItem: - return await self.openai_responses_impl.list_openai_response_input_items( - response_id, after, before, include, limit, order + input, model, previous_response_id, store, stream, temperature, tools ) diff --git a/llama_stack/providers/inline/agents/meta_reference/config.py b/llama_stack/providers/inline/agents/meta_reference/config.py index 1c392f29c..c860e6df1 100644 --- a/llama_stack/providers/inline/agents/meta_reference/config.py +++ b/llama_stack/providers/inline/agents/meta_reference/config.py @@ -10,12 +10,10 @@ from pydantic import BaseModel from llama_stack.providers.utils.kvstore import KVStoreConfig from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig -from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig class MetaReferenceAgentsImplConfig(BaseModel): persistence_store: KVStoreConfig - responses_store: SqlStoreConfig @classmethod def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]: @@ -23,9 +21,5 @@ class MetaReferenceAgentsImplConfig(BaseModel): "persistence_store": SqliteKVStoreConfig.sample_run_config( __distro_dir__=__distro_dir__, db_name="agents_store.db", - ), - "responses_store": SqliteSqlStoreConfig.sample_run_config( - __distro_dir__=__distro_dir__, - db_name="responses_store.db", - ), + ) } diff --git a/llama_stack/providers/inline/agents/meta_reference/openai_responses.py b/llama_stack/providers/inline/agents/meta_reference/openai_responses.py index 19d7ea56f..6d9d06109 100644 --- a/llama_stack/providers/inline/agents/meta_reference/openai_responses.py +++ b/llama_stack/providers/inline/agents/meta_reference/openai_responses.py @@ -5,7 +5,6 @@ # the root directory of this source tree. import json -import time import uuid from collections.abc import AsyncIterator from typing import Any, cast @@ -13,29 +12,24 @@ from typing import Any, cast from openai.types.chat import ChatCompletionToolParam from pydantic import BaseModel -from llama_stack.apis.agents import Order from llama_stack.apis.agents.openai_responses import ( - AllowedToolsFilter, - ListOpenAIResponseInputItem, - ListOpenAIResponseObject, OpenAIResponseInput, OpenAIResponseInputFunctionToolCallOutput, + OpenAIResponseInputItemList, OpenAIResponseInputMessageContent, OpenAIResponseInputMessageContentImage, OpenAIResponseInputMessageContentText, OpenAIResponseInputTool, - OpenAIResponseInputToolMCP, + OpenAIResponseInputToolFunction, OpenAIResponseMessage, OpenAIResponseObject, OpenAIResponseObjectStream, OpenAIResponseObjectStreamResponseCompleted, OpenAIResponseObjectStreamResponseCreated, - OpenAIResponseObjectStreamResponseOutputTextDelta, OpenAIResponseOutput, OpenAIResponseOutputMessageContent, OpenAIResponseOutputMessageContentOutputText, OpenAIResponseOutputMessageFunctionToolCall, - OpenAIResponseOutputMessageMCPListTools, OpenAIResponseOutputMessageWebSearchToolCall, ) from llama_stack.apis.inference.inference import ( @@ -55,12 +49,11 @@ from llama_stack.apis.inference.inference import ( OpenAIToolMessageParam, OpenAIUserMessageParam, ) -from llama_stack.apis.tools.tools import ToolGroups, ToolRuntime +from llama_stack.apis.tools.tools import ToolGroups, ToolInvocationResult, ToolRuntime from llama_stack.log import get_logger from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool -from llama_stack.providers.utils.responses.responses_store import ResponsesStore -from llama_stack.providers.utils.tools.mcp import invoke_mcp_tool, list_mcp_tools +from llama_stack.providers.utils.kvstore import KVStore logger = get_logger(name=__name__, category="openai_responses") @@ -169,43 +162,41 @@ async def _get_message_type_by_role(role: str): class OpenAIResponsePreviousResponseWithInputItems(BaseModel): - input_items: ListOpenAIResponseInputItem + input_items: OpenAIResponseInputItemList response: OpenAIResponseObject -class ChatCompletionContext(BaseModel): - model: str - messages: list[OpenAIMessageParam] - tools: list[ChatCompletionToolParam] | None = None - mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] - stream: bool - temperature: float | None - - class OpenAIResponsesImpl: def __init__( self, + persistence_store: KVStore, inference_api: Inference, tool_groups_api: ToolGroups, tool_runtime_api: ToolRuntime, - responses_store: ResponsesStore, ): + self.persistence_store = persistence_store self.inference_api = inference_api self.tool_groups_api = tool_groups_api self.tool_runtime_api = tool_runtime_api - self.responses_store = responses_store + + async def _get_previous_response_with_input(self, id: str) -> OpenAIResponsePreviousResponseWithInputItems: + key = f"{OPENAI_RESPONSES_PREFIX}{id}" + response_json = await self.persistence_store.get(key=key) + if response_json is None: + raise ValueError(f"OpenAI response with id '{id}' not found") + return OpenAIResponsePreviousResponseWithInputItems.model_validate_json(response_json) async def _prepend_previous_response( self, input: str | list[OpenAIResponseInput], previous_response_id: str | None = None ): if previous_response_id: - previous_response_with_input = await self.responses_store.get_response_object(previous_response_id) + previous_response_with_input = await self._get_previous_response_with_input(previous_response_id) # previous response input items - new_input_items = previous_response_with_input.input + new_input_items = previous_response_with_input.input_items.data # previous response output items - new_input_items.extend(previous_response_with_input.output) + new_input_items.extend(previous_response_with_input.response.output) # new input items from the current request if isinstance(input, str): @@ -217,60 +208,99 @@ class OpenAIResponsesImpl: return input - async def _prepend_instructions(self, messages, instructions): - if instructions: - messages.insert(0, OpenAISystemMessageParam(content=instructions)) - async def get_openai_response( self, - response_id: str, + id: str, ) -> OpenAIResponseObject: - response_with_input = await self.responses_store.get_response_object(response_id) - return OpenAIResponseObject(**{k: v for k, v in response_with_input.model_dump().items() if k != "input"}) + response_with_input = await self._get_previous_response_with_input(id) + return response_with_input.response - async def list_openai_responses( + async def create_openai_response( self, - after: str | None = None, - limit: int | None = 50, - model: str | None = None, - order: Order | None = Order.desc, - ) -> ListOpenAIResponseObject: - return await self.responses_store.list_responses(after, limit, model, order) + input: str | list[OpenAIResponseInput], + model: str, + previous_response_id: str | None = None, + store: bool | None = True, + stream: bool | None = False, + temperature: float | None = None, + tools: list[OpenAIResponseInputTool] | None = None, + ): + stream = False if stream is None else stream - async def list_openai_response_input_items( - self, - response_id: str, - after: str | None = None, - before: str | None = None, - include: list[str] | None = None, - limit: int | None = 20, - order: Order | None = Order.desc, - ) -> ListOpenAIResponseInputItem: - """List input items for a given OpenAI response. + input = await self._prepend_previous_response(input, previous_response_id) + messages = await _convert_response_input_to_chat_messages(input) + chat_tools = await self._convert_response_tools_to_chat_tools(tools) if tools else None + chat_response = await self.inference_api.openai_chat_completion( + model=model, + messages=messages, + tools=chat_tools, + stream=stream, + temperature=temperature, + ) - :param response_id: The ID of the response to retrieve input items for. - :param after: An item ID to list items after, used for pagination. - :param before: An item ID to list items before, used for pagination. - :param include: Additional fields to include in the response. - :param limit: A limit on the number of objects to be returned. - :param order: The order to return the input items in. - :returns: An ListOpenAIResponseInputItem. - """ - return await self.responses_store.list_response_input_items(response_id, after, before, include, limit, order) + if stream: + # TODO: refactor this into a separate method that handles streaming + chat_response_id = "" + chat_response_content = [] + chat_response_tool_calls: dict[int, OpenAIChatCompletionToolCall] = {} + # TODO: these chunk_ fields are hacky and only take the last chunk into account + chunk_created = 0 + chunk_model = "" + chunk_finish_reason = "" + async for chunk in chat_response: + chat_response_id = chunk.id + chunk_created = chunk.created + chunk_model = chunk.model + for chunk_choice in chunk.choices: + # TODO: this only works for text content + chat_response_content.append(chunk_choice.delta.content or "") + if chunk_choice.finish_reason: + chunk_finish_reason = chunk_choice.finish_reason + + # Aggregate tool call arguments across chunks, using their index as the aggregation key + if chunk_choice.delta.tool_calls: + for tool_call in chunk_choice.delta.tool_calls: + response_tool_call = chat_response_tool_calls.get(tool_call.index, None) + if response_tool_call: + response_tool_call.function.arguments += tool_call.function.arguments + else: + tool_call_dict: dict[str, Any] = tool_call.model_dump() + # Ensure we don't have any empty type field in the tool call dict. + # The OpenAI client used by providers often returns a type=None here. + tool_call_dict.pop("type", None) + response_tool_call = OpenAIChatCompletionToolCall(**tool_call_dict) + chat_response_tool_calls[tool_call.index] = response_tool_call + + # Convert the dict of tool calls by index to a list of tool calls to pass back in our response + if chat_response_tool_calls: + tool_calls = [chat_response_tool_calls[i] for i in sorted(chat_response_tool_calls.keys())] + else: + tool_calls = None + assistant_message = OpenAIAssistantMessageParam( + content="".join(chat_response_content), + tool_calls=tool_calls, + ) + chat_response = OpenAIChatCompletion( + id=chat_response_id, + choices=[ + OpenAIChoice( + message=assistant_message, + finish_reason=chunk_finish_reason, + index=0, + ) + ], + created=chunk_created, + model=chunk_model, + ) + else: + # dump and reload to map to our pydantic types + chat_response = OpenAIChatCompletion(**chat_response.model_dump()) - async def _process_response_choices( - self, - chat_response: OpenAIChatCompletion, - ctx: ChatCompletionContext, - tools: list[OpenAIResponseInputTool] | None, - ) -> list[OpenAIResponseOutput]: - """Handle tool execution and response message creation.""" output_messages: list[OpenAIResponseOutput] = [] - # Execute tool calls if any for choice in chat_response.choices: if choice.message.tool_calls and tools: # Assume if the first tool is a function, all tools are functions - if tools[0].type == "function": + if isinstance(tools[0], OpenAIResponseInputToolFunction): for tool_call in choice.message.tool_calls: output_messages.append( OpenAIResponseOutputMessageFunctionToolCall( @@ -282,132 +312,11 @@ class OpenAIResponsesImpl: ) ) else: - tool_messages = await self._execute_tool_and_return_final_output(choice, ctx) - output_messages.extend(tool_messages) + output_messages.extend( + await self._execute_tool_and_return_final_output(model, stream, choice, messages, temperature) + ) else: output_messages.append(await _convert_chat_choice_to_response_message(choice)) - - return output_messages - - async def _store_response( - self, - response: OpenAIResponseObject, - input: str | list[OpenAIResponseInput], - ) -> None: - new_input_id = f"msg_{uuid.uuid4()}" - if isinstance(input, str): - # synthesize a message from the input string - input_content = OpenAIResponseInputMessageContentText(text=input) - input_content_item = OpenAIResponseMessage( - role="user", - content=[input_content], - id=new_input_id, - ) - input_items_data = [input_content_item] - else: - # we already have a list of messages - input_items_data = [] - for input_item in input: - if isinstance(input_item, OpenAIResponseMessage): - # These may or may not already have an id, so dump to dict, check for id, and add if missing - input_item_dict = input_item.model_dump() - if "id" not in input_item_dict: - input_item_dict["id"] = new_input_id - input_items_data.append(OpenAIResponseMessage(**input_item_dict)) - else: - input_items_data.append(input_item) - - await self.responses_store.store_response_object( - response_object=response, - input=input_items_data, - ) - - async def create_openai_response( - self, - input: str | list[OpenAIResponseInput], - model: str, - instructions: str | None = None, - previous_response_id: str | None = None, - store: bool | None = True, - stream: bool | None = False, - temperature: float | None = None, - tools: list[OpenAIResponseInputTool] | None = None, - ): - stream = False if stream is None else stream - - output_messages: list[OpenAIResponseOutput] = [] - - # Input preprocessing - input = await self._prepend_previous_response(input, previous_response_id) - messages = await _convert_response_input_to_chat_messages(input) - await self._prepend_instructions(messages, instructions) - - # Tool setup - chat_tools, mcp_tool_to_server, mcp_list_message = ( - await self._convert_response_tools_to_chat_tools(tools) if tools else (None, {}, None) - ) - if mcp_list_message: - output_messages.append(mcp_list_message) - - ctx = ChatCompletionContext( - model=model, - messages=messages, - tools=chat_tools, - mcp_tool_to_server=mcp_tool_to_server, - stream=stream, - temperature=temperature, - ) - - inference_result = await self.inference_api.openai_chat_completion( - model=model, - messages=messages, - tools=chat_tools, - stream=stream, - temperature=temperature, - ) - - if stream: - return self._create_streaming_response( - inference_result=inference_result, - ctx=ctx, - output_messages=output_messages, - input=input, - model=model, - store=store, - tools=tools, - ) - else: - return await self._create_non_streaming_response( - inference_result=inference_result, - ctx=ctx, - output_messages=output_messages, - input=input, - model=model, - store=store, - tools=tools, - ) - - async def _create_non_streaming_response( - self, - inference_result: Any, - ctx: ChatCompletionContext, - output_messages: list[OpenAIResponseOutput], - input: str | list[OpenAIResponseInput], - model: str, - store: bool | None, - tools: list[OpenAIResponseInputTool] | None, - ) -> OpenAIResponseObject: - chat_response = OpenAIChatCompletion(**inference_result.model_dump()) - - # Process response choices (tool execution and message creation) - output_messages.extend( - await self._process_response_choices( - chat_response=chat_response, - ctx=ctx, - tools=tools, - ) - ) - response = OpenAIResponseObject( created_at=chat_response.created, id=f"resp-{uuid.uuid4()}", @@ -418,173 +327,57 @@ class OpenAIResponsesImpl: ) logger.debug(f"OpenAI Responses response: {response}") - # Store response if requested if store: - await self._store_response( + # Store in kvstore + + new_input_id = f"msg_{uuid.uuid4()}" + if isinstance(input, str): + # synthesize a message from the input string + input_content = OpenAIResponseInputMessageContentText(text=input) + input_content_item = OpenAIResponseMessage( + role="user", + content=[input_content], + id=new_input_id, + ) + input_items_data = [input_content_item] + else: + # we already have a list of messages + input_items_data = [] + for input_item in input: + if isinstance(input_item, OpenAIResponseMessage): + # These may or may not already have an id, so dump to dict, check for id, and add if missing + input_item_dict = input_item.model_dump() + if "id" not in input_item_dict: + input_item_dict["id"] = new_input_id + input_items_data.append(OpenAIResponseMessage(**input_item_dict)) + else: + input_items_data.append(input_item) + + input_items = OpenAIResponseInputItemList(data=input_items_data) + prev_response = OpenAIResponsePreviousResponseWithInputItems( + input_items=input_items, response=response, - input=input, ) + key = f"{OPENAI_RESPONSES_PREFIX}{response.id}" + await self.persistence_store.set( + key=key, + value=prev_response.model_dump_json(), + ) + + if stream: + + async def async_response() -> AsyncIterator[OpenAIResponseObjectStream]: + # TODO: response created should actually get emitted much earlier in the process + yield OpenAIResponseObjectStreamResponseCreated(response=response) + yield OpenAIResponseObjectStreamResponseCompleted(response=response) + + return async_response() return response - async def _create_streaming_response( - self, - inference_result: Any, - ctx: ChatCompletionContext, - output_messages: list[OpenAIResponseOutput], - input: str | list[OpenAIResponseInput], - model: str, - store: bool | None, - tools: list[OpenAIResponseInputTool] | None, - ) -> AsyncIterator[OpenAIResponseObjectStream]: - # Create initial response and emit response.created immediately - response_id = f"resp-{uuid.uuid4()}" - created_at = int(time.time()) - - initial_response = OpenAIResponseObject( - created_at=created_at, - id=response_id, - model=model, - object="response", - status="in_progress", - output=output_messages.copy(), - ) - - # Emit response.created immediately - yield OpenAIResponseObjectStreamResponseCreated(response=initial_response) - - # For streaming, inference_result is an async iterator of chunks - # Stream chunks and emit delta events as they arrive - chat_response_id = "" - chat_response_content = [] - chat_response_tool_calls: dict[int, OpenAIChatCompletionToolCall] = {} - chunk_created = 0 - chunk_model = "" - chunk_finish_reason = "" - sequence_number = 0 - - # Create a placeholder message item for delta events - message_item_id = f"msg_{uuid.uuid4()}" - - async for chunk in inference_result: - chat_response_id = chunk.id - chunk_created = chunk.created - chunk_model = chunk.model - for chunk_choice in chunk.choices: - # Emit incremental text content as delta events - if chunk_choice.delta.content: - sequence_number += 1 - yield OpenAIResponseObjectStreamResponseOutputTextDelta( - content_index=0, - delta=chunk_choice.delta.content, - item_id=message_item_id, - output_index=0, - sequence_number=sequence_number, - ) - - # Collect content for final response - chat_response_content.append(chunk_choice.delta.content or "") - if chunk_choice.finish_reason: - chunk_finish_reason = chunk_choice.finish_reason - - # Aggregate tool call arguments across chunks, using their index as the aggregation key - if chunk_choice.delta.tool_calls: - for tool_call in chunk_choice.delta.tool_calls: - response_tool_call = chat_response_tool_calls.get(tool_call.index, None) - if response_tool_call: - # Don't attempt to concatenate arguments if we don't have any new arguments - if tool_call.function.arguments: - # Guard against an initial None argument before we concatenate - response_tool_call.function.arguments = ( - response_tool_call.function.arguments or "" - ) + tool_call.function.arguments - else: - tool_call_dict: dict[str, Any] = tool_call.model_dump() - tool_call_dict.pop("type", None) - response_tool_call = OpenAIChatCompletionToolCall(**tool_call_dict) - chat_response_tool_calls[tool_call.index] = response_tool_call - - # Convert collected chunks to complete response - if chat_response_tool_calls: - tool_calls = [chat_response_tool_calls[i] for i in sorted(chat_response_tool_calls.keys())] - else: - tool_calls = None - assistant_message = OpenAIAssistantMessageParam( - content="".join(chat_response_content), - tool_calls=tool_calls, - ) - chat_response_obj = OpenAIChatCompletion( - id=chat_response_id, - choices=[ - OpenAIChoice( - message=assistant_message, - finish_reason=chunk_finish_reason, - index=0, - ) - ], - created=chunk_created, - model=chunk_model, - ) - - # Process response choices (tool execution and message creation) - output_messages.extend( - await self._process_response_choices( - chat_response=chat_response_obj, - ctx=ctx, - tools=tools, - ) - ) - - # Create final response - final_response = OpenAIResponseObject( - created_at=created_at, - id=response_id, - model=model, - object="response", - status="completed", - output=output_messages, - ) - - if store: - await self._store_response( - response=final_response, - input=input, - ) - - # Emit response.completed - yield OpenAIResponseObjectStreamResponseCompleted(response=final_response) - async def _convert_response_tools_to_chat_tools( self, tools: list[OpenAIResponseInputTool] - ) -> tuple[ - list[ChatCompletionToolParam], - dict[str, OpenAIResponseInputToolMCP], - OpenAIResponseOutput | None, - ]: - from llama_stack.apis.agents.openai_responses import ( - MCPListToolsTool, - ) - from llama_stack.apis.tools.tools import Tool - - mcp_tool_to_server = {} - - def make_openai_tool(tool_name: str, tool: Tool) -> ChatCompletionToolParam: - tool_def = ToolDefinition( - tool_name=tool_name, - description=tool.description, - parameters={ - param.name: ToolParamDefinition( - param_type=param.parameter_type, - description=param.description, - required=param.required, - default=param.default, - ) - for param in tool.parameters - }, - ) - return convert_tooldef_to_openai_tool(tool_def) - - mcp_list_message = None + ) -> list[ChatCompletionToolParam]: chat_tools: list[ChatCompletionToolParam] = [] for input_tool in tools: # TODO: Handle other tool types @@ -593,95 +386,91 @@ class OpenAIResponsesImpl: elif input_tool.type == "web_search": tool_name = "web_search" tool = await self.tool_groups_api.get_tool(tool_name) - if not tool: - raise ValueError(f"Tool {tool_name} not found") - chat_tools.append(make_openai_tool(tool_name, tool)) - elif input_tool.type == "mcp": - always_allowed = None - never_allowed = None - if input_tool.allowed_tools: - if isinstance(input_tool.allowed_tools, list): - always_allowed = input_tool.allowed_tools - elif isinstance(input_tool.allowed_tools, AllowedToolsFilter): - always_allowed = input_tool.allowed_tools.always - never_allowed = input_tool.allowed_tools.never - - tool_defs = await list_mcp_tools( - endpoint=input_tool.server_url, - headers=input_tool.headers or {}, - ) - - mcp_list_message = OpenAIResponseOutputMessageMCPListTools( - id=f"mcp_list_{uuid.uuid4()}", - status="completed", - server_label=input_tool.server_label, - tools=[], - ) - for t in tool_defs.data: - if never_allowed and t.name in never_allowed: - continue - if not always_allowed or t.name in always_allowed: - chat_tools.append(make_openai_tool(t.name, t)) - if t.name in mcp_tool_to_server: - raise ValueError(f"Duplicate tool name {t.name} found for server {input_tool.server_label}") - mcp_tool_to_server[t.name] = input_tool - mcp_list_message.tools.append( - MCPListToolsTool( - name=t.name, - description=t.description, - input_schema={ - "type": "object", - "properties": { - p.name: { - "type": p.parameter_type, - "description": p.description, - } - for p in t.parameters - }, - "required": [p.name for p in t.parameters if p.required], - }, - ) + tool_def = ToolDefinition( + tool_name=tool_name, + description=tool.description, + parameters={ + param.name: ToolParamDefinition( + param_type=param.parameter_type, + description=param.description, + required=param.required, + default=param.default, ) + for param in tool.parameters + }, + ) + chat_tool = convert_tooldef_to_openai_tool(tool_def) + chat_tools.append(chat_tool) else: raise ValueError(f"Llama Stack OpenAI Responses does not yet support tool type: {input_tool.type}") - return chat_tools, mcp_tool_to_server, mcp_list_message + return chat_tools async def _execute_tool_and_return_final_output( self, + model_id: str, + stream: bool, choice: OpenAIChoice, - ctx: ChatCompletionContext, + messages: list[OpenAIMessageParam], + temperature: float, ) -> list[OpenAIResponseOutput]: output_messages: list[OpenAIResponseOutput] = [] + # If the choice is not an assistant message, we don't need to execute any tools if not isinstance(choice.message, OpenAIAssistantMessageParam): return output_messages + # If the assistant message doesn't have any tool calls, we don't need to execute any tools if not choice.message.tool_calls: return output_messages - next_turn_messages = ctx.messages.copy() + # Copy the messages list to avoid mutating the original list + messages = messages.copy() # Add the assistant message with tool_calls response to the messages list - next_turn_messages.append(choice.message) + messages.append(choice.message) for tool_call in choice.message.tool_calls: + tool_call_id = tool_call.id + function = tool_call.function + + # If for some reason the tool call doesn't have a function or id, we can't execute it + if not function or not tool_call_id: + continue + # TODO: telemetry spans for tool calls - tool_call_log, further_input = await self._execute_tool_call(tool_call, ctx) - if tool_call_log: - output_messages.append(tool_call_log) - if further_input: - next_turn_messages.append(further_input) + result = await self._execute_tool_call(function) + + # Handle tool call failure + if not result: + output_messages.append( + OpenAIResponseOutputMessageWebSearchToolCall( + id=tool_call_id, + status="failed", + ) + ) + continue + + output_messages.append( + OpenAIResponseOutputMessageWebSearchToolCall( + id=tool_call_id, + status="completed", + ), + ) + + result_content = "" + # TODO: handle other result content types and lists + if isinstance(result.content, str): + result_content = result.content + messages.append(OpenAIToolMessageParam(content=result_content, tool_call_id=tool_call_id)) tool_results_chat_response = await self.inference_api.openai_chat_completion( - model=ctx.model, - messages=next_turn_messages, - stream=ctx.stream, - temperature=ctx.temperature, + model=model_id, + messages=messages, + stream=stream, + temperature=temperature, ) - # type cast to appease mypy: this is needed because we don't handle streaming properly :) + # type cast to appease mypy tool_results_chat_response = cast(OpenAIChatCompletion, tool_results_chat_response) - - # Huge TODO: these are NOT the final outputs, we must keep the loop going tool_final_outputs = [ await _convert_chat_choice_to_response_message(choice) for choice in tool_results_chat_response.choices ] @@ -691,86 +480,15 @@ class OpenAIResponsesImpl: async def _execute_tool_call( self, - tool_call: OpenAIChatCompletionToolCall, - ctx: ChatCompletionContext, - ) -> tuple[OpenAIResponseOutput | None, OpenAIMessageParam | None]: - from llama_stack.providers.utils.inference.prompt_adapter import ( - interleaved_content_as_str, + function: OpenAIChatCompletionToolCallFunction, + ) -> ToolInvocationResult | None: + if not function.name: + return None + function_args = json.loads(function.arguments) if function.arguments else {} + logger.info(f"executing tool call: {function.name} with args: {function_args}") + result = await self.tool_runtime_api.invoke_tool( + tool_name=function.name, + kwargs=function_args, ) - - tool_call_id = tool_call.id - function = tool_call.function - - if not function or not tool_call_id or not function.name: - return None, None - - error_exc = None - result = None - try: - if function.name in ctx.mcp_tool_to_server: - mcp_tool = ctx.mcp_tool_to_server[function.name] - result = await invoke_mcp_tool( - endpoint=mcp_tool.server_url, - headers=mcp_tool.headers or {}, - tool_name=function.name, - kwargs=json.loads(function.arguments) if function.arguments else {}, - ) - else: - result = await self.tool_runtime_api.invoke_tool( - tool_name=function.name, - kwargs=json.loads(function.arguments) if function.arguments else {}, - ) - except Exception as e: - error_exc = e - - if function.name in ctx.mcp_tool_to_server: - from llama_stack.apis.agents.openai_responses import OpenAIResponseOutputMessageMCPCall - - message = OpenAIResponseOutputMessageMCPCall( - id=tool_call_id, - arguments=function.arguments, - name=function.name, - server_label=ctx.mcp_tool_to_server[function.name].server_label, - ) - if error_exc: - message.error = str(error_exc) - elif (result.error_code and result.error_code > 0) or result.error_message: - message.error = f"Error (code {result.error_code}): {result.error_message}" - elif result.content: - message.output = interleaved_content_as_str(result.content) - else: - if function.name == "web_search": - message = OpenAIResponseOutputMessageWebSearchToolCall( - id=tool_call_id, - status="completed", - ) - if error_exc or (result.error_code and result.error_code > 0) or result.error_message: - message.status = "failed" - else: - raise ValueError(f"Unknown tool {function.name} called") - - input_message = None - if result and result.content: - if isinstance(result.content, str): - content = result.content - elif isinstance(result.content, list): - from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem - - content = [] - for item in result.content: - if isinstance(item, TextContentItem): - part = OpenAIChatCompletionContentPartTextParam(text=item.text) - elif isinstance(item, ImageContentItem): - if item.image.data: - url = f"data:image;base64,{item.image.data}" - else: - url = item.image.url - part = OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url=url)) - else: - raise ValueError(f"Unknown result content type: {type(item)}") - content.append(part) - else: - raise ValueError(f"Unknown result content type: {type(result.content)}") - input_message = OpenAIToolMessageParam(content=content, tool_call_id=tool_call_id) - - return message, input_message + logger.debug(f"tool call {function.name} completed with result: {result}") + return result diff --git a/llama_stack/providers/inline/inference/meta_reference/inference.py b/llama_stack/providers/inline/inference/meta_reference/inference.py index e238e1b78..8dd594869 100644 --- a/llama_stack/providers/inline/inference/meta_reference/inference.py +++ b/llama_stack/providers/inline/inference/meta_reference/inference.py @@ -6,7 +6,6 @@ import asyncio import os -import sys from collections.abc import AsyncGenerator from pydantic import BaseModel @@ -29,7 +28,7 @@ from llama_stack.apis.inference import ( CompletionRequest, CompletionResponse, CompletionResponseStreamChunk, - InferenceProvider, + Inference, InterleavedContent, LogProbConfig, Message, @@ -87,7 +86,7 @@ class MetaReferenceInferenceImpl( OpenAICompletionToLlamaStackMixin, OpenAIChatCompletionToLlamaStackMixin, SentenceTransformerEmbeddingMixin, - InferenceProvider, + Inference, ModelsProtocolPrivate, ): def __init__(self, config: MetaReferenceInferenceConfig) -> None: @@ -456,9 +455,9 @@ class MetaReferenceInferenceImpl( first = token_results[0] if not first.finished and not first.ignore_token: if os.environ.get("LLAMA_MODELS_DEBUG", "0") in ("1", "2"): - cprint(first.text, color="cyan", end="", file=sys.stderr) + cprint(first.text, "cyan", end="") if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "2": - cprint(f"<{first.token}>", color="magenta", end="", file=sys.stderr) + cprint(f"<{first.token}>", "magenta", end="") for result in token_results: idx = result.batch_idx @@ -520,9 +519,9 @@ class MetaReferenceInferenceImpl( for token_results in self.generator.chat_completion([request]): token_result = token_results[0] if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1": - cprint(token_result.text, color="cyan", end="", file=sys.stderr) + cprint(token_result.text, "cyan", end="") if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "2": - cprint(f"<{token_result.token}>", color="magenta", end="", file=sys.stderr) + cprint(f"<{token_result.token}>", "magenta", end="") if token_result.token == tokenizer.eot_id: stop_reason = StopReason.end_of_turn diff --git a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py index 890c526f5..7b36b0997 100644 --- a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py +++ b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py @@ -9,7 +9,7 @@ from collections.abc import AsyncGenerator from llama_stack.apis.inference import ( CompletionResponse, - InferenceProvider, + Inference, InterleavedContent, LogProbConfig, Message, @@ -38,7 +38,7 @@ class SentenceTransformersInferenceImpl( OpenAIChatCompletionToLlamaStackMixin, OpenAICompletionToLlamaStackMixin, SentenceTransformerEmbeddingMixin, - InferenceProvider, + Inference, ModelsProtocolPrivate, ): def __init__(self, config: SentenceTransformersInferenceConfig) -> None: diff --git a/llama_stack/providers/inline/inference/vllm/vllm.py b/llama_stack/providers/inline/inference/vllm/vllm.py index bf54462b5..438cb14a0 100644 --- a/llama_stack/providers/inline/inference/vllm/vllm.py +++ b/llama_stack/providers/inline/inference/vllm/vllm.py @@ -40,7 +40,6 @@ from llama_stack.apis.inference import ( JsonSchemaResponseFormat, LogProbConfig, Message, - OpenAIEmbeddingsResponse, ResponseFormat, SamplingParams, TextTruncation, @@ -411,16 +410,6 @@ class VLLMInferenceImpl( ) -> EmbeddingsResponse: raise NotImplementedError() - async def openai_embeddings( - self, - model: str, - input: str | list[str], - encoding_format: str | None = "float", - dimensions: int | None = None, - user: str | None = None, - ) -> OpenAIEmbeddingsResponse: - raise NotImplementedError() - async def chat_completion( self, model_id: str, diff --git a/llama_stack/providers/inline/post_training/common/utils.py b/llama_stack/providers/inline/post_training/common/utils.py deleted file mode 100644 index 7840b21e8..000000000 --- a/llama_stack/providers/inline/post_training/common/utils.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import gc - - -def evacuate_model_from_device(model, device: str): - """Safely clear a model from memory and free device resources. - This function handles the proper cleanup of a model by: - 1. Moving the model to CPU if it's on a non-CPU device - 2. Deleting the model object to free memory - 3. Running garbage collection - 4. Clearing CUDA cache if the model was on a CUDA device - Args: - model: The PyTorch model to clear - device: The device type the model is currently on ('cuda', 'mps', 'cpu') - Note: - - For CUDA devices, this will clear the CUDA cache after moving the model to CPU - - For MPS devices, only moves the model to CPU (no cache clearing available) - - For CPU devices, only deletes the model object and runs garbage collection - """ - if device != "cpu": - model.to("cpu") - - del model - gc.collect() - - if device == "cuda": - # we need to import such that this is only imported when the method is called - import torch - - torch.cuda.empty_cache() diff --git a/llama_stack/providers/inline/post_training/huggingface/__init__.py b/llama_stack/providers/inline/post_training/huggingface/__init__.py deleted file mode 100644 index cc1a671c1..000000000 --- a/llama_stack/providers/inline/post_training/huggingface/__init__.py +++ /dev/null @@ -1,27 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from typing import Any - -from llama_stack.distribution.datatypes import Api - -from .config import HuggingFacePostTrainingConfig - -# post_training api and the huggingface provider is still experimental and under heavy development - - -async def get_provider_impl( - config: HuggingFacePostTrainingConfig, - deps: dict[Api, Any], -): - from .post_training import HuggingFacePostTrainingImpl - - impl = HuggingFacePostTrainingImpl( - config, - deps[Api.datasetio], - deps[Api.datasets], - ) - return impl diff --git a/llama_stack/providers/inline/post_training/huggingface/config.py b/llama_stack/providers/inline/post_training/huggingface/config.py deleted file mode 100644 index 06c6d8073..000000000 --- a/llama_stack/providers/inline/post_training/huggingface/config.py +++ /dev/null @@ -1,72 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from typing import Any, Literal - -from pydantic import BaseModel - - -class HuggingFacePostTrainingConfig(BaseModel): - # Device to run training on (cuda, cpu, mps) - device: str = "cuda" - - # Distributed training backend if using multiple devices - # fsdp: Fully Sharded Data Parallel - # deepspeed: DeepSpeed ZeRO optimization - distributed_backend: Literal["fsdp", "deepspeed"] | None = None - - # Format for saving model checkpoints - # full_state: Save complete model state - # huggingface: Save in HuggingFace format (recommended for compatibility) - checkpoint_format: Literal["full_state", "huggingface"] | None = "huggingface" - - # Template for formatting chat inputs and outputs - # Used to structure the conversation format for training - chat_template: str = "<|user|>\n{input}\n<|assistant|>\n{output}" - - # Model-specific configuration parameters - # trust_remote_code: Allow execution of custom model code - # attn_implementation: Use SDPA (Scaled Dot Product Attention) for better performance - model_specific_config: dict = { - "trust_remote_code": True, - "attn_implementation": "sdpa", - } - - # Maximum sequence length for training - # Set to 2048 as this is the maximum that works reliably on MPS (Apple Silicon) - # Longer sequences may cause memory issues on MPS devices - max_seq_length: int = 2048 - - # Enable gradient checkpointing to reduce memory usage - # Trades computation for memory by recomputing activations - gradient_checkpointing: bool = False - - # Maximum number of checkpoints to keep - # Older checkpoints are deleted when this limit is reached - save_total_limit: int = 3 - - # Number of training steps between logging updates - logging_steps: int = 10 - - # Ratio of training steps used for learning rate warmup - # Helps stabilize early training - warmup_ratio: float = 0.1 - - # L2 regularization coefficient - # Helps prevent overfitting - weight_decay: float = 0.01 - - # Number of worker processes for data loading - # Higher values can improve data loading speed but increase memory usage - dataloader_num_workers: int = 4 - - # Whether to pin memory in data loader - # Can improve data transfer speed to GPU but uses more memory - dataloader_pin_memory: bool = True - - @classmethod - def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]: - return {"checkpoint_format": "huggingface", "distributed_backend": None, "device": "cpu"} diff --git a/llama_stack/providers/inline/post_training/huggingface/post_training.py b/llama_stack/providers/inline/post_training/huggingface/post_training.py deleted file mode 100644 index 0b2760792..000000000 --- a/llama_stack/providers/inline/post_training/huggingface/post_training.py +++ /dev/null @@ -1,176 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. -from enum import Enum -from typing import Any - -from llama_stack.apis.datasetio import DatasetIO -from llama_stack.apis.datasets import Datasets -from llama_stack.apis.post_training import ( - AlgorithmConfig, - Checkpoint, - DPOAlignmentConfig, - JobStatus, - ListPostTrainingJobsResponse, - PostTrainingJob, - PostTrainingJobArtifactsResponse, - PostTrainingJobStatusResponse, - TrainingConfig, -) -from llama_stack.providers.inline.post_training.huggingface.config import ( - HuggingFacePostTrainingConfig, -) -from llama_stack.providers.inline.post_training.huggingface.recipes.finetune_single_device import ( - HFFinetuningSingleDevice, -) -from llama_stack.providers.utils.scheduler import JobArtifact, Scheduler -from llama_stack.providers.utils.scheduler import JobStatus as SchedulerJobStatus -from llama_stack.schema_utils import webmethod - - -class TrainingArtifactType(Enum): - CHECKPOINT = "checkpoint" - RESOURCES_STATS = "resources_stats" - - -_JOB_TYPE_SUPERVISED_FINE_TUNE = "supervised-fine-tune" - - -class HuggingFacePostTrainingImpl: - def __init__( - self, - config: HuggingFacePostTrainingConfig, - datasetio_api: DatasetIO, - datasets: Datasets, - ) -> None: - self.config = config - self.datasetio_api = datasetio_api - self.datasets_api = datasets - self._scheduler = Scheduler() - - async def shutdown(self) -> None: - await self._scheduler.shutdown() - - @staticmethod - def _checkpoint_to_artifact(checkpoint: Checkpoint) -> JobArtifact: - return JobArtifact( - type=TrainingArtifactType.CHECKPOINT.value, - name=checkpoint.identifier, - uri=checkpoint.path, - metadata=dict(checkpoint), - ) - - @staticmethod - def _resources_stats_to_artifact(resources_stats: dict[str, Any]) -> JobArtifact: - return JobArtifact( - type=TrainingArtifactType.RESOURCES_STATS.value, - name=TrainingArtifactType.RESOURCES_STATS.value, - metadata=resources_stats, - ) - - async def supervised_fine_tune( - self, - job_uuid: str, - training_config: TrainingConfig, - hyperparam_search_config: dict[str, Any], - logger_config: dict[str, Any], - model: str, - checkpoint_dir: str | None = None, - algorithm_config: AlgorithmConfig | None = None, - ) -> PostTrainingJob: - async def handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb): - on_log_message_cb("Starting HF finetuning") - - recipe = HFFinetuningSingleDevice( - job_uuid=job_uuid, - datasetio_api=self.datasetio_api, - datasets_api=self.datasets_api, - ) - - resources_allocated, checkpoints = await recipe.train( - model=model, - output_dir=checkpoint_dir, - job_uuid=job_uuid, - lora_config=algorithm_config, - config=training_config, - provider_config=self.config, - ) - - on_artifact_collected_cb(self._resources_stats_to_artifact(resources_allocated)) - if checkpoints: - for checkpoint in checkpoints: - artifact = self._checkpoint_to_artifact(checkpoint) - on_artifact_collected_cb(artifact) - - on_status_change_cb(SchedulerJobStatus.completed) - on_log_message_cb("HF finetuning completed") - - job_uuid = self._scheduler.schedule(_JOB_TYPE_SUPERVISED_FINE_TUNE, job_uuid, handler) - return PostTrainingJob(job_uuid=job_uuid) - - async def preference_optimize( - self, - job_uuid: str, - finetuned_model: str, - algorithm_config: DPOAlignmentConfig, - training_config: TrainingConfig, - hyperparam_search_config: dict[str, Any], - logger_config: dict[str, Any], - ) -> PostTrainingJob: - raise NotImplementedError("DPO alignment is not implemented yet") - - async def get_training_jobs(self) -> ListPostTrainingJobsResponse: - return ListPostTrainingJobsResponse( - data=[PostTrainingJob(job_uuid=job.id) for job in self._scheduler.get_jobs()] - ) - - @staticmethod - def _get_artifacts_metadata_by_type(job, artifact_type): - return [artifact.metadata for artifact in job.artifacts if artifact.type == artifact_type] - - @classmethod - def _get_checkpoints(cls, job): - return cls._get_artifacts_metadata_by_type(job, TrainingArtifactType.CHECKPOINT.value) - - @classmethod - def _get_resources_allocated(cls, job): - data = cls._get_artifacts_metadata_by_type(job, TrainingArtifactType.RESOURCES_STATS.value) - return data[0] if data else None - - @webmethod(route="/post-training/job/status") - async def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse | None: - job = self._scheduler.get_job(job_uuid) - - match job.status: - # TODO: Add support for other statuses to API - case SchedulerJobStatus.new | SchedulerJobStatus.scheduled: - status = JobStatus.scheduled - case SchedulerJobStatus.running: - status = JobStatus.in_progress - case SchedulerJobStatus.completed: - status = JobStatus.completed - case SchedulerJobStatus.failed: - status = JobStatus.failed - case _: - raise NotImplementedError() - - return PostTrainingJobStatusResponse( - job_uuid=job_uuid, - status=status, - scheduled_at=job.scheduled_at, - started_at=job.started_at, - completed_at=job.completed_at, - checkpoints=self._get_checkpoints(job), - resources_allocated=self._get_resources_allocated(job), - ) - - @webmethod(route="/post-training/job/cancel") - async def cancel_training_job(self, job_uuid: str) -> None: - self._scheduler.cancel(job_uuid) - - @webmethod(route="/post-training/job/artifacts") - async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse | None: - job = self._scheduler.get_job(job_uuid) - return PostTrainingJobArtifactsResponse(job_uuid=job_uuid, checkpoints=self._get_checkpoints(job)) diff --git a/llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device.py b/llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device.py deleted file mode 100644 index b6d13b029..000000000 --- a/llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device.py +++ /dev/null @@ -1,683 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import gc -import json -import logging -import multiprocessing -import os -import signal -import sys -from datetime import datetime, timezone -from pathlib import Path -from typing import Any - -import psutil - -from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device - -# Set tokenizer parallelism environment variable -os.environ["TOKENIZERS_PARALLELISM"] = "false" - -# Force PyTorch to use OpenBLAS instead of MKL -os.environ["MKL_THREADING_LAYER"] = "GNU" -os.environ["MKL_SERVICE_FORCE_INTEL"] = "0" -os.environ["MKL_NUM_THREADS"] = "1" - -import torch -from datasets import Dataset -from peft import LoraConfig -from transformers import ( - AutoConfig, - AutoModelForCausalLM, - AutoTokenizer, -) -from trl import SFTConfig, SFTTrainer - -from llama_stack.apis.datasetio import DatasetIO -from llama_stack.apis.datasets import Datasets -from llama_stack.apis.post_training import ( - Checkpoint, - DataConfig, - LoraFinetuningConfig, - TrainingConfig, -) - -from ..config import HuggingFacePostTrainingConfig - -logger = logging.getLogger(__name__) - - -def get_gb(to_convert: int) -> str: - """Converts memory stats to GB and formats to 2 decimal places. - Args: - to_convert: Memory value in bytes - Returns: - str: Memory value in GB formatted to 2 decimal places - """ - return f"{(to_convert / (1024**3)):.2f}" - - -def get_memory_stats(device: torch.device) -> dict[str, Any]: - """Get memory statistics for the given device.""" - stats = { - "system_memory": { - "total": get_gb(psutil.virtual_memory().total), - "available": get_gb(psutil.virtual_memory().available), - "used": get_gb(psutil.virtual_memory().used), - "percent": psutil.virtual_memory().percent, - } - } - - if device.type == "cuda": - stats["device_memory"] = { - "allocated": get_gb(torch.cuda.memory_allocated(device)), - "reserved": get_gb(torch.cuda.memory_reserved(device)), - "max_allocated": get_gb(torch.cuda.max_memory_allocated(device)), - } - elif device.type == "mps": - # MPS doesn't provide direct memory stats, but we can track system memory - stats["device_memory"] = { - "note": "MPS memory stats not directly available", - "system_memory_used": get_gb(psutil.virtual_memory().used), - } - elif device.type == "cpu": - # For CPU, we track process memory usage - process = psutil.Process() - stats["device_memory"] = { - "process_rss": get_gb(process.memory_info().rss), - "process_vms": get_gb(process.memory_info().vms), - "process_percent": process.memory_percent(), - } - - return stats - - -def setup_torch_device(device_str: str) -> torch.device: - """Initialize and validate a PyTorch device. - This function handles device initialization and validation for different device types: - - CUDA: Validates CUDA availability and handles device selection - - MPS: Validates MPS availability for Apple Silicon - - CPU: Basic validation - - HPU: Raises error as it's not supported - Args: - device_str: String specifying the device ('cuda', 'cpu', 'mps') - Returns: - torch.device: The initialized and validated device - Raises: - RuntimeError: If device initialization fails or device is not supported - """ - try: - device = torch.device(device_str) - except RuntimeError as e: - raise RuntimeError(f"Error getting Torch Device {str(e)}") from e - - # Validate device capabilities - if device.type == "cuda": - if not torch.cuda.is_available(): - raise RuntimeError( - f"{device.type}: Torch has no CUDA/ROCm support or could not detect a compatible device." - ) - if device.index is None: - device = torch.device(device.type, torch.cuda.current_device()) - elif device.type == "mps": - if not torch.backends.mps.is_available(): - raise RuntimeError(f"{device.type}: Torch has no MPS support or could not detect a compatible device.") - elif device.type == "hpu": - raise RuntimeError(f"{device.type}: training does not support Intel Gaudi.") - - return device - - -class HFFinetuningSingleDevice: - def __init__( - self, - job_uuid: str, - datasetio_api: DatasetIO, - datasets_api: Datasets, - ): - self.datasetio_api = datasetio_api - self.datasets_api = datasets_api - self.job_uuid = job_uuid - - def validate_dataset_format(self, rows: list[dict]) -> bool: - """Validate that the dataset has the required fields.""" - required_fields = ["input_query", "expected_answer", "chat_completion_input"] - return all(field in row for row in rows for field in required_fields) - - def _process_instruct_format(self, row: dict) -> tuple[str | None, str | None]: - """Process a row in instruct format.""" - if "chat_completion_input" in row and "expected_answer" in row: - try: - messages = json.loads(row["chat_completion_input"]) - if not isinstance(messages, list) or len(messages) != 1: - logger.warning(f"Invalid chat_completion_input format: {row['chat_completion_input']}") - return None, None - if "content" not in messages[0]: - logger.warning(f"Message missing content: {messages[0]}") - return None, None - return messages[0]["content"], row["expected_answer"] - except json.JSONDecodeError: - logger.warning(f"Failed to parse chat_completion_input: {row['chat_completion_input']}") - return None, None - return None, None - - def _process_dialog_format(self, row: dict) -> tuple[str | None, str | None]: - """Process a row in dialog format.""" - if "dialog" in row: - try: - dialog = json.loads(row["dialog"]) - if not isinstance(dialog, list) or len(dialog) < 2: - logger.warning(f"Dialog must have at least 2 messages: {row['dialog']}") - return None, None - if dialog[0].get("role") != "user": - logger.warning(f"First message must be from user: {dialog[0]}") - return None, None - if not any(msg.get("role") == "assistant" for msg in dialog): - logger.warning("Dialog must have at least one assistant message") - return None, None - - # Convert to human/gpt format - role_map = {"user": "human", "assistant": "gpt"} - conversations = [] - for msg in dialog: - if "role" not in msg or "content" not in msg: - logger.warning(f"Message missing role or content: {msg}") - continue - conversations.append({"from": role_map[msg["role"]], "value": msg["content"]}) - - # Format as a single conversation - return conversations[0]["value"], conversations[1]["value"] - except json.JSONDecodeError: - logger.warning(f"Failed to parse dialog: {row['dialog']}") - return None, None - return None, None - - def _process_fallback_format(self, row: dict) -> tuple[str | None, str | None]: - """Process a row using fallback formats.""" - if "input" in row and "output" in row: - return row["input"], row["output"] - elif "prompt" in row and "completion" in row: - return row["prompt"], row["completion"] - elif "question" in row and "answer" in row: - return row["question"], row["answer"] - return None, None - - def _format_text(self, input_text: str, output_text: str, provider_config: HuggingFacePostTrainingConfig) -> str: - """Format input and output text based on model requirements.""" - if hasattr(provider_config, "chat_template"): - return provider_config.chat_template.format(input=input_text, output=output_text) - return f"{input_text}\n{output_text}" - - def _create_dataset( - self, rows: list[dict], config: TrainingConfig, provider_config: HuggingFacePostTrainingConfig - ) -> Dataset: - """Create and preprocess the dataset.""" - formatted_rows = [] - for row in rows: - input_text = None - output_text = None - - # Process based on format - assert isinstance(config.data_config, DataConfig), "DataConfig must be initialized" - if config.data_config.data_format.value == "instruct": - input_text, output_text = self._process_instruct_format(row) - elif config.data_config.data_format.value == "dialog": - input_text, output_text = self._process_dialog_format(row) - else: - input_text, output_text = self._process_fallback_format(row) - - if input_text and output_text: - formatted_text = self._format_text(input_text, output_text, provider_config) - formatted_rows.append({"text": formatted_text}) - - if not formatted_rows: - assert isinstance(config.data_config, DataConfig), "DataConfig must be initialized" - raise ValueError( - f"No valid input/output pairs found in the dataset for format: {config.data_config.data_format.value}" - ) - - return Dataset.from_list(formatted_rows) - - def _preprocess_dataset( - self, ds: Dataset, tokenizer: AutoTokenizer, provider_config: HuggingFacePostTrainingConfig - ) -> Dataset: - """Preprocess the dataset with tokenizer.""" - - def tokenize_function(examples): - return tokenizer( - examples["text"], - padding=True, - truncation=True, - max_length=provider_config.max_seq_length, - return_tensors=None, - ) - - return ds.map( - tokenize_function, - batched=True, - remove_columns=ds.column_names, - ) - - async def _setup_data(self, dataset_id: str) -> list[dict[str, Any]]: - """Load dataset from llama stack dataset provider""" - try: - all_rows = await self.datasetio_api.iterrows( - dataset_id=dataset_id, - limit=-1, - ) - if not isinstance(all_rows.data, list): - raise RuntimeError("Expected dataset data to be a list") - return all_rows.data - except Exception as e: - raise RuntimeError(f"Failed to load dataset: {str(e)}") from e - - def _run_training_sync( - self, - model: str, - provider_config: dict[str, Any], - peft_config: LoraConfig | None, - config: dict[str, Any], - output_dir_path: Path | None, - ) -> None: - """Synchronous wrapper for running training process. - This method serves as a bridge between the multiprocessing Process and the async training function. - It creates a new event loop to run the async training process. - Args: - model: The model identifier to load - dataset_id: ID of the dataset to use for training - provider_config: Configuration specific to the HuggingFace provider - peft_config: Optional LoRA configuration - config: General training configuration - output_dir_path: Optional path to save the model - """ - import asyncio - - logger.info("Starting training process with async wrapper") - asyncio.run( - self._run_training( - model=model, - provider_config=provider_config, - peft_config=peft_config, - config=config, - output_dir_path=output_dir_path, - ) - ) - - async def load_dataset( - self, - model: str, - config: TrainingConfig, - provider_config: HuggingFacePostTrainingConfig, - ) -> tuple[Dataset, Dataset, AutoTokenizer]: - """Load and prepare the dataset for training. - Args: - model: The model identifier to load - config: Training configuration - provider_config: Provider-specific configuration - Returns: - tuple: (train_dataset, eval_dataset, tokenizer) - """ - # Validate data config - if not config.data_config: - raise ValueError("DataConfig is required for training") - - # Load dataset - logger.info(f"Loading dataset: {config.data_config.dataset_id}") - rows = await self._setup_data(config.data_config.dataset_id) - if not self.validate_dataset_format(rows): - raise ValueError("Dataset is missing required fields: input_query, expected_answer, chat_completion_input") - logger.info(f"Loaded {len(rows)} rows from dataset") - - # Initialize tokenizer - logger.info(f"Initializing tokenizer for model: {model}") - try: - tokenizer = AutoTokenizer.from_pretrained(model, **provider_config.model_specific_config) - - # Set pad token to eos token if not present - # This is common for models that don't have a dedicated pad token - if not tokenizer.pad_token: - tokenizer.pad_token = tokenizer.eos_token - - # Set padding side to right for causal language modeling - # This ensures that padding tokens don't interfere with the model's ability - # to predict the next token in the sequence - tokenizer.padding_side = "right" - - # Set truncation side to right to keep the beginning of the sequence - # This is important for maintaining context and instruction format - tokenizer.truncation_side = "right" - - # Set model max length to match provider config - # This ensures consistent sequence lengths across the training process - tokenizer.model_max_length = provider_config.max_seq_length - - logger.info("Tokenizer initialized successfully") - except Exception as e: - raise RuntimeError(f"Failed to initialize tokenizer: {str(e)}") from e - - # Create and preprocess dataset - logger.info("Creating and preprocessing dataset") - try: - ds = self._create_dataset(rows, config, provider_config) - ds = self._preprocess_dataset(ds, tokenizer, provider_config) - logger.info(f"Dataset created with {len(ds)} examples") - except Exception as e: - raise ValueError(f"Failed to create dataset: {str(e)}") from e - - # Split dataset - logger.info("Splitting dataset into train and validation sets") - train_val_split = ds.train_test_split(test_size=0.1, seed=42) - train_dataset = train_val_split["train"] - eval_dataset = train_val_split["test"] - logger.info(f"Split dataset into {len(train_dataset)} training and {len(eval_dataset)} validation examples") - - return train_dataset, eval_dataset, tokenizer - - def load_model( - self, - model: str, - device: torch.device, - provider_config: HuggingFacePostTrainingConfig, - ) -> AutoModelForCausalLM: - """Load and initialize the model for training. - Args: - model: The model identifier to load - device: The device to load the model onto - provider_config: Provider-specific configuration - Returns: - The loaded and initialized model - Raises: - RuntimeError: If model loading fails - """ - logger.info("Loading the base model") - try: - model_config = AutoConfig.from_pretrained(model, **provider_config.model_specific_config) - model_obj = AutoModelForCausalLM.from_pretrained( - model, - torch_dtype="auto" if device.type != "cpu" else "float32", - quantization_config=None, - config=model_config, - **provider_config.model_specific_config, - ) - # Always move model to specified device - model_obj = model_obj.to(device) - logger.info(f"Model loaded and moved to device: {model_obj.device}") - return model_obj - except Exception as e: - raise RuntimeError(f"Failed to load model: {str(e)}") from e - - def setup_training_args( - self, - config: TrainingConfig, - provider_config: HuggingFacePostTrainingConfig, - device: torch.device, - output_dir_path: Path | None, - steps_per_epoch: int, - ) -> SFTConfig: - """Setup training arguments. - Args: - config: Training configuration - provider_config: Provider-specific configuration - device: The device to train on - output_dir_path: Optional path to save the model - steps_per_epoch: Number of steps per epoch - Returns: - Configured SFTConfig object - """ - logger.info("Configuring training arguments") - lr = 2e-5 - if config.optimizer_config: - lr = config.optimizer_config.lr - logger.info(f"Using custom learning rate: {lr}") - - # Validate data config - if not config.data_config: - raise ValueError("DataConfig is required for training") - data_config = config.data_config - - # Calculate steps - total_steps = steps_per_epoch * config.n_epochs - max_steps = min(config.max_steps_per_epoch, total_steps) - eval_steps = max(1, steps_per_epoch // 10) # Evaluate 10 times per epoch - save_steps = max(1, steps_per_epoch // 5) # Save 5 times per epoch - logging_steps = max(1, steps_per_epoch // 50) # Log 50 times per epoch - - logger.info("Training configuration:") - logger.info(f"- Steps per epoch: {steps_per_epoch}") - logger.info(f"- Total steps: {total_steps}") - logger.info(f"- Max steps: {max_steps}") - logger.info(f"- Eval steps: {eval_steps}") - logger.info(f"- Save steps: {save_steps}") - logger.info(f"- Logging steps: {logging_steps}") - - # Configure save strategy - save_strategy = "no" - if output_dir_path: - save_strategy = "steps" - logger.info(f"Will save checkpoints to {output_dir_path}") - - return SFTConfig( - max_steps=max_steps, - output_dir=str(output_dir_path) if output_dir_path is not None else None, - num_train_epochs=config.n_epochs, - per_device_train_batch_size=data_config.batch_size, - fp16=device.type == "cuda", - bf16=False, # Causes CPU issues. - eval_strategy="steps", - use_cpu=True if device.type == "cpu" and not torch.backends.mps.is_available() else False, - save_strategy=save_strategy, - report_to="none", - max_seq_length=provider_config.max_seq_length, - gradient_accumulation_steps=config.gradient_accumulation_steps, - gradient_checkpointing=provider_config.gradient_checkpointing, - learning_rate=lr, - warmup_ratio=provider_config.warmup_ratio, - weight_decay=provider_config.weight_decay, - remove_unused_columns=False, - dataloader_pin_memory=provider_config.dataloader_pin_memory, - dataloader_num_workers=provider_config.dataloader_num_workers, - dataset_text_field="text", - packing=False, - load_best_model_at_end=True if output_dir_path else False, - metric_for_best_model="eval_loss", - greater_is_better=False, - eval_steps=eval_steps, - save_steps=save_steps, - logging_steps=logging_steps, - ) - - def save_model( - self, - model_obj: AutoModelForCausalLM, - trainer: SFTTrainer, - peft_config: LoraConfig | None, - output_dir_path: Path, - ) -> None: - """Save the trained model. - Args: - model_obj: The model to save - trainer: The trainer instance - peft_config: Optional LoRA configuration - output_dir_path: Path to save the model - """ - logger.info("Saving final model") - model_obj.config.use_cache = True - - if peft_config: - logger.info("Merging LoRA weights with base model") - model_obj = trainer.model.merge_and_unload() - else: - model_obj = trainer.model - - save_path = output_dir_path / "merged_model" - logger.info(f"Saving model to {save_path}") - model_obj.save_pretrained(save_path) - - async def _run_training( - self, - model: str, - provider_config: dict[str, Any], - peft_config: LoraConfig | None, - config: dict[str, Any], - output_dir_path: Path | None, - ) -> None: - """Run the training process with signal handling.""" - - def signal_handler(signum, frame): - """Handle termination signals gracefully.""" - logger.info(f"Received signal {signum}, initiating graceful shutdown") - sys.exit(0) - - signal.signal(signal.SIGTERM, signal_handler) - signal.signal(signal.SIGINT, signal_handler) - - # Convert config dicts back to objects - logger.info("Initializing configuration objects") - provider_config_obj = HuggingFacePostTrainingConfig(**provider_config) - config_obj = TrainingConfig(**config) - - # Initialize and validate device - device = setup_torch_device(provider_config_obj.device) - logger.info(f"Using device '{device}'") - - # Load dataset and tokenizer - train_dataset, eval_dataset, tokenizer = await self.load_dataset(model, config_obj, provider_config_obj) - - # Calculate steps per epoch - if not config_obj.data_config: - raise ValueError("DataConfig is required for training") - steps_per_epoch = len(train_dataset) // config_obj.data_config.batch_size - - # Setup training arguments - training_args = self.setup_training_args( - config_obj, - provider_config_obj, - device, - output_dir_path, - steps_per_epoch, - ) - - # Load model - model_obj = self.load_model(model, device, provider_config_obj) - - # Initialize trainer - logger.info("Initializing SFTTrainer") - trainer = SFTTrainer( - model=model_obj, - train_dataset=train_dataset, - eval_dataset=eval_dataset, - peft_config=peft_config, - args=training_args, - ) - - try: - # Train - logger.info("Starting training") - trainer.train() - logger.info("Training completed successfully") - - # Save final model if output directory is provided - if output_dir_path: - self.save_model(model_obj, trainer, peft_config, output_dir_path) - - finally: - # Clean up resources - logger.info("Cleaning up resources") - if hasattr(trainer, "model"): - evacuate_model_from_device(trainer.model, device.type) - del trainer - gc.collect() - logger.info("Cleanup completed") - - async def train( - self, - model: str, - output_dir: str | None, - job_uuid: str, - lora_config: LoraFinetuningConfig, - config: TrainingConfig, - provider_config: HuggingFacePostTrainingConfig, - ) -> tuple[dict[str, Any], list[Checkpoint] | None]: - """Train a model using HuggingFace's SFTTrainer""" - # Initialize and validate device - device = setup_torch_device(provider_config.device) - logger.info(f"Using device '{device}'") - - output_dir_path = None - if output_dir: - output_dir_path = Path(output_dir) - - # Track memory stats - memory_stats = { - "initial": get_memory_stats(device), - "after_training": None, - "final": None, - } - - # Configure LoRA - peft_config = None - if lora_config: - peft_config = LoraConfig( - lora_alpha=lora_config.alpha, - lora_dropout=0.1, - r=lora_config.rank, - bias="none", - task_type="CAUSAL_LM", - target_modules=lora_config.lora_attn_modules, - ) - - # Validate data config - if not config.data_config: - raise ValueError("DataConfig is required for training") - - # Train in a separate process - logger.info("Starting training in separate process") - try: - # Set multiprocessing start method to 'spawn' for CUDA/MPS compatibility - if device.type in ["cuda", "mps"]: - multiprocessing.set_start_method("spawn", force=True) - - process = multiprocessing.Process( - target=self._run_training_sync, - kwargs={ - "model": model, - "provider_config": provider_config.model_dump(), - "peft_config": peft_config, - "config": config.model_dump(), - "output_dir_path": output_dir_path, - }, - ) - process.start() - - # Monitor the process - while process.is_alive(): - process.join(timeout=1) # Check every second - if not process.is_alive(): - break - - # Get the return code - if process.exitcode != 0: - raise RuntimeError(f"Training failed with exit code {process.exitcode}") - - memory_stats["after_training"] = get_memory_stats(device) - - checkpoints = None - if output_dir_path: - # Create checkpoint - checkpoint = Checkpoint( - identifier=f"{model}-sft-{config.n_epochs}", - created_at=datetime.now(timezone.utc), - epoch=config.n_epochs, - post_training_job_id=job_uuid, - path=str(output_dir_path / "merged_model"), - ) - checkpoints = [checkpoint] - - return memory_stats, checkpoints - finally: - memory_stats["final"] = get_memory_stats(device) - gc.collect() diff --git a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py index f56dd2499..b5a495935 100644 --- a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +++ b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py @@ -4,6 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import gc import logging import os import time @@ -46,7 +47,6 @@ from llama_stack.apis.post_training import ( from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR from llama_stack.distribution.utils.model_utils import model_local_dir from llama_stack.models.llama.sku_list import resolve_model -from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device from llama_stack.providers.inline.post_training.torchtune.common import utils from llama_stack.providers.inline.post_training.torchtune.common.checkpointer import ( TorchtuneCheckpointer, @@ -554,7 +554,11 @@ class LoraFinetuningSingleDevice: checkpoints.append(checkpoint) # clean up the memory after training finishes - evacuate_model_from_device(self._model, self._device.type) + if self._device.type != "cpu": + self._model.to("cpu") + torch.cuda.empty_cache() + del self._model + gc.collect() return (memory_stats, checkpoints) diff --git a/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py b/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py index ff87889ea..56ce8285f 100644 --- a/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py +++ b/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py @@ -75,9 +75,7 @@ class PromptGuardShield: self.temperature = temperature self.threshold = threshold - self.device = "cpu" - if torch.cuda.is_available(): - self.device = "cuda" + self.device = "cuda" # load model and tokenizer self.tokenizer = AutoTokenizer.from_pretrained(model_dir) diff --git a/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py b/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py index 0f6cf8619..67362dd36 100644 --- a/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py +++ b/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py @@ -16,7 +16,6 @@ from opentelemetry.sdk.resources import Resource from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import BatchSpanProcessor from opentelemetry.semconv.resource import ResourceAttributes -from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator from llama_stack.apis.telemetry import ( Event, @@ -45,7 +44,6 @@ from llama_stack.providers.inline.telemetry.meta_reference.sqlite_span_processor ) from llama_stack.providers.utils.telemetry.dataset_mixin import TelemetryDatasetMixin from llama_stack.providers.utils.telemetry.sqlite_trace_store import SQLiteTraceStore -from llama_stack.providers.utils.telemetry.tracing import ROOT_SPAN_MARKERS from .config import TelemetryConfig, TelemetrySink @@ -148,7 +146,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry): if span: timestamp_ns = int(event.timestamp.timestamp() * 1e9) span.add_event( - name=event.type.value, + name=event.type, attributes={ "message": event.message, "severity": event.severity.value, @@ -208,15 +206,6 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry): event.attributes = {} event.attributes["__ttl__"] = ttl_seconds - # Extract these W3C trace context attributes so they are not written to - # underlying storage, as we just need them to propagate the trace context. - traceparent = event.attributes.pop("traceparent", None) - tracestate = event.attributes.pop("tracestate", None) - if traceparent: - # If we have a traceparent header value, we're not the root span. - for root_attribute in ROOT_SPAN_MARKERS: - event.attributes.pop(root_attribute, None) - if isinstance(event.payload, SpanStartPayload): # Check if span already exists to prevent duplicates if span_id in _GLOBAL_STORAGE["active_spans"]: @@ -227,12 +216,8 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry): parent_span_id = int(event.payload.parent_span_id, 16) parent_span = _GLOBAL_STORAGE["active_spans"].get(parent_span_id) context = trace.set_span_in_context(parent_span) - elif traceparent: - carrier = { - "traceparent": traceparent, - "tracestate": tracestate, - } - context = TraceContextTextMapPropagator().extract(carrier=carrier) + else: + event.attributes["__root_span__"] = "true" span = tracer.start_span( name=event.payload.name, diff --git a/llama_stack/providers/inline/tool_runtime/rag/memory.py b/llama_stack/providers/inline/tool_runtime/rag/memory.py index 4776d47d0..39f752297 100644 --- a/llama_stack/providers/inline/tool_runtime/rag/memory.py +++ b/llama_stack/providers/inline/tool_runtime/rag/memory.py @@ -25,14 +25,14 @@ from llama_stack.apis.tools import ( RAGQueryConfig, RAGQueryResult, RAGToolRuntime, + Tool, ToolDef, - ToolGroup, ToolInvocationResult, ToolParameter, ToolRuntime, ) from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO -from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate +from llama_stack.providers.datatypes import ToolsProtocolPrivate from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str from llama_stack.providers.utils.memory.vector_store import ( content_from_doc, @@ -49,7 +49,7 @@ def make_random_string(length: int = 8): return "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(length)) -class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRuntime): +class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime): def __init__( self, config: RagToolRuntimeConfig, @@ -66,10 +66,10 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti async def shutdown(self): pass - async def register_toolgroup(self, toolgroup: ToolGroup) -> None: + async def register_tool(self, tool: Tool) -> None: pass - async def unregister_toolgroup(self, toolgroup_id: str) -> None: + async def unregister_tool(self, tool_id: str) -> None: return async def insert( @@ -122,7 +122,6 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti query=query, params={ "max_chunks": query_config.max_chunks, - "mode": query_config.mode, }, ) for vector_db_id in vector_db_ids @@ -146,8 +145,8 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti ] for i, chunk in enumerate(chunks): metadata = chunk.metadata - tokens += metadata.get("token_count", 0) - tokens += metadata.get("metadata_token_count", 0) + tokens += metadata["token_count"] + tokens += metadata["metadata_token_count"] if tokens > query_config.max_tokens_in_context: log.error( diff --git a/llama_stack/providers/inline/vector_io/faiss/faiss.py b/llama_stack/providers/inline/vector_io/faiss/faiss.py index 47256d88d..d3dc7e694 100644 --- a/llama_stack/providers/inline/vector_io/faiss/faiss.py +++ b/llama_stack/providers/inline/vector_io/faiss/faiss.py @@ -99,13 +99,9 @@ class FaissIndex(EmbeddingIndex): # Save updated index await self._save_index() - async def query_vector( - self, - embedding: NDArray, - k: int, - score_threshold: float, - ) -> QueryChunksResponse: + async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: distances, indices = await asyncio.to_thread(self.index.search, embedding.reshape(1, -1).astype(np.float32), k) + chunks = [] scores = [] for d, i in zip(distances[0], indices[0], strict=False): @@ -116,14 +112,6 @@ class FaissIndex(EmbeddingIndex): return QueryChunksResponse(chunks=chunks, scores=scores) - async def query_keyword( - self, - query_string: str, - k: int, - score_threshold: float, - ) -> QueryChunksResponse: - raise NotImplementedError("Keyword search is not supported in FAISS") - class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): def __init__(self, config: FaissVectorIOConfig, inference_api: Inference) -> None: diff --git a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py index fc1a8ddb0..ab4384021 100644 --- a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +++ b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py @@ -24,11 +24,6 @@ from llama_stack.providers.utils.memory.vector_store import EmbeddingIndex, Vect logger = logging.getLogger(__name__) -# Specifying search mode is dependent on the VectorIO provider. -VECTOR_SEARCH = "vector" -KEYWORD_SEARCH = "keyword" -SEARCH_MODES = {VECTOR_SEARCH, KEYWORD_SEARCH} - def serialize_vector(vector: list[float]) -> bytes: """Serialize a list of floats into a compact binary representation.""" @@ -50,7 +45,6 @@ class SQLiteVecIndex(EmbeddingIndex): Two tables are used: - A metadata table (chunks_{bank_id}) that holds the chunk JSON. - A virtual table (vec_chunks_{bank_id}) that holds the serialized vector. - - An FTS5 table (fts_chunks_{bank_id}) for full-text keyword search. """ def __init__(self, dimension: int, db_path: str, bank_id: str): @@ -59,7 +53,6 @@ class SQLiteVecIndex(EmbeddingIndex): self.bank_id = bank_id self.metadata_table = f"chunks_{bank_id}".replace("-", "_") self.vector_table = f"vec_chunks_{bank_id}".replace("-", "_") - self.fts_table = f"fts_chunks_{bank_id}".replace("-", "_") @classmethod async def create(cls, dimension: int, db_path: str, bank_id: str): @@ -85,14 +78,6 @@ class SQLiteVecIndex(EmbeddingIndex): USING vec0(embedding FLOAT[{self.dimension}], id TEXT); """) connection.commit() - # FTS5 table (for keyword search) - creating both the tables by default. Will use the relevant one - # based on query. Implementation of the change on client side will allow passing the search_mode option - # during initialization to make it easier to create the table that is required. - cur.execute(f""" - CREATE VIRTUAL TABLE IF NOT EXISTS {self.fts_table} - USING fts5(id, content); - """) - connection.commit() finally: cur.close() connection.close() @@ -106,7 +91,6 @@ class SQLiteVecIndex(EmbeddingIndex): try: cur.execute(f"DROP TABLE IF EXISTS {self.metadata_table};") cur.execute(f"DROP TABLE IF EXISTS {self.vector_table};") - cur.execute(f"DROP TABLE IF EXISTS {self.fts_table};") connection.commit() finally: cur.close() @@ -120,7 +104,6 @@ class SQLiteVecIndex(EmbeddingIndex): For each chunk, we insert its JSON into the metadata table and then insert its embedding (serialized to raw bytes) into the virtual table using the assigned rowid. If any insert fails, the transaction is rolled back to maintain consistency. - Also inserts chunk content into FTS table for keyword search support. """ assert all(isinstance(chunk.content, str) for chunk in chunks), "SQLiteVecIndex only supports text chunks" @@ -129,16 +112,18 @@ class SQLiteVecIndex(EmbeddingIndex): cur = connection.cursor() try: + # Start transaction a single transcation for all batches cur.execute("BEGIN TRANSACTION") for i in range(0, len(chunks), batch_size): batch_chunks = chunks[i : i + batch_size] batch_embeddings = embeddings[i : i + batch_size] - - # Insert metadata + # Prepare metadata inserts metadata_data = [ (generate_chunk_id(chunk.metadata["document_id"], chunk.content), chunk.model_dump_json()) for chunk in batch_chunks + if isinstance(chunk.content, str) ] + # Insert metadata (ON CONFLICT to avoid duplicates) cur.executemany( f""" INSERT INTO {self.metadata_table} (id, chunk) @@ -147,43 +132,21 @@ class SQLiteVecIndex(EmbeddingIndex): """, metadata_data, ) - - # Insert vector embeddings + # Prepare embeddings inserts embedding_data = [ ( - ( - generate_chunk_id(chunk.metadata["document_id"], chunk.content), - serialize_vector(emb.tolist()), - ) + generate_chunk_id(chunk.metadata["document_id"], chunk.content), + serialize_vector(emb.tolist()), ) for chunk, emb in zip(batch_chunks, batch_embeddings, strict=True) + if isinstance(chunk.content, str) ] - cur.executemany( - f"INSERT INTO {self.vector_table} (id, embedding) VALUES (?, ?);", - embedding_data, - ) - - # Insert FTS content - fts_data = [ - (generate_chunk_id(chunk.metadata["document_id"], chunk.content), chunk.content) - for chunk in batch_chunks - ] - # DELETE existing entries with same IDs (FTS5 doesn't support ON CONFLICT) - cur.executemany( - f"DELETE FROM {self.fts_table} WHERE id = ?;", - [(row[0],) for row in fts_data], - ) - - # INSERT new entries - cur.executemany( - f"INSERT INTO {self.fts_table} (id, content) VALUES (?, ?);", - fts_data, - ) - + # Insert embeddings in batch + cur.executemany(f"INSERT INTO {self.vector_table} (id, embedding) VALUES (?, ?);", embedding_data) connection.commit() except sqlite3.Error as e: - connection.rollback() + connection.rollback() # Rollback on failure logger.error(f"Error inserting into {self.vector_table}: {e}") raise @@ -191,25 +154,22 @@ class SQLiteVecIndex(EmbeddingIndex): cur.close() connection.close() - # Run batch insertion in a background thread + # Process all batches in a single thread await asyncio.to_thread(_execute_all_batch_inserts) - async def query_vector( - self, - embedding: NDArray, - k: int, - score_threshold: float, - ) -> QueryChunksResponse: + async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: """ - Performs vector-based search using a virtual table for vector similarity. + Query for the k most similar chunks. We convert the query embedding to a blob and run a SQL query + against the virtual table. The SQL joins the metadata table to recover the chunk JSON. """ + emb_list = embedding.tolist() if isinstance(embedding, np.ndarray) else list(embedding) + emb_blob = serialize_vector(emb_list) def _execute_query(): connection = _create_sqlite_connection(self.db_path) cur = connection.cursor() + try: - emb_list = embedding.tolist() if isinstance(embedding, np.ndarray) else list(embedding) - emb_blob = serialize_vector(emb_list) query_sql = f""" SELECT m.id, m.chunk, v.distance FROM {self.vector_table} AS v @@ -224,66 +184,17 @@ class SQLiteVecIndex(EmbeddingIndex): connection.close() rows = await asyncio.to_thread(_execute_query) + chunks, scores = [], [] - for row in rows: - _id, chunk_json, distance = row + for _id, chunk_json, distance in rows: + try: + chunk = Chunk.model_validate_json(chunk_json) + except Exception as e: + logger.error(f"Error parsing chunk JSON for id {_id}: {e}") + continue + chunks.append(chunk) + # Mimic the Faiss scoring: score = 1/distance (avoid division by zero) score = 1.0 / distance if distance != 0 else float("inf") - if score < score_threshold: - continue - try: - chunk = Chunk.model_validate_json(chunk_json) - except Exception as e: - logger.error(f"Error parsing chunk JSON for id {_id}: {e}") - continue - chunks.append(chunk) - scores.append(score) - return QueryChunksResponse(chunks=chunks, scores=scores) - - async def query_keyword( - self, - query_string: str, - k: int, - score_threshold: float, - ) -> QueryChunksResponse: - """ - Performs keyword-based search using SQLite FTS5 for relevance-ranked full-text search. - """ - if query_string is None: - raise ValueError("query_string is required for keyword search.") - - def _execute_query(): - connection = _create_sqlite_connection(self.db_path) - cur = connection.cursor() - try: - query_sql = f""" - SELECT DISTINCT m.id, m.chunk, bm25({self.fts_table}) AS score - FROM {self.fts_table} AS f - JOIN {self.metadata_table} AS m ON m.id = f.id - WHERE f.content MATCH ? - ORDER BY score ASC - LIMIT ?; - """ - cur.execute(query_sql, (query_string, k)) - return cur.fetchall() - finally: - cur.close() - connection.close() - - rows = await asyncio.to_thread(_execute_query) - chunks, scores = [], [] - for row in rows: - _id, chunk_json, score = row - # BM25 scores returned by sqlite-vec are NEGATED (i.e., more relevant = more negative). - # This design is intentional to simplify sorting by ascending score. - # Reference: https://alexgarcia.xyz/blog/2024/sqlite-vec-hybrid-search/index.html - if score > -score_threshold: - continue - try: - chunk = Chunk.model_validate_json(chunk_json) - except Exception as e: - logger.error(f"Error parsing chunk JSON for id {_id}: {e}") - continue - chunks.append(chunk) scores.append(score) return QueryChunksResponse(chunks=chunks, scores=scores) diff --git a/llama_stack/providers/registry/post_training.py b/llama_stack/providers/registry/post_training.py index d752b8819..35567c07d 100644 --- a/llama_stack/providers/registry/post_training.py +++ b/llama_stack/providers/registry/post_training.py @@ -21,17 +21,6 @@ def available_providers() -> list[ProviderSpec]: Api.datasets, ], ), - InlineProviderSpec( - api=Api.post_training, - provider_type="inline::huggingface", - pip_packages=["torch", "trl", "transformers", "peft", "datasets"], - module="llama_stack.providers.inline.post_training.huggingface", - config_class="llama_stack.providers.inline.post_training.huggingface.HuggingFacePostTrainingConfig", - api_dependencies=[ - Api.datasetio, - Api.datasets, - ], - ), remote_provider_spec( api=Api.post_training, adapter=AdapterSpec( diff --git a/llama_stack/providers/registry/safety.py b/llama_stack/providers/registry/safety.py index e0a04be48..c209da092 100644 --- a/llama_stack/providers/registry/safety.py +++ b/llama_stack/providers/registry/safety.py @@ -63,14 +63,4 @@ def available_providers() -> list[ProviderSpec]: config_class="llama_stack.providers.remote.safety.nvidia.NVIDIASafetyConfig", ), ), - remote_provider_spec( - api=Api.safety, - adapter=AdapterSpec( - adapter_type="sambanova", - pip_packages=["litellm"], - module="llama_stack.providers.remote.safety.sambanova", - config_class="llama_stack.providers.remote.safety.sambanova.SambaNovaSafetyConfig", - provider_data_validator="llama_stack.providers.remote.safety.sambanova.config.SambaNovaProviderDataValidator", - ), - ), ] diff --git a/llama_stack/providers/registry/tool_runtime.py b/llama_stack/providers/registry/tool_runtime.py index 277914df2..b9194810e 100644 --- a/llama_stack/providers/registry/tool_runtime.py +++ b/llama_stack/providers/registry/tool_runtime.py @@ -80,9 +80,8 @@ def available_providers() -> list[ProviderSpec]: adapter=AdapterSpec( adapter_type="model-context-protocol", module="llama_stack.providers.remote.tool_runtime.model_context_protocol", - config_class="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.MCPProviderConfig", + config_class="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.ModelContextProtocolConfig", pip_packages=["mcp"], - provider_data_validator="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.MCPProviderDataValidator", ), ), ] diff --git a/llama_stack/providers/remote/inference/bedrock/bedrock.py b/llama_stack/providers/remote/inference/bedrock/bedrock.py index 952d86f1a..0404a578f 100644 --- a/llama_stack/providers/remote/inference/bedrock/bedrock.py +++ b/llama_stack/providers/remote/inference/bedrock/bedrock.py @@ -22,7 +22,6 @@ from llama_stack.apis.inference import ( Inference, LogProbConfig, Message, - OpenAIEmbeddingsResponse, ResponseFormat, SamplingParams, TextTruncation, @@ -198,13 +197,3 @@ class BedrockInferenceAdapter( response_body = json.loads(response.get("body").read()) embeddings.append(response_body.get("embedding")) return EmbeddingsResponse(embeddings=embeddings) - - async def openai_embeddings( - self, - model: str, - input: str | list[str], - encoding_format: str | None = "float", - dimensions: int | None = None, - user: str | None = None, - ) -> OpenAIEmbeddingsResponse: - raise NotImplementedError() diff --git a/llama_stack/providers/remote/inference/cerebras/cerebras.py b/llama_stack/providers/remote/inference/cerebras/cerebras.py index 952118e24..685375346 100644 --- a/llama_stack/providers/remote/inference/cerebras/cerebras.py +++ b/llama_stack/providers/remote/inference/cerebras/cerebras.py @@ -21,7 +21,6 @@ from llama_stack.apis.inference import ( Inference, LogProbConfig, Message, - OpenAIEmbeddingsResponse, ResponseFormat, SamplingParams, TextTruncation, @@ -195,13 +194,3 @@ class CerebrasInferenceAdapter( task_type: EmbeddingTaskType | None = None, ) -> EmbeddingsResponse: raise NotImplementedError() - - async def openai_embeddings( - self, - model: str, - input: str | list[str], - encoding_format: str | None = "float", - dimensions: int | None = None, - user: str | None = None, - ) -> OpenAIEmbeddingsResponse: - raise NotImplementedError() diff --git a/llama_stack/providers/remote/inference/cerebras_openai_compat/__init__.py b/llama_stack/providers/remote/inference/cerebras_openai_compat/__init__.py index 523a8dfe7..a5f07edd2 100644 --- a/llama_stack/providers/remote/inference/cerebras_openai_compat/__init__.py +++ b/llama_stack/providers/remote/inference/cerebras_openai_compat/__init__.py @@ -4,12 +4,12 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from llama_stack.apis.inference import InferenceProvider +from llama_stack.apis.inference import Inference from .config import CerebrasCompatConfig -async def get_adapter_impl(config: CerebrasCompatConfig, _deps) -> InferenceProvider: +async def get_adapter_impl(config: CerebrasCompatConfig, _deps) -> Inference: # import dynamically so the import is used only when it is needed from .cerebras import CerebrasCompatInferenceAdapter diff --git a/llama_stack/providers/remote/inference/databricks/databricks.py b/llama_stack/providers/remote/inference/databricks/databricks.py index 1dc18b97f..5c36eac3e 100644 --- a/llama_stack/providers/remote/inference/databricks/databricks.py +++ b/llama_stack/providers/remote/inference/databricks/databricks.py @@ -20,7 +20,6 @@ from llama_stack.apis.inference import ( Inference, LogProbConfig, Message, - OpenAIEmbeddingsResponse, ResponseFormat, SamplingParams, TextTruncation, @@ -153,13 +152,3 @@ class DatabricksInferenceAdapter( task_type: EmbeddingTaskType | None = None, ) -> EmbeddingsResponse: raise NotImplementedError() - - async def openai_embeddings( - self, - model: str, - input: str | list[str], - encoding_format: str | None = "float", - dimensions: int | None = None, - user: str | None = None, - ) -> OpenAIEmbeddingsResponse: - raise NotImplementedError() diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py index fe21685dd..b6d3984c6 100644 --- a/llama_stack/providers/remote/inference/fireworks/fireworks.py +++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py @@ -37,7 +37,6 @@ from llama_stack.apis.inference.inference import ( OpenAIChatCompletion, OpenAIChatCompletionChunk, OpenAICompletion, - OpenAIEmbeddingsResponse, OpenAIMessageParam, OpenAIResponseFormatParam, ) @@ -287,16 +286,6 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv embeddings = [data.embedding for data in response.data] return EmbeddingsResponse(embeddings=embeddings) - async def openai_embeddings( - self, - model: str, - input: str | list[str], - encoding_format: str | None = "float", - dimensions: int | None = None, - user: str | None = None, - ) -> OpenAIEmbeddingsResponse: - raise NotImplementedError() - async def openai_completion( self, model: str, diff --git a/llama_stack/providers/remote/inference/fireworks_openai_compat/__init__.py b/llama_stack/providers/remote/inference/fireworks_openai_compat/__init__.py index 15a666cb6..f78f218b5 100644 --- a/llama_stack/providers/remote/inference/fireworks_openai_compat/__init__.py +++ b/llama_stack/providers/remote/inference/fireworks_openai_compat/__init__.py @@ -4,12 +4,12 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from llama_stack.apis.inference import InferenceProvider +from llama_stack.apis.inference import Inference from .config import FireworksCompatConfig -async def get_adapter_impl(config: FireworksCompatConfig, _deps) -> InferenceProvider: +async def get_adapter_impl(config: FireworksCompatConfig, _deps) -> Inference: # import dynamically so the import is used only when it is needed from .fireworks import FireworksCompatInferenceAdapter diff --git a/llama_stack/providers/remote/inference/groq_openai_compat/__init__.py b/llama_stack/providers/remote/inference/groq_openai_compat/__init__.py index 794cdebd7..8161df20d 100644 --- a/llama_stack/providers/remote/inference/groq_openai_compat/__init__.py +++ b/llama_stack/providers/remote/inference/groq_openai_compat/__init__.py @@ -4,12 +4,12 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from llama_stack.apis.inference import InferenceProvider +from llama_stack.apis.inference import Inference from .config import GroqCompatConfig -async def get_adapter_impl(config: GroqCompatConfig, _deps) -> InferenceProvider: +async def get_adapter_impl(config: GroqCompatConfig, _deps) -> Inference: # import dynamically so the import is used only when it is needed from .groq import GroqCompatInferenceAdapter diff --git a/llama_stack/providers/remote/inference/llama_openai_compat/__init__.py b/llama_stack/providers/remote/inference/llama_openai_compat/__init__.py index be48d1067..a6fb37cad 100644 --- a/llama_stack/providers/remote/inference/llama_openai_compat/__init__.py +++ b/llama_stack/providers/remote/inference/llama_openai_compat/__init__.py @@ -4,12 +4,12 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from llama_stack.apis.inference import InferenceProvider +from llama_stack.apis.inference import Inference from .config import LlamaCompatConfig -async def get_adapter_impl(config: LlamaCompatConfig, _deps) -> InferenceProvider: +async def get_adapter_impl(config: LlamaCompatConfig, _deps) -> Inference: # import dynamically so the import is used only when it is needed from .llama import LlamaCompatInferenceAdapter diff --git a/llama_stack/providers/remote/inference/nvidia/nvidia.py b/llama_stack/providers/remote/inference/nvidia/nvidia.py index 4c68322e0..333486fe4 100644 --- a/llama_stack/providers/remote/inference/nvidia/nvidia.py +++ b/llama_stack/providers/remote/inference/nvidia/nvidia.py @@ -29,7 +29,6 @@ from llama_stack.apis.inference import ( Inference, LogProbConfig, Message, - OpenAIEmbeddingsResponse, ResponseFormat, SamplingParams, TextTruncation, @@ -239,16 +238,6 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): # return EmbeddingsResponse(embeddings=[embedding.embedding for embedding in response.data]) - async def openai_embeddings( - self, - model: str, - input: str | list[str], - encoding_format: str | None = "float", - dimensions: int | None = None, - user: str | None = None, - ) -> OpenAIEmbeddingsResponse: - raise NotImplementedError() - async def chat_completion( self, model_id: str, diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index 8863e0edc..72cf0d129 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -28,11 +28,10 @@ from llama_stack.apis.inference import ( EmbeddingsResponse, EmbeddingTaskType, GrammarResponseFormat, - InferenceProvider, + Inference, JsonSchemaResponseFormat, LogProbConfig, Message, - OpenAIEmbeddingsResponse, ResponseFormat, SamplingParams, TextTruncation, @@ -83,7 +82,7 @@ logger = get_logger(name=__name__, category="inference") class OllamaInferenceAdapter( - InferenceProvider, + Inference, ModelsProtocolPrivate, ): def __init__(self, url: str) -> None: @@ -371,16 +370,6 @@ class OllamaInferenceAdapter( return model - async def openai_embeddings( - self, - model: str, - input: str | list[str], - encoding_format: str | None = "float", - dimensions: int | None = None, - user: str | None = None, - ) -> OpenAIEmbeddingsResponse: - raise NotImplementedError() - async def openai_completion( self, model: str, diff --git a/llama_stack/providers/remote/inference/openai/openai.py b/llama_stack/providers/remote/inference/openai/openai.py index 6f3a686a8..9a1ec7ee0 100644 --- a/llama_stack/providers/remote/inference/openai/openai.py +++ b/llama_stack/providers/remote/inference/openai/openai.py @@ -14,9 +14,6 @@ from llama_stack.apis.inference.inference import ( OpenAIChatCompletion, OpenAIChatCompletionChunk, OpenAICompletion, - OpenAIEmbeddingData, - OpenAIEmbeddingsResponse, - OpenAIEmbeddingUsage, OpenAIMessageParam, OpenAIResponseFormatParam, ) @@ -41,7 +38,6 @@ logger = logging.getLogger(__name__) # | batch_chat_completion | LiteLLMOpenAIMixin | # | openai_completion | AsyncOpenAI | # | openai_chat_completion | AsyncOpenAI | -# | openai_embeddings | AsyncOpenAI | # class OpenAIInferenceAdapter(LiteLLMOpenAIMixin): def __init__(self, config: OpenAIConfig) -> None: @@ -96,11 +92,8 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin): if prompt_logprobs is not None: logging.warning("prompt_logprobs is not supported by the OpenAI API. Ignoring.") - model_id = (await self.model_store.get_model(model)).provider_resource_id - if model_id.startswith("openai/"): - model_id = model_id[len("openai/") :] params = await prepare_openai_completion_params( - model=model_id, + model=(await self.model_store.get_model(model)).provider_resource_id, prompt=prompt, best_of=best_of, echo=echo, @@ -146,11 +139,8 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin): top_p: float | None = None, user: str | None = None, ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: - model_id = (await self.model_store.get_model(model)).provider_resource_id - if model_id.startswith("openai/"): - model_id = model_id[len("openai/") :] params = await prepare_openai_completion_params( - model=model_id, + model=(await self.model_store.get_model(model)).provider_resource_id, messages=messages, frequency_penalty=frequency_penalty, function_call=function_call, @@ -175,51 +165,3 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin): user=user, ) return await self._openai_client.chat.completions.create(**params) - - async def openai_embeddings( - self, - model: str, - input: str | list[str], - encoding_format: str | None = "float", - dimensions: int | None = None, - user: str | None = None, - ) -> OpenAIEmbeddingsResponse: - model_id = (await self.model_store.get_model(model)).provider_resource_id - if model_id.startswith("openai/"): - model_id = model_id[len("openai/") :] - - # Prepare parameters for OpenAI embeddings API - params = { - "model": model_id, - "input": input, - } - - if encoding_format is not None: - params["encoding_format"] = encoding_format - if dimensions is not None: - params["dimensions"] = dimensions - if user is not None: - params["user"] = user - - # Call OpenAI embeddings API - response = await self._openai_client.embeddings.create(**params) - - data = [] - for i, embedding_data in enumerate(response.data): - data.append( - OpenAIEmbeddingData( - embedding=embedding_data.embedding, - index=i, - ) - ) - - usage = OpenAIEmbeddingUsage( - prompt_tokens=response.usage.prompt_tokens, - total_tokens=response.usage.total_tokens, - ) - - return OpenAIEmbeddingsResponse( - data=data, - model=response.model, - usage=usage, - ) diff --git a/llama_stack/providers/remote/inference/passthrough/passthrough.py b/llama_stack/providers/remote/inference/passthrough/passthrough.py index 6cf4680e2..78ee52641 100644 --- a/llama_stack/providers/remote/inference/passthrough/passthrough.py +++ b/llama_stack/providers/remote/inference/passthrough/passthrough.py @@ -19,7 +19,6 @@ from llama_stack.apis.inference import ( Inference, LogProbConfig, Message, - OpenAIEmbeddingsResponse, ResponseFormat, SamplingParams, TextTruncation, @@ -211,16 +210,6 @@ class PassthroughInferenceAdapter(Inference): task_type=task_type, ) - async def openai_embeddings( - self, - model: str, - input: str | list[str], - encoding_format: str | None = "float", - dimensions: int | None = None, - user: str | None = None, - ) -> OpenAIEmbeddingsResponse: - raise NotImplementedError() - async def openai_completion( self, model: str, diff --git a/llama_stack/providers/remote/inference/runpod/runpod.py b/llama_stack/providers/remote/inference/runpod/runpod.py index f8c98893e..2706aa15e 100644 --- a/llama_stack/providers/remote/inference/runpod/runpod.py +++ b/llama_stack/providers/remote/inference/runpod/runpod.py @@ -8,7 +8,6 @@ from collections.abc import AsyncGenerator from openai import OpenAI from llama_stack.apis.inference import * # noqa: F403 -from llama_stack.apis.inference.inference import OpenAIEmbeddingsResponse # from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper @@ -135,13 +134,3 @@ class RunpodInferenceAdapter( task_type: Optional[EmbeddingTaskType] = None, ) -> EmbeddingsResponse: raise NotImplementedError() - - async def openai_embeddings( - self, - model: str, - input: str | list[str], - encoding_format: str | None = "float", - dimensions: int | None = None, - user: str | None = None, - ) -> OpenAIEmbeddingsResponse: - raise NotImplementedError() diff --git a/llama_stack/providers/remote/inference/sambanova/sambanova.py b/llama_stack/providers/remote/inference/sambanova/sambanova.py index 20f863665..d182aa1dc 100644 --- a/llama_stack/providers/remote/inference/sambanova/sambanova.py +++ b/llama_stack/providers/remote/inference/sambanova/sambanova.py @@ -218,7 +218,7 @@ class SambaNovaInferenceAdapter(LiteLLMOpenAIMixin): "json_schema": { "name": name, "schema": fmt, - "strict": False, + "strict": True, }, } if request.tools: diff --git a/llama_stack/providers/remote/inference/sambanova_openai_compat/__init__.py b/llama_stack/providers/remote/inference/sambanova_openai_compat/__init__.py index 60afe91ca..e31a3364c 100644 --- a/llama_stack/providers/remote/inference/sambanova_openai_compat/__init__.py +++ b/llama_stack/providers/remote/inference/sambanova_openai_compat/__init__.py @@ -4,12 +4,12 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from llama_stack.apis.inference import InferenceProvider +from llama_stack.apis.inference import Inference from .config import SambaNovaCompatConfig -async def get_adapter_impl(config: SambaNovaCompatConfig, _deps) -> InferenceProvider: +async def get_adapter_impl(config: SambaNovaCompatConfig, _deps) -> Inference: # import dynamically so the import is used only when it is needed from .sambanova import SambaNovaCompatInferenceAdapter diff --git a/llama_stack/providers/remote/inference/tgi/tgi.py b/llama_stack/providers/remote/inference/tgi/tgi.py index 292d74ef8..8f6666462 100644 --- a/llama_stack/providers/remote/inference/tgi/tgi.py +++ b/llama_stack/providers/remote/inference/tgi/tgi.py @@ -23,7 +23,6 @@ from llama_stack.apis.inference import ( Inference, LogProbConfig, Message, - OpenAIEmbeddingsResponse, ResponseFormat, ResponseFormatType, SamplingParams, @@ -292,16 +291,6 @@ class _HfAdapter( ) -> EmbeddingsResponse: raise NotImplementedError() - async def openai_embeddings( - self, - model: str, - input: str | list[str], - encoding_format: str | None = "float", - dimensions: int | None = None, - user: str | None = None, - ) -> OpenAIEmbeddingsResponse: - raise NotImplementedError() - class TGIAdapter(_HfAdapter): async def initialize(self, config: TGIImplConfig) -> None: diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py index 7305a638d..562e6e0ff 100644 --- a/llama_stack/providers/remote/inference/together/together.py +++ b/llama_stack/providers/remote/inference/together/together.py @@ -23,7 +23,6 @@ from llama_stack.apis.inference import ( Inference, LogProbConfig, Message, - OpenAIEmbeddingsResponse, ResponseFormat, ResponseFormatType, SamplingParams, @@ -268,16 +267,6 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi embeddings = [item.embedding for item in r.data] return EmbeddingsResponse(embeddings=embeddings) - async def openai_embeddings( - self, - model: str, - input: str | list[str], - encoding_format: str | None = "float", - dimensions: int | None = None, - user: str | None = None, - ) -> OpenAIEmbeddingsResponse: - raise NotImplementedError() - async def openai_completion( self, model: str, diff --git a/llama_stack/providers/remote/inference/together_openai_compat/__init__.py b/llama_stack/providers/remote/inference/together_openai_compat/__init__.py index 8213fc5f4..6fdf05b7e 100644 --- a/llama_stack/providers/remote/inference/together_openai_compat/__init__.py +++ b/llama_stack/providers/remote/inference/together_openai_compat/__init__.py @@ -4,12 +4,12 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from llama_stack.apis.inference import InferenceProvider +from llama_stack.apis.inference import Inference from .config import TogetherCompatConfig -async def get_adapter_impl(config: TogetherCompatConfig, _deps) -> InferenceProvider: +async def get_adapter_impl(config: TogetherCompatConfig, _deps) -> Inference: # import dynamically so the import is used only when it is needed from .together import TogetherCompatInferenceAdapter diff --git a/llama_stack/providers/remote/inference/vllm/config.py b/llama_stack/providers/remote/inference/vllm/config.py index 99abddf51..8530594b6 100644 --- a/llama_stack/providers/remote/inference/vllm/config.py +++ b/llama_stack/providers/remote/inference/vllm/config.py @@ -4,9 +4,8 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from pathlib import Path -from pydantic import BaseModel, Field, field_validator +from pydantic import BaseModel, Field from llama_stack.schema_utils import json_schema_type @@ -25,27 +24,11 @@ class VLLMInferenceAdapterConfig(BaseModel): default="fake", description="The API token", ) - tls_verify: bool | str = Field( + tls_verify: bool = Field( default=True, - description="Whether to verify TLS certificates. Can be a boolean or a path to a CA certificate file.", + description="Whether to verify TLS certificates", ) - @field_validator("tls_verify") - @classmethod - def validate_tls_verify(cls, v): - if isinstance(v, str): - # Check if it's a boolean string - if v.lower() in ("true", "false"): - return v.lower() == "true" - # Otherwise, treat it as a cert path - cert_path = Path(v).expanduser().resolve() - if not cert_path.exists(): - raise ValueError(f"TLS certificate file does not exist: {v}") - if not cert_path.is_file(): - raise ValueError(f"TLS certificate path is not a file: {v}") - return v - return v - @classmethod def sample_run_config( cls, diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index 9f38d9abf..d00218dd5 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -38,7 +38,6 @@ from llama_stack.apis.inference import ( JsonSchemaResponseFormat, LogProbConfig, Message, - OpenAIEmbeddingsResponse, ResponseFormat, SamplingParams, TextTruncation, @@ -314,7 +313,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): return AsyncOpenAI( base_url=self.config.url, api_key=self.config.api_token, - http_client=httpx.AsyncClient(verify=self.config.tls_verify), + http_client=None if self.config.tls_verify else httpx.AsyncClient(verify=False), ) async def completion( @@ -508,16 +507,6 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): embeddings = [data.embedding for data in response.data] return EmbeddingsResponse(embeddings=embeddings) - async def openai_embeddings( - self, - model: str, - input: str | list[str], - encoding_format: str | None = "float", - dimensions: int | None = None, - user: str | None = None, - ) -> OpenAIEmbeddingsResponse: - raise NotImplementedError() - async def openai_completion( self, model: str, diff --git a/llama_stack/providers/remote/inference/watsonx/watsonx.py b/llama_stack/providers/remote/inference/watsonx/watsonx.py index 59f5f5562..c1299e11f 100644 --- a/llama_stack/providers/remote/inference/watsonx/watsonx.py +++ b/llama_stack/providers/remote/inference/watsonx/watsonx.py @@ -21,7 +21,6 @@ from llama_stack.apis.inference import ( Inference, LogProbConfig, Message, - OpenAIEmbeddingsResponse, ResponseFormat, SamplingParams, TextTruncation, @@ -261,16 +260,6 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper): ) -> EmbeddingsResponse: raise NotImplementedError("embedding is not supported for watsonx") - async def openai_embeddings( - self, - model: str, - input: str | list[str], - encoding_format: str | None = "float", - dimensions: int | None = None, - user: str | None = None, - ) -> OpenAIEmbeddingsResponse: - raise NotImplementedError() - async def openai_completion( self, model: str, diff --git a/llama_stack/providers/remote/post_training/nvidia/post_training.py b/llama_stack/providers/remote/post_training/nvidia/post_training.py index d839ffd6f..409818cb3 100644 --- a/llama_stack/providers/remote/post_training/nvidia/post_training.py +++ b/llama_stack/providers/remote/post_training/nvidia/post_training.py @@ -224,7 +224,7 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper): Parameters: training_config: TrainingConfig - Configuration for training - model: str - NeMo Customizer configuration name + model: str - Model identifier algorithm_config: Optional[AlgorithmConfig] - Algorithm-specific configuration checkpoint_dir: Optional[str] - Directory containing model checkpoints, ignored atm job_uuid: str - Unique identifier for the job, ignored atm @@ -299,6 +299,9 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper): User is informed about unsupported parameters via warnings. """ + # Map model to nvidia model name + # See `_MODEL_ENTRIES` for supported models + nvidia_model = self.get_provider_model_id(model) # Check for unsupported method parameters unsupported_method_params = [] @@ -344,7 +347,7 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper): # Prepare base job configuration job_config = { - "config": model, + "config": nvidia_model, "dataset": { "name": training_config["data_config"]["dataset_id"], "namespace": self.config.dataset_namespace, diff --git a/llama_stack/providers/remote/safety/sambanova/__init__.py b/llama_stack/providers/remote/safety/sambanova/__init__.py deleted file mode 100644 index bb9d15374..000000000 --- a/llama_stack/providers/remote/safety/sambanova/__init__.py +++ /dev/null @@ -1,18 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - - -from typing import Any - -from .config import SambaNovaSafetyConfig - - -async def get_adapter_impl(config: SambaNovaSafetyConfig, _deps) -> Any: - from .sambanova import SambaNovaSafetyAdapter - - impl = SambaNovaSafetyAdapter(config) - await impl.initialize() - return impl diff --git a/llama_stack/providers/remote/safety/sambanova/config.py b/llama_stack/providers/remote/safety/sambanova/config.py deleted file mode 100644 index 383cea244..000000000 --- a/llama_stack/providers/remote/safety/sambanova/config.py +++ /dev/null @@ -1,37 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from typing import Any - -from pydantic import BaseModel, Field, SecretStr - -from llama_stack.schema_utils import json_schema_type - - -class SambaNovaProviderDataValidator(BaseModel): - sambanova_api_key: str | None = Field( - default=None, - description="Sambanova Cloud API key", - ) - - -@json_schema_type -class SambaNovaSafetyConfig(BaseModel): - url: str = Field( - default="https://api.sambanova.ai/v1", - description="The URL for the SambaNova AI server", - ) - api_key: SecretStr | None = Field( - default=None, - description="The SambaNova cloud API Key", - ) - - @classmethod - def sample_run_config(cls, api_key: str = "${env.SAMBANOVA_API_KEY}", **kwargs) -> dict[str, Any]: - return { - "url": "https://api.sambanova.ai/v1", - "api_key": api_key, - } diff --git a/llama_stack/providers/remote/safety/sambanova/sambanova.py b/llama_stack/providers/remote/safety/sambanova/sambanova.py deleted file mode 100644 index 84c8267ae..000000000 --- a/llama_stack/providers/remote/safety/sambanova/sambanova.py +++ /dev/null @@ -1,100 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import json -import logging -from typing import Any - -import litellm -import requests - -from llama_stack.apis.inference import Message -from llama_stack.apis.safety import ( - RunShieldResponse, - Safety, - SafetyViolation, - ViolationLevel, -) -from llama_stack.apis.shields import Shield -from llama_stack.distribution.request_headers import NeedsRequestProviderData -from llama_stack.providers.datatypes import ShieldsProtocolPrivate -from llama_stack.providers.utils.inference.openai_compat import convert_message_to_openai_dict_new - -from .config import SambaNovaSafetyConfig - -logger = logging.getLogger(__name__) - -CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?" - - -class SambaNovaSafetyAdapter(Safety, ShieldsProtocolPrivate, NeedsRequestProviderData): - def __init__(self, config: SambaNovaSafetyConfig) -> None: - self.config = config - - async def initialize(self) -> None: - pass - - async def shutdown(self) -> None: - pass - - def _get_api_key(self) -> str: - config_api_key = self.config.api_key if self.config.api_key else None - if config_api_key: - return config_api_key.get_secret_value() - else: - provider_data = self.get_request_provider_data() - if provider_data is None or not provider_data.sambanova_api_key: - raise ValueError( - 'Pass Sambanova API Key in the header X-LlamaStack-Provider-Data as { "sambanova_api_key": }' - ) - return provider_data.sambanova_api_key - - async def register_shield(self, shield: Shield) -> None: - list_models_url = self.config.url + "/models" - try: - response = requests.get(list_models_url) - response.raise_for_status() - except requests.exceptions.RequestException as e: - raise RuntimeError(f"Request to {list_models_url} failed") from e - available_models = [model.get("id") for model in response.json().get("data", {})] - if ( - len(available_models) == 0 - or "guard" not in shield.provider_resource_id.lower() - or shield.provider_resource_id.split("sambanova/")[-1] not in available_models - ): - raise ValueError(f"Shield {shield.provider_resource_id} not found in SambaNova") - - async def run_shield( - self, shield_id: str, messages: list[Message], params: dict[str, Any] | None = None - ) -> RunShieldResponse: - shield = await self.shield_store.get_shield(shield_id) - if not shield: - raise ValueError(f"Shield {shield_id} not found") - - shield_params = shield.params - logger.debug(f"run_shield::{shield_params}::messages={messages}") - content_messages = [await convert_message_to_openai_dict_new(m) for m in messages] - logger.debug(f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:") - - response = litellm.completion( - model=shield.provider_resource_id, messages=content_messages, api_key=self._get_api_key() - ) - shield_message = response.choices[0].message.content - - if "unsafe" in shield_message.lower(): - user_message = CANNED_RESPONSE_TEXT - violation_type = shield_message.split("\n")[-1] - metadata = {"violation_type": violation_type} - - return RunShieldResponse( - violation=SafetyViolation( - user_message=user_message, - violation_level=ViolationLevel.ERROR, - metadata=metadata, - ) - ) - - return RunShieldResponse() diff --git a/llama_stack/providers/remote/tool_runtime/bing_search/bing_search.py b/llama_stack/providers/remote/tool_runtime/bing_search/bing_search.py index 7e82cb6d4..18bec463f 100644 --- a/llama_stack/providers/remote/tool_runtime/bing_search/bing_search.py +++ b/llama_stack/providers/remote/tool_runtime/bing_search/bing_search.py @@ -12,19 +12,19 @@ import httpx from llama_stack.apis.common.content_types import URL from llama_stack.apis.tools import ( ListToolDefsResponse, + Tool, ToolDef, - ToolGroup, ToolInvocationResult, ToolParameter, ToolRuntime, ) from llama_stack.distribution.request_headers import NeedsRequestProviderData -from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate +from llama_stack.providers.datatypes import ToolsProtocolPrivate from .config import BingSearchToolConfig -class BingSearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData): +class BingSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData): def __init__(self, config: BingSearchToolConfig): self.config = config self.url = "https://api.bing.microsoft.com/v7.0/search" @@ -32,10 +32,10 @@ class BingSearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsReq async def initialize(self): pass - async def register_toolgroup(self, toolgroup: ToolGroup) -> None: + async def register_tool(self, tool: Tool) -> None: pass - async def unregister_toolgroup(self, toolgroup_id: str) -> None: + async def unregister_tool(self, tool_id: str) -> None: return def _get_api_key(self) -> str: diff --git a/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py b/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py index b96b9e59c..355cb98b6 100644 --- a/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py +++ b/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py @@ -11,30 +11,30 @@ import httpx from llama_stack.apis.common.content_types import URL from llama_stack.apis.tools import ( ListToolDefsResponse, + Tool, ToolDef, - ToolGroup, ToolInvocationResult, ToolParameter, ToolRuntime, ) from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.models.llama.datatypes import BuiltinTool -from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate +from llama_stack.providers.datatypes import ToolsProtocolPrivate from .config import BraveSearchToolConfig -class BraveSearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData): +class BraveSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData): def __init__(self, config: BraveSearchToolConfig): self.config = config async def initialize(self): pass - async def register_toolgroup(self, toolgroup: ToolGroup) -> None: + async def register_tool(self, tool: Tool) -> None: pass - async def unregister_toolgroup(self, toolgroup_id: str) -> None: + async def unregister_tool(self, tool_id: str) -> None: return def _get_api_key(self) -> str: diff --git a/llama_stack/providers/remote/tool_runtime/model_context_protocol/__init__.py b/llama_stack/providers/remote/tool_runtime/model_context_protocol/__init__.py index 051a880a7..fb1f558e5 100644 --- a/llama_stack/providers/remote/tool_runtime/model_context_protocol/__init__.py +++ b/llama_stack/providers/remote/tool_runtime/model_context_protocol/__init__.py @@ -4,12 +4,18 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from .config import MCPProviderConfig +from pydantic import BaseModel + +from .config import ModelContextProtocolConfig -async def get_adapter_impl(config: MCPProviderConfig, _deps): +class ModelContextProtocolToolProviderDataValidator(BaseModel): + api_key: str + + +async def get_adapter_impl(config: ModelContextProtocolConfig, _deps): from .model_context_protocol import ModelContextProtocolToolRuntimeImpl - impl = ModelContextProtocolToolRuntimeImpl(config, _deps) + impl = ModelContextProtocolToolRuntimeImpl(config) await impl.initialize() return impl diff --git a/llama_stack/providers/remote/tool_runtime/model_context_protocol/config.py b/llama_stack/providers/remote/tool_runtime/model_context_protocol/config.py index b8c5e77fd..d509074fc 100644 --- a/llama_stack/providers/remote/tool_runtime/model_context_protocol/config.py +++ b/llama_stack/providers/remote/tool_runtime/model_context_protocol/config.py @@ -9,12 +9,7 @@ from typing import Any from pydantic import BaseModel -class MCPProviderDataValidator(BaseModel): - # mcp_endpoint => dict of headers to send - mcp_headers: dict[str, dict[str, str]] | None = None - - -class MCPProviderConfig(BaseModel): +class ModelContextProtocolConfig(BaseModel): @classmethod def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]: return {} diff --git a/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py b/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py index a9b252dfe..142730e89 100644 --- a/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py +++ b/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py @@ -7,45 +7,61 @@ from typing import Any from urllib.parse import urlparse +from mcp import ClientSession +from mcp.client.sse import sse_client + from llama_stack.apis.common.content_types import URL -from llama_stack.apis.datatypes import Api from llama_stack.apis.tools import ( ListToolDefsResponse, - ToolGroup, + ToolDef, ToolInvocationResult, + ToolParameter, ToolRuntime, ) -from llama_stack.distribution.request_headers import NeedsRequestProviderData -from llama_stack.log import get_logger -from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate -from llama_stack.providers.utils.tools.mcp import invoke_mcp_tool, list_mcp_tools +from llama_stack.providers.datatypes import ToolsProtocolPrivate -from .config import MCPProviderConfig - -logger = get_logger(__name__, category="tools") +from .config import ModelContextProtocolConfig -class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData): - def __init__(self, config: MCPProviderConfig, _deps: dict[Api, Any]): +class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime): + def __init__(self, config: ModelContextProtocolConfig): self.config = config async def initialize(self): pass - async def register_toolgroup(self, toolgroup: ToolGroup) -> None: - pass - - async def unregister_toolgroup(self, toolgroup_id: str) -> None: - return - async def list_runtime_tools( self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None ) -> ListToolDefsResponse: - # this endpoint should be retrieved by getting the tool group right? if mcp_endpoint is None: raise ValueError("mcp_endpoint is required") - headers = await self.get_headers_from_request(mcp_endpoint.uri) - return await list_mcp_tools(mcp_endpoint.uri, headers) + + tools = [] + async with sse_client(mcp_endpoint.uri) as streams: + async with ClientSession(*streams) as session: + await session.initialize() + tools_result = await session.list_tools() + for tool in tools_result.tools: + parameters = [] + for param_name, param_schema in tool.inputSchema.get("properties", {}).items(): + parameters.append( + ToolParameter( + name=param_name, + parameter_type=param_schema.get("type", "string"), + description=param_schema.get("description", ""), + ) + ) + tools.append( + ToolDef( + name=tool.name, + description=tool.description, + parameters=parameters, + metadata={ + "endpoint": mcp_endpoint.uri, + }, + ) + ) + return ListToolDefsResponse(data=tools) async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult: tool = await self.tool_store.get_tool(tool_name) @@ -55,19 +71,12 @@ class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime if urlparse(endpoint).scheme not in ("http", "https"): raise ValueError(f"Endpoint {endpoint} is not a valid HTTP(S) URL") - headers = await self.get_headers_from_request(endpoint) - return await invoke_mcp_tool(endpoint, headers, tool_name, kwargs) + async with sse_client(endpoint) as streams: + async with ClientSession(*streams) as session: + await session.initialize() + result = await session.call_tool(tool.identifier, kwargs) - async def get_headers_from_request(self, mcp_endpoint_uri: str) -> dict[str, str]: - def canonicalize_uri(uri: str) -> str: - return f"{urlparse(uri).netloc or ''}/{urlparse(uri).path or ''}" - - headers = {} - - provider_data = self.get_request_provider_data() - if provider_data and provider_data.mcp_headers: - for uri, values in provider_data.mcp_headers.items(): - if canonicalize_uri(uri) != canonicalize_uri(mcp_endpoint_uri): - continue - headers.update(values) - return headers + return ToolInvocationResult( + content="\n".join([result.model_dump_json() for result in result.content]), + error_code=1 if result.isError else 0, + ) diff --git a/llama_stack/providers/remote/tool_runtime/tavily_search/tavily_search.py b/llama_stack/providers/remote/tool_runtime/tavily_search/tavily_search.py index 1fe91fd7f..9d6fcd951 100644 --- a/llama_stack/providers/remote/tool_runtime/tavily_search/tavily_search.py +++ b/llama_stack/providers/remote/tool_runtime/tavily_search/tavily_search.py @@ -12,29 +12,29 @@ import httpx from llama_stack.apis.common.content_types import URL from llama_stack.apis.tools import ( ListToolDefsResponse, + Tool, ToolDef, - ToolGroup, ToolInvocationResult, ToolParameter, ToolRuntime, ) from llama_stack.distribution.request_headers import NeedsRequestProviderData -from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate +from llama_stack.providers.datatypes import ToolsProtocolPrivate from .config import TavilySearchToolConfig -class TavilySearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData): +class TavilySearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData): def __init__(self, config: TavilySearchToolConfig): self.config = config async def initialize(self): pass - async def register_toolgroup(self, toolgroup: ToolGroup) -> None: + async def register_tool(self, tool: Tool) -> None: pass - async def unregister_toolgroup(self, toolgroup_id: str) -> None: + async def unregister_tool(self, tool_id: str) -> None: return def _get_api_key(self) -> str: diff --git a/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py b/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py index 6e1d0f61d..a3724e4b4 100644 --- a/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py +++ b/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py @@ -12,19 +12,19 @@ import httpx from llama_stack.apis.common.content_types import URL from llama_stack.apis.tools import ( ListToolDefsResponse, + Tool, ToolDef, - ToolGroup, ToolInvocationResult, ToolParameter, ToolRuntime, ) from llama_stack.distribution.request_headers import NeedsRequestProviderData -from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate +from llama_stack.providers.datatypes import ToolsProtocolPrivate from .config import WolframAlphaToolConfig -class WolframAlphaToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData): +class WolframAlphaToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData): def __init__(self, config: WolframAlphaToolConfig): self.config = config self.url = "https://api.wolframalpha.com/v2/query" @@ -32,10 +32,10 @@ class WolframAlphaToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsR async def initialize(self): pass - async def register_toolgroup(self, toolgroup: ToolGroup) -> None: + async def register_tool(self, tool: Tool) -> None: pass - async def unregister_toolgroup(self, toolgroup_id: str) -> None: + async def unregister_tool(self, tool_id: str) -> None: return def _get_api_key(self) -> str: diff --git a/llama_stack/providers/remote/vector_io/chroma/chroma.py b/llama_stack/providers/remote/vector_io/chroma/chroma.py index a59a38573..a919963ab 100644 --- a/llama_stack/providers/remote/vector_io/chroma/chroma.py +++ b/llama_stack/providers/remote/vector_io/chroma/chroma.py @@ -84,14 +84,6 @@ class ChromaIndex(EmbeddingIndex): async def delete(self): await maybe_await(self.client.delete_collection(self.collection.name)) - async def query_keyword( - self, - query_string: str, - k: int, - score_threshold: float, - ) -> QueryChunksResponse: - raise NotImplementedError("Keyword search is not supported in Chroma") - class ChromaVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): def __init__( diff --git a/llama_stack/providers/remote/vector_io/milvus/milvus.py b/llama_stack/providers/remote/vector_io/milvus/milvus.py index 6628292db..c98417b56 100644 --- a/llama_stack/providers/remote/vector_io/milvus/milvus.py +++ b/llama_stack/providers/remote/vector_io/milvus/milvus.py @@ -73,7 +73,7 @@ class MilvusIndex(EmbeddingIndex): logger.error(f"Error inserting chunks into Milvus collection {self.collection_name}: {e}") raise e - async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: + async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: search_res = await asyncio.to_thread( self.client.search, collection_name=self.collection_name, @@ -86,14 +86,6 @@ class MilvusIndex(EmbeddingIndex): scores = [res["distance"] for res in search_res[0]] return QueryChunksResponse(chunks=chunks, scores=scores) - async def query_keyword( - self, - query_string: str, - k: int, - score_threshold: float, - ) -> QueryChunksResponse: - raise NotImplementedError("Keyword search is not supported in Milvus") - class MilvusVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): def __init__( diff --git a/llama_stack/providers/remote/vector_io/pgvector/pgvector.py b/llama_stack/providers/remote/vector_io/pgvector/pgvector.py index ea918c552..94546c6cf 100644 --- a/llama_stack/providers/remote/vector_io/pgvector/pgvector.py +++ b/llama_stack/providers/remote/vector_io/pgvector/pgvector.py @@ -99,7 +99,7 @@ class PGVectorIndex(EmbeddingIndex): with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: execute_values(cur, query, values, template="(%s, %s, %s::vector)") - async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: + async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: cur.execute( f""" @@ -120,14 +120,6 @@ class PGVectorIndex(EmbeddingIndex): return QueryChunksResponse(chunks=chunks, scores=scores) - async def query_keyword( - self, - query_string: str, - k: int, - score_threshold: float, - ) -> QueryChunksResponse: - raise NotImplementedError("Keyword search is not supported in PGVector") - async def delete(self): with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: cur.execute(f"DROP TABLE IF EXISTS {self.table_name}") diff --git a/llama_stack/providers/remote/vector_io/qdrant/qdrant.py b/llama_stack/providers/remote/vector_io/qdrant/qdrant.py index ff0690083..514a6c70d 100644 --- a/llama_stack/providers/remote/vector_io/qdrant/qdrant.py +++ b/llama_stack/providers/remote/vector_io/qdrant/qdrant.py @@ -68,7 +68,7 @@ class QdrantIndex(EmbeddingIndex): await self.client.upsert(collection_name=self.collection_name, points=points) - async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: + async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: results = ( await self.client.query_points( collection_name=self.collection_name, @@ -95,14 +95,6 @@ class QdrantIndex(EmbeddingIndex): return QueryChunksResponse(chunks=chunks, scores=scores) - async def query_keyword( - self, - query_string: str, - k: int, - score_threshold: float, - ) -> QueryChunksResponse: - raise NotImplementedError("Keyword search is not supported in Qdrant") - async def delete(self): await self.client.delete_collection(collection_name=self.collection_name) diff --git a/llama_stack/providers/remote/vector_io/weaviate/weaviate.py b/llama_stack/providers/remote/vector_io/weaviate/weaviate.py index e6fe8ccd3..308d2eb3d 100644 --- a/llama_stack/providers/remote/vector_io/weaviate/weaviate.py +++ b/llama_stack/providers/remote/vector_io/weaviate/weaviate.py @@ -55,7 +55,7 @@ class WeaviateIndex(EmbeddingIndex): # TODO: make this async friendly collection.data.insert_many(data_objects) - async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: + async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: collection = self.client.collections.get(self.collection_name) results = collection.query.near_vector( @@ -84,14 +84,6 @@ class WeaviateIndex(EmbeddingIndex): collection = self.client.collections.get(self.collection_name) collection.data.delete_many(where=Filter.by_property("id").contains_any(chunk_ids)) - async def query_keyword( - self, - query_string: str, - k: int, - score_threshold: float, - ) -> QueryChunksResponse: - raise NotImplementedError("Keyword search is not supported in Weaviate") - class WeaviateVectorIOAdapter( VectorIO, diff --git a/llama_stack/providers/utils/inference/embedding_mixin.py b/llama_stack/providers/utils/inference/embedding_mixin.py index 97cf87360..7c8144c62 100644 --- a/llama_stack/providers/utils/inference/embedding_mixin.py +++ b/llama_stack/providers/utils/inference/embedding_mixin.py @@ -4,9 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import base64 import logging -import struct from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -17,9 +15,6 @@ from llama_stack.apis.inference import ( EmbeddingTaskType, InterleavedContentItem, ModelStore, - OpenAIEmbeddingData, - OpenAIEmbeddingsResponse, - OpenAIEmbeddingUsage, TextTruncation, ) from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str @@ -48,50 +43,6 @@ class SentenceTransformerEmbeddingMixin: ) return EmbeddingsResponse(embeddings=embeddings) - async def openai_embeddings( - self, - model: str, - input: str | list[str], - encoding_format: str | None = "float", - dimensions: int | None = None, - user: str | None = None, - ) -> OpenAIEmbeddingsResponse: - # Convert input to list format if it's a single string - input_list = [input] if isinstance(input, str) else input - if not input_list: - raise ValueError("Empty list not supported") - - # Get the model and generate embeddings - model_obj = await self.model_store.get_model(model) - embedding_model = self._load_sentence_transformer_model(model_obj.provider_resource_id) - embeddings = embedding_model.encode(input_list, show_progress_bar=False) - - # Convert embeddings to the requested format - data = [] - for i, embedding in enumerate(embeddings): - if encoding_format == "base64": - # Convert float array to base64 string - float_bytes = struct.pack(f"{len(embedding)}f", *embedding) - embedding_value = base64.b64encode(float_bytes).decode("ascii") - else: - # Default to float format - embedding_value = embedding.tolist() - - data.append( - OpenAIEmbeddingData( - embedding=embedding_value, - index=i, - ) - ) - - # Not returning actual token usage - usage = OpenAIEmbeddingUsage(prompt_tokens=-1, total_tokens=-1) - return OpenAIEmbeddingsResponse( - data=data, - model=model_obj.provider_resource_id, - usage=usage, - ) - def _load_sentence_transformer_model(self, model: str) -> "SentenceTransformer": global EMBEDDING_MODELS diff --git a/llama_stack/providers/utils/inference/inference_store.py b/llama_stack/providers/utils/inference/inference_store.py deleted file mode 100644 index 7b6bc2e3d..000000000 --- a/llama_stack/providers/utils/inference/inference_store.py +++ /dev/null @@ -1,123 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. -from llama_stack.apis.inference import ( - ListOpenAIChatCompletionResponse, - OpenAIChatCompletion, - OpenAICompletionWithInputMessages, - OpenAIMessageParam, - Order, -) -from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR - -from ..sqlstore.api import ColumnDefinition, ColumnType -from ..sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig, sqlstore_impl - - -class InferenceStore: - def __init__(self, sql_store_config: SqlStoreConfig): - if not sql_store_config: - sql_store_config = SqliteSqlStoreConfig( - db_path=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(), - ) - self.sql_store_config = sql_store_config - self.sql_store = None - - async def initialize(self): - """Create the necessary tables if they don't exist.""" - self.sql_store = sqlstore_impl(self.sql_store_config) - await self.sql_store.create_table( - "chat_completions", - { - "id": ColumnDefinition(type=ColumnType.STRING, primary_key=True), - "created": ColumnType.INTEGER, - "model": ColumnType.STRING, - "choices": ColumnType.JSON, - "input_messages": ColumnType.JSON, - }, - ) - - async def store_chat_completion( - self, chat_completion: OpenAIChatCompletion, input_messages: list[OpenAIMessageParam] - ) -> None: - if not self.sql_store: - raise ValueError("Inference store is not initialized") - - data = chat_completion.model_dump() - - await self.sql_store.insert( - "chat_completions", - { - "id": data["id"], - "created": data["created"], - "model": data["model"], - "choices": data["choices"], - "input_messages": [message.model_dump() for message in input_messages], - }, - ) - - async def list_chat_completions( - self, - after: str | None = None, - limit: int | None = 50, - model: str | None = None, - order: Order | None = Order.desc, - ) -> ListOpenAIChatCompletionResponse: - """ - List chat completions from the database. - - :param after: The ID of the last chat completion to return. - :param limit: The maximum number of chat completions to return. - :param model: The model to filter by. - :param order: The order to sort the chat completions by. - """ - if not self.sql_store: - raise ValueError("Inference store is not initialized") - - # TODO: support after - if after: - raise NotImplementedError("After is not supported for SQLite") - if not order: - order = Order.desc - - rows = await self.sql_store.fetch_all( - "chat_completions", - where={"model": model} if model else None, - order_by=[("created", order.value)], - limit=limit, - ) - - data = [ - OpenAICompletionWithInputMessages( - id=row["id"], - created=row["created"], - model=row["model"], - choices=row["choices"], - input_messages=row["input_messages"], - ) - for row in rows - ] - return ListOpenAIChatCompletionResponse( - data=data, - # TODO: implement has_more - has_more=False, - first_id=data[0].id if data else "", - last_id=data[-1].id if data else "", - ) - - async def get_chat_completion(self, completion_id: str) -> OpenAICompletionWithInputMessages: - if not self.sql_store: - raise ValueError("Inference store is not initialized") - - row = await self.sql_store.fetch_one("chat_completions", where={"id": completion_id}) - if not row: - raise ValueError(f"Chat completion with id {completion_id} not found") from None - return OpenAICompletionWithInputMessages( - id=row["id"], - created=row["created"], - model=row["model"], - choices=row["choices"], - input_messages=row["input_messages"], - ) diff --git a/llama_stack/providers/utils/inference/litellm_openai_mixin.py b/llama_stack/providers/utils/inference/litellm_openai_mixin.py index dab10bc55..0a5c5e4f4 100644 --- a/llama_stack/providers/utils/inference/litellm_openai_mixin.py +++ b/llama_stack/providers/utils/inference/litellm_openai_mixin.py @@ -4,8 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import base64 -import struct from collections.abc import AsyncGenerator, AsyncIterator from typing import Any @@ -21,7 +19,7 @@ from llama_stack.apis.inference import ( ChatCompletionResponseStreamChunk, EmbeddingsResponse, EmbeddingTaskType, - InferenceProvider, + Inference, JsonSchemaResponseFormat, LogProbConfig, Message, @@ -37,9 +35,6 @@ from llama_stack.apis.inference.inference import ( OpenAIChatCompletion, OpenAIChatCompletionChunk, OpenAICompletion, - OpenAIEmbeddingData, - OpenAIEmbeddingsResponse, - OpenAIEmbeddingUsage, OpenAIMessageParam, OpenAIResponseFormatParam, ) @@ -64,7 +59,7 @@ logger = get_logger(name=__name__, category="inference") class LiteLLMOpenAIMixin( ModelRegistryHelper, - InferenceProvider, + Inference, NeedsRequestProviderData, ): # TODO: avoid exposing the litellm specific model names to the user. @@ -269,52 +264,6 @@ class LiteLLMOpenAIMixin( embeddings = [data["embedding"] for data in response["data"]] return EmbeddingsResponse(embeddings=embeddings) - async def openai_embeddings( - self, - model: str, - input: str | list[str], - encoding_format: str | None = "float", - dimensions: int | None = None, - user: str | None = None, - ) -> OpenAIEmbeddingsResponse: - model_obj = await self.model_store.get_model(model) - - # Convert input to list if it's a string - input_list = [input] if isinstance(input, str) else input - - # Call litellm embedding function - # litellm.drop_params = True - response = litellm.embedding( - model=self.get_litellm_model_name(model_obj.provider_resource_id), - input=input_list, - api_key=self.get_api_key(), - api_base=self.api_base, - dimensions=dimensions, - ) - - # Convert response to OpenAI format - data = [] - for i, embedding_data in enumerate(response["data"]): - # we encode to base64 if the encoding format is base64 in the request - if encoding_format == "base64": - byte_data = b"".join(struct.pack("f", f) for f in embedding_data["embedding"]) - embedding = base64.b64encode(byte_data).decode("utf-8") - else: - embedding = embedding_data["embedding"] - - data.append(OpenAIEmbeddingData(embedding=embedding, index=i)) - - usage = OpenAIEmbeddingUsage( - prompt_tokens=response["usage"]["prompt_tokens"], - total_tokens=response["usage"]["total_tokens"], - ) - - return OpenAIEmbeddingsResponse( - data=data, - model=model_obj.provider_resource_id, - usage=usage, - ) - async def openai_completion( self, model: str, diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index 049f06fdb..cc0000528 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -1402,8 +1402,9 @@ class OpenAIChatCompletionToLlamaStackMixin: outstanding_responses: list[Awaitable[AsyncIterator[ChatCompletionResponseStreamChunk]]], ): id = f"chatcmpl-{uuid.uuid4()}" - for i, outstanding_response in enumerate(outstanding_responses): + for outstanding_response in outstanding_responses: response = await outstanding_response + i = 0 async for chunk in response: event = chunk.event finish_reason = _convert_stop_reason_to_openai_finish_reason(event.stop_reason) @@ -1458,6 +1459,7 @@ class OpenAIChatCompletionToLlamaStackMixin: model=model, object="chat.completion.chunk", ) + i = i + 1 async def _process_non_stream_response( self, model: str, outstanding_responses: list[Awaitable[ChatCompletionResponse]] diff --git a/llama_stack/providers/utils/inference/stream_utils.py b/llama_stack/providers/utils/inference/stream_utils.py deleted file mode 100644 index a2edbb9c8..000000000 --- a/llama_stack/providers/utils/inference/stream_utils.py +++ /dev/null @@ -1,129 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from collections.abc import AsyncIterator -from datetime import datetime, timezone -from typing import Any - -from llama_stack.apis.inference import ( - OpenAIAssistantMessageParam, - OpenAIChatCompletion, - OpenAIChatCompletionChunk, - OpenAIChatCompletionToolCall, - OpenAIChatCompletionToolCallFunction, - OpenAIChoice, - OpenAIChoiceLogprobs, - OpenAIMessageParam, -) -from llama_stack.providers.utils.inference.inference_store import InferenceStore - - -async def stream_and_store_openai_completion( - provider_stream: AsyncIterator[OpenAIChatCompletionChunk], - model: str, - store: InferenceStore, - input_messages: list[OpenAIMessageParam], -) -> AsyncIterator[OpenAIChatCompletionChunk]: - """ - Wraps a provider's stream, yields chunks, and stores the full completion at the end. - """ - id = None - created = None - choices_data: dict[int, dict[str, Any]] = {} - - try: - async for chunk in provider_stream: - if id is None and chunk.id: - id = chunk.id - if created is None and chunk.created: - created = chunk.created - - if chunk.choices: - for choice_delta in chunk.choices: - idx = choice_delta.index - if idx not in choices_data: - choices_data[idx] = { - "content_parts": [], - "tool_calls_builder": {}, - "finish_reason": None, - "logprobs_content_parts": [], - } - current_choice_data = choices_data[idx] - - if choice_delta.delta: - delta = choice_delta.delta - if delta.content: - current_choice_data["content_parts"].append(delta.content) - if delta.tool_calls: - for tool_call_delta in delta.tool_calls: - tc_idx = tool_call_delta.index - if tc_idx not in current_choice_data["tool_calls_builder"]: - # Initialize with correct structure for _ToolCallBuilderData - current_choice_data["tool_calls_builder"][tc_idx] = { - "id": None, - "type": "function", - "function_name_parts": [], - "function_arguments_parts": [], - } - builder = current_choice_data["tool_calls_builder"][tc_idx] - if tool_call_delta.id: - builder["id"] = tool_call_delta.id - if tool_call_delta.type: - builder["type"] = tool_call_delta.type - if tool_call_delta.function: - if tool_call_delta.function.name: - builder["function_name_parts"].append(tool_call_delta.function.name) - if tool_call_delta.function.arguments: - builder["function_arguments_parts"].append(tool_call_delta.function.arguments) - if choice_delta.finish_reason: - current_choice_data["finish_reason"] = choice_delta.finish_reason - if choice_delta.logprobs and choice_delta.logprobs.content: - # Ensure that we are extending with the correct type - current_choice_data["logprobs_content_parts"].extend(choice_delta.logprobs.content) - yield chunk - finally: - if id: - assembled_choices: list[OpenAIChoice] = [] - for choice_idx, choice_data in choices_data.items(): - content_str = "".join(choice_data["content_parts"]) - assembled_tool_calls: list[OpenAIChatCompletionToolCall] = [] - if choice_data["tool_calls_builder"]: - for tc_build_data in choice_data["tool_calls_builder"].values(): - if tc_build_data["id"]: - func_name = "".join(tc_build_data["function_name_parts"]) - func_args = "".join(tc_build_data["function_arguments_parts"]) - assembled_tool_calls.append( - OpenAIChatCompletionToolCall( - id=tc_build_data["id"], - type=tc_build_data["type"], # No or "function" needed, already set - function=OpenAIChatCompletionToolCallFunction(name=func_name, arguments=func_args), - ) - ) - message = OpenAIAssistantMessageParam( - role="assistant", - content=content_str if content_str else None, - tool_calls=assembled_tool_calls if assembled_tool_calls else None, - ) - logprobs_content = choice_data["logprobs_content_parts"] - final_logprobs = OpenAIChoiceLogprobs(content=logprobs_content) if logprobs_content else None - - assembled_choices.append( - OpenAIChoice( - finish_reason=choice_data["finish_reason"], - index=choice_idx, - message=message, - logprobs=final_logprobs, - ) - ) - - final_response = OpenAIChatCompletion( - id=id, - choices=assembled_choices, - created=created or int(datetime.now(timezone.utc).timestamp()), - model=model, - object="chat.completion", - ) - await store.store_chat_completion(final_response, input_messages) diff --git a/llama_stack/providers/utils/kvstore/config.py b/llama_stack/providers/utils/kvstore/config.py index bbb0c5c0a..e9aac6e8c 100644 --- a/llama_stack/providers/utils/kvstore/config.py +++ b/llama_stack/providers/utils/kvstore/config.py @@ -65,7 +65,7 @@ class SqliteKVStoreConfig(CommonConfig): class PostgresKVStoreConfig(CommonConfig): type: Literal[KVStoreType.postgres.value] = KVStoreType.postgres.value host: str = "localhost" - port: str = "5432" + port: int = 5432 db: str = "llamastack" user: str password: str | None = None diff --git a/llama_stack/providers/utils/memory/vector_store.py b/llama_stack/providers/utils/memory/vector_store.py index 4cd15860b..e0e9d0679 100644 --- a/llama_stack/providers/utils/memory/vector_store.py +++ b/llama_stack/providers/utils/memory/vector_store.py @@ -171,33 +171,13 @@ def make_overlapped_chunks( return chunks -def _validate_embedding(embedding: NDArray, index: int, expected_dimension: int): - """Helper method to validate embedding format and dimensions""" - if not isinstance(embedding, (list | np.ndarray)): - raise ValueError(f"Embedding at index {index} must be a list or numpy array, got {type(embedding)}") - - if isinstance(embedding, np.ndarray): - if not np.issubdtype(embedding.dtype, np.number): - raise ValueError(f"Embedding at index {index} contains non-numeric values") - else: - if not all(isinstance(e, (float | int | np.number)) for e in embedding): - raise ValueError(f"Embedding at index {index} contains non-numeric values") - - if len(embedding) != expected_dimension: - raise ValueError(f"Embedding at index {index} has dimension {len(embedding)}, expected {expected_dimension}") - - class EmbeddingIndex(ABC): @abstractmethod async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray): raise NotImplementedError() @abstractmethod - async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: - raise NotImplementedError() - - @abstractmethod - async def query_keyword(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse: + async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: raise NotImplementedError() @abstractmethod @@ -215,22 +195,11 @@ class VectorDBWithIndex: self, chunks: list[Chunk], ) -> None: - chunks_to_embed = [] - for i, c in enumerate(chunks): - if c.embedding is None: - chunks_to_embed.append(c) - else: - _validate_embedding(c.embedding, i, self.vector_db.embedding_dimension) + embeddings_response = await self.inference_api.embeddings( + self.vector_db.embedding_model, [x.content for x in chunks] + ) + embeddings = np.array(embeddings_response.embeddings) - if chunks_to_embed: - resp = await self.inference_api.embeddings( - self.vector_db.embedding_model, - [c.content for c in chunks_to_embed], - ) - for c, embedding in zip(chunks_to_embed, resp.embeddings, strict=False): - c.embedding = embedding - - embeddings = np.array([c.embedding for c in chunks], dtype=np.float32) await self.index.add_chunks(chunks, embeddings) async def query_chunks( @@ -241,12 +210,9 @@ class VectorDBWithIndex: if params is None: params = {} k = params.get("max_chunks", 3) - mode = params.get("mode") score_threshold = params.get("score_threshold", 0.0) - query_string = interleaved_content_as_str(query) - if mode == "keyword": - return await self.index.query_keyword(query_string, k, score_threshold) - else: - embeddings_response = await self.inference_api.embeddings(self.vector_db.embedding_model, [query_string]) - query_vector = np.array(embeddings_response.embeddings[0], dtype=np.float32) - return await self.index.query_vector(query_vector, k, score_threshold) + + query_str = interleaved_content_as_str(query) + embeddings_response = await self.inference_api.embeddings(self.vector_db.embedding_model, [query_str]) + query_vector = np.array(embeddings_response.embeddings[0], dtype=np.float32) + return await self.index.query(query_vector, k, score_threshold) diff --git a/llama_stack/providers/utils/responses/responses_store.py b/llama_stack/providers/utils/responses/responses_store.py deleted file mode 100644 index 15354e3e2..000000000 --- a/llama_stack/providers/utils/responses/responses_store.py +++ /dev/null @@ -1,135 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. -from llama_stack.apis.agents import ( - Order, -) -from llama_stack.apis.agents.openai_responses import ( - ListOpenAIResponseInputItem, - ListOpenAIResponseObject, - OpenAIResponseInput, - OpenAIResponseObject, - OpenAIResponseObjectWithInput, -) -from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR - -from ..sqlstore.api import ColumnDefinition, ColumnType -from ..sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig, sqlstore_impl - - -class ResponsesStore: - def __init__(self, sql_store_config: SqlStoreConfig): - if not sql_store_config: - sql_store_config = SqliteSqlStoreConfig( - db_path=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(), - ) - self.sql_store = sqlstore_impl(sql_store_config) - - async def initialize(self): - """Create the necessary tables if they don't exist.""" - await self.sql_store.create_table( - "openai_responses", - { - "id": ColumnDefinition(type=ColumnType.STRING, primary_key=True), - "created_at": ColumnType.INTEGER, - "response_object": ColumnType.JSON, - "model": ColumnType.STRING, - }, - ) - - async def store_response_object( - self, response_object: OpenAIResponseObject, input: list[OpenAIResponseInput] - ) -> None: - data = response_object.model_dump() - data["input"] = [input_item.model_dump() for input_item in input] - - await self.sql_store.insert( - "openai_responses", - { - "id": data["id"], - "created_at": data["created_at"], - "model": data["model"], - "response_object": data, - }, - ) - - async def list_responses( - self, - after: str | None = None, - limit: int | None = 50, - model: str | None = None, - order: Order | None = Order.desc, - ) -> ListOpenAIResponseObject: - """ - List responses from the database. - - :param after: The ID of the last response to return. - :param limit: The maximum number of responses to return. - :param model: The model to filter by. - :param order: The order to sort the responses by. - """ - # TODO: support after - if after: - raise NotImplementedError("After is not supported for SQLite") - if not order: - order = Order.desc - - rows = await self.sql_store.fetch_all( - "openai_responses", - where={"model": model} if model else None, - order_by=[("created_at", order.value)], - limit=limit, - ) - - data = [OpenAIResponseObjectWithInput(**row["response_object"]) for row in rows] - return ListOpenAIResponseObject( - data=data, - # TODO: implement has_more - has_more=False, - first_id=data[0].id if data else "", - last_id=data[-1].id if data else "", - ) - - async def get_response_object(self, response_id: str) -> OpenAIResponseObjectWithInput: - row = await self.sql_store.fetch_one("openai_responses", where={"id": response_id}) - if not row: - raise ValueError(f"Response with id {response_id} not found") from None - return OpenAIResponseObjectWithInput(**row["response_object"]) - - async def list_response_input_items( - self, - response_id: str, - after: str | None = None, - before: str | None = None, - include: list[str] | None = None, - limit: int | None = 20, - order: Order | None = Order.desc, - ) -> ListOpenAIResponseInputItem: - """ - List input items for a given response. - - :param response_id: The ID of the response to retrieve input items for. - :param after: An item ID to list items after, used for pagination. - :param before: An item ID to list items before, used for pagination. - :param include: Additional fields to include in the response. - :param limit: A limit on the number of objects to be returned. - :param order: The order to return the input items in. - """ - # TODO: support after/before pagination - if after or before: - raise NotImplementedError("After/before pagination is not supported yet") - if include: - raise NotImplementedError("Include is not supported yet") - - response_with_input = await self.get_response_object(response_id) - input_items = response_with_input.input - - if order == Order.desc: - input_items = list(reversed(input_items)) - - if limit is not None and len(input_items) > limit: - input_items = input_items[:limit] - - return ListOpenAIResponseInputItem(data=input_items) diff --git a/llama_stack/providers/utils/sqlstore/api.py b/llama_stack/providers/utils/sqlstore/api.py deleted file mode 100644 index ace40e4c4..000000000 --- a/llama_stack/providers/utils/sqlstore/api.py +++ /dev/null @@ -1,90 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from collections.abc import Mapping -from enum import Enum -from typing import Any, Literal, Protocol - -from pydantic import BaseModel - - -class ColumnType(Enum): - INTEGER = "INTEGER" - STRING = "STRING" - TEXT = "TEXT" - FLOAT = "FLOAT" - BOOLEAN = "BOOLEAN" - JSON = "JSON" - DATETIME = "DATETIME" - - -class ColumnDefinition(BaseModel): - type: ColumnType - primary_key: bool = False - nullable: bool = True - default: Any = None - - -class SqlStore(Protocol): - """ - A protocol for a SQL store. - """ - - async def create_table(self, table: str, schema: Mapping[str, ColumnType | ColumnDefinition]) -> None: - """ - Create a table. - """ - pass - - async def insert(self, table: str, data: Mapping[str, Any]) -> None: - """ - Insert a row into a table. - """ - pass - - async def fetch_all( - self, - table: str, - where: Mapping[str, Any] | None = None, - limit: int | None = None, - order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None, - ) -> list[dict[str, Any]]: - """ - Fetch all rows from a table. - """ - pass - - async def fetch_one( - self, - table: str, - where: Mapping[str, Any] | None = None, - order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None, - ) -> dict[str, Any] | None: - """ - Fetch one row from a table. - """ - pass - - async def update( - self, - table: str, - data: Mapping[str, Any], - where: Mapping[str, Any], - ) -> None: - """ - Update a row in a table. - """ - pass - - async def delete( - self, - table: str, - where: Mapping[str, Any], - ) -> None: - """ - Delete a row from a table. - """ - pass diff --git a/llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py b/llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py deleted file mode 100644 index 825220679..000000000 --- a/llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py +++ /dev/null @@ -1,163 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. -from collections.abc import Mapping -from typing import Any, Literal - -from sqlalchemy import ( - JSON, - Boolean, - Column, - DateTime, - Float, - Integer, - MetaData, - String, - Table, - Text, - select, -) -from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine - -from .api import ColumnDefinition, ColumnType, SqlStore -from .sqlstore import SqlAlchemySqlStoreConfig - -TYPE_MAPPING: dict[ColumnType, Any] = { - ColumnType.INTEGER: Integer, - ColumnType.STRING: String, - ColumnType.FLOAT: Float, - ColumnType.BOOLEAN: Boolean, - ColumnType.DATETIME: DateTime, - ColumnType.TEXT: Text, - ColumnType.JSON: JSON, -} - - -class SqlAlchemySqlStoreImpl(SqlStore): - def __init__(self, config: SqlAlchemySqlStoreConfig): - self.config = config - self.async_session = async_sessionmaker(create_async_engine(config.engine_str)) - self.metadata = MetaData() - - async def create_table( - self, - table: str, - schema: Mapping[str, ColumnType | ColumnDefinition], - ) -> None: - if not schema: - raise ValueError(f"No columns defined for table '{table}'.") - - sqlalchemy_columns: list[Column] = [] - - for col_name, col_props in schema.items(): - col_type = None - is_primary_key = False - is_nullable = True # Default to nullable - - if isinstance(col_props, ColumnType): - col_type = col_props - elif isinstance(col_props, ColumnDefinition): - col_type = col_props.type - is_primary_key = col_props.primary_key - is_nullable = col_props.nullable - - sqlalchemy_type = TYPE_MAPPING.get(col_type) - if not sqlalchemy_type: - raise ValueError(f"Unsupported column type '{col_type}' for column '{col_name}'.") - - sqlalchemy_columns.append( - Column(col_name, sqlalchemy_type, primary_key=is_primary_key, nullable=is_nullable) - ) - - # Check if table already exists in metadata, otherwise define it - if table not in self.metadata.tables: - sqlalchemy_table = Table(table, self.metadata, *sqlalchemy_columns) - else: - sqlalchemy_table = self.metadata.tables[table] - - # Create the table in the database if it doesn't exist - # checkfirst=True ensures it doesn't try to recreate if it's already there - engine = create_async_engine(self.config.engine_str) - async with engine.begin() as conn: - await conn.run_sync(self.metadata.create_all, tables=[sqlalchemy_table], checkfirst=True) - - async def insert(self, table: str, data: Mapping[str, Any]) -> None: - async with self.async_session() as session: - await session.execute(self.metadata.tables[table].insert(), data) - await session.commit() - - async def fetch_all( - self, - table: str, - where: Mapping[str, Any] | None = None, - limit: int | None = None, - order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None, - ) -> list[dict[str, Any]]: - async with self.async_session() as session: - query = select(self.metadata.tables[table]) - if where: - for key, value in where.items(): - query = query.where(self.metadata.tables[table].c[key] == value) - if limit: - query = query.limit(limit) - if order_by: - if not isinstance(order_by, list): - raise ValueError( - f"order_by must be a list of tuples (column, order={['asc', 'desc']}), got {order_by}" - ) - for order in order_by: - if not isinstance(order, tuple): - raise ValueError( - f"order_by must be a list of tuples (column, order={['asc', 'desc']}), got {order_by}" - ) - name, order_type = order - if order_type == "asc": - query = query.order_by(self.metadata.tables[table].c[name].asc()) - elif order_type == "desc": - query = query.order_by(self.metadata.tables[table].c[name].desc()) - else: - raise ValueError(f"Invalid order '{order_type}' for column '{name}'") - result = await session.execute(query) - if result.rowcount == 0: - return [] - return [dict(row._mapping) for row in result] - - async def fetch_one( - self, - table: str, - where: Mapping[str, Any] | None = None, - order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None, - ) -> dict[str, Any] | None: - rows = await self.fetch_all(table, where, limit=1, order_by=order_by) - if not rows: - return None - return rows[0] - - async def update( - self, - table: str, - data: Mapping[str, Any], - where: Mapping[str, Any], - ) -> None: - if not where: - raise ValueError("where is required for update") - - async with self.async_session() as session: - stmt = self.metadata.tables[table].update() - for key, value in where.items(): - stmt = stmt.where(self.metadata.tables[table].c[key] == value) - await session.execute(stmt, data) - await session.commit() - - async def delete(self, table: str, where: Mapping[str, Any]) -> None: - if not where: - raise ValueError("where is required for delete") - - async with self.async_session() as session: - stmt = self.metadata.tables[table].delete() - for key, value in where.items(): - stmt = stmt.where(self.metadata.tables[table].c[key] == value) - await session.execute(stmt) - await session.commit() diff --git a/llama_stack/providers/utils/sqlstore/sqlstore.py b/llama_stack/providers/utils/sqlstore/sqlstore.py deleted file mode 100644 index 3091e8f96..000000000 --- a/llama_stack/providers/utils/sqlstore/sqlstore.py +++ /dev/null @@ -1,90 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - - -from abc import abstractmethod -from enum import Enum -from pathlib import Path -from typing import Annotated, Literal - -from pydantic import BaseModel, Field - -from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR - -from .api import SqlStore - - -class SqlStoreType(Enum): - sqlite = "sqlite" - postgres = "postgres" - - -class SqlAlchemySqlStoreConfig(BaseModel): - @property - @abstractmethod - def engine_str(self) -> str: ... - - # TODO: move this when we have a better way to specify dependencies with internal APIs - @property - def pip_packages(self) -> list[str]: - return ["sqlalchemy[asyncio]"] - - -class SqliteSqlStoreConfig(SqlAlchemySqlStoreConfig): - type: Literal["sqlite"] = SqlStoreType.sqlite.value - db_path: str = Field( - default=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(), - description="Database path, e.g. ~/.llama/distributions/ollama/sqlstore.db", - ) - - @property - def engine_str(self) -> str: - return "sqlite+aiosqlite:///" + Path(self.db_path).expanduser().as_posix() - - @classmethod - def sample_run_config(cls, __distro_dir__: str, db_name: str = "sqlstore.db"): - return cls( - type="sqlite", - db_path="${env.SQLITE_STORE_DIR:" + __distro_dir__ + "}/" + db_name, - ) - - @property - def pip_packages(self) -> list[str]: - return super().pip_packages + ["aiosqlite"] - - -class PostgresSqlStoreConfig(SqlAlchemySqlStoreConfig): - type: Literal["postgres"] = SqlStoreType.postgres.value - host: str = "localhost" - port: str = "5432" - db: str = "llamastack" - user: str - password: str | None = None - - @property - def engine_str(self) -> str: - return f"postgresql+asyncpg://{self.user}:{self.password}@{self.host}:{self.port}/{self.db}" - - @property - def pip_packages(self) -> list[str]: - return super().pip_packages + ["asyncpg"] - - -SqlStoreConfig = Annotated[ - SqliteSqlStoreConfig | PostgresSqlStoreConfig, - Field(discriminator="type", default=SqlStoreType.sqlite.value), -] - - -def sqlstore_impl(config: SqlStoreConfig) -> SqlStore: - if config.type in [SqlStoreType.sqlite.value, SqlStoreType.postgres.value]: - from .sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl - - impl = SqlAlchemySqlStoreImpl(config) - else: - raise ValueError(f"Unknown sqlstore type {config.type}") - - return impl diff --git a/llama_stack/providers/utils/telemetry/tracing.py b/llama_stack/providers/utils/telemetry/tracing.py index 4edfa6516..0f4fdd0d8 100644 --- a/llama_stack/providers/utils/telemetry/tracing.py +++ b/llama_stack/providers/utils/telemetry/tracing.py @@ -34,8 +34,6 @@ logger = get_logger(__name__, category="core") INVALID_SPAN_ID = 0x0000000000000000 INVALID_TRACE_ID = 0x00000000000000000000000000000000 -ROOT_SPAN_MARKERS = ["__root__", "__root_span__"] - def trace_id_to_str(trace_id: int) -> str: """Convenience trace ID formatting method @@ -180,8 +178,7 @@ async def start_trace(name: str, attributes: dict[str, Any] = None) -> TraceCont trace_id = generate_trace_id() context = TraceContext(BACKGROUND_LOGGER, trace_id) - attributes = {marker: True for marker in ROOT_SPAN_MARKERS} | (attributes or {}) - context.push_span(name, attributes) + context.push_span(name, {"__root__": True, **(attributes or {})}) CURRENT_TRACE_CONTEXT.set(context) return context diff --git a/llama_stack/providers/utils/tools/mcp.py b/llama_stack/providers/utils/tools/mcp.py deleted file mode 100644 index f024693a0..000000000 --- a/llama_stack/providers/utils/tools/mcp.py +++ /dev/null @@ -1,100 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from contextlib import asynccontextmanager -from typing import Any - -try: - # for python < 3.11 - import exceptiongroup - - BaseExceptionGroup = exceptiongroup.BaseExceptionGroup -except ImportError: - pass - -import httpx -from mcp import ClientSession -from mcp import types as mcp_types -from mcp.client.sse import sse_client - -from llama_stack.apis.common.content_types import ImageContentItem, InterleavedContentItem, TextContentItem -from llama_stack.apis.tools import ( - ListToolDefsResponse, - ToolDef, - ToolInvocationResult, - ToolParameter, -) -from llama_stack.distribution.datatypes import AuthenticationRequiredError -from llama_stack.log import get_logger - -logger = get_logger(__name__, category="tools") - - -@asynccontextmanager -async def sse_client_wrapper(endpoint: str, headers: dict[str, str]): - try: - async with sse_client(endpoint, headers=headers) as streams: - async with ClientSession(*streams) as session: - await session.initialize() - yield session - except BaseException as e: - if isinstance(e, BaseExceptionGroup): - for exc in e.exceptions: - if isinstance(exc, httpx.HTTPStatusError) and exc.response.status_code == 401: - raise AuthenticationRequiredError(exc) from exc - elif isinstance(e, httpx.HTTPStatusError) and e.response.status_code == 401: - raise AuthenticationRequiredError(e) from e - - raise - - -async def list_mcp_tools(endpoint: str, headers: dict[str, str]) -> ListToolDefsResponse: - tools = [] - async with sse_client_wrapper(endpoint, headers) as session: - tools_result = await session.list_tools() - for tool in tools_result.tools: - parameters = [] - for param_name, param_schema in tool.inputSchema.get("properties", {}).items(): - parameters.append( - ToolParameter( - name=param_name, - parameter_type=param_schema.get("type", "string"), - description=param_schema.get("description", ""), - ) - ) - tools.append( - ToolDef( - name=tool.name, - description=tool.description, - parameters=parameters, - metadata={ - "endpoint": endpoint, - }, - ) - ) - return ListToolDefsResponse(data=tools) - - -async def invoke_mcp_tool( - endpoint: str, headers: dict[str, str], tool_name: str, kwargs: dict[str, Any] -) -> ToolInvocationResult: - async with sse_client_wrapper(endpoint, headers) as session: - result = await session.call_tool(tool_name, kwargs) - - content: list[InterleavedContentItem] = [] - for item in result.content: - if isinstance(item, mcp_types.TextContent): - content.append(TextContentItem(text=item.text)) - elif isinstance(item, mcp_types.ImageContent): - content.append(ImageContentItem(image=item.data)) - elif isinstance(item, mcp_types.EmbeddedResource): - logger.warning(f"EmbeddedResource is not supported: {item}") - else: - raise ValueError(f"Unknown content type: {type(item)}") - return ToolInvocationResult( - content=content, - error_code=1 if result.isError else 0, - ) diff --git a/llama_stack/templates/bedrock/build.yaml b/llama_stack/templates/bedrock/build.yaml index 97a06f77a..46d5b9c69 100644 --- a/llama_stack/templates/bedrock/build.yaml +++ b/llama_stack/templates/bedrock/build.yaml @@ -29,6 +29,3 @@ distribution_spec: - inline::rag-runtime - remote::model-context-protocol image_type: conda -additional_pip_packages: -- aiosqlite -- sqlalchemy[asyncio] diff --git a/llama_stack/templates/bedrock/run.yaml b/llama_stack/templates/bedrock/run.yaml index a58068a60..30599a6c0 100644 --- a/llama_stack/templates/bedrock/run.yaml +++ b/llama_stack/templates/bedrock/run.yaml @@ -35,9 +35,6 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/bedrock}/agents_store.db - responses_store: - type: sqlite - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/bedrock}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference @@ -99,9 +96,6 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/bedrock}/registry.db -inference_store: - type: sqlite - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/bedrock}/inference_store.db models: - metadata: {} model_id: meta.llama3-1-8b-instruct-v1:0 diff --git a/llama_stack/templates/cerebras/build.yaml b/llama_stack/templates/cerebras/build.yaml index f26f4ed9b..0498da1cd 100644 --- a/llama_stack/templates/cerebras/build.yaml +++ b/llama_stack/templates/cerebras/build.yaml @@ -29,6 +29,3 @@ distribution_spec: - remote::tavily-search - inline::rag-runtime image_type: conda -additional_pip_packages: -- aiosqlite -- sqlalchemy[asyncio] diff --git a/llama_stack/templates/cerebras/doc_template.md b/llama_stack/templates/cerebras/doc_template.md index 5cae2b2da..76f8c34ad 100644 --- a/llama_stack/templates/cerebras/doc_template.md +++ b/llama_stack/templates/cerebras/doc_template.md @@ -46,7 +46,7 @@ docker run \ -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ -v ./run.yaml:/root/my-run.yaml \ llamastack/distribution-{{ name }} \ - --config /root/my-run.yaml \ + --yaml-config /root/my-run.yaml \ --port $LLAMA_STACK_PORT \ --env CEREBRAS_API_KEY=$CEREBRAS_API_KEY ``` diff --git a/llama_stack/templates/cerebras/run.yaml b/llama_stack/templates/cerebras/run.yaml index c080536b7..0731b1df9 100644 --- a/llama_stack/templates/cerebras/run.yaml +++ b/llama_stack/templates/cerebras/run.yaml @@ -41,9 +41,6 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/cerebras}/agents_store.db - responses_store: - type: sqlite - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/cerebras}/responses_store.db eval: - provider_id: meta-reference provider_type: inline::meta-reference @@ -102,9 +99,6 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/cerebras}/registry.db -inference_store: - type: sqlite - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/cerebras}/inference_store.db models: - metadata: {} model_id: llama3.1-8b diff --git a/llama_stack/templates/ci-tests/build.yaml b/llama_stack/templates/ci-tests/build.yaml index 9f4fbbdda..a4c5893c4 100644 --- a/llama_stack/templates/ci-tests/build.yaml +++ b/llama_stack/templates/ci-tests/build.yaml @@ -30,6 +30,3 @@ distribution_spec: - inline::rag-runtime - remote::model-context-protocol image_type: conda -additional_pip_packages: -- aiosqlite -- sqlalchemy[asyncio] diff --git a/llama_stack/templates/ci-tests/run.yaml b/llama_stack/templates/ci-tests/run.yaml index 368187d3a..d9ee5b3cf 100644 --- a/llama_stack/templates/ci-tests/run.yaml +++ b/llama_stack/templates/ci-tests/run.yaml @@ -38,9 +38,6 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ci-tests}/agents_store.db - responses_store: - type: sqlite - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ci-tests}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference @@ -102,9 +99,6 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ci-tests}/registry.db -inference_store: - type: sqlite - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ci-tests}/inference_store.db models: - metadata: {} model_id: accounts/fireworks/models/llama-v3p1-8b-instruct diff --git a/llama_stack/templates/dell/build.yaml b/llama_stack/templates/dell/build.yaml index 513df16c1..f5beb6c2f 100644 --- a/llama_stack/templates/dell/build.yaml +++ b/llama_stack/templates/dell/build.yaml @@ -30,6 +30,3 @@ distribution_spec: - remote::tavily-search - inline::rag-runtime image_type: conda -additional_pip_packages: -- aiosqlite -- sqlalchemy[asyncio] diff --git a/llama_stack/templates/dell/doc_template.md b/llama_stack/templates/dell/doc_template.md index 6bdd7f81c..26f07130b 100644 --- a/llama_stack/templates/dell/doc_template.md +++ b/llama_stack/templates/dell/doc_template.md @@ -143,7 +143,7 @@ docker run \ -v $HOME/.llama:/root/.llama \ -v ./llama_stack/templates/tgi/run-with-safety.yaml:/root/my-run.yaml \ llamastack/distribution-{{ name }} \ - --config /root/my-run.yaml \ + --yaml-config /root/my-run.yaml \ --port $LLAMA_STACK_PORT \ --env INFERENCE_MODEL=$INFERENCE_MODEL \ --env DEH_URL=$DEH_URL \ diff --git a/llama_stack/templates/dell/run-with-safety.yaml b/llama_stack/templates/dell/run-with-safety.yaml index 5c6072245..24c515112 100644 --- a/llama_stack/templates/dell/run-with-safety.yaml +++ b/llama_stack/templates/dell/run-with-safety.yaml @@ -41,9 +41,6 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dell}/agents_store.db - responses_store: - type: sqlite - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dell}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference @@ -102,9 +99,6 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dell}/registry.db -inference_store: - type: sqlite - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dell}/inference_store.db models: - metadata: {} model_id: ${env.INFERENCE_MODEL} diff --git a/llama_stack/templates/dell/run.yaml b/llama_stack/templates/dell/run.yaml index ffaa0bf2f..fdece894f 100644 --- a/llama_stack/templates/dell/run.yaml +++ b/llama_stack/templates/dell/run.yaml @@ -37,9 +37,6 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dell}/agents_store.db - responses_store: - type: sqlite - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dell}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference @@ -98,9 +95,6 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dell}/registry.db -inference_store: - type: sqlite - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dell}/inference_store.db models: - metadata: {} model_id: ${env.INFERENCE_MODEL} diff --git a/llama_stack/templates/dependencies.json b/llama_stack/templates/dependencies.json index 47a35edc0..d1a17e48e 100644 --- a/llama_stack/templates/dependencies.json +++ b/llama_stack/templates/dependencies.json @@ -31,7 +31,6 @@ "scikit-learn", "scipy", "sentencepiece", - "sqlalchemy[asyncio]", "tqdm", "transformers", "tree_sitter", @@ -68,7 +67,6 @@ "scikit-learn", "scipy", "sentencepiece", - "sqlalchemy[asyncio]", "tqdm", "transformers", "tree_sitter", @@ -107,7 +105,6 @@ "scikit-learn", "scipy", "sentencepiece", - "sqlalchemy[asyncio]", "sqlite-vec", "tqdm", "transformers", @@ -148,7 +145,6 @@ "scikit-learn", "scipy", "sentencepiece", - "sqlalchemy[asyncio]", "tqdm", "transformers", "tree_sitter", @@ -188,7 +184,6 @@ "scikit-learn", "scipy", "sentencepiece", - "sqlalchemy[asyncio]", "tqdm", "transformers", "tree_sitter", @@ -226,7 +221,6 @@ "scikit-learn", "scipy", "sentencepiece", - "sqlalchemy[asyncio]", "tqdm", "transformers", "tree_sitter", @@ -265,7 +259,6 @@ "scikit-learn", "scipy", "sentencepiece", - "sqlalchemy[asyncio]", "tqdm", "transformers", "tree_sitter", @@ -304,46 +297,6 @@ "scikit-learn", "scipy", "sentencepiece", - "sqlalchemy[asyncio]", - "tqdm", - "transformers", - "tree_sitter", - "uvicorn", - "sentence-transformers --no-deps", - "torch torchvision --index-url https://download.pytorch.org/whl/cpu" - ], - "kvant": [ - "aiosqlite", - "autoevals", - "blobfile", - "chardet", - "chromadb-client", - "datasets", - "emoji", - "faiss-cpu", - "fastapi", - "fire", - "httpx", - "langdetect", - "matplotlib", - "mcp", - "nltk", - "numpy", - "openai", - "opentelemetry-exporter-otlp-proto-http", - "opentelemetry-sdk", - "pandas", - "pillow", - "psycopg2-binary", - "pymongo", - "pypdf", - "pythainlp", - "redis", - "requests", - "scikit-learn", - "scipy", - "sentencepiece", - "sqlalchemy[asyncio]", "tqdm", "transformers", "tree_sitter", @@ -382,7 +335,6 @@ "scikit-learn", "scipy", "sentencepiece", - "sqlalchemy[asyncio]", "sqlite-vec", "tqdm", "transformers", @@ -427,7 +379,6 @@ "scipy", "sentence-transformers", "sentencepiece", - "sqlalchemy[asyncio]", "torch", "torchao==0.8.0", "torchvision", @@ -463,7 +414,6 @@ "scikit-learn", "scipy", "sentencepiece", - "sqlalchemy[asyncio]", "tqdm", "transformers", "uvicorn" @@ -491,7 +441,6 @@ "opentelemetry-exporter-otlp-proto-http", "opentelemetry-sdk", "pandas", - "peft", "pillow", "psycopg2-binary", "pymongo", @@ -502,12 +451,9 @@ "scikit-learn", "scipy", "sentencepiece", - "sqlalchemy[asyncio]", - "torch", "tqdm", "transformers", "tree_sitter", - "trl", "uvicorn" ], "open-benchmark": [ @@ -541,7 +487,6 @@ "scikit-learn", "scipy", "sentencepiece", - "sqlalchemy[asyncio]", "sqlite-vec", "together", "tqdm", @@ -580,7 +525,6 @@ "scikit-learn", "scipy", "sentencepiece", - "sqlalchemy[asyncio]", "tqdm", "transformers", "tree_sitter", @@ -619,7 +563,6 @@ "scikit-learn", "scipy", "sentencepiece", - "sqlalchemy[asyncio]", "tqdm", "transformers", "tree_sitter", @@ -653,7 +596,6 @@ "scikit-learn", "scipy", "sentencepiece", - "sqlalchemy[asyncio]", "tqdm", "transformers", "uvicorn", @@ -692,7 +634,6 @@ "scikit-learn", "scipy", "sentencepiece", - "sqlalchemy[asyncio]", "sqlite-vec", "tqdm", "transformers", @@ -734,7 +675,6 @@ "scikit-learn", "scipy", "sentencepiece", - "sqlalchemy[asyncio]", "tqdm", "transformers", "tree_sitter", @@ -773,7 +713,6 @@ "scikit-learn", "scipy", "sentencepiece", - "sqlalchemy[asyncio]", "together", "tqdm", "transformers", @@ -813,7 +752,6 @@ "scikit-learn", "scipy", "sentencepiece", - "sqlalchemy[asyncio]", "sqlite-vec", "tqdm", "transformers", @@ -853,7 +791,6 @@ "scikit-learn", "scipy", "sentencepiece", - "sqlalchemy[asyncio]", "tqdm", "transformers", "tree_sitter", @@ -893,7 +830,6 @@ "scikit-learn", "scipy", "sentencepiece", - "sqlalchemy[asyncio]", "tqdm", "transformers", "tree_sitter", diff --git a/llama_stack/templates/experimental-post-training/build.yaml b/llama_stack/templates/experimental-post-training/build.yaml index 55cd189c6..b4b5e2203 100644 --- a/llama_stack/templates/experimental-post-training/build.yaml +++ b/llama_stack/templates/experimental-post-training/build.yaml @@ -13,10 +13,9 @@ distribution_spec: - inline::basic - inline::braintrust post_training: - - inline::huggingface + - inline::torchtune datasetio: - inline::localfs - - remote::huggingface telemetry: - inline::meta-reference agents: diff --git a/llama_stack/templates/experimental-post-training/run.yaml b/llama_stack/templates/experimental-post-training/run.yaml index 393cba41d..2ebdfe1aa 100644 --- a/llama_stack/templates/experimental-post-training/run.yaml +++ b/llama_stack/templates/experimental-post-training/run.yaml @@ -49,24 +49,16 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/experimental-post-training}/localfs_datasetio.db - - provider_id: huggingface - provider_type: remote::huggingface - config: - kvstore: - type: sqlite - namespace: null - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/huggingface}/huggingface_datasetio.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference config: {} post_training: - - provider_id: huggingface - provider_type: inline::huggingface - config: + - provider_id: torchtune-post-training + provider_type: inline::torchtune + config: { checkpoint_format: huggingface - distributed_backend: null - device: cpu + } agents: - provider_id: meta-reference provider_type: inline::meta-reference diff --git a/llama_stack/templates/fireworks/build.yaml b/llama_stack/templates/fireworks/build.yaml index be19181c0..7c74157ee 100644 --- a/llama_stack/templates/fireworks/build.yaml +++ b/llama_stack/templates/fireworks/build.yaml @@ -31,6 +31,3 @@ distribution_spec: - inline::rag-runtime - remote::model-context-protocol image_type: conda -additional_pip_packages: -- aiosqlite -- sqlalchemy[asyncio] diff --git a/llama_stack/templates/fireworks/run-with-safety.yaml b/llama_stack/templates/fireworks/run-with-safety.yaml index 41500f6f6..0ab07613e 100644 --- a/llama_stack/templates/fireworks/run-with-safety.yaml +++ b/llama_stack/templates/fireworks/run-with-safety.yaml @@ -46,9 +46,6 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/fireworks}/agents_store.db - responses_store: - type: sqlite - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/fireworks}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference @@ -114,9 +111,6 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/fireworks}/registry.db -inference_store: - type: sqlite - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/fireworks}/inference_store.db models: - metadata: {} model_id: accounts/fireworks/models/llama-v3p1-8b-instruct diff --git a/llama_stack/templates/fireworks/run.yaml b/llama_stack/templates/fireworks/run.yaml index b1fa03306..81c293a46 100644 --- a/llama_stack/templates/fireworks/run.yaml +++ b/llama_stack/templates/fireworks/run.yaml @@ -41,9 +41,6 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/fireworks}/agents_store.db - responses_store: - type: sqlite - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/fireworks}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference @@ -109,9 +106,6 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/fireworks}/registry.db -inference_store: - type: sqlite - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/fireworks}/inference_store.db models: - metadata: {} model_id: accounts/fireworks/models/llama-v3p1-8b-instruct diff --git a/llama_stack/templates/groq/build.yaml b/llama_stack/templates/groq/build.yaml index 819df22f0..800c3e3ae 100644 --- a/llama_stack/templates/groq/build.yaml +++ b/llama_stack/templates/groq/build.yaml @@ -26,6 +26,3 @@ distribution_spec: - remote::tavily-search - inline::rag-runtime image_type: conda -additional_pip_packages: -- aiosqlite -- sqlalchemy[asyncio] diff --git a/llama_stack/templates/groq/run.yaml b/llama_stack/templates/groq/run.yaml index db7ebffee..79c350c73 100644 --- a/llama_stack/templates/groq/run.yaml +++ b/llama_stack/templates/groq/run.yaml @@ -41,9 +41,6 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/groq}/agents_store.db - responses_store: - type: sqlite - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/groq}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference @@ -102,9 +99,6 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/groq}/registry.db -inference_store: - type: sqlite - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/groq}/inference_store.db models: - metadata: {} model_id: groq/llama3-8b-8192 diff --git a/llama_stack/templates/hf-endpoint/build.yaml b/llama_stack/templates/hf-endpoint/build.yaml index 8ede83694..2a40c3909 100644 --- a/llama_stack/templates/hf-endpoint/build.yaml +++ b/llama_stack/templates/hf-endpoint/build.yaml @@ -29,6 +29,3 @@ distribution_spec: - inline::rag-runtime - remote::model-context-protocol image_type: conda -additional_pip_packages: -- aiosqlite -- sqlalchemy[asyncio] diff --git a/llama_stack/templates/hf-endpoint/run-with-safety.yaml b/llama_stack/templates/hf-endpoint/run-with-safety.yaml index 15cf2a47f..82bcaa3cf 100644 --- a/llama_stack/templates/hf-endpoint/run-with-safety.yaml +++ b/llama_stack/templates/hf-endpoint/run-with-safety.yaml @@ -46,9 +46,6 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-endpoint}/agents_store.db - responses_store: - type: sqlite - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-endpoint}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference @@ -110,9 +107,6 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-endpoint}/registry.db -inference_store: - type: sqlite - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-endpoint}/inference_store.db models: - metadata: {} model_id: ${env.INFERENCE_MODEL} diff --git a/llama_stack/templates/hf-endpoint/run.yaml b/llama_stack/templates/hf-endpoint/run.yaml index 428edf9a2..ec7c55032 100644 --- a/llama_stack/templates/hf-endpoint/run.yaml +++ b/llama_stack/templates/hf-endpoint/run.yaml @@ -41,9 +41,6 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-endpoint}/agents_store.db - responses_store: - type: sqlite - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-endpoint}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference @@ -105,9 +102,6 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-endpoint}/registry.db -inference_store: - type: sqlite - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-endpoint}/inference_store.db models: - metadata: {} model_id: ${env.INFERENCE_MODEL} diff --git a/llama_stack/templates/hf-serverless/build.yaml b/llama_stack/templates/hf-serverless/build.yaml index d0752db9a..f77f8773b 100644 --- a/llama_stack/templates/hf-serverless/build.yaml +++ b/llama_stack/templates/hf-serverless/build.yaml @@ -30,6 +30,3 @@ distribution_spec: - inline::rag-runtime - remote::model-context-protocol image_type: conda -additional_pip_packages: -- aiosqlite -- sqlalchemy[asyncio] diff --git a/llama_stack/templates/hf-serverless/run-with-safety.yaml b/llama_stack/templates/hf-serverless/run-with-safety.yaml index ab461c6c3..320976e2c 100644 --- a/llama_stack/templates/hf-serverless/run-with-safety.yaml +++ b/llama_stack/templates/hf-serverless/run-with-safety.yaml @@ -46,9 +46,6 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-serverless}/agents_store.db - responses_store: - type: sqlite - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-serverless}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference @@ -110,9 +107,6 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-serverless}/registry.db -inference_store: - type: sqlite - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-serverless}/inference_store.db models: - metadata: {} model_id: ${env.INFERENCE_MODEL} diff --git a/llama_stack/templates/hf-serverless/run.yaml b/llama_stack/templates/hf-serverless/run.yaml index d238506fb..2b22b20c6 100644 --- a/llama_stack/templates/hf-serverless/run.yaml +++ b/llama_stack/templates/hf-serverless/run.yaml @@ -41,9 +41,6 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-serverless}/agents_store.db - responses_store: - type: sqlite - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-serverless}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference @@ -105,9 +102,6 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-serverless}/registry.db -inference_store: - type: sqlite - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-serverless}/inference_store.db models: - metadata: {} model_id: ${env.INFERENCE_MODEL} diff --git a/llama_stack/templates/kvant/__init__.py b/llama_stack/templates/kvant/__init__.py deleted file mode 100644 index 61706f7f6..000000000 --- a/llama_stack/templates/kvant/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from .kvant import get_distribution_template # noqa: F401 diff --git a/llama_stack/templates/kvant/build.yaml b/llama_stack/templates/kvant/build.yaml deleted file mode 100644 index 25afc1f4d..000000000 --- a/llama_stack/templates/kvant/build.yaml +++ /dev/null @@ -1,35 +0,0 @@ -version: '2' -distribution_spec: - description: distribution for kvant cloud - providers: - inference: - - remote::vllm - - inline::sentence-transformers - vector_io: - - inline::faiss - - remote::chromadb - - remote::pgvector - safety: - - inline::llama-guard - agents: - - inline::meta-reference - telemetry: - - inline::meta-reference - eval: - - inline::meta-reference - datasetio: - - remote::huggingface - - inline::localfs - scoring: - - inline::basic - - inline::llm-as-judge - - inline::braintrust - tool_runtime: - - remote::brave-search - - remote::tavily-search - - remote::wolfram-alpha - - inline::rag-runtime - - remote::model-context-protocol -image_type: conda -additional_pip_packages: -- sqlalchemy[asyncio] diff --git a/llama_stack/templates/kvant/kvant.py b/llama_stack/templates/kvant/kvant.py deleted file mode 100644 index 44cfc7016..000000000 --- a/llama_stack/templates/kvant/kvant.py +++ /dev/null @@ -1,136 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from pathlib import Path - -from llama_stack.apis.models.models import ModelType -from llama_stack.distribution.datatypes import ( - ModelInput, - Provider, - ShieldInput, - ToolGroupInput, -) -from llama_stack.providers.inline.inference.sentence_transformers import ( - SentenceTransformersInferenceConfig, -) -from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig -from llama_stack.providers.remote.inference.passthrough.config import ( - PassthroughImplConfig, -) -from llama_stack.providers.utils.inference.model_registry import ProviderModelEntry -from llama_stack.templates.template import DistributionTemplate, RunConfigSettings - - -def get_distribution_template() -> DistributionTemplate: - providers = { - "inference": ["remote::openai", "inline::sentence-transformers"], - "vector_io": ["inline::faiss", "remote::chromadb", "remote::pgvector"], - "safety": ["inline::llama-guard"], - "agents": ["inline::meta-reference"], - "telemetry": ["inline::meta-reference"], - "eval": ["inline::meta-reference"], - "datasetio": ["remote::huggingface", "inline::localfs"], - "scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"], - "tool_runtime": [ - "remote::brave-search", - "remote::tavily-search", - "remote::wolfram-alpha", - "inline::rag-runtime", - "remote::model-context-protocol", - ], - } - - name = "kvant" - - inference_provider = Provider( - provider_id="openai", - provider_type="remote::openai", - config=PassthroughImplConfig.sample_run_config(), - ) - embedding_provider = Provider( - provider_id="sentence-transformers", - provider_type="inline::sentence-transformers", - config=SentenceTransformersInferenceConfig.sample_run_config(), - ) - vector_io_provider = Provider( - provider_id="faiss", - provider_type="inline::faiss", - config=FaissVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"), - ) - - default_models = [ - ModelInput( - metadata={}, - model_id="inference-llama4-maverick", - provider_id="openai", - provider_model_id="inference-llama4-maverick", - model_type=ModelType.llm, - ), - ] - - embedding_model = ModelInput( - model_id="all-MiniLM-L6-v2", - provider_id="sentence-transformers", - model_type=ModelType.embedding, - metadata={ - "embedding_dimension": 384, - }, - ) - default_tool_groups = [ - ToolGroupInput( - toolgroup_id="builtin::websearch", - provider_id="tavily-search", - ), - ToolGroupInput( - toolgroup_id="builtin::wolfram_alpha", - provider_id="wolfram-alpha", - ), - ToolGroupInput( - toolgroup_id="builtin::rag", - provider_id="rag-runtime", - ), - ] - - return DistributionTemplate( - name=name, - distro_type="self_hosted", - description="Use Passthrough hosted llama-stack endpoint for LLM inference", - container_image=None, - providers=providers, - available_models_by_provider={ - "openai": [ - ProviderModelEntry( - provider_model_id="inference-llama4-maverick", - model_type=ModelType.llm, - ), - ], - }, - run_configs={ - "run.yaml": RunConfigSettings( - provider_overrides={ - "inference": [inference_provider, embedding_provider], - "vector_io": [vector_io_provider], - }, - default_models=default_models + [embedding_model], - default_shields=[ShieldInput(shield_id="meta-llama/Llama-Guard-3-8B")], - default_tool_groups=default_tool_groups, - ), - }, - run_config_env_vars={ - "LLAMA_STACK_PORT": ( - "8321", - "Port for the Llama Stack distribution server", - ), - "OPENAI_API_KEY": ( - "", - "kvant maas API Key", - ), - "OPENAI_BASE_URL": ( - "https://maas.kvant.cloud", - "kvant maas URL", - ), - }, - ) diff --git a/llama_stack/templates/kvant/run.yaml b/llama_stack/templates/kvant/run.yaml deleted file mode 100644 index 99fb6f7fa..000000000 --- a/llama_stack/templates/kvant/run.yaml +++ /dev/null @@ -1,170 +0,0 @@ -version: '2' -image_name: kvant -apis: -- agents -- datasetio -- eval -- inference -- safety -- scoring -- telemetry -- tool_runtime -- vector_io -providers: - inference: - - provider_id: kvant - provider_type: remote::vllm - config: - url: ${env.VLLM_URL:https://maas.ai-2.kvant.cloud/v1} - max_tokens: ${env.VLLM_MAX_TOKENS:400000} - api_token: ${env.VLLM_API_TOKEN:fake} - tls_verify: ${env.VLLM_TLS_VERIFY:true} - - provider_id: sentence-transformers - provider_type: inline::sentence-transformers - config: {} - vector_io: - - provider_id: faiss - provider_type: inline::faiss - config: - kvstore: - type: sqlite - namespace: null - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/kvant}/faiss_store.db - safety: - - provider_id: llama-guard - provider_type: inline::llama-guard - config: - excluded_categories: [] - agents: - - provider_id: meta-reference - provider_type: inline::meta-reference - config: - persistence_store: - type: sqlite - namespace: null - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/kvant}/agents_store.db - responses_store: - type: sqlite - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/kvant}/responses_store.db - telemetry: - - provider_id: meta-reference - provider_type: inline::meta-reference - config: - service_name: ${env.OTEL_SERVICE_NAME:} - sinks: ${env.TELEMETRY_SINKS:console,sqlite} - sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/kvant}/trace_store.db - eval: - - provider_id: meta-reference - provider_type: inline::meta-reference - config: - kvstore: - type: sqlite - namespace: null - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/kvant}/meta_reference_eval.db - datasetio: - - provider_id: huggingface - provider_type: remote::huggingface - config: - kvstore: - type: sqlite - namespace: null - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/kvant}/huggingface_datasetio.db - - provider_id: localfs - provider_type: inline::localfs - config: - kvstore: - type: sqlite - namespace: null - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/kvant}/localfs_datasetio.db - scoring: - - provider_id: basic - provider_type: inline::basic - config: {} - - provider_id: llm-as-judge - provider_type: inline::llm-as-judge - config: {} - - provider_id: braintrust - provider_type: inline::braintrust - config: - openai_api_key: ${env.OPENAI_API_KEY:} - tool_runtime: - - provider_id: brave-search - provider_type: remote::brave-search - config: - api_key: ${env.BRAVE_SEARCH_API_KEY:} - max_results: 3 - - provider_id: tavily-search - provider_type: remote::tavily-search - config: - api_key: ${env.TAVILY_SEARCH_API_KEY:} - max_results: 3 - - provider_id: wolfram-alpha - provider_type: remote::wolfram-alpha - config: - api_key: ${env.WOLFRAM_ALPHA_API_KEY:} - - provider_id: rag-runtime - provider_type: inline::rag-runtime - config: {} - - provider_id: model-context-protocol - provider_type: remote::model-context-protocol - config: {} -metadata_store: - type: sqlite - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/kvant}/registry.db -inference_store: - type: sqlite - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/kvant}/inference_store.db -models: -- metadata: {} - model_id: Llama-4-Maverick-17B-128E-Instruct-FP8 - provider_id: kvant - provider_model_id: inference-llama4-maverick - model_type: llm -- metadata: - embedding_dimension: 1024 - context_length: 8192 - model_id: inference-bge-m3 - provider_id: kvant - model_type: embedding -- metadata: - embedding_dimension: 384 - model_id: all-MiniLM-L6-v2 - provider_id: sentence-transformers - model_type: embedding -shields: -- shield_id: meta-llama/Llama-Guard-3-8B -vector_dbs: [] -# - vector_db_id: test-bge -# embedding_model: inference-bge-m3 -# embedding_dimension: 1024 -# provider_id: faiss -# - vector_db_id: test-MiniLM-L6-v2 -# embedding_model: all-MiniLM-L6-v2 -# embedding_dimension: 384 -# provider_id: faiss -datasets: [] -scoring_fns: [] -benchmarks: [] -tool_groups: -- toolgroup_id: builtin::websearch - provider_id: tavily-search -- toolgroup_id: builtin::wolfram_alpha - provider_id: wolfram-alpha -- toolgroup_id: builtin::rag - provider_id: rag-runtime -server: - port: 8321 - auth: - provider_type: "oauth2_token" - config: - jwks: - introspection: - url: ${env.KEYCLOAK_INSTROSPECT:https://iam.phoenix-systems.ch/realms/kvant/protocol/openid-connect/token/introspect} - client_id: ${env.KEYCLOAK_CLIENT_ID:llama-stack} - client_secret: ${env.KEYCLOAK_CLIENT_SECRET} - claims_mapping: - sub: projects - scope: roles - #groups: teams - customer/id: teams - aud: namespaces diff --git a/llama_stack/templates/llama_api/build.yaml b/llama_stack/templates/llama_api/build.yaml index 857e5f014..f97ee4091 100644 --- a/llama_stack/templates/llama_api/build.yaml +++ b/llama_stack/templates/llama_api/build.yaml @@ -30,6 +30,3 @@ distribution_spec: - inline::rag-runtime - remote::model-context-protocol image_type: conda -additional_pip_packages: -- aiosqlite -- sqlalchemy[asyncio] diff --git a/llama_stack/templates/llama_api/run.yaml b/llama_stack/templates/llama_api/run.yaml index a7f2b0769..a879482d7 100644 --- a/llama_stack/templates/llama_api/run.yaml +++ b/llama_stack/templates/llama_api/run.yaml @@ -50,9 +50,6 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/llama_api}/agents_store.db - responses_store: - type: sqlite - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/llama_api}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference @@ -114,9 +111,6 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/llama_api}/registry.db -inference_store: - type: sqlite - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/llama_api}/inference_store.db models: - metadata: {} model_id: Llama-3.3-70B-Instruct diff --git a/llama_stack/templates/meta-reference-gpu/build.yaml b/llama_stack/templates/meta-reference-gpu/build.yaml index 53ad411e3..a9d03490b 100644 --- a/llama_stack/templates/meta-reference-gpu/build.yaml +++ b/llama_stack/templates/meta-reference-gpu/build.yaml @@ -29,6 +29,3 @@ distribution_spec: - inline::rag-runtime - remote::model-context-protocol image_type: conda -additional_pip_packages: -- aiosqlite -- sqlalchemy[asyncio] diff --git a/llama_stack/templates/meta-reference-gpu/run-with-safety.yaml b/llama_stack/templates/meta-reference-gpu/run-with-safety.yaml index 2b751a514..180d44e0f 100644 --- a/llama_stack/templates/meta-reference-gpu/run-with-safety.yaml +++ b/llama_stack/templates/meta-reference-gpu/run-with-safety.yaml @@ -56,9 +56,6 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/agents_store.db - responses_store: - type: sqlite - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference @@ -120,9 +117,6 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/registry.db -inference_store: - type: sqlite - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/inference_store.db models: - metadata: {} model_id: ${env.INFERENCE_MODEL} diff --git a/llama_stack/templates/meta-reference-gpu/run.yaml b/llama_stack/templates/meta-reference-gpu/run.yaml index a24c5fec5..d879667e0 100644 --- a/llama_stack/templates/meta-reference-gpu/run.yaml +++ b/llama_stack/templates/meta-reference-gpu/run.yaml @@ -46,9 +46,6 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/agents_store.db - responses_store: - type: sqlite - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference @@ -110,9 +107,6 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/registry.db -inference_store: - type: sqlite - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/inference_store.db models: - metadata: {} model_id: ${env.INFERENCE_MODEL} diff --git a/llama_stack/templates/nvidia/build.yaml b/llama_stack/templates/nvidia/build.yaml index 6bd8a0100..a05cf97ad 100644 --- a/llama_stack/templates/nvidia/build.yaml +++ b/llama_stack/templates/nvidia/build.yaml @@ -24,6 +24,3 @@ distribution_spec: tool_runtime: - inline::rag-runtime image_type: conda -additional_pip_packages: -- aiosqlite -- sqlalchemy[asyncio] diff --git a/llama_stack/templates/nvidia/doc_template.md b/llama_stack/templates/nvidia/doc_template.md index 50c96802f..068dd7ac3 100644 --- a/llama_stack/templates/nvidia/doc_template.md +++ b/llama_stack/templates/nvidia/doc_template.md @@ -116,7 +116,7 @@ docker run \ -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ -v ./run.yaml:/root/my-run.yaml \ llamastack/distribution-{{ name }} \ - --config /root/my-run.yaml \ + --yaml-config /root/my-run.yaml \ --port $LLAMA_STACK_PORT \ --env NVIDIA_API_KEY=$NVIDIA_API_KEY ``` diff --git a/llama_stack/templates/nvidia/run-with-safety.yaml b/llama_stack/templates/nvidia/run-with-safety.yaml index c431e12f2..3cdb8e3d2 100644 --- a/llama_stack/templates/nvidia/run-with-safety.yaml +++ b/llama_stack/templates/nvidia/run-with-safety.yaml @@ -46,9 +46,6 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/agents_store.db - responses_store: - type: sqlite - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference @@ -95,9 +92,6 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/registry.db -inference_store: - type: sqlite - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/inference_store.db models: - metadata: {} model_id: ${env.INFERENCE_MODEL} diff --git a/llama_stack/templates/nvidia/run.yaml b/llama_stack/templates/nvidia/run.yaml index 5b244081d..3337b7942 100644 --- a/llama_stack/templates/nvidia/run.yaml +++ b/llama_stack/templates/nvidia/run.yaml @@ -41,9 +41,6 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/agents_store.db - responses_store: - type: sqlite - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference @@ -83,9 +80,6 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/registry.db -inference_store: - type: sqlite - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/inference_store.db models: - metadata: {} model_id: meta/llama3-8b-instruct diff --git a/llama_stack/templates/ollama/build.yaml b/llama_stack/templates/ollama/build.yaml index 36a120897..88e61bf8a 100644 --- a/llama_stack/templates/ollama/build.yaml +++ b/llama_stack/templates/ollama/build.yaml @@ -23,8 +23,6 @@ distribution_spec: - inline::basic - inline::llm-as-judge - inline::braintrust - post_training: - - inline::huggingface tool_runtime: - remote::brave-search - remote::tavily-search @@ -32,6 +30,3 @@ distribution_spec: - remote::model-context-protocol - remote::wolfram-alpha image_type: conda -additional_pip_packages: -- aiosqlite -- sqlalchemy[asyncio] diff --git a/llama_stack/templates/ollama/doc_template.md b/llama_stack/templates/ollama/doc_template.md index aaa65bab2..f961ab7ed 100644 --- a/llama_stack/templates/ollama/doc_template.md +++ b/llama_stack/templates/ollama/doc_template.md @@ -86,7 +86,7 @@ docker run \ -v ~/.llama:/root/.llama \ -v ./llama_stack/templates/ollama/run-with-safety.yaml:/root/my-run.yaml \ llamastack/distribution-{{ name }} \ - --config /root/my-run.yaml \ + --yaml-config /root/my-run.yaml \ --port $LLAMA_STACK_PORT \ --env INFERENCE_MODEL=$INFERENCE_MODEL \ --env SAFETY_MODEL=$SAFETY_MODEL \ diff --git a/llama_stack/templates/ollama/ollama.py b/llama_stack/templates/ollama/ollama.py index 0b4f05128..d72d299ec 100644 --- a/llama_stack/templates/ollama/ollama.py +++ b/llama_stack/templates/ollama/ollama.py @@ -13,7 +13,6 @@ from llama_stack.distribution.datatypes import ( ShieldInput, ToolGroupInput, ) -from llama_stack.providers.inline.post_training.huggingface import HuggingFacePostTrainingConfig from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig from llama_stack.providers.remote.inference.ollama import OllamaImplConfig from llama_stack.templates.template import DistributionTemplate, RunConfigSettings @@ -29,7 +28,6 @@ def get_distribution_template() -> DistributionTemplate: "eval": ["inline::meta-reference"], "datasetio": ["remote::huggingface", "inline::localfs"], "scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"], - "post_training": ["inline::huggingface"], "tool_runtime": [ "remote::brave-search", "remote::tavily-search", @@ -49,11 +47,7 @@ def get_distribution_template() -> DistributionTemplate: provider_type="inline::faiss", config=FaissVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"), ) - posttraining_provider = Provider( - provider_id="huggingface", - provider_type="inline::huggingface", - config=HuggingFacePostTrainingConfig.sample_run_config(f"~/.llama/distributions/{name}"), - ) + inference_model = ModelInput( model_id="${env.INFERENCE_MODEL}", provider_id="ollama", @@ -98,7 +92,6 @@ def get_distribution_template() -> DistributionTemplate: provider_overrides={ "inference": [inference_provider], "vector_io": [vector_io_provider_faiss], - "post_training": [posttraining_provider], }, default_models=[inference_model, embedding_model], default_tool_groups=default_tool_groups, @@ -107,7 +100,6 @@ def get_distribution_template() -> DistributionTemplate: provider_overrides={ "inference": [inference_provider], "vector_io": [vector_io_provider_faiss], - "post_training": [posttraining_provider], "safety": [ Provider( provider_id="llama-guard", diff --git a/llama_stack/templates/ollama/run-with-safety.yaml b/llama_stack/templates/ollama/run-with-safety.yaml index d63c5e366..651d58117 100644 --- a/llama_stack/templates/ollama/run-with-safety.yaml +++ b/llama_stack/templates/ollama/run-with-safety.yaml @@ -5,7 +5,6 @@ apis: - datasetio - eval - inference -- post_training - safety - scoring - telemetry @@ -40,9 +39,6 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/agents_store.db - responses_store: - type: sqlite - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference @@ -84,13 +80,6 @@ providers: provider_type: inline::braintrust config: openai_api_key: ${env.OPENAI_API_KEY:} - post_training: - - provider_id: huggingface - provider_type: inline::huggingface - config: - checkpoint_format: huggingface - distributed_backend: null - device: cpu tool_runtime: - provider_id: brave-search provider_type: remote::brave-search @@ -115,9 +104,6 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/registry.db -inference_store: - type: sqlite - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/inference_store.db models: - metadata: {} model_id: ${env.INFERENCE_MODEL} diff --git a/llama_stack/templates/ollama/run.yaml b/llama_stack/templates/ollama/run.yaml index d208cd7f0..1372486fe 100644 --- a/llama_stack/templates/ollama/run.yaml +++ b/llama_stack/templates/ollama/run.yaml @@ -5,7 +5,6 @@ apis: - datasetio - eval - inference -- post_training - safety - scoring - telemetry @@ -38,9 +37,6 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/agents_store.db - responses_store: - type: sqlite - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference @@ -82,13 +78,6 @@ providers: provider_type: inline::braintrust config: openai_api_key: ${env.OPENAI_API_KEY:} - post_training: - - provider_id: huggingface - provider_type: inline::huggingface - config: - checkpoint_format: huggingface - distributed_backend: null - device: cpu tool_runtime: - provider_id: brave-search provider_type: remote::brave-search @@ -113,9 +102,6 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/registry.db -inference_store: - type: sqlite - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/inference_store.db models: - metadata: {} model_id: ${env.INFERENCE_MODEL} diff --git a/llama_stack/templates/open-benchmark/build.yaml b/llama_stack/templates/open-benchmark/build.yaml index 840f1e1db..b14e96435 100644 --- a/llama_stack/templates/open-benchmark/build.yaml +++ b/llama_stack/templates/open-benchmark/build.yaml @@ -33,6 +33,3 @@ distribution_spec: - inline::rag-runtime - remote::model-context-protocol image_type: conda -additional_pip_packages: -- aiosqlite -- sqlalchemy[asyncio] diff --git a/llama_stack/templates/open-benchmark/run.yaml b/llama_stack/templates/open-benchmark/run.yaml index 0e5edf728..30a27cbd8 100644 --- a/llama_stack/templates/open-benchmark/run.yaml +++ b/llama_stack/templates/open-benchmark/run.yaml @@ -64,9 +64,6 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/open-benchmark}/agents_store.db - responses_store: - type: sqlite - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/open-benchmark}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference @@ -128,9 +125,6 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/open-benchmark}/registry.db -inference_store: - type: sqlite - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/open-benchmark}/inference_store.db models: - metadata: {} model_id: openai/gpt-4o diff --git a/llama_stack/templates/passthrough/build.yaml b/llama_stack/templates/passthrough/build.yaml index 46b99cb75..f8d099070 100644 --- a/llama_stack/templates/passthrough/build.yaml +++ b/llama_stack/templates/passthrough/build.yaml @@ -31,6 +31,3 @@ distribution_spec: - inline::rag-runtime - remote::model-context-protocol image_type: conda -additional_pip_packages: -- aiosqlite -- sqlalchemy[asyncio] diff --git a/llama_stack/templates/passthrough/run-with-safety.yaml b/llama_stack/templates/passthrough/run-with-safety.yaml index bbf5d9a52..a91b9fc92 100644 --- a/llama_stack/templates/passthrough/run-with-safety.yaml +++ b/llama_stack/templates/passthrough/run-with-safety.yaml @@ -46,9 +46,6 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/passthrough}/agents_store.db - responses_store: - type: sqlite - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/passthrough}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference @@ -114,9 +111,6 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/passthrough}/registry.db -inference_store: - type: sqlite - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/passthrough}/inference_store.db models: - metadata: {} model_id: meta-llama/Llama-3.1-8B-Instruct diff --git a/llama_stack/templates/passthrough/run.yaml b/llama_stack/templates/passthrough/run.yaml index 146906d9b..d1dd3b885 100644 --- a/llama_stack/templates/passthrough/run.yaml +++ b/llama_stack/templates/passthrough/run.yaml @@ -41,9 +41,6 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/passthrough}/agents_store.db - responses_store: - type: sqlite - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/passthrough}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference @@ -109,9 +106,6 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/passthrough}/registry.db -inference_store: - type: sqlite - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/passthrough}/inference_store.db models: - metadata: {} model_id: meta-llama/Llama-3.1-8B-Instruct diff --git a/llama_stack/templates/postgres-demo/__init__.py b/llama_stack/templates/postgres-demo/__init__.py deleted file mode 100644 index 81473cb73..000000000 --- a/llama_stack/templates/postgres-demo/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from .postgres_demo import get_distribution_template # noqa: F401 diff --git a/llama_stack/templates/postgres-demo/build.yaml b/llama_stack/templates/postgres-demo/build.yaml deleted file mode 100644 index 8f3648abe..000000000 --- a/llama_stack/templates/postgres-demo/build.yaml +++ /dev/null @@ -1,24 +0,0 @@ -version: '2' -distribution_spec: - description: Quick start template for running Llama Stack with several popular providers - providers: - inference: - - remote::fireworks - - remote::vllm - vector_io: - - remote::chromadb - safety: - - inline::llama-guard - agents: - - inline::meta-reference - telemetry: - - inline::meta-reference - tool_runtime: - - remote::brave-search - - remote::tavily-search - - inline::rag-runtime - - remote::model-context-protocol -image_type: conda -additional_pip_packages: -- asyncpg -- sqlalchemy[asyncio] diff --git a/llama_stack/templates/postgres-demo/postgres_demo.py b/llama_stack/templates/postgres-demo/postgres_demo.py deleted file mode 100644 index d2e352320..000000000 --- a/llama_stack/templates/postgres-demo/postgres_demo.py +++ /dev/null @@ -1,164 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - - -from llama_stack.distribution.datatypes import ( - ModelInput, - Provider, - ShieldInput, - ToolGroupInput, -) -from llama_stack.providers.remote.inference.fireworks.config import FireworksImplConfig -from llama_stack.providers.remote.inference.fireworks.models import ( - MODEL_ENTRIES as FIREWORKS_MODEL_ENTRIES, -) -from llama_stack.providers.remote.inference.vllm import VLLMInferenceAdapterConfig -from llama_stack.providers.remote.vector_io.chroma.config import ChromaVectorIOConfig -from llama_stack.providers.utils.inference.model_registry import ProviderModelEntry -from llama_stack.providers.utils.kvstore.config import PostgresKVStoreConfig -from llama_stack.providers.utils.sqlstore.sqlstore import PostgresSqlStoreConfig -from llama_stack.templates.template import ( - DistributionTemplate, - RunConfigSettings, - get_model_registry, -) - - -def get_inference_providers() -> tuple[list[Provider], dict[str, list[ProviderModelEntry]]]: - # in this template, we allow each API key to be optional - providers = [ - ( - "fireworks", - FIREWORKS_MODEL_ENTRIES, - FireworksImplConfig.sample_run_config(api_key="${env.FIREWORKS_API_KEY:}"), - ), - ] - inference_providers = [] - available_models = {} - for provider_id, model_entries, config in providers: - inference_providers.append( - Provider( - provider_id=provider_id, - provider_type=f"remote::{provider_id}", - config=config, - ) - ) - available_models[provider_id] = model_entries - inference_providers.append( - Provider( - provider_id="vllm-inference", - provider_type="remote::vllm", - config=VLLMInferenceAdapterConfig.sample_run_config( - url="${env.VLLM_URL:http://localhost:8000/v1}", - ), - ) - ) - return inference_providers, available_models - - -def get_distribution_template() -> DistributionTemplate: - inference_providers, available_models = get_inference_providers() - providers = { - "inference": ([p.provider_type for p in inference_providers]), - "vector_io": ["remote::chromadb"], - "safety": ["inline::llama-guard"], - "agents": ["inline::meta-reference"], - "telemetry": ["inline::meta-reference"], - "tool_runtime": [ - "remote::brave-search", - "remote::tavily-search", - "inline::rag-runtime", - "remote::model-context-protocol", - ], - } - name = "postgres-demo" - - vector_io_providers = [ - Provider( - provider_id="${env.ENABLE_CHROMADB+chromadb}", - provider_type="remote::chromadb", - config=ChromaVectorIOConfig.sample_run_config(url="${env.CHROMADB_URL:}"), - ), - ] - default_tool_groups = [ - ToolGroupInput( - toolgroup_id="builtin::websearch", - provider_id="tavily-search", - ), - ToolGroupInput( - toolgroup_id="builtin::rag", - provider_id="rag-runtime", - ), - ] - - default_models = get_model_registry(available_models) - default_models.append( - ModelInput( - model_id="${env.INFERENCE_MODEL}", - provider_id="vllm-inference", - ) - ) - postgres_config = { - "type": "postgres", - "host": "${env.POSTGRES_HOST:localhost}", - "port": "${env.POSTGRES_PORT:5432}", - "db": "${env.POSTGRES_DB:llamastack}", - "user": "${env.POSTGRES_USER:llamastack}", - "password": "${env.POSTGRES_PASSWORD:llamastack}", - } - - return DistributionTemplate( - name=name, - distro_type="self_hosted", - description="Quick start template for running Llama Stack with several popular providers", - container_image=None, - template_path=None, - providers=providers, - available_models_by_provider=available_models, - run_configs={ - "run.yaml": RunConfigSettings( - provider_overrides={ - "inference": inference_providers, - "vector_io": vector_io_providers, - "agents": [ - Provider( - provider_id="meta-reference", - provider_type="inline::meta-reference", - config=dict( - persistence_store=postgres_config, - responses_store=postgres_config, - ), - ) - ], - "telemetry": [ - Provider( - provider_id="meta-reference", - provider_type="inline::meta-reference", - config=dict( - service_name="${env.OTEL_SERVICE_NAME:}", - sinks="${env.TELEMETRY_SINKS:console}", - ), - ) - ], - }, - default_models=default_models, - default_tool_groups=default_tool_groups, - default_shields=[ShieldInput(shield_id="meta-llama/Llama-Guard-3-8B")], - metadata_store=PostgresKVStoreConfig.model_validate(postgres_config), - inference_store=PostgresSqlStoreConfig.model_validate(postgres_config), - ), - }, - run_config_env_vars={ - "LLAMA_STACK_PORT": ( - "8321", - "Port for the Llama Stack distribution server", - ), - "FIREWORKS_API_KEY": ( - "", - "Fireworks API Key", - ), - }, - ) diff --git a/llama_stack/templates/postgres-demo/run.yaml b/llama_stack/templates/postgres-demo/run.yaml deleted file mode 100644 index 889b8eaa7..000000000 --- a/llama_stack/templates/postgres-demo/run.yaml +++ /dev/null @@ -1,224 +0,0 @@ -version: '2' -image_name: postgres-demo -apis: -- agents -- inference -- safety -- telemetry -- tool_runtime -- vector_io -providers: - inference: - - provider_id: fireworks - provider_type: remote::fireworks - config: - url: https://api.fireworks.ai/inference/v1 - api_key: ${env.FIREWORKS_API_KEY:} - - provider_id: vllm-inference - provider_type: remote::vllm - config: - url: ${env.VLLM_URL:http://localhost:8000/v1} - max_tokens: ${env.VLLM_MAX_TOKENS:4096} - api_token: ${env.VLLM_API_TOKEN:fake} - tls_verify: ${env.VLLM_TLS_VERIFY:true} - vector_io: - - provider_id: ${env.ENABLE_CHROMADB+chromadb} - provider_type: remote::chromadb - config: - url: ${env.CHROMADB_URL:} - safety: - - provider_id: llama-guard - provider_type: inline::llama-guard - config: - excluded_categories: [] - agents: - - provider_id: meta-reference - provider_type: inline::meta-reference - config: - persistence_store: - type: postgres - host: ${env.POSTGRES_HOST:localhost} - port: ${env.POSTGRES_PORT:5432} - db: ${env.POSTGRES_DB:llamastack} - user: ${env.POSTGRES_USER:llamastack} - password: ${env.POSTGRES_PASSWORD:llamastack} - responses_store: - type: postgres - host: ${env.POSTGRES_HOST:localhost} - port: ${env.POSTGRES_PORT:5432} - db: ${env.POSTGRES_DB:llamastack} - user: ${env.POSTGRES_USER:llamastack} - password: ${env.POSTGRES_PASSWORD:llamastack} - telemetry: - - provider_id: meta-reference - provider_type: inline::meta-reference - config: - service_name: ${env.OTEL_SERVICE_NAME:} - sinks: ${env.TELEMETRY_SINKS:console} - tool_runtime: - - provider_id: brave-search - provider_type: remote::brave-search - config: - api_key: ${env.BRAVE_SEARCH_API_KEY:} - max_results: 3 - - provider_id: tavily-search - provider_type: remote::tavily-search - config: - api_key: ${env.TAVILY_SEARCH_API_KEY:} - max_results: 3 - - provider_id: rag-runtime - provider_type: inline::rag-runtime - config: {} - - provider_id: model-context-protocol - provider_type: remote::model-context-protocol - config: {} -metadata_store: - type: postgres - host: ${env.POSTGRES_HOST:localhost} - port: ${env.POSTGRES_PORT:5432} - db: ${env.POSTGRES_DB:llamastack} - user: ${env.POSTGRES_USER:llamastack} - password: ${env.POSTGRES_PASSWORD:llamastack} - table_name: llamastack_kvstore -inference_store: - type: postgres - host: ${env.POSTGRES_HOST:localhost} - port: ${env.POSTGRES_PORT:5432} - db: ${env.POSTGRES_DB:llamastack} - user: ${env.POSTGRES_USER:llamastack} - password: ${env.POSTGRES_PASSWORD:llamastack} -models: -- metadata: {} - model_id: accounts/fireworks/models/llama-v3p1-8b-instruct - provider_id: fireworks - provider_model_id: accounts/fireworks/models/llama-v3p1-8b-instruct - model_type: llm -- metadata: {} - model_id: meta-llama/Llama-3.1-8B-Instruct - provider_id: fireworks - provider_model_id: accounts/fireworks/models/llama-v3p1-8b-instruct - model_type: llm -- metadata: {} - model_id: accounts/fireworks/models/llama-v3p1-70b-instruct - provider_id: fireworks - provider_model_id: accounts/fireworks/models/llama-v3p1-70b-instruct - model_type: llm -- metadata: {} - model_id: meta-llama/Llama-3.1-70B-Instruct - provider_id: fireworks - provider_model_id: accounts/fireworks/models/llama-v3p1-70b-instruct - model_type: llm -- metadata: {} - model_id: accounts/fireworks/models/llama-v3p1-405b-instruct - provider_id: fireworks - provider_model_id: accounts/fireworks/models/llama-v3p1-405b-instruct - model_type: llm -- metadata: {} - model_id: meta-llama/Llama-3.1-405B-Instruct-FP8 - provider_id: fireworks - provider_model_id: accounts/fireworks/models/llama-v3p1-405b-instruct - model_type: llm -- metadata: {} - model_id: accounts/fireworks/models/llama-v3p2-3b-instruct - provider_id: fireworks - provider_model_id: accounts/fireworks/models/llama-v3p2-3b-instruct - model_type: llm -- metadata: {} - model_id: meta-llama/Llama-3.2-3B-Instruct - provider_id: fireworks - provider_model_id: accounts/fireworks/models/llama-v3p2-3b-instruct - model_type: llm -- metadata: {} - model_id: accounts/fireworks/models/llama-v3p2-11b-vision-instruct - provider_id: fireworks - provider_model_id: accounts/fireworks/models/llama-v3p2-11b-vision-instruct - model_type: llm -- metadata: {} - model_id: meta-llama/Llama-3.2-11B-Vision-Instruct - provider_id: fireworks - provider_model_id: accounts/fireworks/models/llama-v3p2-11b-vision-instruct - model_type: llm -- metadata: {} - model_id: accounts/fireworks/models/llama-v3p2-90b-vision-instruct - provider_id: fireworks - provider_model_id: accounts/fireworks/models/llama-v3p2-90b-vision-instruct - model_type: llm -- metadata: {} - model_id: meta-llama/Llama-3.2-90B-Vision-Instruct - provider_id: fireworks - provider_model_id: accounts/fireworks/models/llama-v3p2-90b-vision-instruct - model_type: llm -- metadata: {} - model_id: accounts/fireworks/models/llama-v3p3-70b-instruct - provider_id: fireworks - provider_model_id: accounts/fireworks/models/llama-v3p3-70b-instruct - model_type: llm -- metadata: {} - model_id: meta-llama/Llama-3.3-70B-Instruct - provider_id: fireworks - provider_model_id: accounts/fireworks/models/llama-v3p3-70b-instruct - model_type: llm -- metadata: {} - model_id: accounts/fireworks/models/llama-guard-3-8b - provider_id: fireworks - provider_model_id: accounts/fireworks/models/llama-guard-3-8b - model_type: llm -- metadata: {} - model_id: meta-llama/Llama-Guard-3-8B - provider_id: fireworks - provider_model_id: accounts/fireworks/models/llama-guard-3-8b - model_type: llm -- metadata: {} - model_id: accounts/fireworks/models/llama-guard-3-11b-vision - provider_id: fireworks - provider_model_id: accounts/fireworks/models/llama-guard-3-11b-vision - model_type: llm -- metadata: {} - model_id: meta-llama/Llama-Guard-3-11B-Vision - provider_id: fireworks - provider_model_id: accounts/fireworks/models/llama-guard-3-11b-vision - model_type: llm -- metadata: {} - model_id: accounts/fireworks/models/llama4-scout-instruct-basic - provider_id: fireworks - provider_model_id: accounts/fireworks/models/llama4-scout-instruct-basic - model_type: llm -- metadata: {} - model_id: meta-llama/Llama-4-Scout-17B-16E-Instruct - provider_id: fireworks - provider_model_id: accounts/fireworks/models/llama4-scout-instruct-basic - model_type: llm -- metadata: {} - model_id: accounts/fireworks/models/llama4-maverick-instruct-basic - provider_id: fireworks - provider_model_id: accounts/fireworks/models/llama4-maverick-instruct-basic - model_type: llm -- metadata: {} - model_id: meta-llama/Llama-4-Maverick-17B-128E-Instruct - provider_id: fireworks - provider_model_id: accounts/fireworks/models/llama4-maverick-instruct-basic - model_type: llm -- metadata: - embedding_dimension: 768 - context_length: 8192 - model_id: nomic-ai/nomic-embed-text-v1.5 - provider_id: fireworks - provider_model_id: nomic-ai/nomic-embed-text-v1.5 - model_type: embedding -- metadata: {} - model_id: ${env.INFERENCE_MODEL} - provider_id: vllm-inference - model_type: llm -shields: -- shield_id: meta-llama/Llama-Guard-3-8B -vector_dbs: [] -datasets: [] -scoring_fns: [] -benchmarks: [] -tool_groups: -- toolgroup_id: builtin::websearch - provider_id: tavily-search -- toolgroup_id: builtin::rag - provider_id: rag-runtime -server: - port: 8321 diff --git a/llama_stack/templates/remote-vllm/build.yaml b/llama_stack/templates/remote-vllm/build.yaml index 16fe5d4fd..4baaaf9c8 100644 --- a/llama_stack/templates/remote-vllm/build.yaml +++ b/llama_stack/templates/remote-vllm/build.yaml @@ -31,6 +31,3 @@ distribution_spec: - remote::model-context-protocol - remote::wolfram-alpha image_type: conda -additional_pip_packages: -- aiosqlite -- sqlalchemy[asyncio] diff --git a/llama_stack/templates/remote-vllm/doc_template.md b/llama_stack/templates/remote-vllm/doc_template.md index 5684888da..3cede6080 100644 --- a/llama_stack/templates/remote-vllm/doc_template.md +++ b/llama_stack/templates/remote-vllm/doc_template.md @@ -220,7 +220,7 @@ docker run \ -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ -v ./llama_stack/templates/remote-vllm/run.yaml:/root/my-run.yaml \ llamastack/distribution-{{ name }} \ - --config /root/my-run.yaml \ + --yaml-config /root/my-run.yaml \ --port $LLAMA_STACK_PORT \ --env INFERENCE_MODEL=$INFERENCE_MODEL \ --env VLLM_URL=http://host.docker.internal:$INFERENCE_PORT/v1 @@ -242,7 +242,7 @@ docker run \ -v ~/.llama:/root/.llama \ -v ./llama_stack/templates/remote-vllm/run-with-safety.yaml:/root/my-run.yaml \ llamastack/distribution-{{ name }} \ - --config /root/my-run.yaml \ + --yaml-config /root/my-run.yaml \ --port $LLAMA_STACK_PORT \ --env INFERENCE_MODEL=$INFERENCE_MODEL \ --env VLLM_URL=http://host.docker.internal:$INFERENCE_PORT/v1 \ diff --git a/llama_stack/templates/remote-vllm/run-with-safety.yaml b/llama_stack/templates/remote-vllm/run-with-safety.yaml index e83162a4f..6931d4ba9 100644 --- a/llama_stack/templates/remote-vllm/run-with-safety.yaml +++ b/llama_stack/templates/remote-vllm/run-with-safety.yaml @@ -50,9 +50,6 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/agents_store.db - responses_store: - type: sqlite - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/responses_store.db eval: - provider_id: meta-reference provider_type: inline::meta-reference @@ -118,9 +115,6 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/registry.db -inference_store: - type: sqlite - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/inference_store.db models: - metadata: {} model_id: ${env.INFERENCE_MODEL} diff --git a/llama_stack/templates/remote-vllm/run.yaml b/llama_stack/templates/remote-vllm/run.yaml index 4cdf88c6b..05671165d 100644 --- a/llama_stack/templates/remote-vllm/run.yaml +++ b/llama_stack/templates/remote-vllm/run.yaml @@ -43,9 +43,6 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/agents_store.db - responses_store: - type: sqlite - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/responses_store.db eval: - provider_id: meta-reference provider_type: inline::meta-reference @@ -111,9 +108,6 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/registry.db -inference_store: - type: sqlite - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/inference_store.db models: - metadata: {} model_id: ${env.INFERENCE_MODEL} diff --git a/llama_stack/templates/sambanova/build.yaml b/llama_stack/templates/sambanova/build.yaml index 14b1c8974..81d90f420 100644 --- a/llama_stack/templates/sambanova/build.yaml +++ b/llama_stack/templates/sambanova/build.yaml @@ -1,6 +1,6 @@ version: '2' distribution_spec: - description: Use SambaNova for running LLM inference and safety + description: Use SambaNova for running LLM inference providers: inference: - remote::sambanova @@ -10,7 +10,7 @@ distribution_spec: - remote::chromadb - remote::pgvector safety: - - remote::sambanova + - inline::llama-guard agents: - inline::meta-reference telemetry: @@ -22,6 +22,3 @@ distribution_spec: - remote::model-context-protocol - remote::wolfram-alpha image_type: conda -additional_pip_packages: -- aiosqlite -- sqlalchemy[asyncio] diff --git a/llama_stack/templates/sambanova/doc_template.md b/llama_stack/templates/sambanova/doc_template.md index 1dc76fd3f..42d9efb66 100644 --- a/llama_stack/templates/sambanova/doc_template.md +++ b/llama_stack/templates/sambanova/doc_template.md @@ -37,44 +37,33 @@ The following models are available by default: ### Prerequisite: API Keys -Make sure you have access to a SambaNova API Key. You can get one by visiting [SambaNova.ai](http://cloud.sambanova.ai?utm_source=llamastack&utm_medium=external&utm_campaign=cloud_signup). +Make sure you have access to a SambaNova API Key. You can get one by visiting [SambaNova.ai](https://sambanova.ai/). ## Running Llama Stack with SambaNova You can do this via Conda (build code) or Docker which has a pre-built image. - ### Via Docker +This method allows you to get started quickly without having to build the distribution code. + ```bash LLAMA_STACK_PORT=8321 -llama stack build --template sambanova --image-type container docker run \ -it \ + --pull always \ -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ - -v ~/.llama:/root/.llama \ - distribution-{{ name }} \ + llamastack/distribution-{{ name }} \ --port $LLAMA_STACK_PORT \ --env SAMBANOVA_API_KEY=$SAMBANOVA_API_KEY ``` - -### Via Venv - -```bash -llama stack build --template sambanova --image-type venv -llama stack run --image-type venv ~/.llama/distributions/sambanova/sambanova-run.yaml \ - --port $LLAMA_STACK_PORT \ - --env SAMBANOVA_API_KEY=$SAMBANOVA_API_KEY -``` - - ### Via Conda ```bash llama stack build --template sambanova --image-type conda -llama stack run --image-type conda ~/.llama/distributions/sambanova/sambanova-run.yaml \ +llama stack run ./run.yaml \ --port $LLAMA_STACK_PORT \ --env SAMBANOVA_API_KEY=$SAMBANOVA_API_KEY ``` diff --git a/llama_stack/templates/sambanova/run.yaml b/llama_stack/templates/sambanova/run.yaml index 8c2a933ab..620d50307 100644 --- a/llama_stack/templates/sambanova/run.yaml +++ b/llama_stack/templates/sambanova/run.yaml @@ -38,11 +38,10 @@ providers: user: ${env.PGVECTOR_USER:} password: ${env.PGVECTOR_PASSWORD:} safety: - - provider_id: sambanova - provider_type: remote::sambanova + - provider_id: llama-guard + provider_type: inline::llama-guard config: - url: https://api.sambanova.ai/v1 - api_key: ${env.SAMBANOVA_API_KEY} + excluded_categories: [] agents: - provider_id: meta-reference provider_type: inline::meta-reference @@ -51,9 +50,6 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/sambanova}/agents_store.db - responses_store: - type: sqlite - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/sambanova}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference @@ -85,9 +81,6 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/sambanova}/registry.db -inference_store: - type: sqlite - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/sambanova}/inference_store.db models: - metadata: {} model_id: sambanova/Meta-Llama-3.1-8B-Instruct @@ -196,9 +189,6 @@ models: model_type: embedding shields: - shield_id: meta-llama/Llama-Guard-3-8B - provider_shield_id: sambanova/Meta-Llama-Guard-3-8B -- shield_id: sambanova/Meta-Llama-Guard-3-8B - provider_shield_id: sambanova/Meta-Llama-Guard-3-8B vector_dbs: [] datasets: [] scoring_fns: [] diff --git a/llama_stack/templates/sambanova/sambanova.py b/llama_stack/templates/sambanova/sambanova.py index 54a49423d..2f8a0b08a 100644 --- a/llama_stack/templates/sambanova/sambanova.py +++ b/llama_stack/templates/sambanova/sambanova.py @@ -34,7 +34,7 @@ def get_distribution_template() -> DistributionTemplate: providers = { "inference": ["remote::sambanova", "inline::sentence-transformers"], "vector_io": ["inline::faiss", "remote::chromadb", "remote::pgvector"], - "safety": ["remote::sambanova"], + "safety": ["inline::llama-guard"], "agents": ["inline::meta-reference"], "telemetry": ["inline::meta-reference"], "tool_runtime": [ @@ -110,7 +110,7 @@ def get_distribution_template() -> DistributionTemplate: return DistributionTemplate( name=name, distro_type="self_hosted", - description="Use SambaNova for running LLM inference and safety", + description="Use SambaNova for running LLM inference", container_image=None, template_path=Path(__file__).parent / "doc_template.md", providers=providers, @@ -122,15 +122,7 @@ def get_distribution_template() -> DistributionTemplate: "vector_io": vector_io_providers, }, default_models=default_models + [embedding_model], - default_shields=[ - ShieldInput( - shield_id="meta-llama/Llama-Guard-3-8B", provider_shield_id="sambanova/Meta-Llama-Guard-3-8B" - ), - ShieldInput( - shield_id="sambanova/Meta-Llama-Guard-3-8B", - provider_shield_id="sambanova/Meta-Llama-Guard-3-8B", - ), - ], + default_shields=[ShieldInput(shield_id="meta-llama/Llama-Guard-3-8B")], default_tool_groups=default_tool_groups, ), }, diff --git a/llama_stack/templates/starter/build.yaml b/llama_stack/templates/starter/build.yaml index ec97c7d3e..35bd0c713 100644 --- a/llama_stack/templates/starter/build.yaml +++ b/llama_stack/templates/starter/build.yaml @@ -35,6 +35,3 @@ distribution_spec: - inline::rag-runtime - remote::model-context-protocol image_type: conda -additional_pip_packages: -- aiosqlite -- sqlalchemy[asyncio] diff --git a/llama_stack/templates/starter/run.yaml b/llama_stack/templates/starter/run.yaml index 04425ed35..402695850 100644 --- a/llama_stack/templates/starter/run.yaml +++ b/llama_stack/templates/starter/run.yaml @@ -72,9 +72,6 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/starter}/agents_store.db - responses_store: - type: sqlite - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/starter}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference @@ -136,9 +133,6 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/starter}/registry.db -inference_store: - type: sqlite - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/starter}/inference_store.db models: - metadata: {} model_id: openai/gpt-4o diff --git a/llama_stack/templates/template.py b/llama_stack/templates/template.py index 4013f08f9..e4d28d904 100644 --- a/llama_stack/templates/template.py +++ b/llama_stack/templates/template.py @@ -28,8 +28,7 @@ from llama_stack.distribution.datatypes import ( from llama_stack.distribution.distribution import get_provider_registry from llama_stack.distribution.utils.dynamic import instantiate_class_type from llama_stack.providers.utils.inference.model_registry import ProviderModelEntry -from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig -from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig +from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig def get_model_registry( @@ -64,8 +63,6 @@ class RunConfigSettings(BaseModel): default_tool_groups: list[ToolGroupInput] | None = None default_datasets: list[DatasetInput] | None = None default_benchmarks: list[BenchmarkInput] | None = None - metadata_store: KVStoreConfig | None = None - inference_store: SqlStoreConfig | None = None def run_config( self, @@ -116,16 +113,10 @@ class RunConfigSettings(BaseModel): container_image=container_image, apis=apis, providers=provider_configs, - metadata_store=self.metadata_store - or SqliteKVStoreConfig.sample_run_config( + metadata_store=SqliteKVStoreConfig.sample_run_config( __distro_dir__=f"~/.llama/distributions/{name}", db_name="registry.db", ), - inference_store=self.inference_store - or SqliteSqlStoreConfig.sample_run_config( - __distro_dir__=f"~/.llama/distributions/{name}", - db_name="inference_store.db", - ), models=self.default_models or [], shields=self.default_shields or [], tool_groups=self.default_tool_groups or [], @@ -155,20 +146,14 @@ class DistributionTemplate(BaseModel): available_models_by_provider: dict[str, list[ProviderModelEntry]] | None = None def build_config(self) -> BuildConfig: - additional_pip_packages: list[str] = [] - for run_config in self.run_configs.values(): - run_config_ = run_config.run_config(self.name, self.providers, self.container_image) - if run_config_.inference_store: - additional_pip_packages.extend(run_config_.inference_store.pip_packages) - return BuildConfig( + name=self.name, distribution_spec=DistributionSpec( description=self.description, container_image=self.container_image, providers=self.providers, ), image_type="conda", # default to conda, can be overridden - additional_pip_packages=sorted(set(additional_pip_packages)), ) def generate_markdown_docs(self) -> str: diff --git a/llama_stack/templates/tgi/build.yaml b/llama_stack/templates/tgi/build.yaml index 361b0b680..d2ba1c3e9 100644 --- a/llama_stack/templates/tgi/build.yaml +++ b/llama_stack/templates/tgi/build.yaml @@ -30,6 +30,3 @@ distribution_spec: - inline::rag-runtime - remote::model-context-protocol image_type: conda -additional_pip_packages: -- aiosqlite -- sqlalchemy[asyncio] diff --git a/llama_stack/templates/tgi/doc_template.md b/llama_stack/templates/tgi/doc_template.md index 68b475893..b69ccaa56 100644 --- a/llama_stack/templates/tgi/doc_template.md +++ b/llama_stack/templates/tgi/doc_template.md @@ -105,7 +105,7 @@ docker run \ -v ~/.llama:/root/.llama \ -v ./llama_stack/templates/tgi/run-with-safety.yaml:/root/my-run.yaml \ llamastack/distribution-{{ name }} \ - --config /root/my-run.yaml \ + --yaml-config /root/my-run.yaml \ --port $LLAMA_STACK_PORT \ --env INFERENCE_MODEL=$INFERENCE_MODEL \ --env TGI_URL=http://host.docker.internal:$INFERENCE_PORT \ diff --git a/llama_stack/templates/tgi/run-with-safety.yaml b/llama_stack/templates/tgi/run-with-safety.yaml index c797b93aa..3255e9c0b 100644 --- a/llama_stack/templates/tgi/run-with-safety.yaml +++ b/llama_stack/templates/tgi/run-with-safety.yaml @@ -41,9 +41,6 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/tgi}/agents_store.db - responses_store: - type: sqlite - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/tgi}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference @@ -105,9 +102,6 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/tgi}/registry.db -inference_store: - type: sqlite - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/tgi}/inference_store.db models: - metadata: {} model_id: ${env.INFERENCE_MODEL} diff --git a/llama_stack/templates/tgi/run.yaml b/llama_stack/templates/tgi/run.yaml index 7e91d20bd..179087258 100644 --- a/llama_stack/templates/tgi/run.yaml +++ b/llama_stack/templates/tgi/run.yaml @@ -40,9 +40,6 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/tgi}/agents_store.db - responses_store: - type: sqlite - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/tgi}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference @@ -104,9 +101,6 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/tgi}/registry.db -inference_store: - type: sqlite - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/tgi}/inference_store.db models: - metadata: {} model_id: ${env.INFERENCE_MODEL} diff --git a/llama_stack/templates/together/build.yaml b/llama_stack/templates/together/build.yaml index 5ffeac873..b7338795c 100644 --- a/llama_stack/templates/together/build.yaml +++ b/llama_stack/templates/together/build.yaml @@ -31,6 +31,3 @@ distribution_spec: - remote::model-context-protocol - remote::wolfram-alpha image_type: conda -additional_pip_packages: -- aiosqlite -- sqlalchemy[asyncio] diff --git a/llama_stack/templates/together/run-with-safety.yaml b/llama_stack/templates/together/run-with-safety.yaml index 190a0400b..fe8c8e397 100644 --- a/llama_stack/templates/together/run-with-safety.yaml +++ b/llama_stack/templates/together/run-with-safety.yaml @@ -46,9 +46,6 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/together}/agents_store.db - responses_store: - type: sqlite - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/together}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference @@ -114,9 +111,6 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/together}/registry.db -inference_store: - type: sqlite - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/together}/inference_store.db models: - metadata: {} model_id: meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo diff --git a/llama_stack/templates/together/run.yaml b/llama_stack/templates/together/run.yaml index ce9542130..b903fc659 100644 --- a/llama_stack/templates/together/run.yaml +++ b/llama_stack/templates/together/run.yaml @@ -41,9 +41,6 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/together}/agents_store.db - responses_store: - type: sqlite - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/together}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference @@ -109,9 +106,6 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/together}/registry.db -inference_store: - type: sqlite - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/together}/inference_store.db models: - metadata: {} model_id: meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo diff --git a/llama_stack/templates/verification/build.yaml b/llama_stack/templates/verification/build.yaml index ce083dbba..aae24c3ca 100644 --- a/llama_stack/templates/verification/build.yaml +++ b/llama_stack/templates/verification/build.yaml @@ -35,6 +35,3 @@ distribution_spec: - inline::rag-runtime - remote::model-context-protocol image_type: conda -additional_pip_packages: -- aiosqlite -- sqlalchemy[asyncio] diff --git a/llama_stack/templates/verification/run.yaml b/llama_stack/templates/verification/run.yaml index 58b3c576c..11af41da9 100644 --- a/llama_stack/templates/verification/run.yaml +++ b/llama_stack/templates/verification/run.yaml @@ -74,9 +74,6 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/verification}/agents_store.db - responses_store: - type: sqlite - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/verification}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference @@ -138,9 +135,6 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/verification}/registry.db -inference_store: - type: sqlite - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/verification}/inference_store.db models: - metadata: {} model_id: openai/gpt-4o diff --git a/llama_stack/templates/vllm-gpu/build.yaml b/llama_stack/templates/vllm-gpu/build.yaml index d5ff0f1f4..53e257f22 100644 --- a/llama_stack/templates/vllm-gpu/build.yaml +++ b/llama_stack/templates/vllm-gpu/build.yaml @@ -30,6 +30,3 @@ distribution_spec: - inline::rag-runtime - remote::model-context-protocol image_type: conda -additional_pip_packages: -- aiosqlite -- sqlalchemy[asyncio] diff --git a/llama_stack/templates/vllm-gpu/run.yaml b/llama_stack/templates/vllm-gpu/run.yaml index 6937e2bac..5d3482528 100644 --- a/llama_stack/templates/vllm-gpu/run.yaml +++ b/llama_stack/templates/vllm-gpu/run.yaml @@ -45,9 +45,6 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/vllm-gpu}/agents_store.db - responses_store: - type: sqlite - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/vllm-gpu}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference @@ -109,9 +106,6 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/vllm-gpu}/registry.db -inference_store: - type: sqlite - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/vllm-gpu}/inference_store.db models: - metadata: {} model_id: ${env.INFERENCE_MODEL} diff --git a/llama_stack/templates/watsonx/build.yaml b/llama_stack/templates/watsonx/build.yaml index e68ace183..638b16029 100644 --- a/llama_stack/templates/watsonx/build.yaml +++ b/llama_stack/templates/watsonx/build.yaml @@ -28,6 +28,3 @@ distribution_spec: - inline::rag-runtime - remote::model-context-protocol image_type: conda -additional_pip_packages: -- aiosqlite -- sqlalchemy[asyncio] diff --git a/llama_stack/templates/watsonx/doc_template.md b/llama_stack/templates/watsonx/doc_template.md index f28dbf0bf..af0ae15a8 100644 --- a/llama_stack/templates/watsonx/doc_template.md +++ b/llama_stack/templates/watsonx/doc_template.md @@ -56,7 +56,7 @@ docker run \ -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ -v ./run.yaml:/root/my-run.yaml \ llamastack/distribution-{{ name }} \ - --config /root/my-run.yaml \ + --yaml-config /root/my-run.yaml \ --port $LLAMA_STACK_PORT \ --env WATSONX_API_KEY=$WATSONX_API_KEY \ --env WATSONX_PROJECT_ID=$WATSONX_PROJECT_ID \ diff --git a/llama_stack/templates/watsonx/run.yaml b/llama_stack/templates/watsonx/run.yaml index e7222fd57..8de6a2b6c 100644 --- a/llama_stack/templates/watsonx/run.yaml +++ b/llama_stack/templates/watsonx/run.yaml @@ -42,9 +42,6 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/watsonx}/agents_store.db - responses_store: - type: sqlite - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/watsonx}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference @@ -106,9 +103,6 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/watsonx}/registry.db -inference_store: - type: sqlite - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/watsonx}/inference_store.db models: - metadata: {} model_id: meta-llama/llama-3-3-70b-instruct diff --git a/llama_stack/ui/.prettierignore b/llama_stack/ui/.prettierignore deleted file mode 100644 index 1b8ac8894..000000000 --- a/llama_stack/ui/.prettierignore +++ /dev/null @@ -1,3 +0,0 @@ -# Ignore artifacts: -build -coverage diff --git a/llama_stack/ui/.prettierrc b/llama_stack/ui/.prettierrc deleted file mode 100644 index 0967ef424..000000000 --- a/llama_stack/ui/.prettierrc +++ /dev/null @@ -1 +0,0 @@ -{} diff --git a/llama_stack/ui/README.md b/llama_stack/ui/README.md index b6f803509..e3e21bf0b 100644 --- a/llama_stack/ui/README.md +++ b/llama_stack/ui/README.md @@ -1,5 +1,6 @@ ## This is WIP. + We use shadcdn/ui [Shadcn UI](https://ui.shadcn.com/) for the UI components. ## Getting Started @@ -7,7 +8,7 @@ We use shadcdn/ui [Shadcn UI](https://ui.shadcn.com/) for the UI components. First, install dependencies: ```bash -npm install +npm install next react react-dom ``` Then, run the development server: @@ -22,4 +23,4 @@ pnpm dev bun dev ``` -Open [http://localhost:8322](http://localhost:8322) with your browser to see the result. +Open [http://localhost:3000](http://localhost:3000) with your browser to see the result. diff --git a/llama_stack/ui/app/layout.tsx b/llama_stack/ui/app/layout.tsx index ed8a6cd5d..f029002dd 100644 --- a/llama_stack/ui/app/layout.tsx +++ b/llama_stack/ui/app/layout.tsx @@ -20,7 +20,7 @@ export const metadata: Metadata = { }; import { SidebarProvider, SidebarTrigger } from "@/components/ui/sidebar"; -import { AppSidebar } from "@/components/layout/app-sidebar"; +import { AppSidebar } from "@/components/app-sidebar"; export default function Layout({ children }: { children: React.ReactNode }) { return ( diff --git a/llama_stack/ui/app/logs/chat-completions/[id]/page.tsx b/llama_stack/ui/app/logs/chat-completions/[id]/page.tsx deleted file mode 100644 index e6feef363..000000000 --- a/llama_stack/ui/app/logs/chat-completions/[id]/page.tsx +++ /dev/null @@ -1,58 +0,0 @@ -"use client"; - -import { useEffect, useState } from "react"; -import { useParams } from "next/navigation"; -import { ChatCompletion } from "@/lib/types"; -import { ChatCompletionDetailView } from "@/components/chat-completions/chat-completion-detail"; -import { client } from "@/lib/client"; - -export default function ChatCompletionDetailPage() { - const params = useParams(); - const id = params.id as string; - - const [completionDetail, setCompletionDetail] = - useState(null); - const [isLoading, setIsLoading] = useState(true); - const [error, setError] = useState(null); - - useEffect(() => { - if (!id) { - setError(new Error("Completion ID is missing.")); - setIsLoading(false); - return; - } - - const fetchCompletionDetail = async () => { - setIsLoading(true); - setError(null); - setCompletionDetail(null); - try { - const response = await client.chat.completions.retrieve(id); - setCompletionDetail(response as ChatCompletion); - } catch (err) { - console.error( - `Error fetching chat completion detail for ID ${id}:`, - err, - ); - setError( - err instanceof Error - ? err - : new Error("Failed to fetch completion detail"), - ); - } finally { - setIsLoading(false); - } - }; - - fetchCompletionDetail(); - }, [id]); - - return ( - - ); -} diff --git a/llama_stack/ui/app/logs/chat-completions/layout.tsx b/llama_stack/ui/app/logs/chat-completions/layout.tsx deleted file mode 100644 index f4dbfc782..000000000 --- a/llama_stack/ui/app/logs/chat-completions/layout.tsx +++ /dev/null @@ -1,19 +0,0 @@ -"use client"; - -import React from "react"; -import LogsLayout from "@/components/layout/logs-layout"; - -export default function ChatCompletionsLayout({ - children, -}: { - children: React.ReactNode; -}) { - return ( - - {children} - - ); -} diff --git a/llama_stack/ui/app/logs/chat-completions/page.tsx b/llama_stack/ui/app/logs/chat-completions/page.tsx index 5bbfcce94..84cceb8b7 100644 --- a/llama_stack/ui/app/logs/chat-completions/page.tsx +++ b/llama_stack/ui/app/logs/chat-completions/page.tsx @@ -1,51 +1,7 @@ -"use client"; - -import { useEffect, useState } from "react"; -import { ChatCompletion } from "@/lib/types"; -import { ChatCompletionsTable } from "@/components/chat-completions/chat-completions-table"; -import { client } from "@/lib/client"; - -export default function ChatCompletionsPage() { - const [completions, setCompletions] = useState([]); - const [isLoading, setIsLoading] = useState(true); - const [error, setError] = useState(null); - - useEffect(() => { - const fetchCompletions = async () => { - setIsLoading(true); - setError(null); - try { - const response = await client.chat.completions.list(); - const data = Array.isArray(response) - ? response - : (response as { data: ChatCompletion[] }).data; - - if (Array.isArray(data)) { - setCompletions(data); - } else { - console.error("Unexpected response structure:", response); - setError(new Error("Unexpected response structure")); - setCompletions([]); - } - } catch (err) { - console.error("Error fetching chat completions:", err); - setError( - err instanceof Error ? err : new Error("Failed to fetch completions"), - ); - setCompletions([]); - } finally { - setIsLoading(false); - } - }; - - fetchCompletions(); - }, []); - +export default function ChatCompletions() { return ( - +
+

Under Construction

+
); } diff --git a/llama_stack/ui/app/logs/responses/[id]/page.tsx b/llama_stack/ui/app/logs/responses/[id]/page.tsx deleted file mode 100644 index efe6f0ff3..000000000 --- a/llama_stack/ui/app/logs/responses/[id]/page.tsx +++ /dev/null @@ -1,125 +0,0 @@ -"use client"; - -import { useEffect, useState } from "react"; -import { useParams } from "next/navigation"; -import type { ResponseObject } from "llama-stack-client/resources/responses/responses"; -import { OpenAIResponse, InputItemListResponse } from "@/lib/types"; -import { ResponseDetailView } from "@/components/responses/responses-detail"; -import { client } from "@/lib/client"; - -export default function ResponseDetailPage() { - const params = useParams(); - const id = params.id as string; - - const [responseDetail, setResponseDetail] = useState( - null, - ); - const [inputItems, setInputItems] = useState( - null, - ); - const [isLoading, setIsLoading] = useState(true); - const [isLoadingInputItems, setIsLoadingInputItems] = useState(true); - const [error, setError] = useState(null); - const [inputItemsError, setInputItemsError] = useState(null); - - // Helper function to convert ResponseObject to OpenAIResponse - const convertResponseObject = ( - responseData: ResponseObject, - ): OpenAIResponse => { - return { - id: responseData.id, - created_at: responseData.created_at, - model: responseData.model, - object: responseData.object, - status: responseData.status, - output: responseData.output as OpenAIResponse["output"], - input: [], // ResponseObject doesn't include input; component uses inputItems prop instead - error: responseData.error, - parallel_tool_calls: responseData.parallel_tool_calls, - previous_response_id: responseData.previous_response_id, - temperature: responseData.temperature, - top_p: responseData.top_p, - truncation: responseData.truncation, - user: responseData.user, - }; - }; - - useEffect(() => { - if (!id) { - setError(new Error("Response ID is missing.")); - setIsLoading(false); - return; - } - - const fetchResponseDetail = async () => { - setIsLoading(true); - setIsLoadingInputItems(true); - setError(null); - setInputItemsError(null); - setResponseDetail(null); - setInputItems(null); - - try { - const [responseResult, inputItemsResult] = await Promise.allSettled([ - client.responses.retrieve(id), - client.responses.inputItems.list(id, { order: "asc" }), - ]); - - // Handle response detail result - if (responseResult.status === "fulfilled") { - const convertedResponse = convertResponseObject(responseResult.value); - setResponseDetail(convertedResponse); - } else { - console.error( - `Error fetching response detail for ID ${id}:`, - responseResult.reason, - ); - setError( - responseResult.reason instanceof Error - ? responseResult.reason - : new Error("Failed to fetch response detail"), - ); - } - - // Handle input items result - if (inputItemsResult.status === "fulfilled") { - const inputItemsData = - inputItemsResult.value as unknown as InputItemListResponse; - setInputItems(inputItemsData); - } else { - console.error( - `Error fetching input items for response ID ${id}:`, - inputItemsResult.reason, - ); - setInputItemsError( - inputItemsResult.reason instanceof Error - ? inputItemsResult.reason - : new Error("Failed to fetch input items"), - ); - } - } catch (err) { - console.error(`Unexpected error fetching data for ID ${id}:`, err); - setError( - err instanceof Error ? err : new Error("Unexpected error occurred"), - ); - } finally { - setIsLoading(false); - setIsLoadingInputItems(false); - } - }; - - fetchResponseDetail(); - }, [id]); - - return ( - - ); -} diff --git a/llama_stack/ui/app/logs/responses/layout.tsx b/llama_stack/ui/app/logs/responses/layout.tsx deleted file mode 100644 index 1fe116e5e..000000000 --- a/llama_stack/ui/app/logs/responses/layout.tsx +++ /dev/null @@ -1,16 +0,0 @@ -"use client"; - -import React from "react"; -import LogsLayout from "@/components/layout/logs-layout"; - -export default function ResponsesLayout({ - children, -}: { - children: React.ReactNode; -}) { - return ( - - {children} - - ); -} diff --git a/llama_stack/ui/app/logs/responses/page.tsx b/llama_stack/ui/app/logs/responses/page.tsx index dab0c735f..cdc165d08 100644 --- a/llama_stack/ui/app/logs/responses/page.tsx +++ b/llama_stack/ui/app/logs/responses/page.tsx @@ -1,66 +1,7 @@ -"use client"; - -import { useEffect, useState } from "react"; -import type { ResponseListResponse } from "llama-stack-client/resources/responses/responses"; -import { OpenAIResponse } from "@/lib/types"; -import { ResponsesTable } from "@/components/responses/responses-table"; -import { client } from "@/lib/client"; - -export default function ResponsesPage() { - const [responses, setResponses] = useState([]); - const [isLoading, setIsLoading] = useState(true); - const [error, setError] = useState(null); - - // Helper function to convert ResponseListResponse.Data to OpenAIResponse - const convertResponseListData = ( - responseData: ResponseListResponse.Data, - ): OpenAIResponse => { - return { - id: responseData.id, - created_at: responseData.created_at, - model: responseData.model, - object: responseData.object, - status: responseData.status, - output: responseData.output as OpenAIResponse["output"], - input: responseData.input as OpenAIResponse["input"], - error: responseData.error, - parallel_tool_calls: responseData.parallel_tool_calls, - previous_response_id: responseData.previous_response_id, - temperature: responseData.temperature, - top_p: responseData.top_p, - truncation: responseData.truncation, - user: responseData.user, - }; - }; - - useEffect(() => { - const fetchResponses = async () => { - setIsLoading(true); - setError(null); - try { - const response = await client.responses.list(); - const responseListData = response as ResponseListResponse; - - const convertedResponses: OpenAIResponse[] = responseListData.data.map( - convertResponseListData, - ); - - setResponses(convertedResponses); - } catch (err) { - console.error("Error fetching responses:", err); - setError( - err instanceof Error ? err : new Error("Failed to fetch responses"), - ); - setResponses([]); - } finally { - setIsLoading(false); - } - }; - - fetchResponses(); - }, []); - +export default function Responses() { return ( - +
+

Under Construction

+
); } diff --git a/llama_stack/ui/components/layout/app-sidebar.tsx b/llama_stack/ui/components/app-sidebar.tsx similarity index 50% rename from llama_stack/ui/components/layout/app-sidebar.tsx rename to llama_stack/ui/components/app-sidebar.tsx index 1c53d6cc5..3d541856f 100644 --- a/llama_stack/ui/components/layout/app-sidebar.tsx +++ b/llama_stack/ui/components/app-sidebar.tsx @@ -1,9 +1,5 @@ -"use client"; - import { MessageSquareText, MessagesSquare, MoveUpRight } from "lucide-react"; import Link from "next/link"; -import { usePathname } from "next/navigation"; -import { cn } from "@/lib/utils"; import { Sidebar, @@ -36,8 +32,6 @@ const logItems = [ ]; export function AppSidebar() { - const pathname = usePathname(); - return ( @@ -48,31 +42,16 @@ export function AppSidebar() { Logs - {logItems.map((item) => { - const isActive = pathname.startsWith(item.url); - return ( - - - - - {item.title} - - - - ); - })} + {logItems.map((item) => ( + + + + + {item.title} + + + + ))} diff --git a/llama_stack/ui/components/chat-completions/chat-completion-detail.test.tsx b/llama_stack/ui/components/chat-completions/chat-completion-detail.test.tsx deleted file mode 100644 index 5348dbc3a..000000000 --- a/llama_stack/ui/components/chat-completions/chat-completion-detail.test.tsx +++ /dev/null @@ -1,193 +0,0 @@ -import React from "react"; -import { render, screen } from "@testing-library/react"; -import "@testing-library/jest-dom"; -import { ChatCompletionDetailView } from "./chat-completion-detail"; -import { ChatCompletion } from "@/lib/types"; - -// Initial test file setup for ChatCompletionDetailView - -describe("ChatCompletionDetailView", () => { - test("renders skeleton UI when isLoading is true", () => { - const { container } = render( - , - ); - // Use the data-slot attribute for Skeletons - const skeletons = container.querySelectorAll('[data-slot="skeleton"]'); - expect(skeletons.length).toBeGreaterThan(0); - }); - - test("renders error message when error prop is provided", () => { - render( - , - ); - expect( - screen.getByText(/Error loading details for ID err-id: Network Error/), - ).toBeInTheDocument(); - }); - - test("renders default error message when error.message is empty", () => { - render( - , - ); - // Use regex to match the error message regardless of whitespace - expect( - screen.getByText(/Error loading details for ID\s*err-id\s*:/), - ).toBeInTheDocument(); - }); - - test("renders error message when error prop is an object without message", () => { - render( - , - ); - // Use regex to match the error message regardless of whitespace - expect( - screen.getByText(/Error loading details for ID\s*err-id\s*:/), - ).toBeInTheDocument(); - }); - - test("renders not found message when completion is null and not loading/error", () => { - render( - , - ); - expect( - screen.getByText("No details found for ID: notfound-id."), - ).toBeInTheDocument(); - }); - - test("renders input, output, and properties for valid completion", () => { - const mockCompletion: ChatCompletion = { - id: "comp_123", - object: "chat.completion", - created: 1710000000, - model: "llama-test-model", - choices: [ - { - index: 0, - message: { role: "assistant", content: "Test output" }, - finish_reason: "stop", - }, - ], - input_messages: [{ role: "user", content: "Test input" }], - }; - render( - , - ); - // Input - expect(screen.getByText("Input")).toBeInTheDocument(); - expect(screen.getByText("Test input")).toBeInTheDocument(); - // Output - expect(screen.getByText("Output")).toBeInTheDocument(); - expect(screen.getByText("Test output")).toBeInTheDocument(); - // Properties - expect(screen.getByText("Properties")).toBeInTheDocument(); - expect(screen.getByText("Created:")).toBeInTheDocument(); - expect( - screen.getByText(new Date(1710000000 * 1000).toLocaleString()), - ).toBeInTheDocument(); - expect(screen.getByText("ID:")).toBeInTheDocument(); - expect(screen.getByText("comp_123")).toBeInTheDocument(); - expect(screen.getByText("Model:")).toBeInTheDocument(); - expect(screen.getByText("llama-test-model")).toBeInTheDocument(); - expect(screen.getByText("Finish Reason:")).toBeInTheDocument(); - expect(screen.getByText("stop")).toBeInTheDocument(); - }); - - test("renders tool call in output and properties when present", () => { - const toolCall = { - function: { name: "search", arguments: '{"query":"llama"}' }, - }; - const mockCompletion: ChatCompletion = { - id: "comp_tool", - object: "chat.completion", - created: 1710001000, - model: "llama-tool-model", - choices: [ - { - index: 0, - message: { - role: "assistant", - content: "Tool output", - tool_calls: [toolCall], - }, - finish_reason: "stop", - }, - ], - input_messages: [{ role: "user", content: "Tool input" }], - }; - render( - , - ); - // Output should include the tool call block (should be present twice: input and output) - const toolCallLabels = screen.getAllByText("Tool Call"); - expect(toolCallLabels.length).toBeGreaterThanOrEqual(1); // At least one, but could be two - // The tool call block should contain the formatted tool call string in both input and output - const toolCallBlocks = screen.getAllByText('search({"query":"llama"})'); - expect(toolCallBlocks.length).toBe(2); - // Properties should include the tool call name - expect(screen.getByText("Functions/Tools Called:")).toBeInTheDocument(); - expect(screen.getByText("search")).toBeInTheDocument(); - }); - - test("handles missing/empty fields gracefully", () => { - const mockCompletion: ChatCompletion = { - id: "comp_edge", - object: "chat.completion", - created: 1710002000, - model: "llama-edge-model", - choices: [], // No choices - input_messages: [], // No input messages - }; - render( - , - ); - // Input section should be present but empty - expect(screen.getByText("Input")).toBeInTheDocument(); - // Output section should show fallback message - expect( - screen.getByText("No message found in assistant's choice."), - ).toBeInTheDocument(); - // Properties should show N/A for finish reason - expect(screen.getByText("Finish Reason:")).toBeInTheDocument(); - expect(screen.getByText("N/A")).toBeInTheDocument(); - }); -}); diff --git a/llama_stack/ui/components/chat-completions/chat-completion-detail.tsx b/llama_stack/ui/components/chat-completions/chat-completion-detail.tsx deleted file mode 100644 index 200807864..000000000 --- a/llama_stack/ui/components/chat-completions/chat-completion-detail.tsx +++ /dev/null @@ -1,145 +0,0 @@ -"use client"; - -import { ChatMessage, ChatCompletion } from "@/lib/types"; -import { ChatMessageItem } from "@/components/chat-completions/chat-messasge-item"; -import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card"; -import { - DetailLoadingView, - DetailErrorView, - DetailNotFoundView, - DetailLayout, - PropertiesCard, - PropertyItem, -} from "@/components/layout/detail-layout"; - -interface ChatCompletionDetailViewProps { - completion: ChatCompletion | null; - isLoading: boolean; - error: Error | null; - id: string; -} - -export function ChatCompletionDetailView({ - completion, - isLoading, - error, - id, -}: ChatCompletionDetailViewProps) { - const title = "Chat Completion Details"; - - if (error) { - return ; - } - - if (isLoading) { - return ; - } - - if (!completion) { - return ; - } - - // Main content cards - const mainContent = ( - <> - - - Input - - - {completion.input_messages?.map((msg, index) => ( - - ))} - {completion.choices?.[0]?.message?.tool_calls && - Array.isArray(completion.choices[0].message.tool_calls) && - !completion.input_messages?.some( - (im) => - im.role === "assistant" && - im.tool_calls && - Array.isArray(im.tool_calls) && - im.tool_calls.length > 0, - ) - ? completion.choices[0].message.tool_calls.map( - (toolCall: any, index: number) => { - const assistantToolCallMessage: ChatMessage = { - role: "assistant", - tool_calls: [toolCall], - content: "", // Ensure content is defined, even if empty - }; - return ( - - ); - }, - ) - : null} - - - - - - Output - - - {completion.choices?.[0]?.message ? ( - - ) : ( -

- No message found in assistant's choice. -

- )} -
-
- - ); - - // Properties sidebar - const sidebar = ( - - - - - - {(() => { - const toolCalls = completion.choices?.[0]?.message?.tool_calls; - if (toolCalls && Array.isArray(toolCalls) && toolCalls.length > 0) { - return ( - -
    - {toolCalls.map((toolCall: any, index: number) => ( -
  • - - {toolCall.function?.name || "N/A"} - -
  • - ))} -
- - } - hasBorder - /> - ); - } - return null; - })()} -
- ); - - return ( - - ); -} diff --git a/llama_stack/ui/components/chat-completions/chat-completion-table.test.tsx b/llama_stack/ui/components/chat-completions/chat-completion-table.test.tsx deleted file mode 100644 index c8a55b100..000000000 --- a/llama_stack/ui/components/chat-completions/chat-completion-table.test.tsx +++ /dev/null @@ -1,347 +0,0 @@ -import React from "react"; -import { render, screen, fireEvent } from "@testing-library/react"; -import "@testing-library/jest-dom"; -import { ChatCompletionsTable } from "./chat-completions-table"; -import { ChatCompletion } from "@/lib/types"; - -// Mock next/navigation -const mockPush = jest.fn(); -jest.mock("next/navigation", () => ({ - useRouter: () => ({ - push: mockPush, - }), -})); - -// Mock helper functions -jest.mock("@/lib/truncate-text"); -jest.mock("@/lib/format-message-content"); - -// Import the mocked functions to set up default or specific implementations -import { truncateText as originalTruncateText } from "@/lib/truncate-text"; -import { - extractTextFromContentPart as originalExtractTextFromContentPart, - extractDisplayableText as originalExtractDisplayableText, -} from "@/lib/format-message-content"; - -// Cast to jest.Mock for typings -const truncateText = originalTruncateText as jest.Mock; -const extractTextFromContentPart = - originalExtractTextFromContentPart as jest.Mock; -const extractDisplayableText = originalExtractDisplayableText as jest.Mock; - -describe("ChatCompletionsTable", () => { - const defaultProps = { - data: [] as ChatCompletion[], - isLoading: false, - error: null, - }; - - beforeEach(() => { - // Reset all mocks before each test - mockPush.mockClear(); - truncateText.mockClear(); - extractTextFromContentPart.mockClear(); - extractDisplayableText.mockClear(); - - // Default pass-through implementations - truncateText.mockImplementation((text: string | undefined) => text); - extractTextFromContentPart.mockImplementation((content: unknown) => - typeof content === "string" ? content : "extracted text", - ); - extractDisplayableText.mockImplementation( - (message: unknown) => - (message as { content?: string })?.content || "extracted output", - ); - }); - - test("renders without crashing with default props", () => { - render(); - expect(screen.getByText("No chat completions found.")).toBeInTheDocument(); - }); - - test("click on a row navigates to the correct URL", () => { - const mockCompletion: ChatCompletion = { - id: "comp_123", - object: "chat.completion", - created: Math.floor(Date.now() / 1000), - model: "llama-test-model", - choices: [ - { - index: 0, - message: { role: "assistant", content: "Test output" }, - finish_reason: "stop", - }, - ], - input_messages: [{ role: "user", content: "Test input" }], - }; - - // Set up mocks to return expected values - extractTextFromContentPart.mockReturnValue("Test input"); - extractDisplayableText.mockReturnValue("Test output"); - - render(); - - const row = screen.getByText("Test input").closest("tr"); - if (row) { - fireEvent.click(row); - expect(mockPush).toHaveBeenCalledWith("/logs/chat-completions/comp_123"); - } else { - throw new Error('Row with "Test input" not found for router mock test.'); - } - }); - - describe("Loading State", () => { - test("renders skeleton UI when isLoading is true", () => { - const { container } = render( - , - ); - - // Check for skeleton in the table caption - const tableCaption = container.querySelector("caption"); - expect(tableCaption).toBeInTheDocument(); - if (tableCaption) { - const captionSkeleton = tableCaption.querySelector( - '[data-slot="skeleton"]', - ); - expect(captionSkeleton).toBeInTheDocument(); - } - - // Check for skeletons in the table body cells - const tableBody = container.querySelector("tbody"); - expect(tableBody).toBeInTheDocument(); - if (tableBody) { - const bodySkeletons = tableBody.querySelectorAll( - '[data-slot="skeleton"]', - ); - expect(bodySkeletons.length).toBeGreaterThan(0); - } - }); - }); - - describe("Error State", () => { - test("renders error message when error prop is provided", () => { - const errorMessage = "Network Error"; - render( - , - ); - expect( - screen.getByText(`Error fetching data: ${errorMessage}`), - ).toBeInTheDocument(); - }); - - test("renders default error message when error.message is not available", () => { - render( - , - ); - expect( - screen.getByText("Error fetching data: An unknown error occurred"), - ).toBeInTheDocument(); - }); - - test("renders default error message when error prop is an object without message", () => { - render(); - expect( - screen.getByText("Error fetching data: An unknown error occurred"), - ).toBeInTheDocument(); - }); - }); - - describe("Empty State", () => { - test('renders "No chat completions found." and no table when data array is empty', () => { - render(); - expect( - screen.getByText("No chat completions found."), - ).toBeInTheDocument(); - - // Ensure that the table structure is NOT rendered in the empty state - const table = screen.queryByRole("table"); - expect(table).not.toBeInTheDocument(); - }); - }); - - describe("Data Rendering", () => { - test("renders table caption, headers, and completion data correctly", () => { - const mockCompletions = [ - { - id: "comp_1", - object: "chat.completion", - created: 1710000000, - model: "llama-test-model", - choices: [ - { - index: 0, - message: { role: "assistant", content: "Test output" }, - finish_reason: "stop", - }, - ], - input_messages: [{ role: "user", content: "Test input" }], - }, - { - id: "comp_2", - object: "chat.completion", - created: 1710001000, - model: "llama-another-model", - choices: [ - { - index: 0, - message: { role: "assistant", content: "Another output" }, - finish_reason: "stop", - }, - ], - input_messages: [{ role: "user", content: "Another input" }], - }, - ]; - - // Set up mocks to return expected values - extractTextFromContentPart.mockImplementation((content: unknown) => { - if (content === "Test input") return "Test input"; - if (content === "Another input") return "Another input"; - return "extracted text"; - }); - extractDisplayableText.mockImplementation((message: unknown) => { - const msg = message as { content?: string }; - if (msg?.content === "Test output") return "Test output"; - if (msg?.content === "Another output") return "Another output"; - return "extracted output"; - }); - - render( - , - ); - - // Table caption - expect( - screen.getByText("A list of your recent chat completions."), - ).toBeInTheDocument(); - - // Table headers - expect(screen.getByText("Input")).toBeInTheDocument(); - expect(screen.getByText("Output")).toBeInTheDocument(); - expect(screen.getByText("Model")).toBeInTheDocument(); - expect(screen.getByText("Created")).toBeInTheDocument(); - - // Data rows - expect(screen.getByText("Test input")).toBeInTheDocument(); - expect(screen.getByText("Test output")).toBeInTheDocument(); - expect(screen.getByText("llama-test-model")).toBeInTheDocument(); - expect( - screen.getByText(new Date(1710000000 * 1000).toLocaleString()), - ).toBeInTheDocument(); - - expect(screen.getByText("Another input")).toBeInTheDocument(); - expect(screen.getByText("Another output")).toBeInTheDocument(); - expect(screen.getByText("llama-another-model")).toBeInTheDocument(); - expect( - screen.getByText(new Date(1710001000 * 1000).toLocaleString()), - ).toBeInTheDocument(); - }); - }); - - describe("Text Truncation and Content Extraction", () => { - test("truncates long input and output text", () => { - // Specific mock implementation for this test - truncateText.mockImplementation( - (text: string | undefined, maxLength?: number) => { - const defaultTestMaxLength = 10; - const effectiveMaxLength = maxLength ?? defaultTestMaxLength; - return typeof text === "string" && text.length > effectiveMaxLength - ? text.slice(0, effectiveMaxLength) + "..." - : text; - }, - ); - - const longInput = - "This is a very long input message that should be truncated."; - const longOutput = - "This is a very long output message that should also be truncated."; - - extractTextFromContentPart.mockReturnValue(longInput); - extractDisplayableText.mockReturnValue(longOutput); - - const mockCompletions = [ - { - id: "comp_trunc", - object: "chat.completion", - created: 1710002000, - model: "llama-trunc-model", - choices: [ - { - index: 0, - message: { role: "assistant", content: longOutput }, - finish_reason: "stop", - }, - ], - input_messages: [{ role: "user", content: longInput }], - }, - ]; - - render( - , - ); - - // The truncated text should be present for both input and output - const truncatedTexts = screen.getAllByText( - longInput.slice(0, 10) + "...", - ); - expect(truncatedTexts.length).toBe(2); // one for input, one for output - truncatedTexts.forEach((textElement) => - expect(textElement).toBeInTheDocument(), - ); - }); - - test("uses content extraction functions correctly", () => { - const mockCompletion = { - id: "comp_extract", - object: "chat.completion", - created: 1710003000, - model: "llama-extract-model", - choices: [ - { - index: 0, - message: { role: "assistant", content: "Extracted output" }, - finish_reason: "stop", - }, - ], - input_messages: [{ role: "user", content: "Extracted input" }], - }; - - extractTextFromContentPart.mockReturnValue("Extracted input"); - extractDisplayableText.mockReturnValue("Extracted output"); - - render( - , - ); - - // Verify the extraction functions were called - expect(extractTextFromContentPart).toHaveBeenCalledWith( - "Extracted input", - ); - expect(extractDisplayableText).toHaveBeenCalledWith({ - role: "assistant", - content: "Extracted output", - }); - - // Verify the extracted content is displayed - expect(screen.getByText("Extracted input")).toBeInTheDocument(); - expect(screen.getByText("Extracted output")).toBeInTheDocument(); - }); - }); -}); diff --git a/llama_stack/ui/components/chat-completions/chat-completions-table.tsx b/llama_stack/ui/components/chat-completions/chat-completions-table.tsx deleted file mode 100644 index 5f1d2f03d..000000000 --- a/llama_stack/ui/components/chat-completions/chat-completions-table.tsx +++ /dev/null @@ -1,43 +0,0 @@ -"use client"; - -import { ChatCompletion } from "@/lib/types"; -import { LogsTable, LogTableRow } from "@/components/logs/logs-table"; -import { - extractTextFromContentPart, - extractDisplayableText, -} from "@/lib/format-message-content"; - -interface ChatCompletionsTableProps { - data: ChatCompletion[]; - isLoading: boolean; - error: Error | null; -} - -function formatChatCompletionToRow(completion: ChatCompletion): LogTableRow { - return { - id: completion.id, - input: extractTextFromContentPart(completion.input_messages?.[0]?.content), - output: extractDisplayableText(completion.choices?.[0]?.message), - model: completion.model, - createdTime: new Date(completion.created * 1000).toLocaleString(), - detailPath: `/logs/chat-completions/${completion.id}`, - }; -} - -export function ChatCompletionsTable({ - data, - isLoading, - error, -}: ChatCompletionsTableProps) { - const formattedData = data.map(formatChatCompletionToRow); - - return ( - - ); -} diff --git a/llama_stack/ui/components/chat-completions/chat-messasge-item.tsx b/llama_stack/ui/components/chat-completions/chat-messasge-item.tsx deleted file mode 100644 index 2e8593bfb..000000000 --- a/llama_stack/ui/components/chat-completions/chat-messasge-item.tsx +++ /dev/null @@ -1,76 +0,0 @@ -"use client"; - -import { ChatMessage } from "@/lib/types"; -import React from "react"; -import { formatToolCallToString } from "@/lib/format-tool-call"; -import { extractTextFromContentPart } from "@/lib/format-message-content"; -import { - MessageBlock, - ToolCallBlock, -} from "@/components/ui/message-components"; - -interface ChatMessageItemProps { - message: ChatMessage; -} -export function ChatMessageItem({ message }: ChatMessageItemProps) { - switch (message.role) { - case "system": - return ( - - ); - case "user": - return ( - - ); - - case "assistant": - if ( - message.tool_calls && - Array.isArray(message.tool_calls) && - message.tool_calls.length > 0 - ) { - return ( - <> - {message.tool_calls.map((toolCall: any, index: number) => { - const formattedToolCall = formatToolCallToString(toolCall); - const toolCallContent = ( - - {formattedToolCall || "Error: Could not display tool call"} - - ); - return ( - - ); - })} - - ); - } else { - return ( - - ); - } - case "tool": - const toolOutputContent = ( - - {extractTextFromContentPart(message.content)} - - ); - return ( - - ); - } - return null; -} diff --git a/llama_stack/ui/components/layout/detail-layout.tsx b/llama_stack/ui/components/layout/detail-layout.tsx deleted file mode 100644 index 58b912703..000000000 --- a/llama_stack/ui/components/layout/detail-layout.tsx +++ /dev/null @@ -1,141 +0,0 @@ -import React from "react"; -import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card"; -import { Skeleton } from "@/components/ui/skeleton"; - -export function DetailLoadingView({ title }: { title: string }) { - return ( - <> - {/* Title Skeleton */} -
-
- {[...Array(2)].map((_, i) => ( - - - - - - - - - - - - - ))} -
-
-
- {" "} - {/* Properties Title Skeleton */} - {[...Array(5)].map((_, i) => ( -
- - -
- ))} -
-
-
- - ); -} - -export function DetailErrorView({ - title, - id, - error, -}: { - title: string; - id: string; - error: Error; -}) { - return ( - <> -

{title}

-

- Error loading details for ID {id}: {error.message} -

- - ); -} - -export function DetailNotFoundView({ - title, - id, -}: { - title: string; - id: string; -}) { - return ( - <> -

{title}

-

No details found for ID: {id}.

- - ); -} - -export interface PropertyItemProps { - label: string; - value: React.ReactNode; - className?: string; - hasBorder?: boolean; -} - -export function PropertyItem({ - label, - value, - className = "", - hasBorder = false, -}: PropertyItemProps) { - return ( -
  • - {label}:{" "} - {typeof value === "string" || typeof value === "number" ? ( - {value} - ) : ( - value - )} -
  • - ); -} - -export interface PropertiesCardProps { - children: React.ReactNode; -} - -export function PropertiesCard({ children }: PropertiesCardProps) { - return ( - - - Properties - - -
      {children}
    -
    -
    - ); -} - -export interface DetailLayoutProps { - title: string; - mainContent: React.ReactNode; - sidebar: React.ReactNode; -} - -export function DetailLayout({ - title, - mainContent, - sidebar, -}: DetailLayoutProps) { - return ( - <> -

    {title}

    -
    -
    {mainContent}
    -
    {sidebar}
    -
    - - ); -} diff --git a/llama_stack/ui/components/layout/logs-layout.tsx b/llama_stack/ui/components/layout/logs-layout.tsx deleted file mode 100644 index 468ad6e9a..000000000 --- a/llama_stack/ui/components/layout/logs-layout.tsx +++ /dev/null @@ -1,49 +0,0 @@ -"use client"; - -import React from "react"; -import { usePathname, useParams } from "next/navigation"; -import { - PageBreadcrumb, - BreadcrumbSegment, -} from "@/components/layout/page-breadcrumb"; -import { truncateText } from "@/lib/truncate-text"; - -interface LogsLayoutProps { - children: React.ReactNode; - sectionLabel: string; - basePath: string; -} - -export default function LogsLayout({ - children, - sectionLabel, - basePath, -}: LogsLayoutProps) { - const pathname = usePathname(); - const params = useParams(); - - let segments: BreadcrumbSegment[] = []; - - if (pathname === basePath) { - segments = [{ label: sectionLabel }]; - } - - const idParam = params?.id; - if (idParam && typeof idParam === "string") { - segments = [ - { label: sectionLabel, href: basePath }, - { label: `Details (${truncateText(idParam, 20)})` }, - ]; - } - - return ( -
    - <> - {segments.length > 0 && ( - - )} - {children} - -
    - ); -} diff --git a/llama_stack/ui/components/layout/page-breadcrumb.tsx b/llama_stack/ui/components/layout/page-breadcrumb.tsx deleted file mode 100644 index fdb561d68..000000000 --- a/llama_stack/ui/components/layout/page-breadcrumb.tsx +++ /dev/null @@ -1,49 +0,0 @@ -"use client"; - -import Link from "next/link"; -import React from "react"; -import { - Breadcrumb, - BreadcrumbItem, - BreadcrumbLink, - BreadcrumbList, - BreadcrumbPage, - BreadcrumbSeparator, -} from "@/components/ui/breadcrumb"; - -export interface BreadcrumbSegment { - label: string; - href?: string; -} - -interface PageBreadcrumbProps { - segments: BreadcrumbSegment[]; - className?: string; -} - -export function PageBreadcrumb({ segments, className }: PageBreadcrumbProps) { - if (!segments || segments.length === 0) { - return null; - } - - return ( - - - {segments.map((segment, index) => ( - - - {segment.href ? ( - - {segment.label} - - ) : ( - {segment.label} - )} - - {index < segments.length - 1 && } - - ))} - - - ); -} diff --git a/llama_stack/ui/components/logs/logs-table.test.tsx b/llama_stack/ui/components/logs/logs-table.test.tsx deleted file mode 100644 index 88263b2fc..000000000 --- a/llama_stack/ui/components/logs/logs-table.test.tsx +++ /dev/null @@ -1,350 +0,0 @@ -import React from "react"; -import { render, screen, fireEvent } from "@testing-library/react"; -import "@testing-library/jest-dom"; -import { LogsTable, LogTableRow } from "./logs-table"; - -// Mock next/navigation -const mockPush = jest.fn(); -jest.mock("next/navigation", () => ({ - useRouter: () => ({ - push: mockPush, - }), -})); - -// Mock helper functions -jest.mock("@/lib/truncate-text"); - -// Import the mocked functions -import { truncateText as originalTruncateText } from "@/lib/truncate-text"; - -// Cast to jest.Mock for typings -const truncateText = originalTruncateText as jest.Mock; - -describe("LogsTable", () => { - const defaultProps = { - data: [] as LogTableRow[], - isLoading: false, - error: null, - caption: "Test table caption", - emptyMessage: "No data found", - }; - - beforeEach(() => { - // Reset all mocks before each test - mockPush.mockClear(); - truncateText.mockClear(); - - // Default pass-through implementation - truncateText.mockImplementation((text: string | undefined) => text); - }); - - test("renders without crashing with default props", () => { - render(); - expect(screen.getByText("No data found")).toBeInTheDocument(); - }); - - test("click on a row navigates to the correct URL", () => { - const mockData: LogTableRow[] = [ - { - id: "row_123", - input: "Test input", - output: "Test output", - model: "test-model", - createdTime: "2024-01-01 12:00:00", - detailPath: "/test/path/row_123", - }, - ]; - - render(); - - const row = screen.getByText("Test input").closest("tr"); - if (row) { - fireEvent.click(row); - expect(mockPush).toHaveBeenCalledWith("/test/path/row_123"); - } else { - throw new Error('Row with "Test input" not found for router mock test.'); - } - }); - - describe("Loading State", () => { - test("renders skeleton UI when isLoading is true", () => { - const { container } = render( - , - ); - - // Check for skeleton in the table caption - const tableCaption = container.querySelector("caption"); - expect(tableCaption).toBeInTheDocument(); - if (tableCaption) { - const captionSkeleton = tableCaption.querySelector( - '[data-slot="skeleton"]', - ); - expect(captionSkeleton).toBeInTheDocument(); - } - - // Check for skeletons in the table body cells - const tableBody = container.querySelector("tbody"); - expect(tableBody).toBeInTheDocument(); - if (tableBody) { - const bodySkeletons = tableBody.querySelectorAll( - '[data-slot="skeleton"]', - ); - expect(bodySkeletons.length).toBeGreaterThan(0); - } - - // Check that table headers are still rendered - expect(screen.getByText("Input")).toBeInTheDocument(); - expect(screen.getByText("Output")).toBeInTheDocument(); - expect(screen.getByText("Model")).toBeInTheDocument(); - expect(screen.getByText("Created")).toBeInTheDocument(); - }); - - test("renders correct number of skeleton rows", () => { - const { container } = render( - , - ); - - const skeletonRows = container.querySelectorAll("tbody tr"); - expect(skeletonRows.length).toBe(3); // Should render 3 skeleton rows - }); - }); - - describe("Error State", () => { - test("renders error message when error prop is provided", () => { - const errorMessage = "Network Error"; - render( - , - ); - expect( - screen.getByText(`Error fetching data: ${errorMessage}`), - ).toBeInTheDocument(); - }); - - test("renders default error message when error.message is not available", () => { - render( - , - ); - expect( - screen.getByText("Error fetching data: An unknown error occurred"), - ).toBeInTheDocument(); - }); - - test("renders default error message when error prop is an object without message", () => { - render(); - expect( - screen.getByText("Error fetching data: An unknown error occurred"), - ).toBeInTheDocument(); - }); - - test("does not render table when in error state", () => { - render( - , - ); - const table = screen.queryByRole("table"); - expect(table).not.toBeInTheDocument(); - }); - }); - - describe("Empty State", () => { - test("renders custom empty message when data array is empty", () => { - render( - , - ); - expect(screen.getByText("Custom empty message")).toBeInTheDocument(); - - // Ensure that the table structure is NOT rendered in the empty state - const table = screen.queryByRole("table"); - expect(table).not.toBeInTheDocument(); - }); - }); - - describe("Data Rendering", () => { - test("renders table caption, headers, and data correctly", () => { - const mockData: LogTableRow[] = [ - { - id: "row_1", - input: "First input", - output: "First output", - model: "model-1", - createdTime: "2024-01-01 12:00:00", - detailPath: "/path/1", - }, - { - id: "row_2", - input: "Second input", - output: "Second output", - model: "model-2", - createdTime: "2024-01-02 13:00:00", - detailPath: "/path/2", - }, - ]; - - render( - , - ); - - // Table caption - expect(screen.getByText("Custom table caption")).toBeInTheDocument(); - - // Table headers - expect(screen.getByText("Input")).toBeInTheDocument(); - expect(screen.getByText("Output")).toBeInTheDocument(); - expect(screen.getByText("Model")).toBeInTheDocument(); - expect(screen.getByText("Created")).toBeInTheDocument(); - - // Data rows - expect(screen.getByText("First input")).toBeInTheDocument(); - expect(screen.getByText("First output")).toBeInTheDocument(); - expect(screen.getByText("model-1")).toBeInTheDocument(); - expect(screen.getByText("2024-01-01 12:00:00")).toBeInTheDocument(); - - expect(screen.getByText("Second input")).toBeInTheDocument(); - expect(screen.getByText("Second output")).toBeInTheDocument(); - expect(screen.getByText("model-2")).toBeInTheDocument(); - expect(screen.getByText("2024-01-02 13:00:00")).toBeInTheDocument(); - }); - - test("applies correct CSS classes to table rows", () => { - const mockData: LogTableRow[] = [ - { - id: "row_1", - input: "Test input", - output: "Test output", - model: "test-model", - createdTime: "2024-01-01 12:00:00", - detailPath: "/test/path", - }, - ]; - - render(); - - const row = screen.getByText("Test input").closest("tr"); - expect(row).toHaveClass("cursor-pointer"); - expect(row).toHaveClass("hover:bg-muted/50"); - }); - - test("applies correct alignment to Created column", () => { - const mockData: LogTableRow[] = [ - { - id: "row_1", - input: "Test input", - output: "Test output", - model: "test-model", - createdTime: "2024-01-01 12:00:00", - detailPath: "/test/path", - }, - ]; - - render(); - - const createdCell = screen.getByText("2024-01-01 12:00:00").closest("td"); - expect(createdCell).toHaveClass("text-right"); - }); - }); - - describe("Text Truncation", () => { - test("truncates input and output text using truncateText function", () => { - // Mock truncateText to return truncated versions - truncateText.mockImplementation((text: string | undefined) => { - if (typeof text === "string" && text.length > 10) { - return text.slice(0, 10) + "..."; - } - return text; - }); - - const longInput = - "This is a very long input text that should be truncated"; - const longOutput = - "This is a very long output text that should be truncated"; - - const mockData: LogTableRow[] = [ - { - id: "row_1", - input: longInput, - output: longOutput, - model: "test-model", - createdTime: "2024-01-01 12:00:00", - detailPath: "/test/path", - }, - ]; - - render(); - - // Verify truncateText was called - expect(truncateText).toHaveBeenCalledWith(longInput); - expect(truncateText).toHaveBeenCalledWith(longOutput); - - // Verify truncated text is displayed - const truncatedTexts = screen.getAllByText("This is a ..."); - expect(truncatedTexts).toHaveLength(2); // one for input, one for output - truncatedTexts.forEach((textElement) => - expect(textElement).toBeInTheDocument(), - ); - }); - - test("does not truncate model names", () => { - const mockData: LogTableRow[] = [ - { - id: "row_1", - input: "Test input", - output: "Test output", - model: "very-long-model-name-that-should-not-be-truncated", - createdTime: "2024-01-01 12:00:00", - detailPath: "/test/path", - }, - ]; - - render(); - - // Model name should not be passed to truncateText - expect(truncateText).not.toHaveBeenCalledWith( - "very-long-model-name-that-should-not-be-truncated", - ); - - // Full model name should be displayed - expect( - screen.getByText("very-long-model-name-that-should-not-be-truncated"), - ).toBeInTheDocument(); - }); - }); - - describe("Accessibility", () => { - test("table has proper role and structure", () => { - const mockData: LogTableRow[] = [ - { - id: "row_1", - input: "Test input", - output: "Test output", - model: "test-model", - createdTime: "2024-01-01 12:00:00", - detailPath: "/test/path", - }, - ]; - - render(); - - const table = screen.getByRole("table"); - expect(table).toBeInTheDocument(); - - const columnHeaders = screen.getAllByRole("columnheader"); - expect(columnHeaders).toHaveLength(4); - - const rows = screen.getAllByRole("row"); - expect(rows).toHaveLength(2); // 1 header row + 1 data row - }); - }); -}); diff --git a/llama_stack/ui/components/logs/logs-table.tsx b/llama_stack/ui/components/logs/logs-table.tsx deleted file mode 100644 index 33afea61b..000000000 --- a/llama_stack/ui/components/logs/logs-table.tsx +++ /dev/null @@ -1,113 +0,0 @@ -"use client"; - -import { useRouter } from "next/navigation"; -import { truncateText } from "@/lib/truncate-text"; -import { - Table, - TableBody, - TableCaption, - TableCell, - TableHead, - TableHeader, - TableRow, -} from "@/components/ui/table"; -import { Skeleton } from "@/components/ui/skeleton"; - -// Generic table row data interface -export interface LogTableRow { - id: string; - input: string; - output: string; - model: string; - createdTime: string; - detailPath: string; -} - -interface LogsTableProps { - data: LogTableRow[]; - isLoading: boolean; - error: Error | null; - caption: string; - emptyMessage: string; -} - -export function LogsTable({ - data, - isLoading, - error, - caption, - emptyMessage, -}: LogsTableProps) { - const router = useRouter(); - - const tableHeader = ( - - - Input - Output - Model - Created - - - ); - - if (isLoading) { - return ( - - - - - {tableHeader} - - {[...Array(3)].map((_, i) => ( - - - - - - - - - - - - - - - ))} - -
    - ); - } - - if (error) { - return ( -

    Error fetching data: {error.message || "An unknown error occurred"}

    - ); - } - - if (data.length === 0) { - return

    {emptyMessage}

    ; - } - - return ( - - {caption} - {tableHeader} - - {data.map((row) => ( - router.push(row.detailPath)} - className="cursor-pointer hover:bg-muted/50" - > - {truncateText(row.input)} - {truncateText(row.output)} - {row.model} - {row.createdTime} - - ))} - -
    - ); -} diff --git a/llama_stack/ui/components/responses/grouping/grouped-items-display.tsx b/llama_stack/ui/components/responses/grouping/grouped-items-display.tsx deleted file mode 100644 index 6ddc0eacc..000000000 --- a/llama_stack/ui/components/responses/grouping/grouped-items-display.tsx +++ /dev/null @@ -1,56 +0,0 @@ -import { useFunctionCallGrouping } from "../hooks/function-call-grouping"; -import { ItemRenderer } from "../items/item-renderer"; -import { GroupedFunctionCallItemComponent } from "../items/grouped-function-call-item"; -import { - isFunctionCallItem, - isFunctionCallOutputItem, - AnyResponseItem, -} from "../utils/item-types"; - -interface GroupedItemsDisplayProps { - items: AnyResponseItem[]; - keyPrefix: string; - defaultRole?: string; -} - -export function GroupedItemsDisplay({ - items, - keyPrefix, - defaultRole = "unknown", -}: GroupedItemsDisplayProps) { - const groupedItems = useFunctionCallGrouping(items); - - return ( - <> - {groupedItems.map((groupedItem) => { - // If this is a function call with an output, render the grouped component - if ( - groupedItem.outputItem && - isFunctionCallItem(groupedItem.item) && - isFunctionCallOutputItem(groupedItem.outputItem) - ) { - return ( - - ); - } - - // Otherwise, render the individual item - return ( - - ); - })} - - ); -} diff --git a/llama_stack/ui/components/responses/hooks/function-call-grouping.ts b/llama_stack/ui/components/responses/hooks/function-call-grouping.ts deleted file mode 100644 index 2994354d5..000000000 --- a/llama_stack/ui/components/responses/hooks/function-call-grouping.ts +++ /dev/null @@ -1,92 +0,0 @@ -import { useMemo } from "react"; -import { - isFunctionCallOutputItem, - AnyResponseItem, - FunctionCallOutputItem, -} from "../utils/item-types"; - -export interface GroupedItem { - item: AnyResponseItem; - index: number; - outputItem?: AnyResponseItem; - outputIndex?: number; -} - -/** - * Hook to group function calls with their corresponding outputs - * @param items Array of items to group - * @returns Array of grouped items with their outputs - */ -export function useFunctionCallGrouping( - items: AnyResponseItem[], -): GroupedItem[] { - return useMemo(() => { - const groupedItems: GroupedItem[] = []; - const processedIndices = new Set(); - - // Build a map of call_id to indices for function_call_output items - const callIdToIndices = new Map(); - - for (let i = 0; i < items.length; i++) { - const item = items[i]; - if (isFunctionCallOutputItem(item)) { - if (!callIdToIndices.has(item.call_id)) { - callIdToIndices.set(item.call_id, []); - } - callIdToIndices.get(item.call_id)!.push(i); - } - } - - // Process items and group function calls with their outputs - for (let i = 0; i < items.length; i++) { - if (processedIndices.has(i)) { - continue; - } - - const currentItem = items[i]; - - if ( - currentItem.type === "function_call" && - "name" in currentItem && - "call_id" in currentItem - ) { - const functionCallId = currentItem.call_id as string; - let outputIndex = -1; - let outputItem: FunctionCallOutputItem | null = null; - - const relatedIndices = callIdToIndices.get(functionCallId) || []; - for (const idx of relatedIndices) { - const potentialOutput = items[idx]; - outputIndex = idx; - outputItem = potentialOutput as FunctionCallOutputItem; - break; - } - - if (outputItem && outputIndex !== -1) { - // Group function call with its function_call_output - groupedItems.push({ - item: currentItem, - index: i, - outputItem, - outputIndex, - }); - - // Mark both items as processed - processedIndices.add(i); - processedIndices.add(outputIndex); - - // Matching function call and output found, skip to next item - continue; - } - } - // render normally - groupedItems.push({ - item: currentItem, - index: i, - }); - processedIndices.add(i); - } - - return groupedItems; - }, [items]); -} diff --git a/llama_stack/ui/components/responses/items/function-call-item.tsx b/llama_stack/ui/components/responses/items/function-call-item.tsx deleted file mode 100644 index beca935f0..000000000 --- a/llama_stack/ui/components/responses/items/function-call-item.tsx +++ /dev/null @@ -1,29 +0,0 @@ -import { - MessageBlock, - ToolCallBlock, -} from "@/components/ui/message-components"; -import { FunctionCallItem } from "../utils/item-types"; - -interface FunctionCallItemProps { - item: FunctionCallItem; - index: number; - keyPrefix: string; -} - -export function FunctionCallItemComponent({ - item, - index, - keyPrefix, -}: FunctionCallItemProps) { - const name = item.name || "unknown"; - const args = item.arguments || "{}"; - const formattedFunctionCall = `${name}(${args})`; - - return ( - {formattedFunctionCall}} - /> - ); -} diff --git a/llama_stack/ui/components/responses/items/generic-item.tsx b/llama_stack/ui/components/responses/items/generic-item.tsx deleted file mode 100644 index 6b6f56603..000000000 --- a/llama_stack/ui/components/responses/items/generic-item.tsx +++ /dev/null @@ -1,37 +0,0 @@ -import { - MessageBlock, - ToolCallBlock, -} from "@/components/ui/message-components"; -import { BaseItem } from "../utils/item-types"; - -interface GenericItemProps { - item: BaseItem; - index: number; - keyPrefix: string; -} - -export function GenericItemComponent({ - item, - index, - keyPrefix, -}: GenericItemProps) { - // Handle other types like function calls, tool outputs, etc. - const itemData = item as Record; - - const content = itemData.content - ? typeof itemData.content === "string" - ? itemData.content - : JSON.stringify(itemData.content, null, 2) - : JSON.stringify(itemData, null, 2); - - const label = keyPrefix === "input" ? "Input" : "Output"; - - return ( - {content}} - /> - ); -} diff --git a/llama_stack/ui/components/responses/items/grouped-function-call-item.tsx b/llama_stack/ui/components/responses/items/grouped-function-call-item.tsx deleted file mode 100644 index ded0ced71..000000000 --- a/llama_stack/ui/components/responses/items/grouped-function-call-item.tsx +++ /dev/null @@ -1,54 +0,0 @@ -import { - MessageBlock, - ToolCallBlock, -} from "@/components/ui/message-components"; -import { FunctionCallItem, FunctionCallOutputItem } from "../utils/item-types"; - -interface GroupedFunctionCallItemProps { - functionCall: FunctionCallItem; - output: FunctionCallOutputItem; - index: number; - keyPrefix: string; -} - -export function GroupedFunctionCallItemComponent({ - functionCall, - output, - index, - keyPrefix, -}: GroupedFunctionCallItemProps) { - const name = functionCall.name || "unknown"; - const args = functionCall.arguments || "{}"; - - // Extract the output content from function_call_output - let outputContent = ""; - if (output.output) { - outputContent = - typeof output.output === "string" - ? output.output - : JSON.stringify(output.output); - } else { - outputContent = JSON.stringify(output, null, 2); - } - - const functionCallContent = ( -
    -
    - Arguments - {`${name}(${args})`} -
    -
    - Output - {outputContent} -
    -
    - ); - - return ( - - ); -} diff --git a/llama_stack/ui/components/responses/items/index.ts b/llama_stack/ui/components/responses/items/index.ts deleted file mode 100644 index d7bcc2ea4..000000000 --- a/llama_stack/ui/components/responses/items/index.ts +++ /dev/null @@ -1,6 +0,0 @@ -export { MessageItemComponent } from "./message-item"; -export { FunctionCallItemComponent } from "./function-call-item"; -export { WebSearchItemComponent } from "./web-search-item"; -export { GenericItemComponent } from "./generic-item"; -export { GroupedFunctionCallItemComponent } from "./grouped-function-call-item"; -export { ItemRenderer } from "./item-renderer"; diff --git a/llama_stack/ui/components/responses/items/item-renderer.tsx b/llama_stack/ui/components/responses/items/item-renderer.tsx deleted file mode 100644 index 8f65d50c4..000000000 --- a/llama_stack/ui/components/responses/items/item-renderer.tsx +++ /dev/null @@ -1,60 +0,0 @@ -import { - isMessageItem, - isFunctionCallItem, - isWebSearchCallItem, - AnyResponseItem, -} from "../utils/item-types"; -import { MessageItemComponent } from "./message-item"; -import { FunctionCallItemComponent } from "./function-call-item"; -import { WebSearchItemComponent } from "./web-search-item"; -import { GenericItemComponent } from "./generic-item"; - -interface ItemRendererProps { - item: AnyResponseItem; - index: number; - keyPrefix: string; - defaultRole?: string; -} - -export function ItemRenderer({ - item, - index, - keyPrefix, - defaultRole = "unknown", -}: ItemRendererProps) { - if (isMessageItem(item)) { - return ( - - ); - } - - if (isFunctionCallItem(item)) { - return ( - - ); - } - - if (isWebSearchCallItem(item)) { - return ( - - ); - } - - // Fallback to generic item for unknown types - return ( - - ); -} diff --git a/llama_stack/ui/components/responses/items/message-item.tsx b/llama_stack/ui/components/responses/items/message-item.tsx deleted file mode 100644 index 532fddfaa..000000000 --- a/llama_stack/ui/components/responses/items/message-item.tsx +++ /dev/null @@ -1,41 +0,0 @@ -import { MessageBlock } from "@/components/ui/message-components"; -import { MessageItem } from "../utils/item-types"; - -interface MessageItemProps { - item: MessageItem; - index: number; - keyPrefix: string; - defaultRole?: string; -} - -export function MessageItemComponent({ - item, - index, - keyPrefix, - defaultRole = "unknown", -}: MessageItemProps) { - let content = ""; - - if (typeof item.content === "string") { - content = item.content; - } else if (Array.isArray(item.content)) { - content = item.content - .map((c) => { - return c.type === "input_text" || c.type === "output_text" - ? c.text - : JSON.stringify(c); - }) - .join(" "); - } - - const role = item.role || defaultRole; - const label = role.charAt(0).toUpperCase() + role.slice(1); - - return ( - - ); -} diff --git a/llama_stack/ui/components/responses/items/web-search-item.tsx b/llama_stack/ui/components/responses/items/web-search-item.tsx deleted file mode 100644 index aaa5741ce..000000000 --- a/llama_stack/ui/components/responses/items/web-search-item.tsx +++ /dev/null @@ -1,28 +0,0 @@ -import { - MessageBlock, - ToolCallBlock, -} from "@/components/ui/message-components"; -import { WebSearchCallItem } from "../utils/item-types"; - -interface WebSearchItemProps { - item: WebSearchCallItem; - index: number; - keyPrefix: string; -} - -export function WebSearchItemComponent({ - item, - index, - keyPrefix, -}: WebSearchItemProps) { - const formattedWebSearch = `web_search_call(status: ${item.status})`; - - return ( - {formattedWebSearch}} - /> - ); -} diff --git a/llama_stack/ui/components/responses/responses-detail.test.tsx b/llama_stack/ui/components/responses/responses-detail.test.tsx deleted file mode 100644 index f426dc059..000000000 --- a/llama_stack/ui/components/responses/responses-detail.test.tsx +++ /dev/null @@ -1,777 +0,0 @@ -import React from "react"; -import { render, screen } from "@testing-library/react"; -import "@testing-library/jest-dom"; -import { ResponseDetailView } from "./responses-detail"; -import { OpenAIResponse, InputItemListResponse } from "@/lib/types"; - -describe("ResponseDetailView", () => { - const defaultProps = { - response: null, - inputItems: null, - isLoading: false, - isLoadingInputItems: false, - error: null, - inputItemsError: null, - id: "test_id", - }; - - describe("Loading State", () => { - test("renders loading skeleton when isLoading is true", () => { - const { container } = render( - , - ); - - // Check for skeleton elements - const skeletons = container.querySelectorAll('[data-slot="skeleton"]'); - expect(skeletons.length).toBeGreaterThan(0); - - // The title is replaced by a skeleton when loading, so we shouldn't expect the text - }); - }); - - describe("Error State", () => { - test("renders error message when error prop is provided", () => { - const errorMessage = "Network Error"; - render( - , - ); - - expect(screen.getByText("Responses Details")).toBeInTheDocument(); - // The error message is split across elements, so we check for parts - expect( - screen.getByText(/Error loading details for ID/), - ).toBeInTheDocument(); - expect(screen.getByText(/test_id/)).toBeInTheDocument(); - expect(screen.getByText(/Network Error/)).toBeInTheDocument(); - }); - - test("renders default error message when error.message is not available", () => { - render( - , - ); - - expect( - screen.getByText(/Error loading details for ID/), - ).toBeInTheDocument(); - expect(screen.getByText(/test_id/)).toBeInTheDocument(); - }); - }); - - describe("Not Found State", () => { - test("renders not found message when response is null and not loading/error", () => { - render(); - - expect(screen.getByText("Responses Details")).toBeInTheDocument(); - // The message is split across elements - expect(screen.getByText(/No details found for ID:/)).toBeInTheDocument(); - expect(screen.getByText(/test_id/)).toBeInTheDocument(); - }); - }); - - describe("Response Data Rendering", () => { - const mockResponse: OpenAIResponse = { - id: "resp_123", - object: "response", - created_at: 1710000000, - model: "llama-test-model", - status: "completed", - output: [ - { - type: "message", - role: "assistant", - content: "Test response output", - }, - ], - input: [ - { - type: "message", - role: "user", - content: "Test input message", - }, - ], - temperature: 0.7, - top_p: 0.9, - parallel_tool_calls: true, - previous_response_id: "prev_resp_456", - }; - - test("renders response data with input and output sections", () => { - render(); - - // Check main sections - expect(screen.getByText("Responses Details")).toBeInTheDocument(); - expect(screen.getByText("Input")).toBeInTheDocument(); - expect(screen.getByText("Output")).toBeInTheDocument(); - - // Check input content - expect(screen.getByText("Test input message")).toBeInTheDocument(); - expect(screen.getByText("User")).toBeInTheDocument(); - - // Check output content - expect(screen.getByText("Test response output")).toBeInTheDocument(); - expect(screen.getByText("Assistant")).toBeInTheDocument(); - }); - - test("renders properties sidebar with all response metadata", () => { - render(); - - // Check properties - use regex to handle text split across elements - expect(screen.getByText(/Created/)).toBeInTheDocument(); - expect( - screen.getByText(new Date(1710000000 * 1000).toLocaleString()), - ).toBeInTheDocument(); - - // Check for the specific ID label (not Previous Response ID) - expect( - screen.getByText((content, element) => { - return element?.tagName === "STRONG" && content === "ID:"; - }), - ).toBeInTheDocument(); - expect(screen.getByText("resp_123")).toBeInTheDocument(); - - expect(screen.getByText(/Model/)).toBeInTheDocument(); - expect(screen.getByText("llama-test-model")).toBeInTheDocument(); - - expect(screen.getByText(/Status/)).toBeInTheDocument(); - expect(screen.getByText("completed")).toBeInTheDocument(); - - expect(screen.getByText(/Temperature/)).toBeInTheDocument(); - expect(screen.getByText("0.7")).toBeInTheDocument(); - - expect(screen.getByText(/Top P/)).toBeInTheDocument(); - expect(screen.getByText("0.9")).toBeInTheDocument(); - - expect(screen.getByText(/Parallel Tool Calls/)).toBeInTheDocument(); - expect(screen.getByText("Yes")).toBeInTheDocument(); - - expect(screen.getByText(/Previous Response ID/)).toBeInTheDocument(); - expect(screen.getByText("prev_resp_456")).toBeInTheDocument(); - }); - - test("handles optional properties correctly", () => { - const minimalResponse: OpenAIResponse = { - id: "resp_minimal", - object: "response", - created_at: 1710000000, - model: "test-model", - status: "completed", - output: [], - input: [], - }; - - render( - , - ); - - // Should show required properties - expect(screen.getByText("resp_minimal")).toBeInTheDocument(); - expect(screen.getByText("test-model")).toBeInTheDocument(); - expect(screen.getByText("completed")).toBeInTheDocument(); - - // Should not show optional properties - expect(screen.queryByText("Temperature")).not.toBeInTheDocument(); - expect(screen.queryByText("Top P")).not.toBeInTheDocument(); - expect(screen.queryByText("Parallel Tool Calls")).not.toBeInTheDocument(); - expect( - screen.queryByText("Previous Response ID"), - ).not.toBeInTheDocument(); - }); - - test("renders error information when response has error", () => { - const errorResponse: OpenAIResponse = { - ...mockResponse, - error: { - code: "invalid_request", - message: "The request was invalid", - }, - }; - - render(); - - // The error is shown in the properties sidebar, not as a separate "Error" label - expect( - screen.getByText("invalid_request: The request was invalid"), - ).toBeInTheDocument(); - }); - }); - - describe("Input Items Handling", () => { - const mockResponse: OpenAIResponse = { - id: "resp_123", - object: "response", - created_at: 1710000000, - model: "test-model", - status: "completed", - output: [{ type: "message", role: "assistant", content: "output" }], - input: [{ type: "message", role: "user", content: "fallback input" }], - }; - - test("shows loading state for input items", () => { - render( - , - ); - - // Check for skeleton loading in input items section - const { container } = render( - , - ); - - const skeletons = container.querySelectorAll('[data-slot="skeleton"]'); - expect(skeletons.length).toBeGreaterThan(0); - }); - - test("shows error message for input items with fallback", () => { - render( - , - ); - - expect( - screen.getByText( - "Error loading input items: Failed to load input items", - ), - ).toBeInTheDocument(); - expect( - screen.getByText("Falling back to response input data."), - ).toBeInTheDocument(); - - // Should still show fallback input data - expect(screen.getByText("fallback input")).toBeInTheDocument(); - }); - - test("uses input items data when available", () => { - const mockInputItems: InputItemListResponse = { - object: "list", - data: [ - { - type: "message", - role: "user", - content: "input from items API", - }, - ], - }; - - render( - , - ); - - // Should show input items data, not response.input - expect(screen.getByText("input from items API")).toBeInTheDocument(); - expect(screen.queryByText("fallback input")).not.toBeInTheDocument(); - }); - - test("falls back to response.input when input items is empty", () => { - const emptyInputItems: InputItemListResponse = { - object: "list", - data: [], - }; - - render( - , - ); - - // Should show fallback input data - expect(screen.getByText("fallback input")).toBeInTheDocument(); - }); - - test("shows no input message when no data available", () => { - const responseWithoutInput: OpenAIResponse = { - ...mockResponse, - input: [], - }; - - render( - , - ); - - expect(screen.getByText("No input data available.")).toBeInTheDocument(); - }); - }); - - describe("Input Display Components", () => { - test("renders string content input correctly", () => { - const mockResponse: OpenAIResponse = { - id: "resp_123", - object: "response", - created_at: 1710000000, - model: "test-model", - status: "completed", - output: [], - input: [ - { - type: "message", - role: "user", - content: "Simple string input", - }, - ], - }; - - render(); - - expect(screen.getByText("Simple string input")).toBeInTheDocument(); - expect(screen.getByText("User")).toBeInTheDocument(); - }); - - test("renders array content input correctly", () => { - const mockResponse: OpenAIResponse = { - id: "resp_123", - object: "response", - created_at: 1710000000, - model: "test-model", - status: "completed", - output: [], - input: [ - { - type: "message", - role: "user", - content: [ - { type: "input_text", text: "First part" }, - { type: "output_text", text: "Second part" }, - ], - }, - ], - }; - - render(); - - expect(screen.getByText("First part Second part")).toBeInTheDocument(); - expect(screen.getByText("User")).toBeInTheDocument(); - }); - - test("renders non-message input types correctly", () => { - const mockResponse: OpenAIResponse = { - id: "resp_123", - object: "response", - created_at: 1710000000, - model: "test-model", - status: "completed", - output: [], - input: [ - { - type: "function_call", - content: "function call content", - }, - ], - }; - - render(); - - expect(screen.getByText("function call content")).toBeInTheDocument(); - // Use getAllByText to find the specific "Input" with the type detail - const inputElements = screen.getAllByText("Input"); - expect(inputElements.length).toBeGreaterThan(0); - expect(screen.getByText("(function_call)")).toBeInTheDocument(); - }); - - test("handles input with object content", () => { - const mockResponse: OpenAIResponse = { - id: "resp_123", - object: "response", - created_at: 1710000000, - model: "test-model", - status: "completed", - output: [], - input: [ - { - type: "custom_type", - content: JSON.stringify({ key: "value", nested: { data: "test" } }), - }, - ], - }; - - render(); - - // Should show JSON stringified content (without quotes around keys in the rendered output) - expect(screen.getByText(/key.*value/)).toBeInTheDocument(); - // Use getAllByText to find the specific "Input" with the type detail - const inputElements = screen.getAllByText("Input"); - expect(inputElements.length).toBeGreaterThan(0); - expect(screen.getByText("(custom_type)")).toBeInTheDocument(); - }); - - test("renders function call input correctly", () => { - const mockResponse: OpenAIResponse = { - id: "resp_123", - object: "response", - created_at: 1710000000, - model: "test-model", - status: "completed", - output: [], - input: [ - { - type: "function_call", - id: "call_456", - status: "completed", - name: "input_function", - arguments: '{"param": "value"}', - }, - ], - }; - - render(); - - expect( - screen.getByText('input_function({"param": "value"})'), - ).toBeInTheDocument(); - expect(screen.getByText("Function Call")).toBeInTheDocument(); - }); - - test("renders web search call input correctly", () => { - const mockResponse: OpenAIResponse = { - id: "resp_123", - object: "response", - created_at: 1710000000, - model: "test-model", - status: "completed", - output: [], - input: [ - { - type: "web_search_call", - id: "search_789", - status: "completed", - }, - ], - }; - - render(); - - expect( - screen.getByText("web_search_call(status: completed)"), - ).toBeInTheDocument(); - expect(screen.getByText("Function Call")).toBeInTheDocument(); - expect(screen.getByText("(Web Search)")).toBeInTheDocument(); - }); - }); - - describe("Output Display Components", () => { - test("renders message output with string content", () => { - const mockResponse: OpenAIResponse = { - id: "resp_123", - object: "response", - created_at: 1710000000, - model: "test-model", - status: "completed", - output: [ - { - type: "message", - role: "assistant", - content: "Simple string output", - }, - ], - input: [], - }; - - render(); - - expect(screen.getByText("Simple string output")).toBeInTheDocument(); - expect(screen.getByText("Assistant")).toBeInTheDocument(); - }); - - test("renders message output with array content", () => { - const mockResponse: OpenAIResponse = { - id: "resp_123", - object: "response", - created_at: 1710000000, - model: "test-model", - status: "completed", - output: [ - { - type: "message", - role: "assistant", - content: [ - { type: "output_text", text: "First output" }, - { type: "input_text", text: "Second output" }, - ], - }, - ], - input: [], - }; - - render(); - - expect( - screen.getByText("First output Second output"), - ).toBeInTheDocument(); - expect(screen.getByText("Assistant")).toBeInTheDocument(); - }); - - test("renders function call output correctly", () => { - const mockResponse: OpenAIResponse = { - id: "resp_123", - object: "response", - created_at: 1710000000, - model: "test-model", - status: "completed", - output: [ - { - type: "function_call", - id: "call_123", - status: "completed", - name: "search_function", - arguments: '{"query": "test"}', - }, - ], - input: [], - }; - - render(); - - expect( - screen.getByText('search_function({"query": "test"})'), - ).toBeInTheDocument(); - expect(screen.getByText("Function Call")).toBeInTheDocument(); - }); - - test("renders function call output without arguments", () => { - const mockResponse: OpenAIResponse = { - id: "resp_123", - object: "response", - created_at: 1710000000, - model: "test-model", - status: "completed", - output: [ - { - type: "function_call", - id: "call_123", - status: "completed", - name: "simple_function", - }, - ], - input: [], - }; - - render(); - - expect(screen.getByText("simple_function({})")).toBeInTheDocument(); - expect(screen.getByText(/Function Call/)).toBeInTheDocument(); - }); - - test("renders web search call output correctly", () => { - const mockResponse: OpenAIResponse = { - id: "resp_123", - object: "response", - created_at: 1710000000, - model: "test-model", - status: "completed", - output: [ - { - type: "web_search_call", - id: "search_123", - status: "completed", - }, - ], - input: [], - }; - - render(); - - expect( - screen.getByText("web_search_call(status: completed)"), - ).toBeInTheDocument(); - expect(screen.getByText(/Function Call/)).toBeInTheDocument(); - expect(screen.getByText("(Web Search)")).toBeInTheDocument(); - }); - - test("renders unknown output types with JSON fallback", () => { - const mockResponse: OpenAIResponse = { - id: "resp_123", - object: "response", - created_at: 1710000000, - model: "test-model", - status: "completed", - output: [ - { - type: "unknown_type", - custom_field: "custom_value", - data: { nested: "object" }, - } as any, - ], - input: [], - }; - - render(); - - // Should show JSON stringified content - expect( - screen.getByText(/custom_field.*custom_value/), - ).toBeInTheDocument(); - expect(screen.getByText("(unknown_type)")).toBeInTheDocument(); - }); - - test("shows no output message when output array is empty", () => { - const mockResponse: OpenAIResponse = { - id: "resp_123", - object: "response", - created_at: 1710000000, - model: "test-model", - status: "completed", - output: [], - input: [], - }; - - render(); - - expect(screen.getByText("No output data available.")).toBeInTheDocument(); - }); - - test("groups function call with its output correctly", () => { - const mockResponse: OpenAIResponse = { - id: "resp_123", - object: "response", - created_at: 1710000000, - model: "test-model", - status: "completed", - output: [ - { - type: "function_call", - id: "call_123", - status: "completed", - name: "get_weather", - arguments: '{"city": "Tokyo"}', - }, - { - type: "message", - role: "assistant", - call_id: "call_123", - content: "sunny and warm", - } as any, // Using any to bypass the type restriction for this test - ], - input: [], - }; - - render(); - - // Should show the function call and message as separate items (not grouped) - expect(screen.getByText("Function Call")).toBeInTheDocument(); - expect( - screen.getByText('get_weather({"city": "Tokyo"})'), - ).toBeInTheDocument(); - expect(screen.getByText("Assistant")).toBeInTheDocument(); - expect(screen.getByText("sunny and warm")).toBeInTheDocument(); - - // Should NOT have the grouped "Arguments" and "Output" labels - expect(screen.queryByText("Arguments")).not.toBeInTheDocument(); - }); - - test("groups function call with function_call_output correctly", () => { - const mockResponse: OpenAIResponse = { - id: "resp_123", - object: "response", - created_at: 1710000000, - model: "test-model", - status: "completed", - output: [ - { - type: "function_call", - call_id: "call_123", - status: "completed", - name: "get_weather", - arguments: '{"city": "Tokyo"}', - }, - { - type: "function_call_output", - id: "fc_68364957013081...", - status: "completed", - call_id: "call_123", - output: "sunny and warm", - } as any, // Using any to bypass the type restriction for this test - ], - input: [], - }; - - render(); - - // Should show the function call grouped with its clean output - expect(screen.getByText("Function Call")).toBeInTheDocument(); - expect(screen.getByText("Arguments")).toBeInTheDocument(); - expect( - screen.getByText('get_weather({"city": "Tokyo"})'), - ).toBeInTheDocument(); - // Use getAllByText since there are multiple "Output" elements (card title and output label) - const outputElements = screen.getAllByText("Output"); - expect(outputElements.length).toBeGreaterThan(0); - expect(screen.getByText("sunny and warm")).toBeInTheDocument(); - }); - }); - - describe("Edge Cases and Error Handling", () => { - test("handles missing role in message input", () => { - const mockResponse: OpenAIResponse = { - id: "resp_123", - object: "response", - created_at: 1710000000, - model: "test-model", - status: "completed", - output: [], - input: [ - { - type: "message", - content: "Message without role", - }, - ], - }; - - render(); - - expect(screen.getByText("Message without role")).toBeInTheDocument(); - expect(screen.getByText("Unknown")).toBeInTheDocument(); // Default role - }); - - test("handles missing name in function call output", () => { - const mockResponse: OpenAIResponse = { - id: "resp_123", - object: "response", - created_at: 1710000000, - model: "test-model", - status: "completed", - output: [ - { - type: "function_call", - id: "call_123", - status: "completed", - }, - ], - input: [], - }; - - render(); - - // When name is missing, it falls back to JSON.stringify of the entire output - const functionCallElements = screen.getAllByText(/function_call/); - expect(functionCallElements.length).toBeGreaterThan(0); - expect(screen.getByText(/call_123/)).toBeInTheDocument(); - }); - }); -}); diff --git a/llama_stack/ui/components/responses/responses-detail.tsx b/llama_stack/ui/components/responses/responses-detail.tsx deleted file mode 100644 index c8c447ba4..000000000 --- a/llama_stack/ui/components/responses/responses-detail.tsx +++ /dev/null @@ -1,171 +0,0 @@ -"use client"; - -import { OpenAIResponse, InputItemListResponse } from "@/lib/types"; -import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card"; -import { Skeleton } from "@/components/ui/skeleton"; -import { - DetailLoadingView, - DetailErrorView, - DetailNotFoundView, - DetailLayout, - PropertiesCard, - PropertyItem, -} from "@/components/layout/detail-layout"; -import { GroupedItemsDisplay } from "./grouping/grouped-items-display"; - -interface ResponseDetailViewProps { - response: OpenAIResponse | null; - inputItems: InputItemListResponse | null; - isLoading: boolean; - isLoadingInputItems: boolean; - error: Error | null; - inputItemsError: Error | null; - id: string; -} - -export function ResponseDetailView({ - response, - inputItems, - isLoading, - isLoadingInputItems, - error, - inputItemsError, - id, -}: ResponseDetailViewProps) { - const title = "Responses Details"; - - if (error) { - return ; - } - - if (isLoading) { - return ; - } - - if (!response) { - return ; - } - - // Main content cards - const mainContent = ( - <> - - - Input - - - {/* Show loading state for input items */} - {isLoadingInputItems ? ( -
    - - - -
    - ) : inputItemsError ? ( -
    - Error loading input items: {inputItemsError.message} -
    - - Falling back to response input data. - -
    - ) : null} - - {/* Display input items if available, otherwise fall back to response.input */} - {(() => { - const dataToDisplay = - inputItems?.data && inputItems.data.length > 0 - ? inputItems.data - : response.input; - - if (dataToDisplay && dataToDisplay.length > 0) { - return ( - - ); - } else { - return ( -

    - No input data available. -

    - ); - } - })()} -
    -
    - - - - Output - - - {response.output?.length > 0 ? ( - - ) : ( -

    - No output data available. -

    - )} -
    -
    - - ); - - // Properties sidebar - const sidebar = ( - - - - - - {response.temperature && ( - - )} - {response.top_p && } - {response.parallel_tool_calls && ( - - )} - {response.previous_response_id && ( - {response.previous_response_id} - } - hasBorder - /> - )} - {response.error && ( - - {response.error.code}: {response.error.message} - - } - className="pt-1 mt-1 border-t border-red-200" - /> - )} - - ); - - return ( - - ); -} diff --git a/llama_stack/ui/components/responses/responses-table.test.tsx b/llama_stack/ui/components/responses/responses-table.test.tsx deleted file mode 100644 index 7c45c57d3..000000000 --- a/llama_stack/ui/components/responses/responses-table.test.tsx +++ /dev/null @@ -1,537 +0,0 @@ -import React from "react"; -import { render, screen, fireEvent } from "@testing-library/react"; -import "@testing-library/jest-dom"; -import { ResponsesTable } from "./responses-table"; -import { OpenAIResponse } from "@/lib/types"; - -// Mock next/navigation -const mockPush = jest.fn(); -jest.mock("next/navigation", () => ({ - useRouter: () => ({ - push: mockPush, - }), -})); - -// Mock helper functions -jest.mock("@/lib/truncate-text"); - -// Import the mocked functions -import { truncateText as originalTruncateText } from "@/lib/truncate-text"; - -// Cast to jest.Mock for typings -const truncateText = originalTruncateText as jest.Mock; - -describe("ResponsesTable", () => { - const defaultProps = { - data: [] as OpenAIResponse[], - isLoading: false, - error: null, - }; - - beforeEach(() => { - // Reset all mocks before each test - mockPush.mockClear(); - truncateText.mockClear(); - - // Default pass-through implementation - truncateText.mockImplementation((text: string | undefined) => text); - }); - - test("renders without crashing with default props", () => { - render(); - expect(screen.getByText("No responses found.")).toBeInTheDocument(); - }); - - test("click on a row navigates to the correct URL", () => { - const mockResponse: OpenAIResponse = { - id: "resp_123", - object: "response", - created_at: Math.floor(Date.now() / 1000), - model: "llama-test-model", - status: "completed", - output: [ - { - type: "message", - role: "assistant", - content: "Test output", - }, - ], - input: [ - { - type: "message", - role: "user", - content: "Test input", - }, - ], - }; - - render(); - - const row = screen.getByText("Test input").closest("tr"); - if (row) { - fireEvent.click(row); - expect(mockPush).toHaveBeenCalledWith("/logs/responses/resp_123"); - } else { - throw new Error('Row with "Test input" not found for router mock test.'); - } - }); - - describe("Loading State", () => { - test("renders skeleton UI when isLoading is true", () => { - const { container } = render( - , - ); - - // Check for skeleton in the table caption - const tableCaption = container.querySelector("caption"); - expect(tableCaption).toBeInTheDocument(); - if (tableCaption) { - const captionSkeleton = tableCaption.querySelector( - '[data-slot="skeleton"]', - ); - expect(captionSkeleton).toBeInTheDocument(); - } - - // Check for skeletons in the table body cells - const tableBody = container.querySelector("tbody"); - expect(tableBody).toBeInTheDocument(); - if (tableBody) { - const bodySkeletons = tableBody.querySelectorAll( - '[data-slot="skeleton"]', - ); - expect(bodySkeletons.length).toBeGreaterThan(0); - } - }); - }); - - describe("Error State", () => { - test("renders error message when error prop is provided", () => { - const errorMessage = "Network Error"; - render( - , - ); - expect( - screen.getByText(`Error fetching data: ${errorMessage}`), - ).toBeInTheDocument(); - }); - - test("renders default error message when error.message is not available", () => { - render( - , - ); - expect( - screen.getByText("Error fetching data: An unknown error occurred"), - ).toBeInTheDocument(); - }); - - test("renders default error message when error prop is an object without message", () => { - render(); - expect( - screen.getByText("Error fetching data: An unknown error occurred"), - ).toBeInTheDocument(); - }); - }); - - describe("Empty State", () => { - test('renders "No responses found." and no table when data array is empty', () => { - render(); - expect(screen.getByText("No responses found.")).toBeInTheDocument(); - - // Ensure that the table structure is NOT rendered in the empty state - const table = screen.queryByRole("table"); - expect(table).not.toBeInTheDocument(); - }); - }); - - describe("Data Rendering", () => { - test("renders table caption, headers, and response data correctly", () => { - const mockResponses = [ - { - id: "resp_1", - object: "response" as const, - created_at: 1710000000, - model: "llama-test-model", - status: "completed", - output: [ - { - type: "message" as const, - role: "assistant" as const, - content: "Test output", - }, - ], - input: [ - { - type: "message", - role: "user", - content: "Test input", - }, - ], - }, - { - id: "resp_2", - object: "response" as const, - created_at: 1710001000, - model: "llama-another-model", - status: "completed", - output: [ - { - type: "message" as const, - role: "assistant" as const, - content: "Another output", - }, - ], - input: [ - { - type: "message", - role: "user", - content: "Another input", - }, - ], - }, - ]; - - render( - , - ); - - // Table caption - expect( - screen.getByText("A list of your recent responses."), - ).toBeInTheDocument(); - - // Table headers - expect(screen.getByText("Input")).toBeInTheDocument(); - expect(screen.getByText("Output")).toBeInTheDocument(); - expect(screen.getByText("Model")).toBeInTheDocument(); - expect(screen.getByText("Created")).toBeInTheDocument(); - - // Data rows - expect(screen.getByText("Test input")).toBeInTheDocument(); - expect(screen.getByText("Test output")).toBeInTheDocument(); - expect(screen.getByText("llama-test-model")).toBeInTheDocument(); - expect( - screen.getByText(new Date(1710000000 * 1000).toLocaleString()), - ).toBeInTheDocument(); - - expect(screen.getByText("Another input")).toBeInTheDocument(); - expect(screen.getByText("Another output")).toBeInTheDocument(); - expect(screen.getByText("llama-another-model")).toBeInTheDocument(); - expect( - screen.getByText(new Date(1710001000 * 1000).toLocaleString()), - ).toBeInTheDocument(); - }); - }); - - describe("Input Text Extraction", () => { - test("extracts text from string content", () => { - const mockResponse: OpenAIResponse = { - id: "resp_string", - object: "response", - created_at: 1710000000, - model: "test-model", - status: "completed", - output: [{ type: "message", role: "assistant", content: "output" }], - input: [ - { - type: "message", - role: "user", - content: "Simple string input", - }, - ], - }; - - render( - , - ); - expect(screen.getByText("Simple string input")).toBeInTheDocument(); - }); - - test("extracts text from array content with input_text type", () => { - const mockResponse: OpenAIResponse = { - id: "resp_array", - object: "response", - created_at: 1710000000, - model: "test-model", - status: "completed", - output: [{ type: "message", role: "assistant", content: "output" }], - input: [ - { - type: "message", - role: "user", - content: [ - { type: "input_text", text: "Array input text" }, - { type: "input_text", text: "Should not be used" }, - ], - }, - ], - }; - - render( - , - ); - expect(screen.getByText("Array input text")).toBeInTheDocument(); - }); - - test("returns empty string when no message input found", () => { - const mockResponse: OpenAIResponse = { - id: "resp_no_input", - object: "response", - created_at: 1710000000, - model: "test-model", - status: "completed", - output: [{ type: "message", role: "assistant", content: "output" }], - input: [ - { - type: "other_type", - content: "Not a message", - }, - ], - }; - - const { container } = render( - , - ); - - // Find the input cell (first cell in the data row) and verify it's empty - const inputCell = container.querySelector("tbody tr td:first-child"); - expect(inputCell).toBeInTheDocument(); - expect(inputCell).toHaveTextContent(""); - }); - }); - - describe("Output Text Extraction", () => { - test("extracts text from string message content", () => { - const mockResponse: OpenAIResponse = { - id: "resp_string_output", - object: "response", - created_at: 1710000000, - model: "test-model", - status: "completed", - output: [ - { - type: "message", - role: "assistant", - content: "Simple string output", - }, - ], - input: [{ type: "message", content: "input" }], - }; - - render( - , - ); - expect(screen.getByText("Simple string output")).toBeInTheDocument(); - }); - - test("extracts text from array message content with output_text type", () => { - const mockResponse: OpenAIResponse = { - id: "resp_array_output", - object: "response", - created_at: 1710000000, - model: "test-model", - status: "completed", - output: [ - { - type: "message", - role: "assistant", - content: [ - { type: "output_text", text: "Array output text" }, - { type: "output_text", text: "Should not be used" }, - ], - }, - ], - input: [{ type: "message", content: "input" }], - }; - - render( - , - ); - expect(screen.getByText("Array output text")).toBeInTheDocument(); - }); - - test("formats function call output", () => { - const mockResponse: OpenAIResponse = { - id: "resp_function_call", - object: "response", - created_at: 1710000000, - model: "test-model", - status: "completed", - output: [ - { - type: "function_call", - id: "call_123", - status: "completed", - name: "search_function", - arguments: '{"query": "test"}', - }, - ], - input: [{ type: "message", content: "input" }], - }; - - render( - , - ); - expect( - screen.getByText('search_function({"query": "test"})'), - ).toBeInTheDocument(); - }); - - test("formats function call output without arguments", () => { - const mockResponse: OpenAIResponse = { - id: "resp_function_no_args", - object: "response", - created_at: 1710000000, - model: "test-model", - status: "completed", - output: [ - { - type: "function_call", - id: "call_123", - status: "completed", - name: "simple_function", - }, - ], - input: [{ type: "message", content: "input" }], - }; - - render( - , - ); - expect(screen.getByText("simple_function({})")).toBeInTheDocument(); - }); - - test("formats web search call output", () => { - const mockResponse: OpenAIResponse = { - id: "resp_web_search", - object: "response", - created_at: 1710000000, - model: "test-model", - status: "completed", - output: [ - { - type: "web_search_call", - id: "search_123", - status: "completed", - }, - ], - input: [{ type: "message", content: "input" }], - }; - - render( - , - ); - expect( - screen.getByText("web_search_call(status: completed)"), - ).toBeInTheDocument(); - }); - - test("falls back to JSON.stringify for unknown tool call types", () => { - const mockResponse: OpenAIResponse = { - id: "resp_unknown_tool", - object: "response", - created_at: 1710000000, - model: "test-model", - status: "completed", - output: [ - { - type: "unknown_call", - id: "unknown_123", - status: "completed", - custom_field: "custom_value", - } as any, - ], - input: [{ type: "message", content: "input" }], - }; - - render( - , - ); - // Should contain the JSON stringified version - expect(screen.getByText(/unknown_call/)).toBeInTheDocument(); - }); - - test("falls back to JSON.stringify for entire output when no message or tool call found", () => { - const mockResponse: OpenAIResponse = { - id: "resp_fallback", - object: "response", - created_at: 1710000000, - model: "test-model", - status: "completed", - output: [ - { - type: "unknown_type", - data: "some data", - } as any, - ], - input: [{ type: "message", content: "input" }], - }; - - render( - , - ); - // Should contain the JSON stringified version of the output array - expect(screen.getByText(/unknown_type/)).toBeInTheDocument(); - }); - }); - - describe("Text Truncation", () => { - test("truncates long input and output text", () => { - // Specific mock implementation for this test - truncateText.mockImplementation( - (text: string | undefined, maxLength?: number) => { - const defaultTestMaxLength = 10; - const effectiveMaxLength = maxLength ?? defaultTestMaxLength; - return typeof text === "string" && text.length > effectiveMaxLength - ? text.slice(0, effectiveMaxLength) + "..." - : text; - }, - ); - - const longInput = - "This is a very long input message that should be truncated."; - const longOutput = - "This is a very long output message that should also be truncated."; - - const mockResponse: OpenAIResponse = { - id: "resp_trunc", - object: "response", - created_at: 1710002000, - model: "llama-trunc-model", - status: "completed", - output: [ - { - type: "message", - role: "assistant", - content: longOutput, - }, - ], - input: [ - { - type: "message", - role: "user", - content: longInput, - }, - ], - }; - - render( - , - ); - - // The truncated text should be present for both input and output - const truncatedTexts = screen.getAllByText( - longInput.slice(0, 10) + "...", - ); - expect(truncatedTexts.length).toBe(2); // one for input, one for output - truncatedTexts.forEach((textElement) => - expect(textElement).toBeInTheDocument(), - ); - }); - }); -}); diff --git a/llama_stack/ui/components/responses/responses-table.tsx b/llama_stack/ui/components/responses/responses-table.tsx deleted file mode 100644 index 352450d18..000000000 --- a/llama_stack/ui/components/responses/responses-table.tsx +++ /dev/null @@ -1,117 +0,0 @@ -"use client"; - -import { - OpenAIResponse, - ResponseInput, - ResponseInputMessageContent, -} from "@/lib/types"; -import { LogsTable, LogTableRow } from "@/components/logs/logs-table"; -import { - isMessageInput, - isMessageItem, - isFunctionCallItem, - isWebSearchCallItem, - MessageItem, - FunctionCallItem, - WebSearchCallItem, -} from "./utils/item-types"; - -interface ResponsesTableProps { - data: OpenAIResponse[]; - isLoading: boolean; - error: Error | null; -} - -function getInputText(response: OpenAIResponse): string { - const firstInput = response.input.find(isMessageInput); - if (firstInput) { - return extractContentFromItem(firstInput); - } - return ""; -} - -function getOutputText(response: OpenAIResponse): string { - const firstMessage = response.output.find((item) => - isMessageItem(item as any), - ); - if (firstMessage) { - const content = extractContentFromItem(firstMessage as MessageItem); - if (content) { - return content; - } - } - - const functionCall = response.output.find((item) => - isFunctionCallItem(item as any), - ); - if (functionCall) { - return formatFunctionCall(functionCall as FunctionCallItem); - } - - const webSearchCall = response.output.find((item) => - isWebSearchCallItem(item as any), - ); - if (webSearchCall) { - return formatWebSearchCall(webSearchCall as WebSearchCallItem); - } - - return JSON.stringify(response.output); -} - -function extractContentFromItem(item: { - content?: string | ResponseInputMessageContent[]; -}): string { - if (!item.content) { - return ""; - } - - if (typeof item.content === "string") { - return item.content; - } else if (Array.isArray(item.content)) { - const textContent = item.content.find( - (c: ResponseInputMessageContent) => - c.type === "input_text" || c.type === "output_text", - ); - return textContent?.text || ""; - } - return ""; -} - -function formatFunctionCall(functionCall: FunctionCallItem): string { - const args = functionCall.arguments || "{}"; - const name = functionCall.name || "unknown"; - return `${name}(${args})`; -} - -function formatWebSearchCall(webSearchCall: WebSearchCallItem): string { - return `web_search_call(status: ${webSearchCall.status})`; -} - -function formatResponseToRow(response: OpenAIResponse): LogTableRow { - return { - id: response.id, - input: getInputText(response), - output: getOutputText(response), - model: response.model, - createdTime: new Date(response.created_at * 1000).toLocaleString(), - detailPath: `/logs/responses/${response.id}`, - }; -} - -export function ResponsesTable({ - data, - isLoading, - error, -}: ResponsesTableProps) { - const formattedData = data.map(formatResponseToRow); - - return ( - - ); -} diff --git a/llama_stack/ui/components/responses/utils/item-types.ts b/llama_stack/ui/components/responses/utils/item-types.ts deleted file mode 100644 index 2bde49119..000000000 --- a/llama_stack/ui/components/responses/utils/item-types.ts +++ /dev/null @@ -1,61 +0,0 @@ -/** - * Type guards for different item types in responses - */ - -import type { - ResponseInput, - ResponseOutput, - ResponseMessage, - ResponseToolCall, -} from "@/lib/types"; - -export interface BaseItem { - type: string; - [key: string]: unknown; -} - -export type MessageItem = ResponseMessage; -export type FunctionCallItem = ResponseToolCall & { type: "function_call" }; -export type WebSearchCallItem = ResponseToolCall & { type: "web_search_call" }; -export type FunctionCallOutputItem = BaseItem & { - type: "function_call_output"; - call_id: string; - output?: string | object; -}; - -export type AnyResponseItem = - | ResponseInput - | ResponseOutput - | FunctionCallOutputItem; - -export function isMessageInput( - item: ResponseInput, -): item is ResponseInput & { type: "message" } { - return item.type === "message"; -} - -export function isMessageItem(item: AnyResponseItem): item is MessageItem { - return item.type === "message" && "content" in item; -} - -export function isFunctionCallItem( - item: AnyResponseItem, -): item is FunctionCallItem { - return item.type === "function_call" && "name" in item; -} - -export function isWebSearchCallItem( - item: AnyResponseItem, -): item is WebSearchCallItem { - return item.type === "web_search_call"; -} - -export function isFunctionCallOutputItem( - item: AnyResponseItem, -): item is FunctionCallOutputItem { - return ( - item.type === "function_call_output" && - "call_id" in item && - typeof (item as any).call_id === "string" - ); -} diff --git a/llama_stack/ui/components/ui/breadcrumb.tsx b/llama_stack/ui/components/ui/breadcrumb.tsx deleted file mode 100644 index f63ae19af..000000000 --- a/llama_stack/ui/components/ui/breadcrumb.tsx +++ /dev/null @@ -1,109 +0,0 @@ -import * as React from "react"; -import { Slot } from "@radix-ui/react-slot"; -import { ChevronRight, MoreHorizontal } from "lucide-react"; - -import { cn } from "@/lib/utils"; - -function Breadcrumb({ ...props }: React.ComponentProps<"nav">) { - return