From 5a4b291b3eb1b488127a8a1d0432d43a16574fcb Mon Sep 17 00:00:00 2001 From: Francisco Javier Arceo Date: Tue, 14 Oct 2025 21:10:44 -0400 Subject: [PATCH] incorporating feedback Signed-off-by: Francisco Javier Arceo --- llama_stack/core/stack.py | 26 ++++++ .../utils/memory/openai_vector_store_mixin.py | 23 +++-- tests/unit/core/test_stack_validation.py | 93 +++++++++++++++++++ 3 files changed, 130 insertions(+), 12 deletions(-) create mode 100644 tests/unit/core/test_stack_validation.py diff --git a/llama_stack/core/stack.py b/llama_stack/core/stack.py index f161ac358..733b55262 100644 --- a/llama_stack/core/stack.py +++ b/llama_stack/core/stack.py @@ -98,6 +98,30 @@ REGISTRY_REFRESH_TASK = None TEST_RECORDING_CONTEXT = None +async def validate_default_embedding_model(impls: dict[Api, Any]): + """Validate that at most one embedding model is marked as default.""" + if Api.models not in impls: + return + + models_impl = impls[Api.models] + response = await models_impl.list_models() + models_list = response.data if hasattr(response, "data") else response + + default_embedding_models = [] + for model in models_list: + if model.model_type == "embedding" and model.metadata.get("default_configured") is True: + default_embedding_models.append(model.identifier) + + if len(default_embedding_models) > 1: + raise ValueError( + f"Multiple embedding models marked as default_configured=True: {default_embedding_models}. " + "Only one embedding model can be marked as default." + ) + + if default_embedding_models: + logger.info(f"Default embedding model configured: {default_embedding_models[0]}") + + async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]): for rsrc, api, register_method, list_method in RESOURCES: objects = getattr(run_config, rsrc) @@ -128,6 +152,8 @@ async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]): f"{rsrc.capitalize()}: {obj.identifier} served by {obj.provider_id}", ) + await validate_default_embedding_model(impls) + class EnvVarError(Exception): def __init__(self, var_name: str, path: str = ""): diff --git a/llama_stack/providers/utils/memory/openai_vector_store_mixin.py b/llama_stack/providers/utils/memory/openai_vector_store_mixin.py index eb84f753e..37fd94db0 100644 --- a/llama_stack/providers/utils/memory/openai_vector_store_mixin.py +++ b/llama_stack/providers/utils/memory/openai_vector_store_mixin.py @@ -501,25 +501,24 @@ class OpenAIVectorStoreMixin(ABC): """ embedding_models = await self._get_embedding_models() - default_model_info = [] + default_models = [] for model in embedding_models: if model.metadata.get("default_configured") is True: - embedding_dimension = model.metadata.get("embedding_dimension") - if embedding_dimension is None: - raise ValueError(f"Embedding model '{model.identifier}' has no embedding_dimension in metadata") - default_model_info.append((model.identifier, int(embedding_dimension))) + default_models.append(model.identifier) - if len(default_model_info) > 1: - model_ids = [info[0] for info in default_model_info] + if len(default_models) > 1: raise ValueError( - f"Multiple embedding models marked as default_configured=True: {model_ids}. " + f"Multiple embedding models marked as default_configured=True: {default_models}. " "Only one embedding model can be marked as default." ) - if default_model_info: - model_id, dimension = default_model_info[0] - logger.info(f"Using default embedding model: {model_id} with dimension {dimension}") - return model_id, dimension + if default_models: + model_id = default_models[0] + embedding_dimension = await self._get_embedding_dimension_for_model(model_id) + if embedding_dimension is None: + raise ValueError(f"Embedding model '{model_id}' has no embedding_dimension in metadata") + logger.info(f"Using default embedding model: {model_id} with dimension {embedding_dimension}") + return model_id, embedding_dimension logger.info("DEBUG: No default embedding models found") return None diff --git a/tests/unit/core/test_stack_validation.py b/tests/unit/core/test_stack_validation.py new file mode 100644 index 000000000..5fc27e199 --- /dev/null +++ b/tests/unit/core/test_stack_validation.py @@ -0,0 +1,93 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +""" +Unit tests for Stack validation functions. +""" + +from unittest.mock import AsyncMock + +import pytest + +from llama_stack.apis.models import Model, ModelType +from llama_stack.core.stack import validate_default_embedding_model +from llama_stack.providers.datatypes import Api + + +class TestStackValidation: + """Test Stack validation functions.""" + + @pytest.mark.parametrize( + "models,should_raise", + [ + ([], False), # No models + ( + [ + Model( + identifier="emb1", + model_type=ModelType.embedding, + metadata={"default_configured": True}, + provider_id="p", + provider_resource_id="emb1", + ) + ], + False, + ), # Single default + ( + [ + Model( + identifier="emb1", + model_type=ModelType.embedding, + metadata={"default_configured": True}, + provider_id="p", + provider_resource_id="emb1", + ), + Model( + identifier="emb2", + model_type=ModelType.embedding, + metadata={"default_configured": True}, + provider_id="p", + provider_resource_id="emb2", + ), + ], + True, + ), # Multiple defaults + ( + [ + Model( + identifier="emb1", + model_type=ModelType.embedding, + metadata={"default_configured": True}, + provider_id="p", + provider_resource_id="emb1", + ), + Model( + identifier="llm1", + model_type=ModelType.llm, + metadata={"default_configured": True}, + provider_id="p", + provider_resource_id="llm1", + ), + ], + False, + ), # Ignores non-embedding + ], + ) + async def test_validate_default_embedding_model(self, models, should_raise): + """Test validation with various model configurations.""" + mock_models_impl = AsyncMock() + mock_models_impl.list_models.return_value = models + impls = {Api.models: mock_models_impl} + + if should_raise: + with pytest.raises(ValueError, match="Multiple embedding models marked as default_configured=True"): + await validate_default_embedding_model(impls) + else: + await validate_default_embedding_model(impls) + + async def test_validate_default_embedding_model_no_models_api(self): + """Test validation when models API is not available.""" + await validate_default_embedding_model({})