mirror of
				https://github.com/meta-llama/llama-stack.git
				synced 2025-10-26 01:12:59 +00:00 
			
		
		
		
	feat: Add rerank models and rerank API change (#3831)
# What does this PR do? <!-- Provide a short summary of what this PR does and why. Link to relevant issues if applicable. --> - Extend the model type to include rerank models. - Implement `rerank()` method in inference router. - Add `rerank_model_list` to `OpenAIMixin` to enable providers to register and identify rerank models - Update documentation. <!-- If resolving an issue, uncomment and update the line below --> <!-- Closes #[issue-number] --> ## Test Plan <!-- Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.* --> ``` pytest tests/unit/providers/utils/inference/test_openai_mixin.py ```
This commit is contained in:
		
							parent
							
								
									f2598d30e6
								
							
						
					
					
						commit
						bb1ebb3c6b
					
				
					 12 changed files with 186 additions and 43 deletions
				
			
		|  | @ -1234,9 +1234,10 @@ class Inference(InferenceProvider): | |||
| 
 | ||||
|     Llama Stack Inference API for generating completions, chat completions, and embeddings. | ||||
| 
 | ||||
|     This API provides the raw interface to the underlying models. Two kinds of models are supported: | ||||
|     This API provides the raw interface to the underlying models. Three kinds of models are supported: | ||||
|     - LLM models: these models generate "raw" and "chat" (conversational) completions. | ||||
|     - Embedding models: these models generate embeddings to be used for semantic search. | ||||
|     - Rerank models: these models reorder the documents based on their relevance to a query. | ||||
|     """ | ||||
| 
 | ||||
|     @webmethod(route="/openai/v1/chat/completions", method="GET", level=LLAMA_STACK_API_V1, deprecated=True) | ||||
|  |  | |||
|  | @ -27,10 +27,12 @@ class ModelType(StrEnum): | |||
|     """Enumeration of supported model types in Llama Stack. | ||||
|     :cvar llm: Large language model for text generation and completion | ||||
|     :cvar embedding: Embedding model for converting text to vector representations | ||||
|     :cvar rerank: Reranking model for reordering documents based on their relevance to a query | ||||
|     """ | ||||
| 
 | ||||
|     llm = "llm" | ||||
|     embedding = "embedding" | ||||
|     rerank = "rerank" | ||||
| 
 | ||||
| 
 | ||||
| @json_schema_type | ||||
|  |  | |||
|  | @ -44,9 +44,14 @@ from llama_stack.apis.inference import ( | |||
|     OpenAIEmbeddingsResponse, | ||||
|     OpenAIMessageParam, | ||||
|     Order, | ||||
|     RerankResponse, | ||||
|     StopReason, | ||||
|     ToolPromptFormat, | ||||
| ) | ||||
| from llama_stack.apis.inference.inference import ( | ||||
|     OpenAIChatCompletionContentPartImageParam, | ||||
|     OpenAIChatCompletionContentPartTextParam, | ||||
| ) | ||||
| from llama_stack.apis.models import Model, ModelType | ||||
| from llama_stack.apis.telemetry import MetricEvent, MetricInResponse, Telemetry | ||||
| from llama_stack.log import get_logger | ||||
|  | @ -182,6 +187,23 @@ class InferenceRouter(Inference): | |||
|             raise ModelTypeError(model_id, model.model_type, expected_model_type) | ||||
|         return model | ||||
| 
 | ||||
|     async def rerank( | ||||
|         self, | ||||
|         model: str, | ||||
|         query: str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam, | ||||
|         items: list[str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam], | ||||
|         max_num_results: int | None = None, | ||||
|     ) -> RerankResponse: | ||||
|         logger.debug(f"InferenceRouter.rerank: {model}") | ||||
|         model_obj = await self._get_model(model, ModelType.rerank) | ||||
|         provider = await self.routing_table.get_provider_impl(model_obj.identifier) | ||||
|         return await provider.rerank( | ||||
|             model=model_obj.identifier, | ||||
|             query=query, | ||||
|             items=items, | ||||
|             max_num_results=max_num_results, | ||||
|         ) | ||||
| 
 | ||||
|     async def openai_completion( | ||||
|         self, | ||||
|         params: Annotated[OpenAICompletionRequestWithExtraBody, Body(...)], | ||||
|  |  | |||
|  | @ -48,6 +48,7 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel): | |||
|     - overwrite_completion_id: If True, overwrites the 'id' field in OpenAI responses | ||||
|     - download_images: If True, downloads images and converts to base64 for providers that require it | ||||
|     - embedding_model_metadata: A dictionary mapping model IDs to their embedding metadata | ||||
|     - construct_model_from_identifier: Method to construct a Model instance corresponding to the given identifier | ||||
|     - provider_data_api_key_field: Optional field name in provider data to look for API key | ||||
|     - list_provider_model_ids: Method to list available models from the provider | ||||
|     - get_extra_client_params: Method to provide extra parameters to the AsyncOpenAI client | ||||
|  | @ -121,6 +122,30 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel): | |||
|         """ | ||||
|         return {} | ||||
| 
 | ||||
|     def construct_model_from_identifier(self, identifier: str) -> Model: | ||||
|         """ | ||||
|         Construct a Model instance corresponding to the given identifier | ||||
| 
 | ||||
|         Child classes can override this to customize model typing/metadata. | ||||
| 
 | ||||
|         :param identifier: The provider's model identifier | ||||
|         :return: A Model instance | ||||
|         """ | ||||
|         if metadata := self.embedding_model_metadata.get(identifier): | ||||
|             return Model( | ||||
|                 provider_id=self.__provider_id__,  # type: ignore[attr-defined] | ||||
|                 provider_resource_id=identifier, | ||||
|                 identifier=identifier, | ||||
|                 model_type=ModelType.embedding, | ||||
|                 metadata=metadata, | ||||
|             ) | ||||
|         return Model( | ||||
|             provider_id=self.__provider_id__,  # type: ignore[attr-defined] | ||||
|             provider_resource_id=identifier, | ||||
|             identifier=identifier, | ||||
|             model_type=ModelType.llm, | ||||
|         ) | ||||
| 
 | ||||
|     async def list_provider_model_ids(self) -> Iterable[str]: | ||||
|         """ | ||||
|         List available models from the provider. | ||||
|  | @ -416,21 +441,7 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel): | |||
|             if self.allowed_models and provider_model_id not in self.allowed_models: | ||||
|                 logger.info(f"Skipping model {provider_model_id} as it is not in the allowed models list") | ||||
|                 continue | ||||
|             if metadata := self.embedding_model_metadata.get(provider_model_id): | ||||
|                 model = Model( | ||||
|                     provider_id=self.__provider_id__,  # type: ignore[attr-defined] | ||||
|                     provider_resource_id=provider_model_id, | ||||
|                     identifier=provider_model_id, | ||||
|                     model_type=ModelType.embedding, | ||||
|                     metadata=metadata, | ||||
|                 ) | ||||
|             else: | ||||
|                 model = Model( | ||||
|                     provider_id=self.__provider_id__,  # type: ignore[attr-defined] | ||||
|                     provider_resource_id=provider_model_id, | ||||
|                     identifier=provider_model_id, | ||||
|                     model_type=ModelType.llm, | ||||
|                 ) | ||||
|             model = self.construct_model_from_identifier(provider_model_id) | ||||
|             self._model_cache[provider_model_id] = model | ||||
| 
 | ||||
|         return list(self._model_cache.values()) | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue