incorporating feedback

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
Francisco Javier Arceo 2025-10-14 21:10:44 -04:00
parent 86c1e3b217
commit 5a4b291b3e
3 changed files with 130 additions and 12 deletions

View file

@ -98,6 +98,30 @@ REGISTRY_REFRESH_TASK = None
TEST_RECORDING_CONTEXT = None
async def validate_default_embedding_model(impls: dict[Api, Any]):
"""Validate that at most one embedding model is marked as default."""
if Api.models not in impls:
return
models_impl = impls[Api.models]
response = await models_impl.list_models()
models_list = response.data if hasattr(response, "data") else response
default_embedding_models = []
for model in models_list:
if model.model_type == "embedding" and model.metadata.get("default_configured") is True:
default_embedding_models.append(model.identifier)
if len(default_embedding_models) > 1:
raise ValueError(
f"Multiple embedding models marked as default_configured=True: {default_embedding_models}. "
"Only one embedding model can be marked as default."
)
if default_embedding_models:
logger.info(f"Default embedding model configured: {default_embedding_models[0]}")
async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
for rsrc, api, register_method, list_method in RESOURCES:
objects = getattr(run_config, rsrc)
@ -128,6 +152,8 @@ async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
f"{rsrc.capitalize()}: {obj.identifier} served by {obj.provider_id}",
)
await validate_default_embedding_model(impls)
class EnvVarError(Exception):
def __init__(self, var_name: str, path: str = ""):

View file

@ -501,25 +501,24 @@ class OpenAIVectorStoreMixin(ABC):
"""
embedding_models = await self._get_embedding_models()
default_model_info = []
default_models = []
for model in embedding_models:
if model.metadata.get("default_configured") is True:
embedding_dimension = model.metadata.get("embedding_dimension")
if embedding_dimension is None:
raise ValueError(f"Embedding model '{model.identifier}' has no embedding_dimension in metadata")
default_model_info.append((model.identifier, int(embedding_dimension)))
default_models.append(model.identifier)
if len(default_model_info) > 1:
model_ids = [info[0] for info in default_model_info]
if len(default_models) > 1:
raise ValueError(
f"Multiple embedding models marked as default_configured=True: {model_ids}. "
f"Multiple embedding models marked as default_configured=True: {default_models}. "
"Only one embedding model can be marked as default."
)
if default_model_info:
model_id, dimension = default_model_info[0]
logger.info(f"Using default embedding model: {model_id} with dimension {dimension}")
return model_id, dimension
if default_models:
model_id = default_models[0]
embedding_dimension = await self._get_embedding_dimension_for_model(model_id)
if embedding_dimension is None:
raise ValueError(f"Embedding model '{model_id}' has no embedding_dimension in metadata")
logger.info(f"Using default embedding model: {model_id} with dimension {embedding_dimension}")
return model_id, embedding_dimension
logger.info("DEBUG: No default embedding models found")
return None