forked from phoenix-oss/llama-stack-mirror
feat(dataset api): (1.4/n) fix resolver signature mismatch (#1658)
# What does this PR do? - fix datasets api signature mis-match so that llama stack run can start [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan ``` llama stack run ``` <img width="626" alt="image" src="https://github.com/user-attachments/assets/59072d1a-ccb6-453a-80e8-d87419896c41" /> [//]: # (## Documentation)
This commit is contained in:
parent
72ccdc19a8
commit
2c9d624910
10 changed files with 105 additions and 80 deletions
|
@ -12,7 +12,8 @@ from llama_stack.apis.common.content_types import (
|
|||
InterleavedContent,
|
||||
InterleavedContentItem,
|
||||
)
|
||||
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
|
||||
from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse
|
||||
from llama_stack.apis.datasets import DatasetPurpose, DataSource
|
||||
from llama_stack.apis.eval import (
|
||||
BenchmarkConfig,
|
||||
Eval,
|
||||
|
@ -160,7 +161,11 @@ class InferenceRouter(Inference):
|
|||
await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)
|
||||
|
||||
def _construct_metrics(
|
||||
self, prompt_tokens: int, completion_tokens: int, total_tokens: int, model: Model
|
||||
self,
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
total_tokens: int,
|
||||
model: Model,
|
||||
) -> List[MetricEvent]:
|
||||
"""Constructs a list of MetricEvent objects containing token usage metrics.
|
||||
|
||||
|
@ -298,7 +303,12 @@ class InferenceRouter(Inference):
|
|||
completion_text += chunk.event.delta.text
|
||||
if chunk.event.event_type == ChatCompletionResponseEventType.complete:
|
||||
completion_tokens = await self._count_tokens(
|
||||
[CompletionMessage(content=completion_text, stop_reason=StopReason.end_of_turn)],
|
||||
[
|
||||
CompletionMessage(
|
||||
content=completion_text,
|
||||
stop_reason=StopReason.end_of_turn,
|
||||
)
|
||||
],
|
||||
tool_config.tool_prompt_format,
|
||||
)
|
||||
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
||||
|
@ -471,21 +481,36 @@ class DatasetIORouter(DatasetIO):
|
|||
logger.debug("DatasetIORouter.shutdown")
|
||||
pass
|
||||
|
||||
async def get_rows_paginated(
|
||||
async def register_dataset(
|
||||
self,
|
||||
purpose: DatasetPurpose,
|
||||
source: DataSource,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
dataset_id: Optional[str] = None,
|
||||
) -> None:
|
||||
logger.debug(
|
||||
f"DatasetIORouter.register_dataset: {purpose=} {source=} {metadata=} {dataset_id=}",
|
||||
)
|
||||
await self.routing_table.register_dataset(
|
||||
purpose=purpose,
|
||||
source=source,
|
||||
metadata=metadata,
|
||||
dataset_id=dataset_id,
|
||||
)
|
||||
|
||||
async def iterrows(
|
||||
self,
|
||||
dataset_id: str,
|
||||
rows_in_page: int,
|
||||
page_token: Optional[str] = None,
|
||||
filter_condition: Optional[str] = None,
|
||||
) -> PaginatedRowsResult:
|
||||
start_index: Optional[int] = None,
|
||||
limit: Optional[int] = None,
|
||||
) -> IterrowsResponse:
|
||||
logger.debug(
|
||||
f"DatasetIORouter.get_rows_paginated: {dataset_id}, rows_in_page={rows_in_page}",
|
||||
f"DatasetIORouter.iterrows: {dataset_id}, {start_index=} {limit=}",
|
||||
)
|
||||
return await self.routing_table.get_provider_impl(dataset_id).get_rows_paginated(
|
||||
return await self.routing_table.get_provider_impl(dataset_id).iterrows(
|
||||
dataset_id=dataset_id,
|
||||
rows_in_page=rows_in_page,
|
||||
page_token=page_token,
|
||||
filter_condition=filter_condition,
|
||||
start_index=start_index,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue