diff --git a/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py b/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py index 817085968..168ce5e22 100644 --- a/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py +++ b/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py @@ -202,6 +202,8 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry): parent_span_id = int(event.payload.parent_span_id, 16) parent_span = _GLOBAL_STORAGE["active_spans"].get(parent_span_id) context = trace.set_span_in_context(parent_span) + else: + event.attributes["__root_span__"] = "true" span = tracer.start_span( name=event.payload.name, diff --git a/tests/integration/telemetry/test_telemetry.py b/tests/integration/telemetry/test_telemetry.py new file mode 100644 index 000000000..10cbb2eeb --- /dev/null +++ b/tests/integration/telemetry/test_telemetry.py @@ -0,0 +1,50 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import time +from uuid import uuid4 + +import pytest +from llama_stack_client import Agent + +from llama_stack.distribution.library_client import LlamaStackAsLibraryClient + + +def test_agent_query_spans(llama_stack_client, text_model_id): + if isinstance(llama_stack_client, LlamaStackAsLibraryClient): + pytest.mark.xfail(reason="Need to fix LlamaStackAsLibraryClient to log spans") + + agent = Agent(llama_stack_client, model=text_model_id, instructions="You are a helpful assistant") + session_id = agent.create_session(f"test-session-{uuid4()}") + agent.create_turn( + messages=[ + { + "role": "user", + "content": "Give me a sentence that contains the word: hello", + } + ], + session_id=session_id, + stream=False, + ) + + # Wait for the span to be logged + time.sleep(2) + + agent_logs = [] + + for span in llama_stack_client.telemetry.query_spans( + attribute_filters=[ + {"key": "session_id", "op": "eq", "value": session_id}, + ], + attributes_to_return=["input", "output"], + ): + print(span.attributes) + if span.attributes["output"] != "no shields": + agent_logs.append(span.attributes) + + assert len(agent_logs) == 1 + assert "Give me a sentence that contains the word: hello" in agent_logs[0]["input"] + assert "hello" in agent_logs[0]["output"].lower()