This commit is contained in:
Ashwin Bharambe 2025-04-07 11:57:20 -07:00
parent 63cf5dda50
commit b239c57c54
8 changed files with 25 additions and 30 deletions

View file

@ -135,7 +135,7 @@ class Llama4:
if print_model_input:
cprint("Input to model:\n", "yellow")
for inp in llm_inputs:
cprint(self.tokenizer.decode(inp.tokens.tolist()), "grey")
cprint(self.tokenizer.decode(inp.tokens), "grey")
prompt_tokens = [inp.tokens for inp in llm_inputs]
bsz = len(llm_inputs)

View file

@ -5,19 +5,10 @@
# the root directory of this source tree.
from pathlib import Path
from typing import List, Optional
from pydantic import BaseModel
from llama_stack.distribution.utils.model_utils import model_local_dir
class TokenResult(BaseModel):
token: int
text: str
logprobs: Optional[List[float]] = None
def model_checkpoint_dir(model_id) -> str:
checkpoint_dir = Path(model_local_dir(model_id))

View file

@ -51,7 +51,7 @@ class MetaReferenceInferenceConfig(BaseModel):
model: str = "Llama3.2-3B-Instruct",
checkpoint_dir: str = "${env.CHECKPOINT_DIR:null}",
quantization_type: str = "${env.QUANTIZATION_TYPE:bf16}",
model_parallel_size: str = "${env.MODEL_PARALLEL_SIZE:null}",
model_parallel_size: str = "${env.MODEL_PARALLEL_SIZE:0}",
**kwargs,
) -> Dict[str, Any]:
return {

View file

@ -113,6 +113,7 @@ def _infer_tool_prompt_format(request: ChatCompletionRequestWithRawContent):
return get_default_tool_prompt_format(request.model)
# TODO: combine Llama3 and Llama4 generators since they are almost identical now
class Llama4Generator:
def __init__(
self,
@ -165,8 +166,8 @@ class Llama4Generator:
max_gen_len = self.args.max_seq_len - 1
temperature, top_p = _infer_sampling_params(sampling_params)
yield from self.inner_generator.generate(
llm_input=self.formatter.encode_content(request.content),
for result in self.inner_generator.generate(
llm_inputs=[self.formatter.encode_content(request.content)],
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
@ -177,7 +178,8 @@ class Llama4Generator:
self.args.vocab_size,
request.response_format,
),
)
):
yield result[0]
def chat_completion(
self,
@ -189,8 +191,8 @@ class Llama4Generator:
max_gen_len = self.args.max_seq_len - 1
temperature, top_p = _infer_sampling_params(sampling_params)
yield from self.inner_generator.generate(
llm_input=self.formatter.encode_dialog_prompt(request.messages, _infer_tool_prompt_format(request)),
for result in self.inner_generator.generate(
llm_inputs=[self.formatter.encode_dialog_prompt(request.messages, _infer_tool_prompt_format(request))],
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
@ -201,7 +203,8 @@ class Llama4Generator:
self.args.vocab_size,
request.response_format,
),
)
):
yield result[0]
class Llama3Generator:
@ -255,8 +258,8 @@ class Llama3Generator:
max_gen_len = self.args.max_seq_len - 1
temperature, top_p = _infer_sampling_params(sampling_params)
yield from self.inner_generator.generate(
model_input=self.formatter.encode_content(request.content),
for result in self.inner_generator.generate(
llm_inputs=[self.formatter.encode_content(request.content)],
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
@ -267,7 +270,8 @@ class Llama3Generator:
self.args.vocab_size,
request.response_format,
),
)
):
yield result[0]
def chat_completion(
self,
@ -279,8 +283,8 @@ class Llama3Generator:
max_gen_len = self.args.max_seq_len - 1
temperature, top_p = _infer_sampling_params(sampling_params)
yield from self.inner_generator.generate(
model_input=self.formatter.encode_dialog_prompt(request.messages, _infer_tool_prompt_format(request)),
for result in self.inner_generator.generate(
llm_inputs=[self.formatter.encode_dialog_prompt(request.messages, _infer_tool_prompt_format(request))],
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
@ -291,4 +295,5 @@ class Llama3Generator:
self.args.vocab_size,
request.response_format,
),
)
):
yield result[0]

View file

@ -149,7 +149,7 @@ class MetaReferenceInferenceImpl(
if self.config.create_distributed_process_group:
self.generator = LlamaModelParallelGenerator(
model_parallel_size=llama_model.pth_file_count,
model_parallel_size=self.config.model_parallel_size or llama_model.pth_file_count,
builder_fn=builder_fn,
builder_params=builder_params,
formatter=(

View file

@ -32,13 +32,12 @@ from pydantic import BaseModel, Field
from torch.distributed.launcher.api import LaunchConfig, elastic_launch
from typing_extensions import Annotated
from llama_stack.models.llama.datatypes import GenerationResult
from llama_stack.providers.utils.inference.prompt_adapter import (
ChatCompletionRequestWithRawContent,
CompletionRequestWithRawContent,
)
from .common import TokenResult
log = logging.getLogger(__name__)
@ -75,7 +74,7 @@ class TaskRequest(BaseModel):
class TaskResponse(BaseModel):
type: Literal[ProcessingMessageName.task_response] = ProcessingMessageName.task_response
result: TokenResult
result: GenerationResult
class ExceptionResponse(BaseModel):

View file

@ -20,7 +20,7 @@ providers:
checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null}
quantization:
type: ${env.QUANTIZATION_TYPE:bf16}
model_parallel_size: ${env.MODEL_PARALLEL_SIZE:null}
model_parallel_size: ${env.MODEL_PARALLEL_SIZE:0}
- provider_id: sentence-transformers
provider_type: inline::sentence-transformers
config: {}
@ -32,7 +32,7 @@ providers:
checkpoint_dir: ${env.SAFETY_CHECKPOINT_DIR:null}
quantization:
type: ${env.QUANTIZATION_TYPE:bf16}
model_parallel_size: ${env.MODEL_PARALLEL_SIZE:null}
model_parallel_size: ${env.MODEL_PARALLEL_SIZE:0}
vector_io:
- provider_id: faiss
provider_type: inline::faiss

View file

@ -20,7 +20,7 @@ providers:
checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null}
quantization:
type: ${env.QUANTIZATION_TYPE:bf16}
model_parallel_size: ${env.MODEL_PARALLEL_SIZE:null}
model_parallel_size: ${env.MODEL_PARALLEL_SIZE:0}
- provider_id: sentence-transformers
provider_type: inline::sentence-transformers
config: {}