Fix precommit check after moving to ruff (#927)

Lint check in main branch is failing. This fixes the lint check after we
moved to ruff in https://github.com/meta-llama/llama-stack/pull/921. We
need to move to a `ruff.toml` file as well as fixing and ignoring some
additional checks.

Signed-off-by: Yuan Tang <terrytangyuan@gmail.com>
This commit is contained in:
Yuan Tang 2025-02-02 09:46:45 -05:00 committed by GitHub
parent 4773092dd1
commit 34ab7a3b6c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
217 changed files with 981 additions and 2681 deletions

View file

@ -40,9 +40,7 @@ class MetaReferenceInferenceConfig(BaseModel):
repos = [m.huggingface_repo for m in permitted_models]
if model not in (descriptors + repos):
model_list = "\n\t".join(repos)
raise ValueError(
f"Unknown model: `{model}`. Choose from [\n\t{model_list}\n]"
)
raise ValueError(f"Unknown model: `{model}`. Choose from [\n\t{model_list}\n]")
return model
@classmethod

View file

@ -83,9 +83,7 @@ class TokenResult(BaseModel):
class Llama:
@staticmethod
def build(
config: Union[
MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
],
config: Union[MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig],
model_id: str,
llama_model: Model,
):
@ -150,9 +148,9 @@ class Llama:
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
assert model_parallel_size == len(
checkpoints
), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}"
assert model_parallel_size == len(checkpoints), (
f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}"
)
ckpt_path = checkpoints[get_model_parallel_rank()]
state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)
with open(Path(ckpt_dir) / "params.json", "r") as f:
@ -168,9 +166,9 @@ class Llama:
)
tokenizer = Tokenizer.get_instance()
assert (
model_args.vocab_size == tokenizer.n_words
), f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}"
assert model_args.vocab_size == tokenizer.n_words, (
f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}"
)
if isinstance(config, MetaReferenceQuantizedInferenceConfig):
if isinstance(config.quantization, Fp8QuantizationConfig):
@ -193,10 +191,7 @@ class Llama:
model = convert_to_int4_quantized_model(model, model_args, config)
model.load_state_dict(state_dict, strict=True)
if (
model_args.quantization_args is not None
and model_args.quantization_args.spinquant
):
if model_args.quantization_args is not None and model_args.quantization_args.spinquant:
# Add a wrapper for adding hadamard transform for spinquant.
# This needs to be done after loading the state dict otherwise an error will be raised while
# loading the state dict.
@ -206,9 +201,7 @@ class Llama:
add_hadamard_transform_for_spinquant(model)
else:
raise NotImplementedError(
"Currently int4 and fp8 are the only supported quantization methods."
)
raise NotImplementedError("Currently int4 and fp8 are the only supported quantization methods.")
else:
if device == "cuda":
if torch.cuda.is_bf16_supported():
@ -262,10 +255,7 @@ class Llama:
params = self.model.params
if print_input_tokens:
input_tokens = [
self.formatter.vision_token if t == 128256 else t
for t in model_input.tokens
]
input_tokens = [self.formatter.vision_token if t == 128256 else t for t in model_input.tokens]
log.info("Input to model -> " + self.tokenizer.decode(input_tokens))
prompt_tokens = [model_input.tokens]
@ -287,12 +277,10 @@ class Llama:
mask = model_input.vision.mask if model_input.vision is not None else []
# the method works for bsz > 1 so add a batch dimension
xattn_caches, cross_attention_masks, full_text_row_masked_out_mask = (
self.model.compute_vision_tokens_masks(
batch_images=[images],
batch_masks=[mask],
total_len=total_len,
)
xattn_caches, cross_attention_masks, full_text_row_masked_out_mask = self.model.compute_vision_tokens_masks(
batch_images=[images],
batch_masks=[mask],
total_len=total_len,
)
pad_id = self.tokenizer.pad_id
@ -340,9 +328,7 @@ class Llama:
next_token = next_token.reshape(-1)
# only replace token if prompt has already been generated
next_token = torch.where(
input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
)
next_token = torch.where(input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token)
tokens[:, cur_pos] = next_token
target = tokens[:, prev_pos + 1 : cur_pos + 1]
@ -365,17 +351,11 @@ class Llama:
reduction="none",
ignore_index=pad_id,
)
eos_reached |= (~input_text_mask[:, cur_pos]) & (
torch.isin(next_token, stop_tokens)
)
eos_reached |= (~input_text_mask[:, cur_pos]) & (torch.isin(next_token, stop_tokens))
yield TokenResult(
token=next_token[0].item(),
text=self.tokenizer.decode(next_token.tolist()),
logprobs=(
token_logprobs[:, cur_pos : cur_pos + 1][0].tolist()
if logprobs
else None
),
logprobs=(token_logprobs[:, cur_pos : cur_pos + 1][0].tolist() if logprobs else None),
)
prev_pos = cur_pos
@ -388,11 +368,7 @@ class Llama:
) -> Generator:
sampling_params = request.sampling_params
max_gen_len = sampling_params.max_tokens
if (
max_gen_len is None
or max_gen_len == 0
or max_gen_len >= self.model.params.max_seq_len
):
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.model.params.max_seq_len:
max_gen_len = self.model.params.max_seq_len - 1
model_input = self.formatter.encode_content(request.content)
@ -417,11 +393,7 @@ class Llama:
) -> Generator:
sampling_params = request.sampling_params
max_gen_len = sampling_params.max_tokens
if (
max_gen_len is None
or max_gen_len == 0
or max_gen_len >= self.model.params.max_seq_len
):
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.model.params.max_seq_len:
max_gen_len = self.model.params.max_seq_len - 1
temperature, top_p = _infer_sampling_params(sampling_params)
@ -473,9 +445,7 @@ class LogitsProcessor:
self.token_enforcer = token_enforcer
self.mask: Optional[torch.Tensor] = None
def process_logits(
self, tokens: torch.Tensor, scores: torch.Tensor
) -> torch.Tensor:
def process_logits(self, tokens: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
token_sequence = tokens[0, :].tolist()
allowed_tokens = self.token_enforcer.get_allowed_tokens(token_sequence)
@ -510,9 +480,7 @@ def get_logits_processor(
return LogitsProcessor(token_enforcer)
def _build_regular_tokens_list(
tokenizer: Tokenizer, vocab_size: int
) -> List[Tuple[int, str, bool]]:
def _build_regular_tokens_list(tokenizer: Tokenizer, vocab_size: int) -> List[Tuple[int, str, bool]]:
token_0 = tokenizer.encode("0", bos=False, eos=False)[-1]
regular_tokens = []

View file

@ -80,9 +80,7 @@ class MetaReferenceInferenceImpl(
async def load_model(self, model_id, llama_model) -> None:
log.info(f"Loading model `{model_id}`")
if self.config.create_distributed_process_group:
self.generator = LlamaModelParallelGenerator(
self.config, model_id, llama_model
)
self.generator = LlamaModelParallelGenerator(self.config, model_id, llama_model)
self.generator.start()
else:
self.generator = Llama.build(self.config, model_id, llama_model)
@ -100,9 +98,7 @@ class MetaReferenceInferenceImpl(
"No avaible model yet, please register your requested model or add your model in the resouces first"
)
elif request.model != self.model_id:
raise RuntimeError(
f"Model mismatch: request model: {request.model} != loaded model: {self.model_id}"
)
raise RuntimeError(f"Model mismatch: request model: {request.model} != loaded model: {self.model_id}")
async def unregister_model(self, model_id: str) -> None:
pass
@ -184,13 +180,7 @@ class MetaReferenceInferenceImpl(
if request.logprobs:
assert len(token_result.logprobs) == 1
logprobs = [
TokenLogProbs(
logprobs_by_token={
token_result.text: token_result.logprobs[0]
}
)
]
logprobs = [TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]})]
yield CompletionResponseStreamChunk(
delta=text,
@ -212,9 +202,7 @@ class MetaReferenceInferenceImpl(
for x in impl():
yield x
async def _nonstream_completion(
self, request: CompletionRequest
) -> CompletionResponse:
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
def impl():
tokens = []
logprobs = []
@ -231,13 +219,7 @@ class MetaReferenceInferenceImpl(
if request.logprobs:
assert len(token_result.logprobs) == 1
logprobs.append(
TokenLogProbs(
logprobs_by_token={
token_result.text: token_result.logprobs[0]
}
)
)
logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]}))
if stop_reason is None:
stop_reason = StopReason.out_of_tokens
@ -289,9 +271,7 @@ class MetaReferenceInferenceImpl(
self.check_model(request)
# augment and rewrite messages depending on the model
request.messages = chat_completion_request_to_messages(
request, self.llama_model.core_model_id.value
)
request.messages = chat_completion_request_to_messages(request, self.llama_model.core_model_id.value)
# download media and convert to raw content so we can send it to the model
request = await convert_request_to_raw(request)
@ -304,9 +284,7 @@ class MetaReferenceInferenceImpl(
else:
return await self._nonstream_chat_completion(request)
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest
) -> ChatCompletionResponse:
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
def impl():
tokens = []
logprobs = []
@ -323,20 +301,12 @@ class MetaReferenceInferenceImpl(
if request.logprobs:
assert len(token_result.logprobs) == 1
logprobs.append(
TokenLogProbs(
logprobs_by_token={
token_result.text: token_result.logprobs[0]
}
)
)
logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]}))
if stop_reason is None:
stop_reason = StopReason.out_of_tokens
raw_message = self.generator.formatter.decode_assistant_message(
tokens, stop_reason
)
raw_message = self.generator.formatter.decode_assistant_message(tokens, stop_reason)
return ChatCompletionResponse(
completion_message=CompletionMessage(
content=raw_message.content,
@ -352,9 +322,7 @@ class MetaReferenceInferenceImpl(
else:
return impl()
async def _stream_chat_completion(
self, request: ChatCompletionRequest
) -> AsyncGenerator:
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
def impl():
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
@ -405,13 +373,7 @@ class MetaReferenceInferenceImpl(
if request.logprobs:
assert len(token_result.logprobs) == 1
logprobs.append(
TokenLogProbs(
logprobs_by_token={
token_result.text: token_result.logprobs[0]
}
)
)
logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]}))
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
@ -424,9 +386,7 @@ class MetaReferenceInferenceImpl(
if stop_reason is None:
stop_reason = StopReason.out_of_tokens
message = self.generator.formatter.decode_assistant_message(
tokens, stop_reason
)
message = self.generator.formatter.decode_assistant_message(tokens, stop_reason)
parsed_tool_calls = len(message.tool_calls) > 0
if ipython and not parsed_tool_calls:

View file

@ -91,9 +91,7 @@ class LlamaModelParallelGenerator:
self.group = ModelParallelProcessGroup(
model_parallel_size,
init_model_cb=partial(
init_model_cb, self.config, self.model_id, self.llama_model
),
init_model_cb=partial(init_model_cb, self.config, self.model_id, self.llama_model),
)
self.group.start()
return self

View file

@ -55,47 +55,33 @@ class ProcessingMessageName(str, Enum):
class ReadyRequest(BaseModel):
type: Literal[ProcessingMessageName.ready_request] = (
ProcessingMessageName.ready_request
)
type: Literal[ProcessingMessageName.ready_request] = ProcessingMessageName.ready_request
class ReadyResponse(BaseModel):
type: Literal[ProcessingMessageName.ready_response] = (
ProcessingMessageName.ready_response
)
type: Literal[ProcessingMessageName.ready_response] = ProcessingMessageName.ready_response
class EndSentinel(BaseModel):
type: Literal[ProcessingMessageName.end_sentinel] = (
ProcessingMessageName.end_sentinel
)
type: Literal[ProcessingMessageName.end_sentinel] = ProcessingMessageName.end_sentinel
class CancelSentinel(BaseModel):
type: Literal[ProcessingMessageName.cancel_sentinel] = (
ProcessingMessageName.cancel_sentinel
)
type: Literal[ProcessingMessageName.cancel_sentinel] = ProcessingMessageName.cancel_sentinel
class TaskRequest(BaseModel):
type: Literal[ProcessingMessageName.task_request] = (
ProcessingMessageName.task_request
)
type: Literal[ProcessingMessageName.task_request] = ProcessingMessageName.task_request
task: Union[CompletionRequestWithRawContent, ChatCompletionRequestWithRawContent]
class TaskResponse(BaseModel):
type: Literal[ProcessingMessageName.task_response] = (
ProcessingMessageName.task_response
)
type: Literal[ProcessingMessageName.task_response] = ProcessingMessageName.task_response
result: TokenResult
class ExceptionResponse(BaseModel):
type: Literal[ProcessingMessageName.exception_response] = (
ProcessingMessageName.exception_response
)
type: Literal[ProcessingMessageName.exception_response] = ProcessingMessageName.exception_response
error: str
@ -189,9 +175,7 @@ def retrieve_requests(reply_socket_url: str):
group=get_model_parallel_group(),
)
if isinstance(updates[0], CancelSentinel):
log.info(
"quitting generation loop because request was cancelled"
)
log.info("quitting generation loop because request was cancelled")
break
if mp_rank_0():
@ -350,9 +334,7 @@ class ModelParallelProcessGroup:
def run_inference(
self,
req: Union[
CompletionRequestWithRawContent, ChatCompletionRequestWithRawContent
],
req: Union[CompletionRequestWithRawContent, ChatCompletionRequestWithRawContent],
) -> Generator:
assert not self.running, "inference already running"

View file

@ -19,9 +19,7 @@ try:
log.info("Using efficient FP8 operators in FBGEMM.")
except ImportError:
log.error(
"No efficient FP8 operators. Please install FBGEMM in fp8_requirements.txt."
)
log.error("No efficient FP8 operators. Please install FBGEMM in fp8_requirements.txt.")
raise
import torch
@ -60,14 +58,8 @@ def ffn_swiglu(
num_tokens: Optional[Tensor] = None,
is_memory_bounded: bool = False,
) -> Tensor:
if (
isinstance(w1, Fp8ScaledWeights)
and isinstance(w3, Fp8ScaledWeights)
and isinstance(w2, Fp8ScaledWeights)
):
return ffn_swiglu_fp8_dynamic(
x, w1, w3, w2, w1.activation_scale_ub, num_tokens, is_memory_bounded
)
if isinstance(w1, Fp8ScaledWeights) and isinstance(w3, Fp8ScaledWeights) and isinstance(w2, Fp8ScaledWeights):
return ffn_swiglu_fp8_dynamic(x, w1, w3, w2, w1.activation_scale_ub, num_tokens, is_memory_bounded)
(B, T, D) = x.shape # noqa: N806
(HD_L, D_) = w1.shape # noqa: N806
@ -146,12 +138,8 @@ def fc_fp8_dynamic(
Single w8a8 fc layer with dynamic row-wise scaling.
"""
if isinstance(w, Fp8RowwiseWeights):
xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(
x, num_tokens, activation_scale_ub
)
y = torch.ops.fbgemm.f8f8bf16_rowwise(
xq, w.weight, x_scale, w.scale, use_fast_accum=True
)
xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(x, num_tokens, activation_scale_ub)
y = torch.ops.fbgemm.f8f8bf16_rowwise(xq, w.weight, x_scale, w.scale, use_fast_accum=True)
del xq
return y

View file

@ -17,8 +17,7 @@ from torch import Tensor
@unittest.skipIf(
not torch.cuda.is_available()
or torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9,
not torch.cuda.is_available() or torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9,
"Skip when H100 is not available",
)
class FP8Tests(unittest.TestCase):

View file

@ -57,9 +57,7 @@ class HadamardModule(torch.nn.Module):
return x
def add_hadamard_transform_for_spinquant(
model: torch.nn.Module, prefix: str = ""
) -> None:
def add_hadamard_transform_for_spinquant(model: torch.nn.Module, prefix: str = "") -> None:
"""
Adds a Hadamard transform to the last linear layer of each feedforward network (FFN) in the model.
This function recursively traverses the model's children and looks for layers that match the pattern
@ -81,12 +79,8 @@ def add_hadamard_transform_for_spinquant(
for module_name, module in model.named_children():
child_full_name = prefix + "." + module_name
if re.search(pattern_last_linear_ffn, child_full_name):
new_module = nn.Sequential(
HadamardModule(group_size=module.in_features), module
)
new_module = nn.Sequential(HadamardModule(group_size=module.in_features), module)
del module
setattr(model, module_name, new_module)
else:
add_hadamard_transform_for_spinquant(
module, (prefix + "." if prefix else prefix) + module_name
)
add_hadamard_transform_for_spinquant(module, (prefix + "." if prefix else prefix) + module_name)

View file

@ -63,12 +63,8 @@ def convert_to_fp8_quantized_model(
# Move weights to GPU with quantization
if llama_model.quantization_format == CheckpointQuantizationFormat.fp8_mixed.value:
log.info("Loading fp8 scales...")
fp8_scales_path = os.path.join(
checkpoint_dir, f"fp8_scales_{get_model_parallel_rank()}.pt"
)
assert os.path.isfile(
fp8_scales_path
), f"fp8_scales_path not found for rank {get_model_parallel_rank()}"
fp8_scales_path = os.path.join(checkpoint_dir, f"fp8_scales_{get_model_parallel_rank()}.pt")
assert os.path.isfile(fp8_scales_path), f"fp8_scales_path not found for rank {get_model_parallel_rank()}"
fp8_scales = torch.load(fp8_scales_path, weights_only=True)
for block in model.layers:
@ -81,9 +77,7 @@ def convert_to_fp8_quantized_model(
param = getattr(block.feed_forward, key)
param.weight = load_fp8(
param.weight,
fp8_scales[
f"{block.layer_id}_feed_forward.{key}_{get_model_parallel_rank()}"
],
fp8_scales[f"{block.layer_id}_feed_forward.{key}_{get_model_parallel_rank()}"],
fp8_activation_scale_ub,
)
else:
@ -172,9 +166,7 @@ class Int8DynActInt4WeightLinearLoRA(Int8DynActInt4WeightLinear):
if prefix + "zeros" not in state_dict:
# Zero-point may not be saved in the state dict. In this case, we assume it's zero.
assert prefix + "scales" in state_dict
state_dict[prefix + "zeros"] = torch.zeros_like(
state_dict[prefix + "scales"]
)
state_dict[prefix + "zeros"] = torch.zeros_like(state_dict[prefix + "scales"])
def forward(self, input_: torch.Tensor) -> torch.Tensor:
module_out = super().forward(input_)
@ -229,9 +221,7 @@ class Int8WeightLinear(torch.nn.Linear):
bias: Whether to use bias.
"""
def __init__(
self, in_features: int, out_features: int, bias: bool = True, device=None
) -> None:
def __init__(self, in_features: int, out_features: int, bias: bool = True, device=None) -> None:
super().__init__(in_features, out_features, bias, device=device)
self._register_load_state_dict_pre_hook(self.load_hook)
@ -295,9 +285,7 @@ def _prepare_model_int4_weight_int8_dynamic_activation(
del module
setattr(model, module_name, quantized_module)
else:
_prepare_model_int4_weight_int8_dynamic_activation(
module, group_size, lora_rank, lora_scale
)
_prepare_model_int4_weight_int8_dynamic_activation(module, group_size, lora_rank, lora_scale)
return model
@ -321,9 +309,7 @@ def convert_to_int4_quantized_model(
group_size = model_args.quantization_args.group_size
if group_size is None:
raise ValueError(
"'group_size' cannot be None in 'quantization_args'. Please specify it."
)
raise ValueError("'group_size' cannot be None in 'quantization_args'. Please specify it.")
if model_args.lora_args is None:
# Certain quantized models (e.g., SpinQuant) may not have LoRA.
@ -333,8 +319,6 @@ def convert_to_int4_quantized_model(
lora_rank = model_args.lora_args.rank
lora_scale = model_args.lora_args.scale
_prepare_model_int4_weight_int8_dynamic_activation(
model, group_size, lora_rank, lora_scale
)
_prepare_model_int4_weight_int8_dynamic_activation(model, group_size, lora_rank, lora_scale)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
return model.to(device)

View file

@ -76,9 +76,9 @@ def main(
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
assert model_parallel_size == len(
checkpoints
), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}"
assert model_parallel_size == len(checkpoints), (
f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}"
)
ckpt_path = checkpoints[get_model_parallel_rank()]
checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=True)
with open(Path(ckpt_dir) / "params.json", "r") as f:
@ -90,9 +90,9 @@ def main(
**params,
)
tokenizer = Tokenizer(model_path=tokenizer_path)
assert (
model_args.vocab_size == tokenizer.n_words
), f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}"
assert model_args.vocab_size == tokenizer.n_words, (
f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}"
)
# load on CPU in bf16 so that fp8 conversion does not find an unexpected (fp32, e.g.) datatype
torch.set_default_tensor_type(torch.BFloat16Tensor)
@ -106,9 +106,7 @@ def main(
torch.set_default_tensor_type(torch.cuda.HalfTensor)
log.info(ckpt_path)
assert (
quantized_ckpt_dir is not None
), "QUantized checkpoint directory should not be None"
assert quantized_ckpt_dir is not None, "QUantized checkpoint directory should not be None"
fp8_scales = {}
for block in model.layers:
if isinstance(block, TransformerBlock):
@ -122,9 +120,7 @@ def main(
)
with torch.inference_mode():
block.feed_forward.w1.weight = Parameter(fp8_weight.weight)
fp8_scales[
f"{block.layer_id}_feed_forward.w1_{get_model_parallel_rank()}"
] = fp8_weight.scale
fp8_scales[f"{block.layer_id}_feed_forward.w1_{get_model_parallel_rank()}"] = fp8_weight.scale
fp8_weight = quantize_fp8(
block.feed_forward.w3.weight,
@ -133,9 +129,7 @@ def main(
)
with torch.inference_mode():
block.feed_forward.w3.weight = Parameter(fp8_weight.weight)
fp8_scales[
f"{block.layer_id}_feed_forward.w3_{get_model_parallel_rank()}"
] = fp8_weight.scale
fp8_scales[f"{block.layer_id}_feed_forward.w3_{get_model_parallel_rank()}"] = fp8_weight.scale
fp8_weight = quantize_fp8(
block.feed_forward.w2.weight,
@ -144,13 +138,9 @@ def main(
)
with torch.inference_mode():
block.feed_forward.w2.weight = Parameter(fp8_weight.weight)
fp8_scales[
f"{block.layer_id}_feed_forward.w2_{get_model_parallel_rank()}"
] = fp8_weight.scale
fp8_scales[f"{block.layer_id}_feed_forward.w2_{get_model_parallel_rank()}"] = fp8_weight.scale
fp8_scales_path = os.path.join(
quantized_ckpt_dir, f"fp8_scales_{get_model_parallel_rank()}.pt"
)
fp8_scales_path = os.path.join(quantized_ckpt_dir, f"fp8_scales_{get_model_parallel_rank()}.pt")
torch.save(fp8_scales, fp8_scales_path)
ckpt_path = os.path.join(

View file

@ -10,7 +10,6 @@ from pydantic import BaseModel
class SentenceTransformersInferenceConfig(BaseModel):
@classmethod
def sample_run_config(cls) -> Dict[str, Any]:
return {}

View file

@ -53,7 +53,5 @@ class VLLMConfig(BaseModel):
repos = [m.huggingface_repo for m in permitted_models]
if model not in (descriptors + repos):
model_list = "\n\t".join(repos)
raise ValueError(
f"Unknown model: `{model}`. Choose from [\n\t{model_list}\n]"
)
raise ValueError(f"Unknown model: `{model}`. Choose from [\n\t{model_list}\n]")
return model

View file

@ -176,13 +176,9 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
log.info("Sampling params: %s", sampling_params)
request_id = _random_uuid()
prompt = await chat_completion_request_to_prompt(
request, self.config.model, self.formatter
)
prompt = await chat_completion_request_to_prompt(request, self.config.model, self.formatter)
vllm_sampling_params = self._sampling_params(request.sampling_params)
results_generator = self.engine.generate(
prompt, vllm_sampling_params, request_id
)
results_generator = self.engine.generate(prompt, vllm_sampling_params, request_id)
if stream:
return self._stream_chat_completion(request, results_generator)
else:
@ -230,12 +226,8 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
)
stream = _generate_and_convert_to_openai_compat()
async for chunk in process_chat_completion_stream_response(
stream, self.formatter
):
async for chunk in process_chat_completion_stream_response(stream, self.formatter):
yield chunk
async def embeddings(
self, model_id: str, contents: List[InterleavedContent]
) -> EmbeddingsResponse:
async def embeddings(self, model_id: str, contents: List[InterleavedContent]) -> EmbeddingsResponse:
raise NotImplementedError()