mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-12 04:00:42 +00:00
incorporating feedback
Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
parent
86c1e3b217
commit
5a4b291b3e
3 changed files with 130 additions and 12 deletions
|
|
@ -98,6 +98,30 @@ REGISTRY_REFRESH_TASK = None
|
||||||
TEST_RECORDING_CONTEXT = None
|
TEST_RECORDING_CONTEXT = None
|
||||||
|
|
||||||
|
|
||||||
|
async def validate_default_embedding_model(impls: dict[Api, Any]):
|
||||||
|
"""Validate that at most one embedding model is marked as default."""
|
||||||
|
if Api.models not in impls:
|
||||||
|
return
|
||||||
|
|
||||||
|
models_impl = impls[Api.models]
|
||||||
|
response = await models_impl.list_models()
|
||||||
|
models_list = response.data if hasattr(response, "data") else response
|
||||||
|
|
||||||
|
default_embedding_models = []
|
||||||
|
for model in models_list:
|
||||||
|
if model.model_type == "embedding" and model.metadata.get("default_configured") is True:
|
||||||
|
default_embedding_models.append(model.identifier)
|
||||||
|
|
||||||
|
if len(default_embedding_models) > 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"Multiple embedding models marked as default_configured=True: {default_embedding_models}. "
|
||||||
|
"Only one embedding model can be marked as default."
|
||||||
|
)
|
||||||
|
|
||||||
|
if default_embedding_models:
|
||||||
|
logger.info(f"Default embedding model configured: {default_embedding_models[0]}")
|
||||||
|
|
||||||
|
|
||||||
async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
|
async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
|
||||||
for rsrc, api, register_method, list_method in RESOURCES:
|
for rsrc, api, register_method, list_method in RESOURCES:
|
||||||
objects = getattr(run_config, rsrc)
|
objects = getattr(run_config, rsrc)
|
||||||
|
|
@ -128,6 +152,8 @@ async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
|
||||||
f"{rsrc.capitalize()}: {obj.identifier} served by {obj.provider_id}",
|
f"{rsrc.capitalize()}: {obj.identifier} served by {obj.provider_id}",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
await validate_default_embedding_model(impls)
|
||||||
|
|
||||||
|
|
||||||
class EnvVarError(Exception):
|
class EnvVarError(Exception):
|
||||||
def __init__(self, var_name: str, path: str = ""):
|
def __init__(self, var_name: str, path: str = ""):
|
||||||
|
|
|
||||||
|
|
@ -501,25 +501,24 @@ class OpenAIVectorStoreMixin(ABC):
|
||||||
"""
|
"""
|
||||||
embedding_models = await self._get_embedding_models()
|
embedding_models = await self._get_embedding_models()
|
||||||
|
|
||||||
default_model_info = []
|
default_models = []
|
||||||
for model in embedding_models:
|
for model in embedding_models:
|
||||||
if model.metadata.get("default_configured") is True:
|
if model.metadata.get("default_configured") is True:
|
||||||
embedding_dimension = model.metadata.get("embedding_dimension")
|
default_models.append(model.identifier)
|
||||||
if embedding_dimension is None:
|
|
||||||
raise ValueError(f"Embedding model '{model.identifier}' has no embedding_dimension in metadata")
|
|
||||||
default_model_info.append((model.identifier, int(embedding_dimension)))
|
|
||||||
|
|
||||||
if len(default_model_info) > 1:
|
if len(default_models) > 1:
|
||||||
model_ids = [info[0] for info in default_model_info]
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Multiple embedding models marked as default_configured=True: {model_ids}. "
|
f"Multiple embedding models marked as default_configured=True: {default_models}. "
|
||||||
"Only one embedding model can be marked as default."
|
"Only one embedding model can be marked as default."
|
||||||
)
|
)
|
||||||
|
|
||||||
if default_model_info:
|
if default_models:
|
||||||
model_id, dimension = default_model_info[0]
|
model_id = default_models[0]
|
||||||
logger.info(f"Using default embedding model: {model_id} with dimension {dimension}")
|
embedding_dimension = await self._get_embedding_dimension_for_model(model_id)
|
||||||
return model_id, dimension
|
if embedding_dimension is None:
|
||||||
|
raise ValueError(f"Embedding model '{model_id}' has no embedding_dimension in metadata")
|
||||||
|
logger.info(f"Using default embedding model: {model_id} with dimension {embedding_dimension}")
|
||||||
|
return model_id, embedding_dimension
|
||||||
|
|
||||||
logger.info("DEBUG: No default embedding models found")
|
logger.info("DEBUG: No default embedding models found")
|
||||||
return None
|
return None
|
||||||
|
|
|
||||||
93
tests/unit/core/test_stack_validation.py
Normal file
93
tests/unit/core/test_stack_validation.py
Normal file
|
|
@ -0,0 +1,93 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
"""
|
||||||
|
Unit tests for Stack validation functions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from llama_stack.apis.models import Model, ModelType
|
||||||
|
from llama_stack.core.stack import validate_default_embedding_model
|
||||||
|
from llama_stack.providers.datatypes import Api
|
||||||
|
|
||||||
|
|
||||||
|
class TestStackValidation:
|
||||||
|
"""Test Stack validation functions."""
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"models,should_raise",
|
||||||
|
[
|
||||||
|
([], False), # No models
|
||||||
|
(
|
||||||
|
[
|
||||||
|
Model(
|
||||||
|
identifier="emb1",
|
||||||
|
model_type=ModelType.embedding,
|
||||||
|
metadata={"default_configured": True},
|
||||||
|
provider_id="p",
|
||||||
|
provider_resource_id="emb1",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
False,
|
||||||
|
), # Single default
|
||||||
|
(
|
||||||
|
[
|
||||||
|
Model(
|
||||||
|
identifier="emb1",
|
||||||
|
model_type=ModelType.embedding,
|
||||||
|
metadata={"default_configured": True},
|
||||||
|
provider_id="p",
|
||||||
|
provider_resource_id="emb1",
|
||||||
|
),
|
||||||
|
Model(
|
||||||
|
identifier="emb2",
|
||||||
|
model_type=ModelType.embedding,
|
||||||
|
metadata={"default_configured": True},
|
||||||
|
provider_id="p",
|
||||||
|
provider_resource_id="emb2",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
True,
|
||||||
|
), # Multiple defaults
|
||||||
|
(
|
||||||
|
[
|
||||||
|
Model(
|
||||||
|
identifier="emb1",
|
||||||
|
model_type=ModelType.embedding,
|
||||||
|
metadata={"default_configured": True},
|
||||||
|
provider_id="p",
|
||||||
|
provider_resource_id="emb1",
|
||||||
|
),
|
||||||
|
Model(
|
||||||
|
identifier="llm1",
|
||||||
|
model_type=ModelType.llm,
|
||||||
|
metadata={"default_configured": True},
|
||||||
|
provider_id="p",
|
||||||
|
provider_resource_id="llm1",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
False,
|
||||||
|
), # Ignores non-embedding
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_validate_default_embedding_model(self, models, should_raise):
|
||||||
|
"""Test validation with various model configurations."""
|
||||||
|
mock_models_impl = AsyncMock()
|
||||||
|
mock_models_impl.list_models.return_value = models
|
||||||
|
impls = {Api.models: mock_models_impl}
|
||||||
|
|
||||||
|
if should_raise:
|
||||||
|
with pytest.raises(ValueError, match="Multiple embedding models marked as default_configured=True"):
|
||||||
|
await validate_default_embedding_model(impls)
|
||||||
|
else:
|
||||||
|
await validate_default_embedding_model(impls)
|
||||||
|
|
||||||
|
async def test_validate_default_embedding_model_no_models_api(self):
|
||||||
|
"""Test validation when models API is not available."""
|
||||||
|
await validate_default_embedding_model({})
|
||||||
Loading…
Add table
Add a link
Reference in a new issue