mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-31 16:01:46 +00:00
agent w/ search
This commit is contained in:
parent
8fbbea8c43
commit
436afa68be
1 changed files with 38 additions and 5 deletions
|
@ -130,9 +130,44 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
|
|||
return Job(job_id=job_id)
|
||||
|
||||
async def _run_agent_generation(
|
||||
self, task_config: EvalTaskConfig
|
||||
self, input_rows: List[Dict[str, Any]], task_config: EvalTaskConfig
|
||||
) -> List[Dict[str, Any]]:
|
||||
pass
|
||||
candidate = task_config.eval_candidate
|
||||
create_response = await self.agent_api.create_agent(candidate.config)
|
||||
agent_id = create_response.agent_id
|
||||
|
||||
generations = []
|
||||
for i, x in tqdm(enumerate(input_rows)):
|
||||
assert ColumnName.chat_completion_input.value in x, "Invalid input row"
|
||||
input_messages = eval(str(x[ColumnName.chat_completion_input.value]))
|
||||
input_messages = [UserMessage(**x) for x in input_messages]
|
||||
|
||||
# NOTE: only single-turn agent generation is supported. Create a new session for each input row
|
||||
session_create_response = await self.agent_api.create_agent_session(
|
||||
agent_id, f"session-{i}"
|
||||
)
|
||||
session_id = session_create_response.session_id
|
||||
|
||||
turn_request = dict(
|
||||
agent_id=agent_id,
|
||||
session_id=session_id,
|
||||
messages=input_messages,
|
||||
stream=True,
|
||||
)
|
||||
turn_response = [
|
||||
chunk
|
||||
async for chunk in await self.agent_api.create_agent_turn(
|
||||
**turn_request
|
||||
)
|
||||
]
|
||||
final_event = turn_response[-1].event.payload
|
||||
generations.append(
|
||||
{
|
||||
ColumnName.generated_answer.value: final_event.turn.output_message.content
|
||||
}
|
||||
)
|
||||
|
||||
return generations
|
||||
|
||||
async def _run_model_generation(
|
||||
self, input_rows: List[Dict[str, Any]], task_config: EvalTaskConfig
|
||||
|
@ -190,9 +225,7 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
|
|||
) -> EvaluateResponse:
|
||||
candidate = task_config.eval_candidate
|
||||
if candidate.type == "agent":
|
||||
raise NotImplementedError(
|
||||
"Evaluation with generation has not been implemented for agents"
|
||||
)
|
||||
generations = await self._run_agent_generation(input_rows, task_config)
|
||||
|
||||
if candidate.type == "model":
|
||||
generations = await self._run_model_generation(input_rows, task_config)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue