mirror of
https://github.com/meta-llama/llama-stack.git
synced 2026-01-02 15:54:29 +00:00
Merge branch 'meta-llama:main' into add-unit-tests-and-fix-cli
This commit is contained in:
commit
696bcf6051
459 changed files with 39114 additions and 10751 deletions
175
tests/unit/providers/agents/test_persistence_access_control.py
Normal file
175
tests/unit/providers/agents/test_persistence_access_control.py
Normal file
|
|
@ -0,0 +1,175 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.agents import Turn
|
||||
from llama_stack.apis.inference import CompletionMessage, StopReason
|
||||
from llama_stack.distribution.datatypes import AccessAttributes
|
||||
from llama_stack.providers.inline.agents.meta_reference.persistence import AgentPersistence, AgentSessionInfo
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
from llama_stack.providers.utils.kvstore.sqlite import SqliteKVStoreImpl
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def test_setup():
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
db_path = os.path.join(temp_dir, "test_persistence_access_control.db")
|
||||
kvstore_config = SqliteKVStoreConfig(db_path=db_path)
|
||||
kvstore = SqliteKVStoreImpl(kvstore_config)
|
||||
await kvstore.initialize()
|
||||
agent_persistence = AgentPersistence(agent_id="test_agent", kvstore=kvstore)
|
||||
yield agent_persistence
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_auth_attributes")
|
||||
async def test_session_creation_with_access_attributes(mock_get_auth_attributes, test_setup):
|
||||
agent_persistence = test_setup
|
||||
|
||||
# Set creator's attributes for the session
|
||||
creator_attributes = {"roles": ["researcher"], "teams": ["ai-team"]}
|
||||
mock_get_auth_attributes.return_value = creator_attributes
|
||||
|
||||
# Create a session
|
||||
session_id = await agent_persistence.create_session("Test Session")
|
||||
|
||||
# Get the session and verify access attributes were set
|
||||
session_info = await agent_persistence.get_session_info(session_id)
|
||||
assert session_info is not None
|
||||
assert session_info.access_attributes is not None
|
||||
assert session_info.access_attributes.roles == ["researcher"]
|
||||
assert session_info.access_attributes.teams == ["ai-team"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_auth_attributes")
|
||||
async def test_session_access_control(mock_get_auth_attributes, test_setup):
|
||||
agent_persistence = test_setup
|
||||
|
||||
# Create a session with specific access attributes
|
||||
session_id = str(uuid.uuid4())
|
||||
session_info = AgentSessionInfo(
|
||||
session_id=session_id,
|
||||
session_name="Restricted Session",
|
||||
started_at=datetime.now(),
|
||||
access_attributes=AccessAttributes(roles=["admin"], teams=["security-team"]),
|
||||
)
|
||||
|
||||
await agent_persistence.kvstore.set(
|
||||
key=f"session:{agent_persistence.agent_id}:{session_id}",
|
||||
value=session_info.model_dump_json(),
|
||||
)
|
||||
|
||||
# User with matching attributes can access
|
||||
mock_get_auth_attributes.return_value = {"roles": ["admin", "user"], "teams": ["security-team", "other-team"]}
|
||||
retrieved_session = await agent_persistence.get_session_info(session_id)
|
||||
assert retrieved_session is not None
|
||||
assert retrieved_session.session_id == session_id
|
||||
|
||||
# User without matching attributes cannot access
|
||||
mock_get_auth_attributes.return_value = {"roles": ["user"], "teams": ["other-team"]}
|
||||
retrieved_session = await agent_persistence.get_session_info(session_id)
|
||||
assert retrieved_session is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_auth_attributes")
|
||||
async def test_turn_access_control(mock_get_auth_attributes, test_setup):
|
||||
agent_persistence = test_setup
|
||||
|
||||
# Create a session with restricted access
|
||||
session_id = str(uuid.uuid4())
|
||||
session_info = AgentSessionInfo(
|
||||
session_id=session_id,
|
||||
session_name="Restricted Session",
|
||||
started_at=datetime.now(),
|
||||
access_attributes=AccessAttributes(roles=["admin"]),
|
||||
)
|
||||
|
||||
await agent_persistence.kvstore.set(
|
||||
key=f"session:{agent_persistence.agent_id}:{session_id}",
|
||||
value=session_info.model_dump_json(),
|
||||
)
|
||||
|
||||
# Create a turn for this session
|
||||
turn_id = str(uuid.uuid4())
|
||||
turn = Turn(
|
||||
session_id=session_id,
|
||||
turn_id=turn_id,
|
||||
steps=[],
|
||||
started_at=datetime.now(),
|
||||
input_messages=[],
|
||||
output_message=CompletionMessage(
|
||||
content="Hello",
|
||||
stop_reason=StopReason.end_of_turn,
|
||||
),
|
||||
)
|
||||
|
||||
# Admin can add turn
|
||||
mock_get_auth_attributes.return_value = {"roles": ["admin"]}
|
||||
await agent_persistence.add_turn_to_session(session_id, turn)
|
||||
|
||||
# Admin can get turn
|
||||
retrieved_turn = await agent_persistence.get_session_turn(session_id, turn_id)
|
||||
assert retrieved_turn is not None
|
||||
assert retrieved_turn.turn_id == turn_id
|
||||
|
||||
# Regular user cannot get turn
|
||||
mock_get_auth_attributes.return_value = {"roles": ["user"]}
|
||||
with pytest.raises(ValueError):
|
||||
await agent_persistence.get_session_turn(session_id, turn_id)
|
||||
|
||||
# Regular user cannot get turns for session
|
||||
with pytest.raises(ValueError):
|
||||
await agent_persistence.get_session_turns(session_id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_auth_attributes")
|
||||
async def test_tool_call_and_infer_iters_access_control(mock_get_auth_attributes, test_setup):
|
||||
agent_persistence = test_setup
|
||||
|
||||
# Create a session with restricted access
|
||||
session_id = str(uuid.uuid4())
|
||||
session_info = AgentSessionInfo(
|
||||
session_id=session_id,
|
||||
session_name="Restricted Session",
|
||||
started_at=datetime.now(),
|
||||
access_attributes=AccessAttributes(roles=["admin"]),
|
||||
)
|
||||
|
||||
await agent_persistence.kvstore.set(
|
||||
key=f"session:{agent_persistence.agent_id}:{session_id}",
|
||||
value=session_info.model_dump_json(),
|
||||
)
|
||||
|
||||
turn_id = str(uuid.uuid4())
|
||||
|
||||
# Admin user can set inference iterations
|
||||
mock_get_auth_attributes.return_value = {"roles": ["admin"]}
|
||||
await agent_persistence.set_num_infer_iters_in_turn(session_id, turn_id, 5)
|
||||
|
||||
# Admin user can get inference iterations
|
||||
infer_iters = await agent_persistence.get_num_infer_iters_in_turn(session_id, turn_id)
|
||||
assert infer_iters == 5
|
||||
|
||||
# Regular user cannot get inference iterations
|
||||
mock_get_auth_attributes.return_value = {"roles": ["user"]}
|
||||
infer_iters = await agent_persistence.get_num_infer_iters_in_turn(session_id, turn_id)
|
||||
assert infer_iters is None
|
||||
|
||||
# Regular user cannot set inference iterations (should raise ValueError)
|
||||
with pytest.raises(ValueError):
|
||||
await agent_persistence.set_num_infer_iters_in_turn(session_id, turn_id, 10)
|
||||
|
|
@ -187,8 +187,8 @@ def test_chat_completion_doesnt_block_event_loop(caplog):
|
|||
loop.set_debug(True)
|
||||
caplog.set_level(logging.WARNING)
|
||||
|
||||
# Log when event loop is blocked for more than 100ms
|
||||
loop.slow_callback_duration = 0.1
|
||||
# Log when event loop is blocked for more than 200ms
|
||||
loop.slow_callback_duration = 0.5
|
||||
# Sleep for 500ms in our delayed http response
|
||||
sleep_time = 0.5
|
||||
|
||||
|
|
|
|||
5
tests/unit/providers/nvidia/__init__.py
Normal file
5
tests/unit/providers/nvidia/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
46
tests/unit/providers/nvidia/conftest.py
Normal file
46
tests/unit/providers/nvidia/conftest.py
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.closed = False
|
||||
mock_session.close = AsyncMock()
|
||||
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_session.__aexit__ = AsyncMock()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def patch_aiohttp_session():
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def event_loop():
|
||||
"""Create and provide a new event loop for each test."""
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
yield loop
|
||||
loop.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def run_async():
|
||||
"""Fixture to run async functions in tests."""
|
||||
|
||||
def _run_async(coro):
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
return loop.run_until_complete(coro)
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
return _run_async
|
||||
272
tests/unit/providers/nvidia/test_parameters.py
Normal file
272
tests/unit/providers/nvidia/test_parameters.py
Normal file
|
|
@ -0,0 +1,272 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import unittest
|
||||
import warnings
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from llama_stack_client.types.algorithm_config_param import LoraFinetuningConfig
|
||||
from llama_stack_client.types.post_training_supervised_fine_tune_params import (
|
||||
TrainingConfig,
|
||||
TrainingConfigDataConfig,
|
||||
TrainingConfigEfficiencyConfig,
|
||||
TrainingConfigOptimizerConfig,
|
||||
)
|
||||
|
||||
from llama_stack.providers.remote.post_training.nvidia.post_training import (
|
||||
NvidiaPostTrainingAdapter,
|
||||
NvidiaPostTrainingConfig,
|
||||
)
|
||||
|
||||
|
||||
class TestNvidiaParameters(unittest.TestCase):
|
||||
def setUp(self):
|
||||
os.environ["NVIDIA_BASE_URL"] = "http://nemo.test"
|
||||
os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test"
|
||||
|
||||
config = NvidiaPostTrainingConfig(
|
||||
base_url=os.environ["NVIDIA_BASE_URL"], customizer_url=os.environ["NVIDIA_CUSTOMIZER_URL"], api_key=None
|
||||
)
|
||||
self.adapter = NvidiaPostTrainingAdapter(config)
|
||||
|
||||
self.make_request_patcher = patch(
|
||||
"llama_stack.providers.remote.post_training.nvidia.post_training.NvidiaPostTrainingAdapter._make_request"
|
||||
)
|
||||
self.mock_make_request = self.make_request_patcher.start()
|
||||
self.mock_make_request.return_value = {
|
||||
"id": "job-123",
|
||||
"status": "created",
|
||||
"created_at": "2025-03-04T13:07:47.543605",
|
||||
"updated_at": "2025-03-04T13:07:47.543605",
|
||||
}
|
||||
|
||||
def tearDown(self):
|
||||
self.make_request_patcher.stop()
|
||||
|
||||
def _assert_request_params(self, expected_json):
|
||||
"""Helper method to verify parameters in the request JSON."""
|
||||
call_args = self.mock_make_request.call_args
|
||||
actual_json = call_args[1]["json"]
|
||||
|
||||
for key, value in expected_json.items():
|
||||
if isinstance(value, dict):
|
||||
for nested_key, nested_value in value.items():
|
||||
assert actual_json[key][nested_key] == nested_value
|
||||
else:
|
||||
assert actual_json[key] == value
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def inject_fixtures(self, run_async):
|
||||
self.run_async = run_async
|
||||
|
||||
def test_customizer_parameters_passed(self):
|
||||
"""Test scenario 1: When an optional parameter is passed and value is correctly set."""
|
||||
custom_adapter_dim = 32 # Different from default of 8
|
||||
algorithm_config = LoraFinetuningConfig(
|
||||
type="LoRA",
|
||||
adapter_dim=custom_adapter_dim,
|
||||
adapter_dropout=0.2,
|
||||
apply_lora_to_mlp=True,
|
||||
apply_lora_to_output=True,
|
||||
alpha=16,
|
||||
rank=16,
|
||||
lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
|
||||
)
|
||||
|
||||
data_config = TrainingConfigDataConfig(dataset_id="test-dataset", batch_size=16)
|
||||
optimizer_config = TrainingConfigOptimizerConfig(lr=0.0002)
|
||||
training_config = TrainingConfig(
|
||||
n_epochs=3,
|
||||
data_config=data_config,
|
||||
optimizer_config=optimizer_config,
|
||||
)
|
||||
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
|
||||
self.run_async(
|
||||
self.adapter.supervised_fine_tune(
|
||||
job_uuid="test-job",
|
||||
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||
checkpoint_dir="",
|
||||
algorithm_config=algorithm_config,
|
||||
training_config=training_config,
|
||||
logger_config={},
|
||||
hyperparam_search_config={},
|
||||
)
|
||||
)
|
||||
|
||||
warning_texts = [str(warning.message) for warning in w]
|
||||
|
||||
fields = [
|
||||
"apply_lora_to_output",
|
||||
"lora_attn_modules",
|
||||
"apply_lora_to_mlp",
|
||||
]
|
||||
for field in fields:
|
||||
assert any(field in text for text in warning_texts)
|
||||
|
||||
self._assert_request_params(
|
||||
{
|
||||
"hyperparameters": {
|
||||
"lora": {"adapter_dim": custom_adapter_dim, "adapter_dropout": 0.2, "alpha": 16},
|
||||
"epochs": 3,
|
||||
"learning_rate": 0.0002,
|
||||
"batch_size": 16,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
def test_required_parameters_passed(self):
|
||||
"""Test scenario 2: When required parameters are passed."""
|
||||
required_model = "meta-llama/Llama-3.1-8B-Instruct"
|
||||
required_dataset_id = "required-dataset"
|
||||
required_job_uuid = "required-job"
|
||||
|
||||
algorithm_config = LoraFinetuningConfig(
|
||||
type="LoRA",
|
||||
adapter_dim=16,
|
||||
adapter_dropout=0.1,
|
||||
apply_lora_to_mlp=True,
|
||||
apply_lora_to_output=True,
|
||||
alpha=16,
|
||||
rank=16,
|
||||
lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
|
||||
)
|
||||
|
||||
data_config = TrainingConfigDataConfig(
|
||||
dataset_id=required_dataset_id, # Required parameter
|
||||
batch_size=8,
|
||||
)
|
||||
|
||||
optimizer_config = TrainingConfigOptimizerConfig(lr=0.0001)
|
||||
|
||||
training_config = TrainingConfig(
|
||||
n_epochs=1,
|
||||
data_config=data_config,
|
||||
optimizer_config=optimizer_config,
|
||||
)
|
||||
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
|
||||
self.run_async(
|
||||
self.adapter.supervised_fine_tune(
|
||||
job_uuid=required_job_uuid, # Required parameter
|
||||
model=required_model, # Required parameter
|
||||
checkpoint_dir="",
|
||||
algorithm_config=algorithm_config,
|
||||
training_config=training_config,
|
||||
logger_config={},
|
||||
hyperparam_search_config={},
|
||||
)
|
||||
)
|
||||
|
||||
warning_texts = [str(warning.message) for warning in w]
|
||||
|
||||
fields = [
|
||||
"rank",
|
||||
"apply_lora_to_output",
|
||||
"lora_attn_modules",
|
||||
"apply_lora_to_mlp",
|
||||
]
|
||||
for field in fields:
|
||||
assert any(field in text for text in warning_texts)
|
||||
|
||||
self.mock_make_request.assert_called_once()
|
||||
call_args = self.mock_make_request.call_args
|
||||
|
||||
assert call_args[1]["json"]["config"] == "meta/llama-3.1-8b-instruct"
|
||||
assert call_args[1]["json"]["dataset"]["name"] == required_dataset_id
|
||||
|
||||
def test_unsupported_parameters_warning(self):
|
||||
"""Test that warnings are raised for unsupported parameters."""
|
||||
data_config = TrainingConfigDataConfig(
|
||||
dataset_id="test-dataset",
|
||||
batch_size=8,
|
||||
# Unsupported parameters
|
||||
shuffle=True,
|
||||
data_format="instruct",
|
||||
validation_dataset_id="val-dataset",
|
||||
)
|
||||
|
||||
optimizer_config = TrainingConfigOptimizerConfig(
|
||||
lr=0.0001,
|
||||
weight_decay=0.01,
|
||||
# Unsupported parameters
|
||||
optimizer_type="adam",
|
||||
num_warmup_steps=100,
|
||||
)
|
||||
|
||||
efficiency_config = TrainingConfigEfficiencyConfig(
|
||||
enable_activation_checkpointing=True # Unsupported parameter
|
||||
)
|
||||
|
||||
training_config = TrainingConfig(
|
||||
n_epochs=1,
|
||||
data_config=data_config,
|
||||
optimizer_config=optimizer_config,
|
||||
# Unsupported parameters
|
||||
efficiency_config=efficiency_config,
|
||||
max_steps_per_epoch=1000,
|
||||
gradient_accumulation_steps=4,
|
||||
max_validation_steps=100,
|
||||
dtype="bf16",
|
||||
)
|
||||
|
||||
# Capture warnings
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
|
||||
self.run_async(
|
||||
self.adapter.supervised_fine_tune(
|
||||
job_uuid="test-job",
|
||||
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||
checkpoint_dir="test-dir", # Unsupported parameter
|
||||
algorithm_config=LoraFinetuningConfig(
|
||||
type="LoRA",
|
||||
adapter_dim=16,
|
||||
adapter_dropout=0.1,
|
||||
apply_lora_to_mlp=True,
|
||||
apply_lora_to_output=True,
|
||||
alpha=16,
|
||||
rank=16,
|
||||
lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
|
||||
),
|
||||
training_config=training_config,
|
||||
logger_config={"test": "value"}, # Unsupported parameter
|
||||
hyperparam_search_config={"test": "value"}, # Unsupported parameter
|
||||
)
|
||||
)
|
||||
|
||||
assert len(w) >= 4
|
||||
warning_texts = [str(warning.message) for warning in w]
|
||||
|
||||
fields = [
|
||||
"checkpoint_dir",
|
||||
"hyperparam_search_config",
|
||||
"logger_config",
|
||||
"TrainingConfig",
|
||||
"DataConfig",
|
||||
"OptimizerConfig",
|
||||
"max_steps_per_epoch",
|
||||
"gradient_accumulation_steps",
|
||||
"max_validation_steps",
|
||||
"dtype",
|
||||
# required unsupported parameters
|
||||
"rank",
|
||||
"apply_lora_to_output",
|
||||
"lora_attn_modules",
|
||||
"apply_lora_to_mlp",
|
||||
]
|
||||
for field in fields:
|
||||
assert any(field in text for text in warning_texts)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
295
tests/unit/providers/nvidia/test_supervised_fine_tuning.py
Normal file
295
tests/unit/providers/nvidia/test_supervised_fine_tuning.py
Normal file
|
|
@ -0,0 +1,295 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import unittest
|
||||
import warnings
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from llama_stack_client.types.algorithm_config_param import LoraFinetuningConfig, QatFinetuningConfig
|
||||
from llama_stack_client.types.post_training_supervised_fine_tune_params import (
|
||||
TrainingConfig,
|
||||
TrainingConfigDataConfig,
|
||||
TrainingConfigOptimizerConfig,
|
||||
)
|
||||
|
||||
from llama_stack.providers.remote.post_training.nvidia.post_training import (
|
||||
ListNvidiaPostTrainingJobs,
|
||||
NvidiaPostTrainingAdapter,
|
||||
NvidiaPostTrainingConfig,
|
||||
NvidiaPostTrainingJob,
|
||||
NvidiaPostTrainingJobStatusResponse,
|
||||
)
|
||||
|
||||
|
||||
class TestNvidiaPostTraining(unittest.TestCase):
|
||||
def setUp(self):
|
||||
os.environ["NVIDIA_BASE_URL"] = "http://nemo.test" # needed for llm inference
|
||||
os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test" # needed for nemo customizer
|
||||
|
||||
config = NvidiaPostTrainingConfig(
|
||||
base_url=os.environ["NVIDIA_BASE_URL"], customizer_url=os.environ["NVIDIA_CUSTOMIZER_URL"], api_key=None
|
||||
)
|
||||
self.adapter = NvidiaPostTrainingAdapter(config)
|
||||
self.make_request_patcher = patch(
|
||||
"llama_stack.providers.remote.post_training.nvidia.post_training.NvidiaPostTrainingAdapter._make_request"
|
||||
)
|
||||
self.mock_make_request = self.make_request_patcher.start()
|
||||
|
||||
def tearDown(self):
|
||||
self.make_request_patcher.stop()
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def inject_fixtures(self, run_async):
|
||||
self.run_async = run_async
|
||||
|
||||
def _assert_request(self, mock_call, expected_method, expected_path, expected_params=None, expected_json=None):
|
||||
"""Helper method to verify request details in mock calls."""
|
||||
call_args = mock_call.call_args
|
||||
|
||||
if expected_method and expected_path:
|
||||
if isinstance(call_args[0], tuple) and len(call_args[0]) == 2:
|
||||
assert call_args[0] == (expected_method, expected_path)
|
||||
else:
|
||||
assert call_args[1]["method"] == expected_method
|
||||
assert call_args[1]["path"] == expected_path
|
||||
|
||||
if expected_params:
|
||||
assert call_args[1]["params"] == expected_params
|
||||
|
||||
if expected_json:
|
||||
for key, value in expected_json.items():
|
||||
assert call_args[1]["json"][key] == value
|
||||
|
||||
def test_supervised_fine_tune(self):
|
||||
"""Test the supervised fine-tuning API call."""
|
||||
self.mock_make_request.return_value = {
|
||||
"id": "cust-JGTaMbJMdqjJU8WbQdN9Q2",
|
||||
"created_at": "2024-12-09T04:06:28.542884",
|
||||
"updated_at": "2024-12-09T04:06:28.542884",
|
||||
"config": {
|
||||
"schema_version": "1.0",
|
||||
"id": "af783f5b-d985-4e5b-bbb7-f9eec39cc0b1",
|
||||
"created_at": "2024-12-09T04:06:28.542657",
|
||||
"updated_at": "2024-12-09T04:06:28.569837",
|
||||
"custom_fields": {},
|
||||
"name": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"base_model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"model_path": "llama-3_1-8b-instruct",
|
||||
"training_types": [],
|
||||
"finetuning_types": ["lora"],
|
||||
"precision": "bf16",
|
||||
"num_gpus": 4,
|
||||
"num_nodes": 1,
|
||||
"micro_batch_size": 1,
|
||||
"tensor_parallel_size": 1,
|
||||
"max_seq_length": 4096,
|
||||
},
|
||||
"dataset": {
|
||||
"schema_version": "1.0",
|
||||
"id": "dataset-XU4pvGzr5tvawnbVxeJMTb",
|
||||
"created_at": "2024-12-09T04:06:28.542657",
|
||||
"updated_at": "2024-12-09T04:06:28.542660",
|
||||
"custom_fields": {},
|
||||
"name": "sample-basic-test",
|
||||
"version_id": "main",
|
||||
"version_tags": [],
|
||||
},
|
||||
"hyperparameters": {
|
||||
"finetuning_type": "lora",
|
||||
"training_type": "sft",
|
||||
"batch_size": 16,
|
||||
"epochs": 2,
|
||||
"learning_rate": 0.0001,
|
||||
"lora": {"adapter_dim": 16, "adapter_dropout": 0.1},
|
||||
},
|
||||
"output_model": "default/job-1234",
|
||||
"status": "created",
|
||||
"project": "default",
|
||||
"custom_fields": {},
|
||||
"ownership": {"created_by": "me", "access_policies": {}},
|
||||
}
|
||||
|
||||
algorithm_config = LoraFinetuningConfig(
|
||||
type="LoRA",
|
||||
adapter_dim=16,
|
||||
adapter_dropout=0.1,
|
||||
apply_lora_to_mlp=True,
|
||||
apply_lora_to_output=True,
|
||||
alpha=16,
|
||||
rank=16,
|
||||
lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
|
||||
)
|
||||
|
||||
data_config = TrainingConfigDataConfig(dataset_id="sample-basic-test", batch_size=16)
|
||||
|
||||
optimizer_config = TrainingConfigOptimizerConfig(
|
||||
lr=0.0001,
|
||||
)
|
||||
|
||||
training_config = TrainingConfig(
|
||||
n_epochs=2,
|
||||
data_config=data_config,
|
||||
optimizer_config=optimizer_config,
|
||||
)
|
||||
|
||||
with warnings.catch_warnings(record=True):
|
||||
warnings.simplefilter("always")
|
||||
training_job = self.run_async(
|
||||
self.adapter.supervised_fine_tune(
|
||||
job_uuid="1234",
|
||||
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||
checkpoint_dir="",
|
||||
algorithm_config=algorithm_config,
|
||||
training_config=training_config,
|
||||
logger_config={},
|
||||
hyperparam_search_config={},
|
||||
)
|
||||
)
|
||||
|
||||
# check the output is a PostTrainingJob
|
||||
assert isinstance(training_job, NvidiaPostTrainingJob)
|
||||
assert training_job.job_uuid == "cust-JGTaMbJMdqjJU8WbQdN9Q2"
|
||||
|
||||
self.mock_make_request.assert_called_once()
|
||||
self._assert_request(
|
||||
self.mock_make_request,
|
||||
"POST",
|
||||
"/v1/customization/jobs",
|
||||
expected_json={
|
||||
"config": "meta/llama-3.1-8b-instruct",
|
||||
"dataset": {"name": "sample-basic-test", "namespace": "default"},
|
||||
"hyperparameters": {
|
||||
"training_type": "sft",
|
||||
"finetuning_type": "lora",
|
||||
"epochs": 2,
|
||||
"batch_size": 16,
|
||||
"learning_rate": 0.0001,
|
||||
"lora": {"alpha": 16, "adapter_dim": 16, "adapter_dropout": 0.1},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
def test_supervised_fine_tune_with_qat(self):
|
||||
algorithm_config = QatFinetuningConfig(type="QAT", quantizer_name="quantizer_name", group_size=1)
|
||||
data_config = TrainingConfigDataConfig(dataset_id="sample-basic-test", batch_size=16)
|
||||
optimizer_config = TrainingConfigOptimizerConfig(
|
||||
lr=0.0001,
|
||||
)
|
||||
training_config = TrainingConfig(
|
||||
n_epochs=2,
|
||||
data_config=data_config,
|
||||
optimizer_config=optimizer_config,
|
||||
)
|
||||
# This will raise NotImplementedError since QAT is not supported
|
||||
with self.assertRaises(NotImplementedError):
|
||||
self.run_async(
|
||||
self.adapter.supervised_fine_tune(
|
||||
job_uuid="1234",
|
||||
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||
checkpoint_dir="",
|
||||
algorithm_config=algorithm_config,
|
||||
training_config=training_config,
|
||||
logger_config={},
|
||||
hyperparam_search_config={},
|
||||
)
|
||||
)
|
||||
|
||||
def test_get_training_job_status(self):
|
||||
self.mock_make_request.return_value = {
|
||||
"created_at": "2024-12-09T04:06:28.580220",
|
||||
"updated_at": "2024-12-09T04:21:19.852832",
|
||||
"status": "completed",
|
||||
"steps_completed": 1210,
|
||||
"epochs_completed": 2,
|
||||
"percentage_done": 100.0,
|
||||
"best_epoch": 2,
|
||||
"train_loss": 1.718016266822815,
|
||||
"val_loss": 1.8661999702453613,
|
||||
}
|
||||
|
||||
job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
|
||||
|
||||
status = self.run_async(self.adapter.get_training_job_status(job_uuid=job_id))
|
||||
|
||||
assert isinstance(status, NvidiaPostTrainingJobStatusResponse)
|
||||
assert status.status.value == "completed"
|
||||
assert status.steps_completed == 1210
|
||||
assert status.epochs_completed == 2
|
||||
assert status.percentage_done == 100.0
|
||||
assert status.best_epoch == 2
|
||||
assert status.train_loss == 1.718016266822815
|
||||
assert status.val_loss == 1.8661999702453613
|
||||
|
||||
self.mock_make_request.assert_called_once()
|
||||
self._assert_request(
|
||||
self.mock_make_request, "GET", f"/v1/customization/jobs/{job_id}/status", expected_params={"job_id": job_id}
|
||||
)
|
||||
|
||||
def test_get_training_jobs(self):
|
||||
job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
|
||||
self.mock_make_request.return_value = {
|
||||
"data": [
|
||||
{
|
||||
"id": job_id,
|
||||
"created_at": "2024-12-09T04:06:28.542884",
|
||||
"updated_at": "2024-12-09T04:21:19.852832",
|
||||
"config": {
|
||||
"name": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"base_model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
},
|
||||
"dataset": {"name": "default/sample-basic-test"},
|
||||
"hyperparameters": {
|
||||
"finetuning_type": "lora",
|
||||
"training_type": "sft",
|
||||
"batch_size": 16,
|
||||
"epochs": 2,
|
||||
"learning_rate": 0.0001,
|
||||
"lora": {"adapter_dim": 16, "adapter_dropout": 0.1},
|
||||
},
|
||||
"output_model": "default/job-1234",
|
||||
"status": "completed",
|
||||
"project": "default",
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
jobs = self.run_async(self.adapter.get_training_jobs())
|
||||
|
||||
assert isinstance(jobs, ListNvidiaPostTrainingJobs)
|
||||
assert len(jobs.data) == 1
|
||||
job = jobs.data[0]
|
||||
assert job.job_uuid == job_id
|
||||
assert job.status.value == "completed"
|
||||
|
||||
self.mock_make_request.assert_called_once()
|
||||
self._assert_request(
|
||||
self.mock_make_request,
|
||||
"GET",
|
||||
"/v1/customization/jobs",
|
||||
expected_params={"page": 1, "page_size": 10, "sort": "created_at"},
|
||||
)
|
||||
|
||||
def test_cancel_training_job(self):
|
||||
self.mock_make_request.return_value = {} # Empty response for successful cancellation
|
||||
job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
|
||||
|
||||
result = self.run_async(self.adapter.cancel_training_job(job_uuid=job_id))
|
||||
|
||||
assert result is None
|
||||
|
||||
self.mock_make_request.assert_called_once()
|
||||
self._assert_request(
|
||||
self.mock_make_request,
|
||||
"POST",
|
||||
f"/v1/customization/jobs/{job_id}/cancel",
|
||||
expected_params={"job_id": job_id},
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
43
tests/unit/providers/test_configs.py
Normal file
43
tests/unit/providers/test_configs.py
Normal file
|
|
@ -0,0 +1,43 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.distribution.distribution import get_provider_registry, providable_apis
|
||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||
|
||||
|
||||
class TestProviderConfigurations:
|
||||
"""Test suite for testing provider configurations across all API types."""
|
||||
|
||||
@pytest.mark.parametrize("api", providable_apis())
|
||||
def test_api_providers(self, api):
|
||||
provider_registry = get_provider_registry()
|
||||
providers = provider_registry.get(api, {})
|
||||
|
||||
failures = []
|
||||
for provider_type, provider_spec in providers.items():
|
||||
try:
|
||||
self._verify_provider_config(provider_type, provider_spec)
|
||||
except Exception as e:
|
||||
failures.append(f"Failed to verify {provider_type} config: {str(e)}")
|
||||
|
||||
if failures:
|
||||
pytest.fail("\n".join(failures))
|
||||
|
||||
def _verify_provider_config(self, provider_type, provider_spec):
|
||||
"""Helper method to verify a single provider configuration."""
|
||||
# Get the config class
|
||||
config_class_name = provider_spec.config_class
|
||||
config_type = instantiate_class_type(config_class_name)
|
||||
|
||||
assert issubclass(config_type, BaseModel), f"{config_class_name} is not a subclass of BaseModel"
|
||||
|
||||
assert hasattr(config_type, "sample_run_config"), f"{config_class_name} does not have sample_run_config method"
|
||||
|
||||
sample_config = config_type.sample_run_config(__distro_dir__="foobarbaz")
|
||||
assert isinstance(sample_config, dict), f"{config_class_name}.sample_run_config() did not return a dict"
|
||||
42
tests/unit/providers/vector_io/conftest.py
Normal file
42
tests/unit/providers/vector_io/conftest.py
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.vector_io import Chunk
|
||||
|
||||
EMBEDDING_DIMENSION = 384
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vector_db_id() -> str:
|
||||
return f"test-vector-db-{random.randint(1, 100)}"
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def embedding_dimension() -> int:
|
||||
return EMBEDDING_DIMENSION
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def sample_chunks():
|
||||
"""Generates chunks that force multiple batches for a single document to expose ID conflicts."""
|
||||
n, k = 10, 3
|
||||
sample = [
|
||||
Chunk(content=f"Sentence {i} from document {j}", metadata={"document_id": f"document-{j}"})
|
||||
for j in range(k)
|
||||
for i in range(n)
|
||||
]
|
||||
return sample
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def sample_embeddings(sample_chunks):
|
||||
np.random.seed(42)
|
||||
return np.array([np.random.rand(EMBEDDING_DIMENSION).astype(np.float32) for _ in sample_chunks])
|
||||
135
tests/unit/providers/vector_io/test_qdrant.py
Normal file
135
tests/unit/providers/vector_io/test_qdrant.py
Normal file
|
|
@ -0,0 +1,135 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from llama_stack.apis.inference import EmbeddingsResponse, Inference
|
||||
from llama_stack.apis.vector_io import (
|
||||
QueryChunksResponse,
|
||||
VectorDB,
|
||||
VectorDBStore,
|
||||
)
|
||||
from llama_stack.providers.inline.vector_io.qdrant.config import (
|
||||
QdrantVectorIOConfig as InlineQdrantVectorIOConfig,
|
||||
)
|
||||
from llama_stack.providers.remote.vector_io.qdrant.qdrant import (
|
||||
QdrantVectorIOAdapter,
|
||||
)
|
||||
|
||||
# This test is a unit test for the QdrantVectorIOAdapter class. This should only contain
|
||||
# tests which are specific to this class. More general (API-level) tests should be placed in
|
||||
# tests/integration/vector_io/
|
||||
#
|
||||
# How to run this test:
|
||||
#
|
||||
# pytest tests/unit/providers/vector_io/test_qdrant.py \
|
||||
# -v -s --tb=short --disable-warnings --asyncio-mode=auto
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def qdrant_config(tmp_path) -> InlineQdrantVectorIOConfig:
|
||||
return InlineQdrantVectorIOConfig(path=os.path.join(tmp_path, "qdrant.db"))
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def loop():
|
||||
return asyncio.new_event_loop()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vector_db(vector_db_id) -> MagicMock:
|
||||
mock_vector_db = MagicMock(spec=VectorDB)
|
||||
mock_vector_db.embedding_model = "embedding_model"
|
||||
mock_vector_db.identifier = vector_db_id
|
||||
return mock_vector_db
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vector_db_store(mock_vector_db) -> MagicMock:
|
||||
mock_store = MagicMock(spec=VectorDBStore)
|
||||
mock_store.get_vector_db = AsyncMock(return_value=mock_vector_db)
|
||||
return mock_store
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_api_service(sample_embeddings):
|
||||
mock_api_service = MagicMock(spec=Inference)
|
||||
mock_api_service.embeddings = AsyncMock(return_value=EmbeddingsResponse(embeddings=sample_embeddings))
|
||||
return mock_api_service
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def qdrant_adapter(qdrant_config, mock_vector_db_store, mock_api_service, loop) -> QdrantVectorIOAdapter:
|
||||
adapter = QdrantVectorIOAdapter(config=qdrant_config, inference_api=mock_api_service)
|
||||
adapter.vector_db_store = mock_vector_db_store
|
||||
await adapter.initialize()
|
||||
yield adapter
|
||||
await adapter.shutdown()
|
||||
|
||||
|
||||
__QUERY = "Sample query"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("max_query_chunks, expected_chunks", [(2, 2), (100, 30)])
|
||||
async def test_qdrant_adapter_returns_expected_chunks(
|
||||
qdrant_adapter: QdrantVectorIOAdapter,
|
||||
vector_db_id,
|
||||
sample_chunks,
|
||||
sample_embeddings,
|
||||
max_query_chunks,
|
||||
expected_chunks,
|
||||
) -> None:
|
||||
assert qdrant_adapter is not None
|
||||
await qdrant_adapter.insert_chunks(vector_db_id, sample_chunks)
|
||||
|
||||
index = await qdrant_adapter._get_and_cache_vector_db_index(vector_db_id=vector_db_id)
|
||||
assert index is not None
|
||||
|
||||
response = await qdrant_adapter.query_chunks(
|
||||
query=__QUERY,
|
||||
vector_db_id=vector_db_id,
|
||||
params={"max_chunks": max_query_chunks},
|
||||
)
|
||||
assert isinstance(response, QueryChunksResponse)
|
||||
assert len(response.chunks) == expected_chunks
|
||||
|
||||
|
||||
# To by-pass attempt to convert a Mock to JSON
|
||||
def _prepare_for_json(value: Any) -> str:
|
||||
return str(value)
|
||||
|
||||
|
||||
@patch("llama_stack.providers.utils.telemetry.trace_protocol._prepare_for_json", new=_prepare_for_json)
|
||||
@pytest.mark.asyncio
|
||||
async def test_qdrant_register_and_unregister_vector_db(
|
||||
qdrant_adapter: QdrantVectorIOAdapter,
|
||||
mock_vector_db,
|
||||
sample_chunks,
|
||||
) -> None:
|
||||
# Initially, no collections
|
||||
vector_db_id = mock_vector_db.identifier
|
||||
assert len((await qdrant_adapter.client.get_collections()).collections) == 0
|
||||
|
||||
# Register does not create a collection
|
||||
assert not (await qdrant_adapter.client.collection_exists(vector_db_id))
|
||||
await qdrant_adapter.register_vector_db(mock_vector_db)
|
||||
assert not (await qdrant_adapter.client.collection_exists(vector_db_id))
|
||||
|
||||
# First insert creates the collection
|
||||
await qdrant_adapter.insert_chunks(vector_db_id, sample_chunks)
|
||||
assert await qdrant_adapter.client.collection_exists(vector_db_id)
|
||||
|
||||
# Unregister deletes the collection
|
||||
await qdrant_adapter.unregister_vector_db(vector_db_id)
|
||||
assert not (await qdrant_adapter.client.collection_exists(vector_db_id))
|
||||
assert len((await qdrant_adapter.client.get_collections()).collections) == 0
|
||||
|
|
@ -5,17 +5,16 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import sqlite3
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
import sqlite_vec
|
||||
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse
|
||||
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import (
|
||||
SQLiteVecIndex,
|
||||
SQLiteVecVectorIOAdapter,
|
||||
_create_sqlite_connection,
|
||||
generate_chunk_id,
|
||||
)
|
||||
|
||||
|
|
@ -29,8 +28,6 @@ from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import (
|
|||
# -v -s --tb=short --disable-warnings --asyncio-mode=auto
|
||||
|
||||
SQLITE_VEC_PROVIDER = "sqlite_vec"
|
||||
EMBEDDING_DIMENSION = 384
|
||||
EMBEDDING_MODEL = "all-MiniLM-L6-v2"
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
|
|
@ -38,74 +35,53 @@ def loop():
|
|||
return asyncio.new_event_loop()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def sqlite_connection(loop):
|
||||
conn = sqlite3.connect(":memory:")
|
||||
try:
|
||||
conn.enable_load_extension(True)
|
||||
sqlite_vec.load(conn)
|
||||
yield conn
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session", autouse=True)
|
||||
async def sqlite_vec_index(sqlite_connection):
|
||||
return await SQLiteVecIndex.create(dimension=EMBEDDING_DIMENSION, connection=sqlite_connection, bank_id="test_bank")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def sample_chunks():
|
||||
"""Generates chunks that force multiple batches for a single document to expose ID conflicts."""
|
||||
n, k = 10, 3
|
||||
sample = [
|
||||
Chunk(content=f"Sentence {i} from document {j}", metadata={"document_id": f"document-{j}"})
|
||||
for j in range(k)
|
||||
for i in range(n)
|
||||
]
|
||||
return sample
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def sample_embeddings(sample_chunks):
|
||||
np.random.seed(42)
|
||||
return np.array([np.random.rand(EMBEDDING_DIMENSION).astype(np.float32) for _ in sample_chunks])
|
||||
async def sqlite_vec_index(embedding_dimension, tmp_path_factory):
|
||||
temp_dir = tmp_path_factory.getbasetemp()
|
||||
db_path = str(temp_dir / "test_sqlite.db")
|
||||
index = await SQLiteVecIndex.create(dimension=embedding_dimension, db_path=db_path, bank_id="test_bank")
|
||||
yield index
|
||||
await index.delete()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_chunks(sqlite_vec_index, sample_chunks, sample_embeddings):
|
||||
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings, batch_size=2)
|
||||
cur = sqlite_vec_index.connection.cursor()
|
||||
connection = _create_sqlite_connection(sqlite_vec_index.db_path)
|
||||
cur = connection.cursor()
|
||||
cur.execute(f"SELECT COUNT(*) FROM {sqlite_vec_index.metadata_table}")
|
||||
count = cur.fetchone()[0]
|
||||
assert count == len(sample_chunks)
|
||||
cur.close()
|
||||
connection.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_chunks(sqlite_vec_index, sample_chunks, sample_embeddings):
|
||||
async def test_query_chunks(sqlite_vec_index, sample_chunks, sample_embeddings, embedding_dimension):
|
||||
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
query_embedding = np.random.rand(EMBEDDING_DIMENSION).astype(np.float32)
|
||||
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
|
||||
response = await sqlite_vec_index.query(query_embedding, k=2, score_threshold=0.0)
|
||||
assert isinstance(response, QueryChunksResponse)
|
||||
assert len(response.chunks) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chunk_id_conflict(sqlite_vec_index, sample_chunks):
|
||||
async def test_chunk_id_conflict(sqlite_vec_index, sample_chunks, embedding_dimension):
|
||||
"""Test that chunk IDs do not conflict across batches when inserting chunks."""
|
||||
# Reduce batch size to force multiple batches for same document
|
||||
# since there are 10 chunks per document and batch size is 2
|
||||
batch_size = 2
|
||||
sample_embeddings = np.random.rand(len(sample_chunks), EMBEDDING_DIMENSION).astype(np.float32)
|
||||
sample_embeddings = np.random.rand(len(sample_chunks), embedding_dimension).astype(np.float32)
|
||||
|
||||
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings, batch_size=batch_size)
|
||||
|
||||
cur = sqlite_vec_index.connection.cursor()
|
||||
connection = _create_sqlite_connection(sqlite_vec_index.db_path)
|
||||
cur = connection.cursor()
|
||||
|
||||
# Retrieve all chunk IDs to check for duplicates
|
||||
cur.execute(f"SELECT id FROM {sqlite_vec_index.metadata_table}")
|
||||
chunk_ids = [row[0] for row in cur.fetchall()]
|
||||
cur.close()
|
||||
connection.close()
|
||||
|
||||
# Ensure all chunk IDs are unique
|
||||
assert len(chunk_ids) == len(set(chunk_ids)), "Duplicate chunk IDs detected across batches!"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue