This commit is contained in:
Ashwin Bharambe 2024-09-28 15:32:03 -07:00
parent 23028e26ff
commit e61c4954d5
3 changed files with 6 additions and 9 deletions

View file

@ -185,11 +185,11 @@ class Llama:
) -> Generator:
params = self.model.params
input_tokens = [
self.formatter.vision_token if t == 128256 else t
for t in model_input.tokens
]
cprint("Input to model -> " + self.tokenizer.decode(input_tokens), "red")
# input_tokens = [
# self.formatter.vision_token if t == 128256 else t
# for t in model_input.tokens
# ]
# cprint("Input to model -> " + self.tokenizer.decode(input_tokens), "red")
prompt_tokens = [model_input.tokens]
bsz = 1
@ -207,7 +207,6 @@ class Llama:
total_len = min(max_gen_len + max_prompt_len, params.max_seq_len)
is_vision = isinstance(self.model, CrossAttentionTransformer)
print(f"{is_vision=}")
if is_vision:
images = model_input.vision.images if model_input.vision is not None else []
mask = model_input.vision.mask if model_input.vision is not None else []