From 47b2dc8ae3d5278ac06f3e8561b9d7976a085cd6 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Wed, 11 Dec 2024 10:17:54 -0800 Subject: [PATCH] Revert "add model type to APIs" (#605) Reverts meta-llama/llama-stack#588 --- llama_stack/apis/memory_banks/memory_banks.py | 1 - llama_stack/apis/models/models.py | 10 ----- llama_stack/distribution/routers/routers.py | 24 +--------- .../distribution/routers/routing_tables.py | 44 +++++-------------- llama_stack/distribution/store/registry.py | 2 +- .../utils/inference/model_registry.py | 9 +--- 6 files changed, 13 insertions(+), 77 deletions(-) diff --git a/llama_stack/apis/memory_banks/memory_banks.py b/llama_stack/apis/memory_banks/memory_banks.py index b037dfa66..a17e8e48d 100644 --- a/llama_stack/apis/memory_banks/memory_banks.py +++ b/llama_stack/apis/memory_banks/memory_banks.py @@ -89,7 +89,6 @@ class VectorMemoryBank(MemoryBankResourceMixin): memory_bank_type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value embedding_model: str chunk_size_in_tokens: int - embedding_dimension: Optional[int] = 384 # default to minilm-l6-v2 overlap_size_in_tokens: Optional[int] = None diff --git a/llama_stack/apis/models/models.py b/llama_stack/apis/models/models.py index ed9549d63..cb9cb1117 100644 --- a/llama_stack/apis/models/models.py +++ b/llama_stack/apis/models/models.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from enum import Enum from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable from llama_models.schema_utils import json_schema_type, webmethod @@ -21,11 +20,6 @@ class CommonModelFields(BaseModel): ) -class ModelType(Enum): - llm = "llm" - embedding_model = "embedding" - - @json_schema_type class Model(CommonModelFields, Resource): type: Literal[ResourceType.model.value] = ResourceType.model.value @@ -40,14 +34,11 @@ class Model(CommonModelFields, Resource): model_config = ConfigDict(protected_namespaces=()) - model_type: ModelType = Field(default=ModelType.llm) - class ModelInput(CommonModelFields): model_id: str provider_id: Optional[str] = None provider_model_id: Optional[str] = None - model_type: Optional[ModelType] = ModelType.llm model_config = ConfigDict(protected_namespaces=()) @@ -68,7 +59,6 @@ class Models(Protocol): provider_model_id: Optional[str] = None, provider_id: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None, - model_type: Optional[ModelType] = None, ) -> Model: ... @webmethod(route="/models/unregister", method="POST") diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 51be318cb..5b75a525b 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -88,10 +88,9 @@ class InferenceRouter(Inference): provider_model_id: Optional[str] = None, provider_id: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None, - model_type: Optional[ModelType] = None, ) -> None: await self.routing_table.register_model( - model_id, provider_model_id, provider_id, metadata, model_type + model_id, provider_model_id, provider_id, metadata ) async def chat_completion( @@ -106,13 +105,6 @@ class InferenceRouter(Inference): stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: - model = await self.routing_table.get_model(model_id) - if model is None: - raise ValueError(f"Model '{model_id}' not found") - if model.model_type == ModelType.embedding_model: - raise ValueError( - f"Model '{model_id}' is an embedding model and does not support chat completions" - ) params = dict( model_id=model_id, messages=messages, @@ -139,13 +131,6 @@ class InferenceRouter(Inference): stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: - model = await self.routing_table.get_model(model_id) - if model is None: - raise ValueError(f"Model '{model_id}' not found") - if model.model_type == ModelType.embedding_model: - raise ValueError( - f"Model '{model_id}' is an embedding model and does not support chat completions" - ) provider = self.routing_table.get_provider_impl(model_id) params = dict( model_id=model_id, @@ -165,13 +150,6 @@ class InferenceRouter(Inference): model_id: str, contents: List[InterleavedTextMedia], ) -> EmbeddingsResponse: - model = await self.routing_table.get_model(model_id) - if model is None: - raise ValueError(f"Model '{model_id}' not found") - if model.model_type == ModelType.llm: - raise ValueError( - f"Model '{model_id}' is an LLM model and does not support embeddings" - ) return await self.routing_table.get_provider_impl(model_id).embeddings( model_id=model_id, contents=contents, diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index bc3de8be0..2fb5a5e1c 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -209,7 +209,6 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): provider_model_id: Optional[str] = None, provider_id: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None, - model_type: Optional[ModelType] = None, ) -> Model: if provider_model_id is None: provider_model_id = model_id @@ -223,21 +222,11 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): ) if metadata is None: metadata = {} - if model_type is None: - model_type = ModelType.llm - if ( - "embedding_dimension" not in metadata - and model_type == ModelType.embedding_model - ): - raise ValueError( - "Embedding model must have an embedding dimension in its metadata" - ) model = Model( identifier=model_id, provider_resource_id=provider_model_id, provider_id=provider_id, metadata=metadata, - model_type=model_type, ) registered_model = await self.register_object(model) return registered_model @@ -309,29 +298,16 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks): raise ValueError( "No provider specified and multiple providers available. Please specify a provider_id." ) - model = await self.get_object_by_identifier("model", params.embedding_model) - if model is None: - raise ValueError(f"Model {params.embedding_model} not found") - if model.model_type != ModelType.embedding_model: - raise ValueError( - f"Model {params.embedding_model} is not an embedding model" - ) - if "embedding_dimension" not in model.metadata: - raise ValueError( - f"Model {params.embedding_model} does not have an embedding dimension" - ) - memory_bank_data = { - "identifier": memory_bank_id, - "type": ResourceType.memory_bank.value, - "provider_id": provider_id, - "provider_resource_id": provider_memory_bank_id, - **params.model_dump(), - } - if params.memory_bank_type == MemoryBankType.vector.value: - memory_bank_data["embedding_dimension"] = model.metadata[ - "embedding_dimension" - ] - memory_bank = parse_obj_as(MemoryBank, memory_bank_data) + memory_bank = parse_obj_as( + MemoryBank, + { + "identifier": memory_bank_id, + "type": ResourceType.memory_bank.value, + "provider_id": provider_id, + "provider_resource_id": provider_memory_bank_id, + **params.model_dump(), + }, + ) await self.register_object(memory_bank) return memory_bank diff --git a/llama_stack/distribution/store/registry.py b/llama_stack/distribution/store/registry.py index 8f93c0c4b..041a5677c 100644 --- a/llama_stack/distribution/store/registry.py +++ b/llama_stack/distribution/store/registry.py @@ -40,7 +40,7 @@ class DistributionRegistry(Protocol): REGISTER_PREFIX = "distributions:registry" -KEY_VERSION = "v3" +KEY_VERSION = "v2" KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}" diff --git a/llama_stack/providers/utils/inference/model_registry.py b/llama_stack/providers/utils/inference/model_registry.py index be2642cdb..8dbfab14a 100644 --- a/llama_stack/providers/utils/inference/model_registry.py +++ b/llama_stack/providers/utils/inference/model_registry.py @@ -9,7 +9,6 @@ from typing import List, Optional from llama_models.sku_list import all_registered_models -from llama_stack.apis.models.models import ModelType from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate from llama_stack.providers.utils.inference import ( @@ -78,13 +77,7 @@ class ModelRegistryHelper(ModelsProtocolPrivate): return None async def register_model(self, model: Model) -> Model: - if model.model_type == ModelType.embedding_model: - # embedding models are always registered by their provider model id and does not need to be mapped to a llama model - provider_resource_id = model.provider_resource_id - else: - provider_resource_id = self.get_provider_model_id( - model.provider_resource_id - ) + provider_resource_id = self.get_provider_model_id(model.provider_resource_id) if provider_resource_id: model.provider_resource_id = provider_resource_id else: