diff --git a/llama_toolchain/data/default_inference_config.yaml b/llama_toolchain/data/default_inference_config.yaml index 253e0e143..29e40c9b8 100644 --- a/llama_toolchain/data/default_inference_config.yaml +++ b/llama_toolchain/data/default_inference_config.yaml @@ -1,9 +1,14 @@ inference_config: - impl_type: "inline" - inline_config: - checkpoint_type: "pytorch" - checkpoint_dir: {checkpoint_dir}/ - tokenizer_path: {checkpoint_dir}/tokenizer.model - model_parallel_size: {model_parallel_size} + impl_config: + impl_type: "inline" + checkpoint_config: + checkpoint: + checkpoint_type: "pytorch" + checkpoint_dir: {checkpoint_dir}/ + tokenizer_path: {checkpoint_dir}/tokenizer.model + model_parallel_size: {model_parallel_size} + quantization_format: bf16 + quantization: null + torch_seed: null max_seq_len: 2048 max_batch_size: 1 diff --git a/llama_toolchain/inference/api/config.py b/llama_toolchain/inference/api/config.py index 538345e1f..bf069c0f2 100644 --- a/llama_toolchain/inference/api/config.py +++ b/llama_toolchain/inference/api/config.py @@ -8,6 +8,7 @@ from dataclasses import dataclass from enum import Enum from typing import Literal, Optional, Union +from hydra_zen import builds from hydra.core.config_store import ConfigStore from llama_models.llama3_1.api.datatypes import CheckpointQuantizationFormat @@ -78,78 +79,7 @@ class InferenceConfig(BaseModel): ] -# Hydra does not like unions of containers and -# Pydantic does not like Literals -# Adding a simple dataclass with custom coversion -# to config classes - - -@dataclass -class InlineImplHydraConfig: - checkpoint_type: str # "pytorch" / "HF" - # pytorch checkpoint required args - checkpoint_dir: str - tokenizer_path: str - model_parallel_size: int - max_seq_len: int - max_batch_size: int = 1 - quantization: Optional[QuantizationConfig] = None - # TODO: huggingface checkpoint required args - - def convert_to_inline_impl_config(self): - if self.checkpoint_type == "pytorch": - return InlineImplConfig( - checkpoint_config=ModelCheckpointConfig( - checkpoint=PytorchCheckpoint( - checkpoint_type=CheckpointType.pytorch.value, - checkpoint_dir=self.checkpoint_dir, - tokenizer_path=self.tokenizer_path, - model_parallel_size=self.model_parallel_size, - ) - ), - quantization=self.quantization, - max_seq_len=self.max_seq_len, - max_batch_size=self.max_batch_size, - ) - else: - raise NotImplementedError("HF Checkpoint not supported yet") - - -@dataclass -class RemoteImplHydraConfig: - url: str - - def convert_to_remote_impl_config(self): - return RemoteImplConfig( - url=self.url, - ) - - -@dataclass -class InferenceHydraConfig: - impl_type: str - inline_config: Optional[InlineImplHydraConfig] = None - remote_config: Optional[RemoteImplHydraConfig] = None - - def __post_init__(self): - assert self.impl_type in ["inline", "remote"] - if self.impl_type == "inline": - assert self.inline_config is not None - if self.impl_type == "remote": - assert self.remote_config is not None - - def convert_to_inference_config(self): - if self.impl_type == "inline": - inline_config = InlineImplHydraConfig(**self.inline_config) - return InferenceConfig( - impl_config=inline_config.convert_to_inline_impl_config() - ) - elif self.impl_type == "remote": - remote_config = RemoteImplHydraConfig(**self.remote_config) - return InferenceConfig( - impl_config=remote_config.convert_to_remote_impl_config() - ) - +InferenceHydraConfig = builds(InferenceConfig) cs = ConfigStore.instance() cs.store(name="inference_config", node=InferenceHydraConfig) diff --git a/llama_toolchain/inference/client.py b/llama_toolchain/inference/client.py index 032e4f477..824274965 100644 --- a/llama_toolchain/inference/client.py +++ b/llama_toolchain/inference/client.py @@ -6,6 +6,7 @@ import asyncio import json +from termcolor import cprint from typing import AsyncGenerator from urllib.request import getproxies @@ -65,6 +66,7 @@ async def run_main(host: str, port: int): client = InferenceClient(f"http://{host}:{port}") message = UserMessage(content="hello world, help me out here") + cprint(f"User>{message.content}", "green") req = ChatCompletionRequest( model=InstructModel.llama3_70b_chat, messages=[message], diff --git a/llama_toolchain/inference/server.py b/llama_toolchain/inference/server.py index c0bdfca03..f5790e74b 100644 --- a/llama_toolchain/inference/server.py +++ b/llama_toolchain/inference/server.py @@ -14,6 +14,7 @@ from dotenv import load_dotenv from fastapi import FastAPI, HTTPException, Request from fastapi.responses import StreamingResponse +from hydra_zen import instantiate from omegaconf import OmegaConf from llama_toolchain.utils import get_default_config_dir, parse_config @@ -49,11 +50,8 @@ async def startup(): global InferenceApiInstance config = get_config() - hydra_config = InferenceHydraConfig( - **OmegaConf.to_container(config["inference_config"], resolve=True) - ) - inference_config = hydra_config.convert_to_inference_config() + inference_config = instantiate(config["inference_config"]) InferenceApiInstance = await get_inference_api_instance( inference_config, )