forked from phoenix-oss/llama-stack-mirror
feat: enable xpu support for meta-reference stack (#558)
This commit adds support for XPU and CPU devices into meta-reference stack for text models. On creation stack automatically identifies which device to use checking available accelerate capabilities in the following order: CUDA, then XPU, finally CPU. This behaviour can be overwritten with the `DEVICE` environment variable. In this case explicitly specified device will be used. Tested with: ``` torchrun pytest llama_stack/providers/tests/inference/test_text_inference.py -k meta_reference ``` Results: * Tested on: system with single CUDA device, system with single XPU device and on pure CPU system * Results: all test pass except `test_completion_logprobs` * `test_completion_logprobs` fails in the same way as on a baseline, i.e. unrelated with this change: `AssertionError: Unexpected top_k=3` Requires: https://github.com/meta-llama/llama-models/pull/233 Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>
This commit is contained in:
parent
15dcc4ea5e
commit
7ea14ae62e
1 changed files with 41 additions and 13 deletions
|
@ -96,9 +96,27 @@ class Llama:
|
||||||
This method initializes the distributed process group, sets the device to CUDA,
|
This method initializes the distributed process group, sets the device to CUDA,
|
||||||
and loads the pre-trained model and tokenizer.
|
and loads the pre-trained model and tokenizer.
|
||||||
"""
|
"""
|
||||||
|
if "DEVICE" in os.environ:
|
||||||
|
device = os.environ.get("DEVICE")
|
||||||
|
if device == "cuda":
|
||||||
|
assert torch.cuda.is_available(), "PyTorch CUDA backend not available"
|
||||||
|
if device == "xpu":
|
||||||
|
assert torch.xpu.is_available(), "PyTorch XPU backend not available"
|
||||||
|
else:
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = "cuda"
|
||||||
|
elif torch.xpu.is_available():
|
||||||
|
device = "xpu"
|
||||||
|
else:
|
||||||
|
device = "cpu"
|
||||||
|
log.info(f"Using {device} device")
|
||||||
|
|
||||||
llama_model_id = llama_model.core_model_id.value
|
llama_model_id = llama_model.core_model_id.value
|
||||||
if not torch.distributed.is_initialized():
|
if not torch.distributed.is_initialized():
|
||||||
torch.distributed.init_process_group("nccl")
|
if device == "cuda":
|
||||||
|
torch.distributed.init_process_group("nccl")
|
||||||
|
else:
|
||||||
|
torch.distributed.init_process_group("gloo")
|
||||||
|
|
||||||
model_parallel_size = llama_model.pth_file_count
|
model_parallel_size = llama_model.pth_file_count
|
||||||
|
|
||||||
|
@ -106,7 +124,10 @@ class Llama:
|
||||||
initialize_model_parallel(model_parallel_size)
|
initialize_model_parallel(model_parallel_size)
|
||||||
|
|
||||||
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||||
torch.cuda.set_device(local_rank)
|
if device == "cuda":
|
||||||
|
torch.cuda.set_device(local_rank)
|
||||||
|
elif device == "xpu":
|
||||||
|
torch.xpu.set_device(local_rank)
|
||||||
|
|
||||||
# seed must be the same in all processes
|
# seed must be the same in all processes
|
||||||
if config.torch_seed is not None:
|
if config.torch_seed is not None:
|
||||||
|
@ -189,10 +210,17 @@ class Llama:
|
||||||
"Currently int4 and fp8 are the only supported quantization methods."
|
"Currently int4 and fp8 are the only supported quantization methods."
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if torch.cuda.is_bf16_supported():
|
if device == "cuda":
|
||||||
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
|
if torch.cuda.is_bf16_supported():
|
||||||
|
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
|
||||||
|
else:
|
||||||
|
torch.set_default_tensor_type(torch.cuda.HalfTensor)
|
||||||
else:
|
else:
|
||||||
torch.set_default_tensor_type(torch.cuda.HalfTensor)
|
torch.set_default_device(device)
|
||||||
|
if device == "xpu" and torch.xpu.is_bf16_supported():
|
||||||
|
torch.set_default_dtype(torch.bfloat16)
|
||||||
|
else:
|
||||||
|
torch.set_default_dtype(torch.half)
|
||||||
if model_args.vision_chunk_size > 0:
|
if model_args.vision_chunk_size > 0:
|
||||||
model = CrossAttentionTransformer(model_args)
|
model = CrossAttentionTransformer(model_args)
|
||||||
model.setup_cache(model_args.max_batch_size, torch.bfloat16)
|
model.setup_cache(model_args.max_batch_size, torch.bfloat16)
|
||||||
|
@ -200,6 +228,8 @@ class Llama:
|
||||||
model = Transformer(model_args)
|
model = Transformer(model_args)
|
||||||
model.load_state_dict(state_dict, strict=False)
|
model.load_state_dict(state_dict, strict=False)
|
||||||
|
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
log.info(f"Loaded in {time.time() - start_time:.2f} seconds")
|
log.info(f"Loaded in {time.time() - start_time:.2f} seconds")
|
||||||
return Llama(model, tokenizer, model_args, llama_model_id)
|
return Llama(model, tokenizer, model_args, llama_model_id)
|
||||||
|
|
||||||
|
@ -266,14 +296,14 @@ class Llama:
|
||||||
)
|
)
|
||||||
|
|
||||||
pad_id = self.tokenizer.pad_id
|
pad_id = self.tokenizer.pad_id
|
||||||
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda")
|
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long)
|
||||||
for k, t in enumerate(prompt_tokens):
|
for k, t in enumerate(prompt_tokens):
|
||||||
tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
|
tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long)
|
||||||
if logprobs:
|
if logprobs:
|
||||||
token_logprobs = torch.zeros_like(tokens, dtype=torch.float)
|
token_logprobs = torch.zeros_like(tokens)
|
||||||
|
|
||||||
prev_pos = 0
|
prev_pos = 0
|
||||||
eos_reached = torch.tensor([False] * bsz, device="cuda")
|
eos_reached = torch.tensor([False] * bsz)
|
||||||
input_text_mask = tokens != pad_id
|
input_text_mask = tokens != pad_id
|
||||||
if min_prompt_len == total_len:
|
if min_prompt_len == total_len:
|
||||||
# TODO(ashwin): unify this branch with the one below and figure out multimodal crap
|
# TODO(ashwin): unify this branch with the one below and figure out multimodal crap
|
||||||
|
@ -285,12 +315,10 @@ class Llama:
|
||||||
ignore_index=pad_id,
|
ignore_index=pad_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
stop_tokens = torch.tensor(self.tokenizer.stop_tokens, device="cuda")
|
stop_tokens = torch.tensor(self.tokenizer.stop_tokens)
|
||||||
for cur_pos in range(min_prompt_len, total_len):
|
for cur_pos in range(min_prompt_len, total_len):
|
||||||
if is_vision:
|
if is_vision:
|
||||||
position_ids = torch.arange(
|
position_ids = torch.arange(prev_pos, cur_pos, dtype=torch.long)
|
||||||
prev_pos, cur_pos, dtype=torch.long, device="cuda"
|
|
||||||
)
|
|
||||||
logits = self.model.forward(
|
logits = self.model.forward(
|
||||||
position_ids,
|
position_ids,
|
||||||
tokens,
|
tokens,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue