mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 20:14:13 +00:00
Add ToolPromptFormat to ChatFormat.encode_message so that tools are encoded properly
This commit is contained in:
parent
decbbc127b
commit
69d9655ecd
4 changed files with 21 additions and 5 deletions
|
@ -11,7 +11,7 @@ from functools import partial
|
|||
from typing import Generator, List, Optional
|
||||
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from llama_models.llama3.api.datatypes import Message
|
||||
from llama_models.llama3.api.datatypes import Message, ToolPromptFormat
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
|
@ -27,6 +27,7 @@ class InferenceArgs:
|
|||
top_p: float
|
||||
max_gen_len: int
|
||||
logprobs: bool
|
||||
tool_prompt_format: ToolPromptFormat
|
||||
|
||||
|
||||
class ModelRunner:
|
||||
|
@ -41,6 +42,7 @@ class ModelRunner:
|
|||
task.top_p,
|
||||
task.max_gen_len,
|
||||
task.logprobs,
|
||||
task.tool_prompt_format,
|
||||
)
|
||||
|
||||
|
||||
|
@ -93,6 +95,7 @@ class LlamaModelParallelGenerator:
|
|||
top_p: float = 0.9,
|
||||
max_gen_len: Optional[int] = None,
|
||||
logprobs: bool = False,
|
||||
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
|
||||
) -> Generator:
|
||||
req_obj = InferenceArgs(
|
||||
messages=deepcopy(messages),
|
||||
|
@ -100,6 +103,7 @@ class LlamaModelParallelGenerator:
|
|||
top_p=top_p,
|
||||
max_gen_len=max_gen_len,
|
||||
logprobs=logprobs,
|
||||
tool_prompt_format=tool_prompt_format,
|
||||
)
|
||||
|
||||
gen = self.group.run_inference(req_obj)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue