mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-22 22:42:25 +00:00
chore(api): add mypy coverage to interface
Signed-off-by: Mustafa Elbehery <melbeher@redhat.com>
This commit is contained in:
parent
51b179e1c5
commit
c67097ffeb
3 changed files with 11 additions and 10 deletions
|
|
@ -162,7 +162,9 @@ class LLama31Interface:
|
|||
tool_template = None
|
||||
if builtin_tools or custom_tools:
|
||||
tool_gen = BuiltinToolGenerator()
|
||||
tool_template = tool_gen.gen(builtin_tools + custom_tools)
|
||||
# Convert BuiltinTool to ToolDefinition for the gen method
|
||||
builtin_tool_defs = [ToolDefinition(tool_name=tool) for tool in builtin_tools]
|
||||
tool_template = tool_gen.gen(builtin_tool_defs + custom_tools)
|
||||
|
||||
sys_content += tool_template.render()
|
||||
sys_content += "\n"
|
||||
|
|
@ -177,14 +179,15 @@ class LLama31Interface:
|
|||
messages.append(RawMessage(role="system", content=sys_content))
|
||||
|
||||
if custom_tools:
|
||||
custom_tool_gen: JsonCustomToolGenerator | FunctionTagCustomToolGenerator
|
||||
if self.tool_prompt_format == ToolPromptFormat.json:
|
||||
tool_gen = JsonCustomToolGenerator()
|
||||
custom_tool_gen = JsonCustomToolGenerator()
|
||||
elif self.tool_prompt_format == ToolPromptFormat.function_tag:
|
||||
tool_gen = FunctionTagCustomToolGenerator()
|
||||
custom_tool_gen = FunctionTagCustomToolGenerator()
|
||||
else:
|
||||
raise ValueError(f"Non supported ToolPromptFormat {self.tool_prompt_format}")
|
||||
|
||||
custom_template = tool_gen.gen(custom_tools)
|
||||
custom_template = custom_tool_gen.gen(custom_tools)
|
||||
messages.append(RawMessage(role="user", content=custom_template.render()))
|
||||
|
||||
return messages
|
||||
|
|
@ -212,7 +215,7 @@ class LLama31Interface:
|
|||
|
||||
def display_message_as_tokens(self, message: RawMessage) -> None:
|
||||
"""Util to print tokenized string to shell"""
|
||||
tokens = self.formatter.encode_message(message, self.tool_prompt_format)
|
||||
tokens, _ = self.formatter.encode_message(message, self.tool_prompt_format)
|
||||
on_colors = [
|
||||
"on_red",
|
||||
"on_green",
|
||||
|
|
@ -251,5 +254,5 @@ def render_jinja_template(name: str, tool_prompt_format: ToolPromptFormat):
|
|||
|
||||
tokens = interface.get_tokens(messages)
|
||||
special_tokens = list(interface.tokenizer.special_tokens.values())
|
||||
tokens = [(interface.tokenizer.decode([t]), t in special_tokens) for t in tokens]
|
||||
return template, tokens
|
||||
token_tuples = [(interface.tokenizer.decode([t]), t in special_tokens) for t in tokens]
|
||||
return template, token_tuples
|
||||
|
|
|
|||
|
|
@ -171,7 +171,7 @@ class Tokenizer:
|
|||
str: The decoded string.
|
||||
"""
|
||||
# Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence.
|
||||
return self.model.decode(cast(list[int], t))
|
||||
return cast(str, self.model.decode(cast(list[int], t)))
|
||||
|
||||
@staticmethod
|
||||
def _split_whitespaces_or_nonwhitespaces(s: str, max_consecutive_slice_len: int) -> Iterator[str]:
|
||||
|
|
|
|||
|
|
@ -243,8 +243,6 @@ exclude = [
|
|||
"^llama_stack/distribution/utils/exec\\.py$",
|
||||
"^llama_stack/distribution/utils/prompt_for_config\\.py$",
|
||||
"^llama_stack/models/llama/llama3/chat_format\\.py$",
|
||||
"^llama_stack/models/llama/llama3/interface\\.py$",
|
||||
"^llama_stack/models/llama/llama3/tokenizer\\.py$",
|
||||
"^llama_stack/models/llama/llama3/tool_utils\\.py$",
|
||||
"^llama_stack/providers/inline/agents/meta_reference/",
|
||||
"^llama_stack/providers/inline/agents/meta_reference/agent_instance\\.py$",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue