add model type to APIs (#588)

# What does this PR do?

This PR adds a new model type field to support embedding models to be
registered. Summary of changes:
1) Each registered model by default is an llm model. 
2) User can specify an embedding model type, while registering.If
specified, the model bypass the llama model checks since embedding
models can by of any type and based on llama.
3) User needs to include the required embedding dimension in metadata.
This will be used by embedding generation to generate the requried size
of embeddings.


## Test Plan

This PR will go together will need to be merged with two follow up PRs
that will include test plans.
This commit is contained in:
Dinesh Yeduguru 2024-12-11 10:16:53 -08:00 committed by GitHub
parent 7e1d628864
commit 8e33db6015
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 77 additions and 13 deletions

View file

@ -89,6 +89,7 @@ class VectorMemoryBank(MemoryBankResourceMixin):
memory_bank_type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value memory_bank_type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
embedding_model: str embedding_model: str
chunk_size_in_tokens: int chunk_size_in_tokens: int
embedding_dimension: Optional[int] = 384 # default to minilm-l6-v2
overlap_size_in_tokens: Optional[int] = None overlap_size_in_tokens: Optional[int] = None

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
from llama_models.schema_utils import json_schema_type, webmethod from llama_models.schema_utils import json_schema_type, webmethod
@ -20,6 +21,11 @@ class CommonModelFields(BaseModel):
) )
class ModelType(Enum):
llm = "llm"
embedding_model = "embedding"
@json_schema_type @json_schema_type
class Model(CommonModelFields, Resource): class Model(CommonModelFields, Resource):
type: Literal[ResourceType.model.value] = ResourceType.model.value type: Literal[ResourceType.model.value] = ResourceType.model.value
@ -34,11 +40,14 @@ class Model(CommonModelFields, Resource):
model_config = ConfigDict(protected_namespaces=()) model_config = ConfigDict(protected_namespaces=())
model_type: ModelType = Field(default=ModelType.llm)
class ModelInput(CommonModelFields): class ModelInput(CommonModelFields):
model_id: str model_id: str
provider_id: Optional[str] = None provider_id: Optional[str] = None
provider_model_id: Optional[str] = None provider_model_id: Optional[str] = None
model_type: Optional[ModelType] = ModelType.llm
model_config = ConfigDict(protected_namespaces=()) model_config = ConfigDict(protected_namespaces=())
@ -59,6 +68,7 @@ class Models(Protocol):
provider_model_id: Optional[str] = None, provider_model_id: Optional[str] = None,
provider_id: Optional[str] = None, provider_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[Dict[str, Any]] = None,
model_type: Optional[ModelType] = None,
) -> Model: ... ) -> Model: ...
@webmethod(route="/models/unregister", method="POST") @webmethod(route="/models/unregister", method="POST")

View file

@ -88,9 +88,10 @@ class InferenceRouter(Inference):
provider_model_id: Optional[str] = None, provider_model_id: Optional[str] = None,
provider_id: Optional[str] = None, provider_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[Dict[str, Any]] = None,
model_type: Optional[ModelType] = None,
) -> None: ) -> None:
await self.routing_table.register_model( await self.routing_table.register_model(
model_id, provider_model_id, provider_id, metadata model_id, provider_model_id, provider_id, metadata, model_type
) )
async def chat_completion( async def chat_completion(
@ -105,6 +106,13 @@ class InferenceRouter(Inference):
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
model = await self.routing_table.get_model(model_id)
if model is None:
raise ValueError(f"Model '{model_id}' not found")
if model.model_type == ModelType.embedding_model:
raise ValueError(
f"Model '{model_id}' is an embedding model and does not support chat completions"
)
params = dict( params = dict(
model_id=model_id, model_id=model_id,
messages=messages, messages=messages,
@ -131,6 +139,13 @@ class InferenceRouter(Inference):
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
model = await self.routing_table.get_model(model_id)
if model is None:
raise ValueError(f"Model '{model_id}' not found")
if model.model_type == ModelType.embedding_model:
raise ValueError(
f"Model '{model_id}' is an embedding model and does not support chat completions"
)
provider = self.routing_table.get_provider_impl(model_id) provider = self.routing_table.get_provider_impl(model_id)
params = dict( params = dict(
model_id=model_id, model_id=model_id,
@ -150,6 +165,13 @@ class InferenceRouter(Inference):
model_id: str, model_id: str,
contents: List[InterleavedTextMedia], contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
model = await self.routing_table.get_model(model_id)
if model is None:
raise ValueError(f"Model '{model_id}' not found")
if model.model_type == ModelType.llm:
raise ValueError(
f"Model '{model_id}' is an LLM model and does not support embeddings"
)
return await self.routing_table.get_provider_impl(model_id).embeddings( return await self.routing_table.get_provider_impl(model_id).embeddings(
model_id=model_id, model_id=model_id,
contents=contents, contents=contents,

View file

@ -209,6 +209,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
provider_model_id: Optional[str] = None, provider_model_id: Optional[str] = None,
provider_id: Optional[str] = None, provider_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[Dict[str, Any]] = None,
model_type: Optional[ModelType] = None,
) -> Model: ) -> Model:
if provider_model_id is None: if provider_model_id is None:
provider_model_id = model_id provider_model_id = model_id
@ -222,11 +223,21 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
) )
if metadata is None: if metadata is None:
metadata = {} metadata = {}
if model_type is None:
model_type = ModelType.llm
if (
"embedding_dimension" not in metadata
and model_type == ModelType.embedding_model
):
raise ValueError(
"Embedding model must have an embedding dimension in its metadata"
)
model = Model( model = Model(
identifier=model_id, identifier=model_id,
provider_resource_id=provider_model_id, provider_resource_id=provider_model_id,
provider_id=provider_id, provider_id=provider_id,
metadata=metadata, metadata=metadata,
model_type=model_type,
) )
registered_model = await self.register_object(model) registered_model = await self.register_object(model)
return registered_model return registered_model
@ -298,16 +309,29 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
raise ValueError( raise ValueError(
"No provider specified and multiple providers available. Please specify a provider_id." "No provider specified and multiple providers available. Please specify a provider_id."
) )
memory_bank = parse_obj_as( model = await self.get_object_by_identifier("model", params.embedding_model)
MemoryBank, if model is None:
{ raise ValueError(f"Model {params.embedding_model} not found")
"identifier": memory_bank_id, if model.model_type != ModelType.embedding_model:
"type": ResourceType.memory_bank.value, raise ValueError(
"provider_id": provider_id, f"Model {params.embedding_model} is not an embedding model"
"provider_resource_id": provider_memory_bank_id, )
**params.model_dump(), if "embedding_dimension" not in model.metadata:
}, raise ValueError(
) f"Model {params.embedding_model} does not have an embedding dimension"
)
memory_bank_data = {
"identifier": memory_bank_id,
"type": ResourceType.memory_bank.value,
"provider_id": provider_id,
"provider_resource_id": provider_memory_bank_id,
**params.model_dump(),
}
if params.memory_bank_type == MemoryBankType.vector.value:
memory_bank_data["embedding_dimension"] = model.metadata[
"embedding_dimension"
]
memory_bank = parse_obj_as(MemoryBank, memory_bank_data)
await self.register_object(memory_bank) await self.register_object(memory_bank)
return memory_bank return memory_bank

View file

@ -40,7 +40,7 @@ class DistributionRegistry(Protocol):
REGISTER_PREFIX = "distributions:registry" REGISTER_PREFIX = "distributions:registry"
KEY_VERSION = "v2" KEY_VERSION = "v3"
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}" KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"

View file

@ -9,6 +9,7 @@ from typing import List, Optional
from llama_models.sku_list import all_registered_models from llama_models.sku_list import all_registered_models
from llama_stack.apis.models.models import ModelType
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
from llama_stack.providers.utils.inference import ( from llama_stack.providers.utils.inference import (
@ -77,7 +78,13 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
return None return None
async def register_model(self, model: Model) -> Model: async def register_model(self, model: Model) -> Model:
provider_resource_id = self.get_provider_model_id(model.provider_resource_id) if model.model_type == ModelType.embedding_model:
# embedding models are always registered by their provider model id and does not need to be mapped to a llama model
provider_resource_id = model.provider_resource_id
else:
provider_resource_id = self.get_provider_model_id(
model.provider_resource_id
)
if provider_resource_id: if provider_resource_id:
model.provider_resource_id = provider_resource_id model.provider_resource_id = provider_resource_id
else: else: