update build cli

This commit is contained in:
Xi Yan 2024-09-14 13:35:09 -07:00
parent 768ed09dec
commit d1f0d17644
6 changed files with 109 additions and 117 deletions

View file

@ -72,65 +72,20 @@ class StackBuild(Subcommand):
from llama_toolchain.core.package import ApiInput, build_package, ImageType
from termcolor import cprint
# expect build to take in a distribution spec file
api_inputs = []
if build_config.distribution == "adhoc":
if not build_config.api_providers:
self.parser.error(
"You must specify API providers with (api=provider,...) for building an adhoc distribution"
)
return
parsed = parse_api_provider_tuples(build_config.api_providers, self.parser)
for api, provider_spec in parsed.items():
for dep in provider_spec.api_dependencies:
if dep not in parsed:
self.parser.error(
f"API {api} needs dependency {dep} provided also"
)
return
for api, provider_type in build_config.distribution_spec.providers.items():
api_inputs.append(
ApiInput(
api=api,
provider=provider_spec.provider_type,
)
)
docker_image = None
else:
if build_config.api_providers:
self.parser.error(
"You cannot specify API providers for pre-registered distributions"
)
return
dist = resolve_distribution_spec(build_config.distribution)
if dist is None:
self.parser.error(
f"Could not find distribution {build_config.distribution}"
)
return
for api, provider_type in dist.providers.items():
api_inputs.append(
ApiInput(
api=api,
api=Api(api),
provider=provider_type,
)
)
docker_image = dist.docker_image
build_package(
api_inputs,
image_type=ImageType(build_config.image_type),
name=build_config.name,
distribution_type=build_config.distribution,
docker_image=docker_image,
)
build_package(build_config)
# save build.yaml spec for building same distribution again
build_dir = (
DISTRIBS_BASE_DIR / build_config.distribution / build_config.image_type
)
build_dir = DISTRIBS_BASE_DIR / build_config.image_type
os.makedirs(build_dir, exist_ok=True)
build_file_path = build_dir / f"{build_config.name}-build.yaml"

View file

@ -19,17 +19,15 @@ fi
set -euo pipefail
if [ "$#" -ne 4 ]; then
if [ "$#" -ne 2 ]; then
echo "Usage: $0 <distribution_type> <build_name> <pip_dependencies>" >&2
echo "Example: $0 <distribution_type> mybuild 'numpy pandas scipy'" >&2
exit 1
fi
distribution_type="$1"
build_name="$2"
build_name="$1"
env_name="llamastack-$build_name"
config_file="$3"
pip_dependencies="$4"
pip_dependencies="$2"
# Define color codes
RED='\033[0;31m'

View file

@ -151,12 +151,23 @@ def remote_provider_spec(
@json_schema_type
class DistributionSpec(BaseModel):
distribution_type: str
description: str
distribution_type: str = Field(
default="local",
description="Name of the distribution type. This can used to identify the distribution",
)
description: str = Field(
default="Use code from `llama_toolchain` itself to serve all llama stack APIs",
description="Description of the distribution",
)
docker_image: Optional[str] = None
providers: Dict[Api, str] = Field(
default_factory=dict,
providers: Dict[str, str] = Field(
default={
"inference": "meta-reference",
"memory": "meta-reference-faiss",
"safety": "meta-reference",
"agentic_system": "meta-reference",
"telemetry": "console",
},
description="Provider Types for each of the APIs provided by this distribution",
)
@ -194,12 +205,8 @@ the dependencies of these providers as well.
@json_schema_type
class BuildConfig(BaseModel):
name: str
distribution: str = Field(
default="local", description="Type of distribution to build (adhoc | {})"
)
api_providers: Optional[str] = Field(
default_factory=list,
description="List of API provider names to build",
distribution_spec: DistributionSpec = Field(
description="The distribution spec to build including API providers. "
)
image_type: str = Field(
default="conda",

View file

@ -31,14 +31,14 @@ SERVER_DEPENDENCIES = [
]
def distribution_dependencies(distribution: DistributionSpec) -> List[str]:
# only consider InlineProviderSpecs when calculating dependencies
return [
dep
for provider_spec in distribution.provider_specs.values()
if isinstance(provider_spec, InlineProviderSpec)
for dep in provider_spec.pip_packages
] + SERVER_DEPENDENCIES
# def distribution_dependencies(distribution: DistributionSpec) -> List[str]:
# # only consider InlineProviderSpecs when calculating dependencies
# return [
# dep
# for provider_spec in distribution.provider_specs.values()
# if isinstance(provider_spec, InlineProviderSpec)
# for dep in provider_spec.pip_packages
# ] + SERVER_DEPENDENCIES
def stack_apis() -> List[Api]:

View file

@ -17,62 +17,43 @@ def available_distribution_specs() -> List[DistributionSpec]:
distribution_type="local",
description="Use code from `llama_toolchain` itself to serve all llama stack APIs",
providers={
Api.inference: "meta-reference",
Api.memory: "meta-reference-faiss",
Api.safety: "meta-reference",
Api.agentic_system: "meta-reference",
Api.telemetry: "console",
},
),
DistributionSpec(
distribution_type="remote",
description="Point to remote services for all llama stack APIs",
providers={
**{x: "remote" for x in Api},
Api.telemetry: "console",
"inference": "meta-reference",
"memory": "meta-reference-faiss",
"safety": "meta-reference",
"agentic_system": "meta-reference",
"telemetry": "console",
},
),
DistributionSpec(
distribution_type="local-ollama",
description="Like local, but use ollama for running LLM inference",
providers={
Api.inference: remote_provider_type("ollama"),
Api.safety: "meta-reference",
Api.agentic_system: "meta-reference",
Api.memory: "meta-reference-faiss",
Api.telemetry: "console",
"inference": remote_provider_type("ollama"),
"safety": "meta-reference",
"agentic_system": "meta-reference",
"memory": "meta-reference-faiss",
"telemetry": "console",
},
),
DistributionSpec(
distribution_type="local-plus-fireworks-inference",
description="Use Fireworks.ai for running LLM inference",
providers={
Api.inference: remote_provider_type("fireworks"),
Api.safety: "meta-reference",
Api.agentic_system: "meta-reference",
Api.memory: "meta-reference-faiss",
Api.telemetry: "console",
},
),
DistributionSpec(
distribution_type="local-plus-together-inference",
description="Use Together.ai for running LLM inference",
providers={
Api.inference: remote_provider_type("together"),
Api.safety: "meta-reference",
Api.agentic_system: "meta-reference",
Api.memory: "meta-reference-faiss",
Api.telemetry: "console",
"inference": remote_provider_type("fireworks"),
"safety": "meta-reference",
"agentic_system": "meta-reference",
"memory": "meta-reference-faiss",
"telemetry": "console",
},
),
DistributionSpec(
distribution_type="local-plus-tgi-inference",
description="Use TGI for running LLM inference",
providers={
Api.inference: remote_provider_type("tgi"),
Api.safety: "meta-reference",
Api.agentic_system: "meta-reference",
Api.memory: "meta-reference-faiss",
"inference": remote_provider_type("tgi"),
"safety": "meta-reference",
"agentic_system": "meta-reference",
"memory": "meta-reference-faiss",
},
),
]

View file

@ -39,7 +39,58 @@ class ApiInput(BaseModel):
provider: str
def build_package(
def build_package(build_config: BuildConfig):
package_deps = Dependencies(
docker_image=build_config.distribution_spec.docker_image or "python:3.10-slim",
pip_packages=SERVER_DEPENDENCIES,
)
# extend package dependencies based on providers spec
all_providers = api_providers()
for api_str, provider in build_config.distribution_spec.providers.items():
providers_for_api = all_providers[Api(api_str)]
if provider not in providers_for_api:
raise ValueError(
f"Provider `{provider}` is not available for API `{api_str}`"
)
provider_spec = providers_for_api[provider]
package_deps.pip_packages.extend(provider_spec.pip_packages)
if provider_spec.docker_image:
raise ValueError("A stack's dependencies cannot have a docker image")
if build_config.image_type == ImageType.docker.value:
script = pkg_resources.resource_filename(
"llama_toolchain", "core/build_container.sh"
)
args = [
script,
distribution_type,
package_name,
package_deps.docker_image,
str(package_file),
" ".join(package_deps.pip_packages),
]
else:
script = pkg_resources.resource_filename(
"llama_toolchain", "core/build_conda_env.sh"
)
args = [
script,
build_config.name,
" ".join(package_deps.pip_packages),
]
return_code = run_with_pty(args)
if return_code != 0:
cprint(
f"Failed to build target {build_config.name} with return code {return_code}",
color="red",
)
return
def build_package_deprecated(
api_inputs: List[ApiInput],
image_type: ImageType,
name: str,