forked from phoenix-oss/llama-stack-mirror
# What does this PR do? This PR adds a new model type field to support embedding models to be registered. Summary of changes: 1) Each registered model by default is an llm model. 2) User can specify an embedding model type, while registering.If specified, the model bypass the llama model checks since embedding models can by of any type and based on llama. 3) User needs to include the required embedding dimension in metadata. This will be used by embedding generation to generate the requried size of embeddings. ## Test Plan This PR will go together will need to be merged with two follow up PRs that will include test plans.
152 lines
4.1 KiB
Python
152 lines
4.1 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
from enum import Enum
|
|
from typing import (
|
|
Annotated,
|
|
List,
|
|
Literal,
|
|
Optional,
|
|
Protocol,
|
|
runtime_checkable,
|
|
Union,
|
|
)
|
|
|
|
from llama_models.schema_utils import json_schema_type, webmethod
|
|
|
|
from pydantic import BaseModel, Field
|
|
|
|
from llama_stack.apis.resource import Resource, ResourceType
|
|
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
|
|
|
|
|
@json_schema_type
|
|
class MemoryBankType(Enum):
|
|
vector = "vector"
|
|
keyvalue = "keyvalue"
|
|
keyword = "keyword"
|
|
graph = "graph"
|
|
|
|
|
|
# define params for each type of memory bank, this leads to a tagged union
|
|
# accepted as input from the API or from the config.
|
|
@json_schema_type
|
|
class VectorMemoryBankParams(BaseModel):
|
|
memory_bank_type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
|
|
embedding_model: str
|
|
chunk_size_in_tokens: int
|
|
overlap_size_in_tokens: Optional[int] = None
|
|
|
|
|
|
@json_schema_type
|
|
class KeyValueMemoryBankParams(BaseModel):
|
|
memory_bank_type: Literal[MemoryBankType.keyvalue.value] = (
|
|
MemoryBankType.keyvalue.value
|
|
)
|
|
|
|
|
|
@json_schema_type
|
|
class KeywordMemoryBankParams(BaseModel):
|
|
memory_bank_type: Literal[MemoryBankType.keyword.value] = (
|
|
MemoryBankType.keyword.value
|
|
)
|
|
|
|
|
|
@json_schema_type
|
|
class GraphMemoryBankParams(BaseModel):
|
|
memory_bank_type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value
|
|
|
|
|
|
BankParams = Annotated[
|
|
Union[
|
|
VectorMemoryBankParams,
|
|
KeyValueMemoryBankParams,
|
|
KeywordMemoryBankParams,
|
|
GraphMemoryBankParams,
|
|
],
|
|
Field(discriminator="memory_bank_type"),
|
|
]
|
|
|
|
|
|
# Some common functionality for memory banks.
|
|
class MemoryBankResourceMixin(Resource):
|
|
type: Literal[ResourceType.memory_bank.value] = ResourceType.memory_bank.value
|
|
|
|
@property
|
|
def memory_bank_id(self) -> str:
|
|
return self.identifier
|
|
|
|
@property
|
|
def provider_memory_bank_id(self) -> str:
|
|
return self.provider_resource_id
|
|
|
|
|
|
@json_schema_type
|
|
class VectorMemoryBank(MemoryBankResourceMixin):
|
|
memory_bank_type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
|
|
embedding_model: str
|
|
chunk_size_in_tokens: int
|
|
embedding_dimension: Optional[int] = 384 # default to minilm-l6-v2
|
|
overlap_size_in_tokens: Optional[int] = None
|
|
|
|
|
|
@json_schema_type
|
|
class KeyValueMemoryBank(MemoryBankResourceMixin):
|
|
memory_bank_type: Literal[MemoryBankType.keyvalue.value] = (
|
|
MemoryBankType.keyvalue.value
|
|
)
|
|
|
|
|
|
# TODO: KeyValue and Keyword are so similar in name, oof. Get a better naming convention.
|
|
@json_schema_type
|
|
class KeywordMemoryBank(MemoryBankResourceMixin):
|
|
memory_bank_type: Literal[MemoryBankType.keyword.value] = (
|
|
MemoryBankType.keyword.value
|
|
)
|
|
|
|
|
|
@json_schema_type
|
|
class GraphMemoryBank(MemoryBankResourceMixin):
|
|
memory_bank_type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value
|
|
|
|
|
|
MemoryBank = Annotated[
|
|
Union[
|
|
VectorMemoryBank,
|
|
KeyValueMemoryBank,
|
|
KeywordMemoryBank,
|
|
GraphMemoryBank,
|
|
],
|
|
Field(discriminator="memory_bank_type"),
|
|
]
|
|
|
|
|
|
class MemoryBankInput(BaseModel):
|
|
memory_bank_id: str
|
|
params: BankParams
|
|
provider_memory_bank_id: Optional[str] = None
|
|
|
|
|
|
@runtime_checkable
|
|
@trace_protocol
|
|
class MemoryBanks(Protocol):
|
|
@webmethod(route="/memory-banks/list", method="GET")
|
|
async def list_memory_banks(self) -> List[MemoryBank]: ...
|
|
|
|
@webmethod(route="/memory-banks/get", method="GET")
|
|
async def get_memory_bank(self, memory_bank_id: str) -> Optional[MemoryBank]: ...
|
|
|
|
@webmethod(route="/memory-banks/register", method="POST")
|
|
async def register_memory_bank(
|
|
self,
|
|
memory_bank_id: str,
|
|
params: BankParams,
|
|
provider_id: Optional[str] = None,
|
|
provider_memory_bank_id: Optional[str] = None,
|
|
) -> MemoryBank: ...
|
|
|
|
@webmethod(route="/memory-banks/unregister", method="POST")
|
|
async def unregister_memory_bank(self, memory_bank_id: str) -> None: ...
|