forked from phoenix-oss/llama-stack-mirror
[inference api] modify content types so they follow a more standard structure (#841)
Some small updates to the inference types to make them more standard Specifically: - image data is now located in a "image" subkey - similarly tool call data is located in a "tool_call" subkey The pattern followed is `dict(type="foo", foo=<...>)`
This commit is contained in:
parent
caa8387dd2
commit
07b87365ab
15 changed files with 104 additions and 76 deletions
|
@ -113,28 +113,29 @@ async def interleaved_content_convert_to_raw(
|
|||
elif isinstance(c, TextContentItem):
|
||||
return RawTextItem(text=c.text)
|
||||
elif isinstance(c, ImageContentItem):
|
||||
if c.url:
|
||||
image = c.image
|
||||
if image.url:
|
||||
# Load image bytes from URL
|
||||
if c.url.uri.startswith("data"):
|
||||
match = re.match(r"data:image/(\w+);base64,(.+)", c.url.uri)
|
||||
if image.url.uri.startswith("data"):
|
||||
match = re.match(r"data:image/(\w+);base64,(.+)", image.url.uri)
|
||||
if not match:
|
||||
raise ValueError(
|
||||
f"Invalid data URL format, {c.url.uri[:40]}..."
|
||||
f"Invalid data URL format, {image.url.uri[:40]}..."
|
||||
)
|
||||
_, image_data = match.groups()
|
||||
data = base64.b64decode(image_data)
|
||||
elif c.url.uri.startswith("file://"):
|
||||
path = c.url.uri[len("file://") :]
|
||||
elif image.url.uri.startswith("file://"):
|
||||
path = image.url.uri[len("file://") :]
|
||||
with open(path, "rb") as f:
|
||||
data = f.read() # type: ignore
|
||||
elif c.url.uri.startswith("http"):
|
||||
elif image.url.uri.startswith("http"):
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(c.url.uri)
|
||||
response = await client.get(image.url.uri)
|
||||
data = response.content
|
||||
else:
|
||||
raise ValueError("Unsupported URL type")
|
||||
elif c.data:
|
||||
data = c.data
|
||||
elif image.data:
|
||||
data = image.data
|
||||
else:
|
||||
raise ValueError("No data or URL provided")
|
||||
|
||||
|
@ -170,26 +171,29 @@ def request_has_media(request: Union[ChatCompletionRequest, CompletionRequest]):
|
|||
|
||||
|
||||
async def localize_image_content(media: ImageContentItem) -> Tuple[bytes, str]:
|
||||
if media.url and media.url.uri.startswith("http"):
|
||||
image = media.image
|
||||
if image.url and image.url.uri.startswith("http"):
|
||||
async with httpx.AsyncClient() as client:
|
||||
r = await client.get(media.url.uri)
|
||||
r = await client.get(image.url.uri)
|
||||
content = r.content
|
||||
content_type = r.headers.get("content-type")
|
||||
if content_type:
|
||||
format = content_type.split("/")[-1]
|
||||
else:
|
||||
format = "png"
|
||||
|
||||
return content, format
|
||||
else:
|
||||
image = PIL_Image.open(io.BytesIO(media.data))
|
||||
return media.data, image.format
|
||||
pil_image = PIL_Image.open(io.BytesIO(image.data))
|
||||
return image.data, pil_image.format
|
||||
|
||||
|
||||
async def convert_image_content_to_url(
|
||||
media: ImageContentItem, download: bool = False, include_format: bool = True
|
||||
) -> str:
|
||||
if media.url and (not download or media.url.uri.startswith("data")):
|
||||
return media.url.uri
|
||||
image = media.image
|
||||
if image.url and (not download or image.url.uri.startswith("data")):
|
||||
return image.url.uri
|
||||
|
||||
content, format = await localize_image_content(media)
|
||||
if include_format:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue