mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-07 11:08:20 +00:00
precommit
This commit is contained in:
parent
9da092ff2d
commit
cf225c9710
6 changed files with 56 additions and 160 deletions
|
@ -51,6 +51,4 @@ class DatasetIO(Protocol):
|
|||
...
|
||||
|
||||
@webmethod(route="/datasetio/append-rows/{dataset_id:path}", method="POST")
|
||||
async def append_rows(
|
||||
self, dataset_id: str, rows: List[Dict[str, Any]]
|
||||
) -> None: ...
|
||||
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: ...
|
||||
|
|
|
@ -8,9 +8,9 @@ import time
|
|||
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
URL,
|
||||
InterleavedContent,
|
||||
InterleavedContentItem,
|
||||
URL,
|
||||
)
|
||||
from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse
|
||||
from llama_stack.apis.datasets import DatasetPurpose, DataSource
|
||||
|
@ -94,9 +94,7 @@ class VectorIORouter(VectorIO):
|
|||
provider_id: Optional[str] = None,
|
||||
provider_vector_db_id: Optional[str] = None,
|
||||
) -> None:
|
||||
logger.debug(
|
||||
f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}"
|
||||
)
|
||||
logger.debug(f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}")
|
||||
await self.routing_table.register_vector_db(
|
||||
vector_db_id,
|
||||
embedding_model,
|
||||
|
@ -114,9 +112,7 @@ class VectorIORouter(VectorIO):
|
|||
logger.debug(
|
||||
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}",
|
||||
)
|
||||
return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(
|
||||
vector_db_id, chunks, ttl_seconds
|
||||
)
|
||||
return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(vector_db_id, chunks, ttl_seconds)
|
||||
|
||||
async def query_chunks(
|
||||
self,
|
||||
|
@ -125,9 +121,7 @@ class VectorIORouter(VectorIO):
|
|||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> QueryChunksResponse:
|
||||
logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}")
|
||||
return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(
|
||||
vector_db_id, query, params
|
||||
)
|
||||
return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params)
|
||||
|
||||
|
||||
class InferenceRouter(Inference):
|
||||
|
@ -164,9 +158,7 @@ class InferenceRouter(Inference):
|
|||
logger.debug(
|
||||
f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}",
|
||||
)
|
||||
await self.routing_table.register_model(
|
||||
model_id, provider_model_id, provider_id, metadata, model_type
|
||||
)
|
||||
await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)
|
||||
|
||||
def _construct_metrics(
|
||||
self,
|
||||
|
@ -220,16 +212,11 @@ class InferenceRouter(Inference):
|
|||
total_tokens: int,
|
||||
model: Model,
|
||||
) -> List[MetricInResponse]:
|
||||
metrics = self._construct_metrics(
|
||||
prompt_tokens, completion_tokens, total_tokens, model
|
||||
)
|
||||
metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model)
|
||||
if self.telemetry:
|
||||
for metric in metrics:
|
||||
await self.telemetry.log_event(metric)
|
||||
return [
|
||||
MetricInResponse(metric=metric.metric, value=metric.value)
|
||||
for metric in metrics
|
||||
]
|
||||
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics]
|
||||
|
||||
async def _count_tokens(
|
||||
self,
|
||||
|
@ -254,9 +241,7 @@ class InferenceRouter(Inference):
|
|||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> Union[
|
||||
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
|
||||
]:
|
||||
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
|
||||
logger.debug(
|
||||
f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}",
|
||||
)
|
||||
|
@ -266,19 +251,12 @@ class InferenceRouter(Inference):
|
|||
if model is None:
|
||||
raise ValueError(f"Model '{model_id}' not found")
|
||||
if model.model_type == ModelType.embedding:
|
||||
raise ValueError(
|
||||
f"Model '{model_id}' is an embedding model and does not support chat completions"
|
||||
)
|
||||
raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions")
|
||||
if tool_config:
|
||||
if tool_choice and tool_choice != tool_config.tool_choice:
|
||||
raise ValueError("tool_choice and tool_config.tool_choice must match")
|
||||
if (
|
||||
tool_prompt_format
|
||||
and tool_prompt_format != tool_config.tool_prompt_format
|
||||
):
|
||||
raise ValueError(
|
||||
"tool_prompt_format and tool_config.tool_prompt_format must match"
|
||||
)
|
||||
if tool_prompt_format and tool_prompt_format != tool_config.tool_prompt_format:
|
||||
raise ValueError("tool_prompt_format and tool_config.tool_prompt_format must match")
|
||||
else:
|
||||
params = {}
|
||||
if tool_choice:
|
||||
|
@ -296,14 +274,9 @@ class InferenceRouter(Inference):
|
|||
pass
|
||||
else:
|
||||
# verify tool_choice is one of the tools
|
||||
tool_names = [
|
||||
t.tool_name if isinstance(t.tool_name, str) else t.tool_name.value
|
||||
for t in tools
|
||||
]
|
||||
tool_names = [t.tool_name if isinstance(t.tool_name, str) else t.tool_name.value for t in tools]
|
||||
if tool_config.tool_choice not in tool_names:
|
||||
raise ValueError(
|
||||
f"Tool choice {tool_config.tool_choice} is not one of the tools: {tool_names}"
|
||||
)
|
||||
raise ValueError(f"Tool choice {tool_config.tool_choice} is not one of the tools: {tool_names}")
|
||||
|
||||
params = dict(
|
||||
model_id=model_id,
|
||||
|
@ -318,25 +291,17 @@ class InferenceRouter(Inference):
|
|||
tool_config=tool_config,
|
||||
)
|
||||
provider = self.routing_table.get_provider_impl(model_id)
|
||||
prompt_tokens = await self._count_tokens(
|
||||
messages, tool_config.tool_prompt_format
|
||||
)
|
||||
prompt_tokens = await self._count_tokens(messages, tool_config.tool_prompt_format)
|
||||
|
||||
if stream:
|
||||
|
||||
async def stream_generator():
|
||||
completion_text = ""
|
||||
async for chunk in await provider.chat_completion(**params):
|
||||
if (
|
||||
chunk.event.event_type
|
||||
== ChatCompletionResponseEventType.progress
|
||||
):
|
||||
if chunk.event.event_type == ChatCompletionResponseEventType.progress:
|
||||
if chunk.event.delta.type == "text":
|
||||
completion_text += chunk.event.delta.text
|
||||
if (
|
||||
chunk.event.event_type
|
||||
== ChatCompletionResponseEventType.complete
|
||||
):
|
||||
if chunk.event.event_type == ChatCompletionResponseEventType.complete:
|
||||
completion_tokens = await self._count_tokens(
|
||||
[
|
||||
CompletionMessage(
|
||||
|
@ -353,11 +318,7 @@ class InferenceRouter(Inference):
|
|||
total_tokens,
|
||||
model,
|
||||
)
|
||||
chunk.metrics = (
|
||||
metrics
|
||||
if chunk.metrics is None
|
||||
else chunk.metrics + metrics
|
||||
)
|
||||
chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics
|
||||
yield chunk
|
||||
|
||||
return stream_generator()
|
||||
|
@ -374,9 +335,7 @@ class InferenceRouter(Inference):
|
|||
total_tokens,
|
||||
model,
|
||||
)
|
||||
response.metrics = (
|
||||
metrics if response.metrics is None else response.metrics + metrics
|
||||
)
|
||||
response.metrics = metrics if response.metrics is None else response.metrics + metrics
|
||||
return response
|
||||
|
||||
async def completion(
|
||||
|
@ -397,9 +356,7 @@ class InferenceRouter(Inference):
|
|||
if model is None:
|
||||
raise ValueError(f"Model '{model_id}' not found")
|
||||
if model.model_type == ModelType.embedding:
|
||||
raise ValueError(
|
||||
f"Model '{model_id}' is an embedding model and does not support chat completions"
|
||||
)
|
||||
raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions")
|
||||
provider = self.routing_table.get_provider_impl(model_id)
|
||||
params = dict(
|
||||
model_id=model_id,
|
||||
|
@ -419,11 +376,7 @@ class InferenceRouter(Inference):
|
|||
async for chunk in await provider.completion(**params):
|
||||
if hasattr(chunk, "delta"):
|
||||
completion_text += chunk.delta
|
||||
if (
|
||||
hasattr(chunk, "stop_reason")
|
||||
and chunk.stop_reason
|
||||
and self.telemetry
|
||||
):
|
||||
if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry:
|
||||
completion_tokens = await self._count_tokens(completion_text)
|
||||
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
||||
metrics = await self._compute_and_log_token_usage(
|
||||
|
@ -432,11 +385,7 @@ class InferenceRouter(Inference):
|
|||
total_tokens,
|
||||
model,
|
||||
)
|
||||
chunk.metrics = (
|
||||
metrics
|
||||
if chunk.metrics is None
|
||||
else chunk.metrics + metrics
|
||||
)
|
||||
chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics
|
||||
yield chunk
|
||||
|
||||
return stream_generator()
|
||||
|
@ -450,9 +399,7 @@ class InferenceRouter(Inference):
|
|||
total_tokens,
|
||||
model,
|
||||
)
|
||||
response.metrics = (
|
||||
metrics if response.metrics is None else response.metrics + metrics
|
||||
)
|
||||
response.metrics = metrics if response.metrics is None else response.metrics + metrics
|
||||
return response
|
||||
|
||||
async def embeddings(
|
||||
|
@ -468,9 +415,7 @@ class InferenceRouter(Inference):
|
|||
if model is None:
|
||||
raise ValueError(f"Model '{model_id}' not found")
|
||||
if model.model_type == ModelType.llm:
|
||||
raise ValueError(
|
||||
f"Model '{model_id}' is an LLM model and does not support embeddings"
|
||||
)
|
||||
raise ValueError(f"Model '{model_id}' is an LLM model and does not support embeddings")
|
||||
return await self.routing_table.get_provider_impl(model_id).embeddings(
|
||||
model_id=model_id,
|
||||
contents=contents,
|
||||
|
@ -504,9 +449,7 @@ class SafetyRouter(Safety):
|
|||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> Shield:
|
||||
logger.debug(f"SafetyRouter.register_shield: {shield_id}")
|
||||
return await self.routing_table.register_shield(
|
||||
shield_id, provider_shield_id, provider_id, params
|
||||
)
|
||||
return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params)
|
||||
|
||||
async def run_shield(
|
||||
self,
|
||||
|
@ -603,9 +546,7 @@ class ScoringRouter(Scoring):
|
|||
logger.debug(f"ScoringRouter.score_batch: {dataset_id}")
|
||||
res = {}
|
||||
for fn_identifier in scoring_functions.keys():
|
||||
score_response = await self.routing_table.get_provider_impl(
|
||||
fn_identifier
|
||||
).score_batch(
|
||||
score_response = await self.routing_table.get_provider_impl(fn_identifier).score_batch(
|
||||
dataset_id=dataset_id,
|
||||
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
|
||||
)
|
||||
|
@ -623,15 +564,11 @@ class ScoringRouter(Scoring):
|
|||
input_rows: List[Dict[str, Any]],
|
||||
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
|
||||
) -> ScoreResponse:
|
||||
logger.debug(
|
||||
f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions"
|
||||
)
|
||||
logger.debug(f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions")
|
||||
res = {}
|
||||
# look up and map each scoring function to its provider impl
|
||||
for fn_identifier in scoring_functions.keys():
|
||||
score_response = await self.routing_table.get_provider_impl(
|
||||
fn_identifier
|
||||
).score(
|
||||
score_response = await self.routing_table.get_provider_impl(fn_identifier).score(
|
||||
input_rows=input_rows,
|
||||
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
|
||||
)
|
||||
|
@ -674,9 +611,7 @@ class EvalRouter(Eval):
|
|||
scoring_functions: List[str],
|
||||
benchmark_config: BenchmarkConfig,
|
||||
) -> EvaluateResponse:
|
||||
logger.debug(
|
||||
f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows"
|
||||
)
|
||||
logger.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows")
|
||||
return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows(
|
||||
benchmark_id=benchmark_id,
|
||||
input_rows=input_rows,
|
||||
|
@ -690,9 +625,7 @@ class EvalRouter(Eval):
|
|||
job_id: str,
|
||||
) -> Optional[JobStatus]:
|
||||
logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}")
|
||||
return await self.routing_table.get_provider_impl(benchmark_id).job_status(
|
||||
benchmark_id, job_id
|
||||
)
|
||||
return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id)
|
||||
|
||||
async def job_cancel(
|
||||
self,
|
||||
|
@ -746,9 +679,9 @@ class ToolRuntimeRouter(ToolRuntime):
|
|||
logger.debug(
|
||||
f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}"
|
||||
)
|
||||
return await self.routing_table.get_provider_impl(
|
||||
"insert_into_memory"
|
||||
).insert(documents, vector_db_id, chunk_size_in_tokens)
|
||||
return await self.routing_table.get_provider_impl("insert_into_memory").insert(
|
||||
documents, vector_db_id, chunk_size_in_tokens
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -781,6 +714,4 @@ class ToolRuntimeRouter(ToolRuntime):
|
|||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||
) -> List[ToolDef]:
|
||||
logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
|
||||
return await self.routing_table.get_provider_impl(tool_group_id).list_tools(
|
||||
tool_group_id, mcp_endpoint
|
||||
)
|
||||
return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint)
|
||||
|
|
|
@ -105,9 +105,7 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
self.dist_registry = dist_registry
|
||||
|
||||
async def initialize(self) -> None:
|
||||
async def add_objects(
|
||||
objs: List[RoutableObjectWithProvider], provider_id: str, cls
|
||||
) -> None:
|
||||
async def add_objects(objs: List[RoutableObjectWithProvider], provider_id: str, cls) -> None:
|
||||
for obj in objs:
|
||||
if cls is None:
|
||||
obj.provider_id = provider_id
|
||||
|
@ -142,9 +140,7 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
for p in self.impls_by_provider_id.values():
|
||||
await p.shutdown()
|
||||
|
||||
def get_provider_impl(
|
||||
self, routing_key: str, provider_id: Optional[str] = None
|
||||
) -> Any:
|
||||
def get_provider_impl(self, routing_key: str, provider_id: Optional[str] = None) -> Any:
|
||||
def apiname_object():
|
||||
if isinstance(self, ModelsRoutingTable):
|
||||
return ("Inference", "model")
|
||||
|
@ -182,9 +178,7 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
|
||||
raise ValueError(f"Provider not found for `{routing_key}`")
|
||||
|
||||
async def get_object_by_identifier(
|
||||
self, type: str, identifier: str
|
||||
) -> Optional[RoutableObjectWithProvider]:
|
||||
async def get_object_by_identifier(self, type: str, identifier: str) -> Optional[RoutableObjectWithProvider]:
|
||||
# Get from disk registry
|
||||
obj = await self.dist_registry.get(type, identifier)
|
||||
if not obj:
|
||||
|
@ -194,13 +188,9 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
|
||||
async def unregister_object(self, obj: RoutableObjectWithProvider) -> None:
|
||||
await self.dist_registry.delete(obj.type, obj.identifier)
|
||||
await unregister_object_from_provider(
|
||||
obj, self.impls_by_provider_id[obj.provider_id]
|
||||
)
|
||||
await unregister_object_from_provider(obj, self.impls_by_provider_id[obj.provider_id])
|
||||
|
||||
async def register_object(
|
||||
self, obj: RoutableObjectWithProvider
|
||||
) -> RoutableObjectWithProvider:
|
||||
async def register_object(self, obj: RoutableObjectWithProvider) -> RoutableObjectWithProvider:
|
||||
# if provider_id is not specified, pick an arbitrary one from existing entries
|
||||
if not obj.provider_id and len(self.impls_by_provider_id) > 0:
|
||||
obj.provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||
|
@ -255,9 +245,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
if model_type is None:
|
||||
model_type = ModelType.llm
|
||||
if "embedding_dimension" not in metadata and model_type == ModelType.embedding:
|
||||
raise ValueError(
|
||||
"Embedding model must have an embedding dimension in its metadata"
|
||||
)
|
||||
raise ValueError("Embedding model must have an embedding dimension in its metadata")
|
||||
model = Model(
|
||||
identifier=model_id,
|
||||
provider_resource_id=provider_model_id,
|
||||
|
@ -277,9 +265,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
|
||||
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
||||
async def list_shields(self) -> ListShieldsResponse:
|
||||
return ListShieldsResponse(
|
||||
data=await self.get_all_with_type(ResourceType.shield.value)
|
||||
)
|
||||
return ListShieldsResponse(data=await self.get_all_with_type(ResourceType.shield.value))
|
||||
|
||||
async def get_shield(self, identifier: str) -> Optional[Shield]:
|
||||
return await self.get_object_by_identifier("shield", identifier)
|
||||
|
@ -338,18 +324,14 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
|||
f"No provider specified and multiple providers available. Arbitrarily selected the first provider {provider_id}."
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"No provider available. Please configure a vector_io provider."
|
||||
)
|
||||
raise ValueError("No provider available. Please configure a vector_io provider.")
|
||||
model = await self.get_object_by_identifier("model", embedding_model)
|
||||
if model is None:
|
||||
raise ValueError(f"Model {embedding_model} not found")
|
||||
if model.model_type != ModelType.embedding:
|
||||
raise ValueError(f"Model {embedding_model} is not an embedding model")
|
||||
if "embedding_dimension" not in model.metadata:
|
||||
raise ValueError(
|
||||
f"Model {embedding_model} does not have an embedding dimension"
|
||||
)
|
||||
raise ValueError(f"Model {embedding_model} does not have an embedding dimension")
|
||||
vector_db_data = {
|
||||
"identifier": vector_db_id,
|
||||
"type": ResourceType.vector_db.value,
|
||||
|
@ -371,9 +353,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
|||
|
||||
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
||||
async def list_datasets(self) -> ListDatasetsResponse:
|
||||
return ListDatasetsResponse(
|
||||
data=await self.get_all_with_type(ResourceType.dataset.value)
|
||||
)
|
||||
return ListDatasetsResponse(data=await self.get_all_with_type(ResourceType.dataset.value))
|
||||
|
||||
async def get_dataset(self, dataset_id: str) -> Optional[Dataset]:
|
||||
return await self.get_object_by_identifier("dataset", dataset_id)
|
||||
|
@ -426,9 +406,7 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
|||
|
||||
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
|
||||
async def list_scoring_functions(self) -> ListScoringFunctionsResponse:
|
||||
return ListScoringFunctionsResponse(
|
||||
data=await self.get_all_with_type(ResourceType.scoring_function.value)
|
||||
)
|
||||
return ListScoringFunctionsResponse(data=await self.get_all_with_type(ResourceType.scoring_function.value))
|
||||
|
||||
async def get_scoring_function(self, scoring_fn_id: str) -> Optional[ScoringFn]:
|
||||
return await self.get_object_by_identifier("scoring_function", scoring_fn_id)
|
||||
|
@ -525,12 +503,8 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
args: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
tools = []
|
||||
tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools(
|
||||
toolgroup_id, mcp_endpoint
|
||||
)
|
||||
tool_host = (
|
||||
ToolHost.model_context_protocol if mcp_endpoint else ToolHost.distribution
|
||||
)
|
||||
tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools(toolgroup_id, mcp_endpoint)
|
||||
tool_host = ToolHost.model_context_protocol if mcp_endpoint else ToolHost.distribution
|
||||
|
||||
for tool_def in tool_defs:
|
||||
tools.append(
|
||||
|
|
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
import datasets as hf_datasets
|
||||
|
@ -12,13 +11,11 @@ import datasets as hf_datasets
|
|||
from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse
|
||||
from llama_stack.apis.datasets import Dataset
|
||||
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
|
||||
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
|
||||
from .config import HuggingfaceDatasetIOConfig
|
||||
|
||||
DATASETS_PREFIX = "datasets:"
|
||||
from rich.pretty import pprint
|
||||
|
||||
|
||||
def parse_hf_params(dataset_def: Dataset):
|
||||
|
@ -102,13 +99,9 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
|||
new_dataset = hf_datasets.Dataset.from_list(rows)
|
||||
|
||||
# Concatenate the new rows with existing dataset
|
||||
updated_dataset = hf_datasets.concatenate_datasets(
|
||||
[loaded_dataset, new_dataset]
|
||||
)
|
||||
updated_dataset = hf_datasets.concatenate_datasets([loaded_dataset, new_dataset])
|
||||
|
||||
if dataset_def.metadata.get("path", None):
|
||||
updated_dataset.push_to_hub(dataset_def.metadata["path"])
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Uploading to URL-based datasets is not supported yet"
|
||||
)
|
||||
raise NotImplementedError("Uploading to URL-based datasets is not supported yet")
|
||||
|
|
|
@ -4,10 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import base64
|
||||
import mimetypes
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
|
@ -36,8 +32,6 @@ def test_register_dataset(llama_stack_client, purpose, source, provider_id):
|
|||
)
|
||||
assert dataset.identifier is not None
|
||||
assert dataset.provider_id == provider_id
|
||||
iterrow_response = llama_stack_client.datasets.iterrows(
|
||||
dataset.identifier, limit=10
|
||||
)
|
||||
iterrow_response = llama_stack_client.datasets.iterrows(dataset.identifier, limit=10)
|
||||
assert len(iterrow_response.data) == 10
|
||||
assert iterrow_response.next_index is not None
|
||||
|
|
|
@ -1,3 +1,9 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack_client import LlamaStackClient
|
||||
from rich.pretty import pprint
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue