mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
# What does this PR do? Adds OCI GenAI PaaS models for openai chat completion endpoints. ## Test Plan In an OCI tenancy with access to GenAI PaaS, perform the following steps: 1. Ensure you have IAM policies in place to use service (check docs included in this PR) 2. For local development, [setup OCI cli](https://docs.oracle.com/en-us/iaas/Content/API/SDKDocs/cliinstall.htm) and configure the CLI with your region, tenancy, and auth [here](https://docs.oracle.com/en-us/iaas/Content/API/SDKDocs/cliconfigure.htm) 3. Once configured, go through llama-stack setup and run llama-stack (uses config based auth) like: ```bash OCI_AUTH_TYPE=config_file \ OCI_CLI_PROFILE=CHICAGO \ OCI_REGION=us-chicago-1 \ OCI_COMPARTMENT_OCID=ocid1.compartment.oc1..aaaaaaaa5...5a \ llama stack run oci ``` 4. Hit the `models` endpoint to list models after server is running: ```bash curl http://localhost:8321/v1/models | jq ... { "identifier": "meta.llama-4-scout-17b-16e-instruct", "provider_resource_id": "ocid1.generativeaimodel.oc1.us-chicago-1.am...q", "provider_id": "oci", "type": "model", "metadata": { "display_name": "meta.llama-4-scout-17b-16e-instruct", "capabilities": [ "CHAT" ], "oci_model_id": "ocid1.generativeaimodel.oc1.us-chicago-1.a...q" }, "model_type": "llm" }, ... ``` 5. Use the "display_name" field to use the model in a `/chat/completions` request: ```bash # Streaming result curl -X POST http://localhost:8321/v1/chat/completions -H "Content-Type: application/json" -d '{ "model": "meta.llama-4-scout-17b-16e-instruct", "stream": true, "temperature": 0.9, "messages": [ { "role": "system", "content": "You are a funny comedian. You can be crass." }, { "role": "user", "content": "Tell me a funny joke about programming." } ] }' # Non-streaming result curl -X POST http://localhost:8321/v1/chat/completions -H "Content-Type: application/json" -d '{ "model": "meta.llama-4-scout-17b-16e-instruct", "stream": false, "temperature": 0.9, "messages": [ { "role": "system", "content": "You are a funny comedian. You can be crass." }, { "role": "user", "content": "Tell me a funny joke about programming." } ] }' ``` 6. Try out other models from the `/models` endpoint.
140 lines
5.5 KiB
Python
140 lines
5.5 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
|
|
from collections.abc import Iterable
|
|
from typing import Any
|
|
|
|
import httpx
|
|
import oci
|
|
from oci.generative_ai.generative_ai_client import GenerativeAiClient
|
|
from oci.generative_ai.models import ModelCollection
|
|
from openai._base_client import DefaultAsyncHttpxClient
|
|
|
|
from llama_stack.apis.inference.inference import (
|
|
OpenAIEmbeddingsRequestWithExtraBody,
|
|
OpenAIEmbeddingsResponse,
|
|
)
|
|
from llama_stack.apis.models import ModelType
|
|
from llama_stack.log import get_logger
|
|
from llama_stack.providers.remote.inference.oci.auth import OciInstancePrincipalAuth, OciUserPrincipalAuth
|
|
from llama_stack.providers.remote.inference.oci.config import OCIConfig
|
|
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
|
|
|
logger = get_logger(name=__name__, category="inference::oci")
|
|
|
|
OCI_AUTH_TYPE_INSTANCE_PRINCIPAL = "instance_principal"
|
|
OCI_AUTH_TYPE_CONFIG_FILE = "config_file"
|
|
VALID_OCI_AUTH_TYPES = [OCI_AUTH_TYPE_INSTANCE_PRINCIPAL, OCI_AUTH_TYPE_CONFIG_FILE]
|
|
DEFAULT_OCI_REGION = "us-ashburn-1"
|
|
|
|
MODEL_CAPABILITIES = ["TEXT_GENERATION", "TEXT_SUMMARIZATION", "TEXT_EMBEDDINGS", "CHAT"]
|
|
|
|
|
|
class OCIInferenceAdapter(OpenAIMixin):
|
|
config: OCIConfig
|
|
|
|
async def initialize(self) -> None:
|
|
"""Initialize and validate OCI configuration."""
|
|
if self.config.oci_auth_type not in VALID_OCI_AUTH_TYPES:
|
|
raise ValueError(
|
|
f"Invalid OCI authentication type: {self.config.oci_auth_type}."
|
|
f"Valid types are one of: {VALID_OCI_AUTH_TYPES}"
|
|
)
|
|
|
|
if not self.config.oci_compartment_id:
|
|
raise ValueError("OCI_COMPARTMENT_OCID is a required parameter. Either set in env variable or config.")
|
|
|
|
def get_base_url(self) -> str:
|
|
region = self.config.oci_region or DEFAULT_OCI_REGION
|
|
return f"https://inference.generativeai.{region}.oci.oraclecloud.com/20231130/actions/v1"
|
|
|
|
def get_api_key(self) -> str | None:
|
|
# OCI doesn't use API keys, it uses request signing
|
|
return "<NOTUSED>"
|
|
|
|
def get_extra_client_params(self) -> dict[str, Any]:
|
|
"""
|
|
Get extra parameters for the AsyncOpenAI client, including OCI-specific auth and headers.
|
|
"""
|
|
auth = self._get_auth()
|
|
compartment_id = self.config.oci_compartment_id or ""
|
|
|
|
return {
|
|
"http_client": DefaultAsyncHttpxClient(
|
|
auth=auth,
|
|
headers={
|
|
"CompartmentId": compartment_id,
|
|
},
|
|
),
|
|
}
|
|
|
|
def _get_oci_signer(self) -> oci.signer.AbstractBaseSigner | None:
|
|
if self.config.oci_auth_type == OCI_AUTH_TYPE_INSTANCE_PRINCIPAL:
|
|
return oci.auth.signers.InstancePrincipalsSecurityTokenSigner()
|
|
return None
|
|
|
|
def _get_oci_config(self) -> dict:
|
|
if self.config.oci_auth_type == OCI_AUTH_TYPE_INSTANCE_PRINCIPAL:
|
|
config = {"region": self.config.oci_region}
|
|
elif self.config.oci_auth_type == OCI_AUTH_TYPE_CONFIG_FILE:
|
|
config = oci.config.from_file(self.config.oci_config_file_path, self.config.oci_config_profile)
|
|
if not config.get("region"):
|
|
raise ValueError(
|
|
"Region not specified in config. Please specify in config or with OCI_REGION env variable."
|
|
)
|
|
|
|
return config
|
|
|
|
def _get_auth(self) -> httpx.Auth:
|
|
if self.config.oci_auth_type == OCI_AUTH_TYPE_INSTANCE_PRINCIPAL:
|
|
return OciInstancePrincipalAuth()
|
|
elif self.config.oci_auth_type == OCI_AUTH_TYPE_CONFIG_FILE:
|
|
return OciUserPrincipalAuth(
|
|
config_file=self.config.oci_config_file_path, profile_name=self.config.oci_config_profile
|
|
)
|
|
else:
|
|
raise ValueError(f"Invalid OCI authentication type: {self.config.oci_auth_type}")
|
|
|
|
async def list_provider_model_ids(self) -> Iterable[str]:
|
|
"""
|
|
List available models from OCI Generative AI service.
|
|
"""
|
|
oci_config = self._get_oci_config()
|
|
oci_signer = self._get_oci_signer()
|
|
compartment_id = self.config.oci_compartment_id or ""
|
|
|
|
if oci_signer is None:
|
|
client = GenerativeAiClient(config=oci_config)
|
|
else:
|
|
client = GenerativeAiClient(config=oci_config, signer=oci_signer)
|
|
|
|
models: ModelCollection = client.list_models(
|
|
compartment_id=compartment_id, capability=MODEL_CAPABILITIES, lifecycle_state="ACTIVE"
|
|
).data
|
|
|
|
seen_models = set()
|
|
model_ids = []
|
|
for model in models.items:
|
|
if model.time_deprecated or model.time_on_demand_retired:
|
|
continue
|
|
|
|
if "CHAT" not in model.capabilities or "FINE_TUNE" in model.capabilities:
|
|
continue
|
|
|
|
# Use display_name + model_type as the key to avoid conflicts
|
|
model_key = (model.display_name, ModelType.llm)
|
|
if model_key in seen_models:
|
|
continue
|
|
|
|
seen_models.add(model_key)
|
|
model_ids.append(model.display_name)
|
|
|
|
return model_ids
|
|
|
|
async def openai_embeddings(self, params: OpenAIEmbeddingsRequestWithExtraBody) -> OpenAIEmbeddingsResponse:
|
|
# The constructed url is a mask that hits OCI's "chat" action, which is not supported for embeddings.
|
|
raise NotImplementedError("OCI Provider does not (currently) support embeddings")
|