mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
Inference to use provider resource id to register and validate (#428)
This PR changes the way model id gets translated to the final model name that gets passed through the provider. Major changes include: 1) Providers are responsible for registering an object and as part of the registration returning the object with the correct provider specific name of the model provider_resource_id 2) To help with the common look ups different names a new ModelLookup class is created. Tested all inference providers including together, fireworks, vllm, ollama, meta reference and bedrock
This commit is contained in:
parent
e51107e019
commit
fdff24e77a
21 changed files with 460 additions and 290 deletions
|
@ -86,6 +86,7 @@ class Llama:
|
|||
and loads the pre-trained model and tokenizer.
|
||||
"""
|
||||
model = resolve_model(config.model)
|
||||
llama_model = model.core_model_id.value
|
||||
|
||||
if not torch.distributed.is_initialized():
|
||||
torch.distributed.init_process_group("nccl")
|
||||
|
@ -186,13 +187,20 @@ class Llama:
|
|||
model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
print(f"Loaded in {time.time() - start_time:.2f} seconds")
|
||||
return Llama(model, tokenizer, model_args)
|
||||
return Llama(model, tokenizer, model_args, llama_model)
|
||||
|
||||
def __init__(self, model: Transformer, tokenizer: Tokenizer, args: ModelArgs):
|
||||
def __init__(
|
||||
self,
|
||||
model: Transformer,
|
||||
tokenizer: Tokenizer,
|
||||
args: ModelArgs,
|
||||
llama_model: str,
|
||||
):
|
||||
self.args = args
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
self.formatter = ChatFormat(tokenizer)
|
||||
self.llama_model = llama_model
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate(
|
||||
|
@ -369,7 +377,7 @@ class Llama:
|
|||
self,
|
||||
request: ChatCompletionRequest,
|
||||
) -> Generator:
|
||||
messages = chat_completion_request_to_messages(request)
|
||||
messages = chat_completion_request_to_messages(request, self.llama_model)
|
||||
|
||||
sampling_params = request.sampling_params
|
||||
max_gen_len = sampling_params.max_tokens
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue