diff --git a/llama_stack/apis/inference/client.py b/llama_stack/apis/inference/client.py index 92acc3e14..5cfae633c 100644 --- a/llama_stack/apis/inference/client.py +++ b/llama_stack/apis/inference/client.py @@ -6,6 +6,7 @@ import asyncio import json +import sys from typing import Any, AsyncGenerator, List, Optional import fire @@ -100,15 +101,18 @@ class InferenceClient(Inference): print(f"Error with parsing or validation: {e}") -async def run_main(host: str, port: int, stream: bool): +async def run_main(host: str, port: int, stream: bool, model: Optional[str]): client = InferenceClient(f"http://{host}:{port}") + if not model: + model = "Llama3.1-8B-Instruct" + message = UserMessage( content="hello world, write me a 2 sentence poem about the moon" ) cprint(f"User>{message.content}", "green") iterator = client.chat_completion( - model="Llama3.1-8B-Instruct", + model=model, messages=[message], stream=stream, ) @@ -116,9 +120,14 @@ async def run_main(host: str, port: int, stream: bool): log.print() -async def run_mm_main(host: str, port: int, stream: bool, path: str): +async def run_mm_main( + host: str, port: int, stream: bool, path: Optional[str], model: Optional[str] +): client = InferenceClient(f"http://{host}:{port}") + if not model: + model = "Llama3.2-11B-Vision-Instruct" + message = UserMessage( content=[ ImageMedia(image=URL(uri=f"file://{path}")), @@ -127,7 +136,7 @@ async def run_mm_main(host: str, port: int, stream: bool, path: str): ) cprint(f"User>{message.content}", "green") iterator = client.chat_completion( - model="Llama3.2-11B-Vision-Instruct", + model=model, messages=[message], stream=stream, ) @@ -135,11 +144,18 @@ async def run_mm_main(host: str, port: int, stream: bool, path: str): log.print() -def main(host: str, port: int, stream: bool = True, mm: bool = False, file: str = None): +def main( + host: str, + port: int, + stream: bool = True, + mm: bool = False, + file: Optional[str] = None, + model: Optional[str] = None, +): if mm: - asyncio.run(run_mm_main(host, port, stream, file)) + asyncio.run(run_mm_main(host, port, stream, file, model)) else: - asyncio.run(run_main(host, port, stream)) + asyncio.run(run_main(host, port, stream, model)) if __name__ == "__main__":