mirror of
				https://github.com/meta-llama/llama-stack.git
				synced 2025-10-26 17:23:00 +00:00 
			
		
		
		
	feat: Add rerank models and rerank API change (#3831)
# What does this PR do? <!-- Provide a short summary of what this PR does and why. Link to relevant issues if applicable. --> - Extend the model type to include rerank models. - Implement `rerank()` method in inference router. - Add `rerank_model_list` to `OpenAIMixin` to enable providers to register and identify rerank models - Update documentation. <!-- If resolving an issue, uncomment and update the line below --> <!-- Closes #[issue-number] --> ## Test Plan <!-- Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.* --> ``` pytest tests/unit/providers/utils/inference/test_openai_mixin.py ```
This commit is contained in:
		
							parent
							
								
									f2598d30e6
								
							
						
					
					
						commit
						bb1ebb3c6b
					
				
					 12 changed files with 186 additions and 43 deletions
				
			
		|  | @ -38,6 +38,28 @@ class OpenAIMixinWithEmbeddingsImpl(OpenAIMixinImpl): | |||
|     } | ||||
| 
 | ||||
| 
 | ||||
| class OpenAIMixinWithCustomModelConstruction(OpenAIMixinImpl): | ||||
|     """Test implementation that uses construct_model_from_identifier to add rerank models""" | ||||
| 
 | ||||
|     embedding_model_metadata: dict[str, dict[str, int]] = { | ||||
|         "text-embedding-3-small": {"embedding_dimension": 1536, "context_length": 8192}, | ||||
|         "text-embedding-ada-002": {"embedding_dimension": 1536, "context_length": 8192}, | ||||
|     } | ||||
| 
 | ||||
|     # Adds rerank models via construct_model_from_identifier | ||||
|     rerank_model_ids: set[str] = {"rerank-model-1", "rerank-model-2"} | ||||
| 
 | ||||
|     def construct_model_from_identifier(self, identifier: str) -> Model: | ||||
|         if identifier in self.rerank_model_ids: | ||||
|             return Model( | ||||
|                 provider_id=self.__provider_id__,  # type: ignore[attr-defined] | ||||
|                 provider_resource_id=identifier, | ||||
|                 identifier=identifier, | ||||
|                 model_type=ModelType.rerank, | ||||
|             ) | ||||
|         return super().construct_model_from_identifier(identifier) | ||||
| 
 | ||||
| 
 | ||||
| @pytest.fixture | ||||
| def mixin(): | ||||
|     """Create a test instance of OpenAIMixin with mocked model_store""" | ||||
|  | @ -62,6 +84,13 @@ def mixin_with_embeddings(): | |||
|     return OpenAIMixinWithEmbeddingsImpl(config=config) | ||||
| 
 | ||||
| 
 | ||||
| @pytest.fixture | ||||
| def mixin_with_custom_model_construction(): | ||||
|     """Create a test instance using custom construct_model_from_identifier""" | ||||
|     config = RemoteInferenceProviderConfig() | ||||
|     return OpenAIMixinWithCustomModelConstruction(config=config) | ||||
| 
 | ||||
| 
 | ||||
| @pytest.fixture | ||||
| def mock_models(): | ||||
|     """Create multiple mock OpenAI model objects""" | ||||
|  | @ -113,6 +142,19 @@ def mock_client_context(): | |||
|     return _mock_client_context | ||||
| 
 | ||||
| 
 | ||||
| def _assert_models_match_expected(actual_models, expected_models): | ||||
|     """Verify the models match expected attributes. | ||||
| 
 | ||||
|     Args: | ||||
|         actual_models: List of models to verify | ||||
|         expected_models: Mapping of model identifier to expected attribute values | ||||
|     """ | ||||
|     for identifier, expected_attrs in expected_models.items(): | ||||
|         model = next(m for m in actual_models if m.identifier == identifier) | ||||
|         for attr_name, expected_value in expected_attrs.items(): | ||||
|             assert getattr(model, attr_name) == expected_value | ||||
| 
 | ||||
| 
 | ||||
| class TestOpenAIMixinListModels: | ||||
|     """Test cases for the list_models method""" | ||||
| 
 | ||||
|  | @ -342,21 +384,71 @@ class TestOpenAIMixinEmbeddingModelMetadata: | |||
|             assert result is not None | ||||
|             assert len(result) == 2 | ||||
| 
 | ||||
|             # Find the models in the result | ||||
|             embedding_model = next(m for m in result if m.identifier == "text-embedding-3-small") | ||||
|             llm_model = next(m for m in result if m.identifier == "gpt-4") | ||||
|             expected_models = { | ||||
|                 "text-embedding-3-small": { | ||||
|                     "model_type": ModelType.embedding, | ||||
|                     "metadata": {"embedding_dimension": 1536, "context_length": 8192}, | ||||
|                     "provider_id": "test-provider", | ||||
|                     "provider_resource_id": "text-embedding-3-small", | ||||
|                 }, | ||||
|                 "gpt-4": { | ||||
|                     "model_type": ModelType.llm, | ||||
|                     "metadata": {}, | ||||
|                     "provider_id": "test-provider", | ||||
|                     "provider_resource_id": "gpt-4", | ||||
|                 }, | ||||
|             } | ||||
| 
 | ||||
|             # Check embedding model | ||||
|             assert embedding_model.model_type == ModelType.embedding | ||||
|             assert embedding_model.metadata == {"embedding_dimension": 1536, "context_length": 8192} | ||||
|             assert embedding_model.provider_id == "test-provider" | ||||
|             assert embedding_model.provider_resource_id == "text-embedding-3-small" | ||||
|             _assert_models_match_expected(result, expected_models) | ||||
| 
 | ||||
|             # Check LLM model | ||||
|             assert llm_model.model_type == ModelType.llm | ||||
|             assert llm_model.metadata == {}  # No metadata for LLMs | ||||
|             assert llm_model.provider_id == "test-provider" | ||||
|             assert llm_model.provider_resource_id == "gpt-4" | ||||
| 
 | ||||
| class TestOpenAIMixinCustomModelConstruction: | ||||
|     """Test cases for mixed model types (LLM, embedding, rerank) through construct_model_from_identifier""" | ||||
| 
 | ||||
|     async def test_mixed_model_types_identification(self, mixin_with_custom_model_construction, mock_client_context): | ||||
|         """Test that LLM, embedding, and rerank models are correctly identified with proper types and metadata""" | ||||
|         # Create mock models: 1 embedding, 1 rerank, 1 LLM | ||||
|         mock_embedding_model = MagicMock(id="text-embedding-3-small") | ||||
|         mock_rerank_model = MagicMock(id="rerank-model-1") | ||||
|         mock_llm_model = MagicMock(id="gpt-4") | ||||
|         mock_models = [mock_embedding_model, mock_rerank_model, mock_llm_model] | ||||
| 
 | ||||
|         mock_client = MagicMock() | ||||
| 
 | ||||
|         async def mock_models_list(): | ||||
|             for model in mock_models: | ||||
|                 yield model | ||||
| 
 | ||||
|         mock_client.models.list.return_value = mock_models_list() | ||||
| 
 | ||||
|         with mock_client_context(mixin_with_custom_model_construction, mock_client): | ||||
|             result = await mixin_with_custom_model_construction.list_models() | ||||
| 
 | ||||
|             assert result is not None | ||||
|             assert len(result) == 3 | ||||
| 
 | ||||
|             expected_models = { | ||||
|                 "text-embedding-3-small": { | ||||
|                     "model_type": ModelType.embedding, | ||||
|                     "metadata": {"embedding_dimension": 1536, "context_length": 8192}, | ||||
|                     "provider_id": "test-provider", | ||||
|                     "provider_resource_id": "text-embedding-3-small", | ||||
|                 }, | ||||
|                 "rerank-model-1": { | ||||
|                     "model_type": ModelType.rerank, | ||||
|                     "metadata": {}, | ||||
|                     "provider_id": "test-provider", | ||||
|                     "provider_resource_id": "rerank-model-1", | ||||
|                 }, | ||||
|                 "gpt-4": { | ||||
|                     "model_type": ModelType.llm, | ||||
|                     "metadata": {}, | ||||
|                     "provider_id": "test-provider", | ||||
|                     "provider_resource_id": "gpt-4", | ||||
|                 }, | ||||
|             } | ||||
| 
 | ||||
|             _assert_models_match_expected(result, expected_models) | ||||
| 
 | ||||
| 
 | ||||
| class TestOpenAIMixinAllowedModels: | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue