Codemod from llama_toolchain -> llama_stack

- added providers/registry
- cleaned up api/ subdirectories and moved impls away
- restructured api/api.py
- from llama_stack.apis.<api> import foo should work now
- update imports to do llama_stack.apis.<api>
- update many other imports
- added __init__, fixed some registry imports
- updated registry imports
- create_agentic_system -> create_agent
- AgenticSystem -> Agent
This commit is contained in:
Ashwin Bharambe 2024-09-16 17:34:07 -07:00
parent 2cf731faea
commit 76b354a081
128 changed files with 381 additions and 376 deletions

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,184 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Optional
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_models.llama3.api.tool_utils import ToolUtils
from termcolor import cprint
from llama_stack.apis.agents import AgentTurnResponseEventType, StepType
class LogEvent:
def __init__(
self,
role: Optional[str] = None,
content: str = "",
end: str = "\n",
color="white",
):
self.role = role
self.content = content
self.color = color
self.end = "\n" if end is None else end
def __str__(self):
if self.role is not None:
return f"{self.role}> {self.content}"
else:
return f"{self.content}"
def print(self, flush=True):
cprint(f"{str(self)}", color=self.color, end=self.end, flush=flush)
EventType = AgentTurnResponseEventType
class EventLogger:
async def log(
self,
event_generator,
stream=True,
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
):
previous_event_type = None
previous_step_type = None
async for chunk in event_generator:
if not hasattr(chunk, "event"):
# Need to check for custom tool first
# since it does not produce event but instead
# a Message
if isinstance(chunk, ToolResponseMessage):
yield chunk, LogEvent(
role="CustomTool", content=chunk.content, color="grey"
)
continue
event = chunk.event
event_type = event.payload.event_type
if event_type in {
EventType.turn_start.value,
EventType.turn_complete.value,
}:
# Currently not logging any turn realted info
yield event, None
continue
step_type = event.payload.step_type
# handle safety
if (
step_type == StepType.shield_call
and event_type == EventType.step_complete.value
):
response = event.payload.step_details.response
if not response.is_violation:
yield event, LogEvent(
role=step_type, content="No Violation", color="magenta"
)
else:
yield event, LogEvent(
role=step_type,
content=f"{response.violation_type} {response.violation_return_message}",
color="red",
)
# handle inference
if step_type == StepType.inference:
if stream:
if event_type == EventType.step_start.value:
# TODO: Currently this event is never received
yield event, LogEvent(
role=step_type, content="", end="", color="yellow"
)
elif event_type == EventType.step_progress.value:
# HACK: if previous was not step/event was not inference's step_progress
# this is the first time we are getting model inference response
# aka equivalent to step_start for inference. Hence,
# start with "Model>".
if (
previous_event_type != EventType.step_progress.value
and previous_step_type != StepType.inference
):
yield event, LogEvent(
role=step_type, content="", end="", color="yellow"
)
if event.payload.tool_call_delta:
if isinstance(event.payload.tool_call_delta.content, str):
yield event, LogEvent(
role=None,
content=event.payload.tool_call_delta.content,
end="",
color="cyan",
)
else:
yield event, LogEvent(
role=None,
content=event.payload.model_response_text_delta,
end="",
color="yellow",
)
else:
# step_complete
yield event, LogEvent(role=None, content="")
else:
# Not streaming
if event_type == EventType.step_complete.value:
response = event.payload.step_details.model_response
if response.tool_calls:
content = ToolUtils.encode_tool_call(
response.tool_calls[0], tool_prompt_format
)
else:
content = response.content
yield event, LogEvent(
role=step_type,
content=content,
color="yellow",
)
# handle tool_execution
if (
step_type == StepType.tool_execution
and
# Only print tool calls and responses at the step_complete event
event_type == EventType.step_complete.value
):
details = event.payload.step_details
for t in details.tool_calls:
yield event, LogEvent(
role=step_type,
content=f"Tool:{t.tool_name} Args:{t.arguments}",
color="green",
)
for r in details.tool_responses:
yield event, LogEvent(
role=step_type,
content=f"Tool:{r.tool_name} Response:{r.content}",
color="green",
)
if (
step_type == StepType.memory_retrieval
and event_type == EventType.step_complete.value
):
details = event.payload.step_details
content = interleaved_text_media_as_str(details.inserted_context)
content = content[:200] + "..." if len(content) > 200 else content
yield event, LogEvent(
role=step_type,
content=f"Retrieved context from banks: {details.memory_bank_ids}.\n====\n{content}\n>",
color="cyan",
)
preivous_event_type = event_type
previous_step_type = step_type

View file

@ -0,0 +1,94 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import AsyncGenerator, List
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.agents import * # noqa: F403
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.apis.safety import * # noqa: F403
from llama_stack.apis.agents import AgentTurnResponseEventType as EventType
from llama_stack.tools.custom.datatypes import CustomTool
class AgentWithCustomToolExecutor:
def __init__(
self,
api: Agents,
agent_id: str,
session_id: str,
agent_config: AgentConfig,
custom_tools: List[CustomTool],
):
self.api = api
self.agent_id = agent_id
self.session_id = session_id
self.agent_config = agent_config
self.custom_tools = custom_tools
async def execute_turn(
self,
messages: List[Message],
attachments: Optional[List[Attachment]] = None,
max_iters: int = 5,
stream: bool = True,
) -> AsyncGenerator:
tools_dict = {t.get_name(): t for t in self.custom_tools}
current_messages = messages.copy()
n_iter = 0
while n_iter < max_iters:
n_iter += 1
request = AgentTurnCreateRequest(
agent_id=self.agent_id,
session_id=self.session_id,
messages=current_messages,
attachments=attachments,
stream=stream,
)
turn = None
async for chunk in self.api.create_agent_turn(request):
if chunk.event.payload.event_type != EventType.turn_complete.value:
yield chunk
else:
turn = chunk.event.payload.turn
message = turn.output_message
if len(message.tool_calls) == 0:
yield chunk
return
if message.stop_reason == StopReason.out_of_tokens:
yield chunk
return
tool_call = message.tool_calls[0]
if tool_call.tool_name not in tools_dict:
m = ToolResponseMessage(
call_id=tool_call.call_id,
tool_name=tool_call.tool_name,
content=f"Unknown tool `{tool_call.tool_name}` was called. Try again with something else",
)
next_message = m
else:
tool = tools_dict[tool_call.tool_name]
result_messages = await execute_custom_tool(tool, message)
next_message = result_messages[0]
yield next_message
current_messages = [next_message]
async def execute_custom_tool(tool: CustomTool, message: Message) -> List[Message]:
result_messages = await tool.run([message])
assert (
len(result_messages) == 1
), f"Expected single message, got {len(result_messages)}"
return result_messages

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,42 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.inference import (
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
)
from termcolor import cprint
class LogEvent:
def __init__(
self,
content: str = "",
end: str = "\n",
color="white",
):
self.content = content
self.color = color
self.end = "\n" if end is None else end
def print(self, flush=True):
cprint(f"{self.content}", color=self.color, end=self.end, flush=flush)
class EventLogger:
async def log(self, event_generator):
async for chunk in event_generator:
if isinstance(chunk, ChatCompletionResponseStreamChunk):
event = chunk.event
if event.event_type == ChatCompletionResponseEventType.start:
yield LogEvent("Assistant> ", color="cyan", end="")
elif event.event_type == ChatCompletionResponseEventType.progress:
yield LogEvent(event.delta, color="yellow", end="")
elif event.event_type == ChatCompletionResponseEventType.complete:
yield LogEvent("")
else:
yield LogEvent("Assistant> ", color="cyan", end="")
yield LogEvent(chunk.completion_message.content, color="yellow")

View file

@ -0,0 +1,84 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
from llama_models.llama3.prompt_templates import (
BuiltinToolGenerator,
FunctionTagCustomToolGenerator,
JsonCustomToolGenerator,
SystemDefaultGenerator,
)
def prepare_messages(request: ChatCompletionRequest) -> List[Message]:
assert request.tool_choice == ToolChoice.auto, "Only `ToolChoice.auto` supported"
existing_messages = request.messages
existing_system_message = None
if existing_messages[0].role == Role.system.value:
existing_system_message = existing_messages.pop(0)
assert (
existing_messages[0].role != Role.system.value
), "Should only have 1 system message"
messages = []
default_gen = SystemDefaultGenerator()
default_template = default_gen.gen()
sys_content = ""
tool_template = None
if request.tools:
tool_gen = BuiltinToolGenerator()
tool_template = tool_gen.gen(request.tools)
sys_content += tool_template.render()
sys_content += "\n"
sys_content += default_template.render()
if existing_system_message:
# TODO: this fn is needed in many places
def _process(c):
if isinstance(c, str):
return c
else:
return "<media>"
sys_content += "\n"
if isinstance(existing_system_message.content, str):
sys_content += _process(existing_system_message.content)
elif isinstance(existing_system_message.content, list):
sys_content += "\n".join(
[_process(c) for c in existing_system_message.content]
)
messages.append(SystemMessage(content=sys_content))
has_custom_tools = any(isinstance(dfn.tool_name, str) for dfn in request.tools)
if has_custom_tools:
if request.tool_prompt_format == ToolPromptFormat.json:
tool_gen = JsonCustomToolGenerator()
elif request.tool_prompt_format == ToolPromptFormat.function_tag:
tool_gen = FunctionTagCustomToolGenerator()
else:
raise ValueError(
f"Non supported ToolPromptFormat {request.tool_prompt_format}"
)
custom_tools = [t for t in request.tools if isinstance(t.tool_name, str)]
custom_template = tool_gen.gen(custom_tools)
messages.append(UserMessage(content=custom_template.render()))
# Add back existing messages from the request
messages += existing_messages
return messages

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,26 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import base64
import mimetypes
import os
from llama_models.llama3.api.datatypes import URL
def data_url_from_file(file_path: str) -> URL:
if not os.path.exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}")
with open(file_path, "rb") as file:
file_content = file.read()
base64_content = base64.b64encode(file_content).decode("utf-8")
mime_type, _ = mimetypes.guess_type(file_path)
data_url = f"data:{mime_type};base64,{base64_content}"
return URL(uri=data_url)

View file

@ -0,0 +1,180 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import base64
import io
import re
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
from urllib.parse import unquote
import chardet
import httpx
import numpy as np
from numpy.typing import NDArray
from pypdf import PdfReader
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.apis.memory import * # noqa: F403
ALL_MINILM_L6_V2_DIMENSION = 384
EMBEDDING_MODEL = None
def get_embedding_model() -> "SentenceTransformer":
global EMBEDDING_MODEL
if EMBEDDING_MODEL is None:
print("Loading sentence transformer")
from sentence_transformers import SentenceTransformer
EMBEDDING_MODEL = SentenceTransformer("all-MiniLM-L6-v2")
return EMBEDDING_MODEL
def parse_data_url(data_url: str):
data_url_pattern = re.compile(
r"^"
r"data:"
r"(?P<mimetype>[\w/\-+.]+)"
r"(?P<charset>;charset=(?P<encoding>[\w-]+))?"
r"(?P<base64>;base64)?"
r",(?P<data>.*)"
r"$",
re.DOTALL,
)
match = data_url_pattern.match(data_url)
if not match:
raise ValueError("Invalid Data URL format")
parts = match.groupdict()
parts["is_base64"] = bool(parts["base64"])
return parts
def content_from_data(data_url: str) -> str:
parts = parse_data_url(data_url)
data = parts["data"]
if parts["is_base64"]:
data = base64.b64decode(data)
else:
data = unquote(data)
encoding = parts["encoding"] or "utf-8"
data = data.encode(encoding)
encoding = parts["encoding"]
if not encoding:
detected = chardet.detect(data)
encoding = detected["encoding"]
mime_type = parts["mimetype"]
mime_category = mime_type.split("/")[0]
if mime_category == "text":
# For text-based files (including CSV, MD)
return data.decode(encoding)
elif mime_type == "application/pdf":
# For PDF and DOC/DOCX files, we can't reliably convert to string)
pdf_bytes = io.BytesIO(data)
pdf_reader = PdfReader(pdf_bytes)
return "\n".join([page.extract_text() for page in pdf_reader.pages])
else:
cprint("Could not extract content from data_url properly.", color="red")
return ""
async def content_from_doc(doc: MemoryBankDocument) -> str:
if isinstance(doc.content, URL):
if doc.content.uri.startswith("data:"):
return content_from_data(doc.content.uri)
else:
async with httpx.AsyncClient() as client:
r = await client.get(doc.content.uri)
return r.text
return interleaved_text_media_as_str(doc.content)
def make_overlapped_chunks(
document_id: str, text: str, window_len: int, overlap_len: int
) -> List[Chunk]:
tokenizer = Tokenizer.get_instance()
tokens = tokenizer.encode(text, bos=False, eos=False)
chunks = []
for i in range(0, len(tokens), window_len - overlap_len):
toks = tokens[i : i + window_len]
chunk = tokenizer.decode(toks)
chunks.append(
Chunk(content=chunk, token_count=len(toks), document_id=document_id)
)
return chunks
class EmbeddingIndex(ABC):
@abstractmethod
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
raise NotImplementedError()
@abstractmethod
async def query(self, embedding: NDArray, k: int) -> QueryDocumentsResponse:
raise NotImplementedError()
@dataclass
class BankWithIndex:
bank: MemoryBank
index: EmbeddingIndex
async def insert_documents(
self,
documents: List[MemoryBankDocument],
) -> None:
model = get_embedding_model()
for doc in documents:
content = await content_from_doc(doc)
chunks = make_overlapped_chunks(
doc.document_id,
content,
self.bank.config.chunk_size_in_tokens,
self.bank.config.overlap_size_in_tokens
or (self.bank.config.chunk_size_in_tokens // 4),
)
embeddings = model.encode([x.content for x in chunks]).astype(np.float32)
await self.index.add_chunks(chunks, embeddings)
async def query_documents(
self,
query: InterleavedTextMedia,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse:
if params is None:
params = {}
k = params.get("max_chunks", 3)
def _process(c) -> str:
if isinstance(c, str):
return c
else:
return "<media>"
if isinstance(query, list):
query_str = " ".join([_process(c) for c in query])
else:
query_str = _process(query)
model = get_embedding_model()
query_vector = model.encode([query_str])[0].astype(np.float32)
return await self.index.query(query_vector, k)

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,236 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import base64
import logging
import queue
import threading
import uuid
from datetime import datetime
from functools import wraps
from typing import Any, Dict, List
from llama_stack.apis.telemetry import * # noqa: F403
def generate_short_uuid(len: int = 12):
full_uuid = uuid.uuid4()
uuid_bytes = full_uuid.bytes
encoded = base64.urlsafe_b64encode(uuid_bytes)
return encoded.rstrip(b"=").decode("ascii")[:len]
CURRENT_TRACE_CONTEXT = None
BACKGROUND_LOGGER = None
class BackgroundLogger:
def __init__(self, api: Telemetry, capacity: int = 1000):
self.api = api
self.log_queue = queue.Queue(maxsize=capacity)
self.worker_thread = threading.Thread(target=self._process_logs, daemon=True)
self.worker_thread.start()
def log_event(self, event):
try:
self.log_queue.put_nowait(event)
except queue.Full:
print("Log queue is full, dropping event")
def _process_logs(self):
while True:
try:
event = self.log_queue.get()
# figure out how to use a thread's native loop
asyncio.run(self.api.log_event(event))
except Exception:
import traceback
traceback.print_exc()
print("Error processing log event")
finally:
self.log_queue.task_done()
def __del__(self):
self.log_queue.join()
class TraceContext:
spans: List[Span] = []
def __init__(self, logger: BackgroundLogger, trace_id: str):
self.logger = logger
self.trace_id = trace_id
def push_span(self, name: str, attributes: Dict[str, Any] = None):
current_span = self.get_current_span()
span = Span(
span_id=generate_short_uuid(),
trace_id=self.trace_id,
name=name,
start_time=datetime.now(),
parent_span_id=current_span.span_id if current_span else None,
attributes=attributes,
)
self.logger.log_event(
StructuredLogEvent(
trace_id=span.trace_id,
span_id=span.span_id,
timestamp=span.start_time,
attributes=span.attributes,
payload=SpanStartPayload(
name=span.name,
parent_span_id=span.parent_span_id,
),
)
)
self.spans.append(span)
def pop_span(self, status: SpanStatus = SpanStatus.OK):
span = self.spans.pop()
if span is not None:
self.logger.log_event(
StructuredLogEvent(
trace_id=span.trace_id,
span_id=span.span_id,
timestamp=span.start_time,
attributes=span.attributes,
payload=SpanEndPayload(
status=status,
),
)
)
def get_current_span(self):
return self.spans[-1] if self.spans else None
def setup_logger(api: Telemetry, level: int = logging.INFO):
global BACKGROUND_LOGGER
BACKGROUND_LOGGER = BackgroundLogger(api)
logger = logging.getLogger()
logger.setLevel(level)
logger.addHandler(TelemetryHandler())
async def start_trace(name: str, attributes: Dict[str, Any] = None):
global CURRENT_TRACE_CONTEXT, BACKGROUND_LOGGER
if BACKGROUND_LOGGER is None:
print("No Telemetry implementation set. Skipping trace initialization...")
return
trace_id = generate_short_uuid()
context = TraceContext(BACKGROUND_LOGGER, trace_id)
context.push_span(name, {"__root__": True, **(attributes or {})})
CURRENT_TRACE_CONTEXT = context
async def end_trace(status: SpanStatus = SpanStatus.OK):
global CURRENT_TRACE_CONTEXT
context = CURRENT_TRACE_CONTEXT
if context is None:
return
context.pop_span(status)
CURRENT_TRACE_CONTEXT = None
def severity(levelname: str) -> LogSeverity:
if levelname == "DEBUG":
return LogSeverity.DEBUG
elif levelname == "INFO":
return LogSeverity.INFO
elif levelname == "WARNING":
return LogSeverity.WARNING
elif levelname == "ERROR":
return LogSeverity.ERROR
elif levelname == "CRITICAL":
return LogSeverity.CRITICAL
else:
raise ValueError(f"Unknown log level: {levelname}")
# TODO: ideally, the actual emitting should be done inside a separate daemon
# process completely isolated from the server
class TelemetryHandler(logging.Handler):
def emit(self, record: logging.LogRecord):
# horrendous hack to avoid logging from asyncio and getting into an infinite loop
if record.module in ("asyncio", "selector_events"):
return
global CURRENT_TRACE_CONTEXT, BACKGROUND_LOGGER
if BACKGROUND_LOGGER is None:
raise RuntimeError("Telemetry API not initialized")
context = CURRENT_TRACE_CONTEXT
if context is None:
return
span = context.get_current_span()
if span is None:
return
BACKGROUND_LOGGER.log_event(
UnstructuredLogEvent(
trace_id=span.trace_id,
span_id=span.span_id,
timestamp=datetime.now(),
message=self.format(record),
severity=severity(record.levelname),
)
)
def close(self):
pass
def span(name: str, attributes: Dict[str, Any] = None):
def decorator(func):
@wraps(func)
def sync_wrapper(*args, **kwargs):
try:
global CURRENT_TRACE_CONTEXT
context = CURRENT_TRACE_CONTEXT
if context:
context.push_span(name, attributes)
result = func(*args, **kwargs)
finally:
context.pop_span()
return result
@wraps(func)
async def async_wrapper(*args, **kwargs):
try:
global CURRENT_TRACE_CONTEXT
context = CURRENT_TRACE_CONTEXT
if context:
context.push_span(name, attributes)
result = await func(*args, **kwargs)
finally:
context.pop_span()
return result
@wraps(func)
def wrapper(*args, **kwargs):
if asyncio.iscoroutinefunction(func):
return async_wrapper(*args, **kwargs)
else:
return sync_wrapper(*args, **kwargs)
return wrapper
return decorator