feat (RAG): Implement configurable search mode in RAGQueryConfig

Signed-off-by: Varsha Prasad Narsing <varshaprasad96@gmail.com>
This commit is contained in:
Varsha Prasad Narsing 2025-04-14 16:53:17 -07:00
parent 85b5f3172b
commit e2a7022d3c
14 changed files with 210 additions and 43 deletions

View file

@ -55,7 +55,9 @@ class ChromaIndex(EmbeddingIndex):
)
)
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
async def query(
self, embedding: NDArray, query_string: Optional[str], k: int, score_threshold: float, mode: str
) -> QueryChunksResponse:
results = await maybe_await(
self.collection.query(
query_embeddings=[embedding.tolist()],

View file

@ -73,7 +73,9 @@ class MilvusIndex(EmbeddingIndex):
logger.error(f"Error inserting chunks into Milvus collection {self.collection_name}: {e}")
raise e
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
async def query(
self, embedding: NDArray, query_str: Optional[str], k: int, score_threshold: float, mode: str
) -> QueryChunksResponse:
search_res = await asyncio.to_thread(
self.client.search,
collection_name=self.collection_name,

View file

@ -99,7 +99,9 @@ class PGVectorIndex(EmbeddingIndex):
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
execute_values(cur, query, values, template="(%s, %s, %s::vector)")
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
async def query(
self, embedding: NDArray, query_string: Optional[str], k: int, score_threshold: float, mode: str
) -> QueryChunksResponse:
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
cur.execute(
f"""

View file

@ -68,7 +68,9 @@ class QdrantIndex(EmbeddingIndex):
await self.client.upsert(collection_name=self.collection_name, points=points)
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
async def query(
self, embedding: NDArray, query_string: Optional[str], k: int, score_threshold: float, mode: str
) -> QueryChunksResponse:
results = (
await self.client.query_points(
collection_name=self.collection_name,

View file

@ -55,7 +55,9 @@ class WeaviateIndex(EmbeddingIndex):
# TODO: make this async friendly
collection.data.insert_many(data_objects)
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
async def query(
self, embedding: NDArray, query_string: Optional[str], k: int, score_threshold: float, mode: str
) -> QueryChunksResponse:
collection = self.client.collections.get(self.collection_name)
results = collection.query.near_vector(