mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-29 03:14:19 +00:00
inference: Add model option to client (#170)
I was running this client for testing purposes and being able to specify which model to use is a convenient addition. This change makes that possible.
This commit is contained in:
parent
210b71b0ba
commit
06db9213b1
1 changed files with 23 additions and 7 deletions
|
@ -6,6 +6,7 @@
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
|
import sys
|
||||||
from typing import Any, AsyncGenerator, List, Optional
|
from typing import Any, AsyncGenerator, List, Optional
|
||||||
|
|
||||||
import fire
|
import fire
|
||||||
|
@ -100,15 +101,18 @@ class InferenceClient(Inference):
|
||||||
print(f"Error with parsing or validation: {e}")
|
print(f"Error with parsing or validation: {e}")
|
||||||
|
|
||||||
|
|
||||||
async def run_main(host: str, port: int, stream: bool):
|
async def run_main(host: str, port: int, stream: bool, model: Optional[str]):
|
||||||
client = InferenceClient(f"http://{host}:{port}")
|
client = InferenceClient(f"http://{host}:{port}")
|
||||||
|
|
||||||
|
if not model:
|
||||||
|
model = "Llama3.1-8B-Instruct"
|
||||||
|
|
||||||
message = UserMessage(
|
message = UserMessage(
|
||||||
content="hello world, write me a 2 sentence poem about the moon"
|
content="hello world, write me a 2 sentence poem about the moon"
|
||||||
)
|
)
|
||||||
cprint(f"User>{message.content}", "green")
|
cprint(f"User>{message.content}", "green")
|
||||||
iterator = client.chat_completion(
|
iterator = client.chat_completion(
|
||||||
model="Llama3.1-8B-Instruct",
|
model=model,
|
||||||
messages=[message],
|
messages=[message],
|
||||||
stream=stream,
|
stream=stream,
|
||||||
)
|
)
|
||||||
|
@ -116,9 +120,14 @@ async def run_main(host: str, port: int, stream: bool):
|
||||||
log.print()
|
log.print()
|
||||||
|
|
||||||
|
|
||||||
async def run_mm_main(host: str, port: int, stream: bool, path: str):
|
async def run_mm_main(
|
||||||
|
host: str, port: int, stream: bool, path: Optional[str], model: Optional[str]
|
||||||
|
):
|
||||||
client = InferenceClient(f"http://{host}:{port}")
|
client = InferenceClient(f"http://{host}:{port}")
|
||||||
|
|
||||||
|
if not model:
|
||||||
|
model = "Llama3.2-11B-Vision-Instruct"
|
||||||
|
|
||||||
message = UserMessage(
|
message = UserMessage(
|
||||||
content=[
|
content=[
|
||||||
ImageMedia(image=URL(uri=f"file://{path}")),
|
ImageMedia(image=URL(uri=f"file://{path}")),
|
||||||
|
@ -127,7 +136,7 @@ async def run_mm_main(host: str, port: int, stream: bool, path: str):
|
||||||
)
|
)
|
||||||
cprint(f"User>{message.content}", "green")
|
cprint(f"User>{message.content}", "green")
|
||||||
iterator = client.chat_completion(
|
iterator = client.chat_completion(
|
||||||
model="Llama3.2-11B-Vision-Instruct",
|
model=model,
|
||||||
messages=[message],
|
messages=[message],
|
||||||
stream=stream,
|
stream=stream,
|
||||||
)
|
)
|
||||||
|
@ -135,11 +144,18 @@ async def run_mm_main(host: str, port: int, stream: bool, path: str):
|
||||||
log.print()
|
log.print()
|
||||||
|
|
||||||
|
|
||||||
def main(host: str, port: int, stream: bool = True, mm: bool = False, file: str = None):
|
def main(
|
||||||
|
host: str,
|
||||||
|
port: int,
|
||||||
|
stream: bool = True,
|
||||||
|
mm: bool = False,
|
||||||
|
file: Optional[str] = None,
|
||||||
|
model: Optional[str] = None,
|
||||||
|
):
|
||||||
if mm:
|
if mm:
|
||||||
asyncio.run(run_mm_main(host, port, stream, file))
|
asyncio.run(run_mm_main(host, port, stream, file, model))
|
||||||
else:
|
else:
|
||||||
asyncio.run(run_main(host, port, stream))
|
asyncio.run(run_main(host, port, stream, model))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue