From 5e2093883231eec1e239e443c90590616b319b09 Mon Sep 17 00:00:00 2001 From: ehhuang Date: Thu, 30 Oct 2025 09:13:04 -0700 Subject: [PATCH 01/20] fix: remove LLAMA_STACK_TEST_FORCE_SERVER_RESTART setting in fixture (#3982) # What does this PR do? this is meant to be a manual flag ## Test Plan CI --- tests/integration/telemetry/conftest.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/tests/integration/telemetry/conftest.py b/tests/integration/telemetry/conftest.py index dfb400ce7..58ac4e0df 100644 --- a/tests/integration/telemetry/conftest.py +++ b/tests/integration/telemetry/conftest.py @@ -33,12 +33,10 @@ def telemetry_test_collector(): } previous_env = {key: os.environ.get(key) for key in env_overrides} - previous_force_restart = os.environ.get("LLAMA_STACK_TEST_FORCE_SERVER_RESTART") for key, value in env_overrides.items(): os.environ[key] = value - os.environ["LLAMA_STACK_TEST_FORCE_SERVER_RESTART"] = "1" telemetry_module._TRACER_PROVIDER = None try: @@ -50,10 +48,6 @@ def telemetry_test_collector(): os.environ.pop(key, None) else: os.environ[key] = prior - if previous_force_restart is None: - os.environ.pop("LLAMA_STACK_TEST_FORCE_SERVER_RESTART", None) - else: - os.environ["LLAMA_STACK_TEST_FORCE_SERVER_RESTART"] = previous_force_restart else: manager = InMemoryTelemetryManager() try: From 77c8bc6fa7389d0e82495b203fa32e79c9eec6a7 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Thu, 30 Oct 2025 11:02:59 -0700 Subject: [PATCH 02/20] fix(ci): add back server:ci-tests to replay tests (#3976) It is useful for local debugging. If both server and docker are failing, you can just run server locally to debug which is much easier to do. --- .github/workflows/integration-tests.yml | 2 +- scripts/integration-tests.sh | 9 ++++++++ tests/integration/fixtures/common.py | 1 + tests/integration/telemetry/conftest.py | 28 +++++++++---------------- 4 files changed, 21 insertions(+), 19 deletions(-) diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index 2b8965aad..067f49abd 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -47,7 +47,7 @@ jobs: strategy: fail-fast: false matrix: - client-type: [library, docker] + client-type: [library, docker, server] # Use Python 3.13 only on nightly schedule (daily latest client test), otherwise use 3.12 python-version: ${{ github.event.schedule == '0 0 * * *' && fromJSON('["3.12", "3.13"]') || fromJSON('["3.12"]') }} client-version: ${{ (github.event.schedule == '0 0 * * *' || github.event.inputs.test-all-client-versions == 'true') && fromJSON('["published", "latest"]') || fromJSON('["latest"]') }} diff --git a/scripts/integration-tests.sh b/scripts/integration-tests.sh index a09dc8621..ed3934a5b 100755 --- a/scripts/integration-tests.sh +++ b/scripts/integration-tests.sh @@ -208,6 +208,15 @@ if [[ "$STACK_CONFIG" == *"server:"* && "$COLLECT_ONLY" == false ]]; then echo "=== Starting Llama Stack Server ===" export LLAMA_STACK_LOG_WIDTH=120 + # Configure telemetry collector for server mode + # Use a fixed port for the OTEL collector so the server can connect to it + COLLECTOR_PORT=4317 + export LLAMA_STACK_TEST_COLLECTOR_PORT="${COLLECTOR_PORT}" + export OTEL_EXPORTER_OTLP_ENDPOINT="http://127.0.0.1:${COLLECTOR_PORT}" + export OTEL_EXPORTER_OTLP_PROTOCOL="http/protobuf" + export OTEL_BSP_SCHEDULE_DELAY="200" + export OTEL_BSP_EXPORT_TIMEOUT="2000" + # remove "server:" from STACK_CONFIG stack_config=$(echo "$STACK_CONFIG" | sed 's/^server://') nohup llama stack run $stack_config > server.log 2>&1 & diff --git a/tests/integration/fixtures/common.py b/tests/integration/fixtures/common.py index 41822f850..e68f9dc9e 100644 --- a/tests/integration/fixtures/common.py +++ b/tests/integration/fixtures/common.py @@ -230,6 +230,7 @@ def instantiate_llama_stack_client(session): force_restart = os.environ.get("LLAMA_STACK_TEST_FORCE_SERVER_RESTART") == "1" if force_restart: + print(f"Forcing restart of the server on port {port}") stop_server_on_port(port) # Check if port is available diff --git a/tests/integration/telemetry/conftest.py b/tests/integration/telemetry/conftest.py index 58ac4e0df..fd9224ae4 100644 --- a/tests/integration/telemetry/conftest.py +++ b/tests/integration/telemetry/conftest.py @@ -10,7 +10,6 @@ import os import pytest -import llama_stack.core.telemetry.telemetry as telemetry_module from llama_stack.testing.api_recorder import patch_httpx_for_test_id from tests.integration.fixtures.common import instantiate_llama_stack_client from tests.integration.telemetry.collectors import InMemoryTelemetryManager, OtlpHttpTestCollector @@ -21,33 +20,26 @@ def telemetry_test_collector(): stack_mode = os.environ.get("LLAMA_STACK_TEST_STACK_CONFIG_TYPE", "library_client") if stack_mode == "server": + # In server mode, the collector must be started and the server is already running. + # The integration test script (scripts/integration-tests.sh) should have set + # LLAMA_STACK_TEST_COLLECTOR_PORT and OTEL_EXPORTER_OTLP_ENDPOINT before starting the server. try: collector = OtlpHttpTestCollector() except RuntimeError as exc: pytest.skip(str(exc)) - env_overrides = { - "OTEL_EXPORTER_OTLP_ENDPOINT": collector.endpoint, - "OTEL_EXPORTER_OTLP_PROTOCOL": "http/protobuf", - "OTEL_BSP_SCHEDULE_DELAY": "200", - "OTEL_BSP_EXPORT_TIMEOUT": "2000", - } - previous_env = {key: os.environ.get(key) for key in env_overrides} - - for key, value in env_overrides.items(): - os.environ[key] = value - - telemetry_module._TRACER_PROVIDER = None + # Verify the collector is listening on the expected endpoint + expected_endpoint = os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT") + if expected_endpoint and collector.endpoint != expected_endpoint: + pytest.skip( + f"Collector endpoint mismatch: expected {expected_endpoint}, got {collector.endpoint}. " + "Server was likely started before collector." + ) try: yield collector finally: collector.shutdown() - for key, prior in previous_env.items(): - if prior is None: - os.environ.pop(key, None) - else: - os.environ[key] = prior else: manager = InMemoryTelemetryManager() try: From c2ae42b3436c2a7a1b9bdd08b12a57d7a011ca78 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Thu, 30 Oct 2025 11:48:20 -0700 Subject: [PATCH 03/20] fix(ci): show pre-commit output easily on failure (#3985) Right now, the failed Step which is opened by GH by default tells me to just go up and click and scroll through for no reason. --- .github/workflows/pre-commit.yml | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index 485009578..d10161d93 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -50,19 +50,34 @@ jobs: run: npm ci working-directory: src/llama_stack/ui + - name: Install pre-commit + run: python -m pip install pre-commit + + - name: Cache pre-commit + uses: actions/cache@0057852bfaa89a56745cba8c7296529d2fc39830 # v4 + with: + path: ~/.cache/pre-commit + key: pre-commit-3|${{ env.pythonLocation }}|${{ hashFiles('.pre-commit-config.yaml') }} + - name: Run pre-commit id: precommit - uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1 - continue-on-error: true + run: | + set +e + pre-commit run --show-diff-on-failure --color=always --all-files 2>&1 | tee /tmp/precommit.log + status=${PIPESTATUS[0]} + echo "status=$status" >> $GITHUB_OUTPUT + exit 0 env: SKIP: no-commit-to-branch,mypy RUFF_OUTPUT_FORMAT: github - name: Check pre-commit results - if: steps.precommit.outcome == 'failure' + if: steps.precommit.outputs.status != '0' run: | echo "::error::Pre-commit hooks failed. Please run 'pre-commit run --all-files' locally and commit the fixes." - echo "::warning::Some pre-commit hooks failed. Check the output above for details." + echo "" + echo "Failed hooks output:" + cat /tmp/precommit.log exit 1 - name: Debug From 90234d697350e94d2b4ccfc0065df577acedf2f8 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Thu, 30 Oct 2025 15:20:34 -0700 Subject: [PATCH 04/20] ci: support release branches and match client branch (#3990) - Update workflows to trigger on release-X.Y.x-maint branches - When PR targets release branch, fetch matching branch from llama-stack-client-python - Falls back to main if matching client branch doesn't exist - Updated workflows: - integration-tests.yml - integration-auth-tests.yml - integration-sql-store-tests.yml - integration-vector-io-tests.yml - unit-tests.yml - backward-compat.yml - pre-commit.yml --- .../actions/run-and-record-tests/action.yml | 2 +- .../actions/setup-test-environment/action.yml | 22 +++++++++++++++++-- .github/workflows/backward-compat.yml | 4 +++- .github/workflows/integration-auth-tests.yml | 8 +++++-- .../workflows/integration-sql-store-tests.yml | 8 +++++-- .github/workflows/integration-tests.yml | 8 +++++-- .../workflows/integration-vector-io-tests.yml | 8 +++++-- .github/workflows/pre-commit.yml | 4 +++- .github/workflows/unit-tests.yml | 8 +++++-- 9 files changed, 57 insertions(+), 15 deletions(-) diff --git a/.github/actions/run-and-record-tests/action.yml b/.github/actions/run-and-record-tests/action.yml index ac600d570..ec4d7f977 100644 --- a/.github/actions/run-and-record-tests/action.yml +++ b/.github/actions/run-and-record-tests/action.yml @@ -94,7 +94,7 @@ runs: if: ${{ always() }} uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2 with: - name: logs-${{ github.run_id }}-${{ github.run_attempt || '' }}-${{ strategy.job-index }} + name: logs-${{ github.run_id }}-${{ github.run_attempt || '1' }}-${{ strategy.job-index || github.job }}-${{ github.action }} path: | *.log retention-days: 1 diff --git a/.github/actions/setup-test-environment/action.yml b/.github/actions/setup-test-environment/action.yml index ee9011ed8..542610337 100644 --- a/.github/actions/setup-test-environment/action.yml +++ b/.github/actions/setup-test-environment/action.yml @@ -44,8 +44,26 @@ runs: run: | # Install llama-stack-client-python based on the client-version input if [ "${{ inputs.client-version }}" = "latest" ]; then - echo "Installing latest llama-stack-client-python from main branch" - export LLAMA_STACK_CLIENT_DIR=git+https://github.com/llamastack/llama-stack-client-python.git@main + # Check if PR is targeting a release branch + TARGET_BRANCH="${{ github.base_ref }}" + + if [[ "$TARGET_BRANCH" =~ ^release-[0-9]+\.[0-9]+\.x-maint$ ]]; then + echo "PR targets release branch: $TARGET_BRANCH" + echo "Checking if matching branch exists in llama-stack-client-python..." + + # Check if the branch exists in the client repo + if git ls-remote --exit-code --heads https://github.com/llamastack/llama-stack-client-python.git "$TARGET_BRANCH" > /dev/null 2>&1; then + echo "Installing llama-stack-client-python from matching branch: $TARGET_BRANCH" + export LLAMA_STACK_CLIENT_DIR=git+https://github.com/llamastack/llama-stack-client-python.git@$TARGET_BRANCH + else + echo "::error::Branch $TARGET_BRANCH not found in llama-stack-client-python repository" + echo "::error::Please create the matching release branch in llama-stack-client-python before testing" + exit 1 + fi + else + echo "Installing latest llama-stack-client-python from main branch" + export LLAMA_STACK_CLIENT_DIR=git+https://github.com/llamastack/llama-stack-client-python.git@main + fi elif [ "${{ inputs.client-version }}" = "published" ]; then echo "Installing published llama-stack-client-python from PyPI" unset LLAMA_STACK_CLIENT_DIR diff --git a/.github/workflows/backward-compat.yml b/.github/workflows/backward-compat.yml index 72d2b0c27..88a3db503 100644 --- a/.github/workflows/backward-compat.yml +++ b/.github/workflows/backward-compat.yml @@ -4,7 +4,9 @@ run-name: Check backward compatibility for run.yaml configs on: pull_request: - branches: [main] + branches: + - main + - 'release-[0-9]+.[0-9]+.x-maint' paths: - 'src/llama_stack/core/datatypes.py' - 'src/llama_stack/providers/datatypes.py' diff --git a/.github/workflows/integration-auth-tests.yml b/.github/workflows/integration-auth-tests.yml index 2de3fe9df..ee9d53f22 100644 --- a/.github/workflows/integration-auth-tests.yml +++ b/.github/workflows/integration-auth-tests.yml @@ -4,9 +4,13 @@ run-name: Run the integration test suite with Kubernetes authentication on: push: - branches: [ main ] + branches: + - main + - 'release-[0-9]+.[0-9]+.x-maint' pull_request: - branches: [ main ] + branches: + - main + - 'release-[0-9]+.[0-9]+.x-maint' paths: - 'distributions/**' - 'src/llama_stack/**' diff --git a/.github/workflows/integration-sql-store-tests.yml b/.github/workflows/integration-sql-store-tests.yml index 0653b3fa8..429357c1f 100644 --- a/.github/workflows/integration-sql-store-tests.yml +++ b/.github/workflows/integration-sql-store-tests.yml @@ -4,9 +4,13 @@ run-name: Run the integration test suite with SqlStore on: push: - branches: [ main ] + branches: + - main + - 'release-[0-9]+.[0-9]+.x-maint' pull_request: - branches: [ main ] + branches: + - main + - 'release-[0-9]+.[0-9]+.x-maint' paths: - 'src/llama_stack/providers/utils/sqlstore/**' - 'tests/integration/sqlstore/**' diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index 067f49abd..9f3ffc769 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -4,9 +4,13 @@ run-name: Run the integration test suites from tests/integration in replay mode on: push: - branches: [ main ] + branches: + - main + - 'release-[0-9]+.[0-9]+.x-maint' pull_request: - branches: [ main ] + branches: + - main + - 'release-[0-9]+.[0-9]+.x-maint' types: [opened, synchronize, reopened] paths: - 'src/llama_stack/**' diff --git a/.github/workflows/integration-vector-io-tests.yml b/.github/workflows/integration-vector-io-tests.yml index 0b4e174bc..790c2cf8b 100644 --- a/.github/workflows/integration-vector-io-tests.yml +++ b/.github/workflows/integration-vector-io-tests.yml @@ -4,9 +4,13 @@ run-name: Run the integration test suite with various VectorIO providers on: push: - branches: [ main ] + branches: + - main + - 'release-[0-9]+.[0-9]+.x-maint' pull_request: - branches: [ main ] + branches: + - main + - 'release-[0-9]+.[0-9]+.x-maint' paths: - 'src/llama_stack/**' - '!src/llama_stack/ui/**' diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index d10161d93..77a041d8e 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -5,7 +5,9 @@ run-name: Run pre-commit checks on: pull_request: push: - branches: [main] + branches: + - main + - 'release-[0-9]+.[0-9]+.x-maint' concurrency: group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }} diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index 182643721..881803dbb 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -4,9 +4,13 @@ run-name: Run the unit test suite on: push: - branches: [ main ] + branches: + - main + - 'release-[0-9]+.[0-9]+.x-maint' pull_request: - branches: [ main ] + branches: + - main + - 'release-[0-9]+.[0-9]+.x-maint' paths: - 'src/llama_stack/**' - '!src/llama_stack/ui/**' From 6f90a7af4b67b3fc94e14afbff0085c23d0bec64 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Thu, 30 Oct 2025 16:27:13 -0700 Subject: [PATCH 05/20] ci: target release-X.Y.x branches instead of release-X.Y.x-maint (#3995) We will be updating our release procedure to be more "normal" or "sane". We will - create release branches like normal people - land cherry-picks onto those branches - run releases off of those branches - no more "rc" branch pollution either Given that, this PR cleans things up a bit - Remove `-maint` suffix from release branch patterns in CI workflows - Update branch matching to `release-X.Y.x` format --- .github/actions/setup-test-environment/action.yml | 2 +- .github/workflows/backward-compat.yml | 4 +++- .github/workflows/integration-auth-tests.yml | 8 ++++++-- .github/workflows/integration-sql-store-tests.yml | 8 ++++++-- .github/workflows/integration-tests.yml | 8 ++++++-- .github/workflows/integration-vector-io-tests.yml | 8 ++++++-- .github/workflows/pre-commit.yml | 4 +++- .github/workflows/unit-tests.yml | 8 ++++++-- 8 files changed, 37 insertions(+), 13 deletions(-) diff --git a/.github/actions/setup-test-environment/action.yml b/.github/actions/setup-test-environment/action.yml index 542610337..81b6d0178 100644 --- a/.github/actions/setup-test-environment/action.yml +++ b/.github/actions/setup-test-environment/action.yml @@ -47,7 +47,7 @@ runs: # Check if PR is targeting a release branch TARGET_BRANCH="${{ github.base_ref }}" - if [[ "$TARGET_BRANCH" =~ ^release-[0-9]+\.[0-9]+\.x-maint$ ]]; then + if [[ "$TARGET_BRANCH" =~ ^release-([0-9]+\.){1,3}[0-9]+$ ]]; then echo "PR targets release branch: $TARGET_BRANCH" echo "Checking if matching branch exists in llama-stack-client-python..." diff --git a/.github/workflows/backward-compat.yml b/.github/workflows/backward-compat.yml index 88a3db503..cf91b851e 100644 --- a/.github/workflows/backward-compat.yml +++ b/.github/workflows/backward-compat.yml @@ -6,7 +6,9 @@ on: pull_request: branches: - main - - 'release-[0-9]+.[0-9]+.x-maint' + - 'release-[0-9]+.[0-9]+.[0-9]+.[0-9]+' + - 'release-[0-9]+.[0-9]+.[0-9]+' + - 'release-[0-9]+.[0-9]+' paths: - 'src/llama_stack/core/datatypes.py' - 'src/llama_stack/providers/datatypes.py' diff --git a/.github/workflows/integration-auth-tests.yml b/.github/workflows/integration-auth-tests.yml index ee9d53f22..4157ead35 100644 --- a/.github/workflows/integration-auth-tests.yml +++ b/.github/workflows/integration-auth-tests.yml @@ -6,11 +6,15 @@ on: push: branches: - main - - 'release-[0-9]+.[0-9]+.x-maint' + - 'release-[0-9]+.[0-9]+.[0-9]+.[0-9]+' + - 'release-[0-9]+.[0-9]+.[0-9]+' + - 'release-[0-9]+.[0-9]+' pull_request: branches: - main - - 'release-[0-9]+.[0-9]+.x-maint' + - 'release-[0-9]+.[0-9]+.[0-9]+.[0-9]+' + - 'release-[0-9]+.[0-9]+.[0-9]+' + - 'release-[0-9]+.[0-9]+' paths: - 'distributions/**' - 'src/llama_stack/**' diff --git a/.github/workflows/integration-sql-store-tests.yml b/.github/workflows/integration-sql-store-tests.yml index 429357c1f..fae675be3 100644 --- a/.github/workflows/integration-sql-store-tests.yml +++ b/.github/workflows/integration-sql-store-tests.yml @@ -6,11 +6,15 @@ on: push: branches: - main - - 'release-[0-9]+.[0-9]+.x-maint' + - 'release-[0-9]+.[0-9]+.[0-9]+.[0-9]+' + - 'release-[0-9]+.[0-9]+.[0-9]+' + - 'release-[0-9]+.[0-9]+' pull_request: branches: - main - - 'release-[0-9]+.[0-9]+.x-maint' + - 'release-[0-9]+.[0-9]+.[0-9]+.[0-9]+' + - 'release-[0-9]+.[0-9]+.[0-9]+' + - 'release-[0-9]+.[0-9]+' paths: - 'src/llama_stack/providers/utils/sqlstore/**' - 'tests/integration/sqlstore/**' diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index 9f3ffc769..a9876d06a 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -6,11 +6,15 @@ on: push: branches: - main - - 'release-[0-9]+.[0-9]+.x-maint' + - 'release-[0-9]+.[0-9]+.[0-9]+.[0-9]+' + - 'release-[0-9]+.[0-9]+.[0-9]+' + - 'release-[0-9]+.[0-9]+' pull_request: branches: - main - - 'release-[0-9]+.[0-9]+.x-maint' + - 'release-[0-9]+.[0-9]+.[0-9]+.[0-9]+' + - 'release-[0-9]+.[0-9]+.[0-9]+' + - 'release-[0-9]+.[0-9]+' types: [opened, synchronize, reopened] paths: - 'src/llama_stack/**' diff --git a/.github/workflows/integration-vector-io-tests.yml b/.github/workflows/integration-vector-io-tests.yml index 790c2cf8b..eee7bde70 100644 --- a/.github/workflows/integration-vector-io-tests.yml +++ b/.github/workflows/integration-vector-io-tests.yml @@ -6,11 +6,15 @@ on: push: branches: - main - - 'release-[0-9]+.[0-9]+.x-maint' + - 'release-[0-9]+.[0-9]+.[0-9]+.[0-9]+' + - 'release-[0-9]+.[0-9]+.[0-9]+' + - 'release-[0-9]+.[0-9]+' pull_request: branches: - main - - 'release-[0-9]+.[0-9]+.x-maint' + - 'release-[0-9]+.[0-9]+.[0-9]+.[0-9]+' + - 'release-[0-9]+.[0-9]+.[0-9]+' + - 'release-[0-9]+.[0-9]+' paths: - 'src/llama_stack/**' - '!src/llama_stack/ui/**' diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index 77a041d8e..049911d8b 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -7,7 +7,9 @@ on: push: branches: - main - - 'release-[0-9]+.[0-9]+.x-maint' + - 'release-[0-9]+.[0-9]+.[0-9]+.[0-9]+' + - 'release-[0-9]+.[0-9]+.[0-9]+' + - 'release-[0-9]+.[0-9]+' concurrency: group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }} diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index 881803dbb..7e59e7df4 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -6,11 +6,15 @@ on: push: branches: - main - - 'release-[0-9]+.[0-9]+.x-maint' + - 'release-[0-9]+.[0-9]+.[0-9]+.[0-9]+' + - 'release-[0-9]+.[0-9]+.[0-9]+' + - 'release-[0-9]+.[0-9]+' pull_request: branches: - main - - 'release-[0-9]+.[0-9]+.x-maint' + - 'release-[0-9]+.[0-9]+.[0-9]+.[0-9]+' + - 'release-[0-9]+.[0-9]+.[0-9]+' + - 'release-[0-9]+.[0-9]+' paths: - 'src/llama_stack/**' - '!src/llama_stack/ui/**' From 0e384a55a105380338fc596c14a8fbcda0415bad Mon Sep 17 00:00:00 2001 From: ehhuang Date: Thu, 30 Oct 2025 16:34:12 -0700 Subject: [PATCH 06/20] feat: support `workers` in run config (#3992) # What does this PR do? ## Test Plan Set workers: 4 in run.yaml. Start server and observe logs multiple times. --- src/llama_stack/cli/stack/run.py | 3 ++- src/llama_stack/core/datatypes.py | 4 ++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/src/llama_stack/cli/stack/run.py b/src/llama_stack/cli/stack/run.py index 2882500ce..044ce49c9 100644 --- a/src/llama_stack/cli/stack/run.py +++ b/src/llama_stack/cli/stack/run.py @@ -127,7 +127,7 @@ class StackRun(Subcommand): config = StackRunConfig(**cast_image_name_to_string(replace_env_vars(config_contents))) port = args.port or config.server.port - host = config.server.host or ["::", "0.0.0.0"] + host = config.server.host or "0.0.0.0" # Set the config file in environment so create_app can find it os.environ["LLAMA_STACK_CONFIG"] = str(config_file) @@ -139,6 +139,7 @@ class StackRun(Subcommand): "lifespan": "on", "log_level": logger.getEffectiveLevel(), "log_config": logger_config, + "workers": config.server.workers, } keyfile = config.server.tls_keyfile diff --git a/src/llama_stack/core/datatypes.py b/src/llama_stack/core/datatypes.py index 95907adcf..2182ea4e5 100644 --- a/src/llama_stack/core/datatypes.py +++ b/src/llama_stack/core/datatypes.py @@ -473,6 +473,10 @@ class ServerConfig(BaseModel): "- true: Enable localhost CORS for development\n" "- {allow_origins: [...], allow_methods: [...], ...}: Full configuration", ) + workers: int = Field( + default=1, + description="Number of workers to use for the server", + ) class StackRunConfig(BaseModel): From ff2b270e2f2c24d7f379bda1819e6fd915758acc Mon Sep 17 00:00:00 2001 From: Derek Higgins Date: Thu, 30 Oct 2025 23:55:23 +0000 Subject: [PATCH 07/20] =?UTF-8?q?fix:=20relax=20structured=20output=20test?= =?UTF-8?q?=20assertions=20to=20handle=20whitespace=20and=E2=80=A6=20(#399?= =?UTF-8?q?7)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit … case variations The ollama/llama3.2:3b-instruct-fp16 model returns string values with trailing whitespace in structured JSON output. Updated test assertions to use case-insensitive substring matching instead of exact equality. Use .lower() for case-insensitive comparison Check if expected value is contained in actual value (handles whitespace) Closes: #3996 Signed-off-by: Derek Higgins --- tests/integration/inference/test_openai_completion.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/integration/inference/test_openai_completion.py b/tests/integration/inference/test_openai_completion.py index 964d19c1d..18406610f 100644 --- a/tests/integration/inference/test_openai_completion.py +++ b/tests/integration/inference/test_openai_completion.py @@ -721,6 +721,6 @@ def test_openai_chat_completion_structured_output(openai_client, text_model_id, print(response.choices[0].message.content) answer = AnswerFormat.model_validate_json(response.choices[0].message.content) expected = tc["expected"] - assert answer.first_name == expected["first_name"] - assert answer.last_name == expected["last_name"] + assert expected["first_name"].lower() in answer.first_name.lower() + assert expected["last_name"].lower() in answer.last_name.lower() assert answer.year_of_birth == expected["year_of_birth"] From e8cd8508b5e6f819f186f26da583690caec7537b Mon Sep 17 00:00:00 2001 From: Doug Edgar Date: Thu, 30 Oct 2025 17:01:31 -0700 Subject: [PATCH 08/20] fix: handle missing external_providers_dir (#3974) # What does this PR do? This PR fixes the handling of the external_providers_dir configuration field to align with its ongoing deprecation, in favor of the provider `module` specification approach. It addresses the issue in #3950, where using the default provided run.yaml config resulted in the `external_providers_dir` parameter being set to the literal string `None`, and crashing the llama-stack server when starting. Closes #3950 ## Test Plan - Built a new container image from `podman build . -f containers/Containerfile --build-arg DISTRO_NAME=starter --tag llama-stack:starter` - Tested it locally with `podman run -it localhost/llama-stack:starter` - Tested it on an OpenShift 4.19 cluster, deployed via the llama-stack-k8s-operator. Signed-off-by: Doug Edgar --- src/llama_stack/cli/stack/run.py | 3 ++- src/llama_stack/core/configure.py | 9 --------- tests/unit/cli/test_stack_config.py | 23 +++++++++++++++++++++++ 3 files changed, 25 insertions(+), 10 deletions(-) diff --git a/src/llama_stack/cli/stack/run.py b/src/llama_stack/cli/stack/run.py index 044ce49c9..c9334b9e9 100644 --- a/src/llama_stack/cli/stack/run.py +++ b/src/llama_stack/cli/stack/run.py @@ -106,7 +106,8 @@ class StackRun(Subcommand): try: config = parse_and_maybe_upgrade_config(config_dict) - if not os.path.exists(str(config.external_providers_dir)): + # Create external_providers_dir if it's specified and doesn't exist + if config.external_providers_dir and not os.path.exists(str(config.external_providers_dir)): os.makedirs(str(config.external_providers_dir), exist_ok=True) except AttributeError as e: self.parser.error(f"failed to parse config file '{config_file}':\n {e}") diff --git a/src/llama_stack/core/configure.py b/src/llama_stack/core/configure.py index 734839ea9..5d4a54184 100644 --- a/src/llama_stack/core/configure.py +++ b/src/llama_stack/core/configure.py @@ -17,7 +17,6 @@ from llama_stack.core.distribution import ( get_provider_registry, ) from llama_stack.core.stack import cast_image_name_to_string, replace_env_vars -from llama_stack.core.utils.config_dirs import EXTERNAL_PROVIDERS_DIR from llama_stack.core.utils.dynamic import instantiate_class_type from llama_stack.core.utils.prompt_for_config import prompt_for_config from llama_stack.log import get_logger @@ -194,19 +193,11 @@ def upgrade_from_routing_table( def parse_and_maybe_upgrade_config(config_dict: dict[str, Any]) -> StackRunConfig: - version = config_dict.get("version", None) - if version == LLAMA_STACK_RUN_CONFIG_VERSION: - processed_config_dict = replace_env_vars(config_dict) - return StackRunConfig(**cast_image_name_to_string(processed_config_dict)) - if "routing_table" in config_dict: logger.info("Upgrading config...") config_dict = upgrade_from_routing_table(config_dict) config_dict["version"] = LLAMA_STACK_RUN_CONFIG_VERSION - if not config_dict.get("external_providers_dir", None): - config_dict["external_providers_dir"] = EXTERNAL_PROVIDERS_DIR - processed_config_dict = replace_env_vars(config_dict) return StackRunConfig(**cast_image_name_to_string(processed_config_dict)) diff --git a/tests/unit/cli/test_stack_config.py b/tests/unit/cli/test_stack_config.py index 0977a1e43..5d54c2257 100644 --- a/tests/unit/cli/test_stack_config.py +++ b/tests/unit/cli/test_stack_config.py @@ -206,3 +206,26 @@ def test_parse_and_maybe_upgrade_config_invalid(invalid_config): def test_parse_and_maybe_upgrade_config_image_name_int(config_with_image_name_int): result = parse_and_maybe_upgrade_config(config_with_image_name_int) assert isinstance(result.image_name, str) + + +def test_parse_and_maybe_upgrade_config_sets_external_providers_dir(up_to_date_config): + """Test that external_providers_dir is None when not specified (deprecated field).""" + # Ensure the config doesn't have external_providers_dir set + assert "external_providers_dir" not in up_to_date_config + + result = parse_and_maybe_upgrade_config(up_to_date_config) + + # Verify external_providers_dir is None (not set to default) + # This aligns with the deprecation of external_providers_dir + assert result.external_providers_dir is None + + +def test_parse_and_maybe_upgrade_config_preserves_custom_external_providers_dir(up_to_date_config): + """Test that custom external_providers_dir values are preserved.""" + custom_dir = "/custom/providers/dir" + up_to_date_config["external_providers_dir"] = custom_dir + + result = parse_and_maybe_upgrade_config(up_to_date_config) + + # Verify the custom value was preserved + assert str(result.external_providers_dir) == custom_dir From c396de57a4783e3f4a199f6bf763a5ebb217f415 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Thu, 30 Oct 2025 21:33:32 -0700 Subject: [PATCH 09/20] ci: standardize release branch pattern to release-X.Y.x (#3999) Standardize CI workflows to use `release-X.Y.x` branch pattern instead of multiple numeric variants. That's the pattern we are settling on. See https://github.com/llamastack/llama-stack-ops/pull/20 for reference. --- .github/actions/setup-test-environment/action.yml | 2 +- .github/workflows/integration-auth-tests.yml | 8 ++------ .github/workflows/integration-sql-store-tests.yml | 8 ++------ .github/workflows/integration-tests.yml | 8 ++------ .github/workflows/integration-vector-io-tests.yml | 8 ++------ .github/workflows/pre-commit.yml | 4 +--- .github/workflows/unit-tests.yml | 8 ++------ 7 files changed, 12 insertions(+), 34 deletions(-) diff --git a/.github/actions/setup-test-environment/action.yml b/.github/actions/setup-test-environment/action.yml index 81b6d0178..27d0943fe 100644 --- a/.github/actions/setup-test-environment/action.yml +++ b/.github/actions/setup-test-environment/action.yml @@ -47,7 +47,7 @@ runs: # Check if PR is targeting a release branch TARGET_BRANCH="${{ github.base_ref }}" - if [[ "$TARGET_BRANCH" =~ ^release-([0-9]+\.){1,3}[0-9]+$ ]]; then + if [[ "$TARGET_BRANCH" =~ ^release-[0-9]+\.[0-9]+\.x$ ]]; then echo "PR targets release branch: $TARGET_BRANCH" echo "Checking if matching branch exists in llama-stack-client-python..." diff --git a/.github/workflows/integration-auth-tests.yml b/.github/workflows/integration-auth-tests.yml index 4157ead35..560ab4293 100644 --- a/.github/workflows/integration-auth-tests.yml +++ b/.github/workflows/integration-auth-tests.yml @@ -6,15 +6,11 @@ on: push: branches: - main - - 'release-[0-9]+.[0-9]+.[0-9]+.[0-9]+' - - 'release-[0-9]+.[0-9]+.[0-9]+' - - 'release-[0-9]+.[0-9]+' + - 'release-[0-9]+.[0-9]+.x' pull_request: branches: - main - - 'release-[0-9]+.[0-9]+.[0-9]+.[0-9]+' - - 'release-[0-9]+.[0-9]+.[0-9]+' - - 'release-[0-9]+.[0-9]+' + - 'release-[0-9]+.[0-9]+.x' paths: - 'distributions/**' - 'src/llama_stack/**' diff --git a/.github/workflows/integration-sql-store-tests.yml b/.github/workflows/integration-sql-store-tests.yml index fae675be3..8c3e51dd4 100644 --- a/.github/workflows/integration-sql-store-tests.yml +++ b/.github/workflows/integration-sql-store-tests.yml @@ -6,15 +6,11 @@ on: push: branches: - main - - 'release-[0-9]+.[0-9]+.[0-9]+.[0-9]+' - - 'release-[0-9]+.[0-9]+.[0-9]+' - - 'release-[0-9]+.[0-9]+' + - 'release-[0-9]+.[0-9]+.x' pull_request: branches: - main - - 'release-[0-9]+.[0-9]+.[0-9]+.[0-9]+' - - 'release-[0-9]+.[0-9]+.[0-9]+' - - 'release-[0-9]+.[0-9]+' + - 'release-[0-9]+.[0-9]+.x' paths: - 'src/llama_stack/providers/utils/sqlstore/**' - 'tests/integration/sqlstore/**' diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index a9876d06a..ac70f0960 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -6,15 +6,11 @@ on: push: branches: - main - - 'release-[0-9]+.[0-9]+.[0-9]+.[0-9]+' - - 'release-[0-9]+.[0-9]+.[0-9]+' - - 'release-[0-9]+.[0-9]+' + - 'release-[0-9]+.[0-9]+.x' pull_request: branches: - main - - 'release-[0-9]+.[0-9]+.[0-9]+.[0-9]+' - - 'release-[0-9]+.[0-9]+.[0-9]+' - - 'release-[0-9]+.[0-9]+' + - 'release-[0-9]+.[0-9]+.x' types: [opened, synchronize, reopened] paths: - 'src/llama_stack/**' diff --git a/.github/workflows/integration-vector-io-tests.yml b/.github/workflows/integration-vector-io-tests.yml index eee7bde70..952141f3b 100644 --- a/.github/workflows/integration-vector-io-tests.yml +++ b/.github/workflows/integration-vector-io-tests.yml @@ -6,15 +6,11 @@ on: push: branches: - main - - 'release-[0-9]+.[0-9]+.[0-9]+.[0-9]+' - - 'release-[0-9]+.[0-9]+.[0-9]+' - - 'release-[0-9]+.[0-9]+' + - 'release-[0-9]+.[0-9]+.x' pull_request: branches: - main - - 'release-[0-9]+.[0-9]+.[0-9]+.[0-9]+' - - 'release-[0-9]+.[0-9]+.[0-9]+' - - 'release-[0-9]+.[0-9]+' + - 'release-[0-9]+.[0-9]+.x' paths: - 'src/llama_stack/**' - '!src/llama_stack/ui/**' diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index 049911d8b..695a4f9e2 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -7,9 +7,7 @@ on: push: branches: - main - - 'release-[0-9]+.[0-9]+.[0-9]+.[0-9]+' - - 'release-[0-9]+.[0-9]+.[0-9]+' - - 'release-[0-9]+.[0-9]+' + - 'release-[0-9]+.[0-9]+.x' concurrency: group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }} diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index 7e59e7df4..92c0a6a19 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -6,15 +6,11 @@ on: push: branches: - main - - 'release-[0-9]+.[0-9]+.[0-9]+.[0-9]+' - - 'release-[0-9]+.[0-9]+.[0-9]+' - - 'release-[0-9]+.[0-9]+' + - 'release-[0-9]+.[0-9]+.x' pull_request: branches: - main - - 'release-[0-9]+.[0-9]+.[0-9]+.[0-9]+' - - 'release-[0-9]+.[0-9]+.[0-9]+' - - 'release-[0-9]+.[0-9]+' + - 'release-[0-9]+.[0-9]+.x' paths: - 'src/llama_stack/**' - '!src/llama_stack/ui/**' From fa7699d2c3db55f214a794be8139789174e09cb0 Mon Sep 17 00:00:00 2001 From: Jiayi Ni Date: Thu, 30 Oct 2025 21:42:09 -0700 Subject: [PATCH 10/20] feat: Add rerank API for NVIDIA Inference Provider (#3329) # What does this PR do? Add rerank API for NVIDIA Inference Provider. Closes #3278 ## Test Plan Unit test: ``` pytest tests/unit/providers/nvidia/test_rerank_inference.py ``` Integration test: ``` pytest -s -v tests/integration/inference/test_rerank.py --stack-config="inference=nvidia" --rerank-model=nvidia/nvidia/nv-rerankqa-mistral-4b-v3 --env NVIDIA_API_KEY="" --env NVIDIA_BASE_URL="https://integrate.api.nvidia.com" ``` --- .../providers/inference/remote_nvidia.mdx | 1 + .../remote/inference/nvidia/NVIDIA.md | 19 ++ .../remote/inference/nvidia/config.py | 9 + .../remote/inference/nvidia/nvidia.py | 111 ++++++++ tests/integration/conftest.py | 5 + tests/integration/fixtures/common.py | 13 +- tests/integration/inference/test_rerank.py | 214 +++++++++++++++ .../providers/nvidia/test_rerank_inference.py | 251 ++++++++++++++++++ 8 files changed, 622 insertions(+), 1 deletion(-) create mode 100644 tests/integration/inference/test_rerank.py create mode 100644 tests/unit/providers/nvidia/test_rerank_inference.py diff --git a/docs/docs/providers/inference/remote_nvidia.mdx b/docs/docs/providers/inference/remote_nvidia.mdx index b4e04176c..57c64ab46 100644 --- a/docs/docs/providers/inference/remote_nvidia.mdx +++ b/docs/docs/providers/inference/remote_nvidia.mdx @@ -20,6 +20,7 @@ NVIDIA inference provider for accessing NVIDIA NIM models and AI services. | `url` | `` | No | https://integrate.api.nvidia.com | A base url for accessing the NVIDIA NIM | | `timeout` | `` | No | 60 | Timeout for the HTTP requests | | `append_api_version` | `` | No | True | When set to false, the API version will not be appended to the base_url. By default, it is true. | +| `rerank_model_to_url` | `dict[str, str` | No | `{'nv-rerank-qa-mistral-4b:1': 'https://ai.api.nvidia.com/v1/retrieval/nvidia/reranking', 'nvidia/nv-rerankqa-mistral-4b-v3': 'https://ai.api.nvidia.com/v1/retrieval/nvidia/nv-rerankqa-mistral-4b-v3/reranking', 'nvidia/llama-3.2-nv-rerankqa-1b-v2': 'https://ai.api.nvidia.com/v1/retrieval/nvidia/llama-3_2-nv-rerankqa-1b-v2/reranking'}` | Mapping of rerank model identifiers to their API endpoints. | ## Sample Configuration diff --git a/src/llama_stack/providers/remote/inference/nvidia/NVIDIA.md b/src/llama_stack/providers/remote/inference/nvidia/NVIDIA.md index f1a828413..97fa95a1f 100644 --- a/src/llama_stack/providers/remote/inference/nvidia/NVIDIA.md +++ b/src/llama_stack/providers/remote/inference/nvidia/NVIDIA.md @@ -181,3 +181,22 @@ vlm_response = client.chat.completions.create( print(f"VLM Response: {vlm_response.choices[0].message.content}") ``` + +### Rerank Example + +The following example shows how to rerank documents using an NVIDIA NIM. + +```python +rerank_response = client.alpha.inference.rerank( + model="nvidia/nvidia/llama-3.2-nv-rerankqa-1b-v2", + query="query", + items=[ + "item_1", + "item_2", + "item_3", + ], +) + +for i, result in enumerate(rerank_response): + print(f"{i+1}. [Index: {result.index}, " f"Score: {(result.relevance_score):.3f}]") +``` \ No newline at end of file diff --git a/src/llama_stack/providers/remote/inference/nvidia/config.py b/src/llama_stack/providers/remote/inference/nvidia/config.py index 3545d2b11..618bbe078 100644 --- a/src/llama_stack/providers/remote/inference/nvidia/config.py +++ b/src/llama_stack/providers/remote/inference/nvidia/config.py @@ -28,6 +28,7 @@ class NVIDIAConfig(RemoteInferenceProviderConfig): Attributes: url (str): A base url for accessing the NVIDIA NIM, e.g. http://localhost:8000 api_key (str): The access key for the hosted NIM endpoints + rerank_model_to_url (dict[str, str]): Mapping of rerank model identifiers to their API endpoints There are two ways to access NVIDIA NIMs - 0. Hosted: Preview APIs hosted at https://integrate.api.nvidia.com @@ -55,6 +56,14 @@ class NVIDIAConfig(RemoteInferenceProviderConfig): default_factory=lambda: os.getenv("NVIDIA_APPEND_API_VERSION", "True").lower() != "false", description="When set to false, the API version will not be appended to the base_url. By default, it is true.", ) + rerank_model_to_url: dict[str, str] = Field( + default_factory=lambda: { + "nv-rerank-qa-mistral-4b:1": "https://ai.api.nvidia.com/v1/retrieval/nvidia/reranking", + "nvidia/nv-rerankqa-mistral-4b-v3": "https://ai.api.nvidia.com/v1/retrieval/nvidia/nv-rerankqa-mistral-4b-v3/reranking", + "nvidia/llama-3.2-nv-rerankqa-1b-v2": "https://ai.api.nvidia.com/v1/retrieval/nvidia/llama-3_2-nv-rerankqa-1b-v2/reranking", + }, + description="Mapping of rerank model identifiers to their API endpoints. ", + ) @classmethod def sample_run_config( diff --git a/src/llama_stack/providers/remote/inference/nvidia/nvidia.py b/src/llama_stack/providers/remote/inference/nvidia/nvidia.py index ea11b49cd..bc5aa7953 100644 --- a/src/llama_stack/providers/remote/inference/nvidia/nvidia.py +++ b/src/llama_stack/providers/remote/inference/nvidia/nvidia.py @@ -5,6 +5,19 @@ # the root directory of this source tree. +from collections.abc import Iterable + +import aiohttp + +from llama_stack.apis.inference import ( + RerankData, + RerankResponse, +) +from llama_stack.apis.inference.inference import ( + OpenAIChatCompletionContentPartImageParam, + OpenAIChatCompletionContentPartTextParam, +) +from llama_stack.apis.models import Model, ModelType from llama_stack.log import get_logger from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin @@ -61,3 +74,101 @@ class NVIDIAInferenceAdapter(OpenAIMixin): :return: The NVIDIA API base URL """ return f"{self.config.url}/v1" if self.config.append_api_version else self.config.url + + async def list_provider_model_ids(self) -> Iterable[str]: + """ + Return both dynamic model IDs and statically configured rerank model IDs. + """ + dynamic_ids: Iterable[str] = [] + try: + dynamic_ids = await super().list_provider_model_ids() + except Exception: + # If the dynamic listing fails, proceed with just configured rerank IDs + dynamic_ids = [] + + configured_rerank_ids = list(self.config.rerank_model_to_url.keys()) + return list(dict.fromkeys(list(dynamic_ids) + configured_rerank_ids)) # remove duplicates + + def construct_model_from_identifier(self, identifier: str) -> Model: + """ + Classify rerank models from config; otherwise use the base behavior. + """ + if identifier in self.config.rerank_model_to_url: + return Model( + provider_id=self.__provider_id__, # type: ignore[attr-defined] + provider_resource_id=identifier, + identifier=identifier, + model_type=ModelType.rerank, + ) + return super().construct_model_from_identifier(identifier) + + async def rerank( + self, + model: str, + query: str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam, + items: list[str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam], + max_num_results: int | None = None, + ) -> RerankResponse: + provider_model_id = await self._get_provider_model_id(model) + + ranking_url = self.get_base_url() + + if _is_nvidia_hosted(self.config) and provider_model_id in self.config.rerank_model_to_url: + ranking_url = self.config.rerank_model_to_url[provider_model_id] + + logger.debug(f"Using rerank endpoint: {ranking_url} for model: {provider_model_id}") + + # Convert query to text format + if isinstance(query, str): + query_text = query + elif isinstance(query, OpenAIChatCompletionContentPartTextParam): + query_text = query.text + else: + raise ValueError("Query must be a string or text content part") + + # Convert items to text format + passages = [] + for item in items: + if isinstance(item, str): + passages.append({"text": item}) + elif isinstance(item, OpenAIChatCompletionContentPartTextParam): + passages.append({"text": item.text}) + else: + raise ValueError("Items must be strings or text content parts") + + payload = { + "model": provider_model_id, + "query": {"text": query_text}, + "passages": passages, + } + + headers = { + "Authorization": f"Bearer {self.get_api_key()}", + "Content-Type": "application/json", + } + + try: + async with aiohttp.ClientSession() as session: + async with session.post(ranking_url, headers=headers, json=payload) as response: + if response.status != 200: + response_text = await response.text() + raise ConnectionError( + f"NVIDIA rerank API request failed with status {response.status}: {response_text}" + ) + + result = await response.json() + rankings = result.get("rankings", []) + + # Convert to RerankData format + rerank_data = [] + for ranking in rankings: + rerank_data.append(RerankData(index=ranking["index"], relevance_score=ranking["logit"])) + + # Apply max_num_results limit + if max_num_results is not None: + rerank_data = rerank_data[:max_num_results] + + return RerankResponse(data=rerank_data) + + except aiohttp.ClientError as e: + raise ConnectionError(f"Failed to connect to NVIDIA rerank API at {ranking_url}: {e}") from e diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index aaedd8476..e5ae72fc1 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -171,6 +171,10 @@ def pytest_addoption(parser): "--embedding-model", help="comma-separated list of embedding models. Fixture name: embedding_model_id", ) + parser.addoption( + "--rerank-model", + help="comma-separated list of rerank models. Fixture name: rerank_model_id", + ) parser.addoption( "--safety-shield", help="comma-separated list of safety shields. Fixture name: shield_id", @@ -249,6 +253,7 @@ def pytest_generate_tests(metafunc): "shield_id": ("--safety-shield", "shield"), "judge_model_id": ("--judge-model", "judge"), "embedding_dimension": ("--embedding-dimension", "dim"), + "rerank_model_id": ("--rerank-model", "rerank"), } # Collect all parameters and their values diff --git a/tests/integration/fixtures/common.py b/tests/integration/fixtures/common.py index e68f9dc9e..57775ce25 100644 --- a/tests/integration/fixtures/common.py +++ b/tests/integration/fixtures/common.py @@ -153,6 +153,7 @@ def client_with_models( vision_model_id, embedding_model_id, judge_model_id, + rerank_model_id, ): client = llama_stack_client @@ -170,6 +171,9 @@ def client_with_models( if embedding_model_id and embedding_model_id not in model_ids: raise ValueError(f"embedding_model_id {embedding_model_id} not found") + + if rerank_model_id and rerank_model_id not in model_ids: + raise ValueError(f"rerank_model_id {rerank_model_id} not found") return client @@ -185,7 +189,14 @@ def model_providers(llama_stack_client): @pytest.fixture(autouse=True) def skip_if_no_model(request): - model_fixtures = ["text_model_id", "vision_model_id", "embedding_model_id", "judge_model_id", "shield_id"] + model_fixtures = [ + "text_model_id", + "vision_model_id", + "embedding_model_id", + "judge_model_id", + "shield_id", + "rerank_model_id", + ] test_func = request.node.function actual_params = inspect.signature(test_func).parameters.keys() diff --git a/tests/integration/inference/test_rerank.py b/tests/integration/inference/test_rerank.py new file mode 100644 index 000000000..82f35cd27 --- /dev/null +++ b/tests/integration/inference/test_rerank.py @@ -0,0 +1,214 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import pytest +from llama_stack_client import BadRequestError as LlamaStackBadRequestError +from llama_stack_client.types.alpha import InferenceRerankResponse +from llama_stack_client.types.shared.interleaved_content import ( + ImageContentItem, + ImageContentItemImage, + ImageContentItemImageURL, + TextContentItem, +) + +from llama_stack.core.library_client import LlamaStackAsLibraryClient + +# Test data +DUMMY_STRING = "string_1" +DUMMY_STRING2 = "string_2" +DUMMY_TEXT = TextContentItem(text=DUMMY_STRING, type="text") +DUMMY_TEXT2 = TextContentItem(text=DUMMY_STRING2, type="text") +DUMMY_IMAGE_URL = ImageContentItem( + image=ImageContentItemImage(url=ImageContentItemImageURL(uri="https://example.com/image.jpg")), type="image" +) +DUMMY_IMAGE_BASE64 = ImageContentItem(image=ImageContentItemImage(data="base64string"), type="image") + +PROVIDERS_SUPPORTING_MEDIA = {} # Providers that support media input for rerank models + + +def skip_if_provider_doesnt_support_rerank(inference_provider_type): + supported_providers = {"remote::nvidia"} + if inference_provider_type not in supported_providers: + pytest.skip(f"{inference_provider_type} doesn't support rerank models") + + +def _validate_rerank_response(response: InferenceRerankResponse, items: list) -> None: + """ + Validate that a rerank response has the correct structure and ordering. + + Args: + response: The InferenceRerankResponse to validate + items: The original items list that was ranked + + Raises: + AssertionError: If any validation fails + """ + seen = set() + last_score = float("inf") + for d in response: + assert 0 <= d.index < len(items), f"Index {d.index} out of bounds for {len(items)} items" + assert d.index not in seen, f"Duplicate index {d.index} found" + seen.add(d.index) + assert isinstance(d.relevance_score, float), f"Score must be float, got {type(d.relevance_score)}" + assert d.relevance_score <= last_score, f"Scores not in descending order: {d.relevance_score} > {last_score}" + last_score = d.relevance_score + + +def _validate_semantic_ranking(response: InferenceRerankResponse, items: list, expected_first_item: str) -> None: + """ + Validate that the expected most relevant item ranks first. + + Args: + response: The InferenceRerankResponse to validate + items: The original items list that was ranked + expected_first_item: The expected first item in the ranking + + Raises: + AssertionError: If any validation fails + """ + if not response: + raise AssertionError("No ranking data returned in response") + + actual_first_index = response[0].index + actual_first_item = items[actual_first_index] + assert actual_first_item == expected_first_item, ( + f"Expected '{expected_first_item}' to rank first, but '{actual_first_item}' ranked first instead." + ) + + +@pytest.mark.parametrize( + "query,items", + [ + (DUMMY_STRING, [DUMMY_STRING, DUMMY_STRING2]), + (DUMMY_TEXT, [DUMMY_TEXT, DUMMY_TEXT2]), + (DUMMY_STRING, [DUMMY_STRING2, DUMMY_TEXT]), + (DUMMY_TEXT, [DUMMY_STRING, DUMMY_TEXT2]), + ], + ids=[ + "string-query-string-items", + "text-query-text-items", + "mixed-content-1", + "mixed-content-2", + ], +) +def test_rerank_text(client_with_models, rerank_model_id, query, items, inference_provider_type): + skip_if_provider_doesnt_support_rerank(inference_provider_type) + + response = client_with_models.alpha.inference.rerank(model=rerank_model_id, query=query, items=items) + assert isinstance(response, list) + # TODO: Add type validation for response items once InferenceRerankResponseItem is exported from llama stack client. + assert len(response) <= len(items) + _validate_rerank_response(response, items) + + +@pytest.mark.parametrize( + "query,items", + [ + (DUMMY_IMAGE_URL, [DUMMY_STRING]), + (DUMMY_IMAGE_BASE64, [DUMMY_TEXT]), + (DUMMY_TEXT, [DUMMY_IMAGE_URL]), + (DUMMY_IMAGE_BASE64, [DUMMY_IMAGE_URL, DUMMY_STRING, DUMMY_IMAGE_BASE64, DUMMY_TEXT]), + (DUMMY_TEXT, [DUMMY_IMAGE_URL, DUMMY_STRING, DUMMY_IMAGE_BASE64, DUMMY_TEXT]), + ], + ids=[ + "image-query-url", + "image-query-base64", + "text-query-image-item", + "mixed-content-1", + "mixed-content-2", + ], +) +def test_rerank_image(client_with_models, rerank_model_id, query, items, inference_provider_type): + skip_if_provider_doesnt_support_rerank(inference_provider_type) + + if rerank_model_id not in PROVIDERS_SUPPORTING_MEDIA: + error_type = ( + ValueError if isinstance(client_with_models, LlamaStackAsLibraryClient) else LlamaStackBadRequestError + ) + with pytest.raises(error_type): + client_with_models.alpha.inference.rerank(model=rerank_model_id, query=query, items=items) + else: + response = client_with_models.alpha.inference.rerank(model=rerank_model_id, query=query, items=items) + + assert isinstance(response, list) + assert len(response) <= len(items) + _validate_rerank_response(response, items) + + +def test_rerank_max_results(client_with_models, rerank_model_id, inference_provider_type): + skip_if_provider_doesnt_support_rerank(inference_provider_type) + + items = [DUMMY_STRING, DUMMY_STRING2, DUMMY_TEXT, DUMMY_TEXT2] + max_num_results = 2 + + response = client_with_models.alpha.inference.rerank( + model=rerank_model_id, + query=DUMMY_STRING, + items=items, + max_num_results=max_num_results, + ) + + assert isinstance(response, list) + assert len(response) == max_num_results + _validate_rerank_response(response, items) + + +def test_rerank_max_results_larger_than_items(client_with_models, rerank_model_id, inference_provider_type): + skip_if_provider_doesnt_support_rerank(inference_provider_type) + + items = [DUMMY_STRING, DUMMY_STRING2] + response = client_with_models.alpha.inference.rerank( + model=rerank_model_id, + query=DUMMY_STRING, + items=items, + max_num_results=10, # Larger than items length + ) + + assert isinstance(response, list) + assert len(response) <= len(items) # Should return at most len(items) + + +@pytest.mark.parametrize( + "query,items,expected_first_item", + [ + ( + "What is a reranking model? ", + [ + "A reranking model reranks a list of items based on the query. ", + "Machine learning algorithms learn patterns from data. ", + "Python is a programming language. ", + ], + "A reranking model reranks a list of items based on the query. ", + ), + ( + "What is C++?", + [ + "Learning new things is interesting. ", + "C++ is a programming language. ", + "Books provide knowledge and entertainment. ", + ], + "C++ is a programming language. ", + ), + ( + "What are good learning habits? ", + [ + "Cooking pasta is a fun activity. ", + "Plants need water and sunlight. ", + "Good learning habits include reading daily and taking notes. ", + ], + "Good learning habits include reading daily and taking notes. ", + ), + ], +) +def test_rerank_semantic_correctness( + client_with_models, rerank_model_id, query, items, expected_first_item, inference_provider_type +): + skip_if_provider_doesnt_support_rerank(inference_provider_type) + + response = client_with_models.alpha.inference.rerank(model=rerank_model_id, query=query, items=items) + + _validate_rerank_response(response, items) + _validate_semantic_ranking(response, items, expected_first_item) diff --git a/tests/unit/providers/nvidia/test_rerank_inference.py b/tests/unit/providers/nvidia/test_rerank_inference.py new file mode 100644 index 000000000..2793b5f44 --- /dev/null +++ b/tests/unit/providers/nvidia/test_rerank_inference.py @@ -0,0 +1,251 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from unittest.mock import AsyncMock, MagicMock, patch + +import aiohttp +import pytest + +from llama_stack.apis.models import ModelType +from llama_stack.providers.remote.inference.nvidia.config import NVIDIAConfig +from llama_stack.providers.remote.inference.nvidia.nvidia import NVIDIAInferenceAdapter +from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin + + +class MockResponse: + def __init__(self, status=200, json_data=None, text_data="OK"): + self.status = status + self._json_data = json_data or {"rankings": []} + self._text_data = text_data + + async def json(self): + return self._json_data + + async def text(self): + return self._text_data + + +class MockSession: + def __init__(self, response): + self.response = response + self.post_calls = [] + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + return False + + def post(self, url, **kwargs): + self.post_calls.append((url, kwargs)) + + class PostContext: + def __init__(self, response): + self.response = response + + async def __aenter__(self): + return self.response + + async def __aexit__(self, exc_type, exc_val, exc_tb): + return False + + return PostContext(self.response) + + +def create_adapter(config=None, rerank_endpoints=None): + if config is None: + config = NVIDIAConfig(api_key="test-key") + + adapter = NVIDIAInferenceAdapter(config=config) + + class MockModel: + provider_resource_id = "test-model" + metadata = {} + + adapter.model_store = AsyncMock() + adapter.model_store.get_model = AsyncMock(return_value=MockModel()) + + if rerank_endpoints is not None: + adapter.config.rerank_model_to_url = rerank_endpoints + + return adapter + + +async def test_rerank_basic_functionality(): + adapter = create_adapter() + mock_response = MockResponse(json_data={"rankings": [{"index": 0, "logit": 0.5}]}) + mock_session = MockSession(mock_response) + + with patch("aiohttp.ClientSession", return_value=mock_session): + result = await adapter.rerank(model="test-model", query="test query", items=["item1", "item2"]) + + assert len(result.data) == 1 + assert result.data[0].index == 0 + assert result.data[0].relevance_score == 0.5 + + url, kwargs = mock_session.post_calls[0] + payload = kwargs["json"] + assert payload["model"] == "test-model" + assert payload["query"] == {"text": "test query"} + assert payload["passages"] == [{"text": "item1"}, {"text": "item2"}] + + +async def test_missing_rankings_key(): + adapter = create_adapter() + mock_session = MockSession(MockResponse(json_data={})) + + with patch("aiohttp.ClientSession", return_value=mock_session): + result = await adapter.rerank(model="test-model", query="q", items=["a"]) + + assert len(result.data) == 0 + + +async def test_hosted_with_endpoint(): + adapter = create_adapter( + config=NVIDIAConfig(api_key="key"), rerank_endpoints={"test-model": "https://model.endpoint/rerank"} + ) + mock_session = MockSession(MockResponse()) + + with patch("aiohttp.ClientSession", return_value=mock_session): + await adapter.rerank(model="test-model", query="q", items=["a"]) + + url, _ = mock_session.post_calls[0] + assert url == "https://model.endpoint/rerank" + + +async def test_hosted_without_endpoint(): + adapter = create_adapter( + config=NVIDIAConfig(api_key="key"), # This creates hosted config (integrate.api.nvidia.com). + rerank_endpoints={}, # No endpoint mapping for test-model + ) + mock_session = MockSession(MockResponse()) + + with patch("aiohttp.ClientSession", return_value=mock_session): + await adapter.rerank(model="test-model", query="q", items=["a"]) + + url, _ = mock_session.post_calls[0] + assert "https://integrate.api.nvidia.com" in url + + +async def test_hosted_model_not_in_endpoint_mapping(): + adapter = create_adapter( + config=NVIDIAConfig(api_key="key"), rerank_endpoints={"other-model": "https://other.endpoint/rerank"} + ) + mock_session = MockSession(MockResponse()) + + with patch("aiohttp.ClientSession", return_value=mock_session): + await adapter.rerank(model="test-model", query="q", items=["a"]) + + url, _ = mock_session.post_calls[0] + assert "https://integrate.api.nvidia.com" in url + assert url != "https://other.endpoint/rerank" + + +async def test_self_hosted_ignores_endpoint(): + adapter = create_adapter( + config=NVIDIAConfig(url="http://localhost:8000", api_key=None), + rerank_endpoints={"test-model": "https://model.endpoint/rerank"}, # This should be ignored for self-hosted. + ) + mock_session = MockSession(MockResponse()) + + with patch("aiohttp.ClientSession", return_value=mock_session): + await adapter.rerank(model="test-model", query="q", items=["a"]) + + url, _ = mock_session.post_calls[0] + assert "http://localhost:8000" in url + assert "model.endpoint/rerank" not in url + + +async def test_max_num_results(): + adapter = create_adapter() + rankings = [{"index": 0, "logit": 0.8}, {"index": 1, "logit": 0.6}] + mock_session = MockSession(MockResponse(json_data={"rankings": rankings})) + + with patch("aiohttp.ClientSession", return_value=mock_session): + result = await adapter.rerank(model="test-model", query="q", items=["a", "b"], max_num_results=1) + + assert len(result.data) == 1 + assert result.data[0].index == 0 + assert result.data[0].relevance_score == 0.8 + + +async def test_http_error(): + adapter = create_adapter() + mock_session = MockSession(MockResponse(status=500, text_data="Server Error")) + + with patch("aiohttp.ClientSession", return_value=mock_session): + with pytest.raises(ConnectionError, match="status 500.*Server Error"): + await adapter.rerank(model="test-model", query="q", items=["a"]) + + +async def test_client_error(): + adapter = create_adapter() + mock_session = AsyncMock() + mock_session.__aenter__.side_effect = aiohttp.ClientError("Network error") + + with patch("aiohttp.ClientSession", return_value=mock_session): + with pytest.raises(ConnectionError, match="Failed to connect.*Network error"): + await adapter.rerank(model="test-model", query="q", items=["a"]) + + +async def test_list_models_includes_configured_rerank_models(): + """Test that list_models adds rerank models to the dynamic model list.""" + adapter = create_adapter() + adapter.__provider_id__ = "nvidia" + adapter.__provider_spec__ = MagicMock() + + dynamic_ids = ["llm-1", "embedding-1"] + with patch.object(OpenAIMixin, "list_provider_model_ids", new=AsyncMock(return_value=dynamic_ids)): + result = await adapter.list_models() + + assert result is not None + + # Check that the rerank models are added + model_ids = [m.identifier for m in result] + assert "nv-rerank-qa-mistral-4b:1" in model_ids + assert "nvidia/nv-rerankqa-mistral-4b-v3" in model_ids + assert "nvidia/llama-3.2-nv-rerankqa-1b-v2" in model_ids + + rerank_models = [m for m in result if m.model_type == ModelType.rerank] + + assert len(rerank_models) == 3 + + for m in rerank_models: + assert m.provider_id == "nvidia" + assert m.model_type == ModelType.rerank + assert m.metadata == {} + assert m.identifier in adapter._model_cache + + +async def test_list_provider_model_ids_has_no_duplicates(): + adapter = create_adapter() + + dynamic_ids = [ + "llm-1", + "nvidia/nv-rerankqa-mistral-4b-v3", # overlaps configured rerank ids + "embedding-1", + "llm-1", + ] + + with patch.object(OpenAIMixin, "list_provider_model_ids", new=AsyncMock(return_value=dynamic_ids)): + ids = list(await adapter.list_provider_model_ids()) + + assert len(ids) == len(set(ids)) + assert ids.count("nvidia/nv-rerankqa-mistral-4b-v3") == 1 + assert "nv-rerank-qa-mistral-4b:1" in ids + assert "nvidia/llama-3.2-nv-rerankqa-1b-v2" in ids + + +async def test_list_provider_model_ids_uses_configured_on_dynamic_failure(): + adapter = create_adapter() + + # Simulate dynamic listing failure + with patch.object(OpenAIMixin, "list_provider_model_ids", new=AsyncMock(side_effect=Exception)): + ids = list(await adapter.list_provider_model_ids()) + + # Should still return configured rerank ids + configured_ids = list(adapter.config.rerank_model_to_url.keys()) + assert set(ids) == set(configured_ids) From 6d80ca4bf70f21bad0691b59555c93c9fbe6a033 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Thu, 30 Oct 2025 22:09:25 -0700 Subject: [PATCH 11/20] fix(ci): replace unused LLAMA_STACK_CLIENT_DIR with direct install (#4000) Replace unused `LLAMA_STACK_CLIENT_DIR` env var (from old `llama stack build`) with direct `uv pip install` for release branch client installation. cc @ehhuang --- .github/actions/setup-test-environment/action.yml | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/.github/actions/setup-test-environment/action.yml b/.github/actions/setup-test-environment/action.yml index 27d0943fe..992b25803 100644 --- a/.github/actions/setup-test-environment/action.yml +++ b/.github/actions/setup-test-environment/action.yml @@ -54,23 +54,16 @@ runs: # Check if the branch exists in the client repo if git ls-remote --exit-code --heads https://github.com/llamastack/llama-stack-client-python.git "$TARGET_BRANCH" > /dev/null 2>&1; then echo "Installing llama-stack-client-python from matching branch: $TARGET_BRANCH" - export LLAMA_STACK_CLIENT_DIR=git+https://github.com/llamastack/llama-stack-client-python.git@$TARGET_BRANCH + uv pip install --force-reinstall git+https://github.com/llamastack/llama-stack-client-python.git@$TARGET_BRANCH else echo "::error::Branch $TARGET_BRANCH not found in llama-stack-client-python repository" echo "::error::Please create the matching release branch in llama-stack-client-python before testing" exit 1 fi - else - echo "Installing latest llama-stack-client-python from main branch" - export LLAMA_STACK_CLIENT_DIR=git+https://github.com/llamastack/llama-stack-client-python.git@main fi - elif [ "${{ inputs.client-version }}" = "published" ]; then - echo "Installing published llama-stack-client-python from PyPI" - unset LLAMA_STACK_CLIENT_DIR - else - echo "Invalid client-version: ${{ inputs.client-version }}" - exit 1 + # For main branch, client is already installed by setup-runner fi + # For published version, client is already installed by setup-runner echo "Building Llama Stack" From 5f95c1f8cc16d16f48143bcdeff1fa5c73569222 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Fri, 31 Oct 2025 06:16:20 -0700 Subject: [PATCH 12/20] fix(ci): install client from release branch before uv sync (#4001) Fixes CI failures on release branches where uv sync can't resolve RC dependencies. The problem: on release branches like `release-0.3.x`, pyproject.toml requires `llama-stack-client>=0.3.1rc1`. But RC versions only exist on test.pypi, not PyPI. So uv sync fails before we even get a chance to install the client from git. The fix is simple - on release branches, pre-install the client from the matching git branch first, then run uv sync. This satisfies the RC requirement and lets dependency resolution succeed. Modified setup-runner and pre-commit workflows to do this. Also cleaned up some duplicate logic in setup-test-environment that's now handled centrally. Example failure: https://github.com/llamastack/llama-stack/actions/runs/18963190991/job/54154788350 --- .../install-llama-stack-client/action.yml | 64 +++++++++++++++++++ .github/actions/setup-runner/action.yml | 23 ++++--- .../actions/setup-test-environment/action.yml | 24 +------ .github/workflows/pre-commit.yml | 16 ++++- 4 files changed, 93 insertions(+), 34 deletions(-) create mode 100644 .github/actions/install-llama-stack-client/action.yml diff --git a/.github/actions/install-llama-stack-client/action.yml b/.github/actions/install-llama-stack-client/action.yml new file mode 100644 index 000000000..553d82f01 --- /dev/null +++ b/.github/actions/install-llama-stack-client/action.yml @@ -0,0 +1,64 @@ +name: Install llama-stack-client +description: Install llama-stack-client based on branch context and client-version input + +inputs: + client-version: + description: 'Client version to install on non-release branches (latest or published). Ignored on release branches.' + required: false + default: "" + +outputs: + uv-index-url: + description: 'UV_INDEX_URL to use (set for release branches)' + value: ${{ steps.configure.outputs.uv-index-url }} + uv-extra-index-url: + description: 'UV_EXTRA_INDEX_URL to use (set for release branches)' + value: ${{ steps.configure.outputs.uv-extra-index-url }} + install-after-sync: + description: 'Whether to install client after uv sync' + value: ${{ steps.configure.outputs.install-after-sync }} + install-source: + description: 'Where to install client from after sync' + value: ${{ steps.configure.outputs.install-source }} + +runs: + using: "composite" + steps: + - name: Configure client installation + id: configure + shell: bash + run: | + # Determine the branch we're working with + BRANCH="${{ github.base_ref || github.ref }}" + BRANCH="${BRANCH#refs/heads/}" + + echo "Working with branch: $BRANCH" + + # On release branches: use test.pypi for uv sync, then install from git + # On non-release branches: install based on client-version after sync + if [[ "$BRANCH" =~ ^release-[0-9]+\.[0-9]+\.x$ ]]; then + echo "Detected release branch: $BRANCH" + + # Check if matching branch exists in client repo + if ! git ls-remote --exit-code --heads https://github.com/llamastack/llama-stack-client-python.git "$BRANCH" > /dev/null 2>&1; then + echo "::error::Branch $BRANCH not found in llama-stack-client-python repository" + echo "::error::Please create the matching release branch in llama-stack-client-python before testing" + exit 1 + fi + + # Configure to use test.pypi for sync (to resolve RC versions) + echo "uv-index-url=https://test.pypi.org/simple/" >> $GITHUB_OUTPUT + echo "uv-extra-index-url=https://pypi.org/simple/" >> $GITHUB_OUTPUT + echo "install-after-sync=true" >> $GITHUB_OUTPUT + echo "install-source=git+https://github.com/llamastack/llama-stack-client-python.git@$BRANCH" >> $GITHUB_OUTPUT + elif [ "${{ inputs.client-version }}" = "latest" ]; then + # Install from main git after sync + echo "install-after-sync=true" >> $GITHUB_OUTPUT + echo "install-source=git+https://github.com/llamastack/llama-stack-client-python.git@main" >> $GITHUB_OUTPUT + elif [ "${{ inputs.client-version }}" = "published" ]; then + # Use published version from PyPI (installed by sync) + echo "install-after-sync=false" >> $GITHUB_OUTPUT + elif [ -n "${{ inputs.client-version }}" ]; then + echo "::error::Invalid client-version: ${{ inputs.client-version }}" + exit 1 + fi diff --git a/.github/actions/setup-runner/action.yml b/.github/actions/setup-runner/action.yml index 905d6b73a..52a3c4643 100644 --- a/.github/actions/setup-runner/action.yml +++ b/.github/actions/setup-runner/action.yml @@ -18,8 +18,17 @@ runs: python-version: ${{ inputs.python-version }} version: 0.7.6 + - name: Configure client installation + id: client-config + uses: ./.github/actions/install-llama-stack-client + with: + client-version: ${{ inputs.client-version }} + - name: Install dependencies shell: bash + env: + UV_INDEX_URL: ${{ steps.client-config.outputs.uv-index-url }} + UV_EXTRA_INDEX_URL: ${{ steps.client-config.outputs.uv-extra-index-url }} run: | echo "Updating project dependencies via uv sync" uv sync --all-groups @@ -27,16 +36,10 @@ runs: echo "Installing ad-hoc dependencies" uv pip install faiss-cpu - # Install llama-stack-client-python based on the client-version input - if [ "${{ inputs.client-version }}" = "latest" ]; then - echo "Installing latest llama-stack-client-python from main branch" - uv pip install git+https://github.com/llamastack/llama-stack-client-python.git@main - elif [ "${{ inputs.client-version }}" = "published" ]; then - echo "Installing published llama-stack-client-python from PyPI" - uv pip install llama-stack-client - else - echo "Invalid client-version: ${{ inputs.client-version }}" - exit 1 + # Install specific client version after sync if needed + if [ "${{ steps.client-config.outputs.install-after-sync }}" = "true" ]; then + echo "Installing llama-stack-client from: ${{ steps.client-config.outputs.install-source }}" + uv pip install ${{ steps.client-config.outputs.install-source }} fi echo "Installed llama packages" diff --git a/.github/actions/setup-test-environment/action.yml b/.github/actions/setup-test-environment/action.yml index 992b25803..7b306fef5 100644 --- a/.github/actions/setup-test-environment/action.yml +++ b/.github/actions/setup-test-environment/action.yml @@ -42,29 +42,7 @@ runs: - name: Build Llama Stack shell: bash run: | - # Install llama-stack-client-python based on the client-version input - if [ "${{ inputs.client-version }}" = "latest" ]; then - # Check if PR is targeting a release branch - TARGET_BRANCH="${{ github.base_ref }}" - - if [[ "$TARGET_BRANCH" =~ ^release-[0-9]+\.[0-9]+\.x$ ]]; then - echo "PR targets release branch: $TARGET_BRANCH" - echo "Checking if matching branch exists in llama-stack-client-python..." - - # Check if the branch exists in the client repo - if git ls-remote --exit-code --heads https://github.com/llamastack/llama-stack-client-python.git "$TARGET_BRANCH" > /dev/null 2>&1; then - echo "Installing llama-stack-client-python from matching branch: $TARGET_BRANCH" - uv pip install --force-reinstall git+https://github.com/llamastack/llama-stack-client-python.git@$TARGET_BRANCH - else - echo "::error::Branch $TARGET_BRANCH not found in llama-stack-client-python repository" - echo "::error::Please create the matching release branch in llama-stack-client-python before testing" - exit 1 - fi - fi - # For main branch, client is already installed by setup-runner - fi - # For published version, client is already installed by setup-runner - + # Client is already installed by setup-runner (handles both main and release branches) echo "Building Llama Stack" LLAMA_STACK_DIR=. \ diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index 695a4f9e2..6d9f358d2 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -130,8 +130,22 @@ jobs: exit 1 fi + - name: Configure client installation + id: client-config + uses: ./.github/actions/install-llama-stack-client + - name: Sync dev + type_checking dependencies - run: uv sync --group dev --group type_checking + env: + UV_INDEX_URL: ${{ steps.client-config.outputs.uv-index-url }} + UV_EXTRA_INDEX_URL: ${{ steps.client-config.outputs.uv-extra-index-url }} + run: | + uv sync --group dev --group type_checking + + # Install specific client version after sync if needed + if [ "${{ steps.client-config.outputs.install-after-sync }}" = "true" ]; then + echo "Installing llama-stack-client from: ${{ steps.client-config.outputs.install-source }}" + uv pip install ${{ steps.client-config.outputs.install-source }} + fi - name: Run mypy (full type_checking) run: | From c2fd17474e04b1e75565517fda115cea345f8578 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Fri, 31 Oct 2025 11:22:01 -0700 Subject: [PATCH 13/20] fix: stop printing server log, it is confusing --- scripts/integration-tests.sh | 93 +++++++++++++++++------------------- 1 file changed, 44 insertions(+), 49 deletions(-) diff --git a/scripts/integration-tests.sh b/scripts/integration-tests.sh index ed3934a5b..506ac12e0 100755 --- a/scripts/integration-tests.sh +++ b/scripts/integration-tests.sh @@ -23,7 +23,7 @@ COLLECT_ONLY=false # Function to display usage usage() { - cat << EOF + cat < /dev/null; then +if [[ "$COLLECT_ONLY" == false ]] && ! command -v llama &>/dev/null; then echo "llama could not be found, ensure llama-stack is installed" exit 1 fi -if ! command -v pytest &> /dev/null; then +if ! command -v pytest &>/dev/null; then echo "pytest could not be found, ensure pytest is installed" exit 1 fi @@ -219,7 +218,7 @@ if [[ "$STACK_CONFIG" == *"server:"* && "$COLLECT_ONLY" == false ]]; then # remove "server:" from STACK_CONFIG stack_config=$(echo "$STACK_CONFIG" | sed 's/^server://') - nohup llama stack run $stack_config > server.log 2>&1 & + nohup llama stack run $stack_config >server.log 2>&1 & echo "Waiting for Llama Stack Server to start..." for i in {1..30}; do @@ -248,7 +247,7 @@ if [[ "$STACK_CONFIG" == *"docker:"* && "$COLLECT_ONLY" == false ]]; then container_name="llama-stack-test-$DISTRO" if docker ps -a --format '{{.Names}}' | grep -q "^${container_name}$"; then echo "Dumping container logs before stopping..." - docker logs "$container_name" > "docker-${DISTRO}-${INFERENCE_MODE}.log" 2>&1 || true + docker logs "$container_name" >"docker-${DISTRO}-${INFERENCE_MODE}.log" 2>&1 || true echo "Stopping and removing container: $container_name" docker stop "$container_name" 2>/dev/null || true docker rm "$container_name" 2>/dev/null || true @@ -437,17 +436,13 @@ elif [ $exit_code -eq 5 ]; then else echo "āŒ Tests failed" echo "" - echo "=== Dumping last 100 lines of logs for debugging ===" - # Output server or container logs based on stack config if [[ "$STACK_CONFIG" == *"server:"* && -f "server.log" ]]; then - echo "--- Last 100 lines of server.log ---" - tail -100 server.log + echo "--- Server side failures can be located inside server.log (available from artifacts on CI) ---" elif [[ "$STACK_CONFIG" == *"docker:"* ]]; then docker_log_file="docker-${DISTRO}-${INFERENCE_MODE}.log" if [[ -f "$docker_log_file" ]]; then - echo "--- Last 100 lines of $docker_log_file ---" - tail -100 "$docker_log_file" + echo "--- Server side failures can be located inside $docker_log_file (available from artifacts on CI) ---" fi fi From 7b79cd05d587155f43e5f7b915bb70da5fe31119 Mon Sep 17 00:00:00 2001 From: Francisco Arceo Date: Fri, 31 Oct 2025 14:37:25 -0400 Subject: [PATCH 14/20] feat: Adding Prompts to admin UI (#3987) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # What does this PR do? 1. Updates Llama Stack Typescript client to include `prompts`api in playground client. 2. Updates the UI to display prompts and execute basic CRUD operations for prompts. (2) adds an explicit "Preview" section when creating the prompt to show users how the Prompts API behaves as you dynamically edit the prompt content. See example here:

Screenshot
2025-10-31 at 12 22 34 PM

Some screen shots:
Click me to expand! ### Prompts List with Prompts Screenshot 2025-10-31 at 12 20
05 PM ### Empty Prompts List Screenshot 2025-10-31 at 12 08
44 PM ### Create Prompt Screenshot 2025-10-31 at 11 03
29 AM ### Submit Prompt with error Screenshot 2025-10-31 at 12 09
28 PM
## Closes https://github.com/llamastack/llama-stack/issues/3322 ## Test Plan Added tests and manual testing. Signed-off-by: Francisco Javier Arceo --- .../ui/app/api/v1/[...path]/route.ts | 12 +- src/llama_stack/ui/app/prompts/page.tsx | 5 + .../ui/components/layout/app-sidebar.tsx | 6 + .../ui/components/prompts/index.ts | 4 + .../components/prompts/prompt-editor.test.tsx | 309 ++++++++++++++++ .../ui/components/prompts/prompt-editor.tsx | 346 ++++++++++++++++++ .../components/prompts/prompt-list.test.tsx | 259 +++++++++++++ .../ui/components/prompts/prompt-list.tsx | 164 +++++++++ .../prompts/prompt-management.test.tsx | 304 +++++++++++++++ .../components/prompts/prompt-management.tsx | 233 ++++++++++++ .../ui/components/prompts/types.ts | 16 + src/llama_stack/ui/components/ui/badge.tsx | 36 ++ src/llama_stack/ui/components/ui/label.tsx | 24 ++ src/llama_stack/ui/components/ui/tabs.tsx | 53 +++ src/llama_stack/ui/components/ui/textarea.tsx | 23 ++ src/llama_stack/ui/package-lock.json | 62 +++- src/llama_stack/ui/package.json | 4 +- 17 files changed, 1851 insertions(+), 9 deletions(-) create mode 100644 src/llama_stack/ui/app/prompts/page.tsx create mode 100644 src/llama_stack/ui/components/prompts/index.ts create mode 100644 src/llama_stack/ui/components/prompts/prompt-editor.test.tsx create mode 100644 src/llama_stack/ui/components/prompts/prompt-editor.tsx create mode 100644 src/llama_stack/ui/components/prompts/prompt-list.test.tsx create mode 100644 src/llama_stack/ui/components/prompts/prompt-list.tsx create mode 100644 src/llama_stack/ui/components/prompts/prompt-management.test.tsx create mode 100644 src/llama_stack/ui/components/prompts/prompt-management.tsx create mode 100644 src/llama_stack/ui/components/prompts/types.ts create mode 100644 src/llama_stack/ui/components/ui/badge.tsx create mode 100644 src/llama_stack/ui/components/ui/label.tsx create mode 100644 src/llama_stack/ui/components/ui/tabs.tsx create mode 100644 src/llama_stack/ui/components/ui/textarea.tsx diff --git a/src/llama_stack/ui/app/api/v1/[...path]/route.ts b/src/llama_stack/ui/app/api/v1/[...path]/route.ts index 51c1f8004..d1aa31014 100644 --- a/src/llama_stack/ui/app/api/v1/[...path]/route.ts +++ b/src/llama_stack/ui/app/api/v1/[...path]/route.ts @@ -51,10 +51,14 @@ async function proxyRequest(request: NextRequest, method: string) { ); // Create response with same status and headers - const proxyResponse = new NextResponse(responseText, { - status: response.status, - statusText: response.statusText, - }); + // Handle 204 No Content responses specially + const proxyResponse = + response.status === 204 + ? new NextResponse(null, { status: 204 }) + : new NextResponse(responseText, { + status: response.status, + statusText: response.statusText, + }); // Copy response headers (except problematic ones) response.headers.forEach((value, key) => { diff --git a/src/llama_stack/ui/app/prompts/page.tsx b/src/llama_stack/ui/app/prompts/page.tsx new file mode 100644 index 000000000..30106a056 --- /dev/null +++ b/src/llama_stack/ui/app/prompts/page.tsx @@ -0,0 +1,5 @@ +import { PromptManagement } from "@/components/prompts"; + +export default function PromptsPage() { + return ; +} diff --git a/src/llama_stack/ui/components/layout/app-sidebar.tsx b/src/llama_stack/ui/components/layout/app-sidebar.tsx index 373f0c5ae..a5df60aef 100644 --- a/src/llama_stack/ui/components/layout/app-sidebar.tsx +++ b/src/llama_stack/ui/components/layout/app-sidebar.tsx @@ -8,6 +8,7 @@ import { MessageCircle, Settings2, Compass, + FileText, } from "lucide-react"; import Link from "next/link"; import { usePathname } from "next/navigation"; @@ -50,6 +51,11 @@ const manageItems = [ url: "/logs/vector-stores", icon: Database, }, + { + title: "Prompts", + url: "/prompts", + icon: FileText, + }, { title: "Documentation", url: "https://llama-stack.readthedocs.io/en/latest/references/api_reference/index.html", diff --git a/src/llama_stack/ui/components/prompts/index.ts b/src/llama_stack/ui/components/prompts/index.ts new file mode 100644 index 000000000..d190c5eb6 --- /dev/null +++ b/src/llama_stack/ui/components/prompts/index.ts @@ -0,0 +1,4 @@ +export { PromptManagement } from "./prompt-management"; +export { PromptList } from "./prompt-list"; +export { PromptEditor } from "./prompt-editor"; +export * from "./types"; diff --git a/src/llama_stack/ui/components/prompts/prompt-editor.test.tsx b/src/llama_stack/ui/components/prompts/prompt-editor.test.tsx new file mode 100644 index 000000000..458a5f942 --- /dev/null +++ b/src/llama_stack/ui/components/prompts/prompt-editor.test.tsx @@ -0,0 +1,309 @@ +import React from "react"; +import { render, screen, fireEvent } from "@testing-library/react"; +import "@testing-library/jest-dom"; +import { PromptEditor } from "./prompt-editor"; +import type { Prompt, PromptFormData } from "./types"; + +describe("PromptEditor", () => { + const mockOnSave = jest.fn(); + const mockOnCancel = jest.fn(); + const mockOnDelete = jest.fn(); + + const defaultProps = { + onSave: mockOnSave, + onCancel: mockOnCancel, + onDelete: mockOnDelete, + }; + + beforeEach(() => { + jest.clearAllMocks(); + }); + + describe("Create Mode", () => { + test("renders create form correctly", () => { + render(); + + expect(screen.getByLabelText("Prompt Content *")).toBeInTheDocument(); + expect(screen.getByText("Variables")).toBeInTheDocument(); + expect(screen.getByText("Preview")).toBeInTheDocument(); + expect(screen.getByText("Create Prompt")).toBeInTheDocument(); + expect(screen.getByText("Cancel")).toBeInTheDocument(); + }); + + test("shows preview placeholder when no content", () => { + render(); + + expect( + screen.getByText("Enter content to preview the compiled prompt") + ).toBeInTheDocument(); + }); + + test("submits form with correct data", () => { + render(); + + const promptInput = screen.getByLabelText("Prompt Content *"); + fireEvent.change(promptInput, { + target: { value: "Hello {{name}}, welcome!" }, + }); + + fireEvent.click(screen.getByText("Create Prompt")); + + expect(mockOnSave).toHaveBeenCalledWith({ + prompt: "Hello {{name}}, welcome!", + variables: [], + }); + }); + + test("prevents submission with empty prompt", () => { + render(); + + fireEvent.click(screen.getByText("Create Prompt")); + + expect(mockOnSave).not.toHaveBeenCalled(); + }); + }); + + describe("Edit Mode", () => { + const mockPrompt: Prompt = { + prompt_id: "prompt_123", + prompt: "Hello {{name}}, how is {{weather}}?", + version: 1, + variables: ["name", "weather"], + is_default: true, + }; + + test("renders edit form with existing data", () => { + render(); + + expect( + screen.getByDisplayValue("Hello {{name}}, how is {{weather}}?") + ).toBeInTheDocument(); + expect(screen.getAllByText("name")).toHaveLength(2); // One in variables, one in preview + expect(screen.getAllByText("weather")).toHaveLength(2); // One in variables, one in preview + expect(screen.getByText("Update Prompt")).toBeInTheDocument(); + expect(screen.getByText("Delete Prompt")).toBeInTheDocument(); + }); + + test("submits updated data correctly", () => { + render(); + + const promptInput = screen.getByLabelText("Prompt Content *"); + fireEvent.change(promptInput, { + target: { value: "Updated: Hello {{name}}!" }, + }); + + fireEvent.click(screen.getByText("Update Prompt")); + + expect(mockOnSave).toHaveBeenCalledWith({ + prompt: "Updated: Hello {{name}}!", + variables: ["name", "weather"], + }); + }); + }); + + describe("Variables Management", () => { + test("adds new variable", () => { + render(); + + const variableInput = screen.getByPlaceholderText( + "Add variable name (e.g. user_name, topic)" + ); + fireEvent.change(variableInput, { target: { value: "testVar" } }); + fireEvent.click(screen.getByText("Add")); + + expect(screen.getByText("testVar")).toBeInTheDocument(); + }); + + test("prevents adding duplicate variables", () => { + render(); + + const variableInput = screen.getByPlaceholderText( + "Add variable name (e.g. user_name, topic)" + ); + + // Add first variable + fireEvent.change(variableInput, { target: { value: "test" } }); + fireEvent.click(screen.getByText("Add")); + + // Try to add same variable again + fireEvent.change(variableInput, { target: { value: "test" } }); + + // Button should be disabled + expect(screen.getByText("Add")).toBeDisabled(); + }); + + test("removes variable", () => { + const mockPrompt: Prompt = { + prompt_id: "prompt_123", + prompt: "Hello {{name}}", + version: 1, + variables: ["name", "location"], + is_default: true, + }; + + render(); + + // Check that both variables are present initially + expect(screen.getAllByText("name").length).toBeGreaterThan(0); + expect(screen.getAllByText("location").length).toBeGreaterThan(0); + + // Remove the location variable by clicking the X button with the specific title + const removeLocationButton = screen.getByTitle( + "Remove location variable" + ); + fireEvent.click(removeLocationButton); + + // Name should still be there, location should be gone from the variables section + expect(screen.getAllByText("name").length).toBeGreaterThan(0); + expect( + screen.queryByTitle("Remove location variable") + ).not.toBeInTheDocument(); + }); + + test("adds variable on Enter key", () => { + render(); + + const variableInput = screen.getByPlaceholderText( + "Add variable name (e.g. user_name, topic)" + ); + fireEvent.change(variableInput, { target: { value: "enterVar" } }); + + // Simulate Enter key press + fireEvent.keyPress(variableInput, { + key: "Enter", + code: "Enter", + charCode: 13, + preventDefault: jest.fn(), + }); + + // Check if the variable was added by looking for the badge + expect(screen.getAllByText("enterVar").length).toBeGreaterThan(0); + }); + }); + + describe("Preview Functionality", () => { + test("shows live preview with variables", () => { + render(); + + // Add prompt content + const promptInput = screen.getByLabelText("Prompt Content *"); + fireEvent.change(promptInput, { + target: { value: "Hello {{name}}, welcome to {{place}}!" }, + }); + + // Add variables + const variableInput = screen.getByPlaceholderText( + "Add variable name (e.g. user_name, topic)" + ); + fireEvent.change(variableInput, { target: { value: "name" } }); + fireEvent.click(screen.getByText("Add")); + + fireEvent.change(variableInput, { target: { value: "place" } }); + fireEvent.click(screen.getByText("Add")); + + // Check that preview area shows the content + expect(screen.getByText("Compiled Prompt")).toBeInTheDocument(); + }); + + test("shows variable value inputs in preview", () => { + const mockPrompt: Prompt = { + prompt_id: "prompt_123", + prompt: "Hello {{name}}", + version: 1, + variables: ["name"], + is_default: true, + }; + + render(); + + expect(screen.getByText("Variable Values")).toBeInTheDocument(); + expect( + screen.getByPlaceholderText("Enter value for name") + ).toBeInTheDocument(); + }); + + test("shows color legend for variable states", () => { + render(); + + // Add content to show preview + const promptInput = screen.getByLabelText("Prompt Content *"); + fireEvent.change(promptInput, { + target: { value: "Hello {{name}}" }, + }); + + expect(screen.getByText("Used")).toBeInTheDocument(); + expect(screen.getByText("Unused")).toBeInTheDocument(); + expect(screen.getByText("Undefined")).toBeInTheDocument(); + }); + }); + + describe("Error Handling", () => { + test("displays error message", () => { + const errorMessage = "Prompt contains undeclared variables"; + render(); + + expect(screen.getByText(errorMessage)).toBeInTheDocument(); + }); + }); + + describe("Delete Functionality", () => { + const mockPrompt: Prompt = { + prompt_id: "prompt_123", + prompt: "Hello {{name}}", + version: 1, + variables: ["name"], + is_default: true, + }; + + test("shows delete button in edit mode", () => { + render(); + + expect(screen.getByText("Delete Prompt")).toBeInTheDocument(); + }); + + test("hides delete button in create mode", () => { + render(); + + expect(screen.queryByText("Delete Prompt")).not.toBeInTheDocument(); + }); + + test("calls onDelete with confirmation", () => { + const originalConfirm = window.confirm; + window.confirm = jest.fn(() => true); + + render(); + + fireEvent.click(screen.getByText("Delete Prompt")); + + expect(window.confirm).toHaveBeenCalledWith( + "Are you sure you want to delete this prompt? This action cannot be undone." + ); + expect(mockOnDelete).toHaveBeenCalledWith("prompt_123"); + + window.confirm = originalConfirm; + }); + + test("does not delete when confirmation is cancelled", () => { + const originalConfirm = window.confirm; + window.confirm = jest.fn(() => false); + + render(); + + fireEvent.click(screen.getByText("Delete Prompt")); + + expect(mockOnDelete).not.toHaveBeenCalled(); + + window.confirm = originalConfirm; + }); + }); + + describe("Cancel Functionality", () => { + test("calls onCancel when cancel button is clicked", () => { + render(); + + fireEvent.click(screen.getByText("Cancel")); + + expect(mockOnCancel).toHaveBeenCalled(); + }); + }); +}); diff --git a/src/llama_stack/ui/components/prompts/prompt-editor.tsx b/src/llama_stack/ui/components/prompts/prompt-editor.tsx new file mode 100644 index 000000000..efa76f757 --- /dev/null +++ b/src/llama_stack/ui/components/prompts/prompt-editor.tsx @@ -0,0 +1,346 @@ +"use client"; + +import { useState, useEffect } from "react"; +import { Button } from "@/components/ui/button"; +import { Input } from "@/components/ui/input"; +import { Label } from "@/components/ui/label"; +import { Textarea } from "@/components/ui/textarea"; +import { Badge } from "@/components/ui/badge"; +import { + Card, + CardContent, + CardDescription, + CardHeader, + CardTitle, +} from "@/components/ui/card"; +import { Separator } from "@/components/ui/separator"; +import { X, Plus, Save, Trash2 } from "lucide-react"; +import { Prompt, PromptFormData } from "./types"; + +interface PromptEditorProps { + prompt?: Prompt; + onSave: (prompt: PromptFormData) => void; + onCancel: () => void; + onDelete?: (promptId: string) => void; + error?: string | null; +} + +export function PromptEditor({ + prompt, + onSave, + onCancel, + onDelete, + error, +}: PromptEditorProps) { + const [formData, setFormData] = useState({ + prompt: "", + variables: [], + }); + + const [newVariable, setNewVariable] = useState(""); + const [variableValues, setVariableValues] = useState>( + {} + ); + + useEffect(() => { + if (prompt) { + setFormData({ + prompt: prompt.prompt || "", + variables: prompt.variables || [], + }); + } + }, [prompt]); + + const handleSubmit = (e: React.FormEvent) => { + e.preventDefault(); + if (!formData.prompt.trim()) { + return; + } + onSave(formData); + }; + + const addVariable = () => { + if ( + newVariable.trim() && + !formData.variables.includes(newVariable.trim()) + ) { + setFormData(prev => ({ + ...prev, + variables: [...prev.variables, newVariable.trim()], + })); + setNewVariable(""); + } + }; + + const removeVariable = (variableToRemove: string) => { + setFormData(prev => ({ + ...prev, + variables: prev.variables.filter( + variable => variable !== variableToRemove + ), + })); + }; + + const renderPreview = () => { + const text = formData.prompt; + if (!text) return text; + + // Split text by variable patterns and process each part + const parts = text.split(/(\{\{\s*\w+\s*\}\})/g); + + return parts.map((part, index) => { + const variableMatch = part.match(/\{\{\s*(\w+)\s*\}\}/); + if (variableMatch) { + const variableName = variableMatch[1]; + const isDefined = formData.variables.includes(variableName); + const value = variableValues[variableName]; + + if (!isDefined) { + // Variable not in variables list - likely a typo/bug (RED) + return ( + + {part} + + ); + } else if (value && value.trim()) { + // Variable defined and has value - show the value (GREEN) + return ( + + {value} + + ); + } else { + // Variable defined but empty (YELLOW) + return ( + + {part} + + ); + } + } + return part; + }); + }; + + const updateVariableValue = (variable: string, value: string) => { + setVariableValues(prev => ({ + ...prev, + [variable]: value, + })); + }; + + return ( +
+ {error && ( +
+

{error}

+
+ )} +
+ {/* Form Section */} +
+
+ +