mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
# What does this PR do? This PR fixes some of the issues with our telemetry setup to enable logs to be delivered to opentelemetry and jaeger. Main fixes 1) Updates the open telemetry provider to use the latest oltp exports instead of deprected ones. 2) Adds a tracing middleware, which injects traces into each HTTP request that the server recieves and this is going to be the root trace. Previously, we did this in the create_dynamic_route method, which is actually not the actual exectuion flow, but more of a config and this causes the traces to end prematurely. Through middleware, we plugin the trace start and end at the right location. 3) We manage our own methods to create traces and spans and this does not fit well with Opentelemetry SDK since it does not support provide a way to take in traces and spans that are already created. it expects us to use the SDK to create them. For now, I have a hacky approach of just maintaining a map from our internal telemetry objects to the open telemetry specfic ones. This is not the ideal solution. I will explore other ways to get around this issue. for now, to have something that works, i am going to keep this as is. Addresses: #509
193 lines
6.1 KiB
Python
193 lines
6.1 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import json
|
|
import logging
|
|
import uuid
|
|
from typing import AsyncGenerator
|
|
|
|
from llama_stack.apis.inference import Inference
|
|
from llama_stack.apis.memory import Memory
|
|
from llama_stack.apis.memory_banks import MemoryBanks
|
|
from llama_stack.apis.safety import Safety
|
|
from llama_stack.apis.agents import * # noqa: F403
|
|
|
|
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl
|
|
|
|
from .agent_instance import ChatAgent
|
|
from .config import MetaReferenceAgentsImplConfig
|
|
|
|
logger = logging.getLogger()
|
|
logger.setLevel(logging.INFO)
|
|
|
|
|
|
class MetaReferenceAgentsImpl(Agents):
|
|
def __init__(
|
|
self,
|
|
config: MetaReferenceAgentsImplConfig,
|
|
inference_api: Inference,
|
|
memory_api: Memory,
|
|
safety_api: Safety,
|
|
memory_banks_api: MemoryBanks,
|
|
):
|
|
self.config = config
|
|
self.inference_api = inference_api
|
|
self.memory_api = memory_api
|
|
self.safety_api = safety_api
|
|
self.memory_banks_api = memory_banks_api
|
|
|
|
self.in_memory_store = InmemoryKVStoreImpl()
|
|
|
|
async def initialize(self) -> None:
|
|
self.persistence_store = await kvstore_impl(self.config.persistence_store)
|
|
|
|
async def create_agent(
|
|
self,
|
|
agent_config: AgentConfig,
|
|
) -> AgentCreateResponse:
|
|
agent_id = str(uuid.uuid4())
|
|
|
|
await self.persistence_store.set(
|
|
key=f"agent:{agent_id}",
|
|
value=agent_config.model_dump_json(),
|
|
)
|
|
return AgentCreateResponse(
|
|
agent_id=agent_id,
|
|
)
|
|
|
|
async def get_agent(self, agent_id: str) -> ChatAgent:
|
|
agent_config = await self.persistence_store.get(
|
|
key=f"agent:{agent_id}",
|
|
)
|
|
if not agent_config:
|
|
raise ValueError(f"Could not find agent config for {agent_id}")
|
|
|
|
try:
|
|
agent_config = json.loads(agent_config)
|
|
except json.JSONDecodeError as e:
|
|
raise ValueError(
|
|
f"Could not JSON decode agent config for {agent_id}"
|
|
) from e
|
|
|
|
try:
|
|
agent_config = AgentConfig(**agent_config)
|
|
except Exception as e:
|
|
raise ValueError(
|
|
f"Could not validate(?) agent config for {agent_id}"
|
|
) from e
|
|
|
|
return ChatAgent(
|
|
agent_id=agent_id,
|
|
agent_config=agent_config,
|
|
inference_api=self.inference_api,
|
|
safety_api=self.safety_api,
|
|
memory_api=self.memory_api,
|
|
memory_banks_api=self.memory_banks_api,
|
|
persistence_store=(
|
|
self.persistence_store
|
|
if agent_config.enable_session_persistence
|
|
else self.in_memory_store
|
|
),
|
|
)
|
|
|
|
async def create_agent_session(
|
|
self,
|
|
agent_id: str,
|
|
session_name: str,
|
|
) -> AgentSessionCreateResponse:
|
|
agent = await self.get_agent(agent_id)
|
|
|
|
session_id = await agent.create_session(session_name)
|
|
return AgentSessionCreateResponse(
|
|
session_id=session_id,
|
|
)
|
|
|
|
async def create_agent_turn(
|
|
self,
|
|
agent_id: str,
|
|
session_id: str,
|
|
messages: List[
|
|
Union[
|
|
UserMessage,
|
|
ToolResponseMessage,
|
|
]
|
|
],
|
|
attachments: Optional[List[Attachment]] = None,
|
|
stream: Optional[bool] = False,
|
|
) -> AsyncGenerator:
|
|
request = AgentTurnCreateRequest(
|
|
agent_id=agent_id,
|
|
session_id=session_id,
|
|
messages=messages,
|
|
attachments=attachments,
|
|
stream=True,
|
|
)
|
|
if stream:
|
|
return self._create_agent_turn_streaming(request)
|
|
else:
|
|
raise NotImplementedError("Non-streaming agent turns not yet implemented")
|
|
|
|
async def _create_agent_turn_streaming(
|
|
self,
|
|
request: AgentTurnCreateRequest,
|
|
) -> AsyncGenerator:
|
|
agent = await self.get_agent(request.agent_id)
|
|
async for event in agent.create_and_execute_turn(request):
|
|
yield event
|
|
|
|
async def get_agents_turn(
|
|
self, agent_id: str, session_id: str, turn_id: str
|
|
) -> Turn:
|
|
turn = await self.persistence_store.get(
|
|
f"session:{agent_id}:{session_id}:{turn_id}"
|
|
)
|
|
turn = json.loads(turn)
|
|
turn = Turn(**turn)
|
|
return turn
|
|
|
|
async def get_agents_step(
|
|
self, agent_id: str, session_id: str, turn_id: str, step_id: str
|
|
) -> AgentStepResponse:
|
|
turn = await self.persistence_store.get(
|
|
f"session:{agent_id}:{session_id}:{turn_id}"
|
|
)
|
|
turn = json.loads(turn)
|
|
turn = Turn(**turn)
|
|
steps = turn.steps
|
|
for step in steps:
|
|
if step.step_id == step_id:
|
|
return AgentStepResponse(step=step)
|
|
raise ValueError(f"Provided step_id {step_id} could not be found")
|
|
|
|
async def get_agents_session(
|
|
self,
|
|
agent_id: str,
|
|
session_id: str,
|
|
turn_ids: Optional[List[str]] = None,
|
|
) -> Session:
|
|
session = await self.persistence_store.get(f"session:{agent_id}:{session_id}")
|
|
session = Session(**json.loads(session), turns=[])
|
|
turns = []
|
|
if turn_ids:
|
|
for turn_id in turn_ids:
|
|
turn = await self.persistence_store.get(
|
|
f"session:{agent_id}:{session_id}:{turn_id}"
|
|
)
|
|
turn = json.loads(turn)
|
|
turn = Turn(**turn)
|
|
turns.append(turn)
|
|
return Session(
|
|
session_name=session.session_name,
|
|
session_id=session_id,
|
|
turns=turns if turns else [],
|
|
started_at=session.started_at,
|
|
)
|
|
|
|
async def delete_agents_session(self, agent_id: str, session_id: str) -> None:
|
|
await self.persistence_store.delete(f"session:{agent_id}:{session_id}")
|
|
|
|
async def delete_agents(self, agent_id: str) -> None:
|
|
await self.persistence_store.delete(f"agent:{agent_id}")
|