This commit is contained in:
Derek Higgins 2025-08-14 13:57:16 -04:00 committed by GitHub
commit 6a5dadc395
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 134 additions and 43 deletions

View file

@ -52,9 +52,9 @@ runs:
git add tests/integration/recordings/ git add tests/integration/recordings/
if [ "${{ inputs.run-vision-tests }}" == "true" ]; then if [ "${{ inputs.run-vision-tests }}" == "true" ]; then
git commit -m "Recordings update from CI (vision)" git commit -m "Recordings update from CI (vision) (${{ inputs.provider }})"
else else
git commit -m "Recordings update from CI" git commit -m "Recordings update from CI (${{ inputs.provider }})"
fi fi
git fetch origin ${{ github.event.pull_request.head.ref }} git fetch origin ${{ github.event.pull_request.head.ref }}
@ -70,7 +70,8 @@ runs:
if: ${{ always() }} if: ${{ always() }}
shell: bash shell: bash
run: | run: |
sudo docker logs ollama > ollama-${{ inputs.inference-mode }}.log || true sudo docker logs ollama > ollama-${{ inputs.inference-mode }}.log 2>&1 || true
sudo docker logs vllm > vllm-${{ inputs.inference-mode }}.log 2>&1 || true
- name: Upload logs - name: Upload logs
if: ${{ always() }} if: ${{ always() }}

View file

@ -20,7 +20,6 @@ on:
schedule: schedule:
# If changing the cron schedule, update the provider in the test-matrix job # If changing the cron schedule, update the provider in the test-matrix job
- cron: '0 0 * * *' # (test latest client) Daily at 12 AM UTC - cron: '0 0 * * *' # (test latest client) Daily at 12 AM UTC
- cron: '1 0 * * 0' # (test vllm) Weekly on Sunday at 1 AM UTC
workflow_dispatch: workflow_dispatch:
inputs: inputs:
test-all-client-versions: test-all-client-versions:
@ -38,28 +37,7 @@ concurrency:
cancel-in-progress: true cancel-in-progress: true
jobs: jobs:
discover-tests:
runs-on: ubuntu-latest
outputs:
test-types: ${{ steps.generate-test-types.outputs.test-types }}
steps:
- name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Generate test types
id: generate-test-types
run: |
# Get test directories dynamically, excluding non-test directories
# NOTE: we are excluding post_training since the tests take too long
TEST_TYPES=$(find tests/integration -maxdepth 1 -mindepth 1 -type d |
sed 's|tests/integration/||' |
grep -Ev "^(__pycache__|fixtures|test_cases|recordings|non_ci|post_training)$" |
sort | jq -R -s -c 'split("\n")[:-1]')
echo "test-types=$TEST_TYPES" >> $GITHUB_OUTPUT
run-replay-mode-tests: run-replay-mode-tests:
needs: discover-tests
runs-on: ubuntu-latest runs-on: ubuntu-latest
name: ${{ format('Integration Tests ({0}, {1}, {2}, client={3}, vision={4})', matrix.client-type, matrix.provider, matrix.python-version, matrix.client-version, matrix.run-vision-tests) }} name: ${{ format('Integration Tests ({0}, {1}, {2}, client={3}, vision={4})', matrix.client-type, matrix.provider, matrix.python-version, matrix.client-version, matrix.run-vision-tests) }}
@ -68,11 +46,14 @@ jobs:
matrix: matrix:
client-type: [library, server] client-type: [library, server]
# Use vllm on weekly schedule, otherwise use test-provider input (defaults to ollama) # Use vllm on weekly schedule, otherwise use test-provider input (defaults to ollama)
provider: ${{ (github.event.schedule == '1 0 * * 0') && fromJSON('["vllm"]') || fromJSON(format('["{0}"]', github.event.inputs.test-provider || 'ollama')) }} provider: [ollama, vllm]
# Use Python 3.13 only on nightly schedule (daily latest client test), otherwise use 3.12 # Use Python 3.13 only on nightly schedule (daily latest client test), otherwise use 3.12
python-version: ${{ github.event.schedule == '0 0 * * *' && fromJSON('["3.12", "3.13"]') || fromJSON('["3.12"]') }} python-version: ${{ github.event.schedule == '0 0 * * *' && fromJSON('["3.12", "3.13"]') || fromJSON('["3.12"]') }}
client-version: ${{ (github.event.schedule == '0 0 * * *' || github.event.inputs.test-all-client-versions == 'true') && fromJSON('["published", "latest"]') || fromJSON('["latest"]') }} client-version: ${{ (github.event.schedule == '0 0 * * *' || github.event.inputs.test-all-client-versions == 'true') && fromJSON('["published", "latest"]') || fromJSON('["latest"]') }}
run-vision-tests: [true, false] run-vision-tests: [true, false]
exclude:
- provider: vllm
run-vision-tests: true
steps: steps:
- name: Checkout repository - name: Checkout repository
@ -87,10 +68,27 @@ jobs:
run-vision-tests: ${{ matrix.run-vision-tests }} run-vision-tests: ${{ matrix.run-vision-tests }}
inference-mode: 'replay' inference-mode: 'replay'
- name: Generate test types
id: generate-test-types
run: |
# Only run inference tests for vllm as these are more likely to exercise the vllm provider
# TODO: Add agent tests for vllm
if [ ${{ matrix.provider }} == "vllm" ]; then
echo "test-types=[\"inference\"]" >> $GITHUB_OUTPUT
exit 0
fi
# Get test directories dynamically, excluding non-test directories
# NOTE: we are excluding post_training since the tests take too long
TEST_TYPES=$(find tests/integration -maxdepth 1 -mindepth 1 -type d |
sed 's|tests/integration/||' |
grep -Ev "^(__pycache__|fixtures|test_cases|recordings|non_ci|post_training)$" |
sort | jq -R -s -c 'split("\n")[:-1]')
echo "test-types=$TEST_TYPES" >> $GITHUB_OUTPUT
- name: Run tests - name: Run tests
uses: ./.github/actions/run-and-record-tests uses: ./.github/actions/run-and-record-tests
with: with:
test-types: ${{ needs.discover-tests.outputs.test-types }} test-types: ${{ steps.generate-test-types.outputs.test-types }}
stack-config: ${{ matrix.client-type == 'library' && 'ci-tests' || 'server:ci-tests' }} stack-config: ${{ matrix.client-type == 'library' && 'ci-tests' || 'server:ci-tests' }}
provider: ${{ matrix.provider }} provider: ${{ matrix.provider }}
inference-mode: 'replay' inference-mode: 'replay'

View file

@ -15,12 +15,6 @@ on:
- '.github/actions/setup-ollama/action.yml' - '.github/actions/setup-ollama/action.yml'
- '.github/actions/setup-test-environment/action.yml' - '.github/actions/setup-test-environment/action.yml'
- '.github/actions/run-and-record-tests/action.yml' - '.github/actions/run-and-record-tests/action.yml'
workflow_dispatch:
inputs:
test-provider:
description: 'Test against a specific provider'
type: string
default: 'ollama'
concurrency: concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number }} group: ${{ github.workflow }}-${{ github.event.pull_request.number }}
@ -42,12 +36,6 @@ jobs:
- name: Generate test types - name: Generate test types
id: generate-test-types id: generate-test-types
run: | run: |
# Get test directories dynamically, excluding non-test directories
TEST_TYPES=$(find tests/integration -maxdepth 1 -mindepth 1 -type d -printf "%f\n" |
grep -Ev "^(__pycache__|fixtures|test_cases|recordings|post_training)$" |
sort | jq -R -s -c 'split("\n")[:-1]')
echo "test-types=$TEST_TYPES" >> $GITHUB_OUTPUT
labels=$(gh pr view ${{ github.event.pull_request.number }} --json labels --jq '.labels[].name') labels=$(gh pr view ${{ github.event.pull_request.number }} --json labels --jq '.labels[].name')
echo "labels=$labels" echo "labels=$labels"
@ -82,6 +70,10 @@ jobs:
fail-fast: false fail-fast: false
matrix: matrix:
mode: ${{ fromJSON(needs.discover-tests.outputs.matrix-modes) }} mode: ${{ fromJSON(needs.discover-tests.outputs.matrix-modes) }}
provider: [ollama, vllm]
exclude:
- mode: vision
provider: vllm
steps: steps:
- name: Checkout repository - name: Checkout repository
@ -90,20 +82,33 @@ jobs:
ref: ${{ github.event.pull_request.head.ref }} ref: ${{ github.event.pull_request.head.ref }}
fetch-depth: 0 fetch-depth: 0
- name: Generate test types
id: generate-test-types
run: |
if [ ${{ matrix.provider }} == "vllm" ]; then
echo "test-types=[\"inference\"]" >> $GITHUB_OUTPUT
else
# Get test directories dynamically, excluding non-test directories
TEST_TYPES=$(find tests/integration -maxdepth 1 -mindepth 1 -type d -printf "%f\n" |
grep -Ev "^(__pycache__|fixtures|test_cases|recordings|non_ci|post_training)$" |
sort | jq -R -s -c 'split("\n")[:-1]')
echo "test-types=$TEST_TYPES" >> $GITHUB_OUTPUT
fi
- name: Setup test environment - name: Setup test environment
uses: ./.github/actions/setup-test-environment uses: ./.github/actions/setup-test-environment
with: with:
python-version: "3.12" # Use single Python version for recording python-version: "3.12" # Use single Python version for recording
client-version: "latest" client-version: "latest"
provider: ${{ inputs.test-provider || 'ollama' }} provider: ${{ matrix.provider }}
run-vision-tests: ${{ matrix.mode == 'vision' && 'true' || 'false' }} run-vision-tests: ${{ matrix.mode == 'vision' && 'true' || 'false' }}
inference-mode: 'record' inference-mode: 'record'
- name: Run and record tests - name: Run and record tests
uses: ./.github/actions/run-and-record-tests uses: ./.github/actions/run-and-record-tests
with: with:
test-types: ${{ needs.discover-tests.outputs.test-types }} test-types: ${{ steps.generate-test-types.outputs.test-types }}
stack-config: 'server:ci-tests' # recording must be done with server since more tests are run stack-config: 'server:ci-tests' # recording must be done with server since more tests are run
provider: ${{ inputs.test-provider || 'ollama' }} provider: ${{ matrix.provider }}
inference-mode: 'record' inference-mode: 'record'
run-vision-tests: ${{ matrix.mode == 'vision' && 'true' || 'false' }} run-vision-tests: ${{ matrix.mode == 'vision' && 'true' || 'false' }}

View file

@ -10,12 +10,15 @@ import hashlib
import json import json
import os import os
import sqlite3 import sqlite3
import uuid
from collections.abc import Generator from collections.abc import Generator
from contextlib import contextmanager from contextlib import contextmanager
from enum import StrEnum from enum import StrEnum
from pathlib import Path from pathlib import Path
from typing import Any, Literal, cast from typing import Any, Literal, cast
from openai.types.chat import ChatCompletion, ChatCompletionChunk
from llama_stack.log import get_logger from llama_stack.log import get_logger
logger = get_logger(__name__, category="testing") logger = get_logger(__name__, category="testing")
@ -105,13 +108,29 @@ def _deserialize_response(data: dict[str, Any]) -> Any:
try: try:
# Import the original class and reconstruct the object # Import the original class and reconstruct the object
module_path, class_name = data["__type__"].rsplit(".", 1) module_path, class_name = data["__type__"].rsplit(".", 1)
# Handle generic types (e.g. AsyncPage[Model]) by removing the generic part
if "[" in class_name and "]" in class_name:
class_name = class_name.split("[")[0]
module = __import__(module_path, fromlist=[class_name]) module = __import__(module_path, fromlist=[class_name])
cls = getattr(module, class_name) cls = getattr(module, class_name)
if not hasattr(cls, "model_validate"): if not hasattr(cls, "model_validate"):
raise ValueError(f"Pydantic class {cls} does not support model_validate?") raise ValueError(f"Pydantic class {cls} does not support model_validate?")
return cls.model_validate(data["__data__"]) # Special handling for AsyncPage - convert nested model dicts to proper model objects
validate_data = data["__data__"]
if class_name == "AsyncPage" and isinstance(validate_data, dict) and "data" in validate_data:
# Convert model dictionaries to objects with attributes so they work with .id access
from types import SimpleNamespace
validate_data = dict(validate_data)
validate_data["data"] = [
SimpleNamespace(**item) if isinstance(item, dict) else item for item in validate_data["data"]
]
return cls.model_validate(validate_data)
except (ImportError, AttributeError, TypeError, ValueError) as e: except (ImportError, AttributeError, TypeError, ValueError) as e:
logger.warning(f"Failed to deserialize object of type {data['__type__']}: {e}") logger.warning(f"Failed to deserialize object of type {data['__type__']}: {e}")
return data["__data__"] return data["__data__"]
@ -248,6 +267,20 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
recording = _current_storage.find_recording(request_hash) recording = _current_storage.find_recording(request_hash)
if recording: if recording:
response_body = recording["response"]["body"] response_body = recording["response"]["body"]
if (
isinstance(response_body, list)
and len(response_body) > 0
and isinstance(response_body[0], ChatCompletionChunk)
):
# We can't replay chatcompletions with the same id and we store them in a sqlite database with a unique constraint on the id.
# So we generate a new id and replace the old one.
newid = uuid.uuid4().hex
response_body[0].id = "chatcmpl-" + newid
elif isinstance(response_body, ChatCompletion):
# We can't replay chatcompletions with the same id and we store them in a sqlite database with a unique constraint on the id.
# So we generate a new id and replace the old one.
newid = uuid.uuid4().hex
response_body.id = "chatcmpl-" + newid
if recording["response"].get("is_streaming", False): if recording["response"].get("is_streaming", False):
@ -315,9 +348,11 @@ def patch_inference_clients():
from openai.resources.chat.completions import AsyncCompletions as AsyncChatCompletions from openai.resources.chat.completions import AsyncCompletions as AsyncChatCompletions
from openai.resources.completions import AsyncCompletions from openai.resources.completions import AsyncCompletions
from openai.resources.embeddings import AsyncEmbeddings from openai.resources.embeddings import AsyncEmbeddings
from openai.resources.models import AsyncModels
# Store original methods for both OpenAI and Ollama clients # Store original methods for both OpenAI and Ollama clients
_original_methods = { _original_methods = {
"model_list": AsyncModels.list,
"chat_completions_create": AsyncChatCompletions.create, "chat_completions_create": AsyncChatCompletions.create,
"completions_create": AsyncCompletions.create, "completions_create": AsyncCompletions.create,
"embeddings_create": AsyncEmbeddings.create, "embeddings_create": AsyncEmbeddings.create,
@ -330,6 +365,55 @@ def patch_inference_clients():
} }
# Create patched methods for OpenAI client # Create patched methods for OpenAI client
def patched_model_list(self, *args, **kwargs):
# The original models.list() returns an AsyncPaginator that can be used with async for
# We need to create a wrapper that preserves this behavior
class PatchedAsyncPaginator:
def __init__(self, original_method, instance, client_type, endpoint, args, kwargs):
self.original_method = original_method
self.instance = instance
self.client_type = client_type
self.endpoint = endpoint
self.args = args
self.kwargs = kwargs
self._result = None
def __await__(self):
# Make it awaitable like the original AsyncPaginator
async def _await():
self._result = await _patched_inference_method(
self.original_method, self.instance, self.client_type, self.endpoint, *self.args, **self.kwargs
)
return self._result
return _await().__await__()
def __aiter__(self):
# Make it async iterable like the original AsyncPaginator
return self
async def __anext__(self):
# Get the result if we haven't already
if self._result is None:
self._result = await _patched_inference_method(
self.original_method, self.instance, self.client_type, self.endpoint, *self.args, **self.kwargs
)
# Initialize iteration on first call
if not hasattr(self, "_iter_index"):
# Extract the data list from the result
self._data_list = self._result.data
self._iter_index = 0
# Return next item from the list
if self._iter_index >= len(self._data_list):
raise StopAsyncIteration
item = self._data_list[self._iter_index]
self._iter_index += 1
return item
return PatchedAsyncPaginator(_original_methods["model_list"], self, "openai", "/v1/models", args, kwargs)
async def patched_chat_completions_create(self, *args, **kwargs): async def patched_chat_completions_create(self, *args, **kwargs):
return await _patched_inference_method( return await _patched_inference_method(
_original_methods["chat_completions_create"], self, "openai", "/v1/chat/completions", *args, **kwargs _original_methods["chat_completions_create"], self, "openai", "/v1/chat/completions", *args, **kwargs
@ -346,6 +430,7 @@ def patch_inference_clients():
) )
# Apply OpenAI patches # Apply OpenAI patches
AsyncModels.list = patched_model_list
AsyncChatCompletions.create = patched_chat_completions_create AsyncChatCompletions.create = patched_chat_completions_create
AsyncCompletions.create = patched_completions_create AsyncCompletions.create = patched_completions_create
AsyncEmbeddings.create = patched_embeddings_create AsyncEmbeddings.create = patched_embeddings_create
@ -402,8 +487,10 @@ def unpatch_inference_clients():
from openai.resources.chat.completions import AsyncCompletions as AsyncChatCompletions from openai.resources.chat.completions import AsyncCompletions as AsyncChatCompletions
from openai.resources.completions import AsyncCompletions from openai.resources.completions import AsyncCompletions
from openai.resources.embeddings import AsyncEmbeddings from openai.resources.embeddings import AsyncEmbeddings
from openai.resources.models import AsyncModels
# Restore OpenAI client methods # Restore OpenAI client methods
AsyncModels.list = _original_methods["model_list"]
AsyncChatCompletions.create = _original_methods["chat_completions_create"] AsyncChatCompletions.create = _original_methods["chat_completions_create"]
AsyncCompletions.create = _original_methods["completions_create"] AsyncCompletions.create = _original_methods["completions_create"]
AsyncEmbeddings.create = _original_methods["embeddings_create"] AsyncEmbeddings.create = _original_methods["embeddings_create"]