diff --git a/.github/actions/run-and-record-tests/action.yml b/.github/actions/run-and-record-tests/action.yml index 573148e46..d4f7db7fe 100644 --- a/.github/actions/run-and-record-tests/action.yml +++ b/.github/actions/run-and-record-tests/action.yml @@ -52,9 +52,9 @@ runs: git add tests/integration/recordings/ if [ "${{ inputs.run-vision-tests }}" == "true" ]; then - git commit -m "Recordings update from CI (vision)" + git commit -m "Recordings update from CI (vision) (${{ inputs.provider }})" else - git commit -m "Recordings update from CI" + git commit -m "Recordings update from CI (${{ inputs.provider }})" fi git fetch origin ${{ github.event.pull_request.head.ref }} @@ -70,7 +70,8 @@ runs: if: ${{ always() }} shell: bash run: | - sudo docker logs ollama > ollama-${{ inputs.inference-mode }}.log || true + sudo docker logs ollama > ollama-${{ inputs.inference-mode }}.log 2>&1 || true + sudo docker logs vllm > vllm-${{ inputs.inference-mode }}.log 2>&1 || true - name: Upload logs if: ${{ always() }} diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index 9ef49fba3..e3f2a8c8e 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -20,7 +20,6 @@ on: schedule: # If changing the cron schedule, update the provider in the test-matrix job - cron: '0 0 * * *' # (test latest client) Daily at 12 AM UTC - - cron: '1 0 * * 0' # (test vllm) Weekly on Sunday at 1 AM UTC workflow_dispatch: inputs: test-all-client-versions: @@ -38,28 +37,7 @@ concurrency: cancel-in-progress: true jobs: - discover-tests: - runs-on: ubuntu-latest - outputs: - test-types: ${{ steps.generate-test-types.outputs.test-types }} - - steps: - - name: Checkout repository - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - - name: Generate test types - id: generate-test-types - run: | - # Get test directories dynamically, excluding non-test directories - # NOTE: we are excluding post_training since the tests take too long - TEST_TYPES=$(find tests/integration -maxdepth 1 -mindepth 1 -type d | - sed 's|tests/integration/||' | - grep -Ev "^(__pycache__|fixtures|test_cases|recordings|non_ci|post_training)$" | - sort | jq -R -s -c 'split("\n")[:-1]') - echo "test-types=$TEST_TYPES" >> $GITHUB_OUTPUT - run-replay-mode-tests: - needs: discover-tests runs-on: ubuntu-latest name: ${{ format('Integration Tests ({0}, {1}, {2}, client={3}, vision={4})', matrix.client-type, matrix.provider, matrix.python-version, matrix.client-version, matrix.run-vision-tests) }} @@ -68,11 +46,14 @@ jobs: matrix: client-type: [library, server] # Use vllm on weekly schedule, otherwise use test-provider input (defaults to ollama) - provider: ${{ (github.event.schedule == '1 0 * * 0') && fromJSON('["vllm"]') || fromJSON(format('["{0}"]', github.event.inputs.test-provider || 'ollama')) }} + provider: [ollama, vllm] # Use Python 3.13 only on nightly schedule (daily latest client test), otherwise use 3.12 python-version: ${{ github.event.schedule == '0 0 * * *' && fromJSON('["3.12", "3.13"]') || fromJSON('["3.12"]') }} client-version: ${{ (github.event.schedule == '0 0 * * *' || github.event.inputs.test-all-client-versions == 'true') && fromJSON('["published", "latest"]') || fromJSON('["latest"]') }} run-vision-tests: [true, false] + exclude: + - provider: vllm + run-vision-tests: true steps: - name: Checkout repository @@ -87,10 +68,27 @@ jobs: run-vision-tests: ${{ matrix.run-vision-tests }} inference-mode: 'replay' + - name: Generate test types + id: generate-test-types + run: | + # Only run inference tests for vllm as these are more likely to exercise the vllm provider + # TODO: Add agent tests for vllm + if [ ${{ matrix.provider }} == "vllm" ]; then + echo "test-types=[\"inference\"]" >> $GITHUB_OUTPUT + exit 0 + fi + # Get test directories dynamically, excluding non-test directories + # NOTE: we are excluding post_training since the tests take too long + TEST_TYPES=$(find tests/integration -maxdepth 1 -mindepth 1 -type d | + sed 's|tests/integration/||' | + grep -Ev "^(__pycache__|fixtures|test_cases|recordings|non_ci|post_training)$" | + sort | jq -R -s -c 'split("\n")[:-1]') + echo "test-types=$TEST_TYPES" >> $GITHUB_OUTPUT + - name: Run tests uses: ./.github/actions/run-and-record-tests with: - test-types: ${{ needs.discover-tests.outputs.test-types }} + test-types: ${{ steps.generate-test-types.outputs.test-types }} stack-config: ${{ matrix.client-type == 'library' && 'ci-tests' || 'server:ci-tests' }} provider: ${{ matrix.provider }} inference-mode: 'replay' diff --git a/.github/workflows/record-integration-tests.yml b/.github/workflows/record-integration-tests.yml index b31709a4f..f74e8deb3 100644 --- a/.github/workflows/record-integration-tests.yml +++ b/.github/workflows/record-integration-tests.yml @@ -15,12 +15,6 @@ on: - '.github/actions/setup-ollama/action.yml' - '.github/actions/setup-test-environment/action.yml' - '.github/actions/run-and-record-tests/action.yml' - workflow_dispatch: - inputs: - test-provider: - description: 'Test against a specific provider' - type: string - default: 'ollama' concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number }} @@ -42,12 +36,6 @@ jobs: - name: Generate test types id: generate-test-types run: | - # Get test directories dynamically, excluding non-test directories - TEST_TYPES=$(find tests/integration -maxdepth 1 -mindepth 1 -type d -printf "%f\n" | - grep -Ev "^(__pycache__|fixtures|test_cases|recordings|post_training)$" | - sort | jq -R -s -c 'split("\n")[:-1]') - echo "test-types=$TEST_TYPES" >> $GITHUB_OUTPUT - labels=$(gh pr view ${{ github.event.pull_request.number }} --json labels --jq '.labels[].name') echo "labels=$labels" @@ -82,6 +70,10 @@ jobs: fail-fast: false matrix: mode: ${{ fromJSON(needs.discover-tests.outputs.matrix-modes) }} + provider: [ollama, vllm] + exclude: + - mode: vision + provider: vllm steps: - name: Checkout repository @@ -90,20 +82,33 @@ jobs: ref: ${{ github.event.pull_request.head.ref }} fetch-depth: 0 + - name: Generate test types + id: generate-test-types + run: | + if [ ${{ matrix.provider }} == "vllm" ]; then + echo "test-types=[\"inference\"]" >> $GITHUB_OUTPUT + else + # Get test directories dynamically, excluding non-test directories + TEST_TYPES=$(find tests/integration -maxdepth 1 -mindepth 1 -type d -printf "%f\n" | + grep -Ev "^(__pycache__|fixtures|test_cases|recordings|non_ci|post_training)$" | + sort | jq -R -s -c 'split("\n")[:-1]') + echo "test-types=$TEST_TYPES" >> $GITHUB_OUTPUT + fi + - name: Setup test environment uses: ./.github/actions/setup-test-environment with: python-version: "3.12" # Use single Python version for recording client-version: "latest" - provider: ${{ inputs.test-provider || 'ollama' }} + provider: ${{ matrix.provider }} run-vision-tests: ${{ matrix.mode == 'vision' && 'true' || 'false' }} inference-mode: 'record' - name: Run and record tests uses: ./.github/actions/run-and-record-tests with: - test-types: ${{ needs.discover-tests.outputs.test-types }} + test-types: ${{ steps.generate-test-types.outputs.test-types }} stack-config: 'server:ci-tests' # recording must be done with server since more tests are run - provider: ${{ inputs.test-provider || 'ollama' }} + provider: ${{ matrix.provider }} inference-mode: 'record' run-vision-tests: ${{ matrix.mode == 'vision' && 'true' || 'false' }} diff --git a/llama_stack/testing/inference_recorder.py b/llama_stack/testing/inference_recorder.py index 478f77773..0b2c01a1b 100644 --- a/llama_stack/testing/inference_recorder.py +++ b/llama_stack/testing/inference_recorder.py @@ -10,12 +10,15 @@ import hashlib import json import os import sqlite3 +import uuid from collections.abc import Generator from contextlib import contextmanager from enum import StrEnum from pathlib import Path from typing import Any, Literal, cast +from openai.types.chat import ChatCompletion, ChatCompletionChunk + from llama_stack.log import get_logger logger = get_logger(__name__, category="testing") @@ -105,13 +108,29 @@ def _deserialize_response(data: dict[str, Any]) -> Any: try: # Import the original class and reconstruct the object module_path, class_name = data["__type__"].rsplit(".", 1) + + # Handle generic types (e.g. AsyncPage[Model]) by removing the generic part + if "[" in class_name and "]" in class_name: + class_name = class_name.split("[")[0] + module = __import__(module_path, fromlist=[class_name]) cls = getattr(module, class_name) if not hasattr(cls, "model_validate"): raise ValueError(f"Pydantic class {cls} does not support model_validate?") - return cls.model_validate(data["__data__"]) + # Special handling for AsyncPage - convert nested model dicts to proper model objects + validate_data = data["__data__"] + if class_name == "AsyncPage" and isinstance(validate_data, dict) and "data" in validate_data: + # Convert model dictionaries to objects with attributes so they work with .id access + from types import SimpleNamespace + + validate_data = dict(validate_data) + validate_data["data"] = [ + SimpleNamespace(**item) if isinstance(item, dict) else item for item in validate_data["data"] + ] + + return cls.model_validate(validate_data) except (ImportError, AttributeError, TypeError, ValueError) as e: logger.warning(f"Failed to deserialize object of type {data['__type__']}: {e}") return data["__data__"] @@ -248,6 +267,20 @@ async def _patched_inference_method(original_method, self, client_type, endpoint recording = _current_storage.find_recording(request_hash) if recording: response_body = recording["response"]["body"] + if ( + isinstance(response_body, list) + and len(response_body) > 0 + and isinstance(response_body[0], ChatCompletionChunk) + ): + # We can't replay chatcompletions with the same id and we store them in a sqlite database with a unique constraint on the id. + # So we generate a new id and replace the old one. + newid = uuid.uuid4().hex + response_body[0].id = "chatcmpl-" + newid + elif isinstance(response_body, ChatCompletion): + # We can't replay chatcompletions with the same id and we store them in a sqlite database with a unique constraint on the id. + # So we generate a new id and replace the old one. + newid = uuid.uuid4().hex + response_body.id = "chatcmpl-" + newid if recording["response"].get("is_streaming", False): @@ -315,9 +348,11 @@ def patch_inference_clients(): from openai.resources.chat.completions import AsyncCompletions as AsyncChatCompletions from openai.resources.completions import AsyncCompletions from openai.resources.embeddings import AsyncEmbeddings + from openai.resources.models import AsyncModels # Store original methods for both OpenAI and Ollama clients _original_methods = { + "model_list": AsyncModels.list, "chat_completions_create": AsyncChatCompletions.create, "completions_create": AsyncCompletions.create, "embeddings_create": AsyncEmbeddings.create, @@ -330,6 +365,55 @@ def patch_inference_clients(): } # Create patched methods for OpenAI client + def patched_model_list(self, *args, **kwargs): + # The original models.list() returns an AsyncPaginator that can be used with async for + # We need to create a wrapper that preserves this behavior + class PatchedAsyncPaginator: + def __init__(self, original_method, instance, client_type, endpoint, args, kwargs): + self.original_method = original_method + self.instance = instance + self.client_type = client_type + self.endpoint = endpoint + self.args = args + self.kwargs = kwargs + self._result = None + + def __await__(self): + # Make it awaitable like the original AsyncPaginator + async def _await(): + self._result = await _patched_inference_method( + self.original_method, self.instance, self.client_type, self.endpoint, *self.args, **self.kwargs + ) + return self._result + + return _await().__await__() + + def __aiter__(self): + # Make it async iterable like the original AsyncPaginator + return self + + async def __anext__(self): + # Get the result if we haven't already + if self._result is None: + self._result = await _patched_inference_method( + self.original_method, self.instance, self.client_type, self.endpoint, *self.args, **self.kwargs + ) + + # Initialize iteration on first call + if not hasattr(self, "_iter_index"): + # Extract the data list from the result + self._data_list = self._result.data + self._iter_index = 0 + + # Return next item from the list + if self._iter_index >= len(self._data_list): + raise StopAsyncIteration + item = self._data_list[self._iter_index] + self._iter_index += 1 + return item + + return PatchedAsyncPaginator(_original_methods["model_list"], self, "openai", "/v1/models", args, kwargs) + async def patched_chat_completions_create(self, *args, **kwargs): return await _patched_inference_method( _original_methods["chat_completions_create"], self, "openai", "/v1/chat/completions", *args, **kwargs @@ -346,6 +430,7 @@ def patch_inference_clients(): ) # Apply OpenAI patches + AsyncModels.list = patched_model_list AsyncChatCompletions.create = patched_chat_completions_create AsyncCompletions.create = patched_completions_create AsyncEmbeddings.create = patched_embeddings_create @@ -402,8 +487,10 @@ def unpatch_inference_clients(): from openai.resources.chat.completions import AsyncCompletions as AsyncChatCompletions from openai.resources.completions import AsyncCompletions from openai.resources.embeddings import AsyncEmbeddings + from openai.resources.models import AsyncModels # Restore OpenAI client methods + AsyncModels.list = _original_methods["model_list"] AsyncChatCompletions.create = _original_methods["chat_completions_create"] AsyncCompletions.create = _original_methods["completions_create"] AsyncEmbeddings.create = _original_methods["embeddings_create"]