Update to the ImageContentItem datatype so url + data is not in a confusing union

This commit is contained in:
Ashwin Bharambe 2024-12-17 10:49:47 -08:00
parent 4936794de1
commit cf9fce6b6e
6 changed files with 36 additions and 29 deletions

View file

@ -2554,27 +2554,22 @@
"ImageContentItem": {
"type": "object",
"properties": {
"url": {
"$ref": "#/components/schemas/URL"
},
"data": {
"type": "string",
"contentEncoding": "base64"
},
"type": {
"type": "string",
"const": "image",
"default": "image"
},
"data": {
"oneOf": [
{
"type": "string",
"contentEncoding": "base64"
},
{
"$ref": "#/components/schemas/URL"
}
]
}
},
"additionalProperties": false,
"required": [
"type",
"data"
"type"
]
},
"InterleavedContent": {

View file

@ -1043,17 +1043,16 @@ components:
additionalProperties: false
properties:
data:
oneOf:
- contentEncoding: base64
type: string
- $ref: '#/components/schemas/URL'
contentEncoding: base64
type: string
type:
const: image
default: image
type: string
url:
$ref: '#/components/schemas/URL'
required:
- type
- data
type: object
InferenceStep:
additionalProperties: false

View file

@ -4,11 +4,11 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Annotated, List, Literal, Union
from typing import Annotated, List, Literal, Optional, Union
from llama_models.schema_utils import json_schema_type, register_schema
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, model_validator
@json_schema_type(
@ -21,10 +21,21 @@ class URL(BaseModel):
return self.uri
class _URLOrData(BaseModel):
url: Optional[URL] = None
data: Optional[bytes] = None
@model_validator(mode="before")
@classmethod
def validator(cls, values):
if isinstance(values, dict):
return values
return {"url": values}
@json_schema_type
class ImageContentItem(BaseModel):
class ImageContentItem(_URLOrData):
type: Literal["image"] = "image"
data: Union[bytes, URL]
@json_schema_type

View file

@ -11,6 +11,8 @@ from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel
from llama_stack.apis.common.content_types import URL
@json_schema_type
class RestAPIMethod(Enum):

View file

@ -11,7 +11,7 @@ import pytest
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem
from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem, URL
from .utils import group_chunks
@ -32,7 +32,7 @@ class TestVisionModelInference:
),
(
ImageContentItem(
data=URL(
url=URL(
uri="https://www.healthypawspetinsurance.com/Images/V3/DogAndPuppyInsurance/Dog_CTA_Desktop_HeroImage.jpg"
)
),
@ -98,7 +98,7 @@ class TestVisionModelInference:
images = [
ImageContentItem(
data=URL(
url=URL(
uri="https://www.healthypawspetinsurance.com/Images/V3/DogAndPuppyInsurance/Dog_CTA_Desktop_HeroImage.jpg"
)
),

View file

@ -139,9 +139,9 @@ def request_has_media(request: Union[ChatCompletionRequest, CompletionRequest]):
async def localize_image_content(media: ImageContentItem) -> Tuple[bytes, str]:
if isinstance(media.data, URL) and media.data.uri.startswith("http"):
if media.url and media.url.uri.startswith("http"):
async with httpx.AsyncClient() as client:
r = await client.get(media.data.uri)
r = await client.get(media.url.uri)
content = r.content
content_type = r.headers.get("content-type")
if content_type:
@ -157,8 +157,8 @@ async def localize_image_content(media: ImageContentItem) -> Tuple[bytes, str]:
async def convert_image_content_to_url(
media: ImageContentItem, download: bool = False, include_format: bool = True
) -> str:
if isinstance(media.data, URL) and not download:
return media.data.uri
if media.url and not download:
return media.url.uri
content, format = await localize_image_content(media)
if include_format: