mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
# What does this PR do? adds an inline HF SFTTrainer provider. Alongside touchtune -- this is a super popular option for running training jobs. The config allows a user to specify some key fields such as a model, chat_template, device, etc the provider comes with one recipe `finetune_single_device` which works both with and without LoRA. any model that is a valid HF identifier can be given and the model will be pulled. this has been tested so far with CPU and MPS device types, but should be compatible with CUDA out of the box The provider processes the given dataset into the proper format, establishes the various steps per epoch, steps per save, steps per eval, sets a sane SFTConfig, and runs n_epochs of training if checkpoint_dir is none, no model is saved. If there is a checkpoint dir, a model is saved every `save_steps` and at the end of training. ## Test Plan re-enabled post_training integration test suite with a singular test that loads the simpleqa dataset: https://huggingface.co/datasets/llamastack/simpleqa and a tiny granite model: https://huggingface.co/ibm-granite/granite-3.3-2b-instruct. The test now uses the llama stack client and the proper post_training API runs one step with a batch_size of 1. This test runs on CPU on the Ubuntu runner so it needs to be a small batch and a single step. [//]: # (## Documentation) --------- Signed-off-by: Charlie Doern <cdoern@redhat.com>
325 lines
12 KiB
TOML
325 lines
12 KiB
TOML
[build-system]
|
|
requires = ["setuptools>=61.0"]
|
|
build-backend = "setuptools.build_meta"
|
|
|
|
[project]
|
|
name = "llama_stack"
|
|
version = "0.2.7"
|
|
authors = [{ name = "Meta Llama", email = "llama-oss@meta.com" }]
|
|
description = "Llama Stack"
|
|
readme = "README.md"
|
|
requires-python = ">=3.10"
|
|
license = { "text" = "MIT" }
|
|
classifiers = [
|
|
"License :: OSI Approved :: MIT License",
|
|
"Programming Language :: Python :: 3",
|
|
"Operating System :: OS Independent",
|
|
"Intended Audience :: Developers",
|
|
"Intended Audience :: Information Technology",
|
|
"Intended Audience :: Science/Research",
|
|
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
|
"Topic :: Scientific/Engineering :: Information Analysis",
|
|
]
|
|
dependencies = [
|
|
"blobfile",
|
|
"fire",
|
|
"httpx",
|
|
"huggingface-hub",
|
|
"jinja2>=3.1.6",
|
|
"jsonschema",
|
|
"llama-stack-client>=0.2.7",
|
|
"openai>=1.66",
|
|
"prompt-toolkit",
|
|
"python-dotenv",
|
|
"pydantic>=2",
|
|
"requests",
|
|
"rich",
|
|
"setuptools",
|
|
"termcolor",
|
|
"tiktoken",
|
|
"pillow",
|
|
"h11>=0.16.0",
|
|
"kubernetes",
|
|
]
|
|
|
|
[project.optional-dependencies]
|
|
dev = [
|
|
"pytest",
|
|
"pytest-timeout",
|
|
"pytest-asyncio",
|
|
"pytest-cov",
|
|
"pytest-html",
|
|
"pytest-json-report",
|
|
"nbval", # For notebook testing
|
|
"black",
|
|
"ruff",
|
|
"types-requests",
|
|
"types-setuptools",
|
|
"pre-commit",
|
|
"uvicorn",
|
|
"fastapi",
|
|
"ruamel.yaml", # needed for openapi generator
|
|
]
|
|
# These are the dependencies required for running unit tests.
|
|
unit = [
|
|
"sqlite-vec",
|
|
"openai",
|
|
"aiosqlite",
|
|
"aiohttp",
|
|
"pypdf",
|
|
"chardet",
|
|
"qdrant-client",
|
|
"opentelemetry-exporter-otlp-proto-http",
|
|
]
|
|
# These are the core dependencies required for running integration tests. They are shared across all
|
|
# providers. If a provider requires additional dependencies, please add them to your environment
|
|
# separately. If you are using "uv" to execute your tests, you can use the "--with" flag to specify extra
|
|
# dependencies.
|
|
test = [
|
|
"openai",
|
|
"aiosqlite",
|
|
"aiohttp",
|
|
"torch>=2.6.0",
|
|
"torchvision>=0.21.0",
|
|
"opentelemetry-sdk",
|
|
"opentelemetry-exporter-otlp-proto-http",
|
|
"chardet",
|
|
"pypdf",
|
|
"mcp",
|
|
"datasets",
|
|
"autoevals",
|
|
"transformers",
|
|
]
|
|
docs = [
|
|
"sphinx-autobuild",
|
|
"myst-parser",
|
|
"sphinx-rtd-theme",
|
|
"sphinx_rtd_dark_mode",
|
|
"sphinx-copybutton",
|
|
"sphinx-tabs",
|
|
"sphinx-design",
|
|
"sphinxcontrib.redoc",
|
|
"sphinxcontrib.video",
|
|
"sphinxcontrib.mermaid",
|
|
"tomli",
|
|
]
|
|
codegen = ["rich", "pydantic", "jinja2>=3.1.6"]
|
|
ui = [
|
|
"streamlit",
|
|
"pandas",
|
|
"llama-stack-client>=0.2.7",
|
|
"streamlit-option-menu",
|
|
]
|
|
|
|
[project.urls]
|
|
Homepage = "https://github.com/meta-llama/llama-stack"
|
|
|
|
[project.scripts]
|
|
llama = "llama_stack.cli.llama:main"
|
|
install-wheel-from-presigned = "llama_stack.cli.scripts.run:install_wheel_from_presigned"
|
|
|
|
[tool.setuptools]
|
|
packages = { find = {} }
|
|
license-files = []
|
|
|
|
[[tool.uv.index]]
|
|
name = "pytorch-cpu"
|
|
url = "https://download.pytorch.org/whl/cpu"
|
|
explicit = true
|
|
|
|
[tool.uv.sources]
|
|
torch = [{ index = "pytorch-cpu" }]
|
|
torchvision = [{ index = "pytorch-cpu" }]
|
|
|
|
[tool.ruff]
|
|
line-length = 120
|
|
exclude = [
|
|
"./.git",
|
|
"./docs/*",
|
|
"./build",
|
|
"./venv",
|
|
"*.pyi",
|
|
".pre-commit-config.yaml",
|
|
"*.md",
|
|
".flake8",
|
|
]
|
|
|
|
[tool.ruff.lint]
|
|
select = [
|
|
"UP", # pyupgrade
|
|
"B", # flake8-bugbear
|
|
"B9", # flake8-bugbear subset
|
|
"C", # comprehensions
|
|
"E", # pycodestyle
|
|
"F", # Pyflakes
|
|
"N", # Naming
|
|
"W", # Warnings
|
|
"DTZ", # datetime rules
|
|
"I", # isort (imports order)
|
|
"RUF001", # Checks for ambiguous Unicode characters in strings
|
|
"RUF002", # Checks for ambiguous Unicode characters in docstrings
|
|
"RUF003", # Checks for ambiguous Unicode characters in comments
|
|
"PLC2401", # Checks for the use of non-ASCII characters in variable names
|
|
"PLC2403", # Checks for the use of non-ASCII characters in import statements
|
|
"PLE2510", # Checks for strings that contain the control character BS.
|
|
"PLE2512", # Checks for strings that contain the raw control character SUB.
|
|
"PLE2513", # Checks for strings that contain the raw control character ESC.
|
|
"PLE2514", # Checks for strings that contain the raw control character NUL (0 byte).
|
|
"PLE2515", # Checks for strings that contain the zero width space character.
|
|
]
|
|
ignore = [
|
|
# The following ignores are desired by the project maintainers.
|
|
"E402", # Module level import not at top of file
|
|
"E501", # Line too long
|
|
"F405", # Maybe undefined or defined from star import
|
|
"C408", # Ignored because we like the dict keyword argument syntax
|
|
"N812", # Ignored because import torch.nn.functional as F is PyTorch convention
|
|
|
|
# These are the additional ones we started ignoring after moving to ruff. We should look into each one of them later.
|
|
"C901", # Complexity of the function is too high
|
|
]
|
|
unfixable = [
|
|
"PLE2515",
|
|
] # Do not fix this automatically since ruff will replace the zero-width space with \u200b - let's do it manually
|
|
|
|
# Ignore the following errors for the following files
|
|
[tool.ruff.lint.per-file-ignores]
|
|
"tests/**/*.py" = ["DTZ"] # Ignore datetime rules for tests
|
|
"llama_stack/providers/inline/scoring/basic/utils/ifeval_utils.py" = ["RUF001"]
|
|
"llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/regex_parser_multiple_choice_answer.py" = [
|
|
"RUF001",
|
|
"PLE2515",
|
|
]
|
|
|
|
[tool.mypy]
|
|
mypy_path = ["llama_stack"]
|
|
packages = ["llama_stack"]
|
|
plugins = ['pydantic.mypy']
|
|
disable_error_code = []
|
|
warn_return_any = true
|
|
# # honor excludes by not following there through imports
|
|
follow_imports = "silent"
|
|
# Note: some entries are directories, not files. This is because mypy doesn't
|
|
# respect __init__.py excludes, so the only way to suppress these right now is
|
|
# to exclude the entire directory.
|
|
exclude = [
|
|
# As we fix more and more of these, we should remove them from the list
|
|
"^llama_stack/apis/common/training_types\\.py$",
|
|
"^llama_stack/cli/download\\.py$",
|
|
"^llama_stack/cli/stack/_build\\.py$",
|
|
"^llama_stack/distribution/build\\.py$",
|
|
"^llama_stack/distribution/client\\.py$",
|
|
"^llama_stack/distribution/request_headers\\.py$",
|
|
"^llama_stack/distribution/routers/",
|
|
"^llama_stack/distribution/server/endpoints\\.py$",
|
|
"^llama_stack/distribution/server/server\\.py$",
|
|
"^llama_stack/distribution/stack\\.py$",
|
|
"^llama_stack/distribution/store/registry\\.py$",
|
|
"^llama_stack/distribution/utils/exec\\.py$",
|
|
"^llama_stack/distribution/utils/prompt_for_config\\.py$",
|
|
"^llama_stack/models/llama/llama3/chat_format\\.py$",
|
|
"^llama_stack/models/llama/llama3/interface\\.py$",
|
|
"^llama_stack/models/llama/llama3/tokenizer\\.py$",
|
|
"^llama_stack/models/llama/llama3/tool_utils\\.py$",
|
|
"^llama_stack/models/llama/llama3_3/prompts\\.py$",
|
|
"^llama_stack/providers/inline/agents/meta_reference/",
|
|
"^llama_stack/providers/inline/agents/meta_reference/agent_instance\\.py$",
|
|
"^llama_stack/providers/inline/agents/meta_reference/agents\\.py$",
|
|
"^llama_stack/providers/inline/agents/meta_reference/safety\\.py$",
|
|
"^llama_stack/providers/inline/datasetio/localfs/",
|
|
"^llama_stack/providers/inline/eval/meta_reference/eval\\.py$",
|
|
"^llama_stack/providers/inline/inference/meta_reference/config\\.py$",
|
|
"^llama_stack/providers/inline/inference/meta_reference/inference\\.py$",
|
|
"^llama_stack/models/llama/llama3/generation\\.py$",
|
|
"^llama_stack/models/llama/llama3/multimodal/model\\.py$",
|
|
"^llama_stack/models/llama/llama4/",
|
|
"^llama_stack/providers/inline/inference/meta_reference/parallel_utils\\.py$",
|
|
"^llama_stack/providers/inline/inference/meta_reference/quantization/fp8_impls\\.py$",
|
|
"^llama_stack/providers/inline/inference/meta_reference/quantization/loader\\.py$",
|
|
"^llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers\\.py$",
|
|
"^llama_stack/providers/inline/inference/vllm/",
|
|
"^llama_stack/providers/inline/post_training/common/validator\\.py$",
|
|
"^llama_stack/providers/inline/post_training/torchtune/post_training\\.py$",
|
|
"^llama_stack/providers/inline/safety/code_scanner/",
|
|
"^llama_stack/providers/inline/safety/llama_guard/",
|
|
"^llama_stack/providers/inline/safety/prompt_guard/",
|
|
"^llama_stack/providers/inline/scoring/basic/",
|
|
"^llama_stack/providers/inline/scoring/braintrust/",
|
|
"^llama_stack/providers/inline/scoring/llm_as_judge/",
|
|
"^llama_stack/providers/remote/agents/sample/",
|
|
"^llama_stack/providers/remote/datasetio/huggingface/",
|
|
"^llama_stack/providers/remote/datasetio/nvidia/",
|
|
"^llama_stack/providers/remote/inference/anthropic/",
|
|
"^llama_stack/providers/remote/inference/bedrock/",
|
|
"^llama_stack/providers/remote/inference/cerebras/",
|
|
"^llama_stack/providers/remote/inference/databricks/",
|
|
"^llama_stack/providers/remote/inference/fireworks/",
|
|
"^llama_stack/providers/remote/inference/gemini/",
|
|
"^llama_stack/providers/remote/inference/groq/",
|
|
"^llama_stack/providers/remote/inference/nvidia/",
|
|
"^llama_stack/providers/remote/inference/openai/",
|
|
"^llama_stack/providers/remote/inference/passthrough/",
|
|
"^llama_stack/providers/remote/inference/runpod/",
|
|
"^llama_stack/providers/remote/inference/sambanova/",
|
|
"^llama_stack/providers/remote/inference/sample/",
|
|
"^llama_stack/providers/remote/inference/tgi/",
|
|
"^llama_stack/providers/remote/inference/together/",
|
|
"^llama_stack/providers/remote/inference/watsonx/",
|
|
"^llama_stack/providers/remote/safety/bedrock/",
|
|
"^llama_stack/providers/remote/safety/nvidia/",
|
|
"^llama_stack/providers/remote/safety/sample/",
|
|
"^llama_stack/providers/remote/tool_runtime/bing_search/",
|
|
"^llama_stack/providers/remote/tool_runtime/brave_search/",
|
|
"^llama_stack/providers/remote/tool_runtime/model_context_protocol/",
|
|
"^llama_stack/providers/remote/tool_runtime/tavily_search/",
|
|
"^llama_stack/providers/remote/tool_runtime/wolfram_alpha/",
|
|
"^llama_stack/providers/remote/post_training/nvidia/",
|
|
"^llama_stack/providers/remote/vector_io/chroma/",
|
|
"^llama_stack/providers/remote/vector_io/milvus/",
|
|
"^llama_stack/providers/remote/vector_io/pgvector/",
|
|
"^llama_stack/providers/remote/vector_io/qdrant/",
|
|
"^llama_stack/providers/remote/vector_io/sample/",
|
|
"^llama_stack/providers/remote/vector_io/weaviate/",
|
|
"^llama_stack/providers/tests/conftest\\.py$",
|
|
"^llama_stack/providers/utils/bedrock/client\\.py$",
|
|
"^llama_stack/providers/utils/bedrock/refreshable_boto_session\\.py$",
|
|
"^llama_stack/providers/utils/inference/embedding_mixin\\.py$",
|
|
"^llama_stack/providers/utils/inference/litellm_openai_mixin\\.py$",
|
|
"^llama_stack/providers/utils/inference/model_registry\\.py$",
|
|
"^llama_stack/providers/utils/inference/openai_compat\\.py$",
|
|
"^llama_stack/providers/utils/inference/prompt_adapter\\.py$",
|
|
"^llama_stack/providers/utils/kvstore/config\\.py$",
|
|
"^llama_stack/providers/utils/kvstore/kvstore\\.py$",
|
|
"^llama_stack/providers/utils/kvstore/mongodb/mongodb\\.py$",
|
|
"^llama_stack/providers/utils/kvstore/postgres/postgres\\.py$",
|
|
"^llama_stack/providers/utils/kvstore/redis/redis\\.py$",
|
|
"^llama_stack/providers/utils/kvstore/sqlite/sqlite\\.py$",
|
|
"^llama_stack/providers/utils/memory/vector_store\\.py$",
|
|
"^llama_stack/providers/utils/scoring/aggregation_utils\\.py$",
|
|
"^llama_stack/providers/utils/scoring/base_scoring_fn\\.py$",
|
|
"^llama_stack/providers/utils/telemetry/dataset_mixin\\.py$",
|
|
"^llama_stack/providers/utils/telemetry/trace_protocol\\.py$",
|
|
"^llama_stack/providers/utils/telemetry/tracing\\.py$",
|
|
"^llama_stack/strong_typing/auxiliary\\.py$",
|
|
"^llama_stack/strong_typing/deserializer\\.py$",
|
|
"^llama_stack/strong_typing/inspection\\.py$",
|
|
"^llama_stack/strong_typing/schema\\.py$",
|
|
"^llama_stack/strong_typing/serializer\\.py$",
|
|
"^llama_stack/templates/groq/groq\\.py$",
|
|
"^llama_stack/templates/llama_api/llama_api\\.py$",
|
|
"^llama_stack/templates/sambanova/sambanova\\.py$",
|
|
"^llama_stack/templates/template\\.py$",
|
|
]
|
|
|
|
[[tool.mypy.overrides]]
|
|
# packages that lack typing annotations, do not have stubs, or are unavailable.
|
|
module = ["yaml", "fire"]
|
|
ignore_missing_imports = true
|
|
|
|
[tool.pydantic-mypy]
|
|
init_forbid_extra = true
|
|
init_typed = true
|
|
warn_required_dynamic_aliases = true
|
|
|
|
[tool.ruff.lint.pep8-naming]
|
|
classmethod-decorators = ["classmethod", "pydantic.field_validator"]
|