JSON serialization for parallel processing queue (#232)

* send/recv pydantic json over socket

* fixup

* address feedback

* bidirectional wrapper

* second round of feedback
This commit is contained in:
Dalton Flanagan 2024-10-09 17:24:12 -04:00 committed by GitHub
parent 0f66ae0f61
commit 7a8aa775e5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 158 additions and 52 deletions

View file

@ -17,17 +17,7 @@ from llama_models.sku_list import resolve_model
from .config import MetaReferenceImplConfig
from .generation import Llama, model_checkpoint_dir
from .parallel_utils import ModelParallelProcessGroup
@dataclass
class InferenceArgs:
messages: List[Message]
temperature: float
top_p: float
max_gen_len: int
logprobs: bool
tool_prompt_format: ToolPromptFormat
from .parallel_utils import InferenceArgs, ModelParallelProcessGroup
class ModelRunner:
@ -102,7 +92,7 @@ class LlamaModelParallelGenerator:
temperature=temperature,
top_p=top_p,
max_gen_len=max_gen_len,
logprobs=logprobs,
logprobs=logprobs or False,
tool_prompt_format=tool_prompt_format,
)