This commit is contained in:
Ashwin Bharambe 2025-10-20 14:05:48 -07:00
parent 1dd9f92a15
commit 00c6bbffb7
4 changed files with 36 additions and 60 deletions

View file

@ -354,17 +354,11 @@ class AuthenticationRequiredError(Exception):
pass pass
class DefaultEmbeddingModel(BaseModel): class QualifiedModel(BaseModel):
"""Configuration for default embedding model.""" """A qualified model identifier, consisting of a provider ID and a model ID."""
provider_id: str = Field( provider_id: str
..., model_id: str
description="ID of the inference provider that serves the embedding model (e.g., 'sentence-transformers').",
)
model_id: str = Field(
...,
description="ID of the embedding model (e.g., 'nomic-ai/nomic-embed-text-v1.5').",
)
class VectorStoresConfig(BaseModel): class VectorStoresConfig(BaseModel):
@ -374,7 +368,7 @@ class VectorStoresConfig(BaseModel):
default=None, default=None,
description="ID of the vector_io provider to use as default when multiple providers are available and none is specified.", description="ID of the vector_io provider to use as default when multiple providers are available and none is specified.",
) )
default_embedding_model: DefaultEmbeddingModel | None = Field( default_embedding_model: QualifiedModel | None = Field(
default=None, default=None,
description="Default embedding model configuration for vector stores.", description="Default embedding model configuration for vector stores.",
) )

View file

@ -35,7 +35,7 @@ from llama_stack.apis.telemetry import Telemetry
from llama_stack.apis.tools import RAGToolRuntime, ToolGroups, ToolRuntime from llama_stack.apis.tools import RAGToolRuntime, ToolGroups, ToolRuntime
from llama_stack.apis.vector_io import VectorIO from llama_stack.apis.vector_io import VectorIO
from llama_stack.core.conversations.conversations import ConversationServiceConfig, ConversationServiceImpl from llama_stack.core.conversations.conversations import ConversationServiceConfig, ConversationServiceImpl
from llama_stack.core.datatypes import Provider, StackRunConfig from llama_stack.core.datatypes import Provider, StackRunConfig, VectorStoresConfig
from llama_stack.core.distribution import get_provider_registry from llama_stack.core.distribution import get_provider_registry
from llama_stack.core.inspect import DistributionInspectConfig, DistributionInspectImpl from llama_stack.core.inspect import DistributionInspectConfig, DistributionInspectImpl
from llama_stack.core.prompts.prompts import PromptServiceConfig, PromptServiceImpl from llama_stack.core.prompts.prompts import PromptServiceConfig, PromptServiceImpl
@ -139,58 +139,40 @@ async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
) )
async def validate_vector_stores_config(run_config: StackRunConfig, impls: dict[Api, Any]): async def validate_vector_stores_config(vector_stores_config: VectorStoresConfig | None, impls: dict[Api, Any]):
"""Validate vector stores configuration.""" """Validate vector stores configuration."""
if not run_config.vector_stores: if vector_stores_config is None:
return return
vector_stores_config = run_config.vector_stores default_embedding_model = vector_stores_config.default_embedding_model
if default_embedding_model is None:
return
# Validate default embedding model if configured provider_id = default_embedding_model.provider_id
if vector_stores_config.default_embedding_model: model_id = default_embedding_model.model_id
default_embedding_model = vector_stores_config.default_embedding_model default_model_id = f"{provider_id}/{model_id}"
provider_id = default_embedding_model.provider_id
model_id = default_embedding_model.model_id
# Construct the full model identifier
default_model_id = f"{provider_id}/{model_id}"
if Api.models not in impls: if Api.models not in impls:
raise ValueError( raise ValueError(f"Models API is not available but vector_stores config requires model '{default_model_id}'")
f"Models API is not available but vector_stores config requires model '{default_model_id}'"
)
models_impl = impls[Api.models] models_impl = impls[Api.models]
response = await models_impl.list_models() response = await models_impl.list_models()
models_list = response.data if hasattr(response, "data") else response models_list = {m.identifier: m for m in response.data if m.model_type == "embedding"}
# find default embedding model default_model = models_list.get(default_model_id)
default_model = None if default_model is None:
for model in models_list: raise ValueError(f"Embedding model '{default_model_id}' not found. Available embedding models: {models_list}")
if model.identifier == default_model_id:
default_model = model
break
if not default_model: embedding_dimension = default_model.metadata.get("embedding_dimension")
available_models = [m.identifier for m in models_list if m.model_type == "embedding"] if embedding_dimension is None:
raise ValueError( raise ValueError(f"Embedding model '{default_model_id}' is missing 'embedding_dimension' in metadata")
f"Embedding model '{default_model_id}' not found. Available embedding models: {available_models}"
)
if default_model.model_type != "embedding": try:
raise ValueError(f"Model '{default_model_id}' is type '{default_model.model_type}', not 'embedding'") int(embedding_dimension)
except ValueError as err:
raise ValueError(f"Embedding dimension '{embedding_dimension}' cannot be converted to an integer") from err
embedding_dimension = default_model.metadata.get("embedding_dimension") logger.debug(f"Validated default embedding model: {default_model_id} (dimension: {embedding_dimension})")
if embedding_dimension is None:
raise ValueError(f"Embedding model '{default_model_id}' is missing 'embedding_dimension' in metadata")
try:
int(embedding_dimension)
except ValueError as err:
raise ValueError(f"Embedding dimension '{embedding_dimension}' cannot be converted to an integer") from err
logger.debug(f"Validated default embedding model: {default_model_id} (dimension: {embedding_dimension})")
# If no default embedding model is configured, that's fine - validation passes
class EnvVarError(Exception): class EnvVarError(Exception):
@ -429,7 +411,7 @@ class Stack:
await register_resources(self.run_config, impls) await register_resources(self.run_config, impls)
await refresh_registry_once(impls) await refresh_registry_once(impls)
await validate_vector_stores_config(self.run_config, impls) await validate_vector_stores_config(self.run_config.vector_stores, impls)
self.impls = impls self.impls = impls
def create_registry_refresh_task(self): def create_registry_refresh_task(self):

View file

@ -9,9 +9,9 @@ from typing import Any
from llama_stack.core.datatypes import ( from llama_stack.core.datatypes import (
BuildProvider, BuildProvider,
DefaultEmbeddingModel,
Provider, Provider,
ProviderSpec, ProviderSpec,
QualifiedModel,
ShieldInput, ShieldInput,
ToolGroupInput, ToolGroupInput,
VectorStoresConfig, VectorStoresConfig,
@ -251,7 +251,7 @@ def get_distribution_template(name: str = "starter") -> DistributionTemplate:
default_shields=default_shields, default_shields=default_shields,
vector_stores_config=VectorStoresConfig( vector_stores_config=VectorStoresConfig(
default_provider_id="faiss", default_provider_id="faiss",
default_embedding_model=DefaultEmbeddingModel( default_embedding_model=QualifiedModel(
provider_id="sentence-transformers", provider_id="sentence-transformers",
model_id="nomic-ai/nomic-embed-text-v1.5", model_id="nomic-ai/nomic-embed-text-v1.5",
), ),

View file

@ -11,7 +11,7 @@ from unittest.mock import AsyncMock
import pytest import pytest
from llama_stack.apis.models import Model, ModelType from llama_stack.apis.models import Model, ModelType
from llama_stack.core.datatypes import DefaultEmbeddingModel, StackRunConfig, StorageConfig, VectorStoresConfig from llama_stack.core.datatypes import QualifiedModel, StackRunConfig, StorageConfig, VectorStoresConfig
from llama_stack.core.stack import validate_vector_stores_config from llama_stack.core.stack import validate_vector_stores_config
from llama_stack.providers.datatypes import Api from llama_stack.providers.datatypes import Api
@ -25,7 +25,7 @@ class TestVectorStoresValidation:
storage=StorageConfig(backends={}, stores={}), storage=StorageConfig(backends={}, stores={}),
vector_stores=VectorStoresConfig( vector_stores=VectorStoresConfig(
default_provider_id="faiss", default_provider_id="faiss",
default_embedding_model=DefaultEmbeddingModel( default_embedding_model=QualifiedModel(
provider_id="p", provider_id="p",
model_id="missing", model_id="missing",
), ),
@ -45,7 +45,7 @@ class TestVectorStoresValidation:
storage=StorageConfig(backends={}, stores={}), storage=StorageConfig(backends={}, stores={}),
vector_stores=VectorStoresConfig( vector_stores=VectorStoresConfig(
default_provider_id="faiss", default_provider_id="faiss",
default_embedding_model=DefaultEmbeddingModel( default_embedding_model=QualifiedModel(
provider_id="p", provider_id="p",
model_id="valid", model_id="valid",
), ),