forked from phoenix-oss/llama-stack-mirror
Make embedding generation go through inference (#606)
This PR does the following: 1) adds the ability to generate embeddings in all supported inference providers. 2) Moves all the memory providers to use the inference API and improved the memory tests to setup the inference stack correctly and use the embedding models This is a merge from #589 and #598
This commit is contained in:
parent
a14785af46
commit
96e158eaac
37 changed files with 677 additions and 156 deletions
|
@ -209,6 +209,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
provider_model_id: Optional[str] = None,
|
||||
provider_id: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
model_type: Optional[ModelType] = None,
|
||||
) -> Model:
|
||||
if provider_model_id is None:
|
||||
provider_model_id = model_id
|
||||
|
@ -222,11 +223,21 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
)
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
if model_type is None:
|
||||
model_type = ModelType.llm
|
||||
if (
|
||||
"embedding_dimension" not in metadata
|
||||
and model_type == ModelType.embedding_model
|
||||
):
|
||||
raise ValueError(
|
||||
"Embedding model must have an embedding dimension in its metadata"
|
||||
)
|
||||
model = Model(
|
||||
identifier=model_id,
|
||||
provider_resource_id=provider_model_id,
|
||||
provider_id=provider_id,
|
||||
metadata=metadata,
|
||||
model_type=model_type,
|
||||
)
|
||||
registered_model = await self.register_object(model)
|
||||
return registered_model
|
||||
|
@ -298,16 +309,29 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
|
|||
raise ValueError(
|
||||
"No provider specified and multiple providers available. Please specify a provider_id."
|
||||
)
|
||||
memory_bank = parse_obj_as(
|
||||
MemoryBank,
|
||||
{
|
||||
"identifier": memory_bank_id,
|
||||
"type": ResourceType.memory_bank.value,
|
||||
"provider_id": provider_id,
|
||||
"provider_resource_id": provider_memory_bank_id,
|
||||
**params.model_dump(),
|
||||
},
|
||||
)
|
||||
model = await self.get_object_by_identifier("model", params.embedding_model)
|
||||
if model is None:
|
||||
raise ValueError(f"Model {params.embedding_model} not found")
|
||||
if model.model_type != ModelType.embedding_model:
|
||||
raise ValueError(
|
||||
f"Model {params.embedding_model} is not an embedding model"
|
||||
)
|
||||
if "embedding_dimension" not in model.metadata:
|
||||
raise ValueError(
|
||||
f"Model {params.embedding_model} does not have an embedding dimension"
|
||||
)
|
||||
memory_bank_data = {
|
||||
"identifier": memory_bank_id,
|
||||
"type": ResourceType.memory_bank.value,
|
||||
"provider_id": provider_id,
|
||||
"provider_resource_id": provider_memory_bank_id,
|
||||
**params.model_dump(),
|
||||
}
|
||||
if params.memory_bank_type == MemoryBankType.vector.value:
|
||||
memory_bank_data["embedding_dimension"] = model.metadata[
|
||||
"embedding_dimension"
|
||||
]
|
||||
memory_bank = parse_obj_as(MemoryBank, memory_bank_data)
|
||||
await self.register_object(memory_bank)
|
||||
return memory_bank
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue