precommit

This commit is contained in:
Xi Yan 2025-03-15 16:20:58 -07:00
parent 9da092ff2d
commit cf225c9710
6 changed files with 56 additions and 160 deletions

View file

@ -51,6 +51,4 @@ class DatasetIO(Protocol):
... ...
@webmethod(route="/datasetio/append-rows/{dataset_id:path}", method="POST") @webmethod(route="/datasetio/append-rows/{dataset_id:path}", method="POST")
async def append_rows( async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: ...
self, dataset_id: str, rows: List[Dict[str, Any]]
) -> None: ...

View file

@ -8,9 +8,9 @@ import time
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
URL,
InterleavedContent, InterleavedContent,
InterleavedContentItem, InterleavedContentItem,
URL,
) )
from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse
from llama_stack.apis.datasets import DatasetPurpose, DataSource from llama_stack.apis.datasets import DatasetPurpose, DataSource
@ -94,9 +94,7 @@ class VectorIORouter(VectorIO):
provider_id: Optional[str] = None, provider_id: Optional[str] = None,
provider_vector_db_id: Optional[str] = None, provider_vector_db_id: Optional[str] = None,
) -> None: ) -> None:
logger.debug( logger.debug(f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}")
f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}"
)
await self.routing_table.register_vector_db( await self.routing_table.register_vector_db(
vector_db_id, vector_db_id,
embedding_model, embedding_model,
@ -114,9 +112,7 @@ class VectorIORouter(VectorIO):
logger.debug( logger.debug(
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}", f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}",
) )
return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks( return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(vector_db_id, chunks, ttl_seconds)
vector_db_id, chunks, ttl_seconds
)
async def query_chunks( async def query_chunks(
self, self,
@ -125,9 +121,7 @@ class VectorIORouter(VectorIO):
params: Optional[Dict[str, Any]] = None, params: Optional[Dict[str, Any]] = None,
) -> QueryChunksResponse: ) -> QueryChunksResponse:
logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}") logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}")
return await self.routing_table.get_provider_impl(vector_db_id).query_chunks( return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params)
vector_db_id, query, params
)
class InferenceRouter(Inference): class InferenceRouter(Inference):
@ -164,9 +158,7 @@ class InferenceRouter(Inference):
logger.debug( logger.debug(
f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}", f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}",
) )
await self.routing_table.register_model( await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)
model_id, provider_model_id, provider_id, metadata, model_type
)
def _construct_metrics( def _construct_metrics(
self, self,
@ -220,16 +212,11 @@ class InferenceRouter(Inference):
total_tokens: int, total_tokens: int,
model: Model, model: Model,
) -> List[MetricInResponse]: ) -> List[MetricInResponse]:
metrics = self._construct_metrics( metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model)
prompt_tokens, completion_tokens, total_tokens, model
)
if self.telemetry: if self.telemetry:
for metric in metrics: for metric in metrics:
await self.telemetry.log_event(metric) await self.telemetry.log_event(metric)
return [ return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics]
MetricInResponse(metric=metric.metric, value=metric.value)
for metric in metrics
]
async def _count_tokens( async def _count_tokens(
self, self,
@ -254,9 +241,7 @@ class InferenceRouter(Inference):
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None, tool_config: Optional[ToolConfig] = None,
) -> Union[ ) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
]:
logger.debug( logger.debug(
f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}", f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}",
) )
@ -266,19 +251,12 @@ class InferenceRouter(Inference):
if model is None: if model is None:
raise ValueError(f"Model '{model_id}' not found") raise ValueError(f"Model '{model_id}' not found")
if model.model_type == ModelType.embedding: if model.model_type == ModelType.embedding:
raise ValueError( raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions")
f"Model '{model_id}' is an embedding model and does not support chat completions"
)
if tool_config: if tool_config:
if tool_choice and tool_choice != tool_config.tool_choice: if tool_choice and tool_choice != tool_config.tool_choice:
raise ValueError("tool_choice and tool_config.tool_choice must match") raise ValueError("tool_choice and tool_config.tool_choice must match")
if ( if tool_prompt_format and tool_prompt_format != tool_config.tool_prompt_format:
tool_prompt_format raise ValueError("tool_prompt_format and tool_config.tool_prompt_format must match")
and tool_prompt_format != tool_config.tool_prompt_format
):
raise ValueError(
"tool_prompt_format and tool_config.tool_prompt_format must match"
)
else: else:
params = {} params = {}
if tool_choice: if tool_choice:
@ -296,14 +274,9 @@ class InferenceRouter(Inference):
pass pass
else: else:
# verify tool_choice is one of the tools # verify tool_choice is one of the tools
tool_names = [ tool_names = [t.tool_name if isinstance(t.tool_name, str) else t.tool_name.value for t in tools]
t.tool_name if isinstance(t.tool_name, str) else t.tool_name.value
for t in tools
]
if tool_config.tool_choice not in tool_names: if tool_config.tool_choice not in tool_names:
raise ValueError( raise ValueError(f"Tool choice {tool_config.tool_choice} is not one of the tools: {tool_names}")
f"Tool choice {tool_config.tool_choice} is not one of the tools: {tool_names}"
)
params = dict( params = dict(
model_id=model_id, model_id=model_id,
@ -318,25 +291,17 @@ class InferenceRouter(Inference):
tool_config=tool_config, tool_config=tool_config,
) )
provider = self.routing_table.get_provider_impl(model_id) provider = self.routing_table.get_provider_impl(model_id)
prompt_tokens = await self._count_tokens( prompt_tokens = await self._count_tokens(messages, tool_config.tool_prompt_format)
messages, tool_config.tool_prompt_format
)
if stream: if stream:
async def stream_generator(): async def stream_generator():
completion_text = "" completion_text = ""
async for chunk in await provider.chat_completion(**params): async for chunk in await provider.chat_completion(**params):
if ( if chunk.event.event_type == ChatCompletionResponseEventType.progress:
chunk.event.event_type
== ChatCompletionResponseEventType.progress
):
if chunk.event.delta.type == "text": if chunk.event.delta.type == "text":
completion_text += chunk.event.delta.text completion_text += chunk.event.delta.text
if ( if chunk.event.event_type == ChatCompletionResponseEventType.complete:
chunk.event.event_type
== ChatCompletionResponseEventType.complete
):
completion_tokens = await self._count_tokens( completion_tokens = await self._count_tokens(
[ [
CompletionMessage( CompletionMessage(
@ -353,11 +318,7 @@ class InferenceRouter(Inference):
total_tokens, total_tokens,
model, model,
) )
chunk.metrics = ( chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics
metrics
if chunk.metrics is None
else chunk.metrics + metrics
)
yield chunk yield chunk
return stream_generator() return stream_generator()
@ -374,9 +335,7 @@ class InferenceRouter(Inference):
total_tokens, total_tokens,
model, model,
) )
response.metrics = ( response.metrics = metrics if response.metrics is None else response.metrics + metrics
metrics if response.metrics is None else response.metrics + metrics
)
return response return response
async def completion( async def completion(
@ -397,9 +356,7 @@ class InferenceRouter(Inference):
if model is None: if model is None:
raise ValueError(f"Model '{model_id}' not found") raise ValueError(f"Model '{model_id}' not found")
if model.model_type == ModelType.embedding: if model.model_type == ModelType.embedding:
raise ValueError( raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions")
f"Model '{model_id}' is an embedding model and does not support chat completions"
)
provider = self.routing_table.get_provider_impl(model_id) provider = self.routing_table.get_provider_impl(model_id)
params = dict( params = dict(
model_id=model_id, model_id=model_id,
@ -419,11 +376,7 @@ class InferenceRouter(Inference):
async for chunk in await provider.completion(**params): async for chunk in await provider.completion(**params):
if hasattr(chunk, "delta"): if hasattr(chunk, "delta"):
completion_text += chunk.delta completion_text += chunk.delta
if ( if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry:
hasattr(chunk, "stop_reason")
and chunk.stop_reason
and self.telemetry
):
completion_tokens = await self._count_tokens(completion_text) completion_tokens = await self._count_tokens(completion_text)
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0) total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
metrics = await self._compute_and_log_token_usage( metrics = await self._compute_and_log_token_usage(
@ -432,11 +385,7 @@ class InferenceRouter(Inference):
total_tokens, total_tokens,
model, model,
) )
chunk.metrics = ( chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics
metrics
if chunk.metrics is None
else chunk.metrics + metrics
)
yield chunk yield chunk
return stream_generator() return stream_generator()
@ -450,9 +399,7 @@ class InferenceRouter(Inference):
total_tokens, total_tokens,
model, model,
) )
response.metrics = ( response.metrics = metrics if response.metrics is None else response.metrics + metrics
metrics if response.metrics is None else response.metrics + metrics
)
return response return response
async def embeddings( async def embeddings(
@ -468,9 +415,7 @@ class InferenceRouter(Inference):
if model is None: if model is None:
raise ValueError(f"Model '{model_id}' not found") raise ValueError(f"Model '{model_id}' not found")
if model.model_type == ModelType.llm: if model.model_type == ModelType.llm:
raise ValueError( raise ValueError(f"Model '{model_id}' is an LLM model and does not support embeddings")
f"Model '{model_id}' is an LLM model and does not support embeddings"
)
return await self.routing_table.get_provider_impl(model_id).embeddings( return await self.routing_table.get_provider_impl(model_id).embeddings(
model_id=model_id, model_id=model_id,
contents=contents, contents=contents,
@ -504,9 +449,7 @@ class SafetyRouter(Safety):
params: Optional[Dict[str, Any]] = None, params: Optional[Dict[str, Any]] = None,
) -> Shield: ) -> Shield:
logger.debug(f"SafetyRouter.register_shield: {shield_id}") logger.debug(f"SafetyRouter.register_shield: {shield_id}")
return await self.routing_table.register_shield( return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params)
shield_id, provider_shield_id, provider_id, params
)
async def run_shield( async def run_shield(
self, self,
@ -603,9 +546,7 @@ class ScoringRouter(Scoring):
logger.debug(f"ScoringRouter.score_batch: {dataset_id}") logger.debug(f"ScoringRouter.score_batch: {dataset_id}")
res = {} res = {}
for fn_identifier in scoring_functions.keys(): for fn_identifier in scoring_functions.keys():
score_response = await self.routing_table.get_provider_impl( score_response = await self.routing_table.get_provider_impl(fn_identifier).score_batch(
fn_identifier
).score_batch(
dataset_id=dataset_id, dataset_id=dataset_id,
scoring_functions={fn_identifier: scoring_functions[fn_identifier]}, scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
) )
@ -623,15 +564,11 @@ class ScoringRouter(Scoring):
input_rows: List[Dict[str, Any]], input_rows: List[Dict[str, Any]],
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None, scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
) -> ScoreResponse: ) -> ScoreResponse:
logger.debug( logger.debug(f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions")
f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions"
)
res = {} res = {}
# look up and map each scoring function to its provider impl # look up and map each scoring function to its provider impl
for fn_identifier in scoring_functions.keys(): for fn_identifier in scoring_functions.keys():
score_response = await self.routing_table.get_provider_impl( score_response = await self.routing_table.get_provider_impl(fn_identifier).score(
fn_identifier
).score(
input_rows=input_rows, input_rows=input_rows,
scoring_functions={fn_identifier: scoring_functions[fn_identifier]}, scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
) )
@ -674,9 +611,7 @@ class EvalRouter(Eval):
scoring_functions: List[str], scoring_functions: List[str],
benchmark_config: BenchmarkConfig, benchmark_config: BenchmarkConfig,
) -> EvaluateResponse: ) -> EvaluateResponse:
logger.debug( logger.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows")
f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows"
)
return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows( return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows(
benchmark_id=benchmark_id, benchmark_id=benchmark_id,
input_rows=input_rows, input_rows=input_rows,
@ -690,9 +625,7 @@ class EvalRouter(Eval):
job_id: str, job_id: str,
) -> Optional[JobStatus]: ) -> Optional[JobStatus]:
logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}") logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}")
return await self.routing_table.get_provider_impl(benchmark_id).job_status( return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id)
benchmark_id, job_id
)
async def job_cancel( async def job_cancel(
self, self,
@ -746,9 +679,9 @@ class ToolRuntimeRouter(ToolRuntime):
logger.debug( logger.debug(
f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}" f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}"
) )
return await self.routing_table.get_provider_impl( return await self.routing_table.get_provider_impl("insert_into_memory").insert(
"insert_into_memory" documents, vector_db_id, chunk_size_in_tokens
).insert(documents, vector_db_id, chunk_size_in_tokens) )
def __init__( def __init__(
self, self,
@ -781,6 +714,4 @@ class ToolRuntimeRouter(ToolRuntime):
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]: ) -> List[ToolDef]:
logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}") logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
return await self.routing_table.get_provider_impl(tool_group_id).list_tools( return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint)
tool_group_id, mcp_endpoint
)

View file

@ -105,9 +105,7 @@ class CommonRoutingTableImpl(RoutingTable):
self.dist_registry = dist_registry self.dist_registry = dist_registry
async def initialize(self) -> None: async def initialize(self) -> None:
async def add_objects( async def add_objects(objs: List[RoutableObjectWithProvider], provider_id: str, cls) -> None:
objs: List[RoutableObjectWithProvider], provider_id: str, cls
) -> None:
for obj in objs: for obj in objs:
if cls is None: if cls is None:
obj.provider_id = provider_id obj.provider_id = provider_id
@ -142,9 +140,7 @@ class CommonRoutingTableImpl(RoutingTable):
for p in self.impls_by_provider_id.values(): for p in self.impls_by_provider_id.values():
await p.shutdown() await p.shutdown()
def get_provider_impl( def get_provider_impl(self, routing_key: str, provider_id: Optional[str] = None) -> Any:
self, routing_key: str, provider_id: Optional[str] = None
) -> Any:
def apiname_object(): def apiname_object():
if isinstance(self, ModelsRoutingTable): if isinstance(self, ModelsRoutingTable):
return ("Inference", "model") return ("Inference", "model")
@ -182,9 +178,7 @@ class CommonRoutingTableImpl(RoutingTable):
raise ValueError(f"Provider not found for `{routing_key}`") raise ValueError(f"Provider not found for `{routing_key}`")
async def get_object_by_identifier( async def get_object_by_identifier(self, type: str, identifier: str) -> Optional[RoutableObjectWithProvider]:
self, type: str, identifier: str
) -> Optional[RoutableObjectWithProvider]:
# Get from disk registry # Get from disk registry
obj = await self.dist_registry.get(type, identifier) obj = await self.dist_registry.get(type, identifier)
if not obj: if not obj:
@ -194,13 +188,9 @@ class CommonRoutingTableImpl(RoutingTable):
async def unregister_object(self, obj: RoutableObjectWithProvider) -> None: async def unregister_object(self, obj: RoutableObjectWithProvider) -> None:
await self.dist_registry.delete(obj.type, obj.identifier) await self.dist_registry.delete(obj.type, obj.identifier)
await unregister_object_from_provider( await unregister_object_from_provider(obj, self.impls_by_provider_id[obj.provider_id])
obj, self.impls_by_provider_id[obj.provider_id]
)
async def register_object( async def register_object(self, obj: RoutableObjectWithProvider) -> RoutableObjectWithProvider:
self, obj: RoutableObjectWithProvider
) -> RoutableObjectWithProvider:
# if provider_id is not specified, pick an arbitrary one from existing entries # if provider_id is not specified, pick an arbitrary one from existing entries
if not obj.provider_id and len(self.impls_by_provider_id) > 0: if not obj.provider_id and len(self.impls_by_provider_id) > 0:
obj.provider_id = list(self.impls_by_provider_id.keys())[0] obj.provider_id = list(self.impls_by_provider_id.keys())[0]
@ -255,9 +245,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
if model_type is None: if model_type is None:
model_type = ModelType.llm model_type = ModelType.llm
if "embedding_dimension" not in metadata and model_type == ModelType.embedding: if "embedding_dimension" not in metadata and model_type == ModelType.embedding:
raise ValueError( raise ValueError("Embedding model must have an embedding dimension in its metadata")
"Embedding model must have an embedding dimension in its metadata"
)
model = Model( model = Model(
identifier=model_id, identifier=model_id,
provider_resource_id=provider_model_id, provider_resource_id=provider_model_id,
@ -277,9 +265,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
async def list_shields(self) -> ListShieldsResponse: async def list_shields(self) -> ListShieldsResponse:
return ListShieldsResponse( return ListShieldsResponse(data=await self.get_all_with_type(ResourceType.shield.value))
data=await self.get_all_with_type(ResourceType.shield.value)
)
async def get_shield(self, identifier: str) -> Optional[Shield]: async def get_shield(self, identifier: str) -> Optional[Shield]:
return await self.get_object_by_identifier("shield", identifier) return await self.get_object_by_identifier("shield", identifier)
@ -338,18 +324,14 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
f"No provider specified and multiple providers available. Arbitrarily selected the first provider {provider_id}." f"No provider specified and multiple providers available. Arbitrarily selected the first provider {provider_id}."
) )
else: else:
raise ValueError( raise ValueError("No provider available. Please configure a vector_io provider.")
"No provider available. Please configure a vector_io provider."
)
model = await self.get_object_by_identifier("model", embedding_model) model = await self.get_object_by_identifier("model", embedding_model)
if model is None: if model is None:
raise ValueError(f"Model {embedding_model} not found") raise ValueError(f"Model {embedding_model} not found")
if model.model_type != ModelType.embedding: if model.model_type != ModelType.embedding:
raise ValueError(f"Model {embedding_model} is not an embedding model") raise ValueError(f"Model {embedding_model} is not an embedding model")
if "embedding_dimension" not in model.metadata: if "embedding_dimension" not in model.metadata:
raise ValueError( raise ValueError(f"Model {embedding_model} does not have an embedding dimension")
f"Model {embedding_model} does not have an embedding dimension"
)
vector_db_data = { vector_db_data = {
"identifier": vector_db_id, "identifier": vector_db_id,
"type": ResourceType.vector_db.value, "type": ResourceType.vector_db.value,
@ -371,9 +353,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
async def list_datasets(self) -> ListDatasetsResponse: async def list_datasets(self) -> ListDatasetsResponse:
return ListDatasetsResponse( return ListDatasetsResponse(data=await self.get_all_with_type(ResourceType.dataset.value))
data=await self.get_all_with_type(ResourceType.dataset.value)
)
async def get_dataset(self, dataset_id: str) -> Optional[Dataset]: async def get_dataset(self, dataset_id: str) -> Optional[Dataset]:
return await self.get_object_by_identifier("dataset", dataset_id) return await self.get_object_by_identifier("dataset", dataset_id)
@ -426,9 +406,7 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions): class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
async def list_scoring_functions(self) -> ListScoringFunctionsResponse: async def list_scoring_functions(self) -> ListScoringFunctionsResponse:
return ListScoringFunctionsResponse( return ListScoringFunctionsResponse(data=await self.get_all_with_type(ResourceType.scoring_function.value))
data=await self.get_all_with_type(ResourceType.scoring_function.value)
)
async def get_scoring_function(self, scoring_fn_id: str) -> Optional[ScoringFn]: async def get_scoring_function(self, scoring_fn_id: str) -> Optional[ScoringFn]:
return await self.get_object_by_identifier("scoring_function", scoring_fn_id) return await self.get_object_by_identifier("scoring_function", scoring_fn_id)
@ -525,12 +503,8 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
args: Optional[Dict[str, Any]] = None, args: Optional[Dict[str, Any]] = None,
) -> None: ) -> None:
tools = [] tools = []
tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools( tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools(toolgroup_id, mcp_endpoint)
toolgroup_id, mcp_endpoint tool_host = ToolHost.model_context_protocol if mcp_endpoint else ToolHost.distribution
)
tool_host = (
ToolHost.model_context_protocol if mcp_endpoint else ToolHost.distribution
)
for tool_def in tool_defs: for tool_def in tool_defs:
tools.append( tools.append(

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from urllib.parse import parse_qs, urlparse from urllib.parse import parse_qs, urlparse
import datasets as hf_datasets import datasets as hf_datasets
@ -12,13 +11,11 @@ import datasets as hf_datasets
from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse
from llama_stack.apis.datasets import Dataset from llama_stack.apis.datasets import Dataset
from llama_stack.providers.datatypes import DatasetsProtocolPrivate from llama_stack.providers.datatypes import DatasetsProtocolPrivate
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url
from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore import kvstore_impl
from .config import HuggingfaceDatasetIOConfig from .config import HuggingfaceDatasetIOConfig
DATASETS_PREFIX = "datasets:" DATASETS_PREFIX = "datasets:"
from rich.pretty import pprint
def parse_hf_params(dataset_def: Dataset): def parse_hf_params(dataset_def: Dataset):
@ -102,13 +99,9 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
new_dataset = hf_datasets.Dataset.from_list(rows) new_dataset = hf_datasets.Dataset.from_list(rows)
# Concatenate the new rows with existing dataset # Concatenate the new rows with existing dataset
updated_dataset = hf_datasets.concatenate_datasets( updated_dataset = hf_datasets.concatenate_datasets([loaded_dataset, new_dataset])
[loaded_dataset, new_dataset]
)
if dataset_def.metadata.get("path", None): if dataset_def.metadata.get("path", None):
updated_dataset.push_to_hub(dataset_def.metadata["path"]) updated_dataset.push_to_hub(dataset_def.metadata["path"])
else: else:
raise NotImplementedError( raise NotImplementedError("Uploading to URL-based datasets is not supported yet")
"Uploading to URL-based datasets is not supported yet"
)

View file

@ -4,10 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import base64
import mimetypes
import os
from pathlib import Path
import pytest import pytest
@ -36,8 +32,6 @@ def test_register_dataset(llama_stack_client, purpose, source, provider_id):
) )
assert dataset.identifier is not None assert dataset.identifier is not None
assert dataset.provider_id == provider_id assert dataset.provider_id == provider_id
iterrow_response = llama_stack_client.datasets.iterrows( iterrow_response = llama_stack_client.datasets.iterrows(dataset.identifier, limit=10)
dataset.identifier, limit=10
)
assert len(iterrow_response.data) == 10 assert len(iterrow_response.data) == 10
assert iterrow_response.next_index is not None assert iterrow_response.next_index is not None

View file

@ -1,3 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack_client import LlamaStackClient from llama_stack_client import LlamaStackClient
from rich.pretty import pprint from rich.pretty import pprint