add tracing back to the lib cli (#595)

Adds back all the tracing logic removed from library client. also adds
back the logging to agent_instance.
This commit is contained in:
Dinesh Yeduguru 2024-12-11 08:44:20 -08:00 committed by GitHub
parent 1c03ba239e
commit e128f2547a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 76 additions and 117 deletions

View file

@ -185,9 +185,9 @@ class ChatAgent(ShieldRunnerMixin):
stream=request.stream,
):
if isinstance(chunk, CompletionMessage):
# log.info(
# f"{chunk.role.capitalize()}: {chunk.content}",
# )
log.info(
f"{chunk.role.capitalize()}: {chunk.content}",
)
output_message = chunk
continue
@ -280,7 +280,6 @@ class ChatAgent(ShieldRunnerMixin):
touchpoint: str,
) -> AsyncGenerator:
with tracing.span("run_shields") as span:
span.set_attribute("turn_id", turn_id)
span.set_attribute("input", [m.model_dump_json() for m in messages])
if len(shields) == 0:
span.set_attribute("output", "no shields")
@ -405,11 +404,6 @@ class ChatAgent(ShieldRunnerMixin):
n_iter = 0
while True:
msg = input_messages[-1]
# if len(str(msg)) > 1000:
# msg_str = f"{str(msg)[:500]}...<more>...{str(msg)[-500:]}"
# else:
# msg_str = str(msg)
# log.info(f"{msg_str}")
step_id = str(uuid.uuid4())
yield AgentTurnResponseStreamChunk(
@ -514,12 +508,12 @@ class ChatAgent(ShieldRunnerMixin):
)
if n_iter >= self.agent_config.max_infer_iters:
# log.info("Done with MAX iterations, exiting.")
log.info("Done with MAX iterations, exiting.")
yield message
break
if stop_reason == StopReason.out_of_tokens:
# log.info("Out of token budget, exiting.")
log.info("Out of token budget, exiting.")
yield message
break
@ -533,10 +527,10 @@ class ChatAgent(ShieldRunnerMixin):
message.content = [message.content] + attachments
yield message
else:
# log.info(f"Partial message: {str(message)}")
log.info(f"Partial message: {str(message)}")
input_messages = input_messages + [message]
else:
# log.info(f"{str(message)}")
log.info(f"{str(message)}")
try:
tool_call = message.tool_calls[0]
@ -800,7 +794,7 @@ async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessa
path = urlparse(uri).path
basename = os.path.basename(path)
filepath = f"{tempdir}/{make_random_string() + basename}"
# log.info(f"Downloading {url} -> {filepath}")
log.info(f"Downloading {url} -> {filepath}")
async with httpx.AsyncClient() as client:
r = await client.get(uri)