mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 17:29:01 +00:00
Update to the ImageContentItem datatype so url + data is not in a confusing union
This commit is contained in:
parent
4936794de1
commit
cf9fce6b6e
6 changed files with 36 additions and 29 deletions
|
@ -2554,27 +2554,22 @@
|
||||||
"ImageContentItem": {
|
"ImageContentItem": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
|
"url": {
|
||||||
|
"$ref": "#/components/schemas/URL"
|
||||||
|
},
|
||||||
|
"data": {
|
||||||
|
"type": "string",
|
||||||
|
"contentEncoding": "base64"
|
||||||
|
},
|
||||||
"type": {
|
"type": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"const": "image",
|
"const": "image",
|
||||||
"default": "image"
|
"default": "image"
|
||||||
},
|
|
||||||
"data": {
|
|
||||||
"oneOf": [
|
|
||||||
{
|
|
||||||
"type": "string",
|
|
||||||
"contentEncoding": "base64"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"$ref": "#/components/schemas/URL"
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
"required": [
|
"required": [
|
||||||
"type",
|
"type"
|
||||||
"data"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"InterleavedContent": {
|
"InterleavedContent": {
|
||||||
|
|
|
@ -1043,17 +1043,16 @@ components:
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
properties:
|
properties:
|
||||||
data:
|
data:
|
||||||
oneOf:
|
contentEncoding: base64
|
||||||
- contentEncoding: base64
|
type: string
|
||||||
type: string
|
|
||||||
- $ref: '#/components/schemas/URL'
|
|
||||||
type:
|
type:
|
||||||
const: image
|
const: image
|
||||||
default: image
|
default: image
|
||||||
type: string
|
type: string
|
||||||
|
url:
|
||||||
|
$ref: '#/components/schemas/URL'
|
||||||
required:
|
required:
|
||||||
- type
|
- type
|
||||||
- data
|
|
||||||
type: object
|
type: object
|
||||||
InferenceStep:
|
InferenceStep:
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
|
|
|
@ -4,11 +4,11 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Annotated, List, Literal, Union
|
from typing import Annotated, List, Literal, Optional, Union
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, register_schema
|
from llama_models.schema_utils import json_schema_type, register_schema
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field, model_validator
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type(
|
@json_schema_type(
|
||||||
|
@ -21,10 +21,21 @@ class URL(BaseModel):
|
||||||
return self.uri
|
return self.uri
|
||||||
|
|
||||||
|
|
||||||
|
class _URLOrData(BaseModel):
|
||||||
|
url: Optional[URL] = None
|
||||||
|
data: Optional[bytes] = None
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def validator(cls, values):
|
||||||
|
if isinstance(values, dict):
|
||||||
|
return values
|
||||||
|
return {"url": values}
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ImageContentItem(BaseModel):
|
class ImageContentItem(_URLOrData):
|
||||||
type: Literal["image"] = "image"
|
type: Literal["image"] = "image"
|
||||||
data: Union[bytes, URL]
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
|
|
@ -11,6 +11,8 @@ from llama_models.schema_utils import json_schema_type
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from llama_stack.apis.common.content_types import URL
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class RestAPIMethod(Enum):
|
class RestAPIMethod(Enum):
|
||||||
|
|
|
@ -11,7 +11,7 @@ import pytest
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem
|
from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem, URL
|
||||||
|
|
||||||
from .utils import group_chunks
|
from .utils import group_chunks
|
||||||
|
|
||||||
|
@ -32,7 +32,7 @@ class TestVisionModelInference:
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
ImageContentItem(
|
ImageContentItem(
|
||||||
data=URL(
|
url=URL(
|
||||||
uri="https://www.healthypawspetinsurance.com/Images/V3/DogAndPuppyInsurance/Dog_CTA_Desktop_HeroImage.jpg"
|
uri="https://www.healthypawspetinsurance.com/Images/V3/DogAndPuppyInsurance/Dog_CTA_Desktop_HeroImage.jpg"
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
|
@ -98,7 +98,7 @@ class TestVisionModelInference:
|
||||||
|
|
||||||
images = [
|
images = [
|
||||||
ImageContentItem(
|
ImageContentItem(
|
||||||
data=URL(
|
url=URL(
|
||||||
uri="https://www.healthypawspetinsurance.com/Images/V3/DogAndPuppyInsurance/Dog_CTA_Desktop_HeroImage.jpg"
|
uri="https://www.healthypawspetinsurance.com/Images/V3/DogAndPuppyInsurance/Dog_CTA_Desktop_HeroImage.jpg"
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
|
|
|
@ -139,9 +139,9 @@ def request_has_media(request: Union[ChatCompletionRequest, CompletionRequest]):
|
||||||
|
|
||||||
|
|
||||||
async def localize_image_content(media: ImageContentItem) -> Tuple[bytes, str]:
|
async def localize_image_content(media: ImageContentItem) -> Tuple[bytes, str]:
|
||||||
if isinstance(media.data, URL) and media.data.uri.startswith("http"):
|
if media.url and media.url.uri.startswith("http"):
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
r = await client.get(media.data.uri)
|
r = await client.get(media.url.uri)
|
||||||
content = r.content
|
content = r.content
|
||||||
content_type = r.headers.get("content-type")
|
content_type = r.headers.get("content-type")
|
||||||
if content_type:
|
if content_type:
|
||||||
|
@ -157,8 +157,8 @@ async def localize_image_content(media: ImageContentItem) -> Tuple[bytes, str]:
|
||||||
async def convert_image_content_to_url(
|
async def convert_image_content_to_url(
|
||||||
media: ImageContentItem, download: bool = False, include_format: bool = True
|
media: ImageContentItem, download: bool = False, include_format: bool = True
|
||||||
) -> str:
|
) -> str:
|
||||||
if isinstance(media.data, URL) and not download:
|
if media.url and not download:
|
||||||
return media.data.uri
|
return media.url.uri
|
||||||
|
|
||||||
content, format = await localize_image_content(media)
|
content, format = await localize_image_content(media)
|
||||||
if include_format:
|
if include_format:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue