opentelemetry upload to dataset

This commit is contained in:
Dinesh Yeduguru 2024-11-27 14:09:24 -08:00
parent 2dfbb9744d
commit 32fbe366d7
8 changed files with 92 additions and 11 deletions

View file

@ -37,3 +37,8 @@ class DatasetIO(Protocol):
page_token: Optional[str] = None, page_token: Optional[str] = None,
filter_condition: Optional[str] = None, filter_condition: Optional[str] = None,
) -> PaginatedRowsResult: ... ) -> PaginatedRowsResult: ...
@webmethod(route="/datasetio/upload", method="POST")
async def upload_rows(
self, dataset_id: str, rows: List[Dict[str, Any]]
) -> None: ...

View file

@ -139,6 +139,7 @@ Event = Annotated[
@json_schema_type @json_schema_type
class EvalTrace(BaseModel): class EvalTrace(BaseModel):
session_id: str
step: str step: str
input: str input: str
output: str output: str

View file

@ -222,6 +222,12 @@ class DatasetIORouter(DatasetIO):
filter_condition=filter_condition, filter_condition=filter_condition,
) )
async def upload_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
return await self.routing_table.get_provider_impl(dataset_id).upload_rows(
dataset_id=dataset_id,
rows=rows,
)
class ScoringRouter(Scoring): class ScoringRouter(Scoring):
def __init__( def __init__(

View file

@ -3,7 +3,7 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Optional from typing import Any, Dict, List, Optional
import pandas import pandas
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
@ -11,6 +11,7 @@ from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.datasetio import * # noqa: F403 from llama_stack.apis.datasetio import * # noqa: F403
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from urllib.parse import urlparse
from llama_stack.providers.datatypes import DatasetsProtocolPrivate from llama_stack.providers.datatypes import DatasetsProtocolPrivate
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url
@ -128,3 +129,29 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
total_count=len(rows), total_count=len(rows),
next_page_token=str(end), next_page_token=str(end),
) )
async def upload_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
dataset_info = self.dataset_infos.get(dataset_id)
if dataset_info is None:
raise ValueError(f"Dataset with id {dataset_id} not found")
dataset_impl = dataset_info.dataset_impl
dataset_impl.load()
new_rows_df = pandas.DataFrame(rows)
new_rows_df = dataset_impl._validate_dataset_schema(new_rows_df)
dataset_impl.df = pandas.concat(
[dataset_impl.df, new_rows_df], ignore_index=True
)
url = str(dataset_info.dataset_def.url)
parsed_url = urlparse(url)
if parsed_url.scheme == "file" or not parsed_url.scheme:
file_path = parsed_url.path
dataset_impl.df.to_csv(file_path, index=False)
else:
raise ValueError(
f"Unsupported URL scheme: {parsed_url.scheme}. Only file:// URLs are supported for writing."
)

View file

@ -27,8 +27,10 @@ def available_providers() -> List[ProviderSpec]:
config_class="llama_stack.providers.remote.telemetry.sample.SampleConfig", config_class="llama_stack.providers.remote.telemetry.sample.SampleConfig",
), ),
), ),
remote_provider_spec( RemoteProviderSpec(
api=Api.telemetry, api=Api.telemetry,
provider_type="remote::opentelemetry-jaeger",
config_class="llama_stack.providers.remote.telemetry.opentelemetry.OpenTelemetryConfig",
adapter=AdapterSpec( adapter=AdapterSpec(
adapter_type="opentelemetry-jaeger", adapter_type="opentelemetry-jaeger",
pip_packages=[ pip_packages=[
@ -40,5 +42,8 @@ def available_providers() -> List[ProviderSpec]:
module="llama_stack.providers.remote.telemetry.opentelemetry", module="llama_stack.providers.remote.telemetry.opentelemetry",
config_class="llama_stack.providers.remote.telemetry.opentelemetry.OpenTelemetryConfig", config_class="llama_stack.providers.remote.telemetry.opentelemetry.OpenTelemetryConfig",
), ),
api_dependencies=[
Api.datasetio,
],
), ),
] ]

View file

@ -3,7 +3,7 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Optional from typing import Any, Dict, List, Optional
from llama_stack.apis.datasetio import * # noqa: F403 from llama_stack.apis.datasetio import * # noqa: F403
@ -95,3 +95,22 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
total_count=len(rows), total_count=len(rows),
next_page_token=str(end), next_page_token=str(end),
) )
async def upload_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
dataset_def = self.dataset_infos[dataset_id]
loaded_dataset = load_hf_dataset(dataset_def)
# Convert rows to HF Dataset format
new_dataset = hf_datasets.Dataset.from_list(rows)
# Concatenate the new rows with existing dataset
updated_dataset = hf_datasets.concatenate_datasets(
[loaded_dataset, new_dataset]
)
if dataset_def.metadata.get("path", None):
updated_dataset.push_to_hub(dataset_def.metadata["path"])
else:
raise NotImplementedError(
"Uploading to URL-based datasets is not supported yet"
)

View file

@ -7,9 +7,10 @@
from .config import OpenTelemetryConfig from .config import OpenTelemetryConfig
async def get_adapter_impl(config: OpenTelemetryConfig, _deps): async def get_adapter_impl(config: OpenTelemetryConfig, deps):
from .opentelemetry import OpenTelemetryAdapter from .opentelemetry import OpenTelemetryAdapter
impl = OpenTelemetryAdapter(config) print(f"deps: {deps}")
impl = OpenTelemetryAdapter(config, deps)
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -19,6 +19,8 @@ from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor from opentelemetry.sdk.trace.export import BatchSpanProcessor
from opentelemetry.semconv.resource import ResourceAttributes from opentelemetry.semconv.resource import ResourceAttributes
from llama_stack.distribution.datatypes import Api
from llama_stack.apis.telemetry import * # noqa: F403 from llama_stack.apis.telemetry import * # noqa: F403
@ -49,8 +51,9 @@ def is_tracing_enabled(tracer):
class OpenTelemetryAdapter(Telemetry): class OpenTelemetryAdapter(Telemetry):
def __init__(self, config: OpenTelemetryConfig): def __init__(self, config: OpenTelemetryConfig, deps) -> None:
self.config = config self.config = config
self.datasetio = deps[Api.datasetio]
resource = Resource.create( resource = Resource.create(
{ {
@ -230,18 +233,29 @@ class OpenTelemetryAdapter(Telemetry):
traces_data = await response.json() traces_data = await response.json()
seen_trace_ids = set() seen_trace_ids = set()
# For each trace ID, get the detailed trace information
for trace_data in traces_data.get("data", []): for trace_data in traces_data.get("data", []):
trace_id = trace_data.get("traceID") trace_id = trace_data.get("traceID")
if trace_id and trace_id not in seen_trace_ids: if trace_id and trace_id not in seen_trace_ids:
seen_trace_ids.add(trace_id) seen_trace_ids.add(trace_id)
trace_details = await self.get_trace_for_eval(trace_id) trace_details = await self.get_trace_for_eval(
if trace_details: trace_id, session_id
traces.append(trace_details) )
traces.extend(trace_details)
except Exception as e: except Exception as e:
raise Exception(f"Error querying Jaeger traces: {str(e)}") from e raise Exception(f"Error querying Jaeger traces: {str(e)}") from e
if dataset_id:
traces_dict = [
{
"step": trace.step,
"input": trace.input,
"output": trace.output,
"session_id": trace.session_id,
}
for trace in traces
]
await self.datasetio.upload_rows(dataset_id, traces_dict)
return traces return traces
async def get_trace(self, trace_id: str) -> Dict[str, Any]: async def get_trace(self, trace_id: str) -> Dict[str, Any]:
@ -311,7 +325,9 @@ class OpenTelemetryAdapter(Telemetry):
except Exception as e: except Exception as e:
raise Exception(f"Error querying Jaeger trace structure: {str(e)}") from e raise Exception(f"Error querying Jaeger trace structure: {str(e)}") from e
async def get_trace_for_eval(self, trace_id: str) -> List[EvalTrace]: async def get_trace_for_eval(
self, trace_id: str, session_id: str
) -> List[EvalTrace]:
""" """
Get simplified trace information focusing on first-level children of create_and_execute_turn operations. Get simplified trace information focusing on first-level children of create_and_execute_turn operations.
Returns a list of spans with name, input, and output information, sorted by start time. Returns a list of spans with name, input, and output information, sorted by start time.
@ -332,6 +348,7 @@ class OpenTelemetryAdapter(Telemetry):
step=child["name"], step=child["name"],
input=child["tags"].get("input", ""), input=child["tags"].get("input", ""),
output=child["tags"].get("output", ""), output=child["tags"].get("output", ""),
session_id=session_id,
) )
) )
# Recursively search in children # Recursively search in children