diff --git a/llama_stack/cli/stack/build.py b/llama_stack/cli/stack/build.py index ce1ed2747..08a1e5990 100644 --- a/llama_stack/cli/stack/build.py +++ b/llama_stack/cli/stack/build.py @@ -16,9 +16,9 @@ from pathlib import Path import pkg_resources from llama_stack.distribution.distribution import get_provider_registry +from llama_stack.distribution.resolver import InvalidProviderError from llama_stack.distribution.utils.dynamic import instantiate_class_type - TEMPLATES_PATH = Path(os.path.relpath(__file__)).parent.parent.parent / "templates" @@ -81,13 +81,13 @@ class StackBuild(Subcommand): import textwrap import yaml + + from llama_stack.distribution.distribution import get_provider_registry from prompt_toolkit import prompt from prompt_toolkit.completion import WordCompleter from prompt_toolkit.validation import Validator from termcolor import cprint - from llama_stack.distribution.distribution import get_provider_registry - if args.list_templates: self._run_template_list_cmd(args) return @@ -192,9 +192,9 @@ class StackBuild(Subcommand): import json import yaml - from termcolor import cprint from llama_stack.distribution.build import ImageType + from termcolor import cprint apis = list(build_config.distribution_spec.providers.keys()) run_config = StackRunConfig( @@ -223,6 +223,10 @@ class StackBuild(Subcommand): for i, provider_type in enumerate(provider_types): pid = provider_type.split("::")[-1] + p = provider_registry[Api(api)][provider_type] + if p.deprecation_error: + raise InvalidProviderError(p.deprecation_error) + config_type = instantiate_class_type( provider_registry[Api(api)][provider_type].config_class ) @@ -260,10 +264,10 @@ class StackBuild(Subcommand): import re import yaml - from termcolor import cprint from llama_stack.distribution.build import build_image from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR + from termcolor import cprint # save build.yaml spec for building same distribution again build_dir = DISTRIBS_BASE_DIR / f"llamastack-{build_config.name}" diff --git a/llama_stack/providers/registry/safety.py b/llama_stack/providers/registry/safety.py index 77dd823eb..99b0d2bd8 100644 --- a/llama_stack/providers/registry/safety.py +++ b/llama_stack/providers/registry/safety.py @@ -17,6 +17,16 @@ from llama_stack.distribution.datatypes import ( def available_providers() -> List[ProviderSpec]: return [ + InlineProviderSpec( + api=Api.safety, + provider_type="inline::prompt-guard", + pip_packages=[ + "transformers", + "torch --index-url https://download.pytorch.org/whl/cpu", + ], + module="llama_stack.providers.inline.safety.prompt_guard", + config_class="llama_stack.providers.inline.safety.prompt_guard.PromptGuardConfig", + ), InlineProviderSpec( api=Api.safety, provider_type="inline::meta-reference", @@ -48,16 +58,6 @@ Provider `inline::meta-reference` for API `safety` does not work with the latest Api.inference, ], ), - InlineProviderSpec( - api=Api.safety, - provider_type="inline::prompt-guard", - pip_packages=[ - "transformers", - "torch --index-url https://download.pytorch.org/whl/cpu", - ], - module="llama_stack.providers.inline.safety.prompt_guard", - config_class="llama_stack.providers.inline.safety.prompt_guard.PromptGuardConfig", - ), InlineProviderSpec( api=Api.safety, provider_type="inline::code-scanner",