mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-10 04:08:31 +00:00
fix endpoint, only sdk change
This commit is contained in:
parent
13c7c5b6a1
commit
9e6d99f7b1
8 changed files with 161 additions and 72 deletions
12
docs/_static/llama-stack-spec.html
vendored
12
docs/_static/llama-stack-spec.html
vendored
|
@ -40,7 +40,7 @@
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"paths": {
|
"paths": {
|
||||||
"/v1/datasets/{dataset_id}/append-rows": {
|
"/v1/datasetio/append-rows/{dataset_id}": {
|
||||||
"post": {
|
"post": {
|
||||||
"responses": {
|
"responses": {
|
||||||
"200": {
|
"200": {
|
||||||
|
@ -60,7 +60,7 @@
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"tags": [
|
"tags": [
|
||||||
"Datasets"
|
"DatasetIO"
|
||||||
],
|
],
|
||||||
"description": "",
|
"description": "",
|
||||||
"parameters": [
|
"parameters": [
|
||||||
|
@ -2177,7 +2177,7 @@
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"/v1/datasets/{dataset_id}/iterrows": {
|
"/v1/datasetio/iterrows/{dataset_id}": {
|
||||||
"get": {
|
"get": {
|
||||||
"responses": {
|
"responses": {
|
||||||
"200": {
|
"200": {
|
||||||
|
@ -2204,7 +2204,7 @@
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"tags": [
|
"tags": [
|
||||||
"Datasets"
|
"DatasetIO"
|
||||||
],
|
],
|
||||||
"description": "Get a paginated list of rows from a dataset. Uses cursor-based pagination.",
|
"description": "Get a paginated list of rows from a dataset. Uses cursor-based pagination.",
|
||||||
"parameters": [
|
"parameters": [
|
||||||
|
@ -10274,7 +10274,7 @@
|
||||||
"name": "Benchmarks"
|
"name": "Benchmarks"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "Datasets"
|
"name": "DatasetIO"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "Datasets"
|
"name": "Datasets"
|
||||||
|
@ -10342,7 +10342,7 @@
|
||||||
"Agents",
|
"Agents",
|
||||||
"BatchInference (Coming Soon)",
|
"BatchInference (Coming Soon)",
|
||||||
"Benchmarks",
|
"Benchmarks",
|
||||||
"Datasets",
|
"DatasetIO",
|
||||||
"Datasets",
|
"Datasets",
|
||||||
"Eval",
|
"Eval",
|
||||||
"Files",
|
"Files",
|
||||||
|
|
12
docs/_static/llama-stack-spec.yaml
vendored
12
docs/_static/llama-stack-spec.yaml
vendored
|
@ -10,7 +10,7 @@ info:
|
||||||
servers:
|
servers:
|
||||||
- url: http://any-hosted-llama-stack.com
|
- url: http://any-hosted-llama-stack.com
|
||||||
paths:
|
paths:
|
||||||
/v1/datasets/{dataset_id}/append-rows:
|
/v1/datasetio/append-rows/{dataset_id}:
|
||||||
post:
|
post:
|
||||||
responses:
|
responses:
|
||||||
'200':
|
'200':
|
||||||
|
@ -26,7 +26,7 @@ paths:
|
||||||
default:
|
default:
|
||||||
$ref: '#/components/responses/DefaultError'
|
$ref: '#/components/responses/DefaultError'
|
||||||
tags:
|
tags:
|
||||||
- Datasets
|
- DatasetIO
|
||||||
description: ''
|
description: ''
|
||||||
parameters:
|
parameters:
|
||||||
- name: dataset_id
|
- name: dataset_id
|
||||||
|
@ -1457,7 +1457,7 @@ paths:
|
||||||
schema:
|
schema:
|
||||||
$ref: '#/components/schemas/InvokeToolRequest'
|
$ref: '#/components/schemas/InvokeToolRequest'
|
||||||
required: true
|
required: true
|
||||||
/v1/datasets/{dataset_id}/iterrows:
|
/v1/datasetio/iterrows/{dataset_id}:
|
||||||
get:
|
get:
|
||||||
responses:
|
responses:
|
||||||
'200':
|
'200':
|
||||||
|
@ -1477,7 +1477,7 @@ paths:
|
||||||
default:
|
default:
|
||||||
$ref: '#/components/responses/DefaultError'
|
$ref: '#/components/responses/DefaultError'
|
||||||
tags:
|
tags:
|
||||||
- Datasets
|
- DatasetIO
|
||||||
description: >-
|
description: >-
|
||||||
Get a paginated list of rows from a dataset. Uses cursor-based pagination.
|
Get a paginated list of rows from a dataset. Uses cursor-based pagination.
|
||||||
parameters:
|
parameters:
|
||||||
|
@ -6931,7 +6931,7 @@ tags:
|
||||||
Agents API for creating and interacting with agentic systems.
|
Agents API for creating and interacting with agentic systems.
|
||||||
- name: BatchInference (Coming Soon)
|
- name: BatchInference (Coming Soon)
|
||||||
- name: Benchmarks
|
- name: Benchmarks
|
||||||
- name: Datasets
|
- name: DatasetIO
|
||||||
- name: Datasets
|
- name: Datasets
|
||||||
- name: Eval
|
- name: Eval
|
||||||
x-displayName: >-
|
x-displayName: >-
|
||||||
|
@ -6971,7 +6971,7 @@ x-tagGroups:
|
||||||
- Agents
|
- Agents
|
||||||
- BatchInference (Coming Soon)
|
- BatchInference (Coming Soon)
|
||||||
- Benchmarks
|
- Benchmarks
|
||||||
- Datasets
|
- DatasetIO
|
||||||
- Datasets
|
- Datasets
|
||||||
- Eval
|
- Eval
|
||||||
- Files
|
- Files
|
||||||
|
|
|
@ -552,8 +552,8 @@ class Generator:
|
||||||
print(op.defining_class.__name__)
|
print(op.defining_class.__name__)
|
||||||
|
|
||||||
# TODO (xiyan): temporary fix for datasetio inner impl + datasets api
|
# TODO (xiyan): temporary fix for datasetio inner impl + datasets api
|
||||||
if op.defining_class.__name__ in ["DatasetIO"]:
|
# if op.defining_class.__name__ in ["DatasetIO"]:
|
||||||
op.defining_class.__name__ = "Datasets"
|
# op.defining_class.__name__ = "Datasets"
|
||||||
|
|
||||||
doc_string = parse_type(op.func_ref)
|
doc_string = parse_type(op.func_ref)
|
||||||
doc_params = dict(
|
doc_params = dict(
|
||||||
|
|
|
@ -34,7 +34,8 @@ class DatasetIO(Protocol):
|
||||||
# keeping for aligning with inference/safety, but this is not used
|
# keeping for aligning with inference/safety, but this is not used
|
||||||
dataset_store: DatasetStore
|
dataset_store: DatasetStore
|
||||||
|
|
||||||
@webmethod(route="/datasets/{dataset_id}/iterrows", method="GET")
|
# TODO(xiyan): there's a flakiness here where setting route to "/datasets/" here will not result in proper routing
|
||||||
|
@webmethod(route="/datasetio/iterrows/{dataset_id:path}", method="GET")
|
||||||
async def iterrows(
|
async def iterrows(
|
||||||
self,
|
self,
|
||||||
dataset_id: str,
|
dataset_id: str,
|
||||||
|
@ -49,5 +50,7 @@ class DatasetIO(Protocol):
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@webmethod(route="/datasets/{dataset_id}/append-rows", method="POST")
|
@webmethod(route="/datasetio/append-rows/{dataset_id:path}", method="POST")
|
||||||
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: ...
|
async def append_rows(
|
||||||
|
self, dataset_id: str, rows: List[Dict[str, Any]]
|
||||||
|
) -> None: ...
|
||||||
|
|
|
@ -8,9 +8,9 @@ import time
|
||||||
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import (
|
from llama_stack.apis.common.content_types import (
|
||||||
URL,
|
|
||||||
InterleavedContent,
|
InterleavedContent,
|
||||||
InterleavedContentItem,
|
InterleavedContentItem,
|
||||||
|
URL,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse
|
from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse
|
||||||
from llama_stack.apis.datasets import DatasetPurpose, DataSource
|
from llama_stack.apis.datasets import DatasetPurpose, DataSource
|
||||||
|
@ -94,7 +94,9 @@ class VectorIORouter(VectorIO):
|
||||||
provider_id: Optional[str] = None,
|
provider_id: Optional[str] = None,
|
||||||
provider_vector_db_id: Optional[str] = None,
|
provider_vector_db_id: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
logger.debug(f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}")
|
logger.debug(
|
||||||
|
f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}"
|
||||||
|
)
|
||||||
await self.routing_table.register_vector_db(
|
await self.routing_table.register_vector_db(
|
||||||
vector_db_id,
|
vector_db_id,
|
||||||
embedding_model,
|
embedding_model,
|
||||||
|
@ -112,7 +114,9 @@ class VectorIORouter(VectorIO):
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}",
|
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}",
|
||||||
)
|
)
|
||||||
return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(vector_db_id, chunks, ttl_seconds)
|
return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(
|
||||||
|
vector_db_id, chunks, ttl_seconds
|
||||||
|
)
|
||||||
|
|
||||||
async def query_chunks(
|
async def query_chunks(
|
||||||
self,
|
self,
|
||||||
|
@ -121,7 +125,9 @@ class VectorIORouter(VectorIO):
|
||||||
params: Optional[Dict[str, Any]] = None,
|
params: Optional[Dict[str, Any]] = None,
|
||||||
) -> QueryChunksResponse:
|
) -> QueryChunksResponse:
|
||||||
logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}")
|
logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}")
|
||||||
return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params)
|
return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(
|
||||||
|
vector_db_id, query, params
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class InferenceRouter(Inference):
|
class InferenceRouter(Inference):
|
||||||
|
@ -158,7 +164,9 @@ class InferenceRouter(Inference):
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}",
|
f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}",
|
||||||
)
|
)
|
||||||
await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)
|
await self.routing_table.register_model(
|
||||||
|
model_id, provider_model_id, provider_id, metadata, model_type
|
||||||
|
)
|
||||||
|
|
||||||
def _construct_metrics(
|
def _construct_metrics(
|
||||||
self,
|
self,
|
||||||
|
@ -212,11 +220,16 @@ class InferenceRouter(Inference):
|
||||||
total_tokens: int,
|
total_tokens: int,
|
||||||
model: Model,
|
model: Model,
|
||||||
) -> List[MetricInResponse]:
|
) -> List[MetricInResponse]:
|
||||||
metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model)
|
metrics = self._construct_metrics(
|
||||||
|
prompt_tokens, completion_tokens, total_tokens, model
|
||||||
|
)
|
||||||
if self.telemetry:
|
if self.telemetry:
|
||||||
for metric in metrics:
|
for metric in metrics:
|
||||||
await self.telemetry.log_event(metric)
|
await self.telemetry.log_event(metric)
|
||||||
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics]
|
return [
|
||||||
|
MetricInResponse(metric=metric.metric, value=metric.value)
|
||||||
|
for metric in metrics
|
||||||
|
]
|
||||||
|
|
||||||
async def _count_tokens(
|
async def _count_tokens(
|
||||||
self,
|
self,
|
||||||
|
@ -241,7 +254,9 @@ class InferenceRouter(Inference):
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
tool_config: Optional[ToolConfig] = None,
|
tool_config: Optional[ToolConfig] = None,
|
||||||
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
|
) -> Union[
|
||||||
|
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
|
||||||
|
]:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}",
|
f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}",
|
||||||
)
|
)
|
||||||
|
@ -251,12 +266,19 @@ class InferenceRouter(Inference):
|
||||||
if model is None:
|
if model is None:
|
||||||
raise ValueError(f"Model '{model_id}' not found")
|
raise ValueError(f"Model '{model_id}' not found")
|
||||||
if model.model_type == ModelType.embedding:
|
if model.model_type == ModelType.embedding:
|
||||||
raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions")
|
raise ValueError(
|
||||||
|
f"Model '{model_id}' is an embedding model and does not support chat completions"
|
||||||
|
)
|
||||||
if tool_config:
|
if tool_config:
|
||||||
if tool_choice and tool_choice != tool_config.tool_choice:
|
if tool_choice and tool_choice != tool_config.tool_choice:
|
||||||
raise ValueError("tool_choice and tool_config.tool_choice must match")
|
raise ValueError("tool_choice and tool_config.tool_choice must match")
|
||||||
if tool_prompt_format and tool_prompt_format != tool_config.tool_prompt_format:
|
if (
|
||||||
raise ValueError("tool_prompt_format and tool_config.tool_prompt_format must match")
|
tool_prompt_format
|
||||||
|
and tool_prompt_format != tool_config.tool_prompt_format
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"tool_prompt_format and tool_config.tool_prompt_format must match"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
params = {}
|
params = {}
|
||||||
if tool_choice:
|
if tool_choice:
|
||||||
|
@ -274,9 +296,14 @@ class InferenceRouter(Inference):
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
# verify tool_choice is one of the tools
|
# verify tool_choice is one of the tools
|
||||||
tool_names = [t.tool_name if isinstance(t.tool_name, str) else t.tool_name.value for t in tools]
|
tool_names = [
|
||||||
|
t.tool_name if isinstance(t.tool_name, str) else t.tool_name.value
|
||||||
|
for t in tools
|
||||||
|
]
|
||||||
if tool_config.tool_choice not in tool_names:
|
if tool_config.tool_choice not in tool_names:
|
||||||
raise ValueError(f"Tool choice {tool_config.tool_choice} is not one of the tools: {tool_names}")
|
raise ValueError(
|
||||||
|
f"Tool choice {tool_config.tool_choice} is not one of the tools: {tool_names}"
|
||||||
|
)
|
||||||
|
|
||||||
params = dict(
|
params = dict(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
|
@ -291,17 +318,25 @@ class InferenceRouter(Inference):
|
||||||
tool_config=tool_config,
|
tool_config=tool_config,
|
||||||
)
|
)
|
||||||
provider = self.routing_table.get_provider_impl(model_id)
|
provider = self.routing_table.get_provider_impl(model_id)
|
||||||
prompt_tokens = await self._count_tokens(messages, tool_config.tool_prompt_format)
|
prompt_tokens = await self._count_tokens(
|
||||||
|
messages, tool_config.tool_prompt_format
|
||||||
|
)
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
|
|
||||||
async def stream_generator():
|
async def stream_generator():
|
||||||
completion_text = ""
|
completion_text = ""
|
||||||
async for chunk in await provider.chat_completion(**params):
|
async for chunk in await provider.chat_completion(**params):
|
||||||
if chunk.event.event_type == ChatCompletionResponseEventType.progress:
|
if (
|
||||||
|
chunk.event.event_type
|
||||||
|
== ChatCompletionResponseEventType.progress
|
||||||
|
):
|
||||||
if chunk.event.delta.type == "text":
|
if chunk.event.delta.type == "text":
|
||||||
completion_text += chunk.event.delta.text
|
completion_text += chunk.event.delta.text
|
||||||
if chunk.event.event_type == ChatCompletionResponseEventType.complete:
|
if (
|
||||||
|
chunk.event.event_type
|
||||||
|
== ChatCompletionResponseEventType.complete
|
||||||
|
):
|
||||||
completion_tokens = await self._count_tokens(
|
completion_tokens = await self._count_tokens(
|
||||||
[
|
[
|
||||||
CompletionMessage(
|
CompletionMessage(
|
||||||
|
@ -318,7 +353,11 @@ class InferenceRouter(Inference):
|
||||||
total_tokens,
|
total_tokens,
|
||||||
model,
|
model,
|
||||||
)
|
)
|
||||||
chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics
|
chunk.metrics = (
|
||||||
|
metrics
|
||||||
|
if chunk.metrics is None
|
||||||
|
else chunk.metrics + metrics
|
||||||
|
)
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
return stream_generator()
|
return stream_generator()
|
||||||
|
@ -335,7 +374,9 @@ class InferenceRouter(Inference):
|
||||||
total_tokens,
|
total_tokens,
|
||||||
model,
|
model,
|
||||||
)
|
)
|
||||||
response.metrics = metrics if response.metrics is None else response.metrics + metrics
|
response.metrics = (
|
||||||
|
metrics if response.metrics is None else response.metrics + metrics
|
||||||
|
)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
async def completion(
|
async def completion(
|
||||||
|
@ -356,7 +397,9 @@ class InferenceRouter(Inference):
|
||||||
if model is None:
|
if model is None:
|
||||||
raise ValueError(f"Model '{model_id}' not found")
|
raise ValueError(f"Model '{model_id}' not found")
|
||||||
if model.model_type == ModelType.embedding:
|
if model.model_type == ModelType.embedding:
|
||||||
raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions")
|
raise ValueError(
|
||||||
|
f"Model '{model_id}' is an embedding model and does not support chat completions"
|
||||||
|
)
|
||||||
provider = self.routing_table.get_provider_impl(model_id)
|
provider = self.routing_table.get_provider_impl(model_id)
|
||||||
params = dict(
|
params = dict(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
|
@ -376,7 +419,11 @@ class InferenceRouter(Inference):
|
||||||
async for chunk in await provider.completion(**params):
|
async for chunk in await provider.completion(**params):
|
||||||
if hasattr(chunk, "delta"):
|
if hasattr(chunk, "delta"):
|
||||||
completion_text += chunk.delta
|
completion_text += chunk.delta
|
||||||
if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry:
|
if (
|
||||||
|
hasattr(chunk, "stop_reason")
|
||||||
|
and chunk.stop_reason
|
||||||
|
and self.telemetry
|
||||||
|
):
|
||||||
completion_tokens = await self._count_tokens(completion_text)
|
completion_tokens = await self._count_tokens(completion_text)
|
||||||
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
||||||
metrics = await self._compute_and_log_token_usage(
|
metrics = await self._compute_and_log_token_usage(
|
||||||
|
@ -385,7 +432,11 @@ class InferenceRouter(Inference):
|
||||||
total_tokens,
|
total_tokens,
|
||||||
model,
|
model,
|
||||||
)
|
)
|
||||||
chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics
|
chunk.metrics = (
|
||||||
|
metrics
|
||||||
|
if chunk.metrics is None
|
||||||
|
else chunk.metrics + metrics
|
||||||
|
)
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
return stream_generator()
|
return stream_generator()
|
||||||
|
@ -399,7 +450,9 @@ class InferenceRouter(Inference):
|
||||||
total_tokens,
|
total_tokens,
|
||||||
model,
|
model,
|
||||||
)
|
)
|
||||||
response.metrics = metrics if response.metrics is None else response.metrics + metrics
|
response.metrics = (
|
||||||
|
metrics if response.metrics is None else response.metrics + metrics
|
||||||
|
)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
|
@ -415,7 +468,9 @@ class InferenceRouter(Inference):
|
||||||
if model is None:
|
if model is None:
|
||||||
raise ValueError(f"Model '{model_id}' not found")
|
raise ValueError(f"Model '{model_id}' not found")
|
||||||
if model.model_type == ModelType.llm:
|
if model.model_type == ModelType.llm:
|
||||||
raise ValueError(f"Model '{model_id}' is an LLM model and does not support embeddings")
|
raise ValueError(
|
||||||
|
f"Model '{model_id}' is an LLM model and does not support embeddings"
|
||||||
|
)
|
||||||
return await self.routing_table.get_provider_impl(model_id).embeddings(
|
return await self.routing_table.get_provider_impl(model_id).embeddings(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
contents=contents,
|
contents=contents,
|
||||||
|
@ -449,7 +504,9 @@ class SafetyRouter(Safety):
|
||||||
params: Optional[Dict[str, Any]] = None,
|
params: Optional[Dict[str, Any]] = None,
|
||||||
) -> Shield:
|
) -> Shield:
|
||||||
logger.debug(f"SafetyRouter.register_shield: {shield_id}")
|
logger.debug(f"SafetyRouter.register_shield: {shield_id}")
|
||||||
return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params)
|
return await self.routing_table.register_shield(
|
||||||
|
shield_id, provider_shield_id, provider_id, params
|
||||||
|
)
|
||||||
|
|
||||||
async def run_shield(
|
async def run_shield(
|
||||||
self,
|
self,
|
||||||
|
@ -546,7 +603,9 @@ class ScoringRouter(Scoring):
|
||||||
logger.debug(f"ScoringRouter.score_batch: {dataset_id}")
|
logger.debug(f"ScoringRouter.score_batch: {dataset_id}")
|
||||||
res = {}
|
res = {}
|
||||||
for fn_identifier in scoring_functions.keys():
|
for fn_identifier in scoring_functions.keys():
|
||||||
score_response = await self.routing_table.get_provider_impl(fn_identifier).score_batch(
|
score_response = await self.routing_table.get_provider_impl(
|
||||||
|
fn_identifier
|
||||||
|
).score_batch(
|
||||||
dataset_id=dataset_id,
|
dataset_id=dataset_id,
|
||||||
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
|
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
|
||||||
)
|
)
|
||||||
|
@ -564,11 +623,15 @@ class ScoringRouter(Scoring):
|
||||||
input_rows: List[Dict[str, Any]],
|
input_rows: List[Dict[str, Any]],
|
||||||
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
|
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
|
||||||
) -> ScoreResponse:
|
) -> ScoreResponse:
|
||||||
logger.debug(f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions")
|
logger.debug(
|
||||||
|
f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions"
|
||||||
|
)
|
||||||
res = {}
|
res = {}
|
||||||
# look up and map each scoring function to its provider impl
|
# look up and map each scoring function to its provider impl
|
||||||
for fn_identifier in scoring_functions.keys():
|
for fn_identifier in scoring_functions.keys():
|
||||||
score_response = await self.routing_table.get_provider_impl(fn_identifier).score(
|
score_response = await self.routing_table.get_provider_impl(
|
||||||
|
fn_identifier
|
||||||
|
).score(
|
||||||
input_rows=input_rows,
|
input_rows=input_rows,
|
||||||
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
|
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
|
||||||
)
|
)
|
||||||
|
@ -611,7 +674,9 @@ class EvalRouter(Eval):
|
||||||
scoring_functions: List[str],
|
scoring_functions: List[str],
|
||||||
benchmark_config: BenchmarkConfig,
|
benchmark_config: BenchmarkConfig,
|
||||||
) -> EvaluateResponse:
|
) -> EvaluateResponse:
|
||||||
logger.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows")
|
logger.debug(
|
||||||
|
f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows"
|
||||||
|
)
|
||||||
return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows(
|
return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows(
|
||||||
benchmark_id=benchmark_id,
|
benchmark_id=benchmark_id,
|
||||||
input_rows=input_rows,
|
input_rows=input_rows,
|
||||||
|
@ -625,7 +690,9 @@ class EvalRouter(Eval):
|
||||||
job_id: str,
|
job_id: str,
|
||||||
) -> Optional[JobStatus]:
|
) -> Optional[JobStatus]:
|
||||||
logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}")
|
logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}")
|
||||||
return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id)
|
return await self.routing_table.get_provider_impl(benchmark_id).job_status(
|
||||||
|
benchmark_id, job_id
|
||||||
|
)
|
||||||
|
|
||||||
async def job_cancel(
|
async def job_cancel(
|
||||||
self,
|
self,
|
||||||
|
@ -679,9 +746,9 @@ class ToolRuntimeRouter(ToolRuntime):
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}"
|
f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}"
|
||||||
)
|
)
|
||||||
return await self.routing_table.get_provider_impl("insert_into_memory").insert(
|
return await self.routing_table.get_provider_impl(
|
||||||
documents, vector_db_id, chunk_size_in_tokens
|
"insert_into_memory"
|
||||||
)
|
).insert(documents, vector_db_id, chunk_size_in_tokens)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -714,4 +781,6 @@ class ToolRuntimeRouter(ToolRuntime):
|
||||||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||||
) -> List[ToolDef]:
|
) -> List[ToolDef]:
|
||||||
logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
|
logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
|
||||||
return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint)
|
return await self.routing_table.get_provider_impl(tool_group_id).list_tools(
|
||||||
|
tool_group_id, mcp_endpoint
|
||||||
|
)
|
||||||
|
|
|
@ -5,6 +5,8 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from urllib.parse import parse_qs, urlparse
|
||||||
|
|
||||||
import datasets as hf_datasets
|
import datasets as hf_datasets
|
||||||
|
|
||||||
from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse
|
from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse
|
||||||
|
@ -16,24 +18,17 @@ from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||||
from .config import HuggingfaceDatasetIOConfig
|
from .config import HuggingfaceDatasetIOConfig
|
||||||
|
|
||||||
DATASETS_PREFIX = "datasets:"
|
DATASETS_PREFIX = "datasets:"
|
||||||
|
from rich.pretty import pprint
|
||||||
|
|
||||||
|
|
||||||
def load_hf_dataset(dataset_def: Dataset):
|
def parse_hf_params(dataset_def: Dataset):
|
||||||
if dataset_def.metadata.get("path", None):
|
uri = dataset_def.source.uri
|
||||||
dataset = hf_datasets.load_dataset(**dataset_def.metadata)
|
parsed_uri = urlparse(uri)
|
||||||
else:
|
params = parse_qs(parsed_uri.query)
|
||||||
df = get_dataframe_from_url(dataset_def.url)
|
params = {k: v[0] for k, v in params.items()}
|
||||||
|
path = parsed_uri.path.lstrip("/")
|
||||||
|
|
||||||
if df is None:
|
return path, params
|
||||||
raise ValueError(f"Failed to load dataset from {dataset_def.url}")
|
|
||||||
|
|
||||||
dataset = hf_datasets.Dataset.from_pandas(df)
|
|
||||||
|
|
||||||
# drop columns not specified by schema
|
|
||||||
if dataset_def.dataset_schema:
|
|
||||||
dataset = dataset.select_columns(list(dataset_def.dataset_schema.keys()))
|
|
||||||
|
|
||||||
return dataset
|
|
||||||
|
|
||||||
|
|
||||||
class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
||||||
|
@ -60,6 +55,7 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
||||||
self,
|
self,
|
||||||
dataset_def: Dataset,
|
dataset_def: Dataset,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
print("register_dataset")
|
||||||
# Store in kvstore
|
# Store in kvstore
|
||||||
key = f"{DATASETS_PREFIX}{dataset_def.identifier}"
|
key = f"{DATASETS_PREFIX}{dataset_def.identifier}"
|
||||||
await self.kvstore.set(
|
await self.kvstore.set(
|
||||||
|
@ -80,7 +76,8 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
||||||
limit: Optional[int] = None,
|
limit: Optional[int] = None,
|
||||||
) -> IterrowsResponse:
|
) -> IterrowsResponse:
|
||||||
dataset_def = self.dataset_infos[dataset_id]
|
dataset_def = self.dataset_infos[dataset_id]
|
||||||
loaded_dataset = load_hf_dataset(dataset_def)
|
path, params = parse_hf_params(dataset_def)
|
||||||
|
loaded_dataset = hf_datasets.load_dataset(path, **params)
|
||||||
|
|
||||||
start_index = start_index or 0
|
start_index = start_index or 0
|
||||||
|
|
||||||
|
@ -98,15 +95,20 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
||||||
|
|
||||||
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
|
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
|
||||||
dataset_def = self.dataset_infos[dataset_id]
|
dataset_def = self.dataset_infos[dataset_id]
|
||||||
loaded_dataset = load_hf_dataset(dataset_def)
|
path, params = parse_hf_params(dataset_def)
|
||||||
|
loaded_dataset = hf_datasets.load_dataset(path, **params)
|
||||||
|
|
||||||
# Convert rows to HF Dataset format
|
# Convert rows to HF Dataset format
|
||||||
new_dataset = hf_datasets.Dataset.from_list(rows)
|
new_dataset = hf_datasets.Dataset.from_list(rows)
|
||||||
|
|
||||||
# Concatenate the new rows with existing dataset
|
# Concatenate the new rows with existing dataset
|
||||||
updated_dataset = hf_datasets.concatenate_datasets([loaded_dataset, new_dataset])
|
updated_dataset = hf_datasets.concatenate_datasets(
|
||||||
|
[loaded_dataset, new_dataset]
|
||||||
|
)
|
||||||
|
|
||||||
if dataset_def.metadata.get("path", None):
|
if dataset_def.metadata.get("path", None):
|
||||||
updated_dataset.push_to_hub(dataset_def.metadata["path"])
|
updated_dataset.push_to_hub(dataset_def.metadata["path"])
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("Uploading to URL-based datasets is not supported yet")
|
raise NotImplementedError(
|
||||||
|
"Uploading to URL-based datasets is not supported yet"
|
||||||
|
)
|
||||||
|
|
|
@ -19,6 +19,15 @@ import pytest
|
||||||
def test_register_dataset(llama_stack_client):
|
def test_register_dataset(llama_stack_client):
|
||||||
dataset = llama_stack_client.datasets.register(
|
dataset = llama_stack_client.datasets.register(
|
||||||
purpose="eval/messages-answer",
|
purpose="eval/messages-answer",
|
||||||
source={"type": "uri", "uri": "huggingface://llamastack/simpleqa?split=train"},
|
source={
|
||||||
|
"type": "uri",
|
||||||
|
"uri": "huggingface://datasets/llamastack/simpleqa?split=train",
|
||||||
|
},
|
||||||
)
|
)
|
||||||
print(dataset)
|
assert dataset.identifier is not None
|
||||||
|
assert dataset.provider_id == "huggingface"
|
||||||
|
iterrow_response = llama_stack_client.datasets.iterrows(
|
||||||
|
dataset.identifier, limit=10
|
||||||
|
)
|
||||||
|
assert len(iterrow_response.data) == 10
|
||||||
|
assert iterrow_response.next_index is not None
|
||||||
|
|
|
@ -6,9 +6,15 @@ def test_register_dataset():
|
||||||
client = LlamaStackClient(base_url="http://localhost:8321")
|
client = LlamaStackClient(base_url="http://localhost:8321")
|
||||||
dataset = client.datasets.register(
|
dataset = client.datasets.register(
|
||||||
purpose="eval/messages-answer",
|
purpose="eval/messages-answer",
|
||||||
source={"type": "uri", "uri": "huggingface://llamastack/simpleqa?split=train"},
|
source={
|
||||||
|
"type": "uri",
|
||||||
|
"uri": "huggingface://datasets/llamastack/simpleqa?split=train",
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
dataset_id = dataset.identifier
|
||||||
pprint(dataset)
|
pprint(dataset)
|
||||||
|
rows = client.datasets.iterrows(dataset_id=dataset_id, limit=10)
|
||||||
|
pprint(rows)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue