mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-06 10:42:39 +00:00
precommit
This commit is contained in:
parent
2f140c7ccf
commit
ce6fc9f851
2 changed files with 47 additions and 156 deletions
|
@ -8,9 +8,9 @@ import time
|
||||||
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import (
|
from llama_stack.apis.common.content_types import (
|
||||||
|
URL,
|
||||||
InterleavedContent,
|
InterleavedContent,
|
||||||
InterleavedContentItem,
|
InterleavedContentItem,
|
||||||
URL,
|
|
||||||
)
|
)
|
||||||
from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse
|
from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse
|
||||||
from llama_stack.apis.datasets import DatasetPurpose, DataSource
|
from llama_stack.apis.datasets import DatasetPurpose, DataSource
|
||||||
|
@ -88,9 +88,7 @@ class VectorIORouter(VectorIO):
|
||||||
provider_id: Optional[str] = None,
|
provider_id: Optional[str] = None,
|
||||||
provider_vector_db_id: Optional[str] = None,
|
provider_vector_db_id: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
logger.debug(
|
logger.debug(f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}")
|
||||||
f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}"
|
|
||||||
)
|
|
||||||
await self.routing_table.register_vector_db(
|
await self.routing_table.register_vector_db(
|
||||||
vector_db_id,
|
vector_db_id,
|
||||||
embedding_model,
|
embedding_model,
|
||||||
|
@ -108,9 +106,7 @@ class VectorIORouter(VectorIO):
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}",
|
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}",
|
||||||
)
|
)
|
||||||
return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(
|
return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(vector_db_id, chunks, ttl_seconds)
|
||||||
vector_db_id, chunks, ttl_seconds
|
|
||||||
)
|
|
||||||
|
|
||||||
async def query_chunks(
|
async def query_chunks(
|
||||||
self,
|
self,
|
||||||
|
@ -119,9 +115,7 @@ class VectorIORouter(VectorIO):
|
||||||
params: Optional[Dict[str, Any]] = None,
|
params: Optional[Dict[str, Any]] = None,
|
||||||
) -> QueryChunksResponse:
|
) -> QueryChunksResponse:
|
||||||
logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}")
|
logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}")
|
||||||
return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(
|
return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params)
|
||||||
vector_db_id, query, params
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class InferenceRouter(Inference):
|
class InferenceRouter(Inference):
|
||||||
|
@ -158,9 +152,7 @@ class InferenceRouter(Inference):
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}",
|
f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}",
|
||||||
)
|
)
|
||||||
await self.routing_table.register_model(
|
await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)
|
||||||
model_id, provider_model_id, provider_id, metadata, model_type
|
|
||||||
)
|
|
||||||
|
|
||||||
def _construct_metrics(
|
def _construct_metrics(
|
||||||
self,
|
self,
|
||||||
|
@ -214,16 +206,11 @@ class InferenceRouter(Inference):
|
||||||
total_tokens: int,
|
total_tokens: int,
|
||||||
model: Model,
|
model: Model,
|
||||||
) -> List[MetricInResponse]:
|
) -> List[MetricInResponse]:
|
||||||
metrics = self._construct_metrics(
|
metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model)
|
||||||
prompt_tokens, completion_tokens, total_tokens, model
|
|
||||||
)
|
|
||||||
if self.telemetry:
|
if self.telemetry:
|
||||||
for metric in metrics:
|
for metric in metrics:
|
||||||
await self.telemetry.log_event(metric)
|
await self.telemetry.log_event(metric)
|
||||||
return [
|
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics]
|
||||||
MetricInResponse(metric=metric.metric, value=metric.value)
|
|
||||||
for metric in metrics
|
|
||||||
]
|
|
||||||
|
|
||||||
async def _count_tokens(
|
async def _count_tokens(
|
||||||
self,
|
self,
|
||||||
|
@ -248,9 +235,7 @@ class InferenceRouter(Inference):
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
tool_config: Optional[ToolConfig] = None,
|
tool_config: Optional[ToolConfig] = None,
|
||||||
) -> Union[
|
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
|
||||||
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
|
|
||||||
]:
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}",
|
f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}",
|
||||||
)
|
)
|
||||||
|
@ -260,19 +245,12 @@ class InferenceRouter(Inference):
|
||||||
if model is None:
|
if model is None:
|
||||||
raise ValueError(f"Model '{model_id}' not found")
|
raise ValueError(f"Model '{model_id}' not found")
|
||||||
if model.model_type == ModelType.embedding:
|
if model.model_type == ModelType.embedding:
|
||||||
raise ValueError(
|
raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions")
|
||||||
f"Model '{model_id}' is an embedding model and does not support chat completions"
|
|
||||||
)
|
|
||||||
if tool_config:
|
if tool_config:
|
||||||
if tool_choice and tool_choice != tool_config.tool_choice:
|
if tool_choice and tool_choice != tool_config.tool_choice:
|
||||||
raise ValueError("tool_choice and tool_config.tool_choice must match")
|
raise ValueError("tool_choice and tool_config.tool_choice must match")
|
||||||
if (
|
if tool_prompt_format and tool_prompt_format != tool_config.tool_prompt_format:
|
||||||
tool_prompt_format
|
raise ValueError("tool_prompt_format and tool_config.tool_prompt_format must match")
|
||||||
and tool_prompt_format != tool_config.tool_prompt_format
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
"tool_prompt_format and tool_config.tool_prompt_format must match"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
params = {}
|
params = {}
|
||||||
if tool_choice:
|
if tool_choice:
|
||||||
|
@ -290,14 +268,9 @@ class InferenceRouter(Inference):
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
# verify tool_choice is one of the tools
|
# verify tool_choice is one of the tools
|
||||||
tool_names = [
|
tool_names = [t.tool_name if isinstance(t.tool_name, str) else t.tool_name.value for t in tools]
|
||||||
t.tool_name if isinstance(t.tool_name, str) else t.tool_name.value
|
|
||||||
for t in tools
|
|
||||||
]
|
|
||||||
if tool_config.tool_choice not in tool_names:
|
if tool_config.tool_choice not in tool_names:
|
||||||
raise ValueError(
|
raise ValueError(f"Tool choice {tool_config.tool_choice} is not one of the tools: {tool_names}")
|
||||||
f"Tool choice {tool_config.tool_choice} is not one of the tools: {tool_names}"
|
|
||||||
)
|
|
||||||
|
|
||||||
params = dict(
|
params = dict(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
|
@ -312,25 +285,17 @@ class InferenceRouter(Inference):
|
||||||
tool_config=tool_config,
|
tool_config=tool_config,
|
||||||
)
|
)
|
||||||
provider = self.routing_table.get_provider_impl(model_id)
|
provider = self.routing_table.get_provider_impl(model_id)
|
||||||
prompt_tokens = await self._count_tokens(
|
prompt_tokens = await self._count_tokens(messages, tool_config.tool_prompt_format)
|
||||||
messages, tool_config.tool_prompt_format
|
|
||||||
)
|
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
|
|
||||||
async def stream_generator():
|
async def stream_generator():
|
||||||
completion_text = ""
|
completion_text = ""
|
||||||
async for chunk in await provider.chat_completion(**params):
|
async for chunk in await provider.chat_completion(**params):
|
||||||
if (
|
if chunk.event.event_type == ChatCompletionResponseEventType.progress:
|
||||||
chunk.event.event_type
|
|
||||||
== ChatCompletionResponseEventType.progress
|
|
||||||
):
|
|
||||||
if chunk.event.delta.type == "text":
|
if chunk.event.delta.type == "text":
|
||||||
completion_text += chunk.event.delta.text
|
completion_text += chunk.event.delta.text
|
||||||
if (
|
if chunk.event.event_type == ChatCompletionResponseEventType.complete:
|
||||||
chunk.event.event_type
|
|
||||||
== ChatCompletionResponseEventType.complete
|
|
||||||
):
|
|
||||||
completion_tokens = await self._count_tokens(
|
completion_tokens = await self._count_tokens(
|
||||||
[
|
[
|
||||||
CompletionMessage(
|
CompletionMessage(
|
||||||
|
@ -347,11 +312,7 @@ class InferenceRouter(Inference):
|
||||||
total_tokens,
|
total_tokens,
|
||||||
model,
|
model,
|
||||||
)
|
)
|
||||||
chunk.metrics = (
|
chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics
|
||||||
metrics
|
|
||||||
if chunk.metrics is None
|
|
||||||
else chunk.metrics + metrics
|
|
||||||
)
|
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
return stream_generator()
|
return stream_generator()
|
||||||
|
@ -368,9 +329,7 @@ class InferenceRouter(Inference):
|
||||||
total_tokens,
|
total_tokens,
|
||||||
model,
|
model,
|
||||||
)
|
)
|
||||||
response.metrics = (
|
response.metrics = metrics if response.metrics is None else response.metrics + metrics
|
||||||
metrics if response.metrics is None else response.metrics + metrics
|
|
||||||
)
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
async def completion(
|
async def completion(
|
||||||
|
@ -391,9 +350,7 @@ class InferenceRouter(Inference):
|
||||||
if model is None:
|
if model is None:
|
||||||
raise ValueError(f"Model '{model_id}' not found")
|
raise ValueError(f"Model '{model_id}' not found")
|
||||||
if model.model_type == ModelType.embedding:
|
if model.model_type == ModelType.embedding:
|
||||||
raise ValueError(
|
raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions")
|
||||||
f"Model '{model_id}' is an embedding model and does not support chat completions"
|
|
||||||
)
|
|
||||||
provider = self.routing_table.get_provider_impl(model_id)
|
provider = self.routing_table.get_provider_impl(model_id)
|
||||||
params = dict(
|
params = dict(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
|
@ -413,11 +370,7 @@ class InferenceRouter(Inference):
|
||||||
async for chunk in await provider.completion(**params):
|
async for chunk in await provider.completion(**params):
|
||||||
if hasattr(chunk, "delta"):
|
if hasattr(chunk, "delta"):
|
||||||
completion_text += chunk.delta
|
completion_text += chunk.delta
|
||||||
if (
|
if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry:
|
||||||
hasattr(chunk, "stop_reason")
|
|
||||||
and chunk.stop_reason
|
|
||||||
and self.telemetry
|
|
||||||
):
|
|
||||||
completion_tokens = await self._count_tokens(completion_text)
|
completion_tokens = await self._count_tokens(completion_text)
|
||||||
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
||||||
metrics = await self._compute_and_log_token_usage(
|
metrics = await self._compute_and_log_token_usage(
|
||||||
|
@ -426,11 +379,7 @@ class InferenceRouter(Inference):
|
||||||
total_tokens,
|
total_tokens,
|
||||||
model,
|
model,
|
||||||
)
|
)
|
||||||
chunk.metrics = (
|
chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics
|
||||||
metrics
|
|
||||||
if chunk.metrics is None
|
|
||||||
else chunk.metrics + metrics
|
|
||||||
)
|
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
return stream_generator()
|
return stream_generator()
|
||||||
|
@ -444,9 +393,7 @@ class InferenceRouter(Inference):
|
||||||
total_tokens,
|
total_tokens,
|
||||||
model,
|
model,
|
||||||
)
|
)
|
||||||
response.metrics = (
|
response.metrics = metrics if response.metrics is None else response.metrics + metrics
|
||||||
metrics if response.metrics is None else response.metrics + metrics
|
|
||||||
)
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
|
@ -462,9 +409,7 @@ class InferenceRouter(Inference):
|
||||||
if model is None:
|
if model is None:
|
||||||
raise ValueError(f"Model '{model_id}' not found")
|
raise ValueError(f"Model '{model_id}' not found")
|
||||||
if model.model_type == ModelType.llm:
|
if model.model_type == ModelType.llm:
|
||||||
raise ValueError(
|
raise ValueError(f"Model '{model_id}' is an LLM model and does not support embeddings")
|
||||||
f"Model '{model_id}' is an LLM model and does not support embeddings"
|
|
||||||
)
|
|
||||||
return await self.routing_table.get_provider_impl(model_id).embeddings(
|
return await self.routing_table.get_provider_impl(model_id).embeddings(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
contents=contents,
|
contents=contents,
|
||||||
|
@ -498,9 +443,7 @@ class SafetyRouter(Safety):
|
||||||
params: Optional[Dict[str, Any]] = None,
|
params: Optional[Dict[str, Any]] = None,
|
||||||
) -> Shield:
|
) -> Shield:
|
||||||
logger.debug(f"SafetyRouter.register_shield: {shield_id}")
|
logger.debug(f"SafetyRouter.register_shield: {shield_id}")
|
||||||
return await self.routing_table.register_shield(
|
return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params)
|
||||||
shield_id, provider_shield_id, provider_id, params
|
|
||||||
)
|
|
||||||
|
|
||||||
async def run_shield(
|
async def run_shield(
|
||||||
self,
|
self,
|
||||||
|
@ -597,9 +540,7 @@ class ScoringRouter(Scoring):
|
||||||
logger.debug(f"ScoringRouter.score_batch: {dataset_id}")
|
logger.debug(f"ScoringRouter.score_batch: {dataset_id}")
|
||||||
res = {}
|
res = {}
|
||||||
for fn_identifier in scoring_functions.keys():
|
for fn_identifier in scoring_functions.keys():
|
||||||
score_response = await self.routing_table.get_provider_impl(
|
score_response = await self.routing_table.get_provider_impl(fn_identifier).score_batch(
|
||||||
fn_identifier
|
|
||||||
).score_batch(
|
|
||||||
dataset_id=dataset_id,
|
dataset_id=dataset_id,
|
||||||
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
|
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
|
||||||
)
|
)
|
||||||
|
@ -617,15 +558,11 @@ class ScoringRouter(Scoring):
|
||||||
input_rows: List[Dict[str, Any]],
|
input_rows: List[Dict[str, Any]],
|
||||||
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
|
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
|
||||||
) -> ScoreResponse:
|
) -> ScoreResponse:
|
||||||
logger.debug(
|
logger.debug(f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions")
|
||||||
f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions"
|
|
||||||
)
|
|
||||||
res = {}
|
res = {}
|
||||||
# look up and map each scoring function to its provider impl
|
# look up and map each scoring function to its provider impl
|
||||||
for fn_identifier in scoring_functions.keys():
|
for fn_identifier in scoring_functions.keys():
|
||||||
score_response = await self.routing_table.get_provider_impl(
|
score_response = await self.routing_table.get_provider_impl(fn_identifier).score(
|
||||||
fn_identifier
|
|
||||||
).score(
|
|
||||||
input_rows=input_rows,
|
input_rows=input_rows,
|
||||||
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
|
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
|
||||||
)
|
)
|
||||||
|
@ -668,9 +605,7 @@ class EvalRouter(Eval):
|
||||||
scoring_functions: List[str],
|
scoring_functions: List[str],
|
||||||
benchmark_config: BenchmarkConfig,
|
benchmark_config: BenchmarkConfig,
|
||||||
) -> EvaluateResponse:
|
) -> EvaluateResponse:
|
||||||
logger.debug(
|
logger.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows")
|
||||||
f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows"
|
|
||||||
)
|
|
||||||
return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows(
|
return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows(
|
||||||
benchmark_id=benchmark_id,
|
benchmark_id=benchmark_id,
|
||||||
input_rows=input_rows,
|
input_rows=input_rows,
|
||||||
|
@ -684,9 +619,7 @@ class EvalRouter(Eval):
|
||||||
job_id: str,
|
job_id: str,
|
||||||
) -> Job:
|
) -> Job:
|
||||||
logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}")
|
logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}")
|
||||||
return await self.routing_table.get_provider_impl(benchmark_id).job_status(
|
return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id)
|
||||||
benchmark_id, job_id
|
|
||||||
)
|
|
||||||
|
|
||||||
async def job_cancel(
|
async def job_cancel(
|
||||||
self,
|
self,
|
||||||
|
@ -740,9 +673,9 @@ class ToolRuntimeRouter(ToolRuntime):
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}"
|
f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}"
|
||||||
)
|
)
|
||||||
return await self.routing_table.get_provider_impl(
|
return await self.routing_table.get_provider_impl("insert_into_memory").insert(
|
||||||
"insert_into_memory"
|
documents, vector_db_id, chunk_size_in_tokens
|
||||||
).insert(documents, vector_db_id, chunk_size_in_tokens)
|
)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -775,6 +708,4 @@ class ToolRuntimeRouter(ToolRuntime):
|
||||||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||||
) -> List[ToolDef]:
|
) -> List[ToolDef]:
|
||||||
logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
|
logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
|
||||||
return await self.routing_table.get_provider_impl(tool_group_id).list_tools(
|
return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint)
|
||||||
tool_group_id, mcp_endpoint
|
|
||||||
)
|
|
||||||
|
|
|
@ -89,11 +89,7 @@ class MetaReferenceEvalImpl(
|
||||||
|
|
||||||
all_rows = await self.datasetio_api.iterrows(
|
all_rows = await self.datasetio_api.iterrows(
|
||||||
dataset_id=dataset_id,
|
dataset_id=dataset_id,
|
||||||
limit=(
|
limit=(-1 if benchmark_config.num_examples is None else benchmark_config.num_examples),
|
||||||
-1
|
|
||||||
if benchmark_config.num_examples is None
|
|
||||||
else benchmark_config.num_examples
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
res = await self.evaluate_rows(
|
res = await self.evaluate_rows(
|
||||||
benchmark_id=benchmark_id,
|
benchmark_id=benchmark_id,
|
||||||
|
@ -119,14 +115,10 @@ class MetaReferenceEvalImpl(
|
||||||
for i, x in tqdm(enumerate(input_rows)):
|
for i, x in tqdm(enumerate(input_rows)):
|
||||||
assert ColumnName.chat_completion_input.value in x, "Invalid input row"
|
assert ColumnName.chat_completion_input.value in x, "Invalid input row"
|
||||||
input_messages = json.loads(x[ColumnName.chat_completion_input.value])
|
input_messages = json.loads(x[ColumnName.chat_completion_input.value])
|
||||||
input_messages = [
|
input_messages = [UserMessage(**x) for x in input_messages if x["role"] == "user"]
|
||||||
UserMessage(**x) for x in input_messages if x["role"] == "user"
|
|
||||||
]
|
|
||||||
|
|
||||||
# NOTE: only single-turn agent generation is supported. Create a new session for each input row
|
# NOTE: only single-turn agent generation is supported. Create a new session for each input row
|
||||||
session_create_response = await self.agents_api.create_agent_session(
|
session_create_response = await self.agents_api.create_agent_session(agent_id, f"session-{i}")
|
||||||
agent_id, f"session-{i}"
|
|
||||||
)
|
|
||||||
session_id = session_create_response.session_id
|
session_id = session_create_response.session_id
|
||||||
|
|
||||||
turn_request = dict(
|
turn_request = dict(
|
||||||
|
@ -135,12 +127,7 @@ class MetaReferenceEvalImpl(
|
||||||
messages=input_messages,
|
messages=input_messages,
|
||||||
stream=True,
|
stream=True,
|
||||||
)
|
)
|
||||||
turn_response = [
|
turn_response = [chunk async for chunk in await self.agents_api.create_agent_turn(**turn_request)]
|
||||||
chunk
|
|
||||||
async for chunk in await self.agents_api.create_agent_turn(
|
|
||||||
**turn_request
|
|
||||||
)
|
|
||||||
]
|
|
||||||
final_event = turn_response[-1].event.payload
|
final_event = turn_response[-1].event.payload
|
||||||
|
|
||||||
# check if there's a memory retrieval step and extract the context
|
# check if there's a memory retrieval step and extract the context
|
||||||
|
@ -149,14 +136,10 @@ class MetaReferenceEvalImpl(
|
||||||
if step.step_type == StepType.tool_execution.value:
|
if step.step_type == StepType.tool_execution.value:
|
||||||
for tool_response in step.tool_responses:
|
for tool_response in step.tool_responses:
|
||||||
if tool_response.tool_name == MEMORY_QUERY_TOOL:
|
if tool_response.tool_name == MEMORY_QUERY_TOOL:
|
||||||
memory_rag_context = " ".join(
|
memory_rag_context = " ".join(x.text for x in tool_response.content)
|
||||||
x.text for x in tool_response.content
|
|
||||||
)
|
|
||||||
|
|
||||||
agent_generation = {}
|
agent_generation = {}
|
||||||
agent_generation[ColumnName.generated_answer.value] = (
|
agent_generation[ColumnName.generated_answer.value] = final_event.turn.output_message.content
|
||||||
final_event.turn.output_message.content
|
|
||||||
)
|
|
||||||
if memory_rag_context:
|
if memory_rag_context:
|
||||||
agent_generation[ColumnName.context.value] = memory_rag_context
|
agent_generation[ColumnName.context.value] = memory_rag_context
|
||||||
|
|
||||||
|
@ -168,9 +151,7 @@ class MetaReferenceEvalImpl(
|
||||||
self, input_rows: List[Dict[str, Any]], benchmark_config: BenchmarkConfig
|
self, input_rows: List[Dict[str, Any]], benchmark_config: BenchmarkConfig
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
candidate = benchmark_config.eval_candidate
|
candidate = benchmark_config.eval_candidate
|
||||||
assert (
|
assert candidate.sampling_params.max_tokens is not None, "SamplingParams.max_tokens must be provided"
|
||||||
candidate.sampling_params.max_tokens is not None
|
|
||||||
), "SamplingParams.max_tokens must be provided"
|
|
||||||
|
|
||||||
generations = []
|
generations = []
|
||||||
for x in tqdm(input_rows):
|
for x in tqdm(input_rows):
|
||||||
|
@ -181,39 +162,21 @@ class MetaReferenceEvalImpl(
|
||||||
content=input_content,
|
content=input_content,
|
||||||
sampling_params=candidate.sampling_params,
|
sampling_params=candidate.sampling_params,
|
||||||
)
|
)
|
||||||
generations.append(
|
generations.append({ColumnName.generated_answer.value: response.completion_message.content})
|
||||||
{
|
|
||||||
ColumnName.generated_answer.value: response.completion_message.content
|
|
||||||
}
|
|
||||||
)
|
|
||||||
elif ColumnName.chat_completion_input.value in x:
|
elif ColumnName.chat_completion_input.value in x:
|
||||||
chat_completion_input_json = json.loads(
|
chat_completion_input_json = json.loads(x[ColumnName.chat_completion_input.value])
|
||||||
x[ColumnName.chat_completion_input.value]
|
input_messages = [UserMessage(**x) for x in chat_completion_input_json if x["role"] == "user"]
|
||||||
)
|
|
||||||
input_messages = [
|
|
||||||
UserMessage(**x)
|
|
||||||
for x in chat_completion_input_json
|
|
||||||
if x["role"] == "user"
|
|
||||||
]
|
|
||||||
messages = []
|
messages = []
|
||||||
if candidate.system_message:
|
if candidate.system_message:
|
||||||
messages.append(candidate.system_message)
|
messages.append(candidate.system_message)
|
||||||
messages += [
|
messages += [SystemMessage(**x) for x in chat_completion_input_json if x["role"] == "system"]
|
||||||
SystemMessage(**x)
|
|
||||||
for x in chat_completion_input_json
|
|
||||||
if x["role"] == "system"
|
|
||||||
]
|
|
||||||
messages += input_messages
|
messages += input_messages
|
||||||
response = await self.inference_api.chat_completion(
|
response = await self.inference_api.chat_completion(
|
||||||
model_id=candidate.model,
|
model_id=candidate.model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
sampling_params=candidate.sampling_params,
|
sampling_params=candidate.sampling_params,
|
||||||
)
|
)
|
||||||
generations.append(
|
generations.append({ColumnName.generated_answer.value: response.completion_message.content})
|
||||||
{
|
|
||||||
ColumnName.generated_answer.value: response.completion_message.content
|
|
||||||
}
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError("Invalid input row")
|
raise ValueError("Invalid input row")
|
||||||
|
|
||||||
|
@ -236,8 +199,7 @@ class MetaReferenceEvalImpl(
|
||||||
|
|
||||||
# scoring with generated_answer
|
# scoring with generated_answer
|
||||||
score_input_rows = [
|
score_input_rows = [
|
||||||
input_r | generated_r
|
input_r | generated_r for input_r, generated_r in zip(input_rows, generations, strict=False)
|
||||||
for input_r, generated_r in zip(input_rows, generations, strict=False)
|
|
||||||
]
|
]
|
||||||
|
|
||||||
if benchmark_config.scoring_params is not None:
|
if benchmark_config.scoring_params is not None:
|
||||||
|
@ -246,9 +208,7 @@ class MetaReferenceEvalImpl(
|
||||||
for scoring_fn_id in scoring_functions
|
for scoring_fn_id in scoring_functions
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
scoring_functions_dict = {
|
scoring_functions_dict = {scoring_fn_id: None for scoring_fn_id in scoring_functions}
|
||||||
scoring_fn_id: None for scoring_fn_id in scoring_functions
|
|
||||||
}
|
|
||||||
|
|
||||||
score_response = await self.scoring_api.score(
|
score_response = await self.scoring_api.score(
|
||||||
input_rows=score_input_rows, scoring_functions=scoring_functions_dict
|
input_rows=score_input_rows, scoring_functions=scoring_functions_dict
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue