From 94051cfe9e5d5255dad33391418cfd6401bc0e79 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 30 Jan 2025 15:58:23 -0800 Subject: [PATCH] fix ImageContentItem to take base64 string as image.data (#909) # What does this PR do? - Discussion in https://github.com/meta-llama/llama-stack/pull/906#discussion_r1936260819 - image.data should accept base64 string as input instead of binary bytes, change prompt_adapter to account for that. ## Test Plan ``` pytest -v tests/client-sdk/inference/test_inference.py ``` with test in https://github.com/meta-llama/llama-stack/pull/906 ## Sources Please link relevant resources if necessary. ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Ran pre-commit to handle lint / formatting issues. - [ ] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [ ] Updated relevant documentation. - [ ] Wrote necessary unit or integration tests. --- .../openapi_generator/strong_typing/schema.py | 39 +++++++++++++------ docs/resources/llama-stack-spec.html | 24 ++++++++---- docs/resources/llama-stack-spec.yaml | 13 +++++++ llama_stack/apis/common/content_types.py | 31 ++++++++++----- .../utils/inference/prompt_adapter.py | 9 +++-- 5 files changed, 85 insertions(+), 31 deletions(-) diff --git a/docs/openapi_generator/strong_typing/schema.py b/docs/openapi_generator/strong_typing/schema.py index f4393041f..577428035 100644 --- a/docs/openapi_generator/strong_typing/schema.py +++ b/docs/openapi_generator/strong_typing/schema.py @@ -248,7 +248,9 @@ class JsonSchemaGenerator: type_schema.update(self._metadata_to_schema(m)) return type_schema - def _simple_type_to_schema(self, typ: TypeLike) -> Optional[Schema]: + def _simple_type_to_schema( + self, typ: TypeLike, json_schema_extra: Optional[dict] = None + ) -> Optional[Schema]: """ Returns the JSON schema associated with a simple, unrestricted type. @@ -264,6 +266,11 @@ class JsonSchemaGenerator: elif typ is float: return {"type": "number"} elif typ is str: + if json_schema_extra and "contentEncoding" in json_schema_extra: + return { + "type": "string", + "contentEncoding": json_schema_extra["contentEncoding"], + } return {"type": "string"} elif typ is bytes: return {"type": "string", "contentEncoding": "base64"} @@ -303,7 +310,12 @@ class JsonSchemaGenerator: # not a simple type return None - def type_to_schema(self, data_type: TypeLike, force_expand: bool = False) -> Schema: + def type_to_schema( + self, + data_type: TypeLike, + force_expand: bool = False, + json_schema_extra: Optional[dict] = None, + ) -> Schema: """ Returns the JSON schema associated with a type. @@ -313,7 +325,7 @@ class JsonSchemaGenerator: """ # short-circuit for common simple types - schema = self._simple_type_to_schema(data_type) + schema = self._simple_type_to_schema(data_type, json_schema_extra) if schema is not None: return schema @@ -486,15 +498,9 @@ class JsonSchemaGenerator: property_docstrings = get_class_property_docstrings( typ, self.options.property_description_fun ) - properties: Dict[str, Schema] = {} required: List[str] = [] for property_name, property_type in get_class_properties(typ): - defaults = {} - if "model_fields" in members: - f = members["model_fields"] - defaults = {k: finfo.default for k, finfo in f.items()} - # rename property if an alias name is specified alias = get_annotation(property_type, Alias) if alias: @@ -502,11 +508,22 @@ class JsonSchemaGenerator: else: output_name = property_name + defaults = {} + json_schema_extra = None + if "model_fields" in members: + f = members["model_fields"] + defaults = {k: finfo.default for k, finfo in f.items()} + json_schema_extra = f.get(output_name, None).json_schema_extra + if is_type_optional(property_type): optional_type: type = unwrap_optional_type(property_type) - property_def = self.type_to_schema(optional_type) + property_def = self.type_to_schema( + optional_type, json_schema_extra=json_schema_extra + ) else: - property_def = self.type_to_schema(property_type) + property_def = self.type_to_schema( + property_type, json_schema_extra=json_schema_extra + ) required.append(output_name) # check if attribute has a default value initializer diff --git a/docs/resources/llama-stack-spec.html b/docs/resources/llama-stack-spec.html index 0454e22ec..5b3771340 100644 --- a/docs/resources/llama-stack-spec.html +++ b/docs/resources/llama-stack-spec.html @@ -2439,27 +2439,32 @@ "type": { "type": "string", "const": "image", - "default": "image" + "default": "image", + "description": "Discriminator type of the content item. Always \"image\"" }, "image": { "type": "object", "properties": { "url": { - "$ref": "#/components/schemas/URL" + "$ref": "#/components/schemas/URL", + "description": "A URL of the image or data URL in the format of data:image/{type};base64,{data}. Note that URL could have length limits." }, "data": { "type": "string", - "contentEncoding": "base64" + "contentEncoding": "base64", + "description": "base64 encoded image data as string" } }, - "additionalProperties": false + "additionalProperties": false, + "description": "Image as a base64 encoded string or an URL" } }, "additionalProperties": false, "required": [ "type", "image" - ] + ], + "title": "A image content item" }, "InterleavedContent": { "oneOf": [ @@ -2647,17 +2652,20 @@ "type": { "type": "string", "const": "text", - "default": "text" + "default": "text", + "description": "Discriminator type of the content item. Always \"text\"" }, "text": { - "type": "string" + "type": "string", + "description": "Text content" } }, "additionalProperties": false, "required": [ "type", "text" - ] + ], + "title": "A text content item" }, "ToolCall": { "type": "object", diff --git a/docs/resources/llama-stack-spec.yaml b/docs/resources/llama-stack-spec.yaml index 0734ef236..01232c001 100644 --- a/docs/resources/llama-stack-spec.yaml +++ b/docs/resources/llama-stack-spec.yaml @@ -1466,19 +1466,28 @@ components: type: string const: image default: image + description: >- + Discriminator type of the content item. Always "image" image: type: object properties: url: $ref: '#/components/schemas/URL' + description: >- + A URL of the image or data URL in the format of data:image/{type};base64,{data}. + Note that URL could have length limits. data: type: string contentEncoding: base64 + description: base64 encoded image data as string additionalProperties: false + description: >- + Image as a base64 encoded string or an URL additionalProperties: false required: - type - image + title: A image content item InterleavedContent: oneOf: - type: string @@ -1598,12 +1607,16 @@ components: type: string const: text default: text + description: >- + Discriminator type of the content item. Always "text" text: type: string + description: Text content additionalProperties: false required: - type - text + title: A text content item ToolCall: type: object properties: diff --git a/llama_stack/apis/common/content_types.py b/llama_stack/apis/common/content_types.py index 0b27a0196..8e56f59b1 100644 --- a/llama_stack/apis/common/content_types.py +++ b/llama_stack/apis/common/content_types.py @@ -4,14 +4,13 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import base64 from enum import Enum from typing import Annotated, List, Literal, Optional, Union from llama_models.llama3.api.datatypes import ToolCall from llama_models.schema_utils import json_schema_type, register_schema -from pydantic import BaseModel, Field, field_serializer, model_validator +from pydantic import BaseModel, Field, model_validator @json_schema_type @@ -20,8 +19,16 @@ class URL(BaseModel): class _URLOrData(BaseModel): + """ + A URL or a base64 encoded string + + :param url: A URL of the image or data URL in the format of data:image/{type};base64,{data}. Note that URL could have length limits. + :param data: base64 encoded image data as string + """ + url: Optional[URL] = None - data: Optional[bytes] = None + # data is a base64 encoded string, hint with contentEncoding=base64 + data: Optional[str] = Field(contentEncoding="base64", default=None) @model_validator(mode="before") @classmethod @@ -30,21 +37,27 @@ class _URLOrData(BaseModel): return values return {"url": values} - @field_serializer("data") - def serialize_data(self, data: Optional[bytes], _info): - if data is None: - return None - return base64.b64encode(data).decode("utf-8") - @json_schema_type class ImageContentItem(BaseModel): + """A image content item + + :param type: Discriminator type of the content item. Always "image" + :param image: Image as a base64 encoded string or an URL + """ + type: Literal["image"] = "image" image: _URLOrData @json_schema_type class TextContentItem(BaseModel): + """A text content item + + :param type: Discriminator type of the content item. Always "text" + :param text: Text content + """ + type: Literal["text"] = "text" text: str diff --git a/llama_stack/providers/utils/inference/prompt_adapter.py b/llama_stack/providers/utils/inference/prompt_adapter.py index f5298d844..e49771980 100644 --- a/llama_stack/providers/utils/inference/prompt_adapter.py +++ b/llama_stack/providers/utils/inference/prompt_adapter.py @@ -135,7 +135,8 @@ async def interleaved_content_convert_to_raw( else: raise ValueError("Unsupported URL type") elif image.data: - data = image.data + # data is a base64 encoded string, decode it to bytes for RawMediaItem + data = base64.b64decode(image.data) else: raise ValueError("No data or URL provided") @@ -184,8 +185,10 @@ async def localize_image_content(media: ImageContentItem) -> Tuple[bytes, str]: return content, format else: - pil_image = PIL_Image.open(io.BytesIO(image.data)) - return image.data, pil_image.format + # data is a base64 encoded string, decode it to bytes first + data_bytes = base64.b64decode(image.data) + pil_image = PIL_Image.open(io.BytesIO(data_bytes)) + return data_bytes, pil_image.format async def convert_image_content_to_url(