mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
Add ToolPromptFormat to ChatFormat.encode_message so that tools are encoded properly
This commit is contained in:
parent
decbbc127b
commit
69d9655ecd
4 changed files with 21 additions and 5 deletions
|
@ -44,7 +44,12 @@ EventType = AgenticSystemTurnResponseEventType
|
||||||
|
|
||||||
|
|
||||||
class EventLogger:
|
class EventLogger:
|
||||||
async def log(self, event_generator, stream=True):
|
async def log(
|
||||||
|
self,
|
||||||
|
event_generator,
|
||||||
|
stream=True,
|
||||||
|
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
|
||||||
|
):
|
||||||
previous_event_type = None
|
previous_event_type = None
|
||||||
previous_step_type = None
|
previous_step_type = None
|
||||||
|
|
||||||
|
@ -132,7 +137,9 @@ class EventLogger:
|
||||||
if event_type == EventType.step_complete.value:
|
if event_type == EventType.step_complete.value:
|
||||||
response = event.payload.step_details.model_response
|
response = event.payload.step_details.model_response
|
||||||
if response.tool_calls:
|
if response.tool_calls:
|
||||||
content = ToolUtils.encode_tool_call(response.tool_calls[0])
|
content = ToolUtils.encode_tool_call(
|
||||||
|
response.tool_calls[0], tool_prompt_format
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
content = response.content
|
content = response.content
|
||||||
yield event, LogEvent(
|
yield event, LogEvent(
|
||||||
|
|
|
@ -24,7 +24,7 @@ from fairscale.nn.model_parallel.initialize import (
|
||||||
)
|
)
|
||||||
from llama_models.llama3.api.args import ModelArgs
|
from llama_models.llama3.api.args import ModelArgs
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat, ModelInput
|
from llama_models.llama3.api.chat_format import ChatFormat, ModelInput
|
||||||
from llama_models.llama3.api.datatypes import Message
|
from llama_models.llama3.api.datatypes import Message, ToolPromptFormat
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
from llama_models.llama3.reference_impl.model import Transformer
|
from llama_models.llama3.reference_impl.model import Transformer
|
||||||
from llama_models.sku_list import resolve_model
|
from llama_models.sku_list import resolve_model
|
||||||
|
@ -279,6 +279,7 @@ class Llama:
|
||||||
top_p: float = 0.9,
|
top_p: float = 0.9,
|
||||||
max_gen_len: Optional[int] = None,
|
max_gen_len: Optional[int] = None,
|
||||||
logprobs: bool = False,
|
logprobs: bool = False,
|
||||||
|
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
|
||||||
) -> Generator:
|
) -> Generator:
|
||||||
if (
|
if (
|
||||||
max_gen_len is None
|
max_gen_len is None
|
||||||
|
@ -288,7 +289,10 @@ class Llama:
|
||||||
max_gen_len = self.model.params.max_seq_len - 1
|
max_gen_len = self.model.params.max_seq_len - 1
|
||||||
|
|
||||||
yield from self.generate(
|
yield from self.generate(
|
||||||
model_input=self.formatter.encode_dialog_prompt(messages),
|
model_input=self.formatter.encode_dialog_prompt(
|
||||||
|
messages,
|
||||||
|
tool_prompt_format,
|
||||||
|
),
|
||||||
max_gen_len=max_gen_len,
|
max_gen_len=max_gen_len,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
|
|
|
@ -104,6 +104,7 @@ class MetaReferenceInferenceImpl(Inference):
|
||||||
top_p=request.sampling_params.top_p,
|
top_p=request.sampling_params.top_p,
|
||||||
max_gen_len=request.sampling_params.max_tokens,
|
max_gen_len=request.sampling_params.max_tokens,
|
||||||
logprobs=request.logprobs,
|
logprobs=request.logprobs,
|
||||||
|
tool_prompt_format=request.tool_prompt_format,
|
||||||
):
|
):
|
||||||
buffer += token_result.text
|
buffer += token_result.text
|
||||||
tokens.append(token_result.token)
|
tokens.append(token_result.token)
|
||||||
|
|
|
@ -11,7 +11,7 @@ from functools import partial
|
||||||
from typing import Generator, List, Optional
|
from typing import Generator, List, Optional
|
||||||
|
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
from llama_models.llama3.api.datatypes import Message
|
from llama_models.llama3.api.datatypes import Message, ToolPromptFormat
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
from llama_models.sku_list import resolve_model
|
from llama_models.sku_list import resolve_model
|
||||||
|
|
||||||
|
@ -27,6 +27,7 @@ class InferenceArgs:
|
||||||
top_p: float
|
top_p: float
|
||||||
max_gen_len: int
|
max_gen_len: int
|
||||||
logprobs: bool
|
logprobs: bool
|
||||||
|
tool_prompt_format: ToolPromptFormat
|
||||||
|
|
||||||
|
|
||||||
class ModelRunner:
|
class ModelRunner:
|
||||||
|
@ -41,6 +42,7 @@ class ModelRunner:
|
||||||
task.top_p,
|
task.top_p,
|
||||||
task.max_gen_len,
|
task.max_gen_len,
|
||||||
task.logprobs,
|
task.logprobs,
|
||||||
|
task.tool_prompt_format,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -93,6 +95,7 @@ class LlamaModelParallelGenerator:
|
||||||
top_p: float = 0.9,
|
top_p: float = 0.9,
|
||||||
max_gen_len: Optional[int] = None,
|
max_gen_len: Optional[int] = None,
|
||||||
logprobs: bool = False,
|
logprobs: bool = False,
|
||||||
|
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
|
||||||
) -> Generator:
|
) -> Generator:
|
||||||
req_obj = InferenceArgs(
|
req_obj = InferenceArgs(
|
||||||
messages=deepcopy(messages),
|
messages=deepcopy(messages),
|
||||||
|
@ -100,6 +103,7 @@ class LlamaModelParallelGenerator:
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
max_gen_len=max_gen_len,
|
max_gen_len=max_gen_len,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
|
tool_prompt_format=tool_prompt_format,
|
||||||
)
|
)
|
||||||
|
|
||||||
gen = self.group.run_inference(req_obj)
|
gen = self.group.run_inference(req_obj)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue