chore(api): add mypy coverage to interface

Signed-off-by: Mustafa Elbehery <melbeher@redhat.com>
This commit is contained in:
Mustafa Elbehery 2025-07-08 19:56:32 +02:00
parent 51b179e1c5
commit c67097ffeb
3 changed files with 11 additions and 10 deletions

View file

@ -162,7 +162,9 @@ class LLama31Interface:
tool_template = None tool_template = None
if builtin_tools or custom_tools: if builtin_tools or custom_tools:
tool_gen = BuiltinToolGenerator() tool_gen = BuiltinToolGenerator()
tool_template = tool_gen.gen(builtin_tools + custom_tools) # Convert BuiltinTool to ToolDefinition for the gen method
builtin_tool_defs = [ToolDefinition(tool_name=tool) for tool in builtin_tools]
tool_template = tool_gen.gen(builtin_tool_defs + custom_tools)
sys_content += tool_template.render() sys_content += tool_template.render()
sys_content += "\n" sys_content += "\n"
@ -177,14 +179,15 @@ class LLama31Interface:
messages.append(RawMessage(role="system", content=sys_content)) messages.append(RawMessage(role="system", content=sys_content))
if custom_tools: if custom_tools:
custom_tool_gen: JsonCustomToolGenerator | FunctionTagCustomToolGenerator
if self.tool_prompt_format == ToolPromptFormat.json: if self.tool_prompt_format == ToolPromptFormat.json:
tool_gen = JsonCustomToolGenerator() custom_tool_gen = JsonCustomToolGenerator()
elif self.tool_prompt_format == ToolPromptFormat.function_tag: elif self.tool_prompt_format == ToolPromptFormat.function_tag:
tool_gen = FunctionTagCustomToolGenerator() custom_tool_gen = FunctionTagCustomToolGenerator()
else: else:
raise ValueError(f"Non supported ToolPromptFormat {self.tool_prompt_format}") raise ValueError(f"Non supported ToolPromptFormat {self.tool_prompt_format}")
custom_template = tool_gen.gen(custom_tools) custom_template = custom_tool_gen.gen(custom_tools)
messages.append(RawMessage(role="user", content=custom_template.render())) messages.append(RawMessage(role="user", content=custom_template.render()))
return messages return messages
@ -212,7 +215,7 @@ class LLama31Interface:
def display_message_as_tokens(self, message: RawMessage) -> None: def display_message_as_tokens(self, message: RawMessage) -> None:
"""Util to print tokenized string to shell""" """Util to print tokenized string to shell"""
tokens = self.formatter.encode_message(message, self.tool_prompt_format) tokens, _ = self.formatter.encode_message(message, self.tool_prompt_format)
on_colors = [ on_colors = [
"on_red", "on_red",
"on_green", "on_green",
@ -251,5 +254,5 @@ def render_jinja_template(name: str, tool_prompt_format: ToolPromptFormat):
tokens = interface.get_tokens(messages) tokens = interface.get_tokens(messages)
special_tokens = list(interface.tokenizer.special_tokens.values()) special_tokens = list(interface.tokenizer.special_tokens.values())
tokens = [(interface.tokenizer.decode([t]), t in special_tokens) for t in tokens] token_tuples = [(interface.tokenizer.decode([t]), t in special_tokens) for t in tokens]
return template, tokens return template, token_tuples

View file

@ -171,7 +171,7 @@ class Tokenizer:
str: The decoded string. str: The decoded string.
""" """
# Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence. # Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence.
return self.model.decode(cast(list[int], t)) return cast(str, self.model.decode(cast(list[int], t)))
@staticmethod @staticmethod
def _split_whitespaces_or_nonwhitespaces(s: str, max_consecutive_slice_len: int) -> Iterator[str]: def _split_whitespaces_or_nonwhitespaces(s: str, max_consecutive_slice_len: int) -> Iterator[str]:

View file

@ -243,8 +243,6 @@ exclude = [
"^llama_stack/distribution/utils/exec\\.py$", "^llama_stack/distribution/utils/exec\\.py$",
"^llama_stack/distribution/utils/prompt_for_config\\.py$", "^llama_stack/distribution/utils/prompt_for_config\\.py$",
"^llama_stack/models/llama/llama3/chat_format\\.py$", "^llama_stack/models/llama/llama3/chat_format\\.py$",
"^llama_stack/models/llama/llama3/interface\\.py$",
"^llama_stack/models/llama/llama3/tokenizer\\.py$",
"^llama_stack/models/llama/llama3/tool_utils\\.py$", "^llama_stack/models/llama/llama3/tool_utils\\.py$",
"^llama_stack/providers/inline/agents/meta_reference/", "^llama_stack/providers/inline/agents/meta_reference/",
"^llama_stack/providers/inline/agents/meta_reference/agent_instance\\.py$", "^llama_stack/providers/inline/agents/meta_reference/agent_instance\\.py$",