Merge branch 'main' into inference_refactor

This commit is contained in:
Botao Chen 2024-12-17 20:10:23 -08:00
commit fadb7deae5
79 changed files with 1547 additions and 2026 deletions

View file

@ -24,7 +24,7 @@ from fairscale.nn.model_parallel.initialize import (
model_parallel_is_initialized,
)
from llama_models.llama3.api.args import ModelArgs
from llama_models.llama3.api.chat_format import ChatFormat, ModelInput
from llama_models.llama3.api.chat_format import ChatFormat, LLMInput
from llama_models.llama3.api.datatypes import Model
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.llama3.reference_impl.model import Transformer
@ -39,8 +39,8 @@ from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerToken
from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.providers.utils.inference.prompt_adapter import (
augment_content_with_response_format_prompt,
chat_completion_request_to_messages,
ChatCompletionRequestWithRawContent,
CompletionRequestWithRawContent,
)
from .config import (
@ -207,7 +207,7 @@ class Llama:
@torch.inference_mode()
def generate(
self,
model_input: ModelInput,
model_input: LLMInput,
max_gen_len: int,
temperature: float = 0.6,
top_p: float = 0.9,
@ -344,7 +344,7 @@ class Llama:
def completion(
self,
request: CompletionRequest,
request: CompletionRequestWithRawContent,
) -> Generator:
sampling_params = request.sampling_params
max_gen_len = sampling_params.max_tokens
@ -355,10 +355,7 @@ class Llama:
):
max_gen_len = self.model.params.max_seq_len - 1
content = augment_content_with_response_format_prompt(
request.response_format, request.content
)
model_input = self.formatter.encode_content(content)
model_input = self.formatter.encode_content(request.content)
yield from self.generate(
model_input=model_input,
max_gen_len=max_gen_len,
@ -375,10 +372,8 @@ class Llama:
def chat_completion(
self,
request: ChatCompletionRequest,
request: ChatCompletionRequestWithRawContent,
) -> Generator:
messages = chat_completion_request_to_messages(request, self.llama_model)
sampling_params = request.sampling_params
max_gen_len = sampling_params.max_tokens
if (
@ -390,7 +385,7 @@ class Llama:
yield from self.generate(
model_input=self.formatter.encode_dialog_prompt(
messages,
request.messages,
request.tool_prompt_format,
),
max_gen_len=max_gen_len,