Revert "add model type to APIs" (#605)

Reverts meta-llama/llama-stack#588
This commit is contained in:
Dinesh Yeduguru 2024-12-11 10:17:54 -08:00 committed by GitHub
parent 8e33db6015
commit 47b2dc8ae3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 13 additions and 77 deletions

View file

@ -209,7 +209,6 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
provider_model_id: Optional[str] = None,
provider_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
model_type: Optional[ModelType] = None,
) -> Model:
if provider_model_id is None:
provider_model_id = model_id
@ -223,21 +222,11 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
)
if metadata is None:
metadata = {}
if model_type is None:
model_type = ModelType.llm
if (
"embedding_dimension" not in metadata
and model_type == ModelType.embedding_model
):
raise ValueError(
"Embedding model must have an embedding dimension in its metadata"
)
model = Model(
identifier=model_id,
provider_resource_id=provider_model_id,
provider_id=provider_id,
metadata=metadata,
model_type=model_type,
)
registered_model = await self.register_object(model)
return registered_model
@ -309,29 +298,16 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
raise ValueError(
"No provider specified and multiple providers available. Please specify a provider_id."
)
model = await self.get_object_by_identifier("model", params.embedding_model)
if model is None:
raise ValueError(f"Model {params.embedding_model} not found")
if model.model_type != ModelType.embedding_model:
raise ValueError(
f"Model {params.embedding_model} is not an embedding model"
)
if "embedding_dimension" not in model.metadata:
raise ValueError(
f"Model {params.embedding_model} does not have an embedding dimension"
)
memory_bank_data = {
"identifier": memory_bank_id,
"type": ResourceType.memory_bank.value,
"provider_id": provider_id,
"provider_resource_id": provider_memory_bank_id,
**params.model_dump(),
}
if params.memory_bank_type == MemoryBankType.vector.value:
memory_bank_data["embedding_dimension"] = model.metadata[
"embedding_dimension"
]
memory_bank = parse_obj_as(MemoryBank, memory_bank_data)
memory_bank = parse_obj_as(
MemoryBank,
{
"identifier": memory_bank_id,
"type": ResourceType.memory_bank.value,
"provider_id": provider_id,
"provider_resource_id": provider_memory_bank_id,
**params.model_dump(),
},
)
await self.register_object(memory_bank)
return memory_bank