Revert "add model type to APIs" (#605)

Reverts meta-llama/llama-stack#588
This commit is contained in:
Dinesh Yeduguru 2024-12-11 10:17:54 -08:00 committed by GitHub
parent 8e33db6015
commit 47b2dc8ae3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 13 additions and 77 deletions

View file

@ -89,7 +89,6 @@ class VectorMemoryBank(MemoryBankResourceMixin):
memory_bank_type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value memory_bank_type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
embedding_model: str embedding_model: str
chunk_size_in_tokens: int chunk_size_in_tokens: int
embedding_dimension: Optional[int] = 384 # default to minilm-l6-v2
overlap_size_in_tokens: Optional[int] = None overlap_size_in_tokens: Optional[int] = None

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
from llama_models.schema_utils import json_schema_type, webmethod from llama_models.schema_utils import json_schema_type, webmethod
@ -21,11 +20,6 @@ class CommonModelFields(BaseModel):
) )
class ModelType(Enum):
llm = "llm"
embedding_model = "embedding"
@json_schema_type @json_schema_type
class Model(CommonModelFields, Resource): class Model(CommonModelFields, Resource):
type: Literal[ResourceType.model.value] = ResourceType.model.value type: Literal[ResourceType.model.value] = ResourceType.model.value
@ -40,14 +34,11 @@ class Model(CommonModelFields, Resource):
model_config = ConfigDict(protected_namespaces=()) model_config = ConfigDict(protected_namespaces=())
model_type: ModelType = Field(default=ModelType.llm)
class ModelInput(CommonModelFields): class ModelInput(CommonModelFields):
model_id: str model_id: str
provider_id: Optional[str] = None provider_id: Optional[str] = None
provider_model_id: Optional[str] = None provider_model_id: Optional[str] = None
model_type: Optional[ModelType] = ModelType.llm
model_config = ConfigDict(protected_namespaces=()) model_config = ConfigDict(protected_namespaces=())
@ -68,7 +59,6 @@ class Models(Protocol):
provider_model_id: Optional[str] = None, provider_model_id: Optional[str] = None,
provider_id: Optional[str] = None, provider_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[Dict[str, Any]] = None,
model_type: Optional[ModelType] = None,
) -> Model: ... ) -> Model: ...
@webmethod(route="/models/unregister", method="POST") @webmethod(route="/models/unregister", method="POST")

View file

@ -88,10 +88,9 @@ class InferenceRouter(Inference):
provider_model_id: Optional[str] = None, provider_model_id: Optional[str] = None,
provider_id: Optional[str] = None, provider_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[Dict[str, Any]] = None,
model_type: Optional[ModelType] = None,
) -> None: ) -> None:
await self.routing_table.register_model( await self.routing_table.register_model(
model_id, provider_model_id, provider_id, metadata, model_type model_id, provider_model_id, provider_id, metadata
) )
async def chat_completion( async def chat_completion(
@ -106,13 +105,6 @@ class InferenceRouter(Inference):
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
model = await self.routing_table.get_model(model_id)
if model is None:
raise ValueError(f"Model '{model_id}' not found")
if model.model_type == ModelType.embedding_model:
raise ValueError(
f"Model '{model_id}' is an embedding model and does not support chat completions"
)
params = dict( params = dict(
model_id=model_id, model_id=model_id,
messages=messages, messages=messages,
@ -139,13 +131,6 @@ class InferenceRouter(Inference):
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
model = await self.routing_table.get_model(model_id)
if model is None:
raise ValueError(f"Model '{model_id}' not found")
if model.model_type == ModelType.embedding_model:
raise ValueError(
f"Model '{model_id}' is an embedding model and does not support chat completions"
)
provider = self.routing_table.get_provider_impl(model_id) provider = self.routing_table.get_provider_impl(model_id)
params = dict( params = dict(
model_id=model_id, model_id=model_id,
@ -165,13 +150,6 @@ class InferenceRouter(Inference):
model_id: str, model_id: str,
contents: List[InterleavedTextMedia], contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
model = await self.routing_table.get_model(model_id)
if model is None:
raise ValueError(f"Model '{model_id}' not found")
if model.model_type == ModelType.llm:
raise ValueError(
f"Model '{model_id}' is an LLM model and does not support embeddings"
)
return await self.routing_table.get_provider_impl(model_id).embeddings( return await self.routing_table.get_provider_impl(model_id).embeddings(
model_id=model_id, model_id=model_id,
contents=contents, contents=contents,

View file

@ -209,7 +209,6 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
provider_model_id: Optional[str] = None, provider_model_id: Optional[str] = None,
provider_id: Optional[str] = None, provider_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[Dict[str, Any]] = None,
model_type: Optional[ModelType] = None,
) -> Model: ) -> Model:
if provider_model_id is None: if provider_model_id is None:
provider_model_id = model_id provider_model_id = model_id
@ -223,21 +222,11 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
) )
if metadata is None: if metadata is None:
metadata = {} metadata = {}
if model_type is None:
model_type = ModelType.llm
if (
"embedding_dimension" not in metadata
and model_type == ModelType.embedding_model
):
raise ValueError(
"Embedding model must have an embedding dimension in its metadata"
)
model = Model( model = Model(
identifier=model_id, identifier=model_id,
provider_resource_id=provider_model_id, provider_resource_id=provider_model_id,
provider_id=provider_id, provider_id=provider_id,
metadata=metadata, metadata=metadata,
model_type=model_type,
) )
registered_model = await self.register_object(model) registered_model = await self.register_object(model)
return registered_model return registered_model
@ -309,29 +298,16 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
raise ValueError( raise ValueError(
"No provider specified and multiple providers available. Please specify a provider_id." "No provider specified and multiple providers available. Please specify a provider_id."
) )
model = await self.get_object_by_identifier("model", params.embedding_model) memory_bank = parse_obj_as(
if model is None: MemoryBank,
raise ValueError(f"Model {params.embedding_model} not found") {
if model.model_type != ModelType.embedding_model: "identifier": memory_bank_id,
raise ValueError( "type": ResourceType.memory_bank.value,
f"Model {params.embedding_model} is not an embedding model" "provider_id": provider_id,
) "provider_resource_id": provider_memory_bank_id,
if "embedding_dimension" not in model.metadata: **params.model_dump(),
raise ValueError( },
f"Model {params.embedding_model} does not have an embedding dimension" )
)
memory_bank_data = {
"identifier": memory_bank_id,
"type": ResourceType.memory_bank.value,
"provider_id": provider_id,
"provider_resource_id": provider_memory_bank_id,
**params.model_dump(),
}
if params.memory_bank_type == MemoryBankType.vector.value:
memory_bank_data["embedding_dimension"] = model.metadata[
"embedding_dimension"
]
memory_bank = parse_obj_as(MemoryBank, memory_bank_data)
await self.register_object(memory_bank) await self.register_object(memory_bank)
return memory_bank return memory_bank

View file

@ -40,7 +40,7 @@ class DistributionRegistry(Protocol):
REGISTER_PREFIX = "distributions:registry" REGISTER_PREFIX = "distributions:registry"
KEY_VERSION = "v3" KEY_VERSION = "v2"
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}" KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"

View file

@ -9,7 +9,6 @@ from typing import List, Optional
from llama_models.sku_list import all_registered_models from llama_models.sku_list import all_registered_models
from llama_stack.apis.models.models import ModelType
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
from llama_stack.providers.utils.inference import ( from llama_stack.providers.utils.inference import (
@ -78,13 +77,7 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
return None return None
async def register_model(self, model: Model) -> Model: async def register_model(self, model: Model) -> Model:
if model.model_type == ModelType.embedding_model: provider_resource_id = self.get_provider_model_id(model.provider_resource_id)
# embedding models are always registered by their provider model id and does not need to be mapped to a llama model
provider_resource_id = model.provider_resource_id
else:
provider_resource_id = self.get_provider_model_id(
model.provider_resource_id
)
if provider_resource_id: if provider_resource_id:
model.provider_resource_id = provider_resource_id model.provider_resource_id = provider_resource_id
else: else: