diff --git a/llama_stack/apis/datasetio/datasetio.py b/llama_stack/apis/datasetio/datasetio.py index c5052877a..2340ab377 100644 --- a/llama_stack/apis/datasetio/datasetio.py +++ b/llama_stack/apis/datasetio/datasetio.py @@ -37,3 +37,8 @@ class DatasetIO(Protocol): page_token: Optional[str] = None, filter_condition: Optional[str] = None, ) -> PaginatedRowsResult: ... + + @webmethod(route="/datasetio/upload", method="POST") + async def upload_rows( + self, dataset_id: str, rows: List[Dict[str, Any]] + ) -> None: ... diff --git a/llama_stack/apis/telemetry/telemetry.py b/llama_stack/apis/telemetry/telemetry.py index b0550f848..864380e9f 100644 --- a/llama_stack/apis/telemetry/telemetry.py +++ b/llama_stack/apis/telemetry/telemetry.py @@ -139,6 +139,7 @@ Event = Annotated[ @json_schema_type class EvalTrace(BaseModel): + session_id: str step: str input: str output: str diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 5a62b6d64..4e5d83763 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -222,6 +222,12 @@ class DatasetIORouter(DatasetIO): filter_condition=filter_condition, ) + async def upload_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: + return await self.routing_table.get_provider_impl(dataset_id).upload_rows( + dataset_id=dataset_id, + rows=rows, + ) + class ScoringRouter(Scoring): def __init__( diff --git a/llama_stack/providers/inline/datasetio/localfs/datasetio.py b/llama_stack/providers/inline/datasetio/localfs/datasetio.py index 4de1850ae..bbda6d8bf 100644 --- a/llama_stack/providers/inline/datasetio/localfs/datasetio.py +++ b/llama_stack/providers/inline/datasetio/localfs/datasetio.py @@ -3,7 +3,7 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Optional +from typing import Any, Dict, List, Optional import pandas from llama_models.llama3.api.datatypes import * # noqa: F403 @@ -11,6 +11,7 @@ from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.apis.datasetio import * # noqa: F403 from abc import ABC, abstractmethod from dataclasses import dataclass +from urllib.parse import urlparse from llama_stack.providers.datatypes import DatasetsProtocolPrivate from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url @@ -128,3 +129,29 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): total_count=len(rows), next_page_token=str(end), ) + + async def upload_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: + dataset_info = self.dataset_infos.get(dataset_id) + if dataset_info is None: + raise ValueError(f"Dataset with id {dataset_id} not found") + + dataset_impl = dataset_info.dataset_impl + dataset_impl.load() + + new_rows_df = pandas.DataFrame(rows) + + new_rows_df = dataset_impl._validate_dataset_schema(new_rows_df) + + dataset_impl.df = pandas.concat( + [dataset_impl.df, new_rows_df], ignore_index=True + ) + + url = str(dataset_info.dataset_def.url) + parsed_url = urlparse(url) + if parsed_url.scheme == "file" or not parsed_url.scheme: + file_path = parsed_url.path + dataset_impl.df.to_csv(file_path, index=False) + else: + raise ValueError( + f"Unsupported URL scheme: {parsed_url.scheme}. Only file:// URLs are supported for writing." + ) diff --git a/llama_stack/providers/registry/telemetry.py b/llama_stack/providers/registry/telemetry.py index ac537e076..94b63ab1c 100644 --- a/llama_stack/providers/registry/telemetry.py +++ b/llama_stack/providers/registry/telemetry.py @@ -27,8 +27,10 @@ def available_providers() -> List[ProviderSpec]: config_class="llama_stack.providers.remote.telemetry.sample.SampleConfig", ), ), - remote_provider_spec( + RemoteProviderSpec( api=Api.telemetry, + provider_type="remote::opentelemetry-jaeger", + config_class="llama_stack.providers.remote.telemetry.opentelemetry.OpenTelemetryConfig", adapter=AdapterSpec( adapter_type="opentelemetry-jaeger", pip_packages=[ @@ -40,5 +42,8 @@ def available_providers() -> List[ProviderSpec]: module="llama_stack.providers.remote.telemetry.opentelemetry", config_class="llama_stack.providers.remote.telemetry.opentelemetry.OpenTelemetryConfig", ), + api_dependencies=[ + Api.datasetio, + ], ), ] diff --git a/llama_stack/providers/remote/datasetio/huggingface/huggingface.py b/llama_stack/providers/remote/datasetio/huggingface/huggingface.py index c2e4506bf..f43a1991a 100644 --- a/llama_stack/providers/remote/datasetio/huggingface/huggingface.py +++ b/llama_stack/providers/remote/datasetio/huggingface/huggingface.py @@ -3,7 +3,7 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Optional +from typing import Any, Dict, List, Optional from llama_stack.apis.datasetio import * # noqa: F403 @@ -95,3 +95,22 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): total_count=len(rows), next_page_token=str(end), ) + + async def upload_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: + dataset_def = self.dataset_infos[dataset_id] + loaded_dataset = load_hf_dataset(dataset_def) + + # Convert rows to HF Dataset format + new_dataset = hf_datasets.Dataset.from_list(rows) + + # Concatenate the new rows with existing dataset + updated_dataset = hf_datasets.concatenate_datasets( + [loaded_dataset, new_dataset] + ) + + if dataset_def.metadata.get("path", None): + updated_dataset.push_to_hub(dataset_def.metadata["path"]) + else: + raise NotImplementedError( + "Uploading to URL-based datasets is not supported yet" + ) diff --git a/llama_stack/providers/remote/telemetry/opentelemetry/__init__.py b/llama_stack/providers/remote/telemetry/opentelemetry/__init__.py index 0842afe2d..694d3cef8 100644 --- a/llama_stack/providers/remote/telemetry/opentelemetry/__init__.py +++ b/llama_stack/providers/remote/telemetry/opentelemetry/__init__.py @@ -7,9 +7,10 @@ from .config import OpenTelemetryConfig -async def get_adapter_impl(config: OpenTelemetryConfig, _deps): +async def get_adapter_impl(config: OpenTelemetryConfig, deps): from .opentelemetry import OpenTelemetryAdapter - impl = OpenTelemetryAdapter(config) + print(f"deps: {deps}") + impl = OpenTelemetryAdapter(config, deps) await impl.initialize() return impl diff --git a/llama_stack/providers/remote/telemetry/opentelemetry/opentelemetry.py b/llama_stack/providers/remote/telemetry/opentelemetry/opentelemetry.py index 89d669ab1..d944e0b4f 100644 --- a/llama_stack/providers/remote/telemetry/opentelemetry/opentelemetry.py +++ b/llama_stack/providers/remote/telemetry/opentelemetry/opentelemetry.py @@ -19,6 +19,8 @@ from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import BatchSpanProcessor from opentelemetry.semconv.resource import ResourceAttributes +from llama_stack.distribution.datatypes import Api + from llama_stack.apis.telemetry import * # noqa: F403 @@ -49,8 +51,9 @@ def is_tracing_enabled(tracer): class OpenTelemetryAdapter(Telemetry): - def __init__(self, config: OpenTelemetryConfig): + def __init__(self, config: OpenTelemetryConfig, deps) -> None: self.config = config + self.datasetio = deps[Api.datasetio] resource = Resource.create( { @@ -230,18 +233,29 @@ class OpenTelemetryAdapter(Telemetry): traces_data = await response.json() seen_trace_ids = set() - # For each trace ID, get the detailed trace information for trace_data in traces_data.get("data", []): trace_id = trace_data.get("traceID") if trace_id and trace_id not in seen_trace_ids: seen_trace_ids.add(trace_id) - trace_details = await self.get_trace_for_eval(trace_id) - if trace_details: - traces.append(trace_details) + trace_details = await self.get_trace_for_eval( + trace_id, session_id + ) + traces.extend(trace_details) except Exception as e: raise Exception(f"Error querying Jaeger traces: {str(e)}") from e + if dataset_id: + traces_dict = [ + { + "step": trace.step, + "input": trace.input, + "output": trace.output, + "session_id": trace.session_id, + } + for trace in traces + ] + await self.datasetio.upload_rows(dataset_id, traces_dict) return traces async def get_trace(self, trace_id: str) -> Dict[str, Any]: @@ -311,7 +325,9 @@ class OpenTelemetryAdapter(Telemetry): except Exception as e: raise Exception(f"Error querying Jaeger trace structure: {str(e)}") from e - async def get_trace_for_eval(self, trace_id: str) -> List[EvalTrace]: + async def get_trace_for_eval( + self, trace_id: str, session_id: str + ) -> List[EvalTrace]: """ Get simplified trace information focusing on first-level children of create_and_execute_turn operations. Returns a list of spans with name, input, and output information, sorted by start time. @@ -332,6 +348,7 @@ class OpenTelemetryAdapter(Telemetry): step=child["name"], input=child["tags"].get("input", ""), output=child["tags"].get("output", ""), + session_id=session_id, ) ) # Recursively search in children