mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-11 05:38:38 +00:00
Merge branch 'main' into prompt-api
This commit is contained in:
commit
5d610de5db
34 changed files with 679 additions and 240 deletions
|
@ -77,7 +77,7 @@ You must be careful when re-recording. CI workflows assume a specific setup for
|
|||
./scripts/github/schedule-record-workflow.sh --test-subdirs "agents,inference"
|
||||
|
||||
# Record with vision tests enabled
|
||||
./scripts/github/schedule-record-workflow.sh --test-subdirs "inference" --run-vision-tests
|
||||
./scripts/github/schedule-record-workflow.sh --test-suite vision
|
||||
|
||||
# Record with specific provider
|
||||
./scripts/github/schedule-record-workflow.sh --test-subdirs "agents" --test-provider vllm
|
||||
|
|
|
@ -42,6 +42,27 @@ Model parameters can be influenced by the following options:
|
|||
Each of these are comma-separated lists and can be used to generate multiple parameter combinations. Note that tests will be skipped
|
||||
if no model is specified.
|
||||
|
||||
### Suites (fast selection + sane defaults)
|
||||
|
||||
- `--suite`: comma-separated list of named suites that both narrow which tests are collected and prefill common model options (unless you pass them explicitly).
|
||||
- Available suites:
|
||||
- `responses`: collects tests under `tests/integration/responses`; this is a separate suite because it needs a strong tool-calling model.
|
||||
- `vision`: collects only `tests/integration/inference/test_vision_inference.py`; defaults `--vision-model=ollama/llama3.2-vision:11b`, `--embedding-model=sentence-transformers/all-MiniLM-L6-v2`.
|
||||
- Explicit flags always win. For example, `--suite=responses --text-model=<X>` overrides the suite’s text model.
|
||||
|
||||
Examples:
|
||||
|
||||
```bash
|
||||
# Fast responses run with defaults
|
||||
pytest -s -v tests/integration --stack-config=server:starter --suite=responses
|
||||
|
||||
# Fast single-file vision run with defaults
|
||||
pytest -s -v tests/integration --stack-config=server:starter --suite=vision
|
||||
|
||||
# Combine suites and override a default
|
||||
pytest -s -v tests/integration --stack-config=server:starter --suite=responses,vision --embedding-model=text-embedding-3-small
|
||||
```
|
||||
|
||||
## Examples
|
||||
|
||||
### Testing against a Server
|
||||
|
|
|
@ -268,3 +268,58 @@ class TestBatchesIntegration:
|
|||
|
||||
deleted_error_file = openai_client.files.delete(final_batch.error_file_id)
|
||||
assert deleted_error_file.deleted, f"Error file {final_batch.error_file_id} was not deleted successfully"
|
||||
|
||||
def test_batch_e2e_completions(self, openai_client, batch_helper, text_model_id):
|
||||
"""Run an end-to-end batch with a single successful text completion request."""
|
||||
request_body = {"model": text_model_id, "prompt": "Say completions", "max_tokens": 20}
|
||||
|
||||
batch_requests = [
|
||||
{
|
||||
"custom_id": "success-1",
|
||||
"method": "POST",
|
||||
"url": "/v1/completions",
|
||||
"body": request_body,
|
||||
}
|
||||
]
|
||||
|
||||
with batch_helper.create_file(batch_requests) as uploaded_file:
|
||||
batch = openai_client.batches.create(
|
||||
input_file_id=uploaded_file.id,
|
||||
endpoint="/v1/completions",
|
||||
completion_window="24h",
|
||||
metadata={"test": "e2e_completions_success"},
|
||||
)
|
||||
|
||||
final_batch = batch_helper.wait_for(
|
||||
batch.id,
|
||||
max_wait_time=3 * 60,
|
||||
expected_statuses={"completed"},
|
||||
timeout_action="skip",
|
||||
)
|
||||
|
||||
assert final_batch.status == "completed"
|
||||
assert final_batch.request_counts is not None
|
||||
assert final_batch.request_counts.total == 1
|
||||
assert final_batch.request_counts.completed == 1
|
||||
assert final_batch.output_file_id is not None
|
||||
|
||||
output_content = openai_client.files.content(final_batch.output_file_id)
|
||||
if isinstance(output_content, str):
|
||||
output_text = output_content
|
||||
else:
|
||||
output_text = output_content.content.decode("utf-8")
|
||||
|
||||
output_lines = output_text.strip().split("\n")
|
||||
assert len(output_lines) == 1
|
||||
|
||||
result = json.loads(output_lines[0])
|
||||
assert result["custom_id"] == "success-1"
|
||||
assert "response" in result
|
||||
assert result["response"]["status_code"] == 200
|
||||
|
||||
deleted_output_file = openai_client.files.delete(final_batch.output_file_id)
|
||||
assert deleted_output_file.deleted
|
||||
|
||||
if final_batch.error_file_id is not None:
|
||||
deleted_error_file = openai_client.files.delete(final_batch.error_file_id)
|
||||
assert deleted_error_file.deleted
|
||||
|
|
|
@ -6,15 +6,17 @@
|
|||
import inspect
|
||||
import itertools
|
||||
import os
|
||||
import platform
|
||||
import textwrap
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from .suites import SUITE_DEFINITIONS
|
||||
|
||||
logger = get_logger(__name__, category="tests")
|
||||
|
||||
|
||||
|
@ -61,9 +63,22 @@ def pytest_configure(config):
|
|||
key, value = env_var.split("=", 1)
|
||||
os.environ[key] = value
|
||||
|
||||
if platform.system() == "Darwin": # Darwin is the system name for macOS
|
||||
os.environ["DISABLE_CODE_SANDBOX"] = "1"
|
||||
logger.info("Setting DISABLE_CODE_SANDBOX=1 for macOS")
|
||||
suites_raw = config.getoption("--suite")
|
||||
suites: list[str] = []
|
||||
if suites_raw:
|
||||
suites = [p.strip() for p in str(suites_raw).split(",") if p.strip()]
|
||||
unknown = [p for p in suites if p not in SUITE_DEFINITIONS]
|
||||
if unknown:
|
||||
raise pytest.UsageError(
|
||||
f"Unknown suite(s): {', '.join(unknown)}. Available: {', '.join(sorted(SUITE_DEFINITIONS.keys()))}"
|
||||
)
|
||||
for suite in suites:
|
||||
suite_def = SUITE_DEFINITIONS.get(suite, {})
|
||||
defaults: dict = suite_def.get("defaults", {})
|
||||
for dest, value in defaults.items():
|
||||
current = getattr(config.option, dest, None)
|
||||
if not current:
|
||||
setattr(config.option, dest, value)
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
|
@ -105,16 +120,21 @@ def pytest_addoption(parser):
|
|||
default=384,
|
||||
help="Output dimensionality of the embedding model to use for testing. Default: 384",
|
||||
)
|
||||
parser.addoption(
|
||||
"--record-responses",
|
||||
action="store_true",
|
||||
help="Record new API responses instead of using cached ones.",
|
||||
)
|
||||
parser.addoption(
|
||||
"--report",
|
||||
help="Path where the test report should be written, e.g. --report=/path/to/report.md",
|
||||
)
|
||||
|
||||
available_suites = ", ".join(sorted(SUITE_DEFINITIONS.keys()))
|
||||
suite_help = (
|
||||
"Comma-separated integration test suites to narrow collection and prefill defaults. "
|
||||
"Available: "
|
||||
f"{available_suites}. "
|
||||
"Explicit CLI flags (e.g., --text-model) override suite defaults. "
|
||||
"Examples: --suite=responses or --suite=responses,vision."
|
||||
)
|
||||
parser.addoption("--suite", help=suite_help)
|
||||
|
||||
|
||||
MODEL_SHORT_IDS = {
|
||||
"meta-llama/Llama-3.2-3B-Instruct": "3B",
|
||||
|
@ -197,3 +217,40 @@ def pytest_generate_tests(metafunc):
|
|||
|
||||
|
||||
pytest_plugins = ["tests.integration.fixtures.common"]
|
||||
|
||||
|
||||
def pytest_ignore_collect(path: str, config: pytest.Config) -> bool:
|
||||
"""Skip collecting paths outside the selected suite roots for speed."""
|
||||
suites_raw = config.getoption("--suite")
|
||||
if not suites_raw:
|
||||
return False
|
||||
|
||||
names = [p.strip() for p in str(suites_raw).split(",") if p.strip()]
|
||||
roots: list[str] = []
|
||||
for name in names:
|
||||
suite_def = SUITE_DEFINITIONS.get(name)
|
||||
if suite_def:
|
||||
roots.extend(suite_def.get("roots", []))
|
||||
if not roots:
|
||||
return False
|
||||
|
||||
p = Path(str(path)).resolve()
|
||||
|
||||
# Only constrain within tests/integration to avoid ignoring unrelated tests
|
||||
integration_root = (Path(str(config.rootpath)) / "tests" / "integration").resolve()
|
||||
if not p.is_relative_to(integration_root):
|
||||
return False
|
||||
|
||||
for r in roots:
|
||||
rp = (Path(str(config.rootpath)) / r).resolve()
|
||||
if rp.is_file():
|
||||
# Allow the exact file and any ancestor directories so pytest can walk into it.
|
||||
if p == rp:
|
||||
return False
|
||||
if p.is_dir() and rp.is_relative_to(p):
|
||||
return False
|
||||
else:
|
||||
# Allow anything inside an allowed directory
|
||||
if p.is_relative_to(rp):
|
||||
return False
|
||||
return True
|
||||
|
|
|
@ -58,6 +58,15 @@ def skip_if_model_doesnt_support_suffix(client_with_models, model_id):
|
|||
pytest.skip(f"Provider {provider.provider_type} doesn't support suffix.")
|
||||
|
||||
|
||||
def skip_if_doesnt_support_n(client_with_models, model_id):
|
||||
provider = provider_from_model(client_with_models, model_id)
|
||||
if provider.provider_type in (
|
||||
"remote::sambanova",
|
||||
"remote::ollama",
|
||||
):
|
||||
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support n param.")
|
||||
|
||||
|
||||
def skip_if_model_doesnt_support_openai_chat_completion(client_with_models, model_id):
|
||||
provider = provider_from_model(client_with_models, model_id)
|
||||
if provider.provider_type in (
|
||||
|
@ -262,10 +271,7 @@ def test_openai_chat_completion_streaming(compat_client, client_with_models, tex
|
|||
)
|
||||
def test_openai_chat_completion_streaming_with_n(compat_client, client_with_models, text_model_id, test_case):
|
||||
skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id)
|
||||
|
||||
provider = provider_from_model(client_with_models, text_model_id)
|
||||
if provider.provider_type == "remote::ollama":
|
||||
pytest.skip(f"Model {text_model_id} hosted by {provider.provider_type} doesn't support n > 1.")
|
||||
skip_if_doesnt_support_n(client_with_models, text_model_id)
|
||||
|
||||
tc = TestCase(test_case)
|
||||
question = tc["question"]
|
||||
|
|
42
tests/integration/recordings/responses/41e27b9b5d09.json
Normal file
42
tests/integration/recordings/responses/41e27b9b5d09.json
Normal file
|
@ -0,0 +1,42 @@
|
|||
{
|
||||
"request": {
|
||||
"method": "POST",
|
||||
"url": "http://0.0.0.0:11434/v1/v1/completions",
|
||||
"headers": {},
|
||||
"body": {
|
||||
"model": "llama3.2:3b-instruct-fp16",
|
||||
"prompt": "Say completions",
|
||||
"max_tokens": 20
|
||||
},
|
||||
"endpoint": "/v1/completions",
|
||||
"model": "llama3.2:3b-instruct-fp16"
|
||||
},
|
||||
"response": {
|
||||
"body": {
|
||||
"__type__": "openai.types.completion.Completion",
|
||||
"__data__": {
|
||||
"id": "cmpl-271",
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "length",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"text": "You want me to respond with a completion, but you didn't specify what I should complete. Could"
|
||||
}
|
||||
],
|
||||
"created": 1756846620,
|
||||
"model": "llama3.2:3b-instruct-fp16",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "fp_ollama",
|
||||
"usage": {
|
||||
"completion_tokens": 20,
|
||||
"prompt_tokens": 28,
|
||||
"total_tokens": 48,
|
||||
"completion_tokens_details": null,
|
||||
"prompt_tokens_details": null
|
||||
}
|
||||
}
|
||||
},
|
||||
"is_streaming": false
|
||||
}
|
||||
}
|
Before Width: | Height: | Size: 108 KiB After Width: | Height: | Size: 108 KiB |
Before Width: | Height: | Size: 148 KiB After Width: | Height: | Size: 148 KiB |
Before Width: | Height: | Size: 139 KiB After Width: | Height: | Size: 139 KiB |
53
tests/integration/suites.py
Normal file
53
tests/integration/suites.py
Normal file
|
@ -0,0 +1,53 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Central definition of integration test suites. You can use these suites by passing --suite=name to pytest.
|
||||
# For example:
|
||||
#
|
||||
# ```bash
|
||||
# pytest tests/integration/ --suite=vision
|
||||
# ```
|
||||
#
|
||||
# Each suite can:
|
||||
# - restrict collection to specific roots (dirs or files)
|
||||
# - provide default CLI option values (e.g. text_model, embedding_model, etc.)
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
this_dir = Path(__file__).parent
|
||||
default_roots = [
|
||||
str(p)
|
||||
for p in this_dir.glob("*")
|
||||
if p.is_dir()
|
||||
and p.name not in ("__pycache__", "fixtures", "test_cases", "recordings", "responses", "post_training")
|
||||
]
|
||||
|
||||
SUITE_DEFINITIONS: dict[str, dict] = {
|
||||
"base": {
|
||||
"description": "Base suite that includes most tests but runs them with a text Ollama model",
|
||||
"roots": default_roots,
|
||||
"defaults": {
|
||||
"text_model": "ollama/llama3.2:3b-instruct-fp16",
|
||||
"embedding_model": "sentence-transformers/all-MiniLM-L6-v2",
|
||||
},
|
||||
},
|
||||
"responses": {
|
||||
"description": "Suite that includes only the OpenAI Responses tests; needs a strong tool-calling model",
|
||||
"roots": ["tests/integration/responses"],
|
||||
"defaults": {
|
||||
"text_model": "openai/gpt-4o",
|
||||
"embedding_model": "sentence-transformers/all-MiniLM-L6-v2",
|
||||
},
|
||||
},
|
||||
"vision": {
|
||||
"description": "Suite that includes only the vision tests",
|
||||
"roots": ["tests/integration/inference/test_vision_inference.py"],
|
||||
"defaults": {
|
||||
"vision_model": "ollama/llama3.2-vision:11b",
|
||||
"embedding_model": "sentence-transformers/all-MiniLM-L6-v2",
|
||||
},
|
||||
},
|
||||
}
|
|
@ -47,34 +47,45 @@ def client_with_empty_registry(client_with_models):
|
|||
|
||||
|
||||
def test_vector_db_retrieve(client_with_empty_registry, embedding_model_id, embedding_dimension):
|
||||
# Register a memory bank first
|
||||
vector_db_id = "test_vector_db"
|
||||
client_with_empty_registry.vector_dbs.register(
|
||||
vector_db_id=vector_db_id,
|
||||
vector_db_name = "test_vector_db"
|
||||
register_response = client_with_empty_registry.vector_dbs.register(
|
||||
vector_db_id=vector_db_name,
|
||||
embedding_model=embedding_model_id,
|
||||
embedding_dimension=embedding_dimension,
|
||||
)
|
||||
|
||||
actual_vector_db_id = register_response.identifier
|
||||
|
||||
# Retrieve the memory bank and validate its properties
|
||||
response = client_with_empty_registry.vector_dbs.retrieve(vector_db_id=vector_db_id)
|
||||
response = client_with_empty_registry.vector_dbs.retrieve(vector_db_id=actual_vector_db_id)
|
||||
assert response is not None
|
||||
assert response.identifier == vector_db_id
|
||||
assert response.identifier == actual_vector_db_id
|
||||
assert response.embedding_model == embedding_model_id
|
||||
assert response.provider_resource_id == vector_db_id
|
||||
assert response.identifier.startswith("vs_")
|
||||
|
||||
|
||||
def test_vector_db_register(client_with_empty_registry, embedding_model_id, embedding_dimension):
|
||||
vector_db_id = "test_vector_db"
|
||||
client_with_empty_registry.vector_dbs.register(
|
||||
vector_db_id=vector_db_id,
|
||||
vector_db_name = "test_vector_db"
|
||||
response = client_with_empty_registry.vector_dbs.register(
|
||||
vector_db_id=vector_db_name,
|
||||
embedding_model=embedding_model_id,
|
||||
embedding_dimension=embedding_dimension,
|
||||
)
|
||||
|
||||
vector_dbs_after_register = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()]
|
||||
assert vector_dbs_after_register == [vector_db_id]
|
||||
actual_vector_db_id = response.identifier
|
||||
assert actual_vector_db_id.startswith("vs_")
|
||||
assert actual_vector_db_id != vector_db_name
|
||||
|
||||
client_with_empty_registry.vector_dbs.unregister(vector_db_id=vector_db_id)
|
||||
vector_dbs_after_register = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()]
|
||||
assert vector_dbs_after_register == [actual_vector_db_id]
|
||||
|
||||
vector_stores = client_with_empty_registry.vector_stores.list()
|
||||
assert len(vector_stores.data) == 1
|
||||
vector_store = vector_stores.data[0]
|
||||
assert vector_store.id == actual_vector_db_id
|
||||
assert vector_store.name == vector_db_name
|
||||
|
||||
client_with_empty_registry.vector_dbs.unregister(vector_db_id=actual_vector_db_id)
|
||||
|
||||
vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()]
|
||||
assert len(vector_dbs) == 0
|
||||
|
@ -91,20 +102,22 @@ def test_vector_db_register(client_with_empty_registry, embedding_model_id, embe
|
|||
],
|
||||
)
|
||||
def test_insert_chunks(client_with_empty_registry, embedding_model_id, embedding_dimension, sample_chunks, test_case):
|
||||
vector_db_id = "test_vector_db"
|
||||
client_with_empty_registry.vector_dbs.register(
|
||||
vector_db_id=vector_db_id,
|
||||
vector_db_name = "test_vector_db"
|
||||
register_response = client_with_empty_registry.vector_dbs.register(
|
||||
vector_db_id=vector_db_name,
|
||||
embedding_model=embedding_model_id,
|
||||
embedding_dimension=embedding_dimension,
|
||||
)
|
||||
|
||||
actual_vector_db_id = register_response.identifier
|
||||
|
||||
client_with_empty_registry.vector_io.insert(
|
||||
vector_db_id=vector_db_id,
|
||||
vector_db_id=actual_vector_db_id,
|
||||
chunks=sample_chunks,
|
||||
)
|
||||
|
||||
response = client_with_empty_registry.vector_io.query(
|
||||
vector_db_id=vector_db_id,
|
||||
vector_db_id=actual_vector_db_id,
|
||||
query="What is the capital of France?",
|
||||
)
|
||||
assert response is not None
|
||||
|
@ -113,7 +126,7 @@ def test_insert_chunks(client_with_empty_registry, embedding_model_id, embedding
|
|||
|
||||
query, expected_doc_id = test_case
|
||||
response = client_with_empty_registry.vector_io.query(
|
||||
vector_db_id=vector_db_id,
|
||||
vector_db_id=actual_vector_db_id,
|
||||
query=query,
|
||||
)
|
||||
assert response is not None
|
||||
|
@ -128,13 +141,15 @@ def test_insert_chunks_with_precomputed_embeddings(client_with_empty_registry, e
|
|||
"remote::qdrant": {"score_threshold": -1.0},
|
||||
"inline::qdrant": {"score_threshold": -1.0},
|
||||
}
|
||||
vector_db_id = "test_precomputed_embeddings_db"
|
||||
client_with_empty_registry.vector_dbs.register(
|
||||
vector_db_id=vector_db_id,
|
||||
vector_db_name = "test_precomputed_embeddings_db"
|
||||
register_response = client_with_empty_registry.vector_dbs.register(
|
||||
vector_db_id=vector_db_name,
|
||||
embedding_model=embedding_model_id,
|
||||
embedding_dimension=embedding_dimension,
|
||||
)
|
||||
|
||||
actual_vector_db_id = register_response.identifier
|
||||
|
||||
chunks_with_embeddings = [
|
||||
Chunk(
|
||||
content="This is a test chunk with precomputed embedding.",
|
||||
|
@ -144,13 +159,13 @@ def test_insert_chunks_with_precomputed_embeddings(client_with_empty_registry, e
|
|||
]
|
||||
|
||||
client_with_empty_registry.vector_io.insert(
|
||||
vector_db_id=vector_db_id,
|
||||
vector_db_id=actual_vector_db_id,
|
||||
chunks=chunks_with_embeddings,
|
||||
)
|
||||
|
||||
provider = [p.provider_id for p in client_with_empty_registry.providers.list() if p.api == "vector_io"][0]
|
||||
response = client_with_empty_registry.vector_io.query(
|
||||
vector_db_id=vector_db_id,
|
||||
vector_db_id=actual_vector_db_id,
|
||||
query="precomputed embedding test",
|
||||
params=vector_io_provider_params_dict.get(provider, None),
|
||||
)
|
||||
|
@ -173,13 +188,15 @@ def test_query_returns_valid_object_when_identical_to_embedding_in_vdb(
|
|||
"remote::qdrant": {"score_threshold": 0.0},
|
||||
"inline::qdrant": {"score_threshold": 0.0},
|
||||
}
|
||||
vector_db_id = "test_precomputed_embeddings_db"
|
||||
client_with_empty_registry.vector_dbs.register(
|
||||
vector_db_id=vector_db_id,
|
||||
vector_db_name = "test_precomputed_embeddings_db"
|
||||
register_response = client_with_empty_registry.vector_dbs.register(
|
||||
vector_db_id=vector_db_name,
|
||||
embedding_model=embedding_model_id,
|
||||
embedding_dimension=embedding_dimension,
|
||||
)
|
||||
|
||||
actual_vector_db_id = register_response.identifier
|
||||
|
||||
chunks_with_embeddings = [
|
||||
Chunk(
|
||||
content="duplicate",
|
||||
|
@ -189,13 +206,13 @@ def test_query_returns_valid_object_when_identical_to_embedding_in_vdb(
|
|||
]
|
||||
|
||||
client_with_empty_registry.vector_io.insert(
|
||||
vector_db_id=vector_db_id,
|
||||
vector_db_id=actual_vector_db_id,
|
||||
chunks=chunks_with_embeddings,
|
||||
)
|
||||
|
||||
provider = [p.provider_id for p in client_with_empty_registry.providers.list() if p.api == "vector_io"][0]
|
||||
response = client_with_empty_registry.vector_io.query(
|
||||
vector_db_id=vector_db_id,
|
||||
vector_db_id=actual_vector_db_id,
|
||||
query="duplicate",
|
||||
params=vector_io_provider_params_dict.get(provider, None),
|
||||
)
|
||||
|
|
|
@ -146,6 +146,20 @@ class VectorDBImpl(Impl):
|
|||
async def unregister_vector_db(self, vector_db_id: str):
|
||||
return vector_db_id
|
||||
|
||||
async def openai_create_vector_store(self, **kwargs):
|
||||
import time
|
||||
import uuid
|
||||
|
||||
from llama_stack.apis.vector_io.vector_io import VectorStoreFileCounts, VectorStoreObject
|
||||
|
||||
vector_store_id = kwargs.get("provider_vector_db_id") or f"vs_{uuid.uuid4()}"
|
||||
return VectorStoreObject(
|
||||
id=vector_store_id,
|
||||
name=kwargs.get("name", vector_store_id),
|
||||
created_at=int(time.time()),
|
||||
file_counts=VectorStoreFileCounts(completed=0, cancelled=0, failed=0, in_progress=0, total=0),
|
||||
)
|
||||
|
||||
|
||||
async def test_models_routing_table(cached_disk_dist_registry):
|
||||
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
|
||||
|
@ -247,17 +261,21 @@ async def test_vectordbs_routing_table(cached_disk_dist_registry):
|
|||
)
|
||||
|
||||
# Register multiple vector databases and verify listing
|
||||
await table.register_vector_db(vector_db_id="test-vectordb", embedding_model="test_provider/test-model")
|
||||
await table.register_vector_db(vector_db_id="test-vectordb-2", embedding_model="test_provider/test-model")
|
||||
vdb1 = await table.register_vector_db(vector_db_id="test-vectordb", embedding_model="test_provider/test-model")
|
||||
vdb2 = await table.register_vector_db(vector_db_id="test-vectordb-2", embedding_model="test_provider/test-model")
|
||||
vector_dbs = await table.list_vector_dbs()
|
||||
|
||||
assert len(vector_dbs.data) == 2
|
||||
vector_db_ids = {v.identifier for v in vector_dbs.data}
|
||||
assert "test-vectordb" in vector_db_ids
|
||||
assert "test-vectordb-2" in vector_db_ids
|
||||
assert vdb1.identifier in vector_db_ids
|
||||
assert vdb2.identifier in vector_db_ids
|
||||
|
||||
await table.unregister_vector_db(vector_db_id="test-vectordb")
|
||||
await table.unregister_vector_db(vector_db_id="test-vectordb-2")
|
||||
# Verify they have UUID-based identifiers
|
||||
assert vdb1.identifier.startswith("vs_")
|
||||
assert vdb2.identifier.startswith("vs_")
|
||||
|
||||
await table.unregister_vector_db(vector_db_id=vdb1.identifier)
|
||||
await table.unregister_vector_db(vector_db_id=vdb2.identifier)
|
||||
|
||||
vector_dbs = await table.list_vector_dbs()
|
||||
assert len(vector_dbs.data) == 0
|
||||
|
|
|
@ -7,6 +7,7 @@
|
|||
# Unit tests for the routing tables vector_dbs
|
||||
|
||||
import time
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
@ -34,6 +35,7 @@ from tests.unit.distribution.routers.test_routing_tables import Impl, InferenceI
|
|||
class VectorDBImpl(Impl):
|
||||
def __init__(self):
|
||||
super().__init__(Api.vector_io)
|
||||
self.vector_stores = {}
|
||||
|
||||
async def register_vector_db(self, vector_db: VectorDB):
|
||||
return vector_db
|
||||
|
@ -114,8 +116,35 @@ class VectorDBImpl(Impl):
|
|||
async def openai_delete_vector_store_file(self, vector_store_id, file_id):
|
||||
return VectorStoreFileDeleteResponse(id=file_id, deleted=True)
|
||||
|
||||
async def openai_create_vector_store(
|
||||
self,
|
||||
name=None,
|
||||
embedding_model=None,
|
||||
embedding_dimension=None,
|
||||
provider_id=None,
|
||||
provider_vector_db_id=None,
|
||||
**kwargs,
|
||||
):
|
||||
vector_store_id = provider_vector_db_id or f"vs_{uuid.uuid4()}"
|
||||
vector_store = VectorStoreObject(
|
||||
id=vector_store_id,
|
||||
name=name or vector_store_id,
|
||||
created_at=int(time.time()),
|
||||
file_counts=VectorStoreFileCounts(completed=0, cancelled=0, failed=0, in_progress=0, total=0),
|
||||
)
|
||||
self.vector_stores[vector_store_id] = vector_store
|
||||
return vector_store
|
||||
|
||||
async def openai_list_vector_stores(self, **kwargs):
|
||||
from llama_stack.apis.vector_io.vector_io import VectorStoreListResponse
|
||||
|
||||
return VectorStoreListResponse(
|
||||
data=list(self.vector_stores.values()), has_more=False, first_id=None, last_id=None
|
||||
)
|
||||
|
||||
|
||||
async def test_vectordbs_routing_table(cached_disk_dist_registry):
|
||||
n = 10
|
||||
table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry, {})
|
||||
await table.initialize()
|
||||
|
||||
|
@ -129,22 +158,98 @@ async def test_vectordbs_routing_table(cached_disk_dist_registry):
|
|||
)
|
||||
|
||||
# Register multiple vector databases and verify listing
|
||||
await table.register_vector_db(vector_db_id="test-vectordb", embedding_model="test-model")
|
||||
await table.register_vector_db(vector_db_id="test-vectordb-2", embedding_model="test-model")
|
||||
vdb_dict = {}
|
||||
for i in range(n):
|
||||
vdb_dict[i] = await table.register_vector_db(vector_db_id=f"test-vectordb-{i}", embedding_model="test-model")
|
||||
|
||||
vector_dbs = await table.list_vector_dbs()
|
||||
|
||||
assert len(vector_dbs.data) == 2
|
||||
assert len(vector_dbs.data) == len(vdb_dict)
|
||||
vector_db_ids = {v.identifier for v in vector_dbs.data}
|
||||
assert "test-vectordb" in vector_db_ids
|
||||
assert "test-vectordb-2" in vector_db_ids
|
||||
|
||||
await table.unregister_vector_db(vector_db_id="test-vectordb")
|
||||
await table.unregister_vector_db(vector_db_id="test-vectordb-2")
|
||||
for k in vdb_dict:
|
||||
assert vdb_dict[k].identifier in vector_db_ids
|
||||
for k in vdb_dict:
|
||||
await table.unregister_vector_db(vector_db_id=vdb_dict[k].identifier)
|
||||
|
||||
vector_dbs = await table.list_vector_dbs()
|
||||
assert len(vector_dbs.data) == 0
|
||||
|
||||
|
||||
async def test_vector_db_and_vector_store_id_mapping(cached_disk_dist_registry):
|
||||
n = 10
|
||||
impl = VectorDBImpl()
|
||||
table = VectorDBsRoutingTable({"test_provider": impl}, cached_disk_dist_registry, {})
|
||||
await table.initialize()
|
||||
|
||||
m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
|
||||
await m_table.initialize()
|
||||
await m_table.register_model(
|
||||
model_id="test-model",
|
||||
provider_id="test_provider",
|
||||
metadata={"embedding_dimension": 128},
|
||||
model_type=ModelType.embedding,
|
||||
)
|
||||
|
||||
vdb_dict = {}
|
||||
for i in range(n):
|
||||
vdb_dict[i] = await table.register_vector_db(vector_db_id=f"test-vectordb-{i}", embedding_model="test-model")
|
||||
|
||||
vector_dbs = await table.list_vector_dbs()
|
||||
vector_db_ids = {v.identifier for v in vector_dbs.data}
|
||||
|
||||
vector_stores = await impl.openai_list_vector_stores()
|
||||
vector_store_ids = {v.id for v in vector_stores.data}
|
||||
|
||||
assert vector_db_ids == vector_store_ids, (
|
||||
f"Vector DB IDs {vector_db_ids} don't match vector store IDs {vector_store_ids}"
|
||||
)
|
||||
|
||||
for vector_store in vector_stores.data:
|
||||
vector_db = await table.get_vector_db(vector_store.id)
|
||||
assert vector_store.name == vector_db.vector_db_name, (
|
||||
f"Vector store name {vector_store.name} doesn't match vector store ID {vector_store.id}"
|
||||
)
|
||||
|
||||
for vector_db_id in vector_db_ids:
|
||||
await table.unregister_vector_db(vector_db_id)
|
||||
|
||||
assert len((await table.list_vector_dbs()).data) == 0
|
||||
|
||||
|
||||
async def test_vector_db_id_becomes_vector_store_name(cached_disk_dist_registry):
|
||||
impl = VectorDBImpl()
|
||||
table = VectorDBsRoutingTable({"test_provider": impl}, cached_disk_dist_registry, {})
|
||||
await table.initialize()
|
||||
|
||||
m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
|
||||
await m_table.initialize()
|
||||
await m_table.register_model(
|
||||
model_id="test-model",
|
||||
provider_id="test_provider",
|
||||
metadata={"embedding_dimension": 128},
|
||||
model_type=ModelType.embedding,
|
||||
)
|
||||
|
||||
user_provided_id = "my-custom-vector-db"
|
||||
await table.register_vector_db(vector_db_id=user_provided_id, embedding_model="test-model")
|
||||
|
||||
vector_stores = await impl.openai_list_vector_stores()
|
||||
assert len(vector_stores.data) == 1
|
||||
|
||||
vector_store = vector_stores.data[0]
|
||||
|
||||
assert vector_store.name == user_provided_id
|
||||
|
||||
assert vector_store.id.startswith("vs_")
|
||||
assert vector_store.id != user_provided_id
|
||||
|
||||
vector_dbs = await table.list_vector_dbs()
|
||||
assert len(vector_dbs.data) == 1
|
||||
assert vector_dbs.data[0].identifier == vector_store.id
|
||||
|
||||
await table.unregister_vector_db(vector_store.id)
|
||||
|
||||
|
||||
async def test_openai_vector_stores_routing_table_roles(cached_disk_dist_registry):
|
||||
impl = VectorDBImpl()
|
||||
impl.openai_retrieve_vector_store = AsyncMock(return_value="OK")
|
||||
|
@ -164,7 +269,8 @@ async def test_openai_vector_stores_routing_table_roles(cached_disk_dist_registr
|
|||
|
||||
authorized_user = User(principal="alice", attributes={"roles": [authorized_team]})
|
||||
with request_provider_data_context({}, authorized_user):
|
||||
_ = await table.register_vector_db(vector_db_id="vs1", embedding_model="test-model")
|
||||
registered_vdb = await table.register_vector_db(vector_db_id="vs1", embedding_model="test-model")
|
||||
authorized_table = registered_vdb.identifier # Use the actual generated ID
|
||||
|
||||
# Authorized reader
|
||||
with request_provider_data_context({}, authorized_user):
|
||||
|
@ -227,7 +333,8 @@ async def test_openai_vector_stores_routing_table_actions(cached_disk_dist_regis
|
|||
)
|
||||
|
||||
with request_provider_data_context({}, admin_user):
|
||||
await table.register_vector_db(vector_db_id=vector_db_id, embedding_model="test-model")
|
||||
registered_vdb = await table.register_vector_db(vector_db_id=vector_db_id, embedding_model="test-model")
|
||||
vector_db_id = registered_vdb.identifier # Use the actual generated ID
|
||||
|
||||
read_methods = [
|
||||
(table.openai_retrieve_vector_store, (vector_db_id,), {}),
|
||||
|
|
|
@ -46,7 +46,8 @@ The tests are categorized and outlined below, keep this updated:
|
|||
* test_validate_input_url_mismatch (negative)
|
||||
* test_validate_input_multiple_errors_per_request (negative)
|
||||
* test_validate_input_invalid_request_format (negative)
|
||||
* test_validate_input_missing_parameters (parametrized negative - custom_id, method, url, body, model, messages missing validation)
|
||||
* test_validate_input_missing_parameters_chat_completions (parametrized negative - custom_id, method, url, body, model, messages missing validation for chat/completions)
|
||||
* test_validate_input_missing_parameters_completions (parametrized negative - custom_id, method, url, body, model, prompt missing validation for completions)
|
||||
* test_validate_input_invalid_parameter_types (parametrized negative - custom_id, url, method, body, model, messages type validation)
|
||||
|
||||
The tests use temporary SQLite databases for isolation and mock external
|
||||
|
@ -213,7 +214,6 @@ class TestReferenceBatchesImpl:
|
|||
"endpoint",
|
||||
[
|
||||
"/v1/embeddings",
|
||||
"/v1/completions",
|
||||
"/v1/invalid/endpoint",
|
||||
"",
|
||||
],
|
||||
|
@ -499,8 +499,10 @@ class TestReferenceBatchesImpl:
|
|||
("messages", "body.messages", "invalid_request", "Messages parameter is required"),
|
||||
],
|
||||
)
|
||||
async def test_validate_input_missing_parameters(self, provider, param_name, param_path, error_code, error_message):
|
||||
"""Test _validate_input when file contains request with missing required parameters."""
|
||||
async def test_validate_input_missing_parameters_chat_completions(
|
||||
self, provider, param_name, param_path, error_code, error_message
|
||||
):
|
||||
"""Test _validate_input when file contains request with missing required parameters for chat completions."""
|
||||
provider.files_api.openai_retrieve_file = AsyncMock()
|
||||
mock_response = MagicMock()
|
||||
|
||||
|
@ -541,6 +543,61 @@ class TestReferenceBatchesImpl:
|
|||
assert errors[0].message == error_message
|
||||
assert errors[0].param == param_path
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"param_name,param_path,error_code,error_message",
|
||||
[
|
||||
("custom_id", "custom_id", "missing_required_parameter", "Missing required parameter: custom_id"),
|
||||
("method", "method", "missing_required_parameter", "Missing required parameter: method"),
|
||||
("url", "url", "missing_required_parameter", "Missing required parameter: url"),
|
||||
("body", "body", "missing_required_parameter", "Missing required parameter: body"),
|
||||
("model", "body.model", "invalid_request", "Model parameter is required"),
|
||||
("prompt", "body.prompt", "invalid_request", "Prompt parameter is required"),
|
||||
],
|
||||
)
|
||||
async def test_validate_input_missing_parameters_completions(
|
||||
self, provider, param_name, param_path, error_code, error_message
|
||||
):
|
||||
"""Test _validate_input when file contains request with missing required parameters for text completions."""
|
||||
provider.files_api.openai_retrieve_file = AsyncMock()
|
||||
mock_response = MagicMock()
|
||||
|
||||
base_request = {
|
||||
"custom_id": "req-1",
|
||||
"method": "POST",
|
||||
"url": "/v1/completions",
|
||||
"body": {"model": "test-model", "prompt": "Hello"},
|
||||
}
|
||||
|
||||
# Remove the specific parameter being tested
|
||||
if "." in param_path:
|
||||
top_level, nested_param = param_path.split(".", 1)
|
||||
del base_request[top_level][nested_param]
|
||||
else:
|
||||
del base_request[param_name]
|
||||
|
||||
mock_response.body = json.dumps(base_request).encode()
|
||||
provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
|
||||
|
||||
batch = BatchObject(
|
||||
id="batch_test",
|
||||
object="batch",
|
||||
endpoint="/v1/completions",
|
||||
input_file_id=f"missing_{param_name}_file",
|
||||
completion_window="24h",
|
||||
status="validating",
|
||||
created_at=1234567890,
|
||||
)
|
||||
|
||||
errors, requests = await provider._validate_input(batch)
|
||||
|
||||
assert len(errors) == 1
|
||||
assert len(requests) == 0
|
||||
|
||||
assert errors[0].code == error_code
|
||||
assert errors[0].line == 1
|
||||
assert errors[0].message == error_message
|
||||
assert errors[0].param == param_path
|
||||
|
||||
async def test_validate_input_url_mismatch(self, provider):
|
||||
"""Test _validate_input when file contains request with URL that doesn't match batch endpoint."""
|
||||
provider.files_api.openai_retrieve_file = AsyncMock()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue