mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 18:00:36 +00:00
Merge upstream/main into add-mongodb-vector_io
Resolved conflicts: - Integrated MongoDB provider with newly added Qdrant and Weaviate providers - Updated distribution configs to include all three providers - Merged build.yaml and run.yaml configs for ci-tests, starter, and starter-gpu distributions - Updated starter.py to include MongoDB, Qdrant, and Weaviate provider initialization - Added MongoDB provider files to src/ directory structure - Updated MongoDB provider to use new VectorStore API (was VectorDB) - Updated MongoDB config to use KVStoreReference instead of KVStoreConfig - Applied auto-formatting changes from pre-commit hooks
This commit is contained in:
commit
efe9c04849
1820 changed files with 402590 additions and 32499 deletions
|
|
@ -1,5 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
|
@ -1,5 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
|
@ -1,5 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
|
@ -1,18 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import HuggingfaceDatasetIOConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(
|
||||
config: HuggingfaceDatasetIOConfig,
|
||||
_deps,
|
||||
):
|
||||
from .huggingface import HuggingfaceDatasetIOImpl
|
||||
|
||||
impl = HuggingfaceDatasetIOImpl(config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
@ -1,26 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.providers.utils.kvstore.config import (
|
||||
KVStoreConfig,
|
||||
SqliteKVStoreConfig,
|
||||
)
|
||||
|
||||
|
||||
class HuggingfaceDatasetIOConfig(BaseModel):
|
||||
kvstore: KVStoreConfig
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
||||
return {
|
||||
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
db_name="huggingface_datasetio.db",
|
||||
)
|
||||
}
|
||||
|
|
@ -1,99 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
from llama_stack.apis.common.responses import PaginatedResponse
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Dataset
|
||||
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.pagination import paginate_records
|
||||
|
||||
from .config import HuggingfaceDatasetIOConfig
|
||||
|
||||
DATASETS_PREFIX = "datasets:"
|
||||
|
||||
|
||||
def parse_hf_params(dataset_def: Dataset):
|
||||
uri = dataset_def.source.uri
|
||||
parsed_uri = urlparse(uri)
|
||||
params = parse_qs(parsed_uri.query)
|
||||
params = {k: v[0] for k, v in params.items()}
|
||||
path = parsed_uri.path.lstrip("/")
|
||||
|
||||
return path, params
|
||||
|
||||
|
||||
class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
||||
def __init__(self, config: HuggingfaceDatasetIOConfig) -> None:
|
||||
self.config = config
|
||||
# local registry for keeping track of datasets within the provider
|
||||
self.dataset_infos = {}
|
||||
self.kvstore = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
self.kvstore = await kvstore_impl(self.config.kvstore)
|
||||
# Load existing datasets from kvstore
|
||||
start_key = DATASETS_PREFIX
|
||||
end_key = f"{DATASETS_PREFIX}\xff"
|
||||
stored_datasets = await self.kvstore.values_in_range(start_key, end_key)
|
||||
|
||||
for dataset in stored_datasets:
|
||||
dataset = Dataset.model_validate_json(dataset)
|
||||
self.dataset_infos[dataset.identifier] = dataset
|
||||
|
||||
async def shutdown(self) -> None: ...
|
||||
|
||||
async def register_dataset(
|
||||
self,
|
||||
dataset_def: Dataset,
|
||||
) -> None:
|
||||
# Store in kvstore
|
||||
key = f"{DATASETS_PREFIX}{dataset_def.identifier}"
|
||||
await self.kvstore.set(
|
||||
key=key,
|
||||
value=dataset_def.model_dump_json(),
|
||||
)
|
||||
self.dataset_infos[dataset_def.identifier] = dataset_def
|
||||
|
||||
async def unregister_dataset(self, dataset_id: str) -> None:
|
||||
key = f"{DATASETS_PREFIX}{dataset_id}"
|
||||
await self.kvstore.delete(key=key)
|
||||
del self.dataset_infos[dataset_id]
|
||||
|
||||
async def iterrows(
|
||||
self,
|
||||
dataset_id: str,
|
||||
start_index: int | None = None,
|
||||
limit: int | None = None,
|
||||
) -> PaginatedResponse:
|
||||
import datasets as hf_datasets
|
||||
|
||||
dataset_def = self.dataset_infos[dataset_id]
|
||||
path, params = parse_hf_params(dataset_def)
|
||||
loaded_dataset = hf_datasets.load_dataset(path, **params)
|
||||
|
||||
records = [loaded_dataset[i] for i in range(len(loaded_dataset))]
|
||||
return paginate_records(records, start_index, limit)
|
||||
|
||||
async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None:
|
||||
import datasets as hf_datasets
|
||||
|
||||
dataset_def = self.dataset_infos[dataset_id]
|
||||
path, params = parse_hf_params(dataset_def)
|
||||
loaded_dataset = hf_datasets.load_dataset(path, **params)
|
||||
|
||||
# Convert rows to HF Dataset format
|
||||
new_dataset = hf_datasets.Dataset.from_list(rows)
|
||||
|
||||
# Concatenate the new rows with existing dataset
|
||||
updated_dataset = hf_datasets.concatenate_datasets([loaded_dataset, new_dataset])
|
||||
|
||||
if dataset_def.metadata.get("path", None):
|
||||
updated_dataset.push_to_hub(dataset_def.metadata["path"])
|
||||
else:
|
||||
raise NotImplementedError("Uploading to URL-based datasets is not supported yet")
|
||||
|
|
@ -1,73 +0,0 @@
|
|||
# NVIDIA DatasetIO Provider for LlamaStack
|
||||
|
||||
This provider enables dataset management using NVIDIA's NeMo Customizer service.
|
||||
|
||||
## Features
|
||||
|
||||
- Register datasets for fine-tuning LLMs
|
||||
- Unregister datasets
|
||||
|
||||
## Getting Started
|
||||
|
||||
### Prerequisites
|
||||
|
||||
- LlamaStack with NVIDIA configuration
|
||||
- Access to Hosted NVIDIA NeMo Microservice
|
||||
- API key for authentication with the NVIDIA service
|
||||
|
||||
### Setup
|
||||
|
||||
Build the NVIDIA environment:
|
||||
|
||||
```bash
|
||||
llama stack build --distro nvidia --image-type venv
|
||||
```
|
||||
|
||||
### Basic Usage using the LlamaStack Python Client
|
||||
|
||||
#### Initialize the client
|
||||
|
||||
```python
|
||||
import os
|
||||
|
||||
os.environ["NVIDIA_API_KEY"] = "your-api-key"
|
||||
os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test"
|
||||
os.environ["NVIDIA_DATASET_NAMESPACE"] = "default"
|
||||
os.environ["NVIDIA_PROJECT_ID"] = "test-project"
|
||||
from llama_stack.core.library_client import LlamaStackAsLibraryClient
|
||||
|
||||
client = LlamaStackAsLibraryClient("nvidia")
|
||||
client.initialize()
|
||||
```
|
||||
|
||||
#### Register a dataset
|
||||
|
||||
```python
|
||||
client.datasets.register(
|
||||
purpose="post-training/messages",
|
||||
dataset_id="my-training-dataset",
|
||||
source={"type": "uri", "uri": "hf://datasets/default/sample-dataset"},
|
||||
metadata={
|
||||
"format": "json",
|
||||
"description": "Dataset for LLM fine-tuning",
|
||||
"provider": "nvidia",
|
||||
},
|
||||
)
|
||||
```
|
||||
|
||||
#### Get a list of all registered datasets
|
||||
|
||||
```python
|
||||
datasets = client.datasets.list()
|
||||
for dataset in datasets:
|
||||
print(f"Dataset ID: {dataset.identifier}")
|
||||
print(f"Description: {dataset.metadata.get('description', '')}")
|
||||
print(f"Source: {dataset.source.uri}")
|
||||
print("---")
|
||||
```
|
||||
|
||||
#### Unregister a dataset
|
||||
|
||||
```python
|
||||
client.datasets.unregister(dataset_id="my-training-dataset")
|
||||
```
|
||||
|
|
@ -1,23 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import NvidiaDatasetIOConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(
|
||||
config: NvidiaDatasetIOConfig,
|
||||
_deps,
|
||||
):
|
||||
from .datasetio import NvidiaDatasetIOAdapter
|
||||
|
||||
if not isinstance(config, NvidiaDatasetIOConfig):
|
||||
raise RuntimeError(f"Unexpected config type: {type(config)}")
|
||||
|
||||
impl = NvidiaDatasetIOAdapter(config)
|
||||
return impl
|
||||
|
||||
|
||||
__all__ = ["get_adapter_impl", "NvidiaDatasetIOAdapter"]
|
||||
|
|
@ -1,61 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import warnings
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class NvidiaDatasetIOConfig(BaseModel):
|
||||
"""Configuration for NVIDIA DatasetIO implementation."""
|
||||
|
||||
api_key: str | None = Field(
|
||||
default_factory=lambda: os.getenv("NVIDIA_API_KEY"),
|
||||
description="The NVIDIA API key.",
|
||||
)
|
||||
|
||||
dataset_namespace: str | None = Field(
|
||||
default_factory=lambda: os.getenv("NVIDIA_DATASET_NAMESPACE", "default"),
|
||||
description="The NVIDIA dataset namespace.",
|
||||
)
|
||||
|
||||
project_id: str | None = Field(
|
||||
default_factory=lambda: os.getenv("NVIDIA_PROJECT_ID", "test-project"),
|
||||
description="The NVIDIA project ID.",
|
||||
)
|
||||
|
||||
datasets_url: str = Field(
|
||||
default_factory=lambda: os.getenv("NVIDIA_DATASETS_URL", "http://nemo.test"),
|
||||
description="Base URL for the NeMo Dataset API",
|
||||
)
|
||||
|
||||
# warning for default values
|
||||
def __post_init__(self):
|
||||
default_values = []
|
||||
if os.getenv("NVIDIA_PROJECT_ID") is None:
|
||||
default_values.append("project_id='test-project'")
|
||||
if os.getenv("NVIDIA_DATASET_NAMESPACE") is None:
|
||||
default_values.append("dataset_namespace='default'")
|
||||
if os.getenv("NVIDIA_DATASETS_URL") is None:
|
||||
default_values.append("datasets_url='http://nemo.test'")
|
||||
|
||||
if default_values:
|
||||
warnings.warn(
|
||||
f"Using default values: {', '.join(default_values)}. \
|
||||
Please set the environment variables to avoid this default behavior.",
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
"api_key": "${env.NVIDIA_API_KEY:=}",
|
||||
"dataset_namespace": "${env.NVIDIA_DATASET_NAMESPACE:=default}",
|
||||
"project_id": "${env.NVIDIA_PROJECT_ID:=test-project}",
|
||||
"datasets_url": "${env.NVIDIA_DATASETS_URL:=http://nemo.test}",
|
||||
}
|
||||
|
|
@ -1,116 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.common.responses import PaginatedResponse
|
||||
from llama_stack.apis.common.type_system import ParamType
|
||||
from llama_stack.apis.datasets import Dataset
|
||||
|
||||
from .config import NvidiaDatasetIOConfig
|
||||
|
||||
|
||||
class NvidiaDatasetIOAdapter:
|
||||
"""Nvidia NeMo DatasetIO API."""
|
||||
|
||||
def __init__(self, config: NvidiaDatasetIOConfig):
|
||||
self.config = config
|
||||
self.headers = {}
|
||||
|
||||
async def _make_request(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
headers: dict[str, Any] | None = None,
|
||||
params: dict[str, Any] | None = None,
|
||||
json: dict[str, Any] | None = None,
|
||||
**kwargs,
|
||||
) -> dict[str, Any]:
|
||||
"""Helper method to make HTTP requests to the Customizer API."""
|
||||
url = f"{self.config.datasets_url}{path}"
|
||||
request_headers = self.headers.copy()
|
||||
|
||||
# Set default Content-Type for JSON requests
|
||||
if json is not None:
|
||||
request_headers["Content-Type"] = "application/json"
|
||||
|
||||
if headers:
|
||||
request_headers.update(headers)
|
||||
|
||||
async with aiohttp.ClientSession(headers=request_headers) as session:
|
||||
async with session.request(method, url, params=params, json=json, **kwargs) as response:
|
||||
if response.status != 200:
|
||||
error_data = await response.json()
|
||||
raise Exception(f"API request failed: {error_data}")
|
||||
return await response.json()
|
||||
|
||||
async def register_dataset(
|
||||
self,
|
||||
dataset_def: Dataset,
|
||||
) -> Dataset:
|
||||
"""Register a new dataset.
|
||||
|
||||
Args:
|
||||
dataset_def [Dataset]: The dataset definition.
|
||||
dataset_id [str]: The ID of the dataset.
|
||||
source [DataSource]: The source of the dataset.
|
||||
metadata [Dict[str, Any]]: The metadata of the dataset.
|
||||
format [str]: The format of the dataset.
|
||||
description [str]: The description of the dataset.
|
||||
Returns:
|
||||
Dataset
|
||||
"""
|
||||
# add warnings for unsupported params
|
||||
request_body = {
|
||||
"name": dataset_def.identifier,
|
||||
"namespace": self.config.dataset_namespace,
|
||||
"files_url": dataset_def.source.uri,
|
||||
"project": self.config.project_id,
|
||||
}
|
||||
if dataset_def.metadata:
|
||||
request_body["format"] = dataset_def.metadata.get("format")
|
||||
request_body["description"] = dataset_def.metadata.get("description")
|
||||
await self._make_request(
|
||||
"POST",
|
||||
"/v1/datasets",
|
||||
json=request_body,
|
||||
)
|
||||
return dataset_def
|
||||
|
||||
async def update_dataset(
|
||||
self,
|
||||
dataset_id: str,
|
||||
dataset_schema: dict[str, ParamType],
|
||||
url: URL,
|
||||
provider_dataset_id: str | None = None,
|
||||
provider_id: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
async def unregister_dataset(
|
||||
self,
|
||||
dataset_id: str,
|
||||
) -> None:
|
||||
await self._make_request(
|
||||
"DELETE",
|
||||
f"/v1/datasets/{self.config.dataset_namespace}/{dataset_id}",
|
||||
headers={"Accept": "application/json", "Content-Type": "application/json"},
|
||||
)
|
||||
|
||||
async def iterrows(
|
||||
self,
|
||||
dataset_id: str,
|
||||
start_index: int | None = None,
|
||||
limit: int | None = None,
|
||||
) -> PaginatedResponse:
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None:
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
|
@ -1,5 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
|
@ -1,134 +0,0 @@
|
|||
# NVIDIA NeMo Evaluator Eval Provider
|
||||
|
||||
|
||||
## Overview
|
||||
|
||||
For the first integration, Benchmarks are mapped to Evaluation Configs on in the NeMo Evaluator. The full evaluation config object is provided as part of the meta-data. The `dataset_id` and `scoring_functions` are not used.
|
||||
|
||||
Below are a few examples of how to register a benchmark, which in turn will create an evaluation config in NeMo Evaluator and how to trigger an evaluation.
|
||||
|
||||
### Example for register an academic benchmark
|
||||
|
||||
```
|
||||
POST /eval/benchmarks
|
||||
```
|
||||
```json
|
||||
{
|
||||
"benchmark_id": "mmlu",
|
||||
"dataset_id": "",
|
||||
"scoring_functions": [],
|
||||
"metadata": {
|
||||
"type": "mmlu"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Example for register a custom evaluation
|
||||
|
||||
```
|
||||
POST /eval/benchmarks
|
||||
```
|
||||
```json
|
||||
{
|
||||
"benchmark_id": "my-custom-benchmark",
|
||||
"dataset_id": "",
|
||||
"scoring_functions": [],
|
||||
"metadata": {
|
||||
"type": "custom",
|
||||
"params": {
|
||||
"parallelism": 8
|
||||
},
|
||||
"tasks": {
|
||||
"qa": {
|
||||
"type": "completion",
|
||||
"params": {
|
||||
"template": {
|
||||
"prompt": "{{prompt}}",
|
||||
"max_tokens": 200
|
||||
}
|
||||
},
|
||||
"dataset": {
|
||||
"files_url": "hf://datasets/default/sample-basic-test/testing/testing.jsonl"
|
||||
},
|
||||
"metrics": {
|
||||
"bleu": {
|
||||
"type": "bleu",
|
||||
"params": {
|
||||
"references": [
|
||||
"{{ideal_response}}"
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Example for triggering a benchmark/custom evaluation
|
||||
|
||||
```
|
||||
POST /eval/benchmarks/{benchmark_id}/jobs
|
||||
```
|
||||
```json
|
||||
{
|
||||
"benchmark_id": "my-custom-benchmark",
|
||||
"benchmark_config": {
|
||||
"eval_candidate": {
|
||||
"type": "model",
|
||||
"model": "meta-llama/Llama3.1-8B-Instruct",
|
||||
"sampling_params": {
|
||||
"max_tokens": 100,
|
||||
"temperature": 0.7
|
||||
}
|
||||
},
|
||||
"scoring_params": {}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Response example:
|
||||
```json
|
||||
{
|
||||
"job_id": "eval-1234",
|
||||
"status": "in_progress"
|
||||
}
|
||||
```
|
||||
|
||||
### Example for getting the status of a job
|
||||
```
|
||||
GET /eval/benchmarks/{benchmark_id}/jobs/{job_id}
|
||||
```
|
||||
|
||||
Response example:
|
||||
```json
|
||||
{
|
||||
"job_id": "eval-1234",
|
||||
"status": "in_progress"
|
||||
}
|
||||
```
|
||||
|
||||
### Example for cancelling a job
|
||||
```
|
||||
POST /eval/benchmarks/{benchmark_id}/jobs/{job_id}/cancel
|
||||
```
|
||||
|
||||
### Example for getting the results
|
||||
```
|
||||
GET /eval/benchmarks/{benchmark_id}/results
|
||||
```
|
||||
```json
|
||||
{
|
||||
"generations": [],
|
||||
"scores": {
|
||||
"{benchmark_id}": {
|
||||
"score_rows": [],
|
||||
"aggregated_results": {
|
||||
"tasks": {},
|
||||
"groups": {}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
|
@ -1,31 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.core.datatypes import Api
|
||||
|
||||
from .config import NVIDIAEvalConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(
|
||||
config: NVIDIAEvalConfig,
|
||||
deps: dict[Api, Any],
|
||||
):
|
||||
from .eval import NVIDIAEvalImpl
|
||||
|
||||
impl = NVIDIAEvalImpl(
|
||||
config,
|
||||
deps[Api.datasetio],
|
||||
deps[Api.datasets],
|
||||
deps[Api.scoring],
|
||||
deps[Api.inference],
|
||||
deps[Api.agents],
|
||||
)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
||||
|
||||
__all__ = ["get_adapter_impl", "NVIDIAEvalImpl"]
|
||||
|
|
@ -1,29 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class NVIDIAEvalConfig(BaseModel):
|
||||
"""
|
||||
Configuration for the NVIDIA NeMo Evaluator microservice endpoint.
|
||||
|
||||
Attributes:
|
||||
evaluator_url (str): A base url for accessing the NVIDIA evaluation endpoint, e.g. http://localhost:8000.
|
||||
"""
|
||||
|
||||
evaluator_url: str = Field(
|
||||
default_factory=lambda: os.getenv("NVIDIA_EVALUATOR_URL", "http://0.0.0.0:7331"),
|
||||
description="The url for accessing the evaluator service",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
"evaluator_url": "${env.NVIDIA_EVALUATOR_URL:=http://localhost:7331}",
|
||||
}
|
||||
|
|
@ -1,162 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
|
||||
from llama_stack.apis.agents import Agents
|
||||
from llama_stack.apis.benchmarks import Benchmark
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.scoring import Scoring, ScoringResult
|
||||
from llama_stack.providers.datatypes import BenchmarksProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
|
||||
from .....apis.common.job_types import Job, JobStatus
|
||||
from .....apis.eval.eval import BenchmarkConfig, Eval, EvaluateResponse
|
||||
from .config import NVIDIAEvalConfig
|
||||
|
||||
DEFAULT_NAMESPACE = "nvidia"
|
||||
|
||||
|
||||
class NVIDIAEvalImpl(
|
||||
Eval,
|
||||
BenchmarksProtocolPrivate,
|
||||
ModelRegistryHelper,
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
config: NVIDIAEvalConfig,
|
||||
datasetio_api: DatasetIO,
|
||||
datasets_api: Datasets,
|
||||
scoring_api: Scoring,
|
||||
inference_api: Inference,
|
||||
agents_api: Agents,
|
||||
) -> None:
|
||||
self.config = config
|
||||
self.datasetio_api = datasetio_api
|
||||
self.datasets_api = datasets_api
|
||||
self.scoring_api = scoring_api
|
||||
self.inference_api = inference_api
|
||||
self.agents_api = agents_api
|
||||
|
||||
ModelRegistryHelper.__init__(self)
|
||||
|
||||
async def initialize(self) -> None: ...
|
||||
|
||||
async def shutdown(self) -> None: ...
|
||||
|
||||
async def _evaluator_get(self, path: str):
|
||||
"""Helper for making GET requests to the evaluator service."""
|
||||
response = requests.get(url=f"{self.config.evaluator_url}{path}")
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
async def _evaluator_post(self, path: str, data: dict[str, Any]):
|
||||
"""Helper for making POST requests to the evaluator service."""
|
||||
response = requests.post(url=f"{self.config.evaluator_url}{path}", json=data)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
async def _evaluator_delete(self, path: str) -> None:
|
||||
"""Helper for making DELETE requests to the evaluator service."""
|
||||
response = requests.delete(url=f"{self.config.evaluator_url}{path}")
|
||||
response.raise_for_status()
|
||||
|
||||
async def register_benchmark(self, task_def: Benchmark) -> None:
|
||||
"""Register a benchmark as an evaluation configuration."""
|
||||
await self._evaluator_post(
|
||||
"/v1/evaluation/configs",
|
||||
{
|
||||
"namespace": DEFAULT_NAMESPACE,
|
||||
"name": task_def.benchmark_id,
|
||||
# metadata is copied to request body as-is
|
||||
**task_def.metadata,
|
||||
},
|
||||
)
|
||||
|
||||
async def unregister_benchmark(self, benchmark_id: str) -> None:
|
||||
"""Unregister a benchmark evaluation configuration from NeMo Evaluator."""
|
||||
await self._evaluator_delete(f"/v1/evaluation/configs/{DEFAULT_NAMESPACE}/{benchmark_id}")
|
||||
|
||||
async def run_eval(
|
||||
self,
|
||||
benchmark_id: str,
|
||||
benchmark_config: BenchmarkConfig,
|
||||
) -> Job:
|
||||
"""Run an evaluation job for a benchmark."""
|
||||
model = (
|
||||
benchmark_config.eval_candidate.model
|
||||
if benchmark_config.eval_candidate.type == "model"
|
||||
else benchmark_config.eval_candidate.config.model
|
||||
)
|
||||
nvidia_model = self.get_provider_model_id(model) or model
|
||||
|
||||
result = await self._evaluator_post(
|
||||
"/v1/evaluation/jobs",
|
||||
{
|
||||
"config": f"{DEFAULT_NAMESPACE}/{benchmark_id}",
|
||||
"target": {"type": "model", "model": nvidia_model},
|
||||
},
|
||||
)
|
||||
|
||||
return Job(job_id=result["id"], status=JobStatus.in_progress)
|
||||
|
||||
async def evaluate_rows(
|
||||
self,
|
||||
benchmark_id: str,
|
||||
input_rows: list[dict[str, Any]],
|
||||
scoring_functions: list[str],
|
||||
benchmark_config: BenchmarkConfig,
|
||||
) -> EvaluateResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def job_status(self, benchmark_id: str, job_id: str) -> Job:
|
||||
"""Get the status of an evaluation job.
|
||||
|
||||
EvaluatorStatus: "created", "pending", "running", "cancelled", "cancelling", "failed", "completed".
|
||||
JobStatus: "scheduled", "in_progress", "completed", "cancelled", "failed"
|
||||
"""
|
||||
result = await self._evaluator_get(f"/v1/evaluation/jobs/{job_id}")
|
||||
result_status = result["status"]
|
||||
|
||||
job_status = JobStatus.failed
|
||||
if result_status in ["created", "pending"]:
|
||||
job_status = JobStatus.scheduled
|
||||
elif result_status in ["running"]:
|
||||
job_status = JobStatus.in_progress
|
||||
elif result_status in ["completed"]:
|
||||
job_status = JobStatus.completed
|
||||
elif result_status in ["cancelled"]:
|
||||
job_status = JobStatus.cancelled
|
||||
|
||||
return Job(job_id=job_id, status=job_status)
|
||||
|
||||
async def job_cancel(self, benchmark_id: str, job_id: str) -> None:
|
||||
"""Cancel the evaluation job."""
|
||||
await self._evaluator_post(f"/v1/evaluation/jobs/{job_id}/cancel", {})
|
||||
|
||||
async def job_result(self, benchmark_id: str, job_id: str) -> EvaluateResponse:
|
||||
"""Returns the results of the evaluation job."""
|
||||
|
||||
job = await self.job_status(benchmark_id, job_id)
|
||||
status = job.status
|
||||
if not status or status != JobStatus.completed:
|
||||
raise ValueError(f"Job {job_id} not completed. Status: {status.value}")
|
||||
|
||||
result = await self._evaluator_get(f"/v1/evaluation/jobs/{job_id}/results")
|
||||
|
||||
return EvaluateResponse(
|
||||
# TODO: these are stored in detailed results on NeMo Evaluator side; can be added
|
||||
generations=[],
|
||||
scores={
|
||||
benchmark_id: ScoringResult(
|
||||
score_rows=[],
|
||||
aggregated_results=result,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
|
@ -1,237 +0,0 @@
|
|||
# S3 Files Provider
|
||||
|
||||
A remote S3-based implementation of the Llama Stack Files API that provides scalable cloud file storage with metadata persistence.
|
||||
|
||||
## Features
|
||||
|
||||
- **AWS S3 Storage**: Store files in AWS S3 buckets for scalable, durable storage
|
||||
- **Metadata Management**: Uses SQL database for efficient file metadata queries
|
||||
- **OpenAI API Compatibility**: Full compatibility with OpenAI Files API endpoints
|
||||
- **Flexible Authentication**: Support for IAM roles and access keys
|
||||
- **Custom S3 Endpoints**: Support for MinIO and other S3-compatible services
|
||||
|
||||
## Configuration
|
||||
|
||||
### Basic Configuration
|
||||
|
||||
```yaml
|
||||
api: files
|
||||
provider_type: remote::s3
|
||||
config:
|
||||
bucket_name: my-llama-stack-files
|
||||
region: us-east-1
|
||||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ./s3_files_metadata.db
|
||||
```
|
||||
|
||||
### Advanced Configuration
|
||||
|
||||
```yaml
|
||||
api: files
|
||||
provider_type: remote::s3
|
||||
config:
|
||||
bucket_name: my-llama-stack-files
|
||||
region: us-east-1
|
||||
aws_access_key_id: YOUR_ACCESS_KEY
|
||||
aws_secret_access_key: YOUR_SECRET_KEY
|
||||
endpoint_url: https://s3.amazonaws.com # Optional for custom endpoints
|
||||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ./s3_files_metadata.db
|
||||
```
|
||||
|
||||
### Environment Variables
|
||||
|
||||
The configuration supports environment variable substitution:
|
||||
|
||||
```yaml
|
||||
config:
|
||||
bucket_name: "${env.S3_BUCKET_NAME}"
|
||||
region: "${env.AWS_REGION:=us-east-1}"
|
||||
aws_access_key_id: "${env.AWS_ACCESS_KEY_ID:=}"
|
||||
aws_secret_access_key: "${env.AWS_SECRET_ACCESS_KEY:=}"
|
||||
endpoint_url: "${env.S3_ENDPOINT_URL:=}"
|
||||
```
|
||||
|
||||
Note: `S3_BUCKET_NAME` has no default value since S3 bucket names must be globally unique.
|
||||
|
||||
## Authentication
|
||||
|
||||
### IAM Roles (Recommended)
|
||||
|
||||
For production deployments, use IAM roles:
|
||||
|
||||
```yaml
|
||||
config:
|
||||
bucket_name: my-bucket
|
||||
region: us-east-1
|
||||
# No credentials needed - will use IAM role
|
||||
```
|
||||
|
||||
### Access Keys
|
||||
|
||||
For development or specific use cases:
|
||||
|
||||
```yaml
|
||||
config:
|
||||
bucket_name: my-bucket
|
||||
region: us-east-1
|
||||
aws_access_key_id: AKIAIOSFODNN7EXAMPLE
|
||||
aws_secret_access_key: wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY
|
||||
```
|
||||
|
||||
## S3 Bucket Setup
|
||||
|
||||
### Required Permissions
|
||||
|
||||
The S3 provider requires the following permissions:
|
||||
|
||||
```json
|
||||
{
|
||||
"Version": "2012-10-17",
|
||||
"Statement": [
|
||||
{
|
||||
"Effect": "Allow",
|
||||
"Action": [
|
||||
"s3:GetObject",
|
||||
"s3:PutObject",
|
||||
"s3:DeleteObject",
|
||||
"s3:ListBucket"
|
||||
],
|
||||
"Resource": [
|
||||
"arn:aws:s3:::your-bucket-name",
|
||||
"arn:aws:s3:::your-bucket-name/*"
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### Automatic Bucket Creation
|
||||
|
||||
By default, the S3 provider expects the bucket to already exist. If you want the provider to automatically create the bucket when it doesn't exist, set `auto_create_bucket: true` in your configuration:
|
||||
|
||||
```yaml
|
||||
config:
|
||||
bucket_name: my-bucket
|
||||
auto_create_bucket: true # Will create bucket if it doesn't exist
|
||||
region: us-east-1
|
||||
```
|
||||
|
||||
**Note**: When `auto_create_bucket` is enabled, the provider will need additional permissions:
|
||||
|
||||
```json
|
||||
{
|
||||
"Version": "2012-10-17",
|
||||
"Statement": [
|
||||
{
|
||||
"Effect": "Allow",
|
||||
"Action": [
|
||||
"s3:GetObject",
|
||||
"s3:PutObject",
|
||||
"s3:DeleteObject",
|
||||
"s3:ListBucket",
|
||||
"s3:CreateBucket"
|
||||
],
|
||||
"Resource": [
|
||||
"arn:aws:s3:::your-bucket-name",
|
||||
"arn:aws:s3:::your-bucket-name/*"
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### Bucket Policy (Optional)
|
||||
|
||||
For additional security, you can add a bucket policy:
|
||||
|
||||
```json
|
||||
{
|
||||
"Version": "2012-10-17",
|
||||
"Statement": [
|
||||
{
|
||||
"Sid": "LlamaStackAccess",
|
||||
"Effect": "Allow",
|
||||
"Principal": {
|
||||
"AWS": "arn:aws:iam::YOUR-ACCOUNT:role/LlamaStackRole"
|
||||
},
|
||||
"Action": [
|
||||
"s3:GetObject",
|
||||
"s3:PutObject",
|
||||
"s3:DeleteObject"
|
||||
],
|
||||
"Resource": "arn:aws:s3:::your-bucket-name/*"
|
||||
},
|
||||
{
|
||||
"Sid": "LlamaStackBucketAccess",
|
||||
"Effect": "Allow",
|
||||
"Principal": {
|
||||
"AWS": "arn:aws:iam::YOUR-ACCOUNT:role/LlamaStackRole"
|
||||
},
|
||||
"Action": [
|
||||
"s3:ListBucket"
|
||||
],
|
||||
"Resource": "arn:aws:s3:::your-bucket-name"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
## Features
|
||||
|
||||
### Metadata Persistence
|
||||
|
||||
File metadata is stored in a SQL database for fast queries and OpenAI API compatibility. The metadata includes:
|
||||
|
||||
- File ID
|
||||
- Original filename
|
||||
- Purpose (assistants, batch, etc.)
|
||||
- File size in bytes
|
||||
- Created and expiration timestamps
|
||||
|
||||
### TTL and Cleanup
|
||||
|
||||
Files currently have a fixed long expiration time (100 years).
|
||||
|
||||
## Development and Testing
|
||||
|
||||
### Using MinIO
|
||||
|
||||
For self-hosted S3-compatible storage:
|
||||
|
||||
```yaml
|
||||
config:
|
||||
bucket_name: test-bucket
|
||||
region: us-east-1
|
||||
endpoint_url: http://localhost:9000
|
||||
aws_access_key_id: minioadmin
|
||||
aws_secret_access_key: minioadmin
|
||||
```
|
||||
|
||||
## Monitoring and Logging
|
||||
|
||||
The provider logs important operations and errors. For production deployments, consider:
|
||||
|
||||
- CloudWatch monitoring for S3 operations
|
||||
- Custom metrics for file upload/download rates
|
||||
- Error rate monitoring
|
||||
- Performance metrics tracking
|
||||
|
||||
## Error Handling
|
||||
|
||||
The provider handles various error scenarios:
|
||||
|
||||
- S3 connectivity issues
|
||||
- Bucket access permissions
|
||||
- File not found errors
|
||||
- Metadata consistency checks
|
||||
|
||||
## Known Limitations
|
||||
|
||||
- Fixed long TTL (100 years) instead of configurable expiration
|
||||
- No server-side encryption enabled by default
|
||||
- No support for AWS session tokens
|
||||
- No S3 key prefix organization support
|
||||
- No multipart upload support (all files uploaded as single objects)
|
||||
|
|
@ -1,19 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.core.datatypes import AccessRule, Api
|
||||
|
||||
from .config import S3FilesImplConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: S3FilesImplConfig, deps: dict[Api, Any], policy: list[AccessRule] | None = None):
|
||||
from .files import S3FilesImpl
|
||||
|
||||
impl = S3FilesImpl(config, policy or [])
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
@ -1,42 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig
|
||||
|
||||
|
||||
class S3FilesImplConfig(BaseModel):
|
||||
"""Configuration for S3-based files provider."""
|
||||
|
||||
bucket_name: str = Field(description="S3 bucket name to store files")
|
||||
region: str = Field(default="us-east-1", description="AWS region where the bucket is located")
|
||||
aws_access_key_id: str | None = Field(default=None, description="AWS access key ID (optional if using IAM roles)")
|
||||
aws_secret_access_key: str | None = Field(
|
||||
default=None, description="AWS secret access key (optional if using IAM roles)"
|
||||
)
|
||||
endpoint_url: str | None = Field(default=None, description="Custom S3 endpoint URL (for MinIO, LocalStack, etc.)")
|
||||
auto_create_bucket: bool = Field(
|
||||
default=False, description="Automatically create the S3 bucket if it doesn't exist"
|
||||
)
|
||||
metadata_store: SqlStoreConfig = Field(description="SQL store configuration for file metadata")
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]:
|
||||
return {
|
||||
"bucket_name": "${env.S3_BUCKET_NAME}", # no default, buckets must be globally unique
|
||||
"region": "${env.AWS_REGION:=us-east-1}",
|
||||
"aws_access_key_id": "${env.AWS_ACCESS_KEY_ID:=}",
|
||||
"aws_secret_access_key": "${env.AWS_SECRET_ACCESS_KEY:=}",
|
||||
"endpoint_url": "${env.S3_ENDPOINT_URL:=}",
|
||||
"auto_create_bucket": "${env.S3_AUTO_CREATE_BUCKET:=false}",
|
||||
"metadata_store": SqliteSqlStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
db_name="s3_files_metadata.db",
|
||||
),
|
||||
}
|
||||
|
|
@ -1,313 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
from typing import Annotated, Any
|
||||
|
||||
import boto3
|
||||
from botocore.exceptions import BotoCoreError, ClientError, NoCredentialsError
|
||||
from fastapi import Depends, File, Form, Response, UploadFile
|
||||
|
||||
from llama_stack.apis.common.errors import ResourceNotFoundError
|
||||
from llama_stack.apis.common.responses import Order
|
||||
from llama_stack.apis.files import (
|
||||
ExpiresAfter,
|
||||
Files,
|
||||
ListOpenAIFileResponse,
|
||||
OpenAIFileDeleteResponse,
|
||||
OpenAIFileObject,
|
||||
OpenAIFilePurpose,
|
||||
)
|
||||
from llama_stack.core.datatypes import AccessRule
|
||||
from llama_stack.core.id_generation import generate_object_id
|
||||
from llama_stack.providers.utils.files.form_data import parse_expires_after
|
||||
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
|
||||
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import sqlstore_impl
|
||||
|
||||
from .config import S3FilesImplConfig
|
||||
|
||||
# TODO: provider data for S3 credentials
|
||||
|
||||
|
||||
def _create_s3_client(config: S3FilesImplConfig) -> boto3.client:
|
||||
try:
|
||||
s3_config = {
|
||||
"region_name": config.region,
|
||||
}
|
||||
|
||||
# endpoint URL if specified (for MinIO, LocalStack, etc.)
|
||||
if config.endpoint_url:
|
||||
s3_config["endpoint_url"] = config.endpoint_url
|
||||
|
||||
if config.aws_access_key_id and config.aws_secret_access_key:
|
||||
s3_config.update(
|
||||
{
|
||||
"aws_access_key_id": config.aws_access_key_id,
|
||||
"aws_secret_access_key": config.aws_secret_access_key,
|
||||
}
|
||||
)
|
||||
|
||||
return boto3.client("s3", **s3_config)
|
||||
|
||||
except (BotoCoreError, NoCredentialsError) as e:
|
||||
raise RuntimeError(f"Failed to initialize S3 client: {e}") from e
|
||||
|
||||
|
||||
async def _create_bucket_if_not_exists(client: boto3.client, config: S3FilesImplConfig) -> None:
|
||||
try:
|
||||
client.head_bucket(Bucket=config.bucket_name)
|
||||
except ClientError as e:
|
||||
error_code = e.response["Error"]["Code"]
|
||||
if error_code == "404":
|
||||
if not config.auto_create_bucket:
|
||||
raise RuntimeError(
|
||||
f"S3 bucket '{config.bucket_name}' does not exist. "
|
||||
f"Either create the bucket manually or set 'auto_create_bucket: true' in your configuration."
|
||||
) from e
|
||||
try:
|
||||
# For us-east-1, we can't specify LocationConstraint
|
||||
if config.region == "us-east-1":
|
||||
client.create_bucket(Bucket=config.bucket_name)
|
||||
else:
|
||||
client.create_bucket(
|
||||
Bucket=config.bucket_name,
|
||||
CreateBucketConfiguration={"LocationConstraint": config.region},
|
||||
)
|
||||
except ClientError as create_error:
|
||||
raise RuntimeError(
|
||||
f"Failed to create S3 bucket '{config.bucket_name}': {create_error}"
|
||||
) from create_error
|
||||
elif error_code == "403":
|
||||
raise RuntimeError(f"Access denied to S3 bucket '{config.bucket_name}'") from e
|
||||
else:
|
||||
raise RuntimeError(f"Failed to access S3 bucket '{config.bucket_name}': {e}") from e
|
||||
|
||||
|
||||
def _make_file_object(
|
||||
*,
|
||||
id: str,
|
||||
filename: str,
|
||||
purpose: str,
|
||||
bytes: int,
|
||||
created_at: int,
|
||||
expires_at: int,
|
||||
**kwargs: Any, # here to ignore any additional fields, e.g. extra fields from AuthorizedSqlStore
|
||||
) -> OpenAIFileObject:
|
||||
"""
|
||||
Construct an OpenAIFileObject and normalize expires_at.
|
||||
|
||||
If expires_at is greater than the max we treat it as no-expiration and
|
||||
return None for expires_at.
|
||||
|
||||
The OpenAI spec says expires_at type is Integer, but the implementation
|
||||
will return None for no expiration.
|
||||
"""
|
||||
obj = OpenAIFileObject(
|
||||
id=id,
|
||||
filename=filename,
|
||||
purpose=OpenAIFilePurpose(purpose),
|
||||
bytes=bytes,
|
||||
created_at=created_at,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
|
||||
if obj.expires_at is not None and obj.expires_at > (obj.created_at + ExpiresAfter.MAX):
|
||||
obj.expires_at = None # type: ignore
|
||||
|
||||
return obj
|
||||
|
||||
|
||||
class S3FilesImpl(Files):
|
||||
"""S3-based implementation of the Files API."""
|
||||
|
||||
def __init__(self, config: S3FilesImplConfig, policy: list[AccessRule]) -> None:
|
||||
self._config = config
|
||||
self.policy = policy
|
||||
self._client: boto3.client | None = None
|
||||
self._sql_store: AuthorizedSqlStore | None = None
|
||||
|
||||
def _now(self) -> int:
|
||||
"""Return current UTC timestamp as int seconds."""
|
||||
return int(datetime.now(UTC).timestamp())
|
||||
|
||||
async def _get_file(self, file_id: str, return_expired: bool = False) -> dict[str, Any]:
|
||||
where: dict[str, str | dict] = {"id": file_id}
|
||||
if not return_expired:
|
||||
where["expires_at"] = {">": self._now()}
|
||||
if not (row := await self.sql_store.fetch_one("openai_files", where=where)):
|
||||
raise ResourceNotFoundError(file_id, "File", "files.list()")
|
||||
return row
|
||||
|
||||
async def _delete_file(self, file_id: str) -> None:
|
||||
"""Delete a file from S3 and the database."""
|
||||
try:
|
||||
self.client.delete_object(
|
||||
Bucket=self._config.bucket_name,
|
||||
Key=file_id,
|
||||
)
|
||||
except ClientError as e:
|
||||
if e.response["Error"]["Code"] != "NoSuchKey":
|
||||
raise RuntimeError(f"Failed to delete file from S3: {e}") from e
|
||||
|
||||
await self.sql_store.delete("openai_files", where={"id": file_id})
|
||||
|
||||
async def _delete_if_expired(self, file_id: str) -> None:
|
||||
"""If the file exists and is expired, delete it."""
|
||||
if row := await self._get_file(file_id, return_expired=True):
|
||||
if (expires_at := row.get("expires_at")) and expires_at <= self._now():
|
||||
await self._delete_file(file_id)
|
||||
|
||||
async def initialize(self) -> None:
|
||||
self._client = _create_s3_client(self._config)
|
||||
await _create_bucket_if_not_exists(self._client, self._config)
|
||||
|
||||
self._sql_store = AuthorizedSqlStore(sqlstore_impl(self._config.metadata_store), self.policy)
|
||||
await self._sql_store.create_table(
|
||||
"openai_files",
|
||||
{
|
||||
"id": ColumnDefinition(type=ColumnType.STRING, primary_key=True),
|
||||
"filename": ColumnType.STRING,
|
||||
"purpose": ColumnType.STRING,
|
||||
"bytes": ColumnType.INTEGER,
|
||||
"created_at": ColumnType.INTEGER,
|
||||
"expires_at": ColumnType.INTEGER,
|
||||
# TODO: add s3_etag field for integrity checking
|
||||
},
|
||||
)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
@property
|
||||
def client(self) -> boto3.client:
|
||||
assert self._client is not None, "Provider not initialized"
|
||||
return self._client
|
||||
|
||||
@property
|
||||
def sql_store(self) -> AuthorizedSqlStore:
|
||||
assert self._sql_store is not None, "Provider not initialized"
|
||||
return self._sql_store
|
||||
|
||||
async def openai_upload_file(
|
||||
self,
|
||||
file: Annotated[UploadFile, File()],
|
||||
purpose: Annotated[OpenAIFilePurpose, Form()],
|
||||
expires_after: Annotated[ExpiresAfter | None, Depends(parse_expires_after)] = None,
|
||||
) -> OpenAIFileObject:
|
||||
file_id = generate_object_id("file", lambda: f"file-{uuid.uuid4().hex}")
|
||||
|
||||
filename = getattr(file, "filename", None) or "uploaded_file"
|
||||
|
||||
created_at = self._now()
|
||||
|
||||
# the default is no expiration.
|
||||
# to implement no expiration we set an expiration beyond the max.
|
||||
# we'll hide this fact from users when returning the file object.
|
||||
expires_at = created_at + ExpiresAfter.MAX * 42
|
||||
# the default for BATCH files is 30 days, which happens to be the expiration max.
|
||||
if purpose == OpenAIFilePurpose.BATCH:
|
||||
expires_at = created_at + ExpiresAfter.MAX
|
||||
|
||||
if expires_after is not None:
|
||||
expires_at = created_at + expires_after.seconds
|
||||
|
||||
content = await file.read()
|
||||
file_size = len(content)
|
||||
|
||||
entry: dict[str, Any] = {
|
||||
"id": file_id,
|
||||
"filename": filename,
|
||||
"purpose": purpose.value,
|
||||
"bytes": file_size,
|
||||
"created_at": created_at,
|
||||
"expires_at": expires_at,
|
||||
}
|
||||
|
||||
await self.sql_store.insert("openai_files", entry)
|
||||
|
||||
try:
|
||||
self.client.put_object(
|
||||
Bucket=self._config.bucket_name,
|
||||
Key=file_id,
|
||||
Body=content,
|
||||
# TODO: enable server-side encryption
|
||||
)
|
||||
except ClientError as e:
|
||||
await self.sql_store.delete("openai_files", where={"id": file_id})
|
||||
|
||||
raise RuntimeError(f"Failed to upload file to S3: {e}") from e
|
||||
|
||||
return _make_file_object(**entry)
|
||||
|
||||
async def openai_list_files(
|
||||
self,
|
||||
after: str | None = None,
|
||||
limit: int | None = 10000,
|
||||
order: Order | None = Order.desc,
|
||||
purpose: OpenAIFilePurpose | None = None,
|
||||
) -> ListOpenAIFileResponse:
|
||||
# this purely defensive. it should not happen because the router also default to Order.desc.
|
||||
if not order:
|
||||
order = Order.desc
|
||||
|
||||
where_conditions: dict[str, Any] = {"expires_at": {">": self._now()}}
|
||||
if purpose:
|
||||
where_conditions["purpose"] = purpose.value
|
||||
|
||||
paginated_result = await self.sql_store.fetch_all(
|
||||
table="openai_files",
|
||||
where=where_conditions,
|
||||
order_by=[("created_at", order.value)],
|
||||
cursor=("id", after) if after else None,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
files = [_make_file_object(**row) for row in paginated_result.data]
|
||||
|
||||
return ListOpenAIFileResponse(
|
||||
data=files,
|
||||
has_more=paginated_result.has_more,
|
||||
# empty string or None? spec says str, ref impl returns str | None, we go with spec
|
||||
first_id=files[0].id if files else "",
|
||||
last_id=files[-1].id if files else "",
|
||||
)
|
||||
|
||||
async def openai_retrieve_file(self, file_id: str) -> OpenAIFileObject:
|
||||
await self._delete_if_expired(file_id)
|
||||
row = await self._get_file(file_id)
|
||||
return _make_file_object(**row)
|
||||
|
||||
async def openai_delete_file(self, file_id: str) -> OpenAIFileDeleteResponse:
|
||||
await self._delete_if_expired(file_id)
|
||||
_ = await self._get_file(file_id) # raises if not found
|
||||
await self._delete_file(file_id)
|
||||
return OpenAIFileDeleteResponse(id=file_id, deleted=True)
|
||||
|
||||
async def openai_retrieve_file_content(self, file_id: str) -> Response:
|
||||
await self._delete_if_expired(file_id)
|
||||
|
||||
row = await self._get_file(file_id)
|
||||
|
||||
try:
|
||||
response = self.client.get_object(
|
||||
Bucket=self._config.bucket_name,
|
||||
Key=row["id"],
|
||||
)
|
||||
# TODO: can we stream this instead of loading it into memory
|
||||
content = response["Body"].read()
|
||||
except ClientError as e:
|
||||
if e.response["Error"]["Code"] == "NoSuchKey":
|
||||
await self._delete_file(file_id)
|
||||
raise ResourceNotFoundError(file_id, "File", "files.list()") from e
|
||||
raise RuntimeError(f"Failed to download file from S3: {e}") from e
|
||||
|
||||
return Response(
|
||||
content=content,
|
||||
media_type="application/octet-stream",
|
||||
headers={"Content-Disposition": f'attachment; filename="{row["filename"]}"'},
|
||||
)
|
||||
|
|
@ -1,5 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
|
@ -1,15 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import AnthropicConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: AnthropicConfig, _deps):
|
||||
from .anthropic import AnthropicInferenceAdapter
|
||||
|
||||
impl = AnthropicInferenceAdapter(config=config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
@ -1,36 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import Iterable
|
||||
|
||||
from anthropic import AsyncAnthropic
|
||||
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
from .config import AnthropicConfig
|
||||
|
||||
|
||||
class AnthropicInferenceAdapter(OpenAIMixin):
|
||||
config: AnthropicConfig
|
||||
|
||||
provider_data_api_key_field: str = "anthropic_api_key"
|
||||
# source: https://docs.claude.com/en/docs/build-with-claude/embeddings
|
||||
# TODO: add support for voyageai, which is where these models are hosted
|
||||
# embedding_model_metadata = {
|
||||
# "voyage-3-large": {"embedding_dimension": 1024, "context_length": 32000}, # supports dimensions 256, 512, 1024, 2048
|
||||
# "voyage-3.5": {"embedding_dimension": 1024, "context_length": 32000}, # supports dimensions 256, 512, 1024, 2048
|
||||
# "voyage-3.5-lite": {"embedding_dimension": 1024, "context_length": 32000}, # supports dimensions 256, 512, 1024, 2048
|
||||
# "voyage-code-3": {"embedding_dimension": 1024, "context_length": 32000}, # supports dimensions 256, 512, 1024, 2048
|
||||
# "voyage-finance-2": {"embedding_dimension": 1024, "context_length": 32000},
|
||||
# "voyage-law-2": {"embedding_dimension": 1024, "context_length": 16000},
|
||||
# "voyage-multimodal-3": {"embedding_dimension": 1024, "context_length": 32000},
|
||||
# }
|
||||
|
||||
def get_base_url(self):
|
||||
return "https://api.anthropic.com/v1"
|
||||
|
||||
async def list_provider_model_ids(self) -> Iterable[str]:
|
||||
return [m.id async for m in AsyncAnthropic(api_key=self.get_api_key()).models.list()]
|
||||
|
|
@ -1,28 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
class AnthropicProviderDataValidator(BaseModel):
|
||||
anthropic_api_key: str | None = Field(
|
||||
default=None,
|
||||
description="API key for Anthropic models",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AnthropicConfig(RemoteInferenceProviderConfig):
|
||||
@classmethod
|
||||
def sample_run_config(cls, api_key: str = "${env.ANTHROPIC_API_KEY:=}", **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
"api_key": api_key,
|
||||
}
|
||||
|
|
@ -1,15 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import AzureConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: AzureConfig, _deps):
|
||||
from .azure import AzureInferenceAdapter
|
||||
|
||||
impl = AzureInferenceAdapter(config=config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
@ -1,25 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from urllib.parse import urljoin
|
||||
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
from .config import AzureConfig
|
||||
|
||||
|
||||
class AzureInferenceAdapter(OpenAIMixin):
|
||||
config: AzureConfig
|
||||
|
||||
provider_data_api_key_field: str = "azure_api_key"
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
"""
|
||||
Get the Azure API base URL.
|
||||
|
||||
Returns the Azure API base URL from the configuration.
|
||||
"""
|
||||
return urljoin(str(self.config.api_base), "/openai/v1")
|
||||
|
|
@ -1,61 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, HttpUrl, SecretStr
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
class AzureProviderDataValidator(BaseModel):
|
||||
azure_api_key: SecretStr = Field(
|
||||
description="Azure API key for Azure",
|
||||
)
|
||||
azure_api_base: HttpUrl = Field(
|
||||
description="Azure API base for Azure (e.g., https://your-resource-name.openai.azure.com)",
|
||||
)
|
||||
azure_api_version: str | None = Field(
|
||||
default=None,
|
||||
description="Azure API version for Azure (e.g., 2024-06-01)",
|
||||
)
|
||||
azure_api_type: str | None = Field(
|
||||
default="azure",
|
||||
description="Azure API type for Azure (e.g., azure)",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AzureConfig(RemoteInferenceProviderConfig):
|
||||
api_base: HttpUrl = Field(
|
||||
description="Azure API base for Azure (e.g., https://your-resource-name.openai.azure.com)",
|
||||
)
|
||||
api_version: str | None = Field(
|
||||
default_factory=lambda: os.getenv("AZURE_API_VERSION"),
|
||||
description="Azure API version for Azure (e.g., 2024-12-01-preview)",
|
||||
)
|
||||
api_type: str | None = Field(
|
||||
default_factory=lambda: os.getenv("AZURE_API_TYPE", "azure"),
|
||||
description="Azure API type for Azure (e.g., azure)",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(
|
||||
cls,
|
||||
api_key: str = "${env.AZURE_API_KEY:=}",
|
||||
api_base: str = "${env.AZURE_API_BASE:=}",
|
||||
api_version: str = "${env.AZURE_API_VERSION:=}",
|
||||
api_type: str = "${env.AZURE_API_TYPE:=}",
|
||||
**kwargs,
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"api_key": api_key,
|
||||
"api_base": api_base,
|
||||
"api_version": api_version,
|
||||
"api_type": api_type,
|
||||
}
|
||||
|
|
@ -1,18 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from .config import BedrockConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: BedrockConfig, _deps):
|
||||
from .bedrock import BedrockInferenceAdapter
|
||||
|
||||
assert isinstance(config, BedrockConfig), f"Unexpected config type: {type(config)}"
|
||||
|
||||
impl = BedrockInferenceAdapter(config)
|
||||
|
||||
await impl.initialize()
|
||||
|
||||
return impl
|
||||
|
|
@ -1,190 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
from botocore.client import BaseClient
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
Inference,
|
||||
OpenAIEmbeddingsResponse,
|
||||
)
|
||||
from llama_stack.apis.inference.inference import (
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAICompletion,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
)
|
||||
from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig
|
||||
from llama_stack.providers.utils.bedrock.client import create_bedrock_client
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
get_sampling_strategy_options,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
)
|
||||
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
REGION_PREFIX_MAP = {
|
||||
"us": "us.",
|
||||
"eu": "eu.",
|
||||
"ap": "ap.",
|
||||
}
|
||||
|
||||
|
||||
def _get_region_prefix(region: str | None) -> str:
|
||||
# AWS requires region prefixes for inference profiles
|
||||
if region is None:
|
||||
return "us." # default to US when we don't know
|
||||
|
||||
# Handle case insensitive region matching
|
||||
region_lower = region.lower()
|
||||
for prefix in REGION_PREFIX_MAP:
|
||||
if region_lower.startswith(f"{prefix}-"):
|
||||
return REGION_PREFIX_MAP[prefix]
|
||||
|
||||
# Fallback to US for anything we don't recognize
|
||||
return "us."
|
||||
|
||||
|
||||
def _to_inference_profile_id(model_id: str, region: str = None) -> str:
|
||||
# Return ARNs unchanged
|
||||
if model_id.startswith("arn:"):
|
||||
return model_id
|
||||
|
||||
# Return inference profile IDs that already have regional prefixes
|
||||
if any(model_id.startswith(p) for p in REGION_PREFIX_MAP.values()):
|
||||
return model_id
|
||||
|
||||
# Default to US East when no region is provided
|
||||
if region is None:
|
||||
region = "us-east-1"
|
||||
|
||||
return _get_region_prefix(region) + model_id
|
||||
|
||||
|
||||
class BedrockInferenceAdapter(
|
||||
ModelRegistryHelper,
|
||||
Inference,
|
||||
):
|
||||
def __init__(self, config: BedrockConfig) -> None:
|
||||
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
|
||||
self._config = config
|
||||
self._client = None
|
||||
|
||||
@property
|
||||
def client(self) -> BaseClient:
|
||||
if self._client is None:
|
||||
self._client = create_bedrock_client(self._config)
|
||||
return self._client
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
if self._client is not None:
|
||||
self._client.close()
|
||||
|
||||
async def _get_params_for_chat_completion(self, request: ChatCompletionRequest) -> dict:
|
||||
bedrock_model = request.model
|
||||
|
||||
sampling_params = request.sampling_params
|
||||
options = get_sampling_strategy_options(sampling_params)
|
||||
|
||||
if sampling_params.max_tokens:
|
||||
options["max_gen_len"] = sampling_params.max_tokens
|
||||
if sampling_params.repetition_penalty > 0:
|
||||
options["repetition_penalty"] = sampling_params.repetition_penalty
|
||||
|
||||
prompt = await chat_completion_request_to_prompt(request, self.get_llama_model(request.model))
|
||||
|
||||
# Convert foundation model ID to inference profile ID
|
||||
region_name = self.client.meta.region_name
|
||||
inference_profile_id = _to_inference_profile_id(bedrock_model, region_name)
|
||||
|
||||
return {
|
||||
"modelId": inference_profile_id,
|
||||
"body": json.dumps(
|
||||
{
|
||||
"prompt": prompt,
|
||||
**options,
|
||||
}
|
||||
),
|
||||
}
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
input: str | list[str],
|
||||
encoding_format: str | None = "float",
|
||||
dimensions: int | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def openai_completion(
|
||||
self,
|
||||
# Standard OpenAI completion parameters
|
||||
model: str,
|
||||
prompt: str | list[str] | list[int] | list[list[int]],
|
||||
best_of: int | None = None,
|
||||
echo: bool | None = None,
|
||||
frequency_penalty: float | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
# vLLM-specific parameters
|
||||
guided_choice: list[str] | None = None,
|
||||
prompt_logprobs: int | None = None,
|
||||
# for fill-in-the-middle type completion
|
||||
suffix: str | None = None,
|
||||
) -> OpenAICompletion:
|
||||
raise NotImplementedError("OpenAI completion not supported by the Bedrock provider")
|
||||
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list[OpenAIMessageParam],
|
||||
frequency_penalty: float | None = None,
|
||||
function_call: str | dict[str, Any] | None = None,
|
||||
functions: list[dict[str, Any]] | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_completion_tokens: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
parallel_tool_calls: bool | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
response_format: OpenAIResponseFormatParam | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
top_logprobs: int | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
raise NotImplementedError("OpenAI chat completion not supported by the Bedrock provider")
|
||||
|
|
@ -1,11 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.providers.utils.bedrock.config import BedrockBaseConfig
|
||||
|
||||
|
||||
class BedrockConfig(BedrockBaseConfig):
|
||||
pass
|
||||
|
|
@ -1,29 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.models.llama.sku_types import CoreModelId
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
build_hf_repo_model_entry,
|
||||
)
|
||||
|
||||
SAFETY_MODELS_ENTRIES = []
|
||||
|
||||
|
||||
# https://docs.aws.amazon.com/bedrock/latest/userguide/models-supported.html
|
||||
MODEL_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"meta.llama3-1-8b-instruct-v1:0",
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta.llama3-1-70b-instruct-v1:0",
|
||||
CoreModelId.llama3_1_70b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta.llama3-1-405b-instruct-v1:0",
|
||||
CoreModelId.llama3_1_405b_instruct.value,
|
||||
),
|
||||
] + SAFETY_MODELS_ENTRIES
|
||||
|
|
@ -1,19 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import CerebrasImplConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: CerebrasImplConfig, _deps):
|
||||
from .cerebras import CerebrasInferenceAdapter
|
||||
|
||||
assert isinstance(config, CerebrasImplConfig), f"Unexpected config type: {type(config)}"
|
||||
|
||||
impl = CerebrasInferenceAdapter(config=config)
|
||||
|
||||
await impl.initialize()
|
||||
|
||||
return impl
|
||||
|
|
@ -1,29 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from urllib.parse import urljoin
|
||||
|
||||
from llama_stack.apis.inference import OpenAIEmbeddingsResponse
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
from .config import CerebrasImplConfig
|
||||
|
||||
|
||||
class CerebrasInferenceAdapter(OpenAIMixin):
|
||||
config: CerebrasImplConfig
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
return urljoin(self.config.base_url, "v1")
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
input: str | list[str],
|
||||
encoding_format: str | None = "float",
|
||||
dimensions: int | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
|
@ -1,30 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
DEFAULT_BASE_URL = "https://api.cerebras.ai"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class CerebrasImplConfig(RemoteInferenceProviderConfig):
|
||||
base_url: str = Field(
|
||||
default=os.environ.get("CEREBRAS_BASE_URL", DEFAULT_BASE_URL),
|
||||
description="Base URL for the Cerebras API",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, api_key: str = "${env.CEREBRAS_API_KEY:=}", **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
"base_url": DEFAULT_BASE_URL,
|
||||
"api_key": api_key,
|
||||
}
|
||||
|
|
@ -1,16 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import DatabricksImplConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: DatabricksImplConfig, _deps):
|
||||
from .databricks import DatabricksInferenceAdapter
|
||||
|
||||
assert isinstance(config, DatabricksImplConfig), f"Unexpected config type: {type(config)}"
|
||||
impl = DatabricksInferenceAdapter(config=config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
@ -1,37 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field, SecretStr
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class DatabricksImplConfig(RemoteInferenceProviderConfig):
|
||||
url: str | None = Field(
|
||||
default=None,
|
||||
description="The URL for the Databricks model serving endpoint",
|
||||
)
|
||||
auth_credential: SecretStr | None = Field(
|
||||
default=None,
|
||||
alias="api_token",
|
||||
description="The Databricks API token",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(
|
||||
cls,
|
||||
url: str = "${env.DATABRICKS_HOST:=}",
|
||||
api_token: str = "${env.DATABRICKS_TOKEN:=}",
|
||||
**kwargs: Any,
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"url": url,
|
||||
"api_token": api_token,
|
||||
}
|
||||
|
|
@ -1,64 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import Iterable
|
||||
from typing import Any
|
||||
|
||||
from databricks.sdk import WorkspaceClient
|
||||
|
||||
from llama_stack.apis.inference import OpenAICompletion
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
from .config import DatabricksImplConfig
|
||||
|
||||
logger = get_logger(name=__name__, category="inference::databricks")
|
||||
|
||||
|
||||
class DatabricksInferenceAdapter(OpenAIMixin):
|
||||
config: DatabricksImplConfig
|
||||
|
||||
# source: https://docs.databricks.com/aws/en/machine-learning/foundation-model-apis/supported-models
|
||||
embedding_model_metadata: dict[str, dict[str, int]] = {
|
||||
"databricks-gte-large-en": {"embedding_dimension": 1024, "context_length": 8192},
|
||||
"databricks-bge-large-en": {"embedding_dimension": 1024, "context_length": 512},
|
||||
}
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
return f"{self.config.url}/serving-endpoints"
|
||||
|
||||
async def list_provider_model_ids(self) -> Iterable[str]:
|
||||
return [
|
||||
endpoint.name
|
||||
for endpoint in WorkspaceClient(
|
||||
host=self.config.url, token=self.get_api_key()
|
||||
).serving_endpoints.list() # TODO: this is not async
|
||||
]
|
||||
|
||||
async def openai_completion(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str | list[str] | list[int] | list[list[int]],
|
||||
best_of: int | None = None,
|
||||
echo: bool | None = None,
|
||||
frequency_penalty: float | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
guided_choice: list[str] | None = None,
|
||||
prompt_logprobs: int | None = None,
|
||||
suffix: str | None = None,
|
||||
) -> OpenAICompletion:
|
||||
raise NotImplementedError()
|
||||
|
|
@ -1,22 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .config import FireworksImplConfig
|
||||
|
||||
|
||||
class FireworksProviderDataValidator(BaseModel):
|
||||
fireworks_api_key: str
|
||||
|
||||
|
||||
async def get_adapter_impl(config: FireworksImplConfig, _deps):
|
||||
from .fireworks import FireworksInferenceAdapter
|
||||
|
||||
assert isinstance(config, FireworksImplConfig), f"Unexpected config type: {type(config)}"
|
||||
impl = FireworksInferenceAdapter(config=config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
@ -1,27 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class FireworksImplConfig(RemoteInferenceProviderConfig):
|
||||
url: str = Field(
|
||||
default="https://api.fireworks.ai/inference/v1",
|
||||
description="The URL for the Fireworks server",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, api_key: str = "${env.FIREWORKS_API_KEY:=}", **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
"url": "https://api.fireworks.ai/inference/v1",
|
||||
"api_key": api_key,
|
||||
}
|
||||
|
|
@ -1,27 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
from .config import FireworksImplConfig
|
||||
|
||||
logger = get_logger(name=__name__, category="inference::fireworks")
|
||||
|
||||
|
||||
class FireworksInferenceAdapter(OpenAIMixin):
|
||||
config: FireworksImplConfig
|
||||
|
||||
embedding_model_metadata: dict[str, dict[str, int]] = {
|
||||
"nomic-ai/nomic-embed-text-v1.5": {"embedding_dimension": 768, "context_length": 8192},
|
||||
"accounts/fireworks/models/qwen3-embedding-8b": {"embedding_dimension": 4096, "context_length": 40960},
|
||||
}
|
||||
|
||||
provider_data_api_key_field: str = "fireworks_api_key"
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
return "https://api.fireworks.ai/inference/v1"
|
||||
|
|
@ -1,15 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import GeminiConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: GeminiConfig, _deps):
|
||||
from .gemini import GeminiInferenceAdapter
|
||||
|
||||
impl = GeminiInferenceAdapter(config=config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
@ -1,28 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
class GeminiProviderDataValidator(BaseModel):
|
||||
gemini_api_key: str | None = Field(
|
||||
default=None,
|
||||
description="API key for Gemini models",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class GeminiConfig(RemoteInferenceProviderConfig):
|
||||
@classmethod
|
||||
def sample_run_config(cls, api_key: str = "${env.GEMINI_API_KEY:=}", **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
"api_key": api_key,
|
||||
}
|
||||
|
|
@ -1,21 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
from .config import GeminiConfig
|
||||
|
||||
|
||||
class GeminiInferenceAdapter(OpenAIMixin):
|
||||
config: GeminiConfig
|
||||
|
||||
provider_data_api_key_field: str = "gemini_api_key"
|
||||
embedding_model_metadata: dict[str, dict[str, int]] = {
|
||||
"text-embedding-004": {"embedding_dimension": 768, "context_length": 2048},
|
||||
}
|
||||
|
||||
def get_base_url(self):
|
||||
return "https://generativelanguage.googleapis.com/v1beta/openai/"
|
||||
|
|
@ -1,15 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import GroqConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: GroqConfig, _deps):
|
||||
# import dynamically so the import is used only when it is needed
|
||||
from .groq import GroqInferenceAdapter
|
||||
|
||||
adapter = GroqInferenceAdapter(config=config)
|
||||
return adapter
|
||||
|
|
@ -1,34 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
class GroqProviderDataValidator(BaseModel):
|
||||
groq_api_key: str | None = Field(
|
||||
default=None,
|
||||
description="API key for Groq models",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class GroqConfig(RemoteInferenceProviderConfig):
|
||||
url: str = Field(
|
||||
default="https://api.groq.com",
|
||||
description="The URL for the Groq AI server",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, api_key: str = "${env.GROQ_API_KEY:=}", **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
"url": "https://api.groq.com",
|
||||
"api_key": api_key,
|
||||
}
|
||||
|
|
@ -1,18 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from llama_stack.providers.remote.inference.groq.config import GroqConfig
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
|
||||
class GroqInferenceAdapter(OpenAIMixin):
|
||||
config: GroqConfig
|
||||
|
||||
provider_data_api_key_field: str = "groq_api_key"
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
return f"{self.config.url}/openai/v1"
|
||||
|
|
@ -1,15 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import LlamaCompatConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: LlamaCompatConfig, _deps):
|
||||
# import dynamically so the import is used only when it is needed
|
||||
from .llama import LlamaCompatInferenceAdapter
|
||||
|
||||
adapter = LlamaCompatInferenceAdapter(config=config)
|
||||
return adapter
|
||||
|
|
@ -1,34 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
class LlamaProviderDataValidator(BaseModel):
|
||||
llama_api_key: str | None = Field(
|
||||
default=None,
|
||||
description="API key for api.llama models",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class LlamaCompatConfig(RemoteInferenceProviderConfig):
|
||||
openai_compat_api_base: str = Field(
|
||||
default="https://api.llama.com/compat/v1/",
|
||||
description="The URL for the Llama API server",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, api_key: str = "${env.LLAMA_API_KEY}", **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
"openai_compat_api_base": "https://api.llama.com/compat/v1/",
|
||||
"api_key": api_key,
|
||||
}
|
||||
|
|
@ -1,65 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.inference.inference import OpenAICompletion, OpenAIEmbeddingsResponse
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
logger = get_logger(name=__name__, category="inference::llama_openai_compat")
|
||||
|
||||
|
||||
class LlamaCompatInferenceAdapter(OpenAIMixin):
|
||||
config: LlamaCompatConfig
|
||||
|
||||
provider_data_api_key_field: str = "llama_api_key"
|
||||
"""
|
||||
Llama API Inference Adapter for Llama Stack.
|
||||
"""
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
"""
|
||||
Get the base URL for OpenAI mixin.
|
||||
|
||||
:return: The Llama API base URL
|
||||
"""
|
||||
return self.config.openai_compat_api_base
|
||||
|
||||
async def openai_completion(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str | list[str] | list[int] | list[list[int]],
|
||||
best_of: int | None = None,
|
||||
echo: bool | None = None,
|
||||
frequency_penalty: float | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
guided_choice: list[str] | None = None,
|
||||
prompt_logprobs: int | None = None,
|
||||
suffix: str | None = None,
|
||||
) -> OpenAICompletion:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
input: str | list[str],
|
||||
encoding_format: str | None = "float",
|
||||
dimensions: int | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
|
@ -1,190 +0,0 @@
|
|||
# NVIDIA Inference Provider for LlamaStack
|
||||
|
||||
This provider enables running inference using NVIDIA NIM.
|
||||
|
||||
## Features
|
||||
- Endpoints for completions, chat completions, and embeddings for registered models
|
||||
|
||||
## Getting Started
|
||||
|
||||
### Prerequisites
|
||||
|
||||
- LlamaStack with NVIDIA configuration
|
||||
- Access to NVIDIA NIM deployment
|
||||
- NIM for model to use for inference is deployed
|
||||
|
||||
### Setup
|
||||
|
||||
Build the NVIDIA environment:
|
||||
|
||||
```bash
|
||||
llama stack build --distro nvidia --image-type venv
|
||||
```
|
||||
|
||||
### Basic Usage using the LlamaStack Python Client
|
||||
|
||||
#### Initialize the client
|
||||
|
||||
```python
|
||||
import os
|
||||
|
||||
os.environ["NVIDIA_API_KEY"] = (
|
||||
"" # Required if using hosted NIM endpoint. If self-hosted, not required.
|
||||
)
|
||||
os.environ["NVIDIA_BASE_URL"] = "http://nim.test" # NIM URL
|
||||
|
||||
from llama_stack.core.library_client import LlamaStackAsLibraryClient
|
||||
|
||||
client = LlamaStackAsLibraryClient("nvidia")
|
||||
client.initialize()
|
||||
```
|
||||
|
||||
### Create Chat Completion
|
||||
|
||||
The following example shows how to create a chat completion for an NVIDIA NIM.
|
||||
|
||||
```python
|
||||
response = client.chat.completions.create(
|
||||
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You must respond to each message with only one word",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Complete the sentence using one word: Roses are red, violets are:",
|
||||
},
|
||||
],
|
||||
stream=False,
|
||||
max_tokens=50,
|
||||
)
|
||||
print(f"Response: {response.choices[0].message.content}")
|
||||
```
|
||||
|
||||
### Tool Calling Example ###
|
||||
|
||||
The following example shows how to do tool calling for an NVIDIA NIM.
|
||||
|
||||
```python
|
||||
from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
|
||||
|
||||
tool_definition = ToolDefinition(
|
||||
tool_name="get_weather",
|
||||
description="Get current weather information for a location",
|
||||
parameters={
|
||||
"location": ToolParamDefinition(
|
||||
param_type="string",
|
||||
description="The city and state, e.g. San Francisco, CA",
|
||||
required=True,
|
||||
),
|
||||
"unit": ToolParamDefinition(
|
||||
param_type="string",
|
||||
description="Temperature unit (celsius or fahrenheit)",
|
||||
required=False,
|
||||
default="celsius",
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
tool_response = client.chat.completions.create(
|
||||
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||
messages=[{"role": "user", "content": "What's the weather like in San Francisco?"}],
|
||||
tools=[tool_definition],
|
||||
)
|
||||
|
||||
print(f"Tool Response: {tool_response.choices[0].message.content}")
|
||||
if tool_response.choices[0].message.tool_calls:
|
||||
for tool_call in tool_response.choices[0].message.tool_calls:
|
||||
print(f"Tool Called: {tool_call.tool_name}")
|
||||
print(f"Arguments: {tool_call.arguments}")
|
||||
```
|
||||
|
||||
### Structured Output Example
|
||||
|
||||
The following example shows how to do structured output for an NVIDIA NIM.
|
||||
|
||||
```python
|
||||
from llama_stack.apis.inference import JsonSchemaResponseFormat, ResponseFormatType
|
||||
|
||||
person_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"age": {"type": "integer"},
|
||||
"occupation": {"type": "string"},
|
||||
},
|
||||
"required": ["name", "age", "occupation"],
|
||||
}
|
||||
|
||||
response_format = JsonSchemaResponseFormat(
|
||||
type=ResponseFormatType.json_schema, json_schema=person_schema
|
||||
)
|
||||
|
||||
structured_response = client.chat.completions.create(
|
||||
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Create a profile for a fictional person named Alice who is 30 years old and is a software engineer. ",
|
||||
}
|
||||
],
|
||||
response_format=response_format,
|
||||
)
|
||||
|
||||
print(f"Structured Response: {structured_response.choices[0].message.content}")
|
||||
```
|
||||
|
||||
### Create Embeddings
|
||||
|
||||
The following example shows how to create embeddings for an NVIDIA NIM.
|
||||
|
||||
> [!NOTE]
|
||||
> NVIDIA asymmetric embedding models (e.g., `nvidia/llama-3.2-nv-embedqa-1b-v2`) require an `input_type` parameter not present in the standard OpenAI embeddings API. The NVIDIA Inference Adapter automatically sets `input_type="query"` when using the OpenAI-compatible embeddings endpoint for NVIDIA. For passage embeddings, use the `embeddings` API with `task_type="document"`.
|
||||
|
||||
```python
|
||||
response = client.inference.embeddings(
|
||||
model_id="nvidia/llama-3.2-nv-embedqa-1b-v2",
|
||||
contents=["What is the capital of France?"],
|
||||
task_type="query",
|
||||
)
|
||||
print(f"Embeddings: {response.embeddings}")
|
||||
```
|
||||
|
||||
### Vision Language Models Example
|
||||
|
||||
The following example shows how to run vision inference by using an NVIDIA NIM.
|
||||
|
||||
```python
|
||||
def load_image_as_base64(image_path):
|
||||
with open(image_path, "rb") as image_file:
|
||||
img_bytes = image_file.read()
|
||||
return base64.b64encode(img_bytes).decode("utf-8")
|
||||
|
||||
|
||||
image_path = {path_to_the_image}
|
||||
demo_image_b64 = load_image_as_base64(image_path)
|
||||
|
||||
vlm_response = client.chat.completions.create(
|
||||
model="nvidia/vila",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"image": {
|
||||
"data": demo_image_b64,
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Please describe what you see in this image in detail.",
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
print(f"VLM Response: {vlm_response.choices[0].message.content}")
|
||||
```
|
||||
|
|
@ -1,23 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.inference import Inference
|
||||
|
||||
from .config import NVIDIAConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: NVIDIAConfig, _deps) -> Inference:
|
||||
# import dynamically so `llama stack build` does not fail due to missing dependencies
|
||||
from .nvidia import NVIDIAInferenceAdapter
|
||||
|
||||
if not isinstance(config, NVIDIAConfig):
|
||||
raise RuntimeError(f"Unexpected config type: {type(config)}")
|
||||
adapter = NVIDIAInferenceAdapter(config=config)
|
||||
await adapter.initialize()
|
||||
return adapter
|
||||
|
||||
|
||||
__all__ = ["get_adapter_impl", "NVIDIAConfig"]
|
||||
|
|
@ -1,64 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class NVIDIAConfig(RemoteInferenceProviderConfig):
|
||||
"""
|
||||
Configuration for the NVIDIA NIM inference endpoint.
|
||||
|
||||
Attributes:
|
||||
url (str): A base url for accessing the NVIDIA NIM, e.g. http://localhost:8000
|
||||
api_key (str): The access key for the hosted NIM endpoints
|
||||
|
||||
There are two ways to access NVIDIA NIMs -
|
||||
0. Hosted: Preview APIs hosted at https://integrate.api.nvidia.com
|
||||
1. Self-hosted: You can run NVIDIA NIMs on your own infrastructure
|
||||
|
||||
By default the configuration is set to use the hosted APIs. This requires
|
||||
an API key which can be obtained from https://ngc.nvidia.com/.
|
||||
|
||||
By default the configuration will attempt to read the NVIDIA_API_KEY environment
|
||||
variable to set the api_key. Please do not put your API key in code.
|
||||
|
||||
If you are using a self-hosted NVIDIA NIM, you can set the url to the
|
||||
URL of your running NVIDIA NIM and do not need to set the api_key.
|
||||
"""
|
||||
|
||||
url: str = Field(
|
||||
default_factory=lambda: os.getenv("NVIDIA_BASE_URL", "https://integrate.api.nvidia.com"),
|
||||
description="A base url for accessing the NVIDIA NIM",
|
||||
)
|
||||
timeout: int = Field(
|
||||
default=60,
|
||||
description="Timeout for the HTTP requests",
|
||||
)
|
||||
append_api_version: bool = Field(
|
||||
default_factory=lambda: os.getenv("NVIDIA_APPEND_API_VERSION", "True").lower() != "false",
|
||||
description="When set to false, the API version will not be appended to the base_url. By default, it is true.",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(
|
||||
cls,
|
||||
url: str = "${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com}",
|
||||
api_key: str = "${env.NVIDIA_API_KEY:=}",
|
||||
append_api_version: bool = "${env.NVIDIA_APPEND_API_VERSION:=True}",
|
||||
**kwargs,
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"url": url,
|
||||
"api_key": api_key,
|
||||
"append_api_version": append_api_version,
|
||||
}
|
||||
|
|
@ -1,128 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from openai import NOT_GIVEN
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
OpenAIEmbeddingData,
|
||||
OpenAIEmbeddingsResponse,
|
||||
OpenAIEmbeddingUsage,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
from . import NVIDIAConfig
|
||||
from .utils import _is_nvidia_hosted
|
||||
|
||||
logger = get_logger(name=__name__, category="inference::nvidia")
|
||||
|
||||
|
||||
class NVIDIAInferenceAdapter(OpenAIMixin):
|
||||
config: NVIDIAConfig
|
||||
|
||||
"""
|
||||
NVIDIA Inference Adapter for Llama Stack.
|
||||
|
||||
Note: The inheritance order is important here. OpenAIMixin must come before
|
||||
ModelRegistryHelper to ensure that OpenAIMixin.check_model_availability()
|
||||
is used instead of ModelRegistryHelper.check_model_availability(). It also
|
||||
must come before Inference to ensure that OpenAIMixin methods are available
|
||||
in the Inference interface.
|
||||
|
||||
- OpenAIMixin.check_model_availability() queries the NVIDIA API to check if a model exists
|
||||
- ModelRegistryHelper.check_model_availability() just returns False and shows a warning
|
||||
"""
|
||||
|
||||
# source: https://docs.nvidia.com/nim/nemo-retriever/text-embedding/latest/support-matrix.html
|
||||
embedding_model_metadata: dict[str, dict[str, int]] = {
|
||||
"nvidia/llama-3.2-nv-embedqa-1b-v2": {"embedding_dimension": 2048, "context_length": 8192},
|
||||
"nvidia/nv-embedqa-e5-v5": {"embedding_dimension": 512, "context_length": 1024},
|
||||
"nvidia/nv-embedqa-mistral-7b-v2": {"embedding_dimension": 512, "context_length": 4096},
|
||||
"snowflake/arctic-embed-l": {"embedding_dimension": 512, "context_length": 1024},
|
||||
}
|
||||
|
||||
async def initialize(self) -> None:
|
||||
logger.info(f"Initializing NVIDIAInferenceAdapter({self.config.url})...")
|
||||
|
||||
if _is_nvidia_hosted(self.config):
|
||||
if not self.config.auth_credential:
|
||||
raise RuntimeError(
|
||||
"API key is required for hosted NVIDIA NIM. Either provide an API key or use a self-hosted NIM."
|
||||
)
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
"""
|
||||
Get the API key for OpenAI mixin.
|
||||
|
||||
:return: The NVIDIA API key
|
||||
"""
|
||||
if self.config.auth_credential:
|
||||
return self.config.auth_credential.get_secret_value()
|
||||
|
||||
if not _is_nvidia_hosted(self.config):
|
||||
return "NO KEY REQUIRED"
|
||||
|
||||
return None
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
"""
|
||||
Get the base URL for OpenAI mixin.
|
||||
|
||||
:return: The NVIDIA API base URL
|
||||
"""
|
||||
return f"{self.config.url}/v1" if self.config.append_api_version else self.config.url
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
input: str | list[str],
|
||||
encoding_format: str | None = "float",
|
||||
dimensions: int | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
"""
|
||||
OpenAI-compatible embeddings for NVIDIA NIM.
|
||||
|
||||
Note: NVIDIA NIM asymmetric embedding models require an "input_type" field not present in the standard OpenAI embeddings API.
|
||||
We default this to "query" to ensure requests succeed when using the
|
||||
OpenAI-compatible endpoint. For passage embeddings, use the embeddings API with
|
||||
`task_type='document'`.
|
||||
"""
|
||||
extra_body: dict[str, object] = {"input_type": "query"}
|
||||
logger.warning(
|
||||
"NVIDIA OpenAI-compatible embeddings: defaulting to input_type='query'. "
|
||||
"For passage embeddings, use the embeddings API with task_type='document'."
|
||||
)
|
||||
|
||||
response = await self.client.embeddings.create(
|
||||
model=await self._get_provider_model_id(model),
|
||||
input=input,
|
||||
encoding_format=encoding_format if encoding_format is not None else NOT_GIVEN,
|
||||
dimensions=dimensions if dimensions is not None else NOT_GIVEN,
|
||||
user=user if user is not None else NOT_GIVEN,
|
||||
extra_body=extra_body,
|
||||
)
|
||||
|
||||
data = []
|
||||
for i, embedding_data in enumerate(response.data):
|
||||
data.append(
|
||||
OpenAIEmbeddingData(
|
||||
embedding=embedding_data.embedding,
|
||||
index=i,
|
||||
)
|
||||
)
|
||||
|
||||
usage = OpenAIEmbeddingUsage(
|
||||
prompt_tokens=response.usage.prompt_tokens,
|
||||
total_tokens=response.usage.total_tokens,
|
||||
)
|
||||
|
||||
return OpenAIEmbeddingsResponse(
|
||||
data=data,
|
||||
model=response.model,
|
||||
usage=usage,
|
||||
)
|
||||
|
|
@ -1,11 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from . import NVIDIAConfig
|
||||
|
||||
|
||||
def _is_nvidia_hosted(config: NVIDIAConfig) -> bool:
|
||||
return "integrate.api.nvidia.com" in config.url
|
||||
|
|
@ -1,15 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import OllamaImplConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: OllamaImplConfig, _deps):
|
||||
from .ollama import OllamaInferenceAdapter
|
||||
|
||||
impl = OllamaInferenceAdapter(config=config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
@ -1,25 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field, SecretStr
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
|
||||
DEFAULT_OLLAMA_URL = "http://localhost:11434"
|
||||
|
||||
|
||||
class OllamaImplConfig(RemoteInferenceProviderConfig):
|
||||
auth_credential: SecretStr | None = Field(default=None, exclude=True)
|
||||
|
||||
url: str = DEFAULT_OLLAMA_URL
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, url: str = "${env.OLLAMA_URL:=http://localhost:11434}", **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
"url": url,
|
||||
}
|
||||
|
|
@ -1,102 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
import asyncio
|
||||
|
||||
from ollama import AsyncClient as AsyncOllamaClient
|
||||
|
||||
from llama_stack.apis.common.errors import UnsupportedModelError
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import (
|
||||
HealthResponse,
|
||||
HealthStatus,
|
||||
)
|
||||
from llama_stack.providers.remote.inference.ollama.config import OllamaImplConfig
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
logger = get_logger(name=__name__, category="inference::ollama")
|
||||
|
||||
|
||||
class OllamaInferenceAdapter(OpenAIMixin):
|
||||
config: OllamaImplConfig
|
||||
|
||||
# automatically set by the resolver when instantiating the provider
|
||||
__provider_id__: str
|
||||
|
||||
embedding_model_metadata: dict[str, dict[str, int]] = {
|
||||
"all-minilm:l6-v2": {
|
||||
"embedding_dimension": 384,
|
||||
"context_length": 512,
|
||||
},
|
||||
"nomic-embed-text:latest": {
|
||||
"embedding_dimension": 768,
|
||||
"context_length": 8192,
|
||||
},
|
||||
"nomic-embed-text:v1.5": {
|
||||
"embedding_dimension": 768,
|
||||
"context_length": 8192,
|
||||
},
|
||||
"nomic-embed-text:137m-v1.5-fp16": {
|
||||
"embedding_dimension": 768,
|
||||
"context_length": 8192,
|
||||
},
|
||||
}
|
||||
|
||||
download_images: bool = True
|
||||
_clients: dict[asyncio.AbstractEventLoop, AsyncOllamaClient] = {}
|
||||
|
||||
@property
|
||||
def ollama_client(self) -> AsyncOllamaClient:
|
||||
# ollama client attaches itself to the current event loop (sadly?)
|
||||
loop = asyncio.get_running_loop()
|
||||
if loop not in self._clients:
|
||||
self._clients[loop] = AsyncOllamaClient(host=self.config.url)
|
||||
return self._clients[loop]
|
||||
|
||||
def get_api_key(self):
|
||||
return "NO KEY REQUIRED"
|
||||
|
||||
def get_base_url(self):
|
||||
return self.config.url.rstrip("/") + "/v1"
|
||||
|
||||
async def initialize(self) -> None:
|
||||
logger.info(f"checking connectivity to Ollama at `{self.config.url}`...")
|
||||
r = await self.health()
|
||||
if r["status"] == HealthStatus.ERROR:
|
||||
logger.warning(
|
||||
f"Ollama Server is not running (message: {r['message']}). Make sure to start it using `ollama serve` in a separate terminal"
|
||||
)
|
||||
|
||||
async def health(self) -> HealthResponse:
|
||||
"""
|
||||
Performs a health check by verifying connectivity to the Ollama server.
|
||||
This method is used by initialize() and the Provider API to verify that the service is running
|
||||
correctly.
|
||||
Returns:
|
||||
HealthResponse: A dictionary containing the health status.
|
||||
"""
|
||||
try:
|
||||
await self.ollama_client.ps()
|
||||
return HealthResponse(status=HealthStatus.OK)
|
||||
except Exception as e:
|
||||
return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}")
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
self._clients.clear()
|
||||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
if await self.check_model_availability(model.provider_model_id):
|
||||
return model
|
||||
elif await self.check_model_availability(f"{model.provider_model_id}:latest"):
|
||||
model.provider_resource_id = f"{model.provider_model_id}:latest"
|
||||
logger.warning(
|
||||
f"Imprecise provider resource id was used but 'latest' is available in Ollama - using '{model.provider_model_id}'"
|
||||
)
|
||||
return model
|
||||
|
||||
raise UnsupportedModelError(model.provider_model_id, list(self._model_cache.keys()))
|
||||
|
|
@ -1,15 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import OpenAIConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: OpenAIConfig, _deps):
|
||||
from .openai import OpenAIInferenceAdapter
|
||||
|
||||
impl = OpenAIInferenceAdapter(config=config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
@ -1,39 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
class OpenAIProviderDataValidator(BaseModel):
|
||||
openai_api_key: str | None = Field(
|
||||
default=None,
|
||||
description="API key for OpenAI models",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIConfig(RemoteInferenceProviderConfig):
|
||||
base_url: str = Field(
|
||||
default="https://api.openai.com/v1",
|
||||
description="Base URL for OpenAI API",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(
|
||||
cls,
|
||||
api_key: str = "${env.OPENAI_API_KEY:=}",
|
||||
base_url: str = "${env.OPENAI_BASE_URL:=https://api.openai.com/v1}",
|
||||
**kwargs,
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"api_key": api_key,
|
||||
"base_url": base_url,
|
||||
}
|
||||
|
|
@ -1,38 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
from .config import OpenAIConfig
|
||||
|
||||
logger = get_logger(name=__name__, category="inference::openai")
|
||||
|
||||
|
||||
#
|
||||
# This OpenAI adapter implements Inference methods using OpenAIMixin
|
||||
#
|
||||
class OpenAIInferenceAdapter(OpenAIMixin):
|
||||
"""
|
||||
OpenAI Inference Adapter for Llama Stack.
|
||||
"""
|
||||
|
||||
config: OpenAIConfig
|
||||
|
||||
provider_data_api_key_field: str = "openai_api_key"
|
||||
|
||||
embedding_model_metadata: dict[str, dict[str, int]] = {
|
||||
"text-embedding-3-small": {"embedding_dimension": 1536, "context_length": 8192},
|
||||
"text-embedding-3-large": {"embedding_dimension": 3072, "context_length": 8192},
|
||||
}
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
"""
|
||||
Get the OpenAI API base URL.
|
||||
|
||||
Returns the OpenAI API base URL from the configuration.
|
||||
"""
|
||||
return self.config.base_url
|
||||
|
|
@ -1,23 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .config import PassthroughImplConfig
|
||||
|
||||
|
||||
class PassthroughProviderDataValidator(BaseModel):
|
||||
url: str
|
||||
api_key: str
|
||||
|
||||
|
||||
async def get_adapter_impl(config: PassthroughImplConfig, _deps):
|
||||
from .passthrough import PassthroughInferenceAdapter
|
||||
|
||||
assert isinstance(config, PassthroughImplConfig), f"Unexpected config type: {type(config)}"
|
||||
impl = PassthroughInferenceAdapter(config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
@ -1,34 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field, SecretStr
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class PassthroughImplConfig(RemoteInferenceProviderConfig):
|
||||
url: str = Field(
|
||||
default=None,
|
||||
description="The URL for the passthrough endpoint",
|
||||
)
|
||||
|
||||
api_key: SecretStr | None = Field(
|
||||
default=None,
|
||||
description="API Key for the passthrouth endpoint",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(
|
||||
cls, url: str = "${env.PASSTHROUGH_URL}", api_key: str = "${env.PASSTHROUGH_API_KEY}", **kwargs
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"url": url,
|
||||
"api_key": api_key,
|
||||
}
|
||||
|
|
@ -1,205 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
from llama_stack_client import AsyncLlamaStackClient
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
Inference,
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAICompletion,
|
||||
OpenAIEmbeddingsResponse,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
)
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.core.library_client import convert_pydantic_to_json_value
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
|
||||
|
||||
from .config import PassthroughImplConfig
|
||||
|
||||
|
||||
class PassthroughInferenceAdapter(Inference):
|
||||
def __init__(self, config: PassthroughImplConfig) -> None:
|
||||
ModelRegistryHelper.__init__(self)
|
||||
self.config = config
|
||||
|
||||
async def unregister_model(self, model_id: str) -> None:
|
||||
pass
|
||||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
return model
|
||||
|
||||
def _get_client(self) -> AsyncLlamaStackClient:
|
||||
passthrough_url = None
|
||||
passthrough_api_key = None
|
||||
provider_data = None
|
||||
|
||||
if self.config.url is not None:
|
||||
passthrough_url = self.config.url
|
||||
else:
|
||||
provider_data = self.get_request_provider_data()
|
||||
if provider_data is None or not provider_data.passthrough_url:
|
||||
raise ValueError(
|
||||
'Pass url of the passthrough endpoint in the header X-LlamaStack-Provider-Data as { "passthrough_url": <your passthrough url>}'
|
||||
)
|
||||
passthrough_url = provider_data.passthrough_url
|
||||
|
||||
if self.config.api_key is not None:
|
||||
passthrough_api_key = self.config.api_key.get_secret_value()
|
||||
else:
|
||||
provider_data = self.get_request_provider_data()
|
||||
if provider_data is None or not provider_data.passthrough_api_key:
|
||||
raise ValueError(
|
||||
'Pass API Key for the passthrough endpoint in the header X-LlamaStack-Provider-Data as { "passthrough_api_key": <your api key>}'
|
||||
)
|
||||
passthrough_api_key = provider_data.passthrough_api_key
|
||||
|
||||
return AsyncLlamaStackClient(
|
||||
base_url=passthrough_url,
|
||||
api_key=passthrough_api_key,
|
||||
provider_data=provider_data,
|
||||
)
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
input: str | list[str],
|
||||
encoding_format: str | None = "float",
|
||||
dimensions: int | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def openai_completion(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str | list[str] | list[int] | list[list[int]],
|
||||
best_of: int | None = None,
|
||||
echo: bool | None = None,
|
||||
frequency_penalty: float | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
guided_choice: list[str] | None = None,
|
||||
prompt_logprobs: int | None = None,
|
||||
suffix: str | None = None,
|
||||
) -> OpenAICompletion:
|
||||
client = self._get_client()
|
||||
model_obj = await self.model_store.get_model(model)
|
||||
|
||||
params = await prepare_openai_completion_params(
|
||||
model=model_obj.provider_resource_id,
|
||||
prompt=prompt,
|
||||
best_of=best_of,
|
||||
echo=echo,
|
||||
frequency_penalty=frequency_penalty,
|
||||
logit_bias=logit_bias,
|
||||
logprobs=logprobs,
|
||||
max_tokens=max_tokens,
|
||||
n=n,
|
||||
presence_penalty=presence_penalty,
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
stream_options=stream_options,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
user=user,
|
||||
guided_choice=guided_choice,
|
||||
prompt_logprobs=prompt_logprobs,
|
||||
)
|
||||
|
||||
return await client.inference.openai_completion(**params)
|
||||
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list[OpenAIMessageParam],
|
||||
frequency_penalty: float | None = None,
|
||||
function_call: str | dict[str, Any] | None = None,
|
||||
functions: list[dict[str, Any]] | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_completion_tokens: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
parallel_tool_calls: bool | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
response_format: OpenAIResponseFormatParam | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
top_logprobs: int | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
client = self._get_client()
|
||||
model_obj = await self.model_store.get_model(model)
|
||||
|
||||
params = await prepare_openai_completion_params(
|
||||
model=model_obj.provider_resource_id,
|
||||
messages=messages,
|
||||
frequency_penalty=frequency_penalty,
|
||||
function_call=function_call,
|
||||
functions=functions,
|
||||
logit_bias=logit_bias,
|
||||
logprobs=logprobs,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
max_tokens=max_tokens,
|
||||
n=n,
|
||||
parallel_tool_calls=parallel_tool_calls,
|
||||
presence_penalty=presence_penalty,
|
||||
response_format=response_format,
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
stream_options=stream_options,
|
||||
temperature=temperature,
|
||||
tool_choice=tool_choice,
|
||||
tools=tools,
|
||||
top_logprobs=top_logprobs,
|
||||
top_p=top_p,
|
||||
user=user,
|
||||
)
|
||||
|
||||
return await client.inference.openai_chat_completion(**params)
|
||||
|
||||
def cast_value_to_json_dict(self, request_params: dict[str, Any]) -> dict[str, Any]:
|
||||
json_params = {}
|
||||
for key, value in request_params.items():
|
||||
json_input = convert_pydantic_to_json_value(value)
|
||||
if isinstance(json_input, dict):
|
||||
json_input = {k: v for k, v in json_input.items() if v is not None}
|
||||
elif isinstance(json_input, list):
|
||||
json_input = [x for x in json_input if x is not None]
|
||||
new_input = []
|
||||
for x in json_input:
|
||||
if isinstance(x, dict):
|
||||
x = {k: v for k, v in x.items() if v is not None}
|
||||
new_input.append(x)
|
||||
json_input = new_input
|
||||
|
||||
json_params[key] = json_input
|
||||
|
||||
return json_params
|
||||
|
|
@ -1,16 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import RunpodImplConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: RunpodImplConfig, _deps):
|
||||
from .runpod import RunpodInferenceAdapter
|
||||
|
||||
assert isinstance(config, RunpodImplConfig), f"Unexpected config type: {type(config)}"
|
||||
impl = RunpodInferenceAdapter(config=config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
@ -1,32 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field, SecretStr
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RunpodImplConfig(RemoteInferenceProviderConfig):
|
||||
url: str | None = Field(
|
||||
default=None,
|
||||
description="The URL for the Runpod model serving endpoint",
|
||||
)
|
||||
auth_credential: SecretStr | None = Field(
|
||||
default=None,
|
||||
alias="api_token",
|
||||
description="The API token",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, **kwargs: Any) -> dict[str, Any]:
|
||||
return {
|
||||
"url": "${env.RUNPOD_URL:=}",
|
||||
"api_token": "${env.RUNPOD_API_TOKEN}",
|
||||
}
|
||||
|
|
@ -1,85 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
from .config import RunpodImplConfig
|
||||
|
||||
|
||||
class RunpodInferenceAdapter(OpenAIMixin):
|
||||
"""
|
||||
Adapter for RunPod's OpenAI-compatible API endpoints.
|
||||
Supports VLLM for serverless endpoint self-hosted or public endpoints.
|
||||
Can work with any runpod endpoints that support OpenAI-compatible API
|
||||
"""
|
||||
|
||||
config: RunpodImplConfig
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
"""Get base URL for OpenAI client."""
|
||||
return self.config.url
|
||||
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list[OpenAIMessageParam],
|
||||
frequency_penalty: float | None = None,
|
||||
function_call: str | dict[str, Any] | None = None,
|
||||
functions: list[dict[str, Any]] | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_completion_tokens: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
parallel_tool_calls: bool | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
response_format: OpenAIResponseFormatParam | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
top_logprobs: int | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
):
|
||||
"""Override to add RunPod-specific stream_options requirement."""
|
||||
if stream and not stream_options:
|
||||
stream_options = {"include_usage": True}
|
||||
|
||||
return await super().openai_chat_completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
frequency_penalty=frequency_penalty,
|
||||
function_call=function_call,
|
||||
functions=functions,
|
||||
logit_bias=logit_bias,
|
||||
logprobs=logprobs,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
max_tokens=max_tokens,
|
||||
n=n,
|
||||
parallel_tool_calls=parallel_tool_calls,
|
||||
presence_penalty=presence_penalty,
|
||||
response_format=response_format,
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
stream_options=stream_options,
|
||||
temperature=temperature,
|
||||
tool_choice=tool_choice,
|
||||
tools=tools,
|
||||
top_logprobs=top_logprobs,
|
||||
top_p=top_p,
|
||||
user=user,
|
||||
)
|
||||
|
|
@ -1,16 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import SambaNovaImplConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: SambaNovaImplConfig, _deps):
|
||||
from .sambanova import SambaNovaInferenceAdapter
|
||||
|
||||
assert isinstance(config, SambaNovaImplConfig), f"Unexpected config type: {type(config)}"
|
||||
impl = SambaNovaInferenceAdapter(config=config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
@ -1,34 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
class SambaNovaProviderDataValidator(BaseModel):
|
||||
sambanova_api_key: str | None = Field(
|
||||
default=None,
|
||||
description="Sambanova Cloud API key",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class SambaNovaImplConfig(RemoteInferenceProviderConfig):
|
||||
url: str = Field(
|
||||
default="https://api.sambanova.ai/v1",
|
||||
description="The URL for the SambaNova AI server",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, api_key: str = "${env.SAMBANOVA_API_KEY:=}", **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
"url": "https://api.sambanova.ai/v1",
|
||||
"api_key": api_key,
|
||||
}
|
||||
|
|
@ -1,28 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
from .config import SambaNovaImplConfig
|
||||
|
||||
|
||||
class SambaNovaInferenceAdapter(OpenAIMixin):
|
||||
config: SambaNovaImplConfig
|
||||
|
||||
provider_data_api_key_field: str = "sambanova_api_key"
|
||||
download_images: bool = True # SambaNova does not support image downloads server-size, perform them on the client
|
||||
"""
|
||||
SambaNova Inference Adapter for Llama Stack.
|
||||
"""
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
"""
|
||||
Get the base URL for OpenAI mixin.
|
||||
|
||||
:return: The SambaNova base URL
|
||||
"""
|
||||
return self.config.url
|
||||
|
|
@ -1,28 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(
|
||||
config: InferenceAPIImplConfig | InferenceEndpointImplConfig | TGIImplConfig,
|
||||
_deps,
|
||||
):
|
||||
from .tgi import InferenceAPIAdapter, InferenceEndpointAdapter, TGIAdapter
|
||||
|
||||
if isinstance(config, TGIImplConfig):
|
||||
impl = TGIAdapter()
|
||||
elif isinstance(config, InferenceAPIImplConfig):
|
||||
impl = InferenceAPIAdapter()
|
||||
elif isinstance(config, InferenceEndpointImplConfig):
|
||||
impl = InferenceEndpointAdapter()
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid configuration. Expected 'TGIAdapter', 'InferenceAPIImplConfig' or 'InferenceEndpointImplConfig'. Got {type(config)}."
|
||||
)
|
||||
|
||||
await impl.initialize(config)
|
||||
return impl
|
||||
|
|
@ -1,76 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class TGIImplConfig(RemoteInferenceProviderConfig):
|
||||
auth_credential: SecretStr | None = Field(default=None, exclude=True)
|
||||
|
||||
url: str = Field(
|
||||
description="The URL for the TGI serving endpoint",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(
|
||||
cls,
|
||||
url: str = "${env.TGI_URL:=}",
|
||||
**kwargs,
|
||||
):
|
||||
return {
|
||||
"url": url,
|
||||
}
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class InferenceEndpointImplConfig(BaseModel):
|
||||
endpoint_name: str = Field(
|
||||
description="The name of the Hugging Face Inference Endpoint in the format of '{namespace}/{endpoint_name}' (e.g. 'my-cool-org/meta-llama-3-1-8b-instruct-rce'). Namespace is optional and will default to the user account if not provided.",
|
||||
)
|
||||
api_token: SecretStr | None = Field(
|
||||
default=None,
|
||||
description="Your Hugging Face user access token (will default to locally saved token if not provided)",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(
|
||||
cls,
|
||||
endpoint_name: str = "${env.INFERENCE_ENDPOINT_NAME}",
|
||||
api_token: str = "${env.HF_API_TOKEN}",
|
||||
**kwargs,
|
||||
):
|
||||
return {
|
||||
"endpoint_name": endpoint_name,
|
||||
"api_token": api_token,
|
||||
}
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class InferenceAPIImplConfig(BaseModel):
|
||||
huggingface_repo: str = Field(
|
||||
description="The model ID of the model on the Hugging Face Hub (e.g. 'meta-llama/Meta-Llama-3.1-70B-Instruct')",
|
||||
)
|
||||
api_token: SecretStr | None = Field(
|
||||
default=None,
|
||||
description="Your Hugging Face user access token (will default to locally saved token if not provided)",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(
|
||||
cls,
|
||||
repo: str = "${env.INFERENCE_MODEL}",
|
||||
api_token: str = "${env.HF_API_TOKEN}",
|
||||
**kwargs,
|
||||
):
|
||||
return {
|
||||
"huggingface_repo": repo,
|
||||
"api_token": api_token,
|
||||
}
|
||||
|
|
@ -1,86 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from collections.abc import Iterable
|
||||
|
||||
from huggingface_hub import AsyncInferenceClient, HfApi
|
||||
from pydantic import SecretStr
|
||||
|
||||
from llama_stack.apis.inference import OpenAIEmbeddingsResponse
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig
|
||||
|
||||
log = get_logger(name=__name__, category="inference::tgi")
|
||||
|
||||
|
||||
class _HfAdapter(OpenAIMixin):
|
||||
url: str
|
||||
api_key: SecretStr
|
||||
|
||||
hf_client: AsyncInferenceClient
|
||||
max_tokens: int
|
||||
model_id: str
|
||||
|
||||
overwrite_completion_id = True # TGI always returns id=""
|
||||
|
||||
def get_api_key(self):
|
||||
return "NO KEY REQUIRED"
|
||||
|
||||
def get_base_url(self):
|
||||
return self.url
|
||||
|
||||
async def list_provider_model_ids(self) -> Iterable[str]:
|
||||
return [self.model_id]
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
input: str | list[str],
|
||||
encoding_format: str | None = "float",
|
||||
dimensions: int | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class TGIAdapter(_HfAdapter):
|
||||
async def initialize(self, config: TGIImplConfig) -> None:
|
||||
if not config.url:
|
||||
raise ValueError("You must provide a URL in run.yaml (or via the TGI_URL environment variable) to use TGI.")
|
||||
log.info(f"Initializing TGI client with url={config.url}")
|
||||
self.hf_client = AsyncInferenceClient(model=config.url, provider="hf-inference")
|
||||
endpoint_info = await self.hf_client.get_endpoint_info()
|
||||
self.max_tokens = endpoint_info["max_total_tokens"]
|
||||
self.model_id = endpoint_info["model_id"]
|
||||
self.url = f"{config.url.rstrip('/')}/v1"
|
||||
self.api_key = SecretStr("NO_KEY")
|
||||
|
||||
|
||||
class InferenceAPIAdapter(_HfAdapter):
|
||||
async def initialize(self, config: InferenceAPIImplConfig) -> None:
|
||||
self.hf_client = AsyncInferenceClient(model=config.huggingface_repo, token=config.api_token.get_secret_value())
|
||||
endpoint_info = await self.hf_client.get_endpoint_info()
|
||||
self.max_tokens = endpoint_info["max_total_tokens"]
|
||||
self.model_id = endpoint_info["model_id"]
|
||||
# TODO: how do we set url for this?
|
||||
|
||||
|
||||
class InferenceEndpointAdapter(_HfAdapter):
|
||||
async def initialize(self, config: InferenceEndpointImplConfig) -> None:
|
||||
# Get the inference endpoint details
|
||||
api = HfApi(token=config.api_token.get_secret_value())
|
||||
endpoint = api.get_inference_endpoint(config.endpoint_name)
|
||||
# Wait for the endpoint to be ready (if not already)
|
||||
endpoint.wait(timeout=60)
|
||||
|
||||
# Initialize the adapter
|
||||
self.hf_client = endpoint.async_client
|
||||
self.model_id = endpoint.repository
|
||||
self.max_tokens = int(endpoint.raw["model"]["image"]["custom"]["env"]["MAX_TOTAL_TOKENS"])
|
||||
# TODO: how do we set url for this?
|
||||
|
|
@ -1,22 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .config import TogetherImplConfig
|
||||
|
||||
|
||||
class TogetherProviderDataValidator(BaseModel):
|
||||
together_api_key: str
|
||||
|
||||
|
||||
async def get_adapter_impl(config: TogetherImplConfig, _deps):
|
||||
from .together import TogetherInferenceAdapter
|
||||
|
||||
assert isinstance(config, TogetherImplConfig), f"Unexpected config type: {type(config)}"
|
||||
impl = TogetherInferenceAdapter(config=config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
@ -1,27 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class TogetherImplConfig(RemoteInferenceProviderConfig):
|
||||
url: str = Field(
|
||||
default="https://api.together.xyz/v1",
|
||||
description="The URL for the Together AI server",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
"url": "https://api.together.xyz/v1",
|
||||
"api_key": "${env.TOGETHER_API_KEY:=}",
|
||||
}
|
||||
|
|
@ -1,103 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from collections.abc import Iterable
|
||||
|
||||
from together import AsyncTogether
|
||||
from together.constants import BASE_URL
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
OpenAIEmbeddingsResponse,
|
||||
)
|
||||
from llama_stack.apis.inference.inference import OpenAIEmbeddingUsage
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.core.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
from .config import TogetherImplConfig
|
||||
|
||||
logger = get_logger(name=__name__, category="inference::together")
|
||||
|
||||
|
||||
class TogetherInferenceAdapter(OpenAIMixin, NeedsRequestProviderData):
|
||||
config: TogetherImplConfig
|
||||
|
||||
embedding_model_metadata: dict[str, dict[str, int]] = {
|
||||
"togethercomputer/m2-bert-80M-32k-retrieval": {"embedding_dimension": 768, "context_length": 32768},
|
||||
"BAAI/bge-large-en-v1.5": {"embedding_dimension": 1024, "context_length": 512},
|
||||
"BAAI/bge-base-en-v1.5": {"embedding_dimension": 768, "context_length": 512},
|
||||
"Alibaba-NLP/gte-modernbert-base": {"embedding_dimension": 768, "context_length": 8192},
|
||||
"intfloat/multilingual-e5-large-instruct": {"embedding_dimension": 1024, "context_length": 512},
|
||||
}
|
||||
|
||||
_model_cache: dict[str, Model] = {}
|
||||
|
||||
provider_data_api_key_field: str = "together_api_key"
|
||||
|
||||
def get_base_url(self):
|
||||
return BASE_URL
|
||||
|
||||
def _get_client(self) -> AsyncTogether:
|
||||
together_api_key = None
|
||||
config_api_key = self.config.auth_credential.get_secret_value() if self.config.auth_credential else None
|
||||
if config_api_key:
|
||||
together_api_key = config_api_key
|
||||
else:
|
||||
provider_data = self.get_request_provider_data()
|
||||
if provider_data is None or not provider_data.together_api_key:
|
||||
raise ValueError(
|
||||
'Pass Together API Key in the header X-LlamaStack-Provider-Data as { "together_api_key": <your api key>}'
|
||||
)
|
||||
together_api_key = provider_data.together_api_key
|
||||
return AsyncTogether(api_key=together_api_key)
|
||||
|
||||
async def list_provider_model_ids(self) -> Iterable[str]:
|
||||
# Together's /v1/models is not compatible with OpenAI's /v1/models. Together support ticket #13355 -> will not fix, use Together's own client
|
||||
return [m.id for m in await self._get_client().models.list()]
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
input: str | list[str],
|
||||
encoding_format: str | None = "float",
|
||||
dimensions: int | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
"""
|
||||
Together's OpenAI-compatible embeddings endpoint is not compatible with
|
||||
the standard OpenAI embeddings endpoint.
|
||||
|
||||
The endpoint -
|
||||
- not all models return usage information
|
||||
- does not support user param, returns 400 Unrecognized request arguments supplied: user
|
||||
- does not support dimensions param, returns 400 Unrecognized request arguments supplied: dimensions
|
||||
"""
|
||||
# Together support ticket #13332 -> will not fix
|
||||
if user is not None:
|
||||
raise ValueError("Together's embeddings endpoint does not support user param.")
|
||||
# Together support ticket #13333 -> escalated
|
||||
if dimensions is not None:
|
||||
raise ValueError("Together's embeddings endpoint does not support dimensions param.")
|
||||
|
||||
response = await self.client.embeddings.create(
|
||||
model=await self._get_provider_model_id(model),
|
||||
input=input,
|
||||
encoding_format=encoding_format,
|
||||
)
|
||||
|
||||
response.model = model # return the user the same model id they provided, avoid exposing the provider model id
|
||||
|
||||
# Together support ticket #13330 -> escalated
|
||||
# - togethercomputer/m2-bert-80M-32k-retrieval *does not* return usage information
|
||||
if not hasattr(response, "usage") or response.usage is None:
|
||||
logger.warning(
|
||||
f"Together's embedding endpoint for {model} did not return usage information, substituting -1s."
|
||||
)
|
||||
response.usage = OpenAIEmbeddingUsage(prompt_tokens=-1, total_tokens=-1)
|
||||
|
||||
return response # type: ignore[no-any-return]
|
||||
|
|
@ -1,15 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import VertexAIConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: VertexAIConfig, _deps):
|
||||
from .vertexai import VertexAIInferenceAdapter
|
||||
|
||||
impl = VertexAIInferenceAdapter(config=config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
@ -1,48 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
class VertexAIProviderDataValidator(BaseModel):
|
||||
vertex_project: str | None = Field(
|
||||
default=None,
|
||||
description="Google Cloud project ID for Vertex AI",
|
||||
)
|
||||
vertex_location: str | None = Field(
|
||||
default=None,
|
||||
description="Google Cloud location for Vertex AI (e.g., us-central1)",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class VertexAIConfig(RemoteInferenceProviderConfig):
|
||||
auth_credential: SecretStr | None = Field(default=None, exclude=True)
|
||||
|
||||
project: str = Field(
|
||||
description="Google Cloud project ID for Vertex AI",
|
||||
)
|
||||
location: str = Field(
|
||||
default="us-central1",
|
||||
description="Google Cloud location for Vertex AI",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(
|
||||
cls,
|
||||
project: str = "${env.VERTEX_AI_PROJECT:=}",
|
||||
location: str = "${env.VERTEX_AI_LOCATION:=us-central1}",
|
||||
**kwargs,
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"project": project,
|
||||
"location": location,
|
||||
}
|
||||
|
|
@ -1,44 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
import google.auth.transport.requests
|
||||
from google.auth import default
|
||||
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
from .config import VertexAIConfig
|
||||
|
||||
|
||||
class VertexAIInferenceAdapter(OpenAIMixin):
|
||||
config: VertexAIConfig
|
||||
|
||||
provider_data_api_key_field: str = "vertex_project"
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
"""
|
||||
Get an access token for Vertex AI using Application Default Credentials.
|
||||
|
||||
Vertex AI uses ADC instead of API keys. This method obtains an access token
|
||||
from the default credentials and returns it for use with the OpenAI-compatible client.
|
||||
"""
|
||||
try:
|
||||
# Get default credentials - will read from GOOGLE_APPLICATION_CREDENTIALS
|
||||
credentials, _ = default(scopes=["https://www.googleapis.com/auth/cloud-platform"])
|
||||
credentials.refresh(google.auth.transport.requests.Request())
|
||||
return str(credentials.token)
|
||||
except Exception:
|
||||
# If we can't get credentials, return empty string to let the env work with ADC directly
|
||||
return ""
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
"""
|
||||
Get the Vertex AI OpenAI-compatible API base URL.
|
||||
|
||||
Returns the Vertex AI OpenAI-compatible endpoint URL.
|
||||
Source: https://cloud.google.com/vertex-ai/generative-ai/docs/start/openai
|
||||
"""
|
||||
return f"https://{self.config.location}-aiplatform.googleapis.com/v1/projects/{self.config.project}/locations/{self.config.location}/endpoints/openapi"
|
||||
|
|
@ -1,22 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .config import VLLMInferenceAdapterConfig
|
||||
|
||||
|
||||
class VLLMProviderDataValidator(BaseModel):
|
||||
vllm_api_token: str | None = None
|
||||
|
||||
|
||||
async def get_adapter_impl(config: VLLMInferenceAdapterConfig, _deps):
|
||||
from .vllm import VLLMInferenceAdapter
|
||||
|
||||
assert isinstance(config, VLLMInferenceAdapterConfig), f"Unexpected config type: {type(config)}"
|
||||
impl = VLLMInferenceAdapter(config=config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
@ -1,59 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from pydantic import Field, SecretStr, field_validator
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class VLLMInferenceAdapterConfig(RemoteInferenceProviderConfig):
|
||||
url: str | None = Field(
|
||||
default=None,
|
||||
description="The URL for the vLLM model serving endpoint",
|
||||
)
|
||||
max_tokens: int = Field(
|
||||
default=4096,
|
||||
description="Maximum number of tokens to generate.",
|
||||
)
|
||||
auth_credential: SecretStr | None = Field(
|
||||
default=None,
|
||||
alias="api_token",
|
||||
description="The API token",
|
||||
)
|
||||
tls_verify: bool | str = Field(
|
||||
default=True,
|
||||
description="Whether to verify TLS certificates. Can be a boolean or a path to a CA certificate file.",
|
||||
)
|
||||
|
||||
@field_validator("tls_verify")
|
||||
@classmethod
|
||||
def validate_tls_verify(cls, v):
|
||||
if isinstance(v, str):
|
||||
# Otherwise, treat it as a cert path
|
||||
cert_path = Path(v).expanduser().resolve()
|
||||
if not cert_path.exists():
|
||||
raise ValueError(f"TLS certificate file does not exist: {v}")
|
||||
if not cert_path.is_file():
|
||||
raise ValueError(f"TLS certificate path is not a file: {v}")
|
||||
return v
|
||||
return v
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(
|
||||
cls,
|
||||
url: str = "${env.VLLM_URL:=}",
|
||||
**kwargs,
|
||||
):
|
||||
return {
|
||||
"url": url,
|
||||
"max_tokens": "${env.VLLM_MAX_TOKENS:=4096}",
|
||||
"api_token": "${env.VLLM_API_TOKEN:=fake}",
|
||||
"tls_verify": "${env.VLLM_TLS_VERIFY:=true}",
|
||||
}
|
||||
|
|
@ -1,155 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import httpx
|
||||
from openai.types.chat.chat_completion_chunk import (
|
||||
ChatCompletionChunk as OpenAIChatCompletionChunk,
|
||||
)
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
OpenAIChatCompletion,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
ToolChoice,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import (
|
||||
HealthResponse,
|
||||
HealthStatus,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
from .config import VLLMInferenceAdapterConfig
|
||||
|
||||
log = get_logger(name=__name__, category="inference::vllm")
|
||||
|
||||
|
||||
class VLLMInferenceAdapter(OpenAIMixin):
|
||||
config: VLLMInferenceAdapterConfig
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
provider_data_api_key_field: str = "vllm_api_token"
|
||||
|
||||
def get_api_key(self) -> str | None:
|
||||
if self.config.auth_credential:
|
||||
return self.config.auth_credential.get_secret_value()
|
||||
return "NO KEY REQUIRED"
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
"""Get the base URL from config."""
|
||||
if not self.config.url:
|
||||
raise ValueError("No base URL configured")
|
||||
return self.config.url
|
||||
|
||||
async def initialize(self) -> None:
|
||||
if not self.config.url:
|
||||
raise ValueError(
|
||||
"You must provide a URL in run.yaml (or via the VLLM_URL environment variable) to use vLLM."
|
||||
)
|
||||
|
||||
async def health(self) -> HealthResponse:
|
||||
"""
|
||||
Performs a health check by verifying connectivity to the remote vLLM server.
|
||||
This method is used by the Provider API to verify
|
||||
that the service is running correctly.
|
||||
Uses the unauthenticated /health endpoint.
|
||||
Returns:
|
||||
|
||||
HealthResponse: A dictionary containing the health status.
|
||||
"""
|
||||
try:
|
||||
base_url = self.get_base_url()
|
||||
health_url = urljoin(base_url, "health")
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(health_url)
|
||||
response.raise_for_status()
|
||||
return HealthResponse(status=HealthStatus.OK)
|
||||
except Exception as e:
|
||||
return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}")
|
||||
|
||||
def get_extra_client_params(self):
|
||||
return {"http_client": httpx.AsyncClient(verify=self.config.tls_verify)}
|
||||
|
||||
async def check_model_availability(self, model: str) -> bool:
|
||||
"""
|
||||
Skip the check when running without authentication.
|
||||
"""
|
||||
if not self.config.auth_credential:
|
||||
model_ids = []
|
||||
async for m in self.client.models.list():
|
||||
if m.id == model: # Found exact match
|
||||
return True
|
||||
model_ids.append(m.id)
|
||||
raise ValueError(f"Model '{model}' not found. Available models: {model_ids}")
|
||||
log.warning(f"Not checking model availability for {model} as API token may trigger OAuth workflow")
|
||||
return True
|
||||
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list[OpenAIMessageParam],
|
||||
frequency_penalty: float | None = None,
|
||||
function_call: str | dict[str, Any] | None = None,
|
||||
functions: list[dict[str, Any]] | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_completion_tokens: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
parallel_tool_calls: bool | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
response_format: OpenAIResponseFormatParam | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
top_logprobs: int | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
max_tokens = max_tokens or self.config.max_tokens
|
||||
|
||||
# This is to be consistent with OpenAI API and support vLLM <= v0.6.3
|
||||
# References:
|
||||
# * https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice
|
||||
# * https://github.com/vllm-project/vllm/pull/10000
|
||||
if not tools and tool_choice is not None:
|
||||
tool_choice = ToolChoice.none.value
|
||||
|
||||
return await super().openai_chat_completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
frequency_penalty=frequency_penalty,
|
||||
function_call=function_call,
|
||||
functions=functions,
|
||||
logit_bias=logit_bias,
|
||||
logprobs=logprobs,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
max_tokens=max_tokens,
|
||||
n=n,
|
||||
parallel_tool_calls=parallel_tool_calls,
|
||||
presence_penalty=presence_penalty,
|
||||
response_format=response_format,
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
stream_options=stream_options,
|
||||
temperature=temperature,
|
||||
tool_choice=tool_choice,
|
||||
tools=tools,
|
||||
top_logprobs=top_logprobs,
|
||||
top_p=top_p,
|
||||
user=user,
|
||||
)
|
||||
|
|
@ -1,15 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import WatsonXConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: WatsonXConfig, _deps):
|
||||
# import dynamically so the import is used only when it is needed
|
||||
from .watsonx import WatsonXInferenceAdapter
|
||||
|
||||
adapter = WatsonXInferenceAdapter(config)
|
||||
return adapter
|
||||
|
|
@ -1,45 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
class WatsonXProviderDataValidator(BaseModel):
|
||||
model_config = ConfigDict(
|
||||
from_attributes=True,
|
||||
extra="forbid",
|
||||
)
|
||||
watsonx_api_key: str | None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class WatsonXConfig(RemoteInferenceProviderConfig):
|
||||
url: str = Field(
|
||||
default_factory=lambda: os.getenv("WATSONX_BASE_URL", "https://us-south.ml.cloud.ibm.com"),
|
||||
description="A base url for accessing the watsonx.ai",
|
||||
)
|
||||
project_id: str | None = Field(
|
||||
default=None,
|
||||
description="The watsonx.ai project ID",
|
||||
)
|
||||
timeout: int = Field(
|
||||
default=60,
|
||||
description="Timeout for the HTTP requests",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
"url": "${env.WATSONX_BASE_URL:=https://us-south.ml.cloud.ibm.com}",
|
||||
"api_key": "${env.WATSONX_API_KEY:=}",
|
||||
"project_id": "${env.WATSONX_PROJECT_ID:=}",
|
||||
}
|
||||
|
|
@ -1,123 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
|
||||
from llama_stack.apis.inference import ChatCompletionRequest
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.apis.models.models import ModelType
|
||||
from llama_stack.providers.remote.inference.watsonx.config import WatsonXConfig
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||
|
||||
|
||||
class WatsonXInferenceAdapter(LiteLLMOpenAIMixin):
|
||||
_model_cache: dict[str, Model] = {}
|
||||
|
||||
def __init__(self, config: WatsonXConfig):
|
||||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
litellm_provider_name="watsonx",
|
||||
api_key_from_config=config.auth_credential.get_secret_value() if config.auth_credential else None,
|
||||
provider_data_api_key_field="watsonx_api_key",
|
||||
)
|
||||
self.available_models = None
|
||||
self.config = config
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
return self.config.url
|
||||
|
||||
async def _get_params(self, request: ChatCompletionRequest) -> dict[str, Any]:
|
||||
# Get base parameters from parent
|
||||
params = await super()._get_params(request)
|
||||
|
||||
# Add watsonx.ai specific parameters
|
||||
params["project_id"] = self.config.project_id
|
||||
params["time_limit"] = self.config.timeout
|
||||
return params
|
||||
|
||||
# Copied from OpenAIMixin
|
||||
async def check_model_availability(self, model: str) -> bool:
|
||||
"""
|
||||
Check if a specific model is available from the provider's /v1/models.
|
||||
|
||||
:param model: The model identifier to check.
|
||||
:return: True if the model is available dynamically, False otherwise.
|
||||
"""
|
||||
if not self._model_cache:
|
||||
await self.list_models()
|
||||
return model in self._model_cache
|
||||
|
||||
async def list_models(self) -> list[Model] | None:
|
||||
self._model_cache = {}
|
||||
models = []
|
||||
for model_spec in self._get_model_specs():
|
||||
functions = [f["id"] for f in model_spec.get("functions", [])]
|
||||
# Format: {"embedding_dimension": 1536, "context_length": 8192}
|
||||
|
||||
# Example of an embedding model:
|
||||
# {'model_id': 'ibm/granite-embedding-278m-multilingual',
|
||||
# 'label': 'granite-embedding-278m-multilingual',
|
||||
# 'model_limits': {'max_sequence_length': 512, 'embedding_dimension': 768},
|
||||
# ...
|
||||
provider_resource_id = f"{self.__provider_id__}/{model_spec['model_id']}"
|
||||
if "embedding" in functions:
|
||||
embedding_dimension = model_spec["model_limits"]["embedding_dimension"]
|
||||
context_length = model_spec["model_limits"]["max_sequence_length"]
|
||||
embedding_metadata = {
|
||||
"embedding_dimension": embedding_dimension,
|
||||
"context_length": context_length,
|
||||
}
|
||||
model = Model(
|
||||
identifier=model_spec["model_id"],
|
||||
provider_resource_id=provider_resource_id,
|
||||
provider_id=self.__provider_id__,
|
||||
metadata=embedding_metadata,
|
||||
model_type=ModelType.embedding,
|
||||
)
|
||||
self._model_cache[provider_resource_id] = model
|
||||
models.append(model)
|
||||
if "text_chat" in functions:
|
||||
model = Model(
|
||||
identifier=model_spec["model_id"],
|
||||
provider_resource_id=provider_resource_id,
|
||||
provider_id=self.__provider_id__,
|
||||
metadata={},
|
||||
model_type=ModelType.llm,
|
||||
)
|
||||
# In theory, I guess it is possible that a model could be both an embedding model and a text chat model.
|
||||
# In that case, the cache will record the generator Model object, and the list which we return will have
|
||||
# both the generator Model object and the text chat Model object. That's fine because the cache is
|
||||
# only used for check_model_availability() anyway.
|
||||
self._model_cache[provider_resource_id] = model
|
||||
models.append(model)
|
||||
return models
|
||||
|
||||
# LiteLLM provides methods to list models for many providers, but not for watsonx.ai.
|
||||
# So we need to implement our own method to list models by calling the watsonx.ai API.
|
||||
def _get_model_specs(self) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Retrieves foundation model specifications from the watsonx.ai API.
|
||||
"""
|
||||
url = f"{self.config.url}/ml/v1/foundation_model_specs?version=2023-10-25"
|
||||
headers = {
|
||||
# Note that there is no authorization header. Listing models does not require authentication.
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
response = requests.get(url, headers=headers)
|
||||
|
||||
# --- Process the Response ---
|
||||
# Raise an exception for bad status codes (4xx or 5xx)
|
||||
response.raise_for_status()
|
||||
|
||||
# If the request is successful, parse and return the JSON response.
|
||||
# The response should contain a list of model specifications
|
||||
response_data = response.json()
|
||||
if "resources" not in response_data:
|
||||
raise ValueError("Resources not found in response")
|
||||
return response_data["resources"]
|
||||
|
|
@ -1,5 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
|
@ -1,150 +0,0 @@
|
|||
# NVIDIA Post-Training Provider for LlamaStack
|
||||
|
||||
This provider enables fine-tuning of LLMs using NVIDIA's NeMo Customizer service.
|
||||
|
||||
## Features
|
||||
|
||||
- Supervised fine-tuning of Llama models
|
||||
- LoRA fine-tuning support
|
||||
- Job management and status tracking
|
||||
|
||||
## Getting Started
|
||||
|
||||
### Prerequisites
|
||||
|
||||
- LlamaStack with NVIDIA configuration
|
||||
- Access to Hosted NVIDIA NeMo Customizer service
|
||||
- Dataset registered in the Hosted NVIDIA NeMo Customizer service
|
||||
- Base model downloaded and available in the Hosted NVIDIA NeMo Customizer service
|
||||
|
||||
### Setup
|
||||
|
||||
Build the NVIDIA environment:
|
||||
|
||||
```bash
|
||||
llama stack build --distro nvidia --image-type venv
|
||||
```
|
||||
|
||||
### Basic Usage using the LlamaStack Python Client
|
||||
|
||||
### Create Customization Job
|
||||
|
||||
#### Initialize the client
|
||||
|
||||
```python
|
||||
import os
|
||||
|
||||
os.environ["NVIDIA_API_KEY"] = "your-api-key"
|
||||
os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test"
|
||||
os.environ["NVIDIA_DATASET_NAMESPACE"] = "default"
|
||||
os.environ["NVIDIA_PROJECT_ID"] = "test-project"
|
||||
os.environ["NVIDIA_OUTPUT_MODEL_DIR"] = "test-example-model@v1"
|
||||
|
||||
from llama_stack.core.library_client import LlamaStackAsLibraryClient
|
||||
|
||||
client = LlamaStackAsLibraryClient("nvidia")
|
||||
client.initialize()
|
||||
```
|
||||
|
||||
#### Configure fine-tuning parameters
|
||||
|
||||
```python
|
||||
from llama_stack_client.types.post_training_supervised_fine_tune_params import (
|
||||
TrainingConfig,
|
||||
TrainingConfigDataConfig,
|
||||
TrainingConfigOptimizerConfig,
|
||||
)
|
||||
from llama_stack_client.types.algorithm_config_param import LoraFinetuningConfig
|
||||
```
|
||||
|
||||
#### Set up LoRA configuration
|
||||
|
||||
```python
|
||||
algorithm_config = LoraFinetuningConfig(type="LoRA", adapter_dim=16)
|
||||
```
|
||||
|
||||
#### Configure training data
|
||||
|
||||
```python
|
||||
data_config = TrainingConfigDataConfig(
|
||||
dataset_id="your-dataset-id", # Use client.datasets.list() to see available datasets
|
||||
batch_size=16,
|
||||
)
|
||||
```
|
||||
|
||||
#### Configure optimizer
|
||||
|
||||
```python
|
||||
optimizer_config = TrainingConfigOptimizerConfig(
|
||||
lr=0.0001,
|
||||
)
|
||||
```
|
||||
|
||||
#### Set up training configuration
|
||||
|
||||
```python
|
||||
training_config = TrainingConfig(
|
||||
n_epochs=2,
|
||||
data_config=data_config,
|
||||
optimizer_config=optimizer_config,
|
||||
)
|
||||
```
|
||||
|
||||
#### Start fine-tuning job
|
||||
|
||||
```python
|
||||
training_job = client.post_training.supervised_fine_tune(
|
||||
job_uuid="unique-job-id",
|
||||
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||
checkpoint_dir="",
|
||||
algorithm_config=algorithm_config,
|
||||
training_config=training_config,
|
||||
logger_config={},
|
||||
hyperparam_search_config={},
|
||||
)
|
||||
```
|
||||
|
||||
### List all jobs
|
||||
|
||||
```python
|
||||
jobs = client.post_training.job.list()
|
||||
```
|
||||
|
||||
### Check job status
|
||||
|
||||
```python
|
||||
job_status = client.post_training.job.status(job_uuid="your-job-id")
|
||||
```
|
||||
|
||||
### Cancel a job
|
||||
|
||||
```python
|
||||
client.post_training.job.cancel(job_uuid="your-job-id")
|
||||
```
|
||||
|
||||
### Inference with the fine-tuned model
|
||||
|
||||
#### 1. Register the model
|
||||
|
||||
```python
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
|
||||
client.models.register(
|
||||
model_id="test-example-model@v1",
|
||||
provider_id="nvidia",
|
||||
provider_model_id="test-example-model@v1",
|
||||
model_type=ModelType.llm,
|
||||
)
|
||||
```
|
||||
|
||||
#### 2. Inference with the fine-tuned model
|
||||
|
||||
```python
|
||||
response = client.completions.create(
|
||||
prompt="Complete the sentence using one word: Roses are red, violets are ",
|
||||
stream=False,
|
||||
model="test-example-model@v1",
|
||||
max_tokens=50,
|
||||
)
|
||||
print(response.choices[0].text)
|
||||
```
|
||||
|
|
@ -1,23 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import NvidiaPostTrainingConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(
|
||||
config: NvidiaPostTrainingConfig,
|
||||
_deps,
|
||||
):
|
||||
from .post_training import NvidiaPostTrainingAdapter
|
||||
|
||||
if not isinstance(config, NvidiaPostTrainingConfig):
|
||||
raise RuntimeError(f"Unexpected config type: {type(config)}")
|
||||
|
||||
impl = NvidiaPostTrainingAdapter(config)
|
||||
return impl
|
||||
|
||||
|
||||
__all__ = ["get_adapter_impl", "NvidiaPostTrainingAdapter"]
|
||||
|
|
@ -1,113 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# TODO: add default values for all fields
|
||||
|
||||
|
||||
class NvidiaPostTrainingConfig(BaseModel):
|
||||
"""Configuration for NVIDIA Post Training implementation."""
|
||||
|
||||
api_key: str | None = Field(
|
||||
default_factory=lambda: os.getenv("NVIDIA_API_KEY"),
|
||||
description="The NVIDIA API key.",
|
||||
)
|
||||
|
||||
dataset_namespace: str | None = Field(
|
||||
default_factory=lambda: os.getenv("NVIDIA_DATASET_NAMESPACE", "default"),
|
||||
description="The NVIDIA dataset namespace.",
|
||||
)
|
||||
|
||||
project_id: str | None = Field(
|
||||
default_factory=lambda: os.getenv("NVIDIA_PROJECT_ID", "test-example-model@v1"),
|
||||
description="The NVIDIA project ID.",
|
||||
)
|
||||
|
||||
# ToDO: validate this, add default value
|
||||
customizer_url: str | None = Field(
|
||||
default_factory=lambda: os.getenv("NVIDIA_CUSTOMIZER_URL"),
|
||||
description="Base URL for the NeMo Customizer API",
|
||||
)
|
||||
|
||||
timeout: int = Field(
|
||||
default=300,
|
||||
description="Timeout for the NVIDIA Post Training API",
|
||||
)
|
||||
|
||||
max_retries: int = Field(
|
||||
default=3,
|
||||
description="Maximum number of retries for the NVIDIA Post Training API",
|
||||
)
|
||||
|
||||
# ToDo: validate this
|
||||
output_model_dir: str = Field(
|
||||
default_factory=lambda: os.getenv("NVIDIA_OUTPUT_MODEL_DIR", "test-example-model@v1"),
|
||||
description="Directory to save the output model",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
"api_key": "${env.NVIDIA_API_KEY:=}",
|
||||
"dataset_namespace": "${env.NVIDIA_DATASET_NAMESPACE:=default}",
|
||||
"project_id": "${env.NVIDIA_PROJECT_ID:=test-project}",
|
||||
"customizer_url": "${env.NVIDIA_CUSTOMIZER_URL:=http://nemo.test}",
|
||||
}
|
||||
|
||||
|
||||
class SFTLoRADefaultConfig(BaseModel):
|
||||
"""NVIDIA-specific training configuration with default values."""
|
||||
|
||||
# ToDo: split into SFT and LoRA configs??
|
||||
|
||||
# General training parameters
|
||||
n_epochs: int = 50
|
||||
|
||||
# NeMo customizer specific parameters
|
||||
log_every_n_steps: int | None = None
|
||||
val_check_interval: float = 0.25
|
||||
sequence_packing_enabled: bool = False
|
||||
weight_decay: float = 0.01
|
||||
lr: float = 0.0001
|
||||
|
||||
# SFT specific parameters
|
||||
hidden_dropout: float | None = None
|
||||
attention_dropout: float | None = None
|
||||
ffn_dropout: float | None = None
|
||||
|
||||
# LoRA default parameters
|
||||
lora_adapter_dim: int = 8
|
||||
lora_adapter_dropout: float | None = None
|
||||
lora_alpha: int = 16
|
||||
|
||||
# Data config
|
||||
batch_size: int = 8
|
||||
|
||||
@classmethod
|
||||
def sample_config(cls) -> dict[str, Any]:
|
||||
"""Return a sample configuration for NVIDIA training."""
|
||||
return {
|
||||
"n_epochs": 50,
|
||||
"log_every_n_steps": 10,
|
||||
"val_check_interval": 0.25,
|
||||
"sequence_packing_enabled": False,
|
||||
"weight_decay": 0.01,
|
||||
"hidden_dropout": 0.1,
|
||||
"attention_dropout": 0.1,
|
||||
"lora_adapter_dim": 8,
|
||||
"lora_alpha": 16,
|
||||
"data_config": {
|
||||
"dataset_id": "default",
|
||||
"batch_size": 8,
|
||||
},
|
||||
"optimizer_config": {
|
||||
"lr": 0.0001,
|
||||
},
|
||||
}
|
||||
|
|
@ -1,27 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from llama_stack.models.llama.sku_types import CoreModelId
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ProviderModelEntry,
|
||||
build_hf_repo_model_entry,
|
||||
)
|
||||
|
||||
_MODEL_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"meta/llama-3.1-8b-instruct",
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta/llama-3.2-1b-instruct",
|
||||
CoreModelId.llama3_2_1b_instruct.value,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def get_model_entries() -> list[ProviderModelEntry]:
|
||||
return _MODEL_ENTRIES
|
||||
|
|
@ -1,430 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import warnings
|
||||
from datetime import datetime
|
||||
from typing import Any, Literal
|
||||
|
||||
import aiohttp
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from llama_stack.apis.post_training import (
|
||||
AlgorithmConfig,
|
||||
DPOAlignmentConfig,
|
||||
JobStatus,
|
||||
PostTrainingJob,
|
||||
PostTrainingJobArtifactsResponse,
|
||||
PostTrainingJobStatusResponse,
|
||||
TrainingConfig,
|
||||
)
|
||||
from llama_stack.providers.remote.post_training.nvidia.config import NvidiaPostTrainingConfig
|
||||
from llama_stack.providers.remote.post_training.nvidia.utils import warn_unsupported_params
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
|
||||
from .models import _MODEL_ENTRIES
|
||||
|
||||
# Map API status to JobStatus enum
|
||||
STATUS_MAPPING = {
|
||||
"running": JobStatus.in_progress.value,
|
||||
"completed": JobStatus.completed.value,
|
||||
"failed": JobStatus.failed.value,
|
||||
"cancelled": JobStatus.cancelled.value,
|
||||
"pending": JobStatus.scheduled.value,
|
||||
"unknown": JobStatus.scheduled.value,
|
||||
}
|
||||
|
||||
|
||||
class NvidiaPostTrainingJob(PostTrainingJob):
|
||||
"""Parse the response from the Customizer API.
|
||||
Inherits job_uuid from PostTrainingJob.
|
||||
Adds status, created_at, updated_at parameters.
|
||||
Passes through all other parameters from data field in the response.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
status: JobStatus
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
class ListNvidiaPostTrainingJobs(BaseModel):
|
||||
data: list[NvidiaPostTrainingJob]
|
||||
|
||||
|
||||
class NvidiaPostTrainingJobStatusResponse(PostTrainingJobStatusResponse):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
||||
def __init__(self, config: NvidiaPostTrainingConfig):
|
||||
self.config = config
|
||||
self.headers = {}
|
||||
if config.api_key:
|
||||
self.headers["Authorization"] = f"Bearer {config.api_key}"
|
||||
|
||||
self.timeout = aiohttp.ClientTimeout(total=config.timeout)
|
||||
# TODO: filter by available models based on /config endpoint
|
||||
ModelRegistryHelper.__init__(self, model_entries=_MODEL_ENTRIES)
|
||||
self.session = None
|
||||
|
||||
self.customizer_url = config.customizer_url
|
||||
if not self.customizer_url:
|
||||
warnings.warn("Customizer URL is not set, using default value: http://nemo.test", stacklevel=2)
|
||||
self.customizer_url = "http://nemo.test"
|
||||
|
||||
async def _get_session(self) -> aiohttp.ClientSession:
|
||||
if self.session is None or self.session.closed:
|
||||
self.session = aiohttp.ClientSession(headers=self.headers, timeout=self.timeout)
|
||||
return self.session
|
||||
|
||||
async def _make_request(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
headers: dict[str, Any] | None = None,
|
||||
params: dict[str, Any] | None = None,
|
||||
json: dict[str, Any] | None = None,
|
||||
**kwargs,
|
||||
) -> dict[str, Any]:
|
||||
"""Helper method to make HTTP requests to the Customizer API."""
|
||||
url = f"{self.customizer_url}{path}"
|
||||
request_headers = self.headers.copy()
|
||||
|
||||
if headers:
|
||||
request_headers.update(headers)
|
||||
|
||||
# Add content-type header for JSON requests
|
||||
if json and "Content-Type" not in request_headers:
|
||||
request_headers["Content-Type"] = "application/json"
|
||||
|
||||
session = await self._get_session()
|
||||
for _ in range(self.config.max_retries):
|
||||
async with session.request(method, url, params=params, json=json, **kwargs) as response:
|
||||
if response.status >= 400:
|
||||
error_data = await response.json()
|
||||
raise Exception(f"API request failed: {error_data}")
|
||||
return await response.json()
|
||||
|
||||
async def get_training_jobs(
|
||||
self,
|
||||
page: int | None = 1,
|
||||
page_size: int | None = 10,
|
||||
sort: Literal["created_at", "-created_at"] | None = "created_at",
|
||||
) -> ListNvidiaPostTrainingJobs:
|
||||
"""Get all customization jobs.
|
||||
Updated the base class return type from ListPostTrainingJobsResponse to ListNvidiaPostTrainingJobs.
|
||||
|
||||
Returns a ListNvidiaPostTrainingJobs object with the following fields:
|
||||
- data: List[NvidiaPostTrainingJob] - List of NvidiaPostTrainingJob objects
|
||||
|
||||
ToDo: Support for schema input for filtering.
|
||||
"""
|
||||
params = {"page": page, "page_size": page_size, "sort": sort}
|
||||
|
||||
response = await self._make_request("GET", "/v1/customization/jobs", params=params)
|
||||
|
||||
jobs = []
|
||||
for job in response.get("data", []):
|
||||
job_id = job.pop("id")
|
||||
job_status = job.pop("status", "scheduled").lower()
|
||||
mapped_status = STATUS_MAPPING.get(job_status, "scheduled")
|
||||
|
||||
# Convert string timestamps to datetime objects
|
||||
created_at = (
|
||||
datetime.fromisoformat(job.pop("created_at"))
|
||||
if "created_at" in job
|
||||
else datetime.now(tz=datetime.timezone.utc)
|
||||
)
|
||||
updated_at = (
|
||||
datetime.fromisoformat(job.pop("updated_at"))
|
||||
if "updated_at" in job
|
||||
else datetime.now(tz=datetime.timezone.utc)
|
||||
)
|
||||
|
||||
# Create NvidiaPostTrainingJob instance
|
||||
jobs.append(
|
||||
NvidiaPostTrainingJob(
|
||||
job_uuid=job_id,
|
||||
status=JobStatus(mapped_status),
|
||||
created_at=created_at,
|
||||
updated_at=updated_at,
|
||||
**job,
|
||||
)
|
||||
)
|
||||
|
||||
return ListNvidiaPostTrainingJobs(data=jobs)
|
||||
|
||||
async def get_training_job_status(self, job_uuid: str) -> NvidiaPostTrainingJobStatusResponse:
|
||||
"""Get the status of a customization job.
|
||||
Updated the base class return type from PostTrainingJobResponse to NvidiaPostTrainingJob.
|
||||
|
||||
Returns a NvidiaPostTrainingJob object with the following fields:
|
||||
- job_uuid: str - Unique identifier for the job
|
||||
- status: JobStatus - Current status of the job (in_progress, completed, failed, cancelled, scheduled)
|
||||
- created_at: datetime - The time when the job was created
|
||||
- updated_at: datetime - The last time the job status was updated
|
||||
|
||||
Additional fields that may be included:
|
||||
- steps_completed: Optional[int] - Number of training steps completed
|
||||
- epochs_completed: Optional[int] - Number of epochs completed
|
||||
- percentage_done: Optional[float] - Percentage of training completed (0-100)
|
||||
- best_epoch: Optional[int] - The epoch with the best performance
|
||||
- train_loss: Optional[float] - Training loss of the best checkpoint
|
||||
- val_loss: Optional[float] - Validation loss of the best checkpoint
|
||||
- metrics: Optional[Dict] - Additional training metrics
|
||||
- status_logs: Optional[List] - Detailed logs of status changes
|
||||
"""
|
||||
response = await self._make_request(
|
||||
"GET",
|
||||
f"/v1/customization/jobs/{job_uuid}/status",
|
||||
params={"job_id": job_uuid},
|
||||
)
|
||||
|
||||
api_status = response.pop("status").lower()
|
||||
mapped_status = STATUS_MAPPING.get(api_status, "scheduled")
|
||||
|
||||
return NvidiaPostTrainingJobStatusResponse(
|
||||
status=JobStatus(mapped_status),
|
||||
job_uuid=job_uuid,
|
||||
started_at=datetime.fromisoformat(response.pop("created_at")),
|
||||
updated_at=datetime.fromisoformat(response.pop("updated_at")),
|
||||
**response,
|
||||
)
|
||||
|
||||
async def cancel_training_job(self, job_uuid: str) -> None:
|
||||
await self._make_request(
|
||||
method="POST", path=f"/v1/customization/jobs/{job_uuid}/cancel", params={"job_id": job_uuid}
|
||||
)
|
||||
|
||||
async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse:
|
||||
raise NotImplementedError("Job artifacts are not implemented yet")
|
||||
|
||||
async def get_post_training_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse:
|
||||
raise NotImplementedError("Job artifacts are not implemented yet")
|
||||
|
||||
async def supervised_fine_tune(
|
||||
self,
|
||||
job_uuid: str,
|
||||
training_config: dict[str, Any],
|
||||
hyperparam_search_config: dict[str, Any],
|
||||
logger_config: dict[str, Any],
|
||||
model: str,
|
||||
checkpoint_dir: str | None,
|
||||
algorithm_config: AlgorithmConfig | None = None,
|
||||
) -> NvidiaPostTrainingJob:
|
||||
"""
|
||||
Fine-tunes a model on a dataset.
|
||||
Currently only supports Lora finetuning for standlone docker container.
|
||||
Assumptions:
|
||||
- nemo microservice is running and endpoint is set in config.customizer_url
|
||||
- dataset is registered separately in nemo datastore
|
||||
- model checkpoint is downloaded as per nemo customizer requirements
|
||||
|
||||
Parameters:
|
||||
training_config: TrainingConfig - Configuration for training
|
||||
model: str - NeMo Customizer configuration name
|
||||
algorithm_config: Optional[AlgorithmConfig] - Algorithm-specific configuration
|
||||
checkpoint_dir: Optional[str] - Directory containing model checkpoints, ignored atm
|
||||
job_uuid: str - Unique identifier for the job, ignored atm
|
||||
hyperparam_search_config: Dict[str, Any] - Configuration for hyperparameter search, ignored atm
|
||||
logger_config: Dict[str, Any] - Configuration for logging, ignored atm
|
||||
|
||||
Environment Variables:
|
||||
- NVIDIA_API_KEY: str - API key for the NVIDIA API
|
||||
Default: None
|
||||
- NVIDIA_DATASET_NAMESPACE: str - Namespace of the dataset
|
||||
Default: "default"
|
||||
- NVIDIA_CUSTOMIZER_URL: str - URL of the NeMo Customizer API
|
||||
Default: "http://nemo.test"
|
||||
- NVIDIA_PROJECT_ID: str - ID of the project
|
||||
Default: "test-project"
|
||||
- NVIDIA_OUTPUT_MODEL_DIR: str - Directory to save the output model
|
||||
Default: "test-example-model@v1"
|
||||
|
||||
Supported models:
|
||||
- meta/llama-3.1-8b-instruct
|
||||
- meta/llama-3.2-1b-instruct
|
||||
|
||||
Supported algorithm configs:
|
||||
- LoRA, SFT
|
||||
|
||||
Supported Parameters:
|
||||
- TrainingConfig:
|
||||
- n_epochs: int - Number of epochs to train
|
||||
Default: 50
|
||||
- data_config: DataConfig - Configuration for the dataset
|
||||
- optimizer_config: OptimizerConfig - Configuration for the optimizer
|
||||
- dtype: str - Data type for training
|
||||
not supported (users are informed via warnings)
|
||||
- efficiency_config: EfficiencyConfig - Configuration for efficiency
|
||||
not supported
|
||||
- max_steps_per_epoch: int - Maximum number of steps per epoch
|
||||
Default: 1000
|
||||
## NeMo customizer specific parameters
|
||||
- log_every_n_steps: int - Log every n steps
|
||||
Default: None
|
||||
- val_check_interval: float - Validation check interval
|
||||
Default: 0.25
|
||||
- sequence_packing_enabled: bool - Sequence packing enabled
|
||||
Default: False
|
||||
## NeMo customizer specific SFT parameters
|
||||
- hidden_dropout: float - Hidden dropout
|
||||
Default: None (0.0-1.0)
|
||||
- attention_dropout: float - Attention dropout
|
||||
Default: None (0.0-1.0)
|
||||
- ffn_dropout: float - FFN dropout
|
||||
Default: None (0.0-1.0)
|
||||
|
||||
- DataConfig:
|
||||
- dataset_id: str - Dataset ID
|
||||
- batch_size: int - Batch size
|
||||
Default: 8
|
||||
|
||||
- OptimizerConfig:
|
||||
- lr: float - Learning rate
|
||||
Default: 0.0001
|
||||
## NeMo customizer specific parameter
|
||||
- weight_decay: float - Weight decay
|
||||
Default: 0.01
|
||||
|
||||
- LoRA config:
|
||||
## NeMo customizer specific LoRA parameters
|
||||
- alpha: int - Scaling factor for the LoRA update
|
||||
Default: 16
|
||||
Note:
|
||||
- checkpoint_dir, hyperparam_search_config, logger_config are not supported (users are informed via warnings)
|
||||
- Some parameters from TrainingConfig, DataConfig, OptimizerConfig are not supported (users are informed via warnings)
|
||||
|
||||
User is informed about unsupported parameters via warnings.
|
||||
"""
|
||||
|
||||
# Check for unsupported method parameters
|
||||
unsupported_method_params = []
|
||||
if checkpoint_dir:
|
||||
unsupported_method_params.append(f"checkpoint_dir={checkpoint_dir}")
|
||||
if hyperparam_search_config:
|
||||
unsupported_method_params.append("hyperparam_search_config")
|
||||
if logger_config:
|
||||
unsupported_method_params.append("logger_config")
|
||||
|
||||
if unsupported_method_params:
|
||||
warnings.warn(
|
||||
f"Parameters: {', '.join(unsupported_method_params)} are not supported and will be ignored",
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
# Define all supported parameters
|
||||
supported_params = {
|
||||
"training_config": {
|
||||
"n_epochs",
|
||||
"data_config",
|
||||
"optimizer_config",
|
||||
"log_every_n_steps",
|
||||
"val_check_interval",
|
||||
"sequence_packing_enabled",
|
||||
"hidden_dropout",
|
||||
"attention_dropout",
|
||||
"ffn_dropout",
|
||||
},
|
||||
"data_config": {"dataset_id", "batch_size"},
|
||||
"optimizer_config": {"lr", "weight_decay"},
|
||||
"lora_config": {"type", "alpha"},
|
||||
}
|
||||
|
||||
# Validate all parameters at once
|
||||
warn_unsupported_params(training_config, supported_params["training_config"], "TrainingConfig")
|
||||
warn_unsupported_params(training_config["data_config"], supported_params["data_config"], "DataConfig")
|
||||
warn_unsupported_params(
|
||||
training_config["optimizer_config"], supported_params["optimizer_config"], "OptimizerConfig"
|
||||
)
|
||||
|
||||
output_model = self.config.output_model_dir
|
||||
|
||||
# Prepare base job configuration
|
||||
job_config = {
|
||||
"config": model,
|
||||
"dataset": {
|
||||
"name": training_config["data_config"]["dataset_id"],
|
||||
"namespace": self.config.dataset_namespace,
|
||||
},
|
||||
"hyperparameters": {
|
||||
"training_type": "sft",
|
||||
"finetuning_type": "lora",
|
||||
**{
|
||||
k: v
|
||||
for k, v in {
|
||||
"epochs": training_config.get("n_epochs"),
|
||||
"batch_size": training_config["data_config"].get("batch_size"),
|
||||
"learning_rate": training_config["optimizer_config"].get("lr"),
|
||||
"weight_decay": training_config["optimizer_config"].get("weight_decay"),
|
||||
"log_every_n_steps": training_config.get("log_every_n_steps"),
|
||||
"val_check_interval": training_config.get("val_check_interval"),
|
||||
"sequence_packing_enabled": training_config.get("sequence_packing_enabled"),
|
||||
}.items()
|
||||
if v is not None
|
||||
},
|
||||
},
|
||||
"project": self.config.project_id,
|
||||
# TODO: ignored ownership, add it later
|
||||
# "ownership": {"created_by": self.config.user_id, "access_policies": self.config.access_policies},
|
||||
"output_model": output_model,
|
||||
}
|
||||
|
||||
# Handle SFT-specific optional parameters
|
||||
job_config["hyperparameters"]["sft"] = {
|
||||
k: v
|
||||
for k, v in {
|
||||
"ffn_dropout": training_config.get("ffn_dropout"),
|
||||
"hidden_dropout": training_config.get("hidden_dropout"),
|
||||
"attention_dropout": training_config.get("attention_dropout"),
|
||||
}.items()
|
||||
if v is not None
|
||||
}
|
||||
|
||||
# Remove the sft dictionary if it's empty
|
||||
if not job_config["hyperparameters"]["sft"]:
|
||||
job_config["hyperparameters"].pop("sft")
|
||||
|
||||
# Handle LoRA-specific configuration
|
||||
if algorithm_config:
|
||||
if algorithm_config.type == "LoRA":
|
||||
warn_unsupported_params(algorithm_config, supported_params["lora_config"], "LoRA config")
|
||||
job_config["hyperparameters"]["lora"] = {
|
||||
k: v for k, v in {"alpha": algorithm_config.alpha}.items() if v is not None
|
||||
}
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported algorithm config: {algorithm_config}")
|
||||
|
||||
# Create the customization job
|
||||
response = await self._make_request(
|
||||
method="POST",
|
||||
path="/v1/customization/jobs",
|
||||
headers={"Accept": "application/json"},
|
||||
json=job_config,
|
||||
)
|
||||
|
||||
job_uuid = response["id"]
|
||||
response.pop("status")
|
||||
created_at = datetime.fromisoformat(response.pop("created_at"))
|
||||
updated_at = datetime.fromisoformat(response.pop("updated_at"))
|
||||
|
||||
return NvidiaPostTrainingJob(
|
||||
job_uuid=job_uuid, status=JobStatus.in_progress, created_at=created_at, updated_at=updated_at, **response
|
||||
)
|
||||
|
||||
async def preference_optimize(
|
||||
self,
|
||||
job_uuid: str,
|
||||
finetuned_model: str,
|
||||
algorithm_config: DPOAlignmentConfig,
|
||||
training_config: TrainingConfig,
|
||||
hyperparam_search_config: dict[str, Any],
|
||||
logger_config: dict[str, Any],
|
||||
) -> PostTrainingJob:
|
||||
"""Optimize a model based on preference data."""
|
||||
raise NotImplementedError("Preference optimization is not implemented yet")
|
||||
|
||||
async def get_training_job_container_logs(self, job_uuid: str) -> PostTrainingJobStatusResponse:
|
||||
raise NotImplementedError("Job logs are not implemented yet")
|
||||
|
|
@ -1,63 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import warnings
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.post_training import TrainingConfig
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.remote.post_training.nvidia.config import SFTLoRADefaultConfig
|
||||
|
||||
from .config import NvidiaPostTrainingConfig
|
||||
|
||||
logger = get_logger(name=__name__, category="post_training::nvidia")
|
||||
|
||||
|
||||
def warn_unsupported_params(config_dict: Any, supported_keys: set[str], config_name: str) -> None:
|
||||
keys = set(config_dict.__annotations__.keys()) if isinstance(config_dict, BaseModel) else config_dict.keys()
|
||||
unsupported_params = [k for k in keys if k not in supported_keys]
|
||||
if unsupported_params:
|
||||
warnings.warn(
|
||||
f"Parameters: {unsupported_params} in `{config_name}` not supported and will be ignored.", stacklevel=2
|
||||
)
|
||||
|
||||
|
||||
def validate_training_params(
|
||||
training_config: dict[str, Any], supported_keys: set[str], config_name: str = "TrainingConfig"
|
||||
) -> None:
|
||||
"""
|
||||
Validates training parameters against supported keys.
|
||||
|
||||
Args:
|
||||
training_config: Dictionary containing training configuration parameters
|
||||
supported_keys: Set of supported parameter keys
|
||||
config_name: Name of the configuration for warning messages
|
||||
"""
|
||||
sft_lora_fields = set(SFTLoRADefaultConfig.__annotations__.keys())
|
||||
training_config_fields = set(TrainingConfig.__annotations__.keys())
|
||||
|
||||
# Check for not supported parameters:
|
||||
# - not in either of configs
|
||||
# - in TrainingConfig but not in SFTLoRADefaultConfig
|
||||
unsupported_params = []
|
||||
for key in training_config:
|
||||
if isinstance(key, str) and key not in (supported_keys.union(sft_lora_fields)):
|
||||
if key in (not sft_lora_fields or training_config_fields):
|
||||
unsupported_params.append(key)
|
||||
|
||||
if unsupported_params:
|
||||
warnings.warn(
|
||||
f"Parameters: {unsupported_params} in `{config_name}` are not supported and will be ignored.", stacklevel=2
|
||||
)
|
||||
|
||||
|
||||
# ToDo: implement post health checks for customizer are enabled
|
||||
async def _get_health(url: str) -> tuple[bool, bool]: ...
|
||||
|
||||
|
||||
async def check_health(config: NvidiaPostTrainingConfig) -> None: ...
|
||||
|
|
@ -1,5 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
|
@ -1,18 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from typing import Any
|
||||
|
||||
from .config import BedrockSafetyConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: BedrockSafetyConfig, _deps) -> Any:
|
||||
from .bedrock import BedrockSafetyAdapter
|
||||
|
||||
impl = BedrockSafetyAdapter(config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
@ -1,111 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.inference import Message
|
||||
from llama_stack.apis.safety import (
|
||||
RunShieldResponse,
|
||||
Safety,
|
||||
SafetyViolation,
|
||||
ViolationLevel,
|
||||
)
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
||||
from llama_stack.providers.utils.bedrock.client import create_bedrock_client
|
||||
|
||||
from .config import BedrockSafetyConfig
|
||||
|
||||
logger = get_logger(name=__name__, category="safety::bedrock")
|
||||
|
||||
|
||||
class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
|
||||
def __init__(self, config: BedrockSafetyConfig) -> None:
|
||||
self.config = config
|
||||
self.registered_shields = []
|
||||
|
||||
async def initialize(self) -> None:
|
||||
try:
|
||||
self.bedrock_runtime_client = create_bedrock_client(self.config)
|
||||
self.bedrock_client = create_bedrock_client(self.config, "bedrock")
|
||||
except Exception as e:
|
||||
raise RuntimeError("Error initializing BedrockSafetyAdapter") from e
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def register_shield(self, shield: Shield) -> None:
|
||||
response = self.bedrock_client.list_guardrails(
|
||||
guardrailIdentifier=shield.provider_resource_id,
|
||||
)
|
||||
if (
|
||||
not response["guardrails"]
|
||||
or len(response["guardrails"]) == 0
|
||||
or response["guardrails"][0]["version"] != shield.params["guardrailVersion"]
|
||||
):
|
||||
raise ValueError(
|
||||
f"Shield {shield.provider_resource_id} with version {shield.params['guardrailVersion']} not found in Bedrock"
|
||||
)
|
||||
|
||||
async def unregister_shield(self, identifier: str) -> None:
|
||||
pass
|
||||
|
||||
async def run_shield(
|
||||
self, shield_id: str, messages: list[Message], params: dict[str, Any] = None
|
||||
) -> RunShieldResponse:
|
||||
shield = await self.shield_store.get_shield(shield_id)
|
||||
if not shield:
|
||||
raise ValueError(f"Shield {shield_id} not found")
|
||||
|
||||
"""
|
||||
This is the implementation for the bedrock guardrails. The input to the guardrails is to be of this format
|
||||
```content = [
|
||||
{
|
||||
"text": {
|
||||
"text": "Is the AB503 Product a better investment than the S&P 500?"
|
||||
}
|
||||
}
|
||||
]```
|
||||
Incoming messages contain content, role . For now we will extract the content and
|
||||
default the "qualifiers": ["query"]
|
||||
"""
|
||||
|
||||
shield_params = shield.params
|
||||
logger.debug(f"run_shield::{shield_params}::messages={messages}")
|
||||
|
||||
# - convert the messages into format Bedrock expects
|
||||
content_messages = []
|
||||
for message in messages:
|
||||
content_messages.append({"text": {"text": message.content}})
|
||||
logger.debug(f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:")
|
||||
|
||||
response = self.bedrock_runtime_client.apply_guardrail(
|
||||
guardrailIdentifier=shield.provider_resource_id,
|
||||
guardrailVersion=shield_params["guardrailVersion"],
|
||||
source="OUTPUT", # or 'INPUT' depending on your use case
|
||||
content=content_messages,
|
||||
)
|
||||
if response["action"] == "GUARDRAIL_INTERVENED":
|
||||
user_message = ""
|
||||
metadata = {}
|
||||
for output in response["outputs"]:
|
||||
# guardrails returns a list - however for this implementation we will leverage the last values
|
||||
user_message = output["text"]
|
||||
for assessment in response["assessments"]:
|
||||
# guardrails returns a list - however for this implementation we will leverage the last values
|
||||
metadata = dict(assessment)
|
||||
|
||||
return RunShieldResponse(
|
||||
violation=SafetyViolation(
|
||||
user_message=user_message,
|
||||
violation_level=ViolationLevel.ERROR,
|
||||
metadata=metadata,
|
||||
)
|
||||
)
|
||||
|
||||
return RunShieldResponse()
|
||||
|
|
@ -1,14 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from llama_stack.providers.utils.bedrock.config import BedrockBaseConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class BedrockSafetyConfig(BedrockBaseConfig):
|
||||
pass
|
||||
|
|
@ -1,77 +0,0 @@
|
|||
# NVIDIA Safety Provider for LlamaStack
|
||||
|
||||
This provider enables safety checks and guardrails for LLM interactions using NVIDIA's NeMo Guardrails service.
|
||||
|
||||
## Features
|
||||
|
||||
- Run safety checks for messages
|
||||
|
||||
## Getting Started
|
||||
|
||||
### Prerequisites
|
||||
|
||||
- LlamaStack with NVIDIA configuration
|
||||
- Access to NVIDIA NeMo Guardrails service
|
||||
- NIM for model to use for safety check is deployed
|
||||
|
||||
### Setup
|
||||
|
||||
Build the NVIDIA environment:
|
||||
|
||||
```bash
|
||||
llama stack build --distro nvidia --image-type venv
|
||||
```
|
||||
|
||||
### Basic Usage using the LlamaStack Python Client
|
||||
|
||||
#### Initialize the client
|
||||
|
||||
```python
|
||||
import os
|
||||
|
||||
os.environ["NVIDIA_API_KEY"] = "your-api-key"
|
||||
os.environ["NVIDIA_GUARDRAILS_URL"] = "http://guardrails.test"
|
||||
|
||||
from llama_stack.core.library_client import LlamaStackAsLibraryClient
|
||||
|
||||
client = LlamaStackAsLibraryClient("nvidia")
|
||||
client.initialize()
|
||||
```
|
||||
|
||||
#### Create a safety shield
|
||||
|
||||
```python
|
||||
from llama_stack.apis.safety import Shield
|
||||
from llama_stack.apis.inference import Message
|
||||
|
||||
# Create a safety shield
|
||||
shield = Shield(
|
||||
shield_id="your-shield-id",
|
||||
provider_resource_id="safety-model-id", # The model to use for safety checks
|
||||
description="Safety checks for content moderation",
|
||||
)
|
||||
|
||||
# Register the shield
|
||||
await client.safety.register_shield(shield)
|
||||
```
|
||||
|
||||
#### Run safety checks
|
||||
|
||||
```python
|
||||
# Messages to check
|
||||
messages = [Message(role="user", content="Your message to check")]
|
||||
|
||||
# Run safety check
|
||||
response = await client.safety.run_shield(
|
||||
shield_id="your-shield-id",
|
||||
messages=messages,
|
||||
)
|
||||
|
||||
# Check for violations
|
||||
if response.violation:
|
||||
print(f"Safety violation detected: {response.violation.user_message}")
|
||||
print(f"Violation level: {response.violation.violation_level}")
|
||||
print(f"Metadata: {response.violation.metadata}")
|
||||
else:
|
||||
print("No safety violations detected")
|
||||
```
|
||||
|
|
@ -1,18 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from typing import Any
|
||||
|
||||
from .config import NVIDIASafetyConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: NVIDIASafetyConfig, _deps) -> Any:
|
||||
from .nvidia import NVIDIASafetyAdapter
|
||||
|
||||
impl = NVIDIASafetyAdapter(config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
@ -1,40 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class NVIDIASafetyConfig(BaseModel):
|
||||
"""
|
||||
Configuration for the NVIDIA Guardrail microservice endpoint.
|
||||
|
||||
Attributes:
|
||||
guardrails_service_url (str): A base url for accessing the NVIDIA guardrail endpoint, e.g. http://0.0.0.0:7331
|
||||
config_id (str): The ID of the guardrails configuration to use from the configuration store
|
||||
(https://developer.nvidia.com/docs/nemo-microservices/guardrails/source/guides/configuration-store-guide.html)
|
||||
|
||||
"""
|
||||
|
||||
guardrails_service_url: str = Field(
|
||||
default_factory=lambda: os.getenv("GUARDRAILS_SERVICE_URL", "http://0.0.0.0:7331"),
|
||||
description="The url for accessing the Guardrails service",
|
||||
)
|
||||
config_id: str | None = Field(
|
||||
default_factory=lambda: os.getenv("NVIDIA_GUARDRAILS_CONFIG_ID", "self-check"),
|
||||
description="Guardrails configuration ID to use from the Guardrails configuration store",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
"guardrails_service_url": "${env.GUARDRAILS_SERVICE_URL:=http://localhost:7331}",
|
||||
"config_id": "${env.NVIDIA_GUARDRAILS_CONFIG_ID:=self-check}",
|
||||
}
|
||||
|
|
@ -1,163 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
|
||||
from llama_stack.apis.inference import Message
|
||||
from llama_stack.apis.safety import ModerationObject, RunShieldResponse, Safety, SafetyViolation, ViolationLevel
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.openai_compat import convert_message_to_openai_dict_new
|
||||
|
||||
from .config import NVIDIASafetyConfig
|
||||
|
||||
logger = get_logger(name=__name__, category="safety::nvidia")
|
||||
|
||||
|
||||
class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
|
||||
def __init__(self, config: NVIDIASafetyConfig) -> None:
|
||||
"""
|
||||
Initialize the NVIDIASafetyAdapter with a given safety configuration.
|
||||
|
||||
Args:
|
||||
config (NVIDIASafetyConfig): The configuration containing the guardrails service URL and config ID.
|
||||
"""
|
||||
self.config = config
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def register_shield(self, shield: Shield) -> None:
|
||||
if not shield.provider_resource_id:
|
||||
raise ValueError("Shield model not provided.")
|
||||
|
||||
async def unregister_shield(self, identifier: str) -> None:
|
||||
pass
|
||||
|
||||
async def run_shield(
|
||||
self, shield_id: str, messages: list[Message], params: dict[str, Any] | None = None
|
||||
) -> RunShieldResponse:
|
||||
"""
|
||||
Run a safety shield check against the provided messages.
|
||||
|
||||
Args:
|
||||
shield_id (str): The unique identifier for the shield to be used.
|
||||
messages (List[Message]): A list of Message objects representing the conversation history.
|
||||
params (Optional[dict[str, Any]]): Additional parameters for the shield check.
|
||||
|
||||
Returns:
|
||||
RunShieldResponse: The response containing safety violation details if any.
|
||||
|
||||
Raises:
|
||||
ValueError: If the shield with the provided shield_id is not found.
|
||||
"""
|
||||
shield = await self.shield_store.get_shield(shield_id)
|
||||
if not shield:
|
||||
raise ValueError(f"Shield {shield_id} not found")
|
||||
|
||||
self.shield = NeMoGuardrails(self.config, shield.shield_id)
|
||||
return await self.shield.run(messages)
|
||||
|
||||
async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject:
|
||||
raise NotImplementedError("NVIDIA safety provider currently does not implement run_moderation")
|
||||
|
||||
|
||||
class NeMoGuardrails:
|
||||
"""
|
||||
A class that encapsulates NVIDIA's guardrails safety logic.
|
||||
|
||||
Sends messages to the guardrails service and interprets the response to determine
|
||||
if a safety violation has occurred.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: NVIDIASafetyConfig,
|
||||
model: str,
|
||||
threshold: float = 0.9,
|
||||
temperature: float = 1.0,
|
||||
):
|
||||
"""
|
||||
Initialize a NeMoGuardrails instance with the provided parameters.
|
||||
|
||||
Args:
|
||||
config (NVIDIASafetyConfig): The safety configuration containing the config ID and guardrails URL.
|
||||
model (str): The identifier or name of the model to be used for safety checks.
|
||||
threshold (float, optional): The threshold for flagging violations. Defaults to 0.9.
|
||||
temperature (float, optional): The temperature setting for the underlying model. Must be greater than 0. Defaults to 1.0.
|
||||
|
||||
Raises:
|
||||
ValueError: If temperature is less than or equal to 0.
|
||||
AssertionError: If config_id is not provided in the configuration.
|
||||
"""
|
||||
self.config_id = config.config_id
|
||||
self.model = model
|
||||
assert self.config_id is not None, "Must provide config id"
|
||||
if temperature <= 0:
|
||||
raise ValueError("Temperature must be greater than 0")
|
||||
|
||||
self.temperature = temperature
|
||||
self.threshold = threshold
|
||||
self.guardrails_service_url = config.guardrails_service_url
|
||||
|
||||
async def _guardrails_post(self, path: str, data: Any | None):
|
||||
"""Helper for making POST requests to the guardrails service."""
|
||||
headers = {
|
||||
"Accept": "application/json",
|
||||
}
|
||||
response = requests.post(url=f"{self.guardrails_service_url}{path}", headers=headers, json=data)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
async def run(self, messages: list[Message]) -> RunShieldResponse:
|
||||
"""
|
||||
Queries the /v1/guardrails/checks endpoint of the NeMo guardrails deployed API.
|
||||
|
||||
Args:
|
||||
messages (List[Message]): A list of Message objects to be checked for safety violations.
|
||||
|
||||
Returns:
|
||||
RunShieldResponse: If the response indicates a violation ("blocked" status), returns a
|
||||
RunShieldResponse with a SafetyViolation; otherwise, returns a RunShieldResponse with violation set to None.
|
||||
|
||||
Raises:
|
||||
requests.HTTPError: If the POST request fails.
|
||||
"""
|
||||
request_messages = [await convert_message_to_openai_dict_new(message) for message in messages]
|
||||
request_data = {
|
||||
"model": self.model,
|
||||
"messages": request_messages,
|
||||
"temperature": self.temperature,
|
||||
"top_p": 1,
|
||||
"frequency_penalty": 0,
|
||||
"presence_penalty": 0,
|
||||
"max_tokens": 160,
|
||||
"stream": False,
|
||||
"guardrails": {
|
||||
"config_id": self.config_id,
|
||||
},
|
||||
}
|
||||
response = await self._guardrails_post(path="/v1/guardrail/checks", data=request_data)
|
||||
|
||||
if response["status"] == "blocked":
|
||||
user_message = "Sorry I cannot do this."
|
||||
metadata = response["rails_status"]
|
||||
|
||||
return RunShieldResponse(
|
||||
violation=SafetyViolation(
|
||||
user_message=user_message,
|
||||
violation_level=ViolationLevel.ERROR,
|
||||
metadata=metadata,
|
||||
)
|
||||
)
|
||||
|
||||
return RunShieldResponse(violation=None)
|
||||
|
|
@ -1,18 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from typing import Any
|
||||
|
||||
from .config import SambaNovaSafetyConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: SambaNovaSafetyConfig, _deps) -> Any:
|
||||
from .sambanova import SambaNovaSafetyAdapter
|
||||
|
||||
impl = SambaNovaSafetyAdapter(config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
@ -1,37 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
class SambaNovaProviderDataValidator(BaseModel):
|
||||
sambanova_api_key: str | None = Field(
|
||||
default=None,
|
||||
description="Sambanova Cloud API key",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class SambaNovaSafetyConfig(BaseModel):
|
||||
url: str = Field(
|
||||
default="https://api.sambanova.ai/v1",
|
||||
description="The URL for the SambaNova AI server",
|
||||
)
|
||||
api_key: SecretStr | None = Field(
|
||||
default=None,
|
||||
description="The SambaNova cloud API Key",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, api_key: str = "${env.SAMBANOVA_API_KEY:=}", **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
"url": "https://api.sambanova.ai/v1",
|
||||
"api_key": api_key,
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue