Remove additional_pip_packages; move deps to providers

This commit is contained in:
Ashwin Bharambe 2024-08-08 10:19:46 -07:00
parent 6de36b6a15
commit 8d7ecf0c47
5 changed files with 18 additions and 32 deletions

View file

@ -85,11 +85,6 @@ class DistributionSpec(BaseModel):
description="Provider specifications for each of the APIs provided by this distribution", description="Provider specifications for each of the APIs provided by this distribution",
) )
additional_pip_packages: List[str] = Field(
default_factory=list,
description="Additional pip packages beyond those required by the providers",
)
@json_schema_type @json_schema_type
class DistributionConfig(BaseModel): class DistributionConfig(BaseModel):

View file

@ -22,6 +22,14 @@ from .datatypes import (
ProviderSpec, ProviderSpec,
) )
# These are the dependencies needed by the distribution server.
# `llama-toolchain` is automatically installed by the installation script.
SERVER_DEPENDENCIES = [
"fastapi",
"python-dotenv",
"uvicorn",
]
def distribution_dependencies(distribution: DistributionSpec) -> List[str]: def distribution_dependencies(distribution: DistributionSpec) -> List[str]:
# only consider InlineProviderSpecs when calculating dependencies # only consider InlineProviderSpecs when calculating dependencies
@ -30,7 +38,7 @@ def distribution_dependencies(distribution: DistributionSpec) -> List[str]:
for provider_spec in distribution.provider_specs.values() for provider_spec in distribution.provider_specs.values()
if isinstance(provider_spec, InlineProviderSpec) if isinstance(provider_spec, InlineProviderSpec)
for dep in provider_spec.pip_packages for dep in provider_spec.pip_packages
] + distribution.additional_pip_packages ] + SERVER_DEPENDENCIES
def api_endpoints() -> Dict[Api, List[ApiEndpoint]]: def api_endpoints() -> Dict[Api, List[ApiEndpoint]]:

View file

@ -6,6 +6,9 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
LLAMA_MODELS_DIR=${LLAMA_MODELS_DIR:-}
LLAMA_TOOLCHAIN_DIR=${LLAMA_TOOLCHAIN_DIR:-}
set -euo pipefail set -euo pipefail
# Define color codes # Define color codes

View file

@ -10,14 +10,6 @@ from typing import List, Optional
from .datatypes import Api, DistributionSpec, RemoteProviderSpec from .datatypes import Api, DistributionSpec, RemoteProviderSpec
from .distribution import api_providers from .distribution import api_providers
# These are the dependencies needed by the distribution server.
# `llama-toolchain` is automatically installed by the installation script.
COMMON_DEPENDENCIES = [
"fastapi",
"python-dotenv",
"uvicorn",
]
def client_module(api: Api) -> str: def client_module(api: Api) -> str:
return f"llama_toolchain.{api.value}.client" return f"llama_toolchain.{api.value}.client"
@ -38,22 +30,6 @@ def available_distribution_specs() -> List[DistributionSpec]:
DistributionSpec( DistributionSpec(
spec_id="inline", spec_id="inline",
description="Use code from `llama_toolchain` itself to serve all llama stack APIs", description="Use code from `llama_toolchain` itself to serve all llama stack APIs",
additional_pip_packages=(
COMMON_DEPENDENCIES
# why do we need any of these? they should be completely covered
# by the provider dependencies themselves
+ [
"accelerate",
"blobfile",
"codeshield",
"fairscale",
"pandas",
"Pillow",
"torch",
"transformers",
"fbgemm-gpu==0.8.0",
]
),
provider_specs={ provider_specs={
Api.inference: providers[Api.inference]["meta-reference"], Api.inference: providers[Api.inference]["meta-reference"],
Api.safety: providers[Api.safety]["meta-reference"], Api.safety: providers[Api.safety]["meta-reference"],
@ -63,13 +39,11 @@ def available_distribution_specs() -> List[DistributionSpec]:
DistributionSpec( DistributionSpec(
spec_id="remote", spec_id="remote",
description="Point to remote services for all llama stack APIs", description="Point to remote services for all llama stack APIs",
additional_pip_packages=COMMON_DEPENDENCIES,
provider_specs={x: remote_spec(x) for x in providers}, provider_specs={x: remote_spec(x) for x in providers},
), ),
DistributionSpec( DistributionSpec(
spec_id="ollama-inline", spec_id="ollama-inline",
description="Like local-source, but use ollama for running LLM inference", description="Like local-source, but use ollama for running LLM inference",
additional_pip_packages=COMMON_DEPENDENCIES,
provider_specs={ provider_specs={
Api.inference: providers[Api.inference]["meta-ollama"], Api.inference: providers[Api.inference]["meta-ollama"],
Api.safety: providers[Api.safety]["meta-reference"], Api.safety: providers[Api.safety]["meta-reference"],

View file

@ -15,7 +15,13 @@ def available_inference_providers() -> List[ProviderSpec]:
api=Api.inference, api=Api.inference,
provider_id="meta-reference", provider_id="meta-reference",
pip_packages=[ pip_packages=[
"accelerate",
"blobfile",
"codeshield",
"fairscale",
"fbgemm-gpu==0.8.0",
"torch", "torch",
"transformers",
"zmq", "zmq",
], ],
module="llama_toolchain.inference.meta_reference", module="llama_toolchain.inference.meta_reference",