mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-05 10:13:05 +00:00
fix
This commit is contained in:
parent
63cf5dda50
commit
b239c57c54
8 changed files with 25 additions and 30 deletions
|
@ -135,7 +135,7 @@ class Llama4:
|
||||||
if print_model_input:
|
if print_model_input:
|
||||||
cprint("Input to model:\n", "yellow")
|
cprint("Input to model:\n", "yellow")
|
||||||
for inp in llm_inputs:
|
for inp in llm_inputs:
|
||||||
cprint(self.tokenizer.decode(inp.tokens.tolist()), "grey")
|
cprint(self.tokenizer.decode(inp.tokens), "grey")
|
||||||
prompt_tokens = [inp.tokens for inp in llm_inputs]
|
prompt_tokens = [inp.tokens for inp in llm_inputs]
|
||||||
|
|
||||||
bsz = len(llm_inputs)
|
bsz = len(llm_inputs)
|
||||||
|
|
|
@ -5,19 +5,10 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||||
|
|
||||||
|
|
||||||
class TokenResult(BaseModel):
|
|
||||||
token: int
|
|
||||||
text: str
|
|
||||||
logprobs: Optional[List[float]] = None
|
|
||||||
|
|
||||||
|
|
||||||
def model_checkpoint_dir(model_id) -> str:
|
def model_checkpoint_dir(model_id) -> str:
|
||||||
checkpoint_dir = Path(model_local_dir(model_id))
|
checkpoint_dir = Path(model_local_dir(model_id))
|
||||||
|
|
||||||
|
|
|
@ -51,7 +51,7 @@ class MetaReferenceInferenceConfig(BaseModel):
|
||||||
model: str = "Llama3.2-3B-Instruct",
|
model: str = "Llama3.2-3B-Instruct",
|
||||||
checkpoint_dir: str = "${env.CHECKPOINT_DIR:null}",
|
checkpoint_dir: str = "${env.CHECKPOINT_DIR:null}",
|
||||||
quantization_type: str = "${env.QUANTIZATION_TYPE:bf16}",
|
quantization_type: str = "${env.QUANTIZATION_TYPE:bf16}",
|
||||||
model_parallel_size: str = "${env.MODEL_PARALLEL_SIZE:null}",
|
model_parallel_size: str = "${env.MODEL_PARALLEL_SIZE:0}",
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
|
|
|
@ -113,6 +113,7 @@ def _infer_tool_prompt_format(request: ChatCompletionRequestWithRawContent):
|
||||||
return get_default_tool_prompt_format(request.model)
|
return get_default_tool_prompt_format(request.model)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: combine Llama3 and Llama4 generators since they are almost identical now
|
||||||
class Llama4Generator:
|
class Llama4Generator:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -165,8 +166,8 @@ class Llama4Generator:
|
||||||
max_gen_len = self.args.max_seq_len - 1
|
max_gen_len = self.args.max_seq_len - 1
|
||||||
|
|
||||||
temperature, top_p = _infer_sampling_params(sampling_params)
|
temperature, top_p = _infer_sampling_params(sampling_params)
|
||||||
yield from self.inner_generator.generate(
|
for result in self.inner_generator.generate(
|
||||||
llm_input=self.formatter.encode_content(request.content),
|
llm_inputs=[self.formatter.encode_content(request.content)],
|
||||||
max_gen_len=max_gen_len,
|
max_gen_len=max_gen_len,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
|
@ -177,7 +178,8 @@ class Llama4Generator:
|
||||||
self.args.vocab_size,
|
self.args.vocab_size,
|
||||||
request.response_format,
|
request.response_format,
|
||||||
),
|
),
|
||||||
)
|
):
|
||||||
|
yield result[0]
|
||||||
|
|
||||||
def chat_completion(
|
def chat_completion(
|
||||||
self,
|
self,
|
||||||
|
@ -189,8 +191,8 @@ class Llama4Generator:
|
||||||
max_gen_len = self.args.max_seq_len - 1
|
max_gen_len = self.args.max_seq_len - 1
|
||||||
|
|
||||||
temperature, top_p = _infer_sampling_params(sampling_params)
|
temperature, top_p = _infer_sampling_params(sampling_params)
|
||||||
yield from self.inner_generator.generate(
|
for result in self.inner_generator.generate(
|
||||||
llm_input=self.formatter.encode_dialog_prompt(request.messages, _infer_tool_prompt_format(request)),
|
llm_inputs=[self.formatter.encode_dialog_prompt(request.messages, _infer_tool_prompt_format(request))],
|
||||||
max_gen_len=max_gen_len,
|
max_gen_len=max_gen_len,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
|
@ -201,7 +203,8 @@ class Llama4Generator:
|
||||||
self.args.vocab_size,
|
self.args.vocab_size,
|
||||||
request.response_format,
|
request.response_format,
|
||||||
),
|
),
|
||||||
)
|
):
|
||||||
|
yield result[0]
|
||||||
|
|
||||||
|
|
||||||
class Llama3Generator:
|
class Llama3Generator:
|
||||||
|
@ -255,8 +258,8 @@ class Llama3Generator:
|
||||||
max_gen_len = self.args.max_seq_len - 1
|
max_gen_len = self.args.max_seq_len - 1
|
||||||
|
|
||||||
temperature, top_p = _infer_sampling_params(sampling_params)
|
temperature, top_p = _infer_sampling_params(sampling_params)
|
||||||
yield from self.inner_generator.generate(
|
for result in self.inner_generator.generate(
|
||||||
model_input=self.formatter.encode_content(request.content),
|
llm_inputs=[self.formatter.encode_content(request.content)],
|
||||||
max_gen_len=max_gen_len,
|
max_gen_len=max_gen_len,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
|
@ -267,7 +270,8 @@ class Llama3Generator:
|
||||||
self.args.vocab_size,
|
self.args.vocab_size,
|
||||||
request.response_format,
|
request.response_format,
|
||||||
),
|
),
|
||||||
)
|
):
|
||||||
|
yield result[0]
|
||||||
|
|
||||||
def chat_completion(
|
def chat_completion(
|
||||||
self,
|
self,
|
||||||
|
@ -279,8 +283,8 @@ class Llama3Generator:
|
||||||
max_gen_len = self.args.max_seq_len - 1
|
max_gen_len = self.args.max_seq_len - 1
|
||||||
|
|
||||||
temperature, top_p = _infer_sampling_params(sampling_params)
|
temperature, top_p = _infer_sampling_params(sampling_params)
|
||||||
yield from self.inner_generator.generate(
|
for result in self.inner_generator.generate(
|
||||||
model_input=self.formatter.encode_dialog_prompt(request.messages, _infer_tool_prompt_format(request)),
|
llm_inputs=[self.formatter.encode_dialog_prompt(request.messages, _infer_tool_prompt_format(request))],
|
||||||
max_gen_len=max_gen_len,
|
max_gen_len=max_gen_len,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
|
@ -291,4 +295,5 @@ class Llama3Generator:
|
||||||
self.args.vocab_size,
|
self.args.vocab_size,
|
||||||
request.response_format,
|
request.response_format,
|
||||||
),
|
),
|
||||||
)
|
):
|
||||||
|
yield result[0]
|
||||||
|
|
|
@ -149,7 +149,7 @@ class MetaReferenceInferenceImpl(
|
||||||
|
|
||||||
if self.config.create_distributed_process_group:
|
if self.config.create_distributed_process_group:
|
||||||
self.generator = LlamaModelParallelGenerator(
|
self.generator = LlamaModelParallelGenerator(
|
||||||
model_parallel_size=llama_model.pth_file_count,
|
model_parallel_size=self.config.model_parallel_size or llama_model.pth_file_count,
|
||||||
builder_fn=builder_fn,
|
builder_fn=builder_fn,
|
||||||
builder_params=builder_params,
|
builder_params=builder_params,
|
||||||
formatter=(
|
formatter=(
|
||||||
|
|
|
@ -32,13 +32,12 @@ from pydantic import BaseModel, Field
|
||||||
from torch.distributed.launcher.api import LaunchConfig, elastic_launch
|
from torch.distributed.launcher.api import LaunchConfig, elastic_launch
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
|
from llama_stack.models.llama.datatypes import GenerationResult
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
ChatCompletionRequestWithRawContent,
|
ChatCompletionRequestWithRawContent,
|
||||||
CompletionRequestWithRawContent,
|
CompletionRequestWithRawContent,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .common import TokenResult
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@ -75,7 +74,7 @@ class TaskRequest(BaseModel):
|
||||||
|
|
||||||
class TaskResponse(BaseModel):
|
class TaskResponse(BaseModel):
|
||||||
type: Literal[ProcessingMessageName.task_response] = ProcessingMessageName.task_response
|
type: Literal[ProcessingMessageName.task_response] = ProcessingMessageName.task_response
|
||||||
result: TokenResult
|
result: GenerationResult
|
||||||
|
|
||||||
|
|
||||||
class ExceptionResponse(BaseModel):
|
class ExceptionResponse(BaseModel):
|
||||||
|
|
|
@ -20,7 +20,7 @@ providers:
|
||||||
checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null}
|
checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null}
|
||||||
quantization:
|
quantization:
|
||||||
type: ${env.QUANTIZATION_TYPE:bf16}
|
type: ${env.QUANTIZATION_TYPE:bf16}
|
||||||
model_parallel_size: ${env.MODEL_PARALLEL_SIZE:null}
|
model_parallel_size: ${env.MODEL_PARALLEL_SIZE:0}
|
||||||
- provider_id: sentence-transformers
|
- provider_id: sentence-transformers
|
||||||
provider_type: inline::sentence-transformers
|
provider_type: inline::sentence-transformers
|
||||||
config: {}
|
config: {}
|
||||||
|
@ -32,7 +32,7 @@ providers:
|
||||||
checkpoint_dir: ${env.SAFETY_CHECKPOINT_DIR:null}
|
checkpoint_dir: ${env.SAFETY_CHECKPOINT_DIR:null}
|
||||||
quantization:
|
quantization:
|
||||||
type: ${env.QUANTIZATION_TYPE:bf16}
|
type: ${env.QUANTIZATION_TYPE:bf16}
|
||||||
model_parallel_size: ${env.MODEL_PARALLEL_SIZE:null}
|
model_parallel_size: ${env.MODEL_PARALLEL_SIZE:0}
|
||||||
vector_io:
|
vector_io:
|
||||||
- provider_id: faiss
|
- provider_id: faiss
|
||||||
provider_type: inline::faiss
|
provider_type: inline::faiss
|
||||||
|
|
|
@ -20,7 +20,7 @@ providers:
|
||||||
checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null}
|
checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null}
|
||||||
quantization:
|
quantization:
|
||||||
type: ${env.QUANTIZATION_TYPE:bf16}
|
type: ${env.QUANTIZATION_TYPE:bf16}
|
||||||
model_parallel_size: ${env.MODEL_PARALLEL_SIZE:null}
|
model_parallel_size: ${env.MODEL_PARALLEL_SIZE:0}
|
||||||
- provider_id: sentence-transformers
|
- provider_id: sentence-transformers
|
||||||
provider_type: inline::sentence-transformers
|
provider_type: inline::sentence-transformers
|
||||||
config: {}
|
config: {}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue