fix types

This commit is contained in:
Xi Yan 2024-12-30 14:32:21 -08:00
parent 332283500a
commit 50b1487244

View file

@ -40,7 +40,6 @@ from llama_stack.apis.common.content_types import (
InterleavedContent, InterleavedContent,
InterleavedContentItem, InterleavedContentItem,
TextContentItem, TextContentItem,
URL,
) )
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
@ -117,27 +116,49 @@ async def interleaved_content_convert_to_raw(
elif isinstance(c, TextContentItem): elif isinstance(c, TextContentItem):
return RawTextItem(text=c.text) return RawTextItem(text=c.text)
elif isinstance(c, ImageContentItem): elif isinstance(c, ImageContentItem):
# load image and return PIL version if c.url:
img = c.data if c.url.uri.startswith("data"):
if isinstance(img, URL):
if img.uri.startswith("data"):
match = re.match(r"data:image/(\w+);base64,(.+)", img.uri) match = re.match(r"data:image/(\w+);base64,(.+)", img.uri)
if not match: if not match:
raise ValueError("Invalid data URL format") raise ValueError("Invalid data URL format")
_, image_data = match.groups() _, image_data = match.groups()
data = base64.b64decode(image_data) data = base64.b64decode(image_data)
elif img.uri.startswith("file://"): elif c.url.uri.startswith("file://"):
path = img.uri[len("file://") :] path = c.url.uri[len("file://") :]
with open(path, "rb") as f: with open(path, "rb") as f:
data = f.read() # type: ignore data = f.read() # type: ignore
elif img.uri.startswith("http"): elif c.url.uri.startswith("http"):
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.get(img.uri) response = await client.get(c.url.uri)
data = response.content data = response.content
else: else:
raise ValueError("Unsupported URL type") raise ValueError("Unsupported URL type")
else: elif c.data:
data = c.data data = c.data
else:
raise ValueError("No data or URL provided")
# # load image and return PIL version
# img = c.data
# if isinstance(img, URL):
# if img.uri.startswith("data"):
# match = re.match(r"data:image/(\w+);base64,(.+)", img.uri)
# if not match:
# raise ValueError("Invalid data URL format")
# _, image_data = match.groups()
# data = base64.b64decode(image_data)
# elif img.uri.startswith("file://"):
# path = img.uri[len("file://") :]
# with open(path, "rb") as f:
# data = f.read() # type: ignore
# elif img.uri.startswith("http"):
# async with httpx.AsyncClient() as client:
# response = await client.get(img.uri)
# data = response.content
# else:
# raise ValueError("Unsupported URL type")
# else:
# data = c.data
print("type of data", type(data)) print("type of data", type(data))
return RawMediaItem(data=data) return RawMediaItem(data=data)