mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-01 16:24:44 +00:00
agent w/ search
This commit is contained in:
parent
8fbbea8c43
commit
436afa68be
1 changed files with 38 additions and 5 deletions
|
@ -130,9 +130,44 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
|
||||||
return Job(job_id=job_id)
|
return Job(job_id=job_id)
|
||||||
|
|
||||||
async def _run_agent_generation(
|
async def _run_agent_generation(
|
||||||
self, task_config: EvalTaskConfig
|
self, input_rows: List[Dict[str, Any]], task_config: EvalTaskConfig
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
pass
|
candidate = task_config.eval_candidate
|
||||||
|
create_response = await self.agent_api.create_agent(candidate.config)
|
||||||
|
agent_id = create_response.agent_id
|
||||||
|
|
||||||
|
generations = []
|
||||||
|
for i, x in tqdm(enumerate(input_rows)):
|
||||||
|
assert ColumnName.chat_completion_input.value in x, "Invalid input row"
|
||||||
|
input_messages = eval(str(x[ColumnName.chat_completion_input.value]))
|
||||||
|
input_messages = [UserMessage(**x) for x in input_messages]
|
||||||
|
|
||||||
|
# NOTE: only single-turn agent generation is supported. Create a new session for each input row
|
||||||
|
session_create_response = await self.agent_api.create_agent_session(
|
||||||
|
agent_id, f"session-{i}"
|
||||||
|
)
|
||||||
|
session_id = session_create_response.session_id
|
||||||
|
|
||||||
|
turn_request = dict(
|
||||||
|
agent_id=agent_id,
|
||||||
|
session_id=session_id,
|
||||||
|
messages=input_messages,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
turn_response = [
|
||||||
|
chunk
|
||||||
|
async for chunk in await self.agent_api.create_agent_turn(
|
||||||
|
**turn_request
|
||||||
|
)
|
||||||
|
]
|
||||||
|
final_event = turn_response[-1].event.payload
|
||||||
|
generations.append(
|
||||||
|
{
|
||||||
|
ColumnName.generated_answer.value: final_event.turn.output_message.content
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return generations
|
||||||
|
|
||||||
async def _run_model_generation(
|
async def _run_model_generation(
|
||||||
self, input_rows: List[Dict[str, Any]], task_config: EvalTaskConfig
|
self, input_rows: List[Dict[str, Any]], task_config: EvalTaskConfig
|
||||||
|
@ -190,9 +225,7 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
|
||||||
) -> EvaluateResponse:
|
) -> EvaluateResponse:
|
||||||
candidate = task_config.eval_candidate
|
candidate = task_config.eval_candidate
|
||||||
if candidate.type == "agent":
|
if candidate.type == "agent":
|
||||||
raise NotImplementedError(
|
generations = await self._run_agent_generation(input_rows, task_config)
|
||||||
"Evaluation with generation has not been implemented for agents"
|
|
||||||
)
|
|
||||||
|
|
||||||
if candidate.type == "model":
|
if candidate.type == "model":
|
||||||
generations = await self._run_model_generation(input_rows, task_config)
|
generations = await self._run_model_generation(input_rows, task_config)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue