mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 09:21:45 +00:00
fix
This commit is contained in:
parent
63cf5dda50
commit
b239c57c54
8 changed files with 25 additions and 30 deletions
|
@ -135,7 +135,7 @@ class Llama4:
|
|||
if print_model_input:
|
||||
cprint("Input to model:\n", "yellow")
|
||||
for inp in llm_inputs:
|
||||
cprint(self.tokenizer.decode(inp.tokens.tolist()), "grey")
|
||||
cprint(self.tokenizer.decode(inp.tokens), "grey")
|
||||
prompt_tokens = [inp.tokens for inp in llm_inputs]
|
||||
|
||||
bsz = len(llm_inputs)
|
||||
|
|
|
@ -5,19 +5,10 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||
|
||||
|
||||
class TokenResult(BaseModel):
|
||||
token: int
|
||||
text: str
|
||||
logprobs: Optional[List[float]] = None
|
||||
|
||||
|
||||
def model_checkpoint_dir(model_id) -> str:
|
||||
checkpoint_dir = Path(model_local_dir(model_id))
|
||||
|
||||
|
|
|
@ -51,7 +51,7 @@ class MetaReferenceInferenceConfig(BaseModel):
|
|||
model: str = "Llama3.2-3B-Instruct",
|
||||
checkpoint_dir: str = "${env.CHECKPOINT_DIR:null}",
|
||||
quantization_type: str = "${env.QUANTIZATION_TYPE:bf16}",
|
||||
model_parallel_size: str = "${env.MODEL_PARALLEL_SIZE:null}",
|
||||
model_parallel_size: str = "${env.MODEL_PARALLEL_SIZE:0}",
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
return {
|
||||
|
|
|
@ -113,6 +113,7 @@ def _infer_tool_prompt_format(request: ChatCompletionRequestWithRawContent):
|
|||
return get_default_tool_prompt_format(request.model)
|
||||
|
||||
|
||||
# TODO: combine Llama3 and Llama4 generators since they are almost identical now
|
||||
class Llama4Generator:
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -165,8 +166,8 @@ class Llama4Generator:
|
|||
max_gen_len = self.args.max_seq_len - 1
|
||||
|
||||
temperature, top_p = _infer_sampling_params(sampling_params)
|
||||
yield from self.inner_generator.generate(
|
||||
llm_input=self.formatter.encode_content(request.content),
|
||||
for result in self.inner_generator.generate(
|
||||
llm_inputs=[self.formatter.encode_content(request.content)],
|
||||
max_gen_len=max_gen_len,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
|
@ -177,7 +178,8 @@ class Llama4Generator:
|
|||
self.args.vocab_size,
|
||||
request.response_format,
|
||||
),
|
||||
)
|
||||
):
|
||||
yield result[0]
|
||||
|
||||
def chat_completion(
|
||||
self,
|
||||
|
@ -189,8 +191,8 @@ class Llama4Generator:
|
|||
max_gen_len = self.args.max_seq_len - 1
|
||||
|
||||
temperature, top_p = _infer_sampling_params(sampling_params)
|
||||
yield from self.inner_generator.generate(
|
||||
llm_input=self.formatter.encode_dialog_prompt(request.messages, _infer_tool_prompt_format(request)),
|
||||
for result in self.inner_generator.generate(
|
||||
llm_inputs=[self.formatter.encode_dialog_prompt(request.messages, _infer_tool_prompt_format(request))],
|
||||
max_gen_len=max_gen_len,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
|
@ -201,7 +203,8 @@ class Llama4Generator:
|
|||
self.args.vocab_size,
|
||||
request.response_format,
|
||||
),
|
||||
)
|
||||
):
|
||||
yield result[0]
|
||||
|
||||
|
||||
class Llama3Generator:
|
||||
|
@ -255,8 +258,8 @@ class Llama3Generator:
|
|||
max_gen_len = self.args.max_seq_len - 1
|
||||
|
||||
temperature, top_p = _infer_sampling_params(sampling_params)
|
||||
yield from self.inner_generator.generate(
|
||||
model_input=self.formatter.encode_content(request.content),
|
||||
for result in self.inner_generator.generate(
|
||||
llm_inputs=[self.formatter.encode_content(request.content)],
|
||||
max_gen_len=max_gen_len,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
|
@ -267,7 +270,8 @@ class Llama3Generator:
|
|||
self.args.vocab_size,
|
||||
request.response_format,
|
||||
),
|
||||
)
|
||||
):
|
||||
yield result[0]
|
||||
|
||||
def chat_completion(
|
||||
self,
|
||||
|
@ -279,8 +283,8 @@ class Llama3Generator:
|
|||
max_gen_len = self.args.max_seq_len - 1
|
||||
|
||||
temperature, top_p = _infer_sampling_params(sampling_params)
|
||||
yield from self.inner_generator.generate(
|
||||
model_input=self.formatter.encode_dialog_prompt(request.messages, _infer_tool_prompt_format(request)),
|
||||
for result in self.inner_generator.generate(
|
||||
llm_inputs=[self.formatter.encode_dialog_prompt(request.messages, _infer_tool_prompt_format(request))],
|
||||
max_gen_len=max_gen_len,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
|
@ -291,4 +295,5 @@ class Llama3Generator:
|
|||
self.args.vocab_size,
|
||||
request.response_format,
|
||||
),
|
||||
)
|
||||
):
|
||||
yield result[0]
|
||||
|
|
|
@ -149,7 +149,7 @@ class MetaReferenceInferenceImpl(
|
|||
|
||||
if self.config.create_distributed_process_group:
|
||||
self.generator = LlamaModelParallelGenerator(
|
||||
model_parallel_size=llama_model.pth_file_count,
|
||||
model_parallel_size=self.config.model_parallel_size or llama_model.pth_file_count,
|
||||
builder_fn=builder_fn,
|
||||
builder_params=builder_params,
|
||||
formatter=(
|
||||
|
|
|
@ -32,13 +32,12 @@ from pydantic import BaseModel, Field
|
|||
from torch.distributed.launcher.api import LaunchConfig, elastic_launch
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.models.llama.datatypes import GenerationResult
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
ChatCompletionRequestWithRawContent,
|
||||
CompletionRequestWithRawContent,
|
||||
)
|
||||
|
||||
from .common import TokenResult
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -75,7 +74,7 @@ class TaskRequest(BaseModel):
|
|||
|
||||
class TaskResponse(BaseModel):
|
||||
type: Literal[ProcessingMessageName.task_response] = ProcessingMessageName.task_response
|
||||
result: TokenResult
|
||||
result: GenerationResult
|
||||
|
||||
|
||||
class ExceptionResponse(BaseModel):
|
||||
|
|
|
@ -20,7 +20,7 @@ providers:
|
|||
checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null}
|
||||
quantization:
|
||||
type: ${env.QUANTIZATION_TYPE:bf16}
|
||||
model_parallel_size: ${env.MODEL_PARALLEL_SIZE:null}
|
||||
model_parallel_size: ${env.MODEL_PARALLEL_SIZE:0}
|
||||
- provider_id: sentence-transformers
|
||||
provider_type: inline::sentence-transformers
|
||||
config: {}
|
||||
|
@ -32,7 +32,7 @@ providers:
|
|||
checkpoint_dir: ${env.SAFETY_CHECKPOINT_DIR:null}
|
||||
quantization:
|
||||
type: ${env.QUANTIZATION_TYPE:bf16}
|
||||
model_parallel_size: ${env.MODEL_PARALLEL_SIZE:null}
|
||||
model_parallel_size: ${env.MODEL_PARALLEL_SIZE:0}
|
||||
vector_io:
|
||||
- provider_id: faiss
|
||||
provider_type: inline::faiss
|
||||
|
|
|
@ -20,7 +20,7 @@ providers:
|
|||
checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null}
|
||||
quantization:
|
||||
type: ${env.QUANTIZATION_TYPE:bf16}
|
||||
model_parallel_size: ${env.MODEL_PARALLEL_SIZE:null}
|
||||
model_parallel_size: ${env.MODEL_PARALLEL_SIZE:0}
|
||||
- provider_id: sentence-transformers
|
||||
provider_type: inline::sentence-transformers
|
||||
config: {}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue