drop custom classes to manage hydra

This commit is contained in:
Hardik Shah 2024-07-22 20:40:50 -07:00
parent 86fff23a9e
commit aca6bfe0df
4 changed files with 17 additions and 82 deletions

View file

@ -1,9 +1,14 @@
inference_config:
impl_type: "inline"
inline_config:
checkpoint_type: "pytorch"
checkpoint_dir: {checkpoint_dir}/
tokenizer_path: {checkpoint_dir}/tokenizer.model
model_parallel_size: {model_parallel_size}
impl_config:
impl_type: "inline"
checkpoint_config:
checkpoint:
checkpoint_type: "pytorch"
checkpoint_dir: {checkpoint_dir}/
tokenizer_path: {checkpoint_dir}/tokenizer.model
model_parallel_size: {model_parallel_size}
quantization_format: bf16
quantization: null
torch_seed: null
max_seq_len: 2048
max_batch_size: 1

View file

@ -8,6 +8,7 @@ from dataclasses import dataclass
from enum import Enum
from typing import Literal, Optional, Union
from hydra_zen import builds
from hydra.core.config_store import ConfigStore
from llama_models.llama3_1.api.datatypes import CheckpointQuantizationFormat
@ -78,78 +79,7 @@ class InferenceConfig(BaseModel):
]
# Hydra does not like unions of containers and
# Pydantic does not like Literals
# Adding a simple dataclass with custom coversion
# to config classes
@dataclass
class InlineImplHydraConfig:
checkpoint_type: str # "pytorch" / "HF"
# pytorch checkpoint required args
checkpoint_dir: str
tokenizer_path: str
model_parallel_size: int
max_seq_len: int
max_batch_size: int = 1
quantization: Optional[QuantizationConfig] = None
# TODO: huggingface checkpoint required args
def convert_to_inline_impl_config(self):
if self.checkpoint_type == "pytorch":
return InlineImplConfig(
checkpoint_config=ModelCheckpointConfig(
checkpoint=PytorchCheckpoint(
checkpoint_type=CheckpointType.pytorch.value,
checkpoint_dir=self.checkpoint_dir,
tokenizer_path=self.tokenizer_path,
model_parallel_size=self.model_parallel_size,
)
),
quantization=self.quantization,
max_seq_len=self.max_seq_len,
max_batch_size=self.max_batch_size,
)
else:
raise NotImplementedError("HF Checkpoint not supported yet")
@dataclass
class RemoteImplHydraConfig:
url: str
def convert_to_remote_impl_config(self):
return RemoteImplConfig(
url=self.url,
)
@dataclass
class InferenceHydraConfig:
impl_type: str
inline_config: Optional[InlineImplHydraConfig] = None
remote_config: Optional[RemoteImplHydraConfig] = None
def __post_init__(self):
assert self.impl_type in ["inline", "remote"]
if self.impl_type == "inline":
assert self.inline_config is not None
if self.impl_type == "remote":
assert self.remote_config is not None
def convert_to_inference_config(self):
if self.impl_type == "inline":
inline_config = InlineImplHydraConfig(**self.inline_config)
return InferenceConfig(
impl_config=inline_config.convert_to_inline_impl_config()
)
elif self.impl_type == "remote":
remote_config = RemoteImplHydraConfig(**self.remote_config)
return InferenceConfig(
impl_config=remote_config.convert_to_remote_impl_config()
)
InferenceHydraConfig = builds(InferenceConfig)
cs = ConfigStore.instance()
cs.store(name="inference_config", node=InferenceHydraConfig)

View file

@ -6,6 +6,7 @@
import asyncio
import json
from termcolor import cprint
from typing import AsyncGenerator
from urllib.request import getproxies
@ -65,6 +66,7 @@ async def run_main(host: str, port: int):
client = InferenceClient(f"http://{host}:{port}")
message = UserMessage(content="hello world, help me out here")
cprint(f"User>{message.content}", "green")
req = ChatCompletionRequest(
model=InstructModel.llama3_70b_chat,
messages=[message],

View file

@ -14,6 +14,7 @@ from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import StreamingResponse
from hydra_zen import instantiate
from omegaconf import OmegaConf
from llama_toolchain.utils import get_default_config_dir, parse_config
@ -49,11 +50,8 @@ async def startup():
global InferenceApiInstance
config = get_config()
hydra_config = InferenceHydraConfig(
**OmegaConf.to_container(config["inference_config"], resolve=True)
)
inference_config = hydra_config.convert_to_inference_config()
inference_config = instantiate(config["inference_config"])
InferenceApiInstance = await get_inference_api_instance(
inference_config,
)