Split off meta-reference-quantized provider

This commit is contained in:
Ashwin Bharambe 2024-10-10 15:54:08 -07:00
parent 7ff5800dea
commit 1ff0476002
10 changed files with 54 additions and 58 deletions

View file

@ -11,9 +11,8 @@ import json
import os
import sys
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Generator, List, Optional
from typing import Generator, List, Optional, Union
import torch
import torch.nn.functional as F
@ -36,14 +35,12 @@ from llama_models.llama3.reference_impl.multimodal.model import (
)
from llama_models.sku_list import resolve_model
from llama_stack.apis.inference import QuantizationType
from llama_stack.distribution.utils.model_utils import model_local_dir
from pydantic import BaseModel
from termcolor import cprint
from .config import MetaReferenceImplConfig
from llama_stack.distribution.utils.model_utils import model_local_dir
from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
def model_checkpoint_dir(model) -> str:
@ -68,7 +65,11 @@ class TokenResult(BaseModel):
class Llama:
@staticmethod
def build(config: MetaReferenceImplConfig):
def build(
config: Union[
MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
]
):
"""
Build a Llama instance by initializing and loading a model checkpoint.
@ -78,15 +79,6 @@ class Llama:
"""
model = resolve_model(config.model)
if (
config.quantization
and config.quantization.type == QuantizationType.fp8.value
):
from .quantization.loader import is_fbgemm_available
if not is_fbgemm_available():
raise ImportError("fbgemm-gpu is required for FP8 quantization")
if not torch.distributed.is_initialized():
torch.distributed.init_process_group("nccl")
@ -134,12 +126,7 @@ class Llama:
model_args.vocab_size == tokenizer.n_words
), f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}"
fp8 = (
config.quantization
and config.quantization.type == QuantizationType.fp8.value
)
if fp8:
if isinstance(config, MetaReferenceQuantizedInferenceConfig):
from .quantization.loader import convert_to_quantized_model
# load on CPU in bf16 so that fp8 conversion does not find an