feat: add oci genai service as chat inference provider (#3876)

# What does this PR do?
Adds OCI GenAI PaaS models for openai chat completion endpoints.

## Test Plan
In an OCI tenancy with access to GenAI PaaS, perform the following
steps:

1. Ensure you have IAM policies in place to use service (check docs
included in this PR)
2. For local development, [setup OCI
cli](https://docs.oracle.com/en-us/iaas/Content/API/SDKDocs/cliinstall.htm)
and configure the CLI with your region, tenancy, and auth
[here](https://docs.oracle.com/en-us/iaas/Content/API/SDKDocs/cliconfigure.htm)
3. Once configured, go through llama-stack setup and run llama-stack
(uses config based auth) like:
```bash
OCI_AUTH_TYPE=config_file \
OCI_CLI_PROFILE=CHICAGO \
OCI_REGION=us-chicago-1 \
OCI_COMPARTMENT_OCID=ocid1.compartment.oc1..aaaaaaaa5...5a \
llama stack run oci
```
4. Hit the `models` endpoint to list models after server is running:
```bash
curl http://localhost:8321/v1/models | jq
...
{
      "identifier": "meta.llama-4-scout-17b-16e-instruct",
      "provider_resource_id": "ocid1.generativeaimodel.oc1.us-chicago-1.am...q",
      "provider_id": "oci",
      "type": "model",
      "metadata": {
        "display_name": "meta.llama-4-scout-17b-16e-instruct",
        "capabilities": [
          "CHAT"
        ],
        "oci_model_id": "ocid1.generativeaimodel.oc1.us-chicago-1.a...q"
      },
      "model_type": "llm"
},
   ...
```
5. Use the "display_name" field to use the model in a
`/chat/completions` request:
```bash
# Streaming result
curl -X POST http://localhost:8321/v1/chat/completions   -H "Content-Type: application/json"   -d '{
        "model": "meta.llama-4-scout-17b-16e-instruct",
       "stream": true,
       "temperature": 0.9,
      "messages": [
         {
           "role": "system",
           "content": "You are a funny comedian. You can be crass."
         },
          {
           "role": "user",
          "content": "Tell me a funny joke about programming."
         }
       ]
}'

# Non-streaming result
curl -X POST http://localhost:8321/v1/chat/completions   -H "Content-Type: application/json"   -d '{
        "model": "meta.llama-4-scout-17b-16e-instruct",
       "stream": false,
       "temperature": 0.9,
      "messages": [
         {
           "role": "system",
           "content": "You are a funny comedian. You can be crass."
         },
          {
           "role": "user",
          "content": "Tell me a funny joke about programming."
         }
       ]
}'
```
6. Try out other models from the `/models` endpoint.
This commit is contained in:
Dennis Kennetz 2025-11-10 15:16:24 -06:00 committed by GitHub
parent fadf17daf3
commit 209a78b618
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 938 additions and 0 deletions

View file

@ -0,0 +1,17 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.inference import InferenceProvider
from .config import OCIConfig
async def get_adapter_impl(config: OCIConfig, _deps) -> InferenceProvider:
from .oci import OCIInferenceAdapter
adapter = OCIInferenceAdapter(config=config)
await adapter.initialize()
return adapter

View file

@ -0,0 +1,79 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import Generator, Mapping
from typing import Any, override
import httpx
import oci
import requests
from oci.config import DEFAULT_LOCATION, DEFAULT_PROFILE
OciAuthSigner = type[oci.signer.AbstractBaseSigner]
class HttpxOciAuth(httpx.Auth):
"""
Custom HTTPX authentication class that implements OCI request signing.
This class handles the authentication flow for HTTPX requests by signing them
using the OCI Signer, which adds the necessary authentication headers for
OCI API calls.
Attributes:
signer (oci.signer.Signer): The OCI signer instance used for request signing
"""
def __init__(self, signer: OciAuthSigner):
self.signer = signer
@override
def auth_flow(self, request: httpx.Request) -> Generator[httpx.Request, httpx.Response, None]:
# Read the request content to handle streaming requests properly
try:
content = request.content
except httpx.RequestNotRead:
# For streaming requests, we need to read the content first
content = request.read()
req = requests.Request(
method=request.method,
url=str(request.url),
headers=dict(request.headers),
data=content,
)
prepared_request = req.prepare()
# Sign the request using the OCI Signer
self.signer.do_request_sign(prepared_request) # type: ignore
# Update the original HTTPX request with the signed headers
request.headers.update(prepared_request.headers)
yield request
class OciInstancePrincipalAuth(HttpxOciAuth):
def __init__(self, **kwargs: Mapping[str, Any]):
self.signer = oci.auth.signers.InstancePrincipalsSecurityTokenSigner(**kwargs)
class OciUserPrincipalAuth(HttpxOciAuth):
def __init__(self, config_file: str = DEFAULT_LOCATION, profile_name: str = DEFAULT_PROFILE):
config = oci.config.from_file(config_file, profile_name)
oci.config.validate_config(config) # type: ignore
key_content = ""
with open(config["key_file"]) as f:
key_content = f.read()
self.signer = oci.signer.Signer(
tenancy=config["tenancy"],
user=config["user"],
fingerprint=config["fingerprint"],
private_key_file_location=config.get("key_file"),
pass_phrase="none", # type: ignore
private_key_content=key_content,
)

View file

@ -0,0 +1,75 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
from typing import Any
from pydantic import BaseModel, Field
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
class OCIProviderDataValidator(BaseModel):
oci_auth_type: str = Field(
description="OCI authentication type (must be one of: instance_principal, config_file)",
)
oci_region: str = Field(
description="OCI region (e.g., us-ashburn-1)",
)
oci_compartment_id: str = Field(
description="OCI compartment ID for the Generative AI service",
)
oci_config_file_path: str | None = Field(
default="~/.oci/config",
description="OCI config file path (required if oci_auth_type is config_file)",
)
oci_config_profile: str | None = Field(
default="DEFAULT",
description="OCI config profile (required if oci_auth_type is config_file)",
)
@json_schema_type
class OCIConfig(RemoteInferenceProviderConfig):
oci_auth_type: str = Field(
description="OCI authentication type (must be one of: instance_principal, config_file)",
default_factory=lambda: os.getenv("OCI_AUTH_TYPE", "instance_principal"),
)
oci_region: str = Field(
default_factory=lambda: os.getenv("OCI_REGION", "us-ashburn-1"),
description="OCI region (e.g., us-ashburn-1)",
)
oci_compartment_id: str = Field(
default_factory=lambda: os.getenv("OCI_COMPARTMENT_OCID", ""),
description="OCI compartment ID for the Generative AI service",
)
oci_config_file_path: str = Field(
default_factory=lambda: os.getenv("OCI_CONFIG_FILE_PATH", "~/.oci/config"),
description="OCI config file path (required if oci_auth_type is config_file)",
)
oci_config_profile: str = Field(
default_factory=lambda: os.getenv("OCI_CLI_PROFILE", "DEFAULT"),
description="OCI config profile (required if oci_auth_type is config_file)",
)
@classmethod
def sample_run_config(
cls,
oci_auth_type: str = "${env.OCI_AUTH_TYPE:=instance_principal}",
oci_config_file_path: str = "${env.OCI_CONFIG_FILE_PATH:=~/.oci/config}",
oci_config_profile: str = "${env.OCI_CLI_PROFILE:=DEFAULT}",
oci_region: str = "${env.OCI_REGION:=us-ashburn-1}",
oci_compartment_id: str = "${env.OCI_COMPARTMENT_OCID:=}",
**kwargs,
) -> dict[str, Any]:
return {
"oci_auth_type": oci_auth_type,
"oci_config_file_path": oci_config_file_path,
"oci_config_profile": oci_config_profile,
"oci_region": oci_region,
"oci_compartment_id": oci_compartment_id,
}

View file

@ -0,0 +1,140 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import Iterable
from typing import Any
import httpx
import oci
from oci.generative_ai.generative_ai_client import GenerativeAiClient
from oci.generative_ai.models import ModelCollection
from openai._base_client import DefaultAsyncHttpxClient
from llama_stack.apis.inference.inference import (
OpenAIEmbeddingsRequestWithExtraBody,
OpenAIEmbeddingsResponse,
)
from llama_stack.apis.models import ModelType
from llama_stack.log import get_logger
from llama_stack.providers.remote.inference.oci.auth import OciInstancePrincipalAuth, OciUserPrincipalAuth
from llama_stack.providers.remote.inference.oci.config import OCIConfig
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
logger = get_logger(name=__name__, category="inference::oci")
OCI_AUTH_TYPE_INSTANCE_PRINCIPAL = "instance_principal"
OCI_AUTH_TYPE_CONFIG_FILE = "config_file"
VALID_OCI_AUTH_TYPES = [OCI_AUTH_TYPE_INSTANCE_PRINCIPAL, OCI_AUTH_TYPE_CONFIG_FILE]
DEFAULT_OCI_REGION = "us-ashburn-1"
MODEL_CAPABILITIES = ["TEXT_GENERATION", "TEXT_SUMMARIZATION", "TEXT_EMBEDDINGS", "CHAT"]
class OCIInferenceAdapter(OpenAIMixin):
config: OCIConfig
async def initialize(self) -> None:
"""Initialize and validate OCI configuration."""
if self.config.oci_auth_type not in VALID_OCI_AUTH_TYPES:
raise ValueError(
f"Invalid OCI authentication type: {self.config.oci_auth_type}."
f"Valid types are one of: {VALID_OCI_AUTH_TYPES}"
)
if not self.config.oci_compartment_id:
raise ValueError("OCI_COMPARTMENT_OCID is a required parameter. Either set in env variable or config.")
def get_base_url(self) -> str:
region = self.config.oci_region or DEFAULT_OCI_REGION
return f"https://inference.generativeai.{region}.oci.oraclecloud.com/20231130/actions/v1"
def get_api_key(self) -> str | None:
# OCI doesn't use API keys, it uses request signing
return "<NOTUSED>"
def get_extra_client_params(self) -> dict[str, Any]:
"""
Get extra parameters for the AsyncOpenAI client, including OCI-specific auth and headers.
"""
auth = self._get_auth()
compartment_id = self.config.oci_compartment_id or ""
return {
"http_client": DefaultAsyncHttpxClient(
auth=auth,
headers={
"CompartmentId": compartment_id,
},
),
}
def _get_oci_signer(self) -> oci.signer.AbstractBaseSigner | None:
if self.config.oci_auth_type == OCI_AUTH_TYPE_INSTANCE_PRINCIPAL:
return oci.auth.signers.InstancePrincipalsSecurityTokenSigner()
return None
def _get_oci_config(self) -> dict:
if self.config.oci_auth_type == OCI_AUTH_TYPE_INSTANCE_PRINCIPAL:
config = {"region": self.config.oci_region}
elif self.config.oci_auth_type == OCI_AUTH_TYPE_CONFIG_FILE:
config = oci.config.from_file(self.config.oci_config_file_path, self.config.oci_config_profile)
if not config.get("region"):
raise ValueError(
"Region not specified in config. Please specify in config or with OCI_REGION env variable."
)
return config
def _get_auth(self) -> httpx.Auth:
if self.config.oci_auth_type == OCI_AUTH_TYPE_INSTANCE_PRINCIPAL:
return OciInstancePrincipalAuth()
elif self.config.oci_auth_type == OCI_AUTH_TYPE_CONFIG_FILE:
return OciUserPrincipalAuth(
config_file=self.config.oci_config_file_path, profile_name=self.config.oci_config_profile
)
else:
raise ValueError(f"Invalid OCI authentication type: {self.config.oci_auth_type}")
async def list_provider_model_ids(self) -> Iterable[str]:
"""
List available models from OCI Generative AI service.
"""
oci_config = self._get_oci_config()
oci_signer = self._get_oci_signer()
compartment_id = self.config.oci_compartment_id or ""
if oci_signer is None:
client = GenerativeAiClient(config=oci_config)
else:
client = GenerativeAiClient(config=oci_config, signer=oci_signer)
models: ModelCollection = client.list_models(
compartment_id=compartment_id, capability=MODEL_CAPABILITIES, lifecycle_state="ACTIVE"
).data
seen_models = set()
model_ids = []
for model in models.items:
if model.time_deprecated or model.time_on_demand_retired:
continue
if "CHAT" not in model.capabilities or "FINE_TUNE" in model.capabilities:
continue
# Use display_name + model_type as the key to avoid conflicts
model_key = (model.display_name, ModelType.llm)
if model_key in seen_models:
continue
seen_models.add(model_key)
model_ids.append(model.display_name)
return model_ids
async def openai_embeddings(self, params: OpenAIEmbeddingsRequestWithExtraBody) -> OpenAIEmbeddingsResponse:
# The constructed url is a mask that hits OCI's "chat" action, which is not supported for embeddings.
raise NotImplementedError("OCI Provider does not (currently) support embeddings")