From 7ea14ae62eb2279ae3a7f7b2a484cc14b0bffa8d Mon Sep 17 00:00:00 2001 From: Dmitry Rogozhkin Date: Fri, 31 Jan 2025 12:11:49 -0800 Subject: [PATCH] feat: enable xpu support for meta-reference stack (#558) This commit adds support for XPU and CPU devices into meta-reference stack for text models. On creation stack automatically identifies which device to use checking available accelerate capabilities in the following order: CUDA, then XPU, finally CPU. This behaviour can be overwritten with the `DEVICE` environment variable. In this case explicitly specified device will be used. Tested with: ``` torchrun pytest llama_stack/providers/tests/inference/test_text_inference.py -k meta_reference ``` Results: * Tested on: system with single CUDA device, system with single XPU device and on pure CPU system * Results: all test pass except `test_completion_logprobs` * `test_completion_logprobs` fails in the same way as on a baseline, i.e. unrelated with this change: `AssertionError: Unexpected top_k=3` Requires: https://github.com/meta-llama/llama-models/pull/233 Signed-off-by: Dmitry Rogozhkin --- .../inference/meta_reference/generation.py | 54 ++++++++++++++----- 1 file changed, 41 insertions(+), 13 deletions(-) diff --git a/llama_stack/providers/inline/inference/meta_reference/generation.py b/llama_stack/providers/inline/inference/meta_reference/generation.py index a96409cab..fd18dd72d 100644 --- a/llama_stack/providers/inline/inference/meta_reference/generation.py +++ b/llama_stack/providers/inline/inference/meta_reference/generation.py @@ -96,9 +96,27 @@ class Llama: This method initializes the distributed process group, sets the device to CUDA, and loads the pre-trained model and tokenizer. """ + if "DEVICE" in os.environ: + device = os.environ.get("DEVICE") + if device == "cuda": + assert torch.cuda.is_available(), "PyTorch CUDA backend not available" + if device == "xpu": + assert torch.xpu.is_available(), "PyTorch XPU backend not available" + else: + if torch.cuda.is_available(): + device = "cuda" + elif torch.xpu.is_available(): + device = "xpu" + else: + device = "cpu" + log.info(f"Using {device} device") + llama_model_id = llama_model.core_model_id.value if not torch.distributed.is_initialized(): - torch.distributed.init_process_group("nccl") + if device == "cuda": + torch.distributed.init_process_group("nccl") + else: + torch.distributed.init_process_group("gloo") model_parallel_size = llama_model.pth_file_count @@ -106,7 +124,10 @@ class Llama: initialize_model_parallel(model_parallel_size) local_rank = int(os.environ.get("LOCAL_RANK", 0)) - torch.cuda.set_device(local_rank) + if device == "cuda": + torch.cuda.set_device(local_rank) + elif device == "xpu": + torch.xpu.set_device(local_rank) # seed must be the same in all processes if config.torch_seed is not None: @@ -189,10 +210,17 @@ class Llama: "Currently int4 and fp8 are the only supported quantization methods." ) else: - if torch.cuda.is_bf16_supported(): - torch.set_default_tensor_type(torch.cuda.BFloat16Tensor) + if device == "cuda": + if torch.cuda.is_bf16_supported(): + torch.set_default_tensor_type(torch.cuda.BFloat16Tensor) + else: + torch.set_default_tensor_type(torch.cuda.HalfTensor) else: - torch.set_default_tensor_type(torch.cuda.HalfTensor) + torch.set_default_device(device) + if device == "xpu" and torch.xpu.is_bf16_supported(): + torch.set_default_dtype(torch.bfloat16) + else: + torch.set_default_dtype(torch.half) if model_args.vision_chunk_size > 0: model = CrossAttentionTransformer(model_args) model.setup_cache(model_args.max_batch_size, torch.bfloat16) @@ -200,6 +228,8 @@ class Llama: model = Transformer(model_args) model.load_state_dict(state_dict, strict=False) + model.to(device) + log.info(f"Loaded in {time.time() - start_time:.2f} seconds") return Llama(model, tokenizer, model_args, llama_model_id) @@ -266,14 +296,14 @@ class Llama: ) pad_id = self.tokenizer.pad_id - tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda") + tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long) for k, t in enumerate(prompt_tokens): - tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cuda") + tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long) if logprobs: - token_logprobs = torch.zeros_like(tokens, dtype=torch.float) + token_logprobs = torch.zeros_like(tokens) prev_pos = 0 - eos_reached = torch.tensor([False] * bsz, device="cuda") + eos_reached = torch.tensor([False] * bsz) input_text_mask = tokens != pad_id if min_prompt_len == total_len: # TODO(ashwin): unify this branch with the one below and figure out multimodal crap @@ -285,12 +315,10 @@ class Llama: ignore_index=pad_id, ) - stop_tokens = torch.tensor(self.tokenizer.stop_tokens, device="cuda") + stop_tokens = torch.tensor(self.tokenizer.stop_tokens) for cur_pos in range(min_prompt_len, total_len): if is_vision: - position_ids = torch.arange( - prev_pos, cur_pos, dtype=torch.long, device="cuda" - ) + position_ids = torch.arange(prev_pos, cur_pos, dtype=torch.long) logits = self.model.forward( position_ids, tokens,