add nvidia distribution

This commit is contained in:
Ubuntu 2025-03-06 18:26:53 +00:00 committed by raspawar
parent 63e380400a
commit c71e2a0d87
7 changed files with 67 additions and 12 deletions

View file

@ -6,7 +6,7 @@
from typing import List from typing import List
from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec from llama_stack.providers.datatypes import AdapterSpec, Api, InlineProviderSpec, ProviderSpec, remote_provider_spec
def available_providers() -> List[ProviderSpec]: def available_providers() -> List[ProviderSpec]:
@ -22,15 +22,13 @@ def available_providers() -> List[ProviderSpec]:
Api.datasets, Api.datasets,
], ],
), ),
InlineProviderSpec( remote_provider_spec(
api=Api.post_training, api=Api.post_training,
provider_type="remote::nvidia", adapter=AdapterSpec(
pip_packages=["torch", "numpy"], adapter_type="nvidia",
module="llama_stack.providers.remote.post_training.nvidia", pip_packages=["requests"],
config_class="llama_stack.providers.remote.post_training.nvidia.NvidiaPostTrainingConfig", module="llama_stack.providers.remote.post_training.nvidia",
api_dependencies=[ config_class="llama_stack.providers.remote.post_training.nvidia.NvidiaPostTrainingConfig",
Api.datasetio, ),
Api.datasets,
],
), ),
] ]

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -13,7 +13,7 @@ from .config import NvidiaPostTrainingConfig
# post_training api and the torchtune provider is still experimental and under heavy development # post_training api and the torchtune provider is still experimental and under heavy development
async def get_provider_impl( async def get_adapter_impl(
config: NvidiaPostTrainingConfig, config: NvidiaPostTrainingConfig,
deps: Dict[Api, ProviderSpec], deps: Dict[Api, ProviderSpec],
): ):

View file

@ -5,7 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
import os import os
from typing import Optional from typing import Any, Dict, Optional
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@ -58,3 +58,15 @@ class NvidiaPostTrainingConfig(BaseModel):
default_factory=lambda: os.getenv("NVIDIA_OUTPUT_MODEL_DIR", "test-example-model@v1"), default_factory=lambda: os.getenv("NVIDIA_OUTPUT_MODEL_DIR", "test-example-model@v1"),
description="Directory to save the output model", description="Directory to save the output model",
) )
@classmethod
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
return {
"api_key": "${env.NVIDIA_API_KEY:}",
"user_id": "${env.NVIDIA_USER_ID:llama-stack-user}",
"dataset_namespace": "${env.NVIDIA_DATASET_NAMESPACE:default}",
"access_policies": "${env.NVIDIA_ACCESS_POLICIES:}",
"project_id": "${env.NVIDIA_PROJECT_ID:test-project}",
"customizer_url": "${env.NVIDIA_CUSTOMIZER_URL:}",
"output_model_dir": "${env.NVIDIA_OUTPUT_MODEL_DIR:test-example-model@v1}",
}

View file

@ -20,6 +20,8 @@ distribution_spec:
- inline::basic - inline::basic
- inline::llm-as-judge - inline::llm-as-judge
- inline::braintrust - inline::braintrust
post_training:
- remote::nvidia
tool_runtime: tool_runtime:
- inline::rag-runtime - inline::rag-runtime
image_type: conda image_type: conda

View file

@ -9,6 +9,7 @@ from pathlib import Path
from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput, ToolGroupInput from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput, ToolGroupInput
from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig
from llama_stack.providers.remote.inference.nvidia.models import MODEL_ENTRIES from llama_stack.providers.remote.inference.nvidia.models import MODEL_ENTRIES
from llama_stack.providers.remote.post_training.nvidia import NvidiaPostTrainingConfig
from llama_stack.providers.remote.safety.nvidia import NVIDIASafetyConfig from llama_stack.providers.remote.safety.nvidia import NVIDIASafetyConfig
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings, get_model_registry from llama_stack.templates.template import DistributionTemplate, RunConfigSettings, get_model_registry
@ -18,6 +19,7 @@ def get_distribution_template() -> DistributionTemplate:
"inference": ["remote::nvidia"], "inference": ["remote::nvidia"],
"vector_io": ["inline::faiss"], "vector_io": ["inline::faiss"],
"safety": ["remote::nvidia"], "safety": ["remote::nvidia"],
"post_training": ["remote::nvidia"],
"agents": ["inline::meta-reference"], "agents": ["inline::meta-reference"],
"telemetry": ["inline::meta-reference"], "telemetry": ["inline::meta-reference"],
"eval": ["inline::meta-reference"], "eval": ["inline::meta-reference"],
@ -31,6 +33,12 @@ def get_distribution_template() -> DistributionTemplate:
provider_type="remote::nvidia", provider_type="remote::nvidia",
config=NVIDIAConfig.sample_run_config(), config=NVIDIAConfig.sample_run_config(),
) )
post_training_provider = Provider(
provider_id="nvidia",
provider_type="remote::nvidia",
config=NvidiaPostTrainingConfig.sample_run_config(),
)
safety_provider = Provider( safety_provider = Provider(
provider_id="nvidia", provider_id="nvidia",
provider_type="remote::nvidia", provider_type="remote::nvidia",
@ -89,6 +97,31 @@ def get_distribution_template() -> DistributionTemplate:
"", "",
"NVIDIA API Key", "NVIDIA API Key",
), ),
## Nemo Customizer related variables
"NVIDIA_USER_ID": (
"llama-stack-user",
"NVIDIA User ID",
),
"NVIDIA_DATASET_NAMESPACE": (
"default",
"NVIDIA Dataset Namespace",
),
"NVIDIA_ACCESS_POLICIES": (
"{}",
"NVIDIA Access Policies",
),
"NVIDIA_PROJECT_ID": (
"test-project",
"NVIDIA Project ID",
),
"NVIDIA_CUSTOMIZER_URL": (
"https://customizer.api.nvidia.com",
"NVIDIA Customizer URL",
),
"NVIDIA_OUTPUT_MODEL_DIR": (
"test-example-model@v1",
"NVIDIA Output Model Directory",
),
"GUARDRAILS_SERVICE_URL": ( "GUARDRAILS_SERVICE_URL": (
"http://0.0.0.0:7331", "http://0.0.0.0:7331",
"URL for the NeMo Guardrails Service", "URL for the NeMo Guardrails Service",

View file

@ -8,6 +8,7 @@ apis:
- safety - safety
- scoring - scoring
- telemetry - telemetry
- post_training
- tool_runtime - tool_runtime
- vector_io - vector_io
providers: providers:
@ -73,6 +74,10 @@ providers:
provider_type: inline::braintrust provider_type: inline::braintrust
config: config:
openai_api_key: ${env.OPENAI_API_KEY:} openai_api_key: ${env.OPENAI_API_KEY:}
post_training:
- provider_id: nvidia-customizer
provider_type: remote::nvidia
config: {}
tool_runtime: tool_runtime:
- provider_id: rag-runtime - provider_id: rag-runtime
provider_type: inline::rag-runtime provider_type: inline::rag-runtime