mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
move straggler files and fix some important existing bugs
This commit is contained in:
parent
5e972ece13
commit
7890921e5c
5 changed files with 303 additions and 11 deletions
|
@ -4,15 +4,17 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
|
import copy
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import AsyncGenerator, List, Optional
|
||||||
|
|
||||||
from llama_toolchain.inference.api import Inference
|
from llama_toolchain.inference.api import Inference
|
||||||
from llama_toolchain.safety.api import Safety
|
from llama_toolchain.safety.api import Safety
|
||||||
|
|
||||||
from .api.endpoints import * # noqa
|
from .api.endpoints import * # noqa
|
||||||
|
|
||||||
import uuid
|
|
||||||
from datetime import datetime
|
|
||||||
from typing import AsyncGenerator, List, Optional
|
|
||||||
|
|
||||||
from llama_toolchain.inference.api import ChatCompletionRequest
|
from llama_toolchain.inference.api import ChatCompletionRequest
|
||||||
|
|
||||||
from llama_toolchain.inference.api.datatypes import (
|
from llama_toolchain.inference.api.datatypes import (
|
||||||
|
@ -219,13 +221,14 @@ class AgentInstance(ShieldRunnerMixin):
|
||||||
)
|
)
|
||||||
session.turns.append(turn)
|
session.turns.append(turn)
|
||||||
|
|
||||||
yield AgenticSystemTurnResponseStreamChunk(
|
chunk = AgenticSystemTurnResponseStreamChunk(
|
||||||
event=AgenticSystemTurnResponseEvent(
|
event=AgenticSystemTurnResponseEvent(
|
||||||
payload=AgenticSystemTurnResponseTurnCompletePayload(
|
payload=AgenticSystemTurnResponseTurnCompletePayload(
|
||||||
turn=turn,
|
turn=turn,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
yield chunk
|
||||||
|
|
||||||
async def run_shields_wrapper(
|
async def run_shields_wrapper(
|
||||||
self,
|
self,
|
||||||
|
@ -388,7 +391,10 @@ class AgentInstance(ShieldRunnerMixin):
|
||||||
stop_reason = None
|
stop_reason = None
|
||||||
async for chunk in self.inference_api.chat_completion(req):
|
async for chunk in self.inference_api.chat_completion(req):
|
||||||
event = chunk.event
|
event = chunk.event
|
||||||
if event.event_type != ChatCompletionResponseEventType.progress:
|
if event.event_type == ChatCompletionResponseEventType.start:
|
||||||
|
continue
|
||||||
|
elif event.event_type == ChatCompletionResponseEventType.complete:
|
||||||
|
stop_reason = StopReason.end_of_turn
|
||||||
continue
|
continue
|
||||||
|
|
||||||
delta = event.delta
|
delta = event.delta
|
||||||
|
@ -439,7 +445,12 @@ class AgentInstance(ShieldRunnerMixin):
|
||||||
step_type=StepType.inference.value,
|
step_type=StepType.inference.value,
|
||||||
step_id=step_id,
|
step_id=step_id,
|
||||||
step_details=InferenceStep(
|
step_details=InferenceStep(
|
||||||
step_id=step_id, turn_id=turn_id, model_response=message
|
# somewhere deep, we are re-assigning message or closing over some
|
||||||
|
# variable which causes message to mutate later on. fix with a
|
||||||
|
# `deepcopy` for now, but this is symptomatic of a deeper issue.
|
||||||
|
step_id=step_id,
|
||||||
|
turn_id=turn_id,
|
||||||
|
model_response=copy.deepcopy(message),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
166
llama_toolchain/agentic_system/event_logger.py
Normal file
166
llama_toolchain/agentic_system/event_logger.py
Normal file
|
@ -0,0 +1,166 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from llama_models.llama3_1.api.datatypes import ToolResponseMessage
|
||||||
|
from llama_models.llama3_1.api.tool_utils import ToolUtils
|
||||||
|
|
||||||
|
from llama_toolchain.agentic_system.api import (
|
||||||
|
AgenticSystemTurnResponseEventType,
|
||||||
|
StepType,
|
||||||
|
)
|
||||||
|
|
||||||
|
from termcolor import cprint
|
||||||
|
|
||||||
|
|
||||||
|
class LogEvent:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
role: Optional[str] = None,
|
||||||
|
content: str = "",
|
||||||
|
end: str = "\n",
|
||||||
|
color="white",
|
||||||
|
):
|
||||||
|
self.role = role
|
||||||
|
self.content = content
|
||||||
|
self.color = color
|
||||||
|
self.end = "\n" if end is None else end
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
if self.role is not None:
|
||||||
|
return f"{self.role}> {self.content}"
|
||||||
|
else:
|
||||||
|
return f"{self.content}"
|
||||||
|
|
||||||
|
def print(self, flush=True):
|
||||||
|
cprint(f"{str(self)}", color=self.color, end=self.end, flush=flush)
|
||||||
|
|
||||||
|
|
||||||
|
EventType = AgenticSystemTurnResponseEventType
|
||||||
|
|
||||||
|
|
||||||
|
class EventLogger:
|
||||||
|
async def log(self, event_generator, stream=True):
|
||||||
|
previous_event_type = None
|
||||||
|
previous_step_type = None
|
||||||
|
|
||||||
|
async for chunk in event_generator:
|
||||||
|
if not hasattr(chunk, "event"):
|
||||||
|
# Need to check for custom tool first
|
||||||
|
# since it does not produce event but instead
|
||||||
|
# a Message
|
||||||
|
if isinstance(chunk, ToolResponseMessage):
|
||||||
|
yield chunk, LogEvent(
|
||||||
|
role="CustomTool", content=chunk.content, color="grey"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
event = chunk.event
|
||||||
|
event_type = event.payload.event_type
|
||||||
|
if event_type in {
|
||||||
|
EventType.turn_start.value,
|
||||||
|
EventType.turn_complete.value,
|
||||||
|
}:
|
||||||
|
# Currently not logging any turn realted info
|
||||||
|
yield event, None
|
||||||
|
continue
|
||||||
|
|
||||||
|
step_type = event.payload.step_type
|
||||||
|
# handle safety
|
||||||
|
if (
|
||||||
|
step_type == StepType.shield_call
|
||||||
|
and event_type == EventType.step_complete.value
|
||||||
|
):
|
||||||
|
response = event.payload.step_details.response
|
||||||
|
if not response.is_violation:
|
||||||
|
yield event, LogEvent(
|
||||||
|
role=step_type, content="No Violation", color="magenta"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
yield event, LogEvent(
|
||||||
|
role=step_type,
|
||||||
|
content=f"{response.violation_type} {response.violation_return_message}",
|
||||||
|
color="red",
|
||||||
|
)
|
||||||
|
|
||||||
|
# handle inference
|
||||||
|
if step_type == StepType.inference:
|
||||||
|
if stream:
|
||||||
|
if event_type == EventType.step_start.value:
|
||||||
|
# TODO: Currently this event is never received
|
||||||
|
yield event, LogEvent(
|
||||||
|
role=step_type, content="", end="", color="yellow"
|
||||||
|
)
|
||||||
|
elif event_type == EventType.step_progress.value:
|
||||||
|
# HACK: if previous was not step/event was not inference's step_progress
|
||||||
|
# this is the first time we are getting model inference response
|
||||||
|
# aka equivalent to step_start for inference. Hence,
|
||||||
|
# start with "Model>".
|
||||||
|
if (
|
||||||
|
previous_event_type != EventType.step_progress.value
|
||||||
|
and previous_step_type != StepType.inference
|
||||||
|
):
|
||||||
|
yield event, LogEvent(
|
||||||
|
role=step_type, content="", end="", color="yellow"
|
||||||
|
)
|
||||||
|
|
||||||
|
if event.payload.tool_call_delta:
|
||||||
|
if isinstance(event.payload.tool_call_delta.content, str):
|
||||||
|
yield event, LogEvent(
|
||||||
|
role=None,
|
||||||
|
content=event.payload.tool_call_delta.content,
|
||||||
|
end="",
|
||||||
|
color="cyan",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
yield event, LogEvent(
|
||||||
|
role=None,
|
||||||
|
content=event.payload.model_response_text_delta,
|
||||||
|
end="",
|
||||||
|
color="yellow",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# step_complete
|
||||||
|
yield event, LogEvent(role=None, content="")
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Not streaming
|
||||||
|
if event_type == EventType.step_complete.value:
|
||||||
|
response = event.payload.step_details.model_response
|
||||||
|
if response.tool_calls:
|
||||||
|
content = ToolUtils.encode_tool_call(response.tool_calls[0])
|
||||||
|
else:
|
||||||
|
content = response.content
|
||||||
|
yield event, LogEvent(
|
||||||
|
role=step_type,
|
||||||
|
content=content,
|
||||||
|
color="yellow",
|
||||||
|
)
|
||||||
|
|
||||||
|
# handle tool_execution
|
||||||
|
if (
|
||||||
|
step_type == StepType.tool_execution
|
||||||
|
and
|
||||||
|
# Only print tool calls and responses at the step_complete event
|
||||||
|
event_type == EventType.step_complete.value
|
||||||
|
):
|
||||||
|
details = event.payload.step_details
|
||||||
|
for t in details.tool_calls:
|
||||||
|
yield event, LogEvent(
|
||||||
|
role=step_type,
|
||||||
|
content=f"Tool:{t.tool_name} Args:{t.arguments}",
|
||||||
|
color="green",
|
||||||
|
)
|
||||||
|
for r in details.tool_responses:
|
||||||
|
yield event, LogEvent(
|
||||||
|
role=step_type,
|
||||||
|
content=f"Tool:{r.tool_name} Response:{r.content}",
|
||||||
|
color="green",
|
||||||
|
)
|
||||||
|
|
||||||
|
preivous_event_type = event_type
|
||||||
|
previous_step_type = step_type
|
|
@ -47,7 +47,6 @@ async def execute_with_custom_tools(
|
||||||
yield chunk
|
yield chunk
|
||||||
else:
|
else:
|
||||||
turn = chunk.event.payload.turn
|
turn = chunk.event.payload.turn
|
||||||
break
|
|
||||||
|
|
||||||
message = turn.output_message
|
message = turn.output_message
|
||||||
if len(message.tool_calls) == 0:
|
if len(message.tool_calls) == 0:
|
||||||
|
|
115
llama_toolchain/agentic_system/utils.py
Normal file
115
llama_toolchain/agentic_system/utils.py
Normal file
|
@ -0,0 +1,115 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
|
from llama_models.llama3_1.api.datatypes import BuiltinTool, Message, SamplingParams
|
||||||
|
|
||||||
|
from llama_toolchain.agentic_system.api import (
|
||||||
|
AgenticSystemCreateRequest,
|
||||||
|
AgenticSystemInstanceConfig,
|
||||||
|
AgenticSystemSessionCreateRequest,
|
||||||
|
AgenticSystemToolDefinition,
|
||||||
|
)
|
||||||
|
from llama_toolchain.agentic_system.client import AgenticSystemClient
|
||||||
|
|
||||||
|
from llama_toolchain.agentic_system.tools.execute import execute_with_custom_tools
|
||||||
|
from llama_toolchain.safety.api.datatypes import BuiltinShield, ShieldDefinition
|
||||||
|
|
||||||
|
|
||||||
|
class AgenticSystemClientWrapper:
|
||||||
|
|
||||||
|
def __init__(self, api, system_id, custom_tools):
|
||||||
|
self.api = api
|
||||||
|
self.system_id = system_id
|
||||||
|
self.custom_tools = custom_tools
|
||||||
|
self.session_id = None
|
||||||
|
|
||||||
|
async def create_session(self, name: str = None):
|
||||||
|
if name is None:
|
||||||
|
name = f"Session-{uuid.uuid4()}"
|
||||||
|
|
||||||
|
response = await self.api.create_agentic_system_session(
|
||||||
|
AgenticSystemSessionCreateRequest(
|
||||||
|
system_id=self.system_id,
|
||||||
|
session_name=name,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.session_id = response.session_id
|
||||||
|
return self.session_id
|
||||||
|
|
||||||
|
async def run(self, messages: List[Message], stream: bool = True):
|
||||||
|
async for chunk in execute_with_custom_tools(
|
||||||
|
self.api,
|
||||||
|
self.system_id,
|
||||||
|
self.session_id,
|
||||||
|
messages,
|
||||||
|
self.custom_tools,
|
||||||
|
stream=stream,
|
||||||
|
):
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
|
||||||
|
async def get_agent_system_instance(
|
||||||
|
host: str,
|
||||||
|
port: int,
|
||||||
|
custom_tools: Optional[List[Any]] = None,
|
||||||
|
disable_safety: bool = False,
|
||||||
|
model: str = "Meta-Llama3.1-8B-Instruct",
|
||||||
|
) -> AgenticSystemClientWrapper:
|
||||||
|
custom_tools = custom_tools or []
|
||||||
|
|
||||||
|
api = AgenticSystemClient(base_url=f"http://{host}:{port}")
|
||||||
|
|
||||||
|
tool_definitions = [
|
||||||
|
AgenticSystemToolDefinition(
|
||||||
|
tool_name=BuiltinTool.brave_search,
|
||||||
|
),
|
||||||
|
AgenticSystemToolDefinition(
|
||||||
|
tool_name=BuiltinTool.wolfram_alpha,
|
||||||
|
),
|
||||||
|
AgenticSystemToolDefinition(
|
||||||
|
tool_name=BuiltinTool.photogen,
|
||||||
|
),
|
||||||
|
AgenticSystemToolDefinition(
|
||||||
|
tool_name=BuiltinTool.code_interpreter,
|
||||||
|
),
|
||||||
|
] + [t.get_tool_definition() for t in custom_tools]
|
||||||
|
|
||||||
|
if not disable_safety:
|
||||||
|
for t in tool_definitions:
|
||||||
|
t.input_shields = [ShieldDefinition(shield_type=BuiltinShield.llama_guard)]
|
||||||
|
t.output_shields = [
|
||||||
|
ShieldDefinition(shield_type=BuiltinShield.llama_guard),
|
||||||
|
ShieldDefinition(shield_type=BuiltinShield.injection_shield),
|
||||||
|
]
|
||||||
|
|
||||||
|
create_request = AgenticSystemCreateRequest(
|
||||||
|
model=model,
|
||||||
|
instance_config=AgenticSystemInstanceConfig(
|
||||||
|
instructions="You are a helpful assistant",
|
||||||
|
available_tools=tool_definitions,
|
||||||
|
input_shields=(
|
||||||
|
[]
|
||||||
|
if disable_safety
|
||||||
|
else [
|
||||||
|
ShieldDefinition(shield_type=BuiltinShield.llama_guard),
|
||||||
|
ShieldDefinition(shield_type=BuiltinShield.jailbreak_shield),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
output_shields=(
|
||||||
|
[]
|
||||||
|
if disable_safety
|
||||||
|
else [
|
||||||
|
ShieldDefinition(shield_type=BuiltinShield.llama_guard),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
sampling_params=SamplingParams(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
create_response = await api.create_agentic_system(create_request)
|
||||||
|
return AgenticSystemClientWrapper(api, create_response.system_id, custom_tools)
|
|
@ -233,9 +233,10 @@ def topological_sort(adapters: List[Adapter]) -> List[Adapter]:
|
||||||
def dfs(a: Adapter, visited: Set[ApiSurface], stack: List[ApiSurface]):
|
def dfs(a: Adapter, visited: Set[ApiSurface], stack: List[ApiSurface]):
|
||||||
visited.add(a.api_surface)
|
visited.add(a.api_surface)
|
||||||
|
|
||||||
for surface in a.adapter_dependencies:
|
if not isinstance(a, PassthroughApiAdapter):
|
||||||
if surface not in visited:
|
for surface in a.adapter_dependencies:
|
||||||
dfs(by_id[surface], visited, stack)
|
if surface not in visited:
|
||||||
|
dfs(by_id[surface], visited, stack)
|
||||||
|
|
||||||
stack.append(a.api_surface)
|
stack.append(a.api_surface)
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue