forked from phoenix-oss/llama-stack-mirror
Introduce Llama stack distributions (#22)
* Add distribution CLI scaffolding * More progress towards `llama distribution install` * getting closer to a distro definition, distro install + configure works * Distribution server now functioning * read existing configuration, save enums properly * Remove inference uvicorn server entrypoint and llama inference CLI command * updated dependency and client model name * Improved exception handling * local imports for faster cli * undo a typo, add a passthrough distribution * implement full-passthrough in the server * add safety adapters, configuration handling, server + clients * cleanup, moving stuff to common, nuke utils * Add a Path() wrapper at the earliest place * fixes * Bring agentic system api to toolchain Add adapter dependencies and resolve adapters using a topological sort * refactor to reduce size of `agentic_system` * move straggler files and fix some important existing bugs * ApiSurface -> Api * refactor a method out * Adapter -> Provider * Make each inference provider into its own subdirectory * installation fixes * Rename Distribution -> DistributionSpec, simplify RemoteProviders * dict key instead of attr * update inference config to take model and not model_dir * Fix passthrough streaming, send headers properly not part of body :facepalm * update safety to use model sku ids and not model dirs * Update cli_reference.md * minor fixes * add DistributionConfig, fix a bug in model download * Make install + start scripts do proper configuration automatically * Update CLI_reference * Nuke fp8_requirements, fold fbgemm into common requirements * Update README, add newline between API surface configurations * Refactor download functionality out of the Command so can be reused * Add `llama model download` alias for `llama download` * Show message about checksum file so users can check themselves * Simpler intro statements * get ollama working * Reduce a bunch of dependencies from toolchain Some improvements to the distribution install script * Avoid using `conda run` since it buffers everything * update dependencies and rely on LLAMA_TOOLCHAIN_DIR for dev purposes * add validation for configuration input * resort imports * make optional subclasses default to yes for configuration * Remove additional_pip_packages; move deps to providers * for inline make 8b model the default * Add scripts to MANIFEST * allow installing from test.pypi.org * Fix #2 to help with testing packages * Must install llama-models at that same version first * fix PIP_ARGS --------- Co-authored-by: Hardik Shah <hjshah@fb.com> Co-authored-by: Hardik Shah <hjshah@meta.com>
This commit is contained in:
parent
da4645a27a
commit
e830814399
115 changed files with 5839 additions and 1120 deletions
|
@ -1,104 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import Generator, List, Optional
|
||||
|
||||
from llama_models.llama3_1.api.chat_format import ChatFormat
|
||||
from llama_models.llama3_1.api.datatypes import Message
|
||||
from llama_models.llama3_1.api.tokenizer import Tokenizer
|
||||
|
||||
from .api.config import InlineImplConfig
|
||||
from .generation import Llama
|
||||
from .parallel_utils import ModelParallelProcessGroup
|
||||
|
||||
|
||||
@dataclass
|
||||
class InferenceArgs:
|
||||
messages: List[Message]
|
||||
temperature: float
|
||||
top_p: float
|
||||
max_gen_len: int
|
||||
logprobs: bool
|
||||
|
||||
|
||||
class ModelRunner:
|
||||
def __init__(self, llama):
|
||||
self.llama = llama
|
||||
|
||||
# the `task` object is the same that is sent to `ModelParallelProcessGroup.run_inference()`
|
||||
def __call__(self, task: InferenceArgs):
|
||||
return self.llama.chat_completion(
|
||||
task.messages,
|
||||
task.temperature,
|
||||
task.top_p,
|
||||
task.max_gen_len,
|
||||
task.logprobs,
|
||||
)
|
||||
|
||||
|
||||
def init_model_cb(config: InlineImplConfig):
|
||||
llama = Llama.build(config)
|
||||
return ModelRunner(llama)
|
||||
|
||||
|
||||
class LlamaModelParallelGenerator:
|
||||
"""
|
||||
This abstraction exists so
|
||||
- we can run model parallel code without needing to run the CLIs via torchrun
|
||||
- this also enables use model parallel code within a notebook context.
|
||||
|
||||
A Context Manager is used to ensure that the model parallel process is started and stopped
|
||||
correctly. This does make the ergonomics a little awkward, because it isn't immediately
|
||||
clear at the callsite why we need to use a context manager.
|
||||
"""
|
||||
|
||||
def __init__(self, config: InlineImplConfig):
|
||||
self.config = config
|
||||
|
||||
# this is a hack because Agent's loop uses this to tokenize and check if input is too long
|
||||
# while the tool-use loop is going
|
||||
checkpoint = self.config.checkpoint_config.checkpoint
|
||||
self.formatter = ChatFormat(Tokenizer(checkpoint.tokenizer_path))
|
||||
|
||||
def start(self):
|
||||
self.__enter__()
|
||||
|
||||
def stop(self):
|
||||
self.__exit__(None, None, None)
|
||||
|
||||
def __enter__(self):
|
||||
checkpoint = self.config.checkpoint_config.checkpoint
|
||||
self.group = ModelParallelProcessGroup(
|
||||
checkpoint.model_parallel_size,
|
||||
init_model_cb=partial(init_model_cb, self.config),
|
||||
)
|
||||
self.group.start()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, exc_traceback):
|
||||
self.group.stop()
|
||||
|
||||
def chat_completion(
|
||||
self,
|
||||
messages: List[Message],
|
||||
temperature: float = 0.6,
|
||||
top_p: float = 0.9,
|
||||
max_gen_len: Optional[int] = None,
|
||||
logprobs: bool = False,
|
||||
) -> Generator:
|
||||
req_obj = InferenceArgs(
|
||||
messages=deepcopy(messages),
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
max_gen_len=max_gen_len,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
|
||||
gen = self.group.run_inference(req_obj)
|
||||
yield from gen
|
Loading…
Add table
Add a link
Reference in a new issue