forked from phoenix-oss/llama-stack-mirror
API Updates (#73)
* API Keys passed from Client instead of distro configuration * delete distribution registry * Rename the "package" word away * Introduce a "Router" layer for providers Some providers need to be factorized and considered as thin routing layers on top of other providers. Consider two examples: - The inference API should be a routing layer over inference providers, routed using the "model" key - The memory banks API is another instance where various memory bank types will be provided by independent providers (e.g., a vector store is served by Chroma while a keyvalue memory can be served by Redis or PGVector) This commit introduces a generalized routing layer for this purpose. * update `apis_to_serve` * llama_toolchain -> llama_stack * Codemod from llama_toolchain -> llama_stack - added providers/registry - cleaned up api/ subdirectories and moved impls away - restructured api/api.py - from llama_stack.apis.<api> import foo should work now - update imports to do llama_stack.apis.<api> - update many other imports - added __init__, fixed some registry imports - updated registry imports - create_agentic_system -> create_agent - AgenticSystem -> Agent * Moved some stuff out of common/; re-generated OpenAPI spec * llama-toolchain -> llama-stack (hyphens) * add control plane API * add redis adapter + sqlite provider * move core -> distribution * Some more toolchain -> stack changes * small naming shenanigans * Removing custom tool and agent utilities and moving them client side * Move control plane to distribution server for now * Remove control plane from API list * no codeshield dependency randomly plzzzzz * Add "fire" as a dependency * add back event loggers * stack configure fixes * use brave instead of bing in the example client * add init file so it gets packaged * add init files so it gets packaged * Update MANIFEST * bug fix --------- Co-authored-by: Hardik Shah <hjshah@fb.com> Co-authored-by: Xi Yan <xiyan@meta.com> Co-authored-by: Ashwin Bharambe <ashwin@meta.com>
This commit is contained in:
parent
f294eac5f5
commit
9487ad8294
213 changed files with 1725 additions and 1204 deletions
5
llama_stack/providers/utils/__init__.py
Normal file
5
llama_stack/providers/utils/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
5
llama_stack/providers/utils/inference/__init__.py
Normal file
5
llama_stack/providers/utils/inference/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
84
llama_stack/providers/utils/inference/prepare_messages.py
Normal file
84
llama_stack/providers/utils/inference/prepare_messages.py
Normal file
|
@ -0,0 +1,84 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_models.llama3.prompt_templates import (
|
||||
BuiltinToolGenerator,
|
||||
FunctionTagCustomToolGenerator,
|
||||
JsonCustomToolGenerator,
|
||||
SystemDefaultGenerator,
|
||||
)
|
||||
|
||||
|
||||
def prepare_messages(request: ChatCompletionRequest) -> List[Message]:
|
||||
|
||||
assert request.tool_choice == ToolChoice.auto, "Only `ToolChoice.auto` supported"
|
||||
|
||||
existing_messages = request.messages
|
||||
existing_system_message = None
|
||||
if existing_messages[0].role == Role.system.value:
|
||||
existing_system_message = existing_messages.pop(0)
|
||||
|
||||
assert (
|
||||
existing_messages[0].role != Role.system.value
|
||||
), "Should only have 1 system message"
|
||||
|
||||
messages = []
|
||||
|
||||
default_gen = SystemDefaultGenerator()
|
||||
default_template = default_gen.gen()
|
||||
|
||||
sys_content = ""
|
||||
|
||||
tool_template = None
|
||||
if request.tools:
|
||||
tool_gen = BuiltinToolGenerator()
|
||||
tool_template = tool_gen.gen(request.tools)
|
||||
|
||||
sys_content += tool_template.render()
|
||||
sys_content += "\n"
|
||||
|
||||
sys_content += default_template.render()
|
||||
|
||||
if existing_system_message:
|
||||
# TODO: this fn is needed in many places
|
||||
def _process(c):
|
||||
if isinstance(c, str):
|
||||
return c
|
||||
else:
|
||||
return "<media>"
|
||||
|
||||
sys_content += "\n"
|
||||
|
||||
if isinstance(existing_system_message.content, str):
|
||||
sys_content += _process(existing_system_message.content)
|
||||
elif isinstance(existing_system_message.content, list):
|
||||
sys_content += "\n".join(
|
||||
[_process(c) for c in existing_system_message.content]
|
||||
)
|
||||
|
||||
messages.append(SystemMessage(content=sys_content))
|
||||
|
||||
has_custom_tools = any(isinstance(dfn.tool_name, str) for dfn in request.tools)
|
||||
if has_custom_tools:
|
||||
if request.tool_prompt_format == ToolPromptFormat.json:
|
||||
tool_gen = JsonCustomToolGenerator()
|
||||
elif request.tool_prompt_format == ToolPromptFormat.function_tag:
|
||||
tool_gen = FunctionTagCustomToolGenerator()
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Non supported ToolPromptFormat {request.tool_prompt_format}"
|
||||
)
|
||||
|
||||
custom_tools = [t for t in request.tools if isinstance(t.tool_name, str)]
|
||||
custom_template = tool_gen.gen(custom_tools)
|
||||
messages.append(UserMessage(content=custom_template.render()))
|
||||
|
||||
# Add back existing messages from the request
|
||||
messages += existing_messages
|
||||
|
||||
return messages
|
5
llama_stack/providers/utils/memory/__init__.py
Normal file
5
llama_stack/providers/utils/memory/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
26
llama_stack/providers/utils/memory/file_utils.py
Normal file
26
llama_stack/providers/utils/memory/file_utils.py
Normal file
|
@ -0,0 +1,26 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import base64
|
||||
import mimetypes
|
||||
import os
|
||||
|
||||
from llama_models.llama3.api.datatypes import URL
|
||||
|
||||
|
||||
def data_url_from_file(file_path: str) -> URL:
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
with open(file_path, "rb") as file:
|
||||
file_content = file.read()
|
||||
|
||||
base64_content = base64.b64encode(file_content).decode("utf-8")
|
||||
mime_type, _ = mimetypes.guess_type(file_path)
|
||||
|
||||
data_url = f"data:{mime_type};base64,{base64_content}"
|
||||
|
||||
return URL(uri=data_url)
|
180
llama_stack/providers/utils/memory/vector_store.py
Normal file
180
llama_stack/providers/utils/memory/vector_store.py
Normal file
|
@ -0,0 +1,180 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import base64
|
||||
import io
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
from urllib.parse import unquote
|
||||
|
||||
import chardet
|
||||
import httpx
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
from pypdf import PdfReader
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
|
||||
from llama_stack.apis.memory import * # noqa: F403
|
||||
|
||||
ALL_MINILM_L6_V2_DIMENSION = 384
|
||||
|
||||
EMBEDDING_MODEL = None
|
||||
|
||||
|
||||
def get_embedding_model() -> "SentenceTransformer":
|
||||
global EMBEDDING_MODEL
|
||||
|
||||
if EMBEDDING_MODEL is None:
|
||||
print("Loading sentence transformer")
|
||||
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
EMBEDDING_MODEL = SentenceTransformer("all-MiniLM-L6-v2")
|
||||
|
||||
return EMBEDDING_MODEL
|
||||
|
||||
|
||||
def parse_data_url(data_url: str):
|
||||
data_url_pattern = re.compile(
|
||||
r"^"
|
||||
r"data:"
|
||||
r"(?P<mimetype>[\w/\-+.]+)"
|
||||
r"(?P<charset>;charset=(?P<encoding>[\w-]+))?"
|
||||
r"(?P<base64>;base64)?"
|
||||
r",(?P<data>.*)"
|
||||
r"$",
|
||||
re.DOTALL,
|
||||
)
|
||||
match = data_url_pattern.match(data_url)
|
||||
if not match:
|
||||
raise ValueError("Invalid Data URL format")
|
||||
|
||||
parts = match.groupdict()
|
||||
parts["is_base64"] = bool(parts["base64"])
|
||||
return parts
|
||||
|
||||
|
||||
def content_from_data(data_url: str) -> str:
|
||||
parts = parse_data_url(data_url)
|
||||
data = parts["data"]
|
||||
|
||||
if parts["is_base64"]:
|
||||
data = base64.b64decode(data)
|
||||
else:
|
||||
data = unquote(data)
|
||||
encoding = parts["encoding"] or "utf-8"
|
||||
data = data.encode(encoding)
|
||||
|
||||
encoding = parts["encoding"]
|
||||
if not encoding:
|
||||
detected = chardet.detect(data)
|
||||
encoding = detected["encoding"]
|
||||
|
||||
mime_type = parts["mimetype"]
|
||||
mime_category = mime_type.split("/")[0]
|
||||
if mime_category == "text":
|
||||
# For text-based files (including CSV, MD)
|
||||
return data.decode(encoding)
|
||||
|
||||
elif mime_type == "application/pdf":
|
||||
# For PDF and DOC/DOCX files, we can't reliably convert to string)
|
||||
pdf_bytes = io.BytesIO(data)
|
||||
pdf_reader = PdfReader(pdf_bytes)
|
||||
return "\n".join([page.extract_text() for page in pdf_reader.pages])
|
||||
|
||||
else:
|
||||
cprint("Could not extract content from data_url properly.", color="red")
|
||||
return ""
|
||||
|
||||
|
||||
async def content_from_doc(doc: MemoryBankDocument) -> str:
|
||||
if isinstance(doc.content, URL):
|
||||
if doc.content.uri.startswith("data:"):
|
||||
return content_from_data(doc.content.uri)
|
||||
else:
|
||||
async with httpx.AsyncClient() as client:
|
||||
r = await client.get(doc.content.uri)
|
||||
return r.text
|
||||
|
||||
return interleaved_text_media_as_str(doc.content)
|
||||
|
||||
|
||||
def make_overlapped_chunks(
|
||||
document_id: str, text: str, window_len: int, overlap_len: int
|
||||
) -> List[Chunk]:
|
||||
tokenizer = Tokenizer.get_instance()
|
||||
tokens = tokenizer.encode(text, bos=False, eos=False)
|
||||
|
||||
chunks = []
|
||||
for i in range(0, len(tokens), window_len - overlap_len):
|
||||
toks = tokens[i : i + window_len]
|
||||
chunk = tokenizer.decode(toks)
|
||||
chunks.append(
|
||||
Chunk(content=chunk, token_count=len(toks), document_id=document_id)
|
||||
)
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
class EmbeddingIndex(ABC):
|
||||
@abstractmethod
|
||||
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
async def query(self, embedding: NDArray, k: int) -> QueryDocumentsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@dataclass
|
||||
class BankWithIndex:
|
||||
bank: MemoryBank
|
||||
index: EmbeddingIndex
|
||||
|
||||
async def insert_documents(
|
||||
self,
|
||||
documents: List[MemoryBankDocument],
|
||||
) -> None:
|
||||
model = get_embedding_model()
|
||||
for doc in documents:
|
||||
content = await content_from_doc(doc)
|
||||
chunks = make_overlapped_chunks(
|
||||
doc.document_id,
|
||||
content,
|
||||
self.bank.config.chunk_size_in_tokens,
|
||||
self.bank.config.overlap_size_in_tokens
|
||||
or (self.bank.config.chunk_size_in_tokens // 4),
|
||||
)
|
||||
embeddings = model.encode([x.content for x in chunks]).astype(np.float32)
|
||||
|
||||
await self.index.add_chunks(chunks, embeddings)
|
||||
|
||||
async def query_documents(
|
||||
self,
|
||||
query: InterleavedTextMedia,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> QueryDocumentsResponse:
|
||||
if params is None:
|
||||
params = {}
|
||||
k = params.get("max_chunks", 3)
|
||||
|
||||
def _process(c) -> str:
|
||||
if isinstance(c, str):
|
||||
return c
|
||||
else:
|
||||
return "<media>"
|
||||
|
||||
if isinstance(query, list):
|
||||
query_str = " ".join([_process(c) for c in query])
|
||||
else:
|
||||
query_str = _process(query)
|
||||
|
||||
model = get_embedding_model()
|
||||
query_vector = model.encode([query_str])[0].astype(np.float32)
|
||||
return await self.index.query(query_vector, k)
|
5
llama_stack/providers/utils/telemetry/__init__.py
Normal file
5
llama_stack/providers/utils/telemetry/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
236
llama_stack/providers/utils/telemetry/tracing.py
Normal file
236
llama_stack/providers/utils/telemetry/tracing.py
Normal file
|
@ -0,0 +1,236 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import logging
|
||||
import queue
|
||||
import threading
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from functools import wraps
|
||||
from typing import Any, Dict, List
|
||||
|
||||
|
||||
from llama_stack.apis.telemetry import * # noqa: F403
|
||||
|
||||
|
||||
def generate_short_uuid(len: int = 12):
|
||||
full_uuid = uuid.uuid4()
|
||||
uuid_bytes = full_uuid.bytes
|
||||
encoded = base64.urlsafe_b64encode(uuid_bytes)
|
||||
return encoded.rstrip(b"=").decode("ascii")[:len]
|
||||
|
||||
|
||||
CURRENT_TRACE_CONTEXT = None
|
||||
BACKGROUND_LOGGER = None
|
||||
|
||||
|
||||
class BackgroundLogger:
|
||||
def __init__(self, api: Telemetry, capacity: int = 1000):
|
||||
self.api = api
|
||||
self.log_queue = queue.Queue(maxsize=capacity)
|
||||
self.worker_thread = threading.Thread(target=self._process_logs, daemon=True)
|
||||
self.worker_thread.start()
|
||||
|
||||
def log_event(self, event):
|
||||
try:
|
||||
self.log_queue.put_nowait(event)
|
||||
except queue.Full:
|
||||
print("Log queue is full, dropping event")
|
||||
|
||||
def _process_logs(self):
|
||||
while True:
|
||||
try:
|
||||
event = self.log_queue.get()
|
||||
# figure out how to use a thread's native loop
|
||||
asyncio.run(self.api.log_event(event))
|
||||
except Exception:
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
print("Error processing log event")
|
||||
finally:
|
||||
self.log_queue.task_done()
|
||||
|
||||
def __del__(self):
|
||||
self.log_queue.join()
|
||||
|
||||
|
||||
class TraceContext:
|
||||
spans: List[Span] = []
|
||||
|
||||
def __init__(self, logger: BackgroundLogger, trace_id: str):
|
||||
self.logger = logger
|
||||
self.trace_id = trace_id
|
||||
|
||||
def push_span(self, name: str, attributes: Dict[str, Any] = None):
|
||||
current_span = self.get_current_span()
|
||||
span = Span(
|
||||
span_id=generate_short_uuid(),
|
||||
trace_id=self.trace_id,
|
||||
name=name,
|
||||
start_time=datetime.now(),
|
||||
parent_span_id=current_span.span_id if current_span else None,
|
||||
attributes=attributes,
|
||||
)
|
||||
|
||||
self.logger.log_event(
|
||||
StructuredLogEvent(
|
||||
trace_id=span.trace_id,
|
||||
span_id=span.span_id,
|
||||
timestamp=span.start_time,
|
||||
attributes=span.attributes,
|
||||
payload=SpanStartPayload(
|
||||
name=span.name,
|
||||
parent_span_id=span.parent_span_id,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
self.spans.append(span)
|
||||
|
||||
def pop_span(self, status: SpanStatus = SpanStatus.OK):
|
||||
span = self.spans.pop()
|
||||
if span is not None:
|
||||
self.logger.log_event(
|
||||
StructuredLogEvent(
|
||||
trace_id=span.trace_id,
|
||||
span_id=span.span_id,
|
||||
timestamp=span.start_time,
|
||||
attributes=span.attributes,
|
||||
payload=SpanEndPayload(
|
||||
status=status,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
def get_current_span(self):
|
||||
return self.spans[-1] if self.spans else None
|
||||
|
||||
|
||||
def setup_logger(api: Telemetry, level: int = logging.INFO):
|
||||
global BACKGROUND_LOGGER
|
||||
|
||||
BACKGROUND_LOGGER = BackgroundLogger(api)
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(level)
|
||||
logger.addHandler(TelemetryHandler())
|
||||
|
||||
|
||||
async def start_trace(name: str, attributes: Dict[str, Any] = None):
|
||||
global CURRENT_TRACE_CONTEXT, BACKGROUND_LOGGER
|
||||
|
||||
if BACKGROUND_LOGGER is None:
|
||||
print("No Telemetry implementation set. Skipping trace initialization...")
|
||||
return
|
||||
|
||||
trace_id = generate_short_uuid()
|
||||
context = TraceContext(BACKGROUND_LOGGER, trace_id)
|
||||
context.push_span(name, {"__root__": True, **(attributes or {})})
|
||||
|
||||
CURRENT_TRACE_CONTEXT = context
|
||||
|
||||
|
||||
async def end_trace(status: SpanStatus = SpanStatus.OK):
|
||||
global CURRENT_TRACE_CONTEXT
|
||||
|
||||
context = CURRENT_TRACE_CONTEXT
|
||||
if context is None:
|
||||
return
|
||||
|
||||
context.pop_span(status)
|
||||
CURRENT_TRACE_CONTEXT = None
|
||||
|
||||
|
||||
def severity(levelname: str) -> LogSeverity:
|
||||
if levelname == "DEBUG":
|
||||
return LogSeverity.DEBUG
|
||||
elif levelname == "INFO":
|
||||
return LogSeverity.INFO
|
||||
elif levelname == "WARNING":
|
||||
return LogSeverity.WARNING
|
||||
elif levelname == "ERROR":
|
||||
return LogSeverity.ERROR
|
||||
elif levelname == "CRITICAL":
|
||||
return LogSeverity.CRITICAL
|
||||
else:
|
||||
raise ValueError(f"Unknown log level: {levelname}")
|
||||
|
||||
|
||||
# TODO: ideally, the actual emitting should be done inside a separate daemon
|
||||
# process completely isolated from the server
|
||||
class TelemetryHandler(logging.Handler):
|
||||
def emit(self, record: logging.LogRecord):
|
||||
# horrendous hack to avoid logging from asyncio and getting into an infinite loop
|
||||
if record.module in ("asyncio", "selector_events"):
|
||||
return
|
||||
|
||||
global CURRENT_TRACE_CONTEXT, BACKGROUND_LOGGER
|
||||
|
||||
if BACKGROUND_LOGGER is None:
|
||||
raise RuntimeError("Telemetry API not initialized")
|
||||
|
||||
context = CURRENT_TRACE_CONTEXT
|
||||
if context is None:
|
||||
return
|
||||
|
||||
span = context.get_current_span()
|
||||
if span is None:
|
||||
return
|
||||
|
||||
BACKGROUND_LOGGER.log_event(
|
||||
UnstructuredLogEvent(
|
||||
trace_id=span.trace_id,
|
||||
span_id=span.span_id,
|
||||
timestamp=datetime.now(),
|
||||
message=self.format(record),
|
||||
severity=severity(record.levelname),
|
||||
)
|
||||
)
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
|
||||
def span(name: str, attributes: Dict[str, Any] = None):
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
try:
|
||||
global CURRENT_TRACE_CONTEXT
|
||||
|
||||
context = CURRENT_TRACE_CONTEXT
|
||||
if context:
|
||||
context.push_span(name, attributes)
|
||||
result = func(*args, **kwargs)
|
||||
finally:
|
||||
context.pop_span()
|
||||
return result
|
||||
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
try:
|
||||
global CURRENT_TRACE_CONTEXT
|
||||
|
||||
context = CURRENT_TRACE_CONTEXT
|
||||
if context:
|
||||
context.push_span(name, attributes)
|
||||
result = await func(*args, **kwargs)
|
||||
finally:
|
||||
context.pop_span()
|
||||
return result
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
return async_wrapper(*args, **kwargs)
|
||||
else:
|
||||
return sync_wrapper(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
Loading…
Add table
Add a link
Reference in a new issue