diff --git a/llama_toolchain/distribution/datatypes.py b/llama_toolchain/distribution/datatypes.py index 70f582c1c..4b80c549a 100644 --- a/llama_toolchain/distribution/datatypes.py +++ b/llama_toolchain/distribution/datatypes.py @@ -85,11 +85,6 @@ class DistributionSpec(BaseModel): description="Provider specifications for each of the APIs provided by this distribution", ) - additional_pip_packages: List[str] = Field( - default_factory=list, - description="Additional pip packages beyond those required by the providers", - ) - @json_schema_type class DistributionConfig(BaseModel): diff --git a/llama_toolchain/distribution/distribution.py b/llama_toolchain/distribution/distribution.py index 853092f38..f92547ba7 100644 --- a/llama_toolchain/distribution/distribution.py +++ b/llama_toolchain/distribution/distribution.py @@ -22,6 +22,14 @@ from .datatypes import ( ProviderSpec, ) +# These are the dependencies needed by the distribution server. +# `llama-toolchain` is automatically installed by the installation script. +SERVER_DEPENDENCIES = [ + "fastapi", + "python-dotenv", + "uvicorn", +] + def distribution_dependencies(distribution: DistributionSpec) -> List[str]: # only consider InlineProviderSpecs when calculating dependencies @@ -30,7 +38,7 @@ def distribution_dependencies(distribution: DistributionSpec) -> List[str]: for provider_spec in distribution.provider_specs.values() if isinstance(provider_spec, InlineProviderSpec) for dep in provider_spec.pip_packages - ] + distribution.additional_pip_packages + ] + SERVER_DEPENDENCIES def api_endpoints() -> Dict[Api, List[ApiEndpoint]]: diff --git a/llama_toolchain/distribution/install_distribution.sh b/llama_toolchain/distribution/install_distribution.sh index 60b128e1a..80727725d 100755 --- a/llama_toolchain/distribution/install_distribution.sh +++ b/llama_toolchain/distribution/install_distribution.sh @@ -6,6 +6,9 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +LLAMA_MODELS_DIR=${LLAMA_MODELS_DIR:-} +LLAMA_TOOLCHAIN_DIR=${LLAMA_TOOLCHAIN_DIR:-} + set -euo pipefail # Define color codes diff --git a/llama_toolchain/distribution/registry.py b/llama_toolchain/distribution/registry.py index b5c617a45..a60b3cd4f 100644 --- a/llama_toolchain/distribution/registry.py +++ b/llama_toolchain/distribution/registry.py @@ -10,14 +10,6 @@ from typing import List, Optional from .datatypes import Api, DistributionSpec, RemoteProviderSpec from .distribution import api_providers -# These are the dependencies needed by the distribution server. -# `llama-toolchain` is automatically installed by the installation script. -COMMON_DEPENDENCIES = [ - "fastapi", - "python-dotenv", - "uvicorn", -] - def client_module(api: Api) -> str: return f"llama_toolchain.{api.value}.client" @@ -38,22 +30,6 @@ def available_distribution_specs() -> List[DistributionSpec]: DistributionSpec( spec_id="inline", description="Use code from `llama_toolchain` itself to serve all llama stack APIs", - additional_pip_packages=( - COMMON_DEPENDENCIES - # why do we need any of these? they should be completely covered - # by the provider dependencies themselves - + [ - "accelerate", - "blobfile", - "codeshield", - "fairscale", - "pandas", - "Pillow", - "torch", - "transformers", - "fbgemm-gpu==0.8.0", - ] - ), provider_specs={ Api.inference: providers[Api.inference]["meta-reference"], Api.safety: providers[Api.safety]["meta-reference"], @@ -63,13 +39,11 @@ def available_distribution_specs() -> List[DistributionSpec]: DistributionSpec( spec_id="remote", description="Point to remote services for all llama stack APIs", - additional_pip_packages=COMMON_DEPENDENCIES, provider_specs={x: remote_spec(x) for x in providers}, ), DistributionSpec( spec_id="ollama-inline", description="Like local-source, but use ollama for running LLM inference", - additional_pip_packages=COMMON_DEPENDENCIES, provider_specs={ Api.inference: providers[Api.inference]["meta-ollama"], Api.safety: providers[Api.safety]["meta-reference"], diff --git a/llama_toolchain/inference/providers.py b/llama_toolchain/inference/providers.py index 80428c069..1b1eb05a4 100644 --- a/llama_toolchain/inference/providers.py +++ b/llama_toolchain/inference/providers.py @@ -15,7 +15,13 @@ def available_inference_providers() -> List[ProviderSpec]: api=Api.inference, provider_id="meta-reference", pip_packages=[ + "accelerate", + "blobfile", + "codeshield", + "fairscale", + "fbgemm-gpu==0.8.0", "torch", + "transformers", "zmq", ], module="llama_toolchain.inference.meta_reference",