fold in meta-reference-quantized

This commit is contained in:
Ashwin Bharambe 2025-04-07 11:15:27 -07:00
parent cfaf9e0e8b
commit ff6c47d4e5
9 changed files with 24 additions and 439 deletions

View file

@ -4,13 +4,13 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, Union
from typing import Any, Dict
from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
from .config import MetaReferenceInferenceConfig
async def get_provider_impl(
config: Union[MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig],
config: MetaReferenceInferenceConfig,
_deps: Dict[str, Any],
):
from .inference import MetaReferenceInferenceImpl

View file

@ -31,6 +31,8 @@ class MetaReferenceInferenceConfig(BaseModel):
# can override by specifying the directory explicitly
checkpoint_dir: Optional[str] = None
quantization: Optional[QuantizationConfig] = None
@field_validator("model")
@classmethod
def validate_model(cls, model: str) -> str:
@ -47,27 +49,14 @@ class MetaReferenceInferenceConfig(BaseModel):
cls,
model: str = "Llama3.2-3B-Instruct",
checkpoint_dir: str = "${env.CHECKPOINT_DIR:null}",
quantization_type: str = "${env.QUANTIZATION_TYPE:bf16}",
**kwargs,
) -> Dict[str, Any]:
return {
"model": model,
"max_seq_len": 4096,
"checkpoint_dir": checkpoint_dir,
"quantization": {
"type": quantization_type,
},
}
class MetaReferenceQuantizedInferenceConfig(MetaReferenceInferenceConfig):
quantization: QuantizationConfig
@classmethod
def sample_run_config(
cls,
model: str = "Llama3.2-3B-Instruct",
checkpoint_dir: str = "${env.CHECKPOINT_DIR:null}",
**kwargs,
) -> Dict[str, Any]:
config = super().sample_run_config(model, checkpoint_dir, **kwargs)
config["quantization"] = {
"type": "fp8",
}
return config

View file

@ -11,9 +11,7 @@ import torch
from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData
from llama_stack.apis.inference import (
Fp8QuantizationConfig,
GreedySamplingStrategy,
Int4QuantizationConfig,
JsonSchemaResponseFormat,
ResponseFormat,
SamplingParams,
@ -32,7 +30,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
)
from .common import model_checkpoint_dir
from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
from .config import MetaReferenceInferenceConfig
from .inference import resolve_model
Tokenizer = Llama4Tokenizer | Llama3Tokenizer
@ -118,7 +116,7 @@ def _infer_tool_prompt_format(request: ChatCompletionRequestWithRawContent):
class Llama4Generator:
def __init__(
self,
config: MetaReferenceInferenceConfig | MetaReferenceQuantizedInferenceConfig,
config: MetaReferenceInferenceConfig,
model_id: str,
llama_model: Model,
):
@ -133,11 +131,13 @@ class Llama4Generator:
# if the model is a native llama model, get the default checkpoint_dir based on model core_model_id value
ckpt_dir = model_checkpoint_dir(resolved_model.descriptor())
if isinstance(config, MetaReferenceQuantizedInferenceConfig):
if isinstance(config.quantization, Fp8QuantizationConfig):
if config.quantization:
if config.quantization.type == "fp8":
quantization_mode = QuantizationMode.fp8_mixed
elif isinstance(config.quantization, Int4QuantizationConfig):
elif config.quantization.type == "int4":
quantization_mode = QuantizationMode.int4_mixed
elif config.quantization.type == "bf16":
quantization_mode = None
else:
raise ValueError(f"Unsupported quantization mode {config.quantization}")
else:
@ -207,7 +207,7 @@ class Llama4Generator:
class Llama3Generator:
def __init__(
self,
config: MetaReferenceInferenceConfig | MetaReferenceQuantizedInferenceConfig,
config: MetaReferenceInferenceConfig,
model_id: str,
llama_model: Model,
):
@ -222,11 +222,13 @@ class Llama3Generator:
# if the model is a native llama model, get the default checkpoint_dir based on model core_model_id value
ckpt_dir = model_checkpoint_dir(resolved_model.descriptor())
if isinstance(config, MetaReferenceQuantizedInferenceConfig):
if isinstance(config.quantization, Fp8QuantizationConfig):
if config.quantization:
if config.quantization.type == "fp8":
quantization_mode = QuantizationMode.fp8_mixed
elif isinstance(config.quantization, Int4QuantizationConfig):
elif config.quantization.type == "int4":
quantization_mode = QuantizationMode.int4_mixed
elif config.quantization.type == "bf16":
quantization_mode = None
else:
raise ValueError(f"Unsupported quantization mode {config.quantization}")
else:

View file

@ -24,6 +24,8 @@ META_REFERENCE_DEPS = [
"zmq",
"lm-format-enforcer",
"sentence-transformers",
"torchao==0.5.0",
"fbgemm-gpu-genai==1.1.2",
]
@ -36,13 +38,6 @@ def available_providers() -> List[ProviderSpec]:
module="llama_stack.providers.inline.inference.meta_reference",
config_class="llama_stack.providers.inline.inference.meta_reference.MetaReferenceInferenceConfig",
),
InlineProviderSpec(
api=Api.inference,
provider_type="inline::meta-reference-quantized",
pip_packages=META_REFERENCE_DEPS + ["fbgemm-gpu", "torchao==0.5.0"],
module="llama_stack.providers.inline.inference.meta_reference",
config_class="llama_stack.providers.inline.inference.meta_reference.MetaReferenceQuantizedInferenceConfig",
),
InlineProviderSpec(
api=Api.inference,
provider_type="inline::vllm",

View file

@ -1,7 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .meta_reference import get_distribution_template # noqa: F401

View file

@ -1,32 +0,0 @@
version: '2'
distribution_spec:
description: Use Meta Reference with fp8, int4 quantization for running LLM inference
providers:
inference:
- inline::meta-reference-quantized
vector_io:
- inline::faiss
- remote::chromadb
- remote::pgvector
safety:
- inline::llama-guard
agents:
- inline::meta-reference
telemetry:
- inline::meta-reference
eval:
- inline::meta-reference
datasetio:
- remote::huggingface
- inline::localfs
scoring:
- inline::basic
- inline::llm-as-judge
- inline::braintrust
tool_runtime:
- remote::brave-search
- remote::tavily-search
- inline::code-interpreter
- inline::rag-runtime
- remote::model-context-protocol
image_type: conda

View file

@ -1,113 +0,0 @@
---
orphan: true
---
# Meta Reference Quantized Distribution
```{toctree}
:maxdepth: 2
:hidden:
self
```
The `llamastack/distribution-{{ name }}` distribution consists of the following provider configurations:
{{ providers_table }}
The only difference vs. the `meta-reference-gpu` distribution is that it has support for more efficient inference -- with fp8, int4 quantization, etc.
Note that you need access to nvidia GPUs to run this distribution. This distribution is not compatible with CPU-only machines or machines with AMD GPUs.
{% if run_config_env_vars %}
### Environment Variables
The following environment variables can be configured:
{% for var, (default_value, description) in run_config_env_vars.items() %}
- `{{ var }}`: {{ description }} (default: `{{ default_value }}`)
{% endfor %}
{% endif %}
## Prerequisite: Downloading Models
Please use `llama model list --downloaded` to check that you have llama model checkpoints downloaded in `~/.llama` before proceeding. See [installation guide](https://llama-stack.readthedocs.io/en/latest/references/llama_cli_reference/download_models.html) here to download the models. Run `llama model list` to see the available models to download, and `llama model download` to download the checkpoints.
```
$ llama model list --downloaded
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┓
┃ Model ┃ Size ┃ Modified Time ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━┩
│ Llama3.2-1B-Instruct:int4-qlora-eo8 │ 1.53 GB │ 2025-02-26 11:22:28 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-1B │ 2.31 GB │ 2025-02-18 21:48:52 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Prompt-Guard-86M │ 0.02 GB │ 2025-02-26 11:29:28 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-3B-Instruct:int4-spinquant-eo8 │ 3.69 GB │ 2025-02-26 11:37:41 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-3B │ 5.99 GB │ 2025-02-18 21:51:26 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.1-8B │ 14.97 GB │ 2025-02-16 10:36:37 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-1B-Instruct:int4-spinquant-eo8 │ 1.51 GB │ 2025-02-26 11:35:02 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama-Guard-3-1B │ 2.80 GB │ 2025-02-26 11:20:46 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama-Guard-3-1B:int4 │ 0.43 GB │ 2025-02-26 11:33:33 │
└─────────────────────────────────────────┴──────────┴─────────────────────┘
```
## Running the Distribution
You can do this via Conda (build code) or Docker which has a pre-built image.
### Via Docker
This method allows you to get started quickly without having to build the distribution code.
```bash
LLAMA_STACK_PORT=8321
docker run \
-it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ~/.llama:/root/.llama \
llamastack/distribution-{{ name }} \
--port $LLAMA_STACK_PORT \
--env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
```
If you are using Llama Stack Safety / Shield APIs, use:
```bash
docker run \
-it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ~/.llama:/root/.llama \
llamastack/distribution-{{ name }} \
--port $LLAMA_STACK_PORT \
--env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \
--env SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
```
### Via Conda
Make sure you have done `uv pip install llama-stack` and have the Llama Stack CLI available.
```bash
llama stack build --template {{ name }} --image-type conda
llama stack run distributions/{{ name }}/run.yaml \
--port $LLAMA_STACK_PORT \
--env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
```
If you are using Llama Stack Safety / Shield APIs, use:
```bash
llama stack run distributions/{{ name }}/run-with-safety.yaml \
--port $LLAMA_STACK_PORT \
--env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \
--env SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
```

View file

@ -1,115 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pathlib import Path
from llama_stack.apis.models.models import ModelType
from llama_stack.distribution.datatypes import ModelInput, Provider, ToolGroupInput
from llama_stack.providers.inline.inference.meta_reference import (
MetaReferenceQuantizedInferenceConfig,
)
from llama_stack.providers.inline.inference.sentence_transformers import (
SentenceTransformersInferenceConfig,
)
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
def get_distribution_template() -> DistributionTemplate:
providers = {
"inference": ["inline::meta-reference-quantized"],
"vector_io": ["inline::faiss", "remote::chromadb", "remote::pgvector"],
"safety": ["inline::llama-guard"],
"agents": ["inline::meta-reference"],
"telemetry": ["inline::meta-reference"],
"eval": ["inline::meta-reference"],
"datasetio": ["remote::huggingface", "inline::localfs"],
"scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"],
"tool_runtime": [
"remote::brave-search",
"remote::tavily-search",
"inline::code-interpreter",
"inline::rag-runtime",
"remote::model-context-protocol",
],
}
default_tool_groups = [
ToolGroupInput(
toolgroup_id="builtin::websearch",
provider_id="tavily-search",
),
ToolGroupInput(
toolgroup_id="builtin::rag",
provider_id="rag-runtime",
),
ToolGroupInput(
toolgroup_id="builtin::code_interpreter",
provider_id="code-interpreter",
),
]
name = "meta-reference-quantized-gpu"
inference_provider = Provider(
provider_id="meta-reference-inference",
provider_type="inline::meta-reference-quantized",
config=MetaReferenceQuantizedInferenceConfig.sample_run_config(
model="${env.INFERENCE_MODEL}",
checkpoint_dir="${env.INFERENCE_CHECKPOINT_DIR:null}",
),
)
embedding_provider = Provider(
provider_id="sentence-transformers",
provider_type="inline::sentence-transformers",
config=SentenceTransformersInferenceConfig.sample_run_config(),
)
vector_io_provider = Provider(
provider_id="faiss",
provider_type="inline::faiss",
config=FaissVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
)
inference_model = ModelInput(
model_id="${env.INFERENCE_MODEL}",
provider_id="meta-reference-inference",
)
embedding_model = ModelInput(
model_id="all-MiniLM-L6-v2",
provider_id="sentence-transformers",
model_type=ModelType.embedding,
metadata={
"embedding_dimension": 384,
},
)
return DistributionTemplate(
name=name,
distro_type="self_hosted",
description="Use Meta Reference with fp8, int4 quantization for running LLM inference",
template_path=Path(__file__).parent / "doc_template.md",
providers=providers,
run_configs={
"run.yaml": RunConfigSettings(
provider_overrides={
"inference": [inference_provider, embedding_provider],
"vector_io": [vector_io_provider],
},
default_models=[inference_model, embedding_model],
default_tool_groups=default_tool_groups,
),
},
run_config_env_vars={
"LLAMA_STACK_PORT": (
"8321",
"Port for the Llama Stack distribution server",
),
"INFERENCE_MODEL": (
"meta-llama/Llama-3.2-3B-Instruct",
"Inference model loaded into the Meta Reference server",
),
"INFERENCE_CHECKPOINT_DIR": (
"null",
"Directory containing the Meta Reference model checkpoint",
),
},
)

View file

@ -1,134 +0,0 @@
version: '2'
image_name: meta-reference-quantized-gpu
apis:
- agents
- datasetio
- eval
- inference
- safety
- scoring
- telemetry
- tool_runtime
- vector_io
providers:
inference:
- provider_id: meta-reference-inference
provider_type: inline::meta-reference-quantized
config:
model: ${env.INFERENCE_MODEL}
max_seq_len: 4096
checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null}
quantization:
type: fp8
- provider_id: sentence-transformers
provider_type: inline::sentence-transformers
config: {}
vector_io:
- provider_id: faiss
provider_type: inline::faiss
config:
kvstore:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-quantized-gpu}/faiss_store.db
safety:
- provider_id: llama-guard
provider_type: inline::llama-guard
config:
excluded_categories: []
agents:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
persistence_store:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-quantized-gpu}/agents_store.db
telemetry:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/meta-reference-quantized-gpu/trace_store.db}
eval:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
kvstore:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-quantized-gpu}/meta_reference_eval.db
datasetio:
- provider_id: huggingface
provider_type: remote::huggingface
config:
kvstore:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-quantized-gpu}/huggingface_datasetio.db
- provider_id: localfs
provider_type: inline::localfs
config:
kvstore:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-quantized-gpu}/localfs_datasetio.db
scoring:
- provider_id: basic
provider_type: inline::basic
config: {}
- provider_id: llm-as-judge
provider_type: inline::llm-as-judge
config: {}
- provider_id: braintrust
provider_type: inline::braintrust
config:
openai_api_key: ${env.OPENAI_API_KEY:}
tool_runtime:
- provider_id: brave-search
provider_type: remote::brave-search
config:
api_key: ${env.BRAVE_SEARCH_API_KEY:}
max_results: 3
- provider_id: tavily-search
provider_type: remote::tavily-search
config:
api_key: ${env.TAVILY_SEARCH_API_KEY:}
max_results: 3
- provider_id: code-interpreter
provider_type: inline::code-interpreter
config: {}
- provider_id: rag-runtime
provider_type: inline::rag-runtime
config: {}
- provider_id: model-context-protocol
provider_type: remote::model-context-protocol
config: {}
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-quantized-gpu}/registry.db
models:
- metadata: {}
model_id: ${env.INFERENCE_MODEL}
provider_id: meta-reference-inference
model_type: llm
- metadata:
embedding_dimension: 384
model_id: all-MiniLM-L6-v2
provider_id: sentence-transformers
model_type: embedding
shields: []
vector_dbs: []
datasets: []
scoring_fns: []
benchmarks: []
tool_groups:
- toolgroup_id: builtin::websearch
provider_id: tavily-search
- toolgroup_id: builtin::rag
provider_id: rag-runtime
- toolgroup_id: builtin::code_interpreter
provider_id: code-interpreter
server:
port: 8321