inference: Add model option to client (#170)

I was running this client for testing purposes and being able to
specify which model to use is a convenient addition. This change makes
that possible.
This commit is contained in:
Russell Bryant 2024-10-03 14:18:57 -04:00 committed by GitHub
parent 210b71b0ba
commit 06db9213b1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -6,6 +6,7 @@
import asyncio import asyncio
import json import json
import sys
from typing import Any, AsyncGenerator, List, Optional from typing import Any, AsyncGenerator, List, Optional
import fire import fire
@ -100,15 +101,18 @@ class InferenceClient(Inference):
print(f"Error with parsing or validation: {e}") print(f"Error with parsing or validation: {e}")
async def run_main(host: str, port: int, stream: bool): async def run_main(host: str, port: int, stream: bool, model: Optional[str]):
client = InferenceClient(f"http://{host}:{port}") client = InferenceClient(f"http://{host}:{port}")
if not model:
model = "Llama3.1-8B-Instruct"
message = UserMessage( message = UserMessage(
content="hello world, write me a 2 sentence poem about the moon" content="hello world, write me a 2 sentence poem about the moon"
) )
cprint(f"User>{message.content}", "green") cprint(f"User>{message.content}", "green")
iterator = client.chat_completion( iterator = client.chat_completion(
model="Llama3.1-8B-Instruct", model=model,
messages=[message], messages=[message],
stream=stream, stream=stream,
) )
@ -116,9 +120,14 @@ async def run_main(host: str, port: int, stream: bool):
log.print() log.print()
async def run_mm_main(host: str, port: int, stream: bool, path: str): async def run_mm_main(
host: str, port: int, stream: bool, path: Optional[str], model: Optional[str]
):
client = InferenceClient(f"http://{host}:{port}") client = InferenceClient(f"http://{host}:{port}")
if not model:
model = "Llama3.2-11B-Vision-Instruct"
message = UserMessage( message = UserMessage(
content=[ content=[
ImageMedia(image=URL(uri=f"file://{path}")), ImageMedia(image=URL(uri=f"file://{path}")),
@ -127,7 +136,7 @@ async def run_mm_main(host: str, port: int, stream: bool, path: str):
) )
cprint(f"User>{message.content}", "green") cprint(f"User>{message.content}", "green")
iterator = client.chat_completion( iterator = client.chat_completion(
model="Llama3.2-11B-Vision-Instruct", model=model,
messages=[message], messages=[message],
stream=stream, stream=stream,
) )
@ -135,11 +144,18 @@ async def run_mm_main(host: str, port: int, stream: bool, path: str):
log.print() log.print()
def main(host: str, port: int, stream: bool = True, mm: bool = False, file: str = None): def main(
host: str,
port: int,
stream: bool = True,
mm: bool = False,
file: Optional[str] = None,
model: Optional[str] = None,
):
if mm: if mm:
asyncio.run(run_mm_main(host, port, stream, file)) asyncio.run(run_mm_main(host, port, stream, file, model))
else: else:
asyncio.run(run_main(host, port, stream)) asyncio.run(run_main(host, port, stream, model))
if __name__ == "__main__": if __name__ == "__main__":