fold in meta-reference-quantized

This commit is contained in:
Ashwin Bharambe 2025-04-07 11:15:27 -07:00
parent cfaf9e0e8b
commit ff6c47d4e5
9 changed files with 24 additions and 439 deletions

View file

@ -4,13 +4,13 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, Union
from typing import Any, Dict
from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
from .config import MetaReferenceInferenceConfig
async def get_provider_impl(
config: Union[MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig],
config: MetaReferenceInferenceConfig,
_deps: Dict[str, Any],
):
from .inference import MetaReferenceInferenceImpl

View file

@ -31,6 +31,8 @@ class MetaReferenceInferenceConfig(BaseModel):
# can override by specifying the directory explicitly
checkpoint_dir: Optional[str] = None
quantization: Optional[QuantizationConfig] = None
@field_validator("model")
@classmethod
def validate_model(cls, model: str) -> str:
@ -47,27 +49,14 @@ class MetaReferenceInferenceConfig(BaseModel):
cls,
model: str = "Llama3.2-3B-Instruct",
checkpoint_dir: str = "${env.CHECKPOINT_DIR:null}",
quantization_type: str = "${env.QUANTIZATION_TYPE:bf16}",
**kwargs,
) -> Dict[str, Any]:
return {
"model": model,
"max_seq_len": 4096,
"checkpoint_dir": checkpoint_dir,
"quantization": {
"type": quantization_type,
},
}
class MetaReferenceQuantizedInferenceConfig(MetaReferenceInferenceConfig):
quantization: QuantizationConfig
@classmethod
def sample_run_config(
cls,
model: str = "Llama3.2-3B-Instruct",
checkpoint_dir: str = "${env.CHECKPOINT_DIR:null}",
**kwargs,
) -> Dict[str, Any]:
config = super().sample_run_config(model, checkpoint_dir, **kwargs)
config["quantization"] = {
"type": "fp8",
}
return config

View file

@ -11,9 +11,7 @@ import torch
from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData
from llama_stack.apis.inference import (
Fp8QuantizationConfig,
GreedySamplingStrategy,
Int4QuantizationConfig,
JsonSchemaResponseFormat,
ResponseFormat,
SamplingParams,
@ -32,7 +30,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
)
from .common import model_checkpoint_dir
from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
from .config import MetaReferenceInferenceConfig
from .inference import resolve_model
Tokenizer = Llama4Tokenizer | Llama3Tokenizer
@ -118,7 +116,7 @@ def _infer_tool_prompt_format(request: ChatCompletionRequestWithRawContent):
class Llama4Generator:
def __init__(
self,
config: MetaReferenceInferenceConfig | MetaReferenceQuantizedInferenceConfig,
config: MetaReferenceInferenceConfig,
model_id: str,
llama_model: Model,
):
@ -133,11 +131,13 @@ class Llama4Generator:
# if the model is a native llama model, get the default checkpoint_dir based on model core_model_id value
ckpt_dir = model_checkpoint_dir(resolved_model.descriptor())
if isinstance(config, MetaReferenceQuantizedInferenceConfig):
if isinstance(config.quantization, Fp8QuantizationConfig):
if config.quantization:
if config.quantization.type == "fp8":
quantization_mode = QuantizationMode.fp8_mixed
elif isinstance(config.quantization, Int4QuantizationConfig):
elif config.quantization.type == "int4":
quantization_mode = QuantizationMode.int4_mixed
elif config.quantization.type == "bf16":
quantization_mode = None
else:
raise ValueError(f"Unsupported quantization mode {config.quantization}")
else:
@ -207,7 +207,7 @@ class Llama4Generator:
class Llama3Generator:
def __init__(
self,
config: MetaReferenceInferenceConfig | MetaReferenceQuantizedInferenceConfig,
config: MetaReferenceInferenceConfig,
model_id: str,
llama_model: Model,
):
@ -222,11 +222,13 @@ class Llama3Generator:
# if the model is a native llama model, get the default checkpoint_dir based on model core_model_id value
ckpt_dir = model_checkpoint_dir(resolved_model.descriptor())
if isinstance(config, MetaReferenceQuantizedInferenceConfig):
if isinstance(config.quantization, Fp8QuantizationConfig):
if config.quantization:
if config.quantization.type == "fp8":
quantization_mode = QuantizationMode.fp8_mixed
elif isinstance(config.quantization, Int4QuantizationConfig):
elif config.quantization.type == "int4":
quantization_mode = QuantizationMode.int4_mixed
elif config.quantization.type == "bf16":
quantization_mode = None
else:
raise ValueError(f"Unsupported quantization mode {config.quantization}")
else:

View file

@ -24,6 +24,8 @@ META_REFERENCE_DEPS = [
"zmq",
"lm-format-enforcer",
"sentence-transformers",
"torchao==0.5.0",
"fbgemm-gpu-genai==1.1.2",
]
@ -36,13 +38,6 @@ def available_providers() -> List[ProviderSpec]:
module="llama_stack.providers.inline.inference.meta_reference",
config_class="llama_stack.providers.inline.inference.meta_reference.MetaReferenceInferenceConfig",
),
InlineProviderSpec(
api=Api.inference,
provider_type="inline::meta-reference-quantized",
pip_packages=META_REFERENCE_DEPS + ["fbgemm-gpu", "torchao==0.5.0"],
module="llama_stack.providers.inline.inference.meta_reference",
config_class="llama_stack.providers.inline.inference.meta_reference.MetaReferenceQuantizedInferenceConfig",
),
InlineProviderSpec(
api=Api.inference,
provider_type="inline::vllm",