mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
rename toolchain/ --> llama_toolchain/
This commit is contained in:
parent
d95f5f863d
commit
f9111652ef
73 changed files with 36 additions and 37 deletions
159
llama_toolchain/inference/api/config.py
Normal file
159
llama_toolchain/inference/api/config.py
Normal file
|
@ -0,0 +1,159 @@
|
|||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Literal, Optional, Union
|
||||
|
||||
from hydra.core.config_store import ConfigStore
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from .datatypes import QuantizationConfig
|
||||
|
||||
|
||||
class ImplType(Enum):
|
||||
inline = "inline"
|
||||
remote = "remote"
|
||||
|
||||
|
||||
class CheckpointType(Enum):
|
||||
pytorch = "pytorch"
|
||||
huggingface = "huggingface"
|
||||
|
||||
|
||||
# This enum represents the format in which weights are specified
|
||||
# This does not necessarily always equal what quantization is desired
|
||||
# at runtime since there can be on-the-fly conversions done
|
||||
class CheckpointQuantizationFormat(Enum):
|
||||
# default format
|
||||
bf16 = "bf16"
|
||||
|
||||
# used for enabling fp8_rowwise inference, some weights are bf16
|
||||
fp8_mixed = "fp8_mixed"
|
||||
|
||||
|
||||
class PytorchCheckpoint(BaseModel):
|
||||
checkpoint_type: Literal[CheckpointType.pytorch.value] = (
|
||||
CheckpointType.pytorch.value
|
||||
)
|
||||
checkpoint_dir: str
|
||||
tokenizer_path: str
|
||||
model_parallel_size: int
|
||||
quantization_format: CheckpointQuantizationFormat = (
|
||||
CheckpointQuantizationFormat.bf16
|
||||
)
|
||||
|
||||
|
||||
class HuggingFaceCheckpoint(BaseModel):
|
||||
checkpoint_type: Literal[CheckpointType.huggingface.value] = (
|
||||
CheckpointType.huggingface.value
|
||||
)
|
||||
repo_id: str # or model_name ?
|
||||
model_parallel_size: int
|
||||
quantization_format: CheckpointQuantizationFormat = (
|
||||
CheckpointQuantizationFormat.bf16
|
||||
)
|
||||
|
||||
|
||||
class ModelCheckpointConfig(BaseModel):
|
||||
checkpoint: Annotated[
|
||||
Union[PytorchCheckpoint, HuggingFaceCheckpoint],
|
||||
Field(discriminator="checkpoint_type"),
|
||||
]
|
||||
|
||||
|
||||
class InlineImplConfig(BaseModel):
|
||||
impl_type: Literal[ImplType.inline.value] = ImplType.inline.value
|
||||
checkpoint_config: ModelCheckpointConfig
|
||||
quantization: Optional[QuantizationConfig] = None
|
||||
torch_seed: Optional[int] = None
|
||||
max_seq_len: int
|
||||
max_batch_size: int = 1
|
||||
|
||||
|
||||
class RemoteImplConfig(BaseModel):
|
||||
impl_type: Literal[ImplType.remote.value] = ImplType.remote.value
|
||||
url: str = Field(..., description="The URL of the remote module")
|
||||
|
||||
|
||||
class InferenceConfig(BaseModel):
|
||||
impl_config: Annotated[
|
||||
Union[InlineImplConfig, RemoteImplConfig],
|
||||
Field(discriminator="impl_type"),
|
||||
]
|
||||
|
||||
|
||||
# Hydra does not like unions of containers and
|
||||
# Pydantic does not like Literals
|
||||
# Adding a simple dataclass with custom coversion
|
||||
# to config classes
|
||||
|
||||
|
||||
@dataclass
|
||||
class InlineImplHydraConfig:
|
||||
checkpoint_type: str # "pytorch" / "HF"
|
||||
# pytorch checkpoint required args
|
||||
checkpoint_dir: str
|
||||
tokenizer_path: str
|
||||
model_parallel_size: int
|
||||
max_seq_len: int
|
||||
max_batch_size: int = 1
|
||||
quantization: Optional[QuantizationConfig] = None
|
||||
# TODO: huggingface checkpoint required args
|
||||
|
||||
def convert_to_inline_impl_config(self):
|
||||
if self.checkpoint_type == "pytorch":
|
||||
return InlineImplConfig(
|
||||
checkpoint_config=ModelCheckpointConfig(
|
||||
checkpoint=PytorchCheckpoint(
|
||||
checkpoint_type=CheckpointType.pytorch.value,
|
||||
checkpoint_dir=self.checkpoint_dir,
|
||||
tokenizer_path=self.tokenizer_path,
|
||||
model_parallel_size=self.model_parallel_size,
|
||||
)
|
||||
),
|
||||
quantization=self.quantization,
|
||||
max_seq_len=self.max_seq_len,
|
||||
max_batch_size=self.max_batch_size,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("HF Checkpoint not supported yet")
|
||||
|
||||
|
||||
@dataclass
|
||||
class RemoteImplHydraConfig:
|
||||
url: str
|
||||
|
||||
def convert_to_remote_impl_config(self):
|
||||
return RemoteImplConfig(
|
||||
url=self.url,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class InferenceHydraConfig:
|
||||
impl_type: str
|
||||
inline_config: Optional[InlineImplHydraConfig] = None
|
||||
remote_config: Optional[RemoteImplHydraConfig] = None
|
||||
|
||||
def __post_init__(self):
|
||||
assert self.impl_type in ["inline", "remote"]
|
||||
if self.impl_type == "inline":
|
||||
assert self.inline_config is not None
|
||||
if self.impl_type == "remote":
|
||||
assert self.remote_config is not None
|
||||
|
||||
def convert_to_inference_config(self):
|
||||
if self.impl_type == "inline":
|
||||
inline_config = InlineImplHydraConfig(**self.inline_config)
|
||||
return InferenceConfig(
|
||||
impl_config=inline_config.convert_to_inline_impl_config()
|
||||
)
|
||||
elif self.impl_type == "remote":
|
||||
remote_config = RemoteImplHydraConfig(**self.remote_config)
|
||||
return InferenceConfig(
|
||||
impl_config=remote_config.convert_to_remote_impl_config()
|
||||
)
|
||||
|
||||
|
||||
cs = ConfigStore.instance()
|
||||
cs.store(name="inference_config", node=InferenceHydraConfig)
|
Loading…
Add table
Add a link
Reference in a new issue