several fixes

This commit is contained in:
Ashwin Bharambe 2025-04-07 10:31:20 -07:00
parent e2e2820c9a
commit 53a8086e37
60 changed files with 1006 additions and 1078 deletions

View file

@ -12,6 +12,7 @@ from typing import Dict, List, Optional, Tuple
import torch
from PIL import Image as PIL_Image
# TODO: either fork these or move them to the common package
from ..datatypes import (
BuiltinTool,
RawContent,
@ -26,10 +27,7 @@ from ..datatypes import (
from ..llama3.tool_utils import ToolUtils
from .args import VisionArgs
from .datatypes import LLMInput
from .preprocess import (
ResizeNormalizeImageTransform,
VariableSizeImageTransform,
)
from .preprocess import ResizeNormalizeImageTransform, VariableSizeImageTransform
from .tokenizer import Tokenizer
@ -50,7 +48,7 @@ class TransformedImage:
aspect_ratio: Tuple[int, int]
def convert_rgba_to_rgb(image: PIL_Image.Image, bg: Tuple[int, int, int] = (255, 255, 255)) -> PIL_Image.Image:
def convert_image_to_rgb(image: PIL_Image.Image, bg: Tuple[int, int, int] = (255, 255, 255)) -> PIL_Image.Image:
if image.mode == "RGBA":
image.load() # for png.split()
new_img = PIL_Image.new("RGB", image.size, bg)
@ -167,7 +165,7 @@ class ChatFormat:
bytes_io = io.BytesIO(c.data) if isinstance(c.data, bytes) else c.data
image = PIL_Image.open(bytes_io)
image = convert_rgba_to_rgb(image)
image = convert_image_to_rgb(image)
image_tiles, ar = self.dynamic_image_transform(image, max_num_chunks=self.max_num_chunks)
if image_tiles.shape[0] > 1:
@ -212,12 +210,9 @@ class ChatFormat:
content = ToolUtils.encode_tool_call(t, tool_prompt_format)
_process_content(content)
# Tool calls and Tool Response messages should be eom
eom = False
if message.role == "assistant":
eom = message.stop_reason == StopReason.end_of_message or message.tool_calls
elif message.role == "tool":
eom = True
eom = message.stop_reason == StopReason.end_of_message
tokens.append(self.tokenizer.special_tokens["<|eom|>" if eom else "<|eot|>"])
return tokens, images
@ -252,11 +247,6 @@ class ChatFormat:
if content.startswith(header_str):
content = content[len(header_str) :]
ipython = content.startswith("<|python_start|>")
if ipython:
content = content[len("<|python_start|>") :]
content = content.replace("<|python_end|>", "")
if content.endswith("<|eot|>"):
content = content[: -len("<|eot|>")]
stop_reason = StopReason.end_of_turn
@ -287,11 +277,6 @@ class ChatFormat:
}
if tool_name in BuiltinTool.__members__:
tool_name = BuiltinTool[tool_name]
elif ipython:
tool_name = BuiltinTool.code_interpreter
tool_arguments = {
"code": content,
}
tool_calls = []
if tool_name is not None and tool_arguments is not None: