mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-23 00:27:26 +00:00
chore: Updating how default embedding model is set in stack (#3818)
# What does this PR do? Refactor setting default vector store provider and embedding model to use an optional `vector_stores` config in the `StackRunConfig` and clean up code to do so (had to add back in some pieces of VectorDB). Also added remote Qdrant and Weaviate to starter distro (based on other PR where inference providers were added for UX). New config is simply (default for Starter distro): ```yaml vector_stores: default_provider_id: faiss default_embedding_model: provider_id: sentence-transformers model_id: nomic-ai/nomic-embed-text-v1.5 ``` ## Test Plan CI and Unit tests. --------- Signed-off-by: Francisco Javier Arceo <farceo@redhat.com> Co-authored-by: Ashwin Bharambe <ashwin.bharambe@gmail.com>
This commit is contained in:
parent
2c43285e22
commit
48581bf651
48 changed files with 973 additions and 818 deletions
|
@ -4,90 +4,64 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Unit tests for Stack validation functions.
|
||||
"""
|
||||
"""Unit tests for Stack validation functions."""
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.core.stack import validate_default_embedding_model
|
||||
from llama_stack.apis.models import ListModelsResponse, Model, ModelType
|
||||
from llama_stack.core.datatypes import QualifiedModel, StackRunConfig, StorageConfig, VectorStoresConfig
|
||||
from llama_stack.core.stack import validate_vector_stores_config
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
|
||||
class TestStackValidation:
|
||||
"""Test Stack validation functions."""
|
||||
class TestVectorStoresValidation:
|
||||
async def test_validate_missing_model(self):
|
||||
"""Test validation fails when model not found."""
|
||||
run_config = StackRunConfig(
|
||||
image_name="test",
|
||||
providers={},
|
||||
storage=StorageConfig(backends={}, stores={}),
|
||||
vector_stores=VectorStoresConfig(
|
||||
default_provider_id="faiss",
|
||||
default_embedding_model=QualifiedModel(
|
||||
provider_id="p",
|
||||
model_id="missing",
|
||||
),
|
||||
),
|
||||
)
|
||||
mock_models = AsyncMock()
|
||||
mock_models.list_models.return_value = ListModelsResponse(data=[])
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"models,should_raise",
|
||||
[
|
||||
([], False), # No models
|
||||
(
|
||||
[
|
||||
Model(
|
||||
identifier="emb1",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={"default_configured": True},
|
||||
provider_id="p",
|
||||
provider_resource_id="emb1",
|
||||
)
|
||||
],
|
||||
False,
|
||||
), # Single default
|
||||
(
|
||||
[
|
||||
Model(
|
||||
identifier="emb1",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={"default_configured": True},
|
||||
provider_id="p",
|
||||
provider_resource_id="emb1",
|
||||
),
|
||||
Model(
|
||||
identifier="emb2",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={"default_configured": True},
|
||||
provider_id="p",
|
||||
provider_resource_id="emb2",
|
||||
),
|
||||
],
|
||||
True,
|
||||
), # Multiple defaults
|
||||
(
|
||||
[
|
||||
Model(
|
||||
identifier="emb1",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={"default_configured": True},
|
||||
provider_id="p",
|
||||
provider_resource_id="emb1",
|
||||
),
|
||||
Model(
|
||||
identifier="llm1",
|
||||
model_type=ModelType.llm,
|
||||
metadata={"default_configured": True},
|
||||
provider_id="p",
|
||||
provider_resource_id="llm1",
|
||||
),
|
||||
],
|
||||
False,
|
||||
), # Ignores non-embedding
|
||||
],
|
||||
)
|
||||
async def test_validate_default_embedding_model(self, models, should_raise):
|
||||
"""Test validation with various model configurations."""
|
||||
mock_models_impl = AsyncMock()
|
||||
mock_models_impl.list_models.return_value = models
|
||||
impls = {Api.models: mock_models_impl}
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
await validate_vector_stores_config(run_config.vector_stores, {Api.models: mock_models})
|
||||
|
||||
if should_raise:
|
||||
with pytest.raises(ValueError, match="Multiple embedding models marked as default_configured=True"):
|
||||
await validate_default_embedding_model(impls)
|
||||
else:
|
||||
await validate_default_embedding_model(impls)
|
||||
async def test_validate_success(self):
|
||||
"""Test validation passes with valid model."""
|
||||
run_config = StackRunConfig(
|
||||
image_name="test",
|
||||
providers={},
|
||||
storage=StorageConfig(backends={}, stores={}),
|
||||
vector_stores=VectorStoresConfig(
|
||||
default_provider_id="faiss",
|
||||
default_embedding_model=QualifiedModel(
|
||||
provider_id="p",
|
||||
model_id="valid",
|
||||
),
|
||||
),
|
||||
)
|
||||
mock_models = AsyncMock()
|
||||
mock_models.list_models.return_value = ListModelsResponse(
|
||||
data=[
|
||||
Model(
|
||||
identifier="p/valid", # Must match provider_id/model_id format
|
||||
model_type=ModelType.embedding,
|
||||
metadata={"embedding_dimension": 768},
|
||||
provider_id="p",
|
||||
provider_resource_id="valid",
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
async def test_validate_default_embedding_model_no_models_api(self):
|
||||
"""Test validation when models API is not available."""
|
||||
await validate_default_embedding_model({})
|
||||
await validate_vector_stores_config(run_config.vector_stores, {Api.models: mock_models})
|
||||
|
|
|
@ -146,7 +146,6 @@ async def sqlite_vec_adapter(sqlite_vec_db_path, unique_kvstore_config, mock_inf
|
|||
config=config,
|
||||
inference_api=mock_inference_api,
|
||||
files_api=None,
|
||||
models_api=None,
|
||||
)
|
||||
collection_id = f"sqlite_test_collection_{np.random.randint(1e6)}"
|
||||
await adapter.initialize()
|
||||
|
@ -185,7 +184,6 @@ async def faiss_vec_adapter(unique_kvstore_config, mock_inference_api, embedding
|
|||
config=config,
|
||||
inference_api=mock_inference_api,
|
||||
files_api=None,
|
||||
models_api=None,
|
||||
)
|
||||
await adapter.initialize()
|
||||
await adapter.register_vector_db(
|
||||
|
|
|
@ -11,7 +11,6 @@ import numpy as np
|
|||
import pytest
|
||||
|
||||
from llama_stack.apis.files import Files
|
||||
from llama_stack.apis.models import Models
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse
|
||||
from llama_stack.providers.datatypes import HealthStatus
|
||||
|
@ -76,12 +75,6 @@ def mock_files_api():
|
|||
return mock_api
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_models_api():
|
||||
mock_api = MagicMock(spec=Models)
|
||||
return mock_api
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def faiss_config():
|
||||
config = MagicMock(spec=FaissVectorIOConfig)
|
||||
|
@ -117,7 +110,7 @@ async def test_faiss_query_vector_returns_infinity_when_query_and_embedding_are_
|
|||
assert response.chunks[1] == sample_chunks[1]
|
||||
|
||||
|
||||
async def test_health_success(mock_models_api):
|
||||
async def test_health_success():
|
||||
"""Test that the health check returns OK status when faiss is working correctly."""
|
||||
# Create a fresh instance of FaissVectorIOAdapter for testing
|
||||
config = MagicMock()
|
||||
|
@ -126,9 +119,7 @@ async def test_health_success(mock_models_api):
|
|||
|
||||
with patch("llama_stack.providers.inline.vector_io.faiss.faiss.faiss.IndexFlatL2") as mock_index_flat:
|
||||
mock_index_flat.return_value = MagicMock()
|
||||
adapter = FaissVectorIOAdapter(
|
||||
config=config, inference_api=inference_api, models_api=mock_models_api, files_api=files_api
|
||||
)
|
||||
adapter = FaissVectorIOAdapter(config=config, inference_api=inference_api, files_api=files_api)
|
||||
|
||||
# Calling the health method directly
|
||||
response = await adapter.health()
|
||||
|
@ -142,7 +133,7 @@ async def test_health_success(mock_models_api):
|
|||
mock_index_flat.assert_called_once_with(128) # VECTOR_DIMENSION is 128
|
||||
|
||||
|
||||
async def test_health_failure(mock_models_api):
|
||||
async def test_health_failure():
|
||||
"""Test that the health check returns ERROR status when faiss encounters an error."""
|
||||
# Create a fresh instance of FaissVectorIOAdapter for testing
|
||||
config = MagicMock()
|
||||
|
@ -152,9 +143,7 @@ async def test_health_failure(mock_models_api):
|
|||
with patch("llama_stack.providers.inline.vector_io.faiss.faiss.faiss.IndexFlatL2") as mock_index_flat:
|
||||
mock_index_flat.side_effect = Exception("Test error")
|
||||
|
||||
adapter = FaissVectorIOAdapter(
|
||||
config=config, inference_api=inference_api, models_api=mock_models_api, files_api=files_api
|
||||
)
|
||||
adapter = FaissVectorIOAdapter(config=config, inference_api=inference_api, files_api=files_api)
|
||||
|
||||
# Calling the health method directly
|
||||
response = await adapter.health()
|
||||
|
|
|
@ -6,13 +6,12 @@
|
|||
|
||||
import json
|
||||
import time
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.common.errors import VectorStoreNotFoundError
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import (
|
||||
Chunk,
|
||||
|
@ -996,96 +995,6 @@ async def test_max_concurrent_files_per_batch(vector_io_adapter):
|
|||
assert batch.file_counts.in_progress == 8
|
||||
|
||||
|
||||
async def test_get_default_embedding_model_success(vector_io_adapter):
|
||||
"""Test successful default embedding model detection."""
|
||||
# Mock models API with a default model
|
||||
mock_models_api = Mock()
|
||||
mock_models_api.list_models = AsyncMock(
|
||||
return_value=Mock(
|
||||
data=[
|
||||
Model(
|
||||
identifier="nomic-embed-text-v1.5",
|
||||
model_type=ModelType.embedding,
|
||||
provider_id="test-provider",
|
||||
metadata={
|
||||
"embedding_dimension": 768,
|
||||
"default_configured": True,
|
||||
},
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
vector_io_adapter.models_api = mock_models_api
|
||||
result = await vector_io_adapter._get_default_embedding_model_and_dimension()
|
||||
|
||||
assert result is not None
|
||||
model_id, dimension = result
|
||||
assert model_id == "nomic-embed-text-v1.5"
|
||||
assert dimension == 768
|
||||
|
||||
|
||||
async def test_get_default_embedding_model_multiple_defaults_error(vector_io_adapter):
|
||||
"""Test error when multiple models are marked as default."""
|
||||
mock_models_api = Mock()
|
||||
mock_models_api.list_models = AsyncMock(
|
||||
return_value=Mock(
|
||||
data=[
|
||||
Model(
|
||||
identifier="model1",
|
||||
model_type=ModelType.embedding,
|
||||
provider_id="test-provider",
|
||||
metadata={"embedding_dimension": 768, "default_configured": True},
|
||||
),
|
||||
Model(
|
||||
identifier="model2",
|
||||
model_type=ModelType.embedding,
|
||||
provider_id="test-provider",
|
||||
metadata={"embedding_dimension": 512, "default_configured": True},
|
||||
),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
vector_io_adapter.models_api = mock_models_api
|
||||
|
||||
with pytest.raises(ValueError, match="Multiple embedding models marked as default_configured=True"):
|
||||
await vector_io_adapter._get_default_embedding_model_and_dimension()
|
||||
|
||||
|
||||
async def test_openai_create_vector_store_uses_default_model(vector_io_adapter):
|
||||
"""Test that vector store creation uses default embedding model when none specified."""
|
||||
# Mock models API and dependencies
|
||||
mock_models_api = Mock()
|
||||
mock_models_api.list_models = AsyncMock(
|
||||
return_value=Mock(
|
||||
data=[
|
||||
Model(
|
||||
identifier="default-model",
|
||||
model_type=ModelType.embedding,
|
||||
provider_id="test-provider",
|
||||
metadata={"embedding_dimension": 512, "default_configured": True},
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
vector_io_adapter.models_api = mock_models_api
|
||||
vector_io_adapter.register_vector_db = AsyncMock()
|
||||
vector_io_adapter.__provider_id__ = "test-provider"
|
||||
|
||||
# Create vector store without specifying embedding model
|
||||
params = OpenAICreateVectorStoreRequestWithExtraBody(name="test-store")
|
||||
result = await vector_io_adapter.openai_create_vector_store(params)
|
||||
|
||||
# Verify the vector store was created with default model
|
||||
assert result.name == "test-store"
|
||||
vector_io_adapter.register_vector_db.assert_called_once()
|
||||
call_args = vector_io_adapter.register_vector_db.call_args[0][0]
|
||||
assert call_args.embedding_model == "default-model"
|
||||
assert call_args.embedding_dimension == 512
|
||||
|
||||
|
||||
async def test_embedding_config_from_metadata(vector_io_adapter):
|
||||
"""Test that embedding configuration is correctly extracted from metadata."""
|
||||
|
||||
|
@ -1253,5 +1162,5 @@ async def test_embedding_config_required_model_missing(vector_io_adapter):
|
|||
# Test with no embedding model provided
|
||||
params = OpenAICreateVectorStoreRequestWithExtraBody(name="test_store", metadata={})
|
||||
|
||||
with pytest.raises(ValueError, match="embedding_model is required in extra_body when creating a vector store"):
|
||||
with pytest.raises(ValueError, match="embedding_model is required"):
|
||||
await vector_io_adapter.openai_create_vector_store(params)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue