incorporating feedback

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
Francisco Javier Arceo 2025-10-14 21:10:44 -04:00
parent 86c1e3b217
commit 5a4b291b3e
3 changed files with 130 additions and 12 deletions

View file

@ -98,6 +98,30 @@ REGISTRY_REFRESH_TASK = None
TEST_RECORDING_CONTEXT = None
async def validate_default_embedding_model(impls: dict[Api, Any]):
"""Validate that at most one embedding model is marked as default."""
if Api.models not in impls:
return
models_impl = impls[Api.models]
response = await models_impl.list_models()
models_list = response.data if hasattr(response, "data") else response
default_embedding_models = []
for model in models_list:
if model.model_type == "embedding" and model.metadata.get("default_configured") is True:
default_embedding_models.append(model.identifier)
if len(default_embedding_models) > 1:
raise ValueError(
f"Multiple embedding models marked as default_configured=True: {default_embedding_models}. "
"Only one embedding model can be marked as default."
)
if default_embedding_models:
logger.info(f"Default embedding model configured: {default_embedding_models[0]}")
async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
for rsrc, api, register_method, list_method in RESOURCES:
objects = getattr(run_config, rsrc)
@ -128,6 +152,8 @@ async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
f"{rsrc.capitalize()}: {obj.identifier} served by {obj.provider_id}",
)
await validate_default_embedding_model(impls)
class EnvVarError(Exception):
def __init__(self, var_name: str, path: str = ""):

View file

@ -501,25 +501,24 @@ class OpenAIVectorStoreMixin(ABC):
"""
embedding_models = await self._get_embedding_models()
default_model_info = []
default_models = []
for model in embedding_models:
if model.metadata.get("default_configured") is True:
embedding_dimension = model.metadata.get("embedding_dimension")
if embedding_dimension is None:
raise ValueError(f"Embedding model '{model.identifier}' has no embedding_dimension in metadata")
default_model_info.append((model.identifier, int(embedding_dimension)))
default_models.append(model.identifier)
if len(default_model_info) > 1:
model_ids = [info[0] for info in default_model_info]
if len(default_models) > 1:
raise ValueError(
f"Multiple embedding models marked as default_configured=True: {model_ids}. "
f"Multiple embedding models marked as default_configured=True: {default_models}. "
"Only one embedding model can be marked as default."
)
if default_model_info:
model_id, dimension = default_model_info[0]
logger.info(f"Using default embedding model: {model_id} with dimension {dimension}")
return model_id, dimension
if default_models:
model_id = default_models[0]
embedding_dimension = await self._get_embedding_dimension_for_model(model_id)
if embedding_dimension is None:
raise ValueError(f"Embedding model '{model_id}' has no embedding_dimension in metadata")
logger.info(f"Using default embedding model: {model_id} with dimension {embedding_dimension}")
return model_id, embedding_dimension
logger.info("DEBUG: No default embedding models found")
return None

View file

@ -0,0 +1,93 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
"""
Unit tests for Stack validation functions.
"""
from unittest.mock import AsyncMock
import pytest
from llama_stack.apis.models import Model, ModelType
from llama_stack.core.stack import validate_default_embedding_model
from llama_stack.providers.datatypes import Api
class TestStackValidation:
"""Test Stack validation functions."""
@pytest.mark.parametrize(
"models,should_raise",
[
([], False), # No models
(
[
Model(
identifier="emb1",
model_type=ModelType.embedding,
metadata={"default_configured": True},
provider_id="p",
provider_resource_id="emb1",
)
],
False,
), # Single default
(
[
Model(
identifier="emb1",
model_type=ModelType.embedding,
metadata={"default_configured": True},
provider_id="p",
provider_resource_id="emb1",
),
Model(
identifier="emb2",
model_type=ModelType.embedding,
metadata={"default_configured": True},
provider_id="p",
provider_resource_id="emb2",
),
],
True,
), # Multiple defaults
(
[
Model(
identifier="emb1",
model_type=ModelType.embedding,
metadata={"default_configured": True},
provider_id="p",
provider_resource_id="emb1",
),
Model(
identifier="llm1",
model_type=ModelType.llm,
metadata={"default_configured": True},
provider_id="p",
provider_resource_id="llm1",
),
],
False,
), # Ignores non-embedding
],
)
async def test_validate_default_embedding_model(self, models, should_raise):
"""Test validation with various model configurations."""
mock_models_impl = AsyncMock()
mock_models_impl.list_models.return_value = models
impls = {Api.models: mock_models_impl}
if should_raise:
with pytest.raises(ValueError, match="Multiple embedding models marked as default_configured=True"):
await validate_default_embedding_model(impls)
else:
await validate_default_embedding_model(impls)
async def test_validate_default_embedding_model_no_models_api(self):
"""Test validation when models API is not available."""
await validate_default_embedding_model({})