diff --git a/llama_stack/apis/memory_banks/memory_banks.py b/llama_stack/apis/memory_banks/memory_banks.py index a17e8e48d..b037dfa66 100644 --- a/llama_stack/apis/memory_banks/memory_banks.py +++ b/llama_stack/apis/memory_banks/memory_banks.py @@ -89,6 +89,7 @@ class VectorMemoryBank(MemoryBankResourceMixin): memory_bank_type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value embedding_model: str chunk_size_in_tokens: int + embedding_dimension: Optional[int] = 384 # default to minilm-l6-v2 overlap_size_in_tokens: Optional[int] = None diff --git a/llama_stack/apis/models/models.py b/llama_stack/apis/models/models.py index cb9cb1117..ed9549d63 100644 --- a/llama_stack/apis/models/models.py +++ b/llama_stack/apis/models/models.py @@ -4,6 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from enum import Enum from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable from llama_models.schema_utils import json_schema_type, webmethod @@ -20,6 +21,11 @@ class CommonModelFields(BaseModel): ) +class ModelType(Enum): + llm = "llm" + embedding_model = "embedding" + + @json_schema_type class Model(CommonModelFields, Resource): type: Literal[ResourceType.model.value] = ResourceType.model.value @@ -34,11 +40,14 @@ class Model(CommonModelFields, Resource): model_config = ConfigDict(protected_namespaces=()) + model_type: ModelType = Field(default=ModelType.llm) + class ModelInput(CommonModelFields): model_id: str provider_id: Optional[str] = None provider_model_id: Optional[str] = None + model_type: Optional[ModelType] = ModelType.llm model_config = ConfigDict(protected_namespaces=()) @@ -59,6 +68,7 @@ class Models(Protocol): provider_model_id: Optional[str] = None, provider_id: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None, + model_type: Optional[ModelType] = None, ) -> Model: ... @webmethod(route="/models/unregister", method="POST") diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 5b75a525b..51be318cb 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -88,9 +88,10 @@ class InferenceRouter(Inference): provider_model_id: Optional[str] = None, provider_id: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None, + model_type: Optional[ModelType] = None, ) -> None: await self.routing_table.register_model( - model_id, provider_model_id, provider_id, metadata + model_id, provider_model_id, provider_id, metadata, model_type ) async def chat_completion( @@ -105,6 +106,13 @@ class InferenceRouter(Inference): stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: + model = await self.routing_table.get_model(model_id) + if model is None: + raise ValueError(f"Model '{model_id}' not found") + if model.model_type == ModelType.embedding_model: + raise ValueError( + f"Model '{model_id}' is an embedding model and does not support chat completions" + ) params = dict( model_id=model_id, messages=messages, @@ -131,6 +139,13 @@ class InferenceRouter(Inference): stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: + model = await self.routing_table.get_model(model_id) + if model is None: + raise ValueError(f"Model '{model_id}' not found") + if model.model_type == ModelType.embedding_model: + raise ValueError( + f"Model '{model_id}' is an embedding model and does not support chat completions" + ) provider = self.routing_table.get_provider_impl(model_id) params = dict( model_id=model_id, @@ -150,6 +165,13 @@ class InferenceRouter(Inference): model_id: str, contents: List[InterleavedTextMedia], ) -> EmbeddingsResponse: + model = await self.routing_table.get_model(model_id) + if model is None: + raise ValueError(f"Model '{model_id}' not found") + if model.model_type == ModelType.llm: + raise ValueError( + f"Model '{model_id}' is an LLM model and does not support embeddings" + ) return await self.routing_table.get_provider_impl(model_id).embeddings( model_id=model_id, contents=contents, diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 2fb5a5e1c..bc3de8be0 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -209,6 +209,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): provider_model_id: Optional[str] = None, provider_id: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None, + model_type: Optional[ModelType] = None, ) -> Model: if provider_model_id is None: provider_model_id = model_id @@ -222,11 +223,21 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): ) if metadata is None: metadata = {} + if model_type is None: + model_type = ModelType.llm + if ( + "embedding_dimension" not in metadata + and model_type == ModelType.embedding_model + ): + raise ValueError( + "Embedding model must have an embedding dimension in its metadata" + ) model = Model( identifier=model_id, provider_resource_id=provider_model_id, provider_id=provider_id, metadata=metadata, + model_type=model_type, ) registered_model = await self.register_object(model) return registered_model @@ -298,16 +309,29 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks): raise ValueError( "No provider specified and multiple providers available. Please specify a provider_id." ) - memory_bank = parse_obj_as( - MemoryBank, - { - "identifier": memory_bank_id, - "type": ResourceType.memory_bank.value, - "provider_id": provider_id, - "provider_resource_id": provider_memory_bank_id, - **params.model_dump(), - }, - ) + model = await self.get_object_by_identifier("model", params.embedding_model) + if model is None: + raise ValueError(f"Model {params.embedding_model} not found") + if model.model_type != ModelType.embedding_model: + raise ValueError( + f"Model {params.embedding_model} is not an embedding model" + ) + if "embedding_dimension" not in model.metadata: + raise ValueError( + f"Model {params.embedding_model} does not have an embedding dimension" + ) + memory_bank_data = { + "identifier": memory_bank_id, + "type": ResourceType.memory_bank.value, + "provider_id": provider_id, + "provider_resource_id": provider_memory_bank_id, + **params.model_dump(), + } + if params.memory_bank_type == MemoryBankType.vector.value: + memory_bank_data["embedding_dimension"] = model.metadata[ + "embedding_dimension" + ] + memory_bank = parse_obj_as(MemoryBank, memory_bank_data) await self.register_object(memory_bank) return memory_bank diff --git a/llama_stack/distribution/store/registry.py b/llama_stack/distribution/store/registry.py index 041a5677c..8f93c0c4b 100644 --- a/llama_stack/distribution/store/registry.py +++ b/llama_stack/distribution/store/registry.py @@ -40,7 +40,7 @@ class DistributionRegistry(Protocol): REGISTER_PREFIX = "distributions:registry" -KEY_VERSION = "v2" +KEY_VERSION = "v3" KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}" diff --git a/llama_stack/providers/utils/inference/model_registry.py b/llama_stack/providers/utils/inference/model_registry.py index 8dbfab14a..be2642cdb 100644 --- a/llama_stack/providers/utils/inference/model_registry.py +++ b/llama_stack/providers/utils/inference/model_registry.py @@ -9,6 +9,7 @@ from typing import List, Optional from llama_models.sku_list import all_registered_models +from llama_stack.apis.models.models import ModelType from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate from llama_stack.providers.utils.inference import ( @@ -77,7 +78,13 @@ class ModelRegistryHelper(ModelsProtocolPrivate): return None async def register_model(self, model: Model) -> Model: - provider_resource_id = self.get_provider_model_id(model.provider_resource_id) + if model.model_type == ModelType.embedding_model: + # embedding models are always registered by their provider model id and does not need to be mapped to a llama model + provider_resource_id = model.provider_resource_id + else: + provider_resource_id = self.get_provider_model_id( + model.provider_resource_id + ) if provider_resource_id: model.provider_resource_id = provider_resource_id else: