inference: Add model option to client (#170)

I was running this client for testing purposes and being able to
specify which model to use is a convenient addition. This change makes
that possible.
This commit is contained in:
Russell Bryant 2024-10-03 14:18:57 -04:00 committed by GitHub
parent 210b71b0ba
commit 06db9213b1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -6,6 +6,7 @@
import asyncio
import json
import sys
from typing import Any, AsyncGenerator, List, Optional
import fire
@ -100,15 +101,18 @@ class InferenceClient(Inference):
print(f"Error with parsing or validation: {e}")
async def run_main(host: str, port: int, stream: bool):
async def run_main(host: str, port: int, stream: bool, model: Optional[str]):
client = InferenceClient(f"http://{host}:{port}")
if not model:
model = "Llama3.1-8B-Instruct"
message = UserMessage(
content="hello world, write me a 2 sentence poem about the moon"
)
cprint(f"User>{message.content}", "green")
iterator = client.chat_completion(
model="Llama3.1-8B-Instruct",
model=model,
messages=[message],
stream=stream,
)
@ -116,9 +120,14 @@ async def run_main(host: str, port: int, stream: bool):
log.print()
async def run_mm_main(host: str, port: int, stream: bool, path: str):
async def run_mm_main(
host: str, port: int, stream: bool, path: Optional[str], model: Optional[str]
):
client = InferenceClient(f"http://{host}:{port}")
if not model:
model = "Llama3.2-11B-Vision-Instruct"
message = UserMessage(
content=[
ImageMedia(image=URL(uri=f"file://{path}")),
@ -127,7 +136,7 @@ async def run_mm_main(host: str, port: int, stream: bool, path: str):
)
cprint(f"User>{message.content}", "green")
iterator = client.chat_completion(
model="Llama3.2-11B-Vision-Instruct",
model=model,
messages=[message],
stream=stream,
)
@ -135,11 +144,18 @@ async def run_mm_main(host: str, port: int, stream: bool, path: str):
log.print()
def main(host: str, port: int, stream: bool = True, mm: bool = False, file: str = None):
def main(
host: str,
port: int,
stream: bool = True,
mm: bool = False,
file: Optional[str] = None,
model: Optional[str] = None,
):
if mm:
asyncio.run(run_mm_main(host, port, stream, file))
asyncio.run(run_mm_main(host, port, stream, file, model))
else:
asyncio.run(run_main(host, port, stream))
asyncio.run(run_main(host, port, stream, model))
if __name__ == "__main__":