Split off meta-reference-quantized provider

This commit is contained in:
Ashwin Bharambe 2024-10-10 15:54:08 -07:00
parent 7ff5800dea
commit 1ff0476002
10 changed files with 54 additions and 58 deletions

View file

@ -149,6 +149,7 @@ class StackBuild(Subcommand):
def _run_template_list_cmd(self, args: argparse.Namespace) -> None:
import json
from llama_stack.cli.table import print_table
# eventually, this should query a registry at llama.meta.com/llamastack/distributions
@ -175,6 +176,7 @@ class StackBuild(Subcommand):
def _run_stack_build_command(self, args: argparse.Namespace) -> None:
import textwrap
import yaml
from llama_stack.distribution.distribution import get_provider_registry
from prompt_toolkit import prompt
@ -256,7 +258,9 @@ class StackBuild(Subcommand):
providers = dict()
for api, providers_for_api in get_provider_registry().items():
available_providers = [
x for x in providers_for_api.keys() if x != "remote"
x
for x in providers_for_api.keys()
if x not in ("remote", "remote::sample")
]
api_provider = prompt(
"> Enter provider for API {}: ".format(api.value),

View file

@ -16,7 +16,7 @@ from llama_stack.apis.agents import * # noqa: F403
from ..agents import (
AGENT_INSTANCES_BY_ID,
MetaReferenceAgentsImpl,
MetaReferenceImplConfig,
MetaReferenceInferenceConfig,
)
@ -166,7 +166,7 @@ def mock_memory_api():
@pytest.fixture
async def chat_agent(mock_inference_api, mock_safety_api, mock_memory_api):
impl = MetaReferenceAgentsImpl(
config=MetaReferenceImplConfig(),
config=MetaReferenceInferenceConfig(),
inference_api=mock_inference_api,
safety_api=mock_safety_api,
memory_api=mock_memory_api,

View file

@ -4,16 +4,17 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import MetaReferenceImplConfig # noqa
from typing import Union
from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
async def get_provider_impl(config: MetaReferenceImplConfig, _deps):
async def get_provider_impl(
config: Union[MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig],
_deps,
):
from .inference import MetaReferenceInferenceImpl
assert isinstance(
config, MetaReferenceImplConfig
), f"Unexpected config type: {type(config)}"
impl = MetaReferenceInferenceImpl(config)
await impl.initialize()
return impl

View file

@ -15,12 +15,11 @@ from pydantic import BaseModel, Field, field_validator
from llama_stack.providers.utils.inference import supported_inference_models
class MetaReferenceImplConfig(BaseModel):
class MetaReferenceInferenceConfig(BaseModel):
model: str = Field(
default="Llama3.1-8B-Instruct",
description="Model descriptor from `llama model list`",
)
quantization: Optional[QuantizationConfig] = None
torch_seed: Optional[int] = None
max_seq_len: int = 4096
max_batch_size: int = 1
@ -38,9 +37,9 @@ class MetaReferenceImplConfig(BaseModel):
@property
def model_parallel_size(self) -> int:
# HACK ALERT: this will be fixed when we move inference configuration
# to ModelsRegistry and we can explicitly ask for `model_parallel_size`
# as configuration there
resolved = resolve_model(self.model)
assert resolved is not None
return resolved.pth_file_count
class MetaReferenceQuantizedInferenceConfig(MetaReferenceInferenceConfig):
quantization: QuantizationConfig

View file

@ -11,9 +11,8 @@ import json
import os
import sys
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Generator, List, Optional
from typing import Generator, List, Optional, Union
import torch
import torch.nn.functional as F
@ -36,14 +35,12 @@ from llama_models.llama3.reference_impl.multimodal.model import (
)
from llama_models.sku_list import resolve_model
from llama_stack.apis.inference import QuantizationType
from llama_stack.distribution.utils.model_utils import model_local_dir
from pydantic import BaseModel
from termcolor import cprint
from .config import MetaReferenceImplConfig
from llama_stack.distribution.utils.model_utils import model_local_dir
from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
def model_checkpoint_dir(model) -> str:
@ -68,7 +65,11 @@ class TokenResult(BaseModel):
class Llama:
@staticmethod
def build(config: MetaReferenceImplConfig):
def build(
config: Union[
MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
]
):
"""
Build a Llama instance by initializing and loading a model checkpoint.
@ -78,15 +79,6 @@ class Llama:
"""
model = resolve_model(config.model)
if (
config.quantization
and config.quantization.type == QuantizationType.fp8.value
):
from .quantization.loader import is_fbgemm_available
if not is_fbgemm_available():
raise ImportError("fbgemm-gpu is required for FP8 quantization")
if not torch.distributed.is_initialized():
torch.distributed.init_process_group("nccl")
@ -134,12 +126,7 @@ class Llama:
model_args.vocab_size == tokenizer.n_words
), f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}"
fp8 = (
config.quantization
and config.quantization.type == QuantizationType.fp8.value
)
if fp8:
if isinstance(config, MetaReferenceQuantizedInferenceConfig):
from .quantization.loader import convert_to_quantized_model
# load on CPU in bf16 so that fp8 conversion does not find an

View file

@ -17,7 +17,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_messages,
)
from .config import MetaReferenceImplConfig
from .config import MetaReferenceInferenceConfig
from .model_parallel import LlamaModelParallelGenerator
# there's a single model parallel process running serving the model. for now,
@ -26,7 +26,7 @@ SEMAPHORE = asyncio.Semaphore(1)
class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
def __init__(self, config: MetaReferenceImplConfig) -> None:
def __init__(self, config: MetaReferenceInferenceConfig) -> None:
self.config = config
model = resolve_model(config.model)
if model is None:

View file

@ -6,7 +6,6 @@
import os
from copy import deepcopy
from dataclasses import dataclass
from functools import partial
from typing import Generator, List, Optional
@ -15,7 +14,7 @@ from llama_models.llama3.api.datatypes import Message, ToolPromptFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.sku_list import resolve_model
from .config import MetaReferenceImplConfig
from .config import MetaReferenceInferenceConfig
from .generation import Llama, model_checkpoint_dir
from .parallel_utils import InferenceArgs, ModelParallelProcessGroup
@ -36,7 +35,7 @@ class ModelRunner:
)
def init_model_cb(config: MetaReferenceImplConfig):
def init_model_cb(config: MetaReferenceInferenceConfig):
llama = Llama.build(config)
return ModelRunner(llama)
@ -52,7 +51,7 @@ class LlamaModelParallelGenerator:
clear at the callsite why we need to use a context manager.
"""
def __init__(self, config: MetaReferenceImplConfig):
def __init__(self, config: MetaReferenceInferenceConfig):
self.config = config
self.model = resolve_model(self.config.model)
# this is a hack because Agent's loop uses this to tokenize and check if input is too long

View file

@ -11,7 +11,7 @@ import tempfile
import time
import uuid
from enum import Enum
from typing import Any, Callable, Generator, List, Literal, Optional, Union
from typing import Callable, Generator, List, Literal, Optional, Union
import torch
@ -317,7 +317,7 @@ def start_model_parallel_process(
request_socket.send(encode_msg(ReadyRequest()))
response = request_socket.recv()
print(f"Finished model load {response}")
print("Loaded model...")
return request_socket, process

View file

@ -22,19 +22,10 @@ from torch import Tensor
from llama_stack.apis.inference import QuantizationType
from llama_stack.providers.impls.meta_reference.inference.config import (
MetaReferenceImplConfig,
MetaReferenceQuantizedInferenceConfig,
)
def is_fbgemm_available() -> bool:
try:
import fbgemm_gpu.experimental.gen_ai # noqa: F401
return True
except ImportError:
return False
def swiglu_wrapper(
self,
x: Tensor,
@ -47,7 +38,7 @@ def swiglu_wrapper(
def convert_to_quantized_model(
model: Transformer,
config: MetaReferenceImplConfig,
config: MetaReferenceQuantizedInferenceConfig,
fp8_activation_scale_ub: Optional[float] = 1200.0,
) -> Transformer:
if config.quantization.type == QuantizationType.bf16.value:

View file

@ -14,6 +14,21 @@ def available_providers() -> List[ProviderSpec]:
InlineProviderSpec(
api=Api.inference,
provider_type="meta-reference",
pip_packages=[
"accelerate",
"blobfile",
"fairscale",
"torch",
"torchvision",
"transformers",
"zmq",
],
module="llama_stack.providers.impls.meta_reference.inference",
config_class="llama_stack.providers.impls.meta_reference.inference.MetaReferenceInferenceConfig",
),
InlineProviderSpec(
api=Api.inference,
provider_type="meta-reference-quantized",
pip_packages=[
"accelerate",
"blobfile",
@ -25,7 +40,7 @@ def available_providers() -> List[ProviderSpec]:
"zmq",
],
module="llama_stack.providers.impls.meta_reference.inference",
config_class="llama_stack.providers.impls.meta_reference.inference.MetaReferenceImplConfig",
config_class="llama_stack.providers.impls.meta_reference.inference.MetaReferenceQuantizedInferenceConfig",
),
remote_provider_spec(
api=Api.inference,