Query generators for rag query

This commit is contained in:
Hardik Shah 2024-09-04 17:58:42 -07:00
parent dd1e1ceb13
commit 4a70f3d2ba
3 changed files with 133 additions and 2 deletions

View file

@ -116,10 +116,49 @@ MemoryBankConfig = Annotated[
]
@json_schema_type
class MemoryQueryGenerator(Enum):
default = "default"
llm = "llm"
custom = "custom"
class DefaultMemoryQueryGeneratorConfig(BaseModel):
type: Literal[MemoryQueryGenerator.default.value] = (
MemoryQueryGenerator.default.value
)
sep: str = " "
class LLMMemoryQueryGeneratorConfig(BaseModel):
type: Literal[MemoryQueryGenerator.llm.value] = MemoryQueryGenerator.llm.value
model: str
template: str
host: str = "localhost"
port: int = 5000
class CustomMemoryQueryGeneratorConfig(BaseModel):
type: Literal[MemoryQueryGenerator.custom.value] = MemoryQueryGenerator.custom.value
MemoryQueryGeneratorConfig = Annotated[
Union[
DefaultMemoryQueryGeneratorConfig,
LLMMemoryQueryGeneratorConfig,
CustomMemoryQueryGeneratorConfig,
],
Field(discriminator="type"),
]
class MemoryToolDefinition(ToolDefinitionCommon):
type: Literal[AgenticSystemTool.memory.value] = AgenticSystemTool.memory.value
memory_bank_configs: List[MemoryBankConfig] = Field(default_factory=list)
# This config defines how a query is generated using the messages
# for memory bank retrieval.
query_generator_config: MemoryQueryGeneratorConfig = Field(
default=DefaultMemoryQueryGeneratorConfig
)
max_tokens_in_context: int = 4096
max_chunks: int = 10