add model type to APIs (#588)

# What does this PR do?

This PR adds a new model type field to support embedding models to be
registered. Summary of changes:
1) Each registered model by default is an llm model. 
2) User can specify an embedding model type, while registering.If
specified, the model bypass the llama model checks since embedding
models can by of any type and based on llama.
3) User needs to include the required embedding dimension in metadata.
This will be used by embedding generation to generate the requried size
of embeddings.


## Test Plan

This PR will go together will need to be merged with two follow up PRs
that will include test plans.
This commit is contained in:
Dinesh Yeduguru 2024-12-11 10:16:53 -08:00 committed by GitHub
parent 7e1d628864
commit 8e33db6015
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 77 additions and 13 deletions

View file

@ -89,6 +89,7 @@ class VectorMemoryBank(MemoryBankResourceMixin):
memory_bank_type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
embedding_model: str
chunk_size_in_tokens: int
embedding_dimension: Optional[int] = 384 # default to minilm-l6-v2
overlap_size_in_tokens: Optional[int] = None

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
from llama_models.schema_utils import json_schema_type, webmethod
@ -20,6 +21,11 @@ class CommonModelFields(BaseModel):
)
class ModelType(Enum):
llm = "llm"
embedding_model = "embedding"
@json_schema_type
class Model(CommonModelFields, Resource):
type: Literal[ResourceType.model.value] = ResourceType.model.value
@ -34,11 +40,14 @@ class Model(CommonModelFields, Resource):
model_config = ConfigDict(protected_namespaces=())
model_type: ModelType = Field(default=ModelType.llm)
class ModelInput(CommonModelFields):
model_id: str
provider_id: Optional[str] = None
provider_model_id: Optional[str] = None
model_type: Optional[ModelType] = ModelType.llm
model_config = ConfigDict(protected_namespaces=())
@ -59,6 +68,7 @@ class Models(Protocol):
provider_model_id: Optional[str] = None,
provider_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
model_type: Optional[ModelType] = None,
) -> Model: ...
@webmethod(route="/models/unregister", method="POST")

View file

@ -88,9 +88,10 @@ class InferenceRouter(Inference):
provider_model_id: Optional[str] = None,
provider_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
model_type: Optional[ModelType] = None,
) -> None:
await self.routing_table.register_model(
model_id, provider_model_id, provider_id, metadata
model_id, provider_model_id, provider_id, metadata, model_type
)
async def chat_completion(
@ -105,6 +106,13 @@ class InferenceRouter(Inference):
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
model = await self.routing_table.get_model(model_id)
if model is None:
raise ValueError(f"Model '{model_id}' not found")
if model.model_type == ModelType.embedding_model:
raise ValueError(
f"Model '{model_id}' is an embedding model and does not support chat completions"
)
params = dict(
model_id=model_id,
messages=messages,
@ -131,6 +139,13 @@ class InferenceRouter(Inference):
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
model = await self.routing_table.get_model(model_id)
if model is None:
raise ValueError(f"Model '{model_id}' not found")
if model.model_type == ModelType.embedding_model:
raise ValueError(
f"Model '{model_id}' is an embedding model and does not support chat completions"
)
provider = self.routing_table.get_provider_impl(model_id)
params = dict(
model_id=model_id,
@ -150,6 +165,13 @@ class InferenceRouter(Inference):
model_id: str,
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse:
model = await self.routing_table.get_model(model_id)
if model is None:
raise ValueError(f"Model '{model_id}' not found")
if model.model_type == ModelType.llm:
raise ValueError(
f"Model '{model_id}' is an LLM model and does not support embeddings"
)
return await self.routing_table.get_provider_impl(model_id).embeddings(
model_id=model_id,
contents=contents,

View file

@ -209,6 +209,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
provider_model_id: Optional[str] = None,
provider_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
model_type: Optional[ModelType] = None,
) -> Model:
if provider_model_id is None:
provider_model_id = model_id
@ -222,11 +223,21 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
)
if metadata is None:
metadata = {}
if model_type is None:
model_type = ModelType.llm
if (
"embedding_dimension" not in metadata
and model_type == ModelType.embedding_model
):
raise ValueError(
"Embedding model must have an embedding dimension in its metadata"
)
model = Model(
identifier=model_id,
provider_resource_id=provider_model_id,
provider_id=provider_id,
metadata=metadata,
model_type=model_type,
)
registered_model = await self.register_object(model)
return registered_model
@ -298,16 +309,29 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
raise ValueError(
"No provider specified and multiple providers available. Please specify a provider_id."
)
memory_bank = parse_obj_as(
MemoryBank,
{
model = await self.get_object_by_identifier("model", params.embedding_model)
if model is None:
raise ValueError(f"Model {params.embedding_model} not found")
if model.model_type != ModelType.embedding_model:
raise ValueError(
f"Model {params.embedding_model} is not an embedding model"
)
if "embedding_dimension" not in model.metadata:
raise ValueError(
f"Model {params.embedding_model} does not have an embedding dimension"
)
memory_bank_data = {
"identifier": memory_bank_id,
"type": ResourceType.memory_bank.value,
"provider_id": provider_id,
"provider_resource_id": provider_memory_bank_id,
**params.model_dump(),
},
)
}
if params.memory_bank_type == MemoryBankType.vector.value:
memory_bank_data["embedding_dimension"] = model.metadata[
"embedding_dimension"
]
memory_bank = parse_obj_as(MemoryBank, memory_bank_data)
await self.register_object(memory_bank)
return memory_bank

View file

@ -40,7 +40,7 @@ class DistributionRegistry(Protocol):
REGISTER_PREFIX = "distributions:registry"
KEY_VERSION = "v2"
KEY_VERSION = "v3"
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"

View file

@ -9,6 +9,7 @@ from typing import List, Optional
from llama_models.sku_list import all_registered_models
from llama_stack.apis.models.models import ModelType
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
from llama_stack.providers.utils.inference import (
@ -77,7 +78,13 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
return None
async def register_model(self, model: Model) -> Model:
provider_resource_id = self.get_provider_model_id(model.provider_resource_id)
if model.model_type == ModelType.embedding_model:
# embedding models are always registered by their provider model id and does not need to be mapped to a llama model
provider_resource_id = model.provider_resource_id
else:
provider_resource_id = self.get_provider_model_id(
model.provider_resource_id
)
if provider_resource_id:
model.provider_resource_id = provider_resource_id
else: