Support for Llama3.2 models and Swift SDK (#98)

This commit is contained in:
Ashwin Bharambe 2024-09-25 10:29:58 -07:00 committed by GitHub
parent 95abbf576b
commit 56aed59eb4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
56 changed files with 3745 additions and 630 deletions

View file

@ -10,6 +10,10 @@ from typing import Any, AsyncGenerator, List, Optional
import fire
import httpx
from llama_models.llama3.api.datatypes import ImageMedia, URL
from PIL import Image as PIL_Image
from pydantic import BaseModel
from llama_models.llama3.api import * # noqa: F403
@ -105,7 +109,7 @@ async def run_main(host: str, port: int, stream: bool):
)
cprint(f"User>{message.content}", "green")
iterator = client.chat_completion(
model="Meta-Llama3.1-8B-Instruct",
model="Llama3.1-8B-Instruct",
messages=[message],
stream=stream,
)
@ -113,8 +117,34 @@ async def run_main(host: str, port: int, stream: bool):
log.print()
def main(host: str, port: int, stream: bool = True):
asyncio.run(run_main(host, port, stream))
async def run_mm_main(host: str, port: int, stream: bool, path: str):
client = InferenceClient(f"http://{host}:{port}")
with open(path, "rb") as f:
img = PIL_Image.open(f).convert("RGB")
message = UserMessage(
content=[
ImageMedia(image=URL(uri=f"file://{path}")),
# ImageMedia(image=img),
"Describe this image in two sentences",
],
)
cprint(f"User>{message.content}", "green")
iterator = client.chat_completion(
model="Llama3.2-11B-Vision-Instruct",
messages=[message],
stream=stream,
)
async for log in EventLogger().log(iterator):
log.print()
def main(host: str, port: int, stream: bool = True, mm: bool = False, file: str = None):
if mm:
asyncio.run(run_mm_main(host, port, stream, file))
else:
asyncio.run(run_main(host, port, stream))
if __name__ == "__main__":