forked from phoenix-oss/llama-stack-mirror
# What does this PR do? Change the Telemetry API to be able to support different use cases like returning traces for the UI and ability to export for Evals. Other changes: * Add a new trace_protocol decorator to decorate all our API methods so that any call to them will automatically get traced across all impls. * There is some issue with the decorator pattern of span creation when using async generators, where there are multiple yields with in the same context. I think its much more explicit by using the explicit context manager pattern using with. I moved the span creations in agent instance to be using with * Inject session id at the turn level, which should quickly give us all traces across turns for a given session Addresses #509 ## Test Plan ``` llama stack run /Users/dineshyv/.llama/distributions/llamastack-together/together-run.yaml PYTHONPATH=. python -m examples.agents.rag_with_memory_bank localhost 5000 curl -X POST 'http://localhost:5000/alpha/telemetry/query-traces' \ -H 'Content-Type: application/json' \ -d '{ "attribute_filters": [ { "key": "session_id", "op": "eq", "value": "dd667b87-ca4b-4d30-9265-5a0de318fc65" }], "limit": 100, "offset": 0, "order_by": ["start_time"] }' | jq . [ { "trace_id": "6902f54b83b4b48be18a6f422b13e16f", "root_span_id": "5f37b85543afc15a", "start_time": "2024-12-04T08:08:30.501587", "end_time": "2024-12-04T08:08:36.026463" }, { "trace_id": "92227dac84c0615ed741be393813fb5f", "root_span_id": "af7c5bb46665c2c8", "start_time": "2024-12-04T08:08:36.031170", "end_time": "2024-12-04T08:08:41.693301" }, { "trace_id": "7d578a6edac62f204ab479fba82f77b6", "root_span_id": "1d935e3362676896", "start_time": "2024-12-04T08:08:41.695204", "end_time": "2024-12-04T08:08:47.228016" }, { "trace_id": "dbd767d76991bc816f9f078907dc9ff2", "root_span_id": "f5a7ee76683b9602", "start_time": "2024-12-04T08:08:47.234578", "end_time": "2024-12-04T08:08:53.189412" } ] curl -X POST 'http://localhost:5000/alpha/telemetry/get-span-tree' \ -H 'Content-Type: application/json' \ -d '{ "span_id" : "6cceb4b48a156913", "max_depth": 2, "attributes_to_return": ["input"] }' | jq . % Total % Received % Xferd Average Speed Time Time Time Current Dload Upload Total Spent Left Speed 100 875 100 790 100 85 18462 1986 --:--:-- --:--:-- --:--:-- 20833 { "span_id": "6cceb4b48a156913", "trace_id": "dafa796f6aaf925f511c04cd7c67fdda", "parent_span_id": "892a66d726c7f990", "name": "retrieve_rag_context", "start_time": "2024-12-04T09:28:21.781995", "end_time": "2024-12-04T09:28:21.913352", "attributes": { "input": [ "{\"role\":\"system\",\"content\":\"You are a helpful assistant\"}", "{\"role\":\"user\",\"content\":\"What are the top 5 topics that were explained in the documentation? Only list succinct bullet points.\",\"context\":null}" ] }, "children": [ { "span_id": "1a2df181854064a8", "trace_id": "dafa796f6aaf925f511c04cd7c67fdda", "parent_span_id": "6cceb4b48a156913", "name": "MemoryRouter.query_documents", "start_time": "2024-12-04T09:28:21.787620", "end_time": "2024-12-04T09:28:21.906512", "attributes": { "input": null }, "children": [], "status": "ok" } ], "status": "ok" } ``` <img width="1677" alt="Screenshot 2024-12-04 at 9 42 56 AM" src="https://github.com/user-attachments/assets/4d3cea93-05ce-415a-93d9-4b1628631bf8">
174 lines
5.7 KiB
Python
174 lines
5.7 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
import pandas
|
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
|
|
|
from llama_stack.apis.datasetio import * # noqa: F403
|
|
import base64
|
|
import os
|
|
from abc import ABC, abstractmethod
|
|
from dataclasses import dataclass
|
|
from urllib.parse import urlparse
|
|
|
|
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
|
|
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url
|
|
|
|
from .config import LocalFSDatasetIOConfig
|
|
|
|
|
|
class BaseDataset(ABC):
|
|
def __init__(self, *args, **kwargs) -> None:
|
|
super().__init__(*args, **kwargs)
|
|
|
|
@abstractmethod
|
|
def __len__(self) -> int:
|
|
raise NotImplementedError()
|
|
|
|
@abstractmethod
|
|
def __getitem__(self, idx):
|
|
raise NotImplementedError()
|
|
|
|
@abstractmethod
|
|
def load(self):
|
|
raise NotImplementedError()
|
|
|
|
|
|
@dataclass
|
|
class DatasetInfo:
|
|
dataset_def: Dataset
|
|
dataset_impl: BaseDataset
|
|
|
|
|
|
class PandasDataframeDataset(BaseDataset):
|
|
def __init__(self, dataset_def: Dataset, *args, **kwargs) -> None:
|
|
super().__init__(*args, **kwargs)
|
|
self.dataset_def = dataset_def
|
|
self.df = None
|
|
|
|
def __len__(self) -> int:
|
|
assert self.df is not None, "Dataset not loaded. Please call .load() first"
|
|
return len(self.df)
|
|
|
|
def __getitem__(self, idx):
|
|
assert self.df is not None, "Dataset not loaded. Please call .load() first"
|
|
if isinstance(idx, slice):
|
|
return self.df.iloc[idx].to_dict(orient="records")
|
|
else:
|
|
return self.df.iloc[idx].to_dict()
|
|
|
|
def _validate_dataset_schema(self, df) -> pandas.DataFrame:
|
|
# note that we will drop any columns in dataset that are not in the schema
|
|
df = df[self.dataset_def.dataset_schema.keys()]
|
|
# check all columns in dataset schema are present
|
|
assert len(df.columns) == len(self.dataset_def.dataset_schema)
|
|
# TODO: type checking against column types in dataset schema
|
|
return df
|
|
|
|
def load(self) -> None:
|
|
if self.df is not None:
|
|
return
|
|
|
|
df = get_dataframe_from_url(self.dataset_def.url)
|
|
if df is None:
|
|
raise ValueError(f"Failed to load dataset from {self.dataset_def.url}")
|
|
|
|
self.df = self._validate_dataset_schema(df)
|
|
|
|
|
|
class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
|
def __init__(self, config: LocalFSDatasetIOConfig) -> None:
|
|
self.config = config
|
|
# local registry for keeping track of datasets within the provider
|
|
self.dataset_infos = {}
|
|
|
|
async def initialize(self) -> None: ...
|
|
|
|
async def shutdown(self) -> None: ...
|
|
|
|
async def register_dataset(
|
|
self,
|
|
dataset: Dataset,
|
|
) -> None:
|
|
dataset_impl = PandasDataframeDataset(dataset)
|
|
self.dataset_infos[dataset.identifier] = DatasetInfo(
|
|
dataset_def=dataset,
|
|
dataset_impl=dataset_impl,
|
|
)
|
|
|
|
async def unregister_dataset(self, dataset_id: str) -> None:
|
|
del self.dataset_infos[dataset_id]
|
|
|
|
async def get_rows_paginated(
|
|
self,
|
|
dataset_id: str,
|
|
rows_in_page: int,
|
|
page_token: Optional[str] = None,
|
|
filter_condition: Optional[str] = None,
|
|
) -> PaginatedRowsResult:
|
|
dataset_info = self.dataset_infos.get(dataset_id)
|
|
dataset_info.dataset_impl.load()
|
|
|
|
if page_token and not page_token.isnumeric():
|
|
raise ValueError("Invalid page_token")
|
|
|
|
if page_token is None or len(page_token) == 0:
|
|
next_page_token = 0
|
|
else:
|
|
next_page_token = int(page_token)
|
|
|
|
start = next_page_token
|
|
if rows_in_page == -1:
|
|
end = len(dataset_info.dataset_impl)
|
|
else:
|
|
end = min(start + rows_in_page, len(dataset_info.dataset_impl))
|
|
|
|
rows = dataset_info.dataset_impl[start:end]
|
|
|
|
return PaginatedRowsResult(
|
|
rows=rows,
|
|
total_count=len(rows),
|
|
next_page_token=str(end),
|
|
)
|
|
|
|
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
|
|
dataset_info = self.dataset_infos.get(dataset_id)
|
|
if dataset_info is None:
|
|
raise ValueError(f"Dataset with id {dataset_id} not found")
|
|
|
|
dataset_impl = dataset_info.dataset_impl
|
|
dataset_impl.load()
|
|
|
|
new_rows_df = pandas.DataFrame(rows)
|
|
new_rows_df = dataset_impl._validate_dataset_schema(new_rows_df)
|
|
dataset_impl.df = pandas.concat(
|
|
[dataset_impl.df, new_rows_df], ignore_index=True
|
|
)
|
|
|
|
url = str(dataset_info.dataset_def.url)
|
|
parsed_url = urlparse(url)
|
|
|
|
if parsed_url.scheme == "file" or not parsed_url.scheme:
|
|
file_path = parsed_url.path
|
|
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
|
dataset_impl.df.to_csv(file_path, index=False)
|
|
elif parsed_url.scheme == "data":
|
|
# For data URLs, we need to update the base64-encoded content
|
|
if not parsed_url.path.startswith("text/csv;base64,"):
|
|
raise ValueError("Data URL must be a base64-encoded CSV")
|
|
|
|
csv_buffer = dataset_impl.df.to_csv(index=False)
|
|
base64_content = base64.b64encode(csv_buffer.encode("utf-8")).decode(
|
|
"utf-8"
|
|
)
|
|
dataset_info.dataset_def.url = URL(
|
|
uri=f"data:text/csv;base64,{base64_content}"
|
|
)
|
|
else:
|
|
raise ValueError(
|
|
f"Unsupported URL scheme: {parsed_url.scheme}. Only file:// and data: URLs are supported for writing."
|
|
)
|