update build cli

This commit is contained in:
Xi Yan 2024-09-14 13:35:09 -07:00
parent 768ed09dec
commit d1f0d17644
6 changed files with 109 additions and 117 deletions

View file

@ -72,65 +72,20 @@ class StackBuild(Subcommand):
from llama_toolchain.core.package import ApiInput, build_package, ImageType from llama_toolchain.core.package import ApiInput, build_package, ImageType
from termcolor import cprint from termcolor import cprint
# expect build to take in a distribution spec file
api_inputs = [] api_inputs = []
if build_config.distribution == "adhoc": for api, provider_type in build_config.distribution_spec.providers.items():
if not build_config.api_providers: api_inputs.append(
self.parser.error( ApiInput(
"You must specify API providers with (api=provider,...) for building an adhoc distribution" api=Api(api),
provider=provider_type,
) )
return )
parsed = parse_api_provider_tuples(build_config.api_providers, self.parser) build_package(build_config)
for api, provider_spec in parsed.items():
for dep in provider_spec.api_dependencies:
if dep not in parsed:
self.parser.error(
f"API {api} needs dependency {dep} provided also"
)
return
api_inputs.append(
ApiInput(
api=api,
provider=provider_spec.provider_type,
)
)
docker_image = None
else:
if build_config.api_providers:
self.parser.error(
"You cannot specify API providers for pre-registered distributions"
)
return
dist = resolve_distribution_spec(build_config.distribution)
if dist is None:
self.parser.error(
f"Could not find distribution {build_config.distribution}"
)
return
for api, provider_type in dist.providers.items():
api_inputs.append(
ApiInput(
api=api,
provider=provider_type,
)
)
docker_image = dist.docker_image
build_package(
api_inputs,
image_type=ImageType(build_config.image_type),
name=build_config.name,
distribution_type=build_config.distribution,
docker_image=docker_image,
)
# save build.yaml spec for building same distribution again # save build.yaml spec for building same distribution again
build_dir = ( build_dir = DISTRIBS_BASE_DIR / build_config.image_type
DISTRIBS_BASE_DIR / build_config.distribution / build_config.image_type
)
os.makedirs(build_dir, exist_ok=True) os.makedirs(build_dir, exist_ok=True)
build_file_path = build_dir / f"{build_config.name}-build.yaml" build_file_path = build_dir / f"{build_config.name}-build.yaml"

View file

@ -19,17 +19,15 @@ fi
set -euo pipefail set -euo pipefail
if [ "$#" -ne 4 ]; then if [ "$#" -ne 2 ]; then
echo "Usage: $0 <distribution_type> <build_name> <pip_dependencies>" >&2 echo "Usage: $0 <distribution_type> <build_name> <pip_dependencies>" >&2
echo "Example: $0 <distribution_type> mybuild 'numpy pandas scipy'" >&2 echo "Example: $0 <distribution_type> mybuild 'numpy pandas scipy'" >&2
exit 1 exit 1
fi fi
distribution_type="$1" build_name="$1"
build_name="$2"
env_name="llamastack-$build_name" env_name="llamastack-$build_name"
config_file="$3" pip_dependencies="$2"
pip_dependencies="$4"
# Define color codes # Define color codes
RED='\033[0;31m' RED='\033[0;31m'

View file

@ -151,12 +151,23 @@ def remote_provider_spec(
@json_schema_type @json_schema_type
class DistributionSpec(BaseModel): class DistributionSpec(BaseModel):
distribution_type: str distribution_type: str = Field(
description: str default="local",
description="Name of the distribution type. This can used to identify the distribution",
)
description: str = Field(
default="Use code from `llama_toolchain` itself to serve all llama stack APIs",
description="Description of the distribution",
)
docker_image: Optional[str] = None docker_image: Optional[str] = None
providers: Dict[Api, str] = Field( providers: Dict[str, str] = Field(
default_factory=dict, default={
"inference": "meta-reference",
"memory": "meta-reference-faiss",
"safety": "meta-reference",
"agentic_system": "meta-reference",
"telemetry": "console",
},
description="Provider Types for each of the APIs provided by this distribution", description="Provider Types for each of the APIs provided by this distribution",
) )
@ -194,12 +205,8 @@ the dependencies of these providers as well.
@json_schema_type @json_schema_type
class BuildConfig(BaseModel): class BuildConfig(BaseModel):
name: str name: str
distribution: str = Field( distribution_spec: DistributionSpec = Field(
default="local", description="Type of distribution to build (adhoc | {})" description="The distribution spec to build including API providers. "
)
api_providers: Optional[str] = Field(
default_factory=list,
description="List of API provider names to build",
) )
image_type: str = Field( image_type: str = Field(
default="conda", default="conda",

View file

@ -31,14 +31,14 @@ SERVER_DEPENDENCIES = [
] ]
def distribution_dependencies(distribution: DistributionSpec) -> List[str]: # def distribution_dependencies(distribution: DistributionSpec) -> List[str]:
# only consider InlineProviderSpecs when calculating dependencies # # only consider InlineProviderSpecs when calculating dependencies
return [ # return [
dep # dep
for provider_spec in distribution.provider_specs.values() # for provider_spec in distribution.provider_specs.values()
if isinstance(provider_spec, InlineProviderSpec) # if isinstance(provider_spec, InlineProviderSpec)
for dep in provider_spec.pip_packages # for dep in provider_spec.pip_packages
] + SERVER_DEPENDENCIES # ] + SERVER_DEPENDENCIES
def stack_apis() -> List[Api]: def stack_apis() -> List[Api]:

View file

@ -17,62 +17,43 @@ def available_distribution_specs() -> List[DistributionSpec]:
distribution_type="local", distribution_type="local",
description="Use code from `llama_toolchain` itself to serve all llama stack APIs", description="Use code from `llama_toolchain` itself to serve all llama stack APIs",
providers={ providers={
Api.inference: "meta-reference", "inference": "meta-reference",
Api.memory: "meta-reference-faiss", "memory": "meta-reference-faiss",
Api.safety: "meta-reference", "safety": "meta-reference",
Api.agentic_system: "meta-reference", "agentic_system": "meta-reference",
Api.telemetry: "console", "telemetry": "console",
},
),
DistributionSpec(
distribution_type="remote",
description="Point to remote services for all llama stack APIs",
providers={
**{x: "remote" for x in Api},
Api.telemetry: "console",
}, },
), ),
DistributionSpec( DistributionSpec(
distribution_type="local-ollama", distribution_type="local-ollama",
description="Like local, but use ollama for running LLM inference", description="Like local, but use ollama for running LLM inference",
providers={ providers={
Api.inference: remote_provider_type("ollama"), "inference": remote_provider_type("ollama"),
Api.safety: "meta-reference", "safety": "meta-reference",
Api.agentic_system: "meta-reference", "agentic_system": "meta-reference",
Api.memory: "meta-reference-faiss", "memory": "meta-reference-faiss",
Api.telemetry: "console", "telemetry": "console",
}, },
), ),
DistributionSpec( DistributionSpec(
distribution_type="local-plus-fireworks-inference", distribution_type="local-plus-fireworks-inference",
description="Use Fireworks.ai for running LLM inference", description="Use Fireworks.ai for running LLM inference",
providers={ providers={
Api.inference: remote_provider_type("fireworks"), "inference": remote_provider_type("fireworks"),
Api.safety: "meta-reference", "safety": "meta-reference",
Api.agentic_system: "meta-reference", "agentic_system": "meta-reference",
Api.memory: "meta-reference-faiss", "memory": "meta-reference-faiss",
Api.telemetry: "console", "telemetry": "console",
},
),
DistributionSpec(
distribution_type="local-plus-together-inference",
description="Use Together.ai for running LLM inference",
providers={
Api.inference: remote_provider_type("together"),
Api.safety: "meta-reference",
Api.agentic_system: "meta-reference",
Api.memory: "meta-reference-faiss",
Api.telemetry: "console",
}, },
), ),
DistributionSpec( DistributionSpec(
distribution_type="local-plus-tgi-inference", distribution_type="local-plus-tgi-inference",
description="Use TGI for running LLM inference", description="Use TGI for running LLM inference",
providers={ providers={
Api.inference: remote_provider_type("tgi"), "inference": remote_provider_type("tgi"),
Api.safety: "meta-reference", "safety": "meta-reference",
Api.agentic_system: "meta-reference", "agentic_system": "meta-reference",
Api.memory: "meta-reference-faiss", "memory": "meta-reference-faiss",
}, },
), ),
] ]

View file

@ -39,7 +39,58 @@ class ApiInput(BaseModel):
provider: str provider: str
def build_package( def build_package(build_config: BuildConfig):
package_deps = Dependencies(
docker_image=build_config.distribution_spec.docker_image or "python:3.10-slim",
pip_packages=SERVER_DEPENDENCIES,
)
# extend package dependencies based on providers spec
all_providers = api_providers()
for api_str, provider in build_config.distribution_spec.providers.items():
providers_for_api = all_providers[Api(api_str)]
if provider not in providers_for_api:
raise ValueError(
f"Provider `{provider}` is not available for API `{api_str}`"
)
provider_spec = providers_for_api[provider]
package_deps.pip_packages.extend(provider_spec.pip_packages)
if provider_spec.docker_image:
raise ValueError("A stack's dependencies cannot have a docker image")
if build_config.image_type == ImageType.docker.value:
script = pkg_resources.resource_filename(
"llama_toolchain", "core/build_container.sh"
)
args = [
script,
distribution_type,
package_name,
package_deps.docker_image,
str(package_file),
" ".join(package_deps.pip_packages),
]
else:
script = pkg_resources.resource_filename(
"llama_toolchain", "core/build_conda_env.sh"
)
args = [
script,
build_config.name,
" ".join(package_deps.pip_packages),
]
return_code = run_with_pty(args)
if return_code != 0:
cprint(
f"Failed to build target {build_config.name} with return code {return_code}",
color="red",
)
return
def build_package_deprecated(
api_inputs: List[ApiInput], api_inputs: List[ApiInput],
image_type: ImageType, image_type: ImageType,
name: str, name: str,