mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-27 06:28:50 +00:00
add support embedding models and keeping provider models separate
This commit is contained in:
parent
cf629f81fe
commit
8fb4feeba1
6 changed files with 264 additions and 18 deletions
|
@ -36,6 +36,11 @@ LLAMA_STACK_RUN_CONFIG_VERSION = 2
|
||||||
RoutingKey = str | list[str]
|
RoutingKey = str | list[str]
|
||||||
|
|
||||||
|
|
||||||
|
class RegistryEntrySource(StrEnum):
|
||||||
|
default = "default"
|
||||||
|
provider = "provider"
|
||||||
|
|
||||||
|
|
||||||
class User(BaseModel):
|
class User(BaseModel):
|
||||||
principal: str
|
principal: str
|
||||||
# further attributes that may be used for access control decisions
|
# further attributes that may be used for access control decisions
|
||||||
|
@ -50,6 +55,7 @@ class ResourceWithOwner(Resource):
|
||||||
resource. This can be used to constrain access to the resource."""
|
resource. This can be used to constrain access to the resource."""
|
||||||
|
|
||||||
owner: User | None = None
|
owner: User | None = None
|
||||||
|
source: RegistryEntrySource = RegistryEntrySource.default
|
||||||
|
|
||||||
|
|
||||||
# Use the extended Resource for all routable objects
|
# Use the extended Resource for all routable objects
|
||||||
|
|
|
@ -206,7 +206,6 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
if obj.type == ResourceType.model.value:
|
if obj.type == ResourceType.model.value:
|
||||||
await self.dist_registry.register(registered_obj)
|
await self.dist_registry.register(registered_obj)
|
||||||
return registered_obj
|
return registered_obj
|
||||||
|
|
||||||
else:
|
else:
|
||||||
await self.dist_registry.register(obj)
|
await self.dist_registry.register(obj)
|
||||||
return obj
|
return obj
|
||||||
|
|
|
@ -11,6 +11,7 @@ from typing import Any
|
||||||
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType, OpenAIListModelsResponse, OpenAIModel
|
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType, OpenAIListModelsResponse, OpenAIModel
|
||||||
from llama_stack.distribution.datatypes import (
|
from llama_stack.distribution.datatypes import (
|
||||||
ModelWithOwner,
|
ModelWithOwner,
|
||||||
|
RegistryEntrySource,
|
||||||
)
|
)
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
|
@ -65,7 +66,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
if models is None:
|
if models is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
await self.update_registered_llm_models(provider_id, models)
|
await self.update_registered_models(provider_id, models)
|
||||||
|
|
||||||
await asyncio.sleep(self.model_refresh_interval_seconds)
|
await asyncio.sleep(self.model_refresh_interval_seconds)
|
||||||
|
|
||||||
|
@ -131,6 +132,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
provider_id=provider_id,
|
provider_id=provider_id,
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
model_type=model_type,
|
model_type=model_type,
|
||||||
|
source=RegistryEntrySource.default,
|
||||||
)
|
)
|
||||||
registered_model = await self.register_object(model)
|
registered_model = await self.register_object(model)
|
||||||
return registered_model
|
return registered_model
|
||||||
|
@ -141,7 +143,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
raise ValueError(f"Model {model_id} not found")
|
raise ValueError(f"Model {model_id} not found")
|
||||||
await self.unregister_object(existing_model)
|
await self.unregister_object(existing_model)
|
||||||
|
|
||||||
async def update_registered_llm_models(
|
async def update_registered_models(
|
||||||
self,
|
self,
|
||||||
provider_id: str,
|
provider_id: str,
|
||||||
models: list[Model],
|
models: list[Model],
|
||||||
|
@ -152,18 +154,19 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
# from run.yaml) that we need to keep track of
|
# from run.yaml) that we need to keep track of
|
||||||
model_ids = {}
|
model_ids = {}
|
||||||
for model in existing_models:
|
for model in existing_models:
|
||||||
# we leave embeddings models alone because often we don't get metadata
|
if model.provider_id != provider_id:
|
||||||
# (embedding dimension, etc.) from the provider
|
continue
|
||||||
if model.provider_id == provider_id and model.model_type == ModelType.llm:
|
if model.source == RegistryEntrySource.default:
|
||||||
model_ids[model.provider_resource_id] = model.identifier
|
model_ids[model.provider_resource_id] = model.identifier
|
||||||
logger.debug(f"unregistering model {model.identifier}")
|
continue
|
||||||
await self.unregister_object(model)
|
|
||||||
|
logger.debug(f"unregistering model {model.identifier}")
|
||||||
|
await self.unregister_object(model)
|
||||||
|
|
||||||
for model in models:
|
for model in models:
|
||||||
if model.model_type != ModelType.llm:
|
|
||||||
continue
|
|
||||||
if model.provider_resource_id in model_ids:
|
if model.provider_resource_id in model_ids:
|
||||||
model.identifier = model_ids[model.provider_resource_id]
|
# avoid overwriting a non-provider-registered model entry
|
||||||
|
continue
|
||||||
|
|
||||||
logger.debug(f"registering model {model.identifier} ({model.provider_resource_id})")
|
logger.debug(f"registering model {model.identifier} ({model.provider_resource_id})")
|
||||||
await self.register_object(
|
await self.register_object(
|
||||||
|
@ -173,5 +176,6 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
provider_id=provider_id,
|
provider_id=provider_id,
|
||||||
metadata=model.metadata,
|
metadata=model.metadata,
|
||||||
model_type=model.model_type,
|
model_type=model.model_type,
|
||||||
|
source=RegistryEntrySource.provider,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
|
@ -20,7 +20,7 @@ from llama_stack.apis.inference import (
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
|
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate, ModelType
|
||||||
from llama_stack.providers.utils.inference.embedding_mixin import (
|
from llama_stack.providers.utils.inference.embedding_mixin import (
|
||||||
SentenceTransformerEmbeddingMixin,
|
SentenceTransformerEmbeddingMixin,
|
||||||
)
|
)
|
||||||
|
@ -41,6 +41,8 @@ class SentenceTransformersInferenceImpl(
|
||||||
InferenceProvider,
|
InferenceProvider,
|
||||||
ModelsProtocolPrivate,
|
ModelsProtocolPrivate,
|
||||||
):
|
):
|
||||||
|
__provider_id__: str
|
||||||
|
|
||||||
def __init__(self, config: SentenceTransformersInferenceConfig) -> None:
|
def __init__(self, config: SentenceTransformersInferenceConfig) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
|
@ -54,8 +56,17 @@ class SentenceTransformersInferenceImpl(
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def list_models(self) -> list[Model] | None:
|
async def list_models(self) -> list[Model] | None:
|
||||||
# TODO: add all-mini-lm models
|
return [
|
||||||
return None
|
Model(
|
||||||
|
identifier="all-MiniLM-L6-v2",
|
||||||
|
provider_resource_id="all-MiniLM-L6-v2",
|
||||||
|
provider_id=self.__provider_id__,
|
||||||
|
metadata={
|
||||||
|
"embedding_dimension": 384,
|
||||||
|
},
|
||||||
|
model_type=ModelType.embedding,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
async def register_model(self, model: Model) -> Model:
|
async def register_model(self, model: Model) -> Model:
|
||||||
return model
|
return model
|
||||||
|
|
|
@ -126,10 +126,44 @@ class OllamaInferenceAdapter(
|
||||||
async def list_models(self) -> list[Model] | None:
|
async def list_models(self) -> list[Model] | None:
|
||||||
provider_id = self.__provider_id__
|
provider_id = self.__provider_id__
|
||||||
response = await self.client.list()
|
response = await self.client.list()
|
||||||
models = []
|
|
||||||
|
# always add the two embedding models which can be pulled on demand
|
||||||
|
models = [
|
||||||
|
Model(
|
||||||
|
identifier="all-minilm:l6-v2",
|
||||||
|
provider_resource_id="all-minilm:l6-v2",
|
||||||
|
provider_id=provider_id,
|
||||||
|
metadata={
|
||||||
|
"embedding_dimension": 384,
|
||||||
|
"context_length": 512,
|
||||||
|
},
|
||||||
|
model_type=ModelType.embedding,
|
||||||
|
),
|
||||||
|
# add all-minilm alias
|
||||||
|
Model(
|
||||||
|
identifier="all-minilm",
|
||||||
|
provider_resource_id="all-minilm:l6-v2",
|
||||||
|
provider_id=provider_id,
|
||||||
|
metadata={
|
||||||
|
"embedding_dimension": 384,
|
||||||
|
"context_length": 512,
|
||||||
|
},
|
||||||
|
model_type=ModelType.embedding,
|
||||||
|
),
|
||||||
|
Model(
|
||||||
|
identifier="nomic-embed-text",
|
||||||
|
provider_resource_id="nomic-embed-text",
|
||||||
|
provider_id=provider_id,
|
||||||
|
metadata={
|
||||||
|
"embedding_dimension": 768,
|
||||||
|
"context_length": 8192,
|
||||||
|
},
|
||||||
|
model_type=ModelType.embedding,
|
||||||
|
),
|
||||||
|
]
|
||||||
for m in response.models:
|
for m in response.models:
|
||||||
model_type = ModelType.embedding if m.details.family in ["bert"] else ModelType.llm
|
# kill embedding models since we don't know dimensions for them
|
||||||
if model_type == ModelType.embedding:
|
if m.details.family in ["bert"]:
|
||||||
continue
|
continue
|
||||||
models.append(
|
models.append(
|
||||||
Model(
|
Model(
|
||||||
|
@ -137,7 +171,7 @@ class OllamaInferenceAdapter(
|
||||||
provider_resource_id=m.model,
|
provider_resource_id=m.model,
|
||||||
provider_id=provider_id,
|
provider_id=provider_id,
|
||||||
metadata={},
|
metadata={},
|
||||||
model_type=model_type,
|
model_type=ModelType.llm,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return models
|
return models
|
||||||
|
|
|
@ -15,6 +15,7 @@ from llama_stack.apis.models import Model, ModelType
|
||||||
from llama_stack.apis.shields.shields import Shield
|
from llama_stack.apis.shields.shields import Shield
|
||||||
from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolGroup, ToolParameter
|
from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolGroup, ToolParameter
|
||||||
from llama_stack.apis.vector_dbs import VectorDB
|
from llama_stack.apis.vector_dbs import VectorDB
|
||||||
|
from llama_stack.distribution.datatypes import RegistryEntrySource
|
||||||
from llama_stack.distribution.routing_tables.benchmarks import BenchmarksRoutingTable
|
from llama_stack.distribution.routing_tables.benchmarks import BenchmarksRoutingTable
|
||||||
from llama_stack.distribution.routing_tables.datasets import DatasetsRoutingTable
|
from llama_stack.distribution.routing_tables.datasets import DatasetsRoutingTable
|
||||||
from llama_stack.distribution.routing_tables.models import ModelsRoutingTable
|
from llama_stack.distribution.routing_tables.models import ModelsRoutingTable
|
||||||
|
@ -45,6 +46,30 @@ class InferenceImpl(Impl):
|
||||||
async def unregister_model(self, model_id: str):
|
async def unregister_model(self, model_id: str):
|
||||||
return model_id
|
return model_id
|
||||||
|
|
||||||
|
async def should_refresh_models(self):
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def list_models(self):
|
||||||
|
return [
|
||||||
|
Model(
|
||||||
|
identifier="provider-model-1",
|
||||||
|
provider_resource_id="provider-model-1",
|
||||||
|
provider_id="test_provider",
|
||||||
|
metadata={},
|
||||||
|
model_type=ModelType.llm,
|
||||||
|
),
|
||||||
|
Model(
|
||||||
|
identifier="provider-model-2",
|
||||||
|
provider_resource_id="provider-model-2",
|
||||||
|
provider_id="test_provider",
|
||||||
|
metadata={"embedding_dimension": 512},
|
||||||
|
model_type=ModelType.embedding,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
async def shutdown(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class SafetyImpl(Impl):
|
class SafetyImpl(Impl):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -378,3 +403,170 @@ async def test_models_fallback_lookup_behavior(cached_disk_dist_registry):
|
||||||
raise AssertionError("Should have raised ValueError for non-existent model")
|
raise AssertionError("Should have raised ValueError for non-existent model")
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
assert "not found" in str(e)
|
assert "not found" in str(e)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_models_source_tracking_default(cached_disk_dist_registry):
|
||||||
|
"""Test that models registered via register_model get default source."""
|
||||||
|
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
|
||||||
|
await table.initialize()
|
||||||
|
|
||||||
|
# Register model via register_model (should get default source)
|
||||||
|
await table.register_model(model_id="user-model", provider_id="test_provider")
|
||||||
|
|
||||||
|
models = await table.list_models()
|
||||||
|
assert len(models.data) == 1
|
||||||
|
model = models.data[0]
|
||||||
|
assert model.source == RegistryEntrySource.default
|
||||||
|
assert model.identifier == "test_provider/user-model"
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
await table.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_models_source_tracking_provider(cached_disk_dist_registry):
|
||||||
|
"""Test that models registered via update_registered_models get provider source."""
|
||||||
|
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
|
||||||
|
await table.initialize()
|
||||||
|
|
||||||
|
# Simulate provider refresh by calling update_registered_models
|
||||||
|
provider_models = [
|
||||||
|
Model(
|
||||||
|
identifier="provider-model-1",
|
||||||
|
provider_resource_id="provider-model-1",
|
||||||
|
provider_id="test_provider",
|
||||||
|
metadata={},
|
||||||
|
model_type=ModelType.llm,
|
||||||
|
),
|
||||||
|
Model(
|
||||||
|
identifier="provider-model-2",
|
||||||
|
provider_resource_id="provider-model-2",
|
||||||
|
provider_id="test_provider",
|
||||||
|
metadata={"embedding_dimension": 512},
|
||||||
|
model_type=ModelType.embedding,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
await table.update_registered_models("test_provider", provider_models)
|
||||||
|
|
||||||
|
models = await table.list_models()
|
||||||
|
assert len(models.data) == 2
|
||||||
|
|
||||||
|
# All models should have provider source
|
||||||
|
for model in models.data:
|
||||||
|
assert model.source == RegistryEntrySource.provider
|
||||||
|
assert model.provider_id == "test_provider"
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
await table.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_models_source_interaction_preserves_default(cached_disk_dist_registry):
|
||||||
|
"""Test that provider refresh preserves user-registered models with default source."""
|
||||||
|
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
|
||||||
|
await table.initialize()
|
||||||
|
|
||||||
|
# First register a user model with same provider_resource_id as provider will later provide
|
||||||
|
await table.register_model(
|
||||||
|
model_id="my-custom-alias", provider_model_id="provider-model-1", provider_id="test_provider"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify user model is registered with default source
|
||||||
|
models = await table.list_models()
|
||||||
|
assert len(models.data) == 1
|
||||||
|
user_model = models.data[0]
|
||||||
|
assert user_model.source == RegistryEntrySource.default
|
||||||
|
assert user_model.identifier == "my-custom-alias"
|
||||||
|
assert user_model.provider_resource_id == "provider-model-1"
|
||||||
|
|
||||||
|
# Now simulate provider refresh
|
||||||
|
provider_models = [
|
||||||
|
Model(
|
||||||
|
identifier="provider-model-1",
|
||||||
|
provider_resource_id="provider-model-1",
|
||||||
|
provider_id="test_provider",
|
||||||
|
metadata={},
|
||||||
|
model_type=ModelType.llm,
|
||||||
|
),
|
||||||
|
Model(
|
||||||
|
identifier="different-model",
|
||||||
|
provider_resource_id="different-model",
|
||||||
|
provider_id="test_provider",
|
||||||
|
metadata={},
|
||||||
|
model_type=ModelType.llm,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
await table.update_registered_models("test_provider", provider_models)
|
||||||
|
|
||||||
|
# Verify user model with alias is preserved, but provider added new model
|
||||||
|
models = await table.list_models()
|
||||||
|
assert len(models.data) == 2
|
||||||
|
|
||||||
|
# Find the user model and provider model
|
||||||
|
user_model = next((m for m in models.data if m.identifier == "my-custom-alias"), None)
|
||||||
|
provider_model = next((m for m in models.data if m.identifier == "different-model"), None)
|
||||||
|
|
||||||
|
assert user_model is not None
|
||||||
|
assert user_model.source == RegistryEntrySource.default
|
||||||
|
assert user_model.provider_resource_id == "provider-model-1"
|
||||||
|
|
||||||
|
assert provider_model is not None
|
||||||
|
assert provider_model.source == RegistryEntrySource.provider
|
||||||
|
assert provider_model.provider_resource_id == "different-model"
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
await table.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_models_source_interaction_cleanup_provider_models(cached_disk_dist_registry):
|
||||||
|
"""Test that provider refresh removes old provider models but keeps default ones."""
|
||||||
|
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
|
||||||
|
await table.initialize()
|
||||||
|
|
||||||
|
# Register a user model
|
||||||
|
await table.register_model(model_id="user-model", provider_id="test_provider")
|
||||||
|
|
||||||
|
# Add some provider models
|
||||||
|
provider_models_v1 = [
|
||||||
|
Model(
|
||||||
|
identifier="provider-model-old",
|
||||||
|
provider_resource_id="provider-model-old",
|
||||||
|
provider_id="test_provider",
|
||||||
|
metadata={},
|
||||||
|
model_type=ModelType.llm,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
await table.update_registered_models("test_provider", provider_models_v1)
|
||||||
|
|
||||||
|
# Verify we have both user and provider models
|
||||||
|
models = await table.list_models()
|
||||||
|
assert len(models.data) == 2
|
||||||
|
|
||||||
|
# Now update with new provider models (should remove old provider models)
|
||||||
|
provider_models_v2 = [
|
||||||
|
Model(
|
||||||
|
identifier="provider-model-new",
|
||||||
|
provider_resource_id="provider-model-new",
|
||||||
|
provider_id="test_provider",
|
||||||
|
metadata={},
|
||||||
|
model_type=ModelType.llm,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
await table.update_registered_models("test_provider", provider_models_v2)
|
||||||
|
|
||||||
|
# Should have user model + new provider model, old provider model gone
|
||||||
|
models = await table.list_models()
|
||||||
|
assert len(models.data) == 2
|
||||||
|
|
||||||
|
identifiers = {m.identifier for m in models.data}
|
||||||
|
assert "test_provider/user-model" in identifiers # User model preserved
|
||||||
|
assert "provider-model-new" in identifiers # New provider model (uses provider's identifier)
|
||||||
|
assert "provider-model-old" not in identifiers # Old provider model removed
|
||||||
|
|
||||||
|
# Verify sources are correct
|
||||||
|
user_model = next((m for m in models.data if m.identifier == "test_provider/user-model"), None)
|
||||||
|
provider_model = next((m for m in models.data if m.identifier == "provider-model-new"), None)
|
||||||
|
|
||||||
|
assert user_model.source == RegistryEntrySource.default
|
||||||
|
assert provider_model.source == RegistryEntrySource.provider
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
await table.shutdown()
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue