mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
inference: Add model option to client (#170)
I was running this client for testing purposes and being able to specify which model to use is a convenient addition. This change makes that possible.
This commit is contained in:
parent
210b71b0ba
commit
06db9213b1
1 changed files with 23 additions and 7 deletions
|
@ -6,6 +6,7 @@
|
|||
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
from typing import Any, AsyncGenerator, List, Optional
|
||||
|
||||
import fire
|
||||
|
@ -100,15 +101,18 @@ class InferenceClient(Inference):
|
|||
print(f"Error with parsing or validation: {e}")
|
||||
|
||||
|
||||
async def run_main(host: str, port: int, stream: bool):
|
||||
async def run_main(host: str, port: int, stream: bool, model: Optional[str]):
|
||||
client = InferenceClient(f"http://{host}:{port}")
|
||||
|
||||
if not model:
|
||||
model = "Llama3.1-8B-Instruct"
|
||||
|
||||
message = UserMessage(
|
||||
content="hello world, write me a 2 sentence poem about the moon"
|
||||
)
|
||||
cprint(f"User>{message.content}", "green")
|
||||
iterator = client.chat_completion(
|
||||
model="Llama3.1-8B-Instruct",
|
||||
model=model,
|
||||
messages=[message],
|
||||
stream=stream,
|
||||
)
|
||||
|
@ -116,9 +120,14 @@ async def run_main(host: str, port: int, stream: bool):
|
|||
log.print()
|
||||
|
||||
|
||||
async def run_mm_main(host: str, port: int, stream: bool, path: str):
|
||||
async def run_mm_main(
|
||||
host: str, port: int, stream: bool, path: Optional[str], model: Optional[str]
|
||||
):
|
||||
client = InferenceClient(f"http://{host}:{port}")
|
||||
|
||||
if not model:
|
||||
model = "Llama3.2-11B-Vision-Instruct"
|
||||
|
||||
message = UserMessage(
|
||||
content=[
|
||||
ImageMedia(image=URL(uri=f"file://{path}")),
|
||||
|
@ -127,7 +136,7 @@ async def run_mm_main(host: str, port: int, stream: bool, path: str):
|
|||
)
|
||||
cprint(f"User>{message.content}", "green")
|
||||
iterator = client.chat_completion(
|
||||
model="Llama3.2-11B-Vision-Instruct",
|
||||
model=model,
|
||||
messages=[message],
|
||||
stream=stream,
|
||||
)
|
||||
|
@ -135,11 +144,18 @@ async def run_mm_main(host: str, port: int, stream: bool, path: str):
|
|||
log.print()
|
||||
|
||||
|
||||
def main(host: str, port: int, stream: bool = True, mm: bool = False, file: str = None):
|
||||
def main(
|
||||
host: str,
|
||||
port: int,
|
||||
stream: bool = True,
|
||||
mm: bool = False,
|
||||
file: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
):
|
||||
if mm:
|
||||
asyncio.run(run_mm_main(host, port, stream, file))
|
||||
asyncio.run(run_mm_main(host, port, stream, file, model))
|
||||
else:
|
||||
asyncio.run(run_main(host, port, stream))
|
||||
asyncio.run(run_main(host, port, stream, model))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue