Distribution server now functioning

This commit is contained in:
Ashwin Bharambe 2024-08-02 13:37:40 -07:00
parent 041cafbee3
commit 2cf9915806
21 changed files with 635 additions and 266 deletions

View file

@ -19,7 +19,7 @@ def available_inference_adapters() -> List[Adapter]:
"zmq",
],
module="llama_toolchain.inference.inference",
config_class="llama_toolchain.inference.inference.InlineImplConfig",
config_class="llama_toolchain.inference.inference.MetaReferenceImplConfig",
),
SourceAdapter(
api_surface=ApiSurface.inference,

View file

@ -7,9 +7,6 @@
from enum import Enum
from typing import Literal, Optional, Union
from hydra.core.config_store import ConfigStore
from hydra_zen import builds
from llama_models.llama3_1.api.datatypes import CheckpointQuantizationFormat
from pydantic import BaseModel, Field
@ -19,13 +16,6 @@ from typing_extensions import Annotated
from .datatypes import QuantizationConfig
@json_schema_type
class ImplType(Enum):
inline = "inline"
remote = "remote"
ollama = "ollama"
@json_schema_type
class CheckpointType(Enum):
pytorch = "pytorch"
@ -66,8 +56,8 @@ class ModelCheckpointConfig(BaseModel):
@json_schema_type
class InlineImplConfig(BaseModel):
impl_type: Literal[ImplType.inline.value] = ImplType.inline.value
class MetaReferenceImplConfig(BaseModel):
model: str
checkpoint_config: ModelCheckpointConfig
quantization: Optional[QuantizationConfig] = None
torch_seed: Optional[int] = None
@ -75,28 +65,7 @@ class InlineImplConfig(BaseModel):
max_batch_size: int = 1
@json_schema_type
class RemoteImplConfig(BaseModel):
impl_type: Literal[ImplType.remote.value] = ImplType.remote.value
url: str = Field(..., description="The URL of the remote module")
@json_schema_type
class OllamaImplConfig(BaseModel):
impl_type: Literal[ImplType.ollama.value] = ImplType.ollama.value
model: str = Field(..., description="The name of the model in ollama catalog")
url: str = Field(..., description="The URL for the ollama server")
@json_schema_type
class InferenceConfig(BaseModel):
impl_config: Annotated[
Union[InlineImplConfig, RemoteImplConfig, OllamaImplConfig],
Field(discriminator="impl_type"),
]
InferenceHydraConfig = builds(InferenceConfig)
cs = ConfigStore.instance()
cs.store(name="inference_config", node=InferenceHydraConfig)

View file

@ -4,19 +4,19 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .api.config import ImplType, InferenceConfig
# from .api.config import ImplType, InferenceConfig
async def get_inference_api_instance(config: InferenceConfig):
if config.impl_config.impl_type == ImplType.inline.value:
from .inference import InferenceImpl
# async def get_inference_api_instance(config: InferenceConfig):
# if config.impl_config.impl_type == ImplType.inline.value:
# from .inference import InferenceImpl
return InferenceImpl(config.impl_config)
elif config.impl_config.impl_type == ImplType.ollama.value:
from .ollama import OllamaInference
# return InferenceImpl(config.impl_config)
# elif config.impl_config.impl_type == ImplType.ollama.value:
# from .ollama import OllamaInference
return OllamaInference(config.impl_config)
# return OllamaInference(config.impl_config)
from .client import InferenceClient
# from .client import InferenceClient
return InferenceClient(config.impl_config.url)
# return InferenceClient(config.impl_config.url)

View file

@ -29,7 +29,7 @@ from llama_models.llama3_1.api.model import Transformer
from llama_models.llama3_1.api.tokenizer import Tokenizer
from termcolor import cprint
from .api.config import CheckpointType, InlineImplConfig
from .api.config import CheckpointType, MetaReferenceImplConfig
from .api.datatypes import QuantizationType
@ -42,7 +42,7 @@ class TokenResult:
class Llama:
@staticmethod
def build(config: InlineImplConfig):
def build(config: MetaReferenceImplConfig):
"""
Build a Llama instance by initializing and loading a model checkpoint.

View file

@ -4,11 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import AsyncGenerator
import asyncio
from typing import AsyncIterator, Union
from llama_models.llama3_1.api.datatypes import StopReason
from llama_models.sku_list import resolve_model
from .api.config import InlineImplConfig
from .api.config import MetaReferenceImplConfig
from .api.datatypes import (
ChatCompletionResponseEvent,
ChatCompletionResponseEventType,
@ -19,23 +22,35 @@ from .api.endpoints import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseStreamChunk,
CompletionRequest,
Inference,
)
from .model_parallel import LlamaModelParallelGenerator
def get_adapter_impl(config: InlineImplConfig) -> Inference:
async def get_adapter_impl(config: MetaReferenceImplConfig):
assert isinstance(
config, InlineImplConfig
config, MetaReferenceImplConfig
), f"Unexpected config type: {type(config)}"
return InferenceImpl(config)
impl = MetaReferenceInferenceImpl(config)
await impl.initialize()
return impl
class InferenceImpl(Inference):
# there's a single model parallel process running serving the model. for now,
# we don't support multiple concurrent requests to this process.
SEMAPHORE = asyncio.Semaphore(1)
def __init__(self, config: InlineImplConfig) -> None:
class MetaReferenceInferenceImpl(Inference):
def __init__(self, config: MetaReferenceImplConfig) -> None:
self.config = config
model = resolve_model(config.model)
if model is None:
raise RuntimeError(f"Unknown model: {config.model}, Run `llama model list`")
self.model = model
# verify that the checkpoint actually is for this model lol
async def initialize(self) -> None:
self.generator = LlamaModelParallelGenerator(self.config)
@ -44,125 +59,144 @@ class InferenceImpl(Inference):
async def shutdown(self) -> None:
self.generator.stop()
async def completion(self, request: CompletionRequest) -> AsyncGenerator:
raise NotImplementedError()
async def chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
if request.stream:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
delta="",
)
# hm, when stream=False, we should not be doing SSE :/ which is what the
# top-level server is going to do. make the typing more specific here
async def chat_completion(
self, request: ChatCompletionRequest
) -> AsyncIterator[
Union[ChatCompletionResponseStreamChunk, ChatCompletionResponse]
]:
model = resolve_model(request.model)
if model is None:
raise RuntimeError(
f"Unknown model: {request.model}, Run `llama model list`"
)
elif model.descriptor() != self.model.descriptor():
raise RuntimeError(
f"Model mismatch: {request.model} != {self.model.descriptor()}"
)
tokens = []
logprobs = []
if SEMAPHORE.locked():
raise RuntimeError("Only one concurrent request is supported")
stop_reason = None
buffer = ""
ipython = False
for token_result in self.generator.chat_completion(
messages=request.messages,
temperature=request.sampling_params.temperature,
top_p=request.sampling_params.top_p,
max_gen_len=request.sampling_params.max_tokens,
logprobs=request.logprobs,
):
buffer += token_result.text
tokens.append(token_result.token)
if not ipython and buffer.startswith("<|python_tag|>"):
ipython = True
async with SEMAPHORE:
if request.stream:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.started,
),
event_type=ChatCompletionResponseEventType.start,
delta="",
)
)
buffer = buffer[len("<|python_tag|>") :]
continue
if not request.stream:
if request.logprobs:
logprobs.append(token_result.logprob)
tokens = []
logprobs = []
continue
stop_reason = None
if token_result.text == "<|eot_id|>":
stop_reason = StopReason.end_of_turn
text = ""
elif token_result.text == "<|eom_id|>":
stop_reason = StopReason.end_of_message
text = ""
else:
text = token_result.text
buffer = ""
ipython = False
if ipython:
delta = ToolCallDelta(
content=text,
parse_status=ToolCallParseStatus.in_progress,
)
else:
delta = text
for token_result in self.generator.chat_completion(
messages=request.messages,
temperature=request.sampling_params.temperature,
top_p=request.sampling_params.top_p,
max_gen_len=request.sampling_params.max_tokens,
logprobs=request.logprobs,
):
buffer += token_result.text
tokens.append(token_result.token)
if not ipython and buffer.startswith("<|python_tag|>"):
ipython = True
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.started,
),
)
)
buffer = buffer[len("<|python_tag|>") :]
continue
if not request.stream:
if request.logprobs:
logprobs.append(token_result.logprob)
continue
if token_result.text == "<|eot_id|>":
stop_reason = StopReason.end_of_turn
text = ""
elif token_result.text == "<|eom_id|>":
stop_reason = StopReason.end_of_message
text = ""
else:
text = token_result.text
if ipython:
delta = ToolCallDelta(
content=text,
parse_status=ToolCallParseStatus.in_progress,
)
else:
delta = text
if stop_reason is None:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=delta,
stop_reason=stop_reason,
)
)
if stop_reason is None:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=delta,
stop_reason=stop_reason,
)
)
stop_reason = StopReason.out_of_tokens
if stop_reason is None:
stop_reason = StopReason.out_of_tokens
# TODO(ashwin): parse tool calls separately here and report errors?
# if someone breaks the iteration before coming here we are toast
message = self.generator.formatter.decode_assistant_message(tokens, stop_reason)
if request.stream:
parsed_tool_calls = len(message.tool_calls) > 0
if ipython and not parsed_tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.failure,
),
stop_reason=stop_reason,
)
)
for tool_call in message.tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content=tool_call,
parse_status=ToolCallParseStatus.success,
),
stop_reason=stop_reason,
)
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta="",
stop_reason=stop_reason,
)
# TODO(ashwin): parse tool calls separately here and report errors?
# if someone breaks the iteration before coming here we are toast
message = self.generator.formatter.decode_assistant_message(
tokens, stop_reason
)
if request.stream:
parsed_tool_calls = len(message.tool_calls) > 0
if ipython and not parsed_tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.failure,
),
stop_reason=stop_reason,
)
)
# TODO(ashwin): what else do we need to send out here when everything finishes?
else:
yield ChatCompletionResponse(
completion_message=message,
logprobs=logprobs if request.logprobs else None,
)
for tool_call in message.tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content=tool_call,
parse_status=ToolCallParseStatus.success,
),
stop_reason=stop_reason,
)
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta="",
stop_reason=stop_reason,
)
)
# TODO(ashwin): what else do we need to send out here when everything finishes?
else:
yield ChatCompletionResponse(
completion_message=message,
logprobs=logprobs if request.logprobs else None,
)

View file

@ -13,7 +13,7 @@ from llama_models.llama3_1.api.chat_format import ChatFormat
from llama_models.llama3_1.api.datatypes import Message
from llama_models.llama3_1.api.tokenizer import Tokenizer
from .api.config import InlineImplConfig
from .api.config import MetaReferenceImplConfig
from .generation import Llama
from .parallel_utils import ModelParallelProcessGroup
@ -42,7 +42,7 @@ class ModelRunner:
)
def init_model_cb(config: InlineImplConfig):
def init_model_cb(config: MetaReferenceImplConfig):
llama = Llama.build(config)
return ModelRunner(llama)
@ -58,7 +58,7 @@ class LlamaModelParallelGenerator:
clear at the callsite why we need to use a context manager.
"""
def __init__(self, config: InlineImplConfig):
def __init__(self, config: MetaReferenceImplConfig):
self.config = config
# this is a hack because Agent's loop uses this to tokenize and check if input is too long

View file

@ -17,7 +17,7 @@ from llama_models.llama3_1.api.model import Transformer, TransformerBlock
from llama_toolchain.inference.api.config import (
CheckpointQuantizationFormat,
InlineImplConfig,
MetaReferenceImplConfig,
)
from llama_toolchain.inference.api.datatypes import QuantizationType
@ -46,7 +46,7 @@ def swiglu_wrapper(
def convert_to_quantized_model(
model: Transformer,
config: InlineImplConfig,
config: MetaReferenceImplConfig,
fp8_activation_scale_ub: Optional[float] = 1200.0,
) -> Transformer:
if config.quantization.type == QuantizationType.bf16.value: