mirror of
				https://github.com/meta-llama/llama-stack.git
				synced 2025-10-26 09:15:40 +00:00 
			
		
		
		
	feat: Add rerank models and rerank API change (#3831)
# What does this PR do? <!-- Provide a short summary of what this PR does and why. Link to relevant issues if applicable. --> - Extend the model type to include rerank models. - Implement `rerank()` method in inference router. - Add `rerank_model_list` to `OpenAIMixin` to enable providers to register and identify rerank models - Update documentation. <!-- If resolving an issue, uncomment and update the line below --> <!-- Closes #[issue-number] --> ## Test Plan <!-- Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.* --> ``` pytest tests/unit/providers/utils/inference/test_openai_mixin.py ```
This commit is contained in:
		
							parent
							
								
									f2598d30e6
								
							
						
					
					
						commit
						bb1ebb3c6b
					
				
					 12 changed files with 186 additions and 43 deletions
				
			
		|  | @ -44,9 +44,14 @@ from llama_stack.apis.inference import ( | |||
|     OpenAIEmbeddingsResponse, | ||||
|     OpenAIMessageParam, | ||||
|     Order, | ||||
|     RerankResponse, | ||||
|     StopReason, | ||||
|     ToolPromptFormat, | ||||
| ) | ||||
| from llama_stack.apis.inference.inference import ( | ||||
|     OpenAIChatCompletionContentPartImageParam, | ||||
|     OpenAIChatCompletionContentPartTextParam, | ||||
| ) | ||||
| from llama_stack.apis.models import Model, ModelType | ||||
| from llama_stack.apis.telemetry import MetricEvent, MetricInResponse, Telemetry | ||||
| from llama_stack.log import get_logger | ||||
|  | @ -182,6 +187,23 @@ class InferenceRouter(Inference): | |||
|             raise ModelTypeError(model_id, model.model_type, expected_model_type) | ||||
|         return model | ||||
| 
 | ||||
|     async def rerank( | ||||
|         self, | ||||
|         model: str, | ||||
|         query: str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam, | ||||
|         items: list[str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam], | ||||
|         max_num_results: int | None = None, | ||||
|     ) -> RerankResponse: | ||||
|         logger.debug(f"InferenceRouter.rerank: {model}") | ||||
|         model_obj = await self._get_model(model, ModelType.rerank) | ||||
|         provider = await self.routing_table.get_provider_impl(model_obj.identifier) | ||||
|         return await provider.rerank( | ||||
|             model=model_obj.identifier, | ||||
|             query=query, | ||||
|             items=items, | ||||
|             max_num_results=max_num_results, | ||||
|         ) | ||||
| 
 | ||||
|     async def openai_completion( | ||||
|         self, | ||||
|         params: Annotated[OpenAICompletionRequestWithExtraBody, Body(...)], | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue