[inference api] modify content types so they follow a more standard structure (#841)

Some small updates to the inference types to make them more standard

Specifically:
- image data is now located in a "image" subkey
- similarly tool call data is located in a "tool_call" subkey

The pattern followed is `dict(type="foo", foo=<...>)`
This commit is contained in:
Ashwin Bharambe 2025-01-22 12:16:18 -08:00 committed by GitHub
parent caa8387dd2
commit 07b87365ab
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 104 additions and 76 deletions

View file

@ -3761,22 +3761,29 @@
"ImageContentItem": {
"type": "object",
"properties": {
"url": {
"$ref": "#/components/schemas/URL"
},
"data": {
"type": "string",
"contentEncoding": "base64"
},
"type": {
"type": "string",
"const": "image",
"default": "image"
},
"image": {
"type": "object",
"properties": {
"url": {
"$ref": "#/components/schemas/URL"
},
"data": {
"type": "string",
"contentEncoding": "base64"
}
},
"additionalProperties": false
}
},
"additionalProperties": false,
"required": [
"type"
"type",
"image"
]
},
"InterleavedContent": {
@ -4518,7 +4525,7 @@
"const": "image",
"default": "image"
},
"data": {
"image": {
"type": "string",
"contentEncoding": "base64"
}
@ -4526,7 +4533,7 @@
"additionalProperties": false,
"required": [
"type",
"data"
"image"
]
},
"TextDelta": {
@ -4570,7 +4577,7 @@
"const": "tool_call",
"default": "tool_call"
},
"content": {
"tool_call": {
"oneOf": [
{
"type": "string"
@ -4587,7 +4594,7 @@
"additionalProperties": false,
"required": [
"type",
"content",
"tool_call",
"parse_status"
]
},

View file

@ -926,22 +926,27 @@ components:
ImageContentItem:
additionalProperties: false
properties:
data:
contentEncoding: base64
type: string
image:
additionalProperties: false
properties:
data:
contentEncoding: base64
type: string
url:
$ref: '#/components/schemas/URL'
type: object
type:
const: image
default: image
type: string
url:
$ref: '#/components/schemas/URL'
required:
- type
- image
type: object
ImageDelta:
additionalProperties: false
properties:
data:
image:
contentEncoding: base64
type: string
type:
@ -950,7 +955,7 @@ components:
type: string
required:
- type
- data
- image
type: object
InferenceStep:
additionalProperties: false
@ -2748,19 +2753,19 @@ components:
ToolCallDelta:
additionalProperties: false
properties:
content:
parse_status:
$ref: '#/components/schemas/ToolCallParseStatus'
tool_call:
oneOf:
- type: string
- $ref: '#/components/schemas/ToolCall'
parse_status:
$ref: '#/components/schemas/ToolCallParseStatus'
type:
const: tool_call
default: tool_call
type: string
required:
- type
- content
- tool_call
- parse_status
type: object
ToolCallParseStatus:

View file

@ -137,7 +137,7 @@ class EventLogger:
event,
LogEvent(
role=None,
content=delta.content,
content=delta.tool_call,
end="",
color="cyan",
),

View file

@ -38,8 +38,9 @@ class _URLOrData(BaseModel):
@json_schema_type
class ImageContentItem(_URLOrData):
class ImageContentItem(BaseModel):
type: Literal["image"] = "image"
image: _URLOrData
@json_schema_type
@ -73,7 +74,7 @@ class TextDelta(BaseModel):
@json_schema_type
class ImageDelta(BaseModel):
type: Literal["image"] = "image"
data: bytes
image: bytes
@json_schema_type
@ -91,7 +92,7 @@ class ToolCallDelta(BaseModel):
# you either send an in-progress tool call so the client can stream a long
# code generation or you send the final parsed tool call at the end of the
# stream
content: Union[str, ToolCall]
tool_call: Union[str, ToolCall]
parse_status: ToolCallParseStatus

View file

@ -423,7 +423,7 @@ class ChatAgent(ShieldRunnerMixin):
step_id=step_id,
delta=ToolCallDelta(
parse_status=ToolCallParseStatus.succeeded,
content=ToolCall(
tool_call=ToolCall(
call_id="",
tool_name=MEMORY_QUERY_TOOL,
arguments={},
@ -525,7 +525,7 @@ class ChatAgent(ShieldRunnerMixin):
delta = event.delta
if delta.type == "tool_call":
if delta.parse_status == ToolCallParseStatus.succeeded:
tool_calls.append(delta.content)
tool_calls.append(delta.tool_call)
if stream:
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
@ -639,7 +639,7 @@ class ChatAgent(ShieldRunnerMixin):
tool_call=tool_call,
delta=ToolCallDelta(
parse_status=ToolCallParseStatus.in_progress,
content=tool_call,
tool_call=tool_call,
),
)
)

View file

@ -377,7 +377,7 @@ class MetaReferenceInferenceImpl(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
tool_call="",
parse_status=ToolCallParseStatus.started,
),
)
@ -395,7 +395,7 @@ class MetaReferenceInferenceImpl(
if ipython:
delta = ToolCallDelta(
content=text,
tool_call=text,
parse_status=ToolCallParseStatus.in_progress,
)
else:
@ -434,7 +434,7 @@ class MetaReferenceInferenceImpl(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
tool_call="",
parse_status=ToolCallParseStatus.failed,
),
stop_reason=stop_reason,
@ -446,7 +446,7 @@ class MetaReferenceInferenceImpl(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content=tool_call,
tool_call=tool_call,
parse_status=ToolCallParseStatus.succeeded,
),
stop_reason=stop_reason,

View file

@ -218,7 +218,7 @@ async def convert_chat_completion_response_stream(
event=ChatCompletionResponseEvent(
event_type=event_type,
delta=ToolCallDelta(
content=tool_call,
tool_call=tool_call,
parse_status=ToolCallParseStatus.succeeded,
),
)

View file

@ -505,7 +505,9 @@ async def convert_openai_chat_completion_stream(
event=ChatCompletionResponseEvent(
event_type=next(event_type),
delta=ToolCallDelta(
content=_convert_openai_tool_calls(choice.delta.tool_calls)[0],
tool_call=_convert_openai_tool_calls(choice.delta.tool_calls)[
0
],
parse_status=ToolCallParseStatus.succeeded,
),
logprobs=_convert_openai_logprobs(choice.logprobs),

View file

@ -472,7 +472,7 @@ class TestConvertStreamChatCompletionResponse:
iter = converted.__aiter__()
chunk = await iter.__anext__()
assert chunk.event.event_type == ChatCompletionResponseEventType.start
assert chunk.event.delta.content == ToolCall(
assert chunk.event.delta.tool_call == ToolCall(
call_id="tool_call_id",
tool_name="get_flight_info",
arguments={"origin": "AU", "destination": "LAX"},

View file

@ -470,16 +470,16 @@ class TestInference:
)
first = grouped[ChatCompletionResponseEventType.progress][0]
if not isinstance(
first.event.delta.content, ToolCall
first.event.delta.tool_call, ToolCall
): # first chunk may contain entire call
assert first.event.delta.parse_status == ToolCallParseStatus.started
last = grouped[ChatCompletionResponseEventType.progress][-1]
# assert last.event.stop_reason == expected_stop_reason
assert last.event.delta.parse_status == ToolCallParseStatus.succeeded
assert isinstance(last.event.delta.content, ToolCall)
assert isinstance(last.event.delta.tool_call, ToolCall)
call = last.event.delta.content
call = last.event.delta.tool_call
assert call.tool_name == "get_weather"
assert "location" in call.arguments
assert "San Francisco" in call.arguments["location"]

View file

@ -32,13 +32,15 @@ class TestVisionModelInference:
"image, expected_strings",
[
(
ImageContentItem(data=PASTA_IMAGE),
ImageContentItem(image=dict(data=PASTA_IMAGE)),
["spaghetti"],
),
(
ImageContentItem(
url=URL(
uri="https://www.healthypawspetinsurance.com/Images/V3/DogAndPuppyInsurance/Dog_CTA_Desktop_HeroImage.jpg"
image=dict(
url=URL(
uri="https://www.healthypawspetinsurance.com/Images/V3/DogAndPuppyInsurance/Dog_CTA_Desktop_HeroImage.jpg"
)
)
),
["puppy"],
@ -103,8 +105,10 @@ class TestVisionModelInference:
images = [
ImageContentItem(
url=URL(
uri="https://www.healthypawspetinsurance.com/Images/V3/DogAndPuppyInsurance/Dog_CTA_Desktop_HeroImage.jpg"
image=dict(
url=URL(
uri="https://www.healthypawspetinsurance.com/Images/V3/DogAndPuppyInsurance/Dog_CTA_Desktop_HeroImage.jpg"
)
)
),
]

View file

@ -240,7 +240,7 @@ async def process_chat_completion_stream_response(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
tool_call="",
parse_status=ToolCallParseStatus.started,
),
)
@ -260,7 +260,7 @@ async def process_chat_completion_stream_response(
if ipython:
buffer += text
delta = ToolCallDelta(
content=text,
tool_call=text,
parse_status=ToolCallParseStatus.in_progress,
)
@ -289,7 +289,7 @@ async def process_chat_completion_stream_response(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
tool_call="",
parse_status=ToolCallParseStatus.failed,
),
stop_reason=stop_reason,
@ -301,7 +301,7 @@ async def process_chat_completion_stream_response(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content=tool_call,
tool_call=tool_call,
parse_status=ToolCallParseStatus.succeeded,
),
stop_reason=stop_reason,

View file

@ -113,28 +113,29 @@ async def interleaved_content_convert_to_raw(
elif isinstance(c, TextContentItem):
return RawTextItem(text=c.text)
elif isinstance(c, ImageContentItem):
if c.url:
image = c.image
if image.url:
# Load image bytes from URL
if c.url.uri.startswith("data"):
match = re.match(r"data:image/(\w+);base64,(.+)", c.url.uri)
if image.url.uri.startswith("data"):
match = re.match(r"data:image/(\w+);base64,(.+)", image.url.uri)
if not match:
raise ValueError(
f"Invalid data URL format, {c.url.uri[:40]}..."
f"Invalid data URL format, {image.url.uri[:40]}..."
)
_, image_data = match.groups()
data = base64.b64decode(image_data)
elif c.url.uri.startswith("file://"):
path = c.url.uri[len("file://") :]
elif image.url.uri.startswith("file://"):
path = image.url.uri[len("file://") :]
with open(path, "rb") as f:
data = f.read() # type: ignore
elif c.url.uri.startswith("http"):
elif image.url.uri.startswith("http"):
async with httpx.AsyncClient() as client:
response = await client.get(c.url.uri)
response = await client.get(image.url.uri)
data = response.content
else:
raise ValueError("Unsupported URL type")
elif c.data:
data = c.data
elif image.data:
data = image.data
else:
raise ValueError("No data or URL provided")
@ -170,26 +171,29 @@ def request_has_media(request: Union[ChatCompletionRequest, CompletionRequest]):
async def localize_image_content(media: ImageContentItem) -> Tuple[bytes, str]:
if media.url and media.url.uri.startswith("http"):
image = media.image
if image.url and image.url.uri.startswith("http"):
async with httpx.AsyncClient() as client:
r = await client.get(media.url.uri)
r = await client.get(image.url.uri)
content = r.content
content_type = r.headers.get("content-type")
if content_type:
format = content_type.split("/")[-1]
else:
format = "png"
return content, format
else:
image = PIL_Image.open(io.BytesIO(media.data))
return media.data, image.format
pil_image = PIL_Image.open(io.BytesIO(image.data))
return image.data, pil_image.format
async def convert_image_content_to_url(
media: ImageContentItem, download: bool = False, include_format: bool = True
) -> str:
if media.url and (not download or media.url.uri.startswith("data")):
return media.url.uri
image = media.image
if image.url and (not download or image.url.uri.startswith("data")):
return image.url.uri
content, format = await localize_image_content(media)
if include_format:

View file

@ -258,7 +258,7 @@ def extract_tool_invocation_content(response):
for chunk in response:
delta = chunk.event.delta
if delta.type == "tool_call" and delta.parse_status == "succeeded":
call = delta.content
call = delta.tool_call
tool_invocation_content += f"[{call.tool_name}, {call.arguments}]"
return tool_invocation_content
@ -321,9 +321,11 @@ def test_image_chat_completion_non_streaming(llama_stack_client, vision_model_id
"content": [
{
"type": "image",
"url": {
# TODO: Replace with Github based URI to resources/sample1.jpg
"uri": "https://www.healthypawspetinsurance.com/Images/V3/DogAndPuppyInsurance/Dog_CTA_Desktop_HeroImage.jpg"
"image": {
"url": {
# TODO: Replace with Github based URI to resources/sample1.jpg
"uri": "https://www.healthypawspetinsurance.com/Images/V3/DogAndPuppyInsurance/Dog_CTA_Desktop_HeroImage.jpg"
},
},
},
{
@ -348,9 +350,11 @@ def test_image_chat_completion_streaming(llama_stack_client, vision_model_id):
"content": [
{
"type": "image",
"url": {
# TODO: Replace with Github based URI to resources/sample1.jpg
"uri": "https://www.healthypawspetinsurance.com/Images/V3/DogAndPuppyInsurance/Dog_CTA_Desktop_HeroImage.jpg"
"image": {
"url": {
# TODO: Replace with Github based URI to resources/sample1.jpg
"uri": "https://www.healthypawspetinsurance.com/Images/V3/DogAndPuppyInsurance/Dog_CTA_Desktop_HeroImage.jpg"
},
},
},
{
@ -374,14 +378,15 @@ def test_image_chat_completion_streaming(llama_stack_client, vision_model_id):
def test_image_chat_completion_base64_url(
llama_stack_client, vision_model_id, base64_image_url
):
message = {
"role": "user",
"content": [
{
"type": "image",
"url": {
"uri": base64_image_url,
"image": {
"url": {
"uri": base64_image_url,
},
},
},
{

View file

@ -141,7 +141,7 @@ def test_safety_with_image(llama_stack_client, model_providers):
},
{
"type": "image",
"url": {"uri": data_url_from_image(file_path)},
"image": {"url": {"uri": data_url_from_image(file_path)}},
},
],
}