mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
move straggler files and fix some important existing bugs
This commit is contained in:
parent
5e972ece13
commit
7890921e5c
5 changed files with 303 additions and 11 deletions
|
@ -4,15 +4,17 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
import copy
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import AsyncGenerator, List, Optional
|
||||
|
||||
from llama_toolchain.inference.api import Inference
|
||||
from llama_toolchain.safety.api import Safety
|
||||
|
||||
from .api.endpoints import * # noqa
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import AsyncGenerator, List, Optional
|
||||
|
||||
from llama_toolchain.inference.api import ChatCompletionRequest
|
||||
|
||||
from llama_toolchain.inference.api.datatypes import (
|
||||
|
@ -219,13 +221,14 @@ class AgentInstance(ShieldRunnerMixin):
|
|||
)
|
||||
session.turns.append(turn)
|
||||
|
||||
yield AgenticSystemTurnResponseStreamChunk(
|
||||
chunk = AgenticSystemTurnResponseStreamChunk(
|
||||
event=AgenticSystemTurnResponseEvent(
|
||||
payload=AgenticSystemTurnResponseTurnCompletePayload(
|
||||
turn=turn,
|
||||
)
|
||||
)
|
||||
)
|
||||
yield chunk
|
||||
|
||||
async def run_shields_wrapper(
|
||||
self,
|
||||
|
@ -388,7 +391,10 @@ class AgentInstance(ShieldRunnerMixin):
|
|||
stop_reason = None
|
||||
async for chunk in self.inference_api.chat_completion(req):
|
||||
event = chunk.event
|
||||
if event.event_type != ChatCompletionResponseEventType.progress:
|
||||
if event.event_type == ChatCompletionResponseEventType.start:
|
||||
continue
|
||||
elif event.event_type == ChatCompletionResponseEventType.complete:
|
||||
stop_reason = StopReason.end_of_turn
|
||||
continue
|
||||
|
||||
delta = event.delta
|
||||
|
@ -439,7 +445,12 @@ class AgentInstance(ShieldRunnerMixin):
|
|||
step_type=StepType.inference.value,
|
||||
step_id=step_id,
|
||||
step_details=InferenceStep(
|
||||
step_id=step_id, turn_id=turn_id, model_response=message
|
||||
# somewhere deep, we are re-assigning message or closing over some
|
||||
# variable which causes message to mutate later on. fix with a
|
||||
# `deepcopy` for now, but this is symptomatic of a deeper issue.
|
||||
step_id=step_id,
|
||||
turn_id=turn_id,
|
||||
model_response=copy.deepcopy(message),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
|
166
llama_toolchain/agentic_system/event_logger.py
Normal file
166
llama_toolchain/agentic_system/event_logger.py
Normal file
|
@ -0,0 +1,166 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from llama_models.llama3_1.api.datatypes import ToolResponseMessage
|
||||
from llama_models.llama3_1.api.tool_utils import ToolUtils
|
||||
|
||||
from llama_toolchain.agentic_system.api import (
|
||||
AgenticSystemTurnResponseEventType,
|
||||
StepType,
|
||||
)
|
||||
|
||||
from termcolor import cprint
|
||||
|
||||
|
||||
class LogEvent:
|
||||
def __init__(
|
||||
self,
|
||||
role: Optional[str] = None,
|
||||
content: str = "",
|
||||
end: str = "\n",
|
||||
color="white",
|
||||
):
|
||||
self.role = role
|
||||
self.content = content
|
||||
self.color = color
|
||||
self.end = "\n" if end is None else end
|
||||
|
||||
def __str__(self):
|
||||
if self.role is not None:
|
||||
return f"{self.role}> {self.content}"
|
||||
else:
|
||||
return f"{self.content}"
|
||||
|
||||
def print(self, flush=True):
|
||||
cprint(f"{str(self)}", color=self.color, end=self.end, flush=flush)
|
||||
|
||||
|
||||
EventType = AgenticSystemTurnResponseEventType
|
||||
|
||||
|
||||
class EventLogger:
|
||||
async def log(self, event_generator, stream=True):
|
||||
previous_event_type = None
|
||||
previous_step_type = None
|
||||
|
||||
async for chunk in event_generator:
|
||||
if not hasattr(chunk, "event"):
|
||||
# Need to check for custom tool first
|
||||
# since it does not produce event but instead
|
||||
# a Message
|
||||
if isinstance(chunk, ToolResponseMessage):
|
||||
yield chunk, LogEvent(
|
||||
role="CustomTool", content=chunk.content, color="grey"
|
||||
)
|
||||
continue
|
||||
|
||||
event = chunk.event
|
||||
event_type = event.payload.event_type
|
||||
if event_type in {
|
||||
EventType.turn_start.value,
|
||||
EventType.turn_complete.value,
|
||||
}:
|
||||
# Currently not logging any turn realted info
|
||||
yield event, None
|
||||
continue
|
||||
|
||||
step_type = event.payload.step_type
|
||||
# handle safety
|
||||
if (
|
||||
step_type == StepType.shield_call
|
||||
and event_type == EventType.step_complete.value
|
||||
):
|
||||
response = event.payload.step_details.response
|
||||
if not response.is_violation:
|
||||
yield event, LogEvent(
|
||||
role=step_type, content="No Violation", color="magenta"
|
||||
)
|
||||
else:
|
||||
yield event, LogEvent(
|
||||
role=step_type,
|
||||
content=f"{response.violation_type} {response.violation_return_message}",
|
||||
color="red",
|
||||
)
|
||||
|
||||
# handle inference
|
||||
if step_type == StepType.inference:
|
||||
if stream:
|
||||
if event_type == EventType.step_start.value:
|
||||
# TODO: Currently this event is never received
|
||||
yield event, LogEvent(
|
||||
role=step_type, content="", end="", color="yellow"
|
||||
)
|
||||
elif event_type == EventType.step_progress.value:
|
||||
# HACK: if previous was not step/event was not inference's step_progress
|
||||
# this is the first time we are getting model inference response
|
||||
# aka equivalent to step_start for inference. Hence,
|
||||
# start with "Model>".
|
||||
if (
|
||||
previous_event_type != EventType.step_progress.value
|
||||
and previous_step_type != StepType.inference
|
||||
):
|
||||
yield event, LogEvent(
|
||||
role=step_type, content="", end="", color="yellow"
|
||||
)
|
||||
|
||||
if event.payload.tool_call_delta:
|
||||
if isinstance(event.payload.tool_call_delta.content, str):
|
||||
yield event, LogEvent(
|
||||
role=None,
|
||||
content=event.payload.tool_call_delta.content,
|
||||
end="",
|
||||
color="cyan",
|
||||
)
|
||||
else:
|
||||
yield event, LogEvent(
|
||||
role=None,
|
||||
content=event.payload.model_response_text_delta,
|
||||
end="",
|
||||
color="yellow",
|
||||
)
|
||||
else:
|
||||
# step_complete
|
||||
yield event, LogEvent(role=None, content="")
|
||||
|
||||
else:
|
||||
# Not streaming
|
||||
if event_type == EventType.step_complete.value:
|
||||
response = event.payload.step_details.model_response
|
||||
if response.tool_calls:
|
||||
content = ToolUtils.encode_tool_call(response.tool_calls[0])
|
||||
else:
|
||||
content = response.content
|
||||
yield event, LogEvent(
|
||||
role=step_type,
|
||||
content=content,
|
||||
color="yellow",
|
||||
)
|
||||
|
||||
# handle tool_execution
|
||||
if (
|
||||
step_type == StepType.tool_execution
|
||||
and
|
||||
# Only print tool calls and responses at the step_complete event
|
||||
event_type == EventType.step_complete.value
|
||||
):
|
||||
details = event.payload.step_details
|
||||
for t in details.tool_calls:
|
||||
yield event, LogEvent(
|
||||
role=step_type,
|
||||
content=f"Tool:{t.tool_name} Args:{t.arguments}",
|
||||
color="green",
|
||||
)
|
||||
for r in details.tool_responses:
|
||||
yield event, LogEvent(
|
||||
role=step_type,
|
||||
content=f"Tool:{r.tool_name} Response:{r.content}",
|
||||
color="green",
|
||||
)
|
||||
|
||||
preivous_event_type = event_type
|
||||
previous_step_type = step_type
|
|
@ -47,7 +47,6 @@ async def execute_with_custom_tools(
|
|||
yield chunk
|
||||
else:
|
||||
turn = chunk.event.payload.turn
|
||||
break
|
||||
|
||||
message = turn.output_message
|
||||
if len(message.tool_calls) == 0:
|
||||
|
|
115
llama_toolchain/agentic_system/utils.py
Normal file
115
llama_toolchain/agentic_system/utils.py
Normal file
|
@ -0,0 +1,115 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import uuid
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from llama_models.llama3_1.api.datatypes import BuiltinTool, Message, SamplingParams
|
||||
|
||||
from llama_toolchain.agentic_system.api import (
|
||||
AgenticSystemCreateRequest,
|
||||
AgenticSystemInstanceConfig,
|
||||
AgenticSystemSessionCreateRequest,
|
||||
AgenticSystemToolDefinition,
|
||||
)
|
||||
from llama_toolchain.agentic_system.client import AgenticSystemClient
|
||||
|
||||
from llama_toolchain.agentic_system.tools.execute import execute_with_custom_tools
|
||||
from llama_toolchain.safety.api.datatypes import BuiltinShield, ShieldDefinition
|
||||
|
||||
|
||||
class AgenticSystemClientWrapper:
|
||||
|
||||
def __init__(self, api, system_id, custom_tools):
|
||||
self.api = api
|
||||
self.system_id = system_id
|
||||
self.custom_tools = custom_tools
|
||||
self.session_id = None
|
||||
|
||||
async def create_session(self, name: str = None):
|
||||
if name is None:
|
||||
name = f"Session-{uuid.uuid4()}"
|
||||
|
||||
response = await self.api.create_agentic_system_session(
|
||||
AgenticSystemSessionCreateRequest(
|
||||
system_id=self.system_id,
|
||||
session_name=name,
|
||||
)
|
||||
)
|
||||
self.session_id = response.session_id
|
||||
return self.session_id
|
||||
|
||||
async def run(self, messages: List[Message], stream: bool = True):
|
||||
async for chunk in execute_with_custom_tools(
|
||||
self.api,
|
||||
self.system_id,
|
||||
self.session_id,
|
||||
messages,
|
||||
self.custom_tools,
|
||||
stream=stream,
|
||||
):
|
||||
yield chunk
|
||||
|
||||
|
||||
async def get_agent_system_instance(
|
||||
host: str,
|
||||
port: int,
|
||||
custom_tools: Optional[List[Any]] = None,
|
||||
disable_safety: bool = False,
|
||||
model: str = "Meta-Llama3.1-8B-Instruct",
|
||||
) -> AgenticSystemClientWrapper:
|
||||
custom_tools = custom_tools or []
|
||||
|
||||
api = AgenticSystemClient(base_url=f"http://{host}:{port}")
|
||||
|
||||
tool_definitions = [
|
||||
AgenticSystemToolDefinition(
|
||||
tool_name=BuiltinTool.brave_search,
|
||||
),
|
||||
AgenticSystemToolDefinition(
|
||||
tool_name=BuiltinTool.wolfram_alpha,
|
||||
),
|
||||
AgenticSystemToolDefinition(
|
||||
tool_name=BuiltinTool.photogen,
|
||||
),
|
||||
AgenticSystemToolDefinition(
|
||||
tool_name=BuiltinTool.code_interpreter,
|
||||
),
|
||||
] + [t.get_tool_definition() for t in custom_tools]
|
||||
|
||||
if not disable_safety:
|
||||
for t in tool_definitions:
|
||||
t.input_shields = [ShieldDefinition(shield_type=BuiltinShield.llama_guard)]
|
||||
t.output_shields = [
|
||||
ShieldDefinition(shield_type=BuiltinShield.llama_guard),
|
||||
ShieldDefinition(shield_type=BuiltinShield.injection_shield),
|
||||
]
|
||||
|
||||
create_request = AgenticSystemCreateRequest(
|
||||
model=model,
|
||||
instance_config=AgenticSystemInstanceConfig(
|
||||
instructions="You are a helpful assistant",
|
||||
available_tools=tool_definitions,
|
||||
input_shields=(
|
||||
[]
|
||||
if disable_safety
|
||||
else [
|
||||
ShieldDefinition(shield_type=BuiltinShield.llama_guard),
|
||||
ShieldDefinition(shield_type=BuiltinShield.jailbreak_shield),
|
||||
]
|
||||
),
|
||||
output_shields=(
|
||||
[]
|
||||
if disable_safety
|
||||
else [
|
||||
ShieldDefinition(shield_type=BuiltinShield.llama_guard),
|
||||
]
|
||||
),
|
||||
sampling_params=SamplingParams(),
|
||||
),
|
||||
)
|
||||
create_response = await api.create_agentic_system(create_request)
|
||||
return AgenticSystemClientWrapper(api, create_response.system_id, custom_tools)
|
|
@ -233,9 +233,10 @@ def topological_sort(adapters: List[Adapter]) -> List[Adapter]:
|
|||
def dfs(a: Adapter, visited: Set[ApiSurface], stack: List[ApiSurface]):
|
||||
visited.add(a.api_surface)
|
||||
|
||||
for surface in a.adapter_dependencies:
|
||||
if surface not in visited:
|
||||
dfs(by_id[surface], visited, stack)
|
||||
if not isinstance(a, PassthroughApiAdapter):
|
||||
for surface in a.adapter_dependencies:
|
||||
if surface not in visited:
|
||||
dfs(by_id[surface], visited, stack)
|
||||
|
||||
stack.append(a.api_surface)
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue