mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
feat: add oci genai service as chat inference provider (#3876)
# What does this PR do? Adds OCI GenAI PaaS models for openai chat completion endpoints. ## Test Plan In an OCI tenancy with access to GenAI PaaS, perform the following steps: 1. Ensure you have IAM policies in place to use service (check docs included in this PR) 2. For local development, [setup OCI cli](https://docs.oracle.com/en-us/iaas/Content/API/SDKDocs/cliinstall.htm) and configure the CLI with your region, tenancy, and auth [here](https://docs.oracle.com/en-us/iaas/Content/API/SDKDocs/cliconfigure.htm) 3. Once configured, go through llama-stack setup and run llama-stack (uses config based auth) like: ```bash OCI_AUTH_TYPE=config_file \ OCI_CLI_PROFILE=CHICAGO \ OCI_REGION=us-chicago-1 \ OCI_COMPARTMENT_OCID=ocid1.compartment.oc1..aaaaaaaa5...5a \ llama stack run oci ``` 4. Hit the `models` endpoint to list models after server is running: ```bash curl http://localhost:8321/v1/models | jq ... { "identifier": "meta.llama-4-scout-17b-16e-instruct", "provider_resource_id": "ocid1.generativeaimodel.oc1.us-chicago-1.am...q", "provider_id": "oci", "type": "model", "metadata": { "display_name": "meta.llama-4-scout-17b-16e-instruct", "capabilities": [ "CHAT" ], "oci_model_id": "ocid1.generativeaimodel.oc1.us-chicago-1.a...q" }, "model_type": "llm" }, ... ``` 5. Use the "display_name" field to use the model in a `/chat/completions` request: ```bash # Streaming result curl -X POST http://localhost:8321/v1/chat/completions -H "Content-Type: application/json" -d '{ "model": "meta.llama-4-scout-17b-16e-instruct", "stream": true, "temperature": 0.9, "messages": [ { "role": "system", "content": "You are a funny comedian. You can be crass." }, { "role": "user", "content": "Tell me a funny joke about programming." } ] }' # Non-streaming result curl -X POST http://localhost:8321/v1/chat/completions -H "Content-Type: application/json" -d '{ "model": "meta.llama-4-scout-17b-16e-instruct", "stream": false, "temperature": 0.9, "messages": [ { "role": "system", "content": "You are a funny comedian. You can be crass." }, { "role": "user", "content": "Tell me a funny joke about programming." } ] }' ``` 6. Try out other models from the `/models` endpoint.
This commit is contained in:
parent
fadf17daf3
commit
209a78b618
15 changed files with 938 additions and 0 deletions
17
src/llama_stack/providers/remote/inference/oci/__init__.py
Normal file
17
src/llama_stack/providers/remote/inference/oci/__init__.py
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.inference import InferenceProvider
|
||||
|
||||
from .config import OCIConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: OCIConfig, _deps) -> InferenceProvider:
|
||||
from .oci import OCIInferenceAdapter
|
||||
|
||||
adapter = OCIInferenceAdapter(config=config)
|
||||
await adapter.initialize()
|
||||
return adapter
|
||||
79
src/llama_stack/providers/remote/inference/oci/auth.py
Normal file
79
src/llama_stack/providers/remote/inference/oci/auth.py
Normal file
|
|
@ -0,0 +1,79 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any, override
|
||||
|
||||
import httpx
|
||||
import oci
|
||||
import requests
|
||||
from oci.config import DEFAULT_LOCATION, DEFAULT_PROFILE
|
||||
|
||||
OciAuthSigner = type[oci.signer.AbstractBaseSigner]
|
||||
|
||||
|
||||
class HttpxOciAuth(httpx.Auth):
|
||||
"""
|
||||
Custom HTTPX authentication class that implements OCI request signing.
|
||||
|
||||
This class handles the authentication flow for HTTPX requests by signing them
|
||||
using the OCI Signer, which adds the necessary authentication headers for
|
||||
OCI API calls.
|
||||
|
||||
Attributes:
|
||||
signer (oci.signer.Signer): The OCI signer instance used for request signing
|
||||
"""
|
||||
|
||||
def __init__(self, signer: OciAuthSigner):
|
||||
self.signer = signer
|
||||
|
||||
@override
|
||||
def auth_flow(self, request: httpx.Request) -> Generator[httpx.Request, httpx.Response, None]:
|
||||
# Read the request content to handle streaming requests properly
|
||||
try:
|
||||
content = request.content
|
||||
except httpx.RequestNotRead:
|
||||
# For streaming requests, we need to read the content first
|
||||
content = request.read()
|
||||
|
||||
req = requests.Request(
|
||||
method=request.method,
|
||||
url=str(request.url),
|
||||
headers=dict(request.headers),
|
||||
data=content,
|
||||
)
|
||||
prepared_request = req.prepare()
|
||||
|
||||
# Sign the request using the OCI Signer
|
||||
self.signer.do_request_sign(prepared_request) # type: ignore
|
||||
|
||||
# Update the original HTTPX request with the signed headers
|
||||
request.headers.update(prepared_request.headers)
|
||||
|
||||
yield request
|
||||
|
||||
|
||||
class OciInstancePrincipalAuth(HttpxOciAuth):
|
||||
def __init__(self, **kwargs: Mapping[str, Any]):
|
||||
self.signer = oci.auth.signers.InstancePrincipalsSecurityTokenSigner(**kwargs)
|
||||
|
||||
|
||||
class OciUserPrincipalAuth(HttpxOciAuth):
|
||||
def __init__(self, config_file: str = DEFAULT_LOCATION, profile_name: str = DEFAULT_PROFILE):
|
||||
config = oci.config.from_file(config_file, profile_name)
|
||||
oci.config.validate_config(config) # type: ignore
|
||||
key_content = ""
|
||||
with open(config["key_file"]) as f:
|
||||
key_content = f.read()
|
||||
|
||||
self.signer = oci.signer.Signer(
|
||||
tenancy=config["tenancy"],
|
||||
user=config["user"],
|
||||
fingerprint=config["fingerprint"],
|
||||
private_key_file_location=config.get("key_file"),
|
||||
pass_phrase="none", # type: ignore
|
||||
private_key_content=key_content,
|
||||
)
|
||||
75
src/llama_stack/providers/remote/inference/oci/config.py
Normal file
75
src/llama_stack/providers/remote/inference/oci/config.py
Normal file
|
|
@ -0,0 +1,75 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
class OCIProviderDataValidator(BaseModel):
|
||||
oci_auth_type: str = Field(
|
||||
description="OCI authentication type (must be one of: instance_principal, config_file)",
|
||||
)
|
||||
oci_region: str = Field(
|
||||
description="OCI region (e.g., us-ashburn-1)",
|
||||
)
|
||||
oci_compartment_id: str = Field(
|
||||
description="OCI compartment ID for the Generative AI service",
|
||||
)
|
||||
oci_config_file_path: str | None = Field(
|
||||
default="~/.oci/config",
|
||||
description="OCI config file path (required if oci_auth_type is config_file)",
|
||||
)
|
||||
oci_config_profile: str | None = Field(
|
||||
default="DEFAULT",
|
||||
description="OCI config profile (required if oci_auth_type is config_file)",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OCIConfig(RemoteInferenceProviderConfig):
|
||||
oci_auth_type: str = Field(
|
||||
description="OCI authentication type (must be one of: instance_principal, config_file)",
|
||||
default_factory=lambda: os.getenv("OCI_AUTH_TYPE", "instance_principal"),
|
||||
)
|
||||
oci_region: str = Field(
|
||||
default_factory=lambda: os.getenv("OCI_REGION", "us-ashburn-1"),
|
||||
description="OCI region (e.g., us-ashburn-1)",
|
||||
)
|
||||
oci_compartment_id: str = Field(
|
||||
default_factory=lambda: os.getenv("OCI_COMPARTMENT_OCID", ""),
|
||||
description="OCI compartment ID for the Generative AI service",
|
||||
)
|
||||
oci_config_file_path: str = Field(
|
||||
default_factory=lambda: os.getenv("OCI_CONFIG_FILE_PATH", "~/.oci/config"),
|
||||
description="OCI config file path (required if oci_auth_type is config_file)",
|
||||
)
|
||||
oci_config_profile: str = Field(
|
||||
default_factory=lambda: os.getenv("OCI_CLI_PROFILE", "DEFAULT"),
|
||||
description="OCI config profile (required if oci_auth_type is config_file)",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(
|
||||
cls,
|
||||
oci_auth_type: str = "${env.OCI_AUTH_TYPE:=instance_principal}",
|
||||
oci_config_file_path: str = "${env.OCI_CONFIG_FILE_PATH:=~/.oci/config}",
|
||||
oci_config_profile: str = "${env.OCI_CLI_PROFILE:=DEFAULT}",
|
||||
oci_region: str = "${env.OCI_REGION:=us-ashburn-1}",
|
||||
oci_compartment_id: str = "${env.OCI_COMPARTMENT_OCID:=}",
|
||||
**kwargs,
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"oci_auth_type": oci_auth_type,
|
||||
"oci_config_file_path": oci_config_file_path,
|
||||
"oci_config_profile": oci_config_profile,
|
||||
"oci_region": oci_region,
|
||||
"oci_compartment_id": oci_compartment_id,
|
||||
}
|
||||
140
src/llama_stack/providers/remote/inference/oci/oci.py
Normal file
140
src/llama_stack/providers/remote/inference/oci/oci.py
Normal file
|
|
@ -0,0 +1,140 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from collections.abc import Iterable
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
import oci
|
||||
from oci.generative_ai.generative_ai_client import GenerativeAiClient
|
||||
from oci.generative_ai.models import ModelCollection
|
||||
from openai._base_client import DefaultAsyncHttpxClient
|
||||
|
||||
from llama_stack.apis.inference.inference import (
|
||||
OpenAIEmbeddingsRequestWithExtraBody,
|
||||
OpenAIEmbeddingsResponse,
|
||||
)
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.remote.inference.oci.auth import OciInstancePrincipalAuth, OciUserPrincipalAuth
|
||||
from llama_stack.providers.remote.inference.oci.config import OCIConfig
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
logger = get_logger(name=__name__, category="inference::oci")
|
||||
|
||||
OCI_AUTH_TYPE_INSTANCE_PRINCIPAL = "instance_principal"
|
||||
OCI_AUTH_TYPE_CONFIG_FILE = "config_file"
|
||||
VALID_OCI_AUTH_TYPES = [OCI_AUTH_TYPE_INSTANCE_PRINCIPAL, OCI_AUTH_TYPE_CONFIG_FILE]
|
||||
DEFAULT_OCI_REGION = "us-ashburn-1"
|
||||
|
||||
MODEL_CAPABILITIES = ["TEXT_GENERATION", "TEXT_SUMMARIZATION", "TEXT_EMBEDDINGS", "CHAT"]
|
||||
|
||||
|
||||
class OCIInferenceAdapter(OpenAIMixin):
|
||||
config: OCIConfig
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize and validate OCI configuration."""
|
||||
if self.config.oci_auth_type not in VALID_OCI_AUTH_TYPES:
|
||||
raise ValueError(
|
||||
f"Invalid OCI authentication type: {self.config.oci_auth_type}."
|
||||
f"Valid types are one of: {VALID_OCI_AUTH_TYPES}"
|
||||
)
|
||||
|
||||
if not self.config.oci_compartment_id:
|
||||
raise ValueError("OCI_COMPARTMENT_OCID is a required parameter. Either set in env variable or config.")
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
region = self.config.oci_region or DEFAULT_OCI_REGION
|
||||
return f"https://inference.generativeai.{region}.oci.oraclecloud.com/20231130/actions/v1"
|
||||
|
||||
def get_api_key(self) -> str | None:
|
||||
# OCI doesn't use API keys, it uses request signing
|
||||
return "<NOTUSED>"
|
||||
|
||||
def get_extra_client_params(self) -> dict[str, Any]:
|
||||
"""
|
||||
Get extra parameters for the AsyncOpenAI client, including OCI-specific auth and headers.
|
||||
"""
|
||||
auth = self._get_auth()
|
||||
compartment_id = self.config.oci_compartment_id or ""
|
||||
|
||||
return {
|
||||
"http_client": DefaultAsyncHttpxClient(
|
||||
auth=auth,
|
||||
headers={
|
||||
"CompartmentId": compartment_id,
|
||||
},
|
||||
),
|
||||
}
|
||||
|
||||
def _get_oci_signer(self) -> oci.signer.AbstractBaseSigner | None:
|
||||
if self.config.oci_auth_type == OCI_AUTH_TYPE_INSTANCE_PRINCIPAL:
|
||||
return oci.auth.signers.InstancePrincipalsSecurityTokenSigner()
|
||||
return None
|
||||
|
||||
def _get_oci_config(self) -> dict:
|
||||
if self.config.oci_auth_type == OCI_AUTH_TYPE_INSTANCE_PRINCIPAL:
|
||||
config = {"region": self.config.oci_region}
|
||||
elif self.config.oci_auth_type == OCI_AUTH_TYPE_CONFIG_FILE:
|
||||
config = oci.config.from_file(self.config.oci_config_file_path, self.config.oci_config_profile)
|
||||
if not config.get("region"):
|
||||
raise ValueError(
|
||||
"Region not specified in config. Please specify in config or with OCI_REGION env variable."
|
||||
)
|
||||
|
||||
return config
|
||||
|
||||
def _get_auth(self) -> httpx.Auth:
|
||||
if self.config.oci_auth_type == OCI_AUTH_TYPE_INSTANCE_PRINCIPAL:
|
||||
return OciInstancePrincipalAuth()
|
||||
elif self.config.oci_auth_type == OCI_AUTH_TYPE_CONFIG_FILE:
|
||||
return OciUserPrincipalAuth(
|
||||
config_file=self.config.oci_config_file_path, profile_name=self.config.oci_config_profile
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid OCI authentication type: {self.config.oci_auth_type}")
|
||||
|
||||
async def list_provider_model_ids(self) -> Iterable[str]:
|
||||
"""
|
||||
List available models from OCI Generative AI service.
|
||||
"""
|
||||
oci_config = self._get_oci_config()
|
||||
oci_signer = self._get_oci_signer()
|
||||
compartment_id = self.config.oci_compartment_id or ""
|
||||
|
||||
if oci_signer is None:
|
||||
client = GenerativeAiClient(config=oci_config)
|
||||
else:
|
||||
client = GenerativeAiClient(config=oci_config, signer=oci_signer)
|
||||
|
||||
models: ModelCollection = client.list_models(
|
||||
compartment_id=compartment_id, capability=MODEL_CAPABILITIES, lifecycle_state="ACTIVE"
|
||||
).data
|
||||
|
||||
seen_models = set()
|
||||
model_ids = []
|
||||
for model in models.items:
|
||||
if model.time_deprecated or model.time_on_demand_retired:
|
||||
continue
|
||||
|
||||
if "CHAT" not in model.capabilities or "FINE_TUNE" in model.capabilities:
|
||||
continue
|
||||
|
||||
# Use display_name + model_type as the key to avoid conflicts
|
||||
model_key = (model.display_name, ModelType.llm)
|
||||
if model_key in seen_models:
|
||||
continue
|
||||
|
||||
seen_models.add(model_key)
|
||||
model_ids.append(model.display_name)
|
||||
|
||||
return model_ids
|
||||
|
||||
async def openai_embeddings(self, params: OpenAIEmbeddingsRequestWithExtraBody) -> OpenAIEmbeddingsResponse:
|
||||
# The constructed url is a mask that hits OCI's "chat" action, which is not supported for embeddings.
|
||||
raise NotImplementedError("OCI Provider does not (currently) support embeddings")
|
||||
Loading…
Add table
Add a link
Reference in a new issue