diff --git a/src/llama_stack/providers/remote/inference/oci/oci.py b/src/llama_stack/providers/remote/inference/oci/oci.py index 239443963..59544ca9c 100644 --- a/src/llama_stack/providers/remote/inference/oci/oci.py +++ b/src/llama_stack/providers/remote/inference/oci/oci.py @@ -18,11 +18,7 @@ from llama_stack.log import get_logger from llama_stack.providers.remote.inference.oci.auth import OciInstancePrincipalAuth, OciUserPrincipalAuth from llama_stack.providers.remote.inference.oci.config import OCIConfig from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin -from llama_stack_api import ( - ModelType, - OpenAIEmbeddingsRequestWithExtraBody, - OpenAIEmbeddingsResponse, -) +from llama_stack_api import Model, ModelType logger = get_logger(name=__name__, category="inference::oci") @@ -37,6 +33,8 @@ MODEL_CAPABILITIES = ["TEXT_GENERATION", "TEXT_SUMMARIZATION", "TEXT_EMBEDDINGS" class OCIInferenceAdapter(OpenAIMixin): config: OCIConfig + embedding_models: list[str] = [] + async def initialize(self) -> None: """Initialize and validate OCI configuration.""" if self.config.oci_auth_type not in VALID_OCI_AUTH_TYPES: @@ -113,7 +111,9 @@ class OCIInferenceAdapter(OpenAIMixin): client = GenerativeAiClient(config=oci_config, signer=oci_signer) models: ModelCollection = client.list_models( - compartment_id=compartment_id, capability=MODEL_CAPABILITIES, lifecycle_state="ACTIVE" + compartment_id=compartment_id, + # capability=MODEL_CAPABILITIES, + lifecycle_state="ACTIVE", ).data seen_models = set() @@ -122,7 +122,7 @@ class OCIInferenceAdapter(OpenAIMixin): if model.time_deprecated or model.time_on_demand_retired: continue - if "CHAT" not in model.capabilities or "FINE_TUNE" in model.capabilities: + if "UNKNOWN_ENUM_VALUE" in model.capabilities or "FINE_TUNE" in model.capabilities: continue # Use display_name + model_type as the key to avoid conflicts @@ -133,8 +133,30 @@ class OCIInferenceAdapter(OpenAIMixin): seen_models.add(model_key) model_ids.append(model.display_name) + if "TEXT_EMBEDDINGS" in model.capabilities: + self.embedding_models.append(model.display_name) + return model_ids - async def openai_embeddings(self, params: OpenAIEmbeddingsRequestWithExtraBody) -> OpenAIEmbeddingsResponse: - # The constructed url is a mask that hits OCI's "chat" action, which is not supported for embeddings. - raise NotImplementedError("OCI Provider does not (currently) support embeddings") + def construct_model_from_identifier(self, identifier: str) -> Model: + """ + Construct a Model instance corresponding to the given identifier + + Child classes can override this to customize model typing/metadata. + + :param identifier: The provider's model identifier + :return: A Model instance + """ + if identifier in self.embedding_models: + return Model( + provider_id=self.__provider_id__, # type: ignore[attr-defined] + provider_resource_id=identifier, + identifier=identifier, + model_type=ModelType.embedding, + ) + return Model( + provider_id=self.__provider_id__, # type: ignore[attr-defined] + provider_resource_id=identifier, + identifier=identifier, + model_type=ModelType.llm, + ) diff --git a/tests/integration/inference/test_openai_embeddings.py b/tests/integration/inference/test_openai_embeddings.py index fe8070162..704775716 100644 --- a/tests/integration/inference/test_openai_embeddings.py +++ b/tests/integration/inference/test_openai_embeddings.py @@ -138,7 +138,6 @@ def skip_if_model_doesnt_support_openai_embeddings(client, model_id): "remote::runpod", "remote::sambanova", "remote::tgi", - "remote::oci", ): pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support OpenAI embeddings.")