Tests pass with Ollama now

This commit is contained in:
Ashwin Bharambe 2024-12-15 17:31:21 -08:00
parent a9a041a1de
commit e51154964f
27 changed files with 83 additions and 65 deletions

View file

@ -6,6 +6,7 @@
import asyncio
import base64
import io
import json
import logging
import re
@ -21,7 +22,6 @@ from llama_models.llama3.api.datatypes import (
RawMediaItem,
RawTextItem,
Role,
ToolChoice,
ToolPromptFormat,
)
from llama_models.llama3.prompt_templates import (
@ -47,6 +47,7 @@ from llama_stack.apis.inference import (
ResponseFormatType,
SystemMessage,
TextContentItem,
ToolChoice,
UserMessage,
)
@ -136,7 +137,7 @@ def request_has_media(request: Union[ChatCompletionRequest, CompletionRequest]):
async def localize_image_content(media: ImageContentItem) -> Tuple[bytes, str]:
if isinstance(media.data, URL) and media.data.uri.startswith("http"):
async with httpx.AsyncClient() as client:
r = await client.get(media.image.uri)
r = await client.get(media.data.uri)
content = r.content
content_type = r.headers.get("content-type")
if content_type:
@ -145,7 +146,7 @@ async def localize_image_content(media: ImageContentItem) -> Tuple[bytes, str]:
format = "png"
return content, format
else:
image = PIL_Image.open(media.data)
image = PIL_Image.open(io.BytesIO(media.data))
return media.data, image.format
@ -153,7 +154,7 @@ async def convert_image_content_to_url(
media: ImageContentItem, download: bool = False, include_format: bool = True
) -> str:
if isinstance(media.data, URL) and not download:
return media.image.uri
return media.data.uri
content, format = await localize_image_content(media)
if include_format: