diff --git a/llama_stack/providers/inline/inference/meta_reference/generation.py b/llama_stack/providers/inline/inference/meta_reference/generation.py index a96409cab..fd18dd72d 100644 --- a/llama_stack/providers/inline/inference/meta_reference/generation.py +++ b/llama_stack/providers/inline/inference/meta_reference/generation.py @@ -96,9 +96,27 @@ class Llama: This method initializes the distributed process group, sets the device to CUDA, and loads the pre-trained model and tokenizer. """ + if "DEVICE" in os.environ: + device = os.environ.get("DEVICE") + if device == "cuda": + assert torch.cuda.is_available(), "PyTorch CUDA backend not available" + if device == "xpu": + assert torch.xpu.is_available(), "PyTorch XPU backend not available" + else: + if torch.cuda.is_available(): + device = "cuda" + elif torch.xpu.is_available(): + device = "xpu" + else: + device = "cpu" + log.info(f"Using {device} device") + llama_model_id = llama_model.core_model_id.value if not torch.distributed.is_initialized(): - torch.distributed.init_process_group("nccl") + if device == "cuda": + torch.distributed.init_process_group("nccl") + else: + torch.distributed.init_process_group("gloo") model_parallel_size = llama_model.pth_file_count @@ -106,7 +124,10 @@ class Llama: initialize_model_parallel(model_parallel_size) local_rank = int(os.environ.get("LOCAL_RANK", 0)) - torch.cuda.set_device(local_rank) + if device == "cuda": + torch.cuda.set_device(local_rank) + elif device == "xpu": + torch.xpu.set_device(local_rank) # seed must be the same in all processes if config.torch_seed is not None: @@ -189,10 +210,17 @@ class Llama: "Currently int4 and fp8 are the only supported quantization methods." ) else: - if torch.cuda.is_bf16_supported(): - torch.set_default_tensor_type(torch.cuda.BFloat16Tensor) + if device == "cuda": + if torch.cuda.is_bf16_supported(): + torch.set_default_tensor_type(torch.cuda.BFloat16Tensor) + else: + torch.set_default_tensor_type(torch.cuda.HalfTensor) else: - torch.set_default_tensor_type(torch.cuda.HalfTensor) + torch.set_default_device(device) + if device == "xpu" and torch.xpu.is_bf16_supported(): + torch.set_default_dtype(torch.bfloat16) + else: + torch.set_default_dtype(torch.half) if model_args.vision_chunk_size > 0: model = CrossAttentionTransformer(model_args) model.setup_cache(model_args.max_batch_size, torch.bfloat16) @@ -200,6 +228,8 @@ class Llama: model = Transformer(model_args) model.load_state_dict(state_dict, strict=False) + model.to(device) + log.info(f"Loaded in {time.time() - start_time:.2f} seconds") return Llama(model, tokenizer, model_args, llama_model_id) @@ -266,14 +296,14 @@ class Llama: ) pad_id = self.tokenizer.pad_id - tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda") + tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long) for k, t in enumerate(prompt_tokens): - tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cuda") + tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long) if logprobs: - token_logprobs = torch.zeros_like(tokens, dtype=torch.float) + token_logprobs = torch.zeros_like(tokens) prev_pos = 0 - eos_reached = torch.tensor([False] * bsz, device="cuda") + eos_reached = torch.tensor([False] * bsz) input_text_mask = tokens != pad_id if min_prompt_len == total_len: # TODO(ashwin): unify this branch with the one below and figure out multimodal crap @@ -285,12 +315,10 @@ class Llama: ignore_index=pad_id, ) - stop_tokens = torch.tensor(self.tokenizer.stop_tokens, device="cuda") + stop_tokens = torch.tensor(self.tokenizer.stop_tokens) for cur_pos in range(min_prompt_len, total_len): if is_vision: - position_ids = torch.arange( - prev_pos, cur_pos, dtype=torch.long, device="cuda" - ) + position_ids = torch.arange(prev_pos, cur_pos, dtype=torch.long) logits = self.model.forward( position_ids, tokens,