mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-12 20:12:33 +00:00
cleanup
This commit is contained in:
parent
1dd9f92a15
commit
00c6bbffb7
4 changed files with 36 additions and 60 deletions
|
|
@ -354,17 +354,11 @@ class AuthenticationRequiredError(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class DefaultEmbeddingModel(BaseModel):
|
class QualifiedModel(BaseModel):
|
||||||
"""Configuration for default embedding model."""
|
"""A qualified model identifier, consisting of a provider ID and a model ID."""
|
||||||
|
|
||||||
provider_id: str = Field(
|
provider_id: str
|
||||||
...,
|
model_id: str
|
||||||
description="ID of the inference provider that serves the embedding model (e.g., 'sentence-transformers').",
|
|
||||||
)
|
|
||||||
model_id: str = Field(
|
|
||||||
...,
|
|
||||||
description="ID of the embedding model (e.g., 'nomic-ai/nomic-embed-text-v1.5').",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class VectorStoresConfig(BaseModel):
|
class VectorStoresConfig(BaseModel):
|
||||||
|
|
@ -374,7 +368,7 @@ class VectorStoresConfig(BaseModel):
|
||||||
default=None,
|
default=None,
|
||||||
description="ID of the vector_io provider to use as default when multiple providers are available and none is specified.",
|
description="ID of the vector_io provider to use as default when multiple providers are available and none is specified.",
|
||||||
)
|
)
|
||||||
default_embedding_model: DefaultEmbeddingModel | None = Field(
|
default_embedding_model: QualifiedModel | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Default embedding model configuration for vector stores.",
|
description="Default embedding model configuration for vector stores.",
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -35,7 +35,7 @@ from llama_stack.apis.telemetry import Telemetry
|
||||||
from llama_stack.apis.tools import RAGToolRuntime, ToolGroups, ToolRuntime
|
from llama_stack.apis.tools import RAGToolRuntime, ToolGroups, ToolRuntime
|
||||||
from llama_stack.apis.vector_io import VectorIO
|
from llama_stack.apis.vector_io import VectorIO
|
||||||
from llama_stack.core.conversations.conversations import ConversationServiceConfig, ConversationServiceImpl
|
from llama_stack.core.conversations.conversations import ConversationServiceConfig, ConversationServiceImpl
|
||||||
from llama_stack.core.datatypes import Provider, StackRunConfig
|
from llama_stack.core.datatypes import Provider, StackRunConfig, VectorStoresConfig
|
||||||
from llama_stack.core.distribution import get_provider_registry
|
from llama_stack.core.distribution import get_provider_registry
|
||||||
from llama_stack.core.inspect import DistributionInspectConfig, DistributionInspectImpl
|
from llama_stack.core.inspect import DistributionInspectConfig, DistributionInspectImpl
|
||||||
from llama_stack.core.prompts.prompts import PromptServiceConfig, PromptServiceImpl
|
from llama_stack.core.prompts.prompts import PromptServiceConfig, PromptServiceImpl
|
||||||
|
|
@ -139,58 +139,40 @@ async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def validate_vector_stores_config(run_config: StackRunConfig, impls: dict[Api, Any]):
|
async def validate_vector_stores_config(vector_stores_config: VectorStoresConfig | None, impls: dict[Api, Any]):
|
||||||
"""Validate vector stores configuration."""
|
"""Validate vector stores configuration."""
|
||||||
if not run_config.vector_stores:
|
if vector_stores_config is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
vector_stores_config = run_config.vector_stores
|
default_embedding_model = vector_stores_config.default_embedding_model
|
||||||
|
if default_embedding_model is None:
|
||||||
|
return
|
||||||
|
|
||||||
# Validate default embedding model if configured
|
provider_id = default_embedding_model.provider_id
|
||||||
if vector_stores_config.default_embedding_model:
|
model_id = default_embedding_model.model_id
|
||||||
default_embedding_model = vector_stores_config.default_embedding_model
|
default_model_id = f"{provider_id}/{model_id}"
|
||||||
provider_id = default_embedding_model.provider_id
|
|
||||||
model_id = default_embedding_model.model_id
|
|
||||||
# Construct the full model identifier
|
|
||||||
default_model_id = f"{provider_id}/{model_id}"
|
|
||||||
|
|
||||||
if Api.models not in impls:
|
if Api.models not in impls:
|
||||||
raise ValueError(
|
raise ValueError(f"Models API is not available but vector_stores config requires model '{default_model_id}'")
|
||||||
f"Models API is not available but vector_stores config requires model '{default_model_id}'"
|
|
||||||
)
|
|
||||||
|
|
||||||
models_impl = impls[Api.models]
|
models_impl = impls[Api.models]
|
||||||
response = await models_impl.list_models()
|
response = await models_impl.list_models()
|
||||||
models_list = response.data if hasattr(response, "data") else response
|
models_list = {m.identifier: m for m in response.data if m.model_type == "embedding"}
|
||||||
|
|
||||||
# find default embedding model
|
default_model = models_list.get(default_model_id)
|
||||||
default_model = None
|
if default_model is None:
|
||||||
for model in models_list:
|
raise ValueError(f"Embedding model '{default_model_id}' not found. Available embedding models: {models_list}")
|
||||||
if model.identifier == default_model_id:
|
|
||||||
default_model = model
|
|
||||||
break
|
|
||||||
|
|
||||||
if not default_model:
|
embedding_dimension = default_model.metadata.get("embedding_dimension")
|
||||||
available_models = [m.identifier for m in models_list if m.model_type == "embedding"]
|
if embedding_dimension is None:
|
||||||
raise ValueError(
|
raise ValueError(f"Embedding model '{default_model_id}' is missing 'embedding_dimension' in metadata")
|
||||||
f"Embedding model '{default_model_id}' not found. Available embedding models: {available_models}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if default_model.model_type != "embedding":
|
try:
|
||||||
raise ValueError(f"Model '{default_model_id}' is type '{default_model.model_type}', not 'embedding'")
|
int(embedding_dimension)
|
||||||
|
except ValueError as err:
|
||||||
|
raise ValueError(f"Embedding dimension '{embedding_dimension}' cannot be converted to an integer") from err
|
||||||
|
|
||||||
embedding_dimension = default_model.metadata.get("embedding_dimension")
|
logger.debug(f"Validated default embedding model: {default_model_id} (dimension: {embedding_dimension})")
|
||||||
if embedding_dimension is None:
|
|
||||||
raise ValueError(f"Embedding model '{default_model_id}' is missing 'embedding_dimension' in metadata")
|
|
||||||
|
|
||||||
try:
|
|
||||||
int(embedding_dimension)
|
|
||||||
except ValueError as err:
|
|
||||||
raise ValueError(f"Embedding dimension '{embedding_dimension}' cannot be converted to an integer") from err
|
|
||||||
|
|
||||||
logger.debug(f"Validated default embedding model: {default_model_id} (dimension: {embedding_dimension})")
|
|
||||||
|
|
||||||
# If no default embedding model is configured, that's fine - validation passes
|
|
||||||
|
|
||||||
|
|
||||||
class EnvVarError(Exception):
|
class EnvVarError(Exception):
|
||||||
|
|
@ -429,7 +411,7 @@ class Stack:
|
||||||
|
|
||||||
await register_resources(self.run_config, impls)
|
await register_resources(self.run_config, impls)
|
||||||
await refresh_registry_once(impls)
|
await refresh_registry_once(impls)
|
||||||
await validate_vector_stores_config(self.run_config, impls)
|
await validate_vector_stores_config(self.run_config.vector_stores, impls)
|
||||||
self.impls = impls
|
self.impls = impls
|
||||||
|
|
||||||
def create_registry_refresh_task(self):
|
def create_registry_refresh_task(self):
|
||||||
|
|
|
||||||
|
|
@ -9,9 +9,9 @@ from typing import Any
|
||||||
|
|
||||||
from llama_stack.core.datatypes import (
|
from llama_stack.core.datatypes import (
|
||||||
BuildProvider,
|
BuildProvider,
|
||||||
DefaultEmbeddingModel,
|
|
||||||
Provider,
|
Provider,
|
||||||
ProviderSpec,
|
ProviderSpec,
|
||||||
|
QualifiedModel,
|
||||||
ShieldInput,
|
ShieldInput,
|
||||||
ToolGroupInput,
|
ToolGroupInput,
|
||||||
VectorStoresConfig,
|
VectorStoresConfig,
|
||||||
|
|
@ -251,7 +251,7 @@ def get_distribution_template(name: str = "starter") -> DistributionTemplate:
|
||||||
default_shields=default_shields,
|
default_shields=default_shields,
|
||||||
vector_stores_config=VectorStoresConfig(
|
vector_stores_config=VectorStoresConfig(
|
||||||
default_provider_id="faiss",
|
default_provider_id="faiss",
|
||||||
default_embedding_model=DefaultEmbeddingModel(
|
default_embedding_model=QualifiedModel(
|
||||||
provider_id="sentence-transformers",
|
provider_id="sentence-transformers",
|
||||||
model_id="nomic-ai/nomic-embed-text-v1.5",
|
model_id="nomic-ai/nomic-embed-text-v1.5",
|
||||||
),
|
),
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ from unittest.mock import AsyncMock
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from llama_stack.apis.models import Model, ModelType
|
from llama_stack.apis.models import Model, ModelType
|
||||||
from llama_stack.core.datatypes import DefaultEmbeddingModel, StackRunConfig, StorageConfig, VectorStoresConfig
|
from llama_stack.core.datatypes import QualifiedModel, StackRunConfig, StorageConfig, VectorStoresConfig
|
||||||
from llama_stack.core.stack import validate_vector_stores_config
|
from llama_stack.core.stack import validate_vector_stores_config
|
||||||
from llama_stack.providers.datatypes import Api
|
from llama_stack.providers.datatypes import Api
|
||||||
|
|
||||||
|
|
@ -25,7 +25,7 @@ class TestVectorStoresValidation:
|
||||||
storage=StorageConfig(backends={}, stores={}),
|
storage=StorageConfig(backends={}, stores={}),
|
||||||
vector_stores=VectorStoresConfig(
|
vector_stores=VectorStoresConfig(
|
||||||
default_provider_id="faiss",
|
default_provider_id="faiss",
|
||||||
default_embedding_model=DefaultEmbeddingModel(
|
default_embedding_model=QualifiedModel(
|
||||||
provider_id="p",
|
provider_id="p",
|
||||||
model_id="missing",
|
model_id="missing",
|
||||||
),
|
),
|
||||||
|
|
@ -45,7 +45,7 @@ class TestVectorStoresValidation:
|
||||||
storage=StorageConfig(backends={}, stores={}),
|
storage=StorageConfig(backends={}, stores={}),
|
||||||
vector_stores=VectorStoresConfig(
|
vector_stores=VectorStoresConfig(
|
||||||
default_provider_id="faiss",
|
default_provider_id="faiss",
|
||||||
default_embedding_model=DefaultEmbeddingModel(
|
default_embedding_model=QualifiedModel(
|
||||||
provider_id="p",
|
provider_id="p",
|
||||||
model_id="valid",
|
model_id="valid",
|
||||||
),
|
),
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue