Update RFC-0003-reranker-api.md

This commit is contained in:
Kevin Cogan 2025-03-20 15:45:45 +00:00 committed by GitHub
parent 70a56ce869
commit f81b7a8478
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -286,57 +286,68 @@ Reranks querydocument pairs based on a specified strategy. It requires a `mod
```python
class LocalReranker(RerankerProvider):
def __init__(
self,
model_id: Optional[str] = None,
rerank_strategy: RerankingStrategy = RerankingStrategy.DEFAULT,
rerank_strategy: RerankingStrategy = RerankingStrategy.DEFAULT
) -> None:
self.model_id = model_id
self.rerank_strategy = rerank_strategy
self.model = asyncio.run(self._initialize_model())
self.model = None
async def initialize(self) -> None:
"""Asynchronously initializes the model."""
self.model = await self._initialize_model()
async def _initialize_model(self) -> Any:
"""Loads or registers the model asynchronously."""
if not self.model_id:
raise ValueError(
"No model_id provided, but a model is required for this reranking strategy."
)
if self.rerank_strategy == RerankingStrategy.LLM_RERANK:
# Placeholder for an LLM-based model loader
return self.model_id
# Ensure models_registry is accessible
if not hasattr(self, "models_registry"):
raise AttributeError("models_registry is not defined.")
# Default: use model_id for cross-encoder/embedding-based model loading
return self.model_id
# Attempt to get a pre-registered model
model = await self.models_registry.get_model(self.model_id)
if model:
return model
# Otherwise, register a new one
return await self.models_registry.register_model(
model_id=self.model_id,
provider_model_id=self.model_id,
metadata={"ranking": "512"}
)
async def compute_scores(self, query: str, chunks: List[Any]) -> np.ndarray:
"""Computes reranking scores for the given query and chunks."""
if self.rerank_strategy == RerankingStrategy.NONE:
return np.zeros(len(chunks))
# Ensure the model is initialized before use
if not self.model:
raise RuntimeError("Model is not initialized. Call `await initialize()` first.")
# Create (query, content) pairs from chunks
pairs = [(query, chunk.content) for chunk in chunks]
# Compute scores in a non-blocking way
# Compute scores asynchronously
scores = await asyncio.to_thread(self.model.predict, pairs)
# Apply reranking strategy
if self.rerank_strategy == RerankingStrategy.DEFAULT:
return scores
elif self.rerank_strategy == RerankingStrategy.BOOST:
# Placeholder: apply a boost factor to scores
return scores
return scores # Placeholder for boost logic
elif self.rerank_strategy == RerankingStrategy.HYBRID:
# Placeholder: combine cross-encoder scores with embedding-based scores
return scores
return scores # Placeholder for hybrid logic
elif self.rerank_strategy == RerankingStrategy.LLM_RERANK:
# Placeholder: use an LLM to compute scores
return scores
return scores # Placeholder for LLM rerank logic
else:
raise ValueError(f"Unknown reranking strategy: {self.rerank_strategy}")
```
#### ExternalReranker