This commit is contained in:
Xi Yan 2025-03-15 14:48:26 -07:00
parent 9b38ae9323
commit f262bfd061
10 changed files with 107 additions and 319 deletions

View file

@ -8,9 +8,9 @@ import time
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
URL,
InterleavedContent, InterleavedContent,
InterleavedContentItem, InterleavedContentItem,
URL,
) )
from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse
from llama_stack.apis.datasets import DatasetPurpose, DataSource from llama_stack.apis.datasets import DatasetPurpose, DataSource
@ -94,9 +94,7 @@ class VectorIORouter(VectorIO):
provider_id: Optional[str] = None, provider_id: Optional[str] = None,
provider_vector_db_id: Optional[str] = None, provider_vector_db_id: Optional[str] = None,
) -> None: ) -> None:
logger.debug( logger.debug(f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}")
f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}"
)
await self.routing_table.register_vector_db( await self.routing_table.register_vector_db(
vector_db_id, vector_db_id,
embedding_model, embedding_model,
@ -114,9 +112,7 @@ class VectorIORouter(VectorIO):
logger.debug( logger.debug(
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}", f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}",
) )
return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks( return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(vector_db_id, chunks, ttl_seconds)
vector_db_id, chunks, ttl_seconds
)
async def query_chunks( async def query_chunks(
self, self,
@ -125,9 +121,7 @@ class VectorIORouter(VectorIO):
params: Optional[Dict[str, Any]] = None, params: Optional[Dict[str, Any]] = None,
) -> QueryChunksResponse: ) -> QueryChunksResponse:
logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}") logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}")
return await self.routing_table.get_provider_impl(vector_db_id).query_chunks( return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params)
vector_db_id, query, params
)
class InferenceRouter(Inference): class InferenceRouter(Inference):
@ -164,9 +158,7 @@ class InferenceRouter(Inference):
logger.debug( logger.debug(
f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}", f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}",
) )
await self.routing_table.register_model( await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)
model_id, provider_model_id, provider_id, metadata, model_type
)
def _construct_metrics( def _construct_metrics(
self, self,
@ -220,16 +212,11 @@ class InferenceRouter(Inference):
total_tokens: int, total_tokens: int,
model: Model, model: Model,
) -> List[MetricInResponse]: ) -> List[MetricInResponse]:
metrics = self._construct_metrics( metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model)
prompt_tokens, completion_tokens, total_tokens, model
)
if self.telemetry: if self.telemetry:
for metric in metrics: for metric in metrics:
await self.telemetry.log_event(metric) await self.telemetry.log_event(metric)
return [ return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics]
MetricInResponse(metric=metric.metric, value=metric.value)
for metric in metrics
]
async def _count_tokens( async def _count_tokens(
self, self,
@ -254,9 +241,7 @@ class InferenceRouter(Inference):
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None, tool_config: Optional[ToolConfig] = None,
) -> Union[ ) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
]:
logger.debug( logger.debug(
f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}", f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}",
) )
@ -266,19 +251,12 @@ class InferenceRouter(Inference):
if model is None: if model is None:
raise ValueError(f"Model '{model_id}' not found") raise ValueError(f"Model '{model_id}' not found")
if model.model_type == ModelType.embedding: if model.model_type == ModelType.embedding:
raise ValueError( raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions")
f"Model '{model_id}' is an embedding model and does not support chat completions"
)
if tool_config: if tool_config:
if tool_choice and tool_choice != tool_config.tool_choice: if tool_choice and tool_choice != tool_config.tool_choice:
raise ValueError("tool_choice and tool_config.tool_choice must match") raise ValueError("tool_choice and tool_config.tool_choice must match")
if ( if tool_prompt_format and tool_prompt_format != tool_config.tool_prompt_format:
tool_prompt_format raise ValueError("tool_prompt_format and tool_config.tool_prompt_format must match")
and tool_prompt_format != tool_config.tool_prompt_format
):
raise ValueError(
"tool_prompt_format and tool_config.tool_prompt_format must match"
)
else: else:
params = {} params = {}
if tool_choice: if tool_choice:
@ -296,14 +274,9 @@ class InferenceRouter(Inference):
pass pass
else: else:
# verify tool_choice is one of the tools # verify tool_choice is one of the tools
tool_names = [ tool_names = [t.tool_name if isinstance(t.tool_name, str) else t.tool_name.value for t in tools]
t.tool_name if isinstance(t.tool_name, str) else t.tool_name.value
for t in tools
]
if tool_config.tool_choice not in tool_names: if tool_config.tool_choice not in tool_names:
raise ValueError( raise ValueError(f"Tool choice {tool_config.tool_choice} is not one of the tools: {tool_names}")
f"Tool choice {tool_config.tool_choice} is not one of the tools: {tool_names}"
)
params = dict( params = dict(
model_id=model_id, model_id=model_id,
@ -318,25 +291,17 @@ class InferenceRouter(Inference):
tool_config=tool_config, tool_config=tool_config,
) )
provider = self.routing_table.get_provider_impl(model_id) provider = self.routing_table.get_provider_impl(model_id)
prompt_tokens = await self._count_tokens( prompt_tokens = await self._count_tokens(messages, tool_config.tool_prompt_format)
messages, tool_config.tool_prompt_format
)
if stream: if stream:
async def stream_generator(): async def stream_generator():
completion_text = "" completion_text = ""
async for chunk in await provider.chat_completion(**params): async for chunk in await provider.chat_completion(**params):
if ( if chunk.event.event_type == ChatCompletionResponseEventType.progress:
chunk.event.event_type
== ChatCompletionResponseEventType.progress
):
if chunk.event.delta.type == "text": if chunk.event.delta.type == "text":
completion_text += chunk.event.delta.text completion_text += chunk.event.delta.text
if ( if chunk.event.event_type == ChatCompletionResponseEventType.complete:
chunk.event.event_type
== ChatCompletionResponseEventType.complete
):
completion_tokens = await self._count_tokens( completion_tokens = await self._count_tokens(
[ [
CompletionMessage( CompletionMessage(
@ -353,11 +318,7 @@ class InferenceRouter(Inference):
total_tokens, total_tokens,
model, model,
) )
chunk.metrics = ( chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics
metrics
if chunk.metrics is None
else chunk.metrics + metrics
)
yield chunk yield chunk
return stream_generator() return stream_generator()
@ -374,9 +335,7 @@ class InferenceRouter(Inference):
total_tokens, total_tokens,
model, model,
) )
response.metrics = ( response.metrics = metrics if response.metrics is None else response.metrics + metrics
metrics if response.metrics is None else response.metrics + metrics
)
return response return response
async def completion( async def completion(
@ -397,9 +356,7 @@ class InferenceRouter(Inference):
if model is None: if model is None:
raise ValueError(f"Model '{model_id}' not found") raise ValueError(f"Model '{model_id}' not found")
if model.model_type == ModelType.embedding: if model.model_type == ModelType.embedding:
raise ValueError( raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions")
f"Model '{model_id}' is an embedding model and does not support chat completions"
)
provider = self.routing_table.get_provider_impl(model_id) provider = self.routing_table.get_provider_impl(model_id)
params = dict( params = dict(
model_id=model_id, model_id=model_id,
@ -419,11 +376,7 @@ class InferenceRouter(Inference):
async for chunk in await provider.completion(**params): async for chunk in await provider.completion(**params):
if hasattr(chunk, "delta"): if hasattr(chunk, "delta"):
completion_text += chunk.delta completion_text += chunk.delta
if ( if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry:
hasattr(chunk, "stop_reason")
and chunk.stop_reason
and self.telemetry
):
completion_tokens = await self._count_tokens(completion_text) completion_tokens = await self._count_tokens(completion_text)
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0) total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
metrics = await self._compute_and_log_token_usage( metrics = await self._compute_and_log_token_usage(
@ -432,11 +385,7 @@ class InferenceRouter(Inference):
total_tokens, total_tokens,
model, model,
) )
chunk.metrics = ( chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics
metrics
if chunk.metrics is None
else chunk.metrics + metrics
)
yield chunk yield chunk
return stream_generator() return stream_generator()
@ -450,9 +399,7 @@ class InferenceRouter(Inference):
total_tokens, total_tokens,
model, model,
) )
response.metrics = ( response.metrics = metrics if response.metrics is None else response.metrics + metrics
metrics if response.metrics is None else response.metrics + metrics
)
return response return response
async def embeddings( async def embeddings(
@ -468,9 +415,7 @@ class InferenceRouter(Inference):
if model is None: if model is None:
raise ValueError(f"Model '{model_id}' not found") raise ValueError(f"Model '{model_id}' not found")
if model.model_type == ModelType.llm: if model.model_type == ModelType.llm:
raise ValueError( raise ValueError(f"Model '{model_id}' is an LLM model and does not support embeddings")
f"Model '{model_id}' is an LLM model and does not support embeddings"
)
return await self.routing_table.get_provider_impl(model_id).embeddings( return await self.routing_table.get_provider_impl(model_id).embeddings(
model_id=model_id, model_id=model_id,
contents=contents, contents=contents,
@ -504,9 +449,7 @@ class SafetyRouter(Safety):
params: Optional[Dict[str, Any]] = None, params: Optional[Dict[str, Any]] = None,
) -> Shield: ) -> Shield:
logger.debug(f"SafetyRouter.register_shield: {shield_id}") logger.debug(f"SafetyRouter.register_shield: {shield_id}")
return await self.routing_table.register_shield( return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params)
shield_id, provider_shield_id, provider_id, params
)
async def run_shield( async def run_shield(
self, self,
@ -603,9 +546,7 @@ class ScoringRouter(Scoring):
logger.debug(f"ScoringRouter.score_batch: {dataset_id}") logger.debug(f"ScoringRouter.score_batch: {dataset_id}")
res = {} res = {}
for fn_identifier in scoring_functions.keys(): for fn_identifier in scoring_functions.keys():
score_response = await self.routing_table.get_provider_impl( score_response = await self.routing_table.get_provider_impl(fn_identifier).score_batch(
fn_identifier
).score_batch(
dataset_id=dataset_id, dataset_id=dataset_id,
scoring_functions={fn_identifier: scoring_functions[fn_identifier]}, scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
) )
@ -623,15 +564,11 @@ class ScoringRouter(Scoring):
input_rows: List[Dict[str, Any]], input_rows: List[Dict[str, Any]],
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None, scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
) -> ScoreResponse: ) -> ScoreResponse:
logger.debug( logger.debug(f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions")
f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions"
)
res = {} res = {}
# look up and map each scoring function to its provider impl # look up and map each scoring function to its provider impl
for fn_identifier in scoring_functions.keys(): for fn_identifier in scoring_functions.keys():
score_response = await self.routing_table.get_provider_impl( score_response = await self.routing_table.get_provider_impl(fn_identifier).score(
fn_identifier
).score(
input_rows=input_rows, input_rows=input_rows,
scoring_functions={fn_identifier: scoring_functions[fn_identifier]}, scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
) )
@ -674,9 +611,7 @@ class EvalRouter(Eval):
scoring_functions: List[str], scoring_functions: List[str],
benchmark_config: BenchmarkConfig, benchmark_config: BenchmarkConfig,
) -> EvaluateResponse: ) -> EvaluateResponse:
logger.debug( logger.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows")
f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows"
)
return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows( return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows(
benchmark_id=benchmark_id, benchmark_id=benchmark_id,
input_rows=input_rows, input_rows=input_rows,
@ -690,9 +625,7 @@ class EvalRouter(Eval):
job_id: str, job_id: str,
) -> Optional[JobStatus]: ) -> Optional[JobStatus]:
logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}") logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}")
return await self.routing_table.get_provider_impl(benchmark_id).job_status( return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id)
benchmark_id, job_id
)
async def job_cancel( async def job_cancel(
self, self,
@ -746,9 +679,9 @@ class ToolRuntimeRouter(ToolRuntime):
logger.debug( logger.debug(
f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}" f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}"
) )
return await self.routing_table.get_provider_impl( return await self.routing_table.get_provider_impl("insert_into_memory").insert(
"insert_into_memory" documents, vector_db_id, chunk_size_in_tokens
).insert(documents, vector_db_id, chunk_size_in_tokens) )
def __init__( def __init__(
self, self,
@ -781,6 +714,4 @@ class ToolRuntimeRouter(ToolRuntime):
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]: ) -> List[ToolDef]:
logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}") logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
return await self.routing_table.get_provider_impl(tool_group_id).list_tools( return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint)
tool_group_id, mcp_endpoint
)

View file

@ -105,9 +105,7 @@ class CommonRoutingTableImpl(RoutingTable):
self.dist_registry = dist_registry self.dist_registry = dist_registry
async def initialize(self) -> None: async def initialize(self) -> None:
async def add_objects( async def add_objects(objs: List[RoutableObjectWithProvider], provider_id: str, cls) -> None:
objs: List[RoutableObjectWithProvider], provider_id: str, cls
) -> None:
for obj in objs: for obj in objs:
if cls is None: if cls is None:
obj.provider_id = provider_id obj.provider_id = provider_id
@ -142,9 +140,7 @@ class CommonRoutingTableImpl(RoutingTable):
for p in self.impls_by_provider_id.values(): for p in self.impls_by_provider_id.values():
await p.shutdown() await p.shutdown()
def get_provider_impl( def get_provider_impl(self, routing_key: str, provider_id: Optional[str] = None) -> Any:
self, routing_key: str, provider_id: Optional[str] = None
) -> Any:
def apiname_object(): def apiname_object():
if isinstance(self, ModelsRoutingTable): if isinstance(self, ModelsRoutingTable):
return ("Inference", "model") return ("Inference", "model")
@ -182,9 +178,7 @@ class CommonRoutingTableImpl(RoutingTable):
raise ValueError(f"Provider not found for `{routing_key}`") raise ValueError(f"Provider not found for `{routing_key}`")
async def get_object_by_identifier( async def get_object_by_identifier(self, type: str, identifier: str) -> Optional[RoutableObjectWithProvider]:
self, type: str, identifier: str
) -> Optional[RoutableObjectWithProvider]:
# Get from disk registry # Get from disk registry
obj = await self.dist_registry.get(type, identifier) obj = await self.dist_registry.get(type, identifier)
if not obj: if not obj:
@ -194,13 +188,9 @@ class CommonRoutingTableImpl(RoutingTable):
async def unregister_object(self, obj: RoutableObjectWithProvider) -> None: async def unregister_object(self, obj: RoutableObjectWithProvider) -> None:
await self.dist_registry.delete(obj.type, obj.identifier) await self.dist_registry.delete(obj.type, obj.identifier)
await unregister_object_from_provider( await unregister_object_from_provider(obj, self.impls_by_provider_id[obj.provider_id])
obj, self.impls_by_provider_id[obj.provider_id]
)
async def register_object( async def register_object(self, obj: RoutableObjectWithProvider) -> RoutableObjectWithProvider:
self, obj: RoutableObjectWithProvider
) -> RoutableObjectWithProvider:
# if provider_id is not specified, pick an arbitrary one from existing entries # if provider_id is not specified, pick an arbitrary one from existing entries
if not obj.provider_id and len(self.impls_by_provider_id) > 0: if not obj.provider_id and len(self.impls_by_provider_id) > 0:
obj.provider_id = list(self.impls_by_provider_id.keys())[0] obj.provider_id = list(self.impls_by_provider_id.keys())[0]
@ -255,9 +245,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
if model_type is None: if model_type is None:
model_type = ModelType.llm model_type = ModelType.llm
if "embedding_dimension" not in metadata and model_type == ModelType.embedding: if "embedding_dimension" not in metadata and model_type == ModelType.embedding:
raise ValueError( raise ValueError("Embedding model must have an embedding dimension in its metadata")
"Embedding model must have an embedding dimension in its metadata"
)
model = Model( model = Model(
identifier=model_id, identifier=model_id,
provider_resource_id=provider_model_id, provider_resource_id=provider_model_id,
@ -277,9 +265,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
async def list_shields(self) -> ListShieldsResponse: async def list_shields(self) -> ListShieldsResponse:
return ListShieldsResponse( return ListShieldsResponse(data=await self.get_all_with_type(ResourceType.shield.value))
data=await self.get_all_with_type(ResourceType.shield.value)
)
async def get_shield(self, identifier: str) -> Optional[Shield]: async def get_shield(self, identifier: str) -> Optional[Shield]:
return await self.get_object_by_identifier("shield", identifier) return await self.get_object_by_identifier("shield", identifier)
@ -338,18 +324,14 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
f"No provider specified and multiple providers available. Arbitrarily selected the first provider {provider_id}." f"No provider specified and multiple providers available. Arbitrarily selected the first provider {provider_id}."
) )
else: else:
raise ValueError( raise ValueError("No provider available. Please configure a vector_io provider.")
"No provider available. Please configure a vector_io provider."
)
model = await self.get_object_by_identifier("model", embedding_model) model = await self.get_object_by_identifier("model", embedding_model)
if model is None: if model is None:
raise ValueError(f"Model {embedding_model} not found") raise ValueError(f"Model {embedding_model} not found")
if model.model_type != ModelType.embedding: if model.model_type != ModelType.embedding:
raise ValueError(f"Model {embedding_model} is not an embedding model") raise ValueError(f"Model {embedding_model} is not an embedding model")
if "embedding_dimension" not in model.metadata: if "embedding_dimension" not in model.metadata:
raise ValueError( raise ValueError(f"Model {embedding_model} does not have an embedding dimension")
f"Model {embedding_model} does not have an embedding dimension"
)
vector_db_data = { vector_db_data = {
"identifier": vector_db_id, "identifier": vector_db_id,
"type": ResourceType.vector_db.value, "type": ResourceType.vector_db.value,
@ -371,9 +353,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
async def list_datasets(self) -> ListDatasetsResponse: async def list_datasets(self) -> ListDatasetsResponse:
return ListDatasetsResponse( return ListDatasetsResponse(data=await self.get_all_with_type(ResourceType.dataset.value))
data=await self.get_all_with_type(ResourceType.dataset.value)
)
async def get_dataset(self, dataset_id: str) -> Optional[Dataset]: async def get_dataset(self, dataset_id: str) -> Optional[Dataset]:
return await self.get_object_by_identifier("dataset", dataset_id) return await self.get_object_by_identifier("dataset", dataset_id)
@ -426,9 +406,7 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions): class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
async def list_scoring_functions(self) -> ListScoringFunctionsResponse: async def list_scoring_functions(self) -> ListScoringFunctionsResponse:
return ListScoringFunctionsResponse( return ListScoringFunctionsResponse(data=await self.get_all_with_type(ResourceType.scoring_function.value))
data=await self.get_all_with_type(ResourceType.scoring_function.value)
)
async def get_scoring_function(self, scoring_fn_id: str) -> Optional[ScoringFn]: async def get_scoring_function(self, scoring_fn_id: str) -> Optional[ScoringFn]:
return await self.get_object_by_identifier("scoring_function", scoring_fn_id) return await self.get_object_by_identifier("scoring_function", scoring_fn_id)
@ -525,12 +503,8 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
args: Optional[Dict[str, Any]] = None, args: Optional[Dict[str, Any]] = None,
) -> None: ) -> None:
tools = [] tools = []
tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools( tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools(toolgroup_id, mcp_endpoint)
toolgroup_id, mcp_endpoint tool_host = ToolHost.model_context_protocol if mcp_endpoint else ToolHost.distribution
)
tool_host = (
ToolHost.model_context_protocol if mcp_endpoint else ToolHost.distribution
)
for tool_def in tool_defs: for tool_def in tool_defs:
tools.append( tools.append(

View file

@ -230,9 +230,7 @@ def run_evaluation_3():
output_res[scoring_fn] = [] output_res[scoring_fn] = []
output_res[scoring_fn].append(eval_res.scores[scoring_fn].score_rows[0]) output_res[scoring_fn].append(eval_res.scores[scoring_fn].score_rows[0])
progress_text_container.write( progress_text_container.write(f"Expand to see current processed result ({i + 1} / {len(rows)})")
f"Expand to see current processed result ({i + 1} / {len(rows)})"
)
results_container.json(eval_res, expanded=2) results_container.json(eval_res, expanded=2)
progress_bar.progress(1.0, text="Evaluation complete!") progress_bar.progress(1.0, text="Evaluation complete!")

View file

@ -161,9 +161,7 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
new_rows_df = pandas.DataFrame(rows) new_rows_df = pandas.DataFrame(rows)
new_rows_df = dataset_impl._validate_dataset_schema(new_rows_df) new_rows_df = dataset_impl._validate_dataset_schema(new_rows_df)
dataset_impl.df = pandas.concat( dataset_impl.df = pandas.concat([dataset_impl.df, new_rows_df], ignore_index=True)
[dataset_impl.df, new_rows_df], ignore_index=True
)
url = str(dataset_info.dataset_def.url.uri) url = str(dataset_info.dataset_def.url.uri)
parsed_url = urlparse(url) parsed_url = urlparse(url)
@ -178,12 +176,8 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
raise ValueError("Data URL must be a base64-encoded CSV") raise ValueError("Data URL must be a base64-encoded CSV")
csv_buffer = dataset_impl.df.to_csv(index=False) csv_buffer = dataset_impl.df.to_csv(index=False)
base64_content = base64.b64encode(csv_buffer.encode("utf-8")).decode( base64_content = base64.b64encode(csv_buffer.encode("utf-8")).decode("utf-8")
"utf-8" dataset_info.dataset_def.url = URL(uri=f"data:text/csv;base64,{base64_content}")
)
dataset_info.dataset_def.url = URL(
uri=f"data:text/csv;base64,{base64_content}"
)
else: else:
raise ValueError( raise ValueError(
f"Unsupported URL scheme: {parsed_url.scheme}. Only file:// and data: URLs are supported for writing." f"Unsupported URL scheme: {parsed_url.scheme}. Only file:// and data: URLs are supported for writing."

View file

@ -89,16 +89,10 @@ class MetaReferenceEvalImpl(
dataset_id = task_def.dataset_id dataset_id = task_def.dataset_id
scoring_functions = task_def.scoring_functions scoring_functions = task_def.scoring_functions
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id) dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
validate_dataset_schema( validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.eval.value))
dataset_def.dataset_schema, get_valid_schemas(Api.eval.value)
)
all_rows = await self.datasetio_api.iterrows( all_rows = await self.datasetio_api.iterrows(
dataset_id=dataset_id, dataset_id=dataset_id,
rows_in_page=( rows_in_page=(-1 if benchmark_config.num_examples is None else benchmark_config.num_examples),
-1
if benchmark_config.num_examples is None
else benchmark_config.num_examples
),
) )
res = await self.evaluate_rows( res = await self.evaluate_rows(
benchmark_id=benchmark_id, benchmark_id=benchmark_id,
@ -124,14 +118,10 @@ class MetaReferenceEvalImpl(
for i, x in tqdm(enumerate(input_rows)): for i, x in tqdm(enumerate(input_rows)):
assert ColumnName.chat_completion_input.value in x, "Invalid input row" assert ColumnName.chat_completion_input.value in x, "Invalid input row"
input_messages = json.loads(x[ColumnName.chat_completion_input.value]) input_messages = json.loads(x[ColumnName.chat_completion_input.value])
input_messages = [ input_messages = [UserMessage(**x) for x in input_messages if x["role"] == "user"]
UserMessage(**x) for x in input_messages if x["role"] == "user"
]
# NOTE: only single-turn agent generation is supported. Create a new session for each input row # NOTE: only single-turn agent generation is supported. Create a new session for each input row
session_create_response = await self.agents_api.create_agent_session( session_create_response = await self.agents_api.create_agent_session(agent_id, f"session-{i}")
agent_id, f"session-{i}"
)
session_id = session_create_response.session_id session_id = session_create_response.session_id
turn_request = dict( turn_request = dict(
@ -140,12 +130,7 @@ class MetaReferenceEvalImpl(
messages=input_messages, messages=input_messages,
stream=True, stream=True,
) )
turn_response = [ turn_response = [chunk async for chunk in await self.agents_api.create_agent_turn(**turn_request)]
chunk
async for chunk in await self.agents_api.create_agent_turn(
**turn_request
)
]
final_event = turn_response[-1].event.payload final_event = turn_response[-1].event.payload
# check if there's a memory retrieval step and extract the context # check if there's a memory retrieval step and extract the context
@ -154,14 +139,10 @@ class MetaReferenceEvalImpl(
if step.step_type == StepType.tool_execution.value: if step.step_type == StepType.tool_execution.value:
for tool_response in step.tool_responses: for tool_response in step.tool_responses:
if tool_response.tool_name == MEMORY_QUERY_TOOL: if tool_response.tool_name == MEMORY_QUERY_TOOL:
memory_rag_context = " ".join( memory_rag_context = " ".join(x.text for x in tool_response.content)
x.text for x in tool_response.content
)
agent_generation = {} agent_generation = {}
agent_generation[ColumnName.generated_answer.value] = ( agent_generation[ColumnName.generated_answer.value] = final_event.turn.output_message.content
final_event.turn.output_message.content
)
if memory_rag_context: if memory_rag_context:
agent_generation[ColumnName.context.value] = memory_rag_context agent_generation[ColumnName.context.value] = memory_rag_context
@ -173,9 +154,7 @@ class MetaReferenceEvalImpl(
self, input_rows: List[Dict[str, Any]], benchmark_config: BenchmarkConfig self, input_rows: List[Dict[str, Any]], benchmark_config: BenchmarkConfig
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
candidate = benchmark_config.eval_candidate candidate = benchmark_config.eval_candidate
assert ( assert candidate.sampling_params.max_tokens is not None, "SamplingParams.max_tokens must be provided"
candidate.sampling_params.max_tokens is not None
), "SamplingParams.max_tokens must be provided"
generations = [] generations = []
for x in tqdm(input_rows): for x in tqdm(input_rows):
@ -186,39 +165,21 @@ class MetaReferenceEvalImpl(
content=input_content, content=input_content,
sampling_params=candidate.sampling_params, sampling_params=candidate.sampling_params,
) )
generations.append( generations.append({ColumnName.generated_answer.value: response.completion_message.content})
{
ColumnName.generated_answer.value: response.completion_message.content
}
)
elif ColumnName.chat_completion_input.value in x: elif ColumnName.chat_completion_input.value in x:
chat_completion_input_json = json.loads( chat_completion_input_json = json.loads(x[ColumnName.chat_completion_input.value])
x[ColumnName.chat_completion_input.value] input_messages = [UserMessage(**x) for x in chat_completion_input_json if x["role"] == "user"]
)
input_messages = [
UserMessage(**x)
for x in chat_completion_input_json
if x["role"] == "user"
]
messages = [] messages = []
if candidate.system_message: if candidate.system_message:
messages.append(candidate.system_message) messages.append(candidate.system_message)
messages += [ messages += [SystemMessage(**x) for x in chat_completion_input_json if x["role"] == "system"]
SystemMessage(**x)
for x in chat_completion_input_json
if x["role"] == "system"
]
messages += input_messages messages += input_messages
response = await self.inference_api.chat_completion( response = await self.inference_api.chat_completion(
model_id=candidate.model, model_id=candidate.model,
messages=messages, messages=messages,
sampling_params=candidate.sampling_params, sampling_params=candidate.sampling_params,
) )
generations.append( generations.append({ColumnName.generated_answer.value: response.completion_message.content})
{
ColumnName.generated_answer.value: response.completion_message.content
}
)
else: else:
raise ValueError("Invalid input row") raise ValueError("Invalid input row")
@ -241,8 +202,7 @@ class MetaReferenceEvalImpl(
# scoring with generated_answer # scoring with generated_answer
score_input_rows = [ score_input_rows = [
input_r | generated_r input_r | generated_r for input_r, generated_r in zip(input_rows, generations, strict=False)
for input_r, generated_r in zip(input_rows, generations, strict=False)
] ]
if benchmark_config.scoring_params is not None: if benchmark_config.scoring_params is not None:
@ -251,9 +211,7 @@ class MetaReferenceEvalImpl(
for scoring_fn_id in scoring_functions for scoring_fn_id in scoring_functions
} }
else: else:
scoring_functions_dict = { scoring_functions_dict = {scoring_fn_id: None for scoring_fn_id in scoring_functions}
scoring_fn_id: None for scoring_fn_id in scoring_functions
}
score_response = await self.scoring_api.score( score_response = await self.scoring_api.score(
input_rows=score_input_rows, scoring_functions=scoring_functions_dict input_rows=score_input_rows, scoring_functions=scoring_functions_dict

View file

@ -17,7 +17,8 @@ import torch
from torch import nn from torch import nn
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler from torch.utils.data import DataLoader, DistributedSampler
from torchtune import modules, training, utils as torchtune_utils from torchtune import modules, training
from torchtune import utils as torchtune_utils
from torchtune.data import padded_collate_sft from torchtune.data import padded_collate_sft
from torchtune.modules.loss import CEWithChunkedOutputLoss from torchtune.modules.loss import CEWithChunkedOutputLoss
from torchtune.modules.peft import ( from torchtune.modules.peft import (
@ -88,9 +89,7 @@ class LoraFinetuningSingleDevice:
self.job_uuid = job_uuid self.job_uuid = job_uuid
self.training_config = training_config self.training_config = training_config
if not isinstance(algorithm_config, LoraFinetuningConfig): if not isinstance(algorithm_config, LoraFinetuningConfig):
raise ValueError( raise ValueError("You need to speicifc LoraFinetuningConfig for LoRA finetuning")
"You need to speicifc LoraFinetuningConfig for LoRA finetuning"
)
self.algorithm_config = algorithm_config self.algorithm_config = algorithm_config
self._device = torchtune_utils.get_device() self._device = torchtune_utils.get_device()
self._dtype = training.get_dtype(training_config.dtype, device=self._device) self._dtype = training.get_dtype(training_config.dtype, device=self._device)
@ -99,10 +98,7 @@ class LoraFinetuningSingleDevice:
def model_checkpoint_dir(model) -> str: def model_checkpoint_dir(model) -> str:
checkpoint_dir = Path(model_local_dir(model.descriptor())) checkpoint_dir = Path(model_local_dir(model.descriptor()))
paths = [ paths = [Path(checkpoint_dir / f"consolidated.{ext}") for ext in ["pth", "00.pth"]]
Path(checkpoint_dir / f"consolidated.{ext}")
for ext in ["pth", "00.pth"]
]
if not any(p.exists() for p in paths): if not any(p.exists() for p in paths):
checkpoint_dir = checkpoint_dir / "original" checkpoint_dir = checkpoint_dir / "original"
@ -117,9 +113,7 @@ class LoraFinetuningSingleDevice:
else: else:
model = resolve_model(self.model_id) model = resolve_model(self.model_id)
if model is None: if model is None:
raise ValueError( raise ValueError(f"{self.model_id} not found. Your model id should be in the llama models SKU list")
f"{self.model_id} not found. Your model id should be in the llama models SKU list"
)
self.checkpoint_dir = model_checkpoint_dir(model) self.checkpoint_dir = model_checkpoint_dir(model)
self._output_dir = str(DEFAULT_CHECKPOINT_DIR) self._output_dir = str(DEFAULT_CHECKPOINT_DIR)
@ -191,9 +185,7 @@ class LoraFinetuningSingleDevice:
self._tokenizer = await self._setup_tokenizer() self._tokenizer = await self._setup_tokenizer()
log.info("Tokenizer is initialized.") log.info("Tokenizer is initialized.")
self._optimizer = await self._setup_optimizer( self._optimizer = await self._setup_optimizer(optimizer_config=self.training_config.optimizer_config)
optimizer_config=self.training_config.optimizer_config
)
log.info("Optimizer is initialized.") log.info("Optimizer is initialized.")
self._loss_fn = CEWithChunkedOutputLoss() self._loss_fn = CEWithChunkedOutputLoss()
@ -221,13 +213,8 @@ class LoraFinetuningSingleDevice:
# by the dataloader and the max_steps_per_epoch param set by the user and is used # by the dataloader and the max_steps_per_epoch param set by the user and is used
# for logging and tracking training state. This should be computed after the dataloader # for logging and tracking training state. This should be computed after the dataloader
# has been setup # has been setup
self._steps_per_epoch = ( self._steps_per_epoch = len(self._training_dataloader) // self._gradient_accumulation_steps
len(self._training_dataloader) // self._gradient_accumulation_steps if self.max_steps_per_epoch is not None and self.max_steps_per_epoch < self._steps_per_epoch:
)
if (
self.max_steps_per_epoch is not None
and self.max_steps_per_epoch < self._steps_per_epoch
):
self._steps_per_epoch = self.max_steps_per_epoch self._steps_per_epoch = self.max_steps_per_epoch
self.global_step = self.epochs_run * self._steps_per_epoch self.global_step = self.epochs_run * self._steps_per_epoch
@ -241,9 +228,7 @@ class LoraFinetuningSingleDevice:
log.info("Learning rate scheduler is initialized.") log.info("Learning rate scheduler is initialized.")
# Used to ignore labels for loss computation # Used to ignore labels for loss computation
self.ignore_labels_cache = torch.full( self.ignore_labels_cache = torch.full((self._batch_size, 1), self._loss_fn.ignore_index, device=self._device)
(self._batch_size, 1), self._loss_fn.ignore_index, device=self._device
)
def _log_memory_stats(self): def _log_memory_stats(self):
# torchtune raises: "Logging memory stats is not supported on CPU devices"; do nothing # torchtune raises: "Logging memory stats is not supported on CPU devices"; do nothing
@ -284,13 +269,9 @@ class LoraFinetuningSingleDevice:
set_trainable_params(model, self.adapter_params) set_trainable_params(model, self.adapter_params)
if enable_activation_checkpointing: if enable_activation_checkpointing:
training.set_activation_checkpointing( training.set_activation_checkpointing(model, auto_wrap_policy={modules.TransformerSelfAttentionLayer})
model, auto_wrap_policy={modules.TransformerSelfAttentionLayer}
)
base_missing, base_unexpected = model.load_state_dict( base_missing, base_unexpected = model.load_state_dict(base_model_state_dict, strict=False)
base_model_state_dict, strict=False
)
# This is for any adapters that need to be initialized after base weights # This is for any adapters that need to be initialized after base weights
# have been loaded (e.g. DoRA). # have been loaded (e.g. DoRA).
@ -299,9 +280,7 @@ class LoraFinetuningSingleDevice:
if hasattr(m, "initialize_dora_magnitude"): if hasattr(m, "initialize_dora_magnitude"):
m.initialize_dora_magnitude() m.initialize_dora_magnitude()
if lora_weights_state_dict: if lora_weights_state_dict:
lora_missing, lora_unexpected = model.load_state_dict( lora_missing, lora_unexpected = model.load_state_dict(lora_weights_state_dict, strict=False)
lora_weights_state_dict, strict=False
)
else: else:
lora_missing, lora_unexpected = None, None lora_missing, lora_unexpected = None, None
validate_missing_and_unexpected_for_lora( validate_missing_and_unexpected_for_lora(
@ -315,14 +294,10 @@ class LoraFinetuningSingleDevice:
) )
# Validate model adapter params were loaded in with the expected dtype # Validate model adapter params were loaded in with the expected dtype
training.validate_expected_param_dtype( training.validate_expected_param_dtype(self.adapter_params.items(), dtype=self._dtype)
self.adapter_params.items(), dtype=self._dtype
)
# activation offloading # activation offloading
self.activations_handling_ctx = training.get_act_offloading_ctx_manager( self.activations_handling_ctx = training.get_act_offloading_ctx_manager(model, enable_activation_offloading)
model, enable_activation_offloading
)
self._log_memory_stats() self._log_memory_stats()
@ -458,9 +433,7 @@ class LoraFinetuningSingleDevice:
# Shift labels to compute loss # Shift labels to compute loss
# equivalent to doing labels[..., 1:] and logits[..., :-1, :] # equivalent to doing labels[..., 1:] and logits[..., :-1, :]
# But this way we dont need to slice the logits. We just add an ignore index to labels. # But this way we dont need to slice the logits. We just add an ignore index to labels.
labels = torch.hstack( labels = torch.hstack((labels[..., 1:], self.ignore_labels_cache[: labels.shape[0]]))
(labels[..., 1:], self.ignore_labels_cache[: labels.shape[0]])
)
if not isinstance(logits, list): if not isinstance(logits, list):
labels = labels.reshape(-1) labels = labels.reshape(-1)
logits = logits.reshape(-1, logits.size(-1)) logits = logits.reshape(-1, logits.size(-1))
@ -489,9 +462,7 @@ class LoraFinetuningSingleDevice:
for curr_epoch in range(self.epochs_run, self.total_epochs): for curr_epoch in range(self.epochs_run, self.total_epochs):
# Update the sampler to ensure data is correctly shuffled across epochs # Update the sampler to ensure data is correctly shuffled across epochs
# in case shuffle is True # in case shuffle is True
metric_logger = DiskLogger( metric_logger = DiskLogger(log_dir=self._output_dir + f"/{self.model_id}-sft-{curr_epoch}/log")
log_dir=self._output_dir + f"/{self.model_id}-sft-{curr_epoch}/log"
)
self._training_sampler.set_epoch(curr_epoch) self._training_sampler.set_epoch(curr_epoch)
loss_to_log = 0.0 loss_to_log = 0.0
@ -499,8 +470,7 @@ class LoraFinetuningSingleDevice:
for idx, batch in enumerate(self._training_dataloader): for idx, batch in enumerate(self._training_dataloader):
if ( if (
self.max_steps_per_epoch is not None self.max_steps_per_epoch is not None
and (idx // self._gradient_accumulation_steps) and (idx // self._gradient_accumulation_steps) == self.max_steps_per_epoch
== self.max_steps_per_epoch
): ):
break break
@ -508,9 +478,7 @@ class LoraFinetuningSingleDevice:
# Calculate the number of unmasked tokens in the current batch # Calculate the number of unmasked tokens in the current batch
# and increment the total number of tokens seen in the step # and increment the total number of tokens seen in the step
current_num_tokens = ( current_num_tokens = (batch["labels"] != self._loss_fn.ignore_index).sum()
batch["labels"] != self._loss_fn.ignore_index
).sum()
num_tokens += current_num_tokens num_tokens += current_num_tokens
# Loss is normalized by default so we multiply by the number of tokens # Loss is normalized by default so we multiply by the number of tokens
@ -535,9 +503,7 @@ class LoraFinetuningSingleDevice:
loss_to_log = running_loss.item() / num_tokens loss_to_log = running_loss.item() / num_tokens
pbar.update(1) pbar.update(1)
pbar.set_description( pbar.set_description(f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}")
f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}"
)
time_per_step = time.perf_counter() - t0 time_per_step = time.perf_counter() - t0
log_dict = { log_dict = {

View file

@ -64,15 +64,11 @@ class BasicScoringImpl(
async def list_scoring_functions(self) -> List[ScoringFn]: async def list_scoring_functions(self) -> List[ScoringFn]:
scoring_fn_defs_list = [ scoring_fn_defs_list = [
fn_def fn_def for impl in self.scoring_fn_id_impls.values() for fn_def in impl.get_supported_scoring_fn_defs()
for impl in self.scoring_fn_id_impls.values()
for fn_def in impl.get_supported_scoring_fn_defs()
] ]
for f in scoring_fn_defs_list: for f in scoring_fn_defs_list:
assert f.identifier.startswith( assert f.identifier.startswith("basic"), "All basic scoring fn must have identifier prefixed with 'basic'! "
"basic"
), "All basic scoring fn must have identifier prefixed with 'basic'! "
return scoring_fn_defs_list return scoring_fn_defs_list
@ -86,9 +82,7 @@ class BasicScoringImpl(
save_results_dataset: bool = False, save_results_dataset: bool = False,
) -> ScoreBatchResponse: ) -> ScoreBatchResponse:
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id) dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
validate_dataset_schema( validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value))
dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value)
)
all_rows = await self.datasetio_api.iterrows( all_rows = await self.datasetio_api.iterrows(
dataset_id=dataset_id, dataset_id=dataset_id,
@ -118,12 +112,8 @@ class BasicScoringImpl(
raise ValueError(f"Scoring function {scoring_fn_id} is not supported.") raise ValueError(f"Scoring function {scoring_fn_id} is not supported.")
scoring_fn = self.scoring_fn_id_impls[scoring_fn_id] scoring_fn = self.scoring_fn_id_impls[scoring_fn_id]
scoring_fn_params = scoring_functions.get(scoring_fn_id, None) scoring_fn_params = scoring_functions.get(scoring_fn_id, None)
score_results = await scoring_fn.score( score_results = await scoring_fn.score(input_rows, scoring_fn_id, scoring_fn_params)
input_rows, scoring_fn_id, scoring_fn_params agg_results = await scoring_fn.aggregate(score_results, scoring_fn_id, scoring_fn_params)
)
agg_results = await scoring_fn.aggregate(
score_results, scoring_fn_id, scoring_fn_params
)
res[scoring_fn_id] = ScoringResult( res[scoring_fn_id] = ScoringResult(
score_rows=score_results, score_rows=score_results,
aggregated_results=agg_results, aggregated_results=agg_results,

View file

@ -122,12 +122,10 @@ class BraintrustScoringImpl(
self.datasets_api = datasets_api self.datasets_api = datasets_api
self.braintrust_evaluators = { self.braintrust_evaluators = {
entry.identifier: entry.evaluator entry.identifier: entry.evaluator for entry in SUPPORTED_BRAINTRUST_SCORING_FN_ENTRY
for entry in SUPPORTED_BRAINTRUST_SCORING_FN_ENTRY
} }
self.supported_fn_defs_registry = { self.supported_fn_defs_registry = {
entry.identifier: entry.fn_def entry.identifier: entry.fn_def for entry in SUPPORTED_BRAINTRUST_SCORING_FN_ENTRY
for entry in SUPPORTED_BRAINTRUST_SCORING_FN_ENTRY
} }
async def initialize(self) -> None: ... async def initialize(self) -> None: ...
@ -137,16 +135,14 @@ class BraintrustScoringImpl(
async def list_scoring_functions(self) -> List[ScoringFn]: async def list_scoring_functions(self) -> List[ScoringFn]:
scoring_fn_defs_list = list(self.supported_fn_defs_registry.values()) scoring_fn_defs_list = list(self.supported_fn_defs_registry.values())
for f in scoring_fn_defs_list: for f in scoring_fn_defs_list:
assert f.identifier.startswith( assert f.identifier.startswith("braintrust"), (
"braintrust" "All braintrust scoring fn must have identifier prefixed with 'braintrust'! "
), "All braintrust scoring fn must have identifier prefixed with 'braintrust'! " )
return scoring_fn_defs_list return scoring_fn_defs_list
async def register_scoring_function(self, scoring_fn: ScoringFn) -> None: async def register_scoring_function(self, scoring_fn: ScoringFn) -> None:
raise NotImplementedError( raise NotImplementedError("Registering scoring function not allowed for braintrust provider")
"Registering scoring function not allowed for braintrust provider"
)
async def set_api_key(self) -> None: async def set_api_key(self) -> None:
# api key is in the request headers # api key is in the request headers
@ -169,17 +165,13 @@ class BraintrustScoringImpl(
await self.set_api_key() await self.set_api_key()
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id) dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
validate_dataset_schema( validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value))
dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value)
)
all_rows = await self.datasetio_api.iterrows( all_rows = await self.datasetio_api.iterrows(
dataset_id=dataset_id, dataset_id=dataset_id,
rows_in_page=-1, rows_in_page=-1,
) )
res = await self.score( res = await self.score(input_rows=all_rows.rows, scoring_functions=scoring_functions)
input_rows=all_rows.rows, scoring_functions=scoring_functions
)
if save_results_dataset: if save_results_dataset:
# TODO: persist and register dataset on to server for reading # TODO: persist and register dataset on to server for reading
# self.datasets_api.register_dataset() # self.datasets_api.register_dataset()
@ -220,13 +212,8 @@ class BraintrustScoringImpl(
if scoring_fn_id not in self.supported_fn_defs_registry: if scoring_fn_id not in self.supported_fn_defs_registry:
raise ValueError(f"Scoring function {scoring_fn_id} is not supported.") raise ValueError(f"Scoring function {scoring_fn_id} is not supported.")
score_results = [ score_results = [await self.score_row(input_row, scoring_fn_id) for input_row in input_rows]
await self.score_row(input_row, scoring_fn_id) aggregation_functions = self.supported_fn_defs_registry[scoring_fn_id].params.aggregation_functions
for input_row in input_rows
]
aggregation_functions = self.supported_fn_defs_registry[
scoring_fn_id
].params.aggregation_functions
# override scoring_fn params if provided # override scoring_fn params if provided
if scoring_functions[scoring_fn_id] is not None: if scoring_functions[scoring_fn_id] is not None:

View file

@ -54,9 +54,9 @@ class LlmAsJudgeScoringImpl(
scoring_fn_defs_list = self.llm_as_judge_fn.get_supported_scoring_fn_defs() scoring_fn_defs_list = self.llm_as_judge_fn.get_supported_scoring_fn_defs()
for f in self.llm_as_judge_fn.get_supported_scoring_fn_defs(): for f in self.llm_as_judge_fn.get_supported_scoring_fn_defs():
assert f.identifier.startswith( assert f.identifier.startswith("llm-as-judge"), (
"llm-as-judge" "All llm-as-judge scoring fn must have identifier prefixed with 'llm-as-judge'! "
), "All llm-as-judge scoring fn must have identifier prefixed with 'llm-as-judge'! " )
return scoring_fn_defs_list return scoring_fn_defs_list
@ -70,9 +70,7 @@ class LlmAsJudgeScoringImpl(
save_results_dataset: bool = False, save_results_dataset: bool = False,
) -> ScoreBatchResponse: ) -> ScoreBatchResponse:
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id) dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
validate_dataset_schema( validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value))
dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value)
)
all_rows = await self.datasetio_api.iterrows( all_rows = await self.datasetio_api.iterrows(
dataset_id=dataset_id, dataset_id=dataset_id,
@ -100,12 +98,8 @@ class LlmAsJudgeScoringImpl(
for scoring_fn_id in scoring_functions.keys(): for scoring_fn_id in scoring_functions.keys():
scoring_fn = self.llm_as_judge_fn scoring_fn = self.llm_as_judge_fn
scoring_fn_params = scoring_functions.get(scoring_fn_id, None) scoring_fn_params = scoring_functions.get(scoring_fn_id, None)
score_results = await scoring_fn.score( score_results = await scoring_fn.score(input_rows, scoring_fn_id, scoring_fn_params)
input_rows, scoring_fn_id, scoring_fn_params agg_results = await scoring_fn.aggregate(score_results, scoring_fn_id, scoring_fn_params)
)
agg_results = await scoring_fn.aggregate(
score_results, scoring_fn_id, scoring_fn_params
)
res[scoring_fn_id] = ScoringResult( res[scoring_fn_id] = ScoringResult(
score_rows=score_results, score_rows=score_results,
aggregated_results=agg_results, aggregated_results=agg_results,

View file

@ -104,13 +104,9 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
new_dataset = hf_datasets.Dataset.from_list(rows) new_dataset = hf_datasets.Dataset.from_list(rows)
# Concatenate the new rows with existing dataset # Concatenate the new rows with existing dataset
updated_dataset = hf_datasets.concatenate_datasets( updated_dataset = hf_datasets.concatenate_datasets([loaded_dataset, new_dataset])
[loaded_dataset, new_dataset]
)
if dataset_def.metadata.get("path", None): if dataset_def.metadata.get("path", None):
updated_dataset.push_to_hub(dataset_def.metadata["path"]) updated_dataset.push_to_hub(dataset_def.metadata["path"])
else: else:
raise NotImplementedError( raise NotImplementedError("Uploading to URL-based datasets is not supported yet")
"Uploading to URL-based datasets is not supported yet"
)