chore(api): add mypy coverage to interface

Signed-off-by: Mustafa Elbehery <melbeher@redhat.com>
This commit is contained in:
Mustafa Elbehery 2025-07-08 19:56:32 +02:00
parent 51b179e1c5
commit c67097ffeb
3 changed files with 11 additions and 10 deletions

View file

@ -162,7 +162,9 @@ class LLama31Interface:
tool_template = None
if builtin_tools or custom_tools:
tool_gen = BuiltinToolGenerator()
tool_template = tool_gen.gen(builtin_tools + custom_tools)
# Convert BuiltinTool to ToolDefinition for the gen method
builtin_tool_defs = [ToolDefinition(tool_name=tool) for tool in builtin_tools]
tool_template = tool_gen.gen(builtin_tool_defs + custom_tools)
sys_content += tool_template.render()
sys_content += "\n"
@ -177,14 +179,15 @@ class LLama31Interface:
messages.append(RawMessage(role="system", content=sys_content))
if custom_tools:
custom_tool_gen: JsonCustomToolGenerator | FunctionTagCustomToolGenerator
if self.tool_prompt_format == ToolPromptFormat.json:
tool_gen = JsonCustomToolGenerator()
custom_tool_gen = JsonCustomToolGenerator()
elif self.tool_prompt_format == ToolPromptFormat.function_tag:
tool_gen = FunctionTagCustomToolGenerator()
custom_tool_gen = FunctionTagCustomToolGenerator()
else:
raise ValueError(f"Non supported ToolPromptFormat {self.tool_prompt_format}")
custom_template = tool_gen.gen(custom_tools)
custom_template = custom_tool_gen.gen(custom_tools)
messages.append(RawMessage(role="user", content=custom_template.render()))
return messages
@ -212,7 +215,7 @@ class LLama31Interface:
def display_message_as_tokens(self, message: RawMessage) -> None:
"""Util to print tokenized string to shell"""
tokens = self.formatter.encode_message(message, self.tool_prompt_format)
tokens, _ = self.formatter.encode_message(message, self.tool_prompt_format)
on_colors = [
"on_red",
"on_green",
@ -251,5 +254,5 @@ def render_jinja_template(name: str, tool_prompt_format: ToolPromptFormat):
tokens = interface.get_tokens(messages)
special_tokens = list(interface.tokenizer.special_tokens.values())
tokens = [(interface.tokenizer.decode([t]), t in special_tokens) for t in tokens]
return template, tokens
token_tuples = [(interface.tokenizer.decode([t]), t in special_tokens) for t in tokens]
return template, token_tuples

View file

@ -171,7 +171,7 @@ class Tokenizer:
str: The decoded string.
"""
# Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence.
return self.model.decode(cast(list[int], t))
return cast(str, self.model.decode(cast(list[int], t)))
@staticmethod
def _split_whitespaces_or_nonwhitespaces(s: str, max_consecutive_slice_len: int) -> Iterator[str]:

View file

@ -243,8 +243,6 @@ exclude = [
"^llama_stack/distribution/utils/exec\\.py$",
"^llama_stack/distribution/utils/prompt_for_config\\.py$",
"^llama_stack/models/llama/llama3/chat_format\\.py$",
"^llama_stack/models/llama/llama3/interface\\.py$",
"^llama_stack/models/llama/llama3/tokenizer\\.py$",
"^llama_stack/models/llama/llama3/tool_utils\\.py$",
"^llama_stack/providers/inline/agents/meta_reference/",
"^llama_stack/providers/inline/agents/meta_reference/agent_instance\\.py$",