mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-15 06:00:48 +00:00
Merge 1f29aaa2e1
into 61582f327c
This commit is contained in:
commit
6a5dadc395
4 changed files with 134 additions and 43 deletions
|
@ -52,9 +52,9 @@ runs:
|
|||
git add tests/integration/recordings/
|
||||
|
||||
if [ "${{ inputs.run-vision-tests }}" == "true" ]; then
|
||||
git commit -m "Recordings update from CI (vision)"
|
||||
git commit -m "Recordings update from CI (vision) (${{ inputs.provider }})"
|
||||
else
|
||||
git commit -m "Recordings update from CI"
|
||||
git commit -m "Recordings update from CI (${{ inputs.provider }})"
|
||||
fi
|
||||
|
||||
git fetch origin ${{ github.event.pull_request.head.ref }}
|
||||
|
@ -70,7 +70,8 @@ runs:
|
|||
if: ${{ always() }}
|
||||
shell: bash
|
||||
run: |
|
||||
sudo docker logs ollama > ollama-${{ inputs.inference-mode }}.log || true
|
||||
sudo docker logs ollama > ollama-${{ inputs.inference-mode }}.log 2>&1 || true
|
||||
sudo docker logs vllm > vllm-${{ inputs.inference-mode }}.log 2>&1 || true
|
||||
|
||||
- name: Upload logs
|
||||
if: ${{ always() }}
|
||||
|
|
46
.github/workflows/integration-tests.yml
vendored
46
.github/workflows/integration-tests.yml
vendored
|
@ -20,7 +20,6 @@ on:
|
|||
schedule:
|
||||
# If changing the cron schedule, update the provider in the test-matrix job
|
||||
- cron: '0 0 * * *' # (test latest client) Daily at 12 AM UTC
|
||||
- cron: '1 0 * * 0' # (test vllm) Weekly on Sunday at 1 AM UTC
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
test-all-client-versions:
|
||||
|
@ -38,28 +37,7 @@ concurrency:
|
|||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
discover-tests:
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
test-types: ${{ steps.generate-test-types.outputs.test-types }}
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
|
||||
- name: Generate test types
|
||||
id: generate-test-types
|
||||
run: |
|
||||
# Get test directories dynamically, excluding non-test directories
|
||||
# NOTE: we are excluding post_training since the tests take too long
|
||||
TEST_TYPES=$(find tests/integration -maxdepth 1 -mindepth 1 -type d |
|
||||
sed 's|tests/integration/||' |
|
||||
grep -Ev "^(__pycache__|fixtures|test_cases|recordings|non_ci|post_training)$" |
|
||||
sort | jq -R -s -c 'split("\n")[:-1]')
|
||||
echo "test-types=$TEST_TYPES" >> $GITHUB_OUTPUT
|
||||
|
||||
run-replay-mode-tests:
|
||||
needs: discover-tests
|
||||
runs-on: ubuntu-latest
|
||||
name: ${{ format('Integration Tests ({0}, {1}, {2}, client={3}, vision={4})', matrix.client-type, matrix.provider, matrix.python-version, matrix.client-version, matrix.run-vision-tests) }}
|
||||
|
||||
|
@ -68,11 +46,14 @@ jobs:
|
|||
matrix:
|
||||
client-type: [library, server]
|
||||
# Use vllm on weekly schedule, otherwise use test-provider input (defaults to ollama)
|
||||
provider: ${{ (github.event.schedule == '1 0 * * 0') && fromJSON('["vllm"]') || fromJSON(format('["{0}"]', github.event.inputs.test-provider || 'ollama')) }}
|
||||
provider: [ollama, vllm]
|
||||
# Use Python 3.13 only on nightly schedule (daily latest client test), otherwise use 3.12
|
||||
python-version: ${{ github.event.schedule == '0 0 * * *' && fromJSON('["3.12", "3.13"]') || fromJSON('["3.12"]') }}
|
||||
client-version: ${{ (github.event.schedule == '0 0 * * *' || github.event.inputs.test-all-client-versions == 'true') && fromJSON('["published", "latest"]') || fromJSON('["latest"]') }}
|
||||
run-vision-tests: [true, false]
|
||||
exclude:
|
||||
- provider: vllm
|
||||
run-vision-tests: true
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
|
@ -87,10 +68,27 @@ jobs:
|
|||
run-vision-tests: ${{ matrix.run-vision-tests }}
|
||||
inference-mode: 'replay'
|
||||
|
||||
- name: Generate test types
|
||||
id: generate-test-types
|
||||
run: |
|
||||
# Only run inference tests for vllm as these are more likely to exercise the vllm provider
|
||||
# TODO: Add agent tests for vllm
|
||||
if [ ${{ matrix.provider }} == "vllm" ]; then
|
||||
echo "test-types=[\"inference\"]" >> $GITHUB_OUTPUT
|
||||
exit 0
|
||||
fi
|
||||
# Get test directories dynamically, excluding non-test directories
|
||||
# NOTE: we are excluding post_training since the tests take too long
|
||||
TEST_TYPES=$(find tests/integration -maxdepth 1 -mindepth 1 -type d |
|
||||
sed 's|tests/integration/||' |
|
||||
grep -Ev "^(__pycache__|fixtures|test_cases|recordings|non_ci|post_training)$" |
|
||||
sort | jq -R -s -c 'split("\n")[:-1]')
|
||||
echo "test-types=$TEST_TYPES" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Run tests
|
||||
uses: ./.github/actions/run-and-record-tests
|
||||
with:
|
||||
test-types: ${{ needs.discover-tests.outputs.test-types }}
|
||||
test-types: ${{ steps.generate-test-types.outputs.test-types }}
|
||||
stack-config: ${{ matrix.client-type == 'library' && 'ci-tests' || 'server:ci-tests' }}
|
||||
provider: ${{ matrix.provider }}
|
||||
inference-mode: 'replay'
|
||||
|
|
35
.github/workflows/record-integration-tests.yml
vendored
35
.github/workflows/record-integration-tests.yml
vendored
|
@ -15,12 +15,6 @@ on:
|
|||
- '.github/actions/setup-ollama/action.yml'
|
||||
- '.github/actions/setup-test-environment/action.yml'
|
||||
- '.github/actions/run-and-record-tests/action.yml'
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
test-provider:
|
||||
description: 'Test against a specific provider'
|
||||
type: string
|
||||
default: 'ollama'
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number }}
|
||||
|
@ -42,12 +36,6 @@ jobs:
|
|||
- name: Generate test types
|
||||
id: generate-test-types
|
||||
run: |
|
||||
# Get test directories dynamically, excluding non-test directories
|
||||
TEST_TYPES=$(find tests/integration -maxdepth 1 -mindepth 1 -type d -printf "%f\n" |
|
||||
grep -Ev "^(__pycache__|fixtures|test_cases|recordings|post_training)$" |
|
||||
sort | jq -R -s -c 'split("\n")[:-1]')
|
||||
echo "test-types=$TEST_TYPES" >> $GITHUB_OUTPUT
|
||||
|
||||
labels=$(gh pr view ${{ github.event.pull_request.number }} --json labels --jq '.labels[].name')
|
||||
echo "labels=$labels"
|
||||
|
||||
|
@ -82,6 +70,10 @@ jobs:
|
|||
fail-fast: false
|
||||
matrix:
|
||||
mode: ${{ fromJSON(needs.discover-tests.outputs.matrix-modes) }}
|
||||
provider: [ollama, vllm]
|
||||
exclude:
|
||||
- mode: vision
|
||||
provider: vllm
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
|
@ -90,20 +82,33 @@ jobs:
|
|||
ref: ${{ github.event.pull_request.head.ref }}
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Generate test types
|
||||
id: generate-test-types
|
||||
run: |
|
||||
if [ ${{ matrix.provider }} == "vllm" ]; then
|
||||
echo "test-types=[\"inference\"]" >> $GITHUB_OUTPUT
|
||||
else
|
||||
# Get test directories dynamically, excluding non-test directories
|
||||
TEST_TYPES=$(find tests/integration -maxdepth 1 -mindepth 1 -type d -printf "%f\n" |
|
||||
grep -Ev "^(__pycache__|fixtures|test_cases|recordings|non_ci|post_training)$" |
|
||||
sort | jq -R -s -c 'split("\n")[:-1]')
|
||||
echo "test-types=$TEST_TYPES" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
- name: Setup test environment
|
||||
uses: ./.github/actions/setup-test-environment
|
||||
with:
|
||||
python-version: "3.12" # Use single Python version for recording
|
||||
client-version: "latest"
|
||||
provider: ${{ inputs.test-provider || 'ollama' }}
|
||||
provider: ${{ matrix.provider }}
|
||||
run-vision-tests: ${{ matrix.mode == 'vision' && 'true' || 'false' }}
|
||||
inference-mode: 'record'
|
||||
|
||||
- name: Run and record tests
|
||||
uses: ./.github/actions/run-and-record-tests
|
||||
with:
|
||||
test-types: ${{ needs.discover-tests.outputs.test-types }}
|
||||
test-types: ${{ steps.generate-test-types.outputs.test-types }}
|
||||
stack-config: 'server:ci-tests' # recording must be done with server since more tests are run
|
||||
provider: ${{ inputs.test-provider || 'ollama' }}
|
||||
provider: ${{ matrix.provider }}
|
||||
inference-mode: 'record'
|
||||
run-vision-tests: ${{ matrix.mode == 'vision' && 'true' || 'false' }}
|
||||
|
|
|
@ -10,12 +10,15 @@ import hashlib
|
|||
import json
|
||||
import os
|
||||
import sqlite3
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from enum import StrEnum
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = get_logger(__name__, category="testing")
|
||||
|
@ -105,13 +108,29 @@ def _deserialize_response(data: dict[str, Any]) -> Any:
|
|||
try:
|
||||
# Import the original class and reconstruct the object
|
||||
module_path, class_name = data["__type__"].rsplit(".", 1)
|
||||
|
||||
# Handle generic types (e.g. AsyncPage[Model]) by removing the generic part
|
||||
if "[" in class_name and "]" in class_name:
|
||||
class_name = class_name.split("[")[0]
|
||||
|
||||
module = __import__(module_path, fromlist=[class_name])
|
||||
cls = getattr(module, class_name)
|
||||
|
||||
if not hasattr(cls, "model_validate"):
|
||||
raise ValueError(f"Pydantic class {cls} does not support model_validate?")
|
||||
|
||||
return cls.model_validate(data["__data__"])
|
||||
# Special handling for AsyncPage - convert nested model dicts to proper model objects
|
||||
validate_data = data["__data__"]
|
||||
if class_name == "AsyncPage" and isinstance(validate_data, dict) and "data" in validate_data:
|
||||
# Convert model dictionaries to objects with attributes so they work with .id access
|
||||
from types import SimpleNamespace
|
||||
|
||||
validate_data = dict(validate_data)
|
||||
validate_data["data"] = [
|
||||
SimpleNamespace(**item) if isinstance(item, dict) else item for item in validate_data["data"]
|
||||
]
|
||||
|
||||
return cls.model_validate(validate_data)
|
||||
except (ImportError, AttributeError, TypeError, ValueError) as e:
|
||||
logger.warning(f"Failed to deserialize object of type {data['__type__']}: {e}")
|
||||
return data["__data__"]
|
||||
|
@ -248,6 +267,20 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
|
|||
recording = _current_storage.find_recording(request_hash)
|
||||
if recording:
|
||||
response_body = recording["response"]["body"]
|
||||
if (
|
||||
isinstance(response_body, list)
|
||||
and len(response_body) > 0
|
||||
and isinstance(response_body[0], ChatCompletionChunk)
|
||||
):
|
||||
# We can't replay chatcompletions with the same id and we store them in a sqlite database with a unique constraint on the id.
|
||||
# So we generate a new id and replace the old one.
|
||||
newid = uuid.uuid4().hex
|
||||
response_body[0].id = "chatcmpl-" + newid
|
||||
elif isinstance(response_body, ChatCompletion):
|
||||
# We can't replay chatcompletions with the same id and we store them in a sqlite database with a unique constraint on the id.
|
||||
# So we generate a new id and replace the old one.
|
||||
newid = uuid.uuid4().hex
|
||||
response_body.id = "chatcmpl-" + newid
|
||||
|
||||
if recording["response"].get("is_streaming", False):
|
||||
|
||||
|
@ -315,9 +348,11 @@ def patch_inference_clients():
|
|||
from openai.resources.chat.completions import AsyncCompletions as AsyncChatCompletions
|
||||
from openai.resources.completions import AsyncCompletions
|
||||
from openai.resources.embeddings import AsyncEmbeddings
|
||||
from openai.resources.models import AsyncModels
|
||||
|
||||
# Store original methods for both OpenAI and Ollama clients
|
||||
_original_methods = {
|
||||
"model_list": AsyncModels.list,
|
||||
"chat_completions_create": AsyncChatCompletions.create,
|
||||
"completions_create": AsyncCompletions.create,
|
||||
"embeddings_create": AsyncEmbeddings.create,
|
||||
|
@ -330,6 +365,55 @@ def patch_inference_clients():
|
|||
}
|
||||
|
||||
# Create patched methods for OpenAI client
|
||||
def patched_model_list(self, *args, **kwargs):
|
||||
# The original models.list() returns an AsyncPaginator that can be used with async for
|
||||
# We need to create a wrapper that preserves this behavior
|
||||
class PatchedAsyncPaginator:
|
||||
def __init__(self, original_method, instance, client_type, endpoint, args, kwargs):
|
||||
self.original_method = original_method
|
||||
self.instance = instance
|
||||
self.client_type = client_type
|
||||
self.endpoint = endpoint
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
self._result = None
|
||||
|
||||
def __await__(self):
|
||||
# Make it awaitable like the original AsyncPaginator
|
||||
async def _await():
|
||||
self._result = await _patched_inference_method(
|
||||
self.original_method, self.instance, self.client_type, self.endpoint, *self.args, **self.kwargs
|
||||
)
|
||||
return self._result
|
||||
|
||||
return _await().__await__()
|
||||
|
||||
def __aiter__(self):
|
||||
# Make it async iterable like the original AsyncPaginator
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
# Get the result if we haven't already
|
||||
if self._result is None:
|
||||
self._result = await _patched_inference_method(
|
||||
self.original_method, self.instance, self.client_type, self.endpoint, *self.args, **self.kwargs
|
||||
)
|
||||
|
||||
# Initialize iteration on first call
|
||||
if not hasattr(self, "_iter_index"):
|
||||
# Extract the data list from the result
|
||||
self._data_list = self._result.data
|
||||
self._iter_index = 0
|
||||
|
||||
# Return next item from the list
|
||||
if self._iter_index >= len(self._data_list):
|
||||
raise StopAsyncIteration
|
||||
item = self._data_list[self._iter_index]
|
||||
self._iter_index += 1
|
||||
return item
|
||||
|
||||
return PatchedAsyncPaginator(_original_methods["model_list"], self, "openai", "/v1/models", args, kwargs)
|
||||
|
||||
async def patched_chat_completions_create(self, *args, **kwargs):
|
||||
return await _patched_inference_method(
|
||||
_original_methods["chat_completions_create"], self, "openai", "/v1/chat/completions", *args, **kwargs
|
||||
|
@ -346,6 +430,7 @@ def patch_inference_clients():
|
|||
)
|
||||
|
||||
# Apply OpenAI patches
|
||||
AsyncModels.list = patched_model_list
|
||||
AsyncChatCompletions.create = patched_chat_completions_create
|
||||
AsyncCompletions.create = patched_completions_create
|
||||
AsyncEmbeddings.create = patched_embeddings_create
|
||||
|
@ -402,8 +487,10 @@ def unpatch_inference_clients():
|
|||
from openai.resources.chat.completions import AsyncCompletions as AsyncChatCompletions
|
||||
from openai.resources.completions import AsyncCompletions
|
||||
from openai.resources.embeddings import AsyncEmbeddings
|
||||
from openai.resources.models import AsyncModels
|
||||
|
||||
# Restore OpenAI client methods
|
||||
AsyncModels.list = _original_methods["model_list"]
|
||||
AsyncChatCompletions.create = _original_methods["chat_completions_create"]
|
||||
AsyncCompletions.create = _original_methods["completions_create"]
|
||||
AsyncEmbeddings.create = _original_methods["embeddings_create"]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue