mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
[inference api] modify content types so they follow a more standard structure (#841)
Some small updates to the inference types to make them more standard Specifically: - image data is now located in a "image" subkey - similarly tool call data is located in a "tool_call" subkey The pattern followed is `dict(type="foo", foo=<...>)`
This commit is contained in:
parent
caa8387dd2
commit
07b87365ab
15 changed files with 104 additions and 76 deletions
|
@ -3761,22 +3761,29 @@
|
|||
"ImageContentItem": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {
|
||||
"$ref": "#/components/schemas/URL"
|
||||
},
|
||||
"data": {
|
||||
"type": "string",
|
||||
"contentEncoding": "base64"
|
||||
},
|
||||
"type": {
|
||||
"type": "string",
|
||||
"const": "image",
|
||||
"default": "image"
|
||||
},
|
||||
"image": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {
|
||||
"$ref": "#/components/schemas/URL"
|
||||
},
|
||||
"data": {
|
||||
"type": "string",
|
||||
"contentEncoding": "base64"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"type"
|
||||
"type",
|
||||
"image"
|
||||
]
|
||||
},
|
||||
"InterleavedContent": {
|
||||
|
@ -4518,7 +4525,7 @@
|
|||
"const": "image",
|
||||
"default": "image"
|
||||
},
|
||||
"data": {
|
||||
"image": {
|
||||
"type": "string",
|
||||
"contentEncoding": "base64"
|
||||
}
|
||||
|
@ -4526,7 +4533,7 @@
|
|||
"additionalProperties": false,
|
||||
"required": [
|
||||
"type",
|
||||
"data"
|
||||
"image"
|
||||
]
|
||||
},
|
||||
"TextDelta": {
|
||||
|
@ -4570,7 +4577,7 @@
|
|||
"const": "tool_call",
|
||||
"default": "tool_call"
|
||||
},
|
||||
"content": {
|
||||
"tool_call": {
|
||||
"oneOf": [
|
||||
{
|
||||
"type": "string"
|
||||
|
@ -4587,7 +4594,7 @@
|
|||
"additionalProperties": false,
|
||||
"required": [
|
||||
"type",
|
||||
"content",
|
||||
"tool_call",
|
||||
"parse_status"
|
||||
]
|
||||
},
|
||||
|
|
|
@ -926,22 +926,27 @@ components:
|
|||
ImageContentItem:
|
||||
additionalProperties: false
|
||||
properties:
|
||||
data:
|
||||
contentEncoding: base64
|
||||
type: string
|
||||
image:
|
||||
additionalProperties: false
|
||||
properties:
|
||||
data:
|
||||
contentEncoding: base64
|
||||
type: string
|
||||
url:
|
||||
$ref: '#/components/schemas/URL'
|
||||
type: object
|
||||
type:
|
||||
const: image
|
||||
default: image
|
||||
type: string
|
||||
url:
|
||||
$ref: '#/components/schemas/URL'
|
||||
required:
|
||||
- type
|
||||
- image
|
||||
type: object
|
||||
ImageDelta:
|
||||
additionalProperties: false
|
||||
properties:
|
||||
data:
|
||||
image:
|
||||
contentEncoding: base64
|
||||
type: string
|
||||
type:
|
||||
|
@ -950,7 +955,7 @@ components:
|
|||
type: string
|
||||
required:
|
||||
- type
|
||||
- data
|
||||
- image
|
||||
type: object
|
||||
InferenceStep:
|
||||
additionalProperties: false
|
||||
|
@ -2748,19 +2753,19 @@ components:
|
|||
ToolCallDelta:
|
||||
additionalProperties: false
|
||||
properties:
|
||||
content:
|
||||
parse_status:
|
||||
$ref: '#/components/schemas/ToolCallParseStatus'
|
||||
tool_call:
|
||||
oneOf:
|
||||
- type: string
|
||||
- $ref: '#/components/schemas/ToolCall'
|
||||
parse_status:
|
||||
$ref: '#/components/schemas/ToolCallParseStatus'
|
||||
type:
|
||||
const: tool_call
|
||||
default: tool_call
|
||||
type: string
|
||||
required:
|
||||
- type
|
||||
- content
|
||||
- tool_call
|
||||
- parse_status
|
||||
type: object
|
||||
ToolCallParseStatus:
|
||||
|
|
|
@ -137,7 +137,7 @@ class EventLogger:
|
|||
event,
|
||||
LogEvent(
|
||||
role=None,
|
||||
content=delta.content,
|
||||
content=delta.tool_call,
|
||||
end="",
|
||||
color="cyan",
|
||||
),
|
||||
|
|
|
@ -38,8 +38,9 @@ class _URLOrData(BaseModel):
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class ImageContentItem(_URLOrData):
|
||||
class ImageContentItem(BaseModel):
|
||||
type: Literal["image"] = "image"
|
||||
image: _URLOrData
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -73,7 +74,7 @@ class TextDelta(BaseModel):
|
|||
@json_schema_type
|
||||
class ImageDelta(BaseModel):
|
||||
type: Literal["image"] = "image"
|
||||
data: bytes
|
||||
image: bytes
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -91,7 +92,7 @@ class ToolCallDelta(BaseModel):
|
|||
# you either send an in-progress tool call so the client can stream a long
|
||||
# code generation or you send the final parsed tool call at the end of the
|
||||
# stream
|
||||
content: Union[str, ToolCall]
|
||||
tool_call: Union[str, ToolCall]
|
||||
parse_status: ToolCallParseStatus
|
||||
|
||||
|
||||
|
|
|
@ -423,7 +423,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
step_id=step_id,
|
||||
delta=ToolCallDelta(
|
||||
parse_status=ToolCallParseStatus.succeeded,
|
||||
content=ToolCall(
|
||||
tool_call=ToolCall(
|
||||
call_id="",
|
||||
tool_name=MEMORY_QUERY_TOOL,
|
||||
arguments={},
|
||||
|
@ -525,7 +525,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
delta = event.delta
|
||||
if delta.type == "tool_call":
|
||||
if delta.parse_status == ToolCallParseStatus.succeeded:
|
||||
tool_calls.append(delta.content)
|
||||
tool_calls.append(delta.tool_call)
|
||||
if stream:
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
|
@ -639,7 +639,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
tool_call=tool_call,
|
||||
delta=ToolCallDelta(
|
||||
parse_status=ToolCallParseStatus.in_progress,
|
||||
content=tool_call,
|
||||
tool_call=tool_call,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
|
|
@ -377,7 +377,7 @@ class MetaReferenceInferenceImpl(
|
|||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content="",
|
||||
tool_call="",
|
||||
parse_status=ToolCallParseStatus.started,
|
||||
),
|
||||
)
|
||||
|
@ -395,7 +395,7 @@ class MetaReferenceInferenceImpl(
|
|||
|
||||
if ipython:
|
||||
delta = ToolCallDelta(
|
||||
content=text,
|
||||
tool_call=text,
|
||||
parse_status=ToolCallParseStatus.in_progress,
|
||||
)
|
||||
else:
|
||||
|
@ -434,7 +434,7 @@ class MetaReferenceInferenceImpl(
|
|||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content="",
|
||||
tool_call="",
|
||||
parse_status=ToolCallParseStatus.failed,
|
||||
),
|
||||
stop_reason=stop_reason,
|
||||
|
@ -446,7 +446,7 @@ class MetaReferenceInferenceImpl(
|
|||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content=tool_call,
|
||||
tool_call=tool_call,
|
||||
parse_status=ToolCallParseStatus.succeeded,
|
||||
),
|
||||
stop_reason=stop_reason,
|
||||
|
|
|
@ -218,7 +218,7 @@ async def convert_chat_completion_response_stream(
|
|||
event=ChatCompletionResponseEvent(
|
||||
event_type=event_type,
|
||||
delta=ToolCallDelta(
|
||||
content=tool_call,
|
||||
tool_call=tool_call,
|
||||
parse_status=ToolCallParseStatus.succeeded,
|
||||
),
|
||||
)
|
||||
|
|
|
@ -505,7 +505,9 @@ async def convert_openai_chat_completion_stream(
|
|||
event=ChatCompletionResponseEvent(
|
||||
event_type=next(event_type),
|
||||
delta=ToolCallDelta(
|
||||
content=_convert_openai_tool_calls(choice.delta.tool_calls)[0],
|
||||
tool_call=_convert_openai_tool_calls(choice.delta.tool_calls)[
|
||||
0
|
||||
],
|
||||
parse_status=ToolCallParseStatus.succeeded,
|
||||
),
|
||||
logprobs=_convert_openai_logprobs(choice.logprobs),
|
||||
|
|
|
@ -472,7 +472,7 @@ class TestConvertStreamChatCompletionResponse:
|
|||
iter = converted.__aiter__()
|
||||
chunk = await iter.__anext__()
|
||||
assert chunk.event.event_type == ChatCompletionResponseEventType.start
|
||||
assert chunk.event.delta.content == ToolCall(
|
||||
assert chunk.event.delta.tool_call == ToolCall(
|
||||
call_id="tool_call_id",
|
||||
tool_name="get_flight_info",
|
||||
arguments={"origin": "AU", "destination": "LAX"},
|
||||
|
|
|
@ -470,16 +470,16 @@ class TestInference:
|
|||
)
|
||||
first = grouped[ChatCompletionResponseEventType.progress][0]
|
||||
if not isinstance(
|
||||
first.event.delta.content, ToolCall
|
||||
first.event.delta.tool_call, ToolCall
|
||||
): # first chunk may contain entire call
|
||||
assert first.event.delta.parse_status == ToolCallParseStatus.started
|
||||
|
||||
last = grouped[ChatCompletionResponseEventType.progress][-1]
|
||||
# assert last.event.stop_reason == expected_stop_reason
|
||||
assert last.event.delta.parse_status == ToolCallParseStatus.succeeded
|
||||
assert isinstance(last.event.delta.content, ToolCall)
|
||||
assert isinstance(last.event.delta.tool_call, ToolCall)
|
||||
|
||||
call = last.event.delta.content
|
||||
call = last.event.delta.tool_call
|
||||
assert call.tool_name == "get_weather"
|
||||
assert "location" in call.arguments
|
||||
assert "San Francisco" in call.arguments["location"]
|
||||
|
|
|
@ -32,13 +32,15 @@ class TestVisionModelInference:
|
|||
"image, expected_strings",
|
||||
[
|
||||
(
|
||||
ImageContentItem(data=PASTA_IMAGE),
|
||||
ImageContentItem(image=dict(data=PASTA_IMAGE)),
|
||||
["spaghetti"],
|
||||
),
|
||||
(
|
||||
ImageContentItem(
|
||||
url=URL(
|
||||
uri="https://www.healthypawspetinsurance.com/Images/V3/DogAndPuppyInsurance/Dog_CTA_Desktop_HeroImage.jpg"
|
||||
image=dict(
|
||||
url=URL(
|
||||
uri="https://www.healthypawspetinsurance.com/Images/V3/DogAndPuppyInsurance/Dog_CTA_Desktop_HeroImage.jpg"
|
||||
)
|
||||
)
|
||||
),
|
||||
["puppy"],
|
||||
|
@ -103,8 +105,10 @@ class TestVisionModelInference:
|
|||
|
||||
images = [
|
||||
ImageContentItem(
|
||||
url=URL(
|
||||
uri="https://www.healthypawspetinsurance.com/Images/V3/DogAndPuppyInsurance/Dog_CTA_Desktop_HeroImage.jpg"
|
||||
image=dict(
|
||||
url=URL(
|
||||
uri="https://www.healthypawspetinsurance.com/Images/V3/DogAndPuppyInsurance/Dog_CTA_Desktop_HeroImage.jpg"
|
||||
)
|
||||
)
|
||||
),
|
||||
]
|
||||
|
|
|
@ -240,7 +240,7 @@ async def process_chat_completion_stream_response(
|
|||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content="",
|
||||
tool_call="",
|
||||
parse_status=ToolCallParseStatus.started,
|
||||
),
|
||||
)
|
||||
|
@ -260,7 +260,7 @@ async def process_chat_completion_stream_response(
|
|||
if ipython:
|
||||
buffer += text
|
||||
delta = ToolCallDelta(
|
||||
content=text,
|
||||
tool_call=text,
|
||||
parse_status=ToolCallParseStatus.in_progress,
|
||||
)
|
||||
|
||||
|
@ -289,7 +289,7 @@ async def process_chat_completion_stream_response(
|
|||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content="",
|
||||
tool_call="",
|
||||
parse_status=ToolCallParseStatus.failed,
|
||||
),
|
||||
stop_reason=stop_reason,
|
||||
|
@ -301,7 +301,7 @@ async def process_chat_completion_stream_response(
|
|||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content=tool_call,
|
||||
tool_call=tool_call,
|
||||
parse_status=ToolCallParseStatus.succeeded,
|
||||
),
|
||||
stop_reason=stop_reason,
|
||||
|
|
|
@ -113,28 +113,29 @@ async def interleaved_content_convert_to_raw(
|
|||
elif isinstance(c, TextContentItem):
|
||||
return RawTextItem(text=c.text)
|
||||
elif isinstance(c, ImageContentItem):
|
||||
if c.url:
|
||||
image = c.image
|
||||
if image.url:
|
||||
# Load image bytes from URL
|
||||
if c.url.uri.startswith("data"):
|
||||
match = re.match(r"data:image/(\w+);base64,(.+)", c.url.uri)
|
||||
if image.url.uri.startswith("data"):
|
||||
match = re.match(r"data:image/(\w+);base64,(.+)", image.url.uri)
|
||||
if not match:
|
||||
raise ValueError(
|
||||
f"Invalid data URL format, {c.url.uri[:40]}..."
|
||||
f"Invalid data URL format, {image.url.uri[:40]}..."
|
||||
)
|
||||
_, image_data = match.groups()
|
||||
data = base64.b64decode(image_data)
|
||||
elif c.url.uri.startswith("file://"):
|
||||
path = c.url.uri[len("file://") :]
|
||||
elif image.url.uri.startswith("file://"):
|
||||
path = image.url.uri[len("file://") :]
|
||||
with open(path, "rb") as f:
|
||||
data = f.read() # type: ignore
|
||||
elif c.url.uri.startswith("http"):
|
||||
elif image.url.uri.startswith("http"):
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(c.url.uri)
|
||||
response = await client.get(image.url.uri)
|
||||
data = response.content
|
||||
else:
|
||||
raise ValueError("Unsupported URL type")
|
||||
elif c.data:
|
||||
data = c.data
|
||||
elif image.data:
|
||||
data = image.data
|
||||
else:
|
||||
raise ValueError("No data or URL provided")
|
||||
|
||||
|
@ -170,26 +171,29 @@ def request_has_media(request: Union[ChatCompletionRequest, CompletionRequest]):
|
|||
|
||||
|
||||
async def localize_image_content(media: ImageContentItem) -> Tuple[bytes, str]:
|
||||
if media.url and media.url.uri.startswith("http"):
|
||||
image = media.image
|
||||
if image.url and image.url.uri.startswith("http"):
|
||||
async with httpx.AsyncClient() as client:
|
||||
r = await client.get(media.url.uri)
|
||||
r = await client.get(image.url.uri)
|
||||
content = r.content
|
||||
content_type = r.headers.get("content-type")
|
||||
if content_type:
|
||||
format = content_type.split("/")[-1]
|
||||
else:
|
||||
format = "png"
|
||||
|
||||
return content, format
|
||||
else:
|
||||
image = PIL_Image.open(io.BytesIO(media.data))
|
||||
return media.data, image.format
|
||||
pil_image = PIL_Image.open(io.BytesIO(image.data))
|
||||
return image.data, pil_image.format
|
||||
|
||||
|
||||
async def convert_image_content_to_url(
|
||||
media: ImageContentItem, download: bool = False, include_format: bool = True
|
||||
) -> str:
|
||||
if media.url and (not download or media.url.uri.startswith("data")):
|
||||
return media.url.uri
|
||||
image = media.image
|
||||
if image.url and (not download or image.url.uri.startswith("data")):
|
||||
return image.url.uri
|
||||
|
||||
content, format = await localize_image_content(media)
|
||||
if include_format:
|
||||
|
|
|
@ -258,7 +258,7 @@ def extract_tool_invocation_content(response):
|
|||
for chunk in response:
|
||||
delta = chunk.event.delta
|
||||
if delta.type == "tool_call" and delta.parse_status == "succeeded":
|
||||
call = delta.content
|
||||
call = delta.tool_call
|
||||
tool_invocation_content += f"[{call.tool_name}, {call.arguments}]"
|
||||
return tool_invocation_content
|
||||
|
||||
|
@ -321,9 +321,11 @@ def test_image_chat_completion_non_streaming(llama_stack_client, vision_model_id
|
|||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"url": {
|
||||
# TODO: Replace with Github based URI to resources/sample1.jpg
|
||||
"uri": "https://www.healthypawspetinsurance.com/Images/V3/DogAndPuppyInsurance/Dog_CTA_Desktop_HeroImage.jpg"
|
||||
"image": {
|
||||
"url": {
|
||||
# TODO: Replace with Github based URI to resources/sample1.jpg
|
||||
"uri": "https://www.healthypawspetinsurance.com/Images/V3/DogAndPuppyInsurance/Dog_CTA_Desktop_HeroImage.jpg"
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
|
@ -348,9 +350,11 @@ def test_image_chat_completion_streaming(llama_stack_client, vision_model_id):
|
|||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"url": {
|
||||
# TODO: Replace with Github based URI to resources/sample1.jpg
|
||||
"uri": "https://www.healthypawspetinsurance.com/Images/V3/DogAndPuppyInsurance/Dog_CTA_Desktop_HeroImage.jpg"
|
||||
"image": {
|
||||
"url": {
|
||||
# TODO: Replace with Github based URI to resources/sample1.jpg
|
||||
"uri": "https://www.healthypawspetinsurance.com/Images/V3/DogAndPuppyInsurance/Dog_CTA_Desktop_HeroImage.jpg"
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
|
@ -374,14 +378,15 @@ def test_image_chat_completion_streaming(llama_stack_client, vision_model_id):
|
|||
def test_image_chat_completion_base64_url(
|
||||
llama_stack_client, vision_model_id, base64_image_url
|
||||
):
|
||||
|
||||
message = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"url": {
|
||||
"uri": base64_image_url,
|
||||
"image": {
|
||||
"url": {
|
||||
"uri": base64_image_url,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
|
|
|
@ -141,7 +141,7 @@ def test_safety_with_image(llama_stack_client, model_providers):
|
|||
},
|
||||
{
|
||||
"type": "image",
|
||||
"url": {"uri": data_url_from_image(file_path)},
|
||||
"image": {"url": {"uri": data_url_from_image(file_path)}},
|
||||
},
|
||||
],
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue