mirror of
https://github.com/meta-llama/llama-stack.git
synced 2026-01-02 15:20:02 +00:00
Merge remote-tracking branch 'origin/main' into if_eval
This commit is contained in:
commit
a690c7b230
123 changed files with 4482 additions and 3161 deletions
|
|
@ -614,118 +614,133 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
logger.debug(f"completion message with EOM (iter: {n_iter}): {str(message)}")
|
||||
input_messages = input_messages + [message]
|
||||
else:
|
||||
logger.debug(f"completion message (iter: {n_iter}) from the model: {str(message)}")
|
||||
# 1. Start the tool execution step and progress
|
||||
step_id = str(uuid.uuid4())
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepStartPayload(
|
||||
step_type=StepType.tool_execution.value,
|
||||
step_id=step_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
tool_call = message.tool_calls[0]
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepProgressPayload(
|
||||
step_type=StepType.tool_execution.value,
|
||||
step_id=step_id,
|
||||
tool_call=tool_call,
|
||||
delta=ToolCallDelta(
|
||||
parse_status=ToolCallParseStatus.in_progress,
|
||||
tool_call=tool_call,
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
input_messages = input_messages + [message]
|
||||
|
||||
# If tool is a client tool, yield CompletionMessage and return
|
||||
if tool_call.tool_name in client_tools:
|
||||
# NOTE: mark end_of_message to indicate to client that it may
|
||||
# call the tool and continue the conversation with the tool's response.
|
||||
message.stop_reason = StopReason.end_of_message
|
||||
# Process tool calls in the message
|
||||
client_tool_calls = []
|
||||
non_client_tool_calls = []
|
||||
|
||||
# Separate client and non-client tool calls
|
||||
for tool_call in message.tool_calls:
|
||||
if tool_call.tool_name in client_tools:
|
||||
client_tool_calls.append(tool_call)
|
||||
else:
|
||||
non_client_tool_calls.append(tool_call)
|
||||
|
||||
# Process non-client tool calls first
|
||||
for tool_call in non_client_tool_calls:
|
||||
step_id = str(uuid.uuid4())
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepStartPayload(
|
||||
step_type=StepType.tool_execution.value,
|
||||
step_id=step_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepProgressPayload(
|
||||
step_type=StepType.tool_execution.value,
|
||||
step_id=step_id,
|
||||
delta=ToolCallDelta(
|
||||
parse_status=ToolCallParseStatus.in_progress,
|
||||
tool_call=tool_call,
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Execute the tool call
|
||||
async with tracing.span(
|
||||
"tool_execution",
|
||||
{
|
||||
"tool_name": tool_call.tool_name,
|
||||
"input": message.model_dump_json(),
|
||||
},
|
||||
) as span:
|
||||
tool_execution_start_time = datetime.now(timezone.utc).isoformat()
|
||||
tool_result = await self.execute_tool_call_maybe(
|
||||
session_id,
|
||||
tool_call,
|
||||
)
|
||||
if tool_result.content is None:
|
||||
raise ValueError(
|
||||
f"Tool call result (id: {tool_call.call_id}, name: {tool_call.tool_name}) does not have any content"
|
||||
)
|
||||
result_message = ToolResponseMessage(
|
||||
call_id=tool_call.call_id,
|
||||
content=tool_result.content,
|
||||
)
|
||||
span.set_attribute("output", result_message.model_dump_json())
|
||||
|
||||
# Store tool execution step
|
||||
tool_execution_step = ToolExecutionStep(
|
||||
step_id=step_id,
|
||||
turn_id=turn_id,
|
||||
tool_calls=[tool_call],
|
||||
tool_responses=[
|
||||
ToolResponse(
|
||||
call_id=tool_call.call_id,
|
||||
tool_name=tool_call.tool_name,
|
||||
content=tool_result.content,
|
||||
metadata=tool_result.metadata,
|
||||
)
|
||||
],
|
||||
started_at=tool_execution_start_time,
|
||||
completed_at=datetime.now(timezone.utc).isoformat(),
|
||||
)
|
||||
|
||||
# Yield the step completion event
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepCompletePayload(
|
||||
step_type=StepType.tool_execution.value,
|
||||
step_id=step_id,
|
||||
step_details=tool_execution_step,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Add the result message to input_messages for the next iteration
|
||||
input_messages.append(result_message)
|
||||
|
||||
# TODO: add tool-input touchpoint and a "start" event for this step also
|
||||
# but that needs a lot more refactoring of Tool code potentially
|
||||
if (type(result_message.content) is str) and (
|
||||
out_attachment := _interpret_content_as_attachment(result_message.content)
|
||||
):
|
||||
# NOTE: when we push this message back to the model, the model may ignore the
|
||||
# attached file path etc. since the model is trained to only provide a user message
|
||||
# with the summary. We keep all generated attachments and then attach them to final message
|
||||
output_attachments.append(out_attachment)
|
||||
|
||||
# If there are client tool calls, yield a message with only those tool calls
|
||||
if client_tool_calls:
|
||||
await self.storage.set_in_progress_tool_call_step(
|
||||
session_id,
|
||||
turn_id,
|
||||
ToolExecutionStep(
|
||||
step_id=step_id,
|
||||
turn_id=turn_id,
|
||||
tool_calls=[tool_call],
|
||||
tool_calls=client_tool_calls,
|
||||
tool_responses=[],
|
||||
started_at=datetime.now(timezone.utc).isoformat(),
|
||||
),
|
||||
)
|
||||
yield message
|
||||
|
||||
# Create a copy of the message with only client tool calls
|
||||
client_message = message.model_copy(deep=True)
|
||||
client_message.tool_calls = client_tool_calls
|
||||
# NOTE: mark end_of_message to indicate to client that it may
|
||||
# call the tool and continue the conversation with the tool's response.
|
||||
client_message.stop_reason = StopReason.end_of_message
|
||||
|
||||
# Yield the message with client tool calls
|
||||
yield client_message
|
||||
return
|
||||
|
||||
# If tool is a builtin server tool, execute it
|
||||
tool_name = tool_call.tool_name
|
||||
if isinstance(tool_name, BuiltinTool):
|
||||
tool_name = tool_name.value
|
||||
async with tracing.span(
|
||||
"tool_execution",
|
||||
{
|
||||
"tool_name": tool_name,
|
||||
"input": message.model_dump_json(),
|
||||
},
|
||||
) as span:
|
||||
tool_execution_start_time = datetime.now(timezone.utc).isoformat()
|
||||
tool_call = message.tool_calls[0]
|
||||
tool_result = await self.execute_tool_call_maybe(
|
||||
session_id,
|
||||
tool_call,
|
||||
)
|
||||
if tool_result.content is None:
|
||||
raise ValueError(
|
||||
f"Tool call result (id: {tool_call.call_id}, name: {tool_call.tool_name}) does not have any content"
|
||||
)
|
||||
result_messages = [
|
||||
ToolResponseMessage(
|
||||
call_id=tool_call.call_id,
|
||||
content=tool_result.content,
|
||||
)
|
||||
]
|
||||
assert len(result_messages) == 1, "Currently not supporting multiple messages"
|
||||
result_message = result_messages[0]
|
||||
span.set_attribute("output", result_message.model_dump_json())
|
||||
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepCompletePayload(
|
||||
step_type=StepType.tool_execution.value,
|
||||
step_id=step_id,
|
||||
step_details=ToolExecutionStep(
|
||||
step_id=step_id,
|
||||
turn_id=turn_id,
|
||||
tool_calls=[tool_call],
|
||||
tool_responses=[
|
||||
ToolResponse(
|
||||
call_id=result_message.call_id,
|
||||
tool_name=tool_call.tool_name,
|
||||
content=result_message.content,
|
||||
metadata=tool_result.metadata,
|
||||
)
|
||||
],
|
||||
started_at=tool_execution_start_time,
|
||||
completed_at=datetime.now(timezone.utc).isoformat(),
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# TODO: add tool-input touchpoint and a "start" event for this step also
|
||||
# but that needs a lot more refactoring of Tool code potentially
|
||||
if (type(result_message.content) is str) and (
|
||||
out_attachment := _interpret_content_as_attachment(result_message.content)
|
||||
):
|
||||
# NOTE: when we push this message back to the model, the model may ignore the
|
||||
# attached file path etc. since the model is trained to only provide a user message
|
||||
# with the summary. We keep all generated attachments and then attach them to final message
|
||||
output_attachments.append(out_attachment)
|
||||
|
||||
input_messages = input_messages + [message, result_message]
|
||||
|
||||
async def _initialize_tools(
|
||||
self,
|
||||
toolgroups_for_turn: Optional[List[AgentToolGroup]] = None,
|
||||
|
|
@ -891,16 +906,14 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
if memory_tool and code_interpreter_tool:
|
||||
# if both memory and code_interpreter are available, we download the URLs
|
||||
# and attach the data to the last message.
|
||||
msg = await attachment_message(self.tempdir, url_items)
|
||||
input_messages.append(msg)
|
||||
await attachment_message(self.tempdir, url_items, input_messages[-1])
|
||||
# Since memory is present, add all the data to the memory bank
|
||||
await self.add_to_session_vector_db(session_id, documents)
|
||||
elif code_interpreter_tool:
|
||||
# if only code_interpreter is available, we download the URLs to a tempdir
|
||||
# and attach the path to them as a message to inference with the
|
||||
# assumption that the model invokes the code_interpreter tool with the path
|
||||
msg = await attachment_message(self.tempdir, url_items)
|
||||
input_messages.append(msg)
|
||||
await attachment_message(self.tempdir, url_items, input_messages[-1])
|
||||
elif memory_tool:
|
||||
# if only memory is available, we load the data from the URLs and content items to the memory bank
|
||||
await self.add_to_session_vector_db(session_id, documents)
|
||||
|
|
@ -967,8 +980,8 @@ async def load_data_from_urls(urls: List[URL]) -> List[str]:
|
|||
return data
|
||||
|
||||
|
||||
async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessage:
|
||||
content = []
|
||||
async def attachment_message(tempdir: str, urls: List[URL], message: UserMessage) -> None:
|
||||
contents = []
|
||||
|
||||
for url in urls:
|
||||
uri = url.uri
|
||||
|
|
@ -988,16 +1001,19 @@ async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessa
|
|||
else:
|
||||
raise ValueError(f"Unsupported URL {url}")
|
||||
|
||||
content.append(
|
||||
contents.append(
|
||||
TextContentItem(
|
||||
text=f'# User provided a file accessible to you at "{filepath}"\nYou can use code_interpreter to load and inspect it.'
|
||||
)
|
||||
)
|
||||
|
||||
return ToolResponseMessage(
|
||||
call_id="",
|
||||
content=content,
|
||||
)
|
||||
if isinstance(message.content, list):
|
||||
message.content.extend(contents)
|
||||
else:
|
||||
if isinstance(message.content, str):
|
||||
message.content = [TextContentItem(text=message.content)] + contents
|
||||
else:
|
||||
message.content = [message.content] + contents
|
||||
|
||||
|
||||
def _interpret_content_as_attachment(
|
||||
|
|
|
|||
|
|
@ -3,20 +3,14 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import base64
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import pandas
|
||||
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
|
||||
from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse
|
||||
from llama_stack.apis.datasets import Dataset
|
||||
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
|
||||
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url
|
||||
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_uri
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
|
||||
from .config import LocalFSDatasetIOConfig
|
||||
|
|
@ -24,30 +18,7 @@ from .config import LocalFSDatasetIOConfig
|
|||
DATASETS_PREFIX = "localfs_datasets:"
|
||||
|
||||
|
||||
class BaseDataset(ABC):
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def __len__(self) -> int:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def __getitem__(self, idx):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def load(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatasetInfo:
|
||||
dataset_def: Dataset
|
||||
dataset_impl: BaseDataset
|
||||
|
||||
|
||||
class PandasDataframeDataset(BaseDataset):
|
||||
class PandasDataframeDataset:
|
||||
def __init__(self, dataset_def: Dataset, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.dataset_def = dataset_def
|
||||
|
|
@ -64,23 +35,19 @@ class PandasDataframeDataset(BaseDataset):
|
|||
else:
|
||||
return self.df.iloc[idx].to_dict()
|
||||
|
||||
def _validate_dataset_schema(self, df) -> pandas.DataFrame:
|
||||
# note that we will drop any columns in dataset that are not in the schema
|
||||
df = df[self.dataset_def.dataset_schema.keys()]
|
||||
# check all columns in dataset schema are present
|
||||
assert len(df.columns) == len(self.dataset_def.dataset_schema)
|
||||
# TODO: type checking against column types in dataset schema
|
||||
return df
|
||||
|
||||
def load(self) -> None:
|
||||
if self.df is not None:
|
||||
return
|
||||
|
||||
df = get_dataframe_from_url(self.dataset_def.url)
|
||||
if df is None:
|
||||
raise ValueError(f"Failed to load dataset from {self.dataset_def.url}")
|
||||
if self.dataset_def.source.type == "uri":
|
||||
self.df = get_dataframe_from_uri(self.dataset_def.source.uri)
|
||||
elif self.dataset_def.source.type == "rows":
|
||||
self.df = pandas.DataFrame(self.dataset_def.source.rows)
|
||||
else:
|
||||
raise ValueError(f"Unsupported dataset source type: {self.dataset_def.source.type}")
|
||||
|
||||
self.df = self._validate_dataset_schema(df)
|
||||
if self.df is None:
|
||||
raise ValueError(f"Failed to load dataset from {self.dataset_def.url}")
|
||||
|
||||
|
||||
class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
||||
|
|
@ -99,95 +66,55 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
|||
|
||||
for dataset in stored_datasets:
|
||||
dataset = Dataset.model_validate_json(dataset)
|
||||
dataset_impl = PandasDataframeDataset(dataset)
|
||||
self.dataset_infos[dataset.identifier] = DatasetInfo(
|
||||
dataset_def=dataset,
|
||||
dataset_impl=dataset_impl,
|
||||
)
|
||||
self.dataset_infos[dataset.identifier] = dataset
|
||||
|
||||
async def shutdown(self) -> None: ...
|
||||
|
||||
async def register_dataset(
|
||||
self,
|
||||
dataset: Dataset,
|
||||
dataset_def: Dataset,
|
||||
) -> None:
|
||||
# Store in kvstore
|
||||
key = f"{DATASETS_PREFIX}{dataset.identifier}"
|
||||
key = f"{DATASETS_PREFIX}{dataset_def.identifier}"
|
||||
await self.kvstore.set(
|
||||
key=key,
|
||||
value=dataset.json(),
|
||||
)
|
||||
dataset_impl = PandasDataframeDataset(dataset)
|
||||
self.dataset_infos[dataset.identifier] = DatasetInfo(
|
||||
dataset_def=dataset,
|
||||
dataset_impl=dataset_impl,
|
||||
value=dataset_def.model_dump_json(),
|
||||
)
|
||||
self.dataset_infos[dataset_def.identifier] = dataset_def
|
||||
|
||||
async def unregister_dataset(self, dataset_id: str) -> None:
|
||||
key = f"{DATASETS_PREFIX}{dataset_id}"
|
||||
await self.kvstore.delete(key=key)
|
||||
del self.dataset_infos[dataset_id]
|
||||
|
||||
async def get_rows_paginated(
|
||||
async def iterrows(
|
||||
self,
|
||||
dataset_id: str,
|
||||
rows_in_page: int,
|
||||
page_token: Optional[str] = None,
|
||||
filter_condition: Optional[str] = None,
|
||||
) -> PaginatedRowsResult:
|
||||
dataset_info = self.dataset_infos.get(dataset_id)
|
||||
dataset_info.dataset_impl.load()
|
||||
start_index: Optional[int] = None,
|
||||
limit: Optional[int] = None,
|
||||
) -> IterrowsResponse:
|
||||
dataset_def = self.dataset_infos[dataset_id]
|
||||
dataset_impl = PandasDataframeDataset(dataset_def)
|
||||
dataset_impl.load()
|
||||
|
||||
if page_token and not page_token.isnumeric():
|
||||
raise ValueError("Invalid page_token")
|
||||
start_index = start_index or 0
|
||||
|
||||
if page_token is None or len(page_token) == 0:
|
||||
next_page_token = 0
|
||||
if limit is None or limit == -1:
|
||||
end = len(dataset_impl)
|
||||
else:
|
||||
next_page_token = int(page_token)
|
||||
end = min(start_index + limit, len(dataset_impl))
|
||||
|
||||
start = next_page_token
|
||||
if rows_in_page == -1:
|
||||
end = len(dataset_info.dataset_impl)
|
||||
else:
|
||||
end = min(start + rows_in_page, len(dataset_info.dataset_impl))
|
||||
rows = dataset_impl[start_index:end]
|
||||
|
||||
rows = dataset_info.dataset_impl[start:end]
|
||||
|
||||
return PaginatedRowsResult(
|
||||
rows=rows,
|
||||
total_count=len(rows),
|
||||
next_page_token=str(end),
|
||||
return IterrowsResponse(
|
||||
data=rows,
|
||||
next_start_index=end if end < len(dataset_impl) else None,
|
||||
)
|
||||
|
||||
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
|
||||
dataset_info = self.dataset_infos.get(dataset_id)
|
||||
if dataset_info is None:
|
||||
raise ValueError(f"Dataset with id {dataset_id} not found")
|
||||
|
||||
dataset_impl = dataset_info.dataset_impl
|
||||
dataset_def = self.dataset_infos[dataset_id]
|
||||
dataset_impl = PandasDataframeDataset(dataset_def)
|
||||
dataset_impl.load()
|
||||
|
||||
new_rows_df = pandas.DataFrame(rows)
|
||||
new_rows_df = dataset_impl._validate_dataset_schema(new_rows_df)
|
||||
dataset_impl.df = pandas.concat([dataset_impl.df, new_rows_df], ignore_index=True)
|
||||
|
||||
url = str(dataset_info.dataset_def.url.uri)
|
||||
parsed_url = urlparse(url)
|
||||
|
||||
if parsed_url.scheme == "file" or not parsed_url.scheme:
|
||||
file_path = parsed_url.path
|
||||
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||||
dataset_impl.df.to_csv(file_path, index=False)
|
||||
elif parsed_url.scheme == "data":
|
||||
# For data URLs, we need to update the base64-encoded content
|
||||
if not parsed_url.path.startswith("text/csv;base64,"):
|
||||
raise ValueError("Data URL must be a base64-encoded CSV")
|
||||
|
||||
csv_buffer = dataset_impl.df.to_csv(index=False)
|
||||
base64_content = base64.b64encode(csv_buffer.encode("utf-8")).decode("utf-8")
|
||||
dataset_info.dataset_def.url = URL(uri=f"data:text/csv;base64,{base64_content}")
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported URL scheme: {parsed_url.scheme}. Only file:// and data: URLs are supported for writing."
|
||||
)
|
||||
|
|
|
|||
|
|
@ -14,16 +14,11 @@ from llama_stack.apis.datasetio import DatasetIO
|
|||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.inference import Inference, SystemMessage, UserMessage
|
||||
from llama_stack.apis.scoring import Scoring
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
from llama_stack.providers.datatypes import BenchmarksProtocolPrivate
|
||||
from llama_stack.providers.inline.agents.meta_reference.agent_instance import (
|
||||
MEMORY_QUERY_TOOL,
|
||||
)
|
||||
from llama_stack.providers.utils.common.data_schema_validator import (
|
||||
ColumnName,
|
||||
get_valid_schemas,
|
||||
validate_dataset_schema,
|
||||
)
|
||||
from llama_stack.providers.utils.common.data_schema_validator import ColumnName
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
|
||||
from .....apis.common.job_types import Job
|
||||
|
|
@ -88,15 +83,17 @@ class MetaReferenceEvalImpl(
|
|||
task_def = self.benchmarks[benchmark_id]
|
||||
dataset_id = task_def.dataset_id
|
||||
scoring_functions = task_def.scoring_functions
|
||||
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
|
||||
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.eval.value))
|
||||
all_rows = await self.datasetio_api.get_rows_paginated(
|
||||
|
||||
# TODO (xiyan): validate dataset schema
|
||||
# dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
|
||||
|
||||
all_rows = await self.datasetio_api.iterrows(
|
||||
dataset_id=dataset_id,
|
||||
rows_in_page=(-1 if benchmark_config.num_examples is None else benchmark_config.num_examples),
|
||||
limit=(-1 if benchmark_config.num_examples is None else benchmark_config.num_examples),
|
||||
)
|
||||
res = await self.evaluate_rows(
|
||||
benchmark_id=benchmark_id,
|
||||
input_rows=all_rows.rows,
|
||||
input_rows=all_rows.data,
|
||||
scoring_functions=scoring_functions,
|
||||
benchmark_config=benchmark_config,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
import multiprocessing
|
||||
|
|
@ -213,7 +214,7 @@ def maybe_parse_message(maybe_json: Optional[str]) -> Optional[ProcessingMessage
|
|||
|
||||
def parse_message(json_str: str) -> ProcessingMessage:
|
||||
data = json.loads(json_str)
|
||||
return ProcessingMessageWrapper(**data).payload
|
||||
return copy.deepcopy(ProcessingMessageWrapper(**data).payload)
|
||||
|
||||
|
||||
def worker_process_entrypoint(
|
||||
|
|
|
|||
|
|
@ -9,6 +9,9 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.common.type_system import (
|
||||
ChatCompletionInputType,
|
||||
DialogType,
|
||||
|
|
@ -20,7 +23,7 @@ from llama_stack.providers.utils.common.data_schema_validator import (
|
|||
validate_dataset_schema,
|
||||
)
|
||||
|
||||
EXPECTED_DATASET_SCHEMA = {
|
||||
EXPECTED_DATASET_SCHEMA: dict[str, list[dict[str, Any]]] = {
|
||||
"instruct": [
|
||||
{
|
||||
ColumnName.chat_completion_input.value: ChatCompletionInputType(),
|
||||
|
|
@ -41,6 +44,9 @@ async def validate_input_dataset_schema(
|
|||
dataset_type: str,
|
||||
) -> None:
|
||||
dataset_def = await datasets_api.get_dataset(dataset_id=dataset_id)
|
||||
if not dataset_def:
|
||||
raise ValueError(f"Dataset {dataset_id} does not exist.")
|
||||
|
||||
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
|
||||
raise ValueError(f"Dataset {dataset_id} does not have a schema defined.")
|
||||
|
||||
|
|
|
|||
|
|
@ -37,7 +37,7 @@ class TorchtuneCheckpointer:
|
|||
checkpoint_files: List[str],
|
||||
output_dir: str,
|
||||
model_type: str,
|
||||
) -> None:
|
||||
):
|
||||
# Fail fast if ``checkpoint_files`` is invalid
|
||||
# TODO: support loading more than one file
|
||||
if len(checkpoint_files) != 1:
|
||||
|
|
@ -58,7 +58,7 @@ class TorchtuneCheckpointer:
|
|||
"""
|
||||
Load Meta checkpoint from file. Currently only loading from a single file is supported.
|
||||
"""
|
||||
state_dict: Dict[str:Any] = {}
|
||||
state_dict: Dict[str, Any] = {}
|
||||
model_state_dict = safe_torch_load(self._checkpoint_path)
|
||||
if self._model_type == ModelType.LLAMA3_VISION:
|
||||
from torchtune.models.llama3_2_vision._convert_weights import (
|
||||
|
|
@ -85,10 +85,10 @@ class TorchtuneCheckpointer:
|
|||
state_dict: Dict[str, Any],
|
||||
epoch: int,
|
||||
adapter_only: bool = False,
|
||||
checkpoint_format: str = "meta",
|
||||
checkpoint_format: str | None = None,
|
||||
) -> str:
|
||||
model_file_path = Path(self._output_dir) / f"{self._model_id}-{self._training_algorithm}-{epoch}"
|
||||
if checkpoint_format == "meta":
|
||||
if checkpoint_format == "meta" or checkpoint_format is None:
|
||||
self._save_meta_format_checkpoint(model_file_path, state_dict, adapter_only)
|
||||
elif checkpoint_format == "huggingface":
|
||||
# Note: for saving hugging face format checkpoints, we only suppport saving adapter weights now
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Callable, Dict
|
||||
from typing import Callable, Dict
|
||||
|
||||
import torch
|
||||
from pydantic import BaseModel
|
||||
|
|
@ -25,10 +25,13 @@ from llama_stack.apis.post_training import DatasetFormat
|
|||
from llama_stack.models.llama.datatypes import Model
|
||||
from llama_stack.models.llama.sku_list import resolve_model
|
||||
|
||||
BuildLoraModelCallable = Callable[..., torch.nn.Module]
|
||||
BuildTokenizerCallable = Callable[..., Llama3Tokenizer]
|
||||
|
||||
|
||||
class ModelConfig(BaseModel):
|
||||
model_definition: Any
|
||||
tokenizer_type: Any
|
||||
model_definition: BuildLoraModelCallable
|
||||
tokenizer_type: BuildTokenizerCallable
|
||||
checkpoint_type: str
|
||||
|
||||
|
||||
|
|
@ -51,10 +54,6 @@ DATA_FORMATS: Dict[str, Transform] = {
|
|||
}
|
||||
|
||||
|
||||
BuildLoraModelCallable = Callable[..., torch.nn.Module]
|
||||
BuildTokenizerCallable = Callable[..., Llama3Tokenizer]
|
||||
|
||||
|
||||
def _validate_model_id(model_id: str) -> Model:
|
||||
model = resolve_model(model_id)
|
||||
if model is None or model.core_model_id.value not in MODEL_CONFIGS:
|
||||
|
|
|
|||
|
|
@ -55,7 +55,7 @@ class SFTDataset(Dataset):
|
|||
if "messages" in transformed_sample:
|
||||
validate_messages(transformed_sample["messages"])
|
||||
|
||||
tokenized_dict = self._model_transform(transformed_sample)
|
||||
tokenized_dict: dict[str, Any] = self._model_transform(transformed_sample)
|
||||
|
||||
if not ("tokens" in tokenized_dict and "mask" in tokenized_dict):
|
||||
keys_str = ", ".join(tokenized_dict.keys())
|
||||
|
|
|
|||
|
|
@ -37,10 +37,10 @@ from llama_stack.apis.common.training_types import PostTrainingMetric
|
|||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.post_training import (
|
||||
AlgorithmConfig,
|
||||
Checkpoint,
|
||||
LoraFinetuningConfig,
|
||||
OptimizerConfig,
|
||||
QATFinetuningConfig,
|
||||
TrainingConfig,
|
||||
)
|
||||
from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
|
||||
|
|
@ -73,6 +73,9 @@ class LoraFinetuningSingleDevice:
|
|||
|
||||
# Currently logging only logs limited training metrics to local disk
|
||||
# will figure out more loggings and how it works with telemetry in future PRs
|
||||
|
||||
_checkpointer: TorchtuneCheckpointer
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: TorchtunePostTrainingConfig,
|
||||
|
|
@ -82,7 +85,7 @@ class LoraFinetuningSingleDevice:
|
|||
logger_config: Dict[str, Any],
|
||||
model: str,
|
||||
checkpoint_dir: Optional[str],
|
||||
algorithm_config: Optional[AlgorithmConfig],
|
||||
algorithm_config: LoraFinetuningConfig | QATFinetuningConfig | None,
|
||||
datasetio_api: DatasetIO,
|
||||
datasets_api: Datasets,
|
||||
) -> None:
|
||||
|
|
@ -109,12 +112,12 @@ class LoraFinetuningSingleDevice:
|
|||
return str(checkpoint_dir)
|
||||
|
||||
if checkpoint_dir and checkpoint_dir != "null":
|
||||
self.checkpoint_dir = config.checkpoint_dir
|
||||
self.checkpoint_dir = checkpoint_dir
|
||||
else:
|
||||
model = resolve_model(self.model_id)
|
||||
if model is None:
|
||||
model_obj = resolve_model(self.model_id)
|
||||
if model_obj is None:
|
||||
raise ValueError(f"{self.model_id} not found. Your model id should be in the llama models SKU list")
|
||||
self.checkpoint_dir = model_checkpoint_dir(model)
|
||||
self.checkpoint_dir = model_checkpoint_dir(model_obj)
|
||||
|
||||
self._output_dir = str(DEFAULT_CHECKPOINT_DIR)
|
||||
self._checkpoint_format = config.checkpoint_format
|
||||
|
|
@ -135,16 +138,16 @@ class LoraFinetuningSingleDevice:
|
|||
self.max_validation_steps = training_config.max_validation_steps
|
||||
|
||||
self._clip_grad_norm = 1.0
|
||||
self._enable_activation_checkpointing = (
|
||||
(training_config.efficiency_config.enable_activation_checkpointing)
|
||||
if training_config.efficiency_config
|
||||
else False
|
||||
)
|
||||
self._enable_activation_offloading = (
|
||||
(training_config.efficiency_config.enable_activation_offloading)
|
||||
if training_config.efficiency_config
|
||||
else False
|
||||
)
|
||||
|
||||
self._enable_activation_checkpointing = False
|
||||
self._enable_activation_offloading = False
|
||||
if training_config.efficiency_config:
|
||||
if training_config.efficiency_config.enable_activation_checkpointing:
|
||||
self._enable_activation_checkpointing = (
|
||||
training_config.efficiency_config.enable_activation_checkpointing
|
||||
)
|
||||
if training_config.efficiency_config.enable_activation_offloading:
|
||||
self._enable_activation_offloading = training_config.efficiency_config.enable_activation_offloading
|
||||
|
||||
self.datasetio_api = datasetio_api
|
||||
self.datasets_api = datasets_api
|
||||
|
|
@ -328,13 +331,13 @@ class LoraFinetuningSingleDevice:
|
|||
batch_size: int,
|
||||
) -> Tuple[DistributedSampler, DataLoader]:
|
||||
async def fetch_rows(dataset_id: str):
|
||||
return await self.datasetio_api.get_rows_paginated(
|
||||
return await self.datasetio_api.iterrows(
|
||||
dataset_id=dataset_id,
|
||||
rows_in_page=-1,
|
||||
limit=-1,
|
||||
)
|
||||
|
||||
all_rows = await fetch_rows(dataset_id)
|
||||
rows = all_rows.rows
|
||||
rows = all_rows.data
|
||||
|
||||
await validate_input_dataset_schema(
|
||||
datasets_api=self.datasets_api,
|
||||
|
|
@ -451,12 +454,12 @@ class LoraFinetuningSingleDevice:
|
|||
"""
|
||||
# Initialize tokens count and running loss (for grad accumulation)
|
||||
t0 = time.perf_counter()
|
||||
running_loss = 0
|
||||
running_loss: float = 0.0
|
||||
num_tokens = 0
|
||||
|
||||
# training artifacts
|
||||
checkpoints = []
|
||||
memory_stats = {}
|
||||
memory_stats: Dict[str, Any] = {}
|
||||
|
||||
# self.epochs_run should be non-zero when we're resuming from a checkpoint
|
||||
for curr_epoch in range(self.epochs_run, self.total_epochs):
|
||||
|
|
@ -484,7 +487,7 @@ class LoraFinetuningSingleDevice:
|
|||
# Loss is normalized by default so we multiply by the number of tokens
|
||||
# This way we can normalize by the total number of tokens if we're accumulating gradients
|
||||
current_loss = await self._loss_step(batch) * current_num_tokens
|
||||
running_loss += current_loss
|
||||
running_loss += current_loss.detach().item()
|
||||
current_loss.backward()
|
||||
|
||||
# Step with optimizer
|
||||
|
|
@ -500,7 +503,7 @@ class LoraFinetuningSingleDevice:
|
|||
# Update the number of steps when the weights are updated
|
||||
self.global_step += 1
|
||||
|
||||
loss_to_log = running_loss.item() / num_tokens
|
||||
loss_to_log = running_loss / num_tokens
|
||||
|
||||
pbar.update(1)
|
||||
pbar.set_description(f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}")
|
||||
|
|
@ -523,7 +526,7 @@ class LoraFinetuningSingleDevice:
|
|||
)
|
||||
|
||||
# Reset running stats for the next step
|
||||
running_loss = 0
|
||||
running_loss = 0.0
|
||||
num_tokens = 0
|
||||
t0 = time.perf_counter()
|
||||
|
||||
|
|
|
|||
|
|
@ -227,13 +227,6 @@ class LlamaGuardShield:
|
|||
if len(messages) >= 2 and (messages[0].role == Role.user.value and messages[1].role == Role.user.value):
|
||||
messages = messages[1:]
|
||||
|
||||
for i in range(1, len(messages)):
|
||||
if messages[i].role == messages[i - 1].role:
|
||||
for i, m in enumerate(messages):
|
||||
print(f"{i}: {m.role}: {m.content}")
|
||||
raise ValueError(
|
||||
f"Messages must alternate between user and assistant. Message {i} has the same role as message {i - 1}"
|
||||
)
|
||||
return messages
|
||||
|
||||
async def run(self, messages: List[Message]) -> RunShieldResponse:
|
||||
|
|
|
|||
|
|
@ -24,7 +24,9 @@ from llama_stack.providers.utils.common.data_schema_validator import (
|
|||
from .config import BasicScoringConfig
|
||||
from .scoring_fn.bfcl_scoring_fn import BFCLScoringFn
|
||||
from .scoring_fn.equality_scoring_fn import EqualityScoringFn
|
||||
from .scoring_fn.regex_parser_math_response_scoring_fn import RegexParserMathResponseScoringFn
|
||||
from .scoring_fn.regex_parser_math_response_scoring_fn import (
|
||||
RegexParserMathResponseScoringFn,
|
||||
)
|
||||
from .scoring_fn.regex_parser_scoring_fn import RegexParserScoringFn
|
||||
from .scoring_fn.subset_of_scoring_fn import SubsetOfScoringFn
|
||||
|
||||
|
|
@ -82,12 +84,12 @@ class BasicScoringImpl(
|
|||
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
|
||||
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value))
|
||||
|
||||
all_rows = await self.datasetio_api.get_rows_paginated(
|
||||
all_rows = await self.datasetio_api.iterrows(
|
||||
dataset_id=dataset_id,
|
||||
rows_in_page=-1,
|
||||
limit=-1,
|
||||
)
|
||||
res = await self.score(
|
||||
input_rows=all_rows.rows,
|
||||
input_rows=all_rows.data,
|
||||
scoring_functions=scoring_functions,
|
||||
)
|
||||
if save_results_dataset:
|
||||
|
|
|
|||
|
|
@ -167,11 +167,11 @@ class BraintrustScoringImpl(
|
|||
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
|
||||
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value))
|
||||
|
||||
all_rows = await self.datasetio_api.get_rows_paginated(
|
||||
all_rows = await self.datasetio_api.iterrows(
|
||||
dataset_id=dataset_id,
|
||||
rows_in_page=-1,
|
||||
limit=-1,
|
||||
)
|
||||
res = await self.score(input_rows=all_rows.rows, scoring_functions=scoring_functions)
|
||||
res = await self.score(input_rows=all_rows.data, scoring_functions=scoring_functions)
|
||||
if save_results_dataset:
|
||||
# TODO: persist and register dataset on to server for reading
|
||||
# self.datasets_api.register_dataset()
|
||||
|
|
|
|||
|
|
@ -72,12 +72,12 @@ class LlmAsJudgeScoringImpl(
|
|||
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
|
||||
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value))
|
||||
|
||||
all_rows = await self.datasetio_api.get_rows_paginated(
|
||||
all_rows = await self.datasetio_api.iterrows(
|
||||
dataset_id=dataset_id,
|
||||
rows_in_page=-1,
|
||||
limit=-1,
|
||||
)
|
||||
res = await self.score(
|
||||
input_rows=all_rows.rows,
|
||||
input_rows=all_rows.data,
|
||||
scoring_functions=scoring_functions,
|
||||
)
|
||||
if save_results_dataset:
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
|
|
@ -37,7 +38,7 @@ class CodeInterpreterToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
|||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def register_tool(self, tool: Tool):
|
||||
async def register_tool(self, tool: Tool) -> None:
|
||||
pass
|
||||
|
||||
async def unregister_tool(self, tool_id: str) -> None:
|
||||
|
|
@ -65,7 +66,7 @@ class CodeInterpreterToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
|||
# Use environment variable to control bwrap usage
|
||||
force_disable_bwrap = os.environ.get("DISABLE_CODE_SANDBOX", "").lower() in ("1", "true", "yes")
|
||||
req = CodeExecutionRequest(scripts=[script], use_bwrap=not force_disable_bwrap)
|
||||
res = self.code_executor.execute(req)
|
||||
res = await asyncio.to_thread(self.code_executor.execute, req)
|
||||
pieces = [res["process_status"]]
|
||||
for out_type in ["stdout", "stderr"]:
|
||||
res_out = res[out_type]
|
||||
|
|
|
|||
19
llama_stack/providers/inline/vector_io/qdrant/__init__.py
Normal file
19
llama_stack/providers/inline/vector_io/qdrant/__init__.py
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Dict
|
||||
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||
|
||||
from .config import QdrantVectorIOConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: QdrantVectorIOConfig, deps: Dict[Api, ProviderSpec]):
|
||||
from llama_stack.providers.remote.vector_io.qdrant.qdrant import QdrantVectorIOAdapter
|
||||
|
||||
impl = QdrantVectorIOAdapter(config, deps[Api.inference])
|
||||
await impl.initialize()
|
||||
return impl
|
||||
23
llama_stack/providers/inline/vector_io/qdrant/config.py
Normal file
23
llama_stack/providers/inline/vector_io/qdrant/config.py
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class QdrantVectorIOConfig(BaseModel):
|
||||
path: str
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str) -> Dict[str, Any]:
|
||||
return {
|
||||
"path": "${env.QDRANT_PATH:~/.llama/" + __distro_dir__ + "}/" + "qdrant.db",
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue