diff --git a/llama_stack/providers/utils/inference/openai_mixin.py b/llama_stack/providers/utils/inference/openai_mixin.py index e207b1a43..bbd3d2e10 100644 --- a/llama_stack/providers/utils/inference/openai_mixin.py +++ b/llama_stack/providers/utils/inference/openai_mixin.py @@ -48,7 +48,7 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel): - overwrite_completion_id: If True, overwrites the 'id' field in OpenAI responses - download_images: If True, downloads images and converts to base64 for providers that require it - embedding_model_metadata: A dictionary mapping model IDs to their embedding metadata - - rerank_model_list: A list of model IDs for rerank models + - construct_model_from_identifier: Method to construct a Model instance corresponding to the given identifier - provider_data_api_key_field: Optional field name in provider data to look for API key - list_provider_model_ids: Method to list available models from the provider - get_extra_client_params: Method to provide extra parameters to the AsyncOpenAI client @@ -79,10 +79,6 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel): # Format: {"model_id": {"embedding_dimension": 1536, "context_length": 8192}} embedding_model_metadata: dict[str, dict[str, int]] = {} - # List of rerank model IDs for this provider - # Can be set by subclasses or instances to provide rerank models - rerank_model_list: list[str] = [] - # Cache of available models keyed by model ID # This is set in list_models() and used in check_model_availability() _model_cache: dict[str, Model] = {} @@ -126,6 +122,30 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel): """ return {} + def construct_model_from_identifier(self, identifier: str) -> Model: + """ + Construct a Model instance corresponding to the given identifier + + Child classes can override this to customize model typing/metadata. + + :param identifier: The provider's model identifier + :return: A Model instance + """ + if metadata := self.embedding_model_metadata.get(identifier): + return Model( + provider_id=self.__provider_id__, # type: ignore[attr-defined] + provider_resource_id=identifier, + identifier=identifier, + model_type=ModelType.embedding, + metadata=metadata, + ) + return Model( + provider_id=self.__provider_id__, # type: ignore[attr-defined] + provider_resource_id=identifier, + identifier=identifier, + model_type=ModelType.llm, + ) + async def list_provider_model_ids(self) -> Iterable[str]: """ List available models from the provider. @@ -421,28 +441,7 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel): if self.allowed_models and provider_model_id not in self.allowed_models: logger.info(f"Skipping model {provider_model_id} as it is not in the allowed models list") continue - if metadata := self.embedding_model_metadata.get(provider_model_id): - model = Model( - provider_id=self.__provider_id__, # type: ignore[attr-defined] - provider_resource_id=provider_model_id, - identifier=provider_model_id, - model_type=ModelType.embedding, - metadata=metadata, - ) - elif provider_model_id in self.rerank_model_list: - model = Model( - provider_id=self.__provider_id__, # type: ignore[attr-defined] - provider_resource_id=provider_model_id, - identifier=provider_model_id, - model_type=ModelType.rerank, - ) - else: - model = Model( - provider_id=self.__provider_id__, # type: ignore[attr-defined] - provider_resource_id=provider_model_id, - identifier=provider_model_id, - model_type=ModelType.llm, - ) + model = self.construct_model_from_identifier(provider_model_id) self._model_cache[provider_model_id] = model return list(self._model_cache.values()) diff --git a/tests/unit/providers/utils/inference/test_openai_mixin.py b/tests/unit/providers/utils/inference/test_openai_mixin.py index 277982af6..d98c096aa 100644 --- a/tests/unit/providers/utils/inference/test_openai_mixin.py +++ b/tests/unit/providers/utils/inference/test_openai_mixin.py @@ -38,21 +38,26 @@ class OpenAIMixinWithEmbeddingsImpl(OpenAIMixinImpl): } -class OpenAIMixinWithRerankImpl(OpenAIMixinImpl): - """Test implementation with rerank model list""" - - rerank_model_list: list[str] = ["rerank-model-1", "rerank-model-2"] - - -class OpenAIMixinWithEmbeddingsAndRerankImpl(OpenAIMixinImpl): - """Test implementation with both embedding model metadata and rerank model list""" +class OpenAIMixinWithCustomModelConstruction(OpenAIMixinImpl): + """Test implementation that uses construct_model_from_identifier to add rerank models""" embedding_model_metadata: dict[str, dict[str, int]] = { "text-embedding-3-small": {"embedding_dimension": 1536, "context_length": 8192}, "text-embedding-ada-002": {"embedding_dimension": 1536, "context_length": 8192}, } - rerank_model_list: list[str] = ["rerank-model-1", "rerank-model-2"] + # Adds rerank models via construct_model_from_identifier + rerank_model_ids: set[str] = {"rerank-model-1", "rerank-model-2"} + + def construct_model_from_identifier(self, identifier: str) -> Model: + if identifier in self.rerank_model_ids: + return Model( + provider_id=self.__provider_id__, # type: ignore[attr-defined] + provider_resource_id=identifier, + identifier=identifier, + model_type=ModelType.rerank, + ) + return super().construct_model_from_identifier(identifier) @pytest.fixture @@ -80,17 +85,10 @@ def mixin_with_embeddings(): @pytest.fixture -def mixin_with_rerank(): - """Create a test instance of OpenAIMixin with rerank model list""" +def mixin_with_custom_model_construction(): + """Create a test instance using custom construct_model_from_identifier""" config = RemoteInferenceProviderConfig() - return OpenAIMixinWithRerankImpl(config=config) - - -@pytest.fixture -def mixin_with_embeddings_and_rerank(): - """Create a test instance of OpenAIMixin with both embedding model metadata and rerank model list""" - config = RemoteInferenceProviderConfig() - return OpenAIMixinWithEmbeddingsAndRerankImpl(config=config) + return OpenAIMixinWithCustomModelConstruction(config=config) @pytest.fixture @@ -404,52 +402,10 @@ class TestOpenAIMixinEmbeddingModelMetadata: _assert_models_match_expected(result, expected_models) -class TestOpenAIMixinRerankModelList: - """Test cases for rerank_model_list attribute functionality""" +class TestOpenAIMixinCustomModelConstruction: + """Test cases for mixed model types (LLM, embedding, rerank) through construct_model_from_identifier""" - async def test_rerank_model_identified(self, mixin_with_rerank, mock_client_context): - """Test that models in rerank_model_list are correctly identified as rerank models""" - # Create mock models: 1 rerank model and 1 LLM - mock_rerank_model = MagicMock(id="rerank-model-1") - mock_llm_model = MagicMock(id="gpt-4") - mock_models = [mock_rerank_model, mock_llm_model] - - mock_client = MagicMock() - - async def mock_models_list(): - for model in mock_models: - yield model - - mock_client.models.list.return_value = mock_models_list() - - with mock_client_context(mixin_with_rerank, mock_client): - result = await mixin_with_rerank.list_models() - - assert result is not None - assert len(result) == 2 - - expected_models = { - "rerank-model-1": { - "model_type": ModelType.rerank, - "metadata": {}, - "provider_id": "test-provider", - "provider_resource_id": "rerank-model-1", - }, - "gpt-4": { - "model_type": ModelType.llm, - "metadata": {}, - "provider_id": "test-provider", - "provider_resource_id": "gpt-4", - }, - } - - _assert_models_match_expected(result, expected_models) - - -class TestOpenAIMixinMixedModelTypes: - """Test cases for mixed model types (LLM, embedding, rerank)""" - - async def test_mixed_model_types_identification(self, mixin_with_embeddings_and_rerank, mock_client_context): + async def test_mixed_model_types_identification(self, mixin_with_custom_model_construction, mock_client_context): """Test that LLM, embedding, and rerank models are correctly identified with proper types and metadata""" # Create mock models: 1 embedding, 1 rerank, 1 LLM mock_embedding_model = MagicMock(id="text-embedding-3-small") @@ -465,8 +421,8 @@ class TestOpenAIMixinMixedModelTypes: mock_client.models.list.return_value = mock_models_list() - with mock_client_context(mixin_with_embeddings_and_rerank, mock_client): - result = await mixin_with_embeddings_and_rerank.list_models() + with mock_client_context(mixin_with_custom_model_construction, mock_client): + result = await mixin_with_custom_model_construction.list_models() assert result is not None assert len(result) == 3