[inference api] modify content types so they follow a more standard structure (#841)

Some small updates to the inference types to make them more standard

Specifically:
- image data is now located in a "image" subkey
- similarly tool call data is located in a "tool_call" subkey

The pattern followed is `dict(type="foo", foo=<...>)`
This commit is contained in:
Ashwin Bharambe 2025-01-22 12:16:18 -08:00 committed by GitHub
parent caa8387dd2
commit 07b87365ab
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 104 additions and 76 deletions

View file

@ -472,7 +472,7 @@ class TestConvertStreamChatCompletionResponse:
iter = converted.__aiter__()
chunk = await iter.__anext__()
assert chunk.event.event_type == ChatCompletionResponseEventType.start
assert chunk.event.delta.content == ToolCall(
assert chunk.event.delta.tool_call == ToolCall(
call_id="tool_call_id",
tool_name="get_flight_info",
arguments={"origin": "AU", "destination": "LAX"},

View file

@ -470,16 +470,16 @@ class TestInference:
)
first = grouped[ChatCompletionResponseEventType.progress][0]
if not isinstance(
first.event.delta.content, ToolCall
first.event.delta.tool_call, ToolCall
): # first chunk may contain entire call
assert first.event.delta.parse_status == ToolCallParseStatus.started
last = grouped[ChatCompletionResponseEventType.progress][-1]
# assert last.event.stop_reason == expected_stop_reason
assert last.event.delta.parse_status == ToolCallParseStatus.succeeded
assert isinstance(last.event.delta.content, ToolCall)
assert isinstance(last.event.delta.tool_call, ToolCall)
call = last.event.delta.content
call = last.event.delta.tool_call
assert call.tool_name == "get_weather"
assert "location" in call.arguments
assert "San Francisco" in call.arguments["location"]

View file

@ -32,13 +32,15 @@ class TestVisionModelInference:
"image, expected_strings",
[
(
ImageContentItem(data=PASTA_IMAGE),
ImageContentItem(image=dict(data=PASTA_IMAGE)),
["spaghetti"],
),
(
ImageContentItem(
url=URL(
uri="https://www.healthypawspetinsurance.com/Images/V3/DogAndPuppyInsurance/Dog_CTA_Desktop_HeroImage.jpg"
image=dict(
url=URL(
uri="https://www.healthypawspetinsurance.com/Images/V3/DogAndPuppyInsurance/Dog_CTA_Desktop_HeroImage.jpg"
)
)
),
["puppy"],
@ -103,8 +105,10 @@ class TestVisionModelInference:
images = [
ImageContentItem(
url=URL(
uri="https://www.healthypawspetinsurance.com/Images/V3/DogAndPuppyInsurance/Dog_CTA_Desktop_HeroImage.jpg"
image=dict(
url=URL(
uri="https://www.healthypawspetinsurance.com/Images/V3/DogAndPuppyInsurance/Dog_CTA_Desktop_HeroImage.jpg"
)
)
),
]