From b8fc4d4deefa59ffed04af9901295cb78af8a7ee Mon Sep 17 00:00:00 2001 From: Hardik Shah Date: Thu, 15 Aug 2024 13:23:51 -0700 Subject: [PATCH 1/8] Updates to prompt for tool calls (#29) * update system prompts to drop new line * Add tool prompt formats * support json format * JSON in caps * function_tag system prompt is also added as a user message * added docstrings for ToolPromptFormat --------- Co-authored-by: Hardik Shah --- .../agentic_system/api/datatypes.py | 32 +++++++++ llama_toolchain/agentic_system/client.py | 67 +++++++++++++++++-- .../meta_reference/agent_instance.py | 9 ++- .../meta_reference/agentic_system.py | 1 + .../agentic_system/meta_reference/safety.py | 10 +-- .../meta_reference/system_prompt.py | 56 ++++++++++++---- llama_toolchain/agentic_system/utils.py | 3 + llama_toolchain/safety/api/datatypes.py | 25 ++++++- 8 files changed, 173 insertions(+), 30 deletions(-) diff --git a/llama_toolchain/agentic_system/api/datatypes.py b/llama_toolchain/agentic_system/api/datatypes.py index 1dda64834..db4e40c4b 100644 --- a/llama_toolchain/agentic_system/api/datatypes.py +++ b/llama_toolchain/agentic_system/api/datatypes.py @@ -110,6 +110,35 @@ class Session(BaseModel): started_at: datetime +@json_schema_type +class ToolPromptFormat(Enum): + """This Enum refers to the prompt format for calling zero shot tools + + `json` -- + Refers to the json format for calling tools. + The json format takes the form like + { + "type": "function", + "function" : { + "name": "function_name", + "description": "function_description", + "parameters": {...} + } + } + + `function_tag` -- + This is an example of how you could define + your own user defined format for making tool calls. + The function_tag format looks like this, + (parameters) + + The detailed prompts for each of these formats are defined in `system_prompt.py` + """ + + json = "json" + function_tag = "function_tag" + + @json_schema_type class AgenticSystemInstanceConfig(BaseModel): instructions: str @@ -127,6 +156,9 @@ class AgenticSystemInstanceConfig(BaseModel): # if you completely want to replace the messages prefixed by the system, # this is debug only debug_prefix_messages: Optional[List[Message]] = Field(default_factory=list) + tool_prompt_format: Optional[ToolPromptFormat] = Field( + default=ToolPromptFormat.json + ) class AgenticSystemTurnResponseEventType(Enum): diff --git a/llama_toolchain/agentic_system/client.py b/llama_toolchain/agentic_system/client.py index 71c578e2f..5b8053af9 100644 --- a/llama_toolchain/agentic_system/client.py +++ b/llama_toolchain/agentic_system/client.py @@ -13,8 +13,15 @@ import fire import httpx -from llama_models.llama3_1.api.datatypes import BuiltinTool, SamplingParams +from llama_models.llama3_1.api.datatypes import ( + BuiltinTool, + SamplingParams, + ToolParamDefinition, + UserMessage, +) +from termcolor import cprint +from llama_toolchain.agentic_system.event_logger import EventLogger from .api import ( AgenticSystem, AgenticSystemCreateRequest, @@ -25,6 +32,7 @@ from .api import ( AgenticSystemToolDefinition, AgenticSystemTurnCreateRequest, AgenticSystemTurnResponseStreamChunk, + ToolPromptFormat, ) @@ -87,7 +95,7 @@ class AgenticSystemClient(AgenticSystem): async def run_main(host: str, port: int): # client to test remote impl of agentic system - api = await AgenticSystemClient(f"http://{host}:{port}") + api = AgenticSystemClient(f"http://{host}:{port}") tool_definitions = [ AgenticSystemToolDefinition( @@ -96,13 +104,28 @@ async def run_main(host: str, port: int): AgenticSystemToolDefinition( tool_name=BuiltinTool.wolfram_alpha, ), - AgenticSystemToolDefinition( - tool_name=BuiltinTool.photogen, - ), AgenticSystemToolDefinition( tool_name=BuiltinTool.code_interpreter, ), ] + tool_definitions += [ + AgenticSystemToolDefinition( + tool_name="get_boiling_point", + description="Get the boiling point of a imaginary liquids (eg. polyjuice)", + parameters={ + "liquid_name": ToolParamDefinition( + param_type="str", + description="The name of the liquid", + required=True, + ), + "celcius": ToolParamDefinition( + param_type="str", + description="Whether to return the boiling point in Celcius", + required=False, + ), + }, + ), + ] create_request = AgenticSystemCreateRequest( model="Meta-Llama3.1-8B-Instruct", @@ -114,12 +137,44 @@ async def run_main(host: str, port: int): output_shields=[], quantization_config=None, debug_prefix_messages=[], + tool_prompt_format=ToolPromptFormat.json, ), ) create_response = await api.create_agentic_system(create_request) print(create_response) - # TODO: Add chat session / turn apis to test e2e + + session_response = await api.create_agentic_system_session( + AgenticSystemSessionCreateRequest( + system_id=create_response.system_id, + session_name="test_session", + ) + ) + print(session_response) + + user_prompts = [ + "Who are you?", + "what is the 100th prime number?", + "Search web for who was 44th President of USA?", + "Write code to check if a number is prime. Use that to check if 7 is prime", + "What is the boiling point of polyjuicepotion ?", + ] + for content in user_prompts: + cprint(f"User> {content}", color="blue") + iterator = api.create_agentic_system_turn( + AgenticSystemTurnCreateRequest( + system_id=create_response.system_id, + session_id=session_response.session_id, + messages=[ + UserMessage(content=content), + ], + stream=True, + ) + ) + + async for event, log in EventLogger().log(iterator): + if log is not None: + log.print() def main(host: str, port: int): diff --git a/llama_toolchain/agentic_system/meta_reference/agent_instance.py b/llama_toolchain/agentic_system/meta_reference/agent_instance.py index 8e4555cb4..5be9f8bb6 100644 --- a/llama_toolchain/agentic_system/meta_reference/agent_instance.py +++ b/llama_toolchain/agentic_system/meta_reference/agent_instance.py @@ -10,6 +10,8 @@ import uuid from datetime import datetime from typing import AsyncGenerator, List, Optional +from termcolor import cprint + from llama_toolchain.agentic_system.api.datatypes import ( AgenticSystemInstanceConfig, AgenticSystemTurnResponseEvent, @@ -24,6 +26,7 @@ from llama_toolchain.agentic_system.api.datatypes import ( ShieldCallStep, StepType, ToolExecutionStep, + ToolPromptFormat, Turn, ) @@ -51,7 +54,6 @@ from llama_toolchain.safety.api.datatypes import ( ShieldDefinition, ShieldResponse, ) -from termcolor import cprint from llama_toolchain.agentic_system.api.endpoints import * # noqa from .safety import SafetyException, ShieldRunnerMixin @@ -74,6 +76,7 @@ class AgentInstance(ShieldRunnerMixin): output_shields: List[ShieldDefinition], max_infer_iters: int = 10, prefix_messages: Optional[List[Message]] = None, + tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, ): self.system_id = system_id self.instance_config = instance_config @@ -86,7 +89,9 @@ class AgentInstance(ShieldRunnerMixin): self.prefix_messages = prefix_messages else: self.prefix_messages = get_agentic_prefix_messages( - builtin_tools, custom_tool_definitions + builtin_tools, + custom_tool_definitions, + tool_prompt_format, ) for m in self.prefix_messages: diff --git a/llama_toolchain/agentic_system/meta_reference/agentic_system.py b/llama_toolchain/agentic_system/meta_reference/agentic_system.py index 5db8d6168..ae1d282aa 100644 --- a/llama_toolchain/agentic_system/meta_reference/agentic_system.py +++ b/llama_toolchain/agentic_system/meta_reference/agentic_system.py @@ -108,6 +108,7 @@ class MetaReferenceAgenticSystemImpl(AgenticSystem): input_shields=cfg.input_shields, output_shields=cfg.output_shields, prefix_messages=cfg.debug_prefix_messages, + tool_prompt_format=cfg.tool_prompt_format, ) return AgenticSystemCreateResponse( diff --git a/llama_toolchain/agentic_system/meta_reference/safety.py b/llama_toolchain/agentic_system/meta_reference/safety.py index c78cb3028..ff3633f18 100644 --- a/llama_toolchain/agentic_system/meta_reference/safety.py +++ b/llama_toolchain/agentic_system/meta_reference/safety.py @@ -6,14 +6,15 @@ from typing import List -from llama_models.llama3_1.api.datatypes import Message, Role +from llama_models.llama3_1.api.datatypes import Message, Role, UserMessage +from termcolor import cprint + from llama_toolchain.safety.api.datatypes import ( OnViolationAction, ShieldDefinition, ShieldResponse, ) from llama_toolchain.safety.api.endpoints import RunShieldRequest, Safety -from termcolor import cprint class SafetyException(Exception): # noqa: N818 @@ -36,12 +37,11 @@ class ShieldRunnerMixin: async def run_shields( self, messages: List[Message], shields: List[ShieldDefinition] ) -> List[ShieldResponse]: + messages = messages.copy() # some shields like llama-guard require the first message to be a user message # since this might be a tool call, first role might not be user if len(messages) > 0 and messages[0].role != Role.user.value: - # TODO(ashwin): we need to change the type of the message, this kind of modification - # is no longer appropriate - messages[0].role = Role.user.value + messages[0] = UserMessage(content=messages[0].content) res = await self.safety_api.run_shields( RunShieldRequest( diff --git a/llama_toolchain/agentic_system/meta_reference/system_prompt.py b/llama_toolchain/agentic_system/meta_reference/system_prompt.py index c8c616285..9db3218c1 100644 --- a/llama_toolchain/agentic_system/meta_reference/system_prompt.py +++ b/llama_toolchain/agentic_system/meta_reference/system_prompt.py @@ -5,21 +5,27 @@ # the root directory of this source tree. import json +import textwrap from datetime import datetime from typing import List +from llama_toolchain.agentic_system.api.datatypes import ToolPromptFormat + from llama_toolchain.inference.api import ( BuiltinTool, Message, SystemMessage, ToolDefinition, + UserMessage, ) from .tools.builtin import SingleMessageBuiltinTool def get_agentic_prefix_messages( - builtin_tools: List[SingleMessageBuiltinTool], custom_tools: List[ToolDefinition] + builtin_tools: List[SingleMessageBuiltinTool], + custom_tools: List[ToolDefinition], + tool_prompt_format: ToolPromptFormat, ) -> List[Message]: messages = [] content = "" @@ -34,28 +40,52 @@ def get_agentic_prefix_messages( ] ) if tool_str: - content += f"Tools: {tool_str}\n" + content += f"Tools: {tool_str}" current_date = datetime.now() formatted_date = current_date.strftime("%d %B %Y") date_str = f""" Cutting Knowledge Date: December 2023 -Today Date: {formatted_date}\n\n""" +Today Date: {formatted_date}\n""" content += date_str + messages.append(SystemMessage(content=content)) if custom_tools: - custom_message = get_system_prompt_for_custom_tools(custom_tools) - content += custom_message + if tool_prompt_format == ToolPromptFormat.function_tag: + text = prompt_for_function_tag(custom_tools) + messages.append(UserMessage(content=text)) + elif tool_prompt_format == ToolPromptFormat.json: + text = prompt_for_json(custom_tools) + messages.append(UserMessage(content=text)) + else: + raise NotImplementedError( + f"Tool prompt format {tool_prompt_format} is not supported" + ) + else: + messages.append(SystemMessage(content=content)) - # TODO: Replace this hard coded message with instructions coming in the request - if False: - content += "You are a helpful Assistant." - - messages.append(SystemMessage(content=content)) return messages -def get_system_prompt_for_custom_tools(custom_tools: List[ToolDefinition]) -> str: +def prompt_for_json(custom_tools: List[ToolDefinition]) -> str: + tool_defs = "\n".join( + translate_custom_tool_definition_to_json(t) for t in custom_tools + ) + content = textwrap.dedent( + """ + Answer the user's question by making use of the following functions if needed. + If none of the function can be used, please say so. + Here is a list of functions in JSON format: + {tool_defs} + + Return function calls in JSON format. + """ + ) + content = content.lstrip("\n").format(tool_defs=tool_defs) + return content + + +def prompt_for_function_tag(custom_tools: List[ToolDefinition]) -> str: custom_tool_params = "" for t in custom_tools: custom_tool_params += get_instruction_string(t) + "\n" @@ -76,7 +106,6 @@ Reminder: - Required parameters MUST be specified - Only call one function at a time - Put the entire function call reply on one line - """ return content @@ -98,7 +127,6 @@ def get_parameters_string(custom_tool_definition) -> str: ) -# NOTE: Unused right now def translate_custom_tool_definition_to_json(tool_def): """Translates ToolDefinition to json as expected by model eg. output for a function @@ -149,4 +177,4 @@ def translate_custom_tool_definition_to_json(tool_def): else: func_def["function"]["parameters"] = {} - return json.dumps(func_def) + return json.dumps(func_def, indent=4) diff --git a/llama_toolchain/agentic_system/utils.py b/llama_toolchain/agentic_system/utils.py index bc1639b3d..3ae5c67b6 100644 --- a/llama_toolchain/agentic_system/utils.py +++ b/llama_toolchain/agentic_system/utils.py @@ -15,6 +15,7 @@ from llama_toolchain.agentic_system.api import ( AgenticSystemSessionCreateRequest, AgenticSystemToolDefinition, ) +from llama_toolchain.agentic_system.api.datatypes import ToolPromptFormat from llama_toolchain.agentic_system.client import AgenticSystemClient from llama_toolchain.agentic_system.tools.custom.execute import ( @@ -64,6 +65,7 @@ async def get_agent_system_instance( custom_tools: Optional[List[Any]] = None, disable_safety: bool = False, model: str = "Meta-Llama3.1-8B-Instruct", + tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json, ) -> AgenticSystemClientWrapper: custom_tools = custom_tools or [] @@ -113,6 +115,7 @@ async def get_agent_system_instance( ] ), sampling_params=SamplingParams(), + tool_prompt_format=tool_prompt_format, ), ) create_response = await api.create_agentic_system(create_request) diff --git a/llama_toolchain/safety/api/datatypes.py b/llama_toolchain/safety/api/datatypes.py index c5734da99..c0d23f589 100644 --- a/llama_toolchain/safety/api/datatypes.py +++ b/llama_toolchain/safety/api/datatypes.py @@ -8,12 +8,11 @@ from enum import Enum from typing import Dict, Optional, Union from llama_models.llama3_1.api.datatypes import ToolParamDefinition - from llama_models.schema_utils import json_schema_type -from llama_toolchain.common.deployment_types import RestAPIExecutionConfig +from pydantic import BaseModel, validator -from pydantic import BaseModel +from llama_toolchain.common.deployment_types import RestAPIExecutionConfig @json_schema_type @@ -43,6 +42,16 @@ class ShieldDefinition(BaseModel): on_violation_action: OnViolationAction = OnViolationAction.RAISE execution_config: Optional[RestAPIExecutionConfig] = None + @validator("shield_type", pre=True) + @classmethod + def validate_field(cls, v): + if isinstance(v, str): + try: + return BuiltinShield(v) + except ValueError: + return v + return v + @json_schema_type class ShieldResponse(BaseModel): @@ -51,3 +60,13 @@ class ShieldResponse(BaseModel): is_violation: bool violation_type: Optional[str] = None violation_return_message: Optional[str] = None + + @validator("shield_type", pre=True) + @classmethod + def validate_field(cls, v): + if isinstance(v, str): + try: + return BuiltinShield(v) + except ValueError: + return v + return v From 5e072d0780f1fea4c0bdb1c15a04c941cbc0d5a1 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Sat, 17 Aug 2024 10:08:00 -0700 Subject: [PATCH 2/8] Add a `--manifest-file` option to `llama download` --- llama_toolchain/cli/download.py | 74 +++++++++++++++++-- llama_toolchain/common/model_utils.py | 12 ++- .../inference/meta_reference/generation.py | 4 +- .../safety/meta_reference/safety.py | 2 +- 4 files changed, 78 insertions(+), 14 deletions(-) diff --git a/llama_toolchain/cli/download.py b/llama_toolchain/cli/download.py index 2a1c79220..f7365b7b4 100644 --- a/llama_toolchain/cli/download.py +++ b/llama_toolchain/cli/download.py @@ -6,18 +6,22 @@ import argparse import asyncio +import json import os import shutil import time +from datetime import datetime from functools import partial from pathlib import Path +from typing import Dict, List import httpx - -from llama_toolchain.cli.subcommand import Subcommand +from pydantic import BaseModel from termcolor import cprint +from llama_toolchain.cli.subcommand import Subcommand + class Download(Subcommand): """Llama cli for downloading llama toolchain assets""" @@ -45,7 +49,7 @@ def setup_download_parser(parser: argparse.ArgumentParser) -> None: parser.add_argument( "--model-id", choices=[x.descriptor() for x in models], - required=True, + required=False, ) parser.add_argument( "--hf-token", @@ -88,7 +92,7 @@ def _hf_download( if repo_id is None: raise ValueError(f"No repo id found for model {model.descriptor()}") - output_dir = model_local_dir(model) + output_dir = model_local_dir(model.descriptor()) os.makedirs(output_dir, exist_ok=True) try: true_output_dir = snapshot_download( @@ -118,7 +122,7 @@ def _meta_download(model: "Model", meta_url: str): from llama_toolchain.common.model_utils import model_local_dir - output_dir = Path(model_local_dir(model)) + output_dir = Path(model_local_dir(model.descriptor())) os.makedirs(output_dir, exist_ok=True) info = llama_meta_net_info(model) @@ -139,6 +143,14 @@ def _meta_download(model: "Model", meta_url: str): def run_download_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser): from llama_models.sku_list import resolve_model + if args.manifest_file: + _download_from_manifest(args.manifest_file) + return + + if args.model_id is None: + parser.error("Please provide a model id") + return + model = resolve_model(args.model_id) if model is None: parser.error(f"Model {args.model_id} not found") @@ -156,6 +168,54 @@ def run_download_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser): _meta_download(model, meta_url) +class ModelEntry(BaseModel): + model_id: str + files: Dict[str, str] + + +class Manifest(BaseModel): + models: List[ModelEntry] + expires_on: datetime + + +def _download_from_manifest(manifest_file: str): + from llama_toolchain.common.model_utils import model_local_dir + + with open(manifest_file, "r") as f: + d = json.load(f) + manifest = Manifest(**d) + + if datetime.now() > manifest.expires_on: + raise ValueError(f"Manifest URLs have expired on {manifest.expires_on}") + + for entry in manifest.models: + print(f"Downloading model {entry.model_id}...") + output_dir = Path(model_local_dir(entry.model_id)) + os.makedirs(output_dir, exist_ok=True) + + if any(output_dir.iterdir()): + cprint(f"Output directory {output_dir} is not empty.", "red") + + while True: + resp = input( + "Do you want to (C)ontinue download or (R)estart completely? (continue/restart): " + ) + if resp.lower() == "restart" or resp.lower() == "r": + shutil.rmtree(output_dir) + os.makedirs(output_dir, exist_ok=True) + break + elif resp.lower() == "continue" or resp.lower() == "c": + print("Continuing download...") + break + else: + cprint("Invalid response. Please try again.", "red") + + for fname, url in entry.files.items(): + output_file = str(output_dir / fname) + downloader = ResumableDownloader(url, output_file) + asyncio.run(downloader.download()) + + class ResumableDownloader: def __init__( self, @@ -190,7 +250,7 @@ class ResumableDownloader: async def download(self) -> None: self.start_time = time.time() - async with httpx.AsyncClient() as client: + async with httpx.AsyncClient(follow_redirects=True) as client: await self.get_file_info(client) if os.path.exists(self.output_file): @@ -222,7 +282,7 @@ class ResumableDownloader: headers = { "Range": f"bytes={self.downloaded_size}-{self.downloaded_size + request_size}" } - # print(f"Downloading `{self.output_file}`....{headers}") + print(f"Downloading `{self.output_file}`....{headers}") try: async with client.stream( "GET", self.url, headers=headers diff --git a/llama_toolchain/common/model_utils.py b/llama_toolchain/common/model_utils.py index 282e02ea8..9e0c3f034 100644 --- a/llama_toolchain/common/model_utils.py +++ b/llama_toolchain/common/model_utils.py @@ -1,9 +1,13 @@ -import os +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. -from llama_models.datatypes import Model +import os from .config_dirs import DEFAULT_CHECKPOINT_DIR -def model_local_dir(model: Model) -> str: - return os.path.join(DEFAULT_CHECKPOINT_DIR, model.descriptor()) +def model_local_dir(descriptor: str) -> str: + return os.path.join(DEFAULT_CHECKPOINT_DIR, descriptor) diff --git a/llama_toolchain/inference/meta_reference/generation.py b/llama_toolchain/inference/meta_reference/generation.py index dfbaf1a3e..f4d3c210b 100644 --- a/llama_toolchain/inference/meta_reference/generation.py +++ b/llama_toolchain/inference/meta_reference/generation.py @@ -28,16 +28,16 @@ from llama_models.llama3_1.api.datatypes import Message from llama_models.llama3_1.api.tokenizer import Tokenizer from llama_models.llama3_1.reference_impl.model import Transformer from llama_models.sku_list import resolve_model +from termcolor import cprint from llama_toolchain.common.model_utils import model_local_dir from llama_toolchain.inference.api import QuantizationType -from termcolor import cprint from .config import MetaReferenceImplConfig def model_checkpoint_dir(model) -> str: - checkpoint_dir = Path(model_local_dir(model)) + checkpoint_dir = Path(model_local_dir(model.descriptor())) if not Path(checkpoint_dir / "consolidated.00.pth").exists(): checkpoint_dir = checkpoint_dir / "original" diff --git a/llama_toolchain/safety/meta_reference/safety.py b/llama_toolchain/safety/meta_reference/safety.py index 426376c2d..c669eed2f 100644 --- a/llama_toolchain/safety/meta_reference/safety.py +++ b/llama_toolchain/safety/meta_reference/safety.py @@ -36,7 +36,7 @@ async def get_provider_impl(config: SafetyConfig, _deps: Dict[Api, ProviderSpec] def resolve_and_get_path(model_name: str) -> str: model = resolve_model(model_name) assert model is not None, f"Could not resolve model {model_name}" - model_dir = model_local_dir(model) + model_dir = model_local_dir(model.descriptor()) return model_dir From f502716cf73e80c09da4f4eb4befa11b00b26e46 Mon Sep 17 00:00:00 2001 From: dltn <6599399+dltn@users.noreply.github.com> Date: Sun, 18 Aug 2024 19:13:15 -0700 Subject: [PATCH 3/8] Fix ShieldType Union equality bug --- llama_toolchain/safety/meta_reference/safety.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/llama_toolchain/safety/meta_reference/safety.py b/llama_toolchain/safety/meta_reference/safety.py index c669eed2f..8f63b14f2 100644 --- a/llama_toolchain/safety/meta_reference/safety.py +++ b/llama_toolchain/safety/meta_reference/safety.py @@ -73,30 +73,34 @@ class MetaReferenceSafetyImpl(Safety): return RunShieldResponse(responses=responses) +def shield_type_equals(a: ShieldType, b: ShieldType): + return a == b or a == b.value + + def shield_config_to_shield( sc: ShieldDefinition, safety_config: SafetyConfig ) -> ShieldBase: - if sc.shield_type == BuiltinShield.llama_guard: + if shield_type_equals(sc.shield_type, BuiltinShield.llama_guard): assert ( safety_config.llama_guard_shield is not None ), "Cannot use LlamaGuardShield since not present in config" model_dir = resolve_and_get_path(safety_config.llama_guard_shield.model) return LlamaGuardShield.instance(model_dir=model_dir) - elif sc.shield_type == BuiltinShield.jailbreak_shield: + elif shield_type_equals(sc.shield_type, BuiltinShield.jailbreak_shield): assert ( safety_config.prompt_guard_shield is not None ), "Cannot use Jailbreak Shield since Prompt Guard not present in config" model_dir = resolve_and_get_path(safety_config.prompt_guard_shield.model) return JailbreakShield.instance(model_dir) - elif sc.shield_type == BuiltinShield.injection_shield: + elif shield_type_equals(sc.shield_type, BuiltinShield.injection_shield): assert ( safety_config.prompt_guard_shield is not None ), "Cannot use PromptGuardShield since not present in config" model_dir = resolve_and_get_path(safety_config.prompt_guard_shield.model) return InjectionShield.instance(model_dir) - elif sc.shield_type == BuiltinShield.code_scanner_guard: + elif shield_type_equals(sc.shield_type, BuiltinShield.code_scanner_guard): return CodeScannerShield.instance() - elif sc.shield_type == BuiltinShield.third_party_shield: + elif shield_type_equals(sc.shield_type, BuiltinShield.third_party_shield): return ThirdPartyShield.instance() else: raise ValueError(f"Unknown shield type: {sc.shield_type}") From 38244c316156424041931ea10ebc8942f1bc606c Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Mon, 19 Aug 2024 10:55:37 -0700 Subject: [PATCH 4/8] llama_models.llama3_1 -> llama_models.llama3 --- llama_toolchain/agentic_system/client.py | 2 +- llama_toolchain/agentic_system/event_logger.py | 8 ++++---- .../agentic_system/meta_reference/safety.py | 2 +- .../agentic_system/tools/custom/datatypes.py | 2 +- .../agentic_system/tools/custom/execute.py | 2 +- llama_toolchain/agentic_system/utils.py | 2 +- llama_toolchain/cli/model/template.py | 7 ++++--- llama_toolchain/common/deployment_types.py | 2 +- llama_toolchain/common/training_types.py | 2 +- llama_toolchain/dataset/api/datatypes.py | 2 +- llama_toolchain/evaluations/api/endpoints.py | 2 +- llama_toolchain/inference/api/datatypes.py | 2 +- .../inference/meta_reference/generation.py | 10 +++++----- llama_toolchain/inference/meta_reference/inference.py | 2 +- .../inference/meta_reference/model_parallel.py | 6 +++--- llama_toolchain/inference/ollama/ollama.py | 11 ++++++----- llama_toolchain/inference/quantization/loader.py | 2 +- llama_toolchain/post_training/api/endpoints.py | 2 +- llama_toolchain/reward_scoring/api/datatypes.py | 2 +- llama_toolchain/safety/api/datatypes.py | 2 +- llama_toolchain/safety/api/endpoints.py | 2 +- llama_toolchain/safety/client.py | 2 +- llama_toolchain/safety/meta_reference/shields/base.py | 2 +- .../shields/contrib/third_party_shield.py | 2 +- .../safety/meta_reference/shields/llama_guard.py | 2 +- .../safety/meta_reference/shields/prompt_guard.py | 2 +- .../synthetic_data_generation/api/endpoints.py | 2 +- 27 files changed, 44 insertions(+), 42 deletions(-) diff --git a/llama_toolchain/agentic_system/client.py b/llama_toolchain/agentic_system/client.py index 5b8053af9..154bca614 100644 --- a/llama_toolchain/agentic_system/client.py +++ b/llama_toolchain/agentic_system/client.py @@ -13,7 +13,7 @@ import fire import httpx -from llama_models.llama3_1.api.datatypes import ( +from llama_models.llama3.api.datatypes import ( BuiltinTool, SamplingParams, ToolParamDefinition, diff --git a/llama_toolchain/agentic_system/event_logger.py b/llama_toolchain/agentic_system/event_logger.py index 1bf669a0a..22d961a10 100644 --- a/llama_toolchain/agentic_system/event_logger.py +++ b/llama_toolchain/agentic_system/event_logger.py @@ -6,16 +6,16 @@ from typing import Optional -from llama_models.llama3_1.api.datatypes import ToolResponseMessage -from llama_models.llama3_1.api.tool_utils import ToolUtils +from llama_models.llama3.api.datatypes import ToolResponseMessage +from llama_models.llama3.api.tool_utils import ToolUtils + +from termcolor import cprint from llama_toolchain.agentic_system.api import ( AgenticSystemTurnResponseEventType, StepType, ) -from termcolor import cprint - class LogEvent: def __init__( diff --git a/llama_toolchain/agentic_system/meta_reference/safety.py b/llama_toolchain/agentic_system/meta_reference/safety.py index ff3633f18..683ae622d 100644 --- a/llama_toolchain/agentic_system/meta_reference/safety.py +++ b/llama_toolchain/agentic_system/meta_reference/safety.py @@ -6,7 +6,7 @@ from typing import List -from llama_models.llama3_1.api.datatypes import Message, Role, UserMessage +from llama_models.llama3.api.datatypes import Message, Role, UserMessage from termcolor import cprint from llama_toolchain.safety.api.datatypes import ( diff --git a/llama_toolchain/agentic_system/tools/custom/datatypes.py b/llama_toolchain/agentic_system/tools/custom/datatypes.py index ee46114e8..174b55241 100644 --- a/llama_toolchain/agentic_system/tools/custom/datatypes.py +++ b/llama_toolchain/agentic_system/tools/custom/datatypes.py @@ -9,7 +9,7 @@ import json from abc import abstractmethod from typing import Dict, List -from llama_models.llama3_1.api.datatypes import * # noqa: F403 +from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_toolchain.agentic_system.api import * # noqa: F403 # TODO: this is symptomatic of us needing to pull more tooling related utilities diff --git a/llama_toolchain/agentic_system/tools/custom/execute.py b/llama_toolchain/agentic_system/tools/custom/execute.py index 987aee4e2..4729d35a7 100644 --- a/llama_toolchain/agentic_system/tools/custom/execute.py +++ b/llama_toolchain/agentic_system/tools/custom/execute.py @@ -6,7 +6,7 @@ from typing import Any, AsyncGenerator, List -from llama_models.llama3_1.api.datatypes import StopReason, ToolResponseMessage +from llama_models.llama3.api.datatypes import StopReason, ToolResponseMessage from llama_toolchain.agentic_system.api import ( AgenticSystem, diff --git a/llama_toolchain/agentic_system/utils.py b/llama_toolchain/agentic_system/utils.py index 3ae5c67b6..9613b45df 100644 --- a/llama_toolchain/agentic_system/utils.py +++ b/llama_toolchain/agentic_system/utils.py @@ -7,7 +7,7 @@ import uuid from typing import Any, List, Optional -from llama_models.llama3_1.api.datatypes import BuiltinTool, Message, SamplingParams +from llama_models.llama3.api.datatypes import BuiltinTool, Message, SamplingParams from llama_toolchain.agentic_system.api import ( AgenticSystemCreateRequest, diff --git a/llama_toolchain/cli/model/template.py b/llama_toolchain/cli/model/template.py index 58b245035..1915e87d3 100644 --- a/llama_toolchain/cli/model/template.py +++ b/llama_toolchain/cli/model/template.py @@ -7,10 +7,10 @@ import argparse import textwrap -from llama_toolchain.cli.subcommand import Subcommand - from termcolor import colored +from llama_toolchain.cli.subcommand import Subcommand + class ModelTemplate(Subcommand): """Llama model cli for describe a model template (message formats)""" @@ -48,10 +48,11 @@ class ModelTemplate(Subcommand): ) def _run_model_template_cmd(self, args: argparse.Namespace) -> None: - from llama_models.llama3_1.api.interface import ( + from llama_models.llama3.api.interface import ( list_jinja_templates, render_jinja_template, ) + from llama_toolchain.cli.table import print_table if args.name: diff --git a/llama_toolchain/common/deployment_types.py b/llama_toolchain/common/deployment_types.py index e5117cf2c..8b67eff0d 100644 --- a/llama_toolchain/common/deployment_types.py +++ b/llama_toolchain/common/deployment_types.py @@ -7,7 +7,7 @@ from enum import Enum from typing import Dict, Optional -from llama_models.llama3_1.api.datatypes import URL +from llama_models.llama3.api.datatypes import URL from llama_models.schema_utils import json_schema_type diff --git a/llama_toolchain/common/training_types.py b/llama_toolchain/common/training_types.py index 9c8d786fd..fd74293eb 100644 --- a/llama_toolchain/common/training_types.py +++ b/llama_toolchain/common/training_types.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from llama_models.llama3_1.api.datatypes import URL +from llama_models.llama3.api.datatypes import URL from llama_models.schema_utils import json_schema_type from pydantic import BaseModel diff --git a/llama_toolchain/dataset/api/datatypes.py b/llama_toolchain/dataset/api/datatypes.py index 5724023e9..32109b37c 100644 --- a/llama_toolchain/dataset/api/datatypes.py +++ b/llama_toolchain/dataset/api/datatypes.py @@ -7,7 +7,7 @@ from enum import Enum from typing import Any, Dict, Optional -from llama_models.llama3_1.api.datatypes import URL +from llama_models.llama3.api.datatypes import URL from llama_models.schema_utils import json_schema_type diff --git a/llama_toolchain/evaluations/api/endpoints.py b/llama_toolchain/evaluations/api/endpoints.py index 39b9a28e0..fd5b68bbe 100644 --- a/llama_toolchain/evaluations/api/endpoints.py +++ b/llama_toolchain/evaluations/api/endpoints.py @@ -10,7 +10,7 @@ from llama_models.schema_utils import webmethod from pydantic import BaseModel -from llama_models.llama3_1.api.datatypes import * # noqa: F403 +from llama_models.llama3.api.datatypes import * # noqa: F403 from .datatypes import * # noqa: F403 from llama_toolchain.dataset.api.datatypes import * # noqa: F403 from llama_toolchain.common.training_types import * # noqa: F403 diff --git a/llama_toolchain/inference/api/datatypes.py b/llama_toolchain/inference/api/datatypes.py index 5b0bc7170..571ecc3ea 100644 --- a/llama_toolchain/inference/api/datatypes.py +++ b/llama_toolchain/inference/api/datatypes.py @@ -12,7 +12,7 @@ from llama_models.schema_utils import json_schema_type from pydantic import BaseModel, Field from typing_extensions import Annotated -from llama_models.llama3_1.api.datatypes import * # noqa: F403 +from llama_models.llama3.api.datatypes import * # noqa: F403 class LogProbConfig(BaseModel): diff --git a/llama_toolchain/inference/meta_reference/generation.py b/llama_toolchain/inference/meta_reference/generation.py index f4d3c210b..058874702 100644 --- a/llama_toolchain/inference/meta_reference/generation.py +++ b/llama_toolchain/inference/meta_reference/generation.py @@ -22,11 +22,11 @@ from fairscale.nn.model_parallel.initialize import ( initialize_model_parallel, model_parallel_is_initialized, ) -from llama_models.llama3_1.api.args import ModelArgs -from llama_models.llama3_1.api.chat_format import ChatFormat, ModelInput -from llama_models.llama3_1.api.datatypes import Message -from llama_models.llama3_1.api.tokenizer import Tokenizer -from llama_models.llama3_1.reference_impl.model import Transformer +from llama_models.llama3.api.args import ModelArgs +from llama_models.llama3.api.chat_format import ChatFormat, ModelInput +from llama_models.llama3.api.datatypes import Message +from llama_models.llama3.api.tokenizer import Tokenizer +from llama_models.llama3.reference_impl.model import Transformer from llama_models.sku_list import resolve_model from termcolor import cprint diff --git a/llama_toolchain/inference/meta_reference/inference.py b/llama_toolchain/inference/meta_reference/inference.py index 4bd7a80bc..84caf1ecf 100644 --- a/llama_toolchain/inference/meta_reference/inference.py +++ b/llama_toolchain/inference/meta_reference/inference.py @@ -8,7 +8,7 @@ import asyncio from typing import AsyncIterator, Dict, Union -from llama_models.llama3_1.api.datatypes import StopReason +from llama_models.llama3.api.datatypes import StopReason from llama_models.sku_list import resolve_model from llama_toolchain.distribution.datatypes import Api, ProviderSpec diff --git a/llama_toolchain/inference/meta_reference/model_parallel.py b/llama_toolchain/inference/meta_reference/model_parallel.py index dee05d8d5..3de4a6381 100644 --- a/llama_toolchain/inference/meta_reference/model_parallel.py +++ b/llama_toolchain/inference/meta_reference/model_parallel.py @@ -10,9 +10,9 @@ from dataclasses import dataclass from functools import partial from typing import Generator, List, Optional -from llama_models.llama3_1.api.chat_format import ChatFormat -from llama_models.llama3_1.api.datatypes import Message -from llama_models.llama3_1.api.tokenizer import Tokenizer +from llama_models.llama3.api.chat_format import ChatFormat +from llama_models.llama3.api.datatypes import Message +from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.sku_list import resolve_model from .config import MetaReferenceImplConfig diff --git a/llama_toolchain/inference/ollama/ollama.py b/llama_toolchain/inference/ollama/ollama.py index 64f24bee4..8901d5c02 100644 --- a/llama_toolchain/inference/ollama/ollama.py +++ b/llama_toolchain/inference/ollama/ollama.py @@ -9,15 +9,17 @@ from typing import AsyncGenerator, Dict import httpx -from llama_models.llama3_1.api.datatypes import ( +from llama_models.llama3.api.datatypes import ( BuiltinTool, CompletionMessage, Message, StopReason, ToolCall, ) -from llama_models.llama3_1.api.tool_utils import ToolUtils +from llama_models.llama3.api.tool_utils import ToolUtils from llama_models.sku_list import resolve_model +from ollama import AsyncClient + from llama_toolchain.distribution.datatypes import Api, ProviderSpec from llama_toolchain.inference.api import ( ChatCompletionRequest, @@ -30,7 +32,6 @@ from llama_toolchain.inference.api import ( ToolCallDelta, ToolCallParseStatus, ) -from ollama import AsyncClient from .config import OllamaImplConfig @@ -64,10 +65,10 @@ class OllamaInference(Inference): async def initialize(self) -> None: try: await self.client.ps() - except httpx.ConnectError: + except httpx.ConnectError as e: raise RuntimeError( "Ollama Server is not running, start it using `ollama serve` in a separate terminal" - ) + ) from e async def shutdown(self) -> None: pass diff --git a/llama_toolchain/inference/quantization/loader.py b/llama_toolchain/inference/quantization/loader.py index 583123df6..3645344aa 100644 --- a/llama_toolchain/inference/quantization/loader.py +++ b/llama_toolchain/inference/quantization/loader.py @@ -13,7 +13,7 @@ from typing import Optional import torch from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region -from llama_models.llama3_1.api.model import Transformer, TransformerBlock +from llama_models.llama3.api.model import Transformer, TransformerBlock from llama_toolchain.inference.api.config import ( CheckpointQuantizationFormat, diff --git a/llama_toolchain/post_training/api/endpoints.py b/llama_toolchain/post_training/api/endpoints.py index 0512003d3..e451def17 100644 --- a/llama_toolchain/post_training/api/endpoints.py +++ b/llama_toolchain/post_training/api/endpoints.py @@ -12,7 +12,7 @@ from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel, Field -from llama_models.llama3_1.api.datatypes import * # noqa: F403 +from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_toolchain.dataset.api.datatypes import * # noqa: F403 from llama_toolchain.common.training_types import * # noqa: F403 from .datatypes import * # noqa: F403 diff --git a/llama_toolchain/reward_scoring/api/datatypes.py b/llama_toolchain/reward_scoring/api/datatypes.py index 3359d4fc9..2ce698d47 100644 --- a/llama_toolchain/reward_scoring/api/datatypes.py +++ b/llama_toolchain/reward_scoring/api/datatypes.py @@ -10,7 +10,7 @@ from llama_models.schema_utils import json_schema_type from pydantic import BaseModel -from llama_models.llama3_1.api.datatypes import * # noqa: F403 +from llama_models.llama3.api.datatypes import * # noqa: F403 @json_schema_type diff --git a/llama_toolchain/safety/api/datatypes.py b/llama_toolchain/safety/api/datatypes.py index c0d23f589..5deecc2b3 100644 --- a/llama_toolchain/safety/api/datatypes.py +++ b/llama_toolchain/safety/api/datatypes.py @@ -7,7 +7,7 @@ from enum import Enum from typing import Dict, Optional, Union -from llama_models.llama3_1.api.datatypes import ToolParamDefinition +from llama_models.llama3.api.datatypes import ToolParamDefinition from llama_models.schema_utils import json_schema_type from pydantic import BaseModel, validator diff --git a/llama_toolchain/safety/api/endpoints.py b/llama_toolchain/safety/api/endpoints.py index 11c1282a1..a282a7968 100644 --- a/llama_toolchain/safety/api/endpoints.py +++ b/llama_toolchain/safety/api/endpoints.py @@ -7,7 +7,7 @@ from .datatypes import * # noqa: F403 from typing import List, Protocol -from llama_models.llama3_1.api.datatypes import Message +from llama_models.llama3.api.datatypes import Message # this dependency is annoying and we need a forked up version anyway from llama_models.schema_utils import webmethod diff --git a/llama_toolchain/safety/client.py b/llama_toolchain/safety/client.py index 2bceebc68..5d86f9291 100644 --- a/llama_toolchain/safety/client.py +++ b/llama_toolchain/safety/client.py @@ -9,7 +9,7 @@ import asyncio import fire import httpx -from llama_models.llama3_1.api.datatypes import UserMessage +from llama_models.llama3.api.datatypes import UserMessage from termcolor import cprint from .api import ( diff --git a/llama_toolchain/safety/meta_reference/shields/base.py b/llama_toolchain/safety/meta_reference/shields/base.py index ce19a3676..245373b13 100644 --- a/llama_toolchain/safety/meta_reference/shields/base.py +++ b/llama_toolchain/safety/meta_reference/shields/base.py @@ -7,7 +7,7 @@ from abc import ABC, abstractmethod from typing import List, Union -from llama_models.llama3_1.api.datatypes import Attachment, Message +from llama_models.llama3.api.datatypes import Attachment, Message from llama_toolchain.safety.api.datatypes import * # noqa: F403 CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?" diff --git a/llama_toolchain/safety/meta_reference/shields/contrib/third_party_shield.py b/llama_toolchain/safety/meta_reference/shields/contrib/third_party_shield.py index 789fa5f07..61a5977ed 100644 --- a/llama_toolchain/safety/meta_reference/shields/contrib/third_party_shield.py +++ b/llama_toolchain/safety/meta_reference/shields/contrib/third_party_shield.py @@ -6,7 +6,7 @@ from typing import List -from llama_models.llama3_1.api.datatypes import Message +from llama_models.llama3.api.datatypes import Message from llama_toolchain.safety.meta_reference.shields.base import ( OnViolationAction, diff --git a/llama_toolchain/safety/meta_reference/shields/llama_guard.py b/llama_toolchain/safety/meta_reference/shields/llama_guard.py index 56126abde..a78b8127d 100644 --- a/llama_toolchain/safety/meta_reference/shields/llama_guard.py +++ b/llama_toolchain/safety/meta_reference/shields/llama_guard.py @@ -10,7 +10,7 @@ from string import Template from typing import List, Optional import torch -from llama_models.llama3_1.api.datatypes import Message, Role +from llama_models.llama3.api.datatypes import Message, Role from transformers import AutoModelForCausalLM, AutoTokenizer from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse diff --git a/llama_toolchain/safety/meta_reference/shields/prompt_guard.py b/llama_toolchain/safety/meta_reference/shields/prompt_guard.py index 0acc1e488..b9f5dd5a5 100644 --- a/llama_toolchain/safety/meta_reference/shields/prompt_guard.py +++ b/llama_toolchain/safety/meta_reference/shields/prompt_guard.py @@ -9,7 +9,7 @@ from typing import List import torch -from llama_models.llama3_1.api.datatypes import Message +from llama_models.llama3.api.datatypes import Message from termcolor import cprint from transformers import AutoModelForSequenceClassification, AutoTokenizer diff --git a/llama_toolchain/synthetic_data_generation/api/endpoints.py b/llama_toolchain/synthetic_data_generation/api/endpoints.py index 8eada05cf..91585a943 100644 --- a/llama_toolchain/synthetic_data_generation/api/endpoints.py +++ b/llama_toolchain/synthetic_data_generation/api/endpoints.py @@ -10,7 +10,7 @@ from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel -from llama_models.llama3_1.api.datatypes import * # noqa: F403 +from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_toolchain.reward_scoring.api.datatypes import * # noqa: F403 from .datatypes import * # noqa: F403 From 23de9414248c6093029e5b187a9b39791a9b27b6 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Mon, 19 Aug 2024 14:12:18 -0700 Subject: [PATCH 5/8] Bump version to 0.0.6 --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index ed8b61612..e34ef87bb 100644 --- a/setup.py +++ b/setup.py @@ -16,7 +16,7 @@ def read_requirements(): setup( name="llama_toolchain", - version="0.0.5", + version="0.0.6", author="Meta Llama", author_email="llama-oss@meta.com", description="Llama toolchain", From b3da6b8afb4ec3a939960be38aef750fbce4d2d5 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Mon, 19 Aug 2024 16:27:36 -0700 Subject: [PATCH 6/8] Bump version to 0.0.7 --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index e34ef87bb..3733fe040 100644 --- a/setup.py +++ b/setup.py @@ -16,7 +16,7 @@ def read_requirements(): setup( name="llama_toolchain", - version="0.0.6", + version="0.0.7", author="Meta Llama", author_email="llama-oss@meta.com", description="Llama toolchain", From e08e963f861b9c44f80ac351c79cd9bc1c54000d Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Mon, 19 Aug 2024 18:26:30 -0700 Subject: [PATCH 7/8] Add --manifest-file option to argparser --- llama_toolchain/cli/download.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/llama_toolchain/cli/download.py b/llama_toolchain/cli/download.py index f7365b7b4..19a3ec535 100644 --- a/llama_toolchain/cli/download.py +++ b/llama_toolchain/cli/download.py @@ -74,6 +74,12 @@ For source=huggingface, files matching any of the patterns are not downloaded. D safetensors files to avoid downloading duplicate weights. """, ) + parser.add_argument( + "--manifest-file", + type=str, + help="For source=meta, you can download models from a manifest file containing a file => URL mapping", + required=False, + ) parser.set_defaults(func=partial(run_download_cmd, parser=parser)) From 57881c08c11d75e9bc52348d23526d79d1d181c4 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Mon, 19 Aug 2024 20:12:01 -0700 Subject: [PATCH 8/8] Bump version to 0.0.8 --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 3733fe040..45c4bcf32 100644 --- a/setup.py +++ b/setup.py @@ -16,7 +16,7 @@ def read_requirements(): setup( name="llama_toolchain", - version="0.0.7", + version="0.0.8", author="Meta Llama", author_email="llama-oss@meta.com", description="Llama toolchain",