fixes after rebase

This commit is contained in:
Dinesh Yeduguru 2024-11-12 15:37:07 -08:00
parent 948f6ece6e
commit 919d421bcf
11 changed files with 72 additions and 70 deletions

View file

@ -86,6 +86,7 @@ class Llama:
and loads the pre-trained model and tokenizer.
"""
model = resolve_model(config.model)
llama_model = model.core_model_id.value
if not torch.distributed.is_initialized():
torch.distributed.init_process_group("nccl")
@ -186,13 +187,20 @@ class Llama:
model.load_state_dict(state_dict, strict=False)
print(f"Loaded in {time.time() - start_time:.2f} seconds")
return Llama(model, tokenizer, model_args)
return Llama(model, tokenizer, model_args, llama_model)
def __init__(self, model: Transformer, tokenizer: Tokenizer, args: ModelArgs):
def __init__(
self,
model: Transformer,
tokenizer: Tokenizer,
args: ModelArgs,
llama_model: str,
):
self.args = args
self.model = model
self.tokenizer = tokenizer
self.formatter = ChatFormat(tokenizer)
self.llama_model = llama_model
@torch.inference_mode()
def generate(
@ -369,7 +377,7 @@ class Llama:
self,
request: ChatCompletionRequest,
) -> Generator:
messages = chat_completion_request_to_messages(request)
messages = chat_completion_request_to_messages(request, self.llama_model)
sampling_params = request.sampling_params
max_gen_len = sampling_params.max_tokens