mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-16 23:22:37 +00:00
feat: Adding OCI Embeddings (#4300)
Some checks failed
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 1s
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 0s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Integration Tests (Replay) / generate-matrix (push) Successful in 3s
API Conformance Tests / check-schema-compatibility (push) Successful in 10s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 11s
Python Package Build Test / build (3.12) (push) Successful in 15s
Python Package Build Test / build (3.13) (push) Successful in 18s
Test External API and Providers / test-external (venv) (push) Failing after 30s
UI Tests / ui-tests (22) (push) Successful in 56s
Vector IO Integration Tests / test-matrix (push) Failing after 1m1s
Unit Tests / unit-tests (3.13) (push) Failing after 1m44s
Unit Tests / unit-tests (3.12) (push) Failing after 1m48s
Integration Tests (Replay) / Integration Tests (, , , client=, ) (push) Failing after 3m17s
Pre-commit / pre-commit (22) (push) Successful in 3m22s
Some checks failed
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 1s
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 0s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Integration Tests (Replay) / generate-matrix (push) Successful in 3s
API Conformance Tests / check-schema-compatibility (push) Successful in 10s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 11s
Python Package Build Test / build (3.12) (push) Successful in 15s
Python Package Build Test / build (3.13) (push) Successful in 18s
Test External API and Providers / test-external (venv) (push) Failing after 30s
UI Tests / ui-tests (22) (push) Successful in 56s
Vector IO Integration Tests / test-matrix (push) Failing after 1m1s
Unit Tests / unit-tests (3.13) (push) Failing after 1m44s
Unit Tests / unit-tests (3.12) (push) Failing after 1m48s
Integration Tests (Replay) / Integration Tests (, , , client=, ) (push) Failing after 3m17s
Pre-commit / pre-commit (22) (push) Successful in 3m22s
# What does this PR do? Enabling usage of OCI embedding models. ## Test Plan Testing embedding model: `OCI_COMPARTMENT_OCID="" OCI_REGION="us-chicago-1" OCI_AUTH_TYPE=config_file pytest -sv tests/integration/inference/test_openai_embeddings.py --stack-config oci --embedding-model oci/openai.text-embedding-3-small --inference-mode live` Testing chat model: `OCI_COMPARTMENT_OCID="" OCI_REGION="us-chicago-1" OCI_AUTH_TYPE=config_file pytest -sv tests/integration/inference/ --stack-config oci --text-model oci/openai.gpt-4.1-nano-2025-04-14 --inference-mode live` Testing curl for embeddings: `curl -X POST http://localhost:8321/v1/embeddings -H "Content-Type: application/json" -d '{ "model": "oci/openai.text-embedding-3-small", "input": ["First text", "Second text"], "encoding_format": "float" }'` `{"object":"list","data":[{"object":"embedding","embedding":[-0.017190756...0.025272394],"index":1}],"model":"oci/openai.text-embedding-3-small","usage":{"prompt_tokens":4,"total_tokens":4}}` --------- Co-authored-by: Omar Abdelwahab <omaryashraf10@gmail.com>
This commit is contained in:
parent
d82a2cd6f8
commit
6ad5fb5577
2 changed files with 32 additions and 11 deletions
|
|
@ -18,11 +18,7 @@ from llama_stack.log import get_logger
|
|||
from llama_stack.providers.remote.inference.oci.auth import OciInstancePrincipalAuth, OciUserPrincipalAuth
|
||||
from llama_stack.providers.remote.inference.oci.config import OCIConfig
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
from llama_stack_api import (
|
||||
ModelType,
|
||||
OpenAIEmbeddingsRequestWithExtraBody,
|
||||
OpenAIEmbeddingsResponse,
|
||||
)
|
||||
from llama_stack_api import Model, ModelType
|
||||
|
||||
logger = get_logger(name=__name__, category="inference::oci")
|
||||
|
||||
|
|
@ -37,6 +33,8 @@ MODEL_CAPABILITIES = ["TEXT_GENERATION", "TEXT_SUMMARIZATION", "TEXT_EMBEDDINGS"
|
|||
class OCIInferenceAdapter(OpenAIMixin):
|
||||
config: OCIConfig
|
||||
|
||||
embedding_models: list[str] = []
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize and validate OCI configuration."""
|
||||
if self.config.oci_auth_type not in VALID_OCI_AUTH_TYPES:
|
||||
|
|
@ -113,7 +111,9 @@ class OCIInferenceAdapter(OpenAIMixin):
|
|||
client = GenerativeAiClient(config=oci_config, signer=oci_signer)
|
||||
|
||||
models: ModelCollection = client.list_models(
|
||||
compartment_id=compartment_id, capability=MODEL_CAPABILITIES, lifecycle_state="ACTIVE"
|
||||
compartment_id=compartment_id,
|
||||
# capability=MODEL_CAPABILITIES,
|
||||
lifecycle_state="ACTIVE",
|
||||
).data
|
||||
|
||||
seen_models = set()
|
||||
|
|
@ -122,7 +122,7 @@ class OCIInferenceAdapter(OpenAIMixin):
|
|||
if model.time_deprecated or model.time_on_demand_retired:
|
||||
continue
|
||||
|
||||
if "CHAT" not in model.capabilities or "FINE_TUNE" in model.capabilities:
|
||||
if "UNKNOWN_ENUM_VALUE" in model.capabilities or "FINE_TUNE" in model.capabilities:
|
||||
continue
|
||||
|
||||
# Use display_name + model_type as the key to avoid conflicts
|
||||
|
|
@ -133,8 +133,30 @@ class OCIInferenceAdapter(OpenAIMixin):
|
|||
seen_models.add(model_key)
|
||||
model_ids.append(model.display_name)
|
||||
|
||||
if "TEXT_EMBEDDINGS" in model.capabilities:
|
||||
self.embedding_models.append(model.display_name)
|
||||
|
||||
return model_ids
|
||||
|
||||
async def openai_embeddings(self, params: OpenAIEmbeddingsRequestWithExtraBody) -> OpenAIEmbeddingsResponse:
|
||||
# The constructed url is a mask that hits OCI's "chat" action, which is not supported for embeddings.
|
||||
raise NotImplementedError("OCI Provider does not (currently) support embeddings")
|
||||
def construct_model_from_identifier(self, identifier: str) -> Model:
|
||||
"""
|
||||
Construct a Model instance corresponding to the given identifier
|
||||
|
||||
Child classes can override this to customize model typing/metadata.
|
||||
|
||||
:param identifier: The provider's model identifier
|
||||
:return: A Model instance
|
||||
"""
|
||||
if identifier in self.embedding_models:
|
||||
return Model(
|
||||
provider_id=self.__provider_id__, # type: ignore[attr-defined]
|
||||
provider_resource_id=identifier,
|
||||
identifier=identifier,
|
||||
model_type=ModelType.embedding,
|
||||
)
|
||||
return Model(
|
||||
provider_id=self.__provider_id__, # type: ignore[attr-defined]
|
||||
provider_resource_id=identifier,
|
||||
identifier=identifier,
|
||||
model_type=ModelType.llm,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -138,7 +138,6 @@ def skip_if_model_doesnt_support_openai_embeddings(client, model_id):
|
|||
"remote::runpod",
|
||||
"remote::sambanova",
|
||||
"remote::tgi",
|
||||
"remote::oci",
|
||||
):
|
||||
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support OpenAI embeddings.")
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue