mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
add model type to APIs (#588)
# What does this PR do? This PR adds a new model type field to support embedding models to be registered. Summary of changes: 1) Each registered model by default is an llm model. 2) User can specify an embedding model type, while registering.If specified, the model bypass the llama model checks since embedding models can by of any type and based on llama. 3) User needs to include the required embedding dimension in metadata. This will be used by embedding generation to generate the requried size of embeddings. ## Test Plan This PR will go together will need to be merged with two follow up PRs that will include test plans.
This commit is contained in:
parent
7e1d628864
commit
8e33db6015
6 changed files with 77 additions and 13 deletions
|
@ -89,6 +89,7 @@ class VectorMemoryBank(MemoryBankResourceMixin):
|
||||||
memory_bank_type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
|
memory_bank_type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
|
||||||
embedding_model: str
|
embedding_model: str
|
||||||
chunk_size_in_tokens: int
|
chunk_size_in_tokens: int
|
||||||
|
embedding_dimension: Optional[int] = 384 # default to minilm-l6-v2
|
||||||
overlap_size_in_tokens: Optional[int] = None
|
overlap_size_in_tokens: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -4,6 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
|
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
|
@ -20,6 +21,11 @@ class CommonModelFields(BaseModel):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelType(Enum):
|
||||||
|
llm = "llm"
|
||||||
|
embedding_model = "embedding"
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class Model(CommonModelFields, Resource):
|
class Model(CommonModelFields, Resource):
|
||||||
type: Literal[ResourceType.model.value] = ResourceType.model.value
|
type: Literal[ResourceType.model.value] = ResourceType.model.value
|
||||||
|
@ -34,11 +40,14 @@ class Model(CommonModelFields, Resource):
|
||||||
|
|
||||||
model_config = ConfigDict(protected_namespaces=())
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
|
||||||
|
model_type: ModelType = Field(default=ModelType.llm)
|
||||||
|
|
||||||
|
|
||||||
class ModelInput(CommonModelFields):
|
class ModelInput(CommonModelFields):
|
||||||
model_id: str
|
model_id: str
|
||||||
provider_id: Optional[str] = None
|
provider_id: Optional[str] = None
|
||||||
provider_model_id: Optional[str] = None
|
provider_model_id: Optional[str] = None
|
||||||
|
model_type: Optional[ModelType] = ModelType.llm
|
||||||
|
|
||||||
model_config = ConfigDict(protected_namespaces=())
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
|
||||||
|
@ -59,6 +68,7 @@ class Models(Protocol):
|
||||||
provider_model_id: Optional[str] = None,
|
provider_model_id: Optional[str] = None,
|
||||||
provider_id: Optional[str] = None,
|
provider_id: Optional[str] = None,
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
|
model_type: Optional[ModelType] = None,
|
||||||
) -> Model: ...
|
) -> Model: ...
|
||||||
|
|
||||||
@webmethod(route="/models/unregister", method="POST")
|
@webmethod(route="/models/unregister", method="POST")
|
||||||
|
|
|
@ -88,9 +88,10 @@ class InferenceRouter(Inference):
|
||||||
provider_model_id: Optional[str] = None,
|
provider_model_id: Optional[str] = None,
|
||||||
provider_id: Optional[str] = None,
|
provider_id: Optional[str] = None,
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
|
model_type: Optional[ModelType] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
await self.routing_table.register_model(
|
await self.routing_table.register_model(
|
||||||
model_id, provider_model_id, provider_id, metadata
|
model_id, provider_model_id, provider_id, metadata, model_type
|
||||||
)
|
)
|
||||||
|
|
||||||
async def chat_completion(
|
async def chat_completion(
|
||||||
|
@ -105,6 +106,13 @@ class InferenceRouter(Inference):
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
|
model = await self.routing_table.get_model(model_id)
|
||||||
|
if model is None:
|
||||||
|
raise ValueError(f"Model '{model_id}' not found")
|
||||||
|
if model.model_type == ModelType.embedding_model:
|
||||||
|
raise ValueError(
|
||||||
|
f"Model '{model_id}' is an embedding model and does not support chat completions"
|
||||||
|
)
|
||||||
params = dict(
|
params = dict(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
@ -131,6 +139,13 @@ class InferenceRouter(Inference):
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
|
model = await self.routing_table.get_model(model_id)
|
||||||
|
if model is None:
|
||||||
|
raise ValueError(f"Model '{model_id}' not found")
|
||||||
|
if model.model_type == ModelType.embedding_model:
|
||||||
|
raise ValueError(
|
||||||
|
f"Model '{model_id}' is an embedding model and does not support chat completions"
|
||||||
|
)
|
||||||
provider = self.routing_table.get_provider_impl(model_id)
|
provider = self.routing_table.get_provider_impl(model_id)
|
||||||
params = dict(
|
params = dict(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
|
@ -150,6 +165,13 @@ class InferenceRouter(Inference):
|
||||||
model_id: str,
|
model_id: str,
|
||||||
contents: List[InterleavedTextMedia],
|
contents: List[InterleavedTextMedia],
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
|
model = await self.routing_table.get_model(model_id)
|
||||||
|
if model is None:
|
||||||
|
raise ValueError(f"Model '{model_id}' not found")
|
||||||
|
if model.model_type == ModelType.llm:
|
||||||
|
raise ValueError(
|
||||||
|
f"Model '{model_id}' is an LLM model and does not support embeddings"
|
||||||
|
)
|
||||||
return await self.routing_table.get_provider_impl(model_id).embeddings(
|
return await self.routing_table.get_provider_impl(model_id).embeddings(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
contents=contents,
|
contents=contents,
|
||||||
|
|
|
@ -209,6 +209,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
provider_model_id: Optional[str] = None,
|
provider_model_id: Optional[str] = None,
|
||||||
provider_id: Optional[str] = None,
|
provider_id: Optional[str] = None,
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
|
model_type: Optional[ModelType] = None,
|
||||||
) -> Model:
|
) -> Model:
|
||||||
if provider_model_id is None:
|
if provider_model_id is None:
|
||||||
provider_model_id = model_id
|
provider_model_id = model_id
|
||||||
|
@ -222,11 +223,21 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
)
|
)
|
||||||
if metadata is None:
|
if metadata is None:
|
||||||
metadata = {}
|
metadata = {}
|
||||||
|
if model_type is None:
|
||||||
|
model_type = ModelType.llm
|
||||||
|
if (
|
||||||
|
"embedding_dimension" not in metadata
|
||||||
|
and model_type == ModelType.embedding_model
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"Embedding model must have an embedding dimension in its metadata"
|
||||||
|
)
|
||||||
model = Model(
|
model = Model(
|
||||||
identifier=model_id,
|
identifier=model_id,
|
||||||
provider_resource_id=provider_model_id,
|
provider_resource_id=provider_model_id,
|
||||||
provider_id=provider_id,
|
provider_id=provider_id,
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
|
model_type=model_type,
|
||||||
)
|
)
|
||||||
registered_model = await self.register_object(model)
|
registered_model = await self.register_object(model)
|
||||||
return registered_model
|
return registered_model
|
||||||
|
@ -298,16 +309,29 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"No provider specified and multiple providers available. Please specify a provider_id."
|
"No provider specified and multiple providers available. Please specify a provider_id."
|
||||||
)
|
)
|
||||||
memory_bank = parse_obj_as(
|
model = await self.get_object_by_identifier("model", params.embedding_model)
|
||||||
MemoryBank,
|
if model is None:
|
||||||
{
|
raise ValueError(f"Model {params.embedding_model} not found")
|
||||||
"identifier": memory_bank_id,
|
if model.model_type != ModelType.embedding_model:
|
||||||
"type": ResourceType.memory_bank.value,
|
raise ValueError(
|
||||||
"provider_id": provider_id,
|
f"Model {params.embedding_model} is not an embedding model"
|
||||||
"provider_resource_id": provider_memory_bank_id,
|
)
|
||||||
**params.model_dump(),
|
if "embedding_dimension" not in model.metadata:
|
||||||
},
|
raise ValueError(
|
||||||
)
|
f"Model {params.embedding_model} does not have an embedding dimension"
|
||||||
|
)
|
||||||
|
memory_bank_data = {
|
||||||
|
"identifier": memory_bank_id,
|
||||||
|
"type": ResourceType.memory_bank.value,
|
||||||
|
"provider_id": provider_id,
|
||||||
|
"provider_resource_id": provider_memory_bank_id,
|
||||||
|
**params.model_dump(),
|
||||||
|
}
|
||||||
|
if params.memory_bank_type == MemoryBankType.vector.value:
|
||||||
|
memory_bank_data["embedding_dimension"] = model.metadata[
|
||||||
|
"embedding_dimension"
|
||||||
|
]
|
||||||
|
memory_bank = parse_obj_as(MemoryBank, memory_bank_data)
|
||||||
await self.register_object(memory_bank)
|
await self.register_object(memory_bank)
|
||||||
return memory_bank
|
return memory_bank
|
||||||
|
|
||||||
|
|
|
@ -40,7 +40,7 @@ class DistributionRegistry(Protocol):
|
||||||
|
|
||||||
|
|
||||||
REGISTER_PREFIX = "distributions:registry"
|
REGISTER_PREFIX = "distributions:registry"
|
||||||
KEY_VERSION = "v2"
|
KEY_VERSION = "v3"
|
||||||
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"
|
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -9,6 +9,7 @@ from typing import List, Optional
|
||||||
|
|
||||||
from llama_models.sku_list import all_registered_models
|
from llama_models.sku_list import all_registered_models
|
||||||
|
|
||||||
|
from llama_stack.apis.models.models import ModelType
|
||||||
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
|
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference import (
|
from llama_stack.providers.utils.inference import (
|
||||||
|
@ -77,7 +78,13 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def register_model(self, model: Model) -> Model:
|
async def register_model(self, model: Model) -> Model:
|
||||||
provider_resource_id = self.get_provider_model_id(model.provider_resource_id)
|
if model.model_type == ModelType.embedding_model:
|
||||||
|
# embedding models are always registered by their provider model id and does not need to be mapped to a llama model
|
||||||
|
provider_resource_id = model.provider_resource_id
|
||||||
|
else:
|
||||||
|
provider_resource_id = self.get_provider_model_id(
|
||||||
|
model.provider_resource_id
|
||||||
|
)
|
||||||
if provider_resource_id:
|
if provider_resource_id:
|
||||||
model.provider_resource_id = provider_resource_id
|
model.provider_resource_id = provider_resource_id
|
||||||
else:
|
else:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue