agent w/ search

This commit is contained in:
Xi Yan 2024-11-17 20:42:23 -08:00
parent 8fbbea8c43
commit 436afa68be

View file

@ -130,9 +130,44 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
return Job(job_id=job_id)
async def _run_agent_generation(
self, task_config: EvalTaskConfig
self, input_rows: List[Dict[str, Any]], task_config: EvalTaskConfig
) -> List[Dict[str, Any]]:
pass
candidate = task_config.eval_candidate
create_response = await self.agent_api.create_agent(candidate.config)
agent_id = create_response.agent_id
generations = []
for i, x in tqdm(enumerate(input_rows)):
assert ColumnName.chat_completion_input.value in x, "Invalid input row"
input_messages = eval(str(x[ColumnName.chat_completion_input.value]))
input_messages = [UserMessage(**x) for x in input_messages]
# NOTE: only single-turn agent generation is supported. Create a new session for each input row
session_create_response = await self.agent_api.create_agent_session(
agent_id, f"session-{i}"
)
session_id = session_create_response.session_id
turn_request = dict(
agent_id=agent_id,
session_id=session_id,
messages=input_messages,
stream=True,
)
turn_response = [
chunk
async for chunk in await self.agent_api.create_agent_turn(
**turn_request
)
]
final_event = turn_response[-1].event.payload
generations.append(
{
ColumnName.generated_answer.value: final_event.turn.output_message.content
}
)
return generations
async def _run_model_generation(
self, input_rows: List[Dict[str, Any]], task_config: EvalTaskConfig
@ -190,9 +225,7 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
) -> EvaluateResponse:
candidate = task_config.eval_candidate
if candidate.type == "agent":
raise NotImplementedError(
"Evaluation with generation has not been implemented for agents"
)
generations = await self._run_agent_generation(input_rows, task_config)
if candidate.type == "model":
generations = await self._run_model_generation(input_rows, task_config)