From f81b7a8478b9f752d38ffd8a8c0517ec5661f36f Mon Sep 17 00:00:00 2001 From: Kevin Cogan <44865890+kevincogan@users.noreply.github.com> Date: Thu, 20 Mar 2025 15:45:45 +0000 Subject: [PATCH] Update RFC-0003-reranker-api.md --- rfcs/RFC-0003-reranker-api.md | 51 +++++++++++++++++++++-------------- 1 file changed, 31 insertions(+), 20 deletions(-) diff --git a/rfcs/RFC-0003-reranker-api.md b/rfcs/RFC-0003-reranker-api.md index 9dedbc9bd..c7b698722 100644 --- a/rfcs/RFC-0003-reranker-api.md +++ b/rfcs/RFC-0003-reranker-api.md @@ -286,57 +286,68 @@ Reranks query–document pairs based on a specified strategy. It requires a `mod ```python class LocalReranker(RerankerProvider): - def __init__( self, model_id: Optional[str] = None, - rerank_strategy: RerankingStrategy = RerankingStrategy.DEFAULT, + rerank_strategy: RerankingStrategy = RerankingStrategy.DEFAULT ) -> None: - self.model_id = model_id self.rerank_strategy = rerank_strategy - self.model = asyncio.run(self._initialize_model()) + self.model = None + async def initialize(self) -> None: + """Asynchronously initializes the model.""" + self.model = await self._initialize_model() async def _initialize_model(self) -> Any: + """Loads or registers the model asynchronously.""" if not self.model_id: raise ValueError( "No model_id provided, but a model is required for this reranking strategy." ) - if self.rerank_strategy == RerankingStrategy.LLM_RERANK: - # Placeholder for an LLM-based model loader - return self.model_id + # Ensure models_registry is accessible + if not hasattr(self, "models_registry"): + raise AttributeError("models_registry is not defined.") - # Default: use model_id for cross-encoder/embedding-based model loading - return self.model_id + # Attempt to get a pre-registered model + model = await self.models_registry.get_model(self.model_id) + if model: + return model + + # Otherwise, register a new one + return await self.models_registry.register_model( + model_id=self.model_id, + provider_model_id=self.model_id, + metadata={"ranking": "512"} + ) async def compute_scores(self, query: str, chunks: List[Any]) -> np.ndarray: + """Computes reranking scores for the given query and chunks.""" if self.rerank_strategy == RerankingStrategy.NONE: return np.zeros(len(chunks)) + # Ensure the model is initialized before use + if not self.model: + raise RuntimeError("Model is not initialized. Call `await initialize()` first.") + # Create (query, content) pairs from chunks pairs = [(query, chunk.content) for chunk in chunks] - # Compute scores in a non-blocking way + # Compute scores asynchronously scores = await asyncio.to_thread(self.model.predict, pairs) + # Apply reranking strategy if self.rerank_strategy == RerankingStrategy.DEFAULT: return scores - elif self.rerank_strategy == RerankingStrategy.BOOST: - # Placeholder: apply a boost factor to scores - return scores - + return scores # Placeholder for boost logic elif self.rerank_strategy == RerankingStrategy.HYBRID: - # Placeholder: combine cross-encoder scores with embedding-based scores - return scores - + return scores # Placeholder for hybrid logic elif self.rerank_strategy == RerankingStrategy.LLM_RERANK: - # Placeholder: use an LLM to compute scores - return scores - + return scores # Placeholder for LLM rerank logic else: raise ValueError(f"Unknown reranking strategy: {self.rerank_strategy}") + ``` #### ExternalReranker