fix build

This commit is contained in:
dltn 2024-11-22 13:12:06 -08:00
parent 2137b0af40
commit 2eaab52db9
2 changed files with 19 additions and 15 deletions

View file

@ -16,9 +16,9 @@ from pathlib import Path
import pkg_resources
from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.resolver import InvalidProviderError
from llama_stack.distribution.utils.dynamic import instantiate_class_type
TEMPLATES_PATH = Path(os.path.relpath(__file__)).parent.parent.parent / "templates"
@ -81,13 +81,13 @@ class StackBuild(Subcommand):
import textwrap
import yaml
from llama_stack.distribution.distribution import get_provider_registry
from prompt_toolkit import prompt
from prompt_toolkit.completion import WordCompleter
from prompt_toolkit.validation import Validator
from termcolor import cprint
from llama_stack.distribution.distribution import get_provider_registry
if args.list_templates:
self._run_template_list_cmd(args)
return
@ -192,9 +192,9 @@ class StackBuild(Subcommand):
import json
import yaml
from termcolor import cprint
from llama_stack.distribution.build import ImageType
from termcolor import cprint
apis = list(build_config.distribution_spec.providers.keys())
run_config = StackRunConfig(
@ -223,6 +223,10 @@ class StackBuild(Subcommand):
for i, provider_type in enumerate(provider_types):
pid = provider_type.split("::")[-1]
p = provider_registry[Api(api)][provider_type]
if p.deprecation_error:
raise InvalidProviderError(p.deprecation_error)
config_type = instantiate_class_type(
provider_registry[Api(api)][provider_type].config_class
)
@ -260,10 +264,10 @@ class StackBuild(Subcommand):
import re
import yaml
from termcolor import cprint
from llama_stack.distribution.build import build_image
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
from termcolor import cprint
# save build.yaml spec for building same distribution again
build_dir = DISTRIBS_BASE_DIR / f"llamastack-{build_config.name}"

View file

@ -17,6 +17,16 @@ from llama_stack.distribution.datatypes import (
def available_providers() -> List[ProviderSpec]:
return [
InlineProviderSpec(
api=Api.safety,
provider_type="inline::prompt-guard",
pip_packages=[
"transformers",
"torch --index-url https://download.pytorch.org/whl/cpu",
],
module="llama_stack.providers.inline.safety.prompt_guard",
config_class="llama_stack.providers.inline.safety.prompt_guard.PromptGuardConfig",
),
InlineProviderSpec(
api=Api.safety,
provider_type="inline::meta-reference",
@ -48,16 +58,6 @@ Provider `inline::meta-reference` for API `safety` does not work with the latest
Api.inference,
],
),
InlineProviderSpec(
api=Api.safety,
provider_type="inline::prompt-guard",
pip_packages=[
"transformers",
"torch --index-url https://download.pytorch.org/whl/cpu",
],
module="llama_stack.providers.inline.safety.prompt_guard",
config_class="llama_stack.providers.inline.safety.prompt_guard.PromptGuardConfig",
),
InlineProviderSpec(
api=Api.safety,
provider_type="inline::code-scanner",