diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index 0efcc3cee..4af7b8e31 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -42,7 +42,7 @@ def available_providers() -> list[ProviderSpec]: provider_type="inline::sentence-transformers", # CrossEncoder depends on torchao.quantization pip_packages=[ - "torch torchvision torchao --index-url https://download.pytorch.org/whl/cpu", + "torch torchvision torchao>=0.12.0 --index-url https://download.pytorch.org/whl/cpu", "sentence-transformers --no-deps", ], module="llama_stack.providers.inline.inference.sentence_transformers", diff --git a/llama_stack/providers/registry/post_training.py b/llama_stack/providers/registry/post_training.py index 9ce36f18b..5c93a2e22 100644 --- a/llama_stack/providers/registry/post_training.py +++ b/llama_stack/providers/registry/post_training.py @@ -13,7 +13,7 @@ from llama_stack.providers.datatypes import AdapterSpec, Api, InlineProviderSpec # The CPU version is used for distributions that don't have GPU support -- they result in smaller container images. torchtune_def = dict( api=Api.post_training, - pip_packages=["torchtune==0.5.0", "torchao==0.8.0", "numpy"], + pip_packages=["numpy"], module="llama_stack.providers.inline.post_training.torchtune", config_class="llama_stack.providers.inline.post_training.torchtune.TorchtunePostTrainingConfig", api_dependencies=[ @@ -27,21 +27,21 @@ torchtune_def = dict( def available_providers() -> list[ProviderSpec]: return [ InlineProviderSpec( - **{ + **{ # type: ignore **torchtune_def, "provider_type": "inline::torchtune-cpu", "pip_packages": ( cast(list[str], torchtune_def["pip_packages"]) - + ["torch torchtune==0.5.0 torchao==0.8.0 --index-url https://download.pytorch.org/whl/cpu"] + + ["torch torchtune>=0.5.0 torchao>=0.12.0 --index-url https://download.pytorch.org/whl/cpu"] ), }, ), InlineProviderSpec( - **{ + **{ # type: ignore **torchtune_def, "provider_type": "inline::torchtune-gpu", "pip_packages": ( - cast(list[str], torchtune_def["pip_packages"]) + ["torch torchtune==0.5.0 torchao==0.8.0"] + cast(list[str], torchtune_def["pip_packages"]) + ["torch torchtune>=0.5.0 torchao>=0.12.0"] ), }, ),