mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-31 20:00:02 +00:00
fold in meta-reference-quantized
This commit is contained in:
parent
cfaf9e0e8b
commit
ff6c47d4e5
9 changed files with 24 additions and 439 deletions
|
|
@ -4,13 +4,13 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, Union
|
||||
from typing import Any, Dict
|
||||
|
||||
from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
|
||||
from .config import MetaReferenceInferenceConfig
|
||||
|
||||
|
||||
async def get_provider_impl(
|
||||
config: Union[MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig],
|
||||
config: MetaReferenceInferenceConfig,
|
||||
_deps: Dict[str, Any],
|
||||
):
|
||||
from .inference import MetaReferenceInferenceImpl
|
||||
|
|
|
|||
|
|
@ -31,6 +31,8 @@ class MetaReferenceInferenceConfig(BaseModel):
|
|||
# can override by specifying the directory explicitly
|
||||
checkpoint_dir: Optional[str] = None
|
||||
|
||||
quantization: Optional[QuantizationConfig] = None
|
||||
|
||||
@field_validator("model")
|
||||
@classmethod
|
||||
def validate_model(cls, model: str) -> str:
|
||||
|
|
@ -47,27 +49,14 @@ class MetaReferenceInferenceConfig(BaseModel):
|
|||
cls,
|
||||
model: str = "Llama3.2-3B-Instruct",
|
||||
checkpoint_dir: str = "${env.CHECKPOINT_DIR:null}",
|
||||
quantization_type: str = "${env.QUANTIZATION_TYPE:bf16}",
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
return {
|
||||
"model": model,
|
||||
"max_seq_len": 4096,
|
||||
"checkpoint_dir": checkpoint_dir,
|
||||
"quantization": {
|
||||
"type": quantization_type,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class MetaReferenceQuantizedInferenceConfig(MetaReferenceInferenceConfig):
|
||||
quantization: QuantizationConfig
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(
|
||||
cls,
|
||||
model: str = "Llama3.2-3B-Instruct",
|
||||
checkpoint_dir: str = "${env.CHECKPOINT_DIR:null}",
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
config = super().sample_run_config(model, checkpoint_dir, **kwargs)
|
||||
config["quantization"] = {
|
||||
"type": "fp8",
|
||||
}
|
||||
return config
|
||||
|
|
|
|||
|
|
@ -11,9 +11,7 @@ import torch
|
|||
from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
Fp8QuantizationConfig,
|
||||
GreedySamplingStrategy,
|
||||
Int4QuantizationConfig,
|
||||
JsonSchemaResponseFormat,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
|
|
@ -32,7 +30,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
)
|
||||
|
||||
from .common import model_checkpoint_dir
|
||||
from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
|
||||
from .config import MetaReferenceInferenceConfig
|
||||
from .inference import resolve_model
|
||||
|
||||
Tokenizer = Llama4Tokenizer | Llama3Tokenizer
|
||||
|
|
@ -118,7 +116,7 @@ def _infer_tool_prompt_format(request: ChatCompletionRequestWithRawContent):
|
|||
class Llama4Generator:
|
||||
def __init__(
|
||||
self,
|
||||
config: MetaReferenceInferenceConfig | MetaReferenceQuantizedInferenceConfig,
|
||||
config: MetaReferenceInferenceConfig,
|
||||
model_id: str,
|
||||
llama_model: Model,
|
||||
):
|
||||
|
|
@ -133,11 +131,13 @@ class Llama4Generator:
|
|||
# if the model is a native llama model, get the default checkpoint_dir based on model core_model_id value
|
||||
ckpt_dir = model_checkpoint_dir(resolved_model.descriptor())
|
||||
|
||||
if isinstance(config, MetaReferenceQuantizedInferenceConfig):
|
||||
if isinstance(config.quantization, Fp8QuantizationConfig):
|
||||
if config.quantization:
|
||||
if config.quantization.type == "fp8":
|
||||
quantization_mode = QuantizationMode.fp8_mixed
|
||||
elif isinstance(config.quantization, Int4QuantizationConfig):
|
||||
elif config.quantization.type == "int4":
|
||||
quantization_mode = QuantizationMode.int4_mixed
|
||||
elif config.quantization.type == "bf16":
|
||||
quantization_mode = None
|
||||
else:
|
||||
raise ValueError(f"Unsupported quantization mode {config.quantization}")
|
||||
else:
|
||||
|
|
@ -207,7 +207,7 @@ class Llama4Generator:
|
|||
class Llama3Generator:
|
||||
def __init__(
|
||||
self,
|
||||
config: MetaReferenceInferenceConfig | MetaReferenceQuantizedInferenceConfig,
|
||||
config: MetaReferenceInferenceConfig,
|
||||
model_id: str,
|
||||
llama_model: Model,
|
||||
):
|
||||
|
|
@ -222,11 +222,13 @@ class Llama3Generator:
|
|||
# if the model is a native llama model, get the default checkpoint_dir based on model core_model_id value
|
||||
ckpt_dir = model_checkpoint_dir(resolved_model.descriptor())
|
||||
|
||||
if isinstance(config, MetaReferenceQuantizedInferenceConfig):
|
||||
if isinstance(config.quantization, Fp8QuantizationConfig):
|
||||
if config.quantization:
|
||||
if config.quantization.type == "fp8":
|
||||
quantization_mode = QuantizationMode.fp8_mixed
|
||||
elif isinstance(config.quantization, Int4QuantizationConfig):
|
||||
elif config.quantization.type == "int4":
|
||||
quantization_mode = QuantizationMode.int4_mixed
|
||||
elif config.quantization.type == "bf16":
|
||||
quantization_mode = None
|
||||
else:
|
||||
raise ValueError(f"Unsupported quantization mode {config.quantization}")
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -24,6 +24,8 @@ META_REFERENCE_DEPS = [
|
|||
"zmq",
|
||||
"lm-format-enforcer",
|
||||
"sentence-transformers",
|
||||
"torchao==0.5.0",
|
||||
"fbgemm-gpu-genai==1.1.2",
|
||||
]
|
||||
|
||||
|
||||
|
|
@ -36,13 +38,6 @@ def available_providers() -> List[ProviderSpec]:
|
|||
module="llama_stack.providers.inline.inference.meta_reference",
|
||||
config_class="llama_stack.providers.inline.inference.meta_reference.MetaReferenceInferenceConfig",
|
||||
),
|
||||
InlineProviderSpec(
|
||||
api=Api.inference,
|
||||
provider_type="inline::meta-reference-quantized",
|
||||
pip_packages=META_REFERENCE_DEPS + ["fbgemm-gpu", "torchao==0.5.0"],
|
||||
module="llama_stack.providers.inline.inference.meta_reference",
|
||||
config_class="llama_stack.providers.inline.inference.meta_reference.MetaReferenceQuantizedInferenceConfig",
|
||||
),
|
||||
InlineProviderSpec(
|
||||
api=Api.inference,
|
||||
provider_type="inline::vllm",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue