diff --git a/llama_stack/models/llama/llama3/interface.py b/llama_stack/models/llama/llama3/interface.py index b63ba4847..72f90f512 100644 --- a/llama_stack/models/llama/llama3/interface.py +++ b/llama_stack/models/llama/llama3/interface.py @@ -162,7 +162,9 @@ class LLama31Interface: tool_template = None if builtin_tools or custom_tools: tool_gen = BuiltinToolGenerator() - tool_template = tool_gen.gen(builtin_tools + custom_tools) + # Convert BuiltinTool to ToolDefinition for the gen method + builtin_tool_defs = [ToolDefinition(tool_name=tool) for tool in builtin_tools] + tool_template = tool_gen.gen(builtin_tool_defs + custom_tools) sys_content += tool_template.render() sys_content += "\n" @@ -177,14 +179,15 @@ class LLama31Interface: messages.append(RawMessage(role="system", content=sys_content)) if custom_tools: + custom_tool_gen: JsonCustomToolGenerator | FunctionTagCustomToolGenerator if self.tool_prompt_format == ToolPromptFormat.json: - tool_gen = JsonCustomToolGenerator() + custom_tool_gen = JsonCustomToolGenerator() elif self.tool_prompt_format == ToolPromptFormat.function_tag: - tool_gen = FunctionTagCustomToolGenerator() + custom_tool_gen = FunctionTagCustomToolGenerator() else: raise ValueError(f"Non supported ToolPromptFormat {self.tool_prompt_format}") - custom_template = tool_gen.gen(custom_tools) + custom_template = custom_tool_gen.gen(custom_tools) messages.append(RawMessage(role="user", content=custom_template.render())) return messages @@ -212,7 +215,7 @@ class LLama31Interface: def display_message_as_tokens(self, message: RawMessage) -> None: """Util to print tokenized string to shell""" - tokens = self.formatter.encode_message(message, self.tool_prompt_format) + tokens, _ = self.formatter.encode_message(message, self.tool_prompt_format) on_colors = [ "on_red", "on_green", @@ -251,5 +254,5 @@ def render_jinja_template(name: str, tool_prompt_format: ToolPromptFormat): tokens = interface.get_tokens(messages) special_tokens = list(interface.tokenizer.special_tokens.values()) - tokens = [(interface.tokenizer.decode([t]), t in special_tokens) for t in tokens] - return template, tokens + token_tuples = [(interface.tokenizer.decode([t]), t in special_tokens) for t in tokens] + return template, token_tuples diff --git a/llama_stack/models/llama/llama3/tokenizer.py b/llama_stack/models/llama/llama3/tokenizer.py index e47b579e3..1c272e9fa 100644 --- a/llama_stack/models/llama/llama3/tokenizer.py +++ b/llama_stack/models/llama/llama3/tokenizer.py @@ -171,7 +171,7 @@ class Tokenizer: str: The decoded string. """ # Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence. - return self.model.decode(cast(list[int], t)) + return cast(str, self.model.decode(cast(list[int], t))) @staticmethod def _split_whitespaces_or_nonwhitespaces(s: str, max_consecutive_slice_len: int) -> Iterator[str]: diff --git a/pyproject.toml b/pyproject.toml index 72f3a323f..bbe6834e3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -243,8 +243,6 @@ exclude = [ "^llama_stack/distribution/utils/exec\\.py$", "^llama_stack/distribution/utils/prompt_for_config\\.py$", "^llama_stack/models/llama/llama3/chat_format\\.py$", - "^llama_stack/models/llama/llama3/interface\\.py$", - "^llama_stack/models/llama/llama3/tokenizer\\.py$", "^llama_stack/models/llama/llama3/tool_utils\\.py$", "^llama_stack/providers/inline/agents/meta_reference/", "^llama_stack/providers/inline/agents/meta_reference/agent_instance\\.py$",