From 6ad5fb5577cf74ef41b45f87fff0a4f856a1facd Mon Sep 17 00:00:00 2001 From: "Robert Riley (OCI)" Date: Mon, 8 Dec 2025 15:05:39 -0600 Subject: [PATCH] feat: Adding OCI Embeddings (#4300) # What does this PR do? Enabling usage of OCI embedding models. ## Test Plan Testing embedding model: `OCI_COMPARTMENT_OCID="" OCI_REGION="us-chicago-1" OCI_AUTH_TYPE=config_file pytest -sv tests/integration/inference/test_openai_embeddings.py --stack-config oci --embedding-model oci/openai.text-embedding-3-small --inference-mode live` Testing chat model: `OCI_COMPARTMENT_OCID="" OCI_REGION="us-chicago-1" OCI_AUTH_TYPE=config_file pytest -sv tests/integration/inference/ --stack-config oci --text-model oci/openai.gpt-4.1-nano-2025-04-14 --inference-mode live` Testing curl for embeddings: `curl -X POST http://localhost:8321/v1/embeddings -H "Content-Type: application/json" -d '{ "model": "oci/openai.text-embedding-3-small", "input": ["First text", "Second text"], "encoding_format": "float" }'` `{"object":"list","data":[{"object":"embedding","embedding":[-0.017190756...0.025272394],"index":1}],"model":"oci/openai.text-embedding-3-small","usage":{"prompt_tokens":4,"total_tokens":4}}` --------- Co-authored-by: Omar Abdelwahab --- .../providers/remote/inference/oci/oci.py | 42 ++++++++++++++----- .../inference/test_openai_embeddings.py | 1 - 2 files changed, 32 insertions(+), 11 deletions(-) diff --git a/src/llama_stack/providers/remote/inference/oci/oci.py b/src/llama_stack/providers/remote/inference/oci/oci.py index 239443963..59544ca9c 100644 --- a/src/llama_stack/providers/remote/inference/oci/oci.py +++ b/src/llama_stack/providers/remote/inference/oci/oci.py @@ -18,11 +18,7 @@ from llama_stack.log import get_logger from llama_stack.providers.remote.inference.oci.auth import OciInstancePrincipalAuth, OciUserPrincipalAuth from llama_stack.providers.remote.inference.oci.config import OCIConfig from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin -from llama_stack_api import ( - ModelType, - OpenAIEmbeddingsRequestWithExtraBody, - OpenAIEmbeddingsResponse, -) +from llama_stack_api import Model, ModelType logger = get_logger(name=__name__, category="inference::oci") @@ -37,6 +33,8 @@ MODEL_CAPABILITIES = ["TEXT_GENERATION", "TEXT_SUMMARIZATION", "TEXT_EMBEDDINGS" class OCIInferenceAdapter(OpenAIMixin): config: OCIConfig + embedding_models: list[str] = [] + async def initialize(self) -> None: """Initialize and validate OCI configuration.""" if self.config.oci_auth_type not in VALID_OCI_AUTH_TYPES: @@ -113,7 +111,9 @@ class OCIInferenceAdapter(OpenAIMixin): client = GenerativeAiClient(config=oci_config, signer=oci_signer) models: ModelCollection = client.list_models( - compartment_id=compartment_id, capability=MODEL_CAPABILITIES, lifecycle_state="ACTIVE" + compartment_id=compartment_id, + # capability=MODEL_CAPABILITIES, + lifecycle_state="ACTIVE", ).data seen_models = set() @@ -122,7 +122,7 @@ class OCIInferenceAdapter(OpenAIMixin): if model.time_deprecated or model.time_on_demand_retired: continue - if "CHAT" not in model.capabilities or "FINE_TUNE" in model.capabilities: + if "UNKNOWN_ENUM_VALUE" in model.capabilities or "FINE_TUNE" in model.capabilities: continue # Use display_name + model_type as the key to avoid conflicts @@ -133,8 +133,30 @@ class OCIInferenceAdapter(OpenAIMixin): seen_models.add(model_key) model_ids.append(model.display_name) + if "TEXT_EMBEDDINGS" in model.capabilities: + self.embedding_models.append(model.display_name) + return model_ids - async def openai_embeddings(self, params: OpenAIEmbeddingsRequestWithExtraBody) -> OpenAIEmbeddingsResponse: - # The constructed url is a mask that hits OCI's "chat" action, which is not supported for embeddings. - raise NotImplementedError("OCI Provider does not (currently) support embeddings") + def construct_model_from_identifier(self, identifier: str) -> Model: + """ + Construct a Model instance corresponding to the given identifier + + Child classes can override this to customize model typing/metadata. + + :param identifier: The provider's model identifier + :return: A Model instance + """ + if identifier in self.embedding_models: + return Model( + provider_id=self.__provider_id__, # type: ignore[attr-defined] + provider_resource_id=identifier, + identifier=identifier, + model_type=ModelType.embedding, + ) + return Model( + provider_id=self.__provider_id__, # type: ignore[attr-defined] + provider_resource_id=identifier, + identifier=identifier, + model_type=ModelType.llm, + ) diff --git a/tests/integration/inference/test_openai_embeddings.py b/tests/integration/inference/test_openai_embeddings.py index fe8070162..704775716 100644 --- a/tests/integration/inference/test_openai_embeddings.py +++ b/tests/integration/inference/test_openai_embeddings.py @@ -138,7 +138,6 @@ def skip_if_model_doesnt_support_openai_embeddings(client, model_id): "remote::runpod", "remote::sambanova", "remote::tgi", - "remote::oci", ): pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support OpenAI embeddings.")