Merge branch 'main' into evals_5

This commit is contained in:
Xi Yan 2024-10-24 09:18:15 -07:00
commit 071dba8871
14 changed files with 399 additions and 27 deletions

1
.gitignore vendored
View file

@ -13,6 +13,7 @@ xcuserdata/
Package.resolved
*.pte
*.ipynb_checkpoints*
.idea
.venv/
.idea
_build

View file

@ -7,6 +7,7 @@ A Distribution is where APIs and Providers are assembled together to provide a c
| **Distribution** | **Llama Stack Docker** | Start This Distribution | **Inference** | **Agents** | **Memory** | **Safety** | **Telemetry** |
|:----------------: |:------------------------------------------: |:-----------------------: |:------------------: |:------------------: |:------------------: |:------------------: |:------------------: |
| Meta Reference | [llamastack/distribution-meta-reference-gpu](https://hub.docker.com/repository/docker/llamastack/distribution-meta-reference-gpu/general) | [Guide](./meta-reference-gpu/) | meta-reference | meta-reference | meta-reference; remote::pgvector; remote::chromadb | meta-reference | meta-reference |
| Meta Reference Quantized | [llamastack/distribution-meta-reference-quantized-gpu](https://hub.docker.com/repository/docker/llamastack/distribution-meta-reference-quantized-gpu/general) | [Guide](./meta-reference-quantized-gpu/) | meta-reference-quantized | meta-reference | meta-reference; remote::pgvector; remote::chromadb | meta-reference | meta-reference |
| Ollama | [llamastack/distribution-ollama](https://hub.docker.com/repository/docker/llamastack/distribution-ollama/general) | [Guide](./ollama/) | remote::ollama | meta-reference | remote::pgvector; remote::chromadb | remote::ollama | meta-reference |
| TGI | [llamastack/distribution-tgi](https://hub.docker.com/repository/docker/llamastack/distribution-tgi/general) | [Guide](./tgi/) | remote::tgi | meta-reference | meta-reference; remote::pgvector; remote::chromadb | meta-reference | meta-reference |
| Together | [llamastack/distribution-together](https://hub.docker.com/repository/docker/llamastack/distribution-together/general) | [Guide](./together/) | remote::together | meta-reference | remote::weaviate | meta-reference | meta-reference |

View file

@ -1,5 +1,6 @@
name: meta-reference-gpu
distribution_spec:
docker_image: pytorch/pytorch
description: Use code from `llama_stack` itself to serve all llama stack APIs
providers:
inference: meta-reference

View file

@ -0,0 +1,34 @@
# Meta Reference Quantized Distribution
The `llamastack/distribution-meta-reference-quantized-gpu` distribution consists of the following provider configurations.
| **API** | **Inference** | **Agents** | **Memory** | **Safety** | **Telemetry** |
|----------------- |------------------------ |---------------- |-------------------------------------------------- |---------------- |---------------- |
| **Provider(s)** | meta-reference-quantized | meta-reference | meta-reference, remote::pgvector, remote::chroma | meta-reference | meta-reference |
The only difference vs. the `meta-reference-gpu` distribution is that it has support for more efficient inference -- with fp8, int4 quantization, etc.
### Start the Distribution (Single Node GPU)
> [!NOTE]
> This assumes you have access to GPU to start a local server with access to your GPU.
> [!NOTE]
> `~/.llama` should be the path containing downloaded weights of Llama models.
To download and start running a pre-built docker container, you may use the following commands:
```
docker run -it -p 5000:5000 -v ~/.llama:/root/.llama \
-v ./run.yaml:/root/my-run.yaml \
--gpus=all \
distribution-meta-reference-quantized-gpu \
--yaml_config /root/my-run.yaml
```
### Alternative (Build and start distribution locally via conda)
- You may checkout the [Getting Started](../../docs/getting_started.md) for more details on building locally via conda and starting up the distribution.

View file

@ -0,0 +1,14 @@
name: meta-reference-quantized-gpu
distribution_spec:
docker_image: pytorch/pytorch:2.5.0-cuda12.4-cudnn9-runtime
description: Use code from `llama_stack` itself to serve all llama stack APIs
providers:
inference: meta-reference-quantized
memory:
- meta-reference
- remote::chromadb
- remote::pgvector
safety: meta-reference
agents: meta-reference
telemetry: meta-reference
image_type: docker

View file

@ -0,0 +1,51 @@
version: '2'
built_at: '2024-10-08T17:40:45.325529'
image_name: local
docker_image: null
conda_env: local
apis:
- shields
- agents
- models
- memory
- memory_banks
- inference
- safety
providers:
inference:
- provider_id: meta0
provider_type: meta-reference-quantized
config:
model: Llama3.2-3B-Instruct
quantization:
type: fp8
torch_seed: null
max_seq_len: 2048
max_batch_size: 1
safety:
- provider_id: meta0
provider_type: meta-reference
config:
llama_guard_shield:
model: Llama-Guard-3-1B
excluded_categories: []
disable_input_check: false
disable_output_check: false
prompt_guard_shield:
model: Prompt-Guard-86M
memory:
- provider_id: meta0
provider_type: meta-reference
config: {}
agents:
- provider_id: meta0
provider_type: meta-reference
config:
persistence_store:
namespace: null
type: sqlite
db_path: ~/.llama/runtime/kvstore.db
telemetry:
- provider_id: meta0
provider_type: meta-reference
config: {}

View file

@ -172,7 +172,7 @@ async def run_mm_main(
],
)
cprint(f"User>{message.content}", "green")
iterator = client.chat_completion(
iterator = await client.chat_completion(
model=model,
messages=[message],
stream=stream,

View file

@ -25,6 +25,7 @@ class LogProbConfig(BaseModel):
class QuantizationType(Enum):
bf16 = "bf16"
fp8 = "fp8"
int4 = "int4"
@json_schema_type
@ -37,8 +38,14 @@ class Bf16QuantizationConfig(BaseModel):
type: Literal[QuantizationType.bf16.value] = QuantizationType.bf16.value
@json_schema_type
class Int4QuantizationConfig(BaseModel):
type: Literal[QuantizationType.int4.value] = QuantizationType.int4.value
scheme: Optional[str] = None
QuantizationConfig = Annotated[
Union[Bf16QuantizationConfig, Fp8QuantizationConfig],
Union[Bf16QuantizationConfig, Fp8QuantizationConfig, Int4QuantizationConfig],
Field(discriminator="type"),
]
@ -219,8 +226,6 @@ class Inference(Protocol):
logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, CompletionResponseStreamChunk]: ...
# This method is not `async def` because it can result in either an
# `AsyncGenerator` or a `ChatCompletionResponse` depending on the value of `stream`.
@webmethod(route="/inference/chat_completion")
async def chat_completion(
self,

View file

@ -97,7 +97,7 @@ if [ -n "$pip_dependencies" ]; then
fi
if [ -n "$special_pip_deps" ]; then
IFS='#' read -ra parts <<< "$special_pip_deps"
IFS='#' read -ra parts <<<"$special_pip_deps"
for part in "${parts[@]}"; do
add_to_docker "RUN pip install $part"
done
@ -127,7 +127,7 @@ if [ -n "$LLAMA_MODELS_DIR" ]; then
mounts="$mounts -v $(readlink -f $LLAMA_MODELS_DIR):$models_mount"
fi
if command -v selinuxenabled &> /dev/null && selinuxenabled; then
if command -v selinuxenabled &>/dev/null && selinuxenabled; then
# Disable SELinux labels -- we don't want to relabel the llama-stack source dir
DOCKER_OPTS="$DOCKER_OPTS --security-opt label=disable"
fi
@ -139,4 +139,4 @@ $DOCKER_BINARY build $DOCKER_OPTS -t $image_name -f "$TEMP_DIR/Dockerfile" "$REP
rm -rf $REPO_CONFIGS_DIR
set +x
echo "Success! You can run it with: $DOCKER_BINARY $DOCKER_OPTS run -p 5000:5000 $image_name"
echo "Success!"

View file

@ -30,7 +30,6 @@ from llama_models.llama3.reference_impl.multimodal.model import (
CrossAttentionTransformer,
)
from llama_models.sku_list import resolve_model
from pydantic import BaseModel
from termcolor import cprint
@ -43,7 +42,12 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_messages,
)
from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
from .config import (
Fp8QuantizationConfig,
Int4QuantizationConfig,
MetaReferenceInferenceConfig,
MetaReferenceQuantizedInferenceConfig,
)
def model_checkpoint_dir(model) -> str:
@ -131,18 +135,34 @@ class Llama:
), f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}"
if isinstance(config, MetaReferenceQuantizedInferenceConfig):
from .quantization.loader import convert_to_quantized_model
# load on CPU in bf16 so that fp8 conversion does not find an
# unexpected (fp32, e.g.) datatype
torch.set_default_tensor_type(torch.BFloat16Tensor)
if model_args.vision_chunk_size > 0:
model = CrossAttentionTransformer(model_args)
model.setup_cache(model_args.max_batch_size, torch.bfloat16)
else:
if isinstance(config.quantization, Fp8QuantizationConfig):
from .quantization.loader import convert_to_fp8_quantized_model
# load on CPU in bf16 so that fp8 conversion does not find an
# unexpected (fp32, e.g.) datatype
torch.set_default_tensor_type(torch.BFloat16Tensor)
if model_args.vision_chunk_size > 0:
model = CrossAttentionTransformer(model_args)
model.setup_cache(model_args.max_batch_size, torch.bfloat16)
else:
model = Transformer(model_args)
model.load_state_dict(state_dict, strict=False)
model = convert_to_fp8_quantized_model(model, config, ckpt_dir)
elif isinstance(config.quantization, Int4QuantizationConfig):
from .quantization.loader import convert_to_int4_quantized_model
assert (
config.quantization.scheme is not None
), "Please specify a quantization scheme."
model = Transformer(model_args)
model.load_state_dict(state_dict, strict=False)
model = convert_to_quantized_model(model, config, ckpt_dir)
model = convert_to_int4_quantized_model(model, model_args, config)
model.load_state_dict(state_dict, strict=True)
else:
raise NotImplementedError(
"Currently int4 and fp8 are the only supported quantization methods."
)
else:
if torch.cuda.is_bf16_supported():
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)

View file

@ -8,19 +8,25 @@
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
import os
from typing import Optional
from typing import Any, Dict, List, Optional
import torch
from fairscale.nn.model_parallel.layers import ColumnParallelLinear, RowParallelLinear
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
from llama_models.datatypes import CheckpointQuantizationFormat
from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock
from llama_models.datatypes import CheckpointQuantizationFormat
from llama_models.llama3.api.args import ModelArgs
from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock
from llama_models.sku_list import resolve_model
from termcolor import cprint
from torch import Tensor
from torch import nn, Tensor
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
from llama_stack.apis.inference import QuantizationType
from llama_stack.apis.inference.inference import Int4QuantizationConfig
from llama_stack.providers.impls.meta_reference.inference.config import (
MetaReferenceQuantizedInferenceConfig,
@ -37,7 +43,7 @@ def swiglu_wrapper(
return reduce_from_model_parallel_region(out)
def convert_to_quantized_model(
def convert_to_fp8_quantized_model(
model: Transformer,
config: MetaReferenceQuantizedInferenceConfig,
checkpoint_dir: str,
@ -99,3 +105,241 @@ def convert_to_quantized_model(
if not isinstance(parameter, Fp8ScaledWeights):
parameter.data = parameter.to(device="cuda")
return model
class Int8DynActInt4WeightLinearLoRA(Int8DynActInt4WeightLinear):
"""
Int8DynActInt4WeightLinear with LoRA adaptor.
Args:
in_features: Number of input features.
out_features: Number of output features.
bias: Whether to use bias.
device: Device to use.
group_size: Group size for quantization.
precision: Precision of quantization.
scales_precision: Precision of scales.
lora_rank: Rank of LoRA adaptor.
lora_scale: Scale of LoRA adaptor.
"""
def __init__(
self,
in_features: int,
out_features: int,
bias=False,
device=None,
# quantization parameters
group_size: int = 256,
precision: torch.dtype = torch.float32,
scales_precision: torch.dtype = torch.float32,
# LoRA parameters
lora_rank: Optional[int] = None,
lora_scale: Optional[float] = None,
) -> None:
super().__init__(
in_features,
out_features,
bias=bias,
device=device,
groupsize=group_size,
precision=precision,
scales_precision=scales_precision,
)
if lora_rank is not None:
assert lora_scale is not None, "Please specify lora scale for LoRA."
# Low-rank adaptation. See paper for more details: https://arxiv.org/abs/2106.09685
self.adaptor = nn.Sequential()
self.adaptor.add_module("A", nn.Linear(in_features, lora_rank, bias=False))
self.adaptor.add_module("B", nn.Linear(lora_rank, out_features, bias=False))
self.lora_scale = lora_scale
else:
self.adaptor = None
self.lora_scale = None
self._register_load_state_dict_pre_hook(self.load_hook)
def load_hook(
self,
state_dict: Dict[str, Any],
prefix: str,
local_metadata: Dict[str, Any],
strict: bool,
missing_keys: List[str],
unexpected_keys: List[str],
error_msgs: List[str],
) -> None:
"""A hook to load the quantized weights from the state dict."""
if prefix + "zeros" not in state_dict:
# Zero-point may not be saved in the state dict. In this case, we assume it's zero.
assert prefix + "scales" in state_dict
state_dict[prefix + "zeros"] = torch.zeros_like(
state_dict[prefix + "scales"]
)
def forward(self, input_: torch.Tensor) -> torch.Tensor:
module_out = super().forward(input_)
if self.adaptor is not None:
adaptor_out = self.adaptor(input_) * self.lora_scale
return module_out + adaptor_out
return module_out
class Int8WeightEmbedding(torch.nn.Embedding):
"""An embedding layer to load int8 weights.
Args:
num_embeddings: Number of embeddings.
embedding_dim: Embedding dimension.
padding_idx: Padding index.
"""
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
padding_idx: int,
device=None,
) -> None:
super().__init__(num_embeddings, embedding_dim, padding_idx, device=device)
self._register_load_state_dict_pre_hook(self.load_hook)
def load_hook(
self,
state_dict: Dict[str, Any],
prefix: str,
local_metadata: Dict[str, Any],
strict: bool,
missing_keys: List[str],
unexpected_keys: List[str],
error_msgs: List[str],
) -> None:
"""A hook to load the quantized embedding weight and scales from the state dict."""
weights = state_dict.pop(prefix + "weight")
scales = state_dict.pop(prefix + "scales")
state_dict[prefix + "weight"] = weights * scales
class Int8WeightLinear(torch.nn.Linear):
"""A linear layer to load int8 weights.
Args:
in_features: Number of input features.
out_features: Number of output features.
bias: Whether to use bias.
"""
def __init__(
self, in_features: int, out_features: int, bias: bool = True, device=None
) -> None:
super().__init__(in_features, out_features, bias, device=device)
self._register_load_state_dict_pre_hook(self.load_hook)
def load_hook(
self,
state_dict: Dict[str, Any],
prefix: str,
local_metadata: Dict[str, Any],
strict: bool,
missing_keys: List[str],
unexpected_keys: List[str],
error_msgs: List[str],
) -> None:
"""A hook to load the quantized linear weight and scales from the state dict."""
weights = state_dict.pop(prefix + "weight")
scales = state_dict.pop(prefix + "scales")
state_dict[prefix + "weight"] = weights * scales
def _prepare_model_int4_weight_int8_dynamic_activation(
model: torch.nn.Module,
group_size: int,
lora_rank: Optional[int],
lora_scale: Optional[float],
):
"""Prepare the model for int4 weight and int8 dynamic activation quantization.
Note that the weights of embedding and output layers are quantized to int8.
"""
device = None
for module_name, module in model.named_children():
if module_name == "output":
quantized_module = Int8WeightLinear(
in_features=module.in_features,
out_features=module.out_features,
bias=module.bias,
device=device,
)
del module
setattr(model, module_name, quantized_module)
elif module_name == "tok_embeddings":
quantized_module = Int8WeightEmbedding(
num_embeddings=module.num_embeddings,
embedding_dim=module.embedding_dim,
padding_idx=module.padding_idx,
device=device,
)
del module
setattr(model, module_name, quantized_module)
elif isinstance(module, (ColumnParallelLinear, RowParallelLinear, nn.Linear)):
quantized_module = Int8DynActInt4WeightLinearLoRA(
in_features=module.in_features,
out_features=module.out_features,
bias=False,
group_size=group_size,
lora_rank=lora_rank,
lora_scale=lora_scale,
device=device,
)
del module
setattr(model, module_name, quantized_module)
else:
_prepare_model_int4_weight_int8_dynamic_activation(
module, group_size, lora_rank, lora_scale
)
return model
def convert_to_int4_quantized_model(
model: Transformer,
model_args: ModelArgs,
config: MetaReferenceQuantizedInferenceConfig,
) -> Transformer:
"""Convert the model to int4 quantized model."""
quant_config = config.quantization
if not isinstance(quant_config, Int4QuantizationConfig):
raise ValueError("Only int4 quantization is supported")
if quant_config.type != QuantizationType.int4.value:
raise ValueError("Only int4 quantization is supported")
if quant_config.scheme != "int4_weight_int8_dynamic_activation":
raise NotImplementedError(
"Only int4 quantization with 'int4_weight_int8_dynamic_activation' scheme is supported."
)
if model_args.quantization_args is None:
raise ValueError("'quantization_args' cannot be None. Please specify it.")
group_size = model_args.quantization_args.group_size
if group_size is None:
raise ValueError(
"'group_size' cannot be None in 'quantization_args'. Please specify it."
)
if model_args.lora_args is None:
# Certain quantized models (e.g., SpinQuant) may not have LoRA.
lora_rank = None
lora_scale = None
else:
lora_rank = model_args.lora_args.rank
lora_scale = model_args.lora_args.scale
_prepare_model_int4_weight_int8_dynamic_activation(
model, group_size, lora_rank, lora_scale
)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
return model.to(device)

View file

@ -36,7 +36,8 @@ def available_providers() -> List[ProviderSpec]:
pip_packages=(
META_REFERENCE_DEPS
+ [
"fbgemm-gpu==0.8.0",
"fbgemm-gpu",
"torchao==0.5.0",
]
),
module="llama_stack.providers.impls.meta_reference.inference",

View file

@ -2,7 +2,7 @@ blobfile
fire
httpx
huggingface-hub
llama-models>=0.0.43
llama-models>=0.0.44
prompt-toolkit
python-dotenv
pydantic>=2

View file

@ -16,7 +16,7 @@ def read_requirements():
setup(
name="llama_stack",
version="0.0.43",
version="0.0.44",
author="Meta Llama",
author_email="llama-oss@meta.com",
description="Llama Stack",