mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 01:03:59 +00:00
Update to the ImageContentItem datatype so url + data is not in a confusing union
This commit is contained in:
parent
4936794de1
commit
cf9fce6b6e
6 changed files with 36 additions and 29 deletions
|
@ -2554,27 +2554,22 @@
|
|||
"ImageContentItem": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {
|
||||
"$ref": "#/components/schemas/URL"
|
||||
},
|
||||
"data": {
|
||||
"type": "string",
|
||||
"contentEncoding": "base64"
|
||||
},
|
||||
"type": {
|
||||
"type": "string",
|
||||
"const": "image",
|
||||
"default": "image"
|
||||
},
|
||||
"data": {
|
||||
"oneOf": [
|
||||
{
|
||||
"type": "string",
|
||||
"contentEncoding": "base64"
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/URL"
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"type",
|
||||
"data"
|
||||
"type"
|
||||
]
|
||||
},
|
||||
"InterleavedContent": {
|
||||
|
|
|
@ -1043,17 +1043,16 @@ components:
|
|||
additionalProperties: false
|
||||
properties:
|
||||
data:
|
||||
oneOf:
|
||||
- contentEncoding: base64
|
||||
type: string
|
||||
- $ref: '#/components/schemas/URL'
|
||||
contentEncoding: base64
|
||||
type: string
|
||||
type:
|
||||
const: image
|
||||
default: image
|
||||
type: string
|
||||
url:
|
||||
$ref: '#/components/schemas/URL'
|
||||
required:
|
||||
- type
|
||||
- data
|
||||
type: object
|
||||
InferenceStep:
|
||||
additionalProperties: false
|
||||
|
|
|
@ -4,11 +4,11 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Annotated, List, Literal, Union
|
||||
from typing import Annotated, List, Literal, Optional, Union
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, register_schema
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
|
||||
@json_schema_type(
|
||||
|
@ -21,10 +21,21 @@ class URL(BaseModel):
|
|||
return self.uri
|
||||
|
||||
|
||||
class _URLOrData(BaseModel):
|
||||
url: Optional[URL] = None
|
||||
data: Optional[bytes] = None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validator(cls, values):
|
||||
if isinstance(values, dict):
|
||||
return values
|
||||
return {"url": values}
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ImageContentItem(BaseModel):
|
||||
class ImageContentItem(_URLOrData):
|
||||
type: Literal["image"] = "image"
|
||||
data: Union[bytes, URL]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
@ -11,6 +11,8 @@ from llama_models.schema_utils import json_schema_type
|
|||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RestAPIMethod(Enum):
|
||||
|
|
|
@ -11,7 +11,7 @@ import pytest
|
|||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem
|
||||
from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem, URL
|
||||
|
||||
from .utils import group_chunks
|
||||
|
||||
|
@ -32,7 +32,7 @@ class TestVisionModelInference:
|
|||
),
|
||||
(
|
||||
ImageContentItem(
|
||||
data=URL(
|
||||
url=URL(
|
||||
uri="https://www.healthypawspetinsurance.com/Images/V3/DogAndPuppyInsurance/Dog_CTA_Desktop_HeroImage.jpg"
|
||||
)
|
||||
),
|
||||
|
@ -98,7 +98,7 @@ class TestVisionModelInference:
|
|||
|
||||
images = [
|
||||
ImageContentItem(
|
||||
data=URL(
|
||||
url=URL(
|
||||
uri="https://www.healthypawspetinsurance.com/Images/V3/DogAndPuppyInsurance/Dog_CTA_Desktop_HeroImage.jpg"
|
||||
)
|
||||
),
|
||||
|
|
|
@ -139,9 +139,9 @@ def request_has_media(request: Union[ChatCompletionRequest, CompletionRequest]):
|
|||
|
||||
|
||||
async def localize_image_content(media: ImageContentItem) -> Tuple[bytes, str]:
|
||||
if isinstance(media.data, URL) and media.data.uri.startswith("http"):
|
||||
if media.url and media.url.uri.startswith("http"):
|
||||
async with httpx.AsyncClient() as client:
|
||||
r = await client.get(media.data.uri)
|
||||
r = await client.get(media.url.uri)
|
||||
content = r.content
|
||||
content_type = r.headers.get("content-type")
|
||||
if content_type:
|
||||
|
@ -157,8 +157,8 @@ async def localize_image_content(media: ImageContentItem) -> Tuple[bytes, str]:
|
|||
async def convert_image_content_to_url(
|
||||
media: ImageContentItem, download: bool = False, include_format: bool = True
|
||||
) -> str:
|
||||
if isinstance(media.data, URL) and not download:
|
||||
return media.data.uri
|
||||
if media.url and not download:
|
||||
return media.url.uri
|
||||
|
||||
content, format = await localize_image_content(media)
|
||||
if include_format:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue