Merge branch 'main' into inference_refactor

This commit is contained in:
Botao Chen 2024-12-17 20:10:23 -08:00
commit fadb7deae5
79 changed files with 1547 additions and 2026 deletions

View file

@ -24,7 +24,7 @@ from fairscale.nn.model_parallel.initialize import (
model_parallel_is_initialized,
)
from llama_models.llama3.api.args import ModelArgs
from llama_models.llama3.api.chat_format import ChatFormat, ModelInput
from llama_models.llama3.api.chat_format import ChatFormat, LLMInput
from llama_models.llama3.api.datatypes import Model
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.llama3.reference_impl.model import Transformer
@ -39,8 +39,8 @@ from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerToken
from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.providers.utils.inference.prompt_adapter import (
augment_content_with_response_format_prompt,
chat_completion_request_to_messages,
ChatCompletionRequestWithRawContent,
CompletionRequestWithRawContent,
)
from .config import (
@ -207,7 +207,7 @@ class Llama:
@torch.inference_mode()
def generate(
self,
model_input: ModelInput,
model_input: LLMInput,
max_gen_len: int,
temperature: float = 0.6,
top_p: float = 0.9,
@ -344,7 +344,7 @@ class Llama:
def completion(
self,
request: CompletionRequest,
request: CompletionRequestWithRawContent,
) -> Generator:
sampling_params = request.sampling_params
max_gen_len = sampling_params.max_tokens
@ -355,10 +355,7 @@ class Llama:
):
max_gen_len = self.model.params.max_seq_len - 1
content = augment_content_with_response_format_prompt(
request.response_format, request.content
)
model_input = self.formatter.encode_content(content)
model_input = self.formatter.encode_content(request.content)
yield from self.generate(
model_input=model_input,
max_gen_len=max_gen_len,
@ -375,10 +372,8 @@ class Llama:
def chat_completion(
self,
request: ChatCompletionRequest,
request: ChatCompletionRequestWithRawContent,
) -> Generator:
messages = chat_completion_request_to_messages(request, self.llama_model)
sampling_params = request.sampling_params
max_gen_len = sampling_params.max_tokens
if (
@ -390,7 +385,7 @@ class Llama:
yield from self.generate(
model_input=self.formatter.encode_dialog_prompt(
messages,
request.messages,
request.tool_prompt_format,
),
max_gen_len=max_gen_len,

View file

@ -7,24 +7,50 @@
import asyncio
import logging
from typing import AsyncGenerator, List
from typing import AsyncGenerator, List, Optional, Union
from llama_models.llama3.api.datatypes import (
SamplingParams,
StopReason,
ToolDefinition,
ToolPromptFormat,
)
from llama_models.sku_list import resolve_model
from llama_stack.apis.models import Model
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseEvent,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
CompletionMessage,
CompletionRequest,
CompletionResponse,
CompletionResponseStreamChunk,
Inference,
InterleavedContent,
LogProbConfig,
Message,
ResponseFormat,
TokenLogProbs,
ToolCallDelta,
ToolCallParseStatus,
ToolChoice,
)
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.providers.utils.inference.model_registry import build_model_alias
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.models import Model, ModelType
from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.embedding_mixin import (
SentenceTransformerEmbeddingMixin,
)
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.model_registry import (
build_model_alias,
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
convert_image_media_to_url,
request_has_media,
augment_content_with_response_format_prompt,
chat_completion_request_to_messages,
convert_request_to_raw,
)
from .config import MetaReferenceInferenceConfig
@ -44,7 +70,8 @@ class MetaReferenceInferenceImpl(
):
def __init__(self, config: MetaReferenceInferenceConfig) -> None:
self.config = config
self.model = None
self.model_id = None
self.llama_model = None
async def initialize(self, model_id, llama_model) -> None:
log.info(f"Loading model `{model_id}`")
@ -56,20 +83,21 @@ class MetaReferenceInferenceImpl(
else:
self.generator = Llama.build(self.config, model_id, llama_model)
self.model = model_id
self.model_id = model_id
self.llama_model = llama_model
async def shutdown(self) -> None:
if self.config.create_distributed_process_group:
self.generator.stop()
def check_model(self, request) -> None:
if self.model is None:
if self.model_id or self.llama_model is None:
raise RuntimeError(
"No avaible model yet, please register your requested model or add your model in the resouces first"
)
elif request.model != self.model:
elif request.model != self.model_id:
raise RuntimeError(
f"Model mismatch: request model: {request.model} != loaded model: {self.model}"
f"Model mismatch: request model: {request.model} != loaded model: {self.model_id}"
)
async def unregister_model(self, model_id: str) -> None:
@ -107,7 +135,7 @@ class MetaReferenceInferenceImpl(
async def completion(
self,
model_id: str,
content: InterleavedTextMedia,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
@ -116,6 +144,7 @@ class MetaReferenceInferenceImpl(
if logprobs:
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
content = augment_content_with_response_format_prompt(response_format, content)
request = CompletionRequest(
model=model_id,
content=content,
@ -125,7 +154,7 @@ class MetaReferenceInferenceImpl(
logprobs=logprobs,
)
self.check_model(request)
request = await request_with_localized_media(request)
request = await convert_request_to_raw(request)
if request.stream:
return self._stream_completion(request)
@ -250,7 +279,13 @@ class MetaReferenceInferenceImpl(
logprobs=logprobs,
)
self.check_model(request)
request = await request_with_localized_media(request)
# augment and rewrite messages depending on the model
request.messages = chat_completion_request_to_messages(
request, self.llama_model.core_model_id.value
)
# download media and convert to raw content so we can send it to the model
request = await convert_request_to_raw(request)
if self.config.create_distributed_process_group:
if SEMAPHORE.locked():
@ -291,11 +326,15 @@ class MetaReferenceInferenceImpl(
if stop_reason is None:
stop_reason = StopReason.out_of_tokens
message = self.generator.formatter.decode_assistant_message(
raw_message = self.generator.formatter.decode_assistant_message(
tokens, stop_reason
)
return ChatCompletionResponse(
completion_message=message,
completion_message=CompletionMessage(
content=raw_message.content,
stop_reason=raw_message.stop_reason,
tool_calls=raw_message.tool_calls,
),
logprobs=logprobs if request.logprobs else None,
)
@ -421,31 +460,3 @@ class MetaReferenceInferenceImpl(
else:
for x in impl():
yield x
async def request_with_localized_media(
request: Union[ChatCompletionRequest, CompletionRequest],
) -> Union[ChatCompletionRequest, CompletionRequest]:
if not request_has_media(request):
return request
async def _convert_single_content(content):
if isinstance(content, ImageMedia):
url = await convert_image_media_to_url(content, download=True)
return ImageMedia(image=URL(uri=url))
else:
return content
async def _convert_content(content):
if isinstance(content, list):
return [await _convert_single_content(c) for c in content]
else:
return await _convert_single_content(content)
if isinstance(request, ChatCompletionRequest):
for m in request.messages:
m.content = await _convert_content(m.content)
else:
request.content = await _convert_content(request.content)
return request