forked from phoenix-oss/llama-stack-mirror
Fix fp8 implementation which had bit-rotten a bit
I only tested with "on-the-fly" bf16 -> fp8 conversion, not the "load from fp8" codepath. YAML I tested with: ``` providers: - provider_id: quantized provider_type: meta-reference-quantized config: model: Llama3.1-8B-Instruct quantization: type: fp8 ```
This commit is contained in:
parent
80ada04f76
commit
09b793c4d6
2 changed files with 10 additions and 7 deletions
|
@ -138,7 +138,7 @@ class Llama:
|
|||
else:
|
||||
model = Transformer(model_args)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
model = convert_to_quantized_model(model, config)
|
||||
model = convert_to_quantized_model(model, config, ckpt_dir)
|
||||
else:
|
||||
if torch.cuda.is_bf16_supported():
|
||||
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
|
||||
|
@ -228,8 +228,7 @@ class Llama:
|
|||
ignore_index=pad_id,
|
||||
)
|
||||
|
||||
stop_tokens = torch.tensor(self.tokenizer.stop_tokens)
|
||||
|
||||
stop_tokens = torch.tensor(self.tokenizer.stop_tokens, device="cuda")
|
||||
for cur_pos in range(min_prompt_len, total_len):
|
||||
if is_vision:
|
||||
position_ids = torch.arange(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue