From a6c206ea66146b374704a74321271156b8d04c04 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Mon, 30 Dec 2024 16:40:36 -0800 Subject: [PATCH] [bugfix] fix prompt_adapter interleaved_content_convert_to_raw (#696) # What does this PR do? - fix interleaved_content_convert_to_raw in prompt_adapter to correctly convert ImageContentItem to RawMediaItem with raw data bytes ## Test Plan ``` torchrun $CONDA_PREFIX/bin/pytest -v -s -k "meta_reference" --inference-model="meta-llama/Llama-3.2-11B-Vision-Instruct" ./llama_stack/providers/tests/inference/test_vision_inference.py ``` **Before** image **After** image ## Sources Please link relevant resources if necessary. ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Ran pre-commit to handle lint / formatting issues. - [ ] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [ ] Updated relevant documentation. - [ ] Wrote necessary unit or integration tests. --- .../utils/inference/prompt_adapter.py | 27 ++++++++++--------- 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/llama_stack/providers/utils/inference/prompt_adapter.py b/llama_stack/providers/utils/inference/prompt_adapter.py index f7d2cd84e..ed0cabe1c 100644 --- a/llama_stack/providers/utils/inference/prompt_adapter.py +++ b/llama_stack/providers/utils/inference/prompt_adapter.py @@ -40,7 +40,6 @@ from llama_stack.apis.common.content_types import ( InterleavedContent, InterleavedContentItem, TextContentItem, - URL, ) from llama_stack.apis.inference import ( @@ -117,27 +116,31 @@ async def interleaved_content_convert_to_raw( elif isinstance(c, TextContentItem): return RawTextItem(text=c.text) elif isinstance(c, ImageContentItem): - # load image and return PIL version - img = c.data - if isinstance(img, URL): - if img.uri.startswith("data"): - match = re.match(r"data:image/(\w+);base64,(.+)", img.uri) + if c.url: + # Load image bytes from URL + if c.url.uri.startswith("data"): + match = re.match(r"data:image/(\w+);base64,(.+)", c.url.uri) if not match: - raise ValueError("Invalid data URL format") + raise ValueError( + f"Invalid data URL format, {c.url.uri[:40]}..." + ) _, image_data = match.groups() data = base64.b64decode(image_data) - elif img.uri.startswith("file://"): - path = img.uri[len("file://") :] + elif c.url.uri.startswith("file://"): + path = c.url.uri[len("file://") :] with open(path, "rb") as f: data = f.read() # type: ignore - elif img.uri.startswith("http"): + elif c.url.uri.startswith("http"): async with httpx.AsyncClient() as client: - response = await client.get(img.uri) + response = await client.get(c.url.uri) data = response.content else: raise ValueError("Unsupported URL type") - else: + elif c.data: data = c.data + else: + raise ValueError("No data or URL provided") + return RawMediaItem(data=data) else: raise ValueError(f"Unsupported content type: {type(c)}")