mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-17 11:19:47 +00:00
try upgrading torchao / torchtune deps
This commit is contained in:
parent
104c66f099
commit
967ac345e0
2 changed files with 6 additions and 6 deletions
|
|
@ -42,7 +42,7 @@ def available_providers() -> list[ProviderSpec]:
|
|||
provider_type="inline::sentence-transformers",
|
||||
# CrossEncoder depends on torchao.quantization
|
||||
pip_packages=[
|
||||
"torch torchvision torchao --index-url https://download.pytorch.org/whl/cpu",
|
||||
"torch torchvision torchao>=0.12.0 --index-url https://download.pytorch.org/whl/cpu",
|
||||
"sentence-transformers --no-deps",
|
||||
],
|
||||
module="llama_stack.providers.inline.inference.sentence_transformers",
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from llama_stack.providers.datatypes import AdapterSpec, Api, InlineProviderSpec
|
|||
# The CPU version is used for distributions that don't have GPU support -- they result in smaller container images.
|
||||
torchtune_def = dict(
|
||||
api=Api.post_training,
|
||||
pip_packages=["torchtune==0.5.0", "torchao==0.8.0", "numpy"],
|
||||
pip_packages=["numpy"],
|
||||
module="llama_stack.providers.inline.post_training.torchtune",
|
||||
config_class="llama_stack.providers.inline.post_training.torchtune.TorchtunePostTrainingConfig",
|
||||
api_dependencies=[
|
||||
|
|
@ -27,21 +27,21 @@ torchtune_def = dict(
|
|||
def available_providers() -> list[ProviderSpec]:
|
||||
return [
|
||||
InlineProviderSpec(
|
||||
**{
|
||||
**{ # type: ignore
|
||||
**torchtune_def,
|
||||
"provider_type": "inline::torchtune-cpu",
|
||||
"pip_packages": (
|
||||
cast(list[str], torchtune_def["pip_packages"])
|
||||
+ ["torch torchtune==0.5.0 torchao==0.8.0 --index-url https://download.pytorch.org/whl/cpu"]
|
||||
+ ["torch torchtune>=0.5.0 torchao>=0.12.0 --index-url https://download.pytorch.org/whl/cpu"]
|
||||
),
|
||||
},
|
||||
),
|
||||
InlineProviderSpec(
|
||||
**{
|
||||
**{ # type: ignore
|
||||
**torchtune_def,
|
||||
"provider_type": "inline::torchtune-gpu",
|
||||
"pip_packages": (
|
||||
cast(list[str], torchtune_def["pip_packages"]) + ["torch torchtune==0.5.0 torchao==0.8.0"]
|
||||
cast(list[str], torchtune_def["pip_packages"]) + ["torch torchtune>=0.5.0 torchao>=0.12.0"]
|
||||
),
|
||||
},
|
||||
),
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue