diff --git a/docs/resources/llama-stack-spec.html b/docs/resources/llama-stack-spec.html index f372257a0..cd92a10f5 100644 --- a/docs/resources/llama-stack-spec.html +++ b/docs/resources/llama-stack-spec.html @@ -2554,27 +2554,22 @@ "ImageContentItem": { "type": "object", "properties": { + "url": { + "$ref": "#/components/schemas/URL" + }, + "data": { + "type": "string", + "contentEncoding": "base64" + }, "type": { "type": "string", "const": "image", "default": "image" - }, - "data": { - "oneOf": [ - { - "type": "string", - "contentEncoding": "base64" - }, - { - "$ref": "#/components/schemas/URL" - } - ] } }, "additionalProperties": false, "required": [ - "type", - "data" + "type" ] }, "InterleavedContent": { diff --git a/docs/resources/llama-stack-spec.yaml b/docs/resources/llama-stack-spec.yaml index 52c3aaac6..08db0699e 100644 --- a/docs/resources/llama-stack-spec.yaml +++ b/docs/resources/llama-stack-spec.yaml @@ -1043,17 +1043,16 @@ components: additionalProperties: false properties: data: - oneOf: - - contentEncoding: base64 - type: string - - $ref: '#/components/schemas/URL' + contentEncoding: base64 + type: string type: const: image default: image type: string + url: + $ref: '#/components/schemas/URL' required: - type - - data type: object InferenceStep: additionalProperties: false diff --git a/llama_stack/apis/common/content_types.py b/llama_stack/apis/common/content_types.py index 1403dd782..316a4a5d6 100644 --- a/llama_stack/apis/common/content_types.py +++ b/llama_stack/apis/common/content_types.py @@ -4,11 +4,11 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Annotated, List, Literal, Union +from typing import Annotated, List, Literal, Optional, Union from llama_models.schema_utils import json_schema_type, register_schema -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, model_validator @json_schema_type( @@ -21,10 +21,21 @@ class URL(BaseModel): return self.uri +class _URLOrData(BaseModel): + url: Optional[URL] = None + data: Optional[bytes] = None + + @model_validator(mode="before") + @classmethod + def validator(cls, values): + if isinstance(values, dict): + return values + return {"url": values} + + @json_schema_type -class ImageContentItem(BaseModel): +class ImageContentItem(_URLOrData): type: Literal["image"] = "image" - data: Union[bytes, URL] @json_schema_type diff --git a/llama_stack/apis/common/deployment_types.py b/llama_stack/apis/common/deployment_types.py index 67096ac52..24de0cc91 100644 --- a/llama_stack/apis/common/deployment_types.py +++ b/llama_stack/apis/common/deployment_types.py @@ -11,6 +11,8 @@ from llama_models.schema_utils import json_schema_type from pydantic import BaseModel +from llama_stack.apis.common.content_types import URL + @json_schema_type class RestAPIMethod(Enum): diff --git a/llama_stack/providers/tests/inference/test_vision_inference.py b/llama_stack/providers/tests/inference/test_vision_inference.py index d29ace491..d58164676 100644 --- a/llama_stack/providers/tests/inference/test_vision_inference.py +++ b/llama_stack/providers/tests/inference/test_vision_inference.py @@ -11,7 +11,7 @@ import pytest from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403 -from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem +from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem, URL from .utils import group_chunks @@ -32,7 +32,7 @@ class TestVisionModelInference: ), ( ImageContentItem( - data=URL( + url=URL( uri="https://www.healthypawspetinsurance.com/Images/V3/DogAndPuppyInsurance/Dog_CTA_Desktop_HeroImage.jpg" ) ), @@ -98,7 +98,7 @@ class TestVisionModelInference: images = [ ImageContentItem( - data=URL( + url=URL( uri="https://www.healthypawspetinsurance.com/Images/V3/DogAndPuppyInsurance/Dog_CTA_Desktop_HeroImage.jpg" ) ), diff --git a/llama_stack/providers/utils/inference/prompt_adapter.py b/llama_stack/providers/utils/inference/prompt_adapter.py index 4f51467c2..42aa987c3 100644 --- a/llama_stack/providers/utils/inference/prompt_adapter.py +++ b/llama_stack/providers/utils/inference/prompt_adapter.py @@ -139,9 +139,9 @@ def request_has_media(request: Union[ChatCompletionRequest, CompletionRequest]): async def localize_image_content(media: ImageContentItem) -> Tuple[bytes, str]: - if isinstance(media.data, URL) and media.data.uri.startswith("http"): + if media.url and media.url.uri.startswith("http"): async with httpx.AsyncClient() as client: - r = await client.get(media.data.uri) + r = await client.get(media.url.uri) content = r.content content_type = r.headers.get("content-type") if content_type: @@ -157,8 +157,8 @@ async def localize_image_content(media: ImageContentItem) -> Tuple[bytes, str]: async def convert_image_content_to_url( media: ImageContentItem, download: bool = False, include_format: bool = True ) -> str: - if isinstance(media.data, URL) and not download: - return media.data.uri + if media.url and not download: + return media.url.uri content, format = await localize_image_content(media) if include_format: