mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-06 02:32:40 +00:00
Update RFC-0003-reranker-api.md
This commit is contained in:
parent
70a56ce869
commit
f81b7a8478
1 changed files with 31 additions and 20 deletions
|
@ -286,57 +286,68 @@ Reranks query–document pairs based on a specified strategy. It requires a `mod
|
|||
|
||||
```python
|
||||
class LocalReranker(RerankerProvider):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_id: Optional[str] = None,
|
||||
rerank_strategy: RerankingStrategy = RerankingStrategy.DEFAULT,
|
||||
rerank_strategy: RerankingStrategy = RerankingStrategy.DEFAULT
|
||||
) -> None:
|
||||
|
||||
self.model_id = model_id
|
||||
self.rerank_strategy = rerank_strategy
|
||||
self.model = asyncio.run(self._initialize_model())
|
||||
self.model = None
|
||||
async def initialize(self) -> None:
|
||||
"""Asynchronously initializes the model."""
|
||||
self.model = await self._initialize_model()
|
||||
|
||||
async def _initialize_model(self) -> Any:
|
||||
"""Loads or registers the model asynchronously."""
|
||||
if not self.model_id:
|
||||
raise ValueError(
|
||||
"No model_id provided, but a model is required for this reranking strategy."
|
||||
)
|
||||
|
||||
if self.rerank_strategy == RerankingStrategy.LLM_RERANK:
|
||||
# Placeholder for an LLM-based model loader
|
||||
return self.model_id
|
||||
# Ensure models_registry is accessible
|
||||
if not hasattr(self, "models_registry"):
|
||||
raise AttributeError("models_registry is not defined.")
|
||||
|
||||
# Default: use model_id for cross-encoder/embedding-based model loading
|
||||
return self.model_id
|
||||
# Attempt to get a pre-registered model
|
||||
model = await self.models_registry.get_model(self.model_id)
|
||||
if model:
|
||||
return model
|
||||
|
||||
# Otherwise, register a new one
|
||||
return await self.models_registry.register_model(
|
||||
model_id=self.model_id,
|
||||
provider_model_id=self.model_id,
|
||||
metadata={"ranking": "512"}
|
||||
)
|
||||
|
||||
async def compute_scores(self, query: str, chunks: List[Any]) -> np.ndarray:
|
||||
"""Computes reranking scores for the given query and chunks."""
|
||||
if self.rerank_strategy == RerankingStrategy.NONE:
|
||||
return np.zeros(len(chunks))
|
||||
|
||||
# Ensure the model is initialized before use
|
||||
if not self.model:
|
||||
raise RuntimeError("Model is not initialized. Call `await initialize()` first.")
|
||||
|
||||
# Create (query, content) pairs from chunks
|
||||
pairs = [(query, chunk.content) for chunk in chunks]
|
||||
|
||||
# Compute scores in a non-blocking way
|
||||
# Compute scores asynchronously
|
||||
scores = await asyncio.to_thread(self.model.predict, pairs)
|
||||
|
||||
# Apply reranking strategy
|
||||
if self.rerank_strategy == RerankingStrategy.DEFAULT:
|
||||
return scores
|
||||
|
||||
elif self.rerank_strategy == RerankingStrategy.BOOST:
|
||||
# Placeholder: apply a boost factor to scores
|
||||
return scores
|
||||
|
||||
return scores # Placeholder for boost logic
|
||||
elif self.rerank_strategy == RerankingStrategy.HYBRID:
|
||||
# Placeholder: combine cross-encoder scores with embedding-based scores
|
||||
return scores
|
||||
|
||||
return scores # Placeholder for hybrid logic
|
||||
elif self.rerank_strategy == RerankingStrategy.LLM_RERANK:
|
||||
# Placeholder: use an LLM to compute scores
|
||||
return scores
|
||||
|
||||
return scores # Placeholder for LLM rerank logic
|
||||
else:
|
||||
raise ValueError(f"Unknown reranking strategy: {self.rerank_strategy}")
|
||||
|
||||
```
|
||||
|
||||
#### ExternalReranker
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue