From 7890921e5cd9882a6ad680e4929c4134ac4be451 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Mon, 5 Aug 2024 09:24:45 -0700 Subject: [PATCH] move straggler files and fix some important existing bugs --- .../agentic_system/agent_instance.py | 25 ++- .../agentic_system/event_logger.py | 166 ++++++++++++++++++ .../agentic_system/tools/execute.py | 1 - llama_toolchain/agentic_system/utils.py | 115 ++++++++++++ llama_toolchain/distribution/server.py | 7 +- 5 files changed, 303 insertions(+), 11 deletions(-) create mode 100644 llama_toolchain/agentic_system/event_logger.py create mode 100644 llama_toolchain/agentic_system/utils.py diff --git a/llama_toolchain/agentic_system/agent_instance.py b/llama_toolchain/agentic_system/agent_instance.py index e736394de..afb00655e 100644 --- a/llama_toolchain/agentic_system/agent_instance.py +++ b/llama_toolchain/agentic_system/agent_instance.py @@ -4,15 +4,17 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. + +import copy +import uuid +from datetime import datetime +from typing import AsyncGenerator, List, Optional + from llama_toolchain.inference.api import Inference from llama_toolchain.safety.api import Safety from .api.endpoints import * # noqa -import uuid -from datetime import datetime -from typing import AsyncGenerator, List, Optional - from llama_toolchain.inference.api import ChatCompletionRequest from llama_toolchain.inference.api.datatypes import ( @@ -219,13 +221,14 @@ class AgentInstance(ShieldRunnerMixin): ) session.turns.append(turn) - yield AgenticSystemTurnResponseStreamChunk( + chunk = AgenticSystemTurnResponseStreamChunk( event=AgenticSystemTurnResponseEvent( payload=AgenticSystemTurnResponseTurnCompletePayload( turn=turn, ) ) ) + yield chunk async def run_shields_wrapper( self, @@ -388,7 +391,10 @@ class AgentInstance(ShieldRunnerMixin): stop_reason = None async for chunk in self.inference_api.chat_completion(req): event = chunk.event - if event.event_type != ChatCompletionResponseEventType.progress: + if event.event_type == ChatCompletionResponseEventType.start: + continue + elif event.event_type == ChatCompletionResponseEventType.complete: + stop_reason = StopReason.end_of_turn continue delta = event.delta @@ -439,7 +445,12 @@ class AgentInstance(ShieldRunnerMixin): step_type=StepType.inference.value, step_id=step_id, step_details=InferenceStep( - step_id=step_id, turn_id=turn_id, model_response=message + # somewhere deep, we are re-assigning message or closing over some + # variable which causes message to mutate later on. fix with a + # `deepcopy` for now, but this is symptomatic of a deeper issue. + step_id=step_id, + turn_id=turn_id, + model_response=copy.deepcopy(message), ), ) ) diff --git a/llama_toolchain/agentic_system/event_logger.py b/llama_toolchain/agentic_system/event_logger.py new file mode 100644 index 000000000..1bf669a0a --- /dev/null +++ b/llama_toolchain/agentic_system/event_logger.py @@ -0,0 +1,166 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Optional + +from llama_models.llama3_1.api.datatypes import ToolResponseMessage +from llama_models.llama3_1.api.tool_utils import ToolUtils + +from llama_toolchain.agentic_system.api import ( + AgenticSystemTurnResponseEventType, + StepType, +) + +from termcolor import cprint + + +class LogEvent: + def __init__( + self, + role: Optional[str] = None, + content: str = "", + end: str = "\n", + color="white", + ): + self.role = role + self.content = content + self.color = color + self.end = "\n" if end is None else end + + def __str__(self): + if self.role is not None: + return f"{self.role}> {self.content}" + else: + return f"{self.content}" + + def print(self, flush=True): + cprint(f"{str(self)}", color=self.color, end=self.end, flush=flush) + + +EventType = AgenticSystemTurnResponseEventType + + +class EventLogger: + async def log(self, event_generator, stream=True): + previous_event_type = None + previous_step_type = None + + async for chunk in event_generator: + if not hasattr(chunk, "event"): + # Need to check for custom tool first + # since it does not produce event but instead + # a Message + if isinstance(chunk, ToolResponseMessage): + yield chunk, LogEvent( + role="CustomTool", content=chunk.content, color="grey" + ) + continue + + event = chunk.event + event_type = event.payload.event_type + if event_type in { + EventType.turn_start.value, + EventType.turn_complete.value, + }: + # Currently not logging any turn realted info + yield event, None + continue + + step_type = event.payload.step_type + # handle safety + if ( + step_type == StepType.shield_call + and event_type == EventType.step_complete.value + ): + response = event.payload.step_details.response + if not response.is_violation: + yield event, LogEvent( + role=step_type, content="No Violation", color="magenta" + ) + else: + yield event, LogEvent( + role=step_type, + content=f"{response.violation_type} {response.violation_return_message}", + color="red", + ) + + # handle inference + if step_type == StepType.inference: + if stream: + if event_type == EventType.step_start.value: + # TODO: Currently this event is never received + yield event, LogEvent( + role=step_type, content="", end="", color="yellow" + ) + elif event_type == EventType.step_progress.value: + # HACK: if previous was not step/event was not inference's step_progress + # this is the first time we are getting model inference response + # aka equivalent to step_start for inference. Hence, + # start with "Model>". + if ( + previous_event_type != EventType.step_progress.value + and previous_step_type != StepType.inference + ): + yield event, LogEvent( + role=step_type, content="", end="", color="yellow" + ) + + if event.payload.tool_call_delta: + if isinstance(event.payload.tool_call_delta.content, str): + yield event, LogEvent( + role=None, + content=event.payload.tool_call_delta.content, + end="", + color="cyan", + ) + else: + yield event, LogEvent( + role=None, + content=event.payload.model_response_text_delta, + end="", + color="yellow", + ) + else: + # step_complete + yield event, LogEvent(role=None, content="") + + else: + # Not streaming + if event_type == EventType.step_complete.value: + response = event.payload.step_details.model_response + if response.tool_calls: + content = ToolUtils.encode_tool_call(response.tool_calls[0]) + else: + content = response.content + yield event, LogEvent( + role=step_type, + content=content, + color="yellow", + ) + + # handle tool_execution + if ( + step_type == StepType.tool_execution + and + # Only print tool calls and responses at the step_complete event + event_type == EventType.step_complete.value + ): + details = event.payload.step_details + for t in details.tool_calls: + yield event, LogEvent( + role=step_type, + content=f"Tool:{t.tool_name} Args:{t.arguments}", + color="green", + ) + for r in details.tool_responses: + yield event, LogEvent( + role=step_type, + content=f"Tool:{r.tool_name} Response:{r.content}", + color="green", + ) + + preivous_event_type = event_type + previous_step_type = step_type diff --git a/llama_toolchain/agentic_system/tools/execute.py b/llama_toolchain/agentic_system/tools/execute.py index 2a7625f65..987aee4e2 100644 --- a/llama_toolchain/agentic_system/tools/execute.py +++ b/llama_toolchain/agentic_system/tools/execute.py @@ -47,7 +47,6 @@ async def execute_with_custom_tools( yield chunk else: turn = chunk.event.payload.turn - break message = turn.output_message if len(message.tool_calls) == 0: diff --git a/llama_toolchain/agentic_system/utils.py b/llama_toolchain/agentic_system/utils.py new file mode 100644 index 000000000..293d98944 --- /dev/null +++ b/llama_toolchain/agentic_system/utils.py @@ -0,0 +1,115 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import uuid +from typing import Any, List, Optional + +from llama_models.llama3_1.api.datatypes import BuiltinTool, Message, SamplingParams + +from llama_toolchain.agentic_system.api import ( + AgenticSystemCreateRequest, + AgenticSystemInstanceConfig, + AgenticSystemSessionCreateRequest, + AgenticSystemToolDefinition, +) +from llama_toolchain.agentic_system.client import AgenticSystemClient + +from llama_toolchain.agentic_system.tools.execute import execute_with_custom_tools +from llama_toolchain.safety.api.datatypes import BuiltinShield, ShieldDefinition + + +class AgenticSystemClientWrapper: + + def __init__(self, api, system_id, custom_tools): + self.api = api + self.system_id = system_id + self.custom_tools = custom_tools + self.session_id = None + + async def create_session(self, name: str = None): + if name is None: + name = f"Session-{uuid.uuid4()}" + + response = await self.api.create_agentic_system_session( + AgenticSystemSessionCreateRequest( + system_id=self.system_id, + session_name=name, + ) + ) + self.session_id = response.session_id + return self.session_id + + async def run(self, messages: List[Message], stream: bool = True): + async for chunk in execute_with_custom_tools( + self.api, + self.system_id, + self.session_id, + messages, + self.custom_tools, + stream=stream, + ): + yield chunk + + +async def get_agent_system_instance( + host: str, + port: int, + custom_tools: Optional[List[Any]] = None, + disable_safety: bool = False, + model: str = "Meta-Llama3.1-8B-Instruct", +) -> AgenticSystemClientWrapper: + custom_tools = custom_tools or [] + + api = AgenticSystemClient(base_url=f"http://{host}:{port}") + + tool_definitions = [ + AgenticSystemToolDefinition( + tool_name=BuiltinTool.brave_search, + ), + AgenticSystemToolDefinition( + tool_name=BuiltinTool.wolfram_alpha, + ), + AgenticSystemToolDefinition( + tool_name=BuiltinTool.photogen, + ), + AgenticSystemToolDefinition( + tool_name=BuiltinTool.code_interpreter, + ), + ] + [t.get_tool_definition() for t in custom_tools] + + if not disable_safety: + for t in tool_definitions: + t.input_shields = [ShieldDefinition(shield_type=BuiltinShield.llama_guard)] + t.output_shields = [ + ShieldDefinition(shield_type=BuiltinShield.llama_guard), + ShieldDefinition(shield_type=BuiltinShield.injection_shield), + ] + + create_request = AgenticSystemCreateRequest( + model=model, + instance_config=AgenticSystemInstanceConfig( + instructions="You are a helpful assistant", + available_tools=tool_definitions, + input_shields=( + [] + if disable_safety + else [ + ShieldDefinition(shield_type=BuiltinShield.llama_guard), + ShieldDefinition(shield_type=BuiltinShield.jailbreak_shield), + ] + ), + output_shields=( + [] + if disable_safety + else [ + ShieldDefinition(shield_type=BuiltinShield.llama_guard), + ] + ), + sampling_params=SamplingParams(), + ), + ) + create_response = await api.create_agentic_system(create_request) + return AgenticSystemClientWrapper(api, create_response.system_id, custom_tools) diff --git a/llama_toolchain/distribution/server.py b/llama_toolchain/distribution/server.py index cfbe8b207..5bcabf343 100644 --- a/llama_toolchain/distribution/server.py +++ b/llama_toolchain/distribution/server.py @@ -233,9 +233,10 @@ def topological_sort(adapters: List[Adapter]) -> List[Adapter]: def dfs(a: Adapter, visited: Set[ApiSurface], stack: List[ApiSurface]): visited.add(a.api_surface) - for surface in a.adapter_dependencies: - if surface not in visited: - dfs(by_id[surface], visited, stack) + if not isinstance(a, PassthroughApiAdapter): + for surface in a.adapter_dependencies: + if surface not in visited: + dfs(by_id[surface], visited, stack) stack.append(a.api_surface)