[inference api] modify content types so they follow a more standard structure (#841)

Some small updates to the inference types to make them more standard

Specifically:
- image data is now located in a "image" subkey
- similarly tool call data is located in a "tool_call" subkey

The pattern followed is `dict(type="foo", foo=<...>)`
This commit is contained in:
Ashwin Bharambe 2025-01-22 12:16:18 -08:00 committed by GitHub
parent caa8387dd2
commit 07b87365ab
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 104 additions and 76 deletions

View file

@ -3759,6 +3759,14 @@
] ]
}, },
"ImageContentItem": { "ImageContentItem": {
"type": "object",
"properties": {
"type": {
"type": "string",
"const": "image",
"default": "image"
},
"image": {
"type": "object", "type": "object",
"properties": { "properties": {
"url": { "url": {
@ -3767,16 +3775,15 @@
"data": { "data": {
"type": "string", "type": "string",
"contentEncoding": "base64" "contentEncoding": "base64"
}
}, },
"type": { "additionalProperties": false
"type": "string",
"const": "image",
"default": "image"
} }
}, },
"additionalProperties": false, "additionalProperties": false,
"required": [ "required": [
"type" "type",
"image"
] ]
}, },
"InterleavedContent": { "InterleavedContent": {
@ -4518,7 +4525,7 @@
"const": "image", "const": "image",
"default": "image" "default": "image"
}, },
"data": { "image": {
"type": "string", "type": "string",
"contentEncoding": "base64" "contentEncoding": "base64"
} }
@ -4526,7 +4533,7 @@
"additionalProperties": false, "additionalProperties": false,
"required": [ "required": [
"type", "type",
"data" "image"
] ]
}, },
"TextDelta": { "TextDelta": {
@ -4570,7 +4577,7 @@
"const": "tool_call", "const": "tool_call",
"default": "tool_call" "default": "tool_call"
}, },
"content": { "tool_call": {
"oneOf": [ "oneOf": [
{ {
"type": "string" "type": "string"
@ -4587,7 +4594,7 @@
"additionalProperties": false, "additionalProperties": false,
"required": [ "required": [
"type", "type",
"content", "tool_call",
"parse_status" "parse_status"
] ]
}, },

View file

@ -926,31 +926,36 @@ components:
ImageContentItem: ImageContentItem:
additionalProperties: false additionalProperties: false
properties: properties:
data: image:
contentEncoding: base64
type: string
type:
const: image
default: image
type: string
url:
$ref: '#/components/schemas/URL'
required:
- type
type: object
ImageDelta:
additionalProperties: false additionalProperties: false
properties: properties:
data: data:
contentEncoding: base64 contentEncoding: base64
type: string type: string
url:
$ref: '#/components/schemas/URL'
type: object
type: type:
const: image const: image
default: image default: image
type: string type: string
required: required:
- type - type
- data - image
type: object
ImageDelta:
additionalProperties: false
properties:
image:
contentEncoding: base64
type: string
type:
const: image
default: image
type: string
required:
- type
- image
type: object type: object
InferenceStep: InferenceStep:
additionalProperties: false additionalProperties: false
@ -2748,19 +2753,19 @@ components:
ToolCallDelta: ToolCallDelta:
additionalProperties: false additionalProperties: false
properties: properties:
content: parse_status:
$ref: '#/components/schemas/ToolCallParseStatus'
tool_call:
oneOf: oneOf:
- type: string - type: string
- $ref: '#/components/schemas/ToolCall' - $ref: '#/components/schemas/ToolCall'
parse_status:
$ref: '#/components/schemas/ToolCallParseStatus'
type: type:
const: tool_call const: tool_call
default: tool_call default: tool_call
type: string type: string
required: required:
- type - type
- content - tool_call
- parse_status - parse_status
type: object type: object
ToolCallParseStatus: ToolCallParseStatus:

View file

@ -137,7 +137,7 @@ class EventLogger:
event, event,
LogEvent( LogEvent(
role=None, role=None,
content=delta.content, content=delta.tool_call,
end="", end="",
color="cyan", color="cyan",
), ),

View file

@ -38,8 +38,9 @@ class _URLOrData(BaseModel):
@json_schema_type @json_schema_type
class ImageContentItem(_URLOrData): class ImageContentItem(BaseModel):
type: Literal["image"] = "image" type: Literal["image"] = "image"
image: _URLOrData
@json_schema_type @json_schema_type
@ -73,7 +74,7 @@ class TextDelta(BaseModel):
@json_schema_type @json_schema_type
class ImageDelta(BaseModel): class ImageDelta(BaseModel):
type: Literal["image"] = "image" type: Literal["image"] = "image"
data: bytes image: bytes
@json_schema_type @json_schema_type
@ -91,7 +92,7 @@ class ToolCallDelta(BaseModel):
# you either send an in-progress tool call so the client can stream a long # you either send an in-progress tool call so the client can stream a long
# code generation or you send the final parsed tool call at the end of the # code generation or you send the final parsed tool call at the end of the
# stream # stream
content: Union[str, ToolCall] tool_call: Union[str, ToolCall]
parse_status: ToolCallParseStatus parse_status: ToolCallParseStatus

View file

@ -423,7 +423,7 @@ class ChatAgent(ShieldRunnerMixin):
step_id=step_id, step_id=step_id,
delta=ToolCallDelta( delta=ToolCallDelta(
parse_status=ToolCallParseStatus.succeeded, parse_status=ToolCallParseStatus.succeeded,
content=ToolCall( tool_call=ToolCall(
call_id="", call_id="",
tool_name=MEMORY_QUERY_TOOL, tool_name=MEMORY_QUERY_TOOL,
arguments={}, arguments={},
@ -525,7 +525,7 @@ class ChatAgent(ShieldRunnerMixin):
delta = event.delta delta = event.delta
if delta.type == "tool_call": if delta.type == "tool_call":
if delta.parse_status == ToolCallParseStatus.succeeded: if delta.parse_status == ToolCallParseStatus.succeeded:
tool_calls.append(delta.content) tool_calls.append(delta.tool_call)
if stream: if stream:
yield AgentTurnResponseStreamChunk( yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent( event=AgentTurnResponseEvent(
@ -639,7 +639,7 @@ class ChatAgent(ShieldRunnerMixin):
tool_call=tool_call, tool_call=tool_call,
delta=ToolCallDelta( delta=ToolCallDelta(
parse_status=ToolCallParseStatus.in_progress, parse_status=ToolCallParseStatus.in_progress,
content=tool_call, tool_call=tool_call,
), ),
) )
) )

View file

@ -377,7 +377,7 @@ class MetaReferenceInferenceImpl(
event=ChatCompletionResponseEvent( event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress, event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta( delta=ToolCallDelta(
content="", tool_call="",
parse_status=ToolCallParseStatus.started, parse_status=ToolCallParseStatus.started,
), ),
) )
@ -395,7 +395,7 @@ class MetaReferenceInferenceImpl(
if ipython: if ipython:
delta = ToolCallDelta( delta = ToolCallDelta(
content=text, tool_call=text,
parse_status=ToolCallParseStatus.in_progress, parse_status=ToolCallParseStatus.in_progress,
) )
else: else:
@ -434,7 +434,7 @@ class MetaReferenceInferenceImpl(
event=ChatCompletionResponseEvent( event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress, event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta( delta=ToolCallDelta(
content="", tool_call="",
parse_status=ToolCallParseStatus.failed, parse_status=ToolCallParseStatus.failed,
), ),
stop_reason=stop_reason, stop_reason=stop_reason,
@ -446,7 +446,7 @@ class MetaReferenceInferenceImpl(
event=ChatCompletionResponseEvent( event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress, event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta( delta=ToolCallDelta(
content=tool_call, tool_call=tool_call,
parse_status=ToolCallParseStatus.succeeded, parse_status=ToolCallParseStatus.succeeded,
), ),
stop_reason=stop_reason, stop_reason=stop_reason,

View file

@ -218,7 +218,7 @@ async def convert_chat_completion_response_stream(
event=ChatCompletionResponseEvent( event=ChatCompletionResponseEvent(
event_type=event_type, event_type=event_type,
delta=ToolCallDelta( delta=ToolCallDelta(
content=tool_call, tool_call=tool_call,
parse_status=ToolCallParseStatus.succeeded, parse_status=ToolCallParseStatus.succeeded,
), ),
) )

View file

@ -505,7 +505,9 @@ async def convert_openai_chat_completion_stream(
event=ChatCompletionResponseEvent( event=ChatCompletionResponseEvent(
event_type=next(event_type), event_type=next(event_type),
delta=ToolCallDelta( delta=ToolCallDelta(
content=_convert_openai_tool_calls(choice.delta.tool_calls)[0], tool_call=_convert_openai_tool_calls(choice.delta.tool_calls)[
0
],
parse_status=ToolCallParseStatus.succeeded, parse_status=ToolCallParseStatus.succeeded,
), ),
logprobs=_convert_openai_logprobs(choice.logprobs), logprobs=_convert_openai_logprobs(choice.logprobs),

View file

@ -472,7 +472,7 @@ class TestConvertStreamChatCompletionResponse:
iter = converted.__aiter__() iter = converted.__aiter__()
chunk = await iter.__anext__() chunk = await iter.__anext__()
assert chunk.event.event_type == ChatCompletionResponseEventType.start assert chunk.event.event_type == ChatCompletionResponseEventType.start
assert chunk.event.delta.content == ToolCall( assert chunk.event.delta.tool_call == ToolCall(
call_id="tool_call_id", call_id="tool_call_id",
tool_name="get_flight_info", tool_name="get_flight_info",
arguments={"origin": "AU", "destination": "LAX"}, arguments={"origin": "AU", "destination": "LAX"},

View file

@ -470,16 +470,16 @@ class TestInference:
) )
first = grouped[ChatCompletionResponseEventType.progress][0] first = grouped[ChatCompletionResponseEventType.progress][0]
if not isinstance( if not isinstance(
first.event.delta.content, ToolCall first.event.delta.tool_call, ToolCall
): # first chunk may contain entire call ): # first chunk may contain entire call
assert first.event.delta.parse_status == ToolCallParseStatus.started assert first.event.delta.parse_status == ToolCallParseStatus.started
last = grouped[ChatCompletionResponseEventType.progress][-1] last = grouped[ChatCompletionResponseEventType.progress][-1]
# assert last.event.stop_reason == expected_stop_reason # assert last.event.stop_reason == expected_stop_reason
assert last.event.delta.parse_status == ToolCallParseStatus.succeeded assert last.event.delta.parse_status == ToolCallParseStatus.succeeded
assert isinstance(last.event.delta.content, ToolCall) assert isinstance(last.event.delta.tool_call, ToolCall)
call = last.event.delta.content call = last.event.delta.tool_call
assert call.tool_name == "get_weather" assert call.tool_name == "get_weather"
assert "location" in call.arguments assert "location" in call.arguments
assert "San Francisco" in call.arguments["location"] assert "San Francisco" in call.arguments["location"]

View file

@ -32,14 +32,16 @@ class TestVisionModelInference:
"image, expected_strings", "image, expected_strings",
[ [
( (
ImageContentItem(data=PASTA_IMAGE), ImageContentItem(image=dict(data=PASTA_IMAGE)),
["spaghetti"], ["spaghetti"],
), ),
( (
ImageContentItem( ImageContentItem(
image=dict(
url=URL( url=URL(
uri="https://www.healthypawspetinsurance.com/Images/V3/DogAndPuppyInsurance/Dog_CTA_Desktop_HeroImage.jpg" uri="https://www.healthypawspetinsurance.com/Images/V3/DogAndPuppyInsurance/Dog_CTA_Desktop_HeroImage.jpg"
) )
)
), ),
["puppy"], ["puppy"],
), ),
@ -103,9 +105,11 @@ class TestVisionModelInference:
images = [ images = [
ImageContentItem( ImageContentItem(
image=dict(
url=URL( url=URL(
uri="https://www.healthypawspetinsurance.com/Images/V3/DogAndPuppyInsurance/Dog_CTA_Desktop_HeroImage.jpg" uri="https://www.healthypawspetinsurance.com/Images/V3/DogAndPuppyInsurance/Dog_CTA_Desktop_HeroImage.jpg"
) )
)
), ),
] ]
expected_strings_to_check = [ expected_strings_to_check = [

View file

@ -240,7 +240,7 @@ async def process_chat_completion_stream_response(
event=ChatCompletionResponseEvent( event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress, event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta( delta=ToolCallDelta(
content="", tool_call="",
parse_status=ToolCallParseStatus.started, parse_status=ToolCallParseStatus.started,
), ),
) )
@ -260,7 +260,7 @@ async def process_chat_completion_stream_response(
if ipython: if ipython:
buffer += text buffer += text
delta = ToolCallDelta( delta = ToolCallDelta(
content=text, tool_call=text,
parse_status=ToolCallParseStatus.in_progress, parse_status=ToolCallParseStatus.in_progress,
) )
@ -289,7 +289,7 @@ async def process_chat_completion_stream_response(
event=ChatCompletionResponseEvent( event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress, event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta( delta=ToolCallDelta(
content="", tool_call="",
parse_status=ToolCallParseStatus.failed, parse_status=ToolCallParseStatus.failed,
), ),
stop_reason=stop_reason, stop_reason=stop_reason,
@ -301,7 +301,7 @@ async def process_chat_completion_stream_response(
event=ChatCompletionResponseEvent( event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress, event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta( delta=ToolCallDelta(
content=tool_call, tool_call=tool_call,
parse_status=ToolCallParseStatus.succeeded, parse_status=ToolCallParseStatus.succeeded,
), ),
stop_reason=stop_reason, stop_reason=stop_reason,

View file

@ -113,28 +113,29 @@ async def interleaved_content_convert_to_raw(
elif isinstance(c, TextContentItem): elif isinstance(c, TextContentItem):
return RawTextItem(text=c.text) return RawTextItem(text=c.text)
elif isinstance(c, ImageContentItem): elif isinstance(c, ImageContentItem):
if c.url: image = c.image
if image.url:
# Load image bytes from URL # Load image bytes from URL
if c.url.uri.startswith("data"): if image.url.uri.startswith("data"):
match = re.match(r"data:image/(\w+);base64,(.+)", c.url.uri) match = re.match(r"data:image/(\w+);base64,(.+)", image.url.uri)
if not match: if not match:
raise ValueError( raise ValueError(
f"Invalid data URL format, {c.url.uri[:40]}..." f"Invalid data URL format, {image.url.uri[:40]}..."
) )
_, image_data = match.groups() _, image_data = match.groups()
data = base64.b64decode(image_data) data = base64.b64decode(image_data)
elif c.url.uri.startswith("file://"): elif image.url.uri.startswith("file://"):
path = c.url.uri[len("file://") :] path = image.url.uri[len("file://") :]
with open(path, "rb") as f: with open(path, "rb") as f:
data = f.read() # type: ignore data = f.read() # type: ignore
elif c.url.uri.startswith("http"): elif image.url.uri.startswith("http"):
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.get(c.url.uri) response = await client.get(image.url.uri)
data = response.content data = response.content
else: else:
raise ValueError("Unsupported URL type") raise ValueError("Unsupported URL type")
elif c.data: elif image.data:
data = c.data data = image.data
else: else:
raise ValueError("No data or URL provided") raise ValueError("No data or URL provided")
@ -170,26 +171,29 @@ def request_has_media(request: Union[ChatCompletionRequest, CompletionRequest]):
async def localize_image_content(media: ImageContentItem) -> Tuple[bytes, str]: async def localize_image_content(media: ImageContentItem) -> Tuple[bytes, str]:
if media.url and media.url.uri.startswith("http"): image = media.image
if image.url and image.url.uri.startswith("http"):
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
r = await client.get(media.url.uri) r = await client.get(image.url.uri)
content = r.content content = r.content
content_type = r.headers.get("content-type") content_type = r.headers.get("content-type")
if content_type: if content_type:
format = content_type.split("/")[-1] format = content_type.split("/")[-1]
else: else:
format = "png" format = "png"
return content, format return content, format
else: else:
image = PIL_Image.open(io.BytesIO(media.data)) pil_image = PIL_Image.open(io.BytesIO(image.data))
return media.data, image.format return image.data, pil_image.format
async def convert_image_content_to_url( async def convert_image_content_to_url(
media: ImageContentItem, download: bool = False, include_format: bool = True media: ImageContentItem, download: bool = False, include_format: bool = True
) -> str: ) -> str:
if media.url and (not download or media.url.uri.startswith("data")): image = media.image
return media.url.uri if image.url and (not download or image.url.uri.startswith("data")):
return image.url.uri
content, format = await localize_image_content(media) content, format = await localize_image_content(media)
if include_format: if include_format:

View file

@ -258,7 +258,7 @@ def extract_tool_invocation_content(response):
for chunk in response: for chunk in response:
delta = chunk.event.delta delta = chunk.event.delta
if delta.type == "tool_call" and delta.parse_status == "succeeded": if delta.type == "tool_call" and delta.parse_status == "succeeded":
call = delta.content call = delta.tool_call
tool_invocation_content += f"[{call.tool_name}, {call.arguments}]" tool_invocation_content += f"[{call.tool_name}, {call.arguments}]"
return tool_invocation_content return tool_invocation_content
@ -321,11 +321,13 @@ def test_image_chat_completion_non_streaming(llama_stack_client, vision_model_id
"content": [ "content": [
{ {
"type": "image", "type": "image",
"image": {
"url": { "url": {
# TODO: Replace with Github based URI to resources/sample1.jpg # TODO: Replace with Github based URI to resources/sample1.jpg
"uri": "https://www.healthypawspetinsurance.com/Images/V3/DogAndPuppyInsurance/Dog_CTA_Desktop_HeroImage.jpg" "uri": "https://www.healthypawspetinsurance.com/Images/V3/DogAndPuppyInsurance/Dog_CTA_Desktop_HeroImage.jpg"
}, },
}, },
},
{ {
"type": "text", "type": "text",
"text": "Describe what is in this image.", "text": "Describe what is in this image.",
@ -348,11 +350,13 @@ def test_image_chat_completion_streaming(llama_stack_client, vision_model_id):
"content": [ "content": [
{ {
"type": "image", "type": "image",
"image": {
"url": { "url": {
# TODO: Replace with Github based URI to resources/sample1.jpg # TODO: Replace with Github based URI to resources/sample1.jpg
"uri": "https://www.healthypawspetinsurance.com/Images/V3/DogAndPuppyInsurance/Dog_CTA_Desktop_HeroImage.jpg" "uri": "https://www.healthypawspetinsurance.com/Images/V3/DogAndPuppyInsurance/Dog_CTA_Desktop_HeroImage.jpg"
}, },
}, },
},
{ {
"type": "text", "type": "text",
"text": "Describe what is in this image.", "text": "Describe what is in this image.",
@ -374,16 +378,17 @@ def test_image_chat_completion_streaming(llama_stack_client, vision_model_id):
def test_image_chat_completion_base64_url( def test_image_chat_completion_base64_url(
llama_stack_client, vision_model_id, base64_image_url llama_stack_client, vision_model_id, base64_image_url
): ):
message = { message = {
"role": "user", "role": "user",
"content": [ "content": [
{ {
"type": "image", "type": "image",
"image": {
"url": { "url": {
"uri": base64_image_url, "uri": base64_image_url,
}, },
}, },
},
{ {
"type": "text", "type": "text",
"text": "Describe what is in this image.", "text": "Describe what is in this image.",

View file

@ -141,7 +141,7 @@ def test_safety_with_image(llama_stack_client, model_providers):
}, },
{ {
"type": "image", "type": "image",
"url": {"uri": data_url_from_image(file_path)}, "image": {"url": {"uri": data_url_from_image(file_path)}},
}, },
], ],
} }