llama-stack-mirror/llama_stack/apis/inference/client.py
2024-09-25 10:29:58 -07:00

151 lines
4.7 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import json
from typing import Any, AsyncGenerator, List, Optional
import fire
import httpx
from llama_models.llama3.api.datatypes import ImageMedia, URL
from PIL import Image as PIL_Image
from pydantic import BaseModel
from llama_models.llama3.api import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
from termcolor import cprint
from llama_stack.distribution.datatypes import RemoteProviderConfig
from .event_logger import EventLogger
async def get_client_impl(config: RemoteProviderConfig, _deps: Any) -> Inference:
return InferenceClient(config.url)
def encodable_dict(d: BaseModel):
return json.loads(d.json())
class InferenceClient(Inference):
def __init__(self, base_url: str):
self.base_url = base_url
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def completion(self, request: CompletionRequest) -> AsyncGenerator:
raise NotImplementedError()
async def chat_completion(
self,
model: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
request = ChatCompletionRequest(
model=model,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream,
logprobs=logprobs,
)
async with httpx.AsyncClient() as client:
async with client.stream(
"POST",
f"{self.base_url}/inference/chat_completion",
json=encodable_dict(request),
headers={"Content-Type": "application/json"},
timeout=20,
) as response:
if response.status_code != 200:
content = await response.aread()
cprint(
f"Error: HTTP {response.status_code} {content.decode()}", "red"
)
return
async for line in response.aiter_lines():
if line.startswith("data:"):
data = line[len("data: ") :]
try:
if request.stream:
if "error" in data:
cprint(data, "red")
continue
yield ChatCompletionResponseStreamChunk(
**json.loads(data)
)
else:
yield ChatCompletionResponse(**json.loads(data))
except Exception as e:
print(data)
print(f"Error with parsing or validation: {e}")
async def run_main(host: str, port: int, stream: bool):
client = InferenceClient(f"http://{host}:{port}")
message = UserMessage(
content="hello world, write me a 2 sentence poem about the moon"
)
cprint(f"User>{message.content}", "green")
iterator = client.chat_completion(
model="Llama3.1-8B-Instruct",
messages=[message],
stream=stream,
)
async for log in EventLogger().log(iterator):
log.print()
async def run_mm_main(host: str, port: int, stream: bool, path: str):
client = InferenceClient(f"http://{host}:{port}")
with open(path, "rb") as f:
img = PIL_Image.open(f).convert("RGB")
message = UserMessage(
content=[
ImageMedia(image=URL(uri=f"file://{path}")),
# ImageMedia(image=img),
"Describe this image in two sentences",
],
)
cprint(f"User>{message.content}", "green")
iterator = client.chat_completion(
model="Llama3.2-11B-Vision-Instruct",
messages=[message],
stream=stream,
)
async for log in EventLogger().log(iterator):
log.print()
def main(host: str, port: int, stream: bool = True, mm: bool = False, file: str = None):
if mm:
asyncio.run(run_mm_main(host, port, stream, file))
else:
asyncio.run(run_main(host, port, stream))
if __name__ == "__main__":
fire.Fire(main)