Update to the ImageContentItem datatype so url + data is not in a confusing union

This commit is contained in:
Ashwin Bharambe 2024-12-17 10:49:47 -08:00
parent 4936794de1
commit cf9fce6b6e
6 changed files with 36 additions and 29 deletions

View file

@ -2554,27 +2554,22 @@
"ImageContentItem": { "ImageContentItem": {
"type": "object", "type": "object",
"properties": { "properties": {
"url": {
"$ref": "#/components/schemas/URL"
},
"data": {
"type": "string",
"contentEncoding": "base64"
},
"type": { "type": {
"type": "string", "type": "string",
"const": "image", "const": "image",
"default": "image" "default": "image"
},
"data": {
"oneOf": [
{
"type": "string",
"contentEncoding": "base64"
},
{
"$ref": "#/components/schemas/URL"
}
]
} }
}, },
"additionalProperties": false, "additionalProperties": false,
"required": [ "required": [
"type", "type"
"data"
] ]
}, },
"InterleavedContent": { "InterleavedContent": {

View file

@ -1043,17 +1043,16 @@ components:
additionalProperties: false additionalProperties: false
properties: properties:
data: data:
oneOf: contentEncoding: base64
- contentEncoding: base64
type: string type: string
- $ref: '#/components/schemas/URL'
type: type:
const: image const: image
default: image default: image
type: string type: string
url:
$ref: '#/components/schemas/URL'
required: required:
- type - type
- data
type: object type: object
InferenceStep: InferenceStep:
additionalProperties: false additionalProperties: false

View file

@ -4,11 +4,11 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Annotated, List, Literal, Union from typing import Annotated, List, Literal, Optional, Union
from llama_models.schema_utils import json_schema_type, register_schema from llama_models.schema_utils import json_schema_type, register_schema
from pydantic import BaseModel, Field from pydantic import BaseModel, Field, model_validator
@json_schema_type( @json_schema_type(
@ -21,10 +21,21 @@ class URL(BaseModel):
return self.uri return self.uri
class _URLOrData(BaseModel):
url: Optional[URL] = None
data: Optional[bytes] = None
@model_validator(mode="before")
@classmethod
def validator(cls, values):
if isinstance(values, dict):
return values
return {"url": values}
@json_schema_type @json_schema_type
class ImageContentItem(BaseModel): class ImageContentItem(_URLOrData):
type: Literal["image"] = "image" type: Literal["image"] = "image"
data: Union[bytes, URL]
@json_schema_type @json_schema_type

View file

@ -11,6 +11,8 @@ from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel from pydantic import BaseModel
from llama_stack.apis.common.content_types import URL
@json_schema_type @json_schema_type
class RestAPIMethod(Enum): class RestAPIMethod(Enum):

View file

@ -11,7 +11,7 @@ import pytest
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem, URL
from .utils import group_chunks from .utils import group_chunks
@ -32,7 +32,7 @@ class TestVisionModelInference:
), ),
( (
ImageContentItem( ImageContentItem(
data=URL( url=URL(
uri="https://www.healthypawspetinsurance.com/Images/V3/DogAndPuppyInsurance/Dog_CTA_Desktop_HeroImage.jpg" uri="https://www.healthypawspetinsurance.com/Images/V3/DogAndPuppyInsurance/Dog_CTA_Desktop_HeroImage.jpg"
) )
), ),
@ -98,7 +98,7 @@ class TestVisionModelInference:
images = [ images = [
ImageContentItem( ImageContentItem(
data=URL( url=URL(
uri="https://www.healthypawspetinsurance.com/Images/V3/DogAndPuppyInsurance/Dog_CTA_Desktop_HeroImage.jpg" uri="https://www.healthypawspetinsurance.com/Images/V3/DogAndPuppyInsurance/Dog_CTA_Desktop_HeroImage.jpg"
) )
), ),

View file

@ -139,9 +139,9 @@ def request_has_media(request: Union[ChatCompletionRequest, CompletionRequest]):
async def localize_image_content(media: ImageContentItem) -> Tuple[bytes, str]: async def localize_image_content(media: ImageContentItem) -> Tuple[bytes, str]:
if isinstance(media.data, URL) and media.data.uri.startswith("http"): if media.url and media.url.uri.startswith("http"):
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
r = await client.get(media.data.uri) r = await client.get(media.url.uri)
content = r.content content = r.content
content_type = r.headers.get("content-type") content_type = r.headers.get("content-type")
if content_type: if content_type:
@ -157,8 +157,8 @@ async def localize_image_content(media: ImageContentItem) -> Tuple[bytes, str]:
async def convert_image_content_to_url( async def convert_image_content_to_url(
media: ImageContentItem, download: bool = False, include_format: bool = True media: ImageContentItem, download: bool = False, include_format: bool = True
) -> str: ) -> str:
if isinstance(media.data, URL) and not download: if media.url and not download:
return media.data.uri return media.url.uri
content, format = await localize_image_content(media) content, format = await localize_image_content(media)
if include_format: if include_format: