# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. import asyncio import json from typing import AsyncGenerator import fire import httpx from termcolor import cprint from .api import ( ChatCompletionRequest, ChatCompletionResponseStreamChunk, CompletionRequest, Inference, InstructModel, UserMessage, ) from .event_logger import EventLogger class InferenceClient(Inference): def __init__(self, base_url: str): print(f"Initializing client for {base_url}") self.base_url = base_url async def initialize(self) -> None: pass async def shutdown(self) -> None: pass async def completion(self, request: CompletionRequest) -> AsyncGenerator: raise NotImplementedError() async def chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: async with httpx.AsyncClient() as client: async with client.stream( "POST", f"{self.base_url}/inference/chat_completion", data=request.json(), headers={"Content-Type": "application/json"}, timeout=20, ) as response: async for line in response.aiter_lines(): if line.startswith("data:"): data = line[len("data: ") :] try: yield ChatCompletionResponseStreamChunk(**json.loads(data)) except Exception as e: print(data) print(f"Error with parsing or validation: {e}") async def run_main(host: str, port: int): client = InferenceClient(f"http://{host}:{port}") message = UserMessage(content="hello world, help me out here") cprint(f"User>{message.content}", "green") req = ChatCompletionRequest( model=InstructModel.llama3_70b_chat, messages=[message], stream=True, ) iterator = client.chat_completion( ChatCompletionRequest( model=InstructModel.llama3_8b_chat, messages=[message], stream=True, ) ) async for log in EventLogger().log(iterator): log.print() def main(host: str, port: int): asyncio.run(run_main(host, port)) if __name__ == "__main__": fire.Fire(main)