mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-16 23:22:37 +00:00
feat: Adding OCI Embeddings (#4300)
Some checks failed
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 1s
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 0s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Integration Tests (Replay) / generate-matrix (push) Successful in 3s
API Conformance Tests / check-schema-compatibility (push) Successful in 10s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 11s
Python Package Build Test / build (3.12) (push) Successful in 15s
Python Package Build Test / build (3.13) (push) Successful in 18s
Test External API and Providers / test-external (venv) (push) Failing after 30s
UI Tests / ui-tests (22) (push) Successful in 56s
Vector IO Integration Tests / test-matrix (push) Failing after 1m1s
Unit Tests / unit-tests (3.13) (push) Failing after 1m44s
Unit Tests / unit-tests (3.12) (push) Failing after 1m48s
Integration Tests (Replay) / Integration Tests (, , , client=, ) (push) Failing after 3m17s
Pre-commit / pre-commit (22) (push) Successful in 3m22s
Some checks failed
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 1s
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 0s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Integration Tests (Replay) / generate-matrix (push) Successful in 3s
API Conformance Tests / check-schema-compatibility (push) Successful in 10s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 11s
Python Package Build Test / build (3.12) (push) Successful in 15s
Python Package Build Test / build (3.13) (push) Successful in 18s
Test External API and Providers / test-external (venv) (push) Failing after 30s
UI Tests / ui-tests (22) (push) Successful in 56s
Vector IO Integration Tests / test-matrix (push) Failing after 1m1s
Unit Tests / unit-tests (3.13) (push) Failing after 1m44s
Unit Tests / unit-tests (3.12) (push) Failing after 1m48s
Integration Tests (Replay) / Integration Tests (, , , client=, ) (push) Failing after 3m17s
Pre-commit / pre-commit (22) (push) Successful in 3m22s
# What does this PR do? Enabling usage of OCI embedding models. ## Test Plan Testing embedding model: `OCI_COMPARTMENT_OCID="" OCI_REGION="us-chicago-1" OCI_AUTH_TYPE=config_file pytest -sv tests/integration/inference/test_openai_embeddings.py --stack-config oci --embedding-model oci/openai.text-embedding-3-small --inference-mode live` Testing chat model: `OCI_COMPARTMENT_OCID="" OCI_REGION="us-chicago-1" OCI_AUTH_TYPE=config_file pytest -sv tests/integration/inference/ --stack-config oci --text-model oci/openai.gpt-4.1-nano-2025-04-14 --inference-mode live` Testing curl for embeddings: `curl -X POST http://localhost:8321/v1/embeddings -H "Content-Type: application/json" -d '{ "model": "oci/openai.text-embedding-3-small", "input": ["First text", "Second text"], "encoding_format": "float" }'` `{"object":"list","data":[{"object":"embedding","embedding":[-0.017190756...0.025272394],"index":1}],"model":"oci/openai.text-embedding-3-small","usage":{"prompt_tokens":4,"total_tokens":4}}` --------- Co-authored-by: Omar Abdelwahab <omaryashraf10@gmail.com>
This commit is contained in:
parent
d82a2cd6f8
commit
6ad5fb5577
2 changed files with 32 additions and 11 deletions
|
|
@ -18,11 +18,7 @@ from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.remote.inference.oci.auth import OciInstancePrincipalAuth, OciUserPrincipalAuth
|
from llama_stack.providers.remote.inference.oci.auth import OciInstancePrincipalAuth, OciUserPrincipalAuth
|
||||||
from llama_stack.providers.remote.inference.oci.config import OCIConfig
|
from llama_stack.providers.remote.inference.oci.config import OCIConfig
|
||||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||||
from llama_stack_api import (
|
from llama_stack_api import Model, ModelType
|
||||||
ModelType,
|
|
||||||
OpenAIEmbeddingsRequestWithExtraBody,
|
|
||||||
OpenAIEmbeddingsResponse,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="inference::oci")
|
logger = get_logger(name=__name__, category="inference::oci")
|
||||||
|
|
||||||
|
|
@ -37,6 +33,8 @@ MODEL_CAPABILITIES = ["TEXT_GENERATION", "TEXT_SUMMARIZATION", "TEXT_EMBEDDINGS"
|
||||||
class OCIInferenceAdapter(OpenAIMixin):
|
class OCIInferenceAdapter(OpenAIMixin):
|
||||||
config: OCIConfig
|
config: OCIConfig
|
||||||
|
|
||||||
|
embedding_models: list[str] = []
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
"""Initialize and validate OCI configuration."""
|
"""Initialize and validate OCI configuration."""
|
||||||
if self.config.oci_auth_type not in VALID_OCI_AUTH_TYPES:
|
if self.config.oci_auth_type not in VALID_OCI_AUTH_TYPES:
|
||||||
|
|
@ -113,7 +111,9 @@ class OCIInferenceAdapter(OpenAIMixin):
|
||||||
client = GenerativeAiClient(config=oci_config, signer=oci_signer)
|
client = GenerativeAiClient(config=oci_config, signer=oci_signer)
|
||||||
|
|
||||||
models: ModelCollection = client.list_models(
|
models: ModelCollection = client.list_models(
|
||||||
compartment_id=compartment_id, capability=MODEL_CAPABILITIES, lifecycle_state="ACTIVE"
|
compartment_id=compartment_id,
|
||||||
|
# capability=MODEL_CAPABILITIES,
|
||||||
|
lifecycle_state="ACTIVE",
|
||||||
).data
|
).data
|
||||||
|
|
||||||
seen_models = set()
|
seen_models = set()
|
||||||
|
|
@ -122,7 +122,7 @@ class OCIInferenceAdapter(OpenAIMixin):
|
||||||
if model.time_deprecated or model.time_on_demand_retired:
|
if model.time_deprecated or model.time_on_demand_retired:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if "CHAT" not in model.capabilities or "FINE_TUNE" in model.capabilities:
|
if "UNKNOWN_ENUM_VALUE" in model.capabilities or "FINE_TUNE" in model.capabilities:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Use display_name + model_type as the key to avoid conflicts
|
# Use display_name + model_type as the key to avoid conflicts
|
||||||
|
|
@ -133,8 +133,30 @@ class OCIInferenceAdapter(OpenAIMixin):
|
||||||
seen_models.add(model_key)
|
seen_models.add(model_key)
|
||||||
model_ids.append(model.display_name)
|
model_ids.append(model.display_name)
|
||||||
|
|
||||||
|
if "TEXT_EMBEDDINGS" in model.capabilities:
|
||||||
|
self.embedding_models.append(model.display_name)
|
||||||
|
|
||||||
return model_ids
|
return model_ids
|
||||||
|
|
||||||
async def openai_embeddings(self, params: OpenAIEmbeddingsRequestWithExtraBody) -> OpenAIEmbeddingsResponse:
|
def construct_model_from_identifier(self, identifier: str) -> Model:
|
||||||
# The constructed url is a mask that hits OCI's "chat" action, which is not supported for embeddings.
|
"""
|
||||||
raise NotImplementedError("OCI Provider does not (currently) support embeddings")
|
Construct a Model instance corresponding to the given identifier
|
||||||
|
|
||||||
|
Child classes can override this to customize model typing/metadata.
|
||||||
|
|
||||||
|
:param identifier: The provider's model identifier
|
||||||
|
:return: A Model instance
|
||||||
|
"""
|
||||||
|
if identifier in self.embedding_models:
|
||||||
|
return Model(
|
||||||
|
provider_id=self.__provider_id__, # type: ignore[attr-defined]
|
||||||
|
provider_resource_id=identifier,
|
||||||
|
identifier=identifier,
|
||||||
|
model_type=ModelType.embedding,
|
||||||
|
)
|
||||||
|
return Model(
|
||||||
|
provider_id=self.__provider_id__, # type: ignore[attr-defined]
|
||||||
|
provider_resource_id=identifier,
|
||||||
|
identifier=identifier,
|
||||||
|
model_type=ModelType.llm,
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -138,7 +138,6 @@ def skip_if_model_doesnt_support_openai_embeddings(client, model_id):
|
||||||
"remote::runpod",
|
"remote::runpod",
|
||||||
"remote::sambanova",
|
"remote::sambanova",
|
||||||
"remote::tgi",
|
"remote::tgi",
|
||||||
"remote::oci",
|
|
||||||
):
|
):
|
||||||
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support OpenAI embeddings.")
|
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support OpenAI embeddings.")
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue