Merge remote-tracking branch 'upstream/main' into elasticsearch-integration

This commit is contained in:
Enrico Zimuel 2025-10-31 18:23:42 +01:00
commit 2407115ee8
No known key found for this signature in database
GPG key ID: 6CB203F6934A69F1
1050 changed files with 65153 additions and 2821 deletions

View file

@ -1,5 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -1,5 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -1,5 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -1,18 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import HuggingfaceDatasetIOConfig
async def get_adapter_impl(
config: HuggingfaceDatasetIOConfig,
_deps,
):
from .huggingface import HuggingfaceDatasetIOImpl
impl = HuggingfaceDatasetIOImpl(config)
await impl.initialize()
return impl

View file

@ -1,23 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from pydantic import BaseModel
from llama_stack.core.storage.datatypes import KVStoreReference
class HuggingfaceDatasetIOConfig(BaseModel):
kvstore: KVStoreReference
@classmethod
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
return {
"kvstore": KVStoreReference(
backend="kv_default",
namespace="datasetio::huggingface",
).model_dump(exclude_none=True)
}

View file

@ -1,99 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from urllib.parse import parse_qs, urlparse
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Dataset
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.pagination import paginate_records
from .config import HuggingfaceDatasetIOConfig
DATASETS_PREFIX = "datasets:"
def parse_hf_params(dataset_def: Dataset):
uri = dataset_def.source.uri
parsed_uri = urlparse(uri)
params = parse_qs(parsed_uri.query)
params = {k: v[0] for k, v in params.items()}
path = parsed_uri.path.lstrip("/")
return path, params
class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
def __init__(self, config: HuggingfaceDatasetIOConfig) -> None:
self.config = config
# local registry for keeping track of datasets within the provider
self.dataset_infos = {}
self.kvstore = None
async def initialize(self) -> None:
self.kvstore = await kvstore_impl(self.config.kvstore)
# Load existing datasets from kvstore
start_key = DATASETS_PREFIX
end_key = f"{DATASETS_PREFIX}\xff"
stored_datasets = await self.kvstore.values_in_range(start_key, end_key)
for dataset in stored_datasets:
dataset = Dataset.model_validate_json(dataset)
self.dataset_infos[dataset.identifier] = dataset
async def shutdown(self) -> None: ...
async def register_dataset(
self,
dataset_def: Dataset,
) -> None:
# Store in kvstore
key = f"{DATASETS_PREFIX}{dataset_def.identifier}"
await self.kvstore.set(
key=key,
value=dataset_def.model_dump_json(),
)
self.dataset_infos[dataset_def.identifier] = dataset_def
async def unregister_dataset(self, dataset_id: str) -> None:
key = f"{DATASETS_PREFIX}{dataset_id}"
await self.kvstore.delete(key=key)
del self.dataset_infos[dataset_id]
async def iterrows(
self,
dataset_id: str,
start_index: int | None = None,
limit: int | None = None,
) -> PaginatedResponse:
import datasets as hf_datasets
dataset_def = self.dataset_infos[dataset_id]
path, params = parse_hf_params(dataset_def)
loaded_dataset = hf_datasets.load_dataset(path, **params)
records = [loaded_dataset[i] for i in range(len(loaded_dataset))]
return paginate_records(records, start_index, limit)
async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None:
import datasets as hf_datasets
dataset_def = self.dataset_infos[dataset_id]
path, params = parse_hf_params(dataset_def)
loaded_dataset = hf_datasets.load_dataset(path, **params)
# Convert rows to HF Dataset format
new_dataset = hf_datasets.Dataset.from_list(rows)
# Concatenate the new rows with existing dataset
updated_dataset = hf_datasets.concatenate_datasets([loaded_dataset, new_dataset])
if dataset_def.metadata.get("path", None):
updated_dataset.push_to_hub(dataset_def.metadata["path"])
else:
raise NotImplementedError("Uploading to URL-based datasets is not supported yet")

View file

@ -1,73 +0,0 @@
# NVIDIA DatasetIO Provider for LlamaStack
This provider enables dataset management using NVIDIA's NeMo Customizer service.
## Features
- Register datasets for fine-tuning LLMs
- Unregister datasets
## Getting Started
### Prerequisites
- LlamaStack with NVIDIA configuration
- Access to Hosted NVIDIA NeMo Microservice
- API key for authentication with the NVIDIA service
### Setup
Build the NVIDIA environment:
```bash
uv run llama stack list-deps nvidia | xargs -L1 uv pip install
```
### Basic Usage using the LlamaStack Python Client
#### Initialize the client
```python
import os
os.environ["NVIDIA_API_KEY"] = "your-api-key"
os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test"
os.environ["NVIDIA_DATASET_NAMESPACE"] = "default"
os.environ["NVIDIA_PROJECT_ID"] = "test-project"
from llama_stack.core.library_client import LlamaStackAsLibraryClient
client = LlamaStackAsLibraryClient("nvidia")
client.initialize()
```
#### Register a dataset
```python
client.datasets.register(
purpose="post-training/messages",
dataset_id="my-training-dataset",
source={"type": "uri", "uri": "hf://datasets/default/sample-dataset"},
metadata={
"format": "json",
"description": "Dataset for LLM fine-tuning",
"provider": "nvidia",
},
)
```
#### Get a list of all registered datasets
```python
datasets = client.datasets.list()
for dataset in datasets:
print(f"Dataset ID: {dataset.identifier}")
print(f"Description: {dataset.metadata.get('description', '')}")
print(f"Source: {dataset.source.uri}")
print("---")
```
#### Unregister a dataset
```python
client.datasets.unregister(dataset_id="my-training-dataset")
```

View file

@ -1,23 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import NvidiaDatasetIOConfig
async def get_adapter_impl(
config: NvidiaDatasetIOConfig,
_deps,
):
from .datasetio import NvidiaDatasetIOAdapter
if not isinstance(config, NvidiaDatasetIOConfig):
raise RuntimeError(f"Unexpected config type: {type(config)}")
impl = NvidiaDatasetIOAdapter(config)
return impl
__all__ = ["get_adapter_impl", "NvidiaDatasetIOAdapter"]

View file

@ -1,61 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
import warnings
from typing import Any
from pydantic import BaseModel, Field
class NvidiaDatasetIOConfig(BaseModel):
"""Configuration for NVIDIA DatasetIO implementation."""
api_key: str | None = Field(
default_factory=lambda: os.getenv("NVIDIA_API_KEY"),
description="The NVIDIA API key.",
)
dataset_namespace: str | None = Field(
default_factory=lambda: os.getenv("NVIDIA_DATASET_NAMESPACE", "default"),
description="The NVIDIA dataset namespace.",
)
project_id: str | None = Field(
default_factory=lambda: os.getenv("NVIDIA_PROJECT_ID", "test-project"),
description="The NVIDIA project ID.",
)
datasets_url: str = Field(
default_factory=lambda: os.getenv("NVIDIA_DATASETS_URL", "http://nemo.test"),
description="Base URL for the NeMo Dataset API",
)
# warning for default values
def __post_init__(self):
default_values = []
if os.getenv("NVIDIA_PROJECT_ID") is None:
default_values.append("project_id='test-project'")
if os.getenv("NVIDIA_DATASET_NAMESPACE") is None:
default_values.append("dataset_namespace='default'")
if os.getenv("NVIDIA_DATASETS_URL") is None:
default_values.append("datasets_url='http://nemo.test'")
if default_values:
warnings.warn(
f"Using default values: {', '.join(default_values)}. \
Please set the environment variables to avoid this default behavior.",
stacklevel=2,
)
@classmethod
def sample_run_config(cls, **kwargs) -> dict[str, Any]:
return {
"api_key": "${env.NVIDIA_API_KEY:=}",
"dataset_namespace": "${env.NVIDIA_DATASET_NAMESPACE:=default}",
"project_id": "${env.NVIDIA_PROJECT_ID:=test-project}",
"datasets_url": "${env.NVIDIA_DATASETS_URL:=http://nemo.test}",
}

View file

@ -1,116 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
import aiohttp
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.apis.common.type_system import ParamType
from llama_stack.apis.datasets import Dataset
from .config import NvidiaDatasetIOConfig
class NvidiaDatasetIOAdapter:
"""Nvidia NeMo DatasetIO API."""
def __init__(self, config: NvidiaDatasetIOConfig):
self.config = config
self.headers = {}
async def _make_request(
self,
method: str,
path: str,
headers: dict[str, Any] | None = None,
params: dict[str, Any] | None = None,
json: dict[str, Any] | None = None,
**kwargs,
) -> dict[str, Any]:
"""Helper method to make HTTP requests to the Customizer API."""
url = f"{self.config.datasets_url}{path}"
request_headers = self.headers.copy()
# Set default Content-Type for JSON requests
if json is not None:
request_headers["Content-Type"] = "application/json"
if headers:
request_headers.update(headers)
async with aiohttp.ClientSession(headers=request_headers) as session:
async with session.request(method, url, params=params, json=json, **kwargs) as response:
if response.status != 200:
error_data = await response.json()
raise Exception(f"API request failed: {error_data}")
return await response.json()
async def register_dataset(
self,
dataset_def: Dataset,
) -> Dataset:
"""Register a new dataset.
Args:
dataset_def [Dataset]: The dataset definition.
dataset_id [str]: The ID of the dataset.
source [DataSource]: The source of the dataset.
metadata [Dict[str, Any]]: The metadata of the dataset.
format [str]: The format of the dataset.
description [str]: The description of the dataset.
Returns:
Dataset
"""
# add warnings for unsupported params
request_body = {
"name": dataset_def.identifier,
"namespace": self.config.dataset_namespace,
"files_url": dataset_def.source.uri,
"project": self.config.project_id,
}
if dataset_def.metadata:
request_body["format"] = dataset_def.metadata.get("format")
request_body["description"] = dataset_def.metadata.get("description")
await self._make_request(
"POST",
"/v1/datasets",
json=request_body,
)
return dataset_def
async def update_dataset(
self,
dataset_id: str,
dataset_schema: dict[str, ParamType],
url: URL,
provider_dataset_id: str | None = None,
provider_id: str | None = None,
metadata: dict[str, Any] | None = None,
) -> None:
raise NotImplementedError("Not implemented")
async def unregister_dataset(
self,
dataset_id: str,
) -> None:
await self._make_request(
"DELETE",
f"/v1/datasets/{self.config.dataset_namespace}/{dataset_id}",
headers={"Accept": "application/json", "Content-Type": "application/json"},
)
async def iterrows(
self,
dataset_id: str,
start_index: int | None = None,
limit: int | None = None,
) -> PaginatedResponse:
raise NotImplementedError("Not implemented")
async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None:
raise NotImplementedError("Not implemented")

View file

@ -1,5 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -1,134 +0,0 @@
# NVIDIA NeMo Evaluator Eval Provider
## Overview
For the first integration, Benchmarks are mapped to Evaluation Configs on in the NeMo Evaluator. The full evaluation config object is provided as part of the meta-data. The `dataset_id` and `scoring_functions` are not used.
Below are a few examples of how to register a benchmark, which in turn will create an evaluation config in NeMo Evaluator and how to trigger an evaluation.
### Example for register an academic benchmark
```
POST /eval/benchmarks
```
```json
{
"benchmark_id": "mmlu",
"dataset_id": "",
"scoring_functions": [],
"metadata": {
"type": "mmlu"
}
}
```
### Example for register a custom evaluation
```
POST /eval/benchmarks
```
```json
{
"benchmark_id": "my-custom-benchmark",
"dataset_id": "",
"scoring_functions": [],
"metadata": {
"type": "custom",
"params": {
"parallelism": 8
},
"tasks": {
"qa": {
"type": "completion",
"params": {
"template": {
"prompt": "{{prompt}}",
"max_tokens": 200
}
},
"dataset": {
"files_url": "hf://datasets/default/sample-basic-test/testing/testing.jsonl"
},
"metrics": {
"bleu": {
"type": "bleu",
"params": {
"references": [
"{{ideal_response}}"
]
}
}
}
}
}
}
}
```
### Example for triggering a benchmark/custom evaluation
```
POST /eval/benchmarks/{benchmark_id}/jobs
```
```json
{
"benchmark_id": "my-custom-benchmark",
"benchmark_config": {
"eval_candidate": {
"type": "model",
"model": "meta-llama/Llama3.1-8B-Instruct",
"sampling_params": {
"max_tokens": 100,
"temperature": 0.7
}
},
"scoring_params": {}
}
}
```
Response example:
```json
{
"job_id": "eval-1234",
"status": "in_progress"
}
```
### Example for getting the status of a job
```
GET /eval/benchmarks/{benchmark_id}/jobs/{job_id}
```
Response example:
```json
{
"job_id": "eval-1234",
"status": "in_progress"
}
```
### Example for cancelling a job
```
POST /eval/benchmarks/{benchmark_id}/jobs/{job_id}/cancel
```
### Example for getting the results
```
GET /eval/benchmarks/{benchmark_id}/results
```
```json
{
"generations": [],
"scores": {
"{benchmark_id}": {
"score_rows": [],
"aggregated_results": {
"tasks": {},
"groups": {}
}
}
}
}
```

View file

@ -1,31 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from llama_stack.core.datatypes import Api
from .config import NVIDIAEvalConfig
async def get_adapter_impl(
config: NVIDIAEvalConfig,
deps: dict[Api, Any],
):
from .eval import NVIDIAEvalImpl
impl = NVIDIAEvalImpl(
config,
deps[Api.datasetio],
deps[Api.datasets],
deps[Api.scoring],
deps[Api.inference],
deps[Api.agents],
)
await impl.initialize()
return impl
__all__ = ["get_adapter_impl", "NVIDIAEvalImpl"]

View file

@ -1,29 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
from typing import Any
from pydantic import BaseModel, Field
class NVIDIAEvalConfig(BaseModel):
"""
Configuration for the NVIDIA NeMo Evaluator microservice endpoint.
Attributes:
evaluator_url (str): A base url for accessing the NVIDIA evaluation endpoint, e.g. http://localhost:8000.
"""
evaluator_url: str = Field(
default_factory=lambda: os.getenv("NVIDIA_EVALUATOR_URL", "http://0.0.0.0:7331"),
description="The url for accessing the evaluator service",
)
@classmethod
def sample_run_config(cls, **kwargs) -> dict[str, Any]:
return {
"evaluator_url": "${env.NVIDIA_EVALUATOR_URL:=http://localhost:7331}",
}

View file

@ -1,162 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
import requests
from llama_stack.apis.agents import Agents
from llama_stack.apis.benchmarks import Benchmark
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.inference import Inference
from llama_stack.apis.scoring import Scoring, ScoringResult
from llama_stack.providers.datatypes import BenchmarksProtocolPrivate
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from .....apis.common.job_types import Job, JobStatus
from .....apis.eval.eval import BenchmarkConfig, Eval, EvaluateResponse
from .config import NVIDIAEvalConfig
DEFAULT_NAMESPACE = "nvidia"
class NVIDIAEvalImpl(
Eval,
BenchmarksProtocolPrivate,
ModelRegistryHelper,
):
def __init__(
self,
config: NVIDIAEvalConfig,
datasetio_api: DatasetIO,
datasets_api: Datasets,
scoring_api: Scoring,
inference_api: Inference,
agents_api: Agents,
) -> None:
self.config = config
self.datasetio_api = datasetio_api
self.datasets_api = datasets_api
self.scoring_api = scoring_api
self.inference_api = inference_api
self.agents_api = agents_api
ModelRegistryHelper.__init__(self)
async def initialize(self) -> None: ...
async def shutdown(self) -> None: ...
async def _evaluator_get(self, path: str):
"""Helper for making GET requests to the evaluator service."""
response = requests.get(url=f"{self.config.evaluator_url}{path}")
response.raise_for_status()
return response.json()
async def _evaluator_post(self, path: str, data: dict[str, Any]):
"""Helper for making POST requests to the evaluator service."""
response = requests.post(url=f"{self.config.evaluator_url}{path}", json=data)
response.raise_for_status()
return response.json()
async def _evaluator_delete(self, path: str) -> None:
"""Helper for making DELETE requests to the evaluator service."""
response = requests.delete(url=f"{self.config.evaluator_url}{path}")
response.raise_for_status()
async def register_benchmark(self, task_def: Benchmark) -> None:
"""Register a benchmark as an evaluation configuration."""
await self._evaluator_post(
"/v1/evaluation/configs",
{
"namespace": DEFAULT_NAMESPACE,
"name": task_def.benchmark_id,
# metadata is copied to request body as-is
**task_def.metadata,
},
)
async def unregister_benchmark(self, benchmark_id: str) -> None:
"""Unregister a benchmark evaluation configuration from NeMo Evaluator."""
await self._evaluator_delete(f"/v1/evaluation/configs/{DEFAULT_NAMESPACE}/{benchmark_id}")
async def run_eval(
self,
benchmark_id: str,
benchmark_config: BenchmarkConfig,
) -> Job:
"""Run an evaluation job for a benchmark."""
model = (
benchmark_config.eval_candidate.model
if benchmark_config.eval_candidate.type == "model"
else benchmark_config.eval_candidate.config.model
)
nvidia_model = self.get_provider_model_id(model) or model
result = await self._evaluator_post(
"/v1/evaluation/jobs",
{
"config": f"{DEFAULT_NAMESPACE}/{benchmark_id}",
"target": {"type": "model", "model": nvidia_model},
},
)
return Job(job_id=result["id"], status=JobStatus.in_progress)
async def evaluate_rows(
self,
benchmark_id: str,
input_rows: list[dict[str, Any]],
scoring_functions: list[str],
benchmark_config: BenchmarkConfig,
) -> EvaluateResponse:
raise NotImplementedError()
async def job_status(self, benchmark_id: str, job_id: str) -> Job:
"""Get the status of an evaluation job.
EvaluatorStatus: "created", "pending", "running", "cancelled", "cancelling", "failed", "completed".
JobStatus: "scheduled", "in_progress", "completed", "cancelled", "failed"
"""
result = await self._evaluator_get(f"/v1/evaluation/jobs/{job_id}")
result_status = result["status"]
job_status = JobStatus.failed
if result_status in ["created", "pending"]:
job_status = JobStatus.scheduled
elif result_status in ["running"]:
job_status = JobStatus.in_progress
elif result_status in ["completed"]:
job_status = JobStatus.completed
elif result_status in ["cancelled"]:
job_status = JobStatus.cancelled
return Job(job_id=job_id, status=job_status)
async def job_cancel(self, benchmark_id: str, job_id: str) -> None:
"""Cancel the evaluation job."""
await self._evaluator_post(f"/v1/evaluation/jobs/{job_id}/cancel", {})
async def job_result(self, benchmark_id: str, job_id: str) -> EvaluateResponse:
"""Returns the results of the evaluation job."""
job = await self.job_status(benchmark_id, job_id)
status = job.status
if not status or status != JobStatus.completed:
raise ValueError(f"Job {job_id} not completed. Status: {status.value}")
result = await self._evaluator_get(f"/v1/evaluation/jobs/{job_id}/results")
return EvaluateResponse(
# TODO: these are stored in detailed results on NeMo Evaluator side; can be added
generations=[],
scores={
benchmark_id: ScoringResult(
score_rows=[],
aggregated_results=result,
)
},
)

View file

@ -1,237 +0,0 @@
# S3 Files Provider
A remote S3-based implementation of the Llama Stack Files API that provides scalable cloud file storage with metadata persistence.
## Features
- **AWS S3 Storage**: Store files in AWS S3 buckets for scalable, durable storage
- **Metadata Management**: Uses SQL database for efficient file metadata queries
- **OpenAI API Compatibility**: Full compatibility with OpenAI Files API endpoints
- **Flexible Authentication**: Support for IAM roles and access keys
- **Custom S3 Endpoints**: Support for MinIO and other S3-compatible services
## Configuration
### Basic Configuration
```yaml
api: files
provider_type: remote::s3
config:
bucket_name: my-llama-stack-files
region: us-east-1
metadata_store:
type: sqlite
db_path: ./s3_files_metadata.db
```
### Advanced Configuration
```yaml
api: files
provider_type: remote::s3
config:
bucket_name: my-llama-stack-files
region: us-east-1
aws_access_key_id: YOUR_ACCESS_KEY
aws_secret_access_key: YOUR_SECRET_KEY
endpoint_url: https://s3.amazonaws.com # Optional for custom endpoints
metadata_store:
type: sqlite
db_path: ./s3_files_metadata.db
```
### Environment Variables
The configuration supports environment variable substitution:
```yaml
config:
bucket_name: "${env.S3_BUCKET_NAME}"
region: "${env.AWS_REGION:=us-east-1}"
aws_access_key_id: "${env.AWS_ACCESS_KEY_ID:=}"
aws_secret_access_key: "${env.AWS_SECRET_ACCESS_KEY:=}"
endpoint_url: "${env.S3_ENDPOINT_URL:=}"
```
Note: `S3_BUCKET_NAME` has no default value since S3 bucket names must be globally unique.
## Authentication
### IAM Roles (Recommended)
For production deployments, use IAM roles:
```yaml
config:
bucket_name: my-bucket
region: us-east-1
# No credentials needed - will use IAM role
```
### Access Keys
For development or specific use cases:
```yaml
config:
bucket_name: my-bucket
region: us-east-1
aws_access_key_id: AKIAIOSFODNN7EXAMPLE
aws_secret_access_key: wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY
```
## S3 Bucket Setup
### Required Permissions
The S3 provider requires the following permissions:
```json
{
"Version": "2012-10-17",
"Statement": [
{
"Effect": "Allow",
"Action": [
"s3:GetObject",
"s3:PutObject",
"s3:DeleteObject",
"s3:ListBucket"
],
"Resource": [
"arn:aws:s3:::your-bucket-name",
"arn:aws:s3:::your-bucket-name/*"
]
}
]
}
```
### Automatic Bucket Creation
By default, the S3 provider expects the bucket to already exist. If you want the provider to automatically create the bucket when it doesn't exist, set `auto_create_bucket: true` in your configuration:
```yaml
config:
bucket_name: my-bucket
auto_create_bucket: true # Will create bucket if it doesn't exist
region: us-east-1
```
**Note**: When `auto_create_bucket` is enabled, the provider will need additional permissions:
```json
{
"Version": "2012-10-17",
"Statement": [
{
"Effect": "Allow",
"Action": [
"s3:GetObject",
"s3:PutObject",
"s3:DeleteObject",
"s3:ListBucket",
"s3:CreateBucket"
],
"Resource": [
"arn:aws:s3:::your-bucket-name",
"arn:aws:s3:::your-bucket-name/*"
]
}
]
}
```
### Bucket Policy (Optional)
For additional security, you can add a bucket policy:
```json
{
"Version": "2012-10-17",
"Statement": [
{
"Sid": "LlamaStackAccess",
"Effect": "Allow",
"Principal": {
"AWS": "arn:aws:iam::YOUR-ACCOUNT:role/LlamaStackRole"
},
"Action": [
"s3:GetObject",
"s3:PutObject",
"s3:DeleteObject"
],
"Resource": "arn:aws:s3:::your-bucket-name/*"
},
{
"Sid": "LlamaStackBucketAccess",
"Effect": "Allow",
"Principal": {
"AWS": "arn:aws:iam::YOUR-ACCOUNT:role/LlamaStackRole"
},
"Action": [
"s3:ListBucket"
],
"Resource": "arn:aws:s3:::your-bucket-name"
}
]
}
```
## Features
### Metadata Persistence
File metadata is stored in a SQL database for fast queries and OpenAI API compatibility. The metadata includes:
- File ID
- Original filename
- Purpose (assistants, batch, etc.)
- File size in bytes
- Created and expiration timestamps
### TTL and Cleanup
Files currently have a fixed long expiration time (100 years).
## Development and Testing
### Using MinIO
For self-hosted S3-compatible storage:
```yaml
config:
bucket_name: test-bucket
region: us-east-1
endpoint_url: http://localhost:9000
aws_access_key_id: minioadmin
aws_secret_access_key: minioadmin
```
## Monitoring and Logging
The provider logs important operations and errors. For production deployments, consider:
- CloudWatch monitoring for S3 operations
- Custom metrics for file upload/download rates
- Error rate monitoring
- Performance metrics tracking
## Error Handling
The provider handles various error scenarios:
- S3 connectivity issues
- Bucket access permissions
- File not found errors
- Metadata consistency checks
## Known Limitations
- Fixed long TTL (100 years) instead of configurable expiration
- No server-side encryption enabled by default
- No support for AWS session tokens
- No S3 key prefix organization support
- No multipart upload support (all files uploaded as single objects)

View file

@ -1,19 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from llama_stack.core.datatypes import AccessRule, Api
from .config import S3FilesImplConfig
async def get_adapter_impl(config: S3FilesImplConfig, deps: dict[Api, Any], policy: list[AccessRule] | None = None):
from .files import S3FilesImpl
impl = S3FilesImpl(config, policy or [])
await impl.initialize()
return impl

View file

@ -1,42 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from pydantic import BaseModel, Field
from llama_stack.core.storage.datatypes import SqlStoreReference
class S3FilesImplConfig(BaseModel):
"""Configuration for S3-based files provider."""
bucket_name: str = Field(description="S3 bucket name to store files")
region: str = Field(default="us-east-1", description="AWS region where the bucket is located")
aws_access_key_id: str | None = Field(default=None, description="AWS access key ID (optional if using IAM roles)")
aws_secret_access_key: str | None = Field(
default=None, description="AWS secret access key (optional if using IAM roles)"
)
endpoint_url: str | None = Field(default=None, description="Custom S3 endpoint URL (for MinIO, LocalStack, etc.)")
auto_create_bucket: bool = Field(
default=False, description="Automatically create the S3 bucket if it doesn't exist"
)
metadata_store: SqlStoreReference = Field(description="SQL store configuration for file metadata")
@classmethod
def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]:
return {
"bucket_name": "${env.S3_BUCKET_NAME}", # no default, buckets must be globally unique
"region": "${env.AWS_REGION:=us-east-1}",
"aws_access_key_id": "${env.AWS_ACCESS_KEY_ID:=}",
"aws_secret_access_key": "${env.AWS_SECRET_ACCESS_KEY:=}",
"endpoint_url": "${env.S3_ENDPOINT_URL:=}",
"auto_create_bucket": "${env.S3_AUTO_CREATE_BUCKET:=false}",
"metadata_store": SqlStoreReference(
backend="sql_default",
table_name="s3_files_metadata",
).model_dump(exclude_none=True),
}

View file

@ -1,313 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import uuid
from datetime import UTC, datetime
from typing import Annotated, Any
import boto3
from botocore.exceptions import BotoCoreError, ClientError, NoCredentialsError
from fastapi import Depends, File, Form, Response, UploadFile
from llama_stack.apis.common.errors import ResourceNotFoundError
from llama_stack.apis.common.responses import Order
from llama_stack.apis.files import (
ExpiresAfter,
Files,
ListOpenAIFileResponse,
OpenAIFileDeleteResponse,
OpenAIFileObject,
OpenAIFilePurpose,
)
from llama_stack.core.datatypes import AccessRule
from llama_stack.core.id_generation import generate_object_id
from llama_stack.providers.utils.files.form_data import parse_expires_after
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
from llama_stack.providers.utils.sqlstore.sqlstore import sqlstore_impl
from .config import S3FilesImplConfig
# TODO: provider data for S3 credentials
def _create_s3_client(config: S3FilesImplConfig) -> boto3.client:
try:
s3_config = {
"region_name": config.region,
}
# endpoint URL if specified (for MinIO, LocalStack, etc.)
if config.endpoint_url:
s3_config["endpoint_url"] = config.endpoint_url
if config.aws_access_key_id and config.aws_secret_access_key:
s3_config.update(
{
"aws_access_key_id": config.aws_access_key_id,
"aws_secret_access_key": config.aws_secret_access_key,
}
)
return boto3.client("s3", **s3_config)
except (BotoCoreError, NoCredentialsError) as e:
raise RuntimeError(f"Failed to initialize S3 client: {e}") from e
async def _create_bucket_if_not_exists(client: boto3.client, config: S3FilesImplConfig) -> None:
try:
client.head_bucket(Bucket=config.bucket_name)
except ClientError as e:
error_code = e.response["Error"]["Code"]
if error_code == "404":
if not config.auto_create_bucket:
raise RuntimeError(
f"S3 bucket '{config.bucket_name}' does not exist. "
f"Either create the bucket manually or set 'auto_create_bucket: true' in your configuration."
) from e
try:
# For us-east-1, we can't specify LocationConstraint
if config.region == "us-east-1":
client.create_bucket(Bucket=config.bucket_name)
else:
client.create_bucket(
Bucket=config.bucket_name,
CreateBucketConfiguration={"LocationConstraint": config.region},
)
except ClientError as create_error:
raise RuntimeError(
f"Failed to create S3 bucket '{config.bucket_name}': {create_error}"
) from create_error
elif error_code == "403":
raise RuntimeError(f"Access denied to S3 bucket '{config.bucket_name}'") from e
else:
raise RuntimeError(f"Failed to access S3 bucket '{config.bucket_name}': {e}") from e
def _make_file_object(
*,
id: str,
filename: str,
purpose: str,
bytes: int,
created_at: int,
expires_at: int,
**kwargs: Any, # here to ignore any additional fields, e.g. extra fields from AuthorizedSqlStore
) -> OpenAIFileObject:
"""
Construct an OpenAIFileObject and normalize expires_at.
If expires_at is greater than the max we treat it as no-expiration and
return None for expires_at.
The OpenAI spec says expires_at type is Integer, but the implementation
will return None for no expiration.
"""
obj = OpenAIFileObject(
id=id,
filename=filename,
purpose=OpenAIFilePurpose(purpose),
bytes=bytes,
created_at=created_at,
expires_at=expires_at,
)
if obj.expires_at is not None and obj.expires_at > (obj.created_at + ExpiresAfter.MAX):
obj.expires_at = None # type: ignore
return obj
class S3FilesImpl(Files):
"""S3-based implementation of the Files API."""
def __init__(self, config: S3FilesImplConfig, policy: list[AccessRule]) -> None:
self._config = config
self.policy = policy
self._client: boto3.client | None = None
self._sql_store: AuthorizedSqlStore | None = None
def _now(self) -> int:
"""Return current UTC timestamp as int seconds."""
return int(datetime.now(UTC).timestamp())
async def _get_file(self, file_id: str, return_expired: bool = False) -> dict[str, Any]:
where: dict[str, str | dict] = {"id": file_id}
if not return_expired:
where["expires_at"] = {">": self._now()}
if not (row := await self.sql_store.fetch_one("openai_files", where=where)):
raise ResourceNotFoundError(file_id, "File", "files.list()")
return row
async def _delete_file(self, file_id: str) -> None:
"""Delete a file from S3 and the database."""
try:
self.client.delete_object(
Bucket=self._config.bucket_name,
Key=file_id,
)
except ClientError as e:
if e.response["Error"]["Code"] != "NoSuchKey":
raise RuntimeError(f"Failed to delete file from S3: {e}") from e
await self.sql_store.delete("openai_files", where={"id": file_id})
async def _delete_if_expired(self, file_id: str) -> None:
"""If the file exists and is expired, delete it."""
if row := await self._get_file(file_id, return_expired=True):
if (expires_at := row.get("expires_at")) and expires_at <= self._now():
await self._delete_file(file_id)
async def initialize(self) -> None:
self._client = _create_s3_client(self._config)
await _create_bucket_if_not_exists(self._client, self._config)
self._sql_store = AuthorizedSqlStore(sqlstore_impl(self._config.metadata_store), self.policy)
await self._sql_store.create_table(
"openai_files",
{
"id": ColumnDefinition(type=ColumnType.STRING, primary_key=True),
"filename": ColumnType.STRING,
"purpose": ColumnType.STRING,
"bytes": ColumnType.INTEGER,
"created_at": ColumnType.INTEGER,
"expires_at": ColumnType.INTEGER,
# TODO: add s3_etag field for integrity checking
},
)
async def shutdown(self) -> None:
pass
@property
def client(self) -> boto3.client:
assert self._client is not None, "Provider not initialized"
return self._client
@property
def sql_store(self) -> AuthorizedSqlStore:
assert self._sql_store is not None, "Provider not initialized"
return self._sql_store
async def openai_upload_file(
self,
file: Annotated[UploadFile, File()],
purpose: Annotated[OpenAIFilePurpose, Form()],
expires_after: Annotated[ExpiresAfter | None, Depends(parse_expires_after)] = None,
) -> OpenAIFileObject:
file_id = generate_object_id("file", lambda: f"file-{uuid.uuid4().hex}")
filename = getattr(file, "filename", None) or "uploaded_file"
created_at = self._now()
# the default is no expiration.
# to implement no expiration we set an expiration beyond the max.
# we'll hide this fact from users when returning the file object.
expires_at = created_at + ExpiresAfter.MAX * 42
# the default for BATCH files is 30 days, which happens to be the expiration max.
if purpose == OpenAIFilePurpose.BATCH:
expires_at = created_at + ExpiresAfter.MAX
if expires_after is not None:
expires_at = created_at + expires_after.seconds
content = await file.read()
file_size = len(content)
entry: dict[str, Any] = {
"id": file_id,
"filename": filename,
"purpose": purpose.value,
"bytes": file_size,
"created_at": created_at,
"expires_at": expires_at,
}
await self.sql_store.insert("openai_files", entry)
try:
self.client.put_object(
Bucket=self._config.bucket_name,
Key=file_id,
Body=content,
# TODO: enable server-side encryption
)
except ClientError as e:
await self.sql_store.delete("openai_files", where={"id": file_id})
raise RuntimeError(f"Failed to upload file to S3: {e}") from e
return _make_file_object(**entry)
async def openai_list_files(
self,
after: str | None = None,
limit: int | None = 10000,
order: Order | None = Order.desc,
purpose: OpenAIFilePurpose | None = None,
) -> ListOpenAIFileResponse:
# this purely defensive. it should not happen because the router also default to Order.desc.
if not order:
order = Order.desc
where_conditions: dict[str, Any] = {"expires_at": {">": self._now()}}
if purpose:
where_conditions["purpose"] = purpose.value
paginated_result = await self.sql_store.fetch_all(
table="openai_files",
where=where_conditions,
order_by=[("created_at", order.value)],
cursor=("id", after) if after else None,
limit=limit,
)
files = [_make_file_object(**row) for row in paginated_result.data]
return ListOpenAIFileResponse(
data=files,
has_more=paginated_result.has_more,
# empty string or None? spec says str, ref impl returns str | None, we go with spec
first_id=files[0].id if files else "",
last_id=files[-1].id if files else "",
)
async def openai_retrieve_file(self, file_id: str) -> OpenAIFileObject:
await self._delete_if_expired(file_id)
row = await self._get_file(file_id)
return _make_file_object(**row)
async def openai_delete_file(self, file_id: str) -> OpenAIFileDeleteResponse:
await self._delete_if_expired(file_id)
_ = await self._get_file(file_id) # raises if not found
await self._delete_file(file_id)
return OpenAIFileDeleteResponse(id=file_id, deleted=True)
async def openai_retrieve_file_content(self, file_id: str) -> Response:
await self._delete_if_expired(file_id)
row = await self._get_file(file_id)
try:
response = self.client.get_object(
Bucket=self._config.bucket_name,
Key=row["id"],
)
# TODO: can we stream this instead of loading it into memory
content = response["Body"].read()
except ClientError as e:
if e.response["Error"]["Code"] == "NoSuchKey":
await self._delete_file(file_id)
raise ResourceNotFoundError(file_id, "File", "files.list()") from e
raise RuntimeError(f"Failed to download file from S3: {e}") from e
return Response(
content=content,
media_type="application/octet-stream",
headers={"Content-Disposition": f'attachment; filename="{row["filename"]}"'},
)

View file

@ -1,5 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -1,15 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import AnthropicConfig
async def get_adapter_impl(config: AnthropicConfig, _deps):
from .anthropic import AnthropicInferenceAdapter
impl = AnthropicInferenceAdapter(config=config)
await impl.initialize()
return impl

View file

@ -1,36 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import Iterable
from anthropic import AsyncAnthropic
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .config import AnthropicConfig
class AnthropicInferenceAdapter(OpenAIMixin):
config: AnthropicConfig
provider_data_api_key_field: str = "anthropic_api_key"
# source: https://docs.claude.com/en/docs/build-with-claude/embeddings
# TODO: add support for voyageai, which is where these models are hosted
# embedding_model_metadata = {
# "voyage-3-large": {"embedding_dimension": 1024, "context_length": 32000}, # supports dimensions 256, 512, 1024, 2048
# "voyage-3.5": {"embedding_dimension": 1024, "context_length": 32000}, # supports dimensions 256, 512, 1024, 2048
# "voyage-3.5-lite": {"embedding_dimension": 1024, "context_length": 32000}, # supports dimensions 256, 512, 1024, 2048
# "voyage-code-3": {"embedding_dimension": 1024, "context_length": 32000}, # supports dimensions 256, 512, 1024, 2048
# "voyage-finance-2": {"embedding_dimension": 1024, "context_length": 32000},
# "voyage-law-2": {"embedding_dimension": 1024, "context_length": 16000},
# "voyage-multimodal-3": {"embedding_dimension": 1024, "context_length": 32000},
# }
def get_base_url(self):
return "https://api.anthropic.com/v1"
async def list_provider_model_ids(self) -> Iterable[str]:
return [m.id async for m in AsyncAnthropic(api_key=self.get_api_key()).models.list()]

View file

@ -1,28 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from pydantic import BaseModel, Field
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
class AnthropicProviderDataValidator(BaseModel):
anthropic_api_key: str | None = Field(
default=None,
description="API key for Anthropic models",
)
@json_schema_type
class AnthropicConfig(RemoteInferenceProviderConfig):
@classmethod
def sample_run_config(cls, api_key: str = "${env.ANTHROPIC_API_KEY:=}", **kwargs) -> dict[str, Any]:
return {
"api_key": api_key,
}

View file

@ -1,15 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import AzureConfig
async def get_adapter_impl(config: AzureConfig, _deps):
from .azure import AzureInferenceAdapter
impl = AzureInferenceAdapter(config=config)
await impl.initialize()
return impl

View file

@ -1,25 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from urllib.parse import urljoin
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .config import AzureConfig
class AzureInferenceAdapter(OpenAIMixin):
config: AzureConfig
provider_data_api_key_field: str = "azure_api_key"
def get_base_url(self) -> str:
"""
Get the Azure API base URL.
Returns the Azure API base URL from the configuration.
"""
return urljoin(str(self.config.api_base), "/openai/v1")

View file

@ -1,61 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
from typing import Any
from pydantic import BaseModel, Field, HttpUrl, SecretStr
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
class AzureProviderDataValidator(BaseModel):
azure_api_key: SecretStr = Field(
description="Azure API key for Azure",
)
azure_api_base: HttpUrl = Field(
description="Azure API base for Azure (e.g., https://your-resource-name.openai.azure.com)",
)
azure_api_version: str | None = Field(
default=None,
description="Azure API version for Azure (e.g., 2024-06-01)",
)
azure_api_type: str | None = Field(
default="azure",
description="Azure API type for Azure (e.g., azure)",
)
@json_schema_type
class AzureConfig(RemoteInferenceProviderConfig):
api_base: HttpUrl = Field(
description="Azure API base for Azure (e.g., https://your-resource-name.openai.azure.com)",
)
api_version: str | None = Field(
default_factory=lambda: os.getenv("AZURE_API_VERSION"),
description="Azure API version for Azure (e.g., 2024-12-01-preview)",
)
api_type: str | None = Field(
default_factory=lambda: os.getenv("AZURE_API_TYPE", "azure"),
description="Azure API type for Azure (e.g., azure)",
)
@classmethod
def sample_run_config(
cls,
api_key: str = "${env.AZURE_API_KEY:=}",
api_base: str = "${env.AZURE_API_BASE:=}",
api_version: str = "${env.AZURE_API_VERSION:=}",
api_type: str = "${env.AZURE_API_TYPE:=}",
**kwargs,
) -> dict[str, Any]:
return {
"api_key": api_key,
"api_base": api_base,
"api_version": api_version,
"api_type": api_type,
}

View file

@ -1,18 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import BedrockConfig
async def get_adapter_impl(config: BedrockConfig, _deps):
from .bedrock import BedrockInferenceAdapter
assert isinstance(config, BedrockConfig), f"Unexpected config type: {type(config)}"
impl = BedrockInferenceAdapter(config)
await impl.initialize()
return impl

View file

@ -1,142 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from collections.abc import AsyncIterator
from botocore.client import BaseClient
from llama_stack.apis.inference import (
ChatCompletionRequest,
Inference,
OpenAIChatCompletionRequestWithExtraBody,
OpenAICompletionRequestWithExtraBody,
OpenAIEmbeddingsRequestWithExtraBody,
OpenAIEmbeddingsResponse,
)
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
)
from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig
from llama_stack.providers.utils.bedrock.client import create_bedrock_client
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_strategy_options,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
)
from .models import MODEL_ENTRIES
REGION_PREFIX_MAP = {
"us": "us.",
"eu": "eu.",
"ap": "ap.",
}
def _get_region_prefix(region: str | None) -> str:
# AWS requires region prefixes for inference profiles
if region is None:
return "us." # default to US when we don't know
# Handle case insensitive region matching
region_lower = region.lower()
for prefix in REGION_PREFIX_MAP:
if region_lower.startswith(f"{prefix}-"):
return REGION_PREFIX_MAP[prefix]
# Fallback to US for anything we don't recognize
return "us."
def _to_inference_profile_id(model_id: str, region: str = None) -> str:
# Return ARNs unchanged
if model_id.startswith("arn:"):
return model_id
# Return inference profile IDs that already have regional prefixes
if any(model_id.startswith(p) for p in REGION_PREFIX_MAP.values()):
return model_id
# Default to US East when no region is provided
if region is None:
region = "us-east-1"
return _get_region_prefix(region) + model_id
class BedrockInferenceAdapter(
ModelRegistryHelper,
Inference,
):
def __init__(self, config: BedrockConfig) -> None:
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
self._config = config
self._client = None
@property
def client(self) -> BaseClient:
if self._client is None:
self._client = create_bedrock_client(self._config)
return self._client
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
if self._client is not None:
self._client.close()
async def _get_params_for_chat_completion(self, request: ChatCompletionRequest) -> dict:
bedrock_model = request.model
sampling_params = request.sampling_params
options = get_sampling_strategy_options(sampling_params)
if sampling_params.max_tokens:
options["max_gen_len"] = sampling_params.max_tokens
if sampling_params.repetition_penalty > 0:
options["repetition_penalty"] = sampling_params.repetition_penalty
prompt = await chat_completion_request_to_prompt(request, self.get_llama_model(request.model))
# Convert foundation model ID to inference profile ID
region_name = self.client.meta.region_name
inference_profile_id = _to_inference_profile_id(bedrock_model, region_name)
return {
"modelId": inference_profile_id,
"body": json.dumps(
{
"prompt": prompt,
**options,
}
),
}
async def openai_embeddings(
self,
params: OpenAIEmbeddingsRequestWithExtraBody,
) -> OpenAIEmbeddingsResponse:
raise NotImplementedError()
async def openai_completion(
self,
params: OpenAICompletionRequestWithExtraBody,
) -> OpenAICompletion:
raise NotImplementedError("OpenAI completion not supported by the Bedrock provider")
async def openai_chat_completion(
self,
params: OpenAIChatCompletionRequestWithExtraBody,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
raise NotImplementedError("OpenAI chat completion not supported by the Bedrock provider")

View file

@ -1,11 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.providers.utils.bedrock.config import BedrockBaseConfig
class BedrockConfig(BedrockBaseConfig):
pass

View file

@ -1,29 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.utils.inference.model_registry import (
build_hf_repo_model_entry,
)
SAFETY_MODELS_ENTRIES = []
# https://docs.aws.amazon.com/bedrock/latest/userguide/models-supported.html
MODEL_ENTRIES = [
build_hf_repo_model_entry(
"meta.llama3-1-8b-instruct-v1:0",
CoreModelId.llama3_1_8b_instruct.value,
),
build_hf_repo_model_entry(
"meta.llama3-1-70b-instruct-v1:0",
CoreModelId.llama3_1_70b_instruct.value,
),
build_hf_repo_model_entry(
"meta.llama3-1-405b-instruct-v1:0",
CoreModelId.llama3_1_405b_instruct.value,
),
] + SAFETY_MODELS_ENTRIES

View file

@ -1,19 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import CerebrasImplConfig
async def get_adapter_impl(config: CerebrasImplConfig, _deps):
from .cerebras import CerebrasInferenceAdapter
assert isinstance(config, CerebrasImplConfig), f"Unexpected config type: {type(config)}"
impl = CerebrasInferenceAdapter(config=config)
await impl.initialize()
return impl

View file

@ -1,28 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from urllib.parse import urljoin
from llama_stack.apis.inference import (
OpenAIEmbeddingsRequestWithExtraBody,
OpenAIEmbeddingsResponse,
)
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .config import CerebrasImplConfig
class CerebrasInferenceAdapter(OpenAIMixin):
config: CerebrasImplConfig
def get_base_url(self) -> str:
return urljoin(self.config.base_url, "v1")
async def openai_embeddings(
self,
params: OpenAIEmbeddingsRequestWithExtraBody,
) -> OpenAIEmbeddingsResponse:
raise NotImplementedError()

View file

@ -1,30 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
from typing import Any
from pydantic import Field
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
DEFAULT_BASE_URL = "https://api.cerebras.ai"
@json_schema_type
class CerebrasImplConfig(RemoteInferenceProviderConfig):
base_url: str = Field(
default=os.environ.get("CEREBRAS_BASE_URL", DEFAULT_BASE_URL),
description="Base URL for the Cerebras API",
)
@classmethod
def sample_run_config(cls, api_key: str = "${env.CEREBRAS_API_KEY:=}", **kwargs) -> dict[str, Any]:
return {
"base_url": DEFAULT_BASE_URL,
"api_key": api_key,
}

View file

@ -1,16 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import DatabricksImplConfig
async def get_adapter_impl(config: DatabricksImplConfig, _deps):
from .databricks import DatabricksInferenceAdapter
assert isinstance(config, DatabricksImplConfig), f"Unexpected config type: {type(config)}"
impl = DatabricksInferenceAdapter(config=config)
await impl.initialize()
return impl

View file

@ -1,37 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from pydantic import Field, SecretStr
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
@json_schema_type
class DatabricksImplConfig(RemoteInferenceProviderConfig):
url: str | None = Field(
default=None,
description="The URL for the Databricks model serving endpoint",
)
auth_credential: SecretStr | None = Field(
default=None,
alias="api_token",
description="The Databricks API token",
)
@classmethod
def sample_run_config(
cls,
url: str = "${env.DATABRICKS_HOST:=}",
api_token: str = "${env.DATABRICKS_TOKEN:=}",
**kwargs: Any,
) -> dict[str, Any]:
return {
"url": url,
"api_token": api_token,
}

View file

@ -1,44 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import Iterable
from databricks.sdk import WorkspaceClient
from llama_stack.apis.inference import OpenAICompletion, OpenAICompletionRequestWithExtraBody
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .config import DatabricksImplConfig
logger = get_logger(name=__name__, category="inference::databricks")
class DatabricksInferenceAdapter(OpenAIMixin):
config: DatabricksImplConfig
# source: https://docs.databricks.com/aws/en/machine-learning/foundation-model-apis/supported-models
embedding_model_metadata: dict[str, dict[str, int]] = {
"databricks-gte-large-en": {"embedding_dimension": 1024, "context_length": 8192},
"databricks-bge-large-en": {"embedding_dimension": 1024, "context_length": 512},
}
def get_base_url(self) -> str:
return f"{self.config.url}/serving-endpoints"
async def list_provider_model_ids(self) -> Iterable[str]:
return [
endpoint.name
for endpoint in WorkspaceClient(
host=self.config.url, token=self.get_api_key()
).serving_endpoints.list() # TODO: this is not async
]
async def openai_completion(
self,
params: OpenAICompletionRequestWithExtraBody,
) -> OpenAICompletion:
raise NotImplementedError()

View file

@ -1,22 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
from .config import FireworksImplConfig
class FireworksProviderDataValidator(BaseModel):
fireworks_api_key: str
async def get_adapter_impl(config: FireworksImplConfig, _deps):
from .fireworks import FireworksInferenceAdapter
assert isinstance(config, FireworksImplConfig), f"Unexpected config type: {type(config)}"
impl = FireworksInferenceAdapter(config=config)
await impl.initialize()
return impl

View file

@ -1,27 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from pydantic import Field
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
@json_schema_type
class FireworksImplConfig(RemoteInferenceProviderConfig):
url: str = Field(
default="https://api.fireworks.ai/inference/v1",
description="The URL for the Fireworks server",
)
@classmethod
def sample_run_config(cls, api_key: str = "${env.FIREWORKS_API_KEY:=}", **kwargs) -> dict[str, Any]:
return {
"url": "https://api.fireworks.ai/inference/v1",
"api_key": api_key,
}

View file

@ -1,27 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .config import FireworksImplConfig
logger = get_logger(name=__name__, category="inference::fireworks")
class FireworksInferenceAdapter(OpenAIMixin):
config: FireworksImplConfig
embedding_model_metadata: dict[str, dict[str, int]] = {
"nomic-ai/nomic-embed-text-v1.5": {"embedding_dimension": 768, "context_length": 8192},
"accounts/fireworks/models/qwen3-embedding-8b": {"embedding_dimension": 4096, "context_length": 40960},
}
provider_data_api_key_field: str = "fireworks_api_key"
def get_base_url(self) -> str:
return "https://api.fireworks.ai/inference/v1"

View file

@ -1,15 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import GeminiConfig
async def get_adapter_impl(config: GeminiConfig, _deps):
from .gemini import GeminiInferenceAdapter
impl = GeminiInferenceAdapter(config=config)
await impl.initialize()
return impl

View file

@ -1,28 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from pydantic import BaseModel, Field
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
class GeminiProviderDataValidator(BaseModel):
gemini_api_key: str | None = Field(
default=None,
description="API key for Gemini models",
)
@json_schema_type
class GeminiConfig(RemoteInferenceProviderConfig):
@classmethod
def sample_run_config(cls, api_key: str = "${env.GEMINI_API_KEY:=}", **kwargs) -> dict[str, Any]:
return {
"api_key": api_key,
}

View file

@ -1,82 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from openai import NOT_GIVEN
from llama_stack.apis.inference import (
OpenAIEmbeddingData,
OpenAIEmbeddingsRequestWithExtraBody,
OpenAIEmbeddingsResponse,
OpenAIEmbeddingUsage,
)
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .config import GeminiConfig
class GeminiInferenceAdapter(OpenAIMixin):
config: GeminiConfig
provider_data_api_key_field: str = "gemini_api_key"
embedding_model_metadata: dict[str, dict[str, int]] = {
"models/text-embedding-004": {"embedding_dimension": 768, "context_length": 2048},
"models/gemini-embedding-001": {"embedding_dimension": 3072, "context_length": 2048},
}
def get_base_url(self):
return "https://generativelanguage.googleapis.com/v1beta/openai/"
async def openai_embeddings(
self,
params: OpenAIEmbeddingsRequestWithExtraBody,
) -> OpenAIEmbeddingsResponse:
"""
Override embeddings method to handle Gemini's missing usage statistics.
Gemini's embedding API doesn't return usage information, so we provide default values.
"""
# Prepare request parameters
request_params = {
"model": await self._get_provider_model_id(params.model),
"input": params.input,
"encoding_format": params.encoding_format if params.encoding_format is not None else NOT_GIVEN,
"dimensions": params.dimensions if params.dimensions is not None else NOT_GIVEN,
"user": params.user if params.user is not None else NOT_GIVEN,
}
# Add extra_body if present
extra_body = params.model_extra
if extra_body:
request_params["extra_body"] = extra_body
# Call OpenAI embeddings API with properly typed parameters
response = await self.client.embeddings.create(**request_params)
data = []
for i, embedding_data in enumerate(response.data):
data.append(
OpenAIEmbeddingData(
embedding=embedding_data.embedding,
index=i,
)
)
# Gemini doesn't return usage statistics - use default values
if hasattr(response, "usage") and response.usage:
usage = OpenAIEmbeddingUsage(
prompt_tokens=response.usage.prompt_tokens,
total_tokens=response.usage.total_tokens,
)
else:
usage = OpenAIEmbeddingUsage(
prompt_tokens=0,
total_tokens=0,
)
return OpenAIEmbeddingsResponse(
data=data,
model=params.model,
usage=usage,
)

View file

@ -1,15 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import GroqConfig
async def get_adapter_impl(config: GroqConfig, _deps):
# import dynamically so the import is used only when it is needed
from .groq import GroqInferenceAdapter
adapter = GroqInferenceAdapter(config=config)
return adapter

View file

@ -1,34 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from pydantic import BaseModel, Field
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
class GroqProviderDataValidator(BaseModel):
groq_api_key: str | None = Field(
default=None,
description="API key for Groq models",
)
@json_schema_type
class GroqConfig(RemoteInferenceProviderConfig):
url: str = Field(
default="https://api.groq.com",
description="The URL for the Groq AI server",
)
@classmethod
def sample_run_config(cls, api_key: str = "${env.GROQ_API_KEY:=}", **kwargs) -> dict[str, Any]:
return {
"url": "https://api.groq.com",
"api_key": api_key,
}

View file

@ -1,18 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.providers.remote.inference.groq.config import GroqConfig
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
class GroqInferenceAdapter(OpenAIMixin):
config: GroqConfig
provider_data_api_key_field: str = "groq_api_key"
def get_base_url(self) -> str:
return f"{self.config.url}/openai/v1"

View file

@ -1,15 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import LlamaCompatConfig
async def get_adapter_impl(config: LlamaCompatConfig, _deps):
# import dynamically so the import is used only when it is needed
from .llama import LlamaCompatInferenceAdapter
adapter = LlamaCompatInferenceAdapter(config=config)
return adapter

View file

@ -1,34 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from pydantic import BaseModel, Field
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
class LlamaProviderDataValidator(BaseModel):
llama_api_key: str | None = Field(
default=None,
description="API key for api.llama models",
)
@json_schema_type
class LlamaCompatConfig(RemoteInferenceProviderConfig):
openai_compat_api_base: str = Field(
default="https://api.llama.com/compat/v1/",
description="The URL for the Llama API server",
)
@classmethod
def sample_run_config(cls, api_key: str = "${env.LLAMA_API_KEY}", **kwargs) -> dict[str, Any]:
return {
"openai_compat_api_base": "https://api.llama.com/compat/v1/",
"api_key": api_key,
}

View file

@ -1,46 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.inference.inference import (
OpenAICompletion,
OpenAICompletionRequestWithExtraBody,
OpenAIEmbeddingsRequestWithExtraBody,
OpenAIEmbeddingsResponse,
)
from llama_stack.log import get_logger
from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
logger = get_logger(name=__name__, category="inference::llama_openai_compat")
class LlamaCompatInferenceAdapter(OpenAIMixin):
config: LlamaCompatConfig
provider_data_api_key_field: str = "llama_api_key"
"""
Llama API Inference Adapter for Llama Stack.
"""
def get_base_url(self) -> str:
"""
Get the base URL for OpenAI mixin.
:return: The Llama API base URL
"""
return self.config.openai_compat_api_base
async def openai_completion(
self,
params: OpenAICompletionRequestWithExtraBody,
) -> OpenAICompletion:
raise NotImplementedError()
async def openai_embeddings(
self,
params: OpenAIEmbeddingsRequestWithExtraBody,
) -> OpenAIEmbeddingsResponse:
raise NotImplementedError()

View file

@ -1,183 +0,0 @@
# NVIDIA Inference Provider for LlamaStack
This provider enables running inference using NVIDIA NIM.
## Features
- Endpoints for completions, chat completions, and embeddings for registered models
## Getting Started
### Prerequisites
- LlamaStack with NVIDIA configuration
- Access to NVIDIA NIM deployment
- NIM for model to use for inference is deployed
### Setup
Build the NVIDIA environment:
```bash
uv run llama stack list-deps nvidia | xargs -L1 uv pip install
```
### Basic Usage using the LlamaStack Python Client
#### Initialize the client
```python
import os
os.environ["NVIDIA_API_KEY"] = (
"" # Required if using hosted NIM endpoint. If self-hosted, not required.
)
os.environ["NVIDIA_BASE_URL"] = "http://nim.test" # NIM URL
from llama_stack.core.library_client import LlamaStackAsLibraryClient
client = LlamaStackAsLibraryClient("nvidia")
client.initialize()
```
### Create Chat Completion
The following example shows how to create a chat completion for an NVIDIA NIM.
```python
response = client.chat.completions.create(
model="nvidia/meta/llama-3.1-8b-instruct",
messages=[
{
"role": "system",
"content": "You must respond to each message with only one word",
},
{
"role": "user",
"content": "Complete the sentence using one word: Roses are red, violets are:",
},
],
stream=False,
max_tokens=50,
)
print(f"Response: {response.choices[0].message.content}")
```
### Tool Calling Example ###
The following example shows how to do tool calling for an NVIDIA NIM.
```python
tool_definition = {
"type": "function",
"function": {
"name": "get_weather",
"description": "Get current weather information for a location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {
"type": "string",
"description": "Temperature unit (celsius or fahrenheit)",
"default": "celsius",
},
},
"required": ["location"],
},
},
}
tool_response = client.chat.completions.create(
model="nvidia/meta/llama-3.1-8b-instruct",
messages=[{"role": "user", "content": "What's the weather like in San Francisco?"}],
tools=[tool_definition],
)
print(f"Response content: {tool_response.choices[0].message.content}")
if tool_response.choices[0].message.tool_calls:
for tool_call in tool_response.choices[0].message.tool_calls:
print(f"Tool Called: {tool_call.function.name}")
print(f"Arguments: {tool_call.function.arguments}")
```
### Structured Output Example
The following example shows how to do structured output for an NVIDIA NIM.
```python
person_schema = {
"type": "object",
"properties": {
"name": {"type": "string"},
"age": {"type": "number"},
"occupation": {"type": "string"},
},
"required": ["name", "age", "occupation"],
}
structured_response = client.chat.completions.create(
model="nvidia/meta/llama-3.1-8b-instruct",
messages=[
{
"role": "user",
"content": "Create a profile for a fictional person named Alice who is 30 years old and is a software engineer. ",
}
],
extra_body={"nvext": {"guided_json": person_schema}},
)
print(f"Structured Response: {structured_response.choices[0].message.content}")
```
### Create Embeddings
The following example shows how to create embeddings for an NVIDIA NIM.
```python
response = client.embeddings.create(
model="nvidia/nvidia/llama-3.2-nv-embedqa-1b-v2",
input=["What is the capital of France?"],
extra_body={"input_type": "query"},
)
print(f"Embeddings: {response.data}")
```
### Vision Language Models Example
The following example shows how to run vision inference by using an NVIDIA NIM.
```python
def load_image_as_base64(image_path):
with open(image_path, "rb") as image_file:
img_bytes = image_file.read()
return base64.b64encode(img_bytes).decode("utf-8")
image_path = {path_to_the_image}
demo_image_b64 = load_image_as_base64(image_path)
vlm_response = client.chat.completions.create(
model="nvidia/meta/llama-3.2-11b-vision-instruct",
messages=[
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{demo_image_b64}",
},
},
{
"type": "text",
"text": "Please describe what you see in this image in detail.",
},
],
}
],
)
print(f"VLM Response: {vlm_response.choices[0].message.content}")
```

View file

@ -1,23 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.inference import Inference
from .config import NVIDIAConfig
async def get_adapter_impl(config: NVIDIAConfig, _deps) -> Inference:
# import dynamically so `llama stack list-deps` does not fail due to missing dependencies
from .nvidia import NVIDIAInferenceAdapter
if not isinstance(config, NVIDIAConfig):
raise RuntimeError(f"Unexpected config type: {type(config)}")
adapter = NVIDIAInferenceAdapter(config=config)
await adapter.initialize()
return adapter
__all__ = ["get_adapter_impl", "NVIDIAConfig"]

View file

@ -1,64 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
from typing import Any
from pydantic import Field
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
@json_schema_type
class NVIDIAConfig(RemoteInferenceProviderConfig):
"""
Configuration for the NVIDIA NIM inference endpoint.
Attributes:
url (str): A base url for accessing the NVIDIA NIM, e.g. http://localhost:8000
api_key (str): The access key for the hosted NIM endpoints
There are two ways to access NVIDIA NIMs -
0. Hosted: Preview APIs hosted at https://integrate.api.nvidia.com
1. Self-hosted: You can run NVIDIA NIMs on your own infrastructure
By default the configuration is set to use the hosted APIs. This requires
an API key which can be obtained from https://ngc.nvidia.com/.
By default the configuration will attempt to read the NVIDIA_API_KEY environment
variable to set the api_key. Please do not put your API key in code.
If you are using a self-hosted NVIDIA NIM, you can set the url to the
URL of your running NVIDIA NIM and do not need to set the api_key.
"""
url: str = Field(
default_factory=lambda: os.getenv("NVIDIA_BASE_URL", "https://integrate.api.nvidia.com"),
description="A base url for accessing the NVIDIA NIM",
)
timeout: int = Field(
default=60,
description="Timeout for the HTTP requests",
)
append_api_version: bool = Field(
default_factory=lambda: os.getenv("NVIDIA_APPEND_API_VERSION", "True").lower() != "false",
description="When set to false, the API version will not be appended to the base_url. By default, it is true.",
)
@classmethod
def sample_run_config(
cls,
url: str = "${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com}",
api_key: str = "${env.NVIDIA_API_KEY:=}",
append_api_version: bool = "${env.NVIDIA_APPEND_API_VERSION:=True}",
**kwargs,
) -> dict[str, Any]:
return {
"url": url,
"api_key": api_key,
"append_api_version": append_api_version,
}

View file

@ -1,61 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from . import NVIDIAConfig
from .utils import _is_nvidia_hosted
logger = get_logger(name=__name__, category="inference::nvidia")
class NVIDIAInferenceAdapter(OpenAIMixin):
config: NVIDIAConfig
"""
NVIDIA Inference Adapter for Llama Stack.
"""
# source: https://docs.nvidia.com/nim/nemo-retriever/text-embedding/latest/support-matrix.html
embedding_model_metadata: dict[str, dict[str, int]] = {
"nvidia/llama-3.2-nv-embedqa-1b-v2": {"embedding_dimension": 2048, "context_length": 8192},
"nvidia/nv-embedqa-e5-v5": {"embedding_dimension": 512, "context_length": 1024},
"nvidia/nv-embedqa-mistral-7b-v2": {"embedding_dimension": 512, "context_length": 4096},
"snowflake/arctic-embed-l": {"embedding_dimension": 512, "context_length": 1024},
}
async def initialize(self) -> None:
logger.info(f"Initializing NVIDIAInferenceAdapter({self.config.url})...")
if _is_nvidia_hosted(self.config):
if not self.config.auth_credential:
raise RuntimeError(
"API key is required for hosted NVIDIA NIM. Either provide an API key or use a self-hosted NIM."
)
def get_api_key(self) -> str:
"""
Get the API key for OpenAI mixin.
:return: The NVIDIA API key
"""
if self.config.auth_credential:
return self.config.auth_credential.get_secret_value()
if not _is_nvidia_hosted(self.config):
return "NO KEY REQUIRED"
return None
def get_base_url(self) -> str:
"""
Get the base URL for OpenAI mixin.
:return: The NVIDIA API base URL
"""
return f"{self.config.url}/v1" if self.config.append_api_version else self.config.url

View file

@ -1,11 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from . import NVIDIAConfig
def _is_nvidia_hosted(config: NVIDIAConfig) -> bool:
return "integrate.api.nvidia.com" in config.url

View file

@ -1,15 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import OllamaImplConfig
async def get_adapter_impl(config: OllamaImplConfig, _deps):
from .ollama import OllamaInferenceAdapter
impl = OllamaInferenceAdapter(config=config)
await impl.initialize()
return impl

View file

@ -1,25 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from pydantic import Field, SecretStr
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
DEFAULT_OLLAMA_URL = "http://localhost:11434"
class OllamaImplConfig(RemoteInferenceProviderConfig):
auth_credential: SecretStr | None = Field(default=None, exclude=True)
url: str = DEFAULT_OLLAMA_URL
@classmethod
def sample_run_config(cls, url: str = "${env.OLLAMA_URL:=http://localhost:11434}", **kwargs) -> dict[str, Any]:
return {
"url": url,
}

View file

@ -1,102 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
from ollama import AsyncClient as AsyncOllamaClient
from llama_stack.apis.common.errors import UnsupportedModelError
from llama_stack.apis.models import Model
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import (
HealthResponse,
HealthStatus,
)
from llama_stack.providers.remote.inference.ollama.config import OllamaImplConfig
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
logger = get_logger(name=__name__, category="inference::ollama")
class OllamaInferenceAdapter(OpenAIMixin):
config: OllamaImplConfig
# automatically set by the resolver when instantiating the provider
__provider_id__: str
embedding_model_metadata: dict[str, dict[str, int]] = {
"all-minilm:l6-v2": {
"embedding_dimension": 384,
"context_length": 512,
},
"nomic-embed-text:latest": {
"embedding_dimension": 768,
"context_length": 8192,
},
"nomic-embed-text:v1.5": {
"embedding_dimension": 768,
"context_length": 8192,
},
"nomic-embed-text:137m-v1.5-fp16": {
"embedding_dimension": 768,
"context_length": 8192,
},
}
download_images: bool = True
_clients: dict[asyncio.AbstractEventLoop, AsyncOllamaClient] = {}
@property
def ollama_client(self) -> AsyncOllamaClient:
# ollama client attaches itself to the current event loop (sadly?)
loop = asyncio.get_running_loop()
if loop not in self._clients:
self._clients[loop] = AsyncOllamaClient(host=self.config.url)
return self._clients[loop]
def get_api_key(self):
return "NO KEY REQUIRED"
def get_base_url(self):
return self.config.url.rstrip("/") + "/v1"
async def initialize(self) -> None:
logger.info(f"checking connectivity to Ollama at `{self.config.url}`...")
r = await self.health()
if r["status"] == HealthStatus.ERROR:
logger.warning(
f"Ollama Server is not running (message: {r['message']}). Make sure to start it using `ollama serve` in a separate terminal"
)
async def health(self) -> HealthResponse:
"""
Performs a health check by verifying connectivity to the Ollama server.
This method is used by initialize() and the Provider API to verify that the service is running
correctly.
Returns:
HealthResponse: A dictionary containing the health status.
"""
try:
await self.ollama_client.ps()
return HealthResponse(status=HealthStatus.OK)
except Exception as e:
return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}")
async def shutdown(self) -> None:
self._clients.clear()
async def register_model(self, model: Model) -> Model:
if await self.check_model_availability(model.provider_model_id):
return model
elif await self.check_model_availability(f"{model.provider_model_id}:latest"):
model.provider_resource_id = f"{model.provider_model_id}:latest"
logger.warning(
f"Imprecise provider resource id was used but 'latest' is available in Ollama - using '{model.provider_model_id}'"
)
return model
raise UnsupportedModelError(model.provider_model_id, list(self._model_cache.keys()))

View file

@ -1,15 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import OpenAIConfig
async def get_adapter_impl(config: OpenAIConfig, _deps):
from .openai import OpenAIInferenceAdapter
impl = OpenAIInferenceAdapter(config=config)
await impl.initialize()
return impl

View file

@ -1,39 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from pydantic import BaseModel, Field
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
class OpenAIProviderDataValidator(BaseModel):
openai_api_key: str | None = Field(
default=None,
description="API key for OpenAI models",
)
@json_schema_type
class OpenAIConfig(RemoteInferenceProviderConfig):
base_url: str = Field(
default="https://api.openai.com/v1",
description="Base URL for OpenAI API",
)
@classmethod
def sample_run_config(
cls,
api_key: str = "${env.OPENAI_API_KEY:=}",
base_url: str = "${env.OPENAI_BASE_URL:=https://api.openai.com/v1}",
**kwargs,
) -> dict[str, Any]:
return {
"api_key": api_key,
"base_url": base_url,
}

View file

@ -1,38 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .config import OpenAIConfig
logger = get_logger(name=__name__, category="inference::openai")
#
# This OpenAI adapter implements Inference methods using OpenAIMixin
#
class OpenAIInferenceAdapter(OpenAIMixin):
"""
OpenAI Inference Adapter for Llama Stack.
"""
config: OpenAIConfig
provider_data_api_key_field: str = "openai_api_key"
embedding_model_metadata: dict[str, dict[str, int]] = {
"text-embedding-3-small": {"embedding_dimension": 1536, "context_length": 8192},
"text-embedding-3-large": {"embedding_dimension": 3072, "context_length": 8192},
}
def get_base_url(self) -> str:
"""
Get the OpenAI API base URL.
Returns the OpenAI API base URL from the configuration.
"""
return self.config.base_url

View file

@ -1,23 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
from .config import PassthroughImplConfig
class PassthroughProviderDataValidator(BaseModel):
url: str
api_key: str
async def get_adapter_impl(config: PassthroughImplConfig, _deps):
from .passthrough import PassthroughInferenceAdapter
assert isinstance(config, PassthroughImplConfig), f"Unexpected config type: {type(config)}"
impl = PassthroughInferenceAdapter(config)
await impl.initialize()
return impl

View file

@ -1,34 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from pydantic import Field, SecretStr
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
@json_schema_type
class PassthroughImplConfig(RemoteInferenceProviderConfig):
url: str = Field(
default=None,
description="The URL for the passthrough endpoint",
)
api_key: SecretStr | None = Field(
default=None,
description="API Key for the passthrouth endpoint",
)
@classmethod
def sample_run_config(
cls, url: str = "${env.PASSTHROUGH_URL}", api_key: str = "${env.PASSTHROUGH_API_KEY}", **kwargs
) -> dict[str, Any]:
return {
"url": url,
"api_key": api_key,
}

View file

@ -1,122 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import AsyncIterator
from typing import Any
from llama_stack_client import AsyncLlamaStackClient
from llama_stack.apis.inference import (
Inference,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAIChatCompletionRequestWithExtraBody,
OpenAICompletion,
OpenAICompletionRequestWithExtraBody,
OpenAIEmbeddingsRequestWithExtraBody,
OpenAIEmbeddingsResponse,
)
from llama_stack.apis.models import Model
from llama_stack.core.library_client import convert_pydantic_to_json_value
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from .config import PassthroughImplConfig
class PassthroughInferenceAdapter(Inference):
def __init__(self, config: PassthroughImplConfig) -> None:
ModelRegistryHelper.__init__(self)
self.config = config
async def unregister_model(self, model_id: str) -> None:
pass
async def register_model(self, model: Model) -> Model:
return model
def _get_client(self) -> AsyncLlamaStackClient:
passthrough_url = None
passthrough_api_key = None
provider_data = None
if self.config.url is not None:
passthrough_url = self.config.url
else:
provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.passthrough_url:
raise ValueError(
'Pass url of the passthrough endpoint in the header X-LlamaStack-Provider-Data as { "passthrough_url": <your passthrough url>}'
)
passthrough_url = provider_data.passthrough_url
if self.config.api_key is not None:
passthrough_api_key = self.config.api_key.get_secret_value()
else:
provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.passthrough_api_key:
raise ValueError(
'Pass API Key for the passthrough endpoint in the header X-LlamaStack-Provider-Data as { "passthrough_api_key": <your api key>}'
)
passthrough_api_key = provider_data.passthrough_api_key
return AsyncLlamaStackClient(
base_url=passthrough_url,
api_key=passthrough_api_key,
provider_data=provider_data,
)
async def openai_embeddings(
self,
params: OpenAIEmbeddingsRequestWithExtraBody,
) -> OpenAIEmbeddingsResponse:
raise NotImplementedError()
async def openai_completion(
self,
params: OpenAICompletionRequestWithExtraBody,
) -> OpenAICompletion:
client = self._get_client()
model_obj = await self.model_store.get_model(params.model)
params = params.model_copy()
params.model = model_obj.provider_resource_id
request_params = params.model_dump(exclude_none=True)
return await client.inference.openai_completion(**request_params)
async def openai_chat_completion(
self,
params: OpenAIChatCompletionRequestWithExtraBody,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
client = self._get_client()
model_obj = await self.model_store.get_model(params.model)
params = params.model_copy()
params.model = model_obj.provider_resource_id
request_params = params.model_dump(exclude_none=True)
return await client.inference.openai_chat_completion(**request_params)
def cast_value_to_json_dict(self, request_params: dict[str, Any]) -> dict[str, Any]:
json_params = {}
for key, value in request_params.items():
json_input = convert_pydantic_to_json_value(value)
if isinstance(json_input, dict):
json_input = {k: v for k, v in json_input.items() if v is not None}
elif isinstance(json_input, list):
json_input = [x for x in json_input if x is not None]
new_input = []
for x in json_input:
if isinstance(x, dict):
x = {k: v for k, v in x.items() if v is not None}
new_input.append(x)
json_input = new_input
json_params[key] = json_input
return json_params

View file

@ -1,16 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import RunpodImplConfig
async def get_adapter_impl(config: RunpodImplConfig, _deps):
from .runpod import RunpodInferenceAdapter
assert isinstance(config, RunpodImplConfig), f"Unexpected config type: {type(config)}"
impl = RunpodInferenceAdapter(config=config)
await impl.initialize()
return impl

View file

@ -1,32 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from pydantic import Field, SecretStr
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
@json_schema_type
class RunpodImplConfig(RemoteInferenceProviderConfig):
url: str | None = Field(
default=None,
description="The URL for the Runpod model serving endpoint",
)
auth_credential: SecretStr | None = Field(
default=None,
alias="api_token",
description="The API token",
)
@classmethod
def sample_run_config(cls, **kwargs: Any) -> dict[str, Any]:
return {
"url": "${env.RUNPOD_URL:=}",
"api_token": "${env.RUNPOD_API_TOKEN}",
}

View file

@ -1,42 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import AsyncIterator
from llama_stack.apis.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAIChatCompletionRequestWithExtraBody,
)
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .config import RunpodImplConfig
class RunpodInferenceAdapter(OpenAIMixin):
"""
Adapter for RunPod's OpenAI-compatible API endpoints.
Supports VLLM for serverless endpoint self-hosted or public endpoints.
Can work with any runpod endpoints that support OpenAI-compatible API
"""
config: RunpodImplConfig
def get_base_url(self) -> str:
"""Get base URL for OpenAI client."""
return self.config.url
async def openai_chat_completion(
self,
params: OpenAIChatCompletionRequestWithExtraBody,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
"""Override to add RunPod-specific stream_options requirement."""
params = params.model_copy()
if params.stream and not params.stream_options:
params.stream_options = {"include_usage": True}
return await super().openai_chat_completion(params)

View file

@ -1,16 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import SambaNovaImplConfig
async def get_adapter_impl(config: SambaNovaImplConfig, _deps):
from .sambanova import SambaNovaInferenceAdapter
assert isinstance(config, SambaNovaImplConfig), f"Unexpected config type: {type(config)}"
impl = SambaNovaInferenceAdapter(config=config)
await impl.initialize()
return impl

View file

@ -1,34 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from pydantic import BaseModel, Field
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
class SambaNovaProviderDataValidator(BaseModel):
sambanova_api_key: str | None = Field(
default=None,
description="Sambanova Cloud API key",
)
@json_schema_type
class SambaNovaImplConfig(RemoteInferenceProviderConfig):
url: str = Field(
default="https://api.sambanova.ai/v1",
description="The URL for the SambaNova AI server",
)
@classmethod
def sample_run_config(cls, api_key: str = "${env.SAMBANOVA_API_KEY:=}", **kwargs) -> dict[str, Any]:
return {
"url": "https://api.sambanova.ai/v1",
"api_key": api_key,
}

View file

@ -1,28 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .config import SambaNovaImplConfig
class SambaNovaInferenceAdapter(OpenAIMixin):
config: SambaNovaImplConfig
provider_data_api_key_field: str = "sambanova_api_key"
download_images: bool = True # SambaNova does not support image downloads server-size, perform them on the client
"""
SambaNova Inference Adapter for Llama Stack.
"""
def get_base_url(self) -> str:
"""
Get the base URL for OpenAI mixin.
:return: The SambaNova base URL
"""
return self.config.url

View file

@ -1,28 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig
async def get_adapter_impl(
config: InferenceAPIImplConfig | InferenceEndpointImplConfig | TGIImplConfig,
_deps,
):
from .tgi import InferenceAPIAdapter, InferenceEndpointAdapter, TGIAdapter
if isinstance(config, TGIImplConfig):
impl = TGIAdapter()
elif isinstance(config, InferenceAPIImplConfig):
impl = InferenceAPIAdapter()
elif isinstance(config, InferenceEndpointImplConfig):
impl = InferenceEndpointAdapter()
else:
raise ValueError(
f"Invalid configuration. Expected 'TGIAdapter', 'InferenceAPIImplConfig' or 'InferenceEndpointImplConfig'. Got {type(config)}."
)
await impl.initialize(config)
return impl

View file

@ -1,76 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel, Field, SecretStr
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
@json_schema_type
class TGIImplConfig(RemoteInferenceProviderConfig):
auth_credential: SecretStr | None = Field(default=None, exclude=True)
url: str = Field(
description="The URL for the TGI serving endpoint",
)
@classmethod
def sample_run_config(
cls,
url: str = "${env.TGI_URL:=}",
**kwargs,
):
return {
"url": url,
}
@json_schema_type
class InferenceEndpointImplConfig(BaseModel):
endpoint_name: str = Field(
description="The name of the Hugging Face Inference Endpoint in the format of '{namespace}/{endpoint_name}' (e.g. 'my-cool-org/meta-llama-3-1-8b-instruct-rce'). Namespace is optional and will default to the user account if not provided.",
)
api_token: SecretStr | None = Field(
default=None,
description="Your Hugging Face user access token (will default to locally saved token if not provided)",
)
@classmethod
def sample_run_config(
cls,
endpoint_name: str = "${env.INFERENCE_ENDPOINT_NAME}",
api_token: str = "${env.HF_API_TOKEN}",
**kwargs,
):
return {
"endpoint_name": endpoint_name,
"api_token": api_token,
}
@json_schema_type
class InferenceAPIImplConfig(BaseModel):
huggingface_repo: str = Field(
description="The model ID of the model on the Hugging Face Hub (e.g. 'meta-llama/Meta-Llama-3.1-70B-Instruct')",
)
api_token: SecretStr | None = Field(
default=None,
description="Your Hugging Face user access token (will default to locally saved token if not provided)",
)
@classmethod
def sample_run_config(
cls,
repo: str = "${env.INFERENCE_MODEL}",
api_token: str = "${env.HF_API_TOKEN}",
**kwargs,
):
return {
"huggingface_repo": repo,
"api_token": api_token,
}

View file

@ -1,85 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import Iterable
from huggingface_hub import AsyncInferenceClient, HfApi
from pydantic import SecretStr
from llama_stack.apis.inference import (
OpenAIEmbeddingsRequestWithExtraBody,
OpenAIEmbeddingsResponse,
)
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig
log = get_logger(name=__name__, category="inference::tgi")
class _HfAdapter(OpenAIMixin):
url: str
api_key: SecretStr
hf_client: AsyncInferenceClient
max_tokens: int
model_id: str
overwrite_completion_id = True # TGI always returns id=""
def get_api_key(self):
return "NO KEY REQUIRED"
def get_base_url(self):
return self.url
async def list_provider_model_ids(self) -> Iterable[str]:
return [self.model_id]
async def openai_embeddings(
self,
params: OpenAIEmbeddingsRequestWithExtraBody,
) -> OpenAIEmbeddingsResponse:
raise NotImplementedError()
class TGIAdapter(_HfAdapter):
async def initialize(self, config: TGIImplConfig) -> None:
if not config.url:
raise ValueError("You must provide a URL in run.yaml (or via the TGI_URL environment variable) to use TGI.")
log.info(f"Initializing TGI client with url={config.url}")
self.hf_client = AsyncInferenceClient(model=config.url, provider="hf-inference")
endpoint_info = await self.hf_client.get_endpoint_info()
self.max_tokens = endpoint_info["max_total_tokens"]
self.model_id = endpoint_info["model_id"]
self.url = f"{config.url.rstrip('/')}/v1"
self.api_key = SecretStr("NO_KEY")
class InferenceAPIAdapter(_HfAdapter):
async def initialize(self, config: InferenceAPIImplConfig) -> None:
self.hf_client = AsyncInferenceClient(model=config.huggingface_repo, token=config.api_token.get_secret_value())
endpoint_info = await self.hf_client.get_endpoint_info()
self.max_tokens = endpoint_info["max_total_tokens"]
self.model_id = endpoint_info["model_id"]
# TODO: how do we set url for this?
class InferenceEndpointAdapter(_HfAdapter):
async def initialize(self, config: InferenceEndpointImplConfig) -> None:
# Get the inference endpoint details
api = HfApi(token=config.api_token.get_secret_value())
endpoint = api.get_inference_endpoint(config.endpoint_name)
# Wait for the endpoint to be ready (if not already)
endpoint.wait(timeout=60)
# Initialize the adapter
self.hf_client = endpoint.async_client
self.model_id = endpoint.repository
self.max_tokens = int(endpoint.raw["model"]["image"]["custom"]["env"]["MAX_TOTAL_TOKENS"])
# TODO: how do we set url for this?

View file

@ -1,22 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
from .config import TogetherImplConfig
class TogetherProviderDataValidator(BaseModel):
together_api_key: str
async def get_adapter_impl(config: TogetherImplConfig, _deps):
from .together import TogetherInferenceAdapter
assert isinstance(config, TogetherImplConfig), f"Unexpected config type: {type(config)}"
impl = TogetherInferenceAdapter(config=config)
await impl.initialize()
return impl

View file

@ -1,27 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from pydantic import Field
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
@json_schema_type
class TogetherImplConfig(RemoteInferenceProviderConfig):
url: str = Field(
default="https://api.together.xyz/v1",
description="The URL for the Together AI server",
)
@classmethod
def sample_run_config(cls, **kwargs) -> dict[str, Any]:
return {
"url": "https://api.together.xyz/v1",
"api_key": "${env.TOGETHER_API_KEY:=}",
}

View file

@ -1,102 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import Iterable
from together import AsyncTogether
from together.constants import BASE_URL
from llama_stack.apis.inference import (
OpenAIEmbeddingsRequestWithExtraBody,
OpenAIEmbeddingsResponse,
)
from llama_stack.apis.inference.inference import OpenAIEmbeddingUsage
from llama_stack.apis.models import Model
from llama_stack.core.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .config import TogetherImplConfig
logger = get_logger(name=__name__, category="inference::together")
class TogetherInferenceAdapter(OpenAIMixin, NeedsRequestProviderData):
config: TogetherImplConfig
embedding_model_metadata: dict[str, dict[str, int]] = {
"togethercomputer/m2-bert-80M-32k-retrieval": {"embedding_dimension": 768, "context_length": 32768},
"BAAI/bge-large-en-v1.5": {"embedding_dimension": 1024, "context_length": 512},
"BAAI/bge-base-en-v1.5": {"embedding_dimension": 768, "context_length": 512},
"Alibaba-NLP/gte-modernbert-base": {"embedding_dimension": 768, "context_length": 8192},
"intfloat/multilingual-e5-large-instruct": {"embedding_dimension": 1024, "context_length": 512},
}
_model_cache: dict[str, Model] = {}
provider_data_api_key_field: str = "together_api_key"
def get_base_url(self):
return BASE_URL
def _get_client(self) -> AsyncTogether:
together_api_key = None
config_api_key = self.config.auth_credential.get_secret_value() if self.config.auth_credential else None
if config_api_key:
together_api_key = config_api_key
else:
provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.together_api_key:
raise ValueError(
'Pass Together API Key in the header X-LlamaStack-Provider-Data as { "together_api_key": <your api key>}'
)
together_api_key = provider_data.together_api_key
return AsyncTogether(api_key=together_api_key)
async def list_provider_model_ids(self) -> Iterable[str]:
# Together's /v1/models is not compatible with OpenAI's /v1/models. Together support ticket #13355 -> will not fix, use Together's own client
return [m.id for m in await self._get_client().models.list()]
async def openai_embeddings(
self,
params: OpenAIEmbeddingsRequestWithExtraBody,
) -> OpenAIEmbeddingsResponse:
"""
Together's OpenAI-compatible embeddings endpoint is not compatible with
the standard OpenAI embeddings endpoint.
The endpoint -
- not all models return usage information
- does not support user param, returns 400 Unrecognized request arguments supplied: user
- does not support dimensions param, returns 400 Unrecognized request arguments supplied: dimensions
"""
# Together support ticket #13332 -> will not fix
if params.user is not None:
raise ValueError("Together's embeddings endpoint does not support user param.")
# Together support ticket #13333 -> escalated
if params.dimensions is not None:
raise ValueError("Together's embeddings endpoint does not support dimensions param.")
response = await self.client.embeddings.create(
model=await self._get_provider_model_id(params.model),
input=params.input,
encoding_format=params.encoding_format,
)
response.model = (
params.model
) # return the user the same model id they provided, avoid exposing the provider model id
# Together support ticket #13330 -> escalated
# - togethercomputer/m2-bert-80M-32k-retrieval *does not* return usage information
if not hasattr(response, "usage") or response.usage is None:
logger.warning(
f"Together's embedding endpoint for {params.model} did not return usage information, substituting -1s."
)
response.usage = OpenAIEmbeddingUsage(prompt_tokens=-1, total_tokens=-1)
return response # type: ignore[no-any-return]

View file

@ -1,15 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import VertexAIConfig
async def get_adapter_impl(config: VertexAIConfig, _deps):
from .vertexai import VertexAIInferenceAdapter
impl = VertexAIInferenceAdapter(config=config)
await impl.initialize()
return impl

View file

@ -1,48 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from pydantic import BaseModel, Field, SecretStr
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
class VertexAIProviderDataValidator(BaseModel):
vertex_project: str | None = Field(
default=None,
description="Google Cloud project ID for Vertex AI",
)
vertex_location: str | None = Field(
default=None,
description="Google Cloud location for Vertex AI (e.g., us-central1)",
)
@json_schema_type
class VertexAIConfig(RemoteInferenceProviderConfig):
auth_credential: SecretStr | None = Field(default=None, exclude=True)
project: str = Field(
description="Google Cloud project ID for Vertex AI",
)
location: str = Field(
default="us-central1",
description="Google Cloud location for Vertex AI",
)
@classmethod
def sample_run_config(
cls,
project: str = "${env.VERTEX_AI_PROJECT:=}",
location: str = "${env.VERTEX_AI_LOCATION:=us-central1}",
**kwargs,
) -> dict[str, Any]:
return {
"project": project,
"location": location,
}

View file

@ -1,44 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import google.auth.transport.requests
from google.auth import default
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .config import VertexAIConfig
class VertexAIInferenceAdapter(OpenAIMixin):
config: VertexAIConfig
provider_data_api_key_field: str = "vertex_project"
def get_api_key(self) -> str:
"""
Get an access token for Vertex AI using Application Default Credentials.
Vertex AI uses ADC instead of API keys. This method obtains an access token
from the default credentials and returns it for use with the OpenAI-compatible client.
"""
try:
# Get default credentials - will read from GOOGLE_APPLICATION_CREDENTIALS
credentials, _ = default(scopes=["https://www.googleapis.com/auth/cloud-platform"])
credentials.refresh(google.auth.transport.requests.Request())
return str(credentials.token)
except Exception:
# If we can't get credentials, return empty string to let the env work with ADC directly
return ""
def get_base_url(self) -> str:
"""
Get the Vertex AI OpenAI-compatible API base URL.
Returns the Vertex AI OpenAI-compatible endpoint URL.
Source: https://cloud.google.com/vertex-ai/generative-ai/docs/start/openai
"""
return f"https://{self.config.location}-aiplatform.googleapis.com/v1/projects/{self.config.project}/locations/{self.config.location}/endpoints/openapi"

View file

@ -1,22 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
from .config import VLLMInferenceAdapterConfig
class VLLMProviderDataValidator(BaseModel):
vllm_api_token: str | None = None
async def get_adapter_impl(config: VLLMInferenceAdapterConfig, _deps):
from .vllm import VLLMInferenceAdapter
assert isinstance(config, VLLMInferenceAdapterConfig), f"Unexpected config type: {type(config)}"
impl = VLLMInferenceAdapter(config=config)
await impl.initialize()
return impl

View file

@ -1,59 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pathlib import Path
from pydantic import Field, SecretStr, field_validator
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
@json_schema_type
class VLLMInferenceAdapterConfig(RemoteInferenceProviderConfig):
url: str | None = Field(
default=None,
description="The URL for the vLLM model serving endpoint",
)
max_tokens: int = Field(
default=4096,
description="Maximum number of tokens to generate.",
)
auth_credential: SecretStr | None = Field(
default=None,
alias="api_token",
description="The API token",
)
tls_verify: bool | str = Field(
default=True,
description="Whether to verify TLS certificates. Can be a boolean or a path to a CA certificate file.",
)
@field_validator("tls_verify")
@classmethod
def validate_tls_verify(cls, v):
if isinstance(v, str):
# Otherwise, treat it as a cert path
cert_path = Path(v).expanduser().resolve()
if not cert_path.exists():
raise ValueError(f"TLS certificate file does not exist: {v}")
if not cert_path.is_file():
raise ValueError(f"TLS certificate path is not a file: {v}")
return v
return v
@classmethod
def sample_run_config(
cls,
url: str = "${env.VLLM_URL:=}",
**kwargs,
):
return {
"url": url,
"max_tokens": "${env.VLLM_MAX_TOKENS:=4096}",
"api_token": "${env.VLLM_API_TOKEN:=fake}",
"tls_verify": "${env.VLLM_TLS_VERIFY:=true}",
}

View file

@ -1,111 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import AsyncIterator
from urllib.parse import urljoin
import httpx
from openai.types.chat.chat_completion_chunk import (
ChatCompletionChunk as OpenAIChatCompletionChunk,
)
from pydantic import ConfigDict
from llama_stack.apis.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionRequestWithExtraBody,
ToolChoice,
)
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import (
HealthResponse,
HealthStatus,
)
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .config import VLLMInferenceAdapterConfig
log = get_logger(name=__name__, category="inference::vllm")
class VLLMInferenceAdapter(OpenAIMixin):
config: VLLMInferenceAdapterConfig
model_config = ConfigDict(arbitrary_types_allowed=True)
provider_data_api_key_field: str = "vllm_api_token"
def get_api_key(self) -> str | None:
if self.config.auth_credential:
return self.config.auth_credential.get_secret_value()
return "NO KEY REQUIRED"
def get_base_url(self) -> str:
"""Get the base URL from config."""
if not self.config.url:
raise ValueError("No base URL configured")
return self.config.url
async def initialize(self) -> None:
if not self.config.url:
raise ValueError(
"You must provide a URL in run.yaml (or via the VLLM_URL environment variable) to use vLLM."
)
async def health(self) -> HealthResponse:
"""
Performs a health check by verifying connectivity to the remote vLLM server.
This method is used by the Provider API to verify
that the service is running correctly.
Uses the unauthenticated /health endpoint.
Returns:
HealthResponse: A dictionary containing the health status.
"""
try:
base_url = self.get_base_url()
health_url = urljoin(base_url, "health")
async with httpx.AsyncClient() as client:
response = await client.get(health_url)
response.raise_for_status()
return HealthResponse(status=HealthStatus.OK)
except Exception as e:
return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}")
def get_extra_client_params(self):
return {"http_client": httpx.AsyncClient(verify=self.config.tls_verify)}
async def check_model_availability(self, model: str) -> bool:
"""
Skip the check when running without authentication.
"""
if not self.config.auth_credential:
model_ids = []
async for m in self.client.models.list():
if m.id == model: # Found exact match
return True
model_ids.append(m.id)
raise ValueError(f"Model '{model}' not found. Available models: {model_ids}")
log.warning(f"Not checking model availability for {model} as API token may trigger OAuth workflow")
return True
async def openai_chat_completion(
self,
params: OpenAIChatCompletionRequestWithExtraBody,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
params = params.model_copy()
# Apply vLLM-specific defaults
if params.max_tokens is None and self.config.max_tokens:
params.max_tokens = self.config.max_tokens
# This is to be consistent with OpenAI API and support vLLM <= v0.6.3
# References:
# * https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice
# * https://github.com/vllm-project/vllm/pull/10000
if not params.tools and params.tool_choice is not None:
params.tool_choice = ToolChoice.none.value
return await super().openai_chat_completion(params)

View file

@ -1,15 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import WatsonXConfig
async def get_adapter_impl(config: WatsonXConfig, _deps):
# import dynamically so the import is used only when it is needed
from .watsonx import WatsonXInferenceAdapter
adapter = WatsonXInferenceAdapter(config)
return adapter

View file

@ -1,45 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
from typing import Any
from pydantic import BaseModel, Field
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
class WatsonXProviderDataValidator(BaseModel):
watsonx_project_id: str | None = Field(
default=None,
description="IBM WatsonX project ID",
)
watsonx_api_key: str | None = None
@json_schema_type
class WatsonXConfig(RemoteInferenceProviderConfig):
url: str = Field(
default_factory=lambda: os.getenv("WATSONX_BASE_URL", "https://us-south.ml.cloud.ibm.com"),
description="A base url for accessing the watsonx.ai",
)
project_id: str | None = Field(
default=None,
description="The watsonx.ai project ID",
)
timeout: int = Field(
default=60,
description="Timeout for the HTTP requests",
)
@classmethod
def sample_run_config(cls, **kwargs) -> dict[str, Any]:
return {
"url": "${env.WATSONX_BASE_URL:=https://us-south.ml.cloud.ibm.com}",
"api_key": "${env.WATSONX_API_KEY:=}",
"project_id": "${env.WATSONX_PROJECT_ID:=}",
}

View file

@ -1,340 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import AsyncIterator
from typing import Any
import litellm
import requests
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAIChatCompletionRequestWithExtraBody,
OpenAIChatCompletionUsage,
OpenAICompletion,
OpenAICompletionRequestWithExtraBody,
OpenAIEmbeddingsRequestWithExtraBody,
OpenAIEmbeddingsResponse,
)
from llama_stack.apis.models import Model
from llama_stack.apis.models.models import ModelType
from llama_stack.core.telemetry.tracing import get_current_span
from llama_stack.log import get_logger
from llama_stack.providers.remote.inference.watsonx.config import WatsonXConfig
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
logger = get_logger(name=__name__, category="providers::remote::watsonx")
class WatsonXInferenceAdapter(LiteLLMOpenAIMixin):
_model_cache: dict[str, Model] = {}
provider_data_api_key_field: str = "watsonx_api_key"
def __init__(self, config: WatsonXConfig):
self.available_models = None
self.config = config
api_key = config.auth_credential.get_secret_value() if config.auth_credential else None
LiteLLMOpenAIMixin.__init__(
self,
litellm_provider_name="watsonx",
api_key_from_config=api_key,
provider_data_api_key_field="watsonx_api_key",
openai_compat_api_base=self.get_base_url(),
)
async def openai_chat_completion(
self,
params: OpenAIChatCompletionRequestWithExtraBody,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
"""
Override parent method to add timeout and inject usage object when missing.
This works around a LiteLLM defect where usage block is sometimes dropped.
"""
# Add usage tracking for streaming when telemetry is active
stream_options = params.stream_options
if params.stream and get_current_span() is not None:
if stream_options is None:
stream_options = {"include_usage": True}
elif "include_usage" not in stream_options:
stream_options = {**stream_options, "include_usage": True}
model_obj = await self.model_store.get_model(params.model)
request_params = await prepare_openai_completion_params(
model=self.get_litellm_model_name(model_obj.provider_resource_id),
messages=params.messages,
frequency_penalty=params.frequency_penalty,
function_call=params.function_call,
functions=params.functions,
logit_bias=params.logit_bias,
logprobs=params.logprobs,
max_completion_tokens=params.max_completion_tokens,
max_tokens=params.max_tokens,
n=params.n,
parallel_tool_calls=params.parallel_tool_calls,
presence_penalty=params.presence_penalty,
response_format=params.response_format,
seed=params.seed,
stop=params.stop,
stream=params.stream,
stream_options=stream_options,
temperature=params.temperature,
tool_choice=params.tool_choice,
tools=params.tools,
top_logprobs=params.top_logprobs,
top_p=params.top_p,
user=params.user,
api_key=self.get_api_key(),
api_base=self.api_base,
# These are watsonx-specific parameters
timeout=self.config.timeout,
project_id=self.config.project_id,
)
result = await litellm.acompletion(**request_params)
# If not streaming, check and inject usage if missing
if not params.stream:
# Use getattr to safely handle cases where usage attribute might not exist
if getattr(result, "usage", None) is None:
# Create usage object with zeros
usage_obj = OpenAIChatCompletionUsage(
prompt_tokens=0,
completion_tokens=0,
total_tokens=0,
)
# Use model_copy to create a new response with the usage injected
result = result.model_copy(update={"usage": usage_obj})
return result
# For streaming, wrap the iterator to normalize chunks
return self._normalize_stream(result)
def _normalize_chunk(self, chunk: OpenAIChatCompletionChunk) -> OpenAIChatCompletionChunk:
"""
Normalize a chunk to ensure it has all expected attributes.
This works around LiteLLM not always including all expected attributes.
"""
# Ensure chunk has usage attribute with zeros if missing
if not hasattr(chunk, "usage") or chunk.usage is None:
usage_obj = OpenAIChatCompletionUsage(
prompt_tokens=0,
completion_tokens=0,
total_tokens=0,
)
chunk = chunk.model_copy(update={"usage": usage_obj})
# Ensure all delta objects in choices have expected attributes
if hasattr(chunk, "choices") and chunk.choices:
normalized_choices = []
for choice in chunk.choices:
if hasattr(choice, "delta") and choice.delta:
delta = choice.delta
# Build update dict for missing attributes
delta_updates = {}
if not hasattr(delta, "refusal"):
delta_updates["refusal"] = None
if not hasattr(delta, "reasoning_content"):
delta_updates["reasoning_content"] = None
# If we need to update delta, create a new choice with updated delta
if delta_updates:
new_delta = delta.model_copy(update=delta_updates)
new_choice = choice.model_copy(update={"delta": new_delta})
normalized_choices.append(new_choice)
else:
normalized_choices.append(choice)
else:
normalized_choices.append(choice)
# If we modified any choices, create a new chunk with updated choices
if any(normalized_choices[i] is not chunk.choices[i] for i in range(len(chunk.choices))):
chunk = chunk.model_copy(update={"choices": normalized_choices})
return chunk
async def _normalize_stream(
self, stream: AsyncIterator[OpenAIChatCompletionChunk]
) -> AsyncIterator[OpenAIChatCompletionChunk]:
"""
Normalize all chunks in the stream to ensure they have expected attributes.
This works around LiteLLM sometimes not including expected attributes.
"""
try:
async for chunk in stream:
# Normalize and yield each chunk immediately
yield self._normalize_chunk(chunk)
except Exception as e:
logger.error(f"Error normalizing stream: {e}", exc_info=True)
raise
async def openai_completion(
self,
params: OpenAICompletionRequestWithExtraBody,
) -> OpenAICompletion:
"""
Override parent method to add watsonx-specific parameters.
"""
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
model_obj = await self.model_store.get_model(params.model)
request_params = await prepare_openai_completion_params(
model=self.get_litellm_model_name(model_obj.provider_resource_id),
prompt=params.prompt,
best_of=params.best_of,
echo=params.echo,
frequency_penalty=params.frequency_penalty,
logit_bias=params.logit_bias,
logprobs=params.logprobs,
max_tokens=params.max_tokens,
n=params.n,
presence_penalty=params.presence_penalty,
seed=params.seed,
stop=params.stop,
stream=params.stream,
stream_options=params.stream_options,
temperature=params.temperature,
top_p=params.top_p,
user=params.user,
suffix=params.suffix,
api_key=self.get_api_key(),
api_base=self.api_base,
# These are watsonx-specific parameters
timeout=self.config.timeout,
project_id=self.config.project_id,
)
return await litellm.atext_completion(**request_params)
async def openai_embeddings(
self,
params: OpenAIEmbeddingsRequestWithExtraBody,
) -> OpenAIEmbeddingsResponse:
"""
Override parent method to add watsonx-specific parameters.
"""
model_obj = await self.model_store.get_model(params.model)
# Convert input to list if it's a string
input_list = [params.input] if isinstance(params.input, str) else params.input
# Call litellm embedding function with watsonx-specific parameters
response = litellm.embedding(
model=self.get_litellm_model_name(model_obj.provider_resource_id),
input=input_list,
api_key=self.get_api_key(),
api_base=self.api_base,
dimensions=params.dimensions,
# These are watsonx-specific parameters
timeout=self.config.timeout,
project_id=self.config.project_id,
)
# Convert response to OpenAI format
from llama_stack.apis.inference import OpenAIEmbeddingUsage
from llama_stack.providers.utils.inference.litellm_openai_mixin import b64_encode_openai_embeddings_response
data = b64_encode_openai_embeddings_response(response.data, params.encoding_format)
usage = OpenAIEmbeddingUsage(
prompt_tokens=response["usage"]["prompt_tokens"],
total_tokens=response["usage"]["total_tokens"],
)
return OpenAIEmbeddingsResponse(
data=data,
model=model_obj.provider_resource_id,
usage=usage,
)
def get_base_url(self) -> str:
return self.config.url
# Copied from OpenAIMixin
async def check_model_availability(self, model: str) -> bool:
"""
Check if a specific model is available from the provider's /v1/models.
:param model: The model identifier to check.
:return: True if the model is available dynamically, False otherwise.
"""
if not self._model_cache:
await self.list_models()
return model in self._model_cache
async def list_models(self) -> list[Model] | None:
self._model_cache = {}
models = []
for model_spec in self._get_model_specs():
functions = [f["id"] for f in model_spec.get("functions", [])]
# Format: {"embedding_dimension": 1536, "context_length": 8192}
# Example of an embedding model:
# {'model_id': 'ibm/granite-embedding-278m-multilingual',
# 'label': 'granite-embedding-278m-multilingual',
# 'model_limits': {'max_sequence_length': 512, 'embedding_dimension': 768},
# ...
provider_resource_id = f"{self.__provider_id__}/{model_spec['model_id']}"
if "embedding" in functions:
embedding_dimension = model_spec["model_limits"]["embedding_dimension"]
context_length = model_spec["model_limits"]["max_sequence_length"]
embedding_metadata = {
"embedding_dimension": embedding_dimension,
"context_length": context_length,
}
model = Model(
identifier=model_spec["model_id"],
provider_resource_id=provider_resource_id,
provider_id=self.__provider_id__,
metadata=embedding_metadata,
model_type=ModelType.embedding,
)
self._model_cache[provider_resource_id] = model
models.append(model)
if "text_chat" in functions:
model = Model(
identifier=model_spec["model_id"],
provider_resource_id=provider_resource_id,
provider_id=self.__provider_id__,
metadata={},
model_type=ModelType.llm,
)
# In theory, I guess it is possible that a model could be both an embedding model and a text chat model.
# In that case, the cache will record the generator Model object, and the list which we return will have
# both the generator Model object and the text chat Model object. That's fine because the cache is
# only used for check_model_availability() anyway.
self._model_cache[provider_resource_id] = model
models.append(model)
return models
# LiteLLM provides methods to list models for many providers, but not for watsonx.ai.
# So we need to implement our own method to list models by calling the watsonx.ai API.
def _get_model_specs(self) -> list[dict[str, Any]]:
"""
Retrieves foundation model specifications from the watsonx.ai API.
"""
url = f"{self.config.url}/ml/v1/foundation_model_specs?version=2023-10-25"
headers = {
# Note that there is no authorization header. Listing models does not require authentication.
"Content-Type": "application/json",
}
response = requests.get(url, headers=headers)
# --- Process the Response ---
# Raise an exception for bad status codes (4xx or 5xx)
response.raise_for_status()
# If the request is successful, parse and return the JSON response.
# The response should contain a list of model specifications
response_data = response.json()
if "resources" not in response_data:
raise ValueError("Resources not found in response")
return response_data["resources"]

View file

@ -1,5 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -1,150 +0,0 @@
# NVIDIA Post-Training Provider for LlamaStack
This provider enables fine-tuning of LLMs using NVIDIA's NeMo Customizer service.
## Features
- Supervised fine-tuning of Llama models
- LoRA fine-tuning support
- Job management and status tracking
## Getting Started
### Prerequisites
- LlamaStack with NVIDIA configuration
- Access to Hosted NVIDIA NeMo Customizer service
- Dataset registered in the Hosted NVIDIA NeMo Customizer service
- Base model downloaded and available in the Hosted NVIDIA NeMo Customizer service
### Setup
Build the NVIDIA environment:
```bash
uv run llama stack list-deps nvidia | xargs -L1 uv pip install
```
### Basic Usage using the LlamaStack Python Client
### Create Customization Job
#### Initialize the client
```python
import os
os.environ["NVIDIA_API_KEY"] = "your-api-key"
os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test"
os.environ["NVIDIA_DATASET_NAMESPACE"] = "default"
os.environ["NVIDIA_PROJECT_ID"] = "test-project"
os.environ["NVIDIA_OUTPUT_MODEL_DIR"] = "test-example-model@v1"
from llama_stack.core.library_client import LlamaStackAsLibraryClient
client = LlamaStackAsLibraryClient("nvidia")
client.initialize()
```
#### Configure fine-tuning parameters
```python
from llama_stack_client.types.post_training_supervised_fine_tune_params import (
TrainingConfig,
TrainingConfigDataConfig,
TrainingConfigOptimizerConfig,
)
from llama_stack_client.types.algorithm_config_param import LoraFinetuningConfig
```
#### Set up LoRA configuration
```python
algorithm_config = LoraFinetuningConfig(type="LoRA", adapter_dim=16)
```
#### Configure training data
```python
data_config = TrainingConfigDataConfig(
dataset_id="your-dataset-id", # Use client.datasets.list() to see available datasets
batch_size=16,
)
```
#### Configure optimizer
```python
optimizer_config = TrainingConfigOptimizerConfig(
lr=0.0001,
)
```
#### Set up training configuration
```python
training_config = TrainingConfig(
n_epochs=2,
data_config=data_config,
optimizer_config=optimizer_config,
)
```
#### Start fine-tuning job
```python
training_job = client.post_training.supervised_fine_tune(
job_uuid="unique-job-id",
model="meta-llama/Llama-3.1-8B-Instruct",
checkpoint_dir="",
algorithm_config=algorithm_config,
training_config=training_config,
logger_config={},
hyperparam_search_config={},
)
```
### List all jobs
```python
jobs = client.post_training.job.list()
```
### Check job status
```python
job_status = client.post_training.job.status(job_uuid="your-job-id")
```
### Cancel a job
```python
client.post_training.job.cancel(job_uuid="your-job-id")
```
### Inference with the fine-tuned model
#### 1. Register the model
```python
from llama_stack.apis.models import Model, ModelType
client.models.register(
model_id="test-example-model@v1",
provider_id="nvidia",
provider_model_id="test-example-model@v1",
model_type=ModelType.llm,
)
```
#### 2. Inference with the fine-tuned model
```python
response = client.completions.create(
prompt="Complete the sentence using one word: Roses are red, violets are ",
stream=False,
model="test-example-model@v1",
max_tokens=50,
)
print(response.choices[0].text)
```

View file

@ -1,23 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import NvidiaPostTrainingConfig
async def get_adapter_impl(
config: NvidiaPostTrainingConfig,
_deps,
):
from .post_training import NvidiaPostTrainingAdapter
if not isinstance(config, NvidiaPostTrainingConfig):
raise RuntimeError(f"Unexpected config type: {type(config)}")
impl = NvidiaPostTrainingAdapter(config)
return impl
__all__ = ["get_adapter_impl", "NvidiaPostTrainingAdapter"]

View file

@ -1,113 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
from typing import Any
from pydantic import BaseModel, Field
# TODO: add default values for all fields
class NvidiaPostTrainingConfig(BaseModel):
"""Configuration for NVIDIA Post Training implementation."""
api_key: str | None = Field(
default_factory=lambda: os.getenv("NVIDIA_API_KEY"),
description="The NVIDIA API key.",
)
dataset_namespace: str | None = Field(
default_factory=lambda: os.getenv("NVIDIA_DATASET_NAMESPACE", "default"),
description="The NVIDIA dataset namespace.",
)
project_id: str | None = Field(
default_factory=lambda: os.getenv("NVIDIA_PROJECT_ID", "test-example-model@v1"),
description="The NVIDIA project ID.",
)
# ToDO: validate this, add default value
customizer_url: str | None = Field(
default_factory=lambda: os.getenv("NVIDIA_CUSTOMIZER_URL"),
description="Base URL for the NeMo Customizer API",
)
timeout: int = Field(
default=300,
description="Timeout for the NVIDIA Post Training API",
)
max_retries: int = Field(
default=3,
description="Maximum number of retries for the NVIDIA Post Training API",
)
# ToDo: validate this
output_model_dir: str = Field(
default_factory=lambda: os.getenv("NVIDIA_OUTPUT_MODEL_DIR", "test-example-model@v1"),
description="Directory to save the output model",
)
@classmethod
def sample_run_config(cls, **kwargs) -> dict[str, Any]:
return {
"api_key": "${env.NVIDIA_API_KEY:=}",
"dataset_namespace": "${env.NVIDIA_DATASET_NAMESPACE:=default}",
"project_id": "${env.NVIDIA_PROJECT_ID:=test-project}",
"customizer_url": "${env.NVIDIA_CUSTOMIZER_URL:=http://nemo.test}",
}
class SFTLoRADefaultConfig(BaseModel):
"""NVIDIA-specific training configuration with default values."""
# ToDo: split into SFT and LoRA configs??
# General training parameters
n_epochs: int = 50
# NeMo customizer specific parameters
log_every_n_steps: int | None = None
val_check_interval: float = 0.25
sequence_packing_enabled: bool = False
weight_decay: float = 0.01
lr: float = 0.0001
# SFT specific parameters
hidden_dropout: float | None = None
attention_dropout: float | None = None
ffn_dropout: float | None = None
# LoRA default parameters
lora_adapter_dim: int = 8
lora_adapter_dropout: float | None = None
lora_alpha: int = 16
# Data config
batch_size: int = 8
@classmethod
def sample_config(cls) -> dict[str, Any]:
"""Return a sample configuration for NVIDIA training."""
return {
"n_epochs": 50,
"log_every_n_steps": 10,
"val_check_interval": 0.25,
"sequence_packing_enabled": False,
"weight_decay": 0.01,
"hidden_dropout": 0.1,
"attention_dropout": 0.1,
"lora_adapter_dim": 8,
"lora_alpha": 16,
"data_config": {
"dataset_id": "default",
"batch_size": 8,
},
"optimizer_config": {
"lr": 0.0001,
},
}

View file

@ -1,27 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.utils.inference.model_registry import (
ProviderModelEntry,
build_hf_repo_model_entry,
)
_MODEL_ENTRIES = [
build_hf_repo_model_entry(
"meta/llama-3.1-8b-instruct",
CoreModelId.llama3_1_8b_instruct.value,
),
build_hf_repo_model_entry(
"meta/llama-3.2-1b-instruct",
CoreModelId.llama3_2_1b_instruct.value,
),
]
def get_model_entries() -> list[ProviderModelEntry]:
return _MODEL_ENTRIES

View file

@ -1,430 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import warnings
from datetime import datetime
from typing import Any, Literal
import aiohttp
from pydantic import BaseModel, ConfigDict
from llama_stack.apis.post_training import (
AlgorithmConfig,
DPOAlignmentConfig,
JobStatus,
PostTrainingJob,
PostTrainingJobArtifactsResponse,
PostTrainingJobStatusResponse,
TrainingConfig,
)
from llama_stack.providers.remote.post_training.nvidia.config import NvidiaPostTrainingConfig
from llama_stack.providers.remote.post_training.nvidia.utils import warn_unsupported_params
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from .models import _MODEL_ENTRIES
# Map API status to JobStatus enum
STATUS_MAPPING = {
"running": JobStatus.in_progress.value,
"completed": JobStatus.completed.value,
"failed": JobStatus.failed.value,
"cancelled": JobStatus.cancelled.value,
"pending": JobStatus.scheduled.value,
"unknown": JobStatus.scheduled.value,
}
class NvidiaPostTrainingJob(PostTrainingJob):
"""Parse the response from the Customizer API.
Inherits job_uuid from PostTrainingJob.
Adds status, created_at, updated_at parameters.
Passes through all other parameters from data field in the response.
"""
model_config = ConfigDict(extra="allow")
status: JobStatus
created_at: datetime
updated_at: datetime
class ListNvidiaPostTrainingJobs(BaseModel):
data: list[NvidiaPostTrainingJob]
class NvidiaPostTrainingJobStatusResponse(PostTrainingJobStatusResponse):
model_config = ConfigDict(extra="allow")
class NvidiaPostTrainingAdapter(ModelRegistryHelper):
def __init__(self, config: NvidiaPostTrainingConfig):
self.config = config
self.headers = {}
if config.api_key:
self.headers["Authorization"] = f"Bearer {config.api_key}"
self.timeout = aiohttp.ClientTimeout(total=config.timeout)
# TODO: filter by available models based on /config endpoint
ModelRegistryHelper.__init__(self, model_entries=_MODEL_ENTRIES)
self.session = None
self.customizer_url = config.customizer_url
if not self.customizer_url:
warnings.warn("Customizer URL is not set, using default value: http://nemo.test", stacklevel=2)
self.customizer_url = "http://nemo.test"
async def _get_session(self) -> aiohttp.ClientSession:
if self.session is None or self.session.closed:
self.session = aiohttp.ClientSession(headers=self.headers, timeout=self.timeout)
return self.session
async def _make_request(
self,
method: str,
path: str,
headers: dict[str, Any] | None = None,
params: dict[str, Any] | None = None,
json: dict[str, Any] | None = None,
**kwargs,
) -> dict[str, Any]:
"""Helper method to make HTTP requests to the Customizer API."""
url = f"{self.customizer_url}{path}"
request_headers = self.headers.copy()
if headers:
request_headers.update(headers)
# Add content-type header for JSON requests
if json and "Content-Type" not in request_headers:
request_headers["Content-Type"] = "application/json"
session = await self._get_session()
for _ in range(self.config.max_retries):
async with session.request(method, url, params=params, json=json, **kwargs) as response:
if response.status >= 400:
error_data = await response.json()
raise Exception(f"API request failed: {error_data}")
return await response.json()
async def get_training_jobs(
self,
page: int | None = 1,
page_size: int | None = 10,
sort: Literal["created_at", "-created_at"] | None = "created_at",
) -> ListNvidiaPostTrainingJobs:
"""Get all customization jobs.
Updated the base class return type from ListPostTrainingJobsResponse to ListNvidiaPostTrainingJobs.
Returns a ListNvidiaPostTrainingJobs object with the following fields:
- data: List[NvidiaPostTrainingJob] - List of NvidiaPostTrainingJob objects
ToDo: Support for schema input for filtering.
"""
params = {"page": page, "page_size": page_size, "sort": sort}
response = await self._make_request("GET", "/v1/customization/jobs", params=params)
jobs = []
for job in response.get("data", []):
job_id = job.pop("id")
job_status = job.pop("status", "scheduled").lower()
mapped_status = STATUS_MAPPING.get(job_status, "scheduled")
# Convert string timestamps to datetime objects
created_at = (
datetime.fromisoformat(job.pop("created_at"))
if "created_at" in job
else datetime.now(tz=datetime.timezone.utc)
)
updated_at = (
datetime.fromisoformat(job.pop("updated_at"))
if "updated_at" in job
else datetime.now(tz=datetime.timezone.utc)
)
# Create NvidiaPostTrainingJob instance
jobs.append(
NvidiaPostTrainingJob(
job_uuid=job_id,
status=JobStatus(mapped_status),
created_at=created_at,
updated_at=updated_at,
**job,
)
)
return ListNvidiaPostTrainingJobs(data=jobs)
async def get_training_job_status(self, job_uuid: str) -> NvidiaPostTrainingJobStatusResponse:
"""Get the status of a customization job.
Updated the base class return type from PostTrainingJobResponse to NvidiaPostTrainingJob.
Returns a NvidiaPostTrainingJob object with the following fields:
- job_uuid: str - Unique identifier for the job
- status: JobStatus - Current status of the job (in_progress, completed, failed, cancelled, scheduled)
- created_at: datetime - The time when the job was created
- updated_at: datetime - The last time the job status was updated
Additional fields that may be included:
- steps_completed: Optional[int] - Number of training steps completed
- epochs_completed: Optional[int] - Number of epochs completed
- percentage_done: Optional[float] - Percentage of training completed (0-100)
- best_epoch: Optional[int] - The epoch with the best performance
- train_loss: Optional[float] - Training loss of the best checkpoint
- val_loss: Optional[float] - Validation loss of the best checkpoint
- metrics: Optional[Dict] - Additional training metrics
- status_logs: Optional[List] - Detailed logs of status changes
"""
response = await self._make_request(
"GET",
f"/v1/customization/jobs/{job_uuid}/status",
params={"job_id": job_uuid},
)
api_status = response.pop("status").lower()
mapped_status = STATUS_MAPPING.get(api_status, "scheduled")
return NvidiaPostTrainingJobStatusResponse(
status=JobStatus(mapped_status),
job_uuid=job_uuid,
started_at=datetime.fromisoformat(response.pop("created_at")),
updated_at=datetime.fromisoformat(response.pop("updated_at")),
**response,
)
async def cancel_training_job(self, job_uuid: str) -> None:
await self._make_request(
method="POST", path=f"/v1/customization/jobs/{job_uuid}/cancel", params={"job_id": job_uuid}
)
async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse:
raise NotImplementedError("Job artifacts are not implemented yet")
async def get_post_training_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse:
raise NotImplementedError("Job artifacts are not implemented yet")
async def supervised_fine_tune(
self,
job_uuid: str,
training_config: dict[str, Any],
hyperparam_search_config: dict[str, Any],
logger_config: dict[str, Any],
model: str,
checkpoint_dir: str | None,
algorithm_config: AlgorithmConfig | None = None,
) -> NvidiaPostTrainingJob:
"""
Fine-tunes a model on a dataset.
Currently only supports Lora finetuning for standlone docker container.
Assumptions:
- nemo microservice is running and endpoint is set in config.customizer_url
- dataset is registered separately in nemo datastore
- model checkpoint is downloaded as per nemo customizer requirements
Parameters:
training_config: TrainingConfig - Configuration for training
model: str - NeMo Customizer configuration name
algorithm_config: Optional[AlgorithmConfig] - Algorithm-specific configuration
checkpoint_dir: Optional[str] - Directory containing model checkpoints, ignored atm
job_uuid: str - Unique identifier for the job, ignored atm
hyperparam_search_config: Dict[str, Any] - Configuration for hyperparameter search, ignored atm
logger_config: Dict[str, Any] - Configuration for logging, ignored atm
Environment Variables:
- NVIDIA_API_KEY: str - API key for the NVIDIA API
Default: None
- NVIDIA_DATASET_NAMESPACE: str - Namespace of the dataset
Default: "default"
- NVIDIA_CUSTOMIZER_URL: str - URL of the NeMo Customizer API
Default: "http://nemo.test"
- NVIDIA_PROJECT_ID: str - ID of the project
Default: "test-project"
- NVIDIA_OUTPUT_MODEL_DIR: str - Directory to save the output model
Default: "test-example-model@v1"
Supported models:
- meta/llama-3.1-8b-instruct
- meta/llama-3.2-1b-instruct
Supported algorithm configs:
- LoRA, SFT
Supported Parameters:
- TrainingConfig:
- n_epochs: int - Number of epochs to train
Default: 50
- data_config: DataConfig - Configuration for the dataset
- optimizer_config: OptimizerConfig - Configuration for the optimizer
- dtype: str - Data type for training
not supported (users are informed via warnings)
- efficiency_config: EfficiencyConfig - Configuration for efficiency
not supported
- max_steps_per_epoch: int - Maximum number of steps per epoch
Default: 1000
## NeMo customizer specific parameters
- log_every_n_steps: int - Log every n steps
Default: None
- val_check_interval: float - Validation check interval
Default: 0.25
- sequence_packing_enabled: bool - Sequence packing enabled
Default: False
## NeMo customizer specific SFT parameters
- hidden_dropout: float - Hidden dropout
Default: None (0.0-1.0)
- attention_dropout: float - Attention dropout
Default: None (0.0-1.0)
- ffn_dropout: float - FFN dropout
Default: None (0.0-1.0)
- DataConfig:
- dataset_id: str - Dataset ID
- batch_size: int - Batch size
Default: 8
- OptimizerConfig:
- lr: float - Learning rate
Default: 0.0001
## NeMo customizer specific parameter
- weight_decay: float - Weight decay
Default: 0.01
- LoRA config:
## NeMo customizer specific LoRA parameters
- alpha: int - Scaling factor for the LoRA update
Default: 16
Note:
- checkpoint_dir, hyperparam_search_config, logger_config are not supported (users are informed via warnings)
- Some parameters from TrainingConfig, DataConfig, OptimizerConfig are not supported (users are informed via warnings)
User is informed about unsupported parameters via warnings.
"""
# Check for unsupported method parameters
unsupported_method_params = []
if checkpoint_dir:
unsupported_method_params.append(f"checkpoint_dir={checkpoint_dir}")
if hyperparam_search_config:
unsupported_method_params.append("hyperparam_search_config")
if logger_config:
unsupported_method_params.append("logger_config")
if unsupported_method_params:
warnings.warn(
f"Parameters: {', '.join(unsupported_method_params)} are not supported and will be ignored",
stacklevel=2,
)
# Define all supported parameters
supported_params = {
"training_config": {
"n_epochs",
"data_config",
"optimizer_config",
"log_every_n_steps",
"val_check_interval",
"sequence_packing_enabled",
"hidden_dropout",
"attention_dropout",
"ffn_dropout",
},
"data_config": {"dataset_id", "batch_size"},
"optimizer_config": {"lr", "weight_decay"},
"lora_config": {"type", "alpha"},
}
# Validate all parameters at once
warn_unsupported_params(training_config, supported_params["training_config"], "TrainingConfig")
warn_unsupported_params(training_config["data_config"], supported_params["data_config"], "DataConfig")
warn_unsupported_params(
training_config["optimizer_config"], supported_params["optimizer_config"], "OptimizerConfig"
)
output_model = self.config.output_model_dir
# Prepare base job configuration
job_config = {
"config": model,
"dataset": {
"name": training_config["data_config"]["dataset_id"],
"namespace": self.config.dataset_namespace,
},
"hyperparameters": {
"training_type": "sft",
"finetuning_type": "lora",
**{
k: v
for k, v in {
"epochs": training_config.get("n_epochs"),
"batch_size": training_config["data_config"].get("batch_size"),
"learning_rate": training_config["optimizer_config"].get("lr"),
"weight_decay": training_config["optimizer_config"].get("weight_decay"),
"log_every_n_steps": training_config.get("log_every_n_steps"),
"val_check_interval": training_config.get("val_check_interval"),
"sequence_packing_enabled": training_config.get("sequence_packing_enabled"),
}.items()
if v is not None
},
},
"project": self.config.project_id,
# TODO: ignored ownership, add it later
# "ownership": {"created_by": self.config.user_id, "access_policies": self.config.access_policies},
"output_model": output_model,
}
# Handle SFT-specific optional parameters
job_config["hyperparameters"]["sft"] = {
k: v
for k, v in {
"ffn_dropout": training_config.get("ffn_dropout"),
"hidden_dropout": training_config.get("hidden_dropout"),
"attention_dropout": training_config.get("attention_dropout"),
}.items()
if v is not None
}
# Remove the sft dictionary if it's empty
if not job_config["hyperparameters"]["sft"]:
job_config["hyperparameters"].pop("sft")
# Handle LoRA-specific configuration
if algorithm_config:
if algorithm_config.type == "LoRA":
warn_unsupported_params(algorithm_config, supported_params["lora_config"], "LoRA config")
job_config["hyperparameters"]["lora"] = {
k: v for k, v in {"alpha": algorithm_config.alpha}.items() if v is not None
}
else:
raise NotImplementedError(f"Unsupported algorithm config: {algorithm_config}")
# Create the customization job
response = await self._make_request(
method="POST",
path="/v1/customization/jobs",
headers={"Accept": "application/json"},
json=job_config,
)
job_uuid = response["id"]
response.pop("status")
created_at = datetime.fromisoformat(response.pop("created_at"))
updated_at = datetime.fromisoformat(response.pop("updated_at"))
return NvidiaPostTrainingJob(
job_uuid=job_uuid, status=JobStatus.in_progress, created_at=created_at, updated_at=updated_at, **response
)
async def preference_optimize(
self,
job_uuid: str,
finetuned_model: str,
algorithm_config: DPOAlignmentConfig,
training_config: TrainingConfig,
hyperparam_search_config: dict[str, Any],
logger_config: dict[str, Any],
) -> PostTrainingJob:
"""Optimize a model based on preference data."""
raise NotImplementedError("Preference optimization is not implemented yet")
async def get_training_job_container_logs(self, job_uuid: str) -> PostTrainingJobStatusResponse:
raise NotImplementedError("Job logs are not implemented yet")

View file

@ -1,63 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import warnings
from typing import Any
from pydantic import BaseModel
from llama_stack.apis.post_training import TrainingConfig
from llama_stack.log import get_logger
from llama_stack.providers.remote.post_training.nvidia.config import SFTLoRADefaultConfig
from .config import NvidiaPostTrainingConfig
logger = get_logger(name=__name__, category="post_training::nvidia")
def warn_unsupported_params(config_dict: Any, supported_keys: set[str], config_name: str) -> None:
keys = set(config_dict.__annotations__.keys()) if isinstance(config_dict, BaseModel) else config_dict.keys()
unsupported_params = [k for k in keys if k not in supported_keys]
if unsupported_params:
warnings.warn(
f"Parameters: {unsupported_params} in `{config_name}` not supported and will be ignored.", stacklevel=2
)
def validate_training_params(
training_config: dict[str, Any], supported_keys: set[str], config_name: str = "TrainingConfig"
) -> None:
"""
Validates training parameters against supported keys.
Args:
training_config: Dictionary containing training configuration parameters
supported_keys: Set of supported parameter keys
config_name: Name of the configuration for warning messages
"""
sft_lora_fields = set(SFTLoRADefaultConfig.__annotations__.keys())
training_config_fields = set(TrainingConfig.__annotations__.keys())
# Check for not supported parameters:
# - not in either of configs
# - in TrainingConfig but not in SFTLoRADefaultConfig
unsupported_params = []
for key in training_config:
if isinstance(key, str) and key not in (supported_keys.union(sft_lora_fields)):
if key in (not sft_lora_fields or training_config_fields):
unsupported_params.append(key)
if unsupported_params:
warnings.warn(
f"Parameters: {unsupported_params} in `{config_name}` are not supported and will be ignored.", stacklevel=2
)
# ToDo: implement post health checks for customizer are enabled
async def _get_health(url: str) -> tuple[bool, bool]: ...
async def check_health(config: NvidiaPostTrainingConfig) -> None: ...

View file

@ -1,5 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -1,18 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from .config import BedrockSafetyConfig
async def get_adapter_impl(config: BedrockSafetyConfig, _deps) -> Any:
from .bedrock import BedrockSafetyAdapter
impl = BedrockSafetyAdapter(config)
await impl.initialize()
return impl

View file

@ -1,111 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from typing import Any
from llama_stack.apis.inference import OpenAIMessageParam
from llama_stack.apis.safety import (
RunShieldResponse,
Safety,
SafetyViolation,
ViolationLevel,
)
from llama_stack.apis.shields import Shield
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from llama_stack.providers.utils.bedrock.client import create_bedrock_client
from .config import BedrockSafetyConfig
logger = get_logger(name=__name__, category="safety::bedrock")
class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
def __init__(self, config: BedrockSafetyConfig) -> None:
self.config = config
self.registered_shields = []
async def initialize(self) -> None:
try:
self.bedrock_runtime_client = create_bedrock_client(self.config)
self.bedrock_client = create_bedrock_client(self.config, "bedrock")
except Exception as e:
raise RuntimeError("Error initializing BedrockSafetyAdapter") from e
async def shutdown(self) -> None:
pass
async def register_shield(self, shield: Shield) -> None:
response = self.bedrock_client.list_guardrails(
guardrailIdentifier=shield.provider_resource_id,
)
if (
not response["guardrails"]
or len(response["guardrails"]) == 0
or response["guardrails"][0]["version"] != shield.params["guardrailVersion"]
):
raise ValueError(
f"Shield {shield.provider_resource_id} with version {shield.params['guardrailVersion']} not found in Bedrock"
)
async def unregister_shield(self, identifier: str) -> None:
pass
async def run_shield(
self, shield_id: str, messages: list[OpenAIMessageParam], params: dict[str, Any] = None
) -> RunShieldResponse:
shield = await self.shield_store.get_shield(shield_id)
if not shield:
raise ValueError(f"Shield {shield_id} not found")
"""
This is the implementation for the bedrock guardrails. The input to the guardrails is to be of this format
```content = [
{
"text": {
"text": "Is the AB503 Product a better investment than the S&P 500?"
}
}
]```
Incoming messages contain content, role . For now we will extract the content and
default the "qualifiers": ["query"]
"""
shield_params = shield.params
logger.debug(f"run_shield::{shield_params}::messages={messages}")
# - convert the messages into format Bedrock expects
content_messages = []
for message in messages:
content_messages.append({"text": {"text": message.content}})
logger.debug(f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:")
response = self.bedrock_runtime_client.apply_guardrail(
guardrailIdentifier=shield.provider_resource_id,
guardrailVersion=shield_params["guardrailVersion"],
source="OUTPUT", # or 'INPUT' depending on your use case
content=content_messages,
)
if response["action"] == "GUARDRAIL_INTERVENED":
user_message = ""
metadata = {}
for output in response["outputs"]:
# guardrails returns a list - however for this implementation we will leverage the last values
user_message = output["text"]
for assessment in response["assessments"]:
# guardrails returns a list - however for this implementation we will leverage the last values
metadata = dict(assessment)
return RunShieldResponse(
violation=SafetyViolation(
user_message=user_message,
violation_level=ViolationLevel.ERROR,
metadata=metadata,
)
)
return RunShieldResponse()

View file

@ -1,14 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.providers.utils.bedrock.config import BedrockBaseConfig
from llama_stack.schema_utils import json_schema_type
@json_schema_type
class BedrockSafetyConfig(BedrockBaseConfig):
pass

View file

@ -1,77 +0,0 @@
# NVIDIA Safety Provider for LlamaStack
This provider enables safety checks and guardrails for LLM interactions using NVIDIA's NeMo Guardrails service.
## Features
- Run safety checks for messages
## Getting Started
### Prerequisites
- LlamaStack with NVIDIA configuration
- Access to NVIDIA NeMo Guardrails service
- NIM for model to use for safety check is deployed
### Setup
Build the NVIDIA environment:
```bash
uv run llama stack list-deps nvidia | xargs -L1 uv pip install
```
### Basic Usage using the LlamaStack Python Client
#### Initialize the client
```python
import os
os.environ["NVIDIA_API_KEY"] = "your-api-key"
os.environ["NVIDIA_GUARDRAILS_URL"] = "http://guardrails.test"
from llama_stack.core.library_client import LlamaStackAsLibraryClient
client = LlamaStackAsLibraryClient("nvidia")
client.initialize()
```
#### Create a safety shield
```python
from llama_stack.apis.safety import Shield
from llama_stack.apis.inference import Message
# Create a safety shield
shield = Shield(
shield_id="your-shield-id",
provider_resource_id="safety-model-id", # The model to use for safety checks
description="Safety checks for content moderation",
)
# Register the shield
await client.safety.register_shield(shield)
```
#### Run safety checks
```python
# Messages to check
messages = [Message(role="user", content="Your message to check")]
# Run safety check
response = await client.safety.run_shield(
shield_id="your-shield-id",
messages=messages,
)
# Check for violations
if response.violation:
print(f"Safety violation detected: {response.violation.user_message}")
print(f"Violation level: {response.violation.violation_level}")
print(f"Metadata: {response.violation.metadata}")
else:
print("No safety violations detected")
```

View file

@ -1,18 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from .config import NVIDIASafetyConfig
async def get_adapter_impl(config: NVIDIASafetyConfig, _deps) -> Any:
from .nvidia import NVIDIASafetyAdapter
impl = NVIDIASafetyAdapter(config)
await impl.initialize()
return impl

View file

@ -1,40 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
from typing import Any
from pydantic import BaseModel, Field
from llama_stack.schema_utils import json_schema_type
@json_schema_type
class NVIDIASafetyConfig(BaseModel):
"""
Configuration for the NVIDIA Guardrail microservice endpoint.
Attributes:
guardrails_service_url (str): A base url for accessing the NVIDIA guardrail endpoint, e.g. http://0.0.0.0:7331
config_id (str): The ID of the guardrails configuration to use from the configuration store
(https://developer.nvidia.com/docs/nemo-microservices/guardrails/source/guides/configuration-store-guide.html)
"""
guardrails_service_url: str = Field(
default_factory=lambda: os.getenv("GUARDRAILS_SERVICE_URL", "http://0.0.0.0:7331"),
description="The url for accessing the Guardrails service",
)
config_id: str | None = Field(
default_factory=lambda: os.getenv("NVIDIA_GUARDRAILS_CONFIG_ID", "self-check"),
description="Guardrails configuration ID to use from the Guardrails configuration store",
)
@classmethod
def sample_run_config(cls, **kwargs) -> dict[str, Any]:
return {
"guardrails_service_url": "${env.GUARDRAILS_SERVICE_URL:=http://localhost:7331}",
"config_id": "${env.NVIDIA_GUARDRAILS_CONFIG_ID:=self-check}",
}

View file

@ -1,161 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
import requests
from llama_stack.apis.inference import OpenAIMessageParam
from llama_stack.apis.safety import ModerationObject, RunShieldResponse, Safety, SafetyViolation, ViolationLevel
from llama_stack.apis.shields import Shield
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from .config import NVIDIASafetyConfig
logger = get_logger(name=__name__, category="safety::nvidia")
class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
def __init__(self, config: NVIDIASafetyConfig) -> None:
"""
Initialize the NVIDIASafetyAdapter with a given safety configuration.
Args:
config (NVIDIASafetyConfig): The configuration containing the guardrails service URL and config ID.
"""
self.config = config
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def register_shield(self, shield: Shield) -> None:
if not shield.provider_resource_id:
raise ValueError("Shield model not provided.")
async def unregister_shield(self, identifier: str) -> None:
pass
async def run_shield(
self, shield_id: str, messages: list[OpenAIMessageParam], params: dict[str, Any] | None = None
) -> RunShieldResponse:
"""
Run a safety shield check against the provided messages.
Args:
shield_id (str): The unique identifier for the shield to be used.
messages (List[Message]): A list of Message objects representing the conversation history.
params (Optional[dict[str, Any]]): Additional parameters for the shield check.
Returns:
RunShieldResponse: The response containing safety violation details if any.
Raises:
ValueError: If the shield with the provided shield_id is not found.
"""
shield = await self.shield_store.get_shield(shield_id)
if not shield:
raise ValueError(f"Shield {shield_id} not found")
self.shield = NeMoGuardrails(self.config, shield.shield_id)
return await self.shield.run(messages)
async def run_moderation(self, input: str | list[str], model: str | None = None) -> ModerationObject:
raise NotImplementedError("NVIDIA safety provider currently does not implement run_moderation")
class NeMoGuardrails:
"""
A class that encapsulates NVIDIA's guardrails safety logic.
Sends messages to the guardrails service and interprets the response to determine
if a safety violation has occurred.
"""
def __init__(
self,
config: NVIDIASafetyConfig,
model: str,
threshold: float = 0.9,
temperature: float = 1.0,
):
"""
Initialize a NeMoGuardrails instance with the provided parameters.
Args:
config (NVIDIASafetyConfig): The safety configuration containing the config ID and guardrails URL.
model (str): The identifier or name of the model to be used for safety checks.
threshold (float, optional): The threshold for flagging violations. Defaults to 0.9.
temperature (float, optional): The temperature setting for the underlying model. Must be greater than 0. Defaults to 1.0.
Raises:
ValueError: If temperature is less than or equal to 0.
AssertionError: If config_id is not provided in the configuration.
"""
self.config_id = config.config_id
self.model = model
assert self.config_id is not None, "Must provide config id"
if temperature <= 0:
raise ValueError("Temperature must be greater than 0")
self.temperature = temperature
self.threshold = threshold
self.guardrails_service_url = config.guardrails_service_url
async def _guardrails_post(self, path: str, data: Any | None):
"""Helper for making POST requests to the guardrails service."""
headers = {
"Accept": "application/json",
}
response = requests.post(url=f"{self.guardrails_service_url}{path}", headers=headers, json=data)
response.raise_for_status()
return response.json()
async def run(self, messages: list[OpenAIMessageParam]) -> RunShieldResponse:
"""
Queries the /v1/guardrails/checks endpoint of the NeMo guardrails deployed API.
Args:
messages (List[Message]): A list of Message objects to be checked for safety violations.
Returns:
RunShieldResponse: If the response indicates a violation ("blocked" status), returns a
RunShieldResponse with a SafetyViolation; otherwise, returns a RunShieldResponse with violation set to None.
Raises:
requests.HTTPError: If the POST request fails.
"""
request_data = {
"model": self.model,
"messages": [{"role": message.role, "content": message.content} for message in messages],
"temperature": self.temperature,
"top_p": 1,
"frequency_penalty": 0,
"presence_penalty": 0,
"max_tokens": 160,
"stream": False,
"guardrails": {
"config_id": self.config_id,
},
}
response = await self._guardrails_post(path="/v1/guardrail/checks", data=request_data)
if response["status"] == "blocked":
user_message = "Sorry I cannot do this."
metadata = response["rails_status"]
return RunShieldResponse(
violation=SafetyViolation(
user_message=user_message,
violation_level=ViolationLevel.ERROR,
metadata=metadata,
)
)
return RunShieldResponse(violation=None)

View file

@ -1,18 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from .config import SambaNovaSafetyConfig
async def get_adapter_impl(config: SambaNovaSafetyConfig, _deps) -> Any:
from .sambanova import SambaNovaSafetyAdapter
impl = SambaNovaSafetyAdapter(config)
await impl.initialize()
return impl

View file

@ -1,37 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from pydantic import BaseModel, Field, SecretStr
from llama_stack.schema_utils import json_schema_type
class SambaNovaProviderDataValidator(BaseModel):
sambanova_api_key: str | None = Field(
default=None,
description="Sambanova Cloud API key",
)
@json_schema_type
class SambaNovaSafetyConfig(BaseModel):
url: str = Field(
default="https://api.sambanova.ai/v1",
description="The URL for the SambaNova AI server",
)
api_key: SecretStr | None = Field(
default=None,
description="The SambaNova cloud API Key",
)
@classmethod
def sample_run_config(cls, api_key: str = "${env.SAMBANOVA_API_KEY:=}", **kwargs) -> dict[str, Any]:
return {
"url": "https://api.sambanova.ai/v1",
"api_key": api_key,
}

Some files were not shown because too many files have changed in this diff Show more