Merge remote-tracking branch 'origin/main' into if_eval

This commit is contained in:
Botao Chen 2025-03-19 12:58:14 -07:00
commit 9068416bc4
18 changed files with 183 additions and 135 deletions

2
.github/TRIAGERS.md vendored Normal file
View file

@ -0,0 +1,2 @@
# This file documents Triage members in the Llama Stack community
@franciscojavierarceo @leseb

View file

@ -4159,70 +4159,80 @@
] ]
}, },
"arguments": { "arguments": {
"type": "object", "oneOf": [
"additionalProperties": { {
"oneOf": [ "type": "string"
{ },
"type": "string" {
}, "type": "object",
{ "additionalProperties": {
"type": "integer" "oneOf": [
}, {
{ "type": "string"
"type": "number" },
}, {
{ "type": "integer"
"type": "boolean" },
}, {
{ "type": "number"
"type": "null" },
}, {
{ "type": "boolean"
"type": "array", },
"items": { {
"oneOf": [ "type": "null"
{ },
"type": "string" {
}, "type": "array",
{ "items": {
"type": "integer" "oneOf": [
}, {
{ "type": "string"
"type": "number" },
}, {
{ "type": "integer"
"type": "boolean" },
}, {
{ "type": "number"
"type": "null" },
{
"type": "boolean"
},
{
"type": "null"
}
]
} }
] },
} {
}, "type": "object",
{ "additionalProperties": {
"type": "object", "oneOf": [
"additionalProperties": { {
"oneOf": [ "type": "string"
{ },
"type": "string" {
}, "type": "integer"
{ },
"type": "integer" {
}, "type": "number"
{ },
"type": "number" {
}, "type": "boolean"
{ },
"type": "boolean" {
}, "type": "null"
{ }
"type": "null" ]
} }
] }
} ]
} }
] }
} ]
},
"arguments_json": {
"type": "string"
} }
}, },
"additionalProperties": false, "additionalProperties": false,
@ -7788,7 +7798,8 @@
"type": "object", "type": "object",
"properties": { "properties": {
"document_id": { "document_id": {
"type": "string" "type": "string",
"description": "The unique identifier for the document."
}, },
"content": { "content": {
"oneOf": [ "oneOf": [
@ -7807,10 +7818,12 @@
{ {
"$ref": "#/components/schemas/URL" "$ref": "#/components/schemas/URL"
} }
] ],
"description": "The content of the document."
}, },
"mime_type": { "mime_type": {
"type": "string" "type": "string",
"description": "The MIME type of the document."
}, },
"metadata": { "metadata": {
"type": "object", "type": "object",
@ -7835,7 +7848,8 @@
"type": "object" "type": "object"
} }
] ]
} },
"description": "Additional metadata for the document."
} }
}, },
"additionalProperties": false, "additionalProperties": false,
@ -7844,7 +7858,8 @@
"content", "content",
"metadata" "metadata"
], ],
"title": "RAGDocument" "title": "RAGDocument",
"description": "A document to be used for document ingestion in the RAG Tool."
}, },
"InsertRequest": { "InsertRequest": {
"type": "object", "type": "object",

View file

@ -2864,30 +2864,34 @@ components:
title: BuiltinTool title: BuiltinTool
- type: string - type: string
arguments: arguments:
type: object oneOf:
additionalProperties: - type: string
oneOf: - type: object
- type: string additionalProperties:
- type: integer oneOf:
- type: number - type: string
- type: boolean - type: integer
- type: 'null' - type: number
- type: array - type: boolean
items: - type: 'null'
oneOf: - type: array
- type: string items:
- type: integer oneOf:
- type: number - type: string
- type: boolean - type: integer
- type: 'null' - type: number
- type: object - type: boolean
additionalProperties: - type: 'null'
oneOf: - type: object
- type: string additionalProperties:
- type: integer oneOf:
- type: number - type: string
- type: boolean - type: integer
- type: 'null' - type: number
- type: boolean
- type: 'null'
arguments_json:
type: string
additionalProperties: false additionalProperties: false
required: required:
- call_id - call_id
@ -5376,6 +5380,7 @@ components:
properties: properties:
document_id: document_id:
type: string type: string
description: The unique identifier for the document.
content: content:
oneOf: oneOf:
- type: string - type: string
@ -5384,8 +5389,10 @@ components:
items: items:
$ref: '#/components/schemas/InterleavedContentItem' $ref: '#/components/schemas/InterleavedContentItem'
- $ref: '#/components/schemas/URL' - $ref: '#/components/schemas/URL'
description: The content of the document.
mime_type: mime_type:
type: string type: string
description: The MIME type of the document.
metadata: metadata:
type: object type: object
additionalProperties: additionalProperties:
@ -5396,12 +5403,15 @@ components:
- type: string - type: string
- type: array - type: array
- type: object - type: object
description: Additional metadata for the document.
additionalProperties: false additionalProperties: false
required: required:
- document_id - document_id
- content - content
- metadata - metadata
title: RAGDocument title: RAGDocument
description: >-
A document to be used for document ingestion in the RAG Tool.
InsertRequest: InsertRequest:
type: object type: object
properties: properties:

View file

@ -121,8 +121,6 @@ class Dataset(CommonDatasetFields, Resource):
class DatasetInput(CommonDatasetFields, BaseModel): class DatasetInput(CommonDatasetFields, BaseModel):
dataset_id: str dataset_id: str
provider_id: Optional[str] = None
provider_dataset_id: Optional[str] = None
class ListDatasetsResponse(BaseModel): class ListDatasetsResponse(BaseModel):

View file

@ -17,6 +17,15 @@ from llama_stack.schema_utils import json_schema_type, register_schema, webmetho
@json_schema_type @json_schema_type
class RAGDocument(BaseModel): class RAGDocument(BaseModel):
"""
A document to be used for document ingestion in the RAG Tool.
:param document_id: The unique identifier for the document.
:param content: The content of the document.
:param mime_type: The MIME type of the document.
:param metadata: Additional metadata for the document.
"""
document_id: str document_id: str
content: InterleavedContent | URL content: InterleavedContent | URL
mime_type: str | None = None mime_type: str | None = None

View file

@ -20,6 +20,8 @@ from llama_stack.apis.datasets import (
DatasetType, DatasetType,
DataSource, DataSource,
ListDatasetsResponse, ListDatasetsResponse,
RowsDataSource,
URIDataSource,
) )
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType
from llama_stack.apis.resource import ResourceType from llama_stack.apis.resource import ResourceType
@ -377,6 +379,12 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[Dict[str, Any]] = None,
dataset_id: Optional[str] = None, dataset_id: Optional[str] = None,
) -> Dataset: ) -> Dataset:
if isinstance(source, dict):
if source["type"] == "uri":
source = URIDataSource.parse_obj(source)
elif source["type"] == "rows":
source = RowsDataSource.parse_obj(source)
if not dataset_id: if not dataset_id:
dataset_id = f"dataset-{str(uuid.uuid4())}" dataset_id = f"dataset-{str(uuid.uuid4())}"

View file

@ -47,7 +47,14 @@ RecursiveType = Union[Primitive, List[Primitive], Dict[str, Primitive]]
class ToolCall(BaseModel): class ToolCall(BaseModel):
call_id: str call_id: str
tool_name: Union[BuiltinTool, str] tool_name: Union[BuiltinTool, str]
arguments: Dict[str, RecursiveType] # Plan is to deprecate the Dict in favor of a JSON string
# that is parsed on the client side instead of trying to manage
# the recursive type here.
# Making this a union so that client side can start prepping for this change.
# Eventually, we will remove both the Dict and arguments_json field,
# and arguments will just be a str
arguments: Union[str, Dict[str, RecursiveType]]
arguments_json: Optional[str] = None
@field_validator("tool_name", mode="before") @field_validator("tool_name", mode="before")
@classmethod @classmethod

View file

@ -12,6 +12,7 @@
# the top-level of this source tree. # the top-level of this source tree.
import io import io
import json
import uuid import uuid
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
@ -203,9 +204,10 @@ class ChatFormat:
# This code tries to handle that case # This code tries to handle that case
if tool_name in BuiltinTool.__members__: if tool_name in BuiltinTool.__members__:
tool_name = BuiltinTool[tool_name] tool_name = BuiltinTool[tool_name]
tool_arguments = { if isinstance(tool_arguments, dict):
"query": list(tool_arguments.values())[0], tool_arguments = {
} "query": list(tool_arguments.values())[0],
}
else: else:
builtin_tool_info = ToolUtils.maybe_extract_builtin_tool_call(content) builtin_tool_info = ToolUtils.maybe_extract_builtin_tool_call(content)
if builtin_tool_info is not None: if builtin_tool_info is not None:
@ -229,6 +231,7 @@ class ChatFormat:
call_id=call_id, call_id=call_id,
tool_name=tool_name, tool_name=tool_name,
arguments=tool_arguments, arguments=tool_arguments,
arguments_json=json.dumps(tool_arguments),
) )
) )
content = "" content = ""

View file

@ -11,11 +11,8 @@
# top-level folder for each specific model found within the models/ directory at # top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree. # the top-level of this source tree.
from llama_stack.models.llama.datatypes import (
BuiltinTool, from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall
StopReason,
ToolCall,
)
from .prompt_templates import ( from .prompt_templates import (
BuiltinToolGenerator, BuiltinToolGenerator,

View file

@ -35,12 +35,12 @@ class PandasDataframeDataset:
else: else:
return self.df.iloc[idx].to_dict() return self.df.iloc[idx].to_dict()
def load(self) -> None: async def load(self) -> None:
if self.df is not None: if self.df is not None:
return return
if self.dataset_def.source.type == "uri": if self.dataset_def.source.type == "uri":
self.df = get_dataframe_from_uri(self.dataset_def.source.uri) self.df = await get_dataframe_from_uri(self.dataset_def.source.uri)
elif self.dataset_def.source.type == "rows": elif self.dataset_def.source.type == "rows":
self.df = pandas.DataFrame(self.dataset_def.source.rows) self.df = pandas.DataFrame(self.dataset_def.source.rows)
else: else:
@ -95,7 +95,7 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
) -> IterrowsResponse: ) -> IterrowsResponse:
dataset_def = self.dataset_infos[dataset_id] dataset_def = self.dataset_infos[dataset_id]
dataset_impl = PandasDataframeDataset(dataset_def) dataset_impl = PandasDataframeDataset(dataset_def)
dataset_impl.load() await dataset_impl.load()
start_index = start_index or 0 start_index = start_index or 0
@ -114,7 +114,7 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
dataset_def = self.dataset_infos[dataset_id] dataset_def = self.dataset_infos[dataset_id]
dataset_impl = PandasDataframeDataset(dataset_def) dataset_impl = PandasDataframeDataset(dataset_def)
dataset_impl.load() await dataset_impl.load()
new_rows_df = pandas.DataFrame(rows) new_rows_df = pandas.DataFrame(rows)
dataset_impl.df = pandas.concat([dataset_impl.df, new_rows_df], ignore_index=True) dataset_impl.df = pandas.concat([dataset_impl.df, new_rows_df], ignore_index=True)

View file

@ -582,6 +582,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
tool_name=t.function.name, tool_name=t.function.name,
# vLLM function args come back as a string. Llama Stack expects JSON. # vLLM function args come back as a string. Llama Stack expects JSON.
arguments=json.loads(t.function.arguments), arguments=json.loads(t.function.arguments),
arguments_json=t.function.arguments,
) )
for t in vllm_message.tool_calls for t in vllm_message.tool_calls
], ],

View file

@ -42,9 +42,7 @@ from llama_stack.models.llama.datatypes import (
TopKSamplingStrategy, TopKSamplingStrategy,
TopPSamplingStrategy, TopPSamplingStrategy,
) )
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.openai_compat import (
process_chat_completion_stream_response, process_chat_completion_stream_response,
) )
@ -293,14 +291,12 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
if not tool_calls: if not tool_calls:
return [] return []
for call in tool_calls:
call_function_arguments = json.loads(call.function.arguments)
compitable_tool_calls = [ compitable_tool_calls = [
ToolCall( ToolCall(
call_id=call.id, call_id=call.id,
tool_name=call.function.name, tool_name=call.function.name,
arguments=call_function_arguments, arguments=json.loads(call.function.arguments),
arguments_json=call.function.arguments,
) )
for call in tool_calls for call in tool_calls
] ]

View file

@ -90,15 +90,12 @@ def _convert_to_vllm_tool_calls_in_response(
if not tool_calls: if not tool_calls:
return [] return []
call_function_arguments = None
for call in tool_calls:
call_function_arguments = json.loads(call.function.arguments)
return [ return [
ToolCall( ToolCall(
call_id=call.id, call_id=call.id,
tool_name=call.function.name, tool_name=call.function.name,
arguments=call_function_arguments, arguments=json.loads(call.function.arguments),
arguments_json=call.function.arguments,
) )
for call in tool_calls for call in tool_calls
] ]
@ -183,6 +180,7 @@ async def _process_vllm_chat_completion_stream_response(
call_id=tool_call_buf.call_id, call_id=tool_call_buf.call_id,
tool_name=tool_call_buf.tool_name, tool_name=tool_call_buf.tool_name,
arguments=args, arguments=args,
arguments_json=args_str,
), ),
parse_status=ToolCallParseStatus.succeeded, parse_status=ToolCallParseStatus.succeeded,
), ),

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import asyncio
import base64 import base64
import io import io
from urllib.parse import unquote from urllib.parse import unquote
@ -13,12 +14,15 @@ import pandas
from llama_stack.providers.utils.memory.vector_store import parse_data_url from llama_stack.providers.utils.memory.vector_store import parse_data_url
def get_dataframe_from_uri(uri: str): async def get_dataframe_from_uri(uri: str):
df = None df = None
if uri.endswith(".csv"): if uri.endswith(".csv"):
df = pandas.read_csv(uri) # Moving to its own thread to avoid io from blocking the eventloop
# This isn't ideal as it moves more then just the IO to a new thread
# but it is as close as we can easly get
df = await asyncio.to_thread(pandas.read_csv, uri)
elif uri.endswith(".xlsx"): elif uri.endswith(".xlsx"):
df = pandas.read_excel(uri) df = await asyncio.to_thread(pandas.read_excel, uri)
elif uri.startswith("data:"): elif uri.startswith("data:"):
parts = parse_data_url(uri) parts = parse_data_url(uri)
data = parts["data"] data = parts["data"]

View file

@ -529,7 +529,11 @@ async def convert_message_to_openai_dict_new(
) -> Union[str, Iterable[OpenAIChatCompletionContentPartParam]]: ) -> Union[str, Iterable[OpenAIChatCompletionContentPartParam]]:
async def impl( async def impl(
content_: InterleavedContent, content_: InterleavedContent,
) -> Union[str, OpenAIChatCompletionContentPartParam, List[OpenAIChatCompletionContentPartParam]]: ) -> Union[
str,
OpenAIChatCompletionContentPartParam,
List[OpenAIChatCompletionContentPartParam],
]:
# Llama Stack and OpenAI spec match for str and text input # Llama Stack and OpenAI spec match for str and text input
if isinstance(content_, str): if isinstance(content_, str):
return content_ return content_
@ -570,7 +574,7 @@ async def convert_message_to_openai_dict_new(
OpenAIChatCompletionMessageToolCall( OpenAIChatCompletionMessageToolCall(
id=tool.call_id, id=tool.call_id,
function=OpenAIFunction( function=OpenAIFunction(
name=tool.tool_name if not isinstance(tool.tool_name, BuiltinTool) else tool.tool_name.value, name=(tool.tool_name if not isinstance(tool.tool_name, BuiltinTool) else tool.tool_name.value),
arguments=json.dumps(tool.arguments), arguments=json.dumps(tool.arguments),
), ),
type="function", type="function",
@ -609,6 +613,7 @@ def convert_tool_call(
call_id=tool_call.id, call_id=tool_call.id,
tool_name=tool_call.function.name, tool_name=tool_call.function.name,
arguments=json.loads(tool_call.function.arguments), arguments=json.loads(tool_call.function.arguments),
arguments_json=tool_call.function.arguments,
) )
except Exception: except Exception:
return UnparseableToolCall( return UnparseableToolCall(
@ -759,6 +764,7 @@ def _convert_openai_tool_calls(
call_id=call.id, call_id=call.id,
tool_name=call.function.name, tool_name=call.function.name,
arguments=json.loads(call.function.arguments), arguments=json.loads(call.function.arguments),
arguments_json=call.function.arguments,
) )
for call in tool_calls for call in tool_calls
] ]
@ -890,7 +896,8 @@ async def convert_openai_chat_completion_stream(
# ChatCompletionResponseEvent only supports one per stream # ChatCompletionResponseEvent only supports one per stream
if len(choice.delta.tool_calls) > 1: if len(choice.delta.tool_calls) > 1:
warnings.warn( warnings.warn(
"multiple tool calls found in a single delta, using the first, ignoring the rest", stacklevel=2 "multiple tool calls found in a single delta, using the first, ignoring the rest",
stacklevel=2,
) )
if not enable_incremental_tool_calls: if not enable_incremental_tool_calls:
@ -971,6 +978,7 @@ async def convert_openai_chat_completion_stream(
call_id=buffer["call_id"], call_id=buffer["call_id"],
tool_name=buffer["name"], tool_name=buffer["name"],
arguments=arguments, arguments=arguments,
arguments_json=buffer["arguments"],
) )
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent( event=ChatCompletionResponseEvent(

View file

@ -170,7 +170,6 @@ def get_distribution_template() -> DistributionTemplate:
default_datasets = [ default_datasets = [
DatasetInput( DatasetInput(
dataset_id="simpleqa", dataset_id="simpleqa",
provider_id="huggingface",
purpose=DatasetPurpose.eval_messages_answer, purpose=DatasetPurpose.eval_messages_answer,
source=URIDataSource( source=URIDataSource(
uri="huggingface://datasets/llamastack/simpleqa?split=train", uri="huggingface://datasets/llamastack/simpleqa?split=train",
@ -178,7 +177,6 @@ def get_distribution_template() -> DistributionTemplate:
), ),
DatasetInput( DatasetInput(
dataset_id="mmlu_cot", dataset_id="mmlu_cot",
provider_id="huggingface",
purpose=DatasetPurpose.eval_messages_answer, purpose=DatasetPurpose.eval_messages_answer,
source=URIDataSource( source=URIDataSource(
uri="huggingface://datasets/llamastack/mmlu_cot?split=test&name=all", uri="huggingface://datasets/llamastack/mmlu_cot?split=test&name=all",
@ -186,7 +184,6 @@ def get_distribution_template() -> DistributionTemplate:
), ),
DatasetInput( DatasetInput(
dataset_id="gpqa_cot", dataset_id="gpqa_cot",
provider_id="huggingface",
purpose=DatasetPurpose.eval_messages_answer, purpose=DatasetPurpose.eval_messages_answer,
source=URIDataSource( source=URIDataSource(
uri="huggingface://datasets/llamastack/gpqa_0shot_cot?split=test&name=gpqa_main", uri="huggingface://datasets/llamastack/gpqa_0shot_cot?split=test&name=gpqa_main",
@ -194,7 +191,6 @@ def get_distribution_template() -> DistributionTemplate:
), ),
DatasetInput( DatasetInput(
dataset_id="math_500", dataset_id="math_500",
provider_id="huggingface",
purpose=DatasetPurpose.eval_messages_answer, purpose=DatasetPurpose.eval_messages_answer,
source=URIDataSource( source=URIDataSource(
uri="huggingface://datasets/llamastack/math_500?split=test", uri="huggingface://datasets/llamastack/math_500?split=test",
@ -202,7 +198,6 @@ def get_distribution_template() -> DistributionTemplate:
), ),
DatasetInput( DatasetInput(
dataset_id="bfcl", dataset_id="bfcl",
provider_id="huggingface",
purpose=DatasetPurpose.eval_messages_answer, purpose=DatasetPurpose.eval_messages_answer,
source=URIDataSource( source=URIDataSource(
uri="huggingface://datasets/llamastack/bfcl_v3?split=train", uri="huggingface://datasets/llamastack/bfcl_v3?split=train",

View file

@ -164,42 +164,36 @@ datasets:
uri: huggingface://datasets/llamastack/simpleqa?split=train uri: huggingface://datasets/llamastack/simpleqa?split=train
metadata: {} metadata: {}
dataset_id: simpleqa dataset_id: simpleqa
provider_id: huggingface
- purpose: eval/messages-answer - purpose: eval/messages-answer
source: source:
type: uri type: uri
uri: huggingface://datasets/llamastack/mmlu_cot?split=test&name=all uri: huggingface://datasets/llamastack/mmlu_cot?split=test&name=all
metadata: {} metadata: {}
dataset_id: mmlu_cot dataset_id: mmlu_cot
provider_id: huggingface
- purpose: eval/messages-answer - purpose: eval/messages-answer
source: source:
type: uri type: uri
uri: huggingface://datasets/llamastack/gpqa_0shot_cot?split=test&name=gpqa_main uri: huggingface://datasets/llamastack/gpqa_0shot_cot?split=test&name=gpqa_main
metadata: {} metadata: {}
dataset_id: gpqa_cot dataset_id: gpqa_cot
provider_id: huggingface
- purpose: eval/messages-answer - purpose: eval/messages-answer
source: source:
type: uri type: uri
uri: huggingface://datasets/llamastack/math_500?split=test uri: huggingface://datasets/llamastack/math_500?split=test
metadata: {} metadata: {}
dataset_id: math_500 dataset_id: math_500
provider_id: huggingface
- purpose: eval/messages-answer - purpose: eval/messages-answer
source: source:
type: uri type: uri
uri: huggingface://datasets/llamastack/bfcl_v3?split=train uri: huggingface://datasets/llamastack/bfcl_v3?split=train
metadata: {} metadata: {}
dataset_id: bfcl dataset_id: bfcl
provider_id: huggingface
- purpose: eval/messages-answer - purpose: eval/messages-answer
source: source:
type: uri type: uri
uri: huggingface://datasets/llamastack/IfEval?split=train uri: huggingface://datasets/llamastack/IfEval?split=train
metadata: {} metadata: {}
dataset_id: IfEval dataset_id: IfEval
provider_id: huggingface
scoring_fns: [] scoring_fns: []
benchmarks: benchmarks:
- dataset_id: simpleqa - dataset_id: simpleqa

View file

@ -165,7 +165,10 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
request.model = MODEL request.model = MODEL
request.tool_config.tool_prompt_format = ToolPromptFormat.json request.tool_config.tool_prompt_format = ToolPromptFormat.json
prompt = await chat_completion_request_to_prompt(request, request.model) prompt = await chat_completion_request_to_prompt(request, request.model)
self.assertIn('{"type": "function", "name": "custom1", "parameters": {"param1": "value1"}}', prompt) self.assertIn(
'{"type": "function", "name": "custom1", "parameters": {"param1": "value1"}}',
prompt,
)
async def test_user_provided_system_message(self): async def test_user_provided_system_message(self):
content = "Hello !" content = "Hello !"