add support embedding models and keeping provider models separate

This commit is contained in:
Ashwin Bharambe 2025-07-23 16:13:47 -07:00
parent cf629f81fe
commit 8fb4feeba1
6 changed files with 264 additions and 18 deletions

View file

@ -36,6 +36,11 @@ LLAMA_STACK_RUN_CONFIG_VERSION = 2
RoutingKey = str | list[str]
class RegistryEntrySource(StrEnum):
default = "default"
provider = "provider"
class User(BaseModel):
principal: str
# further attributes that may be used for access control decisions
@ -50,6 +55,7 @@ class ResourceWithOwner(Resource):
resource. This can be used to constrain access to the resource."""
owner: User | None = None
source: RegistryEntrySource = RegistryEntrySource.default
# Use the extended Resource for all routable objects

View file

@ -206,7 +206,6 @@ class CommonRoutingTableImpl(RoutingTable):
if obj.type == ResourceType.model.value:
await self.dist_registry.register(registered_obj)
return registered_obj
else:
await self.dist_registry.register(obj)
return obj

View file

@ -11,6 +11,7 @@ from typing import Any
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType, OpenAIListModelsResponse, OpenAIModel
from llama_stack.distribution.datatypes import (
ModelWithOwner,
RegistryEntrySource,
)
from llama_stack.log import get_logger
@ -65,7 +66,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
if models is None:
continue
await self.update_registered_llm_models(provider_id, models)
await self.update_registered_models(provider_id, models)
await asyncio.sleep(self.model_refresh_interval_seconds)
@ -131,6 +132,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
provider_id=provider_id,
metadata=metadata,
model_type=model_type,
source=RegistryEntrySource.default,
)
registered_model = await self.register_object(model)
return registered_model
@ -141,7 +143,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
raise ValueError(f"Model {model_id} not found")
await self.unregister_object(existing_model)
async def update_registered_llm_models(
async def update_registered_models(
self,
provider_id: str,
models: list[Model],
@ -152,18 +154,19 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
# from run.yaml) that we need to keep track of
model_ids = {}
for model in existing_models:
# we leave embeddings models alone because often we don't get metadata
# (embedding dimension, etc.) from the provider
if model.provider_id == provider_id and model.model_type == ModelType.llm:
if model.provider_id != provider_id:
continue
if model.source == RegistryEntrySource.default:
model_ids[model.provider_resource_id] = model.identifier
logger.debug(f"unregistering model {model.identifier}")
await self.unregister_object(model)
continue
logger.debug(f"unregistering model {model.identifier}")
await self.unregister_object(model)
for model in models:
if model.model_type != ModelType.llm:
continue
if model.provider_resource_id in model_ids:
model.identifier = model_ids[model.provider_resource_id]
# avoid overwriting a non-provider-registered model entry
continue
logger.debug(f"registering model {model.identifier} ({model.provider_resource_id})")
await self.register_object(
@ -173,5 +176,6 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
provider_id=provider_id,
metadata=model.metadata,
model_type=model.model_type,
source=RegistryEntrySource.provider,
)
)

View file

@ -20,7 +20,7 @@ from llama_stack.apis.inference import (
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate, ModelType
from llama_stack.providers.utils.inference.embedding_mixin import (
SentenceTransformerEmbeddingMixin,
)
@ -41,6 +41,8 @@ class SentenceTransformersInferenceImpl(
InferenceProvider,
ModelsProtocolPrivate,
):
__provider_id__: str
def __init__(self, config: SentenceTransformersInferenceConfig) -> None:
self.config = config
@ -54,8 +56,17 @@ class SentenceTransformersInferenceImpl(
return False
async def list_models(self) -> list[Model] | None:
# TODO: add all-mini-lm models
return None
return [
Model(
identifier="all-MiniLM-L6-v2",
provider_resource_id="all-MiniLM-L6-v2",
provider_id=self.__provider_id__,
metadata={
"embedding_dimension": 384,
},
model_type=ModelType.embedding,
),
]
async def register_model(self, model: Model) -> Model:
return model

View file

@ -126,10 +126,44 @@ class OllamaInferenceAdapter(
async def list_models(self) -> list[Model] | None:
provider_id = self.__provider_id__
response = await self.client.list()
models = []
# always add the two embedding models which can be pulled on demand
models = [
Model(
identifier="all-minilm:l6-v2",
provider_resource_id="all-minilm:l6-v2",
provider_id=provider_id,
metadata={
"embedding_dimension": 384,
"context_length": 512,
},
model_type=ModelType.embedding,
),
# add all-minilm alias
Model(
identifier="all-minilm",
provider_resource_id="all-minilm:l6-v2",
provider_id=provider_id,
metadata={
"embedding_dimension": 384,
"context_length": 512,
},
model_type=ModelType.embedding,
),
Model(
identifier="nomic-embed-text",
provider_resource_id="nomic-embed-text",
provider_id=provider_id,
metadata={
"embedding_dimension": 768,
"context_length": 8192,
},
model_type=ModelType.embedding,
),
]
for m in response.models:
model_type = ModelType.embedding if m.details.family in ["bert"] else ModelType.llm
if model_type == ModelType.embedding:
# kill embedding models since we don't know dimensions for them
if m.details.family in ["bert"]:
continue
models.append(
Model(
@ -137,7 +171,7 @@ class OllamaInferenceAdapter(
provider_resource_id=m.model,
provider_id=provider_id,
metadata={},
model_type=model_type,
model_type=ModelType.llm,
)
)
return models