[inference api] modify content types so they follow a more standard structure (#841)

Some small updates to the inference types to make them more standard

Specifically:
- image data is now located in a "image" subkey
- similarly tool call data is located in a "tool_call" subkey

The pattern followed is `dict(type="foo", foo=<...>)`
This commit is contained in:
Ashwin Bharambe 2025-01-22 12:16:18 -08:00 committed by GitHub
parent caa8387dd2
commit 07b87365ab
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 104 additions and 76 deletions

View file

@ -113,28 +113,29 @@ async def interleaved_content_convert_to_raw(
elif isinstance(c, TextContentItem):
return RawTextItem(text=c.text)
elif isinstance(c, ImageContentItem):
if c.url:
image = c.image
if image.url:
# Load image bytes from URL
if c.url.uri.startswith("data"):
match = re.match(r"data:image/(\w+);base64,(.+)", c.url.uri)
if image.url.uri.startswith("data"):
match = re.match(r"data:image/(\w+);base64,(.+)", image.url.uri)
if not match:
raise ValueError(
f"Invalid data URL format, {c.url.uri[:40]}..."
f"Invalid data URL format, {image.url.uri[:40]}..."
)
_, image_data = match.groups()
data = base64.b64decode(image_data)
elif c.url.uri.startswith("file://"):
path = c.url.uri[len("file://") :]
elif image.url.uri.startswith("file://"):
path = image.url.uri[len("file://") :]
with open(path, "rb") as f:
data = f.read() # type: ignore
elif c.url.uri.startswith("http"):
elif image.url.uri.startswith("http"):
async with httpx.AsyncClient() as client:
response = await client.get(c.url.uri)
response = await client.get(image.url.uri)
data = response.content
else:
raise ValueError("Unsupported URL type")
elif c.data:
data = c.data
elif image.data:
data = image.data
else:
raise ValueError("No data or URL provided")
@ -170,26 +171,29 @@ def request_has_media(request: Union[ChatCompletionRequest, CompletionRequest]):
async def localize_image_content(media: ImageContentItem) -> Tuple[bytes, str]:
if media.url and media.url.uri.startswith("http"):
image = media.image
if image.url and image.url.uri.startswith("http"):
async with httpx.AsyncClient() as client:
r = await client.get(media.url.uri)
r = await client.get(image.url.uri)
content = r.content
content_type = r.headers.get("content-type")
if content_type:
format = content_type.split("/")[-1]
else:
format = "png"
return content, format
else:
image = PIL_Image.open(io.BytesIO(media.data))
return media.data, image.format
pil_image = PIL_Image.open(io.BytesIO(image.data))
return image.data, pil_image.format
async def convert_image_content_to_url(
media: ImageContentItem, download: bool = False, include_format: bool = True
) -> str:
if media.url and (not download or media.url.uri.startswith("data")):
return media.url.uri
image = media.image
if image.url and (not download or image.url.uri.startswith("data")):
return image.url.uri
content, format = await localize_image_content(media)
if include_format: