mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
Merge remote-tracking branch 'origin/main' into RFC-0001-The-Llama-Stack
This commit is contained in:
commit
75bbe787b6
35 changed files with 309 additions and 90 deletions
|
@ -110,6 +110,35 @@ class Session(BaseModel):
|
||||||
started_at: datetime
|
started_at: datetime
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ToolPromptFormat(Enum):
|
||||||
|
"""This Enum refers to the prompt format for calling zero shot tools
|
||||||
|
|
||||||
|
`json` --
|
||||||
|
Refers to the json format for calling tools.
|
||||||
|
The json format takes the form like
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function" : {
|
||||||
|
"name": "function_name",
|
||||||
|
"description": "function_description",
|
||||||
|
"parameters": {...}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
`function_tag` --
|
||||||
|
This is an example of how you could define
|
||||||
|
your own user defined format for making tool calls.
|
||||||
|
The function_tag format looks like this,
|
||||||
|
<function=function_name>(parameters)</function>
|
||||||
|
|
||||||
|
The detailed prompts for each of these formats are defined in `system_prompt.py`
|
||||||
|
"""
|
||||||
|
|
||||||
|
json = "json"
|
||||||
|
function_tag = "function_tag"
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class AgenticSystemInstanceConfig(BaseModel):
|
class AgenticSystemInstanceConfig(BaseModel):
|
||||||
instructions: str
|
instructions: str
|
||||||
|
@ -127,6 +156,9 @@ class AgenticSystemInstanceConfig(BaseModel):
|
||||||
# if you completely want to replace the messages prefixed by the system,
|
# if you completely want to replace the messages prefixed by the system,
|
||||||
# this is debug only
|
# this is debug only
|
||||||
debug_prefix_messages: Optional[List[Message]] = Field(default_factory=list)
|
debug_prefix_messages: Optional[List[Message]] = Field(default_factory=list)
|
||||||
|
tool_prompt_format: Optional[ToolPromptFormat] = Field(
|
||||||
|
default=ToolPromptFormat.json
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AgenticSystemTurnResponseEventType(Enum):
|
class AgenticSystemTurnResponseEventType(Enum):
|
||||||
|
|
|
@ -13,8 +13,15 @@ import fire
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from llama_models.llama3_1.api.datatypes import BuiltinTool, SamplingParams
|
from llama_models.llama3.api.datatypes import (
|
||||||
|
BuiltinTool,
|
||||||
|
SamplingParams,
|
||||||
|
ToolParamDefinition,
|
||||||
|
UserMessage,
|
||||||
|
)
|
||||||
|
from termcolor import cprint
|
||||||
|
|
||||||
|
from llama_toolchain.agentic_system.event_logger import EventLogger
|
||||||
from .api import (
|
from .api import (
|
||||||
AgenticSystem,
|
AgenticSystem,
|
||||||
AgenticSystemCreateRequest,
|
AgenticSystemCreateRequest,
|
||||||
|
@ -25,6 +32,7 @@ from .api import (
|
||||||
AgenticSystemToolDefinition,
|
AgenticSystemToolDefinition,
|
||||||
AgenticSystemTurnCreateRequest,
|
AgenticSystemTurnCreateRequest,
|
||||||
AgenticSystemTurnResponseStreamChunk,
|
AgenticSystemTurnResponseStreamChunk,
|
||||||
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -87,7 +95,7 @@ class AgenticSystemClient(AgenticSystem):
|
||||||
|
|
||||||
async def run_main(host: str, port: int):
|
async def run_main(host: str, port: int):
|
||||||
# client to test remote impl of agentic system
|
# client to test remote impl of agentic system
|
||||||
api = await AgenticSystemClient(f"http://{host}:{port}")
|
api = AgenticSystemClient(f"http://{host}:{port}")
|
||||||
|
|
||||||
tool_definitions = [
|
tool_definitions = [
|
||||||
AgenticSystemToolDefinition(
|
AgenticSystemToolDefinition(
|
||||||
|
@ -96,13 +104,28 @@ async def run_main(host: str, port: int):
|
||||||
AgenticSystemToolDefinition(
|
AgenticSystemToolDefinition(
|
||||||
tool_name=BuiltinTool.wolfram_alpha,
|
tool_name=BuiltinTool.wolfram_alpha,
|
||||||
),
|
),
|
||||||
AgenticSystemToolDefinition(
|
|
||||||
tool_name=BuiltinTool.photogen,
|
|
||||||
),
|
|
||||||
AgenticSystemToolDefinition(
|
AgenticSystemToolDefinition(
|
||||||
tool_name=BuiltinTool.code_interpreter,
|
tool_name=BuiltinTool.code_interpreter,
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
tool_definitions += [
|
||||||
|
AgenticSystemToolDefinition(
|
||||||
|
tool_name="get_boiling_point",
|
||||||
|
description="Get the boiling point of a imaginary liquids (eg. polyjuice)",
|
||||||
|
parameters={
|
||||||
|
"liquid_name": ToolParamDefinition(
|
||||||
|
param_type="str",
|
||||||
|
description="The name of the liquid",
|
||||||
|
required=True,
|
||||||
|
),
|
||||||
|
"celcius": ToolParamDefinition(
|
||||||
|
param_type="str",
|
||||||
|
description="Whether to return the boiling point in Celcius",
|
||||||
|
required=False,
|
||||||
|
),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
create_request = AgenticSystemCreateRequest(
|
create_request = AgenticSystemCreateRequest(
|
||||||
model="Meta-Llama3.1-8B-Instruct",
|
model="Meta-Llama3.1-8B-Instruct",
|
||||||
|
@ -114,12 +137,44 @@ async def run_main(host: str, port: int):
|
||||||
output_shields=[],
|
output_shields=[],
|
||||||
quantization_config=None,
|
quantization_config=None,
|
||||||
debug_prefix_messages=[],
|
debug_prefix_messages=[],
|
||||||
|
tool_prompt_format=ToolPromptFormat.json,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
create_response = await api.create_agentic_system(create_request)
|
create_response = await api.create_agentic_system(create_request)
|
||||||
print(create_response)
|
print(create_response)
|
||||||
# TODO: Add chat session / turn apis to test e2e
|
|
||||||
|
session_response = await api.create_agentic_system_session(
|
||||||
|
AgenticSystemSessionCreateRequest(
|
||||||
|
system_id=create_response.system_id,
|
||||||
|
session_name="test_session",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
print(session_response)
|
||||||
|
|
||||||
|
user_prompts = [
|
||||||
|
"Who are you?",
|
||||||
|
"what is the 100th prime number?",
|
||||||
|
"Search web for who was 44th President of USA?",
|
||||||
|
"Write code to check if a number is prime. Use that to check if 7 is prime",
|
||||||
|
"What is the boiling point of polyjuicepotion ?",
|
||||||
|
]
|
||||||
|
for content in user_prompts:
|
||||||
|
cprint(f"User> {content}", color="blue")
|
||||||
|
iterator = api.create_agentic_system_turn(
|
||||||
|
AgenticSystemTurnCreateRequest(
|
||||||
|
system_id=create_response.system_id,
|
||||||
|
session_id=session_response.session_id,
|
||||||
|
messages=[
|
||||||
|
UserMessage(content=content),
|
||||||
|
],
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
async for event, log in EventLogger().log(iterator):
|
||||||
|
if log is not None:
|
||||||
|
log.print()
|
||||||
|
|
||||||
|
|
||||||
def main(host: str, port: int):
|
def main(host: str, port: int):
|
||||||
|
|
|
@ -6,16 +6,16 @@
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from llama_models.llama3_1.api.datatypes import ToolResponseMessage
|
from llama_models.llama3.api.datatypes import ToolResponseMessage
|
||||||
from llama_models.llama3_1.api.tool_utils import ToolUtils
|
from llama_models.llama3.api.tool_utils import ToolUtils
|
||||||
|
|
||||||
|
from termcolor import cprint
|
||||||
|
|
||||||
from llama_toolchain.agentic_system.api import (
|
from llama_toolchain.agentic_system.api import (
|
||||||
AgenticSystemTurnResponseEventType,
|
AgenticSystemTurnResponseEventType,
|
||||||
StepType,
|
StepType,
|
||||||
)
|
)
|
||||||
|
|
||||||
from termcolor import cprint
|
|
||||||
|
|
||||||
|
|
||||||
class LogEvent:
|
class LogEvent:
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
|
@ -10,6 +10,8 @@ import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import AsyncGenerator, List, Optional
|
from typing import AsyncGenerator, List, Optional
|
||||||
|
|
||||||
|
from termcolor import cprint
|
||||||
|
|
||||||
from llama_toolchain.agentic_system.api.datatypes import (
|
from llama_toolchain.agentic_system.api.datatypes import (
|
||||||
AgenticSystemInstanceConfig,
|
AgenticSystemInstanceConfig,
|
||||||
AgenticSystemTurnResponseEvent,
|
AgenticSystemTurnResponseEvent,
|
||||||
|
@ -24,6 +26,7 @@ from llama_toolchain.agentic_system.api.datatypes import (
|
||||||
ShieldCallStep,
|
ShieldCallStep,
|
||||||
StepType,
|
StepType,
|
||||||
ToolExecutionStep,
|
ToolExecutionStep,
|
||||||
|
ToolPromptFormat,
|
||||||
Turn,
|
Turn,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -51,7 +54,6 @@ from llama_toolchain.safety.api.datatypes import (
|
||||||
ShieldDefinition,
|
ShieldDefinition,
|
||||||
ShieldResponse,
|
ShieldResponse,
|
||||||
)
|
)
|
||||||
from termcolor import cprint
|
|
||||||
from llama_toolchain.agentic_system.api.endpoints import * # noqa
|
from llama_toolchain.agentic_system.api.endpoints import * # noqa
|
||||||
|
|
||||||
from .safety import SafetyException, ShieldRunnerMixin
|
from .safety import SafetyException, ShieldRunnerMixin
|
||||||
|
@ -74,6 +76,7 @@ class AgentInstance(ShieldRunnerMixin):
|
||||||
output_shields: List[ShieldDefinition],
|
output_shields: List[ShieldDefinition],
|
||||||
max_infer_iters: int = 10,
|
max_infer_iters: int = 10,
|
||||||
prefix_messages: Optional[List[Message]] = None,
|
prefix_messages: Optional[List[Message]] = None,
|
||||||
|
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||||
):
|
):
|
||||||
self.system_id = system_id
|
self.system_id = system_id
|
||||||
self.instance_config = instance_config
|
self.instance_config = instance_config
|
||||||
|
@ -86,7 +89,9 @@ class AgentInstance(ShieldRunnerMixin):
|
||||||
self.prefix_messages = prefix_messages
|
self.prefix_messages = prefix_messages
|
||||||
else:
|
else:
|
||||||
self.prefix_messages = get_agentic_prefix_messages(
|
self.prefix_messages = get_agentic_prefix_messages(
|
||||||
builtin_tools, custom_tool_definitions
|
builtin_tools,
|
||||||
|
custom_tool_definitions,
|
||||||
|
tool_prompt_format,
|
||||||
)
|
)
|
||||||
|
|
||||||
for m in self.prefix_messages:
|
for m in self.prefix_messages:
|
||||||
|
|
|
@ -108,6 +108,7 @@ class MetaReferenceAgenticSystemImpl(AgenticSystem):
|
||||||
input_shields=cfg.input_shields,
|
input_shields=cfg.input_shields,
|
||||||
output_shields=cfg.output_shields,
|
output_shields=cfg.output_shields,
|
||||||
prefix_messages=cfg.debug_prefix_messages,
|
prefix_messages=cfg.debug_prefix_messages,
|
||||||
|
tool_prompt_format=cfg.tool_prompt_format,
|
||||||
)
|
)
|
||||||
|
|
||||||
return AgenticSystemCreateResponse(
|
return AgenticSystemCreateResponse(
|
||||||
|
|
|
@ -6,14 +6,15 @@
|
||||||
|
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from llama_models.llama3_1.api.datatypes import Message, Role
|
from llama_models.llama3.api.datatypes import Message, Role, UserMessage
|
||||||
|
from termcolor import cprint
|
||||||
|
|
||||||
from llama_toolchain.safety.api.datatypes import (
|
from llama_toolchain.safety.api.datatypes import (
|
||||||
OnViolationAction,
|
OnViolationAction,
|
||||||
ShieldDefinition,
|
ShieldDefinition,
|
||||||
ShieldResponse,
|
ShieldResponse,
|
||||||
)
|
)
|
||||||
from llama_toolchain.safety.api.endpoints import RunShieldRequest, Safety
|
from llama_toolchain.safety.api.endpoints import RunShieldRequest, Safety
|
||||||
from termcolor import cprint
|
|
||||||
|
|
||||||
|
|
||||||
class SafetyException(Exception): # noqa: N818
|
class SafetyException(Exception): # noqa: N818
|
||||||
|
@ -36,12 +37,11 @@ class ShieldRunnerMixin:
|
||||||
async def run_shields(
|
async def run_shields(
|
||||||
self, messages: List[Message], shields: List[ShieldDefinition]
|
self, messages: List[Message], shields: List[ShieldDefinition]
|
||||||
) -> List[ShieldResponse]:
|
) -> List[ShieldResponse]:
|
||||||
|
messages = messages.copy()
|
||||||
# some shields like llama-guard require the first message to be a user message
|
# some shields like llama-guard require the first message to be a user message
|
||||||
# since this might be a tool call, first role might not be user
|
# since this might be a tool call, first role might not be user
|
||||||
if len(messages) > 0 and messages[0].role != Role.user.value:
|
if len(messages) > 0 and messages[0].role != Role.user.value:
|
||||||
# TODO(ashwin): we need to change the type of the message, this kind of modification
|
messages[0] = UserMessage(content=messages[0].content)
|
||||||
# is no longer appropriate
|
|
||||||
messages[0].role = Role.user.value
|
|
||||||
|
|
||||||
res = await self.safety_api.run_shields(
|
res = await self.safety_api.run_shields(
|
||||||
RunShieldRequest(
|
RunShieldRequest(
|
||||||
|
|
|
@ -5,21 +5,27 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import textwrap
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
from llama_toolchain.agentic_system.api.datatypes import ToolPromptFormat
|
||||||
|
|
||||||
from llama_toolchain.inference.api import (
|
from llama_toolchain.inference.api import (
|
||||||
BuiltinTool,
|
BuiltinTool,
|
||||||
Message,
|
Message,
|
||||||
SystemMessage,
|
SystemMessage,
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
|
UserMessage,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .tools.builtin import SingleMessageBuiltinTool
|
from .tools.builtin import SingleMessageBuiltinTool
|
||||||
|
|
||||||
|
|
||||||
def get_agentic_prefix_messages(
|
def get_agentic_prefix_messages(
|
||||||
builtin_tools: List[SingleMessageBuiltinTool], custom_tools: List[ToolDefinition]
|
builtin_tools: List[SingleMessageBuiltinTool],
|
||||||
|
custom_tools: List[ToolDefinition],
|
||||||
|
tool_prompt_format: ToolPromptFormat,
|
||||||
) -> List[Message]:
|
) -> List[Message]:
|
||||||
messages = []
|
messages = []
|
||||||
content = ""
|
content = ""
|
||||||
|
@ -34,28 +40,52 @@ def get_agentic_prefix_messages(
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
if tool_str:
|
if tool_str:
|
||||||
content += f"Tools: {tool_str}\n"
|
content += f"Tools: {tool_str}"
|
||||||
|
|
||||||
current_date = datetime.now()
|
current_date = datetime.now()
|
||||||
formatted_date = current_date.strftime("%d %B %Y")
|
formatted_date = current_date.strftime("%d %B %Y")
|
||||||
date_str = f"""
|
date_str = f"""
|
||||||
Cutting Knowledge Date: December 2023
|
Cutting Knowledge Date: December 2023
|
||||||
Today Date: {formatted_date}\n\n"""
|
Today Date: {formatted_date}\n"""
|
||||||
content += date_str
|
content += date_str
|
||||||
|
messages.append(SystemMessage(content=content))
|
||||||
|
|
||||||
if custom_tools:
|
if custom_tools:
|
||||||
custom_message = get_system_prompt_for_custom_tools(custom_tools)
|
if tool_prompt_format == ToolPromptFormat.function_tag:
|
||||||
content += custom_message
|
text = prompt_for_function_tag(custom_tools)
|
||||||
|
messages.append(UserMessage(content=text))
|
||||||
|
elif tool_prompt_format == ToolPromptFormat.json:
|
||||||
|
text = prompt_for_json(custom_tools)
|
||||||
|
messages.append(UserMessage(content=text))
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"Tool prompt format {tool_prompt_format} is not supported"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
messages.append(SystemMessage(content=content))
|
||||||
|
|
||||||
# TODO: Replace this hard coded message with instructions coming in the request
|
|
||||||
if False:
|
|
||||||
content += "You are a helpful Assistant."
|
|
||||||
|
|
||||||
messages.append(SystemMessage(content=content))
|
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
|
||||||
def get_system_prompt_for_custom_tools(custom_tools: List[ToolDefinition]) -> str:
|
def prompt_for_json(custom_tools: List[ToolDefinition]) -> str:
|
||||||
|
tool_defs = "\n".join(
|
||||||
|
translate_custom_tool_definition_to_json(t) for t in custom_tools
|
||||||
|
)
|
||||||
|
content = textwrap.dedent(
|
||||||
|
"""
|
||||||
|
Answer the user's question by making use of the following functions if needed.
|
||||||
|
If none of the function can be used, please say so.
|
||||||
|
Here is a list of functions in JSON format:
|
||||||
|
{tool_defs}
|
||||||
|
|
||||||
|
Return function calls in JSON format.
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
content = content.lstrip("\n").format(tool_defs=tool_defs)
|
||||||
|
return content
|
||||||
|
|
||||||
|
|
||||||
|
def prompt_for_function_tag(custom_tools: List[ToolDefinition]) -> str:
|
||||||
custom_tool_params = ""
|
custom_tool_params = ""
|
||||||
for t in custom_tools:
|
for t in custom_tools:
|
||||||
custom_tool_params += get_instruction_string(t) + "\n"
|
custom_tool_params += get_instruction_string(t) + "\n"
|
||||||
|
@ -76,7 +106,6 @@ Reminder:
|
||||||
- Required parameters MUST be specified
|
- Required parameters MUST be specified
|
||||||
- Only call one function at a time
|
- Only call one function at a time
|
||||||
- Put the entire function call reply on one line
|
- Put the entire function call reply on one line
|
||||||
|
|
||||||
"""
|
"""
|
||||||
return content
|
return content
|
||||||
|
|
||||||
|
@ -98,7 +127,6 @@ def get_parameters_string(custom_tool_definition) -> str:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# NOTE: Unused right now
|
|
||||||
def translate_custom_tool_definition_to_json(tool_def):
|
def translate_custom_tool_definition_to_json(tool_def):
|
||||||
"""Translates ToolDefinition to json as expected by model
|
"""Translates ToolDefinition to json as expected by model
|
||||||
eg. output for a function
|
eg. output for a function
|
||||||
|
@ -149,4 +177,4 @@ def translate_custom_tool_definition_to_json(tool_def):
|
||||||
else:
|
else:
|
||||||
func_def["function"]["parameters"] = {}
|
func_def["function"]["parameters"] = {}
|
||||||
|
|
||||||
return json.dumps(func_def)
|
return json.dumps(func_def, indent=4)
|
||||||
|
|
|
@ -9,7 +9,7 @@ import json
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
|
|
||||||
from llama_models.llama3_1.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
from llama_toolchain.agentic_system.api import * # noqa: F403
|
from llama_toolchain.agentic_system.api import * # noqa: F403
|
||||||
|
|
||||||
# TODO: this is symptomatic of us needing to pull more tooling related utilities
|
# TODO: this is symptomatic of us needing to pull more tooling related utilities
|
||||||
|
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
from typing import Any, AsyncGenerator, List
|
from typing import Any, AsyncGenerator, List
|
||||||
|
|
||||||
from llama_models.llama3_1.api.datatypes import StopReason, ToolResponseMessage
|
from llama_models.llama3.api.datatypes import StopReason, ToolResponseMessage
|
||||||
|
|
||||||
from llama_toolchain.agentic_system.api import (
|
from llama_toolchain.agentic_system.api import (
|
||||||
AgenticSystem,
|
AgenticSystem,
|
||||||
|
|
|
@ -7,7 +7,7 @@
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any, List, Optional
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
from llama_models.llama3_1.api.datatypes import BuiltinTool, Message, SamplingParams
|
from llama_models.llama3.api.datatypes import BuiltinTool, Message, SamplingParams
|
||||||
|
|
||||||
from llama_toolchain.agentic_system.api import (
|
from llama_toolchain.agentic_system.api import (
|
||||||
AgenticSystemCreateRequest,
|
AgenticSystemCreateRequest,
|
||||||
|
@ -15,6 +15,7 @@ from llama_toolchain.agentic_system.api import (
|
||||||
AgenticSystemSessionCreateRequest,
|
AgenticSystemSessionCreateRequest,
|
||||||
AgenticSystemToolDefinition,
|
AgenticSystemToolDefinition,
|
||||||
)
|
)
|
||||||
|
from llama_toolchain.agentic_system.api.datatypes import ToolPromptFormat
|
||||||
from llama_toolchain.agentic_system.client import AgenticSystemClient
|
from llama_toolchain.agentic_system.client import AgenticSystemClient
|
||||||
|
|
||||||
from llama_toolchain.agentic_system.tools.custom.execute import (
|
from llama_toolchain.agentic_system.tools.custom.execute import (
|
||||||
|
@ -64,6 +65,7 @@ async def get_agent_system_instance(
|
||||||
custom_tools: Optional[List[Any]] = None,
|
custom_tools: Optional[List[Any]] = None,
|
||||||
disable_safety: bool = False,
|
disable_safety: bool = False,
|
||||||
model: str = "Meta-Llama3.1-8B-Instruct",
|
model: str = "Meta-Llama3.1-8B-Instruct",
|
||||||
|
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
|
||||||
) -> AgenticSystemClientWrapper:
|
) -> AgenticSystemClientWrapper:
|
||||||
custom_tools = custom_tools or []
|
custom_tools = custom_tools or []
|
||||||
|
|
||||||
|
@ -113,6 +115,7 @@ async def get_agent_system_instance(
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
sampling_params=SamplingParams(),
|
sampling_params=SamplingParams(),
|
||||||
|
tool_prompt_format=tool_prompt_format,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
create_response = await api.create_agentic_system(create_request)
|
create_response = await api.create_agentic_system(create_request)
|
||||||
|
|
|
@ -6,18 +6,22 @@
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import time
|
import time
|
||||||
|
from datetime import datetime
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
from pydantic import BaseModel
|
||||||
from llama_toolchain.cli.subcommand import Subcommand
|
|
||||||
|
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
|
from llama_toolchain.cli.subcommand import Subcommand
|
||||||
|
|
||||||
|
|
||||||
class Download(Subcommand):
|
class Download(Subcommand):
|
||||||
"""Llama cli for downloading llama toolchain assets"""
|
"""Llama cli for downloading llama toolchain assets"""
|
||||||
|
@ -45,7 +49,7 @@ def setup_download_parser(parser: argparse.ArgumentParser) -> None:
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--model-id",
|
"--model-id",
|
||||||
choices=[x.descriptor() for x in models],
|
choices=[x.descriptor() for x in models],
|
||||||
required=True,
|
required=False,
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--hf-token",
|
"--hf-token",
|
||||||
|
@ -70,6 +74,12 @@ For source=huggingface, files matching any of the patterns are not downloaded. D
|
||||||
safetensors files to avoid downloading duplicate weights.
|
safetensors files to avoid downloading duplicate weights.
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--manifest-file",
|
||||||
|
type=str,
|
||||||
|
help="For source=meta, you can download models from a manifest file containing a file => URL mapping",
|
||||||
|
required=False,
|
||||||
|
)
|
||||||
parser.set_defaults(func=partial(run_download_cmd, parser=parser))
|
parser.set_defaults(func=partial(run_download_cmd, parser=parser))
|
||||||
|
|
||||||
|
|
||||||
|
@ -88,7 +98,7 @@ def _hf_download(
|
||||||
if repo_id is None:
|
if repo_id is None:
|
||||||
raise ValueError(f"No repo id found for model {model.descriptor()}")
|
raise ValueError(f"No repo id found for model {model.descriptor()}")
|
||||||
|
|
||||||
output_dir = model_local_dir(model)
|
output_dir = model_local_dir(model.descriptor())
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
try:
|
try:
|
||||||
true_output_dir = snapshot_download(
|
true_output_dir = snapshot_download(
|
||||||
|
@ -118,7 +128,7 @@ def _meta_download(model: "Model", meta_url: str):
|
||||||
|
|
||||||
from llama_toolchain.common.model_utils import model_local_dir
|
from llama_toolchain.common.model_utils import model_local_dir
|
||||||
|
|
||||||
output_dir = Path(model_local_dir(model))
|
output_dir = Path(model_local_dir(model.descriptor()))
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
info = llama_meta_net_info(model)
|
info = llama_meta_net_info(model)
|
||||||
|
@ -139,6 +149,14 @@ def _meta_download(model: "Model", meta_url: str):
|
||||||
def run_download_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser):
|
def run_download_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser):
|
||||||
from llama_models.sku_list import resolve_model
|
from llama_models.sku_list import resolve_model
|
||||||
|
|
||||||
|
if args.manifest_file:
|
||||||
|
_download_from_manifest(args.manifest_file)
|
||||||
|
return
|
||||||
|
|
||||||
|
if args.model_id is None:
|
||||||
|
parser.error("Please provide a model id")
|
||||||
|
return
|
||||||
|
|
||||||
model = resolve_model(args.model_id)
|
model = resolve_model(args.model_id)
|
||||||
if model is None:
|
if model is None:
|
||||||
parser.error(f"Model {args.model_id} not found")
|
parser.error(f"Model {args.model_id} not found")
|
||||||
|
@ -156,6 +174,54 @@ def run_download_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser):
|
||||||
_meta_download(model, meta_url)
|
_meta_download(model, meta_url)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelEntry(BaseModel):
|
||||||
|
model_id: str
|
||||||
|
files: Dict[str, str]
|
||||||
|
|
||||||
|
|
||||||
|
class Manifest(BaseModel):
|
||||||
|
models: List[ModelEntry]
|
||||||
|
expires_on: datetime
|
||||||
|
|
||||||
|
|
||||||
|
def _download_from_manifest(manifest_file: str):
|
||||||
|
from llama_toolchain.common.model_utils import model_local_dir
|
||||||
|
|
||||||
|
with open(manifest_file, "r") as f:
|
||||||
|
d = json.load(f)
|
||||||
|
manifest = Manifest(**d)
|
||||||
|
|
||||||
|
if datetime.now() > manifest.expires_on:
|
||||||
|
raise ValueError(f"Manifest URLs have expired on {manifest.expires_on}")
|
||||||
|
|
||||||
|
for entry in manifest.models:
|
||||||
|
print(f"Downloading model {entry.model_id}...")
|
||||||
|
output_dir = Path(model_local_dir(entry.model_id))
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
|
if any(output_dir.iterdir()):
|
||||||
|
cprint(f"Output directory {output_dir} is not empty.", "red")
|
||||||
|
|
||||||
|
while True:
|
||||||
|
resp = input(
|
||||||
|
"Do you want to (C)ontinue download or (R)estart completely? (continue/restart): "
|
||||||
|
)
|
||||||
|
if resp.lower() == "restart" or resp.lower() == "r":
|
||||||
|
shutil.rmtree(output_dir)
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
break
|
||||||
|
elif resp.lower() == "continue" or resp.lower() == "c":
|
||||||
|
print("Continuing download...")
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
cprint("Invalid response. Please try again.", "red")
|
||||||
|
|
||||||
|
for fname, url in entry.files.items():
|
||||||
|
output_file = str(output_dir / fname)
|
||||||
|
downloader = ResumableDownloader(url, output_file)
|
||||||
|
asyncio.run(downloader.download())
|
||||||
|
|
||||||
|
|
||||||
class ResumableDownloader:
|
class ResumableDownloader:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -190,7 +256,7 @@ class ResumableDownloader:
|
||||||
|
|
||||||
async def download(self) -> None:
|
async def download(self) -> None:
|
||||||
self.start_time = time.time()
|
self.start_time = time.time()
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient(follow_redirects=True) as client:
|
||||||
await self.get_file_info(client)
|
await self.get_file_info(client)
|
||||||
|
|
||||||
if os.path.exists(self.output_file):
|
if os.path.exists(self.output_file):
|
||||||
|
@ -222,7 +288,7 @@ class ResumableDownloader:
|
||||||
headers = {
|
headers = {
|
||||||
"Range": f"bytes={self.downloaded_size}-{self.downloaded_size + request_size}"
|
"Range": f"bytes={self.downloaded_size}-{self.downloaded_size + request_size}"
|
||||||
}
|
}
|
||||||
# print(f"Downloading `{self.output_file}`....{headers}")
|
print(f"Downloading `{self.output_file}`....{headers}")
|
||||||
try:
|
try:
|
||||||
async with client.stream(
|
async with client.stream(
|
||||||
"GET", self.url, headers=headers
|
"GET", self.url, headers=headers
|
||||||
|
|
|
@ -7,10 +7,10 @@
|
||||||
import argparse
|
import argparse
|
||||||
import textwrap
|
import textwrap
|
||||||
|
|
||||||
from llama_toolchain.cli.subcommand import Subcommand
|
|
||||||
|
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
|
|
||||||
|
from llama_toolchain.cli.subcommand import Subcommand
|
||||||
|
|
||||||
|
|
||||||
class ModelTemplate(Subcommand):
|
class ModelTemplate(Subcommand):
|
||||||
"""Llama model cli for describe a model template (message formats)"""
|
"""Llama model cli for describe a model template (message formats)"""
|
||||||
|
@ -48,10 +48,11 @@ class ModelTemplate(Subcommand):
|
||||||
)
|
)
|
||||||
|
|
||||||
def _run_model_template_cmd(self, args: argparse.Namespace) -> None:
|
def _run_model_template_cmd(self, args: argparse.Namespace) -> None:
|
||||||
from llama_models.llama3_1.api.interface import (
|
from llama_models.llama3.api.interface import (
|
||||||
list_jinja_templates,
|
list_jinja_templates,
|
||||||
render_jinja_template,
|
render_jinja_template,
|
||||||
)
|
)
|
||||||
|
|
||||||
from llama_toolchain.cli.table import print_table
|
from llama_toolchain.cli.table import print_table
|
||||||
|
|
||||||
if args.name:
|
if args.name:
|
||||||
|
|
|
@ -7,7 +7,7 @@
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
|
|
||||||
from llama_models.llama3_1.api.datatypes import URL
|
from llama_models.llama3.api.datatypes import URL
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type
|
from llama_models.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
|
|
@ -1,9 +1,13 @@
|
||||||
import os
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from llama_models.datatypes import Model
|
import os
|
||||||
|
|
||||||
from .config_dirs import DEFAULT_CHECKPOINT_DIR
|
from .config_dirs import DEFAULT_CHECKPOINT_DIR
|
||||||
|
|
||||||
|
|
||||||
def model_local_dir(model: Model) -> str:
|
def model_local_dir(descriptor: str) -> str:
|
||||||
return os.path.join(DEFAULT_CHECKPOINT_DIR, model.descriptor())
|
return os.path.join(DEFAULT_CHECKPOINT_DIR, descriptor)
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from llama_models.llama3_1.api.datatypes import URL
|
from llama_models.llama3.api.datatypes import URL
|
||||||
from llama_models.schema_utils import json_schema_type
|
from llama_models.schema_utils import json_schema_type
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
|
@ -7,7 +7,7 @@
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
from llama_models.llama3_1.api.datatypes import URL
|
from llama_models.llama3.api.datatypes import URL
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type
|
from llama_models.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
|
|
@ -10,7 +10,7 @@ from llama_models.schema_utils import webmethod
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_models.llama3_1.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
from .datatypes import * # noqa: F403
|
from .datatypes import * # noqa: F403
|
||||||
from llama_toolchain.dataset.api.datatypes import * # noqa: F403
|
from llama_toolchain.dataset.api.datatypes import * # noqa: F403
|
||||||
from llama_toolchain.common.training_types import * # noqa: F403
|
from llama_toolchain.common.training_types import * # noqa: F403
|
||||||
|
|
|
@ -12,7 +12,7 @@ from llama_models.schema_utils import json_schema_type
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
from llama_models.llama3_1.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
class LogProbConfig(BaseModel):
|
class LogProbConfig(BaseModel):
|
||||||
|
|
|
@ -22,22 +22,22 @@ from fairscale.nn.model_parallel.initialize import (
|
||||||
initialize_model_parallel,
|
initialize_model_parallel,
|
||||||
model_parallel_is_initialized,
|
model_parallel_is_initialized,
|
||||||
)
|
)
|
||||||
from llama_models.llama3_1.api.args import ModelArgs
|
from llama_models.llama3.api.args import ModelArgs
|
||||||
from llama_models.llama3_1.api.chat_format import ChatFormat, ModelInput
|
from llama_models.llama3.api.chat_format import ChatFormat, ModelInput
|
||||||
from llama_models.llama3_1.api.datatypes import Message
|
from llama_models.llama3.api.datatypes import Message
|
||||||
from llama_models.llama3_1.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
from llama_models.llama3_1.reference_impl.model import Transformer
|
from llama_models.llama3.reference_impl.model import Transformer
|
||||||
from llama_models.sku_list import resolve_model
|
from llama_models.sku_list import resolve_model
|
||||||
|
from termcolor import cprint
|
||||||
|
|
||||||
from llama_toolchain.common.model_utils import model_local_dir
|
from llama_toolchain.common.model_utils import model_local_dir
|
||||||
from llama_toolchain.inference.api import QuantizationType
|
from llama_toolchain.inference.api import QuantizationType
|
||||||
from termcolor import cprint
|
|
||||||
|
|
||||||
from .config import MetaReferenceImplConfig
|
from .config import MetaReferenceImplConfig
|
||||||
|
|
||||||
|
|
||||||
def model_checkpoint_dir(model) -> str:
|
def model_checkpoint_dir(model) -> str:
|
||||||
checkpoint_dir = Path(model_local_dir(model))
|
checkpoint_dir = Path(model_local_dir(model.descriptor()))
|
||||||
if not Path(checkpoint_dir / "consolidated.00.pth").exists():
|
if not Path(checkpoint_dir / "consolidated.00.pth").exists():
|
||||||
checkpoint_dir = checkpoint_dir / "original"
|
checkpoint_dir = checkpoint_dir / "original"
|
||||||
|
|
||||||
|
|
|
@ -8,7 +8,7 @@ import asyncio
|
||||||
|
|
||||||
from typing import AsyncIterator, Dict, Union
|
from typing import AsyncIterator, Dict, Union
|
||||||
|
|
||||||
from llama_models.llama3_1.api.datatypes import StopReason
|
from llama_models.llama3.api.datatypes import StopReason
|
||||||
from llama_models.sku_list import resolve_model
|
from llama_models.sku_list import resolve_model
|
||||||
|
|
||||||
from llama_toolchain.distribution.datatypes import Api, ProviderSpec
|
from llama_toolchain.distribution.datatypes import Api, ProviderSpec
|
||||||
|
|
|
@ -10,9 +10,9 @@ from dataclasses import dataclass
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Generator, List, Optional
|
from typing import Generator, List, Optional
|
||||||
|
|
||||||
from llama_models.llama3_1.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
from llama_models.llama3_1.api.datatypes import Message
|
from llama_models.llama3.api.datatypes import Message
|
||||||
from llama_models.llama3_1.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
from llama_models.sku_list import resolve_model
|
from llama_models.sku_list import resolve_model
|
||||||
|
|
||||||
from .config import MetaReferenceImplConfig
|
from .config import MetaReferenceImplConfig
|
||||||
|
|
|
@ -9,15 +9,17 @@ from typing import AsyncGenerator, Dict
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from llama_models.llama3_1.api.datatypes import (
|
from llama_models.llama3.api.datatypes import (
|
||||||
BuiltinTool,
|
BuiltinTool,
|
||||||
CompletionMessage,
|
CompletionMessage,
|
||||||
Message,
|
Message,
|
||||||
StopReason,
|
StopReason,
|
||||||
ToolCall,
|
ToolCall,
|
||||||
)
|
)
|
||||||
from llama_models.llama3_1.api.tool_utils import ToolUtils
|
from llama_models.llama3.api.tool_utils import ToolUtils
|
||||||
from llama_models.sku_list import resolve_model
|
from llama_models.sku_list import resolve_model
|
||||||
|
from ollama import AsyncClient
|
||||||
|
|
||||||
from llama_toolchain.distribution.datatypes import Api, ProviderSpec
|
from llama_toolchain.distribution.datatypes import Api, ProviderSpec
|
||||||
from llama_toolchain.inference.api import (
|
from llama_toolchain.inference.api import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
|
@ -30,7 +32,6 @@ from llama_toolchain.inference.api import (
|
||||||
ToolCallDelta,
|
ToolCallDelta,
|
||||||
ToolCallParseStatus,
|
ToolCallParseStatus,
|
||||||
)
|
)
|
||||||
from ollama import AsyncClient
|
|
||||||
|
|
||||||
from .config import OllamaImplConfig
|
from .config import OllamaImplConfig
|
||||||
|
|
||||||
|
@ -64,10 +65,10 @@ class OllamaInference(Inference):
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
try:
|
try:
|
||||||
await self.client.ps()
|
await self.client.ps()
|
||||||
except httpx.ConnectError:
|
except httpx.ConnectError as e:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Ollama Server is not running, start it using `ollama serve` in a separate terminal"
|
"Ollama Server is not running, start it using `ollama serve` in a separate terminal"
|
||||||
)
|
) from e
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -13,7 +13,7 @@ from typing import Optional
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
|
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
|
||||||
from llama_models.llama3_1.api.model import Transformer, TransformerBlock
|
from llama_models.llama3.api.model import Transformer, TransformerBlock
|
||||||
|
|
||||||
from llama_toolchain.inference.api.config import (
|
from llama_toolchain.inference.api.config import (
|
||||||
CheckpointQuantizationFormat,
|
CheckpointQuantizationFormat,
|
||||||
|
|
|
@ -12,7 +12,7 @@ from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from llama_models.llama3_1.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
from llama_toolchain.dataset.api.datatypes import * # noqa: F403
|
from llama_toolchain.dataset.api.datatypes import * # noqa: F403
|
||||||
from llama_toolchain.common.training_types import * # noqa: F403
|
from llama_toolchain.common.training_types import * # noqa: F403
|
||||||
from .datatypes import * # noqa: F403
|
from .datatypes import * # noqa: F403
|
||||||
|
|
|
@ -10,7 +10,7 @@ from llama_models.schema_utils import json_schema_type
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_models.llama3_1.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
|
|
@ -7,13 +7,12 @@
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Dict, Optional, Union
|
from typing import Dict, Optional, Union
|
||||||
|
|
||||||
from llama_models.llama3_1.api.datatypes import ToolParamDefinition
|
from llama_models.llama3.api.datatypes import ToolParamDefinition
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type
|
from llama_models.schema_utils import json_schema_type
|
||||||
|
|
||||||
from llama_toolchain.common.deployment_types import RestAPIExecutionConfig
|
from pydantic import BaseModel, validator
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from llama_toolchain.common.deployment_types import RestAPIExecutionConfig
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -43,6 +42,16 @@ class ShieldDefinition(BaseModel):
|
||||||
on_violation_action: OnViolationAction = OnViolationAction.RAISE
|
on_violation_action: OnViolationAction = OnViolationAction.RAISE
|
||||||
execution_config: Optional[RestAPIExecutionConfig] = None
|
execution_config: Optional[RestAPIExecutionConfig] = None
|
||||||
|
|
||||||
|
@validator("shield_type", pre=True)
|
||||||
|
@classmethod
|
||||||
|
def validate_field(cls, v):
|
||||||
|
if isinstance(v, str):
|
||||||
|
try:
|
||||||
|
return BuiltinShield(v)
|
||||||
|
except ValueError:
|
||||||
|
return v
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ShieldResponse(BaseModel):
|
class ShieldResponse(BaseModel):
|
||||||
|
@ -51,3 +60,13 @@ class ShieldResponse(BaseModel):
|
||||||
is_violation: bool
|
is_violation: bool
|
||||||
violation_type: Optional[str] = None
|
violation_type: Optional[str] = None
|
||||||
violation_return_message: Optional[str] = None
|
violation_return_message: Optional[str] = None
|
||||||
|
|
||||||
|
@validator("shield_type", pre=True)
|
||||||
|
@classmethod
|
||||||
|
def validate_field(cls, v):
|
||||||
|
if isinstance(v, str):
|
||||||
|
try:
|
||||||
|
return BuiltinShield(v)
|
||||||
|
except ValueError:
|
||||||
|
return v
|
||||||
|
return v
|
||||||
|
|
|
@ -7,7 +7,7 @@
|
||||||
from .datatypes import * # noqa: F403
|
from .datatypes import * # noqa: F403
|
||||||
from typing import List, Protocol
|
from typing import List, Protocol
|
||||||
|
|
||||||
from llama_models.llama3_1.api.datatypes import Message
|
from llama_models.llama3.api.datatypes import Message
|
||||||
|
|
||||||
# this dependency is annoying and we need a forked up version anyway
|
# this dependency is annoying and we need a forked up version anyway
|
||||||
from llama_models.schema_utils import webmethod
|
from llama_models.schema_utils import webmethod
|
||||||
|
|
|
@ -9,7 +9,7 @@ import asyncio
|
||||||
import fire
|
import fire
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from llama_models.llama3_1.api.datatypes import UserMessage
|
from llama_models.llama3.api.datatypes import UserMessage
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
from .api import (
|
from .api import (
|
||||||
|
|
|
@ -36,7 +36,7 @@ async def get_provider_impl(config: SafetyConfig, _deps: Dict[Api, ProviderSpec]
|
||||||
def resolve_and_get_path(model_name: str) -> str:
|
def resolve_and_get_path(model_name: str) -> str:
|
||||||
model = resolve_model(model_name)
|
model = resolve_model(model_name)
|
||||||
assert model is not None, f"Could not resolve model {model_name}"
|
assert model is not None, f"Could not resolve model {model_name}"
|
||||||
model_dir = model_local_dir(model)
|
model_dir = model_local_dir(model.descriptor())
|
||||||
return model_dir
|
return model_dir
|
||||||
|
|
||||||
|
|
||||||
|
@ -73,30 +73,34 @@ class MetaReferenceSafetyImpl(Safety):
|
||||||
return RunShieldResponse(responses=responses)
|
return RunShieldResponse(responses=responses)
|
||||||
|
|
||||||
|
|
||||||
|
def shield_type_equals(a: ShieldType, b: ShieldType):
|
||||||
|
return a == b or a == b.value
|
||||||
|
|
||||||
|
|
||||||
def shield_config_to_shield(
|
def shield_config_to_shield(
|
||||||
sc: ShieldDefinition, safety_config: SafetyConfig
|
sc: ShieldDefinition, safety_config: SafetyConfig
|
||||||
) -> ShieldBase:
|
) -> ShieldBase:
|
||||||
if sc.shield_type == BuiltinShield.llama_guard:
|
if shield_type_equals(sc.shield_type, BuiltinShield.llama_guard):
|
||||||
assert (
|
assert (
|
||||||
safety_config.llama_guard_shield is not None
|
safety_config.llama_guard_shield is not None
|
||||||
), "Cannot use LlamaGuardShield since not present in config"
|
), "Cannot use LlamaGuardShield since not present in config"
|
||||||
model_dir = resolve_and_get_path(safety_config.llama_guard_shield.model)
|
model_dir = resolve_and_get_path(safety_config.llama_guard_shield.model)
|
||||||
return LlamaGuardShield.instance(model_dir=model_dir)
|
return LlamaGuardShield.instance(model_dir=model_dir)
|
||||||
elif sc.shield_type == BuiltinShield.jailbreak_shield:
|
elif shield_type_equals(sc.shield_type, BuiltinShield.jailbreak_shield):
|
||||||
assert (
|
assert (
|
||||||
safety_config.prompt_guard_shield is not None
|
safety_config.prompt_guard_shield is not None
|
||||||
), "Cannot use Jailbreak Shield since Prompt Guard not present in config"
|
), "Cannot use Jailbreak Shield since Prompt Guard not present in config"
|
||||||
model_dir = resolve_and_get_path(safety_config.prompt_guard_shield.model)
|
model_dir = resolve_and_get_path(safety_config.prompt_guard_shield.model)
|
||||||
return JailbreakShield.instance(model_dir)
|
return JailbreakShield.instance(model_dir)
|
||||||
elif sc.shield_type == BuiltinShield.injection_shield:
|
elif shield_type_equals(sc.shield_type, BuiltinShield.injection_shield):
|
||||||
assert (
|
assert (
|
||||||
safety_config.prompt_guard_shield is not None
|
safety_config.prompt_guard_shield is not None
|
||||||
), "Cannot use PromptGuardShield since not present in config"
|
), "Cannot use PromptGuardShield since not present in config"
|
||||||
model_dir = resolve_and_get_path(safety_config.prompt_guard_shield.model)
|
model_dir = resolve_and_get_path(safety_config.prompt_guard_shield.model)
|
||||||
return InjectionShield.instance(model_dir)
|
return InjectionShield.instance(model_dir)
|
||||||
elif sc.shield_type == BuiltinShield.code_scanner_guard:
|
elif shield_type_equals(sc.shield_type, BuiltinShield.code_scanner_guard):
|
||||||
return CodeScannerShield.instance()
|
return CodeScannerShield.instance()
|
||||||
elif sc.shield_type == BuiltinShield.third_party_shield:
|
elif shield_type_equals(sc.shield_type, BuiltinShield.third_party_shield):
|
||||||
return ThirdPartyShield.instance()
|
return ThirdPartyShield.instance()
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown shield type: {sc.shield_type}")
|
raise ValueError(f"Unknown shield type: {sc.shield_type}")
|
||||||
|
|
|
@ -7,7 +7,7 @@
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import List, Union
|
from typing import List, Union
|
||||||
|
|
||||||
from llama_models.llama3_1.api.datatypes import Attachment, Message
|
from llama_models.llama3.api.datatypes import Attachment, Message
|
||||||
from llama_toolchain.safety.api.datatypes import * # noqa: F403
|
from llama_toolchain.safety.api.datatypes import * # noqa: F403
|
||||||
|
|
||||||
CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?"
|
CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?"
|
||||||
|
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from llama_models.llama3_1.api.datatypes import Message
|
from llama_models.llama3.api.datatypes import Message
|
||||||
|
|
||||||
from llama_toolchain.safety.meta_reference.shields.base import (
|
from llama_toolchain.safety.meta_reference.shields.base import (
|
||||||
OnViolationAction,
|
OnViolationAction,
|
||||||
|
|
|
@ -10,7 +10,7 @@ from string import Template
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from llama_models.llama3_1.api.datatypes import Message, Role
|
from llama_models.llama3.api.datatypes import Message, Role
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse
|
from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse
|
||||||
|
|
|
@ -9,7 +9,7 @@ from typing import List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from llama_models.llama3_1.api.datatypes import Message
|
from llama_models.llama3.api.datatypes import Message
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
||||||
|
|
||||||
|
|
|
@ -10,7 +10,7 @@ from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_models.llama3_1.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
from llama_toolchain.reward_scoring.api.datatypes import * # noqa: F403
|
from llama_toolchain.reward_scoring.api.datatypes import * # noqa: F403
|
||||||
from .datatypes import * # noqa: F403
|
from .datatypes import * # noqa: F403
|
||||||
|
|
||||||
|
|
2
setup.py
2
setup.py
|
@ -16,7 +16,7 @@ def read_requirements():
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="llama_toolchain",
|
name="llama_toolchain",
|
||||||
version="0.0.5",
|
version="0.0.8",
|
||||||
author="Meta Llama",
|
author="Meta Llama",
|
||||||
author_email="llama-oss@meta.com",
|
author_email="llama-oss@meta.com",
|
||||||
description="Llama toolchain",
|
description="Llama toolchain",
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue