updating license for toolchain

This commit is contained in:
Ashwin Bharambe 2024-07-22 20:31:42 -07:00
parent 0e2fc9966a
commit 86fff23a9e
74 changed files with 512 additions and 94 deletions

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described found in the
# LICENSE file in the root directory of this source tree.

View file

@ -1,2 +1,8 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described found in the
# LICENSE file in the root directory of this source tree.
from .datatypes import * # noqa: F401 F403
from .endpoints import * # noqa: F401 F403

View file

@ -1,14 +1,20 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
from enum import Enum
from typing import Literal, Optional, Union
from hydra.core.config_store import ConfigStore
from llama_models.llama3_1.api.datatypes import CheckpointQuantizationFormat
from pydantic import BaseModel, Field
from typing_extensions import Annotated
from .datatypes import QuantizationConfig
from llama_models.llama3_1.api.datatypes import CheckpointQuantizationFormat
class ImplType(Enum):

View file

@ -1,3 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described found in the
# LICENSE file in the root directory of this source tree.
from enum import Enum
from typing import List, Literal, Optional, Union
@ -26,9 +32,7 @@ class Fp8QuantizationConfig(BaseModel):
@json_schema_type
class Bf16QuantizationConfig(BaseModel):
type: Literal[QuantizationType.bf16.value] = (
QuantizationType.bf16.value
)
type: Literal[QuantizationType.bf16.value] = QuantizationType.bf16.value
QuantizationConfig = Annotated[

View file

@ -1,3 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described found in the
# LICENSE file in the root directory of this source tree.
from .datatypes import * # noqa: F403
from typing import Optional, Protocol

View file

@ -1,3 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described found in the
# LICENSE file in the root directory of this source tree.
from .api.config import ImplType, InferenceConfig

View file

@ -1,7 +1,15 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described found in the
# LICENSE file in the root directory of this source tree.
import asyncio
import json
from typing import AsyncGenerator
from urllib.request import getproxies
import fire
import httpx
@ -9,12 +17,16 @@ from .api import (
ChatCompletionRequest,
ChatCompletionResponseStreamChunk,
CompletionRequest,
InstructModel,
Inference,
InstructModel,
UserMessage,
)
from .event_logger import EventLogger
print(getproxies())
# import sys
# sys.exit(0)
class InferenceClient(Inference):
def __init__(self, base_url: str):

View file

@ -1,8 +1,12 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described found in the
# LICENSE file in the root directory of this source tree.
from termcolor import cprint
from llama_toolchain.inference.api import (
ChatCompletionResponseEventType,
)
from termcolor import cprint
from llama_toolchain.inference.api import ChatCompletionResponseEventType
class LogEvent:
@ -30,4 +34,3 @@ class EventLogger:
yield LogEvent(event.delta, color="yellow", end="")
elif event.event_type == ChatCompletionResponseEventType.complete:
yield LogEvent("")

View file

@ -1,3 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described found in the
# LICENSE file in the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.

View file

@ -1,16 +1,17 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described found in the
# LICENSE file in the root directory of this source tree.
from typing import AsyncGenerator
from llama_models.llama3_1.api.datatypes import StopReason
from .api.config import (
CheckpointQuantizationFormat,
CheckpointType,
InlineImplConfig,
)
from .api.config import InlineImplConfig
from .api.datatypes import (
ChatCompletionResponseEvent,
ChatCompletionResponseEventType,
QuantizationConfig,
ToolCallDelta,
ToolCallParseStatus,
)

View file

@ -1,3 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described found in the
# LICENSE file in the root directory of this source tree.
from copy import deepcopy
from dataclasses import dataclass
from functools import partial

View file

@ -1,3 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described found in the
# LICENSE file in the root directory of this source tree.
import multiprocessing
import os
import pickle

View file

@ -1,3 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described found in the
# LICENSE file in the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
@ -8,7 +14,7 @@ try:
import fbgemm_gpu.experimental.gen_ai # noqa: F401
print("Using efficient FP8 operators in FBGEMM.")
except (ImportError, ModuleNotFoundError):
except ImportError:
print("No efficient FP8 operators. Please install FBGEMM in fp8_requirements.txt.")
raise
@ -57,8 +63,8 @@ def ffn_swiglu(
x, w1, w3, w2, w1.activation_scale_ub, num_tokens, is_memory_bounded
)
(B, T, D) = x.shape
(HD_L, D_) = w1.shape
(B, T, D) = x.shape # noqa: N806
(HD_L, D_) = w1.shape # noqa: N806
assert D_ == D
assert isinstance(w1, Tensor)
@ -153,8 +159,8 @@ def ffn_swiglu_fp8_dynamic(
num_tokens: Optional[Tensor] = None,
is_memory_bounded: bool = False,
) -> Tensor:
(B, T, D) = x.shape
HD_L = w1.shape[0]
(B, T, D) = x.shape # noqa: N806
HD_L = w1.shape[0] # noqa: N806
assert HD_L == w3.shape[0]
x1 = fc_fp8_dynamic(
x.view(B * T, D),

View file

@ -1,3 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described found in the
# LICENSE file in the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
@ -9,13 +15,13 @@ import torch
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
from llama_models.llama3_1.api.model import Transformer, TransformerBlock
from termcolor import cprint
from llama_toolchain.inference.api.config import (
CheckpointQuantizationFormat,
InlineImplConfig,
)
from llama_toolchain.inference.api.datatypes import QuantizationType
from termcolor import cprint
from torch import Tensor
@ -24,7 +30,7 @@ def is_fbgemm_available() -> bool:
import fbgemm_gpu.experimental.gen_ai # noqa: F401
return True
except (ImportError, ModuleNotFoundError):
except ImportError:
return False

View file

@ -1,3 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described found in the
# LICENSE file in the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.

View file

@ -1,5 +1,11 @@
#!/bin/bash
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described found in the
# LICENSE file in the root directory of this source tree.
set -euo pipefail
set -x

View file

@ -1,3 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described found in the
# LICENSE file in the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
@ -5,7 +11,7 @@ import unittest
import torch
from fp8_impls import ffn_swiglu_fp8_dynamic, quantize_fp8, FfnQuantizeMode
from fp8_impls import ffn_swiglu_fp8_dynamic, FfnQuantizeMode, quantize_fp8
from hypothesis import given, settings, strategies as st
from torch import Tensor
@ -26,29 +32,25 @@ class FP8Tests(unittest.TestCase):
)
def test_fp8_ffn(
self,
D: int,
D: int, # noqa
HD_L: int,
B: int,
T: int,
UB: float,
) -> None:
x = torch.randn(size=(B, T, D), dtype=torch.bfloat16, device="cuda") * 0.1
w1 = (
torch.randn(size=(HD_L, D), dtype=torch.bfloat16, device="cuda") * 0.01
)
w3 = (
torch.randn(size=(HD_L, D), dtype=torch.bfloat16, device="cuda") * 0.01
)
w1 = torch.randn(size=(HD_L, D), dtype=torch.bfloat16, device="cuda") * 0.01
w3 = torch.randn(size=(HD_L, D), dtype=torch.bfloat16, device="cuda") * 0.01
w2 = torch.randn(size=(D, HD_L), dtype=torch.bfloat16, device="cuda") * 0.1
x_q = quantize_fp8(x, UB, mode = FfnQuantizeMode.FP8_ROWWISE)
w1_q = quantize_fp8(w1, UB, mode = FfnQuantizeMode.FP8_ROWWISE)
w3_q = quantize_fp8(w3, UB, mode = FfnQuantizeMode.FP8_ROWWISE)
w2_q = quantize_fp8(w2, UB, mode = FfnQuantizeMode.FP8_ROWWISE)
x_q = quantize_fp8(x, UB, mode=FfnQuantizeMode.FP8_ROWWISE)
w1_q = quantize_fp8(w1, UB, mode=FfnQuantizeMode.FP8_ROWWISE)
w3_q = quantize_fp8(w3, UB, mode=FfnQuantizeMode.FP8_ROWWISE)
w2_q = quantize_fp8(w2, UB, mode=FfnQuantizeMode.FP8_ROWWISE)
def ref_ffn(x: Tensor, w1: Tensor, w3: Tensor, w2: Tensor) -> Tensor:
(B, T, D) = x.shape
(HD_L, D_) = w1.shape
(B, T, D) = x.shape # noqa: N806
(HD_L, D_) = w1.shape # noqa: N806
assert D_ == D
x1 = x.view(B * T, D) @ w1.T

View file

@ -1,3 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described found in the
# LICENSE file in the root directory of this source tree.
import asyncio
import signal