mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 10:54:19 +00:00
Split off meta-reference-quantized provider
This commit is contained in:
parent
7ff5800dea
commit
1ff0476002
10 changed files with 54 additions and 58 deletions
|
@ -11,9 +11,8 @@ import json
|
|||
import os
|
||||
import sys
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Generator, List, Optional
|
||||
from typing import Generator, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
@ -36,14 +35,12 @@ from llama_models.llama3.reference_impl.multimodal.model import (
|
|||
)
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from llama_stack.apis.inference import QuantizationType
|
||||
|
||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||
|
||||
from pydantic import BaseModel
|
||||
from termcolor import cprint
|
||||
|
||||
from .config import MetaReferenceImplConfig
|
||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||
|
||||
from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
|
||||
|
||||
|
||||
def model_checkpoint_dir(model) -> str:
|
||||
|
@ -68,7 +65,11 @@ class TokenResult(BaseModel):
|
|||
|
||||
class Llama:
|
||||
@staticmethod
|
||||
def build(config: MetaReferenceImplConfig):
|
||||
def build(
|
||||
config: Union[
|
||||
MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
|
||||
]
|
||||
):
|
||||
"""
|
||||
Build a Llama instance by initializing and loading a model checkpoint.
|
||||
|
||||
|
@ -78,15 +79,6 @@ class Llama:
|
|||
"""
|
||||
model = resolve_model(config.model)
|
||||
|
||||
if (
|
||||
config.quantization
|
||||
and config.quantization.type == QuantizationType.fp8.value
|
||||
):
|
||||
from .quantization.loader import is_fbgemm_available
|
||||
|
||||
if not is_fbgemm_available():
|
||||
raise ImportError("fbgemm-gpu is required for FP8 quantization")
|
||||
|
||||
if not torch.distributed.is_initialized():
|
||||
torch.distributed.init_process_group("nccl")
|
||||
|
||||
|
@ -134,12 +126,7 @@ class Llama:
|
|||
model_args.vocab_size == tokenizer.n_words
|
||||
), f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}"
|
||||
|
||||
fp8 = (
|
||||
config.quantization
|
||||
and config.quantization.type == QuantizationType.fp8.value
|
||||
)
|
||||
|
||||
if fp8:
|
||||
if isinstance(config, MetaReferenceQuantizedInferenceConfig):
|
||||
from .quantization.loader import convert_to_quantized_model
|
||||
|
||||
# load on CPU in bf16 so that fp8 conversion does not find an
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue