Update to the ImageContentItem datatype so url + data is not in a confusing union

This commit is contained in:
Ashwin Bharambe 2024-12-17 10:49:47 -08:00
parent 4936794de1
commit cf9fce6b6e
6 changed files with 36 additions and 29 deletions

View file

@ -11,7 +11,7 @@ import pytest
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem
from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem, URL
from .utils import group_chunks
@ -32,7 +32,7 @@ class TestVisionModelInference:
),
(
ImageContentItem(
data=URL(
url=URL(
uri="https://www.healthypawspetinsurance.com/Images/V3/DogAndPuppyInsurance/Dog_CTA_Desktop_HeroImage.jpg"
)
),
@ -98,7 +98,7 @@ class TestVisionModelInference:
images = [
ImageContentItem(
data=URL(
url=URL(
uri="https://www.healthypawspetinsurance.com/Images/V3/DogAndPuppyInsurance/Dog_CTA_Desktop_HeroImage.jpg"
)
),

View file

@ -139,9 +139,9 @@ def request_has_media(request: Union[ChatCompletionRequest, CompletionRequest]):
async def localize_image_content(media: ImageContentItem) -> Tuple[bytes, str]:
if isinstance(media.data, URL) and media.data.uri.startswith("http"):
if media.url and media.url.uri.startswith("http"):
async with httpx.AsyncClient() as client:
r = await client.get(media.data.uri)
r = await client.get(media.url.uri)
content = r.content
content_type = r.headers.get("content-type")
if content_type:
@ -157,8 +157,8 @@ async def localize_image_content(media: ImageContentItem) -> Tuple[bytes, str]:
async def convert_image_content_to_url(
media: ImageContentItem, download: bool = False, include_format: bool = True
) -> str:
if isinstance(media.data, URL) and not download:
return media.data.uri
if media.url and not download:
return media.url.uri
content, format = await localize_image_content(media)
if include_format: