Merge remote-tracking branch 'origin/main' into if_eval

This commit is contained in:
Botao Chen 2025-03-19 12:58:14 -07:00
commit 9068416bc4
18 changed files with 183 additions and 135 deletions

2
.github/TRIAGERS.md vendored Normal file
View file

@ -0,0 +1,2 @@
# This file documents Triage members in the Llama Stack community
@franciscojavierarceo @leseb

View file

@ -4159,70 +4159,80 @@
]
},
"arguments": {
"type": "object",
"additionalProperties": {
"oneOf": [
{
"type": "string"
},
{
"type": "integer"
},
{
"type": "number"
},
{
"type": "boolean"
},
{
"type": "null"
},
{
"type": "array",
"items": {
"oneOf": [
{
"type": "string"
},
{
"type": "integer"
},
{
"type": "number"
},
{
"type": "boolean"
},
{
"type": "null"
"oneOf": [
{
"type": "string"
},
{
"type": "object",
"additionalProperties": {
"oneOf": [
{
"type": "string"
},
{
"type": "integer"
},
{
"type": "number"
},
{
"type": "boolean"
},
{
"type": "null"
},
{
"type": "array",
"items": {
"oneOf": [
{
"type": "string"
},
{
"type": "integer"
},
{
"type": "number"
},
{
"type": "boolean"
},
{
"type": "null"
}
]
}
]
}
},
{
"type": "object",
"additionalProperties": {
"oneOf": [
{
"type": "string"
},
{
"type": "integer"
},
{
"type": "number"
},
{
"type": "boolean"
},
{
"type": "null"
},
{
"type": "object",
"additionalProperties": {
"oneOf": [
{
"type": "string"
},
{
"type": "integer"
},
{
"type": "number"
},
{
"type": "boolean"
},
{
"type": "null"
}
]
}
]
}
}
]
}
]
}
}
]
},
"arguments_json": {
"type": "string"
}
},
"additionalProperties": false,
@ -7788,7 +7798,8 @@
"type": "object",
"properties": {
"document_id": {
"type": "string"
"type": "string",
"description": "The unique identifier for the document."
},
"content": {
"oneOf": [
@ -7807,10 +7818,12 @@
{
"$ref": "#/components/schemas/URL"
}
]
],
"description": "The content of the document."
},
"mime_type": {
"type": "string"
"type": "string",
"description": "The MIME type of the document."
},
"metadata": {
"type": "object",
@ -7835,7 +7848,8 @@
"type": "object"
}
]
}
},
"description": "Additional metadata for the document."
}
},
"additionalProperties": false,
@ -7844,7 +7858,8 @@
"content",
"metadata"
],
"title": "RAGDocument"
"title": "RAGDocument",
"description": "A document to be used for document ingestion in the RAG Tool."
},
"InsertRequest": {
"type": "object",

View file

@ -2864,30 +2864,34 @@ components:
title: BuiltinTool
- type: string
arguments:
type: object
additionalProperties:
oneOf:
- type: string
- type: integer
- type: number
- type: boolean
- type: 'null'
- type: array
items:
oneOf:
- type: string
- type: integer
- type: number
- type: boolean
- type: 'null'
- type: object
additionalProperties:
oneOf:
- type: string
- type: integer
- type: number
- type: boolean
- type: 'null'
oneOf:
- type: string
- type: object
additionalProperties:
oneOf:
- type: string
- type: integer
- type: number
- type: boolean
- type: 'null'
- type: array
items:
oneOf:
- type: string
- type: integer
- type: number
- type: boolean
- type: 'null'
- type: object
additionalProperties:
oneOf:
- type: string
- type: integer
- type: number
- type: boolean
- type: 'null'
arguments_json:
type: string
additionalProperties: false
required:
- call_id
@ -5376,6 +5380,7 @@ components:
properties:
document_id:
type: string
description: The unique identifier for the document.
content:
oneOf:
- type: string
@ -5384,8 +5389,10 @@ components:
items:
$ref: '#/components/schemas/InterleavedContentItem'
- $ref: '#/components/schemas/URL'
description: The content of the document.
mime_type:
type: string
description: The MIME type of the document.
metadata:
type: object
additionalProperties:
@ -5396,12 +5403,15 @@ components:
- type: string
- type: array
- type: object
description: Additional metadata for the document.
additionalProperties: false
required:
- document_id
- content
- metadata
title: RAGDocument
description: >-
A document to be used for document ingestion in the RAG Tool.
InsertRequest:
type: object
properties:

View file

@ -121,8 +121,6 @@ class Dataset(CommonDatasetFields, Resource):
class DatasetInput(CommonDatasetFields, BaseModel):
dataset_id: str
provider_id: Optional[str] = None
provider_dataset_id: Optional[str] = None
class ListDatasetsResponse(BaseModel):

View file

@ -17,6 +17,15 @@ from llama_stack.schema_utils import json_schema_type, register_schema, webmetho
@json_schema_type
class RAGDocument(BaseModel):
"""
A document to be used for document ingestion in the RAG Tool.
:param document_id: The unique identifier for the document.
:param content: The content of the document.
:param mime_type: The MIME type of the document.
:param metadata: Additional metadata for the document.
"""
document_id: str
content: InterleavedContent | URL
mime_type: str | None = None

View file

@ -20,6 +20,8 @@ from llama_stack.apis.datasets import (
DatasetType,
DataSource,
ListDatasetsResponse,
RowsDataSource,
URIDataSource,
)
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType
from llama_stack.apis.resource import ResourceType
@ -377,6 +379,12 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
metadata: Optional[Dict[str, Any]] = None,
dataset_id: Optional[str] = None,
) -> Dataset:
if isinstance(source, dict):
if source["type"] == "uri":
source = URIDataSource.parse_obj(source)
elif source["type"] == "rows":
source = RowsDataSource.parse_obj(source)
if not dataset_id:
dataset_id = f"dataset-{str(uuid.uuid4())}"

View file

@ -47,7 +47,14 @@ RecursiveType = Union[Primitive, List[Primitive], Dict[str, Primitive]]
class ToolCall(BaseModel):
call_id: str
tool_name: Union[BuiltinTool, str]
arguments: Dict[str, RecursiveType]
# Plan is to deprecate the Dict in favor of a JSON string
# that is parsed on the client side instead of trying to manage
# the recursive type here.
# Making this a union so that client side can start prepping for this change.
# Eventually, we will remove both the Dict and arguments_json field,
# and arguments will just be a str
arguments: Union[str, Dict[str, RecursiveType]]
arguments_json: Optional[str] = None
@field_validator("tool_name", mode="before")
@classmethod

View file

@ -12,6 +12,7 @@
# the top-level of this source tree.
import io
import json
import uuid
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
@ -203,9 +204,10 @@ class ChatFormat:
# This code tries to handle that case
if tool_name in BuiltinTool.__members__:
tool_name = BuiltinTool[tool_name]
tool_arguments = {
"query": list(tool_arguments.values())[0],
}
if isinstance(tool_arguments, dict):
tool_arguments = {
"query": list(tool_arguments.values())[0],
}
else:
builtin_tool_info = ToolUtils.maybe_extract_builtin_tool_call(content)
if builtin_tool_info is not None:
@ -229,6 +231,7 @@ class ChatFormat:
call_id=call_id,
tool_name=tool_name,
arguments=tool_arguments,
arguments_json=json.dumps(tool_arguments),
)
)
content = ""

View file

@ -11,11 +11,8 @@
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
from llama_stack.models.llama.datatypes import (
BuiltinTool,
StopReason,
ToolCall,
)
from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall
from .prompt_templates import (
BuiltinToolGenerator,

View file

@ -35,12 +35,12 @@ class PandasDataframeDataset:
else:
return self.df.iloc[idx].to_dict()
def load(self) -> None:
async def load(self) -> None:
if self.df is not None:
return
if self.dataset_def.source.type == "uri":
self.df = get_dataframe_from_uri(self.dataset_def.source.uri)
self.df = await get_dataframe_from_uri(self.dataset_def.source.uri)
elif self.dataset_def.source.type == "rows":
self.df = pandas.DataFrame(self.dataset_def.source.rows)
else:
@ -95,7 +95,7 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
) -> IterrowsResponse:
dataset_def = self.dataset_infos[dataset_id]
dataset_impl = PandasDataframeDataset(dataset_def)
dataset_impl.load()
await dataset_impl.load()
start_index = start_index or 0
@ -114,7 +114,7 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
dataset_def = self.dataset_infos[dataset_id]
dataset_impl = PandasDataframeDataset(dataset_def)
dataset_impl.load()
await dataset_impl.load()
new_rows_df = pandas.DataFrame(rows)
dataset_impl.df = pandas.concat([dataset_impl.df, new_rows_df], ignore_index=True)

View file

@ -582,6 +582,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
tool_name=t.function.name,
# vLLM function args come back as a string. Llama Stack expects JSON.
arguments=json.loads(t.function.arguments),
arguments_json=t.function.arguments,
)
for t in vllm_message.tool_calls
],

View file

@ -42,9 +42,7 @@ from llama_stack.models.llama.datatypes import (
TopKSamplingStrategy,
TopPSamplingStrategy,
)
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import (
process_chat_completion_stream_response,
)
@ -293,14 +291,12 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
if not tool_calls:
return []
for call in tool_calls:
call_function_arguments = json.loads(call.function.arguments)
compitable_tool_calls = [
ToolCall(
call_id=call.id,
tool_name=call.function.name,
arguments=call_function_arguments,
arguments=json.loads(call.function.arguments),
arguments_json=call.function.arguments,
)
for call in tool_calls
]

View file

@ -90,15 +90,12 @@ def _convert_to_vllm_tool_calls_in_response(
if not tool_calls:
return []
call_function_arguments = None
for call in tool_calls:
call_function_arguments = json.loads(call.function.arguments)
return [
ToolCall(
call_id=call.id,
tool_name=call.function.name,
arguments=call_function_arguments,
arguments=json.loads(call.function.arguments),
arguments_json=call.function.arguments,
)
for call in tool_calls
]
@ -183,6 +180,7 @@ async def _process_vllm_chat_completion_stream_response(
call_id=tool_call_buf.call_id,
tool_name=tool_call_buf.tool_name,
arguments=args,
arguments_json=args_str,
),
parse_status=ToolCallParseStatus.succeeded,
),

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import base64
import io
from urllib.parse import unquote
@ -13,12 +14,15 @@ import pandas
from llama_stack.providers.utils.memory.vector_store import parse_data_url
def get_dataframe_from_uri(uri: str):
async def get_dataframe_from_uri(uri: str):
df = None
if uri.endswith(".csv"):
df = pandas.read_csv(uri)
# Moving to its own thread to avoid io from blocking the eventloop
# This isn't ideal as it moves more then just the IO to a new thread
# but it is as close as we can easly get
df = await asyncio.to_thread(pandas.read_csv, uri)
elif uri.endswith(".xlsx"):
df = pandas.read_excel(uri)
df = await asyncio.to_thread(pandas.read_excel, uri)
elif uri.startswith("data:"):
parts = parse_data_url(uri)
data = parts["data"]

View file

@ -529,7 +529,11 @@ async def convert_message_to_openai_dict_new(
) -> Union[str, Iterable[OpenAIChatCompletionContentPartParam]]:
async def impl(
content_: InterleavedContent,
) -> Union[str, OpenAIChatCompletionContentPartParam, List[OpenAIChatCompletionContentPartParam]]:
) -> Union[
str,
OpenAIChatCompletionContentPartParam,
List[OpenAIChatCompletionContentPartParam],
]:
# Llama Stack and OpenAI spec match for str and text input
if isinstance(content_, str):
return content_
@ -570,7 +574,7 @@ async def convert_message_to_openai_dict_new(
OpenAIChatCompletionMessageToolCall(
id=tool.call_id,
function=OpenAIFunction(
name=tool.tool_name if not isinstance(tool.tool_name, BuiltinTool) else tool.tool_name.value,
name=(tool.tool_name if not isinstance(tool.tool_name, BuiltinTool) else tool.tool_name.value),
arguments=json.dumps(tool.arguments),
),
type="function",
@ -609,6 +613,7 @@ def convert_tool_call(
call_id=tool_call.id,
tool_name=tool_call.function.name,
arguments=json.loads(tool_call.function.arguments),
arguments_json=tool_call.function.arguments,
)
except Exception:
return UnparseableToolCall(
@ -759,6 +764,7 @@ def _convert_openai_tool_calls(
call_id=call.id,
tool_name=call.function.name,
arguments=json.loads(call.function.arguments),
arguments_json=call.function.arguments,
)
for call in tool_calls
]
@ -890,7 +896,8 @@ async def convert_openai_chat_completion_stream(
# ChatCompletionResponseEvent only supports one per stream
if len(choice.delta.tool_calls) > 1:
warnings.warn(
"multiple tool calls found in a single delta, using the first, ignoring the rest", stacklevel=2
"multiple tool calls found in a single delta, using the first, ignoring the rest",
stacklevel=2,
)
if not enable_incremental_tool_calls:
@ -971,6 +978,7 @@ async def convert_openai_chat_completion_stream(
call_id=buffer["call_id"],
tool_name=buffer["name"],
arguments=arguments,
arguments_json=buffer["arguments"],
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(

View file

@ -170,7 +170,6 @@ def get_distribution_template() -> DistributionTemplate:
default_datasets = [
DatasetInput(
dataset_id="simpleqa",
provider_id="huggingface",
purpose=DatasetPurpose.eval_messages_answer,
source=URIDataSource(
uri="huggingface://datasets/llamastack/simpleqa?split=train",
@ -178,7 +177,6 @@ def get_distribution_template() -> DistributionTemplate:
),
DatasetInput(
dataset_id="mmlu_cot",
provider_id="huggingface",
purpose=DatasetPurpose.eval_messages_answer,
source=URIDataSource(
uri="huggingface://datasets/llamastack/mmlu_cot?split=test&name=all",
@ -186,7 +184,6 @@ def get_distribution_template() -> DistributionTemplate:
),
DatasetInput(
dataset_id="gpqa_cot",
provider_id="huggingface",
purpose=DatasetPurpose.eval_messages_answer,
source=URIDataSource(
uri="huggingface://datasets/llamastack/gpqa_0shot_cot?split=test&name=gpqa_main",
@ -194,7 +191,6 @@ def get_distribution_template() -> DistributionTemplate:
),
DatasetInput(
dataset_id="math_500",
provider_id="huggingface",
purpose=DatasetPurpose.eval_messages_answer,
source=URIDataSource(
uri="huggingface://datasets/llamastack/math_500?split=test",
@ -202,7 +198,6 @@ def get_distribution_template() -> DistributionTemplate:
),
DatasetInput(
dataset_id="bfcl",
provider_id="huggingface",
purpose=DatasetPurpose.eval_messages_answer,
source=URIDataSource(
uri="huggingface://datasets/llamastack/bfcl_v3?split=train",

View file

@ -164,42 +164,36 @@ datasets:
uri: huggingface://datasets/llamastack/simpleqa?split=train
metadata: {}
dataset_id: simpleqa
provider_id: huggingface
- purpose: eval/messages-answer
source:
type: uri
uri: huggingface://datasets/llamastack/mmlu_cot?split=test&name=all
metadata: {}
dataset_id: mmlu_cot
provider_id: huggingface
- purpose: eval/messages-answer
source:
type: uri
uri: huggingface://datasets/llamastack/gpqa_0shot_cot?split=test&name=gpqa_main
metadata: {}
dataset_id: gpqa_cot
provider_id: huggingface
- purpose: eval/messages-answer
source:
type: uri
uri: huggingface://datasets/llamastack/math_500?split=test
metadata: {}
dataset_id: math_500
provider_id: huggingface
- purpose: eval/messages-answer
source:
type: uri
uri: huggingface://datasets/llamastack/bfcl_v3?split=train
metadata: {}
dataset_id: bfcl
provider_id: huggingface
- purpose: eval/messages-answer
source:
type: uri
uri: huggingface://datasets/llamastack/IfEval?split=train
metadata: {}
dataset_id: IfEval
provider_id: huggingface
scoring_fns: []
benchmarks:
- dataset_id: simpleqa

View file

@ -165,7 +165,10 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
request.model = MODEL
request.tool_config.tool_prompt_format = ToolPromptFormat.json
prompt = await chat_completion_request_to_prompt(request, request.model)
self.assertIn('{"type": "function", "name": "custom1", "parameters": {"param1": "value1"}}', prompt)
self.assertIn(
'{"type": "function", "name": "custom1", "parameters": {"param1": "value1"}}',
prompt,
)
async def test_user_provided_system_message(self):
content = "Hello !"