remove base MemoryBank inheritence

This commit is contained in:
Dinesh Yeduguru 2024-11-11 17:09:41 -08:00
parent 8436d386d3
commit 0789233dca
2 changed files with 10 additions and 13 deletions

View file

@ -31,13 +31,8 @@ class MemoryBankType(Enum):
@json_schema_type @json_schema_type
class MemoryBank(Resource): class VectorMemoryBank(Resource):
type: Literal[ResourceType.memory_bank.value] = ResourceType.memory_bank.value type: Literal[ResourceType.memory_bank.value] = ResourceType.memory_bank.value
memory_bank_type: MemoryBankType
@json_schema_type
class VectorMemoryBank(MemoryBank):
memory_bank_type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value memory_bank_type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
embedding_model: str embedding_model: str
chunk_size_in_tokens: int chunk_size_in_tokens: int
@ -45,21 +40,24 @@ class VectorMemoryBank(MemoryBank):
@json_schema_type @json_schema_type
class KeyValueMemoryBank(MemoryBank): class KeyValueMemoryBank(Resource):
type: Literal[ResourceType.memory_bank.value] = ResourceType.memory_bank.value
memory_bank_type: Literal[MemoryBankType.keyvalue.value] = ( memory_bank_type: Literal[MemoryBankType.keyvalue.value] = (
MemoryBankType.keyvalue.value MemoryBankType.keyvalue.value
) )
@json_schema_type @json_schema_type
class KeywordMemoryBank(MemoryBank): class KeywordMemoryBank(Resource):
type: Literal[ResourceType.memory_bank.value] = ResourceType.memory_bank.value
memory_bank_type: Literal[MemoryBankType.keyword.value] = ( memory_bank_type: Literal[MemoryBankType.keyword.value] = (
MemoryBankType.keyword.value MemoryBankType.keyword.value
) )
@json_schema_type @json_schema_type
class GraphMemoryBank(MemoryBank): class GraphMemoryBank(Resource):
type: Literal[ResourceType.memory_bank.value] = ResourceType.memory_bank.value
memory_bank_type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value memory_bank_type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value
@ -90,7 +88,7 @@ class GraphMemoryBankParams(BaseModel):
memory_bank_type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value memory_bank_type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value
AnyMemoryBank = Annotated[ MemoryBank = Annotated[
Union[ Union[
VectorMemoryBank, VectorMemoryBank,
KeyValueMemoryBank, KeyValueMemoryBank,

View file

@ -288,14 +288,13 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
"No provider specified and multiple providers available. Please specify a provider_id." "No provider specified and multiple providers available. Please specify a provider_id."
) )
memory_bank = parse_obj_as( memory_bank = parse_obj_as(
AnyMemoryBank, MemoryBank,
{ {
"identifier": memory_bank_id, "identifier": memory_bank_id,
"type": ResourceType.memory_bank.value, "type": ResourceType.memory_bank.value,
"provider_id": provider_id, "provider_id": provider_id,
"provider_resource_id": provider_memorybank_id, "provider_resource_id": provider_memorybank_id,
"memory_bank_type": params.memory_bank_type, **params.model_dump(),
**params.model_dump(exclude={"memory_bank_type"}),
}, },
) )
await self.register_object(memory_bank) await self.register_object(memory_bank)