drop custom classes to manage hydra

This commit is contained in:
Hardik Shah 2024-07-22 20:40:50 -07:00
parent 86fff23a9e
commit aca6bfe0df
4 changed files with 17 additions and 82 deletions

View file

@ -8,6 +8,7 @@ from dataclasses import dataclass
from enum import Enum
from typing import Literal, Optional, Union
from hydra_zen import builds
from hydra.core.config_store import ConfigStore
from llama_models.llama3_1.api.datatypes import CheckpointQuantizationFormat
@ -78,78 +79,7 @@ class InferenceConfig(BaseModel):
]
# Hydra does not like unions of containers and
# Pydantic does not like Literals
# Adding a simple dataclass with custom coversion
# to config classes
@dataclass
class InlineImplHydraConfig:
checkpoint_type: str # "pytorch" / "HF"
# pytorch checkpoint required args
checkpoint_dir: str
tokenizer_path: str
model_parallel_size: int
max_seq_len: int
max_batch_size: int = 1
quantization: Optional[QuantizationConfig] = None
# TODO: huggingface checkpoint required args
def convert_to_inline_impl_config(self):
if self.checkpoint_type == "pytorch":
return InlineImplConfig(
checkpoint_config=ModelCheckpointConfig(
checkpoint=PytorchCheckpoint(
checkpoint_type=CheckpointType.pytorch.value,
checkpoint_dir=self.checkpoint_dir,
tokenizer_path=self.tokenizer_path,
model_parallel_size=self.model_parallel_size,
)
),
quantization=self.quantization,
max_seq_len=self.max_seq_len,
max_batch_size=self.max_batch_size,
)
else:
raise NotImplementedError("HF Checkpoint not supported yet")
@dataclass
class RemoteImplHydraConfig:
url: str
def convert_to_remote_impl_config(self):
return RemoteImplConfig(
url=self.url,
)
@dataclass
class InferenceHydraConfig:
impl_type: str
inline_config: Optional[InlineImplHydraConfig] = None
remote_config: Optional[RemoteImplHydraConfig] = None
def __post_init__(self):
assert self.impl_type in ["inline", "remote"]
if self.impl_type == "inline":
assert self.inline_config is not None
if self.impl_type == "remote":
assert self.remote_config is not None
def convert_to_inference_config(self):
if self.impl_type == "inline":
inline_config = InlineImplHydraConfig(**self.inline_config)
return InferenceConfig(
impl_config=inline_config.convert_to_inline_impl_config()
)
elif self.impl_type == "remote":
remote_config = RemoteImplHydraConfig(**self.remote_config)
return InferenceConfig(
impl_config=remote_config.convert_to_remote_impl_config()
)
InferenceHydraConfig = builds(InferenceConfig)
cs = ConfigStore.instance()
cs.store(name="inference_config", node=InferenceHydraConfig)