forked from phoenix-oss/llama-stack-mirror
Revert "add model type to APIs" (#605)
Reverts meta-llama/llama-stack#588
This commit is contained in:
parent
8e33db6015
commit
47b2dc8ae3
6 changed files with 13 additions and 77 deletions
|
@ -89,7 +89,6 @@ class VectorMemoryBank(MemoryBankResourceMixin):
|
||||||
memory_bank_type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
|
memory_bank_type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
|
||||||
embedding_model: str
|
embedding_model: str
|
||||||
chunk_size_in_tokens: int
|
chunk_size_in_tokens: int
|
||||||
embedding_dimension: Optional[int] = 384 # default to minilm-l6-v2
|
|
||||||
overlap_size_in_tokens: Optional[int] = None
|
overlap_size_in_tokens: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -4,7 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from enum import Enum
|
|
||||||
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
|
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
|
@ -21,11 +20,6 @@ class CommonModelFields(BaseModel):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ModelType(Enum):
|
|
||||||
llm = "llm"
|
|
||||||
embedding_model = "embedding"
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class Model(CommonModelFields, Resource):
|
class Model(CommonModelFields, Resource):
|
||||||
type: Literal[ResourceType.model.value] = ResourceType.model.value
|
type: Literal[ResourceType.model.value] = ResourceType.model.value
|
||||||
|
@ -40,14 +34,11 @@ class Model(CommonModelFields, Resource):
|
||||||
|
|
||||||
model_config = ConfigDict(protected_namespaces=())
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
|
||||||
model_type: ModelType = Field(default=ModelType.llm)
|
|
||||||
|
|
||||||
|
|
||||||
class ModelInput(CommonModelFields):
|
class ModelInput(CommonModelFields):
|
||||||
model_id: str
|
model_id: str
|
||||||
provider_id: Optional[str] = None
|
provider_id: Optional[str] = None
|
||||||
provider_model_id: Optional[str] = None
|
provider_model_id: Optional[str] = None
|
||||||
model_type: Optional[ModelType] = ModelType.llm
|
|
||||||
|
|
||||||
model_config = ConfigDict(protected_namespaces=())
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
|
||||||
|
@ -68,7 +59,6 @@ class Models(Protocol):
|
||||||
provider_model_id: Optional[str] = None,
|
provider_model_id: Optional[str] = None,
|
||||||
provider_id: Optional[str] = None,
|
provider_id: Optional[str] = None,
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
model_type: Optional[ModelType] = None,
|
|
||||||
) -> Model: ...
|
) -> Model: ...
|
||||||
|
|
||||||
@webmethod(route="/models/unregister", method="POST")
|
@webmethod(route="/models/unregister", method="POST")
|
||||||
|
|
|
@ -88,10 +88,9 @@ class InferenceRouter(Inference):
|
||||||
provider_model_id: Optional[str] = None,
|
provider_model_id: Optional[str] = None,
|
||||||
provider_id: Optional[str] = None,
|
provider_id: Optional[str] = None,
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
model_type: Optional[ModelType] = None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
await self.routing_table.register_model(
|
await self.routing_table.register_model(
|
||||||
model_id, provider_model_id, provider_id, metadata, model_type
|
model_id, provider_model_id, provider_id, metadata
|
||||||
)
|
)
|
||||||
|
|
||||||
async def chat_completion(
|
async def chat_completion(
|
||||||
|
@ -106,13 +105,6 @@ class InferenceRouter(Inference):
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
model = await self.routing_table.get_model(model_id)
|
|
||||||
if model is None:
|
|
||||||
raise ValueError(f"Model '{model_id}' not found")
|
|
||||||
if model.model_type == ModelType.embedding_model:
|
|
||||||
raise ValueError(
|
|
||||||
f"Model '{model_id}' is an embedding model and does not support chat completions"
|
|
||||||
)
|
|
||||||
params = dict(
|
params = dict(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
@ -139,13 +131,6 @@ class InferenceRouter(Inference):
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
model = await self.routing_table.get_model(model_id)
|
|
||||||
if model is None:
|
|
||||||
raise ValueError(f"Model '{model_id}' not found")
|
|
||||||
if model.model_type == ModelType.embedding_model:
|
|
||||||
raise ValueError(
|
|
||||||
f"Model '{model_id}' is an embedding model and does not support chat completions"
|
|
||||||
)
|
|
||||||
provider = self.routing_table.get_provider_impl(model_id)
|
provider = self.routing_table.get_provider_impl(model_id)
|
||||||
params = dict(
|
params = dict(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
|
@ -165,13 +150,6 @@ class InferenceRouter(Inference):
|
||||||
model_id: str,
|
model_id: str,
|
||||||
contents: List[InterleavedTextMedia],
|
contents: List[InterleavedTextMedia],
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
model = await self.routing_table.get_model(model_id)
|
|
||||||
if model is None:
|
|
||||||
raise ValueError(f"Model '{model_id}' not found")
|
|
||||||
if model.model_type == ModelType.llm:
|
|
||||||
raise ValueError(
|
|
||||||
f"Model '{model_id}' is an LLM model and does not support embeddings"
|
|
||||||
)
|
|
||||||
return await self.routing_table.get_provider_impl(model_id).embeddings(
|
return await self.routing_table.get_provider_impl(model_id).embeddings(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
contents=contents,
|
contents=contents,
|
||||||
|
|
|
@ -209,7 +209,6 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
provider_model_id: Optional[str] = None,
|
provider_model_id: Optional[str] = None,
|
||||||
provider_id: Optional[str] = None,
|
provider_id: Optional[str] = None,
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
model_type: Optional[ModelType] = None,
|
|
||||||
) -> Model:
|
) -> Model:
|
||||||
if provider_model_id is None:
|
if provider_model_id is None:
|
||||||
provider_model_id = model_id
|
provider_model_id = model_id
|
||||||
|
@ -223,21 +222,11 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
)
|
)
|
||||||
if metadata is None:
|
if metadata is None:
|
||||||
metadata = {}
|
metadata = {}
|
||||||
if model_type is None:
|
|
||||||
model_type = ModelType.llm
|
|
||||||
if (
|
|
||||||
"embedding_dimension" not in metadata
|
|
||||||
and model_type == ModelType.embedding_model
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
"Embedding model must have an embedding dimension in its metadata"
|
|
||||||
)
|
|
||||||
model = Model(
|
model = Model(
|
||||||
identifier=model_id,
|
identifier=model_id,
|
||||||
provider_resource_id=provider_model_id,
|
provider_resource_id=provider_model_id,
|
||||||
provider_id=provider_id,
|
provider_id=provider_id,
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
model_type=model_type,
|
|
||||||
)
|
)
|
||||||
registered_model = await self.register_object(model)
|
registered_model = await self.register_object(model)
|
||||||
return registered_model
|
return registered_model
|
||||||
|
@ -309,29 +298,16 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"No provider specified and multiple providers available. Please specify a provider_id."
|
"No provider specified and multiple providers available. Please specify a provider_id."
|
||||||
)
|
)
|
||||||
model = await self.get_object_by_identifier("model", params.embedding_model)
|
memory_bank = parse_obj_as(
|
||||||
if model is None:
|
MemoryBank,
|
||||||
raise ValueError(f"Model {params.embedding_model} not found")
|
{
|
||||||
if model.model_type != ModelType.embedding_model:
|
|
||||||
raise ValueError(
|
|
||||||
f"Model {params.embedding_model} is not an embedding model"
|
|
||||||
)
|
|
||||||
if "embedding_dimension" not in model.metadata:
|
|
||||||
raise ValueError(
|
|
||||||
f"Model {params.embedding_model} does not have an embedding dimension"
|
|
||||||
)
|
|
||||||
memory_bank_data = {
|
|
||||||
"identifier": memory_bank_id,
|
"identifier": memory_bank_id,
|
||||||
"type": ResourceType.memory_bank.value,
|
"type": ResourceType.memory_bank.value,
|
||||||
"provider_id": provider_id,
|
"provider_id": provider_id,
|
||||||
"provider_resource_id": provider_memory_bank_id,
|
"provider_resource_id": provider_memory_bank_id,
|
||||||
**params.model_dump(),
|
**params.model_dump(),
|
||||||
}
|
},
|
||||||
if params.memory_bank_type == MemoryBankType.vector.value:
|
)
|
||||||
memory_bank_data["embedding_dimension"] = model.metadata[
|
|
||||||
"embedding_dimension"
|
|
||||||
]
|
|
||||||
memory_bank = parse_obj_as(MemoryBank, memory_bank_data)
|
|
||||||
await self.register_object(memory_bank)
|
await self.register_object(memory_bank)
|
||||||
return memory_bank
|
return memory_bank
|
||||||
|
|
||||||
|
|
|
@ -40,7 +40,7 @@ class DistributionRegistry(Protocol):
|
||||||
|
|
||||||
|
|
||||||
REGISTER_PREFIX = "distributions:registry"
|
REGISTER_PREFIX = "distributions:registry"
|
||||||
KEY_VERSION = "v3"
|
KEY_VERSION = "v2"
|
||||||
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"
|
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -9,7 +9,6 @@ from typing import List, Optional
|
||||||
|
|
||||||
from llama_models.sku_list import all_registered_models
|
from llama_models.sku_list import all_registered_models
|
||||||
|
|
||||||
from llama_stack.apis.models.models import ModelType
|
|
||||||
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
|
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference import (
|
from llama_stack.providers.utils.inference import (
|
||||||
|
@ -78,13 +77,7 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def register_model(self, model: Model) -> Model:
|
async def register_model(self, model: Model) -> Model:
|
||||||
if model.model_type == ModelType.embedding_model:
|
provider_resource_id = self.get_provider_model_id(model.provider_resource_id)
|
||||||
# embedding models are always registered by their provider model id and does not need to be mapped to a llama model
|
|
||||||
provider_resource_id = model.provider_resource_id
|
|
||||||
else:
|
|
||||||
provider_resource_id = self.get_provider_model_id(
|
|
||||||
model.provider_resource_id
|
|
||||||
)
|
|
||||||
if provider_resource_id:
|
if provider_resource_id:
|
||||||
model.provider_resource_id = provider_resource_id
|
model.provider_resource_id = provider_resource_id
|
||||||
else:
|
else:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue