mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-06 02:32:40 +00:00
Update RFC-0003-reranker-api.md
This commit is contained in:
parent
70a56ce869
commit
f81b7a8478
1 changed files with 31 additions and 20 deletions
|
@ -286,57 +286,68 @@ Reranks query–document pairs based on a specified strategy. It requires a `mod
|
||||||
|
|
||||||
```python
|
```python
|
||||||
class LocalReranker(RerankerProvider):
|
class LocalReranker(RerankerProvider):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_id: Optional[str] = None,
|
model_id: Optional[str] = None,
|
||||||
rerank_strategy: RerankingStrategy = RerankingStrategy.DEFAULT,
|
rerank_strategy: RerankingStrategy = RerankingStrategy.DEFAULT
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
||||||
self.model_id = model_id
|
self.model_id = model_id
|
||||||
self.rerank_strategy = rerank_strategy
|
self.rerank_strategy = rerank_strategy
|
||||||
self.model = asyncio.run(self._initialize_model())
|
self.model = None
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
"""Asynchronously initializes the model."""
|
||||||
|
self.model = await self._initialize_model()
|
||||||
|
|
||||||
async def _initialize_model(self) -> Any:
|
async def _initialize_model(self) -> Any:
|
||||||
|
"""Loads or registers the model asynchronously."""
|
||||||
if not self.model_id:
|
if not self.model_id:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"No model_id provided, but a model is required for this reranking strategy."
|
"No model_id provided, but a model is required for this reranking strategy."
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.rerank_strategy == RerankingStrategy.LLM_RERANK:
|
# Ensure models_registry is accessible
|
||||||
# Placeholder for an LLM-based model loader
|
if not hasattr(self, "models_registry"):
|
||||||
return self.model_id
|
raise AttributeError("models_registry is not defined.")
|
||||||
|
|
||||||
# Default: use model_id for cross-encoder/embedding-based model loading
|
# Attempt to get a pre-registered model
|
||||||
return self.model_id
|
model = await self.models_registry.get_model(self.model_id)
|
||||||
|
if model:
|
||||||
|
return model
|
||||||
|
|
||||||
|
# Otherwise, register a new one
|
||||||
|
return await self.models_registry.register_model(
|
||||||
|
model_id=self.model_id,
|
||||||
|
provider_model_id=self.model_id,
|
||||||
|
metadata={"ranking": "512"}
|
||||||
|
)
|
||||||
|
|
||||||
async def compute_scores(self, query: str, chunks: List[Any]) -> np.ndarray:
|
async def compute_scores(self, query: str, chunks: List[Any]) -> np.ndarray:
|
||||||
|
"""Computes reranking scores for the given query and chunks."""
|
||||||
if self.rerank_strategy == RerankingStrategy.NONE:
|
if self.rerank_strategy == RerankingStrategy.NONE:
|
||||||
return np.zeros(len(chunks))
|
return np.zeros(len(chunks))
|
||||||
|
|
||||||
|
# Ensure the model is initialized before use
|
||||||
|
if not self.model:
|
||||||
|
raise RuntimeError("Model is not initialized. Call `await initialize()` first.")
|
||||||
|
|
||||||
# Create (query, content) pairs from chunks
|
# Create (query, content) pairs from chunks
|
||||||
pairs = [(query, chunk.content) for chunk in chunks]
|
pairs = [(query, chunk.content) for chunk in chunks]
|
||||||
|
|
||||||
# Compute scores in a non-blocking way
|
# Compute scores asynchronously
|
||||||
scores = await asyncio.to_thread(self.model.predict, pairs)
|
scores = await asyncio.to_thread(self.model.predict, pairs)
|
||||||
|
|
||||||
|
# Apply reranking strategy
|
||||||
if self.rerank_strategy == RerankingStrategy.DEFAULT:
|
if self.rerank_strategy == RerankingStrategy.DEFAULT:
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
elif self.rerank_strategy == RerankingStrategy.BOOST:
|
elif self.rerank_strategy == RerankingStrategy.BOOST:
|
||||||
# Placeholder: apply a boost factor to scores
|
return scores # Placeholder for boost logic
|
||||||
return scores
|
|
||||||
|
|
||||||
elif self.rerank_strategy == RerankingStrategy.HYBRID:
|
elif self.rerank_strategy == RerankingStrategy.HYBRID:
|
||||||
# Placeholder: combine cross-encoder scores with embedding-based scores
|
return scores # Placeholder for hybrid logic
|
||||||
return scores
|
|
||||||
|
|
||||||
elif self.rerank_strategy == RerankingStrategy.LLM_RERANK:
|
elif self.rerank_strategy == RerankingStrategy.LLM_RERANK:
|
||||||
# Placeholder: use an LLM to compute scores
|
return scores # Placeholder for LLM rerank logic
|
||||||
return scores
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown reranking strategy: {self.rerank_strategy}")
|
raise ValueError(f"Unknown reranking strategy: {self.rerank_strategy}")
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
#### ExternalReranker
|
#### ExternalReranker
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue