add support embedding models and keeping provider models separate

This commit is contained in:
Ashwin Bharambe 2025-07-23 16:13:47 -07:00
parent cf629f81fe
commit 8fb4feeba1
6 changed files with 264 additions and 18 deletions

View file

@ -36,6 +36,11 @@ LLAMA_STACK_RUN_CONFIG_VERSION = 2
RoutingKey = str | list[str] RoutingKey = str | list[str]
class RegistryEntrySource(StrEnum):
default = "default"
provider = "provider"
class User(BaseModel): class User(BaseModel):
principal: str principal: str
# further attributes that may be used for access control decisions # further attributes that may be used for access control decisions
@ -50,6 +55,7 @@ class ResourceWithOwner(Resource):
resource. This can be used to constrain access to the resource.""" resource. This can be used to constrain access to the resource."""
owner: User | None = None owner: User | None = None
source: RegistryEntrySource = RegistryEntrySource.default
# Use the extended Resource for all routable objects # Use the extended Resource for all routable objects

View file

@ -206,7 +206,6 @@ class CommonRoutingTableImpl(RoutingTable):
if obj.type == ResourceType.model.value: if obj.type == ResourceType.model.value:
await self.dist_registry.register(registered_obj) await self.dist_registry.register(registered_obj)
return registered_obj return registered_obj
else: else:
await self.dist_registry.register(obj) await self.dist_registry.register(obj)
return obj return obj

View file

@ -11,6 +11,7 @@ from typing import Any
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType, OpenAIListModelsResponse, OpenAIModel from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType, OpenAIListModelsResponse, OpenAIModel
from llama_stack.distribution.datatypes import ( from llama_stack.distribution.datatypes import (
ModelWithOwner, ModelWithOwner,
RegistryEntrySource,
) )
from llama_stack.log import get_logger from llama_stack.log import get_logger
@ -65,7 +66,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
if models is None: if models is None:
continue continue
await self.update_registered_llm_models(provider_id, models) await self.update_registered_models(provider_id, models)
await asyncio.sleep(self.model_refresh_interval_seconds) await asyncio.sleep(self.model_refresh_interval_seconds)
@ -131,6 +132,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
provider_id=provider_id, provider_id=provider_id,
metadata=metadata, metadata=metadata,
model_type=model_type, model_type=model_type,
source=RegistryEntrySource.default,
) )
registered_model = await self.register_object(model) registered_model = await self.register_object(model)
return registered_model return registered_model
@ -141,7 +143,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
raise ValueError(f"Model {model_id} not found") raise ValueError(f"Model {model_id} not found")
await self.unregister_object(existing_model) await self.unregister_object(existing_model)
async def update_registered_llm_models( async def update_registered_models(
self, self,
provider_id: str, provider_id: str,
models: list[Model], models: list[Model],
@ -152,18 +154,19 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
# from run.yaml) that we need to keep track of # from run.yaml) that we need to keep track of
model_ids = {} model_ids = {}
for model in existing_models: for model in existing_models:
# we leave embeddings models alone because often we don't get metadata if model.provider_id != provider_id:
# (embedding dimension, etc.) from the provider continue
if model.provider_id == provider_id and model.model_type == ModelType.llm: if model.source == RegistryEntrySource.default:
model_ids[model.provider_resource_id] = model.identifier model_ids[model.provider_resource_id] = model.identifier
logger.debug(f"unregistering model {model.identifier}") continue
await self.unregister_object(model)
logger.debug(f"unregistering model {model.identifier}")
await self.unregister_object(model)
for model in models: for model in models:
if model.model_type != ModelType.llm:
continue
if model.provider_resource_id in model_ids: if model.provider_resource_id in model_ids:
model.identifier = model_ids[model.provider_resource_id] # avoid overwriting a non-provider-registered model entry
continue
logger.debug(f"registering model {model.identifier} ({model.provider_resource_id})") logger.debug(f"registering model {model.identifier} ({model.provider_resource_id})")
await self.register_object( await self.register_object(
@ -173,5 +176,6 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
provider_id=provider_id, provider_id=provider_id,
metadata=model.metadata, metadata=model.metadata,
model_type=model.model_type, model_type=model.model_type,
source=RegistryEntrySource.provider,
) )
) )

View file

@ -20,7 +20,7 @@ from llama_stack.apis.inference import (
ToolDefinition, ToolDefinition,
ToolPromptFormat, ToolPromptFormat,
) )
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate, ModelType
from llama_stack.providers.utils.inference.embedding_mixin import ( from llama_stack.providers.utils.inference.embedding_mixin import (
SentenceTransformerEmbeddingMixin, SentenceTransformerEmbeddingMixin,
) )
@ -41,6 +41,8 @@ class SentenceTransformersInferenceImpl(
InferenceProvider, InferenceProvider,
ModelsProtocolPrivate, ModelsProtocolPrivate,
): ):
__provider_id__: str
def __init__(self, config: SentenceTransformersInferenceConfig) -> None: def __init__(self, config: SentenceTransformersInferenceConfig) -> None:
self.config = config self.config = config
@ -54,8 +56,17 @@ class SentenceTransformersInferenceImpl(
return False return False
async def list_models(self) -> list[Model] | None: async def list_models(self) -> list[Model] | None:
# TODO: add all-mini-lm models return [
return None Model(
identifier="all-MiniLM-L6-v2",
provider_resource_id="all-MiniLM-L6-v2",
provider_id=self.__provider_id__,
metadata={
"embedding_dimension": 384,
},
model_type=ModelType.embedding,
),
]
async def register_model(self, model: Model) -> Model: async def register_model(self, model: Model) -> Model:
return model return model

View file

@ -126,10 +126,44 @@ class OllamaInferenceAdapter(
async def list_models(self) -> list[Model] | None: async def list_models(self) -> list[Model] | None:
provider_id = self.__provider_id__ provider_id = self.__provider_id__
response = await self.client.list() response = await self.client.list()
models = []
# always add the two embedding models which can be pulled on demand
models = [
Model(
identifier="all-minilm:l6-v2",
provider_resource_id="all-minilm:l6-v2",
provider_id=provider_id,
metadata={
"embedding_dimension": 384,
"context_length": 512,
},
model_type=ModelType.embedding,
),
# add all-minilm alias
Model(
identifier="all-minilm",
provider_resource_id="all-minilm:l6-v2",
provider_id=provider_id,
metadata={
"embedding_dimension": 384,
"context_length": 512,
},
model_type=ModelType.embedding,
),
Model(
identifier="nomic-embed-text",
provider_resource_id="nomic-embed-text",
provider_id=provider_id,
metadata={
"embedding_dimension": 768,
"context_length": 8192,
},
model_type=ModelType.embedding,
),
]
for m in response.models: for m in response.models:
model_type = ModelType.embedding if m.details.family in ["bert"] else ModelType.llm # kill embedding models since we don't know dimensions for them
if model_type == ModelType.embedding: if m.details.family in ["bert"]:
continue continue
models.append( models.append(
Model( Model(
@ -137,7 +171,7 @@ class OllamaInferenceAdapter(
provider_resource_id=m.model, provider_resource_id=m.model,
provider_id=provider_id, provider_id=provider_id,
metadata={}, metadata={},
model_type=model_type, model_type=ModelType.llm,
) )
) )
return models return models

View file

@ -15,6 +15,7 @@ from llama_stack.apis.models import Model, ModelType
from llama_stack.apis.shields.shields import Shield from llama_stack.apis.shields.shields import Shield
from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolGroup, ToolParameter from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolGroup, ToolParameter
from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.distribution.datatypes import RegistryEntrySource
from llama_stack.distribution.routing_tables.benchmarks import BenchmarksRoutingTable from llama_stack.distribution.routing_tables.benchmarks import BenchmarksRoutingTable
from llama_stack.distribution.routing_tables.datasets import DatasetsRoutingTable from llama_stack.distribution.routing_tables.datasets import DatasetsRoutingTable
from llama_stack.distribution.routing_tables.models import ModelsRoutingTable from llama_stack.distribution.routing_tables.models import ModelsRoutingTable
@ -45,6 +46,30 @@ class InferenceImpl(Impl):
async def unregister_model(self, model_id: str): async def unregister_model(self, model_id: str):
return model_id return model_id
async def should_refresh_models(self):
return False
async def list_models(self):
return [
Model(
identifier="provider-model-1",
provider_resource_id="provider-model-1",
provider_id="test_provider",
metadata={},
model_type=ModelType.llm,
),
Model(
identifier="provider-model-2",
provider_resource_id="provider-model-2",
provider_id="test_provider",
metadata={"embedding_dimension": 512},
model_type=ModelType.embedding,
),
]
async def shutdown(self):
pass
class SafetyImpl(Impl): class SafetyImpl(Impl):
def __init__(self): def __init__(self):
@ -378,3 +403,170 @@ async def test_models_fallback_lookup_behavior(cached_disk_dist_registry):
raise AssertionError("Should have raised ValueError for non-existent model") raise AssertionError("Should have raised ValueError for non-existent model")
except ValueError as e: except ValueError as e:
assert "not found" in str(e) assert "not found" in str(e)
async def test_models_source_tracking_default(cached_disk_dist_registry):
"""Test that models registered via register_model get default source."""
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
await table.initialize()
# Register model via register_model (should get default source)
await table.register_model(model_id="user-model", provider_id="test_provider")
models = await table.list_models()
assert len(models.data) == 1
model = models.data[0]
assert model.source == RegistryEntrySource.default
assert model.identifier == "test_provider/user-model"
# Cleanup
await table.shutdown()
async def test_models_source_tracking_provider(cached_disk_dist_registry):
"""Test that models registered via update_registered_models get provider source."""
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
await table.initialize()
# Simulate provider refresh by calling update_registered_models
provider_models = [
Model(
identifier="provider-model-1",
provider_resource_id="provider-model-1",
provider_id="test_provider",
metadata={},
model_type=ModelType.llm,
),
Model(
identifier="provider-model-2",
provider_resource_id="provider-model-2",
provider_id="test_provider",
metadata={"embedding_dimension": 512},
model_type=ModelType.embedding,
),
]
await table.update_registered_models("test_provider", provider_models)
models = await table.list_models()
assert len(models.data) == 2
# All models should have provider source
for model in models.data:
assert model.source == RegistryEntrySource.provider
assert model.provider_id == "test_provider"
# Cleanup
await table.shutdown()
async def test_models_source_interaction_preserves_default(cached_disk_dist_registry):
"""Test that provider refresh preserves user-registered models with default source."""
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
await table.initialize()
# First register a user model with same provider_resource_id as provider will later provide
await table.register_model(
model_id="my-custom-alias", provider_model_id="provider-model-1", provider_id="test_provider"
)
# Verify user model is registered with default source
models = await table.list_models()
assert len(models.data) == 1
user_model = models.data[0]
assert user_model.source == RegistryEntrySource.default
assert user_model.identifier == "my-custom-alias"
assert user_model.provider_resource_id == "provider-model-1"
# Now simulate provider refresh
provider_models = [
Model(
identifier="provider-model-1",
provider_resource_id="provider-model-1",
provider_id="test_provider",
metadata={},
model_type=ModelType.llm,
),
Model(
identifier="different-model",
provider_resource_id="different-model",
provider_id="test_provider",
metadata={},
model_type=ModelType.llm,
),
]
await table.update_registered_models("test_provider", provider_models)
# Verify user model with alias is preserved, but provider added new model
models = await table.list_models()
assert len(models.data) == 2
# Find the user model and provider model
user_model = next((m for m in models.data if m.identifier == "my-custom-alias"), None)
provider_model = next((m for m in models.data if m.identifier == "different-model"), None)
assert user_model is not None
assert user_model.source == RegistryEntrySource.default
assert user_model.provider_resource_id == "provider-model-1"
assert provider_model is not None
assert provider_model.source == RegistryEntrySource.provider
assert provider_model.provider_resource_id == "different-model"
# Cleanup
await table.shutdown()
async def test_models_source_interaction_cleanup_provider_models(cached_disk_dist_registry):
"""Test that provider refresh removes old provider models but keeps default ones."""
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
await table.initialize()
# Register a user model
await table.register_model(model_id="user-model", provider_id="test_provider")
# Add some provider models
provider_models_v1 = [
Model(
identifier="provider-model-old",
provider_resource_id="provider-model-old",
provider_id="test_provider",
metadata={},
model_type=ModelType.llm,
),
]
await table.update_registered_models("test_provider", provider_models_v1)
# Verify we have both user and provider models
models = await table.list_models()
assert len(models.data) == 2
# Now update with new provider models (should remove old provider models)
provider_models_v2 = [
Model(
identifier="provider-model-new",
provider_resource_id="provider-model-new",
provider_id="test_provider",
metadata={},
model_type=ModelType.llm,
),
]
await table.update_registered_models("test_provider", provider_models_v2)
# Should have user model + new provider model, old provider model gone
models = await table.list_models()
assert len(models.data) == 2
identifiers = {m.identifier for m in models.data}
assert "test_provider/user-model" in identifiers # User model preserved
assert "provider-model-new" in identifiers # New provider model (uses provider's identifier)
assert "provider-model-old" not in identifiers # Old provider model removed
# Verify sources are correct
user_model = next((m for m in models.data if m.identifier == "test_provider/user-model"), None)
provider_model = next((m for m in models.data if m.identifier == "provider-model-new"), None)
assert user_model.source == RegistryEntrySource.default
assert provider_model.source == RegistryEntrySource.provider
# Cleanup
await table.shutdown()