mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-31 03:23:53 +00:00
fix test, fix llama3 generator
This commit is contained in:
parent
a3cee70014
commit
771daa4b91
2 changed files with 12 additions and 47 deletions
|
|
@ -154,7 +154,7 @@ class Llama3:
|
|||
@torch.inference_mode()
|
||||
def generate(
|
||||
self,
|
||||
model_inputs: List[LLMInput],
|
||||
llm_inputs: List[LLMInput],
|
||||
temperature: float = 0.6,
|
||||
top_p: float = 0.9,
|
||||
max_gen_len: Optional[int] = None,
|
||||
|
|
@ -169,15 +169,15 @@ class Llama3:
|
|||
|
||||
print_model_input = print_model_input or os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1"
|
||||
if print_model_input:
|
||||
for inp in model_inputs:
|
||||
for inp in llm_inputs:
|
||||
tokens_to_print = [self.formatter.vision_token if t == 128256 else t for t in inp.tokens]
|
||||
cprint(
|
||||
"Input to model:\n" + self.tokenizer.decode(tokens_to_print) + "\n",
|
||||
"red",
|
||||
)
|
||||
prompt_tokens = [inp.tokens for inp in model_inputs]
|
||||
prompt_tokens = [inp.tokens for inp in llm_inputs]
|
||||
|
||||
bsz = len(model_inputs)
|
||||
bsz = len(llm_inputs)
|
||||
assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
|
||||
|
||||
min_prompt_len = min(len(t) for t in prompt_tokens)
|
||||
|
|
@ -198,8 +198,8 @@ class Llama3:
|
|||
|
||||
is_vision = not isinstance(self.model, Transformer)
|
||||
if is_vision:
|
||||
images = [inp.vision.images if inp.vision is not None else [] for inp in model_inputs]
|
||||
mask = [inp.vision.mask if inp.vision is not None else [] for inp in model_inputs]
|
||||
images = [inp.vision.images if inp.vision is not None else [] for inp in llm_inputs]
|
||||
mask = [inp.vision.mask if inp.vision is not None else [] for inp in llm_inputs]
|
||||
|
||||
xattn_caches, cross_attention_masks, full_text_row_masked_out_mask = self.model.compute_vision_tokens_masks(
|
||||
batch_images=images,
|
||||
|
|
@ -234,7 +234,7 @@ class Llama3:
|
|||
for cur_pos in range(min_prompt_len, total_len):
|
||||
if is_vision:
|
||||
position_ids = torch.arange(prev_pos, cur_pos, dtype=torch.long)
|
||||
text_only_inference = all(inp.vision is None for inp in model_inputs)
|
||||
text_only_inference = all(inp.vision is None for inp in llm_inputs)
|
||||
logits = self.model.forward(
|
||||
position_ids,
|
||||
tokens,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue