From 50b14872443cc101fac8c314d54e0f1484daa38b Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Mon, 30 Dec 2024 14:32:21 -0800 Subject: [PATCH] fix types --- .../utils/inference/prompt_adapter.py | 41 ++++++++++++++----- 1 file changed, 31 insertions(+), 10 deletions(-) diff --git a/llama_stack/providers/utils/inference/prompt_adapter.py b/llama_stack/providers/utils/inference/prompt_adapter.py index 7ed6dbdfc..8664e68f2 100644 --- a/llama_stack/providers/utils/inference/prompt_adapter.py +++ b/llama_stack/providers/utils/inference/prompt_adapter.py @@ -40,7 +40,6 @@ from llama_stack.apis.common.content_types import ( InterleavedContent, InterleavedContentItem, TextContentItem, - URL, ) from llama_stack.apis.inference import ( @@ -117,27 +116,49 @@ async def interleaved_content_convert_to_raw( elif isinstance(c, TextContentItem): return RawTextItem(text=c.text) elif isinstance(c, ImageContentItem): - # load image and return PIL version - img = c.data - if isinstance(img, URL): - if img.uri.startswith("data"): + if c.url: + if c.url.uri.startswith("data"): match = re.match(r"data:image/(\w+);base64,(.+)", img.uri) if not match: raise ValueError("Invalid data URL format") _, image_data = match.groups() data = base64.b64decode(image_data) - elif img.uri.startswith("file://"): - path = img.uri[len("file://") :] + elif c.url.uri.startswith("file://"): + path = c.url.uri[len("file://") :] with open(path, "rb") as f: data = f.read() # type: ignore - elif img.uri.startswith("http"): + elif c.url.uri.startswith("http"): async with httpx.AsyncClient() as client: - response = await client.get(img.uri) + response = await client.get(c.url.uri) data = response.content else: raise ValueError("Unsupported URL type") - else: + elif c.data: data = c.data + else: + raise ValueError("No data or URL provided") + + # # load image and return PIL version + # img = c.data + # if isinstance(img, URL): + # if img.uri.startswith("data"): + # match = re.match(r"data:image/(\w+);base64,(.+)", img.uri) + # if not match: + # raise ValueError("Invalid data URL format") + # _, image_data = match.groups() + # data = base64.b64decode(image_data) + # elif img.uri.startswith("file://"): + # path = img.uri[len("file://") :] + # with open(path, "rb") as f: + # data = f.read() # type: ignore + # elif img.uri.startswith("http"): + # async with httpx.AsyncClient() as client: + # response = await client.get(img.uri) + # data = response.content + # else: + # raise ValueError("Unsupported URL type") + # else: + # data = c.data print("type of data", type(data)) return RawMediaItem(data=data)