Merge branch 'main' into eval_api_final

This commit is contained in:
Xi Yan 2025-03-26 12:29:45 -07:00
commit bc0cd07008
79 changed files with 3257 additions and 2358 deletions

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,60 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
POST_TRAINING_PROVIDER_TYPES = ["remote::nvidia"]
@pytest.mark.integration
@pytest.fixture(scope="session")
def post_training_provider_available(llama_stack_client):
providers = llama_stack_client.providers.list()
post_training_providers = [p for p in providers if p.provider_type in POST_TRAINING_PROVIDER_TYPES]
return len(post_training_providers) > 0
@pytest.mark.integration
def test_post_training_provider_registration(llama_stack_client, post_training_provider_available):
"""Check if post_training is in the api list.
This is a sanity check to ensure the provider is registered."""
if not post_training_provider_available:
pytest.skip("post training provider not available")
providers = llama_stack_client.providers.list()
post_training_providers = [p for p in providers if p.provider_type in POST_TRAINING_PROVIDER_TYPES]
assert len(post_training_providers) > 0
@pytest.mark.integration
def test_get_training_jobs(llama_stack_client, post_training_provider_available):
"""Test listing all training jobs."""
if not post_training_provider_available:
pytest.skip("post training provider not available")
jobs = llama_stack_client.post_training.get_training_jobs()
assert isinstance(jobs, dict)
assert "data" in jobs
assert isinstance(jobs["data"], list)
@pytest.mark.integration
def test_get_training_job_status(llama_stack_client, post_training_provider_available):
"""Test getting status of a specific training job."""
if not post_training_provider_available:
pytest.skip("post training provider not available")
jobs = llama_stack_client.post_training.get_training_jobs()
if not jobs["data"]:
pytest.skip("No training jobs available to check status")
job_uuid = jobs["data"][0]["job_uuid"]
job_status = llama_stack_client.post_training.get_training_job_status(job_uuid=job_uuid)
assert job_status is not None
assert "job_uuid" in job_status
assert "status" in job_status
assert job_status["job_uuid"] == job_uuid

View file

@ -99,6 +99,33 @@ def test_text_completion_streaming(client_with_models, text_model_id, test_case)
assert len(content_str) > 10
@pytest.mark.parametrize(
"test_case",
[
"inference:completion:stop_sequence",
],
)
def test_text_completion_stop_sequence(client_with_models, text_model_id, inference_provider_type, test_case):
skip_if_model_doesnt_support_completion(client_with_models, text_model_id)
# This is only supported/tested for remote vLLM: https://github.com/meta-llama/llama-stack/issues/1771
if inference_provider_type != "remote::vllm":
pytest.xfail(f"{inference_provider_type} doesn't support 'stop' parameter yet")
tc = TestCase(test_case)
response = client_with_models.inference.completion(
content=tc["content"],
stream=True,
model_id=text_model_id,
sampling_params={
"max_tokens": 50,
"stop": ["1963"],
},
)
streamed_content = [chunk.delta for chunk in response]
content_str = "".join(streamed_content).lower().strip()
assert "1963" not in content_str
@pytest.mark.parametrize(
"test_case",
[

View file

@ -10,6 +10,11 @@
"expected": "1963"
}
},
"stop_sequence": {
"data": {
"content": "Return the exact same sentence and don't add additional words): Michael Jordan was born in the year of 1963"
}
},
"streaming": {
"data": {
"content": "Roses are red,"

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,45 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
mock_session = MagicMock()
mock_session.closed = False
mock_session.close = AsyncMock()
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
mock_session.__aexit__ = AsyncMock()
@pytest.fixture(scope="session", autouse=True)
def patch_aiohttp_session():
with patch("aiohttp.ClientSession", return_value=mock_session):
yield
@pytest.fixture
def event_loop():
"""Create and provide a new event loop for each test."""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
yield loop
loop.close()
@pytest.fixture
def run_async():
"""Fixture to run async functions in tests."""
def _run_async(coro):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
return loop.run_until_complete(coro)
finally:
loop.close()
return _run_async

View file

@ -0,0 +1,271 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
import unittest
import warnings
from unittest.mock import patch
import pytest
from llama_stack_client.types.algorithm_config_param import LoraFinetuningConfig
from llama_stack_client.types.post_training_supervised_fine_tune_params import (
TrainingConfig,
TrainingConfigDataConfig,
TrainingConfigOptimizerConfig,
TrainingConfigEfficiencyConfig,
)
from llama_stack.providers.remote.post_training.nvidia.post_training import (
NvidiaPostTrainingAdapter,
NvidiaPostTrainingConfig,
)
class TestNvidiaParameters(unittest.TestCase):
def setUp(self):
os.environ["NVIDIA_BASE_URL"] = "http://nemo.test"
os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test"
config = NvidiaPostTrainingConfig(
base_url=os.environ["NVIDIA_BASE_URL"], customizer_url=os.environ["NVIDIA_CUSTOMIZER_URL"], api_key=None
)
self.adapter = NvidiaPostTrainingAdapter(config)
self.make_request_patcher = patch(
"llama_stack.providers.remote.post_training.nvidia.post_training.NvidiaPostTrainingAdapter._make_request"
)
self.mock_make_request = self.make_request_patcher.start()
self.mock_make_request.return_value = {
"id": "job-123",
"status": "created",
"created_at": "2025-03-04T13:07:47.543605",
"updated_at": "2025-03-04T13:07:47.543605",
}
def tearDown(self):
self.make_request_patcher.stop()
def _assert_request_params(self, expected_json):
"""Helper method to verify parameters in the request JSON."""
call_args = self.mock_make_request.call_args
actual_json = call_args[1]["json"]
for key, value in expected_json.items():
if isinstance(value, dict):
for nested_key, nested_value in value.items():
assert actual_json[key][nested_key] == nested_value
else:
assert actual_json[key] == value
@pytest.fixture(autouse=True)
def inject_fixtures(self, run_async):
self.run_async = run_async
def test_customizer_parameters_passed(self):
"""Test scenario 1: When an optional parameter is passed and value is correctly set."""
custom_adapter_dim = 32 # Different from default of 8
algorithm_config = LoraFinetuningConfig(
type="LoRA",
adapter_dim=custom_adapter_dim,
adapter_dropout=0.2,
apply_lora_to_mlp=True,
apply_lora_to_output=True,
alpha=16,
rank=16,
lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
)
data_config = TrainingConfigDataConfig(dataset_id="test-dataset", batch_size=16)
optimizer_config = TrainingConfigOptimizerConfig(lr=0.0002)
training_config = TrainingConfig(
n_epochs=3,
data_config=data_config,
optimizer_config=optimizer_config,
)
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
self.run_async(
self.adapter.supervised_fine_tune(
job_uuid="test-job",
model="meta-llama/Llama-3.1-8B-Instruct",
checkpoint_dir="",
algorithm_config=algorithm_config,
training_config=training_config,
logger_config={},
hyperparam_search_config={},
)
)
warning_texts = [str(warning.message) for warning in w]
fields = [
"apply_lora_to_output",
"lora_attn_modules",
"apply_lora_to_mlp",
]
for field in fields:
assert any(field in text for text in warning_texts)
self._assert_request_params(
{
"hyperparameters": {
"lora": {"adapter_dim": custom_adapter_dim, "adapter_dropout": 0.2, "alpha": 16},
"epochs": 3,
"learning_rate": 0.0002,
"batch_size": 16,
}
}
)
def test_required_parameters_passed(self):
"""Test scenario 2: When required parameters are passed."""
required_model = "meta-llama/Llama-3.1-8B-Instruct"
required_dataset_id = "required-dataset"
required_job_uuid = "required-job"
algorithm_config = LoraFinetuningConfig(
type="LoRA",
adapter_dim=16,
adapter_dropout=0.1,
apply_lora_to_mlp=True,
apply_lora_to_output=True,
alpha=16,
rank=16,
lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
)
data_config = TrainingConfigDataConfig(
dataset_id=required_dataset_id, # Required parameter
batch_size=8,
)
optimizer_config = TrainingConfigOptimizerConfig(lr=0.0001)
training_config = TrainingConfig(
n_epochs=1,
data_config=data_config,
optimizer_config=optimizer_config,
)
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
self.run_async(
self.adapter.supervised_fine_tune(
job_uuid=required_job_uuid, # Required parameter
model=required_model, # Required parameter
checkpoint_dir="",
algorithm_config=algorithm_config,
training_config=training_config,
logger_config={},
hyperparam_search_config={},
)
)
warning_texts = [str(warning.message) for warning in w]
fields = [
"rank",
"apply_lora_to_output",
"lora_attn_modules",
"apply_lora_to_mlp",
]
for field in fields:
assert any(field in text for text in warning_texts)
self.mock_make_request.assert_called_once()
call_args = self.mock_make_request.call_args
assert call_args[1]["json"]["config"] == "meta/llama-3.1-8b-instruct"
assert call_args[1]["json"]["dataset"]["name"] == required_dataset_id
def test_unsupported_parameters_warning(self):
"""Test that warnings are raised for unsupported parameters."""
data_config = TrainingConfigDataConfig(
dataset_id="test-dataset",
batch_size=8,
# Unsupported parameters
shuffle=True,
data_format="instruct",
validation_dataset_id="val-dataset",
)
optimizer_config = TrainingConfigOptimizerConfig(
lr=0.0001,
weight_decay=0.01,
# Unsupported parameters
optimizer_type="adam",
num_warmup_steps=100,
)
efficiency_config = TrainingConfigEfficiencyConfig(
enable_activation_checkpointing=True # Unsupported parameter
)
training_config = TrainingConfig(
n_epochs=1,
data_config=data_config,
optimizer_config=optimizer_config,
# Unsupported parameters
efficiency_config=efficiency_config,
max_steps_per_epoch=1000,
gradient_accumulation_steps=4,
max_validation_steps=100,
dtype="bf16",
)
# Capture warnings
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
self.run_async(
self.adapter.supervised_fine_tune(
job_uuid="test-job",
model="meta-llama/Llama-3.1-8B-Instruct",
checkpoint_dir="test-dir", # Unsupported parameter
algorithm_config=LoraFinetuningConfig(
type="LoRA",
adapter_dim=16,
adapter_dropout=0.1,
apply_lora_to_mlp=True,
apply_lora_to_output=True,
alpha=16,
rank=16,
lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
),
training_config=training_config,
logger_config={"test": "value"}, # Unsupported parameter
hyperparam_search_config={"test": "value"}, # Unsupported parameter
)
)
assert len(w) >= 4
warning_texts = [str(warning.message) for warning in w]
fields = [
"checkpoint_dir",
"hyperparam_search_config",
"logger_config",
"TrainingConfig",
"DataConfig",
"OptimizerConfig",
"max_steps_per_epoch",
"gradient_accumulation_steps",
"max_validation_steps",
"dtype",
# required unsupported parameters
"rank",
"apply_lora_to_output",
"lora_attn_modules",
"apply_lora_to_mlp",
]
for field in fields:
assert any(field in text for text in warning_texts)
if __name__ == "__main__":
unittest.main()

View file

@ -0,0 +1,295 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
import unittest
from unittest.mock import patch
import warnings
import pytest
from llama_stack_client.types.algorithm_config_param import LoraFinetuningConfig, QatFinetuningConfig
from llama_stack_client.types.post_training_supervised_fine_tune_params import (
TrainingConfig,
TrainingConfigDataConfig,
TrainingConfigOptimizerConfig,
)
from llama_stack.providers.remote.post_training.nvidia.post_training import (
NvidiaPostTrainingAdapter,
NvidiaPostTrainingConfig,
NvidiaPostTrainingJobStatusResponse,
ListNvidiaPostTrainingJobs,
NvidiaPostTrainingJob,
)
class TestNvidiaPostTraining(unittest.TestCase):
def setUp(self):
os.environ["NVIDIA_BASE_URL"] = "http://nemo.test" # needed for llm inference
os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test" # needed for nemo customizer
config = NvidiaPostTrainingConfig(
base_url=os.environ["NVIDIA_BASE_URL"], customizer_url=os.environ["NVIDIA_CUSTOMIZER_URL"], api_key=None
)
self.adapter = NvidiaPostTrainingAdapter(config)
self.make_request_patcher = patch(
"llama_stack.providers.remote.post_training.nvidia.post_training.NvidiaPostTrainingAdapter._make_request"
)
self.mock_make_request = self.make_request_patcher.start()
def tearDown(self):
self.make_request_patcher.stop()
@pytest.fixture(autouse=True)
def inject_fixtures(self, run_async):
self.run_async = run_async
def _assert_request(self, mock_call, expected_method, expected_path, expected_params=None, expected_json=None):
"""Helper method to verify request details in mock calls."""
call_args = mock_call.call_args
if expected_method and expected_path:
if isinstance(call_args[0], tuple) and len(call_args[0]) == 2:
assert call_args[0] == (expected_method, expected_path)
else:
assert call_args[1]["method"] == expected_method
assert call_args[1]["path"] == expected_path
if expected_params:
assert call_args[1]["params"] == expected_params
if expected_json:
for key, value in expected_json.items():
assert call_args[1]["json"][key] == value
def test_supervised_fine_tune(self):
"""Test the supervised fine-tuning API call."""
self.mock_make_request.return_value = {
"id": "cust-JGTaMbJMdqjJU8WbQdN9Q2",
"created_at": "2024-12-09T04:06:28.542884",
"updated_at": "2024-12-09T04:06:28.542884",
"config": {
"schema_version": "1.0",
"id": "af783f5b-d985-4e5b-bbb7-f9eec39cc0b1",
"created_at": "2024-12-09T04:06:28.542657",
"updated_at": "2024-12-09T04:06:28.569837",
"custom_fields": {},
"name": "meta-llama/Llama-3.1-8B-Instruct",
"base_model": "meta-llama/Llama-3.1-8B-Instruct",
"model_path": "llama-3_1-8b-instruct",
"training_types": [],
"finetuning_types": ["lora"],
"precision": "bf16",
"num_gpus": 4,
"num_nodes": 1,
"micro_batch_size": 1,
"tensor_parallel_size": 1,
"max_seq_length": 4096,
},
"dataset": {
"schema_version": "1.0",
"id": "dataset-XU4pvGzr5tvawnbVxeJMTb",
"created_at": "2024-12-09T04:06:28.542657",
"updated_at": "2024-12-09T04:06:28.542660",
"custom_fields": {},
"name": "sample-basic-test",
"version_id": "main",
"version_tags": [],
},
"hyperparameters": {
"finetuning_type": "lora",
"training_type": "sft",
"batch_size": 16,
"epochs": 2,
"learning_rate": 0.0001,
"lora": {"adapter_dim": 16, "adapter_dropout": 0.1},
},
"output_model": "default/job-1234",
"status": "created",
"project": "default",
"custom_fields": {},
"ownership": {"created_by": "me", "access_policies": {}},
}
algorithm_config = LoraFinetuningConfig(
type="LoRA",
adapter_dim=16,
adapter_dropout=0.1,
apply_lora_to_mlp=True,
apply_lora_to_output=True,
alpha=16,
rank=16,
lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
)
data_config = TrainingConfigDataConfig(dataset_id="sample-basic-test", batch_size=16)
optimizer_config = TrainingConfigOptimizerConfig(
lr=0.0001,
)
training_config = TrainingConfig(
n_epochs=2,
data_config=data_config,
optimizer_config=optimizer_config,
)
with warnings.catch_warnings(record=True):
warnings.simplefilter("always")
training_job = self.run_async(
self.adapter.supervised_fine_tune(
job_uuid="1234",
model="meta-llama/Llama-3.1-8B-Instruct",
checkpoint_dir="",
algorithm_config=algorithm_config,
training_config=training_config,
logger_config={},
hyperparam_search_config={},
)
)
# check the output is a PostTrainingJob
assert isinstance(training_job, NvidiaPostTrainingJob)
assert training_job.job_uuid == "cust-JGTaMbJMdqjJU8WbQdN9Q2"
self.mock_make_request.assert_called_once()
self._assert_request(
self.mock_make_request,
"POST",
"/v1/customization/jobs",
expected_json={
"config": "meta/llama-3.1-8b-instruct",
"dataset": {"name": "sample-basic-test", "namespace": "default"},
"hyperparameters": {
"training_type": "sft",
"finetuning_type": "lora",
"epochs": 2,
"batch_size": 16,
"learning_rate": 0.0001,
"lora": {"alpha": 16, "adapter_dim": 16, "adapter_dropout": 0.1},
},
},
)
def test_supervised_fine_tune_with_qat(self):
algorithm_config = QatFinetuningConfig(type="QAT", quantizer_name="quantizer_name", group_size=1)
data_config = TrainingConfigDataConfig(dataset_id="sample-basic-test", batch_size=16)
optimizer_config = TrainingConfigOptimizerConfig(
lr=0.0001,
)
training_config = TrainingConfig(
n_epochs=2,
data_config=data_config,
optimizer_config=optimizer_config,
)
# This will raise NotImplementedError since QAT is not supported
with self.assertRaises(NotImplementedError):
self.run_async(
self.adapter.supervised_fine_tune(
job_uuid="1234",
model="meta-llama/Llama-3.1-8B-Instruct",
checkpoint_dir="",
algorithm_config=algorithm_config,
training_config=training_config,
logger_config={},
hyperparam_search_config={},
)
)
def test_get_training_job_status(self):
self.mock_make_request.return_value = {
"created_at": "2024-12-09T04:06:28.580220",
"updated_at": "2024-12-09T04:21:19.852832",
"status": "completed",
"steps_completed": 1210,
"epochs_completed": 2,
"percentage_done": 100.0,
"best_epoch": 2,
"train_loss": 1.718016266822815,
"val_loss": 1.8661999702453613,
}
job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
status = self.run_async(self.adapter.get_training_job_status(job_uuid=job_id))
assert isinstance(status, NvidiaPostTrainingJobStatusResponse)
assert status.status.value == "completed"
assert status.steps_completed == 1210
assert status.epochs_completed == 2
assert status.percentage_done == 100.0
assert status.best_epoch == 2
assert status.train_loss == 1.718016266822815
assert status.val_loss == 1.8661999702453613
self.mock_make_request.assert_called_once()
self._assert_request(
self.mock_make_request, "GET", f"/v1/customization/jobs/{job_id}/status", expected_params={"job_id": job_id}
)
def test_get_training_jobs(self):
job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
self.mock_make_request.return_value = {
"data": [
{
"id": job_id,
"created_at": "2024-12-09T04:06:28.542884",
"updated_at": "2024-12-09T04:21:19.852832",
"config": {
"name": "meta-llama/Llama-3.1-8B-Instruct",
"base_model": "meta-llama/Llama-3.1-8B-Instruct",
},
"dataset": {"name": "default/sample-basic-test"},
"hyperparameters": {
"finetuning_type": "lora",
"training_type": "sft",
"batch_size": 16,
"epochs": 2,
"learning_rate": 0.0001,
"lora": {"adapter_dim": 16, "adapter_dropout": 0.1},
},
"output_model": "default/job-1234",
"status": "completed",
"project": "default",
}
]
}
jobs = self.run_async(self.adapter.get_training_jobs())
assert isinstance(jobs, ListNvidiaPostTrainingJobs)
assert len(jobs.data) == 1
job = jobs.data[0]
assert job.job_uuid == job_id
assert job.status.value == "completed"
self.mock_make_request.assert_called_once()
self._assert_request(
self.mock_make_request,
"GET",
"/v1/customization/jobs",
expected_params={"page": 1, "page_size": 10, "sort": "created_at"},
)
def test_cancel_training_job(self):
self.mock_make_request.return_value = {} # Empty response for successful cancellation
job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
result = self.run_async(self.adapter.cancel_training_job(job_uuid=job_id))
assert result is None
self.mock_make_request.assert_called_once()
self._assert_request(
self.mock_make_request,
"POST",
f"/v1/customization/jobs/{job_id}/cancel",
expected_params={"job_id": job_id},
)
if __name__ == "__main__":
unittest.main()

View file

@ -14,17 +14,10 @@ from llama_stack.distribution.utils.dynamic import instantiate_class_type
class TestProviderConfigurations:
"""Test suite for testing provider configurations across all API types."""
def test_all_api_providers_exist(self):
provider_registry = get_provider_registry()
for api in providable_apis():
providers = provider_registry.get(api, {})
assert providers, f"No providers found for API type: {api}"
@pytest.mark.parametrize("api", providable_apis())
def test_api_providers(self, api):
provider_registry = get_provider_registry()
providers = provider_registry.get(api, {})
assert providers, f"No providers found for API type: {api}"
failures = []
for provider_type, provider_spec in providers.items():

View file

@ -5,17 +5,16 @@
# the root directory of this source tree.
import asyncio
import sqlite3
import numpy as np
import pytest
import pytest_asyncio
import sqlite_vec
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import (
SQLiteVecIndex,
SQLiteVecVectorIOAdapter,
_create_sqlite_connection,
generate_chunk_id,
)
@ -36,29 +35,25 @@ def loop():
return asyncio.new_event_loop()
@pytest.fixture(scope="session", autouse=True)
def sqlite_connection(loop):
conn = sqlite3.connect(":memory:")
try:
conn.enable_load_extension(True)
sqlite_vec.load(conn)
yield conn
finally:
conn.close()
@pytest_asyncio.fixture(scope="session", autouse=True)
async def sqlite_vec_index(sqlite_connection, embedding_dimension):
return await SQLiteVecIndex.create(dimension=embedding_dimension, connection=sqlite_connection, bank_id="test_bank")
async def sqlite_vec_index(embedding_dimension, tmp_path_factory):
temp_dir = tmp_path_factory.getbasetemp()
db_path = str(temp_dir / "test_sqlite.db")
index = await SQLiteVecIndex.create(dimension=embedding_dimension, db_path=db_path, bank_id="test_bank")
yield index
await index.delete()
@pytest.mark.asyncio
async def test_add_chunks(sqlite_vec_index, sample_chunks, sample_embeddings):
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings, batch_size=2)
cur = sqlite_vec_index.connection.cursor()
connection = _create_sqlite_connection(sqlite_vec_index.db_path)
cur = connection.cursor()
cur.execute(f"SELECT COUNT(*) FROM {sqlite_vec_index.metadata_table}")
count = cur.fetchone()[0]
assert count == len(sample_chunks)
cur.close()
connection.close()
@pytest.mark.asyncio
@ -79,13 +74,14 @@ async def test_chunk_id_conflict(sqlite_vec_index, sample_chunks, embedding_dime
sample_embeddings = np.random.rand(len(sample_chunks), embedding_dimension).astype(np.float32)
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings, batch_size=batch_size)
cur = sqlite_vec_index.connection.cursor()
connection = _create_sqlite_connection(sqlite_vec_index.db_path)
cur = connection.cursor()
# Retrieve all chunk IDs to check for duplicates
cur.execute(f"SELECT id FROM {sqlite_vec_index.metadata_table}")
chunk_ids = [row[0] for row in cur.fetchall()]
cur.close()
connection.close()
# Ensure all chunk IDs are unique
assert len(chunk_ids) == len(set(chunk_ids)), "Duplicate chunk IDs detected across batches!"