mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
drop custom classes to manage hydra
This commit is contained in:
parent
86fff23a9e
commit
aca6bfe0df
4 changed files with 17 additions and 82 deletions
|
@ -1,9 +1,14 @@
|
||||||
inference_config:
|
inference_config:
|
||||||
impl_type: "inline"
|
impl_config:
|
||||||
inline_config:
|
impl_type: "inline"
|
||||||
checkpoint_type: "pytorch"
|
checkpoint_config:
|
||||||
checkpoint_dir: {checkpoint_dir}/
|
checkpoint:
|
||||||
tokenizer_path: {checkpoint_dir}/tokenizer.model
|
checkpoint_type: "pytorch"
|
||||||
model_parallel_size: {model_parallel_size}
|
checkpoint_dir: {checkpoint_dir}/
|
||||||
|
tokenizer_path: {checkpoint_dir}/tokenizer.model
|
||||||
|
model_parallel_size: {model_parallel_size}
|
||||||
|
quantization_format: bf16
|
||||||
|
quantization: null
|
||||||
|
torch_seed: null
|
||||||
max_seq_len: 2048
|
max_seq_len: 2048
|
||||||
max_batch_size: 1
|
max_batch_size: 1
|
||||||
|
|
|
@ -8,6 +8,7 @@ from dataclasses import dataclass
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Literal, Optional, Union
|
from typing import Literal, Optional, Union
|
||||||
|
|
||||||
|
from hydra_zen import builds
|
||||||
from hydra.core.config_store import ConfigStore
|
from hydra.core.config_store import ConfigStore
|
||||||
from llama_models.llama3_1.api.datatypes import CheckpointQuantizationFormat
|
from llama_models.llama3_1.api.datatypes import CheckpointQuantizationFormat
|
||||||
|
|
||||||
|
@ -78,78 +79,7 @@ class InferenceConfig(BaseModel):
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
# Hydra does not like unions of containers and
|
InferenceHydraConfig = builds(InferenceConfig)
|
||||||
# Pydantic does not like Literals
|
|
||||||
# Adding a simple dataclass with custom coversion
|
|
||||||
# to config classes
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class InlineImplHydraConfig:
|
|
||||||
checkpoint_type: str # "pytorch" / "HF"
|
|
||||||
# pytorch checkpoint required args
|
|
||||||
checkpoint_dir: str
|
|
||||||
tokenizer_path: str
|
|
||||||
model_parallel_size: int
|
|
||||||
max_seq_len: int
|
|
||||||
max_batch_size: int = 1
|
|
||||||
quantization: Optional[QuantizationConfig] = None
|
|
||||||
# TODO: huggingface checkpoint required args
|
|
||||||
|
|
||||||
def convert_to_inline_impl_config(self):
|
|
||||||
if self.checkpoint_type == "pytorch":
|
|
||||||
return InlineImplConfig(
|
|
||||||
checkpoint_config=ModelCheckpointConfig(
|
|
||||||
checkpoint=PytorchCheckpoint(
|
|
||||||
checkpoint_type=CheckpointType.pytorch.value,
|
|
||||||
checkpoint_dir=self.checkpoint_dir,
|
|
||||||
tokenizer_path=self.tokenizer_path,
|
|
||||||
model_parallel_size=self.model_parallel_size,
|
|
||||||
)
|
|
||||||
),
|
|
||||||
quantization=self.quantization,
|
|
||||||
max_seq_len=self.max_seq_len,
|
|
||||||
max_batch_size=self.max_batch_size,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError("HF Checkpoint not supported yet")
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class RemoteImplHydraConfig:
|
|
||||||
url: str
|
|
||||||
|
|
||||||
def convert_to_remote_impl_config(self):
|
|
||||||
return RemoteImplConfig(
|
|
||||||
url=self.url,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class InferenceHydraConfig:
|
|
||||||
impl_type: str
|
|
||||||
inline_config: Optional[InlineImplHydraConfig] = None
|
|
||||||
remote_config: Optional[RemoteImplHydraConfig] = None
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
assert self.impl_type in ["inline", "remote"]
|
|
||||||
if self.impl_type == "inline":
|
|
||||||
assert self.inline_config is not None
|
|
||||||
if self.impl_type == "remote":
|
|
||||||
assert self.remote_config is not None
|
|
||||||
|
|
||||||
def convert_to_inference_config(self):
|
|
||||||
if self.impl_type == "inline":
|
|
||||||
inline_config = InlineImplHydraConfig(**self.inline_config)
|
|
||||||
return InferenceConfig(
|
|
||||||
impl_config=inline_config.convert_to_inline_impl_config()
|
|
||||||
)
|
|
||||||
elif self.impl_type == "remote":
|
|
||||||
remote_config = RemoteImplHydraConfig(**self.remote_config)
|
|
||||||
return InferenceConfig(
|
|
||||||
impl_config=remote_config.convert_to_remote_impl_config()
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
cs = ConfigStore.instance()
|
cs = ConfigStore.instance()
|
||||||
cs.store(name="inference_config", node=InferenceHydraConfig)
|
cs.store(name="inference_config", node=InferenceHydraConfig)
|
||||||
|
|
|
@ -6,6 +6,7 @@
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
|
from termcolor import cprint
|
||||||
from typing import AsyncGenerator
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
from urllib.request import getproxies
|
from urllib.request import getproxies
|
||||||
|
@ -65,6 +66,7 @@ async def run_main(host: str, port: int):
|
||||||
client = InferenceClient(f"http://{host}:{port}")
|
client = InferenceClient(f"http://{host}:{port}")
|
||||||
|
|
||||||
message = UserMessage(content="hello world, help me out here")
|
message = UserMessage(content="hello world, help me out here")
|
||||||
|
cprint(f"User>{message.content}", "green")
|
||||||
req = ChatCompletionRequest(
|
req = ChatCompletionRequest(
|
||||||
model=InstructModel.llama3_70b_chat,
|
model=InstructModel.llama3_70b_chat,
|
||||||
messages=[message],
|
messages=[message],
|
||||||
|
|
|
@ -14,6 +14,7 @@ from dotenv import load_dotenv
|
||||||
from fastapi import FastAPI, HTTPException, Request
|
from fastapi import FastAPI, HTTPException, Request
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
|
|
||||||
|
from hydra_zen import instantiate
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
from llama_toolchain.utils import get_default_config_dir, parse_config
|
from llama_toolchain.utils import get_default_config_dir, parse_config
|
||||||
|
@ -49,11 +50,8 @@ async def startup():
|
||||||
global InferenceApiInstance
|
global InferenceApiInstance
|
||||||
|
|
||||||
config = get_config()
|
config = get_config()
|
||||||
hydra_config = InferenceHydraConfig(
|
|
||||||
**OmegaConf.to_container(config["inference_config"], resolve=True)
|
|
||||||
)
|
|
||||||
inference_config = hydra_config.convert_to_inference_config()
|
|
||||||
|
|
||||||
|
inference_config = instantiate(config["inference_config"])
|
||||||
InferenceApiInstance = await get_inference_api_instance(
|
InferenceApiInstance = await get_inference_api_instance(
|
||||||
inference_config,
|
inference_config,
|
||||||
)
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue