mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-02 08:44:44 +00:00
opentelemetry upload to dataset
This commit is contained in:
parent
2dfbb9744d
commit
32fbe366d7
8 changed files with 92 additions and 11 deletions
|
@ -37,3 +37,8 @@ class DatasetIO(Protocol):
|
||||||
page_token: Optional[str] = None,
|
page_token: Optional[str] = None,
|
||||||
filter_condition: Optional[str] = None,
|
filter_condition: Optional[str] = None,
|
||||||
) -> PaginatedRowsResult: ...
|
) -> PaginatedRowsResult: ...
|
||||||
|
|
||||||
|
@webmethod(route="/datasetio/upload", method="POST")
|
||||||
|
async def upload_rows(
|
||||||
|
self, dataset_id: str, rows: List[Dict[str, Any]]
|
||||||
|
) -> None: ...
|
||||||
|
|
|
@ -139,6 +139,7 @@ Event = Annotated[
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class EvalTrace(BaseModel):
|
class EvalTrace(BaseModel):
|
||||||
|
session_id: str
|
||||||
step: str
|
step: str
|
||||||
input: str
|
input: str
|
||||||
output: str
|
output: str
|
||||||
|
|
|
@ -222,6 +222,12 @@ class DatasetIORouter(DatasetIO):
|
||||||
filter_condition=filter_condition,
|
filter_condition=filter_condition,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def upload_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
|
||||||
|
return await self.routing_table.get_provider_impl(dataset_id).upload_rows(
|
||||||
|
dataset_id=dataset_id,
|
||||||
|
rows=rows,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ScoringRouter(Scoring):
|
class ScoringRouter(Scoring):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
|
@ -3,7 +3,7 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
from typing import Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import pandas
|
import pandas
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
|
@ -11,6 +11,7 @@ from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
from llama_stack.apis.datasetio import * # noqa: F403
|
from llama_stack.apis.datasetio import * # noqa: F403
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
|
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
|
||||||
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url
|
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url
|
||||||
|
@ -128,3 +129,29 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
||||||
total_count=len(rows),
|
total_count=len(rows),
|
||||||
next_page_token=str(end),
|
next_page_token=str(end),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def upload_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
|
||||||
|
dataset_info = self.dataset_infos.get(dataset_id)
|
||||||
|
if dataset_info is None:
|
||||||
|
raise ValueError(f"Dataset with id {dataset_id} not found")
|
||||||
|
|
||||||
|
dataset_impl = dataset_info.dataset_impl
|
||||||
|
dataset_impl.load()
|
||||||
|
|
||||||
|
new_rows_df = pandas.DataFrame(rows)
|
||||||
|
|
||||||
|
new_rows_df = dataset_impl._validate_dataset_schema(new_rows_df)
|
||||||
|
|
||||||
|
dataset_impl.df = pandas.concat(
|
||||||
|
[dataset_impl.df, new_rows_df], ignore_index=True
|
||||||
|
)
|
||||||
|
|
||||||
|
url = str(dataset_info.dataset_def.url)
|
||||||
|
parsed_url = urlparse(url)
|
||||||
|
if parsed_url.scheme == "file" or not parsed_url.scheme:
|
||||||
|
file_path = parsed_url.path
|
||||||
|
dataset_impl.df.to_csv(file_path, index=False)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported URL scheme: {parsed_url.scheme}. Only file:// URLs are supported for writing."
|
||||||
|
)
|
||||||
|
|
|
@ -27,8 +27,10 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
config_class="llama_stack.providers.remote.telemetry.sample.SampleConfig",
|
config_class="llama_stack.providers.remote.telemetry.sample.SampleConfig",
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
RemoteProviderSpec(
|
||||||
api=Api.telemetry,
|
api=Api.telemetry,
|
||||||
|
provider_type="remote::opentelemetry-jaeger",
|
||||||
|
config_class="llama_stack.providers.remote.telemetry.opentelemetry.OpenTelemetryConfig",
|
||||||
adapter=AdapterSpec(
|
adapter=AdapterSpec(
|
||||||
adapter_type="opentelemetry-jaeger",
|
adapter_type="opentelemetry-jaeger",
|
||||||
pip_packages=[
|
pip_packages=[
|
||||||
|
@ -40,5 +42,8 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
module="llama_stack.providers.remote.telemetry.opentelemetry",
|
module="llama_stack.providers.remote.telemetry.opentelemetry",
|
||||||
config_class="llama_stack.providers.remote.telemetry.opentelemetry.OpenTelemetryConfig",
|
config_class="llama_stack.providers.remote.telemetry.opentelemetry.OpenTelemetryConfig",
|
||||||
),
|
),
|
||||||
|
api_dependencies=[
|
||||||
|
Api.datasetio,
|
||||||
|
],
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
|
@ -3,7 +3,7 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
from typing import Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from llama_stack.apis.datasetio import * # noqa: F403
|
from llama_stack.apis.datasetio import * # noqa: F403
|
||||||
|
|
||||||
|
@ -95,3 +95,22 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
||||||
total_count=len(rows),
|
total_count=len(rows),
|
||||||
next_page_token=str(end),
|
next_page_token=str(end),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def upload_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
|
||||||
|
dataset_def = self.dataset_infos[dataset_id]
|
||||||
|
loaded_dataset = load_hf_dataset(dataset_def)
|
||||||
|
|
||||||
|
# Convert rows to HF Dataset format
|
||||||
|
new_dataset = hf_datasets.Dataset.from_list(rows)
|
||||||
|
|
||||||
|
# Concatenate the new rows with existing dataset
|
||||||
|
updated_dataset = hf_datasets.concatenate_datasets(
|
||||||
|
[loaded_dataset, new_dataset]
|
||||||
|
)
|
||||||
|
|
||||||
|
if dataset_def.metadata.get("path", None):
|
||||||
|
updated_dataset.push_to_hub(dataset_def.metadata["path"])
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Uploading to URL-based datasets is not supported yet"
|
||||||
|
)
|
||||||
|
|
|
@ -7,9 +7,10 @@
|
||||||
from .config import OpenTelemetryConfig
|
from .config import OpenTelemetryConfig
|
||||||
|
|
||||||
|
|
||||||
async def get_adapter_impl(config: OpenTelemetryConfig, _deps):
|
async def get_adapter_impl(config: OpenTelemetryConfig, deps):
|
||||||
from .opentelemetry import OpenTelemetryAdapter
|
from .opentelemetry import OpenTelemetryAdapter
|
||||||
|
|
||||||
impl = OpenTelemetryAdapter(config)
|
print(f"deps: {deps}")
|
||||||
|
impl = OpenTelemetryAdapter(config, deps)
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
|
@ -19,6 +19,8 @@ from opentelemetry.sdk.trace import TracerProvider
|
||||||
from opentelemetry.sdk.trace.export import BatchSpanProcessor
|
from opentelemetry.sdk.trace.export import BatchSpanProcessor
|
||||||
from opentelemetry.semconv.resource import ResourceAttributes
|
from opentelemetry.semconv.resource import ResourceAttributes
|
||||||
|
|
||||||
|
from llama_stack.distribution.datatypes import Api
|
||||||
|
|
||||||
|
|
||||||
from llama_stack.apis.telemetry import * # noqa: F403
|
from llama_stack.apis.telemetry import * # noqa: F403
|
||||||
|
|
||||||
|
@ -49,8 +51,9 @@ def is_tracing_enabled(tracer):
|
||||||
|
|
||||||
|
|
||||||
class OpenTelemetryAdapter(Telemetry):
|
class OpenTelemetryAdapter(Telemetry):
|
||||||
def __init__(self, config: OpenTelemetryConfig):
|
def __init__(self, config: OpenTelemetryConfig, deps) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.datasetio = deps[Api.datasetio]
|
||||||
|
|
||||||
resource = Resource.create(
|
resource = Resource.create(
|
||||||
{
|
{
|
||||||
|
@ -230,18 +233,29 @@ class OpenTelemetryAdapter(Telemetry):
|
||||||
traces_data = await response.json()
|
traces_data = await response.json()
|
||||||
seen_trace_ids = set()
|
seen_trace_ids = set()
|
||||||
|
|
||||||
# For each trace ID, get the detailed trace information
|
|
||||||
for trace_data in traces_data.get("data", []):
|
for trace_data in traces_data.get("data", []):
|
||||||
trace_id = trace_data.get("traceID")
|
trace_id = trace_data.get("traceID")
|
||||||
if trace_id and trace_id not in seen_trace_ids:
|
if trace_id and trace_id not in seen_trace_ids:
|
||||||
seen_trace_ids.add(trace_id)
|
seen_trace_ids.add(trace_id)
|
||||||
trace_details = await self.get_trace_for_eval(trace_id)
|
trace_details = await self.get_trace_for_eval(
|
||||||
if trace_details:
|
trace_id, session_id
|
||||||
traces.append(trace_details)
|
)
|
||||||
|
traces.extend(trace_details)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise Exception(f"Error querying Jaeger traces: {str(e)}") from e
|
raise Exception(f"Error querying Jaeger traces: {str(e)}") from e
|
||||||
|
|
||||||
|
if dataset_id:
|
||||||
|
traces_dict = [
|
||||||
|
{
|
||||||
|
"step": trace.step,
|
||||||
|
"input": trace.input,
|
||||||
|
"output": trace.output,
|
||||||
|
"session_id": trace.session_id,
|
||||||
|
}
|
||||||
|
for trace in traces
|
||||||
|
]
|
||||||
|
await self.datasetio.upload_rows(dataset_id, traces_dict)
|
||||||
return traces
|
return traces
|
||||||
|
|
||||||
async def get_trace(self, trace_id: str) -> Dict[str, Any]:
|
async def get_trace(self, trace_id: str) -> Dict[str, Any]:
|
||||||
|
@ -311,7 +325,9 @@ class OpenTelemetryAdapter(Telemetry):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise Exception(f"Error querying Jaeger trace structure: {str(e)}") from e
|
raise Exception(f"Error querying Jaeger trace structure: {str(e)}") from e
|
||||||
|
|
||||||
async def get_trace_for_eval(self, trace_id: str) -> List[EvalTrace]:
|
async def get_trace_for_eval(
|
||||||
|
self, trace_id: str, session_id: str
|
||||||
|
) -> List[EvalTrace]:
|
||||||
"""
|
"""
|
||||||
Get simplified trace information focusing on first-level children of create_and_execute_turn operations.
|
Get simplified trace information focusing on first-level children of create_and_execute_turn operations.
|
||||||
Returns a list of spans with name, input, and output information, sorted by start time.
|
Returns a list of spans with name, input, and output information, sorted by start time.
|
||||||
|
@ -332,6 +348,7 @@ class OpenTelemetryAdapter(Telemetry):
|
||||||
step=child["name"],
|
step=child["name"],
|
||||||
input=child["tags"].get("input", ""),
|
input=child["tags"].get("input", ""),
|
||||||
output=child["tags"].get("output", ""),
|
output=child["tags"].get("output", ""),
|
||||||
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
# Recursively search in children
|
# Recursively search in children
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue